diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/crypto/CryptoComponent.java b/bramble-api/src/main/java/org/briarproject/bramble/api/crypto/CryptoComponent.java
index cf45b4316..2ce40ca18 100644
--- a/bramble-api/src/main/java/org/briarproject/bramble/api/crypto/CryptoComponent.java
+++ b/bramble-api/src/main/java/org/briarproject/bramble/api/crypto/CryptoComponent.java
@@ -54,6 +54,38 @@ public interface CryptoComponent {
KeyPair ourKeyPair, byte[]... inputs)
throws GeneralSecurityException;
+ /**
+ * Derives a shared secret from two static and two ephemeral key pairs.
+ *
+ * Do not use this method for new protocols. The shared secret can be
+ * re-derived using the ephemeral public keys and both static private
+ * keys, so keys derived from the shared secret should not be used if
+ * forward secrecy is required. Use {@link #deriveSharedSecret(String,
+ * PublicKey, PublicKey, KeyPair, KeyPair, boolean, byte[]...)} instead.
+ *
+ * TODO: Remove this after a reasonable migration period (added 2023-03-10).
+ *
+ *
+ * @param label A namespaced label indicating the purpose of this shared
+ * secret, to prevent it from being repurposed or colliding with a shared
+ * secret derived for another purpose
+ * @param theirStaticPublicKey The static public key of the remote party
+ * @param theirEphemeralPublicKey The ephemeral public key of the remote
+ * party
+ * @param ourStaticKeyPair The static key pair of the local party
+ * @param ourEphemeralKeyPair The ephemeral key pair of the local party
+ * @param alice True if the local party is Alice
+ * @param inputs Additional inputs that will be included in the
+ * derivation of the shared secret
+ * @return The shared secret
+ */
+ @Deprecated
+ SecretKey deriveSharedSecretBadly(String label,
+ PublicKey theirStaticPublicKey, PublicKey theirEphemeralPublicKey,
+ KeyPair ourStaticKeyPair, KeyPair ourEphemeralKeyPair,
+ boolean alice, byte[]... inputs)
+ throws GeneralSecurityException;
+
/**
* Derives a shared secret from two static and two ephemeral key pairs.
*
diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/record/RecordReader.java b/bramble-api/src/main/java/org/briarproject/bramble/api/record/RecordReader.java
index ada8f1940..3076068c7 100644
--- a/bramble-api/src/main/java/org/briarproject/bramble/api/record/RecordReader.java
+++ b/bramble-api/src/main/java/org/briarproject/bramble/api/record/RecordReader.java
@@ -32,8 +32,15 @@ public interface RecordReader {
* 'accept' or 'ignore' predicates
*/
@Nullable
- Record readRecord(Predicate accept, Predicate ignore)
+ Record readRecord(RecordPredicate accept, RecordPredicate ignore)
throws IOException;
void close() throws IOException;
+
+ /**
+ * Interface that reifies the generic interface {@code Predicate}
+ * for easier testing.
+ */
+ interface RecordPredicate extends Predicate {
+ }
}
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactExchangeManagerImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactExchangeManagerImpl.java
index 5040cf959..2875ad566 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactExchangeManagerImpl.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactExchangeManagerImpl.java
@@ -1,7 +1,6 @@
package org.briarproject.bramble.contact;
import org.briarproject.bramble.api.FormatException;
-import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.client.ClientHelper;
import org.briarproject.bramble.api.contact.Contact;
import org.briarproject.bramble.api.contact.ContactExchangeManager;
@@ -24,6 +23,7 @@ import org.briarproject.bramble.api.properties.TransportProperties;
import org.briarproject.bramble.api.properties.TransportPropertyManager;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
+import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory;
@@ -61,12 +61,12 @@ class ContactExchangeManagerImpl implements ContactExchangeManager {
getLogger(ContactExchangeManagerImpl.class.getName());
// Accept records with current protocol version, known record type
- private static final Predicate ACCEPT = r ->
+ private static final RecordPredicate ACCEPT = r ->
r.getProtocolVersion() == PROTOCOL_VERSION &&
isKnownRecordType(r.getRecordType());
// Ignore records with current protocol version, unknown record type
- private static final Predicate IGNORE = r ->
+ private static final RecordPredicate IGNORE = r ->
r.getProtocolVersion() == PROTOCOL_VERSION &&
!isKnownRecordType(r.getRecordType());
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeConstants.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeConstants.java
index 37e7a41d3..45470f5b8 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeConstants.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeConstants.java
@@ -5,14 +5,31 @@ import static org.briarproject.bramble.api.crypto.CryptoConstants.MAC_BYTES;
interface HandshakeConstants {
/**
- * The current version of the handshake protocol.
+ * The current major version of the handshake protocol.
*/
- byte PROTOCOL_VERSION = 0;
+ byte PROTOCOL_MAJOR_VERSION = 0;
/**
- * Label for deriving the master key.
+ * The current minor version of the handshake protocol.
*/
- String MASTER_KEY_LABEL = "org.briarproject.bramble.handshake/MASTER_KEY";
+ byte PROTOCOL_MINOR_VERSION = 1;
+
+ /**
+ * Label for deriving the master key when using the deprecated v0.0 key
+ * derivation method.
+ *
+ * TODO: Remove this after a reasonable migration period (added 2023-03-10).
+ */
+ @Deprecated
+ String MASTER_KEY_LABEL_0_0 =
+ "org.briarproject.bramble.handshake/MASTER_KEY";
+
+ /**
+ * Label for deriving the master key when using the v0.1 key derivation
+ * method.
+ */
+ String MASTER_KEY_LABEL_0_1 =
+ "org.briarproject.bramble.handshake/MASTER_KEY_0_1";
/**
* Label for deriving Alice's proof of ownership from the master key.
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCrypto.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCrypto.java
index b7bf351ca..ce9a656eb 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCrypto.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCrypto.java
@@ -13,11 +13,26 @@ interface HandshakeCrypto {
KeyPair generateEphemeralKeyPair();
/**
- * Derives the master key from the given static and ephemeral keys.
+ * Derives the master key from the given static and ephemeral keys using
+ * the deprecated v0.0 key derivation method.
+ *
+ * TODO: Remove this after a reasonable migration period (added 2023-03-10).
*
* @param alice Whether the local peer is Alice
*/
- SecretKey deriveMasterKey(PublicKey theirStaticPublicKey,
+ @Deprecated
+ SecretKey deriveMasterKey_0_0(PublicKey theirStaticPublicKey,
+ PublicKey theirEphemeralPublicKey, KeyPair ourStaticKeyPair,
+ KeyPair ourEphemeralKeyPair, boolean alice)
+ throws GeneralSecurityException;
+
+ /**
+ * Derives the master key from the given static and ephemeral keys using
+ * the v0.1 key derivation method.
+ *
+ * @param alice Whether the local peer is Alice
+ */
+ SecretKey deriveMasterKey_0_1(PublicKey theirStaticPublicKey,
PublicKey theirEphemeralPublicKey, KeyPair ourStaticKeyPair,
KeyPair ourEphemeralKeyPair, boolean alice)
throws GeneralSecurityException;
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCryptoImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCryptoImpl.java
index f5ebb5e1d..7606ec938 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCryptoImpl.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeCryptoImpl.java
@@ -13,7 +13,8 @@ import javax.inject.Inject;
import static org.briarproject.bramble.contact.HandshakeConstants.ALICE_PROOF_LABEL;
import static org.briarproject.bramble.contact.HandshakeConstants.BOB_PROOF_LABEL;
-import static org.briarproject.bramble.contact.HandshakeConstants.MASTER_KEY_LABEL;
+import static org.briarproject.bramble.contact.HandshakeConstants.MASTER_KEY_LABEL_0_0;
+import static org.briarproject.bramble.contact.HandshakeConstants.MASTER_KEY_LABEL_0_1;
@Immutable
@NotNullByDefault
@@ -32,7 +33,8 @@ class HandshakeCryptoImpl implements HandshakeCrypto {
}
@Override
- public SecretKey deriveMasterKey(PublicKey theirStaticPublicKey,
+ @Deprecated
+ public SecretKey deriveMasterKey_0_0(PublicKey theirStaticPublicKey,
PublicKey theirEphemeralPublicKey, KeyPair ourStaticKeyPair,
KeyPair ourEphemeralKeyPair, boolean alice) throws
GeneralSecurityException {
@@ -46,9 +48,29 @@ class HandshakeCryptoImpl implements HandshakeCrypto {
alice ? ourEphemeral : theirEphemeral,
alice ? theirEphemeral : ourEphemeral
};
- return crypto.deriveSharedSecret(MASTER_KEY_LABEL, theirStaticPublicKey,
- theirEphemeralPublicKey, ourStaticKeyPair, ourEphemeralKeyPair,
- alice, inputs);
+ return crypto.deriveSharedSecretBadly(MASTER_KEY_LABEL_0_0,
+ theirStaticPublicKey, theirEphemeralPublicKey,
+ ourStaticKeyPair, ourEphemeralKeyPair, alice, inputs);
+ }
+
+ @Override
+ public SecretKey deriveMasterKey_0_1(PublicKey theirStaticPublicKey,
+ PublicKey theirEphemeralPublicKey, KeyPair ourStaticKeyPair,
+ KeyPair ourEphemeralKeyPair, boolean alice) throws
+ GeneralSecurityException {
+ byte[] theirStatic = theirStaticPublicKey.getEncoded();
+ byte[] theirEphemeral = theirEphemeralPublicKey.getEncoded();
+ byte[] ourStatic = ourStaticKeyPair.getPublic().getEncoded();
+ byte[] ourEphemeral = ourEphemeralKeyPair.getPublic().getEncoded();
+ byte[][] inputs = {
+ alice ? ourStatic : theirStatic,
+ alice ? theirStatic : ourStatic,
+ alice ? ourEphemeral : theirEphemeral,
+ alice ? theirEphemeral : ourEphemeral
+ };
+ return crypto.deriveSharedSecret(MASTER_KEY_LABEL_0_1,
+ theirStaticPublicKey, theirEphemeralPublicKey,
+ ourStaticKeyPair, ourEphemeralKeyPair, alice, inputs);
}
@Override
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeManagerImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeManagerImpl.java
index acd1c80e7..496fdc482 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeManagerImpl.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeManagerImpl.java
@@ -2,7 +2,6 @@ package org.briarproject.bramble.contact;
import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.Pair;
-import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.contact.ContactManager;
import org.briarproject.bramble.api.contact.HandshakeManager;
import org.briarproject.bramble.api.contact.PendingContact;
@@ -12,12 +11,12 @@ import org.briarproject.bramble.api.crypto.KeyPair;
import org.briarproject.bramble.api.crypto.PublicKey;
import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.crypto.TransportCrypto;
-import org.briarproject.bramble.api.db.DatabaseComponent;
import org.briarproject.bramble.api.db.DbException;
import org.briarproject.bramble.api.db.TransactionManager;
import org.briarproject.bramble.api.identity.IdentityManager;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
+import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory;
@@ -28,15 +27,20 @@ import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.security.GeneralSecurityException;
+import java.util.List;
import javax.annotation.concurrent.Immutable;
import javax.inject.Inject;
+import static java.util.Arrays.asList;
+import static java.util.Collections.singletonList;
import static org.briarproject.bramble.api.crypto.CryptoConstants.MAX_AGREEMENT_PUBLIC_KEY_BYTES;
import static org.briarproject.bramble.contact.HandshakeConstants.PROOF_BYTES;
-import static org.briarproject.bramble.contact.HandshakeConstants.PROTOCOL_VERSION;
-import static org.briarproject.bramble.contact.HandshakeRecordTypes.EPHEMERAL_PUBLIC_KEY;
-import static org.briarproject.bramble.contact.HandshakeRecordTypes.PROOF_OF_OWNERSHIP;
+import static org.briarproject.bramble.contact.HandshakeConstants.PROTOCOL_MAJOR_VERSION;
+import static org.briarproject.bramble.contact.HandshakeConstants.PROTOCOL_MINOR_VERSION;
+import static org.briarproject.bramble.contact.HandshakeRecordTypes.RECORD_TYPE_EPHEMERAL_PUBLIC_KEY;
+import static org.briarproject.bramble.contact.HandshakeRecordTypes.RECORD_TYPE_MINOR_VERSION;
+import static org.briarproject.bramble.contact.HandshakeRecordTypes.RECORD_TYPE_PROOF_OF_OWNERSHIP;
import static org.briarproject.bramble.util.ValidationUtils.checkLength;
@Immutable
@@ -44,12 +48,14 @@ import static org.briarproject.bramble.util.ValidationUtils.checkLength;
class HandshakeManagerImpl implements HandshakeManager {
// Ignore records with current protocol version, unknown record type
- private static final Predicate IGNORE = r ->
- r.getProtocolVersion() == PROTOCOL_VERSION &&
+ private static final RecordPredicate IGNORE = r ->
+ r.getProtocolVersion() == PROTOCOL_MAJOR_VERSION &&
!isKnownRecordType(r.getRecordType());
private static boolean isKnownRecordType(byte type) {
- return type == EPHEMERAL_PUBLIC_KEY || type == PROOF_OF_OWNERSHIP;
+ return type == RECORD_TYPE_EPHEMERAL_PUBLIC_KEY ||
+ type == RECORD_TYPE_PROOF_OF_OWNERSHIP ||
+ type == RECORD_TYPE_MINOR_VERSION;
}
private final TransactionManager db;
@@ -61,7 +67,7 @@ class HandshakeManagerImpl implements HandshakeManager {
private final RecordWriterFactory recordWriterFactory;
@Inject
- HandshakeManagerImpl(DatabaseComponent db,
+ HandshakeManagerImpl(TransactionManager db,
IdentityManager identityManager,
ContactManager contactManager,
TransportCrypto transportCrypto,
@@ -95,19 +101,31 @@ class HandshakeManagerImpl implements HandshakeManager {
.createRecordWriter(out.getOutputStream());
KeyPair ourEphemeralKeyPair =
handshakeCrypto.generateEphemeralKeyPair();
- PublicKey theirEphemeralPublicKey;
+ Pair theirMinorVersionAndKey;
if (alice) {
+ sendMinorVersion(recordWriter);
sendPublicKey(recordWriter, ourEphemeralKeyPair.getPublic());
- theirEphemeralPublicKey = receivePublicKey(recordReader);
+ theirMinorVersionAndKey = receiveMinorVersionAndKey(recordReader);
} else {
- theirEphemeralPublicKey = receivePublicKey(recordReader);
+ theirMinorVersionAndKey = receiveMinorVersionAndKey(recordReader);
+ sendMinorVersion(recordWriter);
sendPublicKey(recordWriter, ourEphemeralKeyPair.getPublic());
}
+ byte theirMinorVersion = theirMinorVersionAndKey.getFirst();
+ PublicKey theirEphemeralPublicKey = theirMinorVersionAndKey.getSecond();
SecretKey masterKey;
try {
- masterKey = handshakeCrypto.deriveMasterKey(theirStaticPublicKey,
- theirEphemeralPublicKey, ourStaticKeyPair,
- ourEphemeralKeyPair, alice);
+ if (theirMinorVersion > 0) {
+ masterKey = handshakeCrypto.deriveMasterKey_0_1(
+ theirStaticPublicKey, theirEphemeralPublicKey,
+ ourStaticKeyPair, ourEphemeralKeyPair, alice);
+ } else {
+ // TODO: Remove this branch after a reasonable migration
+ // period (added 2023-03-10).
+ masterKey = handshakeCrypto.deriveMasterKey_0_0(
+ theirStaticPublicKey, theirEphemeralPublicKey,
+ ourStaticKeyPair, ourEphemeralKeyPair, alice);
+ }
} catch (GeneralSecurityException e) {
throw new FormatException();
}
@@ -128,34 +146,91 @@ class HandshakeManagerImpl implements HandshakeManager {
}
private void sendPublicKey(RecordWriter w, PublicKey k) throws IOException {
- w.writeRecord(new Record(PROTOCOL_VERSION, EPHEMERAL_PUBLIC_KEY,
- k.getEncoded()));
+ w.writeRecord(new Record(PROTOCOL_MAJOR_VERSION,
+ RECORD_TYPE_EPHEMERAL_PUBLIC_KEY, k.getEncoded()));
w.flush();
}
- private PublicKey receivePublicKey(RecordReader r) throws IOException {
- byte[] key = readRecord(r, EPHEMERAL_PUBLIC_KEY).getPayload();
+ /**
+ * Receives the remote peer's protocol minor version and ephemeral public
+ * key.
+ *
+ * In version 0.1 of the protocol, each peer sends a minor version record
+ * followed by an ephemeral public key record.
+ *
+ * In version 0.0 of the protocol, each peer sends an ephemeral public key
+ * record without a preceding minor version record.
+ *
+ * Therefore the remote peer's minor version must be non-zero if a minor
+ * version record is received, and is assumed to be zero if no minor
+ * version record is received.
+ */
+ private Pair receiveMinorVersionAndKey(RecordReader r)
+ throws IOException {
+ byte theirMinorVersion;
+ PublicKey theirEphemeralPublicKey;
+ // The first record can be either a minor version record or an
+ // ephemeral public key record
+ Record first = readRecord(r, asList(RECORD_TYPE_MINOR_VERSION,
+ RECORD_TYPE_EPHEMERAL_PUBLIC_KEY));
+ if (first.getRecordType() == RECORD_TYPE_MINOR_VERSION) {
+ // The payload must be a single byte giving the remote peer's
+ // protocol minor version, which must be non-zero
+ byte[] payload = first.getPayload();
+ checkLength(payload, 1);
+ theirMinorVersion = payload[0];
+ if (theirMinorVersion == 0) throw new FormatException();
+ // The second record must be an ephemeral public key record
+ Record second = readRecord(r,
+ singletonList(RECORD_TYPE_EPHEMERAL_PUBLIC_KEY));
+ theirEphemeralPublicKey = parsePublicKey(second);
+ } else {
+ // The remote peer did not send a minor version record, so the
+ // remote peer's protocol minor version is assumed to be zero
+ // TODO: Remove this branch after a reasonable migration period
+ // (added 2023-03-10).
+ theirMinorVersion = 0;
+ theirEphemeralPublicKey = parsePublicKey(first);
+ }
+ return new Pair<>(theirMinorVersion, theirEphemeralPublicKey);
+ }
+
+ private PublicKey parsePublicKey(Record rec) throws IOException {
+ if (rec.getRecordType() != RECORD_TYPE_EPHEMERAL_PUBLIC_KEY) {
+ throw new AssertionError();
+ }
+ byte[] key = rec.getPayload();
checkLength(key, 1, MAX_AGREEMENT_PUBLIC_KEY_BYTES);
return new AgreementPublicKey(key);
}
private void sendProof(RecordWriter w, byte[] proof) throws IOException {
- w.writeRecord(new Record(PROTOCOL_VERSION, PROOF_OF_OWNERSHIP, proof));
+ w.writeRecord(new Record(PROTOCOL_MAJOR_VERSION,
+ RECORD_TYPE_PROOF_OF_OWNERSHIP, proof));
w.flush();
}
private byte[] receiveProof(RecordReader r) throws IOException {
- byte[] proof = readRecord(r, PROOF_OF_OWNERSHIP).getPayload();
+ Record rec = readRecord(r,
+ singletonList(RECORD_TYPE_PROOF_OF_OWNERSHIP));
+ byte[] proof = rec.getPayload();
checkLength(proof, PROOF_BYTES, PROOF_BYTES);
return proof;
}
- private Record readRecord(RecordReader r, byte expectedType)
+ private void sendMinorVersion(RecordWriter w) throws IOException {
+ w.writeRecord(new Record(PROTOCOL_MAJOR_VERSION,
+ RECORD_TYPE_MINOR_VERSION,
+ new byte[] {PROTOCOL_MINOR_VERSION}));
+ w.flush();
+ }
+
+ private Record readRecord(RecordReader r, List expectedTypes)
throws IOException {
- // Accept records with current protocol version, expected type only
- Predicate accept = rec ->
- rec.getProtocolVersion() == PROTOCOL_VERSION &&
- rec.getRecordType() == expectedType;
+ // Accept records with current protocol version, expected types only
+ RecordPredicate accept = rec ->
+ rec.getProtocolVersion() == PROTOCOL_MAJOR_VERSION &&
+ expectedTypes.contains(rec.getRecordType());
Record rec = r.readRecord(accept, IGNORE);
if (rec == null) throw new EOFException();
return rec;
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeRecordTypes.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeRecordTypes.java
index 7c38d9a84..9875c1a25 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeRecordTypes.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/HandshakeRecordTypes.java
@@ -5,7 +5,9 @@ package org.briarproject.bramble.contact;
*/
interface HandshakeRecordTypes {
- byte EPHEMERAL_PUBLIC_KEY = 0;
+ byte RECORD_TYPE_EPHEMERAL_PUBLIC_KEY = 0;
- byte PROOF_OF_OWNERSHIP = 1;
+ byte RECORD_TYPE_PROOF_OF_OWNERSHIP = 1;
+
+ byte RECORD_TYPE_MINOR_VERSION = 2;
}
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/crypto/CryptoComponentImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/crypto/CryptoComponentImpl.java
index 06ce15f29..1bb7fe9f8 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/crypto/CryptoComponentImpl.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/crypto/CryptoComponentImpl.java
@@ -222,7 +222,8 @@ class CryptoComponentImpl implements CryptoComponent {
}
@Override
- public SecretKey deriveSharedSecret(String label,
+ @Deprecated
+ public SecretKey deriveSharedSecretBadly(String label,
PublicKey theirStaticPublicKey, PublicKey theirEphemeralPublicKey,
KeyPair ourStaticKeyPair, KeyPair ourEphemeralKeyPair,
boolean alice, byte[]... inputs) throws GeneralSecurityException {
@@ -250,6 +251,35 @@ class CryptoComponentImpl implements CryptoComponent {
return new SecretKey(hash);
}
+ @Override
+ public SecretKey deriveSharedSecret(String label,
+ PublicKey theirStaticPublicKey, PublicKey theirEphemeralPublicKey,
+ KeyPair ourStaticKeyPair, KeyPair ourEphemeralKeyPair,
+ boolean alice, byte[]... inputs) throws GeneralSecurityException {
+ PrivateKey ourStaticPrivateKey = ourStaticKeyPair.getPrivate();
+ PrivateKey ourEphemeralPrivateKey = ourEphemeralKeyPair.getPrivate();
+ byte[][] hashInputs = new byte[inputs.length + 3][];
+ // Alice ephemeral/Bob ephemeral
+ hashInputs[0] = performRawKeyAgreement(ourEphemeralPrivateKey,
+ theirEphemeralPublicKey);
+ // Alice static/Bob ephemeral, Bob static/Alice ephemeral
+ if (alice) {
+ hashInputs[1] = performRawKeyAgreement(ourStaticPrivateKey,
+ theirEphemeralPublicKey);
+ hashInputs[2] = performRawKeyAgreement(ourEphemeralPrivateKey,
+ theirStaticPublicKey);
+ } else {
+ hashInputs[1] = performRawKeyAgreement(ourEphemeralPrivateKey,
+ theirStaticPublicKey);
+ hashInputs[2] = performRawKeyAgreement(ourStaticPrivateKey,
+ theirEphemeralPublicKey);
+ }
+ arraycopy(inputs, 0, hashInputs, 3, inputs.length);
+ byte[] hash = hash(label, hashInputs);
+ if (hash.length != SecretKey.LENGTH) throw new IllegalStateException();
+ return new SecretKey(hash);
+ }
+
@Override
public byte[] sign(String label, byte[] toSign, PrivateKey privateKey)
throws GeneralSecurityException {
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java
index e5e1e2ca3..bd09ceb09 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java
@@ -1,11 +1,11 @@
package org.briarproject.bramble.keyagreement;
-import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
+import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory;
@@ -34,12 +34,12 @@ class KeyAgreementTransport {
Logger.getLogger(KeyAgreementTransport.class.getName());
// Accept records with current protocol version, known record type
- private static final Predicate ACCEPT = r ->
+ private static final RecordPredicate ACCEPT = r ->
r.getProtocolVersion() == PROTOCOL_VERSION &&
isKnownRecordType(r.getRecordType());
// Ignore records with current protocol version, unknown record type
- private static final Predicate IGNORE = r ->
+ private static final RecordPredicate IGNORE = r ->
r.getProtocolVersion() == PROTOCOL_VERSION &&
!isKnownRecordType(r.getRecordType());
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/record/RecordReaderImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/record/RecordReaderImpl.java
index fb6abb151..872c937fe 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/record/RecordReaderImpl.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/record/RecordReaderImpl.java
@@ -1,7 +1,6 @@
package org.briarproject.bramble.record;
import org.briarproject.bramble.api.FormatException;
-import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.util.ByteUtils;
@@ -45,7 +44,7 @@ class RecordReaderImpl implements RecordReader {
@Nullable
@Override
- public Record readRecord(Predicate accept, Predicate ignore)
+ public Record readRecord(RecordPredicate accept, RecordPredicate ignore)
throws IOException {
while (true) {
if (eof()) return null;
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java
index 250f72953..ee247e4b6 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/sync/SyncRecordReaderImpl.java
@@ -1,10 +1,10 @@
package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.FormatException;
-import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.UniqueId;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
+import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.Message;
import org.briarproject.bramble.api.sync.MessageFactory;
@@ -41,12 +41,12 @@ import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION;
class SyncRecordReaderImpl implements SyncRecordReader {
// Accept records with current protocol version, known record type
- private static final Predicate ACCEPT = r ->
+ private static final RecordPredicate ACCEPT = r ->
r.getProtocolVersion() == PROTOCOL_VERSION &&
isKnownRecordType(r.getRecordType());
// Ignore records with current protocol version, unknown record type
- private static final Predicate IGNORE = r ->
+ private static final RecordPredicate IGNORE = r ->
r.getProtocolVersion() == PROTOCOL_VERSION &&
!isKnownRecordType(r.getRecordType());
diff --git a/bramble-core/src/test/java/org/briarproject/bramble/contact/HandshakeManagerImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/contact/HandshakeManagerImplTest.java
new file mode 100644
index 000000000..9597d486b
--- /dev/null
+++ b/bramble-core/src/test/java/org/briarproject/bramble/contact/HandshakeManagerImplTest.java
@@ -0,0 +1,316 @@
+package org.briarproject.bramble.contact;
+
+import org.briarproject.bramble.api.FormatException;
+import org.briarproject.bramble.api.contact.ContactManager;
+import org.briarproject.bramble.api.contact.HandshakeManager.HandshakeResult;
+import org.briarproject.bramble.api.contact.PendingContact;
+import org.briarproject.bramble.api.crypto.KeyPair;
+import org.briarproject.bramble.api.crypto.PrivateKey;
+import org.briarproject.bramble.api.crypto.PublicKey;
+import org.briarproject.bramble.api.crypto.SecretKey;
+import org.briarproject.bramble.api.crypto.TransportCrypto;
+import org.briarproject.bramble.api.db.Transaction;
+import org.briarproject.bramble.api.db.TransactionManager;
+import org.briarproject.bramble.api.identity.IdentityManager;
+import org.briarproject.bramble.api.record.Record;
+import org.briarproject.bramble.api.record.RecordReader;
+import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
+import org.briarproject.bramble.api.record.RecordReaderFactory;
+import org.briarproject.bramble.api.record.RecordWriter;
+import org.briarproject.bramble.api.record.RecordWriterFactory;
+import org.briarproject.bramble.api.transport.StreamWriter;
+import org.briarproject.bramble.test.BrambleMockTestCase;
+import org.briarproject.bramble.test.DbExpectations;
+import org.briarproject.bramble.test.PredicateMatcher;
+import org.jmock.Expectations;
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.Arrays;
+
+import static org.briarproject.bramble.contact.HandshakeConstants.PROOF_BYTES;
+import static org.briarproject.bramble.contact.HandshakeConstants.PROTOCOL_MAJOR_VERSION;
+import static org.briarproject.bramble.contact.HandshakeConstants.PROTOCOL_MINOR_VERSION;
+import static org.briarproject.bramble.contact.HandshakeRecordTypes.RECORD_TYPE_EPHEMERAL_PUBLIC_KEY;
+import static org.briarproject.bramble.contact.HandshakeRecordTypes.RECORD_TYPE_MINOR_VERSION;
+import static org.briarproject.bramble.contact.HandshakeRecordTypes.RECORD_TYPE_PROOF_OF_OWNERSHIP;
+import static org.briarproject.bramble.test.TestUtils.getAgreementPrivateKey;
+import static org.briarproject.bramble.test.TestUtils.getAgreementPublicKey;
+import static org.briarproject.bramble.test.TestUtils.getPendingContact;
+import static org.briarproject.bramble.test.TestUtils.getRandomBytes;
+import static org.briarproject.bramble.test.TestUtils.getSecretKey;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+public class HandshakeManagerImplTest extends BrambleMockTestCase {
+
+ private final TransactionManager db =
+ context.mock(TransactionManager.class);
+ private final IdentityManager identityManager =
+ context.mock(IdentityManager.class);
+ private final ContactManager contactManager =
+ context.mock(ContactManager.class);
+ private final TransportCrypto transportCrypto =
+ context.mock(TransportCrypto.class);
+ private final HandshakeCrypto handshakeCrypto =
+ context.mock(HandshakeCrypto.class);
+ private final RecordReaderFactory recordReaderFactory =
+ context.mock(RecordReaderFactory.class);
+ private final RecordWriterFactory recordWriterFactory =
+ context.mock(RecordWriterFactory.class);
+ private final RecordReader recordReader = context.mock(RecordReader.class);
+ private final RecordWriter recordWriter = context.mock(RecordWriter.class);
+ private final StreamWriter streamWriter = context.mock(StreamWriter.class);
+
+ private final PendingContact pendingContact = getPendingContact();
+ private final PublicKey theirStaticPublicKey =
+ pendingContact.getPublicKey();
+ private final PublicKey ourStaticPublicKey = getAgreementPublicKey();
+ private final PrivateKey ourStaticPrivateKey = getAgreementPrivateKey();
+ private final KeyPair ourStaticKeyPair =
+ new KeyPair(ourStaticPublicKey, ourStaticPrivateKey);
+ private final PublicKey theirEphemeralPublicKey = getAgreementPublicKey();
+ private final PublicKey ourEphemeralPublicKey = getAgreementPublicKey();
+ private final PrivateKey ourEphemeralPrivateKey = getAgreementPrivateKey();
+ private final KeyPair ourEphemeralKeyPair =
+ new KeyPair(ourEphemeralPublicKey, ourEphemeralPrivateKey);
+ private final SecretKey masterKey = getSecretKey();
+ private final byte[] ourProof = getRandomBytes(PROOF_BYTES);
+ private final byte[] theirProof = getRandomBytes(PROOF_BYTES);
+
+ private final InputStream in = new ByteArrayInputStream(new byte[0]);
+ private final OutputStream out = new ByteArrayOutputStream(0);
+
+ private final HandshakeManagerImpl handshakeManager =
+ new HandshakeManagerImpl(db, identityManager, contactManager,
+ transportCrypto, handshakeCrypto, recordReaderFactory,
+ recordWriterFactory);
+
+ @Test
+ public void testHandshakeAsAliceWithPeerVersion_0_1() throws Exception {
+ testHandshakeWithPeerVersion_0_1(true);
+ }
+
+ @Test
+ public void testHandshakeAsBobWithPeerVersion_0_1() throws Exception {
+ testHandshakeWithPeerVersion_0_1(false);
+ }
+
+ private void testHandshakeWithPeerVersion_0_1(boolean alice)
+ throws Exception {
+ expectPrepareForHandshake(alice);
+ expectSendMinorVersion();
+ expectSendKey();
+ // Remote peer sends minor version, so use new key derivation
+ expectReceiveMinorVersion();
+ expectReceiveKey();
+ expectDeriveMasterKey_0_1(alice);
+ expectDeriveProof(alice);
+ expectSendProof();
+ expectReceiveProof();
+ expectSendEof();
+ expectReceiveEof();
+ expectVerifyOwnership(alice, true);
+
+ HandshakeResult result = handshakeManager.handshake(
+ pendingContact.getId(), in, streamWriter);
+
+ assertArrayEquals(masterKey.getBytes(),
+ result.getMasterKey().getBytes());
+ assertEquals(alice, result.isAlice());
+ }
+
+ @Test
+ public void testHandshakeAsAliceWithPeerVersion_0_0() throws Exception {
+ testHandshakeWithPeerVersion_0_0(true);
+ }
+
+ @Test
+ public void testHandshakeAsBobWithPeerVersion_0_0() throws Exception {
+ testHandshakeWithPeerVersion_0_0(false);
+ }
+
+ private void testHandshakeWithPeerVersion_0_0(boolean alice)
+ throws Exception {
+ expectPrepareForHandshake(alice);
+ expectSendMinorVersion();
+ expectSendKey();
+ // Remote peer does not send minor version, so use old key derivation
+ expectReceiveKey();
+ expectDeriveMasterKey_0_0(alice);
+ expectDeriveProof(alice);
+ expectSendProof();
+ expectReceiveProof();
+ expectSendEof();
+ expectReceiveEof();
+ expectVerifyOwnership(alice, true);
+
+ HandshakeResult result = handshakeManager.handshake(
+ pendingContact.getId(), in, streamWriter);
+
+ assertArrayEquals(masterKey.getBytes(),
+ result.getMasterKey().getBytes());
+ assertEquals(alice, result.isAlice());
+ }
+
+ @Test(expected = FormatException.class)
+ public void testProofOfOwnershipNotVerifiedAsAlice() throws Exception {
+ testProofOfOwnershipNotVerified(true);
+ }
+
+ @Test(expected = FormatException.class)
+ public void testProofOfOwnershipNotVerifiedAsBob() throws Exception {
+ testProofOfOwnershipNotVerified(false);
+ }
+
+ private void testProofOfOwnershipNotVerified(boolean alice)
+ throws Exception {
+ expectPrepareForHandshake(alice);
+ expectSendMinorVersion();
+ expectSendKey();
+ expectReceiveMinorVersion();
+ expectReceiveKey();
+ expectDeriveMasterKey_0_1(alice);
+ expectDeriveProof(alice);
+ expectSendProof();
+ expectReceiveProof();
+ expectSendEof();
+ expectReceiveEof();
+ expectVerifyOwnership(alice, false);
+
+ handshakeManager.handshake(pendingContact.getId(), in, streamWriter);
+ }
+
+ private void expectPrepareForHandshake(boolean alice) throws Exception {
+ Transaction txn = new Transaction(null, true);
+
+ context.checking(new DbExpectations() {{
+ oneOf(db).transactionWithResult(with(true), withDbCallable(txn));
+ oneOf(contactManager).getPendingContact(txn,
+ pendingContact.getId());
+ will(returnValue(pendingContact));
+ oneOf(identityManager).getHandshakeKeys(txn);
+ will(returnValue(ourStaticKeyPair));
+ oneOf(transportCrypto).isAlice(theirStaticPublicKey,
+ ourStaticKeyPair);
+ will(returnValue(alice));
+ oneOf(recordReaderFactory).createRecordReader(in);
+ will(returnValue(recordReader));
+ oneOf(streamWriter).getOutputStream();
+ will(returnValue(out));
+ oneOf(recordWriterFactory).createRecordWriter(out);
+ will(returnValue(recordWriter));
+ oneOf(handshakeCrypto).generateEphemeralKeyPair();
+ will(returnValue(ourEphemeralKeyPair));
+ }});
+ }
+
+ private void expectSendMinorVersion() throws Exception {
+ expectWriteRecord(new Record(PROTOCOL_MAJOR_VERSION,
+ RECORD_TYPE_MINOR_VERSION,
+ new byte[] {PROTOCOL_MINOR_VERSION}));
+ }
+
+ private void expectReceiveMinorVersion() throws Exception {
+ expectReadRecord(new Record(PROTOCOL_MAJOR_VERSION,
+ RECORD_TYPE_MINOR_VERSION,
+ new byte[] {PROTOCOL_MINOR_VERSION}));
+ }
+
+ private void expectSendKey() throws Exception {
+ expectWriteRecord(new Record(PROTOCOL_MAJOR_VERSION,
+ RECORD_TYPE_EPHEMERAL_PUBLIC_KEY,
+ ourEphemeralPublicKey.getEncoded()));
+ }
+
+ private void expectReceiveKey() throws Exception {
+ expectReadRecord(new Record(PROTOCOL_MAJOR_VERSION,
+ RECORD_TYPE_EPHEMERAL_PUBLIC_KEY,
+ theirEphemeralPublicKey.getEncoded()));
+ }
+
+ private void expectDeriveMasterKey_0_1(boolean alice) throws Exception {
+ context.checking(new Expectations() {{
+ oneOf(handshakeCrypto).deriveMasterKey_0_1(theirStaticPublicKey,
+ theirEphemeralPublicKey, ourStaticKeyPair,
+ ourEphemeralKeyPair, alice);
+ will(returnValue(masterKey));
+ }});
+ }
+
+ private void expectDeriveMasterKey_0_0(boolean alice) throws Exception {
+ context.checking(new Expectations() {{
+ oneOf(handshakeCrypto).deriveMasterKey_0_0(theirStaticPublicKey,
+ theirEphemeralPublicKey, ourStaticKeyPair,
+ ourEphemeralKeyPair, alice);
+ will(returnValue(masterKey));
+ }});
+ }
+
+ private void expectDeriveProof(boolean alice) {
+ context.checking(new Expectations() {{
+ oneOf(handshakeCrypto).proveOwnership(masterKey, alice);
+ will(returnValue(ourProof));
+ }});
+ }
+
+ private void expectSendProof() throws Exception {
+ expectWriteRecord(new Record(PROTOCOL_MAJOR_VERSION,
+ RECORD_TYPE_PROOF_OF_OWNERSHIP, ourProof));
+ }
+
+ private void expectReceiveProof() throws Exception {
+ expectReadRecord(new Record(PROTOCOL_MAJOR_VERSION,
+ RECORD_TYPE_PROOF_OF_OWNERSHIP, theirProof));
+ }
+
+ private void expectSendEof() throws Exception {
+ context.checking(new Expectations() {{
+ oneOf(streamWriter).sendEndOfStream();
+ }});
+ }
+
+ private void expectReceiveEof() throws Exception {
+ context.checking(new Expectations() {{
+ oneOf(recordReader).readRecord(with(any(RecordPredicate.class)),
+ with(any(RecordPredicate.class)));
+ will(returnValue(null));
+ }});
+ }
+
+ private void expectVerifyOwnership(boolean alice, boolean verified) {
+ context.checking(new Expectations() {{
+ oneOf(handshakeCrypto).verifyOwnership(masterKey, !alice,
+ theirProof);
+ will(returnValue(verified));
+ }});
+ }
+
+ private void expectWriteRecord(Record record) throws Exception {
+ context.checking(new Expectations() {{
+ oneOf(recordWriter).writeRecord(with(new PredicateMatcher<>(
+ Record.class, r -> recordEquals(record, r))));
+ oneOf(recordWriter).flush();
+ }});
+ }
+
+ private boolean recordEquals(Record expected, Record actual) {
+ return expected.getProtocolVersion() == actual.getProtocolVersion() &&
+ expected.getRecordType() == actual.getRecordType() &&
+ Arrays.equals(expected.getPayload(), actual.getPayload());
+ }
+
+ private void expectReadRecord(Record record) throws Exception {
+ context.checking(new Expectations() {{
+ // Test that the `accept` predicate passed to the reader would
+ // accept the expected record
+ oneOf(recordReader).readRecord(with(new PredicateMatcher<>(
+ RecordPredicate.class, rp -> rp.test(record))),
+ with(any(RecordPredicate.class)));
+ will(returnValue(record));
+ }});
+ }
+}
diff --git a/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyAgreementTest.java b/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyAgreementTest.java
index 5612663ae..e179ac2b2 100644
--- a/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyAgreementTest.java
+++ b/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyAgreementTest.java
@@ -60,6 +60,22 @@ public class KeyAgreementTest extends BrambleTestCase {
assertArrayEquals(aShared.getBytes(), bShared.getBytes());
}
+ @Test
+ public void testDerivesStaticEphemeralSharedSecretBadly() throws Exception {
+ String label = getRandomString(123);
+ KeyPair aStatic = crypto.generateAgreementKeyPair();
+ KeyPair aEphemeral = crypto.generateAgreementKeyPair();
+ KeyPair bStatic = crypto.generateAgreementKeyPair();
+ KeyPair bEphemeral = crypto.generateAgreementKeyPair();
+ SecretKey aShared = crypto.deriveSharedSecretBadly(label,
+ bStatic.getPublic(), bEphemeral.getPublic(), aStatic,
+ aEphemeral, true, inputs);
+ SecretKey bShared = crypto.deriveSharedSecretBadly(label,
+ aStatic.getPublic(), aEphemeral.getPublic(), bStatic,
+ bEphemeral, false, inputs);
+ assertArrayEquals(aShared.getBytes(), bShared.getBytes());
+ }
+
@Test
public void testDerivesStaticEphemeralSharedSecret() throws Exception {
String label = getRandomString(123);
diff --git a/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java b/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java
index d75bf40f4..b1708e859 100644
--- a/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java
+++ b/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java
@@ -1,6 +1,5 @@
package org.briarproject.bramble.keyagreement;
-import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection;
import org.briarproject.bramble.api.plugin.TransportConnectionReader;
import org.briarproject.bramble.api.plugin.TransportConnectionWriter;
@@ -8,11 +7,13 @@ import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
+import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import org.briarproject.bramble.test.BrambleMockTestCase;
import org.briarproject.bramble.test.CaptureArgumentAction;
+import org.briarproject.bramble.test.PredicateMatcher;
import org.jmock.Expectations;
import org.jmock.imposters.ByteBuddyClassImposteriser;
import org.junit.Test;
@@ -21,8 +22,6 @@ import java.io.InputStream;
import java.io.OutputStream;
import java.util.concurrent.atomic.AtomicReference;
-import javax.annotation.Nullable;
-
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM;
@@ -119,7 +118,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
public void testReceiveKeyThrowsExceptionIfAtEndOfStream()
throws Exception {
setup();
- expectReadRecord(null);
+ expectReadEof();
kat.receiveKey();
}
@@ -148,7 +147,7 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
public void testReceiveConfirmThrowsExceptionIfAtEndOfStream()
throws Exception {
setup();
- expectReadRecord(null);
+ expectReadEof();
kat.receiveConfirm();
}
@@ -209,12 +208,22 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
assertArrayEquals(expectedPayload, actual.getPayload());
}
- private void expectReadRecord(@Nullable Record record) throws Exception {
+ private void expectReadRecord(Record record) throws Exception {
context.checking(new Expectations() {{
- //noinspection unchecked
- oneOf(recordReader).readRecord(with(any(Predicate.class)),
- with(any(Predicate.class)));
+ // Test that the `accept` predicate passed to the reader would
+ // accept the expected record
+ oneOf(recordReader).readRecord(with(new PredicateMatcher<>(
+ RecordPredicate.class, rp -> rp.test(record))),
+ with(any(RecordPredicate.class)));
will(returnValue(record));
}});
}
+
+ private void expectReadEof() throws Exception {
+ context.checking(new Expectations() {{
+ oneOf(recordReader).readRecord(with(any(RecordPredicate.class)),
+ with(any(RecordPredicate.class)));
+ will(returnValue(null));
+ }});
+ }
}
diff --git a/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java
index c2cc4c55f..ed6f412ca 100644
--- a/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java
+++ b/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java
@@ -1,9 +1,9 @@
package org.briarproject.bramble.record;
import org.briarproject.bramble.api.FormatException;
-import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
+import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.test.BrambleTestCase;
import org.briarproject.bramble.util.ByteUtils;
import org.junit.Test;
@@ -128,12 +128,12 @@ public class RecordReaderImplTest extends BrambleTestCase {
RecordReader reader = new RecordReaderImpl(in);
// Accept records with version 0, type 0 or 1
- Predicate accept = r -> {
+ RecordPredicate accept = r -> {
byte version = r.getProtocolVersion(), type = r.getRecordType();
return version == 0 && (type == 0 || type == 1);
};
// Ignore records with version 0, any other type
- Predicate ignore = r -> {
+ RecordPredicate ignore = r -> {
byte version = r.getProtocolVersion(), type = r.getRecordType();
return version == 0 && !(type == 0 || type == 1);
};
@@ -183,12 +183,12 @@ public class RecordReaderImplTest extends BrambleTestCase {
RecordReader reader = new RecordReaderImpl(in);
// Accept records with version 0, type 0 or 1
- Predicate accept = r -> {
+ RecordPredicate accept = r -> {
byte version = r.getProtocolVersion(), type = r.getRecordType();
return version == 0 && (type == 0 || type == 1);
};
// Ignore records with version 0, any other type
- Predicate ignore = r -> {
+ RecordPredicate ignore = r -> {
byte version = r.getProtocolVersion(), type = r.getRecordType();
return version == 0 && !(type == 0 || type == 1);
};
diff --git a/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java
index d4eb61ff5..d9400cbd5 100644
--- a/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java
+++ b/bramble-core/src/test/java/org/briarproject/bramble/sync/SyncRecordReaderImplTest.java
@@ -1,10 +1,10 @@
package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.FormatException;
-import org.briarproject.bramble.api.Predicate;
import org.briarproject.bramble.api.UniqueId;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
+import org.briarproject.bramble.api.record.RecordReader.RecordPredicate;
import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.GroupId;
import org.briarproject.bramble.api.sync.Message;
@@ -24,8 +24,6 @@ import org.junit.Test;
import java.io.ByteArrayOutputStream;
import java.util.List;
-import javax.annotation.Nullable;
-
import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES;
import static org.briarproject.bramble.api.sync.RecordTypes.ACK;
import static org.briarproject.bramble.api.sync.RecordTypes.MESSAGE;
@@ -186,7 +184,7 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase {
@Test
public void testEofReturnsTrueWhenAtEndOfStream() throws Exception {
expectReadRecord(createAck());
- expectReadRecord(null);
+ expectReadEof();
SyncRecordReader reader =
new SyncRecordReaderImpl(messageFactory, recordReader);
@@ -212,15 +210,25 @@ public class SyncRecordReaderImplTest extends BrambleMockTestCase {
}});
}
- private void expectReadRecord(@Nullable Record record) throws Exception {
+ private void expectReadRecord(Record record) throws Exception {
context.checking(new Expectations() {{
- //noinspection unchecked
- oneOf(recordReader).readRecord(with(any(Predicate.class)),
- with(any(Predicate.class)));
+ // Test that the `accept` predicate passed to the reader would
+ // accept the expected record
+ oneOf(recordReader).readRecord(with(new PredicateMatcher<>(
+ RecordPredicate.class, rp -> rp.test(record))),
+ with(any(RecordPredicate.class)));
will(returnValue(record));
}});
}
+ private void expectReadEof() throws Exception {
+ context.checking(new Expectations() {{
+ oneOf(recordReader).readRecord(with(any(RecordPredicate.class)),
+ with(any(RecordPredicate.class)));
+ will(returnValue(null));
+ }});
+ }
+
private Record createMessage(int payloadLength) {
return new Record(PROTOCOL_VERSION, MESSAGE, new byte[payloadLength]);
}