mirror of
https://code.briarproject.org/briar/briar.git
synced 2026-02-14 11:49:04 +01:00
Use predicates to specify records to accept or ignore.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
package org.briarproject.bramble.keyagreement;
|
||||
|
||||
import org.briarproject.bramble.api.Predicate;
|
||||
import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection;
|
||||
import org.briarproject.bramble.api.plugin.TransportConnectionReader;
|
||||
import org.briarproject.bramble.api.plugin.TransportConnectionWriter;
|
||||
@@ -16,11 +17,12 @@ import org.jmock.Expectations;
|
||||
import org.jmock.lib.legacy.ClassImposteriser;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.EOFException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION;
|
||||
import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT;
|
||||
import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM;
|
||||
@@ -70,7 +72,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
|
||||
|
||||
kat.sendKey(key);
|
||||
assertNotNull(written.get());
|
||||
assertRecordEquals(PROTOCOL_VERSION, KEY, key, written.get());
|
||||
assertRecordEquals(KEY, key, written.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -82,7 +84,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
|
||||
|
||||
kat.sendConfirm(confirm);
|
||||
assertNotNull(written.get());
|
||||
assertRecordEquals(PROTOCOL_VERSION, CONFIRM, confirm, written.get());
|
||||
assertRecordEquals(CONFIRM, confirm, written.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -96,7 +98,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
|
||||
|
||||
kat.sendAbort(true);
|
||||
assertNotNull(written.get());
|
||||
assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get());
|
||||
assertRecordEquals(ABORT, new byte[0], written.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -110,32 +112,14 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
|
||||
|
||||
kat.sendAbort(false);
|
||||
assertNotNull(written.get());
|
||||
assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get());
|
||||
assertRecordEquals(ABORT, new byte[0], written.get());
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveKeyThrowsExceptionIfAtEndOfStream()
|
||||
throws Exception {
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(throwException(new EOFException()));
|
||||
}});
|
||||
|
||||
kat.receiveKey();
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveKeyThrowsExceptionIfProtocolVersionIsUnrecognised()
|
||||
throws Exception {
|
||||
byte unknownVersion = (byte) (PROTOCOL_VERSION + 1);
|
||||
byte[] key = getRandomBytes(123);
|
||||
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(new Record(unknownVersion, KEY, key)));
|
||||
}});
|
||||
expectReadRecord(null);
|
||||
|
||||
kat.receiveKey();
|
||||
}
|
||||
@@ -144,10 +128,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
|
||||
public void testReceiveKeyThrowsExceptionIfAbortIsReceived()
|
||||
throws Exception {
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0])));
|
||||
}});
|
||||
expectReadRecord(new Record(PROTOCOL_VERSION, ABORT, new byte[0]));
|
||||
|
||||
kat.receiveKey();
|
||||
}
|
||||
@@ -158,61 +139,16 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
|
||||
byte[] confirm = getRandomBytes(123);
|
||||
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(new Record(PROTOCOL_VERSION, CONFIRM, confirm)));
|
||||
}});
|
||||
expectReadRecord(new Record(PROTOCOL_VERSION, CONFIRM, confirm));
|
||||
|
||||
kat.receiveKey();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReceiveKeySkipsUnrecognisedRecordTypes() throws Exception {
|
||||
byte type1 = (byte) (ABORT + 1);
|
||||
byte[] payload1 = getRandomBytes(123);
|
||||
Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1);
|
||||
byte type2 = (byte) (ABORT + 2);
|
||||
byte[] payload2 = new byte[0];
|
||||
Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2);
|
||||
byte[] key = getRandomBytes(123);
|
||||
Record keyRecord = new Record(PROTOCOL_VERSION, KEY, key);
|
||||
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(unknownRecord1));
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(unknownRecord2));
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(keyRecord));
|
||||
}});
|
||||
|
||||
assertArrayEquals(key, kat.receiveKey());
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveConfirmThrowsExceptionIfAtEndOfStream()
|
||||
throws Exception {
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(throwException(new EOFException()));
|
||||
}});
|
||||
|
||||
kat.receiveConfirm();
|
||||
}
|
||||
|
||||
@Test(expected = AbortException.class)
|
||||
public void testReceiveConfirmThrowsExceptionIfProtocolVersionIsUnrecognised()
|
||||
throws Exception {
|
||||
byte unknownVersion = (byte) (PROTOCOL_VERSION + 1);
|
||||
byte[] confirm = getRandomBytes(123);
|
||||
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(new Record(unknownVersion, CONFIRM, confirm)));
|
||||
}});
|
||||
expectReadRecord(null);
|
||||
|
||||
kat.receiveConfirm();
|
||||
}
|
||||
@@ -221,10 +157,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
|
||||
public void testReceiveConfirmThrowsExceptionIfAbortIsReceived()
|
||||
throws Exception {
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0])));
|
||||
}});
|
||||
expectReadRecord(new Record(PROTOCOL_VERSION, ABORT, new byte[0]));
|
||||
|
||||
kat.receiveConfirm();
|
||||
}
|
||||
@@ -235,39 +168,11 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
|
||||
byte[] key = getRandomBytes(123);
|
||||
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(new Record(PROTOCOL_VERSION, KEY, key)));
|
||||
}});
|
||||
expectReadRecord(new Record(PROTOCOL_VERSION, KEY, key));
|
||||
|
||||
kat.receiveConfirm();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReceiveConfirmSkipsUnrecognisedRecordTypes()
|
||||
throws Exception {
|
||||
byte type1 = (byte) (ABORT + 1);
|
||||
byte[] payload1 = getRandomBytes(123);
|
||||
Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1);
|
||||
byte type2 = (byte) (ABORT + 2);
|
||||
byte[] payload2 = new byte[0];
|
||||
Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2);
|
||||
byte[] confirm = getRandomBytes(123);
|
||||
Record confirmRecord = new Record(PROTOCOL_VERSION, CONFIRM, confirm);
|
||||
|
||||
setup();
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(unknownRecord1));
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(unknownRecord2));
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(confirmRecord));
|
||||
}});
|
||||
|
||||
assertArrayEquals(confirm, kat.receiveConfirm());
|
||||
}
|
||||
|
||||
private void setup() throws Exception {
|
||||
context.checking(new Expectations() {{
|
||||
allowing(duplexTransportConnection).getReader();
|
||||
@@ -297,10 +202,19 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
|
||||
return captured;
|
||||
}
|
||||
|
||||
private void assertRecordEquals(byte expectedVersion, byte expectedType,
|
||||
private void assertRecordEquals(byte expectedType,
|
||||
byte[] expectedPayload, Record actual) {
|
||||
assertEquals(expectedVersion, actual.getProtocolVersion());
|
||||
assertEquals(PROTOCOL_VERSION, actual.getProtocolVersion());
|
||||
assertEquals(expectedType, actual.getRecordType());
|
||||
assertArrayEquals(expectedPayload, actual.getPayload());
|
||||
}
|
||||
|
||||
private void expectReadRecord(@Nullable Record record) throws Exception {
|
||||
context.checking(new Expectations() {{
|
||||
//noinspection unchecked
|
||||
oneOf(recordReader).readRecord(with(any(Predicate.class)),
|
||||
with(any(Predicate.class)));
|
||||
will(returnValue(record));
|
||||
}});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.briarproject.bramble.record;
|
||||
|
||||
import org.briarproject.bramble.api.FormatException;
|
||||
import org.briarproject.bramble.api.Predicate;
|
||||
import org.briarproject.bramble.api.record.Record;
|
||||
import org.briarproject.bramble.api.record.RecordReader;
|
||||
import org.briarproject.bramble.test.BrambleTestCase;
|
||||
@@ -8,12 +9,17 @@ import org.briarproject.bramble.util.ByteUtils;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.EOFException;
|
||||
|
||||
import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES;
|
||||
import static org.briarproject.bramble.api.record.Record.RECORD_HEADER_BYTES;
|
||||
import static org.briarproject.bramble.test.TestUtils.getRandomBytes;
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
public class RecordReaderImplTest extends BrambleTestCase {
|
||||
|
||||
@@ -99,4 +105,109 @@ public class RecordReaderImplTest extends BrambleTestCase {
|
||||
RecordReader reader = new RecordReaderImpl(in);
|
||||
reader.readRecord();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAcceptsAndRejectsRecords() throws Exception {
|
||||
// Version 0, type 0, payload length 123
|
||||
byte[] header1 = new byte[] {0, 0, 0, 123};
|
||||
// Version 0, type 1, payload length 123
|
||||
byte[] header2 = new byte[] {0, 1, 0, 123};
|
||||
// Version 1, type 0, payload length 123
|
||||
byte[] header3 = new byte[] {1, 0, 0, 123};
|
||||
// Same payload for all records
|
||||
byte[] payload = getRandomBytes(123);
|
||||
|
||||
ByteArrayOutputStream out = new ByteArrayOutputStream();
|
||||
out.write(header1);
|
||||
out.write(payload);
|
||||
out.write(header2);
|
||||
out.write(payload);
|
||||
out.write(header3);
|
||||
out.write(payload);
|
||||
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
|
||||
RecordReader reader = new RecordReaderImpl(in);
|
||||
|
||||
// Accept records with version 0, type 0 or 1
|
||||
Predicate<Record> accept = r -> {
|
||||
byte version = r.getProtocolVersion(), type = r.getRecordType();
|
||||
return version == 0 && (type == 0 || type == 1);
|
||||
};
|
||||
// Ignore records with version 0, any other type
|
||||
Predicate<Record> ignore = r -> {
|
||||
byte version = r.getProtocolVersion(), type = r.getRecordType();
|
||||
return version == 0 && !(type == 0 || type == 1);
|
||||
};
|
||||
|
||||
// The first record should be accepted
|
||||
Record r = reader.readRecord(accept, ignore);
|
||||
assertNotNull(r);
|
||||
assertEquals(0, r.getProtocolVersion());
|
||||
assertEquals(0, r.getRecordType());
|
||||
assertArrayEquals(payload, r.getPayload());
|
||||
|
||||
// The second record should be accepted
|
||||
r = reader.readRecord(accept, ignore);
|
||||
assertNotNull(r);
|
||||
assertEquals(0, r.getProtocolVersion());
|
||||
assertEquals(1, r.getRecordType());
|
||||
assertArrayEquals(payload, r.getPayload());
|
||||
|
||||
// The third record should be rejected
|
||||
try {
|
||||
reader.readRecord(accept, ignore);
|
||||
fail();
|
||||
} catch (FormatException expected) {
|
||||
// Expected
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAcceptsAndIgnoresRecords() throws Exception {
|
||||
// Version 0, type 0, payload length 123
|
||||
byte[] header1 = new byte[] {0, 0, 0, 123};
|
||||
// Version 0, type 2, payload length 123
|
||||
byte[] header2 = new byte[] {0, 2, 0, 123};
|
||||
// Version 0, type 1, payload length 123
|
||||
byte[] header3 = new byte[] {0, 1, 0, 123};
|
||||
// Same payload for all records
|
||||
byte[] payload = getRandomBytes(123);
|
||||
|
||||
ByteArrayOutputStream out = new ByteArrayOutputStream();
|
||||
out.write(header1);
|
||||
out.write(payload);
|
||||
out.write(header2);
|
||||
out.write(payload);
|
||||
out.write(header3);
|
||||
out.write(payload);
|
||||
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
|
||||
RecordReader reader = new RecordReaderImpl(in);
|
||||
|
||||
// Accept records with version 0, type 0 or 1
|
||||
Predicate<Record> accept = r -> {
|
||||
byte version = r.getProtocolVersion(), type = r.getRecordType();
|
||||
return version == 0 && (type == 0 || type == 1);
|
||||
};
|
||||
// Ignore records with version 0, any other type
|
||||
Predicate<Record> ignore = r -> {
|
||||
byte version = r.getProtocolVersion(), type = r.getRecordType();
|
||||
return version == 0 && !(type == 0 || type == 1);
|
||||
};
|
||||
|
||||
// The first record should be accepted
|
||||
Record r = reader.readRecord(accept, ignore);
|
||||
assertNotNull(r);
|
||||
assertEquals(0, r.getProtocolVersion());
|
||||
assertEquals(0, r.getRecordType());
|
||||
assertArrayEquals(payload, r.getPayload());
|
||||
|
||||
// The second record should be ignored, the third should be accepted
|
||||
r = reader.readRecord(accept, ignore);
|
||||
assertNotNull(r);
|
||||
assertEquals(0, r.getProtocolVersion());
|
||||
assertEquals(1, r.getRecordType());
|
||||
assertArrayEquals(payload, r.getPayload());
|
||||
|
||||
// The reader should have reached the end of the stream
|
||||
assertNull(reader.readRecord(accept, ignore));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.briarproject.bramble.sync;
|
||||
|
||||
import org.briarproject.bramble.api.FormatException;
|
||||
import org.briarproject.bramble.api.Predicate;
|
||||
import org.briarproject.bramble.api.UniqueId;
|
||||
import org.briarproject.bramble.api.record.Record;
|
||||
import org.briarproject.bramble.api.record.RecordReader;
|
||||
@@ -14,7 +15,8 @@ import org.jmock.Expectations;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.EOFException;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES;
|
||||
import static org.briarproject.bramble.api.sync.RecordTypes.ACK;
|
||||
@@ -22,7 +24,6 @@ import static org.briarproject.bramble.api.sync.RecordTypes.OFFER;
|
||||
import static org.briarproject.bramble.api.sync.RecordTypes.REQUEST;
|
||||
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_IDS;
|
||||
import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION;
|
||||
import static org.briarproject.bramble.test.TestUtils.getRandomBytes;
|
||||
import static org.briarproject.bramble.test.TestUtils.getRandomId;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
@@ -93,70 +94,24 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase {
|
||||
|
||||
@Test
|
||||
public void testEofReturnsTrueWhenAtEndOfStream() throws Exception {
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(throwException(new EOFException()));
|
||||
}});
|
||||
|
||||
SyncRecordReader reader =
|
||||
new SyncRecordReaderImpl(messageFactory, recordReader);
|
||||
assertTrue(reader.eof());
|
||||
assertTrue(reader.eof());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEofReturnsFalseWhenNotAtEndOfStream() throws Exception {
|
||||
expectReadRecord(createAck());
|
||||
expectReadRecord(null);
|
||||
|
||||
SyncRecordReader reader =
|
||||
new SyncRecordReaderImpl(messageFactory, recordReader);
|
||||
assertFalse(reader.eof());
|
||||
assertFalse(reader.eof());
|
||||
}
|
||||
|
||||
@Test(expected = FormatException.class)
|
||||
public void testThrowsExceptionIfProtocolVersionIsUnrecognised()
|
||||
throws Exception {
|
||||
byte version = (byte) (PROTOCOL_VERSION + 1);
|
||||
byte[] payload = getRandomId();
|
||||
|
||||
expectReadRecord(new Record(version, ACK, payload));
|
||||
|
||||
SyncRecordReader reader =
|
||||
new SyncRecordReaderImpl(messageFactory, recordReader);
|
||||
reader.eof();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSkipsUnrecognisedRecordTypes() throws Exception {
|
||||
byte type1 = (byte) (REQUEST + 1);
|
||||
byte[] payload1 = getRandomBytes(123);
|
||||
Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1);
|
||||
byte type2 = (byte) (REQUEST + 2);
|
||||
byte[] payload2 = new byte[0];
|
||||
Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2);
|
||||
Record ackRecord = createAck();
|
||||
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(unknownRecord1));
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(unknownRecord2));
|
||||
oneOf(recordReader).readRecord();
|
||||
will(returnValue(ackRecord));
|
||||
|
||||
}});
|
||||
|
||||
SyncRecordReader reader =
|
||||
new SyncRecordReaderImpl(messageFactory, recordReader);
|
||||
assertTrue(reader.hasAck());
|
||||
Ack a = reader.readAck();
|
||||
assertEquals(MAX_MESSAGE_IDS, a.getMessageIds().size());
|
||||
Ack ack = reader.readAck();
|
||||
assertEquals(MAX_MESSAGE_IDS, ack.getMessageIds().size());
|
||||
assertTrue(reader.eof());
|
||||
assertTrue(reader.eof());
|
||||
}
|
||||
|
||||
private void expectReadRecord(Record record) throws Exception {
|
||||
private void expectReadRecord(@Nullable Record record) throws Exception {
|
||||
context.checking(new Expectations() {{
|
||||
oneOf(recordReader).readRecord();
|
||||
//noinspection unchecked
|
||||
oneOf(recordReader).readRecord(with(any(Predicate.class)),
|
||||
with(any(Predicate.class)));
|
||||
will(returnValue(record));
|
||||
}});
|
||||
}
|
||||
@@ -165,7 +120,7 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase {
|
||||
return new Record(PROTOCOL_VERSION, ACK, createPayload());
|
||||
}
|
||||
|
||||
private Record createEmptyAck() throws Exception {
|
||||
private Record createEmptyAck() {
|
||||
return new Record(PROTOCOL_VERSION, ACK, new byte[0]);
|
||||
}
|
||||
|
||||
@@ -173,7 +128,7 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase {
|
||||
return new Record(PROTOCOL_VERSION, OFFER, createPayload());
|
||||
}
|
||||
|
||||
private Record createEmptyOffer() throws Exception {
|
||||
private Record createEmptyOffer() {
|
||||
return new Record(PROTOCOL_VERSION, OFFER, new byte[0]);
|
||||
}
|
||||
|
||||
@@ -181,7 +136,7 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase {
|
||||
return new Record(PROTOCOL_VERSION, REQUEST, createPayload());
|
||||
}
|
||||
|
||||
private Record createEmptyRequest() throws Exception {
|
||||
private Record createEmptyRequest() {
|
||||
return new Record(PROTOCOL_VERSION, REQUEST, new byte[0]);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user