diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/crypto/CryptoComponent.java b/bramble-api/src/main/java/org/briarproject/bramble/api/crypto/CryptoComponent.java index cf45b4316..2ce40ca18 100644 --- a/bramble-api/src/main/java/org/briarproject/bramble/api/crypto/CryptoComponent.java +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/crypto/CryptoComponent.java @@ -54,6 +54,38 @@ public interface CryptoComponent { KeyPair ourKeyPair, byte[]... inputs) throws GeneralSecurityException; + /** + * Derives a shared secret from two static and two ephemeral key pairs. + *

+ * Do not use this method for new protocols. The shared secret can be + * re-derived using the ephemeral public keys and both static private + * keys, so keys derived from the shared secret should not be used if + * forward secrecy is required. Use {@link #deriveSharedSecret(String, + * PublicKey, PublicKey, KeyPair, KeyPair, boolean, byte[]...)} instead. + *

+ * TODO: Remove this after a reasonable migration period (added 2023-03-10). + *

+ * + * @param label A namespaced label indicating the purpose of this shared + * secret, to prevent it from being repurposed or colliding with a shared + * secret derived for another purpose + * @param theirStaticPublicKey The static public key of the remote party + * @param theirEphemeralPublicKey The ephemeral public key of the remote + * party + * @param ourStaticKeyPair The static key pair of the local party + * @param ourEphemeralKeyPair The ephemeral key pair of the local party + * @param alice True if the local party is Alice + * @param inputs Additional inputs that will be included in the + * derivation of the shared secret + * @return The shared secret + */ + @Deprecated + SecretKey deriveSharedSecretBadly(String label, + PublicKey theirStaticPublicKey, PublicKey theirEphemeralPublicKey, + KeyPair ourStaticKeyPair, KeyPair ourEphemeralKeyPair, + boolean alice, byte[]... inputs) + throws GeneralSecurityException; + /** * Derives a shared secret from two static and two ephemeral key pairs. * diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/record/RecordReader.java b/bramble-api/src/main/java/org/briarproject/bramble/api/record/RecordReader.java index ada8f1940..3076068c7 100644 --- a/bramble-api/src/main/java/org/briarproject/bramble/api/record/RecordReader.java +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/record/RecordReader.java @@ -32,8 +32,15 @@ public interface RecordReader { * 'accept' or 'ignore' predicates */ @Nullable - Record readRecord(Predicate accept, Predicate ignore) + Record readRecord(RecordPredicate accept, RecordPredicate ignore) throws IOException; void close() throws IOException; + + /** + * Interface that reifies the generic interface {@code Predicate} + * for easier testing. + */ + interface RecordPredicate extends Predicate { + } } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactExchangeManagerImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactExchangeManagerImpl.java index 5040cf959..2875ad566 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactExchangeManagerImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactExchangeManagerImpl.java @@ -1,7 +1,6 @@ package org.briarproject.bramble.contact; import org.briarproject.bramble.api.FormatException; -import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.client.ClientHelper; import org.briarproject.bramble.api.contact.Contact; import org.briarproject.bramble.api.contact.ContactExchangeManager; @@ -24,6 +23,7 @@ import org.briarproject.bramble.api.properties.TransportProperties; import org.briarproject.bramble.api.properties.TransportPropertyManager; import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.RecordReader; +import org.briarproject.bramble.api.record.RecordReader.RecordPredicate; import org.briarproject.bramble.api.record.RecordReaderFactory; import org.briarproject.bramble.api.record.RecordWriter; import org.briarproject.bramble.api.record.RecordWriterFactory; @@ -61,12 +61,12 @@ class ContactExchangeManagerImpl implements ContactExchangeManager { getLogger(ContactExchangeManagerImpl.class.getName()); // Accept records with current protocol version, known record type - private static final Predicate ACCEPT = r -> + private static final RecordPredicate ACCEPT = r -> r.getProtocolVersion() == PROTOCOL_VERSION && isKnownRecordType(r.getRecordType()); // Ignore records with current protocol version, unknown record type - private static final Predicate IGNORE = r -> + private static final RecordPredicate IGNORE = r -> r.getProtocolVersion() == PROTOCOL_VERSION && !isKnownRecordType(r.getRecordType()); diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeConstants.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeConstants.java index 37e7a41d3..45470f5b8 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeConstants.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeConstants.java @@ -5,14 +5,31 @@ import static org.briarproject.bramble.api.crypto.CryptoConstants.MAC_BYTES; interface HandshakeConstants { /** - * The current version of the handshake protocol. + * The current major version of the handshake protocol. */ - byte PROTOCOL_VERSION = 0; + byte PROTOCOL_MAJOR_VERSION = 0; /** - * Label for deriving the master key. + * The current minor version of the handshake protocol. */ - String MASTER_KEY_LABEL = "org.briarproject.bramble.handshake/MASTER_KEY"; + byte PROTOCOL_MINOR_VERSION = 1; + + /** + * Label for deriving the master key when using the deprecated v0.0 key + * derivation method. + *

+ * TODO: Remove this after a reasonable migration period (added 2023-03-10). + */ + @Deprecated + String MASTER_KEY_LABEL_0_0 = + "org.briarproject.bramble.handshake/MASTER_KEY"; + + /** + * Label for deriving the master key when using the v0.1 key derivation + * method. + */ + String MASTER_KEY_LABEL_0_1 = + "org.briarproject.bramble.handshake/MASTER_KEY_0_1"; /** * Label for deriving Alice's proof of ownership from the master key. diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCrypto.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCrypto.java index b7bf351ca..ce9a656eb 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCrypto.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCrypto.java @@ -13,11 +13,26 @@ interface HandshakeCrypto { KeyPair generateEphemeralKeyPair(); /** - * Derives the master key from the given static and ephemeral keys. + * Derives the master key from the given static and ephemeral keys using + * the deprecated v0.0 key derivation method. + *

+ * TODO: Remove this after a reasonable migration period (added 2023-03-10). * * @param alice Whether the local peer is Alice */ - SecretKey deriveMasterKey(PublicKey theirStaticPublicKey, + @Deprecated + SecretKey deriveMasterKey_0_0(PublicKey theirStaticPublicKey, + PublicKey theirEphemeralPublicKey, KeyPair ourStaticKeyPair, + KeyPair ourEphemeralKeyPair, boolean alice) + throws GeneralSecurityException; + + /** + * Derives the master key from the given static and ephemeral keys using + * the v0.1 key derivation method. + * + * @param alice Whether the local peer is Alice + */ + SecretKey deriveMasterKey_0_1(PublicKey theirStaticPublicKey, PublicKey theirEphemeralPublicKey, KeyPair ourStaticKeyPair, KeyPair ourEphemeralKeyPair, boolean alice) throws GeneralSecurityException; diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCryptoImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCryptoImpl.java index f5ebb5e1d..7606ec938 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCryptoImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCryptoImpl.java @@ -13,7 +13,8 @@ import javax.inject.Inject; import static org.briarproject.bramble.contact.HandshakeConstants.ALICE_PROOF_LABEL; import static org.briarproject.bramble.contact.HandshakeConstants.BOB_PROOF_LABEL; -import static org.briarproject.bramble.contact.HandshakeConstants.MASTER_KEY_LABEL; +import static org.briarproject.bramble.contact.HandshakeConstants.MASTER_KEY_LABEL_0_0; +import static org.briarproject.bramble.contact.HandshakeConstants.MASTER_KEY_LABEL_0_1; @Immutable @NotNullByDefault @@ -32,7 +33,8 @@ class HandshakeCryptoImpl implements HandshakeCrypto { } @Override - public SecretKey deriveMasterKey(PublicKey theirStaticPublicKey, + @Deprecated + public SecretKey deriveMasterKey_0_0(PublicKey theirStaticPublicKey, PublicKey theirEphemeralPublicKey, KeyPair ourStaticKeyPair, KeyPair ourEphemeralKeyPair, boolean alice) throws GeneralSecurityException { @@ -46,9 +48,29 @@ class HandshakeCryptoImpl implements HandshakeCrypto { alice ? ourEphemeral : theirEphemeral, alice ? theirEphemeral : ourEphemeral }; - return crypto.deriveSharedSecret(MASTER_KEY_LABEL, theirStaticPublicKey, - theirEphemeralPublicKey, ourStaticKeyPair, ourEphemeralKeyPair, - alice, inputs); + return crypto.deriveSharedSecretBadly(MASTER_KEY_LABEL_0_0, + theirStaticPublicKey, theirEphemeralPublicKey, + ourStaticKeyPair, ourEphemeralKeyPair, alice, inputs); + } + + @Override + public SecretKey deriveMasterKey_0_1(PublicKey theirStaticPublicKey, + PublicKey theirEphemeralPublicKey, KeyPair ourStaticKeyPair, + KeyPair ourEphemeralKeyPair, boolean alice) throws + GeneralSecurityException { + byte[] theirStatic = theirStaticPublicKey.getEncoded(); + byte[] theirEphemeral = theirEphemeralPublicKey.getEncoded(); + byte[] ourStatic = ourStaticKeyPair.getPublic().getEncoded(); + byte[] ourEphemeral = ourEphemeralKeyPair.getPublic().getEncoded(); + byte[][] inputs = { + alice ? ourStatic : theirStatic, + alice ? theirStatic : ourStatic, + alice ? ourEphemeral : theirEphemeral, + alice ? theirEphemeral : ourEphemeral + }; + return crypto.deriveSharedSecret(MASTER_KEY_LABEL_0_1, + theirStaticPublicKey, theirEphemeralPublicKey, + ourStaticKeyPair, ourEphemeralKeyPair, alice, inputs); } @Override diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeManagerImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeManagerImpl.java index acd1c80e7..496fdc482 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeManagerImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeManagerImpl.java @@ -2,7 +2,6 @@ package org.briarproject.bramble.contact; import org.briarproject.bramble.api.FormatException; import org.briarproject.bramble.api.Pair; -import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.contact.ContactManager; import org.briarproject.bramble.api.contact.HandshakeManager; import org.briarproject.bramble.api.contact.PendingContact; @@ -12,12 +11,12 @@ import org.briarproject.bramble.api.crypto.KeyPair; import org.briarproject.bramble.api.crypto.PublicKey; import org.briarproject.bramble.api.crypto.SecretKey; import org.briarproject.bramble.api.crypto.TransportCrypto; -import org.briarproject.bramble.api.db.DatabaseComponent; import org.briarproject.bramble.api.db.DbException; import org.briarproject.bramble.api.db.TransactionManager; import org.briarproject.bramble.api.identity.IdentityManager; import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.RecordReader; +import org.briarproject.bramble.api.record.RecordReader.RecordPredicate; import org.briarproject.bramble.api.record.RecordReaderFactory; import org.briarproject.bramble.api.record.RecordWriter; import org.briarproject.bramble.api.record.RecordWriterFactory; @@ -28,15 +27,20 @@ import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.security.GeneralSecurityException; +import java.util.List; import javax.annotation.concurrent.Immutable; import javax.inject.Inject; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; import static org.briarproject.bramble.api.crypto.CryptoConstants.MAX_AGREEMENT_PUBLIC_KEY_BYTES; import static org.briarproject.bramble.contact.HandshakeConstants.PROOF_BYTES; -import static org.briarproject.bramble.contact.HandshakeConstants.PROTOCOL_VERSION; -import static org.briarproject.bramble.contact.HandshakeRecordTypes.EPHEMERAL_PUBLIC_KEY; -import static org.briarproject.bramble.contact.HandshakeRecordTypes.PROOF_OF_OWNERSHIP; +import static org.briarproject.bramble.contact.HandshakeConstants.PROTOCOL_MAJOR_VERSION; +import static org.briarproject.bramble.contact.HandshakeConstants.PROTOCOL_MINOR_VERSION; +import static org.briarproject.bramble.contact.HandshakeRecordTypes.RECORD_TYPE_EPHEMERAL_PUBLIC_KEY; +import static org.briarproject.bramble.contact.HandshakeRecordTypes.RECORD_TYPE_MINOR_VERSION; +import static org.briarproject.bramble.contact.HandshakeRecordTypes.RECORD_TYPE_PROOF_OF_OWNERSHIP; import static org.briarproject.bramble.util.ValidationUtils.checkLength; @Immutable @@ -44,12 +48,14 @@ import static org.briarproject.bramble.util.ValidationUtils.checkLength; class HandshakeManagerImpl implements HandshakeManager { // Ignore records with current protocol version, unknown record type - private static final Predicate IGNORE = r -> - r.getProtocolVersion() == PROTOCOL_VERSION && + private static final RecordPredicate IGNORE = r -> + r.getProtocolVersion() == PROTOCOL_MAJOR_VERSION && !isKnownRecordType(r.getRecordType()); private static boolean isKnownRecordType(byte type) { - return type == EPHEMERAL_PUBLIC_KEY || type == PROOF_OF_OWNERSHIP; + return type == RECORD_TYPE_EPHEMERAL_PUBLIC_KEY || + type == RECORD_TYPE_PROOF_OF_OWNERSHIP || + type == RECORD_TYPE_MINOR_VERSION; } private final TransactionManager db; @@ -61,7 +67,7 @@ class HandshakeManagerImpl implements HandshakeManager { private final RecordWriterFactory recordWriterFactory; @Inject - HandshakeManagerImpl(DatabaseComponent db, + HandshakeManagerImpl(TransactionManager db, IdentityManager identityManager, ContactManager contactManager, TransportCrypto transportCrypto, @@ -95,19 +101,31 @@ class HandshakeManagerImpl implements HandshakeManager { .createRecordWriter(out.getOutputStream()); KeyPair ourEphemeralKeyPair = handshakeCrypto.generateEphemeralKeyPair(); - PublicKey theirEphemeralPublicKey; + Pair theirMinorVersionAndKey; if (alice) { + sendMinorVersion(recordWriter); sendPublicKey(recordWriter, ourEphemeralKeyPair.getPublic()); - theirEphemeralPublicKey = receivePublicKey(recordReader); + theirMinorVersionAndKey = receiveMinorVersionAndKey(recordReader); } else { - theirEphemeralPublicKey = receivePublicKey(recordReader); + theirMinorVersionAndKey = receiveMinorVersionAndKey(recordReader); + sendMinorVersion(recordWriter); sendPublicKey(recordWriter, ourEphemeralKeyPair.getPublic()); } + byte theirMinorVersion = theirMinorVersionAndKey.getFirst(); + PublicKey theirEphemeralPublicKey = theirMinorVersionAndKey.getSecond(); SecretKey masterKey; try { - masterKey = handshakeCrypto.deriveMasterKey(theirStaticPublicKey, - theirEphemeralPublicKey, ourStaticKeyPair, - ourEphemeralKeyPair, alice); + if (theirMinorVersion > 0) { + masterKey = handshakeCrypto.deriveMasterKey_0_1( + theirStaticPublicKey, theirEphemeralPublicKey, + ourStaticKeyPair, ourEphemeralKeyPair, alice); + } else { + // TODO: Remove this branch after a reasonable migration + // period (added 2023-03-10). + masterKey = handshakeCrypto.deriveMasterKey_0_0( + theirStaticPublicKey, theirEphemeralPublicKey, + ourStaticKeyPair, ourEphemeralKeyPair, alice); + } } catch (GeneralSecurityException e) { throw new FormatException(); } @@ -128,34 +146,91 @@ class HandshakeManagerImpl implements HandshakeManager { } private void sendPublicKey(RecordWriter w, PublicKey k) throws IOException { - w.writeRecord(new Record(PROTOCOL_VERSION, EPHEMERAL_PUBLIC_KEY, - k.getEncoded())); + w.writeRecord(new Record(PROTOCOL_MAJOR_VERSION, + RECORD_TYPE_EPHEMERAL_PUBLIC_KEY, k.getEncoded())); w.flush(); } - private PublicKey receivePublicKey(RecordReader r) throws IOException { - byte[] key = readRecord(r, EPHEMERAL_PUBLIC_KEY).getPayload(); + /** + * Receives the remote peer's protocol minor version and ephemeral public + * key. + *

+ * In version 0.1 of the protocol, each peer sends a minor version record + * followed by an ephemeral public key record. + *

+ * In version 0.0 of the protocol, each peer sends an ephemeral public key + * record without a preceding minor version record. + *

+ * Therefore the remote peer's minor version must be non-zero if a minor + * version record is received, and is assumed to be zero if no minor + * version record is received. + */ + private Pair receiveMinorVersionAndKey(RecordReader r) + throws IOException { + byte theirMinorVersion; + PublicKey theirEphemeralPublicKey; + // The first record can be either a minor version record or an + // ephemeral public key record + Record first = readRecord(r, asList(RECORD_TYPE_MINOR_VERSION, + RECORD_TYPE_EPHEMERAL_PUBLIC_KEY)); + if (first.getRecordType() == RECORD_TYPE_MINOR_VERSION) { + // The payload must be a single byte giving the remote peer's + // protocol minor version, which must be non-zero + byte[] payload = first.getPayload(); + checkLength(payload, 1); + theirMinorVersion = payload[0]; + if (theirMinorVersion == 0) throw new FormatException(); + // The second record must be an ephemeral public key record + Record second = readRecord(r, + singletonList(RECORD_TYPE_EPHEMERAL_PUBLIC_KEY)); + theirEphemeralPublicKey = parsePublicKey(second); + } else { + // The remote peer did not send a minor version record, so the + // remote peer's protocol minor version is assumed to be zero + // TODO: Remove this branch after a reasonable migration period + // (added 2023-03-10). + theirMinorVersion = 0; + theirEphemeralPublicKey = parsePublicKey(first); + } + return new Pair<>(theirMinorVersion, theirEphemeralPublicKey); + } + + private PublicKey parsePublicKey(Record rec) throws IOException { + if (rec.getRecordType() != RECORD_TYPE_EPHEMERAL_PUBLIC_KEY) { + throw new AssertionError(); + } + byte[] key = rec.getPayload(); checkLength(key, 1, MAX_AGREEMENT_PUBLIC_KEY_BYTES); return new AgreementPublicKey(key); } private void sendProof(RecordWriter w, byte[] proof) throws IOException { - w.writeRecord(new Record(PROTOCOL_VERSION, PROOF_OF_OWNERSHIP, proof)); + w.writeRecord(new Record(PROTOCOL_MAJOR_VERSION, + RECORD_TYPE_PROOF_OF_OWNERSHIP, proof)); w.flush(); } private byte[] receiveProof(RecordReader r) throws IOException { - byte[] proof = readRecord(r, PROOF_OF_OWNERSHIP).getPayload(); + Record rec = readRecord(r, + singletonList(RECORD_TYPE_PROOF_OF_OWNERSHIP)); + byte[] proof = rec.getPayload(); checkLength(proof, PROOF_BYTES, PROOF_BYTES); return proof; } - private Record readRecord(RecordReader r, byte expectedType) + private void sendMinorVersion(RecordWriter w) throws IOException { + w.writeRecord(new Record(PROTOCOL_MAJOR_VERSION, + RECORD_TYPE_MINOR_VERSION, + new byte[] {PROTOCOL_MINOR_VERSION})); + w.flush(); + } + + private Record readRecord(RecordReader r, List expectedTypes) throws IOException { - // Accept records with current protocol version, expected type only - Predicate accept = rec -> - rec.getProtocolVersion() == PROTOCOL_VERSION && - rec.getRecordType() == expectedType; + // Accept records with current protocol version, expected types only + RecordPredicate accept = rec -> + rec.getProtocolVersion() == PROTOCOL_MAJOR_VERSION && + expectedTypes.contains(rec.getRecordType()); Record rec = r.readRecord(accept, IGNORE); if (rec == null) throw new EOFException(); return rec; diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeRecordTypes.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeRecordTypes.java index 7c38d9a84..9875c1a25 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeRecordTypes.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeRecordTypes.java @@ -5,7 +5,9 @@ package org.briarproject.bramble.contact; */ interface HandshakeRecordTypes { - byte EPHEMERAL_PUBLIC_KEY = 0; + byte RECORD_TYPE_EPHEMERAL_PUBLIC_KEY = 0; - byte PROOF_OF_OWNERSHIP = 1; + byte RECORD_TYPE_PROOF_OF_OWNERSHIP = 1; + + byte RECORD_TYPE_MINOR_VERSION = 2; } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/crypto/CryptoComponentImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/crypto/CryptoComponentImpl.java index 06ce15f29..1bb7fe9f8 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/crypto/CryptoComponentImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/crypto/CryptoComponentImpl.java @@ -222,7 +222,8 @@ class CryptoComponentImpl implements CryptoComponent { } @Override - public SecretKey deriveSharedSecret(String label, + @Deprecated + public SecretKey deriveSharedSecretBadly(String label, PublicKey theirStaticPublicKey, PublicKey theirEphemeralPublicKey, KeyPair ourStaticKeyPair, KeyPair ourEphemeralKeyPair, boolean alice, byte[]... inputs) throws GeneralSecurityException { @@ -250,6 +251,35 @@ class CryptoComponentImpl implements CryptoComponent { return new SecretKey(hash); } + @Override + public SecretKey deriveSharedSecret(String label, + PublicKey theirStaticPublicKey, PublicKey theirEphemeralPublicKey, + KeyPair ourStaticKeyPair, KeyPair ourEphemeralKeyPair, + boolean alice, byte[]... inputs) throws GeneralSecurityException { + PrivateKey ourStaticPrivateKey = ourStaticKeyPair.getPrivate(); + PrivateKey ourEphemeralPrivateKey = ourEphemeralKeyPair.getPrivate(); + byte[][] hashInputs = new byte[inputs.length + 3][]; + // Alice ephemeral/Bob ephemeral + hashInputs[0] = performRawKeyAgreement(ourEphemeralPrivateKey, + theirEphemeralPublicKey); + // Alice static/Bob ephemeral, Bob static/Alice ephemeral + if (alice) { + hashInputs[1] = performRawKeyAgreement(ourStaticPrivateKey, + theirEphemeralPublicKey); + hashInputs[2] = performRawKeyAgreement(ourEphemeralPrivateKey, + theirStaticPublicKey); + } else { + hashInputs[1] = performRawKeyAgreement(ourEphemeralPrivateKey, + theirStaticPublicKey); + hashInputs[2] = performRawKeyAgreement(ourStaticPrivateKey, + theirEphemeralPublicKey); + } + arraycopy(inputs, 0, hashInputs, 3, inputs.length); + byte[] hash = hash(label, hashInputs); + if (hash.length != SecretKey.LENGTH) throw new IllegalStateException(); + return new SecretKey(hash); + } + @Override public byte[] sign(String label, byte[] toSign, PrivateKey privateKey) throws GeneralSecurityException { diff --git a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java index e5e1e2ca3..bd09ceb09 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java @@ -1,11 +1,11 @@ package org.briarproject.bramble.keyagreement; -import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection; import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection; import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.RecordReader; +import org.briarproject.bramble.api.record.RecordReader.RecordPredicate; import org.briarproject.bramble.api.record.RecordReaderFactory; import org.briarproject.bramble.api.record.RecordWriter; import org.briarproject.bramble.api.record.RecordWriterFactory; @@ -34,12 +34,12 @@ class KeyAgreementTransport { Logger.getLogger(KeyAgreementTransport.class.getName()); // Accept records with current protocol version, known record type - private static final Predicate ACCEPT = r -> + private static final RecordPredicate ACCEPT = r -> r.getProtocolVersion() == PROTOCOL_VERSION && isKnownRecordType(r.getRecordType()); // Ignore records with current protocol version, unknown record type - private static final Predicate IGNORE = r -> + private static final RecordPredicate IGNORE = r -> r.getProtocolVersion() == PROTOCOL_VERSION && !isKnownRecordType(r.getRecordType()); diff --git a/bramble-core/src/main/java/org/briarproject/bramble/record/RecordReaderImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/record/RecordReaderImpl.java index fb6abb151..872c937fe 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/record/RecordReaderImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/record/RecordReaderImpl.java @@ -1,7 +1,6 @@ package org.briarproject.bramble.record; import org.briarproject.bramble.api.FormatException; -import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.RecordReader; import org.briarproject.bramble.util.ByteUtils; @@ -45,7 +44,7 @@ class RecordReaderImpl implements RecordReader { @Nullable @Override - public Record readRecord(Predicate accept, Predicate ignore) + public Record readRecord(RecordPredicate accept, RecordPredicate ignore) throws IOException { while (true) { if (eof()) return null; diff --git a/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java index 250f72953..ee247e4b6 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java @@ -1,10 +1,10 @@ package org.briarproject.bramble.sync; import org.briarproject.bramble.api.FormatException; -import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.UniqueId; import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.RecordReader; +import org.briarproject.bramble.api.record.RecordReader.RecordPredicate; import org.briarproject.bramble.api.sync.Ack; import org.briarproject.bramble.api.sync.Message; import org.briarproject.bramble.api.sync.MessageFactory; @@ -41,12 +41,12 @@ import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION; class SyncRecordReaderImpl implements SyncRecordReader { // Accept records with current protocol version, known record type - private static final Predicate ACCEPT = r -> + private static final RecordPredicate ACCEPT = r -> r.getProtocolVersion() == PROTOCOL_VERSION && isKnownRecordType(r.getRecordType()); // Ignore records with current protocol version, unknown record type - private static final Predicate IGNORE = r -> + private static final RecordPredicate IGNORE = r -> r.getProtocolVersion() == PROTOCOL_VERSION && !isKnownRecordType(r.getRecordType()); diff --git a/bramble-core/src/test/java/org/briarproject/bramble/contact/HandshakeManagerImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/contact/HandshakeManagerImplTest.java new file mode 100644 index 000000000..9597d486b --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/contact/HandshakeManagerImplTest.java @@ -0,0 +1,316 @@ +package org.briarproject.bramble.contact; + +import org.briarproject.bramble.api.FormatException; +import org.briarproject.bramble.api.contact.ContactManager; +import org.briarproject.bramble.api.contact.HandshakeManager.HandshakeResult; +import org.briarproject.bramble.api.contact.PendingContact; +import org.briarproject.bramble.api.crypto.KeyPair; +import org.briarproject.bramble.api.crypto.PrivateKey; +import org.briarproject.bramble.api.crypto.PublicKey; +import org.briarproject.bramble.api.crypto.SecretKey; +import org.briarproject.bramble.api.crypto.TransportCrypto; +import org.briarproject.bramble.api.db.Transaction; +import org.briarproject.bramble.api.db.TransactionManager; +import org.briarproject.bramble.api.identity.IdentityManager; +import org.briarproject.bramble.api.record.Record; +import org.briarproject.bramble.api.record.RecordReader; +import org.briarproject.bramble.api.record.RecordReader.RecordPredicate; +import org.briarproject.bramble.api.record.RecordReaderFactory; +import org.briarproject.bramble.api.record.RecordWriter; +import org.briarproject.bramble.api.record.RecordWriterFactory; +import org.briarproject.bramble.api.transport.StreamWriter; +import org.briarproject.bramble.test.BrambleMockTestCase; +import org.briarproject.bramble.test.DbExpectations; +import org.briarproject.bramble.test.PredicateMatcher; +import org.jmock.Expectations; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; + +import static org.briarproject.bramble.contact.HandshakeConstants.PROOF_BYTES; +import static org.briarproject.bramble.contact.HandshakeConstants.PROTOCOL_MAJOR_VERSION; +import static org.briarproject.bramble.contact.HandshakeConstants.PROTOCOL_MINOR_VERSION; +import static org.briarproject.bramble.contact.HandshakeRecordTypes.RECORD_TYPE_EPHEMERAL_PUBLIC_KEY; +import static org.briarproject.bramble.contact.HandshakeRecordTypes.RECORD_TYPE_MINOR_VERSION; +import static org.briarproject.bramble.contact.HandshakeRecordTypes.RECORD_TYPE_PROOF_OF_OWNERSHIP; +import static org.briarproject.bramble.test.TestUtils.getAgreementPrivateKey; +import static org.briarproject.bramble.test.TestUtils.getAgreementPublicKey; +import static org.briarproject.bramble.test.TestUtils.getPendingContact; +import static org.briarproject.bramble.test.TestUtils.getRandomBytes; +import static org.briarproject.bramble.test.TestUtils.getSecretKey; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class HandshakeManagerImplTest extends BrambleMockTestCase { + + private final TransactionManager db = + context.mock(TransactionManager.class); + private final IdentityManager identityManager = + context.mock(IdentityManager.class); + private final ContactManager contactManager = + context.mock(ContactManager.class); + private final TransportCrypto transportCrypto = + context.mock(TransportCrypto.class); + private final HandshakeCrypto handshakeCrypto = + context.mock(HandshakeCrypto.class); + private final RecordReaderFactory recordReaderFactory = + context.mock(RecordReaderFactory.class); + private final RecordWriterFactory recordWriterFactory = + context.mock(RecordWriterFactory.class); + private final RecordReader recordReader = context.mock(RecordReader.class); + private final RecordWriter recordWriter = context.mock(RecordWriter.class); + private final StreamWriter streamWriter = context.mock(StreamWriter.class); + + private final PendingContact pendingContact = getPendingContact(); + private final PublicKey theirStaticPublicKey = + pendingContact.getPublicKey(); + private final PublicKey ourStaticPublicKey = getAgreementPublicKey(); + private final PrivateKey ourStaticPrivateKey = getAgreementPrivateKey(); + private final KeyPair ourStaticKeyPair = + new KeyPair(ourStaticPublicKey, ourStaticPrivateKey); + private final PublicKey theirEphemeralPublicKey = getAgreementPublicKey(); + private final PublicKey ourEphemeralPublicKey = getAgreementPublicKey(); + private final PrivateKey ourEphemeralPrivateKey = getAgreementPrivateKey(); + private final KeyPair ourEphemeralKeyPair = + new KeyPair(ourEphemeralPublicKey, ourEphemeralPrivateKey); + private final SecretKey masterKey = getSecretKey(); + private final byte[] ourProof = getRandomBytes(PROOF_BYTES); + private final byte[] theirProof = getRandomBytes(PROOF_BYTES); + + private final InputStream in = new ByteArrayInputStream(new byte[0]); + private final OutputStream out = new ByteArrayOutputStream(0); + + private final HandshakeManagerImpl handshakeManager = + new HandshakeManagerImpl(db, identityManager, contactManager, + transportCrypto, handshakeCrypto, recordReaderFactory, + recordWriterFactory); + + @Test + public void testHandshakeAsAliceWithPeerVersion_0_1() throws Exception { + testHandshakeWithPeerVersion_0_1(true); + } + + @Test + public void testHandshakeAsBobWithPeerVersion_0_1() throws Exception { + testHandshakeWithPeerVersion_0_1(false); + } + + private void testHandshakeWithPeerVersion_0_1(boolean alice) + throws Exception { + expectPrepareForHandshake(alice); + expectSendMinorVersion(); + expectSendKey(); + // Remote peer sends minor version, so use new key derivation + expectReceiveMinorVersion(); + expectReceiveKey(); + expectDeriveMasterKey_0_1(alice); + expectDeriveProof(alice); + expectSendProof(); + expectReceiveProof(); + expectSendEof(); + expectReceiveEof(); + expectVerifyOwnership(alice, true); + + HandshakeResult result = handshakeManager.handshake( + pendingContact.getId(), in, streamWriter); + + assertArrayEquals(masterKey.getBytes(), + result.getMasterKey().getBytes()); + assertEquals(alice, result.isAlice()); + } + + @Test + public void testHandshakeAsAliceWithPeerVersion_0_0() throws Exception { + testHandshakeWithPeerVersion_0_0(true); + } + + @Test + public void testHandshakeAsBobWithPeerVersion_0_0() throws Exception { + testHandshakeWithPeerVersion_0_0(false); + } + + private void testHandshakeWithPeerVersion_0_0(boolean alice) + throws Exception { + expectPrepareForHandshake(alice); + expectSendMinorVersion(); + expectSendKey(); + // Remote peer does not send minor version, so use old key derivation + expectReceiveKey(); + expectDeriveMasterKey_0_0(alice); + expectDeriveProof(alice); + expectSendProof(); + expectReceiveProof(); + expectSendEof(); + expectReceiveEof(); + expectVerifyOwnership(alice, true); + + HandshakeResult result = handshakeManager.handshake( + pendingContact.getId(), in, streamWriter); + + assertArrayEquals(masterKey.getBytes(), + result.getMasterKey().getBytes()); + assertEquals(alice, result.isAlice()); + } + + @Test(expected = FormatException.class) + public void testProofOfOwnershipNotVerifiedAsAlice() throws Exception { + testProofOfOwnershipNotVerified(true); + } + + @Test(expected = FormatException.class) + public void testProofOfOwnershipNotVerifiedAsBob() throws Exception { + testProofOfOwnershipNotVerified(false); + } + + private void testProofOfOwnershipNotVerified(boolean alice) + throws Exception { + expectPrepareForHandshake(alice); + expectSendMinorVersion(); + expectSendKey(); + expectReceiveMinorVersion(); + expectReceiveKey(); + expectDeriveMasterKey_0_1(alice); + expectDeriveProof(alice); + expectSendProof(); + expectReceiveProof(); + expectSendEof(); + expectReceiveEof(); + expectVerifyOwnership(alice, false); + + handshakeManager.handshake(pendingContact.getId(), in, streamWriter); + } + + private void expectPrepareForHandshake(boolean alice) throws Exception { + Transaction txn = new Transaction(null, true); + + context.checking(new DbExpectations() {{ + oneOf(db).transactionWithResult(with(true), withDbCallable(txn)); + oneOf(contactManager).getPendingContact(txn, + pendingContact.getId()); + will(returnValue(pendingContact)); + oneOf(identityManager).getHandshakeKeys(txn); + will(returnValue(ourStaticKeyPair)); + oneOf(transportCrypto).isAlice(theirStaticPublicKey, + ourStaticKeyPair); + will(returnValue(alice)); + oneOf(recordReaderFactory).createRecordReader(in); + will(returnValue(recordReader)); + oneOf(streamWriter).getOutputStream(); + will(returnValue(out)); + oneOf(recordWriterFactory).createRecordWriter(out); + will(returnValue(recordWriter)); + oneOf(handshakeCrypto).generateEphemeralKeyPair(); + will(returnValue(ourEphemeralKeyPair)); + }}); + } + + private void expectSendMinorVersion() throws Exception { + expectWriteRecord(new Record(PROTOCOL_MAJOR_VERSION, + RECORD_TYPE_MINOR_VERSION, + new byte[] {PROTOCOL_MINOR_VERSION})); + } + + private void expectReceiveMinorVersion() throws Exception { + expectReadRecord(new Record(PROTOCOL_MAJOR_VERSION, + RECORD_TYPE_MINOR_VERSION, + new byte[] {PROTOCOL_MINOR_VERSION})); + } + + private void expectSendKey() throws Exception { + expectWriteRecord(new Record(PROTOCOL_MAJOR_VERSION, + RECORD_TYPE_EPHEMERAL_PUBLIC_KEY, + ourEphemeralPublicKey.getEncoded())); + } + + private void expectReceiveKey() throws Exception { + expectReadRecord(new Record(PROTOCOL_MAJOR_VERSION, + RECORD_TYPE_EPHEMERAL_PUBLIC_KEY, + theirEphemeralPublicKey.getEncoded())); + } + + private void expectDeriveMasterKey_0_1(boolean alice) throws Exception { + context.checking(new Expectations() {{ + oneOf(handshakeCrypto).deriveMasterKey_0_1(theirStaticPublicKey, + theirEphemeralPublicKey, ourStaticKeyPair, + ourEphemeralKeyPair, alice); + will(returnValue(masterKey)); + }}); + } + + private void expectDeriveMasterKey_0_0(boolean alice) throws Exception { + context.checking(new Expectations() {{ + oneOf(handshakeCrypto).deriveMasterKey_0_0(theirStaticPublicKey, + theirEphemeralPublicKey, ourStaticKeyPair, + ourEphemeralKeyPair, alice); + will(returnValue(masterKey)); + }}); + } + + private void expectDeriveProof(boolean alice) { + context.checking(new Expectations() {{ + oneOf(handshakeCrypto).proveOwnership(masterKey, alice); + will(returnValue(ourProof)); + }}); + } + + private void expectSendProof() throws Exception { + expectWriteRecord(new Record(PROTOCOL_MAJOR_VERSION, + RECORD_TYPE_PROOF_OF_OWNERSHIP, ourProof)); + } + + private void expectReceiveProof() throws Exception { + expectReadRecord(new Record(PROTOCOL_MAJOR_VERSION, + RECORD_TYPE_PROOF_OF_OWNERSHIP, theirProof)); + } + + private void expectSendEof() throws Exception { + context.checking(new Expectations() {{ + oneOf(streamWriter).sendEndOfStream(); + }}); + } + + private void expectReceiveEof() throws Exception { + context.checking(new Expectations() {{ + oneOf(recordReader).readRecord(with(any(RecordPredicate.class)), + with(any(RecordPredicate.class))); + will(returnValue(null)); + }}); + } + + private void expectVerifyOwnership(boolean alice, boolean verified) { + context.checking(new Expectations() {{ + oneOf(handshakeCrypto).verifyOwnership(masterKey, !alice, + theirProof); + will(returnValue(verified)); + }}); + } + + private void expectWriteRecord(Record record) throws Exception { + context.checking(new Expectations() {{ + oneOf(recordWriter).writeRecord(with(new PredicateMatcher<>( + Record.class, r -> recordEquals(record, r)))); + oneOf(recordWriter).flush(); + }}); + } + + private boolean recordEquals(Record expected, Record actual) { + return expected.getProtocolVersion() == actual.getProtocolVersion() && + expected.getRecordType() == actual.getRecordType() && + Arrays.equals(expected.getPayload(), actual.getPayload()); + } + + private void expectReadRecord(Record record) throws Exception { + context.checking(new Expectations() {{ + // Test that the `accept` predicate passed to the reader would + // accept the expected record + oneOf(recordReader).readRecord(with(new PredicateMatcher<>( + RecordPredicate.class, rp -> rp.test(record))), + with(any(RecordPredicate.class))); + will(returnValue(record)); + }}); + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyAgreementTest.java b/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyAgreementTest.java index 5612663ae..e179ac2b2 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyAgreementTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyAgreementTest.java @@ -60,6 +60,22 @@ public class KeyAgreementTest extends BrambleTestCase { assertArrayEquals(aShared.getBytes(), bShared.getBytes()); } + @Test + public void testDerivesStaticEphemeralSharedSecretBadly() throws Exception { + String label = getRandomString(123); + KeyPair aStatic = crypto.generateAgreementKeyPair(); + KeyPair aEphemeral = crypto.generateAgreementKeyPair(); + KeyPair bStatic = crypto.generateAgreementKeyPair(); + KeyPair bEphemeral = crypto.generateAgreementKeyPair(); + SecretKey aShared = crypto.deriveSharedSecretBadly(label, + bStatic.getPublic(), bEphemeral.getPublic(), aStatic, + aEphemeral, true, inputs); + SecretKey bShared = crypto.deriveSharedSecretBadly(label, + aStatic.getPublic(), aEphemeral.getPublic(), bStatic, + bEphemeral, false, inputs); + assertArrayEquals(aShared.getBytes(), bShared.getBytes()); + } + @Test public void testDerivesStaticEphemeralSharedSecret() throws Exception { String label = getRandomString(123); diff --git a/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java b/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java index d75bf40f4..b1708e859 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java @@ -1,6 +1,5 @@ package org.briarproject.bramble.keyagreement; -import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection; import org.briarproject.bramble.api.plugin.TransportConnectionReader; import org.briarproject.bramble.api.plugin.TransportConnectionWriter; @@ -8,11 +7,13 @@ import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection; import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.RecordReader; +import org.briarproject.bramble.api.record.RecordReader.RecordPredicate; import org.briarproject.bramble.api.record.RecordReaderFactory; import org.briarproject.bramble.api.record.RecordWriter; import org.briarproject.bramble.api.record.RecordWriterFactory; import org.briarproject.bramble.test.BrambleMockTestCase; import org.briarproject.bramble.test.CaptureArgumentAction; +import org.briarproject.bramble.test.PredicateMatcher; import org.jmock.Expectations; import org.jmock.imposters.ByteBuddyClassImposteriser; import org.junit.Test; @@ -21,8 +22,6 @@ import java.io.InputStream; import java.io.OutputStream; import java.util.concurrent.atomic.AtomicReference; -import javax.annotation.Nullable; - import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION; import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT; import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM; @@ -119,7 +118,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase { public void testReceiveKeyThrowsExceptionIfAtEndOfStream() throws Exception { setup(); - expectReadRecord(null); + expectReadEof(); kat.receiveKey(); } @@ -148,7 +147,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase { public void testReceiveConfirmThrowsExceptionIfAtEndOfStream() throws Exception { setup(); - expectReadRecord(null); + expectReadEof(); kat.receiveConfirm(); } @@ -209,12 +208,22 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase { assertArrayEquals(expectedPayload, actual.getPayload()); } - private void expectReadRecord(@Nullable Record record) throws Exception { + private void expectReadRecord(Record record) throws Exception { context.checking(new Expectations() {{ - //noinspection unchecked - oneOf(recordReader).readRecord(with(any(Predicate.class)), - with(any(Predicate.class))); + // Test that the `accept` predicate passed to the reader would + // accept the expected record + oneOf(recordReader).readRecord(with(new PredicateMatcher<>( + RecordPredicate.class, rp -> rp.test(record))), + with(any(RecordPredicate.class))); will(returnValue(record)); }}); } + + private void expectReadEof() throws Exception { + context.checking(new Expectations() {{ + oneOf(recordReader).readRecord(with(any(RecordPredicate.class)), + with(any(RecordPredicate.class))); + will(returnValue(null)); + }}); + } } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java index c2cc4c55f..ed6f412ca 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java @@ -1,9 +1,9 @@ package org.briarproject.bramble.record; import org.briarproject.bramble.api.FormatException; -import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.RecordReader; +import org.briarproject.bramble.api.record.RecordReader.RecordPredicate; import org.briarproject.bramble.test.BrambleTestCase; import org.briarproject.bramble.util.ByteUtils; import org.junit.Test; @@ -128,12 +128,12 @@ public class RecordReaderImplTest extends BrambleTestCase { RecordReader reader = new RecordReaderImpl(in); // Accept records with version 0, type 0 or 1 - Predicate accept = r -> { + RecordPredicate accept = r -> { byte version = r.getProtocolVersion(), type = r.getRecordType(); return version == 0 && (type == 0 || type == 1); }; // Ignore records with version 0, any other type - Predicate ignore = r -> { + RecordPredicate ignore = r -> { byte version = r.getProtocolVersion(), type = r.getRecordType(); return version == 0 && !(type == 0 || type == 1); }; @@ -183,12 +183,12 @@ public class RecordReaderImplTest extends BrambleTestCase { RecordReader reader = new RecordReaderImpl(in); // Accept records with version 0, type 0 or 1 - Predicate accept = r -> { + RecordPredicate accept = r -> { byte version = r.getProtocolVersion(), type = r.getRecordType(); return version == 0 && (type == 0 || type == 1); }; // Ignore records with version 0, any other type - Predicate ignore = r -> { + RecordPredicate ignore = r -> { byte version = r.getProtocolVersion(), type = r.getRecordType(); return version == 0 && !(type == 0 || type == 1); }; diff --git a/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java index d4eb61ff5..d9400cbd5 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java @@ -1,10 +1,10 @@ package org.briarproject.bramble.sync; import org.briarproject.bramble.api.FormatException; -import org.briarproject.bramble.api.Predicate; import org.briarproject.bramble.api.UniqueId; import org.briarproject.bramble.api.record.Record; import org.briarproject.bramble.api.record.RecordReader; +import org.briarproject.bramble.api.record.RecordReader.RecordPredicate; import org.briarproject.bramble.api.sync.Ack; import org.briarproject.bramble.api.sync.GroupId; import org.briarproject.bramble.api.sync.Message; @@ -24,8 +24,6 @@ import org.junit.Test; import java.io.ByteArrayOutputStream; import java.util.List; -import javax.annotation.Nullable; - import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES; import static org.briarproject.bramble.api.sync.RecordTypes.ACK; import static org.briarproject.bramble.api.sync.RecordTypes.MESSAGE; @@ -186,7 +184,7 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase { @Test public void testEofReturnsTrueWhenAtEndOfStream() throws Exception { expectReadRecord(createAck()); - expectReadRecord(null); + expectReadEof(); SyncRecordReader reader = new SyncRecordReaderImpl(messageFactory, recordReader); @@ -212,15 +210,25 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase { }}); } - private void expectReadRecord(@Nullable Record record) throws Exception { + private void expectReadRecord(Record record) throws Exception { context.checking(new Expectations() {{ - //noinspection unchecked - oneOf(recordReader).readRecord(with(any(Predicate.class)), - with(any(Predicate.class))); + // Test that the `accept` predicate passed to the reader would + // accept the expected record + oneOf(recordReader).readRecord(with(new PredicateMatcher<>( + RecordPredicate.class, rp -> rp.test(record))), + with(any(RecordPredicate.class))); will(returnValue(record)); }}); } + private void expectReadEof() throws Exception { + context.checking(new Expectations() {{ + oneOf(recordReader).readRecord(with(any(RecordPredicate.class)), + with(any(RecordPredicate.class))); + will(returnValue(null)); + }}); + } + private Record createMessage(int payloadLength) { return new Record(PROTOCOL_VERSION, MESSAGE, new byte[payloadLength]); }