Use generic record reader/writer for key agreement.

This commit is contained in:
akwizgran
2018-04-19 11:55:58 +01:00
parent cc2791c37f
commit 6fa6ceb5ee
5 changed files with 216 additions and 175 deletions

View File

@@ -5,23 +5,31 @@ import org.briarproject.bramble.api.plugin.TransportConnectionReader;
import org.briarproject.bramble.api.plugin.TransportConnectionWriter;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import org.briarproject.bramble.test.BrambleMockTestCase;
import org.briarproject.bramble.test.TestUtils;
import org.briarproject.bramble.util.ByteUtils;
import org.briarproject.bramble.test.CaptureArgumentAction;
import org.jmock.Expectations;
import org.jmock.lib.legacy.ClassImposteriser;
import org.junit.Test;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.concurrent.atomic.AtomicReference;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_LENGTH;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.KEY;
import static org.briarproject.bramble.test.TestUtils.getRandomBytes;
import static org.briarproject.bramble.test.TestUtils.getTransportId;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class KeyAgreementTransportTest extends BrambleMockTestCase {
@@ -31,222 +39,268 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
context.mock(TransportConnectionReader.class);
private final TransportConnectionWriter transportConnectionWriter =
context.mock(TransportConnectionWriter.class);
private final RecordReaderFactory recordReaderFactory =
context.mock(RecordReaderFactory.class);
private final RecordWriterFactory recordWriterFactory =
context.mock(RecordWriterFactory.class);
private final RecordReader recordReader = context.mock(RecordReader.class);
private final RecordWriter recordWriter = context.mock(RecordWriter.class);
private final TransportId transportId = getTransportId();
private final KeyAgreementConnection keyAgreementConnection =
new KeyAgreementConnection(duplexTransportConnection, transportId);
private ByteArrayInputStream inputStream;
private ByteArrayOutputStream outputStream;
private final InputStream inputStream;
private final OutputStream outputStream;
private KeyAgreementTransport kat;
public KeyAgreementTransportTest() {
context.setImposteriser(ClassImposteriser.INSTANCE);
inputStream = context.mock(InputStream.class);
outputStream = context.mock(OutputStream.class);
}
@Test
public void testSendKey() throws Exception {
setup(new byte[0]);
byte[] key = TestUtils.getRandomBytes(123);
byte[] key = getRandomBytes(123);
setup();
AtomicReference<Record> written = expectWriteRecord();
kat.sendKey(key);
assertRecordSent(KEY, key);
assertNotNull(written.get());
assertRecordEquals(PROTOCOL_VERSION, KEY, key, written.get());
}
@Test
public void testSendConfirm() throws Exception {
setup(new byte[0]);
byte[] confirm = TestUtils.getRandomBytes(123);
byte[] confirm = getRandomBytes(123);
setup();
AtomicReference<Record> written = expectWriteRecord();
kat.sendConfirm(confirm);
assertRecordSent(CONFIRM, confirm);
assertNotNull(written.get());
assertRecordEquals(PROTOCOL_VERSION, CONFIRM, confirm, written.get());
}
@Test
public void testSendAbortWithException() throws Exception {
setup(new byte[0]);
setup();
AtomicReference<Record> written = expectWriteRecord();
context.checking(new Expectations() {{
oneOf(transportConnectionReader).dispose(true, true);
oneOf(transportConnectionWriter).dispose(true);
}});
kat.sendAbort(true);
assertRecordSent(ABORT, new byte[0]);
assertNotNull(written.get());
assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get());
}
@Test
public void testSendAbortWithoutException() throws Exception {
setup(new byte[0]);
setup();
AtomicReference<Record> written = expectWriteRecord();
context.checking(new Expectations() {{
oneOf(transportConnectionReader).dispose(false, true);
oneOf(transportConnectionWriter).dispose(false);
}});
kat.sendAbort(false);
assertRecordSent(ABORT, new byte[0]);
assertNotNull(written.get());
assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get());
}
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfAtEndOfStream()
throws Exception {
setup(new byte[0]);
kat.receiveKey();
}
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(throwException(new EOFException()));
}});
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfHeaderIsTooShort()
throws Exception {
byte[] input = new byte[RECORD_HEADER_LENGTH - 1];
input[0] = PROTOCOL_VERSION;
input[1] = KEY;
setup(input);
kat.receiveKey();
}
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfPayloadIsTooShort()
throws Exception {
int payloadLength = 123;
byte[] input = new byte[RECORD_HEADER_LENGTH + payloadLength - 1];
input[0] = PROTOCOL_VERSION;
input[1] = KEY;
ByteUtils.writeUint16(payloadLength, input, 2);
setup(input);
kat.receiveKey();
}
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfProtocolVersionIsUnrecognised()
throws Exception {
setup(createRecord((byte) (PROTOCOL_VERSION + 1), KEY, new byte[123]));
byte unknownVersion = (byte) (PROTOCOL_VERSION + 1);
byte[] key = getRandomBytes(123);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(unknownVersion, KEY, key)));
}});
kat.receiveKey();
}
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfAbortIsReceived()
throws Exception {
setup(createRecord(PROTOCOL_VERSION, ABORT, new byte[0]));
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0])));
}});
kat.receiveKey();
}
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfConfirmIsReceived()
throws Exception {
setup(createRecord(PROTOCOL_VERSION, CONFIRM, new byte[123]));
byte[] confirm = getRandomBytes(123);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(PROTOCOL_VERSION, CONFIRM, confirm)));
}});
kat.receiveKey();
}
@Test
public void testReceiveKeySkipsUnrecognisedRecordTypes() throws Exception {
byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 1),
new byte[123]);
byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 2),
new byte[0]);
byte[] payload = TestUtils.getRandomBytes(123);
byte[] key = createRecord(PROTOCOL_VERSION, KEY, payload);
ByteArrayOutputStream input = new ByteArrayOutputStream();
input.write(skip1);
input.write(skip2);
input.write(key);
setup(input.toByteArray());
assertArrayEquals(payload, kat.receiveKey());
byte type1 = (byte) (ABORT + 1);
byte[] payload1 = getRandomBytes(123);
Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1);
byte type2 = (byte) (ABORT + 2);
byte[] payload2 = new byte[0];
Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2);
byte[] key = getRandomBytes(123);
Record keyRecord = new Record(PROTOCOL_VERSION, KEY, key);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(unknownRecord1));
oneOf(recordReader).readRecord();
will(returnValue(unknownRecord2));
oneOf(recordReader).readRecord();
will(returnValue(keyRecord));
}});
assertArrayEquals(key, kat.receiveKey());
}
@Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfAtEndOfStream()
throws Exception {
setup(new byte[0]);
kat.receiveConfirm();
}
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(throwException(new EOFException()));
}});
@Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfHeaderIsTooShort()
throws Exception {
byte[] input = new byte[RECORD_HEADER_LENGTH - 1];
input[0] = PROTOCOL_VERSION;
input[1] = CONFIRM;
setup(input);
kat.receiveConfirm();
}
@Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfPayloadIsTooShort()
throws Exception {
int payloadLength = 123;
byte[] input = new byte[RECORD_HEADER_LENGTH + payloadLength - 1];
input[0] = PROTOCOL_VERSION;
input[1] = CONFIRM;
ByteUtils.writeUint16(payloadLength, input, 2);
setup(input);
kat.receiveConfirm();
}
@Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfProtocolVersionIsUnrecognised()
throws Exception {
setup(createRecord((byte) (PROTOCOL_VERSION + 1), CONFIRM,
new byte[123]));
byte unknownVersion = (byte) (PROTOCOL_VERSION + 1);
byte[] confirm = getRandomBytes(123);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(unknownVersion, CONFIRM, confirm)));
}});
kat.receiveConfirm();
}
@Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfAbortIsReceived()
throws Exception {
setup(createRecord(PROTOCOL_VERSION, ABORT, new byte[0]));
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0])));
}});
kat.receiveConfirm();
}
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfKeyIsReceived()
public void testReceiveConfirmThrowsExceptionIfKeyIsReceived()
throws Exception {
setup(createRecord(PROTOCOL_VERSION, KEY, new byte[123]));
byte[] key = getRandomBytes(123);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(PROTOCOL_VERSION, KEY, key)));
}});
kat.receiveConfirm();
}
@Test
public void testReceiveConfirmSkipsUnrecognisedRecordTypes()
throws Exception {
byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 1),
new byte[123]);
byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 2),
new byte[0]);
byte[] payload = TestUtils.getRandomBytes(123);
byte[] confirm = createRecord(PROTOCOL_VERSION, CONFIRM, payload);
ByteArrayOutputStream input = new ByteArrayOutputStream();
input.write(skip1);
input.write(skip2);
input.write(confirm);
setup(input.toByteArray());
assertArrayEquals(payload, kat.receiveConfirm());
byte type1 = (byte) (ABORT + 1);
byte[] payload1 = getRandomBytes(123);
Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1);
byte type2 = (byte) (ABORT + 2);
byte[] payload2 = new byte[0];
Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2);
byte[] confirm = getRandomBytes(123);
Record confirmRecord = new Record(PROTOCOL_VERSION, CONFIRM, confirm);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(unknownRecord1));
oneOf(recordReader).readRecord();
will(returnValue(unknownRecord2));
oneOf(recordReader).readRecord();
will(returnValue(confirmRecord));
}});
assertArrayEquals(confirm, kat.receiveConfirm());
}
private void setup(byte[] input) throws Exception {
inputStream = new ByteArrayInputStream(input);
outputStream = new ByteArrayOutputStream();
private void setup() throws Exception {
context.checking(new Expectations() {{
allowing(duplexTransportConnection).getReader();
will(returnValue(transportConnectionReader));
allowing(transportConnectionReader).getInputStream();
will(returnValue(inputStream));
oneOf(recordReaderFactory).createRecordReader(inputStream);
will(returnValue(recordReader));
allowing(duplexTransportConnection).getWriter();
will(returnValue(transportConnectionWriter));
allowing(transportConnectionWriter).getOutputStream();
will(returnValue(outputStream));
oneOf(recordWriterFactory).createRecordWriter(outputStream);
will(returnValue(recordWriter));
}});
kat = new KeyAgreementTransport(keyAgreementConnection);
kat = new KeyAgreementTransport(recordReaderFactory,
recordWriterFactory, keyAgreementConnection);
}
private void assertRecordSent(byte expectedType, byte[] expectedPayload) {
byte[] output = outputStream.toByteArray();
assertEquals(RECORD_HEADER_LENGTH + expectedPayload.length,
output.length);
assertEquals(PROTOCOL_VERSION, output[0]);
assertEquals(expectedType, output[1]);
assertEquals(expectedPayload.length, ByteUtils.readUint16(output, 2));
byte[] payload = new byte[output.length - RECORD_HEADER_LENGTH];
System.arraycopy(output, RECORD_HEADER_LENGTH, payload, 0,
payload.length);
assertArrayEquals(expectedPayload, payload);
private AtomicReference<Record> expectWriteRecord() throws Exception {
AtomicReference<Record> captured = new AtomicReference<>();
context.checking(new Expectations() {{
oneOf(recordWriter).writeRecord(with(any(Record.class)));
will(new CaptureArgumentAction<>(captured, Record.class, 0));
oneOf(recordWriter).flush();
}});
return captured;
}
private byte[] createRecord(byte version, byte type, byte[] payload) {
byte[] b = new byte[RECORD_HEADER_LENGTH + payload.length];
b[0] = version;
b[1] = type;
ByteUtils.writeUint16(payload.length, b, 2);
System.arraycopy(payload, 0, b, RECORD_HEADER_LENGTH, payload.length);
return b;
private void assertRecordEquals(byte expectedVersion, byte expectedType,
byte[] expectedPayload, Record actual) {
assertEquals(expectedVersion, actual.getProtocolVersion());
assertEquals(expectedType, actual.getRecordType());
assertArrayEquals(expectedPayload, actual.getPayload());
}
}