Merge branch 'multiple-transport-keys' into 'master'

Support multiple sets of transport keys per contact

See merge request akwizgran/briar!745
This commit is contained in:
akwizgran
2018-04-17 14:02:45 +00:00
23 changed files with 1278 additions and 476 deletions

View File

@@ -6,7 +6,7 @@ import javax.annotation.concurrent.Immutable;
/** /**
* Type-safe wrapper for an integer that uniquely identifies a contact within * Type-safe wrapper for an integer that uniquely identifies a contact within
* the scope of a single node. * the scope of the local device.
*/ */
@Immutable @Immutable
@NotNullByDefault @NotNullByDefault

View File

@@ -23,13 +23,21 @@ public interface ContactManager {
void registerRemoveContactHook(RemoveContactHook hook); void registerRemoveContactHook(RemoveContactHook hook);
/** /**
* Stores a contact within the given transaction associated with the given * Stores a contact associated with the given local and remote pseudonyms,
* local and remote pseudonyms, and returns an ID for the contact. * derives and stores transport keys for each transport, and returns an ID
* for the contact.
*/ */
ContactId addContact(Transaction txn, Author remote, AuthorId local, ContactId addContact(Transaction txn, Author remote, AuthorId local,
SecretKey master, long timestamp, boolean alice, boolean verified, SecretKey master, long timestamp, boolean alice, boolean verified,
boolean active) throws DbException; boolean active) throws DbException;
/**
* Stores a contact associated with the given local and remote pseudonyms
* and returns an ID for the contact.
*/
ContactId addContact(Transaction txn, Author remote, AuthorId local,
boolean verified, boolean active) throws DbException;
/** /**
* Stores a contact associated with the given local and remote pseudonyms, * Stores a contact associated with the given local and remote pseudonyms,
* and returns an ID for the contact. * and returns an ID for the contact.

View File

@@ -14,9 +14,10 @@ public interface TransportCrypto {
* rotation period from the given master secret. * rotation period from the given master secret.
* *
* @param alice whether the keys are for use by Alice or Bob. * @param alice whether the keys are for use by Alice or Bob.
* @param active whether the keys are usable for outgoing streams.
*/ */
TransportKeys deriveTransportKeys(TransportId t, SecretKey master, TransportKeys deriveTransportKeys(TransportId t, SecretKey master,
long rotationPeriod, boolean alice); long rotationPeriod, boolean alice, boolean active);
/** /**
* Rotates the given transport keys to the given rotation period. If the * Rotates the given transport keys to the given rotation period. If the

View File

@@ -18,6 +18,8 @@ import org.briarproject.bramble.api.sync.MessageId;
import org.briarproject.bramble.api.sync.MessageStatus; import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.Offer; import org.briarproject.bramble.api.sync.Offer;
import org.briarproject.bramble.api.sync.Request; import org.briarproject.bramble.api.sync.Request;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import java.util.Collection; import java.util.Collection;
@@ -102,10 +104,17 @@ public interface DatabaseComponent {
throws DbException; throws DbException;
/** /**
* Stores transport keys for a newly added contact. * Stores the given transport keys, optionally binding them to the given
* contact, and returns a key set ID.
*/ */
void addTransportKeys(Transaction txn, ContactId c, TransportKeys k) KeySetId addTransportKeys(Transaction txn, @Nullable ContactId c,
throws DbException; TransportKeys k) throws DbException;
/**
* Binds the given keys for the given transport to the given contact.
*/
void bindTransportKeys(Transaction txn, ContactId c, TransportId t,
KeySetId k) throws DbException;
/** /**
* Returns true if the database contains the given contact for the given * Returns true if the database contains the given contact for the given
@@ -394,15 +403,14 @@ public interface DatabaseComponent {
* <p/> * <p/>
* Read-only. * Read-only.
*/ */
Map<ContactId, TransportKeys> getTransportKeys(Transaction txn, Collection<KeySet> getTransportKeys(Transaction txn, TransportId t)
TransportId t) throws DbException; throws DbException;
/** /**
* Increments the outgoing stream counter for the given contact and * Increments the outgoing stream counter for the given transport keys.
* transport in the given rotation period .
*/ */
void incrementStreamCounter(Transaction txn, ContactId c, TransportId t, void incrementStreamCounter(Transaction txn, TransportId t, KeySetId k)
long rotationPeriod) throws DbException; throws DbException;
/** /**
* Merges the given metadata with the existing metadata for the given * Merges the given metadata with the existing metadata for the given
@@ -472,6 +480,12 @@ public interface DatabaseComponent {
*/ */
void removeTransport(Transaction txn, TransportId t) throws DbException; void removeTransport(Transaction txn, TransportId t) throws DbException;
/**
* Removes the given transport keys from the database.
*/
void removeTransportKeys(Transaction txn, TransportId t, KeySetId k)
throws DbException;
/** /**
* Marks the given contact as verified. * Marks the given contact as verified.
*/ */
@@ -507,15 +521,21 @@ public interface DatabaseComponent {
Collection<MessageId> dependencies) throws DbException; Collection<MessageId> dependencies) throws DbException;
/** /**
* Sets the reordering window for the given contact and transport in the * Sets the reordering window for the given key set and transport in the
* given rotation period. * given rotation period.
*/ */
void setReorderingWindow(Transaction txn, ContactId c, TransportId t, void setReorderingWindow(Transaction txn, KeySetId k, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException; long rotationPeriod, long base, byte[] bitmap) throws DbException;
/**
* Marks the given transport keys as usable for outgoing streams.
*/
void setTransportKeysActive(Transaction txn, TransportId t, KeySetId k)
throws DbException;
/** /**
* Stores the given transport keys, deleting any keys they have replaced. * Stores the given transport keys, deleting any keys they have replaced.
*/ */
void updateTransportKeys(Transaction txn, void updateTransportKeys(Transaction txn, Collection<KeySet> keys)
Map<ContactId, TransportKeys> keys) throws DbException; throws DbException;
} }

View File

@@ -6,6 +6,8 @@ import org.briarproject.bramble.api.db.DbException;
import org.briarproject.bramble.api.db.Transaction; import org.briarproject.bramble.api.db.Transaction;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
import java.util.Map;
import javax.annotation.Nullable; import javax.annotation.Nullable;
/** /**
@@ -16,13 +18,51 @@ public interface KeyManager {
/** /**
* Informs the key manager that a new contact has been added. Derives and * Informs the key manager that a new contact has been added. Derives and
* stores transport keys for communicating with the contact. * stores a set of transport keys for communicating with the contact over
* each transport.
* <p/>
* {@link StreamContext StreamContexts} for the contact can be created * {@link StreamContext StreamContexts} for the contact can be created
* after this method has returned. * after this method has returned.
*/ */
void addContact(Transaction txn, ContactId c, SecretKey master, void addContact(Transaction txn, ContactId c, SecretKey master,
long timestamp, boolean alice) throws DbException; long timestamp, boolean alice) throws DbException;
/**
* Derives and stores a set of unbound transport keys for each transport
* and returns the key set IDs.
* <p/>
* The keys must be bound before they can be used for incoming streams,
* and also activated before they can be used for outgoing streams.
*/
Map<TransportId, KeySetId> addUnboundKeys(Transaction txn, SecretKey master,
long timestamp, boolean alice) throws DbException;
/**
* Binds the given transport keys to the given contact.
*/
void bindKeys(Transaction txn, ContactId c, Map<TransportId, KeySetId> keys)
throws DbException;
/**
* Marks the given transport keys as usable for outgoing streams. Keys must
* be bound before they are activated.
*/
void activateKeys(Transaction txn, Map<TransportId, KeySetId> keys)
throws DbException;
/**
* Removes the given transport keys, which must not have been bound, from
* the manager and the database.
*/
void removeKeys(Transaction txn, Map<TransportId, KeySetId> keys)
throws DbException;
/**
* Returns true if we have keys that can be used for outgoing streams to
* the given contact over the given transport.
*/
boolean canSendOutgoingStreams(ContactId c, TransportId t);
/** /**
* Returns a {@link StreamContext} for sending a stream to the given * Returns a {@link StreamContext} for sending a stream to the given
* contact over the given transport, or null if an error occurs or the * contact over the given transport, or null if an error occurs or the

View File

@@ -0,0 +1,51 @@
package org.briarproject.bramble.api.transport;
import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import javax.annotation.Nullable;
import javax.annotation.concurrent.Immutable;
/**
* A set of transport keys for communicating with a contact. If the keys have
* not yet been bound to a contact, {@link #getContactId()}} returns null.
*/
@Immutable
@NotNullByDefault
public class KeySet {
private final KeySetId keySetId;
@Nullable
private final ContactId contactId;
private final TransportKeys transportKeys;
public KeySet(KeySetId keySetId, @Nullable ContactId contactId,
TransportKeys transportKeys) {
this.keySetId = keySetId;
this.contactId = contactId;
this.transportKeys = transportKeys;
}
public KeySetId getKeySetId() {
return keySetId;
}
@Nullable
public ContactId getContactId() {
return contactId;
}
public TransportKeys getTransportKeys() {
return transportKeys;
}
@Override
public int hashCode() {
return keySetId.hashCode();
}
@Override
public boolean equals(Object o) {
return o instanceof KeySet && keySetId.equals(((KeySet) o).keySetId);
}
}

View File

@@ -0,0 +1,36 @@
package org.briarproject.bramble.api.transport;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import javax.annotation.concurrent.Immutable;
/**
* Type-safe wrapper for an integer that uniquely identifies a set of transport
* keys within the scope of the local device.
* <p/>
* Key sets created on a given device must have increasing identifiers.
*/
@Immutable
@NotNullByDefault
public class KeySetId {
private final int id;
public KeySetId(int id) {
this.id = id;
}
public int getInt() {
return id;
}
@Override
public int hashCode() {
return id;
}
@Override
public boolean equals(Object o) {
return o instanceof KeySetId && id == ((KeySetId) o).id;
}
}

View File

@@ -10,18 +10,20 @@ public class OutgoingKeys {
private final SecretKey tagKey, headerKey; private final SecretKey tagKey, headerKey;
private final long rotationPeriod, streamCounter; private final long rotationPeriod, streamCounter;
private final boolean active;
public OutgoingKeys(SecretKey tagKey, SecretKey headerKey, public OutgoingKeys(SecretKey tagKey, SecretKey headerKey,
long rotationPeriod) { long rotationPeriod, boolean active) {
this(tagKey, headerKey, rotationPeriod, 0); this(tagKey, headerKey, rotationPeriod, 0, active);
} }
public OutgoingKeys(SecretKey tagKey, SecretKey headerKey, public OutgoingKeys(SecretKey tagKey, SecretKey headerKey,
long rotationPeriod, long streamCounter) { long rotationPeriod, long streamCounter, boolean active) {
this.tagKey = tagKey; this.tagKey = tagKey;
this.headerKey = headerKey; this.headerKey = headerKey;
this.rotationPeriod = rotationPeriod; this.rotationPeriod = rotationPeriod;
this.streamCounter = streamCounter; this.streamCounter = streamCounter;
this.active = active;
} }
public SecretKey getTagKey() { public SecretKey getTagKey() {
@@ -39,4 +41,8 @@ public class OutgoingKeys {
public long getStreamCounter() { public long getStreamCounter() {
return streamCounter; return streamCounter;
} }
public boolean isActive() {
return active;
}
} }

View File

@@ -50,7 +50,7 @@ class ContactManagerImpl implements ContactManager {
@Override @Override
public ContactId addContact(Transaction txn, Author remote, AuthorId local, public ContactId addContact(Transaction txn, Author remote, AuthorId local,
SecretKey master,long timestamp, boolean alice, boolean verified, SecretKey master, long timestamp, boolean alice, boolean verified,
boolean active) throws DbException { boolean active) throws DbException {
ContactId c = db.addContact(txn, remote, local, verified, active); ContactId c = db.addContact(txn, remote, local, verified, active);
keyManager.addContact(txn, c, master, timestamp, alice); keyManager.addContact(txn, c, master, timestamp, alice);
@@ -60,6 +60,16 @@ class ContactManagerImpl implements ContactManager {
return c; return c;
} }
@Override
public ContactId addContact(Transaction txn, Author remote, AuthorId local,
boolean verified, boolean active) throws DbException {
ContactId c = db.addContact(txn, remote, local, verified, active);
Contact contact = db.getContact(txn, c);
for (AddContactHook hook : addHooks)
hook.addingContact(txn, contact);
return c;
}
@Override @Override
public ContactId addContact(Author remote, AuthorId local, SecretKey master, public ContactId addContact(Author remote, AuthorId local, SecretKey master,
long timestamp, boolean alice, boolean verified, boolean active) long timestamp, boolean alice, boolean verified, boolean active)

View File

@@ -36,7 +36,8 @@ class TransportCryptoImpl implements TransportCrypto {
@Override @Override
public TransportKeys deriveTransportKeys(TransportId t, public TransportKeys deriveTransportKeys(TransportId t,
SecretKey master, long rotationPeriod, boolean alice) { SecretKey master, long rotationPeriod, boolean alice,
boolean active) {
// Keys for the previous period are derived from the master secret // Keys for the previous period are derived from the master secret
SecretKey inTagPrev = deriveTagKey(master, t, !alice); SecretKey inTagPrev = deriveTagKey(master, t, !alice);
SecretKey inHeaderPrev = deriveHeaderKey(master, t, !alice); SecretKey inHeaderPrev = deriveHeaderKey(master, t, !alice);
@@ -57,7 +58,7 @@ class TransportCryptoImpl implements TransportCrypto {
IncomingKeys inNext = new IncomingKeys(inTagNext, inHeaderNext, IncomingKeys inNext = new IncomingKeys(inTagNext, inHeaderNext,
rotationPeriod + 1); rotationPeriod + 1);
OutgoingKeys outCurr = new OutgoingKeys(outTagCurr, outHeaderCurr, OutgoingKeys outCurr = new OutgoingKeys(outTagCurr, outHeaderCurr,
rotationPeriod); rotationPeriod, active);
// Collect and return the keys // Collect and return the keys
return new TransportKeys(t, inPrev, inCurr, inNext, outCurr); return new TransportKeys(t, inPrev, inCurr, inNext, outCurr);
} }
@@ -71,6 +72,7 @@ class TransportCryptoImpl implements TransportCrypto {
IncomingKeys inNext = k.getNextIncomingKeys(); IncomingKeys inNext = k.getNextIncomingKeys();
OutgoingKeys outCurr = k.getCurrentOutgoingKeys(); OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
long startPeriod = outCurr.getRotationPeriod(); long startPeriod = outCurr.getRotationPeriod();
boolean active = outCurr.isActive();
// Rotate the keys // Rotate the keys
for (long p = startPeriod + 1; p <= rotationPeriod; p++) { for (long p = startPeriod + 1; p <= rotationPeriod; p++) {
inPrev = inCurr; inPrev = inCurr;
@@ -80,7 +82,7 @@ class TransportCryptoImpl implements TransportCrypto {
inNext = new IncomingKeys(inNextTag, inNextHeader, p + 1); inNext = new IncomingKeys(inNextTag, inNextHeader, p + 1);
SecretKey outCurrTag = rotateKey(outCurr.getTagKey(), p); SecretKey outCurrTag = rotateKey(outCurr.getTagKey(), p);
SecretKey outCurrHeader = rotateKey(outCurr.getHeaderKey(), p); SecretKey outCurrHeader = rotateKey(outCurr.getHeaderKey(), p);
outCurr = new OutgoingKeys(outCurrTag, outCurrHeader, p); outCurr = new OutgoingKeys(outCurrTag, outCurrHeader, p, active);
} }
// Collect and return the keys // Collect and return the keys
return new TransportKeys(k.getTransportId(), inPrev, inCurr, inNext, return new TransportKeys(k.getTransportId(), inPrev, inCurr, inNext,

View File

@@ -21,6 +21,8 @@ import org.briarproject.bramble.api.sync.Message;
import org.briarproject.bramble.api.sync.MessageId; import org.briarproject.bramble.api.sync.MessageId;
import org.briarproject.bramble.api.sync.MessageStatus; import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.ValidationManager.State; import org.briarproject.bramble.api.sync.ValidationManager.State;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import java.util.Collection; import java.util.Collection;
@@ -123,9 +125,16 @@ interface Database<T> {
throws DbException; throws DbException;
/** /**
* Stores transport keys for a newly added contact. * Stores the given transport keys, optionally binding them to the given
* contact, and returns a key set ID.
*/ */
void addTransportKeys(T txn, ContactId c, TransportKeys k) KeySetId addTransportKeys(T txn, @Nullable ContactId c, TransportKeys k)
throws DbException;
/**
* Binds the given keys for the given transport to the given contact.
*/
void bindTransportKeys(T txn, ContactId c, TransportId t, KeySetId k)
throws DbException; throws DbException;
/** /**
@@ -486,15 +495,14 @@ interface Database<T> {
* <p/> * <p/>
* Read-only. * Read-only.
*/ */
Map<ContactId, TransportKeys> getTransportKeys(T txn, TransportId t) Collection<KeySet> getTransportKeys(T txn, TransportId t)
throws DbException; throws DbException;
/** /**
* Increments the outgoing stream counter for the given contact and * Increments the outgoing stream counter for the given transport keys.
* transport in the given rotation period.
*/ */
void incrementStreamCounter(T txn, ContactId c, TransportId t, void incrementStreamCounter(T txn, TransportId t, KeySetId k)
long rotationPeriod) throws DbException; throws DbException;
/** /**
* Marks the given messages as not needing to be acknowledged to the * Marks the given messages as not needing to be acknowledged to the
@@ -584,6 +592,12 @@ interface Database<T> {
*/ */
void removeTransport(T txn, TransportId t) throws DbException; void removeTransport(T txn, TransportId t) throws DbException;
/**
* Removes the given transport keys from the database.
*/
void removeTransportKeys(T txn, TransportId t, KeySetId k)
throws DbException;
/** /**
* Resets the transmission count and expiry time of the given message with * Resets the transmission count and expiry time of the given message with
* respect to the given contact. * respect to the given contact.
@@ -619,12 +633,18 @@ interface Database<T> {
void setMessageState(T txn, MessageId m, State state) throws DbException; void setMessageState(T txn, MessageId m, State state) throws DbException;
/** /**
* Sets the reordering window for the given contact and transport in the * Sets the reordering window for the given key set and transport in the
* given rotation period. * given rotation period.
*/ */
void setReorderingWindow(T txn, ContactId c, TransportId t, void setReorderingWindow(T txn, KeySetId k, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException; long rotationPeriod, long base, byte[] bitmap) throws DbException;
/**
* Marks the given transport keys as usable for outgoing streams.
*/
void setTransportKeysActive(T txn, TransportId t, KeySetId k)
throws DbException;
/** /**
* Updates the transmission count and expiry time of the given message * Updates the transmission count and expiry time of the given message
* with respect to the given contact, using the latency of the transport * with respect to the given contact, using the latency of the transport
@@ -636,6 +656,5 @@ interface Database<T> {
/** /**
* Stores the given transport keys, deleting any keys they have replaced. * Stores the given transport keys, deleting any keys they have replaced.
*/ */
void updateTransportKeys(T txn, Map<ContactId, TransportKeys> keys) void updateTransportKeys(T txn, Collection<KeySet> keys) throws DbException;
throws DbException;
} }

View File

@@ -51,15 +51,15 @@ import org.briarproject.bramble.api.sync.event.MessageToAckEvent;
import org.briarproject.bramble.api.sync.event.MessageToRequestEvent; import org.briarproject.bramble.api.sync.event.MessageToRequestEvent;
import org.briarproject.bramble.api.sync.event.MessagesAckedEvent; import org.briarproject.bramble.api.sync.event.MessagesAckedEvent;
import org.briarproject.bramble.api.sync.event.MessagesSentEvent; import org.briarproject.bramble.api.sync.event.MessagesSentEvent;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.logging.Logger; import java.util.logging.Logger;
@@ -234,15 +234,27 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
} }
@Override @Override
public void addTransportKeys(Transaction transaction, ContactId c, public KeySetId addTransportKeys(Transaction transaction,
TransportKeys k) throws DbException { @Nullable ContactId c, TransportKeys k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (c != null && !db.containsContact(txn, c))
throw new NoSuchContactException();
if (!db.containsTransport(txn, k.getTransportId()))
throw new NoSuchTransportException();
return db.addTransportKeys(txn, c, k);
}
@Override
public void bindTransportKeys(Transaction transaction, ContactId c,
TransportId t, KeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException(); if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction); T txn = unbox(transaction);
if (!db.containsContact(txn, c)) if (!db.containsContact(txn, c))
throw new NoSuchContactException(); throw new NoSuchContactException();
if (!db.containsTransport(txn, k.getTransportId())) if (!db.containsTransport(txn, t))
throw new NoSuchTransportException(); throw new NoSuchTransportException();
db.addTransportKeys(txn, c, k); db.bindTransportKeys(txn, c, t, k);
} }
@Override @Override
@@ -586,8 +598,8 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
} }
@Override @Override
public Map<ContactId, TransportKeys> getTransportKeys( public Collection<KeySet> getTransportKeys(Transaction transaction,
Transaction transaction, TransportId t) throws DbException { TransportId t) throws DbException {
T txn = unbox(transaction); T txn = unbox(transaction);
if (!db.containsTransport(txn, t)) if (!db.containsTransport(txn, t))
throw new NoSuchTransportException(); throw new NoSuchTransportException();
@@ -595,15 +607,13 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
} }
@Override @Override
public void incrementStreamCounter(Transaction transaction, ContactId c, public void incrementStreamCounter(Transaction transaction, TransportId t,
TransportId t, long rotationPeriod) throws DbException { KeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException(); if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction); T txn = unbox(transaction);
if (!db.containsContact(txn, c))
throw new NoSuchContactException();
if (!db.containsTransport(txn, t)) if (!db.containsTransport(txn, t))
throw new NoSuchTransportException(); throw new NoSuchTransportException();
db.incrementStreamCounter(txn, c, t, rotationPeriod); db.incrementStreamCounter(txn, t, k);
} }
@Override @Override
@@ -779,6 +789,16 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
db.removeTransport(txn, t); db.removeTransport(txn, t);
} }
@Override
public void removeTransportKeys(Transaction transaction,
TransportId t, KeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
db.removeTransportKeys(txn, t, k);
}
@Override @Override
public void setContactVerified(Transaction transaction, ContactId c) public void setContactVerified(Transaction transaction, ContactId c)
throws DbException { throws DbException {
@@ -858,31 +878,35 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
} }
@Override @Override
public void setReorderingWindow(Transaction transaction, ContactId c, public void setReorderingWindow(Transaction transaction, KeySetId k,
TransportId t, long rotationPeriod, long base, byte[] bitmap) TransportId t, long rotationPeriod, long base, byte[] bitmap)
throws DbException { throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException(); if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction); T txn = unbox(transaction);
if (!db.containsContact(txn, c))
throw new NoSuchContactException();
if (!db.containsTransport(txn, t)) if (!db.containsTransport(txn, t))
throw new NoSuchTransportException(); throw new NoSuchTransportException();
db.setReorderingWindow(txn, c, t, rotationPeriod, base, bitmap); db.setReorderingWindow(txn, k, t, rotationPeriod, base, bitmap);
}
@Override
public void setTransportKeysActive(Transaction transaction, TransportId t,
KeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
db.setTransportKeysActive(txn, t, k);
} }
@Override @Override
public void updateTransportKeys(Transaction transaction, public void updateTransportKeys(Transaction transaction,
Map<ContactId, TransportKeys> keys) throws DbException { Collection<KeySet> keys) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException(); if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction); T txn = unbox(transaction);
Map<ContactId, TransportKeys> filtered = new HashMap<>(); Collection<KeySet> filtered = new ArrayList<>();
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { for (KeySet ks : keys) {
ContactId c = e.getKey(); TransportId t = ks.getTransportKeys().getTransportId();
TransportKeys k = e.getValue(); if (db.containsTransport(txn, t)) filtered.add(ks);
if (db.containsContact(txn, c)
&& db.containsTransport(txn, k.getTransportId())) {
filtered.put(c, k);
}
} }
db.updateTransportKeys(txn, filtered); db.updateTransportKeys(txn, filtered);
} }

View File

@@ -25,6 +25,8 @@ import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.ValidationManager.State; import org.briarproject.bramble.api.sync.ValidationManager.State;
import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.transport.IncomingKeys; import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
@@ -223,37 +225,44 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " maxLatency INT NOT NULL," + " maxLatency INT NOT NULL,"
+ " PRIMARY KEY (transportId))"; + " PRIMARY KEY (transportId))";
private static final String CREATE_OUTGOING_KEYS =
"CREATE TABLE outgoingKeys"
+ " (transportId _STRING NOT NULL,"
+ " keySetId _COUNTER,"
+ " rotationPeriod BIGINT NOT NULL,"
+ " contactId INT," // Null if keys are not bound
+ " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL,"
+ " stream BIGINT NOT NULL,"
+ " active BOOLEAN NOT NULL,"
+ " PRIMARY KEY (transportId, keySetId),"
+ " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)"
+ " ON DELETE CASCADE,"
+ " UNIQUE (keySetId),"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE)";
private static final String CREATE_INCOMING_KEYS = private static final String CREATE_INCOMING_KEYS =
"CREATE TABLE incomingKeys" "CREATE TABLE incomingKeys"
+ " (contactId INT NOT NULL," + " (transportId _STRING NOT NULL,"
+ " transportId _STRING NOT NULL," + " keySetId INT NOT NULL,"
+ " rotationPeriod BIGINT NOT NULL," + " rotationPeriod BIGINT NOT NULL,"
+ " contactId INT," // Null if keys are not bound
+ " tagKey _SECRET NOT NULL," + " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL," + " headerKey _SECRET NOT NULL,"
+ " base BIGINT NOT NULL," + " base BIGINT NOT NULL,"
+ " bitmap _BINARY NOT NULL," + " bitmap _BINARY NOT NULL,"
+ " PRIMARY KEY (contactId, transportId, rotationPeriod)," + " PRIMARY KEY (transportId, keySetId, rotationPeriod),"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE,"
+ " FOREIGN KEY (transportId)" + " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)" + " REFERENCES transports (transportId)"
+ " ON DELETE CASCADE)"; + " ON DELETE CASCADE,"
+ " FOREIGN KEY (keySetId)"
private static final String CREATE_OUTGOING_KEYS = + " REFERENCES outgoingKeys (keySetId)"
"CREATE TABLE outgoingKeys" + " ON DELETE CASCADE,"
+ " (contactId INT NOT NULL,"
+ " transportId _STRING NOT NULL,"
+ " rotationPeriod BIGINT NOT NULL,"
+ " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL,"
+ " stream BIGINT NOT NULL,"
+ " PRIMARY KEY (contactId, transportId),"
+ " FOREIGN KEY (contactId)" + " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)" + " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE,"
+ " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)"
+ " ON DELETE CASCADE)"; + " ON DELETE CASCADE)";
private static final String INDEX_CONTACTS_BY_AUTHOR_ID = private static final String INDEX_CONTACTS_BY_AUTHOR_ID =
@@ -415,8 +424,8 @@ abstract class JdbcDatabase implements Database<Connection> {
s.executeUpdate(insertTypeNames(CREATE_OFFERS)); s.executeUpdate(insertTypeNames(CREATE_OFFERS));
s.executeUpdate(insertTypeNames(CREATE_STATUSES)); s.executeUpdate(insertTypeNames(CREATE_STATUSES));
s.executeUpdate(insertTypeNames(CREATE_TRANSPORTS)); s.executeUpdate(insertTypeNames(CREATE_TRANSPORTS));
s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS));
s.executeUpdate(insertTypeNames(CREATE_OUTGOING_KEYS)); s.executeUpdate(insertTypeNames(CREATE_OUTGOING_KEYS));
s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS));
s.close(); s.close();
} catch (SQLException e) { } catch (SQLException e) {
tryToClose(s); tryToClose(s);
@@ -865,61 +874,105 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
@Override @Override
public void addTransportKeys(Connection txn, ContactId c, TransportKeys k) public KeySetId addTransportKeys(Connection txn, @Nullable ContactId c,
throws DbException { TransportKeys k) throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null;
try { try {
// Store the incoming keys // Store the outgoing keys
String sql = "INSERT INTO incomingKeys (contactId, transportId," String sql = "INSERT INTO outgoingKeys (contactId, transportId,"
+ " rotationPeriod, tagKey, headerKey, base, bitmap)" + " rotationPeriod, tagKey, headerKey, stream, active)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?)"; + " VALUES (?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); if (c == null) ps.setNull(1, INTEGER);
ps.setString(2, k.getTransportId().getString()); else ps.setInt(1, c.getInt());
// Previous rotation period
IncomingKeys inPrev = k.getPreviousIncomingKeys();
ps.setLong(3, inPrev.getRotationPeriod());
ps.setBytes(4, inPrev.getTagKey().getBytes());
ps.setBytes(5, inPrev.getHeaderKey().getBytes());
ps.setLong(6, inPrev.getWindowBase());
ps.setBytes(7, inPrev.getWindowBitmap());
ps.addBatch();
// Current rotation period
IncomingKeys inCurr = k.getCurrentIncomingKeys();
ps.setLong(3, inCurr.getRotationPeriod());
ps.setBytes(4, inCurr.getTagKey().getBytes());
ps.setBytes(5, inCurr.getHeaderKey().getBytes());
ps.setLong(6, inCurr.getWindowBase());
ps.setBytes(7, inCurr.getWindowBitmap());
ps.addBatch();
// Next rotation period
IncomingKeys inNext = k.getNextIncomingKeys();
ps.setLong(3, inNext.getRotationPeriod());
ps.setBytes(4, inNext.getTagKey().getBytes());
ps.setBytes(5, inNext.getHeaderKey().getBytes());
ps.setLong(6, inNext.getWindowBase());
ps.setBytes(7, inNext.getWindowBitmap());
ps.addBatch();
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != 3) throw new DbStateException();
for (int rows : batchAffected)
if (rows != 1) throw new DbStateException();
ps.close();
// Store the outgoing keys
sql = "INSERT INTO outgoingKeys (contactId, transportId,"
+ " rotationPeriod, tagKey, headerKey, stream)"
+ " VALUES (?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setString(2, k.getTransportId().getString()); ps.setString(2, k.getTransportId().getString());
OutgoingKeys outCurr = k.getCurrentOutgoingKeys(); OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
ps.setLong(3, outCurr.getRotationPeriod()); ps.setLong(3, outCurr.getRotationPeriod());
ps.setBytes(4, outCurr.getTagKey().getBytes()); ps.setBytes(4, outCurr.getTagKey().getBytes());
ps.setBytes(5, outCurr.getHeaderKey().getBytes()); ps.setBytes(5, outCurr.getHeaderKey().getBytes());
ps.setLong(6, outCurr.getStreamCounter()); ps.setLong(6, outCurr.getStreamCounter());
ps.setBoolean(7, outCurr.isActive());
int affected = ps.executeUpdate(); int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException(); if (affected != 1) throw new DbStateException();
ps.close(); ps.close();
// Get the new (highest) key set ID
sql = "SELECT keySetId FROM outgoingKeys"
+ " ORDER BY keySetId DESC LIMIT 1";
ps = txn.prepareStatement(sql);
rs = ps.executeQuery();
if (!rs.next()) throw new DbStateException();
KeySetId keySetId = new KeySetId(rs.getInt(1));
if (rs.next()) throw new DbStateException();
rs.close();
ps.close();
// Store the incoming keys
sql = "INSERT INTO incomingKeys (keySetId, contactId, transportId,"
+ " rotationPeriod, tagKey, headerKey, base, bitmap)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setInt(1, keySetId.getInt());
if (c == null) ps.setNull(2, INTEGER);
else ps.setInt(2, c.getInt());
ps.setString(3, k.getTransportId().getString());
// Previous rotation period
IncomingKeys inPrev = k.getPreviousIncomingKeys();
ps.setLong(4, inPrev.getRotationPeriod());
ps.setBytes(5, inPrev.getTagKey().getBytes());
ps.setBytes(6, inPrev.getHeaderKey().getBytes());
ps.setLong(7, inPrev.getWindowBase());
ps.setBytes(8, inPrev.getWindowBitmap());
ps.addBatch();
// Current rotation period
IncomingKeys inCurr = k.getCurrentIncomingKeys();
ps.setLong(4, inCurr.getRotationPeriod());
ps.setBytes(5, inCurr.getTagKey().getBytes());
ps.setBytes(6, inCurr.getHeaderKey().getBytes());
ps.setLong(7, inCurr.getWindowBase());
ps.setBytes(8, inCurr.getWindowBitmap());
ps.addBatch();
// Next rotation period
IncomingKeys inNext = k.getNextIncomingKeys();
ps.setLong(4, inNext.getRotationPeriod());
ps.setBytes(5, inNext.getTagKey().getBytes());
ps.setBytes(6, inNext.getHeaderKey().getBytes());
ps.setLong(7, inNext.getWindowBase());
ps.setBytes(8, inNext.getWindowBitmap());
ps.addBatch();
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != 3) throw new DbStateException();
for (int rows : batchAffected)
if (rows != 1) throw new DbStateException();
ps.close();
return keySetId;
} catch (SQLException e) {
tryToClose(rs);
tryToClose(ps);
throw new DbException(e);
}
}
@Override
public void bindTransportKeys(Connection txn, ContactId c, TransportId t,
KeySetId k) throws DbException {
PreparedStatement ps = null;
try {
String sql = "UPDATE outgoingKeys SET contactId = ?"
+ " WHERE keySetId = ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setInt(2, k.getInt());
int affected = ps.executeUpdate();
if (affected < 0) throw new DbStateException();
ps.close();
sql = "UPDATE incomingKeys SET contactId = ?"
+ " WHERE keySetId = ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setInt(2, k.getInt());
affected = ps.executeUpdate();
if (affected < 0) throw new DbStateException();
ps.close();
} catch (SQLException e) { } catch (SQLException e) {
tryToClose(ps); tryToClose(ps);
throw new DbException(e); throw new DbException(e);
@@ -2078,8 +2131,8 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
@Override @Override
public Map<ContactId, TransportKeys> getTransportKeys(Connection txn, public Collection<KeySet> getTransportKeys(Connection txn, TransportId t)
TransportId t) throws DbException { throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null; ResultSet rs = null;
try { try {
@@ -2088,7 +2141,7 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " base, bitmap" + " base, bitmap"
+ " FROM incomingKeys" + " FROM incomingKeys"
+ " WHERE transportId = ?" + " WHERE transportId = ?"
+ " ORDER BY contactId, rotationPeriod"; + " ORDER BY keySetId, rotationPeriod";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setString(1, t.getString()); ps.setString(1, t.getString());
rs = ps.executeQuery(); rs = ps.executeQuery();
@@ -2105,29 +2158,34 @@ abstract class JdbcDatabase implements Database<Connection> {
rs.close(); rs.close();
ps.close(); ps.close();
// Retrieve the outgoing keys in the same order // Retrieve the outgoing keys in the same order
sql = "SELECT contactId, rotationPeriod, tagKey, headerKey, stream" sql = "SELECT keySetId, contactId, rotationPeriod,"
+ " tagKey, headerKey, stream, active"
+ " FROM outgoingKeys" + " FROM outgoingKeys"
+ " WHERE transportId = ?" + " WHERE transportId = ?"
+ " ORDER BY contactId, rotationPeriod"; + " ORDER BY keySetId";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setString(1, t.getString()); ps.setString(1, t.getString());
rs = ps.executeQuery(); rs = ps.executeQuery();
Map<ContactId, TransportKeys> keys = new HashMap<>(); Collection<KeySet> keys = new ArrayList<>();
for (int i = 0; rs.next(); i++) { for (int i = 0; rs.next(); i++) {
// There should be three times as many incoming keys // There should be three times as many incoming keys
if (inKeys.size() < (i + 1) * 3) throw new DbStateException(); if (inKeys.size() < (i + 1) * 3) throw new DbStateException();
ContactId contactId = new ContactId(rs.getInt(1)); KeySetId keySetId = new KeySetId(rs.getInt(1));
long rotationPeriod = rs.getLong(2); ContactId contactId = new ContactId(rs.getInt(2));
SecretKey tagKey = new SecretKey(rs.getBytes(3)); if (rs.wasNull()) contactId = null;
SecretKey headerKey = new SecretKey(rs.getBytes(4)); long rotationPeriod = rs.getLong(3);
long streamCounter = rs.getLong(5); SecretKey tagKey = new SecretKey(rs.getBytes(4));
SecretKey headerKey = new SecretKey(rs.getBytes(5));
long streamCounter = rs.getLong(6);
boolean active = rs.getBoolean(7);
OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey, OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey,
rotationPeriod, streamCounter); rotationPeriod, streamCounter, active);
IncomingKeys inPrev = inKeys.get(i * 3); IncomingKeys inPrev = inKeys.get(i * 3);
IncomingKeys inCurr = inKeys.get(i * 3 + 1); IncomingKeys inCurr = inKeys.get(i * 3 + 1);
IncomingKeys inNext = inKeys.get(i * 3 + 2); IncomingKeys inNext = inKeys.get(i * 3 + 2);
keys.put(contactId, new TransportKeys(t, inPrev, inCurr, TransportKeys transportKeys = new TransportKeys(t, inPrev,
inNext, outCurr)); inCurr, inNext, outCurr);
keys.add(new KeySet(keySetId, contactId, transportKeys));
} }
rs.close(); rs.close();
ps.close(); ps.close();
@@ -2140,17 +2198,15 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
@Override @Override
public void incrementStreamCounter(Connection txn, ContactId c, public void incrementStreamCounter(Connection txn, TransportId t,
TransportId t, long rotationPeriod) throws DbException { KeySetId k) throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
try { try {
String sql = "UPDATE outgoingKeys SET stream = stream + 1" String sql = "UPDATE outgoingKeys SET stream = stream + 1"
+ " WHERE contactId = ? AND transportId = ?" + " WHERE transportId = ? AND keySetId = ?";
+ " AND rotationPeriod = ?";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); ps.setString(1, t.getString());
ps.setString(2, t.getString()); ps.setInt(2, k.getInt());
ps.setLong(3, rotationPeriod);
int affected = ps.executeUpdate(); int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException(); if (affected != 1) throw new DbStateException();
ps.close(); ps.close();
@@ -2626,6 +2682,27 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
} }
@Override
public void removeTransportKeys(Connection txn, TransportId t, KeySetId k)
throws DbException {
PreparedStatement ps = null;
try {
// Delete any existing outgoing keys - this will also remove any
// incoming keys with the same key set ID
String sql = "DELETE FROM outgoingKeys"
+ " WHERE transportId = ? AND keySetId = ?";
ps = txn.prepareStatement(sql);
ps.setString(1, t.getString());
ps.setInt(2, k.getInt());
int affected = ps.executeUpdate();
if (affected < 0) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
}
@Override @Override
public void resetExpiryTime(Connection txn, ContactId c, MessageId m) public void resetExpiryTime(Connection txn, ContactId c, MessageId m)
throws DbException { throws DbException {
@@ -2791,18 +2868,18 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
@Override @Override
public void setReorderingWindow(Connection txn, ContactId c, TransportId t, public void setReorderingWindow(Connection txn, KeySetId k, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException { long rotationPeriod, long base, byte[] bitmap) throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
try { try {
String sql = "UPDATE incomingKeys SET base = ?, bitmap = ?" String sql = "UPDATE incomingKeys SET base = ?, bitmap = ?"
+ " WHERE contactId = ? AND transportId = ?" + " WHERE transportId = ? AND keySetId = ?"
+ " AND rotationPeriod = ?"; + " AND rotationPeriod = ?";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setLong(1, base); ps.setLong(1, base);
ps.setBytes(2, bitmap); ps.setBytes(2, bitmap);
ps.setInt(3, c.getInt()); ps.setString(3, t.getString());
ps.setString(4, t.getString()); ps.setInt(4, k.getInt());
ps.setLong(5, rotationPeriod); ps.setLong(5, rotationPeriod);
int affected = ps.executeUpdate(); int affected = ps.executeUpdate();
if (affected < 0 || affected > 1) throw new DbStateException(); if (affected < 0 || affected > 1) throw new DbStateException();
@@ -2813,6 +2890,23 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
} }
@Override
public void setTransportKeysActive(Connection txn, TransportId t,
KeySetId k) throws DbException {
PreparedStatement ps = null;
try {
String sql = "UPDATE outgoingKeys SET active = true"
+ " WHERE transportId = ? AND keySetId = ?";
ps = txn.prepareStatement(sql);
int affected = ps.executeUpdate();
if (affected < 0 || affected > 1) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
}
@Override @Override
public void updateExpiryTime(Connection txn, ContactId c, MessageId m, public void updateExpiryTime(Connection txn, ContactId c, MessageId m,
int maxLatency) throws DbException { int maxLatency) throws DbException {
@@ -2848,45 +2942,12 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
@Override @Override
public void updateTransportKeys(Connection txn, public void updateTransportKeys(Connection txn, Collection<KeySet> keys)
Map<ContactId, TransportKeys> keys) throws DbException { throws DbException {
PreparedStatement ps = null; for (KeySet ks : keys) {
try { TransportKeys k = ks.getTransportKeys();
// Delete any existing incoming keys removeTransportKeys(txn, k.getTransportId(), ks.getKeySetId());
String sql = "DELETE FROM incomingKeys" addTransportKeys(txn, ks.getContactId(), k);
+ " WHERE contactId = ?"
+ " AND transportId = ?";
ps = txn.prepareStatement(sql);
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
ps.setInt(1, e.getKey().getInt());
ps.setString(2, e.getValue().getTransportId().getString());
ps.addBatch();
}
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != keys.size())
throw new DbStateException();
ps.close();
// Delete any existing outgoing keys
sql = "DELETE FROM outgoingKeys"
+ " WHERE contactId = ?"
+ " AND transportId = ?";
ps = txn.prepareStatement(sql);
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
ps.setInt(1, e.getKey().getInt());
ps.setString(2, e.getValue().getTransportId().getString());
ps.addBatch();
}
batchAffected = ps.executeBatch();
if (batchAffected.length != keys.size())
throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
// Store the new keys
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
addTransportKeys(txn, e.getKey(), e.getValue());
} }
} }
} }

View File

@@ -19,6 +19,7 @@ import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexPluginFactory; import org.briarproject.bramble.api.plugin.duplex.DuplexPluginFactory;
import org.briarproject.bramble.api.plugin.simplex.SimplexPluginFactory; import org.briarproject.bramble.api.plugin.simplex.SimplexPluginFactory;
import org.briarproject.bramble.api.transport.KeyManager; import org.briarproject.bramble.api.transport.KeyManager;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.StreamContext; import org.briarproject.bramble.api.transport.StreamContext;
import java.util.HashMap; import java.util.HashMap;
@@ -104,6 +105,67 @@ class KeyManagerImpl implements KeyManager, Service, EventListener {
m.addContact(txn, c, master, timestamp, alice); m.addContact(txn, c, master, timestamp, alice);
} }
@Override
public Map<TransportId, KeySetId> addUnboundKeys(Transaction txn,
SecretKey master, long timestamp, boolean alice)
throws DbException {
Map<TransportId, KeySetId> ids = new HashMap<>();
for (Entry<TransportId, TransportKeyManager> e : managers.entrySet()) {
TransportId t = e.getKey();
TransportKeyManager m = e.getValue();
ids.put(t, m.addUnboundKeys(txn, master, timestamp, alice));
}
return ids;
}
@Override
public void bindKeys(Transaction txn, ContactId c,
Map<TransportId, KeySetId> keys) throws DbException {
for (Entry<TransportId, KeySetId> e : keys.entrySet()) {
TransportId t = e.getKey();
TransportKeyManager m = managers.get(t);
if (m == null) {
if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t);
} else {
m.bindKeys(txn, c, e.getValue());
}
}
}
@Override
public void activateKeys(Transaction txn, Map<TransportId, KeySetId> keys)
throws DbException {
for (Entry<TransportId, KeySetId> e : keys.entrySet()) {
TransportId t = e.getKey();
TransportKeyManager m = managers.get(t);
if (m == null) {
if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t);
} else {
m.activateKeys(txn, e.getValue());
}
}
}
@Override
public void removeKeys(Transaction txn, Map<TransportId, KeySetId> keys)
throws DbException {
for (Entry<TransportId, KeySetId> e : keys.entrySet()) {
TransportId t = e.getKey();
TransportKeyManager m = managers.get(t);
if (m == null) {
if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t);
} else {
m.removeKeys(txn, e.getValue());
}
}
}
@Override
public boolean canSendOutgoingStreams(ContactId c, TransportId t) {
TransportKeyManager m = managers.get(t);
return m == null ? false : m.canSendOutgoingStreams(c);
}
@Override @Override
public StreamContext getStreamContext(ContactId c, TransportId t) public StreamContext getStreamContext(ContactId c, TransportId t)
throws DbException { throws DbException {
@@ -114,7 +176,7 @@ class KeyManagerImpl implements KeyManager, Service, EventListener {
if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t); if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t);
return null; return null;
} }
StreamContext ctx = null; StreamContext ctx;
Transaction txn = db.startTransaction(false); Transaction txn = db.startTransaction(false);
try { try {
ctx = m.getStreamContext(txn, c); ctx = m.getStreamContext(txn, c);
@@ -133,7 +195,7 @@ class KeyManagerImpl implements KeyManager, Service, EventListener {
if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t); if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t);
return null; return null;
} }
StreamContext ctx = null; StreamContext ctx;
Transaction txn = db.startTransaction(false); Transaction txn = db.startTransaction(false);
try { try {
ctx = m.getStreamContext(txn, tag); ctx = m.getStreamContext(txn, tag);

View File

@@ -0,0 +1,34 @@
package org.briarproject.bramble.transport;
import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.transport.KeySetId;
import javax.annotation.Nullable;
public class MutableKeySet {
private final KeySetId keySetId;
@Nullable
private final ContactId contactId;
private final MutableTransportKeys transportKeys;
public MutableKeySet(KeySetId keySetId, @Nullable ContactId contactId,
MutableTransportKeys transportKeys) {
this.keySetId = keySetId;
this.contactId = contactId;
this.transportKeys = transportKeys;
}
public KeySetId getKeySetId() {
return keySetId;
}
@Nullable
public ContactId getContactId() {
return contactId;
}
public MutableTransportKeys getTransportKeys() {
return transportKeys;
}
}

View File

@@ -13,17 +13,19 @@ class MutableOutgoingKeys {
private final SecretKey tagKey, headerKey; private final SecretKey tagKey, headerKey;
private final long rotationPeriod; private final long rotationPeriod;
private long streamCounter; private long streamCounter;
private boolean active;
MutableOutgoingKeys(OutgoingKeys out) { MutableOutgoingKeys(OutgoingKeys out) {
tagKey = out.getTagKey(); tagKey = out.getTagKey();
headerKey = out.getHeaderKey(); headerKey = out.getHeaderKey();
rotationPeriod = out.getRotationPeriod(); rotationPeriod = out.getRotationPeriod();
streamCounter = out.getStreamCounter(); streamCounter = out.getStreamCounter();
active = out.isActive();
} }
OutgoingKeys snapshot() { OutgoingKeys snapshot() {
return new OutgoingKeys(tagKey, headerKey, rotationPeriod, return new OutgoingKeys(tagKey, headerKey, rotationPeriod,
streamCounter); streamCounter, active);
} }
SecretKey getTagKey() { SecretKey getTagKey() {
@@ -45,4 +47,12 @@ class MutableOutgoingKeys {
void incrementStreamCounter() { void incrementStreamCounter() {
streamCounter++; streamCounter++;
} }
boolean isActive() {
return active;
}
void activate() {
active = true;
}
} }

View File

@@ -5,6 +5,7 @@ import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.db.DbException; import org.briarproject.bramble.api.db.DbException;
import org.briarproject.bramble.api.db.Transaction; import org.briarproject.bramble.api.db.Transaction;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault; import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.StreamContext; import org.briarproject.bramble.api.transport.StreamContext;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@@ -17,8 +18,19 @@ interface TransportKeyManager {
void addContact(Transaction txn, ContactId c, SecretKey master, void addContact(Transaction txn, ContactId c, SecretKey master,
long timestamp, boolean alice) throws DbException; long timestamp, boolean alice) throws DbException;
KeySetId addUnboundKeys(Transaction txn, SecretKey master, long timestamp,
boolean alice) throws DbException;
void bindKeys(Transaction txn, ContactId c, KeySetId k) throws DbException;
void activateKeys(Transaction txn, KeySetId k) throws DbException;
void removeKeys(Transaction txn, KeySetId k) throws DbException;
void removeContact(ContactId c); void removeContact(ContactId c);
boolean canSendOutgoingStreams(ContactId c);
@Nullable @Nullable
StreamContext getStreamContext(Transaction txn, ContactId c) StreamContext getStreamContext(Transaction txn, ContactId c)
throws DbException; throws DbException;

View File

@@ -11,19 +11,24 @@ import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.system.Scheduler; import org.briarproject.bramble.api.system.Scheduler;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.StreamContext; import org.briarproject.bramble.api.transport.StreamContext;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.transport.ReorderingWindow.Change; import org.briarproject.bramble.transport.ReorderingWindow.Change;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger; import java.util.logging.Logger;
import javax.annotation.Nullable;
import javax.annotation.concurrent.ThreadSafe; import javax.annotation.concurrent.ThreadSafe;
import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MILLISECONDS;
@@ -47,12 +52,13 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private final Clock clock; private final Clock clock;
private final TransportId transportId; private final TransportId transportId;
private final long rotationPeriodLength; private final long rotationPeriodLength;
private final ReentrantLock lock; private final AtomicBoolean used = new AtomicBoolean(false);
private final ReentrantLock lock = new ReentrantLock();
// The following are locking: lock // The following are locking: lock
private final Map<Bytes, TagContext> inContexts; private final Map<KeySetId, MutableKeySet> keys = new HashMap<>();
private final Map<ContactId, MutableOutgoingKeys> outContexts; private final Map<Bytes, TagContext> inContexts = new HashMap<>();
private final Map<ContactId, MutableTransportKeys> keys; private final Map<ContactId, MutableKeySet> outContexts = new HashMap<>();
TransportKeyManagerImpl(DatabaseComponent db, TransportKeyManagerImpl(DatabaseComponent db,
TransportCrypto transportCrypto, Executor dbExecutor, TransportCrypto transportCrypto, Executor dbExecutor,
@@ -65,20 +71,16 @@ class TransportKeyManagerImpl implements TransportKeyManager {
this.clock = clock; this.clock = clock;
this.transportId = transportId; this.transportId = transportId;
rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE; rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE;
lock = new ReentrantLock();
inContexts = new HashMap<>();
outContexts = new HashMap<>();
keys = new HashMap<>();
} }
@Override @Override
public void start(Transaction txn) throws DbException { public void start(Transaction txn) throws DbException {
if (used.getAndSet(true)) throw new IllegalStateException();
long now = clock.currentTimeMillis(); long now = clock.currentTimeMillis();
lock.lock(); lock.lock();
try { try {
// Load the transport keys from the DB // Load the transport keys from the DB
Map<ContactId, TransportKeys> loaded = Collection<KeySet> loaded = db.getTransportKeys(txn, transportId);
db.getTransportKeys(txn, transportId);
// Rotate the keys to the current rotation period // Rotate the keys to the current rotation period
RotationResult rotationResult = rotateKeys(loaded, now); RotationResult rotationResult = rotateKeys(loaded, now);
// Initialise mutable state for all contacts // Initialise mutable state for all contacts
@@ -93,41 +95,48 @@ class TransportKeyManagerImpl implements TransportKeyManager {
scheduleKeyRotation(now); scheduleKeyRotation(now);
} }
private RotationResult rotateKeys(Map<ContactId, TransportKeys> keys, private RotationResult rotateKeys(Collection<KeySet> keys, long now) {
long now) {
RotationResult rotationResult = new RotationResult(); RotationResult rotationResult = new RotationResult();
long rotationPeriod = now / rotationPeriodLength; long rotationPeriod = now / rotationPeriodLength;
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { for (KeySet ks : keys) {
ContactId c = e.getKey(); TransportKeys k = ks.getTransportKeys();
TransportKeys k = e.getValue();
TransportKeys k1 = TransportKeys k1 =
transportCrypto.rotateTransportKeys(k, rotationPeriod); transportCrypto.rotateTransportKeys(k, rotationPeriod);
KeySet ks1 = new KeySet(ks.getKeySetId(), ks.getContactId(), k1);
if (k1.getRotationPeriod() > k.getRotationPeriod()) if (k1.getRotationPeriod() > k.getRotationPeriod())
rotationResult.rotated.put(c, k1); rotationResult.rotated.add(ks1);
rotationResult.current.put(c, k1); rotationResult.current.add(ks1);
} }
return rotationResult; return rotationResult;
} }
// Locking: lock // Locking: lock
private void addKeys(Map<ContactId, TransportKeys> m) { private void addKeys(Collection<KeySet> keys) {
for (Entry<ContactId, TransportKeys> e : m.entrySet()) for (KeySet ks : keys) {
addKeys(e.getKey(), new MutableTransportKeys(e.getValue())); addKeys(ks.getKeySetId(), ks.getContactId(),
new MutableTransportKeys(ks.getTransportKeys()));
}
} }
// Locking: lock // Locking: lock
private void addKeys(ContactId c, MutableTransportKeys m) { private void addKeys(KeySetId keySetId, @Nullable ContactId contactId,
encodeTags(c, m.getPreviousIncomingKeys()); MutableTransportKeys m) {
encodeTags(c, m.getCurrentIncomingKeys()); MutableKeySet ks = new MutableKeySet(keySetId, contactId, m);
encodeTags(c, m.getNextIncomingKeys()); keys.put(keySetId, ks);
outContexts.put(c, m.getCurrentOutgoingKeys()); if (contactId != null) {
keys.put(c, m); encodeTags(keySetId, contactId, m.getPreviousIncomingKeys());
encodeTags(keySetId, contactId, m.getCurrentIncomingKeys());
encodeTags(keySetId, contactId, m.getNextIncomingKeys());
considerReplacingOutgoingKeys(ks);
}
} }
// Locking: lock // Locking: lock
private void encodeTags(ContactId c, MutableIncomingKeys inKeys) { private void encodeTags(KeySetId keySetId, ContactId contactId,
MutableIncomingKeys inKeys) {
for (long streamNumber : inKeys.getWindow().getUnseen()) { for (long streamNumber : inKeys.getWindow().getUnseen()) {
TagContext tagCtx = new TagContext(c, inKeys, streamNumber); TagContext tagCtx =
new TagContext(keySetId, contactId, inKeys, streamNumber);
byte[] tag = new byte[TAG_LENGTH]; byte[] tag = new byte[TAG_LENGTH];
transportCrypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION, transportCrypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION,
streamNumber); streamNumber);
@@ -135,6 +144,17 @@ class TransportKeyManagerImpl implements TransportKeyManager {
} }
} }
// Locking: lock
private void considerReplacingOutgoingKeys(MutableKeySet ks) {
// Use the active outgoing keys with the highest key set ID
if (ks.getTransportKeys().getCurrentOutgoingKeys().isActive()) {
MutableKeySet old = outContexts.get(ks.getContactId());
if (old == null ||
old.getKeySetId().getInt() < ks.getKeySetId().getInt())
outContexts.put(ks.getContactId(), ks);
}
}
private void scheduleKeyRotation(long now) { private void scheduleKeyRotation(long now) {
long delay = rotationPeriodLength - now % rotationPeriodLength; long delay = rotationPeriodLength - now % rotationPeriodLength;
scheduler.schedule((Runnable) this::rotateKeys, delay, MILLISECONDS); scheduler.schedule((Runnable) this::rotateKeys, delay, MILLISECONDS);
@@ -159,20 +179,82 @@ class TransportKeyManagerImpl implements TransportKeyManager {
@Override @Override
public void addContact(Transaction txn, ContactId c, SecretKey master, public void addContact(Transaction txn, ContactId c, SecretKey master,
long timestamp, boolean alice) throws DbException { long timestamp, boolean alice) throws DbException {
deriveAndAddKeys(txn, c, master, timestamp, alice, true);
}
@Override
public KeySetId addUnboundKeys(Transaction txn, SecretKey master,
long timestamp, boolean alice) throws DbException {
return deriveAndAddKeys(txn, null, master, timestamp, alice, false);
}
private KeySetId deriveAndAddKeys(Transaction txn, @Nullable ContactId c,
SecretKey master, long timestamp, boolean alice, boolean active)
throws DbException {
lock.lock(); lock.lock();
try { try {
// Work out what rotation period the timestamp belongs to // Work out what rotation period the timestamp belongs to
long rotationPeriod = timestamp / rotationPeriodLength; long rotationPeriod = timestamp / rotationPeriodLength;
// Derive the transport keys // Derive the transport keys
TransportKeys k = transportCrypto.deriveTransportKeys(transportId, TransportKeys k = transportCrypto.deriveTransportKeys(transportId,
master, rotationPeriod, alice); master, rotationPeriod, alice, active);
// Rotate the keys to the current rotation period if necessary // Rotate the keys to the current rotation period if necessary
rotationPeriod = clock.currentTimeMillis() / rotationPeriodLength; rotationPeriod = clock.currentTimeMillis() / rotationPeriodLength;
k = transportCrypto.rotateTransportKeys(k, rotationPeriod); k = transportCrypto.rotateTransportKeys(k, rotationPeriod);
// Initialise mutable state for the contact
addKeys(c, new MutableTransportKeys(k));
// Write the keys back to the DB // Write the keys back to the DB
db.addTransportKeys(txn, c, k); KeySetId keySetId = db.addTransportKeys(txn, c, k);
// Initialise mutable state for the contact
addKeys(keySetId, c, new MutableTransportKeys(k));
return keySetId;
} finally {
lock.unlock();
}
}
@Override
public void bindKeys(Transaction txn, ContactId c, KeySetId k)
throws DbException {
lock.lock();
try {
MutableKeySet ks = keys.get(k);
if (ks == null) throw new IllegalArgumentException();
// Check that the keys haven't already been bound
if (ks.getContactId() != null) throw new IllegalArgumentException();
MutableTransportKeys m = ks.getTransportKeys();
addKeys(k, c, m);
db.bindTransportKeys(txn, c, m.getTransportId(), k);
} finally {
lock.unlock();
}
}
@Override
public void activateKeys(Transaction txn, KeySetId k) throws DbException {
lock.lock();
try {
MutableKeySet ks = keys.get(k);
if (ks == null) throw new IllegalArgumentException();
// Check that the keys have been bound
if (ks.getContactId() == null) throw new IllegalArgumentException();
MutableTransportKeys m = ks.getTransportKeys();
m.getCurrentOutgoingKeys().activate();
considerReplacingOutgoingKeys(ks);
db.setTransportKeysActive(txn, m.getTransportId(), k);
} finally {
lock.unlock();
}
}
@Override
public void removeKeys(Transaction txn, KeySetId k) throws DbException {
lock.lock();
try {
MutableKeySet ks = keys.remove(k);
if (ks == null) throw new IllegalArgumentException();
// Check that the keys haven't been bound
if (ks.getContactId() != null) throw new IllegalArgumentException();
TransportId t = ks.getTransportKeys().getTransportId();
db.removeTransportKeys(txn, t, k);
} finally { } finally {
lock.unlock(); lock.unlock();
} }
@@ -183,12 +265,29 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock(); lock.lock();
try { try {
// Remove mutable state for the contact // Remove mutable state for the contact
Iterator<Entry<Bytes, TagContext>> it = Iterator<TagContext> it = inContexts.values().iterator();
inContexts.entrySet().iterator(); while (it.hasNext()) if (it.next().contactId.equals(c)) it.remove();
while (it.hasNext())
if (it.next().getValue().contactId.equals(c)) it.remove();
outContexts.remove(c); outContexts.remove(c);
keys.remove(c); Iterator<MutableKeySet> it1 = keys.values().iterator();
while (it1.hasNext()) {
ContactId c1 = it1.next().getContactId();
if (c1 != null && c1.equals(c)) it1.remove();
}
} finally {
lock.unlock();
}
}
@Override
public boolean canSendOutgoingStreams(ContactId c) {
lock.lock();
try {
MutableKeySet ks = outContexts.get(c);
if (ks == null) return false;
MutableOutgoingKeys outKeys =
ks.getTransportKeys().getCurrentOutgoingKeys();
if (!outKeys.isActive()) throw new AssertionError();
return outKeys.getStreamCounter() <= MAX_32_BIT_UNSIGNED;
} finally { } finally {
lock.unlock(); lock.unlock();
} }
@@ -200,8 +299,11 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock(); lock.lock();
try { try {
// Look up the outgoing keys for the contact // Look up the outgoing keys for the contact
MutableOutgoingKeys outKeys = outContexts.get(c); MutableKeySet ks = outContexts.get(c);
if (outKeys == null) return null; if (ks == null) return null;
MutableOutgoingKeys outKeys =
ks.getTransportKeys().getCurrentOutgoingKeys();
if (!outKeys.isActive()) throw new AssertionError();
if (outKeys.getStreamCounter() > MAX_32_BIT_UNSIGNED) return null; if (outKeys.getStreamCounter() > MAX_32_BIT_UNSIGNED) return null;
// Create a stream context // Create a stream context
StreamContext ctx = new StreamContext(c, transportId, StreamContext ctx = new StreamContext(c, transportId,
@@ -209,8 +311,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
outKeys.getStreamCounter()); outKeys.getStreamCounter());
// Increment the stream counter and write it back to the DB // Increment the stream counter and write it back to the DB
outKeys.incrementStreamCounter(); outKeys.incrementStreamCounter();
db.incrementStreamCounter(txn, c, transportId, db.incrementStreamCounter(txn, transportId, ks.getKeySetId());
outKeys.getRotationPeriod());
return ctx; return ctx;
} finally { } finally {
lock.unlock(); lock.unlock();
@@ -238,8 +339,9 @@ class TransportKeyManagerImpl implements TransportKeyManager {
byte[] addTag = new byte[TAG_LENGTH]; byte[] addTag = new byte[TAG_LENGTH];
transportCrypto.encodeTag(addTag, inKeys.getTagKey(), transportCrypto.encodeTag(addTag, inKeys.getTagKey(),
PROTOCOL_VERSION, streamNumber); PROTOCOL_VERSION, streamNumber);
inContexts.put(new Bytes(addTag), new TagContext( TagContext tagCtx1 = new TagContext(tagCtx.keySetId,
tagCtx.contactId, inKeys, streamNumber)); tagCtx.contactId, inKeys, streamNumber);
inContexts.put(new Bytes(addTag), tagCtx1);
} }
// Remove tags for any stream numbers removed from the window // Remove tags for any stream numbers removed from the window
for (long streamNumber : change.getRemoved()) { for (long streamNumber : change.getRemoved()) {
@@ -250,9 +352,19 @@ class TransportKeyManagerImpl implements TransportKeyManager {
inContexts.remove(new Bytes(removeTag)); inContexts.remove(new Bytes(removeTag));
} }
// Write the window back to the DB // Write the window back to the DB
db.setReorderingWindow(txn, tagCtx.contactId, transportId, db.setReorderingWindow(txn, tagCtx.keySetId, transportId,
inKeys.getRotationPeriod(), window.getBase(), inKeys.getRotationPeriod(), window.getBase(),
window.getBitmap()); window.getBitmap());
// If the outgoing keys are inactive, activate them
MutableKeySet ks = keys.get(tagCtx.keySetId);
MutableOutgoingKeys outKeys =
ks.getTransportKeys().getCurrentOutgoingKeys();
if (!outKeys.isActive()) {
LOG.info("Activating outgoing keys");
outKeys.activate();
considerReplacingOutgoingKeys(ks);
db.setTransportKeysActive(txn, transportId, tagCtx.keySetId);
}
return ctx; return ctx;
} finally { } finally {
lock.unlock(); lock.unlock();
@@ -264,9 +376,11 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock(); lock.lock();
try { try {
// Rotate the keys to the current rotation period // Rotate the keys to the current rotation period
Map<ContactId, TransportKeys> snapshot = new HashMap<>(); Collection<KeySet> snapshot = new ArrayList<>(keys.size());
for (Entry<ContactId, MutableTransportKeys> e : keys.entrySet()) for (MutableKeySet ks : keys.values()) {
snapshot.put(e.getKey(), e.getValue().snapshot()); snapshot.add(new KeySet(ks.getKeySetId(), ks.getContactId(),
ks.getTransportKeys().snapshot()));
}
RotationResult rotationResult = rotateKeys(snapshot, now); RotationResult rotationResult = rotateKeys(snapshot, now);
// Rebuild the mutable state for all contacts // Rebuild the mutable state for all contacts
inContexts.clear(); inContexts.clear();
@@ -285,12 +399,14 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private static class TagContext { private static class TagContext {
private final KeySetId keySetId;
private final ContactId contactId; private final ContactId contactId;
private final MutableIncomingKeys inKeys; private final MutableIncomingKeys inKeys;
private final long streamNumber; private final long streamNumber;
private TagContext(ContactId contactId, MutableIncomingKeys inKeys, private TagContext(KeySetId keySetId, ContactId contactId,
long streamNumber) { MutableIncomingKeys inKeys, long streamNumber) {
this.keySetId = keySetId;
this.contactId = contactId; this.contactId = contactId;
this.inKeys = inKeys; this.inKeys = inKeys;
this.streamNumber = streamNumber; this.streamNumber = streamNumber;
@@ -299,11 +415,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private static class RotationResult { private static class RotationResult {
private final Map<ContactId, TransportKeys> current, rotated; private final Collection<KeySet> current = new ArrayList<>();
private final Collection<KeySet> rotated = new ArrayList<>();
private RotationResult() {
current = new HashMap<>();
rotated = new HashMap<>();
}
} }
} }

View File

@@ -33,7 +33,7 @@ public class KeyDerivationTest extends BrambleTestCase {
@Test @Test
public void testKeysAreDistinct() { public void testKeysAreDistinct() {
TransportKeys k = transportCrypto.deriveTransportKeys(transportId, TransportKeys k = transportCrypto.deriveTransportKeys(transportId,
master, 123, true); master, 123, true, true);
assertAllDifferent(k); assertAllDifferent(k);
} }
@@ -41,9 +41,9 @@ public class KeyDerivationTest extends BrambleTestCase {
public void testCurrentKeysMatchCurrentKeysOfContact() { public void testCurrentKeysMatchCurrentKeysOfContact() {
// Start in rotation period 123 // Start in rotation period 123
TransportKeys kA = transportCrypto.deriveTransportKeys(transportId, TransportKeys kA = transportCrypto.deriveTransportKeys(transportId,
master, 123, true); master, 123, true, true);
TransportKeys kB = transportCrypto.deriveTransportKeys(transportId, TransportKeys kB = transportCrypto.deriveTransportKeys(transportId,
master, 123, false); master, 123, false, true);
// Alice's incoming keys should equal Bob's outgoing keys // Alice's incoming keys should equal Bob's outgoing keys
assertArrayEquals(kA.getCurrentIncomingKeys().getTagKey().getBytes(), assertArrayEquals(kA.getCurrentIncomingKeys().getTagKey().getBytes(),
kB.getCurrentOutgoingKeys().getTagKey().getBytes()); kB.getCurrentOutgoingKeys().getTagKey().getBytes());
@@ -73,9 +73,9 @@ public class KeyDerivationTest extends BrambleTestCase {
public void testPreviousKeysMatchPreviousKeysOfContact() { public void testPreviousKeysMatchPreviousKeysOfContact() {
// Start in rotation period 123 // Start in rotation period 123
TransportKeys kA = transportCrypto.deriveTransportKeys(transportId, TransportKeys kA = transportCrypto.deriveTransportKeys(transportId,
master, 123, true); master, 123, true, true);
TransportKeys kB = transportCrypto.deriveTransportKeys(transportId, TransportKeys kB = transportCrypto.deriveTransportKeys(transportId,
master, 123, false); master, 123, false, true);
// Compare Alice's previous keys in period 456 with Bob's current keys // Compare Alice's previous keys in period 456 with Bob's current keys
// in period 455 // in period 455
kA = transportCrypto.rotateTransportKeys(kA, 456); kA = transportCrypto.rotateTransportKeys(kA, 456);
@@ -100,9 +100,9 @@ public class KeyDerivationTest extends BrambleTestCase {
public void testNextKeysMatchNextKeysOfContact() { public void testNextKeysMatchNextKeysOfContact() {
// Start in rotation period 123 // Start in rotation period 123
TransportKeys kA = transportCrypto.deriveTransportKeys(transportId, TransportKeys kA = transportCrypto.deriveTransportKeys(transportId,
master, 123, true); master, 123, true, true);
TransportKeys kB = transportCrypto.deriveTransportKeys(transportId, TransportKeys kB = transportCrypto.deriveTransportKeys(transportId,
master, 123, false); master, 123, false, true);
// Compare Alice's current keys in period 456 with Bob's next keys in // Compare Alice's current keys in period 456 with Bob's next keys in
// period 455 // period 455
kA = transportCrypto.rotateTransportKeys(kA, 456); kA = transportCrypto.rotateTransportKeys(kA, 456);
@@ -127,9 +127,9 @@ public class KeyDerivationTest extends BrambleTestCase {
SecretKey master1 = getSecretKey(); SecretKey master1 = getSecretKey();
assertFalse(Arrays.equals(master.getBytes(), master1.getBytes())); assertFalse(Arrays.equals(master.getBytes(), master1.getBytes()));
TransportKeys k = transportCrypto.deriveTransportKeys(transportId, TransportKeys k = transportCrypto.deriveTransportKeys(transportId,
master, 123, true); master, 123, true, true);
TransportKeys k1 = transportCrypto.deriveTransportKeys(transportId, TransportKeys k1 = transportCrypto.deriveTransportKeys(transportId,
master1, 123, true); master1, 123, true, true);
assertAllDifferent(k, k1); assertAllDifferent(k, k1);
} }
@@ -138,9 +138,9 @@ public class KeyDerivationTest extends BrambleTestCase {
TransportId transportId1 = new TransportId("id1"); TransportId transportId1 = new TransportId("id1");
assertFalse(transportId.getString().equals(transportId1.getString())); assertFalse(transportId.getString().equals(transportId1.getString()));
TransportKeys k = transportCrypto.deriveTransportKeys(transportId, TransportKeys k = transportCrypto.deriveTransportKeys(transportId,
master, 123, true); master, 123, true, true);
TransportKeys k1 = transportCrypto.deriveTransportKeys(transportId1, TransportKeys k1 = transportCrypto.deriveTransportKeys(transportId1,
master, 123, true); master, 123, true, true);
assertAllDifferent(k, k1); assertAllDifferent(k, k1);
} }

View File

@@ -44,6 +44,8 @@ import org.briarproject.bramble.api.sync.event.MessageToRequestEvent;
import org.briarproject.bramble.api.sync.event.MessagesAckedEvent; import org.briarproject.bramble.api.sync.event.MessagesAckedEvent;
import org.briarproject.bramble.api.sync.event.MessagesSentEvent; import org.briarproject.bramble.api.sync.event.MessagesSentEvent;
import org.briarproject.bramble.api.transport.IncomingKeys; import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.test.BrambleMockTestCase; import org.briarproject.bramble.test.BrambleMockTestCase;
@@ -55,12 +57,10 @@ import org.junit.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import static java.util.Collections.emptyMap; import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE; import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE;
import static org.briarproject.bramble.api.sync.Group.Visibility.SHARED; import static org.briarproject.bramble.api.sync.Group.Visibility.SHARED;
import static org.briarproject.bramble.api.sync.Group.Visibility.VISIBLE; import static org.briarproject.bramble.api.sync.Group.Visibility.VISIBLE;
@@ -71,6 +71,7 @@ import static org.briarproject.bramble.api.transport.TransportConstants.REORDERI
import static org.briarproject.bramble.db.DatabaseConstants.MAX_OFFERED_MESSAGES; import static org.briarproject.bramble.db.DatabaseConstants.MAX_OFFERED_MESSAGES;
import static org.briarproject.bramble.test.TestUtils.getAuthor; import static org.briarproject.bramble.test.TestUtils.getAuthor;
import static org.briarproject.bramble.test.TestUtils.getLocalAuthor; import static org.briarproject.bramble.test.TestUtils.getLocalAuthor;
import static org.briarproject.bramble.test.TestUtils.getSecretKey;
import static org.briarproject.bramble.util.StringUtils.getRandomString; import static org.briarproject.bramble.util.StringUtils.getRandomString;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
@@ -100,6 +101,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
private final int maxLatency; private final int maxLatency;
private final ContactId contactId; private final ContactId contactId;
private final Contact contact; private final Contact contact;
private final KeySetId keySetId;
public DatabaseComponentImplTest() { public DatabaseComponentImplTest() {
clientId = new ClientId(getRandomString(123)); clientId = new ClientId(getRandomString(123));
@@ -121,6 +123,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
contactId = new ContactId(234); contactId = new ContactId(234);
contact = new Contact(contactId, author, localAuthor.getId(), contact = new Contact(contactId, author, localAuthor.getId(),
true, true); true, true);
keySetId = new KeySetId(345);
} }
private DatabaseComponent createDatabaseComponent(Database<Object> database, private DatabaseComponent createDatabaseComponent(Database<Object> database,
@@ -282,11 +285,11 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
throws Exception { throws Exception {
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// Check whether the contact is in the DB (which it's not) // Check whether the contact is in the DB (which it's not)
exactly(18).of(database).startTransaction(); exactly(17).of(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
exactly(18).of(database).containsContact(txn, contactId); exactly(17).of(database).containsContact(txn, contactId);
will(returnValue(false)); will(returnValue(false));
exactly(18).of(database).abortTransaction(txn); exactly(17).of(database).abortTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, eventBus, DatabaseComponent db = createDatabaseComponent(database, eventBus,
shutdown); shutdown);
@@ -301,6 +304,16 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
db.endTransaction(transaction); db.endTransaction(transaction);
} }
transaction = db.startTransaction(false);
try {
db.bindTransportKeys(transaction, contactId, transportId, keySetId);
fail();
} catch (NoSuchContactException expected) {
// Expected
} finally {
db.endTransaction(transaction);
}
transaction = db.startTransaction(false); transaction = db.startTransaction(false);
try { try {
db.generateAck(transaction, contactId, 123); db.generateAck(transaction, contactId, 123);
@@ -371,16 +384,6 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
db.endTransaction(transaction); db.endTransaction(transaction);
} }
transaction = db.startTransaction(false);
try {
db.incrementStreamCounter(transaction, contactId, transportId, 0);
fail();
} catch (NoSuchContactException expected) {
// Expected
} finally {
db.endTransaction(transaction);
}
transaction = db.startTransaction(false); transaction = db.startTransaction(false);
try { try {
db.getGroupVisibility(transaction, contactId, groupId); db.getGroupVisibility(transaction, contactId, groupId);
@@ -454,17 +457,6 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
db.endTransaction(transaction); db.endTransaction(transaction);
} }
transaction = db.startTransaction(false);
try {
db.setReorderingWindow(transaction, contactId, transportId, 0, 0,
new byte[REORDERING_WINDOW_SIZE / 8]);
fail();
} catch (NoSuchContactException expected) {
// Expected
} finally {
db.endTransaction(transaction);
}
transaction = db.startTransaction(false); transaction = db.startTransaction(false);
try { try {
db.setGroupVisibility(transaction, contactId, groupId, SHARED); db.setGroupVisibility(transaction, contactId, groupId, SHARED);
@@ -777,13 +769,13 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
// endTransaction() // endTransaction()
oneOf(database).commitTransaction(txn); oneOf(database).commitTransaction(txn);
// Check whether the transport is in the DB (which it's not) // Check whether the transport is in the DB (which it's not)
exactly(4).of(database).startTransaction(); exactly(6).of(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
exactly(2).of(database).containsContact(txn, contactId); oneOf(database).containsContact(txn, contactId);
will(returnValue(true)); will(returnValue(true));
exactly(4).of(database).containsTransport(txn, transportId); exactly(6).of(database).containsTransport(txn, transportId);
will(returnValue(false)); will(returnValue(false));
exactly(4).of(database).abortTransaction(txn); exactly(6).of(database).abortTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, eventBus, DatabaseComponent db = createDatabaseComponent(database, eventBus,
shutdown); shutdown);
@@ -798,6 +790,16 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
db.endTransaction(transaction); db.endTransaction(transaction);
} }
transaction = db.startTransaction(false);
try {
db.bindTransportKeys(transaction, contactId, transportId, keySetId);
fail();
} catch (NoSuchTransportException expected) {
// Expected
} finally {
db.endTransaction(transaction);
}
transaction = db.startTransaction(false); transaction = db.startTransaction(false);
try { try {
db.getTransportKeys(transaction, transportId); db.getTransportKeys(transaction, transportId);
@@ -810,7 +812,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
transaction = db.startTransaction(false); transaction = db.startTransaction(false);
try { try {
db.incrementStreamCounter(transaction, contactId, transportId, 0); db.incrementStreamCounter(transaction, transportId, keySetId);
fail(); fail();
} catch (NoSuchTransportException expected) { } catch (NoSuchTransportException expected) {
// Expected // Expected
@@ -830,7 +832,17 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
transaction = db.startTransaction(false); transaction = db.startTransaction(false);
try { try {
db.setReorderingWindow(transaction, contactId, transportId, 0, 0, db.removeTransportKeys(transaction, transportId, keySetId);
fail();
} catch (NoSuchTransportException expected) {
// Expected
} finally {
db.endTransaction(transaction);
}
transaction = db.startTransaction(false);
try {
db.setReorderingWindow(transaction, keySetId, transportId, 0, 0,
new byte[REORDERING_WINDOW_SIZE / 8]); new byte[REORDERING_WINDOW_SIZE / 8]);
fail(); fail();
} catch (NoSuchTransportException expected) { } catch (NoSuchTransportException expected) {
@@ -1303,15 +1315,13 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
@Test @Test
public void testTransportKeys() throws Exception { public void testTransportKeys() throws Exception {
TransportKeys transportKeys = createTransportKeys(); TransportKeys transportKeys = createTransportKeys();
Map<ContactId, TransportKeys> keys = Collection<KeySet> keys =
singletonMap(contactId, transportKeys); singletonList(new KeySet(keySetId, contactId, transportKeys));
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// startTransaction() // startTransaction()
oneOf(database).startTransaction(); oneOf(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
// updateTransportKeys() // updateTransportKeys()
oneOf(database).containsContact(txn, contactId);
will(returnValue(true));
oneOf(database).containsTransport(txn, transportId); oneOf(database).containsTransport(txn, transportId);
will(returnValue(true)); will(returnValue(true));
oneOf(database).updateTransportKeys(txn, keys); oneOf(database).updateTransportKeys(txn, keys);
@@ -1337,22 +1347,22 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
} }
private TransportKeys createTransportKeys() { private TransportKeys createTransportKeys() {
SecretKey inPrevTagKey = TestUtils.getSecretKey(); SecretKey inPrevTagKey = getSecretKey();
SecretKey inPrevHeaderKey = TestUtils.getSecretKey(); SecretKey inPrevHeaderKey = getSecretKey();
IncomingKeys inPrev = new IncomingKeys(inPrevTagKey, inPrevHeaderKey, IncomingKeys inPrev = new IncomingKeys(inPrevTagKey, inPrevHeaderKey,
1, 123, new byte[4]); 1, 123, new byte[4]);
SecretKey inCurrTagKey = TestUtils.getSecretKey(); SecretKey inCurrTagKey = getSecretKey();
SecretKey inCurrHeaderKey = TestUtils.getSecretKey(); SecretKey inCurrHeaderKey = getSecretKey();
IncomingKeys inCurr = new IncomingKeys(inCurrTagKey, inCurrHeaderKey, IncomingKeys inCurr = new IncomingKeys(inCurrTagKey, inCurrHeaderKey,
2, 234, new byte[4]); 2, 234, new byte[4]);
SecretKey inNextTagKey = TestUtils.getSecretKey(); SecretKey inNextTagKey = getSecretKey();
SecretKey inNextHeaderKey = TestUtils.getSecretKey(); SecretKey inNextHeaderKey = getSecretKey();
IncomingKeys inNext = new IncomingKeys(inNextTagKey, inNextHeaderKey, IncomingKeys inNext = new IncomingKeys(inNextTagKey, inNextHeaderKey,
3, 345, new byte[4]); 3, 345, new byte[4]);
SecretKey outCurrTagKey = TestUtils.getSecretKey(); SecretKey outCurrTagKey = getSecretKey();
SecretKey outCurrHeaderKey = TestUtils.getSecretKey(); SecretKey outCurrHeaderKey = getSecretKey();
OutgoingKeys outCurr = new OutgoingKeys(outCurrTagKey, outCurrHeaderKey, OutgoingKeys outCurr = new OutgoingKeys(outCurrTagKey, outCurrHeaderKey,
2, 456); 2, 456, true);
return new TransportKeys(transportId, inPrev, inCurr, inNext, outCurr); return new TransportKeys(transportId, inPrev, inCurr, inNext, outCurr);
} }

View File

@@ -19,6 +19,8 @@ import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.ValidationManager.State; import org.briarproject.bramble.api.sync.ValidationManager.State;
import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.transport.IncomingKeys; import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.system.SystemClock; import org.briarproject.bramble.system.SystemClock;
@@ -34,7 +36,6 @@ import java.sql.Connection;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
@@ -42,6 +43,10 @@ import java.util.Random;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static java.util.concurrent.TimeUnit.SECONDS; import static java.util.concurrent.TimeUnit.SECONDS;
import static org.briarproject.bramble.api.db.Metadata.REMOVE; import static org.briarproject.bramble.api.db.Metadata.REMOVE;
import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE; import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE;
@@ -86,6 +91,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
private final Message message; private final Message message;
private final TransportId transportId; private final TransportId transportId;
private final ContactId contactId; private final ContactId contactId;
private final KeySetId keySetId, keySetId1;
JdbcDatabaseTest() throws Exception { JdbcDatabaseTest() throws Exception {
groupId = new GroupId(getRandomId()); groupId = new GroupId(getRandomId());
@@ -101,6 +107,8 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
message = new Message(messageId, groupId, timestamp, raw); message = new Message(messageId, groupId, timestamp, raw);
transportId = new TransportId("id"); transportId = new TransportId("id");
contactId = new ContactId(1); contactId = new ContactId(1);
keySetId = new KeySetId(1);
keySetId1 = new KeySetId(2);
} }
protected abstract JdbcDatabase createDatabase(DatabaseConfig config, protected abstract JdbcDatabase createDatabase(DatabaseConfig config,
@@ -190,9 +198,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// The contact has not seen the message, so it should be sendable // The contact has not seen the message, so it should be sendable
Collection<MessageId> ids = Collection<MessageId> ids =
db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
ids = db.getMessagesToOffer(txn, contactId, 100); ids = db.getMessagesToOffer(txn, contactId, 100);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
// Changing the status to seen = true should make the message unsendable // Changing the status to seen = true should make the message unsendable
db.raiseSeenFlag(txn, contactId, messageId); db.raiseSeenFlag(txn, contactId, messageId);
@@ -228,9 +236,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// Marking the message delivered should make it sendable // Marking the message delivered should make it sendable
db.setMessageState(txn, messageId, DELIVERED); db.setMessageState(txn, messageId, DELIVERED);
ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
ids = db.getMessagesToOffer(txn, contactId, 100); ids = db.getMessagesToOffer(txn, contactId, 100);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
// Marking the message invalid should make it unsendable // Marking the message invalid should make it unsendable
db.setMessageState(txn, messageId, INVALID); db.setMessageState(txn, messageId, INVALID);
@@ -279,9 +287,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// Sharing the group should make the message sendable // Sharing the group should make the message sendable
db.setGroupVisibility(txn, contactId, groupId, true); db.setGroupVisibility(txn, contactId, groupId, true);
ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
ids = db.getMessagesToOffer(txn, contactId, 100); ids = db.getMessagesToOffer(txn, contactId, 100);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
// Unsharing the group should make the message unsendable // Unsharing the group should make the message unsendable
db.setGroupVisibility(txn, contactId, groupId, false); db.setGroupVisibility(txn, contactId, groupId, false);
@@ -324,9 +332,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// Sharing the message should make it sendable // Sharing the message should make it sendable
db.setMessageShared(txn, messageId); db.setMessageShared(txn, messageId);
ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
ids = db.getMessagesToOffer(txn, contactId, 100); ids = db.getMessagesToOffer(txn, contactId, 100);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
db.commitTransaction(txn); db.commitTransaction(txn);
db.close(); db.close();
@@ -352,7 +360,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// The message is just the right size to send // The message is just the right size to send
ids = db.getMessagesToSend(txn, contactId, size); ids = db.getMessagesToSend(txn, contactId, size);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
db.commitTransaction(txn); db.commitTransaction(txn);
db.close(); db.close();
@@ -384,7 +392,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
db.lowerAckFlag(txn, contactId, Arrays.asList(messageId, messageId1)); db.lowerAckFlag(txn, contactId, Arrays.asList(messageId, messageId1));
// Both message IDs should have been removed // Both message IDs should have been removed
assertEquals(Collections.emptyList(), db.getMessagesToAck(txn, assertEquals(emptyList(), db.getMessagesToAck(txn,
contactId, 1234)); contactId, 1234));
// Raise the ack flag again // Raise the ack flag again
@@ -415,7 +423,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// Retrieve the message from the database and mark it as sent // Retrieve the message from the database and mark it as sent
Collection<MessageId> ids = db.getMessagesToSend(txn, contactId, Collection<MessageId> ids = db.getMessagesToSend(txn, contactId,
ONE_MEGABYTE); ONE_MEGABYTE);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
db.updateExpiryTime(txn, contactId, messageId, Integer.MAX_VALUE); db.updateExpiryTime(txn, contactId, messageId, Integer.MAX_VALUE);
// The message should no longer be sendable // The message should no longer be sendable
@@ -626,31 +634,31 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// The group should not be visible to the contact // The group should not be visible to the contact
assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId)); assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId));
assertEquals(Collections.emptyMap(), assertEquals(emptyMap(),
db.getGroupVisibility(txn, groupId)); db.getGroupVisibility(txn, groupId));
// Make the group visible to the contact // Make the group visible to the contact
db.addGroupVisibility(txn, contactId, groupId, false); db.addGroupVisibility(txn, contactId, groupId, false);
assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId)); assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId));
assertEquals(Collections.singletonMap(contactId, false), assertEquals(singletonMap(contactId, false),
db.getGroupVisibility(txn, groupId)); db.getGroupVisibility(txn, groupId));
// Share the group with the contact // Share the group with the contact
db.setGroupVisibility(txn, contactId, groupId, true); db.setGroupVisibility(txn, contactId, groupId, true);
assertEquals(SHARED, db.getGroupVisibility(txn, contactId, groupId)); assertEquals(SHARED, db.getGroupVisibility(txn, contactId, groupId));
assertEquals(Collections.singletonMap(contactId, true), assertEquals(singletonMap(contactId, true),
db.getGroupVisibility(txn, groupId)); db.getGroupVisibility(txn, groupId));
// Unshare the group with the contact // Unshare the group with the contact
db.setGroupVisibility(txn, contactId, groupId, false); db.setGroupVisibility(txn, contactId, groupId, false);
assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId)); assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId));
assertEquals(Collections.singletonMap(contactId, false), assertEquals(singletonMap(contactId, false),
db.getGroupVisibility(txn, groupId)); db.getGroupVisibility(txn, groupId));
// Make the group invisible again // Make the group invisible again
db.removeGroupVisibility(txn, contactId, groupId); db.removeGroupVisibility(txn, contactId, groupId);
assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId)); assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId));
assertEquals(Collections.emptyMap(), assertEquals(emptyMap(),
db.getGroupVisibility(txn, groupId)); db.getGroupVisibility(txn, groupId));
db.commitTransaction(txn); db.commitTransaction(txn);
@@ -660,48 +668,125 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
@Test @Test
public void testTransportKeys() throws Exception { public void testTransportKeys() throws Exception {
TransportKeys keys = createTransportKeys(); TransportKeys keys = createTransportKeys();
TransportKeys keys1 = createTransportKeys();
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Initially there should be no transport keys in the database // Initially there should be no transport keys in the database
assertEquals(Collections.emptyMap(), assertEquals(emptyList(), db.getTransportKeys(txn, transportId));
db.getTransportKeys(txn, transportId));
// Add the contact, the transport and the transport keys // Add the contact, the transport and the transport keys
db.addLocalAuthor(txn, localAuthor); db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
true, true)); true, true));
db.addTransport(txn, transportId, 123); db.addTransport(txn, transportId, 123);
db.addTransportKeys(txn, contactId, keys); assertEquals(keySetId, db.addTransportKeys(txn, contactId, keys));
assertEquals(keySetId1, db.addTransportKeys(txn, contactId, keys1));
// Retrieve the transport keys // Retrieve the transport keys
Map<ContactId, TransportKeys> newKeys = Collection<KeySet> allKeys = db.getTransportKeys(txn, transportId);
db.getTransportKeys(txn, transportId); assertEquals(2, allKeys.size());
assertEquals(1, newKeys.size()); for (KeySet ks : allKeys) {
Entry<ContactId, TransportKeys> e = assertEquals(contactId, ks.getContactId());
newKeys.entrySet().iterator().next(); if (ks.getKeySetId().equals(keySetId)) {
assertEquals(contactId, e.getKey()); assertKeysEquals(keys, ks.getTransportKeys());
TransportKeys k = e.getValue(); } else {
assertEquals(transportId, k.getTransportId()); assertEquals(keySetId1, ks.getKeySetId());
assertKeysEquals(keys.getPreviousIncomingKeys(), assertKeysEquals(keys1, ks.getTransportKeys());
k.getPreviousIncomingKeys()); }
assertKeysEquals(keys.getCurrentIncomingKeys(), }
k.getCurrentIncomingKeys());
assertKeysEquals(keys.getNextIncomingKeys(),
k.getNextIncomingKeys());
assertKeysEquals(keys.getCurrentOutgoingKeys(),
k.getCurrentOutgoingKeys());
// Removing the contact should remove the transport keys // Removing the contact should remove the transport keys
db.removeContact(txn, contactId); db.removeContact(txn, contactId);
assertEquals(Collections.emptyMap(), assertEquals(emptyList(), db.getTransportKeys(txn, transportId));
db.getTransportKeys(txn, transportId));
db.commitTransaction(txn); db.commitTransaction(txn);
db.close(); db.close();
} }
@Test
public void testUnboundTransportKeys() throws Exception {
TransportKeys keys = createTransportKeys();
TransportKeys keys1 = createTransportKeys();
Database<Connection> db = open(false);
Connection txn = db.startTransaction();
// Initially there should be no transport keys in the database
assertEquals(emptyList(), db.getTransportKeys(txn, transportId));
// Add the contact, the transport and the unbound transport keys
db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
true, true));
db.addTransport(txn, transportId, 123);
assertEquals(keySetId, db.addTransportKeys(txn, null, keys));
assertEquals(keySetId1, db.addTransportKeys(txn, null, keys1));
// Retrieve the transport keys
Collection<KeySet> allKeys = db.getTransportKeys(txn, transportId);
assertEquals(2, allKeys.size());
for (KeySet ks : allKeys) {
assertNull(ks.getContactId());
if (ks.getKeySetId().equals(keySetId)) {
assertKeysEquals(keys, ks.getTransportKeys());
} else {
assertEquals(keySetId1, ks.getKeySetId());
assertKeysEquals(keys1, ks.getTransportKeys());
}
}
// Bind the first set of transport keys
db.bindTransportKeys(txn, contactId, transportId, keySetId);
// Retrieve the keys again - the first set should be bound
allKeys = db.getTransportKeys(txn, transportId);
assertEquals(2, allKeys.size());
for (KeySet ks : allKeys) {
if (ks.getKeySetId().equals(keySetId)) {
assertEquals(contactId, ks.getContactId());
assertKeysEquals(keys, ks.getTransportKeys());
} else {
assertEquals(keySetId1, ks.getKeySetId());
assertNull(ks.getContactId());
assertKeysEquals(keys1, ks.getTransportKeys());
}
}
// Remove the unbound transport keys
db.removeTransportKeys(txn, transportId, keySetId1);
// Retrieve the keys again - the second set should be gone
allKeys = db.getTransportKeys(txn, transportId);
assertEquals(1, allKeys.size());
KeySet ks = allKeys.iterator().next();
assertEquals(keySetId, ks.getKeySetId());
assertEquals(contactId, ks.getContactId());
assertKeysEquals(keys, ks.getTransportKeys());
// Removing the transport should remove the remaining transport keys
db.removeTransport(txn, transportId);
assertEquals(emptyList(), db.getTransportKeys(txn, transportId));
db.commitTransaction(txn);
db.close();
}
private void assertKeysEquals(TransportKeys expected,
TransportKeys actual) {
assertEquals(expected.getTransportId(), actual.getTransportId());
assertEquals(expected.getRotationPeriod(), actual.getRotationPeriod());
assertKeysEquals(expected.getPreviousIncomingKeys(),
actual.getPreviousIncomingKeys());
assertKeysEquals(expected.getCurrentIncomingKeys(),
actual.getCurrentIncomingKeys());
assertKeysEquals(expected.getNextIncomingKeys(),
actual.getNextIncomingKeys());
assertKeysEquals(expected.getCurrentOutgoingKeys(),
actual.getCurrentOutgoingKeys());
}
private void assertKeysEquals(IncomingKeys expected, IncomingKeys actual) { private void assertKeysEquals(IncomingKeys expected, IncomingKeys actual) {
assertArrayEquals(expected.getTagKey().getBytes(), assertArrayEquals(expected.getTagKey().getBytes(),
actual.getTagKey().getBytes()); actual.getTagKey().getBytes());
@@ -719,6 +804,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
actual.getHeaderKey().getBytes()); actual.getHeaderKey().getBytes());
assertEquals(expected.getRotationPeriod(), actual.getRotationPeriod()); assertEquals(expected.getRotationPeriod(), actual.getRotationPeriod());
assertEquals(expected.getStreamCounter(), actual.getStreamCounter()); assertEquals(expected.getStreamCounter(), actual.getStreamCounter());
assertEquals(expected.isActive(), actual.isActive());
} }
@Test @Test
@@ -735,18 +821,18 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
true, true)); true, true));
db.addTransport(txn, transportId, 123); db.addTransport(txn, transportId, 123);
db.updateTransportKeys(txn, Collections.singletonMap(contactId, keys)); db.updateTransportKeys(txn,
singletonList(new KeySet(keySetId, contactId, keys)));
// Increment the stream counter twice and retrieve the transport keys // Increment the stream counter twice and retrieve the transport keys
db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod); db.incrementStreamCounter(txn, transportId, keySetId);
db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod); db.incrementStreamCounter(txn, transportId, keySetId);
Map<ContactId, TransportKeys> newKeys = Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId);
db.getTransportKeys(txn, transportId);
assertEquals(1, newKeys.size()); assertEquals(1, newKeys.size());
Entry<ContactId, TransportKeys> e = KeySet ks = newKeys.iterator().next();
newKeys.entrySet().iterator().next(); assertEquals(keySetId, ks.getKeySetId());
assertEquals(contactId, e.getKey()); assertEquals(contactId, ks.getContactId());
TransportKeys k = e.getValue(); TransportKeys k = ks.getTransportKeys();
assertEquals(transportId, k.getTransportId()); assertEquals(transportId, k.getTransportId());
OutgoingKeys outCurr = k.getCurrentOutgoingKeys(); OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
assertEquals(rotationPeriod, outCurr.getRotationPeriod()); assertEquals(rotationPeriod, outCurr.getRotationPeriod());
@@ -771,19 +857,19 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
true, true)); true, true));
db.addTransport(txn, transportId, 123); db.addTransport(txn, transportId, 123);
db.updateTransportKeys(txn, Collections.singletonMap(contactId, keys)); db.updateTransportKeys(txn,
singletonList(new KeySet(keySetId, contactId, keys)));
// Update the reordering window and retrieve the transport keys // Update the reordering window and retrieve the transport keys
new Random().nextBytes(bitmap); new Random().nextBytes(bitmap);
db.setReorderingWindow(txn, contactId, transportId, rotationPeriod, db.setReorderingWindow(txn, keySetId, transportId, rotationPeriod,
base + 1, bitmap); base + 1, bitmap);
Map<ContactId, TransportKeys> newKeys = Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId);
db.getTransportKeys(txn, transportId);
assertEquals(1, newKeys.size()); assertEquals(1, newKeys.size());
Entry<ContactId, TransportKeys> e = KeySet ks = newKeys.iterator().next();
newKeys.entrySet().iterator().next(); assertEquals(keySetId, ks.getKeySetId());
assertEquals(contactId, e.getKey()); assertEquals(contactId, ks.getContactId());
TransportKeys k = e.getValue(); TransportKeys k = ks.getTransportKeys();
assertEquals(transportId, k.getTransportId()); assertEquals(transportId, k.getTransportId());
IncomingKeys inCurr = k.getCurrentIncomingKeys(); IncomingKeys inCurr = k.getCurrentIncomingKeys();
assertEquals(rotationPeriod, inCurr.getRotationPeriod()); assertEquals(rotationPeriod, inCurr.getRotationPeriod());
@@ -830,18 +916,18 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
db.addLocalAuthor(txn, localAuthor); db.addLocalAuthor(txn, localAuthor);
Collection<ContactId> contacts = Collection<ContactId> contacts =
db.getContacts(txn, localAuthor.getId()); db.getContacts(txn, localAuthor.getId());
assertEquals(Collections.emptyList(), contacts); assertEquals(emptyList(), contacts);
// Add a contact associated with the local author // Add a contact associated with the local author
assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
true, true)); true, true));
contacts = db.getContacts(txn, localAuthor.getId()); contacts = db.getContacts(txn, localAuthor.getId());
assertEquals(Collections.singletonList(contactId), contacts); assertEquals(singletonList(contactId), contacts);
// Remove the local author - the contact should be removed // Remove the local author - the contact should be removed
db.removeLocalAuthor(txn, localAuthor.getId()); db.removeLocalAuthor(txn, localAuthor.getId());
contacts = db.getContacts(txn, localAuthor.getId()); contacts = db.getContacts(txn, localAuthor.getId());
assertEquals(Collections.emptyList(), contacts); assertEquals(emptyList(), contacts);
assertFalse(db.containsContact(txn, contactId)); assertFalse(db.containsContact(txn, contactId));
db.commitTransaction(txn); db.commitTransaction(txn);
@@ -1560,9 +1646,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// The message should be sendable // The message should be sendable
Collection<MessageId> ids = db.getMessagesToSend(txn, contactId, Collection<MessageId> ids = db.getMessagesToSend(txn, contactId,
ONE_MEGABYTE); ONE_MEGABYTE);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
ids = db.getMessagesToOffer(txn, contactId, 100); ids = db.getMessagesToOffer(txn, contactId, 100);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
// The raw message should not be null // The raw message should not be null
assertNotNull(db.getRawMessage(txn, messageId)); assertNotNull(db.getRawMessage(txn, messageId));
@@ -1735,7 +1821,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
SecretKey outCurrTagKey = getSecretKey(); SecretKey outCurrTagKey = getSecretKey();
SecretKey outCurrHeaderKey = getSecretKey(); SecretKey outCurrHeaderKey = getSecretKey();
OutgoingKeys outCurr = new OutgoingKeys(outCurrTagKey, outCurrHeaderKey, OutgoingKeys outCurr = new OutgoingKeys(outCurrTagKey, outCurrHeaderKey,
2, 456); 2, 456, true);
return new TransportKeys(transportId, inPrev, inCurr, inNext, outCurr); return new TransportKeys(transportId, inPrev, inCurr, inNext, outCurr);
} }

View File

@@ -12,19 +12,20 @@ import org.briarproject.bramble.api.identity.AuthorId;
import org.briarproject.bramble.api.plugin.PluginConfig; import org.briarproject.bramble.api.plugin.PluginConfig;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.simplex.SimplexPluginFactory; import org.briarproject.bramble.api.plugin.simplex.SimplexPluginFactory;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.StreamContext; import org.briarproject.bramble.api.transport.StreamContext;
import org.briarproject.bramble.test.BrambleTestCase; import org.briarproject.bramble.test.BrambleMockTestCase;
import org.jmock.Expectations; import org.jmock.Expectations;
import org.jmock.Mockery;
import org.jmock.lib.concurrent.DeterministicExecutor; import org.jmock.lib.concurrent.DeterministicExecutor;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.Random; import java.util.Random;
import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static org.briarproject.bramble.api.transport.TransportConstants.TAG_LENGTH; import static org.briarproject.bramble.api.transport.TransportConstants.TAG_LENGTH;
import static org.briarproject.bramble.test.TestUtils.getAuthor; import static org.briarproject.bramble.test.TestUtils.getAuthor;
import static org.briarproject.bramble.test.TestUtils.getRandomBytes; import static org.briarproject.bramble.test.TestUtils.getRandomBytes;
@@ -32,31 +33,29 @@ import static org.briarproject.bramble.test.TestUtils.getRandomId;
import static org.briarproject.bramble.test.TestUtils.getSecretKey; import static org.briarproject.bramble.test.TestUtils.getSecretKey;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
public class KeyManagerImplTest extends BrambleTestCase { public class KeyManagerImplTest extends BrambleMockTestCase {
private final Mockery context = new Mockery();
private final KeyManagerImpl keyManager;
private final DatabaseComponent db = context.mock(DatabaseComponent.class); private final DatabaseComponent db = context.mock(DatabaseComponent.class);
private final PluginConfig pluginConfig = context.mock(PluginConfig.class); private final PluginConfig pluginConfig = context.mock(PluginConfig.class);
private final TransportKeyManagerFactory transportKeyManagerFactory = private final TransportKeyManagerFactory transportKeyManagerFactory =
context.mock(TransportKeyManagerFactory.class); context.mock(TransportKeyManagerFactory.class);
private final TransportKeyManager transportKeyManager = private final TransportKeyManager transportKeyManager =
context.mock(TransportKeyManager.class); context.mock(TransportKeyManager.class);
private final DeterministicExecutor executor = new DeterministicExecutor(); private final DeterministicExecutor executor = new DeterministicExecutor();
private final Transaction txn = new Transaction(null, false); private final Transaction txn = new Transaction(null, false);
private final ContactId contactId = new ContactId(42); private final ContactId contactId = new ContactId(123);
private final ContactId inactiveContactId = new ContactId(43); private final ContactId inactiveContactId = new ContactId(234);
private final TransportId transportId = new TransportId("tId"); private final KeySetId keySetId = new KeySetId(345);
private final TransportId unknownTransportId = new TransportId("id"); private final TransportId transportId = new TransportId("known");
private final TransportId unknownTransportId = new TransportId("unknown");
private final StreamContext streamContext = private final StreamContext streamContext =
new StreamContext(contactId, transportId, getSecretKey(), new StreamContext(contactId, transportId, getSecretKey(),
getSecretKey(), 1); getSecretKey(), 1);
private final byte[] tag = getRandomBytes(TAG_LENGTH); private final byte[] tag = getRandomBytes(TAG_LENGTH);
public KeyManagerImplTest() { private final KeyManagerImpl keyManager = new KeyManagerImpl(db, executor,
keyManager = new KeyManagerImpl(db, executor, pluginConfig, pluginConfig, transportKeyManagerFactory);
transportKeyManagerFactory);
}
@Before @Before
public void testStartService() throws Exception { public void testStartService() throws Exception {
@@ -70,8 +69,8 @@ public class KeyManagerImplTest extends BrambleTestCase {
true, false)); true, false));
SimplexPluginFactory pluginFactory = SimplexPluginFactory pluginFactory =
context.mock(SimplexPluginFactory.class); context.mock(SimplexPluginFactory.class);
Collection<SimplexPluginFactory> factories = Collections Collection<SimplexPluginFactory> factories =
.singletonList(pluginFactory); singletonList(pluginFactory);
int maxLatency = 1337; int maxLatency = 1337;
context.checking(new Expectations() {{ context.checking(new Expectations() {{
@@ -110,7 +109,22 @@ public class KeyManagerImplTest extends BrambleTestCase {
}}); }});
keyManager.addContact(txn, contactId, secretKey, timestamp, alice); keyManager.addContact(txn, contactId, secretKey, timestamp, alice);
context.assertIsSatisfied(); }
@Test
public void testAddUnboundKeys() throws Exception {
SecretKey secretKey = getSecretKey();
long timestamp = System.currentTimeMillis();
boolean alice = new Random().nextBoolean();
context.checking(new Expectations() {{
oneOf(transportKeyManager).addUnboundKeys(txn, secretKey,
timestamp, alice);
will(returnValue(keySetId));
}});
assertEquals(singletonMap(transportId, keySetId),
keyManager.addUnboundKeys(txn, secretKey, timestamp, alice));
} }
@Test @Test
@@ -138,7 +152,6 @@ public class KeyManagerImplTest extends BrambleTestCase {
assertEquals(streamContext, assertEquals(streamContext,
keyManager.getStreamContext(contactId, transportId)); keyManager.getStreamContext(contactId, transportId));
context.assertIsSatisfied();
} }
@Test @Test
@@ -161,7 +174,6 @@ public class KeyManagerImplTest extends BrambleTestCase {
assertEquals(streamContext, assertEquals(streamContext,
keyManager.getStreamContext(transportId, tag)); keyManager.getStreamContext(transportId, tag));
context.assertIsSatisfied();
} }
@Test @Test
@@ -175,8 +187,6 @@ public class KeyManagerImplTest extends BrambleTestCase {
keyManager.eventOccurred(event); keyManager.eventOccurred(event);
executor.runUntilIdle(); executor.runUntilIdle();
assertEquals(null, keyManager.getStreamContext(contactId, transportId)); assertEquals(null, keyManager.getStreamContext(contactId, transportId));
context.assertIsSatisfied();
} }
@Test @Test
@@ -196,8 +206,5 @@ public class KeyManagerImplTest extends BrambleTestCase {
keyManager.eventOccurred(event); keyManager.eventOccurred(event);
assertEquals(streamContext, assertEquals(streamContext,
keyManager.getStreamContext(inactiveContactId, transportId)); keyManager.getStreamContext(inactiveContactId, transportId));
context.assertIsSatisfied();
} }
} }

View File

@@ -8,6 +8,8 @@ import org.briarproject.bramble.api.db.Transaction;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.transport.IncomingKeys; import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.StreamContext; import org.briarproject.bramble.api.transport.StreamContext;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
@@ -22,14 +24,13 @@ import org.junit.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.briarproject.bramble.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE; import static org.briarproject.bramble.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE;
import static org.briarproject.bramble.api.transport.TransportConstants.PROTOCOL_VERSION; import static org.briarproject.bramble.api.transport.TransportConstants.PROTOCOL_VERSION;
@@ -37,8 +38,10 @@ import static org.briarproject.bramble.api.transport.TransportConstants.REORDERI
import static org.briarproject.bramble.api.transport.TransportConstants.TAG_LENGTH; import static org.briarproject.bramble.api.transport.TransportConstants.TAG_LENGTH;
import static org.briarproject.bramble.util.ByteUtils.MAX_32_BIT_UNSIGNED; import static org.briarproject.bramble.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
public class TransportKeyManagerImplTest extends BrambleMockTestCase { public class TransportKeyManagerImplTest extends BrambleMockTestCase {
@@ -55,6 +58,9 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
private final long rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE; private final long rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE;
private final ContactId contactId = new ContactId(123); private final ContactId contactId = new ContactId(123);
private final ContactId contactId1 = new ContactId(234); private final ContactId contactId1 = new ContactId(234);
private final KeySetId keySetId = new KeySetId(345);
private final KeySetId keySetId1 = new KeySetId(456);
private final KeySetId keySetId2 = new KeySetId(567);
private final SecretKey tagKey = TestUtils.getSecretKey(); private final SecretKey tagKey = TestUtils.getSecretKey();
private final SecretKey headerKey = TestUtils.getSecretKey(); private final SecretKey headerKey = TestUtils.getSecretKey();
private final SecretKey masterKey = TestUtils.getSecretKey(); private final SecretKey masterKey = TestUtils.getSecretKey();
@@ -62,12 +68,16 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
@Test @Test
public void testKeysAreRotatedAtStartup() throws Exception { public void testKeysAreRotatedAtStartup() throws Exception {
Map<ContactId, TransportKeys> loaded = new LinkedHashMap<>(); TransportKeys shouldRotate = createTransportKeys(900, 0, true);
TransportKeys shouldRotate = createTransportKeys(900, 0); TransportKeys shouldNotRotate = createTransportKeys(1000, 0, true);
TransportKeys shouldNotRotate = createTransportKeys(1000, 0); TransportKeys shouldRotate1 = createTransportKeys(999, 0, false);
loaded.put(contactId, shouldRotate); Collection<KeySet> loaded = asList(
loaded.put(contactId1, shouldNotRotate); new KeySet(keySetId, contactId, shouldRotate),
TransportKeys rotated = createTransportKeys(1000, 0); new KeySet(keySetId1, contactId1, shouldNotRotate),
new KeySet(keySetId2, null, shouldRotate1)
);
TransportKeys rotated = createTransportKeys(1000, 0, true);
TransportKeys rotated1 = createTransportKeys(1000, 0, false);
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
@@ -82,6 +92,8 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
will(returnValue(rotated)); will(returnValue(rotated));
oneOf(transportCrypto).rotateTransportKeys(shouldNotRotate, 1000); oneOf(transportCrypto).rotateTransportKeys(shouldNotRotate, 1000);
will(returnValue(shouldNotRotate)); will(returnValue(shouldNotRotate));
oneOf(transportCrypto).rotateTransportKeys(shouldRotate1, 1000);
will(returnValue(rotated1));
// Encode the tags (3 sets per contact) // Encode the tags (3 sets per contact)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(6).of(transportCrypto).encodeTag( exactly(6).of(transportCrypto).encodeTag(
@@ -90,8 +102,10 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
will(new EncodeTagAction()); will(new EncodeTagAction());
} }
// Save the keys that were rotated // Save the keys that were rotated
oneOf(db).updateTransportKeys(txn, oneOf(db).updateTransportKeys(txn, asList(
Collections.singletonMap(contactId, rotated)); new KeySet(keySetId, contactId, rotated),
new KeySet(keySetId2, null, rotated1))
);
// Schedule key rotation at the start of the next rotation period // Schedule key rotation at the start of the next rotation period
oneOf(scheduler).schedule(with(any(Runnable.class)), oneOf(scheduler).schedule(with(any(Runnable.class)),
with(rotationPeriodLength - 1), with(MILLISECONDS)); with(rotationPeriodLength - 1), with(MILLISECONDS));
@@ -101,18 +115,19 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
db, transportCrypto, dbExecutor, scheduler, clock, transportId, db, transportCrypto, dbExecutor, scheduler, clock, transportId,
maxLatency); maxLatency);
transportKeyManager.start(txn); transportKeyManager.start(txn);
assertTrue(transportKeyManager.canSendOutgoingStreams(contactId));
} }
@Test @Test
public void testKeysAreRotatedWhenAddingContact() throws Exception { public void testKeysAreRotatedWhenAddingContact() throws Exception {
boolean alice = random.nextBoolean(); boolean alice = random.nextBoolean();
TransportKeys transportKeys = createTransportKeys(999, 0); TransportKeys transportKeys = createTransportKeys(999, 0, true);
TransportKeys rotated = createTransportKeys(1000, 0); TransportKeys rotated = createTransportKeys(1000, 0, true);
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
999, alice); 999, alice, true);
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Get the current time (1 ms after start of rotation period 1000) // Get the current time (1 ms after start of rotation period 1000)
oneOf(clock).currentTimeMillis(); oneOf(clock).currentTimeMillis();
@@ -129,6 +144,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
} }
// Save the keys // Save the keys
oneOf(db).addTransportKeys(txn, contactId, rotated); oneOf(db).addTransportKeys(txn, contactId, rotated);
will(returnValue(keySetId));
}}); }});
TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
@@ -138,6 +154,39 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
long timestamp = rotationPeriodLength * 1000 - 1; long timestamp = rotationPeriodLength * 1000 - 1;
transportKeyManager.addContact(txn, contactId, masterKey, timestamp, transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
alice); alice);
assertTrue(transportKeyManager.canSendOutgoingStreams(contactId));
}
@Test
public void testKeysAreRotatedWhenAddingUnboundKeys() throws Exception {
boolean alice = random.nextBoolean();
TransportKeys transportKeys = createTransportKeys(999, 0, false);
TransportKeys rotated = createTransportKeys(1000, 0, false);
Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{
oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
999, alice, false);
will(returnValue(transportKeys));
// Get the current time (1 ms after start of rotation period 1000)
oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1000 + 1));
// Rotate the transport keys
oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
will(returnValue(rotated));
// Save the keys
oneOf(db).addTransportKeys(txn, null, rotated);
will(returnValue(keySetId));
}});
TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
db, transportCrypto, dbExecutor, scheduler, clock, transportId,
maxLatency);
// The timestamp is 1 ms before the start of rotation period 1000
long timestamp = rotationPeriodLength * 1000 - 1;
assertEquals(keySetId, transportKeyManager.addUnboundKeys(txn,
masterKey, timestamp, alice));
assertFalse(transportKeyManager.canSendOutgoingStreams(contactId));
} }
@Test @Test
@@ -149,6 +198,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
db, transportCrypto, dbExecutor, scheduler, clock, transportId, db, transportCrypto, dbExecutor, scheduler, clock, transportId,
maxLatency); maxLatency);
assertNull(transportKeyManager.getStreamContext(txn, contactId)); assertNull(transportKeyManager.getStreamContext(txn, contactId));
assertFalse(transportKeyManager.canSendOutgoingStreams(contactId));
} }
@Test @Test
@@ -157,29 +207,10 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
boolean alice = random.nextBoolean(); boolean alice = random.nextBoolean();
// The stream counter has been exhausted // The stream counter has been exhausted
TransportKeys transportKeys = createTransportKeys(1000, TransportKeys transportKeys = createTransportKeys(1000,
MAX_32_BIT_UNSIGNED + 1); MAX_32_BIT_UNSIGNED + 1, true);
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ expectAddContactNoRotation(alice, transportKeys, txn);
oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
1000, alice);
will(returnValue(transportKeys));
// Get the current time (the start of rotation period 1000)
oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1000));
// Encode the tags (3 sets)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(3).of(transportCrypto).encodeTag(
with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction());
}
// Rotate the transport keys (the keys are unaffected)
oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
will(returnValue(transportKeys));
// Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys);
}});
TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
db, transportCrypto, dbExecutor, scheduler, clock, transportId, db, transportCrypto, dbExecutor, scheduler, clock, transportId,
@@ -188,6 +219,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
long timestamp = rotationPeriodLength * 1000; long timestamp = rotationPeriodLength * 1000;
transportKeyManager.addContact(txn, contactId, masterKey, timestamp, transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
alice); alice);
assertFalse(transportKeyManager.canSendOutgoingStreams(contactId));
assertNull(transportKeyManager.getStreamContext(txn, contactId)); assertNull(transportKeyManager.getStreamContext(txn, contactId));
} }
@@ -196,30 +228,14 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
boolean alice = random.nextBoolean(); boolean alice = random.nextBoolean();
// The stream counter can be used one more time before being exhausted // The stream counter can be used one more time before being exhausted
TransportKeys transportKeys = createTransportKeys(1000, TransportKeys transportKeys = createTransportKeys(1000,
MAX_32_BIT_UNSIGNED); MAX_32_BIT_UNSIGNED, true);
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
expectAddContactNoRotation(alice, transportKeys, txn);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
1000, alice);
will(returnValue(transportKeys));
// Get the current time (the start of rotation period 1000)
oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1000));
// Encode the tags (3 sets)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(3).of(transportCrypto).encodeTag(
with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction());
}
// Rotate the transport keys (the keys are unaffected)
oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
will(returnValue(transportKeys));
// Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys);
// Increment the stream counter // Increment the stream counter
oneOf(db).incrementStreamCounter(txn, contactId, transportId, 1000); oneOf(db).incrementStreamCounter(txn, transportId, keySetId);
}}); }});
TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
@@ -230,6 +246,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
transportKeyManager.addContact(txn, contactId, masterKey, timestamp, transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
alice); alice);
// The first request should return a stream context // The first request should return a stream context
assertTrue(transportKeyManager.canSendOutgoingStreams(contactId));
StreamContext ctx = transportKeyManager.getStreamContext(txn, StreamContext ctx = transportKeyManager.getStreamContext(txn,
contactId); contactId);
assertNotNull(ctx); assertNotNull(ctx);
@@ -239,6 +256,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
assertEquals(headerKey, ctx.getHeaderKey()); assertEquals(headerKey, ctx.getHeaderKey());
assertEquals(MAX_32_BIT_UNSIGNED, ctx.getStreamNumber()); assertEquals(MAX_32_BIT_UNSIGNED, ctx.getStreamNumber());
// The second request should return null, the counter is exhausted // The second request should return null, the counter is exhausted
assertFalse(transportKeyManager.canSendOutgoingStreams(contactId));
assertNull(transportKeyManager.getStreamContext(txn, contactId)); assertNull(transportKeyManager.getStreamContext(txn, contactId));
} }
@@ -246,29 +264,10 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
public void testIncomingStreamContextIsNullIfTagIsNotFound() public void testIncomingStreamContextIsNullIfTagIsNotFound()
throws Exception { throws Exception {
boolean alice = random.nextBoolean(); boolean alice = random.nextBoolean();
TransportKeys transportKeys = createTransportKeys(1000, 0); TransportKeys transportKeys = createTransportKeys(1000, 0, true);
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ expectAddContactNoRotation(alice, transportKeys, txn);
oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
1000, alice);
will(returnValue(transportKeys));
// Get the current time (the start of rotation period 1000)
oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1000));
// Encode the tags (3 sets)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(3).of(transportCrypto).encodeTag(
with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction());
}
// Rotate the transport keys (the keys are unaffected)
oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
will(returnValue(transportKeys));
// Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys);
}});
TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
db, transportCrypto, dbExecutor, scheduler, clock, transportId, db, transportCrypto, dbExecutor, scheduler, clock, transportId,
@@ -277,6 +276,8 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
long timestamp = rotationPeriodLength * 1000; long timestamp = rotationPeriodLength * 1000;
transportKeyManager.addContact(txn, contactId, masterKey, timestamp, transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
alice); alice);
assertTrue(transportKeyManager.canSendOutgoingStreams(contactId));
// The tag should not be recognised
assertNull(transportKeyManager.getStreamContext(txn, assertNull(transportKeyManager.getStreamContext(txn,
new byte[TAG_LENGTH])); new byte[TAG_LENGTH]));
} }
@@ -284,14 +285,15 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
@Test @Test
public void testTagIsNotRecognisedTwice() throws Exception { public void testTagIsNotRecognisedTwice() throws Exception {
boolean alice = random.nextBoolean(); boolean alice = random.nextBoolean();
TransportKeys transportKeys = createTransportKeys(1000, 0); TransportKeys transportKeys = createTransportKeys(1000, 0, true);
Transaction txn = new Transaction(null, false);
// Keep a copy of the tags // Keep a copy of the tags
List<byte[]> tags = new ArrayList<>(); List<byte[]> tags = new ArrayList<>();
Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
1000, alice); 1000, alice, true);
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Get the current time (the start of rotation period 1000) // Get the current time (the start of rotation period 1000)
oneOf(clock).currentTimeMillis(); oneOf(clock).currentTimeMillis();
@@ -308,13 +310,14 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Save the keys // Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys); oneOf(db).addTransportKeys(txn, contactId, transportKeys);
will(returnValue(keySetId));
// Encode a new tag after sliding the window // Encode a new tag after sliding the window
oneOf(transportCrypto).encodeTag(with(any(byte[].class)), oneOf(transportCrypto).encodeTag(with(any(byte[].class)),
with(tagKey), with(PROTOCOL_VERSION), with(tagKey), with(PROTOCOL_VERSION),
with((long) REORDERING_WINDOW_SIZE)); with((long) REORDERING_WINDOW_SIZE));
will(new EncodeTagAction(tags)); will(new EncodeTagAction(tags));
// Save the reordering window (previous rotation period, base 1) // Save the reordering window (previous rotation period, base 1)
oneOf(db).setReorderingWindow(txn, contactId, transportId, 999, oneOf(db).setReorderingWindow(txn, keySetId, transportId, 999,
1, new byte[REORDERING_WINDOW_SIZE / 8]); 1, new byte[REORDERING_WINDOW_SIZE / 8]);
}}); }});
@@ -325,6 +328,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
long timestamp = rotationPeriodLength * 1000; long timestamp = rotationPeriodLength * 1000;
transportKeyManager.addContact(txn, contactId, masterKey, timestamp, transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
alice); alice);
assertTrue(transportKeyManager.canSendOutgoingStreams(contactId));
// Use the first tag (previous rotation period, stream number 0) // Use the first tag (previous rotation period, stream number 0)
assertEquals(REORDERING_WINDOW_SIZE * 3, tags.size()); assertEquals(REORDERING_WINDOW_SIZE * 3, tags.size());
byte[] tag = tags.get(0); byte[] tag = tags.get(0);
@@ -344,10 +348,10 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
@Test @Test
public void testKeysAreRotatedToCurrentPeriod() throws Exception { public void testKeysAreRotatedToCurrentPeriod() throws Exception {
TransportKeys transportKeys = createTransportKeys(1000, 0); TransportKeys transportKeys = createTransportKeys(1000, 0, true);
Map<ContactId, TransportKeys> loaded = Collection<KeySet> loaded =
Collections.singletonMap(contactId, transportKeys); singletonList(new KeySet(keySetId, contactId, transportKeys));
TransportKeys rotated = createTransportKeys(1001, 0); TransportKeys rotated = createTransportKeys(1001, 0, true);
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
Transaction txn1 = new Transaction(null, false); Transaction txn1 = new Transaction(null, false);
@@ -393,7 +397,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
} }
// Save the keys that were rotated // Save the keys that were rotated
oneOf(db).updateTransportKeys(txn1, oneOf(db).updateTransportKeys(txn1,
Collections.singletonMap(contactId, rotated)); singletonList(new KeySet(keySetId, contactId, rotated)));
// Schedule key rotation at the start of the next rotation period // Schedule key rotation at the start of the next rotation period
oneOf(scheduler).schedule(with(any(Runnable.class)), oneOf(scheduler).schedule(with(any(Runnable.class)),
with(rotationPeriodLength), with(MILLISECONDS)); with(rotationPeriodLength), with(MILLISECONDS));
@@ -406,10 +410,197 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
db, transportCrypto, dbExecutor, scheduler, clock, transportId, db, transportCrypto, dbExecutor, scheduler, clock, transportId,
maxLatency); maxLatency);
transportKeyManager.start(txn); transportKeyManager.start(txn);
assertTrue(transportKeyManager.canSendOutgoingStreams(contactId));
}
@Test
public void testBindingAndActivatingKeys() throws Exception {
boolean alice = random.nextBoolean();
TransportKeys transportKeys = createTransportKeys(1000, 0, false);
Transaction txn = new Transaction(null, false);
expectAddUnboundKeysNoRotation(alice, transportKeys, txn);
context.checking(new Expectations() {{
// When the keys are bound, encode the tags (3 sets)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(3).of(transportCrypto).encodeTag(
with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction());
}
// Save the key binding
oneOf(db).bindTransportKeys(txn, contactId, transportId, keySetId);
// Activate the keys
oneOf(db).setTransportKeysActive(txn, transportId, keySetId);
// Increment the stream counter
oneOf(db).incrementStreamCounter(txn, transportId, keySetId);
}});
TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
db, transportCrypto, dbExecutor, scheduler, clock, transportId,
maxLatency);
// The timestamp is at the start of rotation period 1000
long timestamp = rotationPeriodLength * 1000;
assertEquals(keySetId, transportKeyManager.addUnboundKeys(txn,
masterKey, timestamp, alice));
// The keys are unbound so no stream context should be returned
assertFalse(transportKeyManager.canSendOutgoingStreams(contactId));
assertNull(transportKeyManager.getStreamContext(txn, contactId));
transportKeyManager.bindKeys(txn, contactId, keySetId);
// The keys are inactive so no stream context should be returned
assertFalse(transportKeyManager.canSendOutgoingStreams(contactId));
assertNull(transportKeyManager.getStreamContext(txn, contactId));
transportKeyManager.activateKeys(txn, keySetId);
// The keys are active so a stream context should be returned
assertTrue(transportKeyManager.canSendOutgoingStreams(contactId));
StreamContext ctx = transportKeyManager.getStreamContext(txn,
contactId);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(transportId, ctx.getTransportId());
assertEquals(tagKey, ctx.getTagKey());
assertEquals(headerKey, ctx.getHeaderKey());
assertEquals(0L, ctx.getStreamNumber());
}
@Test
public void testRecognisingTagActivatesOutgoingKeys() throws Exception {
boolean alice = random.nextBoolean();
TransportKeys transportKeys = createTransportKeys(1000, 0, false);
Transaction txn = new Transaction(null, false);
// Keep a copy of the tags
List<byte[]> tags = new ArrayList<>();
expectAddUnboundKeysNoRotation(alice, transportKeys, txn);
context.checking(new Expectations() {{
// When the keys are bound, encode the tags (3 sets)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(3).of(transportCrypto).encodeTag(
with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction(tags));
}
// Save the key binding
oneOf(db).bindTransportKeys(txn, contactId, transportId, keySetId);
// Encode a new tag after sliding the window
oneOf(transportCrypto).encodeTag(with(any(byte[].class)),
with(tagKey), with(PROTOCOL_VERSION),
with((long) REORDERING_WINDOW_SIZE));
will(new EncodeTagAction(tags));
// Save the reordering window (previous rotation period, base 1)
oneOf(db).setReorderingWindow(txn, keySetId, transportId, 999,
1, new byte[REORDERING_WINDOW_SIZE / 8]);
// Activate the keys
oneOf(db).setTransportKeysActive(txn, transportId, keySetId);
// Increment the stream counter
oneOf(db).incrementStreamCounter(txn, transportId, keySetId);
}});
TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
db, transportCrypto, dbExecutor, scheduler, clock, transportId,
maxLatency);
// The timestamp is at the start of rotation period 1000
long timestamp = rotationPeriodLength * 1000;
assertEquals(keySetId, transportKeyManager.addUnboundKeys(txn,
masterKey, timestamp, alice));
transportKeyManager.bindKeys(txn, contactId, keySetId);
// The keys are inactive so no stream context should be returned
assertFalse(transportKeyManager.canSendOutgoingStreams(contactId));
assertNull(transportKeyManager.getStreamContext(txn, contactId));
// Recognising an incoming tag should activate the outgoing keys
assertEquals(REORDERING_WINDOW_SIZE * 3, tags.size());
byte[] tag = tags.get(0);
StreamContext ctx = transportKeyManager.getStreamContext(txn, tag);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(transportId, ctx.getTransportId());
assertEquals(tagKey, ctx.getTagKey());
assertEquals(headerKey, ctx.getHeaderKey());
assertEquals(0L, ctx.getStreamNumber());
// The keys are active so a stream context should be returned
assertTrue(transportKeyManager.canSendOutgoingStreams(contactId));
ctx = transportKeyManager.getStreamContext(txn, contactId);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(transportId, ctx.getTransportId());
assertEquals(tagKey, ctx.getTagKey());
assertEquals(headerKey, ctx.getHeaderKey());
assertEquals(0L, ctx.getStreamNumber());
}
@Test
public void testRemovingUnboundKeys() throws Exception {
boolean alice = random.nextBoolean();
TransportKeys transportKeys = createTransportKeys(1000, 0, false);
Transaction txn = new Transaction(null, false);
expectAddUnboundKeysNoRotation(alice, transportKeys, txn);
context.checking(new Expectations() {{
// Remove the unbound keys
oneOf(db).removeTransportKeys(txn, transportId, keySetId);
}});
TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
db, transportCrypto, dbExecutor, scheduler, clock, transportId,
maxLatency);
// The timestamp is at the start of rotation period 1000
long timestamp = rotationPeriodLength * 1000;
assertEquals(keySetId, transportKeyManager.addUnboundKeys(txn,
masterKey, timestamp, alice));
assertFalse(transportKeyManager.canSendOutgoingStreams(contactId));
transportKeyManager.removeKeys(txn, keySetId);
assertFalse(transportKeyManager.canSendOutgoingStreams(contactId));
}
private void expectAddContactNoRotation(boolean alice,
TransportKeys transportKeys, Transaction txn) throws Exception {
context.checking(new Expectations() {{
oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
1000, alice, true);
will(returnValue(transportKeys));
// Get the current time (the start of rotation period 1000)
oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1000));
// Encode the tags (3 sets)
for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
exactly(3).of(transportCrypto).encodeTag(
with(any(byte[].class)), with(tagKey),
with(PROTOCOL_VERSION), with(i));
will(new EncodeTagAction());
}
// Rotate the transport keys (the keys are unaffected)
oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
will(returnValue(transportKeys));
// Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys);
will(returnValue(keySetId));
}});
}
private void expectAddUnboundKeysNoRotation(boolean alice,
TransportKeys transportKeys, Transaction txn) throws Exception {
context.checking(new Expectations() {{
oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
1000, alice, false);
will(returnValue(transportKeys));
// Get the current time (the start of rotation period 1000)
oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1000));
// Rotate the transport keys (the keys are unaffected)
oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
will(returnValue(transportKeys));
// Save the unbound keys
oneOf(db).addTransportKeys(txn, null, transportKeys);
will(returnValue(keySetId));
}});
} }
private TransportKeys createTransportKeys(long rotationPeriod, private TransportKeys createTransportKeys(long rotationPeriod,
long streamCounter) { long streamCounter, boolean active) {
IncomingKeys inPrev = new IncomingKeys(tagKey, headerKey, IncomingKeys inPrev = new IncomingKeys(tagKey, headerKey,
rotationPeriod - 1); rotationPeriod - 1);
IncomingKeys inCurr = new IncomingKeys(tagKey, headerKey, IncomingKeys inCurr = new IncomingKeys(tagKey, headerKey,
@@ -417,7 +608,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
IncomingKeys inNext = new IncomingKeys(tagKey, headerKey, IncomingKeys inNext = new IncomingKeys(tagKey, headerKey,
rotationPeriod + 1); rotationPeriod + 1);
OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey, OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey,
rotationPeriod, streamCounter); rotationPeriod, streamCounter, active);
return new TransportKeys(transportId, inPrev, inCurr, inNext, outCurr); return new TransportKeys(transportId, inPrev, inCurr, inNext, outCurr);
} }