mirror of
https://code.briarproject.org/briar/briar.git
synced 2026-02-12 18:59:06 +01:00
Use generic record reader/writer for key agreement.
This commit is contained in:
@@ -13,6 +13,7 @@ import org.briarproject.bramble.keyagreement.KeyAgreementModule;
|
||||
import org.briarproject.bramble.lifecycle.LifecycleModule;
|
||||
import org.briarproject.bramble.plugin.PluginModule;
|
||||
import org.briarproject.bramble.properties.PropertiesModule;
|
||||
import org.briarproject.bramble.record.RecordModule;
|
||||
import org.briarproject.bramble.reliability.ReliabilityModule;
|
||||
import org.briarproject.bramble.reporting.ReportingModule;
|
||||
import org.briarproject.bramble.settings.SettingsModule;
|
||||
@@ -38,6 +39,7 @@ import dagger.Module;
|
||||
LifecycleModule.class,
|
||||
PluginModule.class,
|
||||
PropertiesModule.class,
|
||||
RecordModule.class,
|
||||
ReliabilityModule.class,
|
||||
ReportingModule.class,
|
||||
SettingsModule.class,
|
||||
|
||||
@@ -13,6 +13,8 @@ import org.briarproject.bramble.api.plugin.PluginManager;
|
||||
import org.briarproject.bramble.api.plugin.TransportId;
|
||||
import org.briarproject.bramble.api.plugin.duplex.DuplexPlugin;
|
||||
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
|
||||
import org.briarproject.bramble.api.record.RecordReaderFactory;
|
||||
import org.briarproject.bramble.api.record.RecordWriterFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
@@ -44,6 +46,8 @@ class KeyAgreementConnector {
|
||||
private final KeyAgreementCrypto keyAgreementCrypto;
|
||||
private final PluginManager pluginManager;
|
||||
private final ConnectionChooser connectionChooser;
|
||||
private final RecordReaderFactory recordReaderFactory;
|
||||
private final RecordWriterFactory recordWriterFactory;
|
||||
|
||||
private final List<KeyAgreementListener> listeners =
|
||||
new CopyOnWriteArrayList<>();
|
||||
@@ -54,11 +58,15 @@ class KeyAgreementConnector {
|
||||
|
||||
KeyAgreementConnector(Callbacks callbacks,
|
||||
KeyAgreementCrypto keyAgreementCrypto, PluginManager pluginManager,
|
||||
ConnectionChooser connectionChooser) {
|
||||
ConnectionChooser connectionChooser,
|
||||
RecordReaderFactory recordReaderFactory,
|
||||
RecordWriterFactory recordWriterFactory) {
|
||||
this.callbacks = callbacks;
|
||||
this.keyAgreementCrypto = keyAgreementCrypto;
|
||||
this.pluginManager = pluginManager;
|
||||
this.connectionChooser = connectionChooser;
|
||||
this.recordReaderFactory = recordReaderFactory;
|
||||
this.recordWriterFactory = recordWriterFactory;
|
||||
}
|
||||
|
||||
Payload listen(KeyPair localKeyPair) {
|
||||
@@ -119,7 +127,8 @@ class KeyAgreementConnector {
|
||||
KeyAgreementConnection chosen =
|
||||
connectionChooser.poll(CONNECTION_TIMEOUT);
|
||||
if (chosen == null) return null;
|
||||
return new KeyAgreementTransport(chosen);
|
||||
return new KeyAgreementTransport(recordReaderFactory,
|
||||
recordWriterFactory, chosen);
|
||||
} catch (InterruptedException e) {
|
||||
LOG.info("Interrupted while waiting for connection");
|
||||
Thread.currentThread().interrupt();
|
||||
|
||||
@@ -19,6 +19,8 @@ import org.briarproject.bramble.api.keyagreement.event.KeyAgreementWaitingEvent;
|
||||
import org.briarproject.bramble.api.nullsafety.MethodsNotNullByDefault;
|
||||
import org.briarproject.bramble.api.nullsafety.ParametersNotNullByDefault;
|
||||
import org.briarproject.bramble.api.plugin.PluginManager;
|
||||
import org.briarproject.bramble.api.record.RecordReaderFactory;
|
||||
import org.briarproject.bramble.api.record.RecordWriterFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.logging.Logger;
|
||||
@@ -49,14 +51,17 @@ class KeyAgreementTaskImpl extends Thread implements KeyAgreementTask,
|
||||
KeyAgreementTaskImpl(CryptoComponent crypto,
|
||||
KeyAgreementCrypto keyAgreementCrypto, EventBus eventBus,
|
||||
PayloadEncoder payloadEncoder, PluginManager pluginManager,
|
||||
ConnectionChooser connectionChooser) {
|
||||
ConnectionChooser connectionChooser,
|
||||
RecordReaderFactory recordReaderFactory,
|
||||
RecordWriterFactory recordWriterFactory) {
|
||||
this.crypto = crypto;
|
||||
this.keyAgreementCrypto = keyAgreementCrypto;
|
||||
this.eventBus = eventBus;
|
||||
this.payloadEncoder = payloadEncoder;
|
||||
localKeyPair = crypto.generateAgreementKeyPair();
|
||||
connector = new KeyAgreementConnector(this, keyAgreementCrypto,
|
||||
pluginManager, connectionChooser);
|
||||
pluginManager, connectionChooser, recordReaderFactory,
|
||||
recordWriterFactory);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -4,9 +4,12 @@ import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection;
|
||||
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
|
||||
import org.briarproject.bramble.api.plugin.TransportId;
|
||||
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
|
||||
import org.briarproject.bramble.util.ByteUtils;
|
||||
import org.briarproject.bramble.api.record.Record;
|
||||
import org.briarproject.bramble.api.record.RecordReader;
|
||||
import org.briarproject.bramble.api.record.RecordReaderFactory;
|
||||
import org.briarproject.bramble.api.record.RecordWriter;
|
||||
import org.briarproject.bramble.api.record.RecordWriterFactory;
|
||||
|
||||
import java.io.EOFException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
@@ -14,8 +17,6 @@ import java.util.logging.Logger;
|
||||
|
||||
import static java.util.logging.Level.WARNING;
|
||||
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION;
|
||||
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_LENGTH;
|
||||
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_PAYLOAD_LENGTH_OFFSET;
|
||||
import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT;
|
||||
import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM;
|
||||
import static org.briarproject.bramble.api.keyagreement.RecordTypes.KEY;
|
||||
@@ -30,14 +31,17 @@ class KeyAgreementTransport {
|
||||
Logger.getLogger(KeyAgreementTransport.class.getName());
|
||||
|
||||
private final KeyAgreementConnection kac;
|
||||
private final InputStream in;
|
||||
private final OutputStream out;
|
||||
private final RecordReader reader;
|
||||
private final RecordWriter writer;
|
||||
|
||||
KeyAgreementTransport(KeyAgreementConnection kac)
|
||||
KeyAgreementTransport(RecordReaderFactory recordReaderFactory,
|
||||
RecordWriterFactory recordWriterFactory, KeyAgreementConnection kac)
|
||||
throws IOException {
|
||||
this.kac = kac;
|
||||
in = kac.getConnection().getReader().getInputStream();
|
||||
out = kac.getConnection().getWriter().getOutputStream();
|
||||
InputStream in = kac.getConnection().getReader().getInputStream();
|
||||
reader = recordReaderFactory.createRecordReader(in);
|
||||
OutputStream out = kac.getConnection().getWriter().getOutputStream();
|
||||
writer = recordWriterFactory.createRecordWriter(out);
|
||||
}
|
||||
|
||||
public DuplexTransportConnection getConnection() {
|
||||
@@ -74,9 +78,8 @@ class KeyAgreementTransport {
|
||||
tryToClose(exception);
|
||||
}
|
||||
|
||||
public void tryToClose(boolean exception) {
|
||||
private void tryToClose(boolean exception) {
|
||||
try {
|
||||
LOG.info("Closing connection");
|
||||
kac.getConnection().getReader().dispose(exception, true);
|
||||
kac.getConnection().getWriter().dispose(exception);
|
||||
} catch (IOException e) {
|
||||
@@ -85,59 +88,27 @@ class KeyAgreementTransport {
|
||||
}
|
||||
|
||||
private void writeRecord(byte type, byte[] payload) throws IOException {
|
||||
byte[] recordHeader = new byte[RECORD_HEADER_LENGTH];
|
||||
recordHeader[0] = PROTOCOL_VERSION;
|
||||
recordHeader[1] = type;
|
||||
ByteUtils.writeUint16(payload.length, recordHeader,
|
||||
RECORD_HEADER_PAYLOAD_LENGTH_OFFSET);
|
||||
out.write(recordHeader);
|
||||
out.write(payload);
|
||||
out.flush();
|
||||
writer.writeRecord(new Record(PROTOCOL_VERSION, type, payload));
|
||||
writer.flush();
|
||||
}
|
||||
|
||||
private byte[] readRecord(byte expectedType) throws AbortException {
|
||||
while (true) {
|
||||
byte[] header = readHeader();
|
||||
byte version = header[0], type = header[1];
|
||||
int len = ByteUtils.readUint16(header,
|
||||
RECORD_HEADER_PAYLOAD_LENGTH_OFFSET);
|
||||
// Reject unrecognised protocol version
|
||||
if (version != PROTOCOL_VERSION) throw new AbortException(false);
|
||||
if (type == ABORT) throw new AbortException(true);
|
||||
if (type == expectedType) {
|
||||
try {
|
||||
return readData(len);
|
||||
} catch (IOException e) {
|
||||
throw new AbortException(e);
|
||||
}
|
||||
}
|
||||
// Reject recognised but unexpected record type
|
||||
if (type == KEY || type == CONFIRM) throw new AbortException(false);
|
||||
// Skip unrecognised record type
|
||||
try {
|
||||
readData(len);
|
||||
Record record = reader.readRecord();
|
||||
// Reject unrecognised protocol version
|
||||
if (record.getProtocolVersion() != PROTOCOL_VERSION)
|
||||
throw new AbortException(false);
|
||||
byte type = record.getRecordType();
|
||||
if (type == ABORT) throw new AbortException(true);
|
||||
if (type == expectedType) return record.getPayload();
|
||||
// Reject recognised but unexpected record type
|
||||
if (type == KEY || type == CONFIRM)
|
||||
throw new AbortException(false);
|
||||
// Skip unrecognised record type
|
||||
} catch (IOException e) {
|
||||
throw new AbortException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] readHeader() throws AbortException {
|
||||
try {
|
||||
return readData(RECORD_HEADER_LENGTH);
|
||||
} catch (IOException e) {
|
||||
throw new AbortException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] readData(int len) throws IOException {
|
||||
byte[] data = new byte[len];
|
||||
int offset = 0;
|
||||
while (offset < data.length) {
|
||||
int read = in.read(data, offset, data.length - offset);
|
||||
if (read == -1) throw new EOFException();
|
||||
offset += read;
|
||||
}
|
||||
return data;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,23 +5,31 @@ import org.briarproject.bramble.api.plugin.TransportConnectionReader;
|
||||
import org.briarproject.bramble.api.plugin.TransportConnectionWriter;
|
||||
import org.briarproject.bramble.api.plugin.TransportId;
|
||||
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
|
||||
import org.briarproject.bramble.api.record.Record;
|
||||
import org.briarproject.bramble.api.record.RecordReader;
|
||||
import org.briarproject.bramble.api.record.RecordReaderFactory;
|
||||
import org.briarproject.bramble.api.record.RecordWriter;
|
||||
import org.briarproject.bramble.api.record.RecordWriterFactory;
|
||||
import org.briarproject.bramble.test.BrambleMockTestCase;
|
||||
import org.briarproject.bramble.test.TestUtils;
|
||||
import org.briarproject.bramble.util.ByteUtils;
|
||||
import org.briarproject.bramble.test.CaptureArgumentAction;
|
||||
import org.jmock.Expectations;
|
||||
import org.jmock.lib.legacy.ClassImposteriser;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.EOFException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION;
|
||||
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_LENGTH;
|
||||
import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT;
|
||||
import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM;
|
||||
import static org.briarproject.bramble.api.keyagreement.RecordTypes.KEY;
|
||||
import static org.briarproject.bramble.test.TestUtils.getRandomBytes;
|
||||
import static org.briarproject.bramble.test.TestUtils.getTransportId;
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
|
||||
public class KeyAgreementTransportTest extends BrambleMockTestCase {
|
||||
|
||||
@@ -31,222 +39,268 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
|
||||
context.mock(TransportConnectionReader.class);
|
||||
private final TransportConnectionWriter transportConnectionWriter =
|
||||
context.mock(TransportConnectionWriter.class);
|
||||
private final RecordReaderFactory recordReaderFactory =
|
||||
context.mock(RecordReaderFactory.class);
|
||||
private final RecordWriterFactory recordWriterFactory =
|
||||
context.mock(RecordWriterFactory.class);
|
||||
private final RecordReader recordReader = context.mock(RecordReader.class);
|
||||
private final RecordWriter recordWriter = context.mock(RecordWriter.class);
|
||||
|
||||
private final TransportId transportId = getTransportId();
|
||||
private final KeyAgreementConnection keyAgreementConnection =
|
||||
new KeyAgreementConnection(duplexTransportConnection, transportId);
|
||||
|
||||
private ByteArrayInputStream inputStream;
|
||||
private ByteArrayOutputStream outputStream;
|
||||
private final InputStream inputStream;
|
||||
private final OutputStream outputStream;
|
||||
|
||||
private KeyAgreementTransport kat;
|
||||
|
||||
public KeyAgreementTransportTest() {
|
||||
context.setImposteriser(ClassImposteriser.INSTANCE);
|
||||
inputStream = context.mock(InputStream.class);
|
||||
outputStream = context.mock(OutputStream.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSendKey() throws Exception {
|
||||
setup(new byte[0]);
|
||||
byte[] key = TestUtils.getRandomBytes(123);
|
||||
byte[] key = getRandomBytes(123);
|
||||
|
||||
setup();
|
||||
AtomicReference<Record> written = expectWriteRecord();
|
||||
|
||||
kat.sendKey(key);
|
||||
assertRecordSent(KEY, key);
|
||||
assertNotNull(written.get());
|
||||
assertRecordEquals(PROTOCOL_VERSION, KEY, key, written.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSendConfirm() throws Exception {
|
||||
setup(new byte[0]);
|
||||
byte[] confirm = TestUtils.getRandomBytes(123);
|
||||
byte[] confirm = getRandomBytes(123);
|
||||
|
||||
setup();
|
||||
AtomicReference<Record> written = expectWriteRecord();
|
||||
|
||||
kat.sendConfirm(confirm);
|
||||
assertRecordSent(CONFIRM, confirm);
|
||||
assertNotNull(written.get());
|
||||
assertRecordEquals(PROTOCOL_VERSION, CONFIRM, confirm, written.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSendAbortWithException() throws Exception {
|
||||
setup(new byte[0]);
|
||||
setup();
|
||||
AtomicReference<Record> written = expectWriteRecord();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(transportConnectionReader).dispose(true, true);
|
||||
oneOf(transportConnectionWriter).dispose(true);
|
||||
}});
|
||||
|
||||
kat.sendAbort(true);
|
||||
assertRecordSent(ABORT, new byte[0]);
|
||||
assertNotNull(written.get());
|
||||
assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSendAbortWithoutException() throws Exception {
|
||||
setup(new byte[0]);
|
||||
setup();
|
||||
AtomicReference<Record> written = expectWriteRecord();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(transportConnectionReader).dispose(false, true);
|
||||
oneOf(transportConnectionWriter).dispose(false);
|
||||
}});
|
||||
|
||||
kat.sendAbort(false);
|
||||
assertRecordSent(ABORT, new byte[0]);
|
||||
assertNotNull(written.get());
|
||||
assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get());
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveKeyThrowsExceptionIfAtEndOfStream()
|
||||
throws Exception {
|
||||
setup(new byte[0]);
|
||||
kat.receiveKey();
|
||||
}
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(throwException(new EOFException()));
|
||||
}});
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveKeyThrowsExceptionIfHeaderIsTooShort()
|
||||
throws Exception {
|
||||
byte[] input = new byte[RECORD_HEADER_LENGTH - 1];
|
||||
input[0] = PROTOCOL_VERSION;
|
||||
input[1] = KEY;
|
||||
setup(input);
|
||||
kat.receiveKey();
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveKeyThrowsExceptionIfPayloadIsTooShort()
|
||||
throws Exception {
|
||||
int payloadLength = 123;
|
||||
byte[] input = new byte[RECORD_HEADER_LENGTH + payloadLength - 1];
|
||||
input[0] = PROTOCOL_VERSION;
|
||||
input[1] = KEY;
|
||||
ByteUtils.writeUint16(payloadLength, input, 2);
|
||||
setup(input);
|
||||
kat.receiveKey();
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveKeyThrowsExceptionIfProtocolVersionIsUnrecognised()
|
||||
throws Exception {
|
||||
setup(createRecord((byte) (PROTOCOL_VERSION + 1), KEY, new byte[123]));
|
||||
byte unknownVersion = (byte) (PROTOCOL_VERSION + 1);
|
||||
byte[] key = getRandomBytes(123);
|
||||
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(new Record(unknownVersion, KEY, key)));
|
||||
}});
|
||||
|
||||
kat.receiveKey();
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveKeyThrowsExceptionIfAbortIsReceived()
|
||||
throws Exception {
|
||||
setup(createRecord(PROTOCOL_VERSION, ABORT, new byte[0]));
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0])));
|
||||
}});
|
||||
|
||||
kat.receiveKey();
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveKeyThrowsExceptionIfConfirmIsReceived()
|
||||
throws Exception {
|
||||
setup(createRecord(PROTOCOL_VERSION, CONFIRM, new byte[123]));
|
||||
byte[] confirm = getRandomBytes(123);
|
||||
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(new Record(PROTOCOL_VERSION, CONFIRM, confirm)));
|
||||
}});
|
||||
|
||||
kat.receiveKey();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReceiveKeySkipsUnrecognisedRecordTypes() throws Exception {
|
||||
byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 1),
|
||||
new byte[123]);
|
||||
byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 2),
|
||||
new byte[0]);
|
||||
byte[] payload = TestUtils.getRandomBytes(123);
|
||||
byte[] key = createRecord(PROTOCOL_VERSION, KEY, payload);
|
||||
ByteArrayOutputStream input = new ByteArrayOutputStream();
|
||||
input.write(skip1);
|
||||
input.write(skip2);
|
||||
input.write(key);
|
||||
setup(input.toByteArray());
|
||||
assertArrayEquals(payload, kat.receiveKey());
|
||||
byte type1 = (byte) (ABORT + 1);
|
||||
byte[] payload1 = getRandomBytes(123);
|
||||
Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1);
|
||||
byte type2 = (byte) (ABORT + 2);
|
||||
byte[] payload2 = new byte[0];
|
||||
Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2);
|
||||
byte[] key = getRandomBytes(123);
|
||||
Record keyRecord = new Record(PROTOCOL_VERSION, KEY, key);
|
||||
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(unknownRecord1));
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(unknownRecord2));
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(keyRecord));
|
||||
}});
|
||||
|
||||
assertArrayEquals(key, kat.receiveKey());
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveConfirmThrowsExceptionIfAtEndOfStream()
|
||||
throws Exception {
|
||||
setup(new byte[0]);
|
||||
kat.receiveConfirm();
|
||||
}
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(throwException(new EOFException()));
|
||||
}});
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveConfirmThrowsExceptionIfHeaderIsTooShort()
|
||||
throws Exception {
|
||||
byte[] input = new byte[RECORD_HEADER_LENGTH - 1];
|
||||
input[0] = PROTOCOL_VERSION;
|
||||
input[1] = CONFIRM;
|
||||
setup(input);
|
||||
kat.receiveConfirm();
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveConfirmThrowsExceptionIfPayloadIsTooShort()
|
||||
throws Exception {
|
||||
int payloadLength = 123;
|
||||
byte[] input = new byte[RECORD_HEADER_LENGTH + payloadLength - 1];
|
||||
input[0] = PROTOCOL_VERSION;
|
||||
input[1] = CONFIRM;
|
||||
ByteUtils.writeUint16(payloadLength, input, 2);
|
||||
setup(input);
|
||||
kat.receiveConfirm();
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveConfirmThrowsExceptionIfProtocolVersionIsUnrecognised()
|
||||
throws Exception {
|
||||
setup(createRecord((byte) (PROTOCOL_VERSION + 1), CONFIRM,
|
||||
new byte[123]));
|
||||
byte unknownVersion = (byte) (PROTOCOL_VERSION + 1);
|
||||
byte[] confirm = getRandomBytes(123);
|
||||
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(new Record(unknownVersion, CONFIRM, confirm)));
|
||||
}});
|
||||
|
||||
kat.receiveConfirm();
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveConfirmThrowsExceptionIfAbortIsReceived()
|
||||
throws Exception {
|
||||
setup(createRecord(PROTOCOL_VERSION, ABORT, new byte[0]));
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0])));
|
||||
}});
|
||||
|
||||
kat.receiveConfirm();
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveKeyThrowsExceptionIfKeyIsReceived()
|
||||
public void testReceiveConfirmThrowsExceptionIfKeyIsReceived()
|
||||
throws Exception {
|
||||
setup(createRecord(PROTOCOL_VERSION, KEY, new byte[123]));
|
||||
byte[] key = getRandomBytes(123);
|
||||
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(new Record(PROTOCOL_VERSION, KEY, key)));
|
||||
}});
|
||||
|
||||
kat.receiveConfirm();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReceiveConfirmSkipsUnrecognisedRecordTypes()
|
||||
throws Exception {
|
||||
byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 1),
|
||||
new byte[123]);
|
||||
byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 2),
|
||||
new byte[0]);
|
||||
byte[] payload = TestUtils.getRandomBytes(123);
|
||||
byte[] confirm = createRecord(PROTOCOL_VERSION, CONFIRM, payload);
|
||||
ByteArrayOutputStream input = new ByteArrayOutputStream();
|
||||
input.write(skip1);
|
||||
input.write(skip2);
|
||||
input.write(confirm);
|
||||
setup(input.toByteArray());
|
||||
assertArrayEquals(payload, kat.receiveConfirm());
|
||||
byte type1 = (byte) (ABORT + 1);
|
||||
byte[] payload1 = getRandomBytes(123);
|
||||
Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1);
|
||||
byte type2 = (byte) (ABORT + 2);
|
||||
byte[] payload2 = new byte[0];
|
||||
Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2);
|
||||
byte[] confirm = getRandomBytes(123);
|
||||
Record confirmRecord = new Record(PROTOCOL_VERSION, CONFIRM, confirm);
|
||||
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(unknownRecord1));
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(unknownRecord2));
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(confirmRecord));
|
||||
}});
|
||||
|
||||
assertArrayEquals(confirm, kat.receiveConfirm());
|
||||
}
|
||||
|
||||
private void setup(byte[] input) throws Exception {
|
||||
inputStream = new ByteArrayInputStream(input);
|
||||
outputStream = new ByteArrayOutputStream();
|
||||
private void setup() throws Exception {
|
||||
context.checking(new Expectations() {{
|
||||
allowing(duplexTransportConnection).getReader();
|
||||
will(returnValue(transportConnectionReader));
|
||||
allowing(transportConnectionReader).getInputStream();
|
||||
will(returnValue(inputStream));
|
||||
oneOf(recordReaderFactory).createRecordReader(inputStream);
|
||||
will(returnValue(recordReader));
|
||||
allowing(duplexTransportConnection).getWriter();
|
||||
will(returnValue(transportConnectionWriter));
|
||||
allowing(transportConnectionWriter).getOutputStream();
|
||||
will(returnValue(outputStream));
|
||||
oneOf(recordWriterFactory).createRecordWriter(outputStream);
|
||||
will(returnValue(recordWriter));
|
||||
}});
|
||||
kat = new KeyAgreementTransport(keyAgreementConnection);
|
||||
kat = new KeyAgreementTransport(recordReaderFactory,
|
||||
recordWriterFactory, keyAgreementConnection);
|
||||
}
|
||||
|
||||
private void assertRecordSent(byte expectedType, byte[] expectedPayload) {
|
||||
byte[] output = outputStream.toByteArray();
|
||||
assertEquals(RECORD_HEADER_LENGTH + expectedPayload.length,
|
||||
output.length);
|
||||
assertEquals(PROTOCOL_VERSION, output[0]);
|
||||
assertEquals(expectedType, output[1]);
|
||||
assertEquals(expectedPayload.length, ByteUtils.readUint16(output, 2));
|
||||
byte[] payload = new byte[output.length - RECORD_HEADER_LENGTH];
|
||||
System.arraycopy(output, RECORD_HEADER_LENGTH, payload, 0,
|
||||
payload.length);
|
||||
assertArrayEquals(expectedPayload, payload);
|
||||
private AtomicReference<Record> expectWriteRecord() throws Exception {
|
||||
AtomicReference<Record> captured = new AtomicReference<>();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordWriter).writeRecord(with(any(Record.class)));
|
||||
will(new CaptureArgumentAction<>(captured, Record.class, 0));
|
||||
oneOf(recordWriter).flush();
|
||||
}});
|
||||
return captured;
|
||||
}
|
||||
|
||||
private byte[] createRecord(byte version, byte type, byte[] payload) {
|
||||
byte[] b = new byte[RECORD_HEADER_LENGTH + payload.length];
|
||||
b[0] = version;
|
||||
b[1] = type;
|
||||
ByteUtils.writeUint16(payload.length, b, 2);
|
||||
System.arraycopy(payload, 0, b, RECORD_HEADER_LENGTH, payload.length);
|
||||
return b;
|
||||
private void assertRecordEquals(byte expectedVersion, byte expectedType,
|
||||
byte[] expectedPayload, Record actual) {
|
||||
assertEquals(expectedVersion, actual.getProtocolVersion());
|
||||
assertEquals(expectedType, actual.getRecordType());
|
||||
assertArrayEquals(expectedPayload, actual.getPayload());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user