Factor out transport crypto from CryptoComponent.

This commit is contained in:
akwizgran
2017-11-27 16:24:58 +00:00
parent 9f7021acd3
commit 1843aea2a7
12 changed files with 356 additions and 377 deletions

View File

@@ -1,8 +1,5 @@
package org.briarproject.bramble.api.crypto; package org.briarproject.bramble.api.crypto;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.transport.TransportKeys;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.security.SecureRandom; import java.security.SecureRandom;
@@ -89,33 +86,6 @@ public interface CryptoComponent {
PublicKey theirPublicKey, KeyPair ourKeyPair, PublicKey theirPublicKey, KeyPair ourKeyPair,
boolean alice, boolean aliceRecord); boolean alice, boolean aliceRecord);
/**
* Derives initial transport keys for the given transport in the given
* rotation period from the given master secret.
* <p/>
* Used by the transport security protocol.
*
* @param alice whether the keys are for use by Alice or Bob.
*/
TransportKeys deriveTransportKeys(TransportId t, SecretKey master,
long rotationPeriod, boolean alice);
/**
* Rotates the given transport keys to the given rotation period. If the
* keys are for a future rotation period they are not rotated.
* <p/>
* Used by the transport security protocol.
*/
TransportKeys rotateTransportKeys(TransportKeys k, long rotationPeriod);
/**
* Encodes the pseudo-random tag that is used to recognise a stream.
* <p/>
* Used by the transport security protocol.
*/
void encodeTag(byte[] tag, SecretKey tagKey, int protocolVersion,
long streamNumber);
/** /**
* Signs the given byte[] with the given ECDSA private key. * Signs the given byte[] with the given ECDSA private key.
* *

View File

@@ -0,0 +1,34 @@
package org.briarproject.bramble.api.crypto;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.transport.TransportKeys;
public interface TransportCrypto {
/**
* Derives initial transport keys for the given transport in the given
* rotation period from the given master secret.
* <p/>
* Used by the transport security protocol.
*
* @param alice whether the keys are for use by Alice or Bob.
*/
TransportKeys deriveTransportKeys(TransportId t, SecretKey master,
long rotationPeriod, boolean alice);
/**
* Rotates the given transport keys to the given rotation period. If the
* keys are for a future rotation period they are not rotated.
* <p/>
* Used by the transport security protocol.
*/
TransportKeys rotateTransportKeys(TransportKeys k, long rotationPeriod);
/**
* Encodes the pseudo-random tag that is used to recognise a stream.
* <p/>
* Used by the transport security protocol.
*/
void encodeTag(byte[] tag, SecretKey tagKey, int protocolVersion,
long streamNumber);
}

View File

@@ -10,11 +10,7 @@ import org.briarproject.bramble.api.crypto.KeyParser;
import org.briarproject.bramble.api.crypto.PrivateKey; import org.briarproject.bramble.api.crypto.PrivateKey;
import org.briarproject.bramble.api.crypto.PublicKey; import org.briarproject.bramble.api.crypto.PublicKey;
import org.briarproject.bramble.api.crypto.SecretKey; import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.system.SecureRandomProvider; import org.briarproject.bramble.api.system.SecureRandomProvider;
import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.util.ByteUtils; import org.briarproject.bramble.util.ByteUtils;
import org.briarproject.bramble.util.StringUtils; import org.briarproject.bramble.util.StringUtils;
import org.spongycastle.crypto.AsymmetricCipherKeyPair; import org.spongycastle.crypto.AsymmetricCipherKeyPair;
@@ -44,13 +40,8 @@ import javax.inject.Inject;
import static java.util.logging.Level.INFO; import static java.util.logging.Level.INFO;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.COMMIT_LENGTH; import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.COMMIT_LENGTH;
import static org.briarproject.bramble.api.transport.TransportConstants.TAG_LENGTH;
import static org.briarproject.bramble.crypto.EllipticCurveConstants.PARAMETERS; import static org.briarproject.bramble.crypto.EllipticCurveConstants.PARAMETERS;
import static org.briarproject.bramble.util.ByteUtils.INT_16_BYTES;
import static org.briarproject.bramble.util.ByteUtils.INT_32_BYTES; import static org.briarproject.bramble.util.ByteUtils.INT_32_BYTES;
import static org.briarproject.bramble.util.ByteUtils.INT_64_BYTES;
import static org.briarproject.bramble.util.ByteUtils.MAX_16_BIT_UNSIGNED;
import static org.briarproject.bramble.util.ByteUtils.MAX_32_BIT_UNSIGNED;
class CryptoComponentImpl implements CryptoComponent { class CryptoComponentImpl implements CryptoComponent {
@@ -75,15 +66,6 @@ class CryptoComponentImpl implements CryptoComponent {
private static final String CONFIRMATION_MAC_LABEL = private static final String CONFIRMATION_MAC_LABEL =
"org.briarproject.bramble.keyagreement/CONFIRMATION_MAC"; "org.briarproject.bramble.keyagreement/CONFIRMATION_MAC";
// KDF labels for tag key derivation
private static final String A_TAG = "ALICE_TAG_KEY";
private static final String B_TAG = "BOB_TAG_KEY";
// KDF labels for header key derivation
private static final String A_HEADER = "ALICE_HEADER_KEY";
private static final String B_HEADER = "BOB_HEADER_KEY";
// KDF label for key rotation
private static final String ROTATE = "ROTATE";
private final SecureRandom secureRandom; private final SecureRandom secureRandom;
private final ECKeyPairGenerator agreementKeyPairGenerator; private final ECKeyPairGenerator agreementKeyPairGenerator;
private final ECKeyPairGenerator signatureKeyPairGenerator; private final ECKeyPairGenerator signatureKeyPairGenerator;
@@ -309,104 +291,6 @@ class CryptoComponentImpl implements CryptoComponent {
} }
} }
@Override
public TransportKeys deriveTransportKeys(TransportId t,
SecretKey master, long rotationPeriod, boolean alice) {
// Keys for the previous period are derived from the master secret
SecretKey inTagPrev = deriveTagKey(master, t, !alice);
SecretKey inHeaderPrev = deriveHeaderKey(master, t, !alice);
SecretKey outTagPrev = deriveTagKey(master, t, alice);
SecretKey outHeaderPrev = deriveHeaderKey(master, t, alice);
// Derive the keys for the current and next periods
SecretKey inTagCurr = rotateKey(inTagPrev, rotationPeriod);
SecretKey inHeaderCurr = rotateKey(inHeaderPrev, rotationPeriod);
SecretKey inTagNext = rotateKey(inTagCurr, rotationPeriod + 1);
SecretKey inHeaderNext = rotateKey(inHeaderCurr, rotationPeriod + 1);
SecretKey outTagCurr = rotateKey(outTagPrev, rotationPeriod);
SecretKey outHeaderCurr = rotateKey(outHeaderPrev, rotationPeriod);
// Initialise the reordering windows and stream counters
IncomingKeys inPrev = new IncomingKeys(inTagPrev, inHeaderPrev,
rotationPeriod - 1);
IncomingKeys inCurr = new IncomingKeys(inTagCurr, inHeaderCurr,
rotationPeriod);
IncomingKeys inNext = new IncomingKeys(inTagNext, inHeaderNext,
rotationPeriod + 1);
OutgoingKeys outCurr = new OutgoingKeys(outTagCurr, outHeaderCurr,
rotationPeriod);
// Collect and return the keys
return new TransportKeys(t, inPrev, inCurr, inNext, outCurr);
}
@Override
public TransportKeys rotateTransportKeys(TransportKeys k,
long rotationPeriod) {
if (k.getRotationPeriod() >= rotationPeriod) return k;
IncomingKeys inPrev = k.getPreviousIncomingKeys();
IncomingKeys inCurr = k.getCurrentIncomingKeys();
IncomingKeys inNext = k.getNextIncomingKeys();
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
long startPeriod = outCurr.getRotationPeriod();
// Rotate the keys
for (long p = startPeriod + 1; p <= rotationPeriod; p++) {
inPrev = inCurr;
inCurr = inNext;
SecretKey inNextTag = rotateKey(inNext.getTagKey(), p + 1);
SecretKey inNextHeader = rotateKey(inNext.getHeaderKey(), p + 1);
inNext = new IncomingKeys(inNextTag, inNextHeader, p + 1);
SecretKey outCurrTag = rotateKey(outCurr.getTagKey(), p);
SecretKey outCurrHeader = rotateKey(outCurr.getHeaderKey(), p);
outCurr = new OutgoingKeys(outCurrTag, outCurrHeader, p);
}
// Collect and return the keys
return new TransportKeys(k.getTransportId(), inPrev, inCurr, inNext,
outCurr);
}
private SecretKey rotateKey(SecretKey k, long rotationPeriod) {
byte[] period = new byte[INT_64_BYTES];
ByteUtils.writeUint64(rotationPeriod, period, 0);
return deriveKey(ROTATE, k, period);
}
private SecretKey deriveTagKey(SecretKey master, TransportId t,
boolean alice) {
byte[] id = StringUtils.toUtf8(t.getString());
return deriveKey(alice ? A_TAG : B_TAG, master, id);
}
private SecretKey deriveHeaderKey(SecretKey master, TransportId t,
boolean alice) {
byte[] id = StringUtils.toUtf8(t.getString());
return deriveKey(alice ? A_HEADER : B_HEADER, master, id);
}
@Override
public void encodeTag(byte[] tag, SecretKey tagKey, int protocolVersion,
long streamNumber) {
if (tag.length < TAG_LENGTH) throw new IllegalArgumentException();
if (protocolVersion < 0 || protocolVersion > MAX_16_BIT_UNSIGNED)
throw new IllegalArgumentException();
if (streamNumber < 0 || streamNumber > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
// Initialise the PRF
Digest prf = new Blake2sDigest(tagKey.getBytes());
// The output of the PRF must be long enough to use as a tag
int macLength = prf.getDigestSize();
if (macLength < TAG_LENGTH) throw new IllegalStateException();
// The input is the protocol version as a 16-bit integer, followed by
// the stream number as a 64-bit integer
byte[] protocolVersionBytes = new byte[INT_16_BYTES];
ByteUtils.writeUint16(protocolVersion, protocolVersionBytes, 0);
prf.update(protocolVersionBytes, 0, protocolVersionBytes.length);
byte[] streamNumberBytes = new byte[INT_64_BYTES];
ByteUtils.writeUint64(streamNumber, streamNumberBytes, 0);
prf.update(streamNumberBytes, 0, streamNumberBytes.length);
byte[] mac = new byte[macLength];
prf.doFinal(mac, 0);
// The output is the first TAG_LENGTH bytes of the MAC
System.arraycopy(mac, 0, tag, 0, TAG_LENGTH);
}
@Override @Override
public byte[] sign(String label, byte[] toSign, byte[] privateKey) public byte[] sign(String label, byte[] toSign, byte[] privateKey)
throws GeneralSecurityException { throws GeneralSecurityException {

View File

@@ -6,6 +6,7 @@ import org.briarproject.bramble.api.crypto.CryptoExecutor;
import org.briarproject.bramble.api.crypto.PasswordStrengthEstimator; import org.briarproject.bramble.api.crypto.PasswordStrengthEstimator;
import org.briarproject.bramble.api.crypto.StreamDecrypterFactory; import org.briarproject.bramble.api.crypto.StreamDecrypterFactory;
import org.briarproject.bramble.api.crypto.StreamEncrypterFactory; import org.briarproject.bramble.api.crypto.StreamEncrypterFactory;
import org.briarproject.bramble.api.crypto.TransportCrypto;
import org.briarproject.bramble.api.lifecycle.LifecycleManager; import org.briarproject.bramble.api.lifecycle.LifecycleManager;
import org.briarproject.bramble.api.system.SecureRandomProvider; import org.briarproject.bramble.api.system.SecureRandomProvider;
@@ -74,6 +75,12 @@ public class CryptoModule {
return new PasswordStrengthEstimatorImpl(); return new PasswordStrengthEstimatorImpl();
} }
@Provides
TransportCrypto provideTransportCrypto(
TransportCryptoImpl transportCrypto) {
return transportCrypto;
}
@Provides @Provides
StreamDecrypterFactory provideStreamDecrypterFactory( StreamDecrypterFactory provideStreamDecrypterFactory(
Provider<AuthenticatedCipher> cipherProvider) { Provider<AuthenticatedCipher> cipherProvider) {
@@ -81,9 +88,11 @@ public class CryptoModule {
} }
@Provides @Provides
StreamEncrypterFactory provideStreamEncrypterFactory(CryptoComponent crypto, StreamEncrypterFactory provideStreamEncrypterFactory(
CryptoComponent crypto, TransportCrypto transportCrypto,
Provider<AuthenticatedCipher> cipherProvider) { Provider<AuthenticatedCipher> cipherProvider) {
return new StreamEncrypterFactoryImpl(crypto, cipherProvider); return new StreamEncrypterFactoryImpl(crypto, transportCrypto,
cipherProvider);
} }
@Provides @Provides

View File

@@ -4,6 +4,7 @@ import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.SecretKey; import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.crypto.StreamEncrypter; import org.briarproject.bramble.api.crypto.StreamEncrypter;
import org.briarproject.bramble.api.crypto.StreamEncrypterFactory; import org.briarproject.bramble.api.crypto.StreamEncrypterFactory;
import org.briarproject.bramble.api.crypto.TransportCrypto;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault; import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.transport.StreamContext; import org.briarproject.bramble.api.transport.StreamContext;
@@ -22,12 +23,15 @@ import static org.briarproject.bramble.api.transport.TransportConstants.TAG_LENG
class StreamEncrypterFactoryImpl implements StreamEncrypterFactory { class StreamEncrypterFactoryImpl implements StreamEncrypterFactory {
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final TransportCrypto transportCrypto;
private final Provider<AuthenticatedCipher> cipherProvider; private final Provider<AuthenticatedCipher> cipherProvider;
@Inject @Inject
StreamEncrypterFactoryImpl(CryptoComponent crypto, StreamEncrypterFactoryImpl(CryptoComponent crypto,
TransportCrypto transportCrypto,
Provider<AuthenticatedCipher> cipherProvider) { Provider<AuthenticatedCipher> cipherProvider) {
this.crypto = crypto; this.crypto = crypto;
this.transportCrypto = transportCrypto;
this.cipherProvider = cipherProvider; this.cipherProvider = cipherProvider;
} }
@@ -37,7 +41,8 @@ class StreamEncrypterFactoryImpl implements StreamEncrypterFactory {
AuthenticatedCipher cipher = cipherProvider.get(); AuthenticatedCipher cipher = cipherProvider.get();
long streamNumber = ctx.getStreamNumber(); long streamNumber = ctx.getStreamNumber();
byte[] tag = new byte[TAG_LENGTH]; byte[] tag = new byte[TAG_LENGTH];
crypto.encodeTag(tag, ctx.getTagKey(), PROTOCOL_VERSION, streamNumber); transportCrypto.encodeTag(tag, ctx.getTagKey(), PROTOCOL_VERSION,
streamNumber);
byte[] streamHeaderNonce = new byte[STREAM_HEADER_NONCE_LENGTH]; byte[] streamHeaderNonce = new byte[STREAM_HEADER_NONCE_LENGTH];
crypto.getSecureRandom().nextBytes(streamHeaderNonce); crypto.getSecureRandom().nextBytes(streamHeaderNonce);
SecretKey frameKey = crypto.generateSecretKey(); SecretKey frameKey = crypto.generateSecretKey();

View File

@@ -0,0 +1,137 @@
package org.briarproject.bramble.crypto;
import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.crypto.TransportCrypto;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.util.ByteUtils;
import org.briarproject.bramble.util.StringUtils;
import org.spongycastle.crypto.Digest;
import javax.inject.Inject;
import static org.briarproject.bramble.api.transport.TransportConstants.TAG_LENGTH;
import static org.briarproject.bramble.util.ByteUtils.INT_16_BYTES;
import static org.briarproject.bramble.util.ByteUtils.INT_64_BYTES;
import static org.briarproject.bramble.util.ByteUtils.MAX_16_BIT_UNSIGNED;
import static org.briarproject.bramble.util.ByteUtils.MAX_32_BIT_UNSIGNED;
class TransportCryptoImpl implements TransportCrypto {
// KDF labels for tag key derivation
private static final String A_TAG = "ALICE_TAG_KEY";
private static final String B_TAG = "BOB_TAG_KEY";
// KDF labels for header key derivation
private static final String A_HEADER = "ALICE_HEADER_KEY";
private static final String B_HEADER = "BOB_HEADER_KEY";
// KDF label for key rotation
private static final String ROTATE = "ROTATE";
private final CryptoComponent crypto;
@Inject
TransportCryptoImpl(CryptoComponent crypto) {
this.crypto = crypto;
}
@Override
public TransportKeys deriveTransportKeys(TransportId t,
SecretKey master, long rotationPeriod, boolean alice) {
// Keys for the previous period are derived from the master secret
SecretKey inTagPrev = deriveTagKey(master, t, !alice);
SecretKey inHeaderPrev = deriveHeaderKey(master, t, !alice);
SecretKey outTagPrev = deriveTagKey(master, t, alice);
SecretKey outHeaderPrev = deriveHeaderKey(master, t, alice);
// Derive the keys for the current and next periods
SecretKey inTagCurr = rotateKey(inTagPrev, rotationPeriod);
SecretKey inHeaderCurr = rotateKey(inHeaderPrev, rotationPeriod);
SecretKey inTagNext = rotateKey(inTagCurr, rotationPeriod + 1);
SecretKey inHeaderNext = rotateKey(inHeaderCurr, rotationPeriod + 1);
SecretKey outTagCurr = rotateKey(outTagPrev, rotationPeriod);
SecretKey outHeaderCurr = rotateKey(outHeaderPrev, rotationPeriod);
// Initialise the reordering windows and stream counters
IncomingKeys inPrev = new IncomingKeys(inTagPrev, inHeaderPrev,
rotationPeriod - 1);
IncomingKeys inCurr = new IncomingKeys(inTagCurr, inHeaderCurr,
rotationPeriod);
IncomingKeys inNext = new IncomingKeys(inTagNext, inHeaderNext,
rotationPeriod + 1);
OutgoingKeys outCurr = new OutgoingKeys(outTagCurr, outHeaderCurr,
rotationPeriod);
// Collect and return the keys
return new TransportKeys(t, inPrev, inCurr, inNext, outCurr);
}
@Override
public TransportKeys rotateTransportKeys(TransportKeys k,
long rotationPeriod) {
if (k.getRotationPeriod() >= rotationPeriod) return k;
IncomingKeys inPrev = k.getPreviousIncomingKeys();
IncomingKeys inCurr = k.getCurrentIncomingKeys();
IncomingKeys inNext = k.getNextIncomingKeys();
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
long startPeriod = outCurr.getRotationPeriod();
// Rotate the keys
for (long p = startPeriod + 1; p <= rotationPeriod; p++) {
inPrev = inCurr;
inCurr = inNext;
SecretKey inNextTag = rotateKey(inNext.getTagKey(), p + 1);
SecretKey inNextHeader = rotateKey(inNext.getHeaderKey(), p + 1);
inNext = new IncomingKeys(inNextTag, inNextHeader, p + 1);
SecretKey outCurrTag = rotateKey(outCurr.getTagKey(), p);
SecretKey outCurrHeader = rotateKey(outCurr.getHeaderKey(), p);
outCurr = new OutgoingKeys(outCurrTag, outCurrHeader, p);
}
// Collect and return the keys
return new TransportKeys(k.getTransportId(), inPrev, inCurr, inNext,
outCurr);
}
private SecretKey rotateKey(SecretKey k, long rotationPeriod) {
byte[] period = new byte[INT_64_BYTES];
ByteUtils.writeUint64(rotationPeriod, period, 0);
return crypto.deriveKey(ROTATE, k, period);
}
private SecretKey deriveTagKey(SecretKey master, TransportId t,
boolean alice) {
byte[] id = StringUtils.toUtf8(t.getString());
return crypto.deriveKey(alice ? A_TAG : B_TAG, master, id);
}
private SecretKey deriveHeaderKey(SecretKey master, TransportId t,
boolean alice) {
byte[] id = StringUtils.toUtf8(t.getString());
return crypto.deriveKey(alice ? A_HEADER : B_HEADER, master, id);
}
@Override
public void encodeTag(byte[] tag, SecretKey tagKey, int protocolVersion,
long streamNumber) {
if (tag.length < TAG_LENGTH) throw new IllegalArgumentException();
if (protocolVersion < 0 || protocolVersion > MAX_16_BIT_UNSIGNED)
throw new IllegalArgumentException();
if (streamNumber < 0 || streamNumber > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
// Initialise the PRF
Digest prf = new Blake2sDigest(tagKey.getBytes());
// The output of the PRF must be long enough to use as a tag
int macLength = prf.getDigestSize();
if (macLength < TAG_LENGTH) throw new IllegalStateException();
// The input is the protocol version as a 16-bit integer, followed by
// the stream number as a 64-bit integer
byte[] protocolVersionBytes = new byte[INT_16_BYTES];
ByteUtils.writeUint16(protocolVersion, protocolVersionBytes, 0);
prf.update(protocolVersionBytes, 0, protocolVersionBytes.length);
byte[] streamNumberBytes = new byte[INT_64_BYTES];
ByteUtils.writeUint64(streamNumber, streamNumberBytes, 0);
prf.update(streamNumberBytes, 0, streamNumberBytes.length);
byte[] mac = new byte[macLength];
prf.doFinal(mac, 0);
// The output is the first TAG_LENGTH bytes of the MAC
System.arraycopy(mac, 0, tag, 0, TAG_LENGTH);
}
}

View File

@@ -1,6 +1,6 @@
package org.briarproject.bramble.transport; package org.briarproject.bramble.transport;
import org.briarproject.bramble.api.crypto.CryptoComponent; import org.briarproject.bramble.api.crypto.TransportCrypto;
import org.briarproject.bramble.api.db.DatabaseComponent; import org.briarproject.bramble.api.db.DatabaseComponent;
import org.briarproject.bramble.api.db.DatabaseExecutor; import org.briarproject.bramble.api.db.DatabaseExecutor;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault; import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
@@ -20,17 +20,18 @@ class TransportKeyManagerFactoryImpl implements
TransportKeyManagerFactory { TransportKeyManagerFactory {
private final DatabaseComponent db; private final DatabaseComponent db;
private final CryptoComponent crypto; private final TransportCrypto transportCrypto;
private final Executor dbExecutor; private final Executor dbExecutor;
private final ScheduledExecutorService scheduler; private final ScheduledExecutorService scheduler;
private final Clock clock; private final Clock clock;
@Inject @Inject
TransportKeyManagerFactoryImpl(DatabaseComponent db, CryptoComponent crypto, TransportKeyManagerFactoryImpl(DatabaseComponent db,
TransportCrypto transportCrypto,
@DatabaseExecutor Executor dbExecutor, @DatabaseExecutor Executor dbExecutor,
@Scheduler ScheduledExecutorService scheduler, Clock clock) { @Scheduler ScheduledExecutorService scheduler, Clock clock) {
this.db = db; this.db = db;
this.crypto = crypto; this.transportCrypto = transportCrypto;
this.dbExecutor = dbExecutor; this.dbExecutor = dbExecutor;
this.scheduler = scheduler; this.scheduler = scheduler;
this.clock = clock; this.clock = clock;
@@ -39,8 +40,8 @@ class TransportKeyManagerFactoryImpl implements
@Override @Override
public TransportKeyManager createTransportKeyManager( public TransportKeyManager createTransportKeyManager(
TransportId transportId, long maxLatency) { TransportId transportId, long maxLatency) {
return new TransportKeyManagerImpl(db, crypto, dbExecutor, scheduler, return new TransportKeyManagerImpl(db, transportCrypto, dbExecutor,
clock, transportId, maxLatency); scheduler, clock, transportId, maxLatency);
} }
} }

View File

@@ -2,8 +2,8 @@ package org.briarproject.bramble.transport;
import org.briarproject.bramble.api.Bytes; import org.briarproject.bramble.api.Bytes;
import org.briarproject.bramble.api.contact.ContactId; import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.SecretKey; import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.crypto.TransportCrypto;
import org.briarproject.bramble.api.db.DatabaseComponent; import org.briarproject.bramble.api.db.DatabaseComponent;
import org.briarproject.bramble.api.db.DbException; import org.briarproject.bramble.api.db.DbException;
import org.briarproject.bramble.api.db.Transaction; import org.briarproject.bramble.api.db.Transaction;
@@ -41,7 +41,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
Logger.getLogger(TransportKeyManagerImpl.class.getName()); Logger.getLogger(TransportKeyManagerImpl.class.getName());
private final DatabaseComponent db; private final DatabaseComponent db;
private final CryptoComponent crypto; private final TransportCrypto transportCrypto;
private final Executor dbExecutor; private final Executor dbExecutor;
private final ScheduledExecutorService scheduler; private final ScheduledExecutorService scheduler;
private final Clock clock; private final Clock clock;
@@ -54,11 +54,12 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private final Map<ContactId, MutableOutgoingKeys> outContexts; private final Map<ContactId, MutableOutgoingKeys> outContexts;
private final Map<ContactId, MutableTransportKeys> keys; private final Map<ContactId, MutableTransportKeys> keys;
TransportKeyManagerImpl(DatabaseComponent db, CryptoComponent crypto, TransportKeyManagerImpl(DatabaseComponent db,
Executor dbExecutor, @Scheduler ScheduledExecutorService scheduler, TransportCrypto transportCrypto, Executor dbExecutor,
Clock clock, TransportId transportId, long maxLatency) { @Scheduler ScheduledExecutorService scheduler, Clock clock,
TransportId transportId, long maxLatency) {
this.db = db; this.db = db;
this.crypto = crypto; this.transportCrypto = transportCrypto;
this.dbExecutor = dbExecutor; this.dbExecutor = dbExecutor;
this.scheduler = scheduler; this.scheduler = scheduler;
this.clock = clock; this.clock = clock;
@@ -99,7 +100,8 @@ class TransportKeyManagerImpl implements TransportKeyManager {
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
ContactId c = e.getKey(); ContactId c = e.getKey();
TransportKeys k = e.getValue(); TransportKeys k = e.getValue();
TransportKeys k1 = crypto.rotateTransportKeys(k, rotationPeriod); TransportKeys k1 =
transportCrypto.rotateTransportKeys(k, rotationPeriod);
if (k1.getRotationPeriod() > k.getRotationPeriod()) if (k1.getRotationPeriod() > k.getRotationPeriod())
rotationResult.rotated.put(c, k1); rotationResult.rotated.put(c, k1);
rotationResult.current.put(c, k1); rotationResult.current.put(c, k1);
@@ -127,7 +129,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
for (long streamNumber : inKeys.getWindow().getUnseen()) { for (long streamNumber : inKeys.getWindow().getUnseen()) {
TagContext tagCtx = new TagContext(c, inKeys, streamNumber); TagContext tagCtx = new TagContext(c, inKeys, streamNumber);
byte[] tag = new byte[TAG_LENGTH]; byte[] tag = new byte[TAG_LENGTH];
crypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION, transportCrypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION,
streamNumber); streamNumber);
inContexts.put(new Bytes(tag), tagCtx); inContexts.put(new Bytes(tag), tagCtx);
} }
@@ -162,11 +164,11 @@ class TransportKeyManagerImpl implements TransportKeyManager {
// Work out what rotation period the timestamp belongs to // Work out what rotation period the timestamp belongs to
long rotationPeriod = timestamp / rotationPeriodLength; long rotationPeriod = timestamp / rotationPeriodLength;
// Derive the transport keys // Derive the transport keys
TransportKeys k = crypto.deriveTransportKeys(transportId, master, TransportKeys k = transportCrypto.deriveTransportKeys(transportId,
rotationPeriod, alice); master, rotationPeriod, alice);
// Rotate the keys to the current rotation period if necessary // Rotate the keys to the current rotation period if necessary
rotationPeriod = clock.currentTimeMillis() / rotationPeriodLength; rotationPeriod = clock.currentTimeMillis() / rotationPeriodLength;
k = crypto.rotateTransportKeys(k, rotationPeriod); k = transportCrypto.rotateTransportKeys(k, rotationPeriod);
// Initialise mutable state for the contact // Initialise mutable state for the contact
addKeys(c, new MutableTransportKeys(k)); addKeys(c, new MutableTransportKeys(k));
// Write the keys back to the DB // Write the keys back to the DB
@@ -234,8 +236,8 @@ class TransportKeyManagerImpl implements TransportKeyManager {
// Add tags for any stream numbers added to the window // Add tags for any stream numbers added to the window
for (long streamNumber : change.getAdded()) { for (long streamNumber : change.getAdded()) {
byte[] addTag = new byte[TAG_LENGTH]; byte[] addTag = new byte[TAG_LENGTH];
crypto.encodeTag(addTag, inKeys.getTagKey(), PROTOCOL_VERSION, transportCrypto.encodeTag(addTag, inKeys.getTagKey(),
streamNumber); PROTOCOL_VERSION, streamNumber);
inContexts.put(new Bytes(addTag), new TagContext( inContexts.put(new Bytes(addTag), new TagContext(
tagCtx.contactId, inKeys, streamNumber)); tagCtx.contactId, inKeys, streamNumber));
} }
@@ -243,7 +245,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
for (long streamNumber : change.getRemoved()) { for (long streamNumber : change.getRemoved()) {
if (streamNumber == tagCtx.streamNumber) continue; if (streamNumber == tagCtx.streamNumber) continue;
byte[] removeTag = new byte[TAG_LENGTH]; byte[] removeTag = new byte[TAG_LENGTH];
crypto.encodeTag(removeTag, inKeys.getTagKey(), transportCrypto.encodeTag(removeTag, inKeys.getTagKey(),
PROTOCOL_VERSION, streamNumber); PROTOCOL_VERSION, streamNumber);
inContexts.remove(new Bytes(removeTag)); inContexts.remove(new Bytes(removeTag));
} }

View File

@@ -3,11 +3,11 @@ package org.briarproject.bramble.crypto;
import org.briarproject.bramble.api.Bytes; import org.briarproject.bramble.api.Bytes;
import org.briarproject.bramble.api.crypto.CryptoComponent; import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.SecretKey; import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.crypto.TransportCrypto;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.test.BrambleTestCase; import org.briarproject.bramble.test.BrambleTestCase;
import org.briarproject.bramble.test.TestSecureRandomProvider; import org.briarproject.bramble.test.TestSecureRandomProvider;
import org.briarproject.bramble.test.TestUtils;
import org.junit.Test; import org.junit.Test;
import java.util.ArrayList; import java.util.ArrayList;
@@ -16,35 +16,34 @@ import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import static org.briarproject.bramble.test.TestUtils.getSecretKey;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
public class KeyDerivationTest extends BrambleTestCase { public class KeyDerivationTest extends BrambleTestCase {
private final CryptoComponent crypto =
new CryptoComponentImpl(new TestSecureRandomProvider());
private final TransportCrypto transportCrypto =
new TransportCryptoImpl(crypto);
private final TransportId transportId = new TransportId("id"); private final TransportId transportId = new TransportId("id");
private final CryptoComponent crypto; private final SecretKey master = getSecretKey();
private final SecretKey master;
public KeyDerivationTest() {
crypto = new CryptoComponentImpl(new TestSecureRandomProvider());
master = TestUtils.getSecretKey();
}
@Test @Test
public void testKeysAreDistinct() { public void testKeysAreDistinct() {
TransportKeys k = crypto.deriveTransportKeys(transportId, master, TransportKeys k = transportCrypto.deriveTransportKeys(transportId,
123, true); master, 123, true);
assertAllDifferent(k); assertAllDifferent(k);
} }
@Test @Test
public void testCurrentKeysMatchCurrentKeysOfContact() { public void testCurrentKeysMatchCurrentKeysOfContact() {
// Start in rotation period 123 // Start in rotation period 123
TransportKeys kA = crypto.deriveTransportKeys(transportId, master, TransportKeys kA = transportCrypto.deriveTransportKeys(transportId,
123, true); master, 123, true);
TransportKeys kB = crypto.deriveTransportKeys(transportId, master, TransportKeys kB = transportCrypto.deriveTransportKeys(transportId,
123, false); master, 123, false);
// Alice's incoming keys should equal Bob's outgoing keys // Alice's incoming keys should equal Bob's outgoing keys
assertArrayEquals(kA.getCurrentIncomingKeys().getTagKey().getBytes(), assertArrayEquals(kA.getCurrentIncomingKeys().getTagKey().getBytes(),
kB.getCurrentOutgoingKeys().getTagKey().getBytes()); kB.getCurrentOutgoingKeys().getTagKey().getBytes());
@@ -56,8 +55,8 @@ public class KeyDerivationTest extends BrambleTestCase {
assertArrayEquals(kA.getCurrentOutgoingKeys().getHeaderKey().getBytes(), assertArrayEquals(kA.getCurrentOutgoingKeys().getHeaderKey().getBytes(),
kB.getCurrentIncomingKeys().getHeaderKey().getBytes()); kB.getCurrentIncomingKeys().getHeaderKey().getBytes());
// Rotate into the future // Rotate into the future
kA = crypto.rotateTransportKeys(kA, 456); kA = transportCrypto.rotateTransportKeys(kA, 456);
kB = crypto.rotateTransportKeys(kB, 456); kB = transportCrypto.rotateTransportKeys(kB, 456);
// Alice's incoming keys should equal Bob's outgoing keys // Alice's incoming keys should equal Bob's outgoing keys
assertArrayEquals(kA.getCurrentIncomingKeys().getTagKey().getBytes(), assertArrayEquals(kA.getCurrentIncomingKeys().getTagKey().getBytes(),
kB.getCurrentOutgoingKeys().getTagKey().getBytes()); kB.getCurrentOutgoingKeys().getTagKey().getBytes());
@@ -73,22 +72,23 @@ public class KeyDerivationTest extends BrambleTestCase {
@Test @Test
public void testPreviousKeysMatchPreviousKeysOfContact() { public void testPreviousKeysMatchPreviousKeysOfContact() {
// Start in rotation period 123 // Start in rotation period 123
TransportKeys kA = crypto.deriveTransportKeys(transportId, master, TransportKeys kA = transportCrypto.deriveTransportKeys(transportId,
123, true); master, 123, true);
TransportKeys kB = crypto.deriveTransportKeys(transportId, master, TransportKeys kB = transportCrypto.deriveTransportKeys(transportId,
123, false); master, 123, false);
// Compare Alice's previous keys in period 456 with Bob's current keys // Compare Alice's previous keys in period 456 with Bob's current keys
// in period 455 // in period 455
kA = crypto.rotateTransportKeys(kA, 456); kA = transportCrypto.rotateTransportKeys(kA, 456);
kB = crypto.rotateTransportKeys(kB, 455); kB = transportCrypto.rotateTransportKeys(kB, 455);
// Alice's previous incoming keys should equal Bob's outgoing keys // Alice's previous incoming keys should equal Bob's outgoing keys
assertArrayEquals(kA.getPreviousIncomingKeys().getTagKey().getBytes(), assertArrayEquals(kA.getPreviousIncomingKeys().getTagKey().getBytes(),
kB.getCurrentOutgoingKeys().getTagKey().getBytes()); kB.getCurrentOutgoingKeys().getTagKey().getBytes());
assertArrayEquals(kA.getPreviousIncomingKeys().getHeaderKey().getBytes(), assertArrayEquals(
kA.getPreviousIncomingKeys().getHeaderKey().getBytes(),
kB.getCurrentOutgoingKeys().getHeaderKey().getBytes()); kB.getCurrentOutgoingKeys().getHeaderKey().getBytes());
// Compare Alice's current keys in period 456 with Bob's previous keys // Compare Alice's current keys in period 456 with Bob's previous keys
// in period 457 // in period 457
kB = crypto.rotateTransportKeys(kB, 457); kB = transportCrypto.rotateTransportKeys(kB, 457);
// Alice's outgoing keys should equal Bob's previous incoming keys // Alice's outgoing keys should equal Bob's previous incoming keys
assertArrayEquals(kA.getCurrentOutgoingKeys().getTagKey().getBytes(), assertArrayEquals(kA.getCurrentOutgoingKeys().getTagKey().getBytes(),
kB.getPreviousIncomingKeys().getTagKey().getBytes()); kB.getPreviousIncomingKeys().getTagKey().getBytes());
@@ -99,14 +99,14 @@ public class KeyDerivationTest extends BrambleTestCase {
@Test @Test
public void testNextKeysMatchNextKeysOfContact() { public void testNextKeysMatchNextKeysOfContact() {
// Start in rotation period 123 // Start in rotation period 123
TransportKeys kA = crypto.deriveTransportKeys(transportId, master, TransportKeys kA = transportCrypto.deriveTransportKeys(transportId,
123, true); master, 123, true);
TransportKeys kB = crypto.deriveTransportKeys(transportId, master, TransportKeys kB = transportCrypto.deriveTransportKeys(transportId,
123, false); master, 123, false);
// Compare Alice's current keys in period 456 with Bob's next keys in // Compare Alice's current keys in period 456 with Bob's next keys in
// period 455 // period 455
kA = crypto.rotateTransportKeys(kA, 456); kA = transportCrypto.rotateTransportKeys(kA, 456);
kB = crypto.rotateTransportKeys(kB, 455); kB = transportCrypto.rotateTransportKeys(kB, 455);
// Alice's outgoing keys should equal Bob's next incoming keys // Alice's outgoing keys should equal Bob's next incoming keys
assertArrayEquals(kA.getCurrentOutgoingKeys().getTagKey().getBytes(), assertArrayEquals(kA.getCurrentOutgoingKeys().getTagKey().getBytes(),
kB.getNextIncomingKeys().getTagKey().getBytes()); kB.getNextIncomingKeys().getTagKey().getBytes());
@@ -114,7 +114,7 @@ public class KeyDerivationTest extends BrambleTestCase {
kB.getNextIncomingKeys().getHeaderKey().getBytes()); kB.getNextIncomingKeys().getHeaderKey().getBytes());
// Compare Alice's next keys in period 456 with Bob's current keys // Compare Alice's next keys in period 456 with Bob's current keys
// in period 457 // in period 457
kB = crypto.rotateTransportKeys(kB, 457); kB = transportCrypto.rotateTransportKeys(kB, 457);
// Alice's next incoming keys should equal Bob's outgoing keys // Alice's next incoming keys should equal Bob's outgoing keys
assertArrayEquals(kA.getNextIncomingKeys().getTagKey().getBytes(), assertArrayEquals(kA.getNextIncomingKeys().getTagKey().getBytes(),
kB.getCurrentOutgoingKeys().getTagKey().getBytes()); kB.getCurrentOutgoingKeys().getTagKey().getBytes());
@@ -124,12 +124,12 @@ public class KeyDerivationTest extends BrambleTestCase {
@Test @Test
public void testMasterKeyAffectsOutput() { public void testMasterKeyAffectsOutput() {
SecretKey master1 = TestUtils.getSecretKey(); SecretKey master1 = getSecretKey();
assertFalse(Arrays.equals(master.getBytes(), master1.getBytes())); assertFalse(Arrays.equals(master.getBytes(), master1.getBytes()));
TransportKeys k = crypto.deriveTransportKeys(transportId, master, TransportKeys k = transportCrypto.deriveTransportKeys(transportId,
123, true); master, 123, true);
TransportKeys k1 = crypto.deriveTransportKeys(transportId, master1, TransportKeys k1 = transportCrypto.deriveTransportKeys(transportId,
123, true); master1, 123, true);
assertAllDifferent(k, k1); assertAllDifferent(k, k1);
} }
@@ -137,10 +137,10 @@ public class KeyDerivationTest extends BrambleTestCase {
public void testTransportIdAffectsOutput() { public void testTransportIdAffectsOutput() {
TransportId transportId1 = new TransportId("id1"); TransportId transportId1 = new TransportId("id1");
assertFalse(transportId.getString().equals(transportId1.getString())); assertFalse(transportId.getString().equals(transportId1.getString()));
TransportKeys k = crypto.deriveTransportKeys(transportId, master, TransportKeys k = transportCrypto.deriveTransportKeys(transportId,
123, true); master, 123, true);
TransportKeys k1 = crypto.deriveTransportKeys(transportId1, master, TransportKeys k1 = transportCrypto.deriveTransportKeys(transportId1,
123, true); master, 123, true);
assertAllDifferent(k, k1); assertAllDifferent(k, k1);
} }

View File

@@ -3,9 +3,8 @@ package org.briarproject.bramble.crypto;
import org.briarproject.bramble.api.Bytes; import org.briarproject.bramble.api.Bytes;
import org.briarproject.bramble.api.crypto.CryptoComponent; import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.SecretKey; import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.test.BrambleTestCase; import org.briarproject.bramble.api.crypto.TransportCrypto;
import org.briarproject.bramble.test.TestSecureRandomProvider; import org.briarproject.bramble.test.BrambleMockTestCase;
import org.briarproject.bramble.test.TestUtils;
import org.junit.Test; import org.junit.Test;
import java.util.HashSet; import java.util.HashSet;
@@ -14,25 +13,25 @@ import java.util.Set;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
import static org.briarproject.bramble.api.transport.TransportConstants.PROTOCOL_VERSION; import static org.briarproject.bramble.api.transport.TransportConstants.PROTOCOL_VERSION;
import static org.briarproject.bramble.api.transport.TransportConstants.TAG_LENGTH; import static org.briarproject.bramble.api.transport.TransportConstants.TAG_LENGTH;
import static org.briarproject.bramble.test.TestUtils.getSecretKey;
public class TagEncodingTest extends BrambleTestCase { public class TagEncodingTest extends BrambleMockTestCase {
private final CryptoComponent crypto; private final CryptoComponent crypto = context.mock(CryptoComponent.class);
private final SecretKey tagKey;
private final TransportCrypto transportCrypto =
new TransportCryptoImpl(crypto);
private final SecretKey tagKey = getSecretKey();
private final long streamNumber = 1234567890; private final long streamNumber = 1234567890;
public TagEncodingTest() {
crypto = new CryptoComponentImpl(new TestSecureRandomProvider());
tagKey = TestUtils.getSecretKey();
}
@Test @Test
public void testKeyAffectsTag() throws Exception { public void testKeyAffectsTag() throws Exception {
Set<Bytes> set = new HashSet<>(); Set<Bytes> set = new HashSet<>();
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
byte[] tag = new byte[TAG_LENGTH]; byte[] tag = new byte[TAG_LENGTH];
SecretKey tagKey = TestUtils.getSecretKey(); SecretKey tagKey = getSecretKey();
crypto.encodeTag(tag, tagKey, PROTOCOL_VERSION, streamNumber); transportCrypto.encodeTag(tag, tagKey, PROTOCOL_VERSION,
streamNumber);
assertTrue(set.add(new Bytes(tag))); assertTrue(set.add(new Bytes(tag)));
} }
} }
@@ -42,7 +41,8 @@ public class TagEncodingTest extends BrambleTestCase {
Set<Bytes> set = new HashSet<>(); Set<Bytes> set = new HashSet<>();
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
byte[] tag = new byte[TAG_LENGTH]; byte[] tag = new byte[TAG_LENGTH];
crypto.encodeTag(tag, tagKey, PROTOCOL_VERSION + i, streamNumber); transportCrypto.encodeTag(tag, tagKey, PROTOCOL_VERSION + i,
streamNumber);
assertTrue(set.add(new Bytes(tag))); assertTrue(set.add(new Bytes(tag)));
} }
} }
@@ -52,7 +52,8 @@ public class TagEncodingTest extends BrambleTestCase {
Set<Bytes> set = new HashSet<>(); Set<Bytes> set = new HashSet<>();
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
byte[] tag = new byte[TAG_LENGTH]; byte[] tag = new byte[TAG_LENGTH];
crypto.encodeTag(tag, tagKey, PROTOCOL_VERSION, streamNumber + i); transportCrypto.encodeTag(tag, tagKey, PROTOCOL_VERSION,
streamNumber + i);
assertTrue(set.add(new Bytes(tag))); assertTrue(set.add(new Bytes(tag)));
} }
} }

View File

@@ -1,8 +1,8 @@
package org.briarproject.bramble.sync; package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.contact.ContactId; import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.SecretKey; import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.crypto.TransportCrypto;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.sync.Ack; import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.ClientId; import org.briarproject.bramble.api.sync.ClientId;
@@ -57,7 +57,7 @@ public class SyncIntegrationTest extends BrambleTestCase {
@Inject @Inject
RecordWriterFactory recordWriterFactory; RecordWriterFactory recordWriterFactory;
@Inject @Inject
CryptoComponent crypto; TransportCrypto transportCrypto;
private final ContactId contactId; private final ContactId contactId;
private final TransportId transportId; private final TransportId transportId;
@@ -117,7 +117,8 @@ public class SyncIntegrationTest extends BrambleTestCase {
private void read(byte[] connectionData) throws Exception { private void read(byte[] connectionData) throws Exception {
// Calculate the expected tag // Calculate the expected tag
byte[] expectedTag = new byte[TAG_LENGTH]; byte[] expectedTag = new byte[TAG_LENGTH];
crypto.encodeTag(expectedTag, tagKey, PROTOCOL_VERSION, streamNumber); transportCrypto.encodeTag(expectedTag, tagKey, PROTOCOL_VERSION,
streamNumber);
// Read the tag // Read the tag
InputStream in = new ByteArrayInputStream(connectionData); InputStream in = new ByteArrayInputStream(connectionData);

View File

@@ -1,8 +1,8 @@
package org.briarproject.bramble.transport; package org.briarproject.bramble.transport;
import org.briarproject.bramble.api.contact.ContactId; import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.SecretKey; import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.crypto.TransportCrypto;
import org.briarproject.bramble.api.db.DatabaseComponent; import org.briarproject.bramble.api.db.DatabaseComponent;
import org.briarproject.bramble.api.db.Transaction; import org.briarproject.bramble.api.db.Transaction;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
@@ -11,12 +11,11 @@ import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.StreamContext; import org.briarproject.bramble.api.transport.StreamContext;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.test.BrambleTestCase; import org.briarproject.bramble.test.BrambleMockTestCase;
import org.briarproject.bramble.test.RunAction; import org.briarproject.bramble.test.RunAction;
import org.briarproject.bramble.test.TestUtils; import org.briarproject.bramble.test.TestUtils;
import org.hamcrest.Description; import org.hamcrest.Description;
import org.jmock.Expectations; import org.jmock.Expectations;
import org.jmock.Mockery;
import org.jmock.api.Action; import org.jmock.api.Action;
import org.jmock.api.Invocation; import org.jmock.api.Invocation;
import org.junit.Test; import org.junit.Test;
@@ -41,7 +40,15 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
public class TransportKeyManagerImplTest extends BrambleTestCase { public class TransportKeyManagerImplTest extends BrambleMockTestCase {
private final DatabaseComponent db = context.mock(DatabaseComponent.class);
private final TransportCrypto transportCrypto =
context.mock(TransportCrypto.class);
private final Executor dbExecutor = context.mock(Executor.class);
private final ScheduledExecutorService scheduler =
context.mock(ScheduledExecutorService.class);
private final Clock clock = context.mock(Clock.class);
private final TransportId transportId = new TransportId("id"); private final TransportId transportId = new TransportId("id");
private final long maxLatency = 30 * 1000; // 30 seconds private final long maxLatency = 30 * 1000; // 30 seconds
@@ -55,14 +62,6 @@ public class TransportKeyManagerImplTest extends BrambleTestCase {
@Test @Test
public void testKeysAreRotatedAtStartup() throws Exception { public void testKeysAreRotatedAtStartup() throws Exception {
Mockery context = new Mockery();
DatabaseComponent db = context.mock(DatabaseComponent.class);
CryptoComponent crypto = context.mock(CryptoComponent.class);
Executor dbExecutor = context.mock(Executor.class);
ScheduledExecutorService scheduler =
context.mock(ScheduledExecutorService.class);
Clock clock = context.mock(Clock.class);
Map<ContactId, TransportKeys> loaded = new LinkedHashMap<>(); Map<ContactId, TransportKeys> loaded = new LinkedHashMap<>();
TransportKeys shouldRotate = createTransportKeys(900, 0); TransportKeys shouldRotate = createTransportKeys(900, 0);
TransportKeys shouldNotRotate = createTransportKeys(1000, 0); TransportKeys shouldNotRotate = createTransportKeys(1000, 0);
@@ -79,14 +78,15 @@ public class TransportKeyManagerImplTest extends BrambleTestCase {
oneOf(db).getTransportKeys(txn, transportId); oneOf(db).getTransportKeys(txn, transportId);
will(returnValue(loaded)); will(returnValue(loaded));
// Rotate the transport keys // Rotate the transport keys
oneOf(crypto).rotateTransportKeys(shouldRotate, 1000); oneOf(transportCrypto).rotateTransportKeys(shouldRotate, 1000);
will(returnValue(rotated)); will(returnValue(rotated));
oneOf(crypto).rotateTransportKeys(shouldNotRotate, 1000); oneOf(transportCrypto).rotateTransportKeys(shouldNotRotate, 1000);
will(returnValue(shouldNotRotate)); will(returnValue(shouldNotRotate));
// Encode the tags (3 sets per contact) // Encode the tags (3 sets per contact)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(6).of(crypto).encodeTag(with(any(byte[].class)), exactly(6).of(transportCrypto).encodeTag(
with(tagKey), with(PROTOCOL_VERSION), with(i)); with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction()); will(new EncodeTagAction());
} }
// Save the keys that were rotated // Save the keys that were rotated
@@ -97,161 +97,124 @@ public class TransportKeyManagerImplTest extends BrambleTestCase {
with(rotationPeriodLength - 1), with(MILLISECONDS)); with(rotationPeriodLength - 1), with(MILLISECONDS));
}}); }});
TransportKeyManager TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
transportKeyManager = new TransportKeyManagerImpl(db, db, transportCrypto, dbExecutor, scheduler, clock, transportId,
crypto, dbExecutor, scheduler, clock, transportId, maxLatency); maxLatency);
transportKeyManager.start(txn); transportKeyManager.start(txn);
context.assertIsSatisfied();
} }
@Test @Test
public void testKeysAreRotatedWhenAddingContact() throws Exception { public void testKeysAreRotatedWhenAddingContact() throws Exception {
Mockery context = new Mockery(); boolean alice = random.nextBoolean();
DatabaseComponent db = context.mock(DatabaseComponent.class);
CryptoComponent crypto = context.mock(CryptoComponent.class);
Executor dbExecutor = context.mock(Executor.class);
ScheduledExecutorService scheduler =
context.mock(ScheduledExecutorService.class);
Clock clock = context.mock(Clock.class);
boolean alice = true;
TransportKeys transportKeys = createTransportKeys(999, 0); TransportKeys transportKeys = createTransportKeys(999, 0);
TransportKeys rotated = createTransportKeys(1000, 0); TransportKeys rotated = createTransportKeys(1000, 0);
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(crypto).deriveTransportKeys(transportId, masterKey, 999, oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
alice); 999, alice);
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Get the current time (1 ms after start of rotation period 1000) // Get the current time (1 ms after start of rotation period 1000)
oneOf(clock).currentTimeMillis(); oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1000 + 1)); will(returnValue(rotationPeriodLength * 1000 + 1));
// Rotate the transport keys // Rotate the transport keys
oneOf(crypto).rotateTransportKeys(transportKeys, 1000); oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
will(returnValue(rotated)); will(returnValue(rotated));
// Encode the tags (3 sets) // Encode the tags (3 sets)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(3).of(crypto).encodeTag(with(any(byte[].class)), exactly(3).of(transportCrypto).encodeTag(
with(tagKey), with(PROTOCOL_VERSION), with(i)); with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction()); will(new EncodeTagAction());
} }
// Save the keys // Save the keys
oneOf(db).addTransportKeys(txn, contactId, rotated); oneOf(db).addTransportKeys(txn, contactId, rotated);
}}); }});
TransportKeyManager TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
transportKeyManager = new TransportKeyManagerImpl(db, db, transportCrypto, dbExecutor, scheduler, clock, transportId,
crypto, dbExecutor, scheduler, clock, transportId, maxLatency); maxLatency);
// The timestamp is 1 ms before the start of rotation period 1000 // The timestamp is 1 ms before the start of rotation period 1000
long timestamp = rotationPeriodLength * 1000 - 1; long timestamp = rotationPeriodLength * 1000 - 1;
transportKeyManager.addContact(txn, contactId, masterKey, timestamp, transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
alice); alice);
context.assertIsSatisfied();
} }
@Test @Test
public void testOutgoingStreamContextIsNullIfContactIsNotFound() public void testOutgoingStreamContextIsNullIfContactIsNotFound()
throws Exception { throws Exception {
Mockery context = new Mockery();
DatabaseComponent db = context.mock(DatabaseComponent.class);
CryptoComponent crypto = context.mock(CryptoComponent.class);
Executor dbExecutor = context.mock(Executor.class);
ScheduledExecutorService scheduler =
context.mock(ScheduledExecutorService.class);
Clock clock = context.mock(Clock.class);
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
TransportKeyManager TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
transportKeyManager = new TransportKeyManagerImpl(db, db, transportCrypto, dbExecutor, scheduler, clock, transportId,
crypto, dbExecutor, scheduler, clock, transportId, maxLatency); maxLatency);
assertNull(transportKeyManager.getStreamContext(txn, contactId)); assertNull(transportKeyManager.getStreamContext(txn, contactId));
context.assertIsSatisfied();
} }
@Test @Test
public void testOutgoingStreamContextIsNullIfStreamCounterIsExhausted() public void testOutgoingStreamContextIsNullIfStreamCounterIsExhausted()
throws Exception { throws Exception {
Mockery context = new Mockery(); boolean alice = random.nextBoolean();
DatabaseComponent db = context.mock(DatabaseComponent.class);
CryptoComponent crypto = context.mock(CryptoComponent.class);
Executor dbExecutor = context.mock(Executor.class);
ScheduledExecutorService scheduler =
context.mock(ScheduledExecutorService.class);
Clock clock = context.mock(Clock.class);
boolean alice = true;
// The stream counter has been exhausted // The stream counter has been exhausted
TransportKeys transportKeys = createTransportKeys(1000, TransportKeys transportKeys = createTransportKeys(1000,
MAX_32_BIT_UNSIGNED + 1); MAX_32_BIT_UNSIGNED + 1);
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000, oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
alice); 1000, alice);
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Get the current time (the start of rotation period 1000) // Get the current time (the start of rotation period 1000)
oneOf(clock).currentTimeMillis(); oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1000)); will(returnValue(rotationPeriodLength * 1000));
// Encode the tags (3 sets) // Encode the tags (3 sets)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(3).of(crypto).encodeTag(with(any(byte[].class)), exactly(3).of(transportCrypto).encodeTag(
with(tagKey), with(PROTOCOL_VERSION), with(i)); with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction()); will(new EncodeTagAction());
} }
// Rotate the transport keys (the keys are unaffected) // Rotate the transport keys (the keys are unaffected)
oneOf(crypto).rotateTransportKeys(transportKeys, 1000); oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Save the keys // Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys); oneOf(db).addTransportKeys(txn, contactId, transportKeys);
}}); }});
TransportKeyManager TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
transportKeyManager = new TransportKeyManagerImpl(db, db, transportCrypto, dbExecutor, scheduler, clock, transportId,
crypto, dbExecutor, scheduler, clock, transportId, maxLatency); maxLatency);
// The timestamp is at the start of rotation period 1000 // The timestamp is at the start of rotation period 1000
long timestamp = rotationPeriodLength * 1000; long timestamp = rotationPeriodLength * 1000;
transportKeyManager.addContact(txn, contactId, masterKey, timestamp, transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
alice); alice);
assertNull(transportKeyManager.getStreamContext(txn, contactId)); assertNull(transportKeyManager.getStreamContext(txn, contactId));
context.assertIsSatisfied();
} }
@Test @Test
public void testOutgoingStreamCounterIsIncremented() throws Exception { public void testOutgoingStreamCounterIsIncremented() throws Exception {
Mockery context = new Mockery(); boolean alice = random.nextBoolean();
DatabaseComponent db = context.mock(DatabaseComponent.class);
CryptoComponent crypto = context.mock(CryptoComponent.class);
Executor dbExecutor = context.mock(Executor.class);
ScheduledExecutorService scheduler =
context.mock(ScheduledExecutorService.class);
Clock clock = context.mock(Clock.class);
boolean alice = true;
// The stream counter can be used one more time before being exhausted // The stream counter can be used one more time before being exhausted
TransportKeys transportKeys = createTransportKeys(1000, TransportKeys transportKeys = createTransportKeys(1000,
MAX_32_BIT_UNSIGNED); MAX_32_BIT_UNSIGNED);
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000, oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
alice); 1000, alice);
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Get the current time (the start of rotation period 1000) // Get the current time (the start of rotation period 1000)
oneOf(clock).currentTimeMillis(); oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1000)); will(returnValue(rotationPeriodLength * 1000));
// Encode the tags (3 sets) // Encode the tags (3 sets)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(3).of(crypto).encodeTag(with(any(byte[].class)), exactly(3).of(transportCrypto).encodeTag(
with(tagKey), with(PROTOCOL_VERSION), with(i)); with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction()); will(new EncodeTagAction());
} }
// Rotate the transport keys (the keys are unaffected) // Rotate the transport keys (the keys are unaffected)
oneOf(crypto).rotateTransportKeys(transportKeys, 1000); oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Save the keys // Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys); oneOf(db).addTransportKeys(txn, contactId, transportKeys);
@@ -259,9 +222,9 @@ public class TransportKeyManagerImplTest extends BrambleTestCase {
oneOf(db).incrementStreamCounter(txn, contactId, transportId, 1000); oneOf(db).incrementStreamCounter(txn, contactId, transportId, 1000);
}}); }});
TransportKeyManager TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
transportKeyManager = new TransportKeyManagerImpl(db, db, transportCrypto, dbExecutor, scheduler, clock, transportId,
crypto, dbExecutor, scheduler, clock, transportId, maxLatency); maxLatency);
// The timestamp is at the start of rotation period 1000 // The timestamp is at the start of rotation period 1000
long timestamp = rotationPeriodLength * 1000; long timestamp = rotationPeriodLength * 1000;
transportKeyManager.addContact(txn, contactId, masterKey, timestamp, transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
@@ -277,94 +240,76 @@ public class TransportKeyManagerImplTest extends BrambleTestCase {
assertEquals(MAX_32_BIT_UNSIGNED, ctx.getStreamNumber()); assertEquals(MAX_32_BIT_UNSIGNED, ctx.getStreamNumber());
// The second request should return null, the counter is exhausted // The second request should return null, the counter is exhausted
assertNull(transportKeyManager.getStreamContext(txn, contactId)); assertNull(transportKeyManager.getStreamContext(txn, contactId));
context.assertIsSatisfied();
} }
@Test @Test
public void testIncomingStreamContextIsNullIfTagIsNotFound() public void testIncomingStreamContextIsNullIfTagIsNotFound()
throws Exception { throws Exception {
Mockery context = new Mockery(); boolean alice = random.nextBoolean();
DatabaseComponent db = context.mock(DatabaseComponent.class);
CryptoComponent crypto = context.mock(CryptoComponent.class);
Executor dbExecutor = context.mock(Executor.class);
ScheduledExecutorService scheduler =
context.mock(ScheduledExecutorService.class);
Clock clock = context.mock(Clock.class);
boolean alice = true;
TransportKeys transportKeys = createTransportKeys(1000, 0); TransportKeys transportKeys = createTransportKeys(1000, 0);
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000, oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
alice); 1000, alice);
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Get the current time (the start of rotation period 1000) // Get the current time (the start of rotation period 1000)
oneOf(clock).currentTimeMillis(); oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1000)); will(returnValue(rotationPeriodLength * 1000));
// Encode the tags (3 sets) // Encode the tags (3 sets)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(3).of(crypto).encodeTag(with(any(byte[].class)), exactly(3).of(transportCrypto).encodeTag(
with(tagKey), with(PROTOCOL_VERSION), with(i)); with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction()); will(new EncodeTagAction());
} }
// Rotate the transport keys (the keys are unaffected) // Rotate the transport keys (the keys are unaffected)
oneOf(crypto).rotateTransportKeys(transportKeys, 1000); oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Save the keys // Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys); oneOf(db).addTransportKeys(txn, contactId, transportKeys);
}}); }});
TransportKeyManager TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
transportKeyManager = new TransportKeyManagerImpl(db, db, transportCrypto, dbExecutor, scheduler, clock, transportId,
crypto, dbExecutor, scheduler, clock, transportId, maxLatency); maxLatency);
// The timestamp is at the start of rotation period 1000 // The timestamp is at the start of rotation period 1000
long timestamp = rotationPeriodLength * 1000; long timestamp = rotationPeriodLength * 1000;
transportKeyManager.addContact(txn, contactId, masterKey, timestamp, transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
alice); alice);
assertNull(transportKeyManager.getStreamContext(txn, assertNull(transportKeyManager.getStreamContext(txn,
new byte[TAG_LENGTH])); new byte[TAG_LENGTH]));
context.assertIsSatisfied();
} }
@Test @Test
public void testTagIsNotRecognisedTwice() throws Exception { public void testTagIsNotRecognisedTwice() throws Exception {
Mockery context = new Mockery(); boolean alice = random.nextBoolean();
DatabaseComponent db = context.mock(DatabaseComponent.class);
CryptoComponent crypto = context.mock(CryptoComponent.class);
Executor dbExecutor = context.mock(Executor.class);
ScheduledExecutorService scheduler =
context.mock(ScheduledExecutorService.class);
Clock clock = context.mock(Clock.class);
boolean alice = true;
TransportKeys transportKeys = createTransportKeys(1000, 0); TransportKeys transportKeys = createTransportKeys(1000, 0);
// Keep a copy of the tags // Keep a copy of the tags
List<byte[]> tags = new ArrayList<>(); List<byte[]> tags = new ArrayList<>();
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000, oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
alice); 1000, alice);
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Get the current time (the start of rotation period 1000) // Get the current time (the start of rotation period 1000)
oneOf(clock).currentTimeMillis(); oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1000)); will(returnValue(rotationPeriodLength * 1000));
// Encode the tags (3 sets) // Encode the tags (3 sets)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(3).of(crypto).encodeTag(with(any(byte[].class)), exactly(3).of(transportCrypto).encodeTag(
with(tagKey), with(PROTOCOL_VERSION), with(i)); with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction(tags)); will(new EncodeTagAction(tags));
} }
// Rotate the transport keys (the keys are unaffected) // Rotate the transport keys (the keys are unaffected)
oneOf(crypto).rotateTransportKeys(transportKeys, 1000); oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Save the keys // Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys); oneOf(db).addTransportKeys(txn, contactId, transportKeys);
// Encode a new tag after sliding the window // Encode a new tag after sliding the window
oneOf(crypto).encodeTag(with(any(byte[].class)), oneOf(transportCrypto).encodeTag(with(any(byte[].class)),
with(tagKey), with(PROTOCOL_VERSION), with(tagKey), with(PROTOCOL_VERSION),
with((long) REORDERING_WINDOW_SIZE)); with((long) REORDERING_WINDOW_SIZE));
will(new EncodeTagAction(tags)); will(new EncodeTagAction(tags));
@@ -373,9 +318,9 @@ public class TransportKeyManagerImplTest extends BrambleTestCase {
1, new byte[REORDERING_WINDOW_SIZE / 8]); 1, new byte[REORDERING_WINDOW_SIZE / 8]);
}}); }});
TransportKeyManager TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
transportKeyManager = new TransportKeyManagerImpl(db, db, transportCrypto, dbExecutor, scheduler, clock, transportId,
crypto, dbExecutor, scheduler, clock, transportId, maxLatency); maxLatency);
// The timestamp is at the start of rotation period 1000 // The timestamp is at the start of rotation period 1000
long timestamp = rotationPeriodLength * 1000; long timestamp = rotationPeriodLength * 1000;
transportKeyManager.addContact(txn, contactId, masterKey, timestamp, transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
@@ -395,20 +340,10 @@ public class TransportKeyManagerImplTest extends BrambleTestCase {
assertEquals(REORDERING_WINDOW_SIZE * 3 + 1, tags.size()); assertEquals(REORDERING_WINDOW_SIZE * 3 + 1, tags.size());
// The second request should return null, the tag has already been used // The second request should return null, the tag has already been used
assertNull(transportKeyManager.getStreamContext(txn, tag)); assertNull(transportKeyManager.getStreamContext(txn, tag));
context.assertIsSatisfied();
} }
@Test @Test
public void testKeysAreRotatedToCurrentPeriod() throws Exception { public void testKeysAreRotatedToCurrentPeriod() throws Exception {
Mockery context = new Mockery();
DatabaseComponent db = context.mock(DatabaseComponent.class);
CryptoComponent crypto = context.mock(CryptoComponent.class);
Executor dbExecutor = context.mock(Executor.class);
ScheduledExecutorService scheduler =
context.mock(ScheduledExecutorService.class);
Clock clock = context.mock(Clock.class);
TransportKeys transportKeys = createTransportKeys(1000, 0); TransportKeys transportKeys = createTransportKeys(1000, 0);
Map<ContactId, TransportKeys> loaded = Map<ContactId, TransportKeys> loaded =
Collections.singletonMap(contactId, transportKeys); Collections.singletonMap(contactId, transportKeys);
@@ -424,12 +359,13 @@ public class TransportKeyManagerImplTest extends BrambleTestCase {
oneOf(db).getTransportKeys(txn, transportId); oneOf(db).getTransportKeys(txn, transportId);
will(returnValue(loaded)); will(returnValue(loaded));
// Rotate the transport keys (the keys are unaffected) // Rotate the transport keys (the keys are unaffected)
oneOf(crypto).rotateTransportKeys(transportKeys, 1000); oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Encode the tags (3 sets) // Encode the tags (3 sets)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(3).of(crypto).encodeTag(with(any(byte[].class)), exactly(3).of(transportCrypto).encodeTag(
with(tagKey), with(PROTOCOL_VERSION), with(i)); with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction()); will(new EncodeTagAction());
} }
// Schedule key rotation at the start of the next rotation period // Schedule key rotation at the start of the next rotation period
@@ -445,13 +381,14 @@ public class TransportKeyManagerImplTest extends BrambleTestCase {
oneOf(clock).currentTimeMillis(); oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1001)); will(returnValue(rotationPeriodLength * 1001));
// Rotate the transport keys // Rotate the transport keys
oneOf(crypto).rotateTransportKeys(with(any(TransportKeys.class)), oneOf(transportCrypto).rotateTransportKeys(
with(1001L)); with(any(TransportKeys.class)), with(1001L));
will(returnValue(rotated)); will(returnValue(rotated));
// Encode the tags (3 sets) // Encode the tags (3 sets)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(3).of(crypto).encodeTag(with(any(byte[].class)), exactly(3).of(transportCrypto).encodeTag(
with(tagKey), with(PROTOCOL_VERSION), with(i)); with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction()); will(new EncodeTagAction());
} }
// Save the keys that were rotated // Save the keys that were rotated
@@ -465,12 +402,10 @@ public class TransportKeyManagerImplTest extends BrambleTestCase {
oneOf(db).endTransaction(txn1); oneOf(db).endTransaction(txn1);
}}); }});
TransportKeyManager TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
transportKeyManager = new TransportKeyManagerImpl(db, db, transportCrypto, dbExecutor, scheduler, clock, transportId,
crypto, dbExecutor, scheduler, clock, transportId, maxLatency); maxLatency);
transportKeyManager.start(txn); transportKeyManager.start(txn);
context.assertIsSatisfied();
} }
private TransportKeys createTransportKeys(long rotationPeriod, private TransportKeys createTransportKeys(long rotationPeriod,