Support multiple sets of transport keys per contact.

This commit is contained in:
akwizgran
2018-03-13 17:20:46 +00:00
parent 309c7a4668
commit 78f2d48bc4
12 changed files with 430 additions and 273 deletions

View File

@@ -6,7 +6,7 @@ import javax.annotation.concurrent.Immutable;
/** /**
* Type-safe wrapper for an integer that uniquely identifies a contact within * Type-safe wrapper for an integer that uniquely identifies a contact within
* the scope of a single node. * the scope of the local device.
*/ */
@Immutable @Immutable
@NotNullByDefault @NotNullByDefault

View File

@@ -18,6 +18,8 @@ import org.briarproject.bramble.api.sync.MessageId;
import org.briarproject.bramble.api.sync.MessageStatus; import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.Offer; import org.briarproject.bramble.api.sync.Offer;
import org.briarproject.bramble.api.sync.Request; import org.briarproject.bramble.api.sync.Request;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import java.util.Collection; import java.util.Collection;
@@ -102,10 +104,11 @@ public interface DatabaseComponent {
throws DbException; throws DbException;
/** /**
* Stores transport keys for a newly added contact. * Stores the given transport keys, optionally binding them to the given
* contact, and returns a key set ID.
*/ */
void addTransportKeys(Transaction txn, ContactId c, TransportKeys k) KeySetId addTransportKeys(Transaction txn, @Nullable ContactId c,
throws DbException; TransportKeys k) throws DbException;
/** /**
* Returns true if the database contains the given contact for the given * Returns true if the database contains the given contact for the given
@@ -394,8 +397,8 @@ public interface DatabaseComponent {
* <p/> * <p/>
* Read-only. * Read-only.
*/ */
Map<ContactId, TransportKeys> getTransportKeys(Transaction txn, Collection<KeySet> getTransportKeys(Transaction txn, TransportId t)
TransportId t) throws DbException; throws DbException;
/** /**
* Increments the outgoing stream counter for the given contact and * Increments the outgoing stream counter for the given contact and
@@ -507,15 +510,15 @@ public interface DatabaseComponent {
Collection<MessageId> dependencies) throws DbException; Collection<MessageId> dependencies) throws DbException;
/** /**
* Sets the reordering window for the given contact and transport in the * Sets the reordering window for the given key set and transport in the
* given rotation period. * given rotation period.
*/ */
void setReorderingWindow(Transaction txn, ContactId c, TransportId t, void setReorderingWindow(Transaction txn, KeySetId k, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException; long rotationPeriod, long base, byte[] bitmap) throws DbException;
/** /**
* Stores the given transport keys, deleting any keys they have replaced. * Stores the given transport keys, deleting any keys they have replaced.
*/ */
void updateTransportKeys(Transaction txn, void updateTransportKeys(Transaction txn, Collection<KeySet> keys)
Map<ContactId, TransportKeys> keys) throws DbException; throws DbException;
} }

View File

@@ -0,0 +1,51 @@
package org.briarproject.bramble.api.transport;
import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import javax.annotation.Nullable;
import javax.annotation.concurrent.Immutable;
/**
* A set of transport keys for communicating with a contact. If the keys have
* not yet been bound to a contact, {@link #getContactId()}} returns null.
*/
@Immutable
@NotNullByDefault
public class KeySet {
private final KeySetId keySetId;
@Nullable
private final ContactId contactId;
private final TransportKeys transportKeys;
public KeySet(KeySetId keySetId, @Nullable ContactId contactId,
TransportKeys transportKeys) {
this.keySetId = keySetId;
this.contactId = contactId;
this.transportKeys = transportKeys;
}
public KeySetId getKeySetId() {
return keySetId;
}
@Nullable
public ContactId getContactId() {
return contactId;
}
public TransportKeys getTransportKeys() {
return transportKeys;
}
@Override
public int hashCode() {
return keySetId.hashCode();
}
@Override
public boolean equals(Object o) {
return o instanceof KeySet && keySetId.equals(((KeySet) o).keySetId);
}
}

View File

@@ -0,0 +1,36 @@
package org.briarproject.bramble.api.transport;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import javax.annotation.concurrent.Immutable;
/**
* Type-safe wrapper for an integer that uniquely identifies a set of transport
* keys within the scope of the local device.
* <p/>
* Key sets created on a given device must have increasing identifiers.
*/
@Immutable
@NotNullByDefault
public class KeySetId {
private final int id;
public KeySetId(int id) {
this.id = id;
}
public int getInt() {
return id;
}
@Override
public int hashCode() {
return id;
}
@Override
public boolean equals(Object o) {
return o instanceof KeySetId && id == ((KeySetId) o).id;
}
}

View File

@@ -21,6 +21,8 @@ import org.briarproject.bramble.api.sync.Message;
import org.briarproject.bramble.api.sync.MessageId; import org.briarproject.bramble.api.sync.MessageId;
import org.briarproject.bramble.api.sync.MessageStatus; import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.ValidationManager.State; import org.briarproject.bramble.api.sync.ValidationManager.State;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import java.util.Collection; import java.util.Collection;
@@ -123,9 +125,10 @@ interface Database<T> {
throws DbException; throws DbException;
/** /**
* Stores transport keys for a newly added contact. * Stores the given transport keys, optionally binding them to the given
* contact, and returns a key set ID.
*/ */
void addTransportKeys(T txn, ContactId c, TransportKeys k) KeySetId addTransportKeys(T txn, @Nullable ContactId c, TransportKeys k)
throws DbException; throws DbException;
/** /**
@@ -486,7 +489,7 @@ interface Database<T> {
* <p/> * <p/>
* Read-only. * Read-only.
*/ */
Map<ContactId, TransportKeys> getTransportKeys(T txn, TransportId t) Collection<KeySet> getTransportKeys(T txn, TransportId t)
throws DbException; throws DbException;
/** /**
@@ -619,10 +622,10 @@ interface Database<T> {
void setMessageState(T txn, MessageId m, State state) throws DbException; void setMessageState(T txn, MessageId m, State state) throws DbException;
/** /**
* Sets the reordering window for the given contact and transport in the * Sets the reordering window for the given key set and transport in the
* given rotation period. * given rotation period.
*/ */
void setReorderingWindow(T txn, ContactId c, TransportId t, void setReorderingWindow(T txn, KeySetId k, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException; long rotationPeriod, long base, byte[] bitmap) throws DbException;
/** /**
@@ -636,6 +639,5 @@ interface Database<T> {
/** /**
* Stores the given transport keys, deleting any keys they have replaced. * Stores the given transport keys, deleting any keys they have replaced.
*/ */
void updateTransportKeys(T txn, Map<ContactId, TransportKeys> keys) void updateTransportKeys(T txn, Collection<KeySet> keys) throws DbException;
throws DbException;
} }

View File

@@ -51,15 +51,15 @@ import org.briarproject.bramble.api.sync.event.MessageToAckEvent;
import org.briarproject.bramble.api.sync.event.MessageToRequestEvent; import org.briarproject.bramble.api.sync.event.MessageToRequestEvent;
import org.briarproject.bramble.api.sync.event.MessagesAckedEvent; import org.briarproject.bramble.api.sync.event.MessagesAckedEvent;
import org.briarproject.bramble.api.sync.event.MessagesSentEvent; import org.briarproject.bramble.api.sync.event.MessagesSentEvent;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.logging.Logger; import java.util.logging.Logger;
@@ -234,15 +234,15 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
} }
@Override @Override
public void addTransportKeys(Transaction transaction, ContactId c, public KeySetId addTransportKeys(Transaction transaction,
TransportKeys k) throws DbException { @Nullable ContactId c, TransportKeys k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException(); if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction); T txn = unbox(transaction);
if (!db.containsContact(txn, c)) if (c != null && !db.containsContact(txn, c))
throw new NoSuchContactException(); throw new NoSuchContactException();
if (!db.containsTransport(txn, k.getTransportId())) if (!db.containsTransport(txn, k.getTransportId()))
throw new NoSuchTransportException(); throw new NoSuchTransportException();
db.addTransportKeys(txn, c, k); return db.addTransportKeys(txn, c, k);
} }
@Override @Override
@@ -586,8 +586,8 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
} }
@Override @Override
public Map<ContactId, TransportKeys> getTransportKeys( public Collection<KeySet> getTransportKeys(Transaction transaction,
Transaction transaction, TransportId t) throws DbException { TransportId t) throws DbException {
T txn = unbox(transaction); T txn = unbox(transaction);
if (!db.containsTransport(txn, t)) if (!db.containsTransport(txn, t))
throw new NoSuchTransportException(); throw new NoSuchTransportException();
@@ -858,31 +858,25 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
} }
@Override @Override
public void setReorderingWindow(Transaction transaction, ContactId c, public void setReorderingWindow(Transaction transaction, KeySetId k,
TransportId t, long rotationPeriod, long base, byte[] bitmap) TransportId t, long rotationPeriod, long base, byte[] bitmap)
throws DbException { throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException(); if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction); T txn = unbox(transaction);
if (!db.containsContact(txn, c))
throw new NoSuchContactException();
if (!db.containsTransport(txn, t)) if (!db.containsTransport(txn, t))
throw new NoSuchTransportException(); throw new NoSuchTransportException();
db.setReorderingWindow(txn, c, t, rotationPeriod, base, bitmap); db.setReorderingWindow(txn, k, t, rotationPeriod, base, bitmap);
} }
@Override @Override
public void updateTransportKeys(Transaction transaction, public void updateTransportKeys(Transaction transaction,
Map<ContactId, TransportKeys> keys) throws DbException { Collection<KeySet> keys) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException(); if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction); T txn = unbox(transaction);
Map<ContactId, TransportKeys> filtered = new HashMap<>(); Collection<KeySet> filtered = new ArrayList<>();
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { for (KeySet ks : keys) {
ContactId c = e.getKey(); TransportId t = ks.getTransportKeys().getTransportId();
TransportKeys k = e.getValue(); if (db.containsTransport(txn, t)) filtered.add(ks);
if (db.containsContact(txn, c)
&& db.containsTransport(txn, k.getTransportId())) {
filtered.put(c, k);
}
} }
db.updateTransportKeys(txn, filtered); db.updateTransportKeys(txn, filtered);
} }

View File

@@ -25,6 +25,8 @@ import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.ValidationManager.State; import org.briarproject.bramble.api.sync.ValidationManager.State;
import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.transport.IncomingKeys; import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
@@ -223,37 +225,43 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " maxLatency INT NOT NULL," + " maxLatency INT NOT NULL,"
+ " PRIMARY KEY (transportId))"; + " PRIMARY KEY (transportId))";
private static final String CREATE_OUTGOING_KEYS =
"CREATE TABLE outgoingKeys"
+ " (transportId _STRING NOT NULL,"
+ " keySetId _COUNTER,"
+ " rotationPeriod BIGINT NOT NULL,"
+ " contactId INT," // Null if keys are not bound
+ " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL,"
+ " stream BIGINT NOT NULL,"
+ " PRIMARY KEY (transportId, keySetId),"
+ " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)"
+ " ON DELETE CASCADE,"
+ " UNIQUE (keySetId),"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE)";
private static final String CREATE_INCOMING_KEYS = private static final String CREATE_INCOMING_KEYS =
"CREATE TABLE incomingKeys" "CREATE TABLE incomingKeys"
+ " (contactId INT NOT NULL," + " (transportId _STRING NOT NULL,"
+ " transportId _STRING NOT NULL," + " keySetId INT NOT NULL,"
+ " rotationPeriod BIGINT NOT NULL," + " rotationPeriod BIGINT NOT NULL,"
+ " contactId INT," // Null if keys are not bound
+ " tagKey _SECRET NOT NULL," + " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL," + " headerKey _SECRET NOT NULL,"
+ " base BIGINT NOT NULL," + " base BIGINT NOT NULL,"
+ " bitmap _BINARY NOT NULL," + " bitmap _BINARY NOT NULL,"
+ " PRIMARY KEY (contactId, transportId, rotationPeriod)," + " PRIMARY KEY (transportId, keySetId, rotationPeriod),"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE,"
+ " FOREIGN KEY (transportId)" + " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)" + " REFERENCES transports (transportId)"
+ " ON DELETE CASCADE)"; + " ON DELETE CASCADE,"
+ " FOREIGN KEY (keySetId)"
private static final String CREATE_OUTGOING_KEYS = + " REFERENCES outgoingKeys (keySetId)"
"CREATE TABLE outgoingKeys" + " ON DELETE CASCADE,"
+ " (contactId INT NOT NULL,"
+ " transportId _STRING NOT NULL,"
+ " rotationPeriod BIGINT NOT NULL,"
+ " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL,"
+ " stream BIGINT NOT NULL,"
+ " PRIMARY KEY (contactId, transportId),"
+ " FOREIGN KEY (contactId)" + " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)" + " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE,"
+ " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)"
+ " ON DELETE CASCADE)"; + " ON DELETE CASCADE)";
private static final String INDEX_CONTACTS_BY_AUTHOR_ID = private static final String INDEX_CONTACTS_BY_AUTHOR_ID =
@@ -415,8 +423,8 @@ abstract class JdbcDatabase implements Database<Connection> {
s.executeUpdate(insertTypeNames(CREATE_OFFERS)); s.executeUpdate(insertTypeNames(CREATE_OFFERS));
s.executeUpdate(insertTypeNames(CREATE_STATUSES)); s.executeUpdate(insertTypeNames(CREATE_STATUSES));
s.executeUpdate(insertTypeNames(CREATE_TRANSPORTS)); s.executeUpdate(insertTypeNames(CREATE_TRANSPORTS));
s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS));
s.executeUpdate(insertTypeNames(CREATE_OUTGOING_KEYS)); s.executeUpdate(insertTypeNames(CREATE_OUTGOING_KEYS));
s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS));
s.close(); s.close();
} catch (SQLException e) { } catch (SQLException e) {
tryToClose(s); tryToClose(s);
@@ -865,52 +873,18 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
@Override @Override
public void addTransportKeys(Connection txn, ContactId c, TransportKeys k) public KeySetId addTransportKeys(Connection txn, @Nullable ContactId c,
throws DbException { TransportKeys k) throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null;
try { try {
// Store the incoming keys
String sql = "INSERT INTO incomingKeys (contactId, transportId,"
+ " rotationPeriod, tagKey, headerKey, base, bitmap)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setString(2, k.getTransportId().getString());
// Previous rotation period
IncomingKeys inPrev = k.getPreviousIncomingKeys();
ps.setLong(3, inPrev.getRotationPeriod());
ps.setBytes(4, inPrev.getTagKey().getBytes());
ps.setBytes(5, inPrev.getHeaderKey().getBytes());
ps.setLong(6, inPrev.getWindowBase());
ps.setBytes(7, inPrev.getWindowBitmap());
ps.addBatch();
// Current rotation period
IncomingKeys inCurr = k.getCurrentIncomingKeys();
ps.setLong(3, inCurr.getRotationPeriod());
ps.setBytes(4, inCurr.getTagKey().getBytes());
ps.setBytes(5, inCurr.getHeaderKey().getBytes());
ps.setLong(6, inCurr.getWindowBase());
ps.setBytes(7, inCurr.getWindowBitmap());
ps.addBatch();
// Next rotation period
IncomingKeys inNext = k.getNextIncomingKeys();
ps.setLong(3, inNext.getRotationPeriod());
ps.setBytes(4, inNext.getTagKey().getBytes());
ps.setBytes(5, inNext.getHeaderKey().getBytes());
ps.setLong(6, inNext.getWindowBase());
ps.setBytes(7, inNext.getWindowBitmap());
ps.addBatch();
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != 3) throw new DbStateException();
for (int rows : batchAffected)
if (rows != 1) throw new DbStateException();
ps.close();
// Store the outgoing keys // Store the outgoing keys
sql = "INSERT INTO outgoingKeys (contactId, transportId," String sql = "INSERT INTO outgoingKeys (contactId, transportId,"
+ " rotationPeriod, tagKey, headerKey, stream)" + " rotationPeriod, tagKey, headerKey, stream)"
+ " VALUES (?, ?, ?, ?, ?, ?)"; + " VALUES (?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); if (c == null) ps.setNull(1, INTEGER);
else ps.setInt(1, c.getInt());
ps.setString(2, k.getTransportId().getString()); ps.setString(2, k.getTransportId().getString());
OutgoingKeys outCurr = k.getCurrentOutgoingKeys(); OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
ps.setLong(3, outCurr.getRotationPeriod()); ps.setLong(3, outCurr.getRotationPeriod());
@@ -920,7 +894,57 @@ abstract class JdbcDatabase implements Database<Connection> {
int affected = ps.executeUpdate(); int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException(); if (affected != 1) throw new DbStateException();
ps.close(); ps.close();
// Get the new (highest) key set ID
sql = "SELECT keySetId FROM outgoingKeys"
+ " ORDER BY keySetId DESC LIMIT 1";
ps = txn.prepareStatement(sql);
rs = ps.executeQuery();
if (!rs.next()) throw new DbStateException();
KeySetId keySetId = new KeySetId(rs.getInt(1));
if (rs.next()) throw new DbStateException();
rs.close();
ps.close();
// Store the incoming keys
sql = "INSERT INTO incomingKeys (keySetId, contactId, transportId,"
+ " rotationPeriod, tagKey, headerKey, base, bitmap)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setInt(1, keySetId.getInt());
if (c == null) ps.setNull(2, INTEGER);
else ps.setInt(2, c.getInt());
ps.setString(3, k.getTransportId().getString());
// Previous rotation period
IncomingKeys inPrev = k.getPreviousIncomingKeys();
ps.setLong(4, inPrev.getRotationPeriod());
ps.setBytes(5, inPrev.getTagKey().getBytes());
ps.setBytes(6, inPrev.getHeaderKey().getBytes());
ps.setLong(7, inPrev.getWindowBase());
ps.setBytes(8, inPrev.getWindowBitmap());
ps.addBatch();
// Current rotation period
IncomingKeys inCurr = k.getCurrentIncomingKeys();
ps.setLong(4, inCurr.getRotationPeriod());
ps.setBytes(5, inCurr.getTagKey().getBytes());
ps.setBytes(6, inCurr.getHeaderKey().getBytes());
ps.setLong(7, inCurr.getWindowBase());
ps.setBytes(8, inCurr.getWindowBitmap());
ps.addBatch();
// Next rotation period
IncomingKeys inNext = k.getNextIncomingKeys();
ps.setLong(4, inNext.getRotationPeriod());
ps.setBytes(5, inNext.getTagKey().getBytes());
ps.setBytes(6, inNext.getHeaderKey().getBytes());
ps.setLong(7, inNext.getWindowBase());
ps.setBytes(8, inNext.getWindowBitmap());
ps.addBatch();
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != 3) throw new DbStateException();
for (int rows : batchAffected)
if (rows != 1) throw new DbStateException();
ps.close();
return keySetId;
} catch (SQLException e) { } catch (SQLException e) {
tryToClose(rs);
tryToClose(ps); tryToClose(ps);
throw new DbException(e); throw new DbException(e);
} }
@@ -2078,8 +2102,8 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
@Override @Override
public Map<ContactId, TransportKeys> getTransportKeys(Connection txn, public Collection<KeySet> getTransportKeys(Connection txn, TransportId t)
TransportId t) throws DbException { throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null; ResultSet rs = null;
try { try {
@@ -2088,7 +2112,7 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " base, bitmap" + " base, bitmap"
+ " FROM incomingKeys" + " FROM incomingKeys"
+ " WHERE transportId = ?" + " WHERE transportId = ?"
+ " ORDER BY contactId, rotationPeriod"; + " ORDER BY keySetId, rotationPeriod";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setString(1, t.getString()); ps.setString(1, t.getString());
rs = ps.executeQuery(); rs = ps.executeQuery();
@@ -2105,29 +2129,33 @@ abstract class JdbcDatabase implements Database<Connection> {
rs.close(); rs.close();
ps.close(); ps.close();
// Retrieve the outgoing keys in the same order // Retrieve the outgoing keys in the same order
sql = "SELECT contactId, rotationPeriod, tagKey, headerKey, stream" sql = "SELECT keySetId, contactId, rotationPeriod,"
+ " tagKey, headerKey, stream"
+ " FROM outgoingKeys" + " FROM outgoingKeys"
+ " WHERE transportId = ?" + " WHERE transportId = ?"
+ " ORDER BY contactId, rotationPeriod"; + " ORDER BY keySetId";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setString(1, t.getString()); ps.setString(1, t.getString());
rs = ps.executeQuery(); rs = ps.executeQuery();
Map<ContactId, TransportKeys> keys = new HashMap<>(); Collection<KeySet> keys = new ArrayList<>();
for (int i = 0; rs.next(); i++) { for (int i = 0; rs.next(); i++) {
// There should be three times as many incoming keys // There should be three times as many incoming keys
if (inKeys.size() < (i + 1) * 3) throw new DbStateException(); if (inKeys.size() < (i + 1) * 3) throw new DbStateException();
ContactId contactId = new ContactId(rs.getInt(1)); KeySetId keySetId = new KeySetId(rs.getInt(1));
long rotationPeriod = rs.getLong(2); ContactId contactId = new ContactId(rs.getInt(2));
SecretKey tagKey = new SecretKey(rs.getBytes(3)); if (rs.wasNull()) contactId = null;
SecretKey headerKey = new SecretKey(rs.getBytes(4)); long rotationPeriod = rs.getLong(3);
long streamCounter = rs.getLong(5); SecretKey tagKey = new SecretKey(rs.getBytes(4));
SecretKey headerKey = new SecretKey(rs.getBytes(5));
long streamCounter = rs.getLong(6);
OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey, OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey,
rotationPeriod, streamCounter); rotationPeriod, streamCounter);
IncomingKeys inPrev = inKeys.get(i * 3); IncomingKeys inPrev = inKeys.get(i * 3);
IncomingKeys inCurr = inKeys.get(i * 3 + 1); IncomingKeys inCurr = inKeys.get(i * 3 + 1);
IncomingKeys inNext = inKeys.get(i * 3 + 2); IncomingKeys inNext = inKeys.get(i * 3 + 2);
keys.put(contactId, new TransportKeys(t, inPrev, inCurr, TransportKeys transportKeys = new TransportKeys(t, inPrev,
inNext, outCurr)); inCurr, inNext, outCurr);
keys.add(new KeySet(keySetId, contactId, transportKeys));
} }
rs.close(); rs.close();
ps.close(); ps.close();
@@ -2791,18 +2819,18 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
@Override @Override
public void setReorderingWindow(Connection txn, ContactId c, TransportId t, public void setReorderingWindow(Connection txn, KeySetId k, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException { long rotationPeriod, long base, byte[] bitmap) throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
try { try {
String sql = "UPDATE incomingKeys SET base = ?, bitmap = ?" String sql = "UPDATE incomingKeys SET base = ?, bitmap = ?"
+ " WHERE contactId = ? AND transportId = ?" + " WHERE transportId = ? AND keySetId = ?"
+ " AND rotationPeriod = ?"; + " AND rotationPeriod = ?";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setLong(1, base); ps.setLong(1, base);
ps.setBytes(2, bitmap); ps.setBytes(2, bitmap);
ps.setInt(3, c.getInt()); ps.setString(3, t.getString());
ps.setString(4, t.getString()); ps.setInt(4, k.getInt());
ps.setLong(5, rotationPeriod); ps.setLong(5, rotationPeriod);
int affected = ps.executeUpdate(); int affected = ps.executeUpdate();
if (affected < 0 || affected > 1) throw new DbStateException(); if (affected < 0 || affected > 1) throw new DbStateException();
@@ -2848,45 +2876,31 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
@Override @Override
public void updateTransportKeys(Connection txn, public void updateTransportKeys(Connection txn, Collection<KeySet> keys)
Map<ContactId, TransportKeys> keys) throws DbException { throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
try { try {
// Delete any existing incoming keys // Delete any existing outgoing keys - this will also remove any
String sql = "DELETE FROM incomingKeys" // incoming keys with the same key set ID
+ " WHERE contactId = ?" String sql = "DELETE FROM outgoingKeys WHERE keySetId = ?";
+ " AND transportId = ?";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { for (KeySet ks : keys) {
ps.setInt(1, e.getKey().getInt()); ps.setInt(1, ks.getKeySetId().getInt());
ps.setString(2, e.getValue().getTransportId().getString());
ps.addBatch(); ps.addBatch();
} }
int[] batchAffected = ps.executeBatch(); int[] batchAffected = ps.executeBatch();
if (batchAffected.length != keys.size()) if (batchAffected.length != keys.size())
throw new DbStateException(); throw new DbStateException();
ps.close(); for (int rows: batchAffected)
// Delete any existing outgoing keys if (rows < 0) throw new DbStateException();
sql = "DELETE FROM outgoingKeys"
+ " WHERE contactId = ?"
+ " AND transportId = ?";
ps = txn.prepareStatement(sql);
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
ps.setInt(1, e.getKey().getInt());
ps.setString(2, e.getValue().getTransportId().getString());
ps.addBatch();
}
batchAffected = ps.executeBatch();
if (batchAffected.length != keys.size())
throw new DbStateException();
ps.close(); ps.close();
} catch (SQLException e) { } catch (SQLException e) {
tryToClose(ps); tryToClose(ps);
throw new DbException(e); throw new DbException(e);
} }
// Store the new keys // Store the new keys
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { for (KeySet ks : keys) {
addTransportKeys(txn, e.getKey(), e.getValue()); addTransportKeys(txn, ks.getContactId(), ks.getTransportKeys());
} }
} }
} }

View File

@@ -0,0 +1,34 @@
package org.briarproject.bramble.transport;
import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.transport.KeySetId;
import javax.annotation.Nullable;
public class MutableKeySet {
private final KeySetId keySetId;
@Nullable
private final ContactId contactId;
private final MutableTransportKeys transportKeys;
public MutableKeySet(KeySetId keySetId, @Nullable ContactId contactId,
MutableTransportKeys transportKeys) {
this.keySetId = keySetId;
this.contactId = contactId;
this.transportKeys = transportKeys;
}
public KeySetId getKeySetId() {
return keySetId;
}
@Nullable
public ContactId getContactId() {
return contactId;
}
public MutableTransportKeys getTransportKeys() {
return transportKeys;
}
}

View File

@@ -11,19 +11,25 @@ import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.system.Scheduler; import org.briarproject.bramble.api.system.Scheduler;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.StreamContext; import org.briarproject.bramble.api.transport.StreamContext;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.transport.ReorderingWindow.Change; import org.briarproject.bramble.transport.ReorderingWindow.Change;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger; import java.util.logging.Logger;
import javax.annotation.Nullable;
import javax.annotation.concurrent.ThreadSafe; import javax.annotation.concurrent.ThreadSafe;
import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MILLISECONDS;
@@ -47,12 +53,13 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private final Clock clock; private final Clock clock;
private final TransportId transportId; private final TransportId transportId;
private final long rotationPeriodLength; private final long rotationPeriodLength;
private final ReentrantLock lock; private final AtomicBoolean used = new AtomicBoolean(false);
private final ReentrantLock lock = new ReentrantLock();
// The following are locking: lock // The following are locking: lock
private final Map<Bytes, TagContext> inContexts; private final Collection<MutableKeySet> keys = new ArrayList<>();
private final Map<ContactId, MutableOutgoingKeys> outContexts; private final Map<Bytes, TagContext> inContexts = new HashMap<>();
private final Map<ContactId, MutableTransportKeys> keys; private final Map<ContactId, MutableKeySet> outContexts = new HashMap<>();
TransportKeyManagerImpl(DatabaseComponent db, TransportKeyManagerImpl(DatabaseComponent db,
TransportCrypto transportCrypto, Executor dbExecutor, TransportCrypto transportCrypto, Executor dbExecutor,
@@ -65,20 +72,16 @@ class TransportKeyManagerImpl implements TransportKeyManager {
this.clock = clock; this.clock = clock;
this.transportId = transportId; this.transportId = transportId;
rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE; rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE;
lock = new ReentrantLock();
inContexts = new HashMap<>();
outContexts = new HashMap<>();
keys = new HashMap<>();
} }
@Override @Override
public void start(Transaction txn) throws DbException { public void start(Transaction txn) throws DbException {
if (used.getAndSet(true)) throw new IllegalStateException();
long now = clock.currentTimeMillis(); long now = clock.currentTimeMillis();
lock.lock(); lock.lock();
try { try {
// Load the transport keys from the DB // Load the transport keys from the DB
Map<ContactId, TransportKeys> loaded = Collection<KeySet> loaded = db.getTransportKeys(txn, transportId);
db.getTransportKeys(txn, transportId);
// Rotate the keys to the current rotation period // Rotate the keys to the current rotation period
RotationResult rotationResult = rotateKeys(loaded, now); RotationResult rotationResult = rotateKeys(loaded, now);
// Initialise mutable state for all contacts // Initialise mutable state for all contacts
@@ -93,41 +96,51 @@ class TransportKeyManagerImpl implements TransportKeyManager {
scheduleKeyRotation(now); scheduleKeyRotation(now);
} }
private RotationResult rotateKeys(Map<ContactId, TransportKeys> keys, private RotationResult rotateKeys(Collection<KeySet> keys, long now) {
long now) {
RotationResult rotationResult = new RotationResult(); RotationResult rotationResult = new RotationResult();
long rotationPeriod = now / rotationPeriodLength; long rotationPeriod = now / rotationPeriodLength;
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { for (KeySet ks : keys) {
ContactId c = e.getKey(); TransportKeys k = ks.getTransportKeys();
TransportKeys k = e.getValue();
TransportKeys k1 = TransportKeys k1 =
transportCrypto.rotateTransportKeys(k, rotationPeriod); transportCrypto.rotateTransportKeys(k, rotationPeriod);
KeySet ks1 = new KeySet(ks.getKeySetId(), ks.getContactId(), k1);
if (k1.getRotationPeriod() > k.getRotationPeriod()) if (k1.getRotationPeriod() > k.getRotationPeriod())
rotationResult.rotated.put(c, k1); rotationResult.rotated.add(ks1);
rotationResult.current.put(c, k1); rotationResult.current.add(ks1);
} }
return rotationResult; return rotationResult;
} }
// Locking: lock // Locking: lock
private void addKeys(Map<ContactId, TransportKeys> m) { private void addKeys(Collection<KeySet> keys) {
for (Entry<ContactId, TransportKeys> e : m.entrySet()) for (KeySet ks : keys) {
addKeys(e.getKey(), new MutableTransportKeys(e.getValue())); addKeys(ks.getKeySetId(), ks.getContactId(),
new MutableTransportKeys(ks.getTransportKeys()));
}
} }
// Locking: lock // Locking: lock
private void addKeys(ContactId c, MutableTransportKeys m) { private void addKeys(KeySetId keySetId, @Nullable ContactId contactId,
encodeTags(c, m.getPreviousIncomingKeys()); MutableTransportKeys m) {
encodeTags(c, m.getCurrentIncomingKeys()); MutableKeySet ks = new MutableKeySet(keySetId, contactId, m);
encodeTags(c, m.getNextIncomingKeys()); keys.add(ks);
outContexts.put(c, m.getCurrentOutgoingKeys()); if (contactId != null) {
keys.put(c, m); encodeTags(keySetId, contactId, m.getPreviousIncomingKeys());
encodeTags(keySetId, contactId, m.getCurrentIncomingKeys());
encodeTags(keySetId, contactId, m.getNextIncomingKeys());
// Use the outgoing keys with the highest key set ID
MutableKeySet old = outContexts.get(contactId);
if (old == null || old.getKeySetId().getInt() < keySetId.getInt())
outContexts.put(contactId, ks);
}
} }
// Locking: lock // Locking: lock
private void encodeTags(ContactId c, MutableIncomingKeys inKeys) { private void encodeTags(KeySetId keySetId, ContactId contactId,
MutableIncomingKeys inKeys) {
for (long streamNumber : inKeys.getWindow().getUnseen()) { for (long streamNumber : inKeys.getWindow().getUnseen()) {
TagContext tagCtx = new TagContext(c, inKeys, streamNumber); TagContext tagCtx =
new TagContext(keySetId, contactId, inKeys, streamNumber);
byte[] tag = new byte[TAG_LENGTH]; byte[] tag = new byte[TAG_LENGTH];
transportCrypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION, transportCrypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION,
streamNumber); streamNumber);
@@ -169,10 +182,10 @@ class TransportKeyManagerImpl implements TransportKeyManager {
// Rotate the keys to the current rotation period if necessary // Rotate the keys to the current rotation period if necessary
rotationPeriod = clock.currentTimeMillis() / rotationPeriodLength; rotationPeriod = clock.currentTimeMillis() / rotationPeriodLength;
k = transportCrypto.rotateTransportKeys(k, rotationPeriod); k = transportCrypto.rotateTransportKeys(k, rotationPeriod);
// Initialise mutable state for the contact
addKeys(c, new MutableTransportKeys(k));
// Write the keys back to the DB // Write the keys back to the DB
db.addTransportKeys(txn, c, k); KeySetId keySetId = db.addTransportKeys(txn, c, k);
// Initialise mutable state for the contact
addKeys(keySetId, c, new MutableTransportKeys(k));
} finally { } finally {
lock.unlock(); lock.unlock();
} }
@@ -183,12 +196,18 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock(); lock.lock();
try { try {
// Remove mutable state for the contact // Remove mutable state for the contact
Iterator<Entry<Bytes, TagContext>> it = Iterator<Entry<Bytes, TagContext>> inContextsIter =
inContexts.entrySet().iterator(); inContexts.entrySet().iterator();
while (it.hasNext()) while (inContextsIter.hasNext()) {
if (it.next().getValue().contactId.equals(c)) it.remove(); ContactId c1 = inContextsIter.next().getValue().contactId;
if (c1.equals(c)) inContextsIter.remove();
}
outContexts.remove(c); outContexts.remove(c);
keys.remove(c); Iterator<MutableKeySet> keysIter = keys.iterator();
while (keysIter.hasNext()) {
ContactId c1 = keysIter.next().getContactId();
if (c1 != null && c1.equals(c)) keysIter.remove();
}
} finally { } finally {
lock.unlock(); lock.unlock();
} }
@@ -200,8 +219,10 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock(); lock.lock();
try { try {
// Look up the outgoing keys for the contact // Look up the outgoing keys for the contact
MutableOutgoingKeys outKeys = outContexts.get(c); MutableKeySet ks = outContexts.get(c);
if (outKeys == null) return null; if (ks == null) return null;
MutableOutgoingKeys outKeys =
ks.getTransportKeys().getCurrentOutgoingKeys();
if (outKeys.getStreamCounter() > MAX_32_BIT_UNSIGNED) return null; if (outKeys.getStreamCounter() > MAX_32_BIT_UNSIGNED) return null;
// Create a stream context // Create a stream context
StreamContext ctx = new StreamContext(c, transportId, StreamContext ctx = new StreamContext(c, transportId,
@@ -238,8 +259,9 @@ class TransportKeyManagerImpl implements TransportKeyManager {
byte[] addTag = new byte[TAG_LENGTH]; byte[] addTag = new byte[TAG_LENGTH];
transportCrypto.encodeTag(addTag, inKeys.getTagKey(), transportCrypto.encodeTag(addTag, inKeys.getTagKey(),
PROTOCOL_VERSION, streamNumber); PROTOCOL_VERSION, streamNumber);
inContexts.put(new Bytes(addTag), new TagContext( TagContext tagCtx1 = new TagContext(tagCtx.keySetId,
tagCtx.contactId, inKeys, streamNumber)); tagCtx.contactId, inKeys, streamNumber);
inContexts.put(new Bytes(addTag), tagCtx1);
} }
// Remove tags for any stream numbers removed from the window // Remove tags for any stream numbers removed from the window
for (long streamNumber : change.getRemoved()) { for (long streamNumber : change.getRemoved()) {
@@ -250,7 +272,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
inContexts.remove(new Bytes(removeTag)); inContexts.remove(new Bytes(removeTag));
} }
// Write the window back to the DB // Write the window back to the DB
db.setReorderingWindow(txn, tagCtx.contactId, transportId, db.setReorderingWindow(txn, tagCtx.keySetId, transportId,
inKeys.getRotationPeriod(), window.getBase(), inKeys.getRotationPeriod(), window.getBase(),
window.getBitmap()); window.getBitmap());
return ctx; return ctx;
@@ -264,9 +286,11 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock(); lock.lock();
try { try {
// Rotate the keys to the current rotation period // Rotate the keys to the current rotation period
Map<ContactId, TransportKeys> snapshot = new HashMap<>(); Collection<KeySet> snapshot = new ArrayList<>(keys.size());
for (Entry<ContactId, MutableTransportKeys> e : keys.entrySet()) for (MutableKeySet ks : keys) {
snapshot.put(e.getKey(), e.getValue().snapshot()); snapshot.add(new KeySet(ks.getKeySetId(), ks.getContactId(),
ks.getTransportKeys().snapshot()));
}
RotationResult rotationResult = rotateKeys(snapshot, now); RotationResult rotationResult = rotateKeys(snapshot, now);
// Rebuild the mutable state for all contacts // Rebuild the mutable state for all contacts
inContexts.clear(); inContexts.clear();
@@ -285,12 +309,14 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private static class TagContext { private static class TagContext {
private final KeySetId keySetId;
private final ContactId contactId; private final ContactId contactId;
private final MutableIncomingKeys inKeys; private final MutableIncomingKeys inKeys;
private final long streamNumber; private final long streamNumber;
private TagContext(ContactId contactId, MutableIncomingKeys inKeys, private TagContext(KeySetId keySetId, ContactId contactId,
long streamNumber) { MutableIncomingKeys inKeys, long streamNumber) {
this.keySetId = keySetId;
this.contactId = contactId; this.contactId = contactId;
this.inKeys = inKeys; this.inKeys = inKeys;
this.streamNumber = streamNumber; this.streamNumber = streamNumber;
@@ -299,11 +325,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private static class RotationResult { private static class RotationResult {
private final Map<ContactId, TransportKeys> current, rotated; private final Collection<KeySet> current = new ArrayList<>();
private final Collection<KeySet> rotated = new ArrayList<>();
private RotationResult() {
current = new HashMap<>();
rotated = new HashMap<>();
}
} }
} }

View File

@@ -44,6 +44,8 @@ import org.briarproject.bramble.api.sync.event.MessageToRequestEvent;
import org.briarproject.bramble.api.sync.event.MessagesAckedEvent; import org.briarproject.bramble.api.sync.event.MessagesAckedEvent;
import org.briarproject.bramble.api.sync.event.MessagesSentEvent; import org.briarproject.bramble.api.sync.event.MessagesSentEvent;
import org.briarproject.bramble.api.transport.IncomingKeys; import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.test.BrambleMockTestCase; import org.briarproject.bramble.test.BrambleMockTestCase;
@@ -55,12 +57,10 @@ import org.junit.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import static java.util.Collections.emptyMap; import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE; import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE;
import static org.briarproject.bramble.api.sync.Group.Visibility.SHARED; import static org.briarproject.bramble.api.sync.Group.Visibility.SHARED;
import static org.briarproject.bramble.api.sync.Group.Visibility.VISIBLE; import static org.briarproject.bramble.api.sync.Group.Visibility.VISIBLE;
@@ -100,6 +100,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
private final int maxLatency; private final int maxLatency;
private final ContactId contactId; private final ContactId contactId;
private final Contact contact; private final Contact contact;
private final KeySetId keySetId;
public DatabaseComponentImplTest() { public DatabaseComponentImplTest() {
clientId = new ClientId(getRandomString(123)); clientId = new ClientId(getRandomString(123));
@@ -121,6 +122,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
contactId = new ContactId(234); contactId = new ContactId(234);
contact = new Contact(contactId, author, localAuthor.getId(), contact = new Contact(contactId, author, localAuthor.getId(),
true, true); true, true);
keySetId = new KeySetId(345);
} }
private DatabaseComponent createDatabaseComponent(Database<Object> database, private DatabaseComponent createDatabaseComponent(Database<Object> database,
@@ -282,11 +284,11 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
throws Exception { throws Exception {
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// Check whether the contact is in the DB (which it's not) // Check whether the contact is in the DB (which it's not)
exactly(18).of(database).startTransaction(); exactly(17).of(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
exactly(18).of(database).containsContact(txn, contactId); exactly(17).of(database).containsContact(txn, contactId);
will(returnValue(false)); will(returnValue(false));
exactly(18).of(database).abortTransaction(txn); exactly(17).of(database).abortTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, eventBus, DatabaseComponent db = createDatabaseComponent(database, eventBus,
shutdown); shutdown);
@@ -454,17 +456,6 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
db.endTransaction(transaction); db.endTransaction(transaction);
} }
transaction = db.startTransaction(false);
try {
db.setReorderingWindow(transaction, contactId, transportId, 0, 0,
new byte[REORDERING_WINDOW_SIZE / 8]);
fail();
} catch (NoSuchContactException expected) {
// Expected
} finally {
db.endTransaction(transaction);
}
transaction = db.startTransaction(false); transaction = db.startTransaction(false);
try { try {
db.setGroupVisibility(transaction, contactId, groupId, SHARED); db.setGroupVisibility(transaction, contactId, groupId, SHARED);
@@ -779,7 +770,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
// Check whether the transport is in the DB (which it's not) // Check whether the transport is in the DB (which it's not)
exactly(4).of(database).startTransaction(); exactly(4).of(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
exactly(2).of(database).containsContact(txn, contactId); oneOf(database).containsContact(txn, contactId);
will(returnValue(true)); will(returnValue(true));
exactly(4).of(database).containsTransport(txn, transportId); exactly(4).of(database).containsTransport(txn, transportId);
will(returnValue(false)); will(returnValue(false));
@@ -830,7 +821,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
transaction = db.startTransaction(false); transaction = db.startTransaction(false);
try { try {
db.setReorderingWindow(transaction, contactId, transportId, 0, 0, db.setReorderingWindow(transaction, keySetId, transportId, 0, 0,
new byte[REORDERING_WINDOW_SIZE / 8]); new byte[REORDERING_WINDOW_SIZE / 8]);
fail(); fail();
} catch (NoSuchTransportException expected) { } catch (NoSuchTransportException expected) {
@@ -1303,15 +1294,13 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
@Test @Test
public void testTransportKeys() throws Exception { public void testTransportKeys() throws Exception {
TransportKeys transportKeys = createTransportKeys(); TransportKeys transportKeys = createTransportKeys();
Map<ContactId, TransportKeys> keys = Collection<KeySet> keys =
singletonMap(contactId, transportKeys); singletonList(new KeySet(keySetId, contactId, transportKeys));
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// startTransaction() // startTransaction()
oneOf(database).startTransaction(); oneOf(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
// updateTransportKeys() // updateTransportKeys()
oneOf(database).containsContact(txn, contactId);
will(returnValue(true));
oneOf(database).containsTransport(txn, transportId); oneOf(database).containsTransport(txn, transportId);
will(returnValue(true)); will(returnValue(true));
oneOf(database).updateTransportKeys(txn, keys); oneOf(database).updateTransportKeys(txn, keys);

View File

@@ -19,6 +19,8 @@ import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.ValidationManager.State; import org.briarproject.bramble.api.sync.ValidationManager.State;
import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.transport.IncomingKeys; import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.system.SystemClock; import org.briarproject.bramble.system.SystemClock;
@@ -34,7 +36,6 @@ import java.sql.Connection;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
@@ -42,6 +43,10 @@ import java.util.Random;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static java.util.concurrent.TimeUnit.SECONDS; import static java.util.concurrent.TimeUnit.SECONDS;
import static org.briarproject.bramble.api.db.Metadata.REMOVE; import static org.briarproject.bramble.api.db.Metadata.REMOVE;
import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE; import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE;
@@ -86,6 +91,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
private final Message message; private final Message message;
private final TransportId transportId; private final TransportId transportId;
private final ContactId contactId; private final ContactId contactId;
private final KeySetId keySetId;
JdbcDatabaseTest() throws Exception { JdbcDatabaseTest() throws Exception {
groupId = new GroupId(getRandomId()); groupId = new GroupId(getRandomId());
@@ -101,6 +107,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
message = new Message(messageId, groupId, timestamp, raw); message = new Message(messageId, groupId, timestamp, raw);
transportId = new TransportId("id"); transportId = new TransportId("id");
contactId = new ContactId(1); contactId = new ContactId(1);
keySetId = new KeySetId(1);
} }
protected abstract JdbcDatabase createDatabase(DatabaseConfig config, protected abstract JdbcDatabase createDatabase(DatabaseConfig config,
@@ -190,9 +197,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// The contact has not seen the message, so it should be sendable // The contact has not seen the message, so it should be sendable
Collection<MessageId> ids = Collection<MessageId> ids =
db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
ids = db.getMessagesToOffer(txn, contactId, 100); ids = db.getMessagesToOffer(txn, contactId, 100);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
// Changing the status to seen = true should make the message unsendable // Changing the status to seen = true should make the message unsendable
db.raiseSeenFlag(txn, contactId, messageId); db.raiseSeenFlag(txn, contactId, messageId);
@@ -228,9 +235,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// Marking the message delivered should make it sendable // Marking the message delivered should make it sendable
db.setMessageState(txn, messageId, DELIVERED); db.setMessageState(txn, messageId, DELIVERED);
ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
ids = db.getMessagesToOffer(txn, contactId, 100); ids = db.getMessagesToOffer(txn, contactId, 100);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
// Marking the message invalid should make it unsendable // Marking the message invalid should make it unsendable
db.setMessageState(txn, messageId, INVALID); db.setMessageState(txn, messageId, INVALID);
@@ -279,9 +286,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// Sharing the group should make the message sendable // Sharing the group should make the message sendable
db.setGroupVisibility(txn, contactId, groupId, true); db.setGroupVisibility(txn, contactId, groupId, true);
ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
ids = db.getMessagesToOffer(txn, contactId, 100); ids = db.getMessagesToOffer(txn, contactId, 100);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
// Unsharing the group should make the message unsendable // Unsharing the group should make the message unsendable
db.setGroupVisibility(txn, contactId, groupId, false); db.setGroupVisibility(txn, contactId, groupId, false);
@@ -324,9 +331,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// Sharing the message should make it sendable // Sharing the message should make it sendable
db.setMessageShared(txn, messageId); db.setMessageShared(txn, messageId);
ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
ids = db.getMessagesToOffer(txn, contactId, 100); ids = db.getMessagesToOffer(txn, contactId, 100);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
db.commitTransaction(txn); db.commitTransaction(txn);
db.close(); db.close();
@@ -352,7 +359,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// The message is just the right size to send // The message is just the right size to send
ids = db.getMessagesToSend(txn, contactId, size); ids = db.getMessagesToSend(txn, contactId, size);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
db.commitTransaction(txn); db.commitTransaction(txn);
db.close(); db.close();
@@ -384,7 +391,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
db.lowerAckFlag(txn, contactId, Arrays.asList(messageId, messageId1)); db.lowerAckFlag(txn, contactId, Arrays.asList(messageId, messageId1));
// Both message IDs should have been removed // Both message IDs should have been removed
assertEquals(Collections.emptyList(), db.getMessagesToAck(txn, assertEquals(emptyList(), db.getMessagesToAck(txn,
contactId, 1234)); contactId, 1234));
// Raise the ack flag again // Raise the ack flag again
@@ -415,7 +422,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// Retrieve the message from the database and mark it as sent // Retrieve the message from the database and mark it as sent
Collection<MessageId> ids = db.getMessagesToSend(txn, contactId, Collection<MessageId> ids = db.getMessagesToSend(txn, contactId,
ONE_MEGABYTE); ONE_MEGABYTE);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
db.updateExpiryTime(txn, contactId, messageId, Integer.MAX_VALUE); db.updateExpiryTime(txn, contactId, messageId, Integer.MAX_VALUE);
// The message should no longer be sendable // The message should no longer be sendable
@@ -626,31 +633,31 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// The group should not be visible to the contact // The group should not be visible to the contact
assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId)); assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId));
assertEquals(Collections.emptyMap(), assertEquals(emptyMap(),
db.getGroupVisibility(txn, groupId)); db.getGroupVisibility(txn, groupId));
// Make the group visible to the contact // Make the group visible to the contact
db.addGroupVisibility(txn, contactId, groupId, false); db.addGroupVisibility(txn, contactId, groupId, false);
assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId)); assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId));
assertEquals(Collections.singletonMap(contactId, false), assertEquals(singletonMap(contactId, false),
db.getGroupVisibility(txn, groupId)); db.getGroupVisibility(txn, groupId));
// Share the group with the contact // Share the group with the contact
db.setGroupVisibility(txn, contactId, groupId, true); db.setGroupVisibility(txn, contactId, groupId, true);
assertEquals(SHARED, db.getGroupVisibility(txn, contactId, groupId)); assertEquals(SHARED, db.getGroupVisibility(txn, contactId, groupId));
assertEquals(Collections.singletonMap(contactId, true), assertEquals(singletonMap(contactId, true),
db.getGroupVisibility(txn, groupId)); db.getGroupVisibility(txn, groupId));
// Unshare the group with the contact // Unshare the group with the contact
db.setGroupVisibility(txn, contactId, groupId, false); db.setGroupVisibility(txn, contactId, groupId, false);
assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId)); assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId));
assertEquals(Collections.singletonMap(contactId, false), assertEquals(singletonMap(contactId, false),
db.getGroupVisibility(txn, groupId)); db.getGroupVisibility(txn, groupId));
// Make the group invisible again // Make the group invisible again
db.removeGroupVisibility(txn, contactId, groupId); db.removeGroupVisibility(txn, contactId, groupId);
assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId)); assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId));
assertEquals(Collections.emptyMap(), assertEquals(emptyMap(),
db.getGroupVisibility(txn, groupId)); db.getGroupVisibility(txn, groupId));
db.commitTransaction(txn); db.commitTransaction(txn);
@@ -665,24 +672,22 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Initially there should be no transport keys in the database // Initially there should be no transport keys in the database
assertEquals(Collections.emptyMap(), assertEquals(emptyList(), db.getTransportKeys(txn, transportId));
db.getTransportKeys(txn, transportId));
// Add the contact, the transport and the transport keys // Add the contact, the transport and the transport keys
db.addLocalAuthor(txn, localAuthor); db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
true, true)); true, true));
db.addTransport(txn, transportId, 123); db.addTransport(txn, transportId, 123);
db.addTransportKeys(txn, contactId, keys); assertEquals(keySetId, db.addTransportKeys(txn, contactId, keys));
// Retrieve the transport keys // Retrieve the transport keys
Map<ContactId, TransportKeys> newKeys = Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId);
db.getTransportKeys(txn, transportId);
assertEquals(1, newKeys.size()); assertEquals(1, newKeys.size());
Entry<ContactId, TransportKeys> e = KeySet ks = newKeys.iterator().next();
newKeys.entrySet().iterator().next(); assertEquals(keySetId, ks.getKeySetId());
assertEquals(contactId, e.getKey()); assertEquals(contactId, ks.getContactId());
TransportKeys k = e.getValue(); TransportKeys k = ks.getTransportKeys();
assertEquals(transportId, k.getTransportId()); assertEquals(transportId, k.getTransportId());
assertKeysEquals(keys.getPreviousIncomingKeys(), assertKeysEquals(keys.getPreviousIncomingKeys(),
k.getPreviousIncomingKeys()); k.getPreviousIncomingKeys());
@@ -695,8 +700,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// Removing the contact should remove the transport keys // Removing the contact should remove the transport keys
db.removeContact(txn, contactId); db.removeContact(txn, contactId);
assertEquals(Collections.emptyMap(), assertEquals(emptyList(), db.getTransportKeys(txn, transportId));
db.getTransportKeys(txn, transportId));
db.commitTransaction(txn); db.commitTransaction(txn);
db.close(); db.close();
@@ -735,18 +739,18 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
true, true)); true, true));
db.addTransport(txn, transportId, 123); db.addTransport(txn, transportId, 123);
db.updateTransportKeys(txn, Collections.singletonMap(contactId, keys)); db.updateTransportKeys(txn,
singletonList(new KeySet(keySetId, contactId, keys)));
// Increment the stream counter twice and retrieve the transport keys // Increment the stream counter twice and retrieve the transport keys
db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod); db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod);
db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod); db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod);
Map<ContactId, TransportKeys> newKeys = Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId);
db.getTransportKeys(txn, transportId);
assertEquals(1, newKeys.size()); assertEquals(1, newKeys.size());
Entry<ContactId, TransportKeys> e = KeySet ks = newKeys.iterator().next();
newKeys.entrySet().iterator().next(); assertEquals(keySetId, ks.getKeySetId());
assertEquals(contactId, e.getKey()); assertEquals(contactId, ks.getContactId());
TransportKeys k = e.getValue(); TransportKeys k = ks.getTransportKeys();
assertEquals(transportId, k.getTransportId()); assertEquals(transportId, k.getTransportId());
OutgoingKeys outCurr = k.getCurrentOutgoingKeys(); OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
assertEquals(rotationPeriod, outCurr.getRotationPeriod()); assertEquals(rotationPeriod, outCurr.getRotationPeriod());
@@ -771,19 +775,19 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
true, true)); true, true));
db.addTransport(txn, transportId, 123); db.addTransport(txn, transportId, 123);
db.updateTransportKeys(txn, Collections.singletonMap(contactId, keys)); db.updateTransportKeys(txn,
singletonList(new KeySet(keySetId, contactId, keys)));
// Update the reordering window and retrieve the transport keys // Update the reordering window and retrieve the transport keys
new Random().nextBytes(bitmap); new Random().nextBytes(bitmap);
db.setReorderingWindow(txn, contactId, transportId, rotationPeriod, db.setReorderingWindow(txn, keySetId, transportId, rotationPeriod,
base + 1, bitmap); base + 1, bitmap);
Map<ContactId, TransportKeys> newKeys = Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId);
db.getTransportKeys(txn, transportId);
assertEquals(1, newKeys.size()); assertEquals(1, newKeys.size());
Entry<ContactId, TransportKeys> e = KeySet ks = newKeys.iterator().next();
newKeys.entrySet().iterator().next(); assertEquals(keySetId, ks.getKeySetId());
assertEquals(contactId, e.getKey()); assertEquals(contactId, ks.getContactId());
TransportKeys k = e.getValue(); TransportKeys k = ks.getTransportKeys();
assertEquals(transportId, k.getTransportId()); assertEquals(transportId, k.getTransportId());
IncomingKeys inCurr = k.getCurrentIncomingKeys(); IncomingKeys inCurr = k.getCurrentIncomingKeys();
assertEquals(rotationPeriod, inCurr.getRotationPeriod()); assertEquals(rotationPeriod, inCurr.getRotationPeriod());
@@ -830,18 +834,18 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
db.addLocalAuthor(txn, localAuthor); db.addLocalAuthor(txn, localAuthor);
Collection<ContactId> contacts = Collection<ContactId> contacts =
db.getContacts(txn, localAuthor.getId()); db.getContacts(txn, localAuthor.getId());
assertEquals(Collections.emptyList(), contacts); assertEquals(emptyList(), contacts);
// Add a contact associated with the local author // Add a contact associated with the local author
assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
true, true)); true, true));
contacts = db.getContacts(txn, localAuthor.getId()); contacts = db.getContacts(txn, localAuthor.getId());
assertEquals(Collections.singletonList(contactId), contacts); assertEquals(singletonList(contactId), contacts);
// Remove the local author - the contact should be removed // Remove the local author - the contact should be removed
db.removeLocalAuthor(txn, localAuthor.getId()); db.removeLocalAuthor(txn, localAuthor.getId());
contacts = db.getContacts(txn, localAuthor.getId()); contacts = db.getContacts(txn, localAuthor.getId());
assertEquals(Collections.emptyList(), contacts); assertEquals(emptyList(), contacts);
assertFalse(db.containsContact(txn, contactId)); assertFalse(db.containsContact(txn, contactId));
db.commitTransaction(txn); db.commitTransaction(txn);
@@ -1560,9 +1564,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
// The message should be sendable // The message should be sendable
Collection<MessageId> ids = db.getMessagesToSend(txn, contactId, Collection<MessageId> ids = db.getMessagesToSend(txn, contactId,
ONE_MEGABYTE); ONE_MEGABYTE);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
ids = db.getMessagesToOffer(txn, contactId, 100); ids = db.getMessagesToOffer(txn, contactId, 100);
assertEquals(Collections.singletonList(messageId), ids); assertEquals(singletonList(messageId), ids);
// The raw message should not be null // The raw message should not be null
assertNotNull(db.getRawMessage(txn, messageId)); assertNotNull(db.getRawMessage(txn, messageId));

View File

@@ -8,6 +8,8 @@ import org.briarproject.bramble.api.db.Transaction;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.transport.IncomingKeys; import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.StreamContext; import org.briarproject.bramble.api.transport.StreamContext;
import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.api.transport.TransportKeys;
@@ -22,14 +24,13 @@ import org.junit.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.briarproject.bramble.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE; import static org.briarproject.bramble.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE;
import static org.briarproject.bramble.api.transport.TransportConstants.PROTOCOL_VERSION; import static org.briarproject.bramble.api.transport.TransportConstants.PROTOCOL_VERSION;
@@ -55,6 +56,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
private final long rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE; private final long rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE;
private final ContactId contactId = new ContactId(123); private final ContactId contactId = new ContactId(123);
private final ContactId contactId1 = new ContactId(234); private final ContactId contactId1 = new ContactId(234);
private final KeySetId keySetId = new KeySetId(345);
private final SecretKey tagKey = TestUtils.getSecretKey(); private final SecretKey tagKey = TestUtils.getSecretKey();
private final SecretKey headerKey = TestUtils.getSecretKey(); private final SecretKey headerKey = TestUtils.getSecretKey();
private final SecretKey masterKey = TestUtils.getSecretKey(); private final SecretKey masterKey = TestUtils.getSecretKey();
@@ -62,11 +64,12 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
@Test @Test
public void testKeysAreRotatedAtStartup() throws Exception { public void testKeysAreRotatedAtStartup() throws Exception {
Map<ContactId, TransportKeys> loaded = new LinkedHashMap<>();
TransportKeys shouldRotate = createTransportKeys(900, 0); TransportKeys shouldRotate = createTransportKeys(900, 0);
TransportKeys shouldNotRotate = createTransportKeys(1000, 0); TransportKeys shouldNotRotate = createTransportKeys(1000, 0);
loaded.put(contactId, shouldRotate); Collection<KeySet> loaded = asList(
loaded.put(contactId1, shouldNotRotate); new KeySet(keySetId, contactId, shouldRotate),
new KeySet(keySetId, contactId1, shouldNotRotate)
);
TransportKeys rotated = createTransportKeys(1000, 0); TransportKeys rotated = createTransportKeys(1000, 0);
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
@@ -91,7 +94,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
} }
// Save the keys that were rotated // Save the keys that were rotated
oneOf(db).updateTransportKeys(txn, oneOf(db).updateTransportKeys(txn,
Collections.singletonMap(contactId, rotated)); singletonList(new KeySet(keySetId, contactId, rotated)));
// Schedule key rotation at the start of the next rotation period // Schedule key rotation at the start of the next rotation period
oneOf(scheduler).schedule(with(any(Runnable.class)), oneOf(scheduler).schedule(with(any(Runnable.class)),
with(rotationPeriodLength - 1), with(MILLISECONDS)); with(rotationPeriodLength - 1), with(MILLISECONDS));
@@ -129,6 +132,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
} }
// Save the keys // Save the keys
oneOf(db).addTransportKeys(txn, contactId, rotated); oneOf(db).addTransportKeys(txn, contactId, rotated);
will(returnValue(keySetId));
}}); }});
TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
@@ -179,6 +183,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Save the keys // Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys); oneOf(db).addTransportKeys(txn, contactId, transportKeys);
will(returnValue(keySetId));
}}); }});
TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
@@ -218,6 +223,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Save the keys // Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys); oneOf(db).addTransportKeys(txn, contactId, transportKeys);
will(returnValue(keySetId));
// Increment the stream counter // Increment the stream counter
oneOf(db).incrementStreamCounter(txn, contactId, transportId, 1000); oneOf(db).incrementStreamCounter(txn, contactId, transportId, 1000);
}}); }});
@@ -268,6 +274,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Save the keys // Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys); oneOf(db).addTransportKeys(txn, contactId, transportKeys);
will(returnValue(keySetId));
}}); }});
TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
@@ -308,13 +315,14 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
will(returnValue(transportKeys)); will(returnValue(transportKeys));
// Save the keys // Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys); oneOf(db).addTransportKeys(txn, contactId, transportKeys);
will(returnValue(keySetId));
// Encode a new tag after sliding the window // Encode a new tag after sliding the window
oneOf(transportCrypto).encodeTag(with(any(byte[].class)), oneOf(transportCrypto).encodeTag(with(any(byte[].class)),
with(tagKey), with(PROTOCOL_VERSION), with(tagKey), with(PROTOCOL_VERSION),
with((long) REORDERING_WINDOW_SIZE)); with((long) REORDERING_WINDOW_SIZE));
will(new EncodeTagAction(tags)); will(new EncodeTagAction(tags));
// Save the reordering window (previous rotation period, base 1) // Save the reordering window (previous rotation period, base 1)
oneOf(db).setReorderingWindow(txn, contactId, transportId, 999, oneOf(db).setReorderingWindow(txn, keySetId, transportId, 999,
1, new byte[REORDERING_WINDOW_SIZE / 8]); 1, new byte[REORDERING_WINDOW_SIZE / 8]);
}}); }});
@@ -345,8 +353,8 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
@Test @Test
public void testKeysAreRotatedToCurrentPeriod() throws Exception { public void testKeysAreRotatedToCurrentPeriod() throws Exception {
TransportKeys transportKeys = createTransportKeys(1000, 0); TransportKeys transportKeys = createTransportKeys(1000, 0);
Map<ContactId, TransportKeys> loaded = Collection<KeySet> loaded =
Collections.singletonMap(contactId, transportKeys); singletonList(new KeySet(keySetId, contactId, transportKeys));
TransportKeys rotated = createTransportKeys(1001, 0); TransportKeys rotated = createTransportKeys(1001, 0);
Transaction txn = new Transaction(null, false); Transaction txn = new Transaction(null, false);
Transaction txn1 = new Transaction(null, false); Transaction txn1 = new Transaction(null, false);
@@ -393,7 +401,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
} }
// Save the keys that were rotated // Save the keys that were rotated
oneOf(db).updateTransportKeys(txn1, oneOf(db).updateTransportKeys(txn1,
Collections.singletonMap(contactId, rotated)); singletonList(new KeySet(keySetId, contactId, rotated)));
// Schedule key rotation at the start of the next rotation period // Schedule key rotation at the start of the next rotation period
oneOf(scheduler).schedule(with(any(Runnable.class)), oneOf(scheduler).schedule(with(any(Runnable.class)),
with(rotationPeriodLength), with(MILLISECONDS)); with(rotationPeriodLength), with(MILLISECONDS));