diff --git a/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java index 2329999ce..250f72953 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java @@ -30,6 +30,7 @@ import static org.briarproject.bramble.api.sync.RecordTypes.OFFER; import static org.briarproject.bramble.api.sync.RecordTypes.PRIORITY; import static org.briarproject.bramble.api.sync.RecordTypes.REQUEST; import static org.briarproject.bramble.api.sync.RecordTypes.VERSIONS; +import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_LENGTH; import static org.briarproject.bramble.api.sync.SyncConstants.MAX_SUPPORTED_VERSIONS; import static org.briarproject.bramble.api.sync.SyncConstants.MESSAGE_HEADER_LENGTH; import static org.briarproject.bramble.api.sync.SyncConstants.PRIORITY_NONCE_BYTES; @@ -126,6 +127,8 @@ class SyncRecordReaderImpl implements SyncRecordReader { byte[] payload = nextRecord.getPayload(); if (payload.length <= MESSAGE_HEADER_LENGTH) throw new FormatException(); + if (payload.length > MAX_MESSAGE_LENGTH) + throw new FormatException(); // Validate timestamp long timestamp = ByteUtils.readUint64(payload, UniqueId.LENGTH); if (timestamp < 0) throw new FormatException(); diff --git a/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java index 2c6b0002b..d4eb61ff5 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java @@ -6,13 +6,18 @@ import org.briarproject.bramble.api.UniqueId; import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.RecordReader; import org.briarproject.bramble.api.sync.Ack; +import org.briarproject.bramble.api.sync.GroupId; +import org.briarproject.bramble.api.sync.Message; import org.briarproject.bramble.api.sync.MessageFactory; +import org.briarproject.bramble.api.sync.MessageId; import org.briarproject.bramble.api.sync.Offer; import org.briarproject.bramble.api.sync.Priority; import org.briarproject.bramble.api.sync.Request; import org.briarproject.bramble.api.sync.SyncRecordReader; import org.briarproject.bramble.api.sync.Versions; import org.briarproject.bramble.test.BrambleMockTestCase; +import org.briarproject.bramble.test.PredicateMatcher; +import org.hamcrest.Matcher; import org.jmock.Expectations; import org.junit.Test; @@ -23,12 +28,15 @@ import javax.annotation.Nullable; import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES; import static org.briarproject.bramble.api.sync.RecordTypes.ACK; +import static org.briarproject.bramble.api.sync.RecordTypes.MESSAGE; import static org.briarproject.bramble.api.sync.RecordTypes.OFFER; import static org.briarproject.bramble.api.sync.RecordTypes.PRIORITY; import static org.briarproject.bramble.api.sync.RecordTypes.REQUEST; import static org.briarproject.bramble.api.sync.RecordTypes.VERSIONS; +import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_BODY_LENGTH; import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_IDS; import static org.briarproject.bramble.api.sync.SyncConstants.MAX_SUPPORTED_VERSIONS; +import static org.briarproject.bramble.api.sync.SyncConstants.MESSAGE_HEADER_LENGTH; import static org.briarproject.bramble.api.sync.SyncConstants.PRIORITY_NONCE_BYTES; import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION; import static org.briarproject.bramble.test.TestUtils.getRandomBytes; @@ -46,6 +54,38 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase { private final SyncRecordReader reader = new SyncRecordReaderImpl(messageFactory, recordReader); + @Test + public void testNoFormatExceptionIfMessageIsMinimumSize() throws Exception { + expectReadRecord(createMessage(MESSAGE_HEADER_LENGTH + 1)); + expectCreateMessage(1); + + reader.readMessage(); + } + + @Test(expected = FormatException.class) + public void testFormatExceptionIfMessageIsTooSmall() throws Exception { + expectReadRecord(createMessage(MESSAGE_HEADER_LENGTH)); + + reader.readMessage(); + } + + @Test + public void testNoFormatExceptionIfMessageIsMaximumSize() throws Exception { + expectReadRecord(createMessage(MESSAGE_HEADER_LENGTH + + MAX_MESSAGE_BODY_LENGTH)); + expectCreateMessage(MAX_MESSAGE_BODY_LENGTH); + + reader.readMessage(); + } + + @Test(expected = FormatException.class) + public void testFormatExceptionIfMessageIsTooLarge() throws Exception { + expectReadRecord(createMessage(MESSAGE_HEADER_LENGTH + + MAX_MESSAGE_BODY_LENGTH + 1)); + + reader.readMessage(); + } + @Test public void testNoFormatExceptionIfAckIsMaximumSize() throws Exception { expectReadRecord(createAck()); @@ -158,6 +198,20 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase { assertTrue(reader.eof()); } + private void expectCreateMessage(int bodyLength) { + MessageId messageId = new MessageId(getRandomId()); + GroupId groupId = new GroupId(getRandomId()); + long timestamp = System.currentTimeMillis(); + + context.checking(new Expectations() {{ + Matcher matcher = new PredicateMatcher<>(byte[].class, + b -> b.length == MESSAGE_HEADER_LENGTH + bodyLength); + oneOf(messageFactory).createMessage(with(matcher)); + will(returnValue(new Message(messageId, groupId, timestamp, + new byte[bodyLength]))); + }}); + } + private void expectReadRecord(@Nullable Record record) throws Exception { context.checking(new Expectations() {{ //noinspection unchecked @@ -167,6 +221,10 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase { }}); } + private Record createMessage(int payloadLength) { + return new Record(PROTOCOL_VERSION, MESSAGE, new byte[payloadLength]); + } + private Record createAck() throws Exception { return new Record(PROTOCOL_VERSION, ACK, createPayload()); }