Use generic record reader/writer for key agreement.

This commit is contained in:
akwizgran
2018-04-19 11:55:58 +01:00
parent cc2791c37f
commit 6fa6ceb5ee
5 changed files with 216 additions and 175 deletions

View File

@@ -13,6 +13,7 @@ import org.briarproject.bramble.keyagreement.KeyAgreementModule;
import org.briarproject.bramble.lifecycle.LifecycleModule; import org.briarproject.bramble.lifecycle.LifecycleModule;
import org.briarproject.bramble.plugin.PluginModule; import org.briarproject.bramble.plugin.PluginModule;
import org.briarproject.bramble.properties.PropertiesModule; import org.briarproject.bramble.properties.PropertiesModule;
import org.briarproject.bramble.record.RecordModule;
import org.briarproject.bramble.reliability.ReliabilityModule; import org.briarproject.bramble.reliability.ReliabilityModule;
import org.briarproject.bramble.reporting.ReportingModule; import org.briarproject.bramble.reporting.ReportingModule;
import org.briarproject.bramble.settings.SettingsModule; import org.briarproject.bramble.settings.SettingsModule;
@@ -38,6 +39,7 @@ import dagger.Module;
LifecycleModule.class, LifecycleModule.class,
PluginModule.class, PluginModule.class,
PropertiesModule.class, PropertiesModule.class,
RecordModule.class,
ReliabilityModule.class, ReliabilityModule.class,
ReportingModule.class, ReportingModule.class,
SettingsModule.class, SettingsModule.class,

View File

@@ -13,6 +13,8 @@ import org.briarproject.bramble.api.plugin.PluginManager;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexPlugin; import org.briarproject.bramble.api.plugin.duplex.DuplexPlugin;
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection; import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@@ -44,6 +46,8 @@ class KeyAgreementConnector {
private final KeyAgreementCrypto keyAgreementCrypto; private final KeyAgreementCrypto keyAgreementCrypto;
private final PluginManager pluginManager; private final PluginManager pluginManager;
private final ConnectionChooser connectionChooser; private final ConnectionChooser connectionChooser;
private final RecordReaderFactory recordReaderFactory;
private final RecordWriterFactory recordWriterFactory;
private final List<KeyAgreementListener> listeners = private final List<KeyAgreementListener> listeners =
new CopyOnWriteArrayList<>(); new CopyOnWriteArrayList<>();
@@ -54,11 +58,15 @@ class KeyAgreementConnector {
KeyAgreementConnector(Callbacks callbacks, KeyAgreementConnector(Callbacks callbacks,
KeyAgreementCrypto keyAgreementCrypto, PluginManager pluginManager, KeyAgreementCrypto keyAgreementCrypto, PluginManager pluginManager,
ConnectionChooser connectionChooser) { ConnectionChooser connectionChooser,
RecordReaderFactory recordReaderFactory,
RecordWriterFactory recordWriterFactory) {
this.callbacks = callbacks; this.callbacks = callbacks;
this.keyAgreementCrypto = keyAgreementCrypto; this.keyAgreementCrypto = keyAgreementCrypto;
this.pluginManager = pluginManager; this.pluginManager = pluginManager;
this.connectionChooser = connectionChooser; this.connectionChooser = connectionChooser;
this.recordReaderFactory = recordReaderFactory;
this.recordWriterFactory = recordWriterFactory;
} }
Payload listen(KeyPair localKeyPair) { Payload listen(KeyPair localKeyPair) {
@@ -119,7 +127,8 @@ class KeyAgreementConnector {
KeyAgreementConnection chosen = KeyAgreementConnection chosen =
connectionChooser.poll(CONNECTION_TIMEOUT); connectionChooser.poll(CONNECTION_TIMEOUT);
if (chosen == null) return null; if (chosen == null) return null;
return new KeyAgreementTransport(chosen); return new KeyAgreementTransport(recordReaderFactory,
recordWriterFactory, chosen);
} catch (InterruptedException e) { } catch (InterruptedException e) {
LOG.info("Interrupted while waiting for connection"); LOG.info("Interrupted while waiting for connection");
Thread.currentThread().interrupt(); Thread.currentThread().interrupt();

View File

@@ -19,6 +19,8 @@ import org.briarproject.bramble.api.keyagreement.event.KeyAgreementWaitingEvent;
import org.briarproject.bramble.api.nullsafety.MethodsNotNullByDefault; import org.briarproject.bramble.api.nullsafety.MethodsNotNullByDefault;
import org.briarproject.bramble.api.nullsafety.ParametersNotNullByDefault; import org.briarproject.bramble.api.nullsafety.ParametersNotNullByDefault;
import org.briarproject.bramble.api.plugin.PluginManager; import org.briarproject.bramble.api.plugin.PluginManager;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import java.io.IOException; import java.io.IOException;
import java.util.logging.Logger; import java.util.logging.Logger;
@@ -49,14 +51,17 @@ class KeyAgreementTaskImpl extends Thread implements KeyAgreementTask,
KeyAgreementTaskImpl(CryptoComponent crypto, KeyAgreementTaskImpl(CryptoComponent crypto,
KeyAgreementCrypto keyAgreementCrypto, EventBus eventBus, KeyAgreementCrypto keyAgreementCrypto, EventBus eventBus,
PayloadEncoder payloadEncoder, PluginManager pluginManager, PayloadEncoder payloadEncoder, PluginManager pluginManager,
ConnectionChooser connectionChooser) { ConnectionChooser connectionChooser,
RecordReaderFactory recordReaderFactory,
RecordWriterFactory recordWriterFactory) {
this.crypto = crypto; this.crypto = crypto;
this.keyAgreementCrypto = keyAgreementCrypto; this.keyAgreementCrypto = keyAgreementCrypto;
this.eventBus = eventBus; this.eventBus = eventBus;
this.payloadEncoder = payloadEncoder; this.payloadEncoder = payloadEncoder;
localKeyPair = crypto.generateAgreementKeyPair(); localKeyPair = crypto.generateAgreementKeyPair();
connector = new KeyAgreementConnector(this, keyAgreementCrypto, connector = new KeyAgreementConnector(this, keyAgreementCrypto,
pluginManager, connectionChooser); pluginManager, connectionChooser, recordReaderFactory,
recordWriterFactory);
} }
@Override @Override

View File

@@ -4,9 +4,12 @@ import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault; import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection; import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
import org.briarproject.bramble.util.ByteUtils; import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import java.io.EOFException;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
@@ -14,8 +17,6 @@ import java.util.logging.Logger;
import static java.util.logging.Level.WARNING; import static java.util.logging.Level.WARNING;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION; import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_LENGTH;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_PAYLOAD_LENGTH_OFFSET;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT; import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM; import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.KEY; import static org.briarproject.bramble.api.keyagreement.RecordTypes.KEY;
@@ -30,14 +31,17 @@ class KeyAgreementTransport {
Logger.getLogger(KeyAgreementTransport.class.getName()); Logger.getLogger(KeyAgreementTransport.class.getName());
private final KeyAgreementConnection kac; private final KeyAgreementConnection kac;
private final InputStream in; private final RecordReader reader;
private final OutputStream out; private final RecordWriter writer;
KeyAgreementTransport(KeyAgreementConnection kac) KeyAgreementTransport(RecordReaderFactory recordReaderFactory,
RecordWriterFactory recordWriterFactory, KeyAgreementConnection kac)
throws IOException { throws IOException {
this.kac = kac; this.kac = kac;
in = kac.getConnection().getReader().getInputStream(); InputStream in = kac.getConnection().getReader().getInputStream();
out = kac.getConnection().getWriter().getOutputStream(); reader = recordReaderFactory.createRecordReader(in);
OutputStream out = kac.getConnection().getWriter().getOutputStream();
writer = recordWriterFactory.createRecordWriter(out);
} }
public DuplexTransportConnection getConnection() { public DuplexTransportConnection getConnection() {
@@ -74,9 +78,8 @@ class KeyAgreementTransport {
tryToClose(exception); tryToClose(exception);
} }
public void tryToClose(boolean exception) { private void tryToClose(boolean exception) {
try { try {
LOG.info("Closing connection");
kac.getConnection().getReader().dispose(exception, true); kac.getConnection().getReader().dispose(exception, true);
kac.getConnection().getWriter().dispose(exception); kac.getConnection().getWriter().dispose(exception);
} catch (IOException e) { } catch (IOException e) {
@@ -85,59 +88,27 @@ class KeyAgreementTransport {
} }
private void writeRecord(byte type, byte[] payload) throws IOException { private void writeRecord(byte type, byte[] payload) throws IOException {
byte[] recordHeader = new byte[RECORD_HEADER_LENGTH]; writer.writeRecord(new Record(PROTOCOL_VERSION, type, payload));
recordHeader[0] = PROTOCOL_VERSION; writer.flush();
recordHeader[1] = type;
ByteUtils.writeUint16(payload.length, recordHeader,
RECORD_HEADER_PAYLOAD_LENGTH_OFFSET);
out.write(recordHeader);
out.write(payload);
out.flush();
} }
private byte[] readRecord(byte expectedType) throws AbortException { private byte[] readRecord(byte expectedType) throws AbortException {
while (true) { while (true) {
byte[] header = readHeader();
byte version = header[0], type = header[1];
int len = ByteUtils.readUint16(header,
RECORD_HEADER_PAYLOAD_LENGTH_OFFSET);
// Reject unrecognised protocol version
if (version != PROTOCOL_VERSION) throw new AbortException(false);
if (type == ABORT) throw new AbortException(true);
if (type == expectedType) {
try {
return readData(len);
} catch (IOException e) {
throw new AbortException(e);
}
}
// Reject recognised but unexpected record type
if (type == KEY || type == CONFIRM) throw new AbortException(false);
// Skip unrecognised record type
try { try {
readData(len); Record record = reader.readRecord();
// Reject unrecognised protocol version
if (record.getProtocolVersion() != PROTOCOL_VERSION)
throw new AbortException(false);
byte type = record.getRecordType();
if (type == ABORT) throw new AbortException(true);
if (type == expectedType) return record.getPayload();
// Reject recognised but unexpected record type
if (type == KEY || type == CONFIRM)
throw new AbortException(false);
// Skip unrecognised record type
} catch (IOException e) { } catch (IOException e) {
throw new AbortException(e); throw new AbortException(e);
} }
} }
} }
private byte[] readHeader() throws AbortException {
try {
return readData(RECORD_HEADER_LENGTH);
} catch (IOException e) {
throw new AbortException(e);
}
}
private byte[] readData(int len) throws IOException {
byte[] data = new byte[len];
int offset = 0;
while (offset < data.length) {
int read = in.read(data, offset, data.length - offset);
if (read == -1) throw new EOFException();
offset += read;
}
return data;
}
} }

View File

@@ -5,23 +5,31 @@ import org.briarproject.bramble.api.plugin.TransportConnectionReader;
import org.briarproject.bramble.api.plugin.TransportConnectionWriter; import org.briarproject.bramble.api.plugin.TransportConnectionWriter;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection; import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import org.briarproject.bramble.test.BrambleMockTestCase; import org.briarproject.bramble.test.BrambleMockTestCase;
import org.briarproject.bramble.test.TestUtils; import org.briarproject.bramble.test.CaptureArgumentAction;
import org.briarproject.bramble.util.ByteUtils;
import org.jmock.Expectations; import org.jmock.Expectations;
import org.jmock.lib.legacy.ClassImposteriser;
import org.junit.Test; import org.junit.Test;
import java.io.ByteArrayInputStream; import java.io.EOFException;
import java.io.ByteArrayOutputStream; import java.io.InputStream;
import java.io.OutputStream;
import java.util.concurrent.atomic.AtomicReference;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION; import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_LENGTH;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT; import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM; import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.KEY; import static org.briarproject.bramble.api.keyagreement.RecordTypes.KEY;
import static org.briarproject.bramble.test.TestUtils.getRandomBytes;
import static org.briarproject.bramble.test.TestUtils.getTransportId; import static org.briarproject.bramble.test.TestUtils.getTransportId;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class KeyAgreementTransportTest extends BrambleMockTestCase { public class KeyAgreementTransportTest extends BrambleMockTestCase {
@@ -31,222 +39,268 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
context.mock(TransportConnectionReader.class); context.mock(TransportConnectionReader.class);
private final TransportConnectionWriter transportConnectionWriter = private final TransportConnectionWriter transportConnectionWriter =
context.mock(TransportConnectionWriter.class); context.mock(TransportConnectionWriter.class);
private final RecordReaderFactory recordReaderFactory =
context.mock(RecordReaderFactory.class);
private final RecordWriterFactory recordWriterFactory =
context.mock(RecordWriterFactory.class);
private final RecordReader recordReader = context.mock(RecordReader.class);
private final RecordWriter recordWriter = context.mock(RecordWriter.class);
private final TransportId transportId = getTransportId(); private final TransportId transportId = getTransportId();
private final KeyAgreementConnection keyAgreementConnection = private final KeyAgreementConnection keyAgreementConnection =
new KeyAgreementConnection(duplexTransportConnection, transportId); new KeyAgreementConnection(duplexTransportConnection, transportId);
private ByteArrayInputStream inputStream; private final InputStream inputStream;
private ByteArrayOutputStream outputStream; private final OutputStream outputStream;
private KeyAgreementTransport kat; private KeyAgreementTransport kat;
public KeyAgreementTransportTest() {
context.setImposteriser(ClassImposteriser.INSTANCE);
inputStream = context.mock(InputStream.class);
outputStream = context.mock(OutputStream.class);
}
@Test @Test
public void testSendKey() throws Exception { public void testSendKey() throws Exception {
setup(new byte[0]); byte[] key = getRandomBytes(123);
byte[] key = TestUtils.getRandomBytes(123);
setup();
AtomicReference<Record> written = expectWriteRecord();
kat.sendKey(key); kat.sendKey(key);
assertRecordSent(KEY, key); assertNotNull(written.get());
assertRecordEquals(PROTOCOL_VERSION, KEY, key, written.get());
} }
@Test @Test
public void testSendConfirm() throws Exception { public void testSendConfirm() throws Exception {
setup(new byte[0]); byte[] confirm = getRandomBytes(123);
byte[] confirm = TestUtils.getRandomBytes(123);
setup();
AtomicReference<Record> written = expectWriteRecord();
kat.sendConfirm(confirm); kat.sendConfirm(confirm);
assertRecordSent(CONFIRM, confirm); assertNotNull(written.get());
assertRecordEquals(PROTOCOL_VERSION, CONFIRM, confirm, written.get());
} }
@Test @Test
public void testSendAbortWithException() throws Exception { public void testSendAbortWithException() throws Exception {
setup(new byte[0]); setup();
AtomicReference<Record> written = expectWriteRecord();
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(transportConnectionReader).dispose(true, true); oneOf(transportConnectionReader).dispose(true, true);
oneOf(transportConnectionWriter).dispose(true); oneOf(transportConnectionWriter).dispose(true);
}}); }});
kat.sendAbort(true); kat.sendAbort(true);
assertRecordSent(ABORT, new byte[0]); assertNotNull(written.get());
assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get());
} }
@Test @Test
public void testSendAbortWithoutException() throws Exception { public void testSendAbortWithoutException() throws Exception {
setup(new byte[0]); setup();
AtomicReference<Record> written = expectWriteRecord();
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(transportConnectionReader).dispose(false, true); oneOf(transportConnectionReader).dispose(false, true);
oneOf(transportConnectionWriter).dispose(false); oneOf(transportConnectionWriter).dispose(false);
}}); }});
kat.sendAbort(false); kat.sendAbort(false);
assertRecordSent(ABORT, new byte[0]); assertNotNull(written.get());
assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get());
} }
@Test(expected = AbortException.class) @Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfAtEndOfStream() public void testReceiveKeyThrowsExceptionIfAtEndOfStream()
throws Exception { throws Exception {
setup(new byte[0]); setup();
kat.receiveKey(); context.checking(new Expectations() {{
} oneOf(recordReader).readRecord();
will(throwException(new EOFException()));
}});
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfHeaderIsTooShort()
throws Exception {
byte[] input = new byte[RECORD_HEADER_LENGTH - 1];
input[0] = PROTOCOL_VERSION;
input[1] = KEY;
setup(input);
kat.receiveKey();
}
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfPayloadIsTooShort()
throws Exception {
int payloadLength = 123;
byte[] input = new byte[RECORD_HEADER_LENGTH + payloadLength - 1];
input[0] = PROTOCOL_VERSION;
input[1] = KEY;
ByteUtils.writeUint16(payloadLength, input, 2);
setup(input);
kat.receiveKey(); kat.receiveKey();
} }
@Test(expected = AbortException.class) @Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfProtocolVersionIsUnrecognised() public void testReceiveKeyThrowsExceptionIfProtocolVersionIsUnrecognised()
throws Exception { throws Exception {
setup(createRecord((byte) (PROTOCOL_VERSION + 1), KEY, new byte[123])); byte unknownVersion = (byte) (PROTOCOL_VERSION + 1);
byte[] key = getRandomBytes(123);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(unknownVersion, KEY, key)));
}});
kat.receiveKey(); kat.receiveKey();
} }
@Test(expected = AbortException.class) @Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfAbortIsReceived() public void testReceiveKeyThrowsExceptionIfAbortIsReceived()
throws Exception { throws Exception {
setup(createRecord(PROTOCOL_VERSION, ABORT, new byte[0])); setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0])));
}});
kat.receiveKey(); kat.receiveKey();
} }
@Test(expected = AbortException.class) @Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfConfirmIsReceived() public void testReceiveKeyThrowsExceptionIfConfirmIsReceived()
throws Exception { throws Exception {
setup(createRecord(PROTOCOL_VERSION, CONFIRM, new byte[123])); byte[] confirm = getRandomBytes(123);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(PROTOCOL_VERSION, CONFIRM, confirm)));
}});
kat.receiveKey(); kat.receiveKey();
} }
@Test @Test
public void testReceiveKeySkipsUnrecognisedRecordTypes() throws Exception { public void testReceiveKeySkipsUnrecognisedRecordTypes() throws Exception {
byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 1), byte type1 = (byte) (ABORT + 1);
new byte[123]); byte[] payload1 = getRandomBytes(123);
byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 2), Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1);
new byte[0]); byte type2 = (byte) (ABORT + 2);
byte[] payload = TestUtils.getRandomBytes(123); byte[] payload2 = new byte[0];
byte[] key = createRecord(PROTOCOL_VERSION, KEY, payload); Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2);
ByteArrayOutputStream input = new ByteArrayOutputStream(); byte[] key = getRandomBytes(123);
input.write(skip1); Record keyRecord = new Record(PROTOCOL_VERSION, KEY, key);
input.write(skip2);
input.write(key); setup();
setup(input.toByteArray()); context.checking(new Expectations() {{
assertArrayEquals(payload, kat.receiveKey()); oneOf(recordReader).readRecord();
will(returnValue(unknownRecord1));
oneOf(recordReader).readRecord();
will(returnValue(unknownRecord2));
oneOf(recordReader).readRecord();
will(returnValue(keyRecord));
}});
assertArrayEquals(key, kat.receiveKey());
} }
@Test(expected = AbortException.class) @Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfAtEndOfStream() public void testReceiveConfirmThrowsExceptionIfAtEndOfStream()
throws Exception { throws Exception {
setup(new byte[0]); setup();
kat.receiveConfirm(); context.checking(new Expectations() {{
} oneOf(recordReader).readRecord();
will(throwException(new EOFException()));
}});
@Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfHeaderIsTooShort()
throws Exception {
byte[] input = new byte[RECORD_HEADER_LENGTH - 1];
input[0] = PROTOCOL_VERSION;
input[1] = CONFIRM;
setup(input);
kat.receiveConfirm();
}
@Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfPayloadIsTooShort()
throws Exception {
int payloadLength = 123;
byte[] input = new byte[RECORD_HEADER_LENGTH + payloadLength - 1];
input[0] = PROTOCOL_VERSION;
input[1] = CONFIRM;
ByteUtils.writeUint16(payloadLength, input, 2);
setup(input);
kat.receiveConfirm(); kat.receiveConfirm();
} }
@Test(expected = AbortException.class) @Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfProtocolVersionIsUnrecognised() public void testReceiveConfirmThrowsExceptionIfProtocolVersionIsUnrecognised()
throws Exception { throws Exception {
setup(createRecord((byte) (PROTOCOL_VERSION + 1), CONFIRM, byte unknownVersion = (byte) (PROTOCOL_VERSION + 1);
new byte[123])); byte[] confirm = getRandomBytes(123);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(unknownVersion, CONFIRM, confirm)));
}});
kat.receiveConfirm(); kat.receiveConfirm();
} }
@Test(expected = AbortException.class) @Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfAbortIsReceived() public void testReceiveConfirmThrowsExceptionIfAbortIsReceived()
throws Exception { throws Exception {
setup(createRecord(PROTOCOL_VERSION, ABORT, new byte[0])); setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0])));
}});
kat.receiveConfirm(); kat.receiveConfirm();
} }
@Test(expected = AbortException.class) @Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfKeyIsReceived() public void testReceiveConfirmThrowsExceptionIfKeyIsReceived()
throws Exception { throws Exception {
setup(createRecord(PROTOCOL_VERSION, KEY, new byte[123])); byte[] key = getRandomBytes(123);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(PROTOCOL_VERSION, KEY, key)));
}});
kat.receiveConfirm(); kat.receiveConfirm();
} }
@Test @Test
public void testReceiveConfirmSkipsUnrecognisedRecordTypes() public void testReceiveConfirmSkipsUnrecognisedRecordTypes()
throws Exception { throws Exception {
byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 1), byte type1 = (byte) (ABORT + 1);
new byte[123]); byte[] payload1 = getRandomBytes(123);
byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 2), Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1);
new byte[0]); byte type2 = (byte) (ABORT + 2);
byte[] payload = TestUtils.getRandomBytes(123); byte[] payload2 = new byte[0];
byte[] confirm = createRecord(PROTOCOL_VERSION, CONFIRM, payload); Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2);
ByteArrayOutputStream input = new ByteArrayOutputStream(); byte[] confirm = getRandomBytes(123);
input.write(skip1); Record confirmRecord = new Record(PROTOCOL_VERSION, CONFIRM, confirm);
input.write(skip2);
input.write(confirm); setup();
setup(input.toByteArray()); context.checking(new Expectations() {{
assertArrayEquals(payload, kat.receiveConfirm()); oneOf(recordReader).readRecord();
will(returnValue(unknownRecord1));
oneOf(recordReader).readRecord();
will(returnValue(unknownRecord2));
oneOf(recordReader).readRecord();
will(returnValue(confirmRecord));
}});
assertArrayEquals(confirm, kat.receiveConfirm());
} }
private void setup(byte[] input) throws Exception { private void setup() throws Exception {
inputStream = new ByteArrayInputStream(input);
outputStream = new ByteArrayOutputStream();
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(duplexTransportConnection).getReader(); allowing(duplexTransportConnection).getReader();
will(returnValue(transportConnectionReader)); will(returnValue(transportConnectionReader));
allowing(transportConnectionReader).getInputStream(); allowing(transportConnectionReader).getInputStream();
will(returnValue(inputStream)); will(returnValue(inputStream));
oneOf(recordReaderFactory).createRecordReader(inputStream);
will(returnValue(recordReader));
allowing(duplexTransportConnection).getWriter(); allowing(duplexTransportConnection).getWriter();
will(returnValue(transportConnectionWriter)); will(returnValue(transportConnectionWriter));
allowing(transportConnectionWriter).getOutputStream(); allowing(transportConnectionWriter).getOutputStream();
will(returnValue(outputStream)); will(returnValue(outputStream));
oneOf(recordWriterFactory).createRecordWriter(outputStream);
will(returnValue(recordWriter));
}}); }});
kat = new KeyAgreementTransport(keyAgreementConnection); kat = new KeyAgreementTransport(recordReaderFactory,
recordWriterFactory, keyAgreementConnection);
} }
private void assertRecordSent(byte expectedType, byte[] expectedPayload) { private AtomicReference<Record> expectWriteRecord() throws Exception {
byte[] output = outputStream.toByteArray(); AtomicReference<Record> captured = new AtomicReference<>();
assertEquals(RECORD_HEADER_LENGTH + expectedPayload.length, context.checking(new Expectations() {{
output.length); oneOf(recordWriter).writeRecord(with(any(Record.class)));
assertEquals(PROTOCOL_VERSION, output[0]); will(new CaptureArgumentAction<>(captured, Record.class, 0));
assertEquals(expectedType, output[1]); oneOf(recordWriter).flush();
assertEquals(expectedPayload.length, ByteUtils.readUint16(output, 2)); }});
byte[] payload = new byte[output.length - RECORD_HEADER_LENGTH]; return captured;
System.arraycopy(output, RECORD_HEADER_LENGTH, payload, 0,
payload.length);
assertArrayEquals(expectedPayload, payload);
} }
private byte[] createRecord(byte version, byte type, byte[] payload) { private void assertRecordEquals(byte expectedVersion, byte expectedType,
byte[] b = new byte[RECORD_HEADER_LENGTH + payload.length]; byte[] expectedPayload, Record actual) {
b[0] = version; assertEquals(expectedVersion, actual.getProtocolVersion());
b[1] = type; assertEquals(expectedType, actual.getRecordType());
ByteUtils.writeUint16(payload.length, b, 2); assertArrayEquals(expectedPayload, actual.getPayload());
System.arraycopy(payload, 0, b, RECORD_HEADER_LENGTH, payload.length);
return b;
} }
} }