diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/Predicate.java b/bramble-api/src/main/java/org/briarproject/bramble/api/Predicate.java new file mode 100644 index 000000000..0bd0520e4 --- /dev/null +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/Predicate.java @@ -0,0 +1,9 @@ +package org.briarproject.bramble.api; + +import org.briarproject.bramble.api.nullsafety.NotNullByDefault; + +@NotNullByDefault +public interface Predicate { + + boolean test(T t); +} diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/record/RecordReader.java b/bramble-api/src/main/java/org/briarproject/bramble/api/record/RecordReader.java index 371dead20..14756860c 100644 --- a/bramble-api/src/main/java/org/briarproject/bramble/api/record/RecordReader.java +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/record/RecordReader.java @@ -1,10 +1,14 @@ package org.briarproject.bramble.api.record; +import org.briarproject.bramble.api.FormatException; +import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.nullsafety.NotNullByDefault; import java.io.EOFException; import java.io.IOException; +import javax.annotation.Nullable; + @NotNullByDefault public interface RecordReader { @@ -16,5 +20,20 @@ public interface RecordReader { */ Record readRecord() throws IOException; + /** + * Reads and returns the next record matching the 'accept' predicate, + * skipping any records that match the 'ignore' predicate. Returns null if + * no record matching the 'accept' predicate is found before the end of the + * stream. + * + * @throws EOFException If the end of the stream is reached without + * reading a complete record + * @throws FormatException If a record is read that does not match the + * 'accept' or 'ignore' predicates + */ + @Nullable + Record readRecord(Predicate accept, Predicate ignore) + throws IOException; + void close() throws IOException; } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactExchangeTaskImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactExchangeTaskImpl.java index 061a43a48..a6c90699e 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactExchangeTaskImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactExchangeTaskImpl.java @@ -1,6 +1,7 @@ package org.briarproject.bramble.contact; import org.briarproject.bramble.api.FormatException; +import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.client.ClientHelper; import org.briarproject.bramble.api.contact.ContactExchangeTask; import org.briarproject.bramble.api.contact.ContactId; @@ -61,6 +62,20 @@ class ContactExchangeTaskImpl extends Thread implements ContactExchangeTask { private static final String SIGNING_LABEL_EXCHANGE = "org.briarproject.briar.contact/EXCHANGE"; + // Accept records with current protocol version, known record type + private static final Predicate ACCEPT = r -> + r.getProtocolVersion() == PROTOCOL_VERSION && + isKnownRecordType(r.getRecordType()); + + // Ignore records with current protocol version, unknown record type + private static final Predicate IGNORE = r -> + r.getProtocolVersion() == PROTOCOL_VERSION && + !isKnownRecordType(r.getRecordType()); + + private static boolean isKnownRecordType(byte type) { + return type == CONTACT_INFO; + } + private final DatabaseComponent db; private final ClientHelper clientHelper; private final RecordReaderFactory recordReaderFactory; @@ -191,11 +206,7 @@ class ContactExchangeTaskImpl extends Thread implements ContactExchangeTask { // Send EOF on the outgoing stream streamWriter.sendEndOfStream(); // Skip any remaining records from the incoming stream - try { - while (true) recordReader.readRecord(); - } catch (EOFException expected) { - LOG.info("End of stream"); - } + recordReader.readRecord(r -> false, IGNORE); } catch (IOException e) { logException(LOG, WARNING, e); eventBus.broadcast(new ContactExchangeFailedEvent()); @@ -268,12 +279,8 @@ class ContactExchangeTaskImpl extends Thread implements ContactExchangeTask { private ContactInfo receiveContactInfo(RecordReader recordReader) throws IOException { - Record record; - do { - record = recordReader.readRecord(); - if (record.getProtocolVersion() != PROTOCOL_VERSION) - throw new FormatException(); - } while (record.getRecordType() != CONTACT_INFO); + Record record = recordReader.readRecord(ACCEPT, IGNORE); + if (record == null) throw new EOFException(); LOG.info("Received contact info"); BdfList payload = clientHelper.toList(record.getPayload()); checkSize(payload, 4); diff --git a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java index 3dddec421..5785575ee 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java @@ -1,5 +1,6 @@ package org.briarproject.bramble.keyagreement; +import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection; import org.briarproject.bramble.api.nullsafety.NotNullByDefault; import org.briarproject.bramble.api.plugin.TransportId; @@ -10,6 +11,7 @@ import org.briarproject.bramble.api.record.RecordReaderFactory; import org.briarproject.bramble.api.record.RecordWriter; import org.briarproject.bramble.api.record.RecordWriterFactory; +import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -31,6 +33,20 @@ class KeyAgreementTransport { private static final Logger LOG = Logger.getLogger(KeyAgreementTransport.class.getName()); + // Accept records with current protocol version, known record type + private static Predicate ACCEPT = r -> + r.getProtocolVersion() == PROTOCOL_VERSION && + isKnownRecordType(r.getRecordType()); + + // Ignore records with current protocol version, unknown record type + private static Predicate IGNORE = r -> + r.getProtocolVersion() == PROTOCOL_VERSION && + !isKnownRecordType(r.getRecordType()); + + private static boolean isKnownRecordType(byte type) { + return type == KEY || type == CONFIRM || type == ABORT; + } + private final KeyAgreementConnection kac; private final RecordReader reader; private final RecordWriter writer; @@ -94,22 +110,15 @@ class KeyAgreementTransport { } private byte[] readRecord(byte expectedType) throws AbortException { - while (true) { - try { - Record record = reader.readRecord(); - // Reject unrecognised protocol version - if (record.getProtocolVersion() != PROTOCOL_VERSION) - throw new AbortException(false); - byte type = record.getRecordType(); - if (type == ABORT) throw new AbortException(true); - if (type == expectedType) return record.getPayload(); - // Reject recognised but unexpected record type - if (type == KEY || type == CONFIRM) - throw new AbortException(false); - // Skip unrecognised record type - } catch (IOException e) { - throw new AbortException(e); - } + try { + Record record = reader.readRecord(ACCEPT, IGNORE); + if (record == null) throw new AbortException(new EOFException()); + byte type = record.getRecordType(); + if (type == ABORT) throw new AbortException(true); + if (type != expectedType) throw new AbortException(false); + return record.getPayload(); + } catch (IOException e) { + throw new AbortException(e); } } } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/record/RecordReaderImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/record/RecordReaderImpl.java index 12aac514c..70d909fb2 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/record/RecordReaderImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/record/RecordReaderImpl.java @@ -1,15 +1,18 @@ package org.briarproject.bramble.record; import org.briarproject.bramble.api.FormatException; +import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.nullsafety.NotNullByDefault; import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.RecordReader; import org.briarproject.bramble.util.ByteUtils; +import java.io.BufferedInputStream; import java.io.DataInputStream; import java.io.IOException; import java.io.InputStream; +import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES; @@ -23,6 +26,7 @@ class RecordReaderImpl implements RecordReader { private final byte[] header = new byte[RECORD_HEADER_BYTES]; RecordReaderImpl(InputStream in) { + if (!in.markSupported()) in = new BufferedInputStream(in, 1); this.in = new DataInputStream(in); } @@ -39,8 +43,27 @@ class RecordReaderImpl implements RecordReader { return new Record(protocolVersion, recordType, payload); } + @Nullable + @Override + public Record readRecord(Predicate accept, Predicate ignore) + throws IOException { + while (true) { + if (eof()) return null; + Record r = readRecord(); + if (accept.test(r)) return r; + if (!ignore.test(r)) throw new FormatException(); + } + } + @Override public void close() throws IOException { in.close(); } + + private boolean eof() throws IOException { + in.mark(1); + int next = in.read(); + in.reset(); + return next == -1; + } } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java index c06f7876a..02b3a4154 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java @@ -1,6 +1,7 @@ package org.briarproject.bramble.sync; import org.briarproject.bramble.api.FormatException; +import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.UniqueId; import org.briarproject.bramble.api.nullsafety.NotNullByDefault; import org.briarproject.bramble.api.record.Record; @@ -14,7 +15,6 @@ import org.briarproject.bramble.api.sync.Request; import org.briarproject.bramble.api.sync.SyncRecordReader; import org.briarproject.bramble.util.ByteUtils; -import java.io.EOFException; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -33,6 +33,21 @@ import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION; @NotNullByDefault class SyncRecordReaderImpl implements SyncRecordReader { + // Accept records with current protocol version, known record type + private static final Predicate ACCEPT = r -> + r.getProtocolVersion() == PROTOCOL_VERSION && + isKnownRecordType(r.getRecordType()); + + // Ignore records with current protocol version, unknown record type + private static final Predicate IGNORE = r -> + r.getProtocolVersion() == PROTOCOL_VERSION && + !isKnownRecordType(r.getRecordType()); + + private static boolean isKnownRecordType(byte type) { + return type == ACK || type == MESSAGE || type == OFFER || + type == REQUEST; + } + private final MessageFactory messageFactory; private final RecordReader reader; @@ -45,22 +60,6 @@ class SyncRecordReaderImpl implements SyncRecordReader { this.reader = reader; } - private void readRecord() throws IOException { - if (nextRecord != null) throw new AssertionError(); - while (true) { - nextRecord = reader.readRecord(); - // Check the protocol version - byte version = nextRecord.getProtocolVersion(); - if (version != PROTOCOL_VERSION) throw new FormatException(); - byte type = nextRecord.getRecordType(); - // Return if this is a known record type, otherwise continue - if (type == ACK || type == MESSAGE || type == OFFER || - type == REQUEST) { - return; - } - } - } - private byte getNextRecordType() { if (nextRecord == null) throw new AssertionError(); return nextRecord.getRecordType(); @@ -78,14 +77,9 @@ class SyncRecordReaderImpl implements SyncRecordReader { public boolean eof() throws IOException { if (nextRecord != null) return false; if (eof) return true; - try { - readRecord(); - return false; - } catch (EOFException e) { - nextRecord = null; - eof = true; - return true; - } + nextRecord = reader.readRecord(ACCEPT, IGNORE); + if (nextRecord == null) eof = true; + return eof; } @Override @@ -154,5 +148,4 @@ class SyncRecordReaderImpl implements SyncRecordReader { if (!hasRequest()) throw new FormatException(); return new Request(readMessageIds()); } - } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java b/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java index cddb5b40e..c5c111c77 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java @@ -1,5 +1,6 @@ package org.briarproject.bramble.keyagreement; +import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection; import org.briarproject.bramble.api.plugin.TransportConnectionReader; import org.briarproject.bramble.api.plugin.TransportConnectionWriter; @@ -16,11 +17,12 @@ import org.jmock.Expectations; import org.jmock.lib.legacy.ClassImposteriser; import org.junit.Test; -import java.io.EOFException; import java.io.InputStream; import java.io.OutputStream; import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nullable; + import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION; import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT; import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM; @@ -70,7 +72,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase { kat.sendKey(key); assertNotNull(written.get()); - assertRecordEquals(PROTOCOL_VERSION, KEY, key, written.get()); + assertRecordEquals(KEY, key, written.get()); } @Test @@ -82,7 +84,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase { kat.sendConfirm(confirm); assertNotNull(written.get()); - assertRecordEquals(PROTOCOL_VERSION, CONFIRM, confirm, written.get()); + assertRecordEquals(CONFIRM, confirm, written.get()); } @Test @@ -96,7 +98,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase { kat.sendAbort(true); assertNotNull(written.get()); - assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get()); + assertRecordEquals(ABORT, new byte[0], written.get()); } @Test @@ -110,32 +112,14 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase { kat.sendAbort(false); assertNotNull(written.get()); - assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get()); + assertRecordEquals(ABORT, new byte[0], written.get()); } @Test(expected = AbortException.class) public void testReceiveKeyThrowsExceptionIfAtEndOfStream() throws Exception { setup(); - context.checking(new Expectations() {{ - oneOf(recordReader).readRecord(); - will(throwException(new EOFException())); - }}); - - kat.receiveKey(); - } - - @Test(expected = AbortException.class) - public void testReceiveKeyThrowsExceptionIfProtocolVersionIsUnrecognised() - throws Exception { - byte unknownVersion = (byte) (PROTOCOL_VERSION + 1); - byte[] key = getRandomBytes(123); - - setup(); - context.checking(new Expectations() {{ - oneOf(recordReader).readRecord(); - will(returnValue(new Record(unknownVersion, KEY, key))); - }}); + expectReadRecord(null); kat.receiveKey(); } @@ -144,10 +128,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase { public void testReceiveKeyThrowsExceptionIfAbortIsReceived() throws Exception { setup(); - context.checking(new Expectations() {{ - oneOf(recordReader).readRecord(); - will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0]))); - }}); + expectReadRecord(new Record(PROTOCOL_VERSION, ABORT, new byte[0])); kat.receiveKey(); } @@ -158,61 +139,16 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase { byte[] confirm = getRandomBytes(123); setup(); - context.checking(new Expectations() {{ - oneOf(recordReader).readRecord(); - will(returnValue(new Record(PROTOCOL_VERSION, CONFIRM, confirm))); - }}); + expectReadRecord(new Record(PROTOCOL_VERSION, CONFIRM, confirm)); kat.receiveKey(); } - @Test - public void testReceiveKeySkipsUnrecognisedRecordTypes() throws Exception { - byte type1 = (byte) (ABORT + 1); - byte[] payload1 = getRandomBytes(123); - Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1); - byte type2 = (byte) (ABORT + 2); - byte[] payload2 = new byte[0]; - Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2); - byte[] key = getRandomBytes(123); - Record keyRecord = new Record(PROTOCOL_VERSION, KEY, key); - - setup(); - context.checking(new Expectations() {{ - oneOf(recordReader).readRecord(); - will(returnValue(unknownRecord1)); - oneOf(recordReader).readRecord(); - will(returnValue(unknownRecord2)); - oneOf(recordReader).readRecord(); - will(returnValue(keyRecord)); - }}); - - assertArrayEquals(key, kat.receiveKey()); - } - @Test(expected = AbortException.class) public void testReceiveConfirmThrowsExceptionIfAtEndOfStream() throws Exception { setup(); - context.checking(new Expectations() {{ - oneOf(recordReader).readRecord(); - will(throwException(new EOFException())); - }}); - - kat.receiveConfirm(); - } - - @Test(expected = AbortException.class) - public void testReceiveConfirmThrowsExceptionIfProtocolVersionIsUnrecognised() - throws Exception { - byte unknownVersion = (byte) (PROTOCOL_VERSION + 1); - byte[] confirm = getRandomBytes(123); - - setup(); - context.checking(new Expectations() {{ - oneOf(recordReader).readRecord(); - will(returnValue(new Record(unknownVersion, CONFIRM, confirm))); - }}); + expectReadRecord(null); kat.receiveConfirm(); } @@ -221,10 +157,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase { public void testReceiveConfirmThrowsExceptionIfAbortIsReceived() throws Exception { setup(); - context.checking(new Expectations() {{ - oneOf(recordReader).readRecord(); - will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0]))); - }}); + expectReadRecord(new Record(PROTOCOL_VERSION, ABORT, new byte[0])); kat.receiveConfirm(); } @@ -235,39 +168,11 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase { byte[] key = getRandomBytes(123); setup(); - context.checking(new Expectations() {{ - oneOf(recordReader).readRecord(); - will(returnValue(new Record(PROTOCOL_VERSION, KEY, key))); - }}); + expectReadRecord(new Record(PROTOCOL_VERSION, KEY, key)); kat.receiveConfirm(); } - @Test - public void testReceiveConfirmSkipsUnrecognisedRecordTypes() - throws Exception { - byte type1 = (byte) (ABORT + 1); - byte[] payload1 = getRandomBytes(123); - Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1); - byte type2 = (byte) (ABORT + 2); - byte[] payload2 = new byte[0]; - Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2); - byte[] confirm = getRandomBytes(123); - Record confirmRecord = new Record(PROTOCOL_VERSION, CONFIRM, confirm); - - setup(); - context.checking(new Expectations() {{ - oneOf(recordReader).readRecord(); - will(returnValue(unknownRecord1)); - oneOf(recordReader).readRecord(); - will(returnValue(unknownRecord2)); - oneOf(recordReader).readRecord(); - will(returnValue(confirmRecord)); - }}); - - assertArrayEquals(confirm, kat.receiveConfirm()); - } - private void setup() throws Exception { context.checking(new Expectations() {{ allowing(duplexTransportConnection).getReader(); @@ -297,10 +202,19 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase { return captured; } - private void assertRecordEquals(byte expectedVersion, byte expectedType, + private void assertRecordEquals(byte expectedType, byte[] expectedPayload, Record actual) { - assertEquals(expectedVersion, actual.getProtocolVersion()); + assertEquals(PROTOCOL_VERSION, actual.getProtocolVersion()); assertEquals(expectedType, actual.getRecordType()); assertArrayEquals(expectedPayload, actual.getPayload()); } + + private void expectReadRecord(@Nullable Record record) throws Exception { + context.checking(new Expectations() {{ + //noinspection unchecked + oneOf(recordReader).readRecord(with(any(Predicate.class)), + with(any(Predicate.class))); + will(returnValue(record)); + }}); + } } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java index 26ef89c8e..c2cc4c55f 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java @@ -1,6 +1,7 @@ package org.briarproject.bramble.record; import org.briarproject.bramble.api.FormatException; +import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.RecordReader; import org.briarproject.bramble.test.BrambleTestCase; @@ -8,12 +9,17 @@ import org.briarproject.bramble.util.ByteUtils; import org.junit.Test; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.EOFException; import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES; import static org.briarproject.bramble.api.record.Record.RECORD_HEADER_BYTES; +import static org.briarproject.bramble.test.TestUtils.getRandomBytes; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; public class RecordReaderImplTest extends BrambleTestCase { @@ -99,4 +105,109 @@ public class RecordReaderImplTest extends BrambleTestCase { RecordReader reader = new RecordReaderImpl(in); reader.readRecord(); } + + @Test + public void testAcceptsAndRejectsRecords() throws Exception { + // Version 0, type 0, payload length 123 + byte[] header1 = new byte[] {0, 0, 0, 123}; + // Version 0, type 1, payload length 123 + byte[] header2 = new byte[] {0, 1, 0, 123}; + // Version 1, type 0, payload length 123 + byte[] header3 = new byte[] {1, 0, 0, 123}; + // Same payload for all records + byte[] payload = getRandomBytes(123); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + out.write(header1); + out.write(payload); + out.write(header2); + out.write(payload); + out.write(header3); + out.write(payload); + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + RecordReader reader = new RecordReaderImpl(in); + + // Accept records with version 0, type 0 or 1 + Predicate accept = r -> { + byte version = r.getProtocolVersion(), type = r.getRecordType(); + return version == 0 && (type == 0 || type == 1); + }; + // Ignore records with version 0, any other type + Predicate ignore = r -> { + byte version = r.getProtocolVersion(), type = r.getRecordType(); + return version == 0 && !(type == 0 || type == 1); + }; + + // The first record should be accepted + Record r = reader.readRecord(accept, ignore); + assertNotNull(r); + assertEquals(0, r.getProtocolVersion()); + assertEquals(0, r.getRecordType()); + assertArrayEquals(payload, r.getPayload()); + + // The second record should be accepted + r = reader.readRecord(accept, ignore); + assertNotNull(r); + assertEquals(0, r.getProtocolVersion()); + assertEquals(1, r.getRecordType()); + assertArrayEquals(payload, r.getPayload()); + + // The third record should be rejected + try { + reader.readRecord(accept, ignore); + fail(); + } catch (FormatException expected) { + // Expected + } + } + + @Test + public void testAcceptsAndIgnoresRecords() throws Exception { + // Version 0, type 0, payload length 123 + byte[] header1 = new byte[] {0, 0, 0, 123}; + // Version 0, type 2, payload length 123 + byte[] header2 = new byte[] {0, 2, 0, 123}; + // Version 0, type 1, payload length 123 + byte[] header3 = new byte[] {0, 1, 0, 123}; + // Same payload for all records + byte[] payload = getRandomBytes(123); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + out.write(header1); + out.write(payload); + out.write(header2); + out.write(payload); + out.write(header3); + out.write(payload); + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + RecordReader reader = new RecordReaderImpl(in); + + // Accept records with version 0, type 0 or 1 + Predicate accept = r -> { + byte version = r.getProtocolVersion(), type = r.getRecordType(); + return version == 0 && (type == 0 || type == 1); + }; + // Ignore records with version 0, any other type + Predicate ignore = r -> { + byte version = r.getProtocolVersion(), type = r.getRecordType(); + return version == 0 && !(type == 0 || type == 1); + }; + + // The first record should be accepted + Record r = reader.readRecord(accept, ignore); + assertNotNull(r); + assertEquals(0, r.getProtocolVersion()); + assertEquals(0, r.getRecordType()); + assertArrayEquals(payload, r.getPayload()); + + // The second record should be ignored, the third should be accepted + r = reader.readRecord(accept, ignore); + assertNotNull(r); + assertEquals(0, r.getProtocolVersion()); + assertEquals(1, r.getRecordType()); + assertArrayEquals(payload, r.getPayload()); + + // The reader should have reached the end of the stream + assertNull(reader.readRecord(accept, ignore)); + } } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java index ccd1207ff..ae52af979 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java @@ -1,6 +1,7 @@ package org.briarproject.bramble.sync; import org.briarproject.bramble.api.FormatException; +import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.UniqueId; import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.RecordReader; @@ -14,7 +15,8 @@ import org.jmock.Expectations; import org.junit.Test; import java.io.ByteArrayOutputStream; -import java.io.EOFException; + +import javax.annotation.Nullable; import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES; import static org.briarproject.bramble.api.sync.RecordTypes.ACK; @@ -22,7 +24,6 @@ import static org.briarproject.bramble.api.sync.RecordTypes.OFFER; import static org.briarproject.bramble.api.sync.RecordTypes.REQUEST; import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_IDS; import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION; -import static org.briarproject.bramble.test.TestUtils.getRandomBytes; import static org.briarproject.bramble.test.TestUtils.getRandomId; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -93,70 +94,24 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase { @Test public void testEofReturnsTrueWhenAtEndOfStream() throws Exception { - context.checking(new Expectations() {{ - oneOf(recordReader).readRecord(); - will(throwException(new EOFException())); - }}); - - SyncRecordReader reader = - new SyncRecordReaderImpl(messageFactory, recordReader); - assertTrue(reader.eof()); - assertTrue(reader.eof()); - } - - @Test - public void testEofReturnsFalseWhenNotAtEndOfStream() throws Exception { expectReadRecord(createAck()); + expectReadRecord(null); SyncRecordReader reader = new SyncRecordReaderImpl(messageFactory, recordReader); assertFalse(reader.eof()); - assertFalse(reader.eof()); - } - - @Test(expected = FormatException.class) - public void testThrowsExceptionIfProtocolVersionIsUnrecognised() - throws Exception { - byte version = (byte) (PROTOCOL_VERSION + 1); - byte[] payload = getRandomId(); - - expectReadRecord(new Record(version, ACK, payload)); - - SyncRecordReader reader = - new SyncRecordReaderImpl(messageFactory, recordReader); - reader.eof(); - } - - @Test - public void testSkipsUnrecognisedRecordTypes() throws Exception { - byte type1 = (byte) (REQUEST + 1); - byte[] payload1 = getRandomBytes(123); - Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1); - byte type2 = (byte) (REQUEST + 2); - byte[] payload2 = new byte[0]; - Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2); - Record ackRecord = createAck(); - - context.checking(new Expectations() {{ - oneOf(recordReader).readRecord(); - will(returnValue(unknownRecord1)); - oneOf(recordReader).readRecord(); - will(returnValue(unknownRecord2)); - oneOf(recordReader).readRecord(); - will(returnValue(ackRecord)); - - }}); - - SyncRecordReader reader = - new SyncRecordReaderImpl(messageFactory, recordReader); assertTrue(reader.hasAck()); - Ack a = reader.readAck(); - assertEquals(MAX_MESSAGE_IDS, a.getMessageIds().size()); + Ack ack = reader.readAck(); + assertEquals(MAX_MESSAGE_IDS, ack.getMessageIds().size()); + assertTrue(reader.eof()); + assertTrue(reader.eof()); } - private void expectReadRecord(Record record) throws Exception { + private void expectReadRecord(@Nullable Record record) throws Exception { context.checking(new Expectations() {{ - oneOf(recordReader).readRecord(); + //noinspection unchecked + oneOf(recordReader).readRecord(with(any(Predicate.class)), + with(any(Predicate.class))); will(returnValue(record)); }}); } @@ -165,7 +120,7 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase { return new Record(PROTOCOL_VERSION, ACK, createPayload()); } - private Record createEmptyAck() throws Exception { + private Record createEmptyAck() { return new Record(PROTOCOL_VERSION, ACK, new byte[0]); } @@ -173,7 +128,7 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase { return new Record(PROTOCOL_VERSION, OFFER, createPayload()); } - private Record createEmptyOffer() throws Exception { + private Record createEmptyOffer() { return new Record(PROTOCOL_VERSION, OFFER, new byte[0]); } @@ -181,7 +136,7 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase { return new Record(PROTOCOL_VERSION, REQUEST, createPayload()); } - private Record createEmptyRequest() throws Exception { + private Record createEmptyRequest() { return new Record(PROTOCOL_VERSION, REQUEST, new byte[0]); }