Reify RecordPredicate for easier testing.

This commit is contained in:
akwizgran
2023-03-10 15:15:29 +00:00
parent f1ae57b213
commit 8d20c5d8b8
9 changed files with 62 additions and 40 deletions

View File

@@ -32,8 +32,15 @@ public interface RecordReader {
* 'accept' or 'ignore' predicates * 'accept' or 'ignore' predicates
*/ */
@Nullable @Nullable
Record readRecord(Predicate<Record> accept, Predicate<Record> ignore) Record readRecord(RecordPredicate accept, RecordPredicate ignore)
throws IOException; throws IOException;
void close() throws IOException; void close() throws IOException;
/**
* Interface that reifies the generic interface {@code Predicate<Record>}
* for easier testing.
*/
interface RecordPredicate extends Predicate<Record> {
}
} }

View File

@@ -1,7 +1,6 @@
package org.briarproject.bramble.contact; package org.briarproject.bramble.contact;
import org.briarproject.bramble.api.FormatException; import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.client.ClientHelper; import org.briarproject.bramble.api.client.ClientHelper;
import org.briarproject.bramble.api.contact.Contact; import org.briarproject.bramble.api.contact.Contact;
import org.briarproject.bramble.api.contact.ContactExchangeManager; import org.briarproject.bramble.api.contact.ContactExchangeManager;
@@ -24,6 +23,7 @@ import org.briarproject.bramble.api.properties.TransportProperties;
import org.briarproject.bramble.api.properties.TransportPropertyManager; import org.briarproject.bramble.api.properties.TransportPropertyManager;
import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader; import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.api.record.RecordReaderFactory; import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter; import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory; import org.briarproject.bramble.api.record.RecordWriterFactory;
@@ -61,12 +61,12 @@ class ContactExchangeManagerImpl implements ContactExchangeManager {
getLogger(ContactExchangeManagerImpl.class.getName()); getLogger(ContactExchangeManagerImpl.class.getName());
// Accept records with current protocol version, known record type // Accept records with current protocol version, known record type
private static final Predicate<Record> ACCEPT = r -> private static final RecordPredicate ACCEPT = r ->
r.getProtocolVersion() == PROTOCOL_VERSION && r.getProtocolVersion() == PROTOCOL_VERSION &&
isKnownRecordType(r.getRecordType()); isKnownRecordType(r.getRecordType());
// Ignore records with current protocol version, unknown record type // Ignore records with current protocol version, unknown record type
private static final Predicate<Record> IGNORE = r -> private static final RecordPredicate IGNORE = r ->
r.getProtocolVersion() == PROTOCOL_VERSION && r.getProtocolVersion() == PROTOCOL_VERSION &&
!isKnownRecordType(r.getRecordType()); !isKnownRecordType(r.getRecordType());

View File

@@ -2,7 +2,6 @@ package org.briarproject.bramble.contact;
import org.briarproject.bramble.api.FormatException; import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.Pair; import org.briarproject.bramble.api.Pair;
import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.contact.ContactManager; import org.briarproject.bramble.api.contact.ContactManager;
import org.briarproject.bramble.api.contact.HandshakeManager; import org.briarproject.bramble.api.contact.HandshakeManager;
import org.briarproject.bramble.api.contact.PendingContact; import org.briarproject.bramble.api.contact.PendingContact;
@@ -12,12 +11,12 @@ import org.briarproject.bramble.api.crypto.KeyPair;
import org.briarproject.bramble.api.crypto.PublicKey; import org.briarproject.bramble.api.crypto.PublicKey;
import org.briarproject.bramble.api.crypto.SecretKey; import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.crypto.TransportCrypto; import org.briarproject.bramble.api.crypto.TransportCrypto;
import org.briarproject.bramble.api.db.DatabaseComponent;
import org.briarproject.bramble.api.db.DbException; import org.briarproject.bramble.api.db.DbException;
import org.briarproject.bramble.api.db.TransactionManager; import org.briarproject.bramble.api.db.TransactionManager;
import org.briarproject.bramble.api.identity.IdentityManager; import org.briarproject.bramble.api.identity.IdentityManager;
import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader; import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.api.record.RecordReaderFactory; import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter; import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory; import org.briarproject.bramble.api.record.RecordWriterFactory;
@@ -44,7 +43,7 @@ import static org.briarproject.bramble.util.ValidationUtils.checkLength;
class HandshakeManagerImpl implements HandshakeManager { class HandshakeManagerImpl implements HandshakeManager {
// Ignore records with current protocol version, unknown record type // Ignore records with current protocol version, unknown record type
private static final Predicate<Record> IGNORE = r -> private static final RecordPredicate IGNORE = r ->
r.getProtocolVersion() == PROTOCOL_VERSION && r.getProtocolVersion() == PROTOCOL_VERSION &&
!isKnownRecordType(r.getRecordType()); !isKnownRecordType(r.getRecordType());
@@ -61,7 +60,7 @@ class HandshakeManagerImpl implements HandshakeManager {
private final RecordWriterFactory recordWriterFactory; private final RecordWriterFactory recordWriterFactory;
@Inject @Inject
HandshakeManagerImpl(DatabaseComponent db, HandshakeManagerImpl(TransactionManager db,
IdentityManager identityManager, IdentityManager identityManager,
ContactManager contactManager, ContactManager contactManager,
TransportCrypto transportCrypto, TransportCrypto transportCrypto,
@@ -152,8 +151,8 @@ class HandshakeManagerImpl implements HandshakeManager {
private Record readRecord(RecordReader r, byte expectedType) private Record readRecord(RecordReader r, byte expectedType)
throws IOException { throws IOException {
// Accept records with current protocol version, expected type only // Accept records with current protocol version, expected types only
Predicate<Record> accept = rec -> RecordPredicate accept = rec ->
rec.getProtocolVersion() == PROTOCOL_VERSION && rec.getProtocolVersion() == PROTOCOL_VERSION &&
rec.getRecordType() == expectedType; rec.getRecordType() == expectedType;
Record rec = r.readRecord(accept, IGNORE); Record rec = r.readRecord(accept, IGNORE);

View File

@@ -1,11 +1,11 @@
package org.briarproject.bramble.keyagreement; package org.briarproject.bramble.keyagreement;
import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection; import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection; import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader; import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.api.record.RecordReaderFactory; import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter; import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory; import org.briarproject.bramble.api.record.RecordWriterFactory;
@@ -34,12 +34,12 @@ class KeyAgreementTransport {
Logger.getLogger(KeyAgreementTransport.class.getName()); Logger.getLogger(KeyAgreementTransport.class.getName());
// Accept records with current protocol version, known record type // Accept records with current protocol version, known record type
private static final Predicate<Record> ACCEPT = r -> private static final RecordPredicate ACCEPT = r ->
r.getProtocolVersion() == PROTOCOL_VERSION && r.getProtocolVersion() == PROTOCOL_VERSION &&
isKnownRecordType(r.getRecordType()); isKnownRecordType(r.getRecordType());
// Ignore records with current protocol version, unknown record type // Ignore records with current protocol version, unknown record type
private static final Predicate<Record> IGNORE = r -> private static final RecordPredicate IGNORE = r ->
r.getProtocolVersion() == PROTOCOL_VERSION && r.getProtocolVersion() == PROTOCOL_VERSION &&
!isKnownRecordType(r.getRecordType()); !isKnownRecordType(r.getRecordType());

View File

@@ -1,7 +1,6 @@
package org.briarproject.bramble.record; package org.briarproject.bramble.record;
import org.briarproject.bramble.api.FormatException; import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader; import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.util.ByteUtils; import org.briarproject.bramble.util.ByteUtils;
@@ -45,7 +44,7 @@ class RecordReaderImpl implements RecordReader {
@Nullable @Nullable
@Override @Override
public Record readRecord(Predicate<Record> accept, Predicate<Record> ignore) public Record readRecord(RecordPredicate accept, RecordPredicate ignore)
throws IOException { throws IOException {
while (true) { while (true) {
if (eof()) return null; if (eof()) return null;

View File

@@ -1,10 +1,10 @@
package org.briarproject.bramble.sync; package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.FormatException; import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.UniqueId; import org.briarproject.bramble.api.UniqueId;
import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader; import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.api.sync.Ack; import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.Message; import org.briarproject.bramble.api.sync.Message;
import org.briarproject.bramble.api.sync.MessageFactory; import org.briarproject.bramble.api.sync.MessageFactory;
@@ -41,12 +41,12 @@ import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION;
class SyncRecordReaderImpl implements SyncRecordReader { class SyncRecordReaderImpl implements SyncRecordReader {
// Accept records with current protocol version, known record type // Accept records with current protocol version, known record type
private static final Predicate<Record> ACCEPT = r -> private static final RecordPredicate ACCEPT = r ->
r.getProtocolVersion() == PROTOCOL_VERSION && r.getProtocolVersion() == PROTOCOL_VERSION &&
isKnownRecordType(r.getRecordType()); isKnownRecordType(r.getRecordType());
// Ignore records with current protocol version, unknown record type // Ignore records with current protocol version, unknown record type
private static final Predicate<Record> IGNORE = r -> private static final RecordPredicate IGNORE = r ->
r.getProtocolVersion() == PROTOCOL_VERSION && r.getProtocolVersion() == PROTOCOL_VERSION &&
!isKnownRecordType(r.getRecordType()); !isKnownRecordType(r.getRecordType());

View File

@@ -1,6 +1,5 @@
package org.briarproject.bramble.keyagreement; package org.briarproject.bramble.keyagreement;
import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection; import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection;
import org.briarproject.bramble.api.plugin.TransportConnectionReader; import org.briarproject.bramble.api.plugin.TransportConnectionReader;
import org.briarproject.bramble.api.plugin.TransportConnectionWriter; import org.briarproject.bramble.api.plugin.TransportConnectionWriter;
@@ -8,11 +7,13 @@ import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection; import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader; import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.api.record.RecordReaderFactory; import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter; import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory; import org.briarproject.bramble.api.record.RecordWriterFactory;
import org.briarproject.bramble.test.BrambleMockTestCase; import org.briarproject.bramble.test.BrambleMockTestCase;
import org.briarproject.bramble.test.CaptureArgumentAction; import org.briarproject.bramble.test.CaptureArgumentAction;
import org.briarproject.bramble.test.PredicateMatcher;
import org.jmock.Expectations; import org.jmock.Expectations;
import org.jmock.imposters.ByteBuddyClassImposteriser; import org.jmock.imposters.ByteBuddyClassImposteriser;
import org.junit.Test; import org.junit.Test;
@@ -21,8 +22,6 @@ import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION; import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT; import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM; import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM;
@@ -119,7 +118,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
public void testReceiveKeyThrowsExceptionIfAtEndOfStream() public void testReceiveKeyThrowsExceptionIfAtEndOfStream()
throws Exception { throws Exception {
setup(); setup();
expectReadRecord(null); expectReadEof();
kat.receiveKey(); kat.receiveKey();
} }
@@ -148,7 +147,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
public void testReceiveConfirmThrowsExceptionIfAtEndOfStream() public void testReceiveConfirmThrowsExceptionIfAtEndOfStream()
throws Exception { throws Exception {
setup(); setup();
expectReadRecord(null); expectReadEof();
kat.receiveConfirm(); kat.receiveConfirm();
} }
@@ -209,12 +208,22 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
assertArrayEquals(expectedPayload, actual.getPayload()); assertArrayEquals(expectedPayload, actual.getPayload());
} }
private void expectReadRecord(@Nullable Record record) throws Exception { private void expectReadRecord(Record record) throws Exception {
context.checking(new Expectations() {{ context.checking(new Expectations() {{
//noinspection unchecked // Test that the `accept` predicate passed to the reader would
oneOf(recordReader).readRecord(with(any(Predicate.class)), // accept the expected record
with(any(Predicate.class))); oneOf(recordReader).readRecord(with(new PredicateMatcher<>(
RecordPredicate.class, rp -> rp.test(record))),
with(any(RecordPredicate.class)));
will(returnValue(record)); will(returnValue(record));
}}); }});
} }
private void expectReadEof() throws Exception {
context.checking(new Expectations() {{
oneOf(recordReader).readRecord(with(any(RecordPredicate.class)),
with(any(RecordPredicate.class)));
will(returnValue(null));
}});
}
} }

View File

@@ -1,9 +1,9 @@
package org.briarproject.bramble.record; package org.briarproject.bramble.record;
import org.briarproject.bramble.api.FormatException; import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader; import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.test.BrambleTestCase; import org.briarproject.bramble.test.BrambleTestCase;
import org.briarproject.bramble.util.ByteUtils; import org.briarproject.bramble.util.ByteUtils;
import org.junit.Test; import org.junit.Test;
@@ -128,12 +128,12 @@ public class RecordReaderImplTest extends BrambleTestCase {
RecordReader reader = new RecordReaderImpl(in); RecordReader reader = new RecordReaderImpl(in);
// Accept records with version 0, type 0 or 1 // Accept records with version 0, type 0 or 1
Predicate<Record> accept = r -> { RecordPredicate accept = r -> {
byte version = r.getProtocolVersion(), type = r.getRecordType(); byte version = r.getProtocolVersion(), type = r.getRecordType();
return version == 0 && (type == 0 || type == 1); return version == 0 && (type == 0 || type == 1);
}; };
// Ignore records with version 0, any other type // Ignore records with version 0, any other type
Predicate<Record> ignore = r -> { RecordPredicate ignore = r -> {
byte version = r.getProtocolVersion(), type = r.getRecordType(); byte version = r.getProtocolVersion(), type = r.getRecordType();
return version == 0 && !(type == 0 || type == 1); return version == 0 && !(type == 0 || type == 1);
}; };
@@ -183,12 +183,12 @@ public class RecordReaderImplTest extends BrambleTestCase {
RecordReader reader = new RecordReaderImpl(in); RecordReader reader = new RecordReaderImpl(in);
// Accept records with version 0, type 0 or 1 // Accept records with version 0, type 0 or 1
Predicate<Record> accept = r -> { RecordPredicate accept = r -> {
byte version = r.getProtocolVersion(), type = r.getRecordType(); byte version = r.getProtocolVersion(), type = r.getRecordType();
return version == 0 && (type == 0 || type == 1); return version == 0 && (type == 0 || type == 1);
}; };
// Ignore records with version 0, any other type // Ignore records with version 0, any other type
Predicate<Record> ignore = r -> { RecordPredicate ignore = r -> {
byte version = r.getProtocolVersion(), type = r.getRecordType(); byte version = r.getProtocolVersion(), type = r.getRecordType();
return version == 0 && !(type == 0 || type == 1); return version == 0 && !(type == 0 || type == 1);
}; };

View File

@@ -1,10 +1,10 @@
package org.briarproject.bramble.sync; package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.FormatException; import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.UniqueId; import org.briarproject.bramble.api.UniqueId;
import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader; import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.api.sync.Ack; import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.GroupId; import org.briarproject.bramble.api.sync.GroupId;
import org.briarproject.bramble.api.sync.Message; import org.briarproject.bramble.api.sync.Message;
@@ -24,8 +24,6 @@ import org.junit.Test;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.util.List; import java.util.List;
import javax.annotation.Nullable;
import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES; import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES;
import static org.briarproject.bramble.api.sync.RecordTypes.ACK; import static org.briarproject.bramble.api.sync.RecordTypes.ACK;
import static org.briarproject.bramble.api.sync.RecordTypes.MESSAGE; import static org.briarproject.bramble.api.sync.RecordTypes.MESSAGE;
@@ -186,7 +184,7 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase {
@Test @Test
public void testEofReturnsTrueWhenAtEndOfStream() throws Exception { public void testEofReturnsTrueWhenAtEndOfStream() throws Exception {
expectReadRecord(createAck()); expectReadRecord(createAck());
expectReadRecord(null); expectReadEof();
SyncRecordReader reader = SyncRecordReader reader =
new SyncRecordReaderImpl(messageFactory, recordReader); new SyncRecordReaderImpl(messageFactory, recordReader);
@@ -212,15 +210,25 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase {
}}); }});
} }
private void expectReadRecord(@Nullable Record record) throws Exception { private void expectReadRecord(Record record) throws Exception {
context.checking(new Expectations() {{ context.checking(new Expectations() {{
//noinspection unchecked // Test that the `accept` predicate passed to the reader would
oneOf(recordReader).readRecord(with(any(Predicate.class)), // accept the expected record
with(any(Predicate.class))); oneOf(recordReader).readRecord(with(new PredicateMatcher<>(
RecordPredicate.class, rp -> rp.test(record))),
with(any(RecordPredicate.class)));
will(returnValue(record)); will(returnValue(record));
}}); }});
} }
private void expectReadEof() throws Exception {
context.checking(new Expectations() {{
oneOf(recordReader).readRecord(with(any(RecordPredicate.class)),
with(any(RecordPredicate.class)));
will(returnValue(null));
}});
}
private Record createMessage(int payloadLength) { private Record createMessage(int payloadLength) {
return new Record(PROTOCOL_VERSION, MESSAGE, new byte[payloadLength]); return new Record(PROTOCOL_VERSION, MESSAGE, new byte[payloadLength]);
} }