diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java b/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java
index 9e11e5557..21293599f 100644
--- a/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java
+++ b/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java
@@ -6,7 +6,7 @@ import javax.annotation.concurrent.Immutable;
/**
* Type-safe wrapper for an integer that uniquely identifies a contact within
- * the scope of a single node.
+ * the scope of the local device.
*/
@Immutable
@NotNullByDefault
diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java b/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java
index 08fbe4540..45f721f22 100644
--- a/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java
+++ b/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java
@@ -18,6 +18,8 @@ import org.briarproject.bramble.api.sync.MessageId;
import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.Offer;
import org.briarproject.bramble.api.sync.Request;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.TransportKeys;
import java.util.Collection;
@@ -102,10 +104,11 @@ public interface DatabaseComponent {
throws DbException;
/**
- * Stores transport keys for a newly added contact.
+ * Stores the given transport keys, optionally binding them to the given
+ * contact, and returns a key set ID.
*/
- void addTransportKeys(Transaction txn, ContactId c, TransportKeys k)
- throws DbException;
+ KeySetId addTransportKeys(Transaction txn, @Nullable ContactId c,
+ TransportKeys k) throws DbException;
/**
* Returns true if the database contains the given contact for the given
@@ -394,8 +397,8 @@ public interface DatabaseComponent {
*
* Read-only.
*/
- Map getTransportKeys(Transaction txn,
- TransportId t) throws DbException;
+ Collection getTransportKeys(Transaction txn, TransportId t)
+ throws DbException;
/**
* Increments the outgoing stream counter for the given contact and
@@ -507,15 +510,15 @@ public interface DatabaseComponent {
Collection dependencies) throws DbException;
/**
- * Sets the reordering window for the given contact and transport in the
+ * Sets the reordering window for the given key set and transport in the
* given rotation period.
*/
- void setReorderingWindow(Transaction txn, ContactId c, TransportId t,
+ void setReorderingWindow(Transaction txn, KeySetId k, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException;
/**
* Stores the given transport keys, deleting any keys they have replaced.
*/
- void updateTransportKeys(Transaction txn,
- Map keys) throws DbException;
+ void updateTransportKeys(Transaction txn, Collection keys)
+ throws DbException;
}
diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySet.java b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySet.java
new file mode 100644
index 000000000..9cc8f63c2
--- /dev/null
+++ b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySet.java
@@ -0,0 +1,51 @@
+package org.briarproject.bramble.api.transport;
+
+import org.briarproject.bramble.api.contact.ContactId;
+import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
+
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.Immutable;
+
+/**
+ * A set of transport keys for communicating with a contact. If the keys have
+ * not yet been bound to a contact, {@link #getContactId()}} returns null.
+ */
+@Immutable
+@NotNullByDefault
+public class KeySet {
+
+ private final KeySetId keySetId;
+ @Nullable
+ private final ContactId contactId;
+ private final TransportKeys transportKeys;
+
+ public KeySet(KeySetId keySetId, @Nullable ContactId contactId,
+ TransportKeys transportKeys) {
+ this.keySetId = keySetId;
+ this.contactId = contactId;
+ this.transportKeys = transportKeys;
+ }
+
+ public KeySetId getKeySetId() {
+ return keySetId;
+ }
+
+ @Nullable
+ public ContactId getContactId() {
+ return contactId;
+ }
+
+ public TransportKeys getTransportKeys() {
+ return transportKeys;
+ }
+
+ @Override
+ public int hashCode() {
+ return keySetId.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return o instanceof KeySet && keySetId.equals(((KeySet) o).keySetId);
+ }
+}
diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySetId.java b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySetId.java
new file mode 100644
index 000000000..1f872e72a
--- /dev/null
+++ b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySetId.java
@@ -0,0 +1,36 @@
+package org.briarproject.bramble.api.transport;
+
+import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
+
+import javax.annotation.concurrent.Immutable;
+
+/**
+ * Type-safe wrapper for an integer that uniquely identifies a set of transport
+ * keys within the scope of the local device.
+ *
+ * Key sets created on a given device must have increasing identifiers.
+ */
+@Immutable
+@NotNullByDefault
+public class KeySetId {
+
+ private final int id;
+
+ public KeySetId(int id) {
+ this.id = id;
+ }
+
+ public int getInt() {
+ return id;
+ }
+
+ @Override
+ public int hashCode() {
+ return id;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return o instanceof KeySetId && id == ((KeySetId) o).id;
+ }
+}
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java b/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java
index bc3166d56..e8597eaba 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java
@@ -21,6 +21,8 @@ import org.briarproject.bramble.api.sync.Message;
import org.briarproject.bramble.api.sync.MessageId;
import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.ValidationManager.State;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.TransportKeys;
import java.util.Collection;
@@ -123,9 +125,10 @@ interface Database {
throws DbException;
/**
- * Stores transport keys for a newly added contact.
+ * Stores the given transport keys, optionally binding them to the given
+ * contact, and returns a key set ID.
*/
- void addTransportKeys(T txn, ContactId c, TransportKeys k)
+ KeySetId addTransportKeys(T txn, @Nullable ContactId c, TransportKeys k)
throws DbException;
/**
@@ -486,7 +489,7 @@ interface Database {
*
* Read-only.
*/
- Map getTransportKeys(T txn, TransportId t)
+ Collection getTransportKeys(T txn, TransportId t)
throws DbException;
/**
@@ -619,10 +622,10 @@ interface Database {
void setMessageState(T txn, MessageId m, State state) throws DbException;
/**
- * Sets the reordering window for the given contact and transport in the
+ * Sets the reordering window for the given key set and transport in the
* given rotation period.
*/
- void setReorderingWindow(T txn, ContactId c, TransportId t,
+ void setReorderingWindow(T txn, KeySetId k, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException;
/**
@@ -636,6 +639,5 @@ interface Database {
/**
* Stores the given transport keys, deleting any keys they have replaced.
*/
- void updateTransportKeys(T txn, Map keys)
- throws DbException;
+ void updateTransportKeys(T txn, Collection keys) throws DbException;
}
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java
index 013233f39..f90f9e3ff 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java
@@ -51,15 +51,15 @@ import org.briarproject.bramble.api.sync.event.MessageToAckEvent;
import org.briarproject.bramble.api.sync.event.MessageToRequestEvent;
import org.briarproject.bramble.api.sync.event.MessagesAckedEvent;
import org.briarproject.bramble.api.sync.event.MessagesSentEvent;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.TransportKeys;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.logging.Logger;
@@ -234,15 +234,15 @@ class DatabaseComponentImpl implements DatabaseComponent {
}
@Override
- public void addTransportKeys(Transaction transaction, ContactId c,
- TransportKeys k) throws DbException {
+ public KeySetId addTransportKeys(Transaction transaction,
+ @Nullable ContactId c, TransportKeys k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
- if (!db.containsContact(txn, c))
+ if (c != null && !db.containsContact(txn, c))
throw new NoSuchContactException();
if (!db.containsTransport(txn, k.getTransportId()))
throw new NoSuchTransportException();
- db.addTransportKeys(txn, c, k);
+ return db.addTransportKeys(txn, c, k);
}
@Override
@@ -586,8 +586,8 @@ class DatabaseComponentImpl implements DatabaseComponent {
}
@Override
- public Map getTransportKeys(
- Transaction transaction, TransportId t) throws DbException {
+ public Collection getTransportKeys(Transaction transaction,
+ TransportId t) throws DbException {
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
@@ -858,31 +858,25 @@ class DatabaseComponentImpl implements DatabaseComponent {
}
@Override
- public void setReorderingWindow(Transaction transaction, ContactId c,
+ public void setReorderingWindow(Transaction transaction, KeySetId k,
TransportId t, long rotationPeriod, long base, byte[] bitmap)
throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
- if (!db.containsContact(txn, c))
- throw new NoSuchContactException();
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
- db.setReorderingWindow(txn, c, t, rotationPeriod, base, bitmap);
+ db.setReorderingWindow(txn, k, t, rotationPeriod, base, bitmap);
}
@Override
public void updateTransportKeys(Transaction transaction,
- Map keys) throws DbException {
+ Collection keys) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
- Map filtered = new HashMap<>();
- for (Entry e : keys.entrySet()) {
- ContactId c = e.getKey();
- TransportKeys k = e.getValue();
- if (db.containsContact(txn, c)
- && db.containsTransport(txn, k.getTransportId())) {
- filtered.put(c, k);
- }
+ Collection filtered = new ArrayList<>();
+ for (KeySet ks : keys) {
+ TransportId t = ks.getTransportKeys().getTransportId();
+ if (db.containsTransport(txn, t)) filtered.add(ks);
}
db.updateTransportKeys(txn, filtered);
}
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java
index 51fa9ae30..327203d63 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java
@@ -25,6 +25,8 @@ import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.ValidationManager.State;
import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.transport.IncomingKeys;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.TransportKeys;
@@ -223,37 +225,43 @@ abstract class JdbcDatabase implements Database {
+ " maxLatency INT NOT NULL,"
+ " PRIMARY KEY (transportId))";
+ private static final String CREATE_OUTGOING_KEYS =
+ "CREATE TABLE outgoingKeys"
+ + " (transportId _STRING NOT NULL,"
+ + " keySetId _COUNTER,"
+ + " rotationPeriod BIGINT NOT NULL,"
+ + " contactId INT," // Null if keys are not bound
+ + " tagKey _SECRET NOT NULL,"
+ + " headerKey _SECRET NOT NULL,"
+ + " stream BIGINT NOT NULL,"
+ + " PRIMARY KEY (transportId, keySetId),"
+ + " FOREIGN KEY (transportId)"
+ + " REFERENCES transports (transportId)"
+ + " ON DELETE CASCADE,"
+ + " UNIQUE (keySetId),"
+ + " FOREIGN KEY (contactId)"
+ + " REFERENCES contacts (contactId)"
+ + " ON DELETE CASCADE)";
+
private static final String CREATE_INCOMING_KEYS =
"CREATE TABLE incomingKeys"
- + " (contactId INT NOT NULL,"
- + " transportId _STRING NOT NULL,"
+ + " (transportId _STRING NOT NULL,"
+ + " keySetId INT NOT NULL,"
+ " rotationPeriod BIGINT NOT NULL,"
+ + " contactId INT," // Null if keys are not bound
+ " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL,"
+ " base BIGINT NOT NULL,"
+ " bitmap _BINARY NOT NULL,"
- + " PRIMARY KEY (contactId, transportId, rotationPeriod),"
- + " FOREIGN KEY (contactId)"
- + " REFERENCES contacts (contactId)"
- + " ON DELETE CASCADE,"
+ + " PRIMARY KEY (transportId, keySetId, rotationPeriod),"
+ " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)"
- + " ON DELETE CASCADE)";
-
- private static final String CREATE_OUTGOING_KEYS =
- "CREATE TABLE outgoingKeys"
- + " (contactId INT NOT NULL,"
- + " transportId _STRING NOT NULL,"
- + " rotationPeriod BIGINT NOT NULL,"
- + " tagKey _SECRET NOT NULL,"
- + " headerKey _SECRET NOT NULL,"
- + " stream BIGINT NOT NULL,"
- + " PRIMARY KEY (contactId, transportId),"
+ + " ON DELETE CASCADE,"
+ + " FOREIGN KEY (keySetId)"
+ + " REFERENCES outgoingKeys (keySetId)"
+ + " ON DELETE CASCADE,"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
- + " ON DELETE CASCADE,"
- + " FOREIGN KEY (transportId)"
- + " REFERENCES transports (transportId)"
+ " ON DELETE CASCADE)";
private static final String INDEX_CONTACTS_BY_AUTHOR_ID =
@@ -415,8 +423,8 @@ abstract class JdbcDatabase implements Database {
s.executeUpdate(insertTypeNames(CREATE_OFFERS));
s.executeUpdate(insertTypeNames(CREATE_STATUSES));
s.executeUpdate(insertTypeNames(CREATE_TRANSPORTS));
- s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS));
s.executeUpdate(insertTypeNames(CREATE_OUTGOING_KEYS));
+ s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS));
s.close();
} catch (SQLException e) {
tryToClose(s);
@@ -865,52 +873,18 @@ abstract class JdbcDatabase implements Database {
}
@Override
- public void addTransportKeys(Connection txn, ContactId c, TransportKeys k)
- throws DbException {
+ public KeySetId addTransportKeys(Connection txn, @Nullable ContactId c,
+ TransportKeys k) throws DbException {
PreparedStatement ps = null;
+ ResultSet rs = null;
try {
- // Store the incoming keys
- String sql = "INSERT INTO incomingKeys (contactId, transportId,"
- + " rotationPeriod, tagKey, headerKey, base, bitmap)"
- + " VALUES (?, ?, ?, ?, ?, ?, ?)";
- ps = txn.prepareStatement(sql);
- ps.setInt(1, c.getInt());
- ps.setString(2, k.getTransportId().getString());
- // Previous rotation period
- IncomingKeys inPrev = k.getPreviousIncomingKeys();
- ps.setLong(3, inPrev.getRotationPeriod());
- ps.setBytes(4, inPrev.getTagKey().getBytes());
- ps.setBytes(5, inPrev.getHeaderKey().getBytes());
- ps.setLong(6, inPrev.getWindowBase());
- ps.setBytes(7, inPrev.getWindowBitmap());
- ps.addBatch();
- // Current rotation period
- IncomingKeys inCurr = k.getCurrentIncomingKeys();
- ps.setLong(3, inCurr.getRotationPeriod());
- ps.setBytes(4, inCurr.getTagKey().getBytes());
- ps.setBytes(5, inCurr.getHeaderKey().getBytes());
- ps.setLong(6, inCurr.getWindowBase());
- ps.setBytes(7, inCurr.getWindowBitmap());
- ps.addBatch();
- // Next rotation period
- IncomingKeys inNext = k.getNextIncomingKeys();
- ps.setLong(3, inNext.getRotationPeriod());
- ps.setBytes(4, inNext.getTagKey().getBytes());
- ps.setBytes(5, inNext.getHeaderKey().getBytes());
- ps.setLong(6, inNext.getWindowBase());
- ps.setBytes(7, inNext.getWindowBitmap());
- ps.addBatch();
- int[] batchAffected = ps.executeBatch();
- if (batchAffected.length != 3) throw new DbStateException();
- for (int rows : batchAffected)
- if (rows != 1) throw new DbStateException();
- ps.close();
// Store the outgoing keys
- sql = "INSERT INTO outgoingKeys (contactId, transportId,"
+ String sql = "INSERT INTO outgoingKeys (contactId, transportId,"
+ " rotationPeriod, tagKey, headerKey, stream)"
+ " VALUES (?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
- ps.setInt(1, c.getInt());
+ if (c == null) ps.setNull(1, INTEGER);
+ else ps.setInt(1, c.getInt());
ps.setString(2, k.getTransportId().getString());
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
ps.setLong(3, outCurr.getRotationPeriod());
@@ -920,7 +894,57 @@ abstract class JdbcDatabase implements Database {
int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException();
ps.close();
+ // Get the new (highest) key set ID
+ sql = "SELECT keySetId FROM outgoingKeys"
+ + " ORDER BY keySetId DESC LIMIT 1";
+ ps = txn.prepareStatement(sql);
+ rs = ps.executeQuery();
+ if (!rs.next()) throw new DbStateException();
+ KeySetId keySetId = new KeySetId(rs.getInt(1));
+ if (rs.next()) throw new DbStateException();
+ rs.close();
+ ps.close();
+ // Store the incoming keys
+ sql = "INSERT INTO incomingKeys (keySetId, contactId, transportId,"
+ + " rotationPeriod, tagKey, headerKey, base, bitmap)"
+ + " VALUES (?, ?, ?, ?, ?, ?, ?, ?)";
+ ps = txn.prepareStatement(sql);
+ ps.setInt(1, keySetId.getInt());
+ if (c == null) ps.setNull(2, INTEGER);
+ else ps.setInt(2, c.getInt());
+ ps.setString(3, k.getTransportId().getString());
+ // Previous rotation period
+ IncomingKeys inPrev = k.getPreviousIncomingKeys();
+ ps.setLong(4, inPrev.getRotationPeriod());
+ ps.setBytes(5, inPrev.getTagKey().getBytes());
+ ps.setBytes(6, inPrev.getHeaderKey().getBytes());
+ ps.setLong(7, inPrev.getWindowBase());
+ ps.setBytes(8, inPrev.getWindowBitmap());
+ ps.addBatch();
+ // Current rotation period
+ IncomingKeys inCurr = k.getCurrentIncomingKeys();
+ ps.setLong(4, inCurr.getRotationPeriod());
+ ps.setBytes(5, inCurr.getTagKey().getBytes());
+ ps.setBytes(6, inCurr.getHeaderKey().getBytes());
+ ps.setLong(7, inCurr.getWindowBase());
+ ps.setBytes(8, inCurr.getWindowBitmap());
+ ps.addBatch();
+ // Next rotation period
+ IncomingKeys inNext = k.getNextIncomingKeys();
+ ps.setLong(4, inNext.getRotationPeriod());
+ ps.setBytes(5, inNext.getTagKey().getBytes());
+ ps.setBytes(6, inNext.getHeaderKey().getBytes());
+ ps.setLong(7, inNext.getWindowBase());
+ ps.setBytes(8, inNext.getWindowBitmap());
+ ps.addBatch();
+ int[] batchAffected = ps.executeBatch();
+ if (batchAffected.length != 3) throw new DbStateException();
+ for (int rows : batchAffected)
+ if (rows != 1) throw new DbStateException();
+ ps.close();
+ return keySetId;
} catch (SQLException e) {
+ tryToClose(rs);
tryToClose(ps);
throw new DbException(e);
}
@@ -2078,8 +2102,8 @@ abstract class JdbcDatabase implements Database {
}
@Override
- public Map getTransportKeys(Connection txn,
- TransportId t) throws DbException {
+ public Collection getTransportKeys(Connection txn, TransportId t)
+ throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
@@ -2088,7 +2112,7 @@ abstract class JdbcDatabase implements Database {
+ " base, bitmap"
+ " FROM incomingKeys"
+ " WHERE transportId = ?"
- + " ORDER BY contactId, rotationPeriod";
+ + " ORDER BY keySetId, rotationPeriod";
ps = txn.prepareStatement(sql);
ps.setString(1, t.getString());
rs = ps.executeQuery();
@@ -2105,29 +2129,33 @@ abstract class JdbcDatabase implements Database {
rs.close();
ps.close();
// Retrieve the outgoing keys in the same order
- sql = "SELECT contactId, rotationPeriod, tagKey, headerKey, stream"
+ sql = "SELECT keySetId, contactId, rotationPeriod,"
+ + " tagKey, headerKey, stream"
+ " FROM outgoingKeys"
+ " WHERE transportId = ?"
- + " ORDER BY contactId, rotationPeriod";
+ + " ORDER BY keySetId";
ps = txn.prepareStatement(sql);
ps.setString(1, t.getString());
rs = ps.executeQuery();
- Map keys = new HashMap<>();
+ Collection keys = new ArrayList<>();
for (int i = 0; rs.next(); i++) {
// There should be three times as many incoming keys
if (inKeys.size() < (i + 1) * 3) throw new DbStateException();
- ContactId contactId = new ContactId(rs.getInt(1));
- long rotationPeriod = rs.getLong(2);
- SecretKey tagKey = new SecretKey(rs.getBytes(3));
- SecretKey headerKey = new SecretKey(rs.getBytes(4));
- long streamCounter = rs.getLong(5);
+ KeySetId keySetId = new KeySetId(rs.getInt(1));
+ ContactId contactId = new ContactId(rs.getInt(2));
+ if (rs.wasNull()) contactId = null;
+ long rotationPeriod = rs.getLong(3);
+ SecretKey tagKey = new SecretKey(rs.getBytes(4));
+ SecretKey headerKey = new SecretKey(rs.getBytes(5));
+ long streamCounter = rs.getLong(6);
OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey,
rotationPeriod, streamCounter);
IncomingKeys inPrev = inKeys.get(i * 3);
IncomingKeys inCurr = inKeys.get(i * 3 + 1);
IncomingKeys inNext = inKeys.get(i * 3 + 2);
- keys.put(contactId, new TransportKeys(t, inPrev, inCurr,
- inNext, outCurr));
+ TransportKeys transportKeys = new TransportKeys(t, inPrev,
+ inCurr, inNext, outCurr);
+ keys.add(new KeySet(keySetId, contactId, transportKeys));
}
rs.close();
ps.close();
@@ -2791,18 +2819,18 @@ abstract class JdbcDatabase implements Database {
}
@Override
- public void setReorderingWindow(Connection txn, ContactId c, TransportId t,
+ public void setReorderingWindow(Connection txn, KeySetId k, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException {
PreparedStatement ps = null;
try {
String sql = "UPDATE incomingKeys SET base = ?, bitmap = ?"
- + " WHERE contactId = ? AND transportId = ?"
+ + " WHERE transportId = ? AND keySetId = ?"
+ " AND rotationPeriod = ?";
ps = txn.prepareStatement(sql);
ps.setLong(1, base);
ps.setBytes(2, bitmap);
- ps.setInt(3, c.getInt());
- ps.setString(4, t.getString());
+ ps.setString(3, t.getString());
+ ps.setInt(4, k.getInt());
ps.setLong(5, rotationPeriod);
int affected = ps.executeUpdate();
if (affected < 0 || affected > 1) throw new DbStateException();
@@ -2848,45 +2876,31 @@ abstract class JdbcDatabase implements Database {
}
@Override
- public void updateTransportKeys(Connection txn,
- Map keys) throws DbException {
+ public void updateTransportKeys(Connection txn, Collection keys)
+ throws DbException {
PreparedStatement ps = null;
try {
- // Delete any existing incoming keys
- String sql = "DELETE FROM incomingKeys"
- + " WHERE contactId = ?"
- + " AND transportId = ?";
+ // Delete any existing outgoing keys - this will also remove any
+ // incoming keys with the same key set ID
+ String sql = "DELETE FROM outgoingKeys WHERE keySetId = ?";
ps = txn.prepareStatement(sql);
- for (Entry e : keys.entrySet()) {
- ps.setInt(1, e.getKey().getInt());
- ps.setString(2, e.getValue().getTransportId().getString());
+ for (KeySet ks : keys) {
+ ps.setInt(1, ks.getKeySetId().getInt());
ps.addBatch();
}
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != keys.size())
throw new DbStateException();
- ps.close();
- // Delete any existing outgoing keys
- sql = "DELETE FROM outgoingKeys"
- + " WHERE contactId = ?"
- + " AND transportId = ?";
- ps = txn.prepareStatement(sql);
- for (Entry e : keys.entrySet()) {
- ps.setInt(1, e.getKey().getInt());
- ps.setString(2, e.getValue().getTransportId().getString());
- ps.addBatch();
- }
- batchAffected = ps.executeBatch();
- if (batchAffected.length != keys.size())
- throw new DbStateException();
+ for (int rows: batchAffected)
+ if (rows < 0) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
// Store the new keys
- for (Entry e : keys.entrySet()) {
- addTransportKeys(txn, e.getKey(), e.getValue());
+ for (KeySet ks : keys) {
+ addTransportKeys(txn, ks.getContactId(), ks.getTransportKeys());
}
}
}
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableKeySet.java b/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableKeySet.java
new file mode 100644
index 000000000..b55c5aef4
--- /dev/null
+++ b/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableKeySet.java
@@ -0,0 +1,34 @@
+package org.briarproject.bramble.transport;
+
+import org.briarproject.bramble.api.contact.ContactId;
+import org.briarproject.bramble.api.transport.KeySetId;
+
+import javax.annotation.Nullable;
+
+public class MutableKeySet {
+
+ private final KeySetId keySetId;
+ @Nullable
+ private final ContactId contactId;
+ private final MutableTransportKeys transportKeys;
+
+ public MutableKeySet(KeySetId keySetId, @Nullable ContactId contactId,
+ MutableTransportKeys transportKeys) {
+ this.keySetId = keySetId;
+ this.contactId = contactId;
+ this.transportKeys = transportKeys;
+ }
+
+ public KeySetId getKeySetId() {
+ return keySetId;
+ }
+
+ @Nullable
+ public ContactId getContactId() {
+ return contactId;
+ }
+
+ public MutableTransportKeys getTransportKeys() {
+ return transportKeys;
+ }
+}
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java
index 60b48427f..1220beee5 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java
@@ -11,19 +11,25 @@ import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.system.Scheduler;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.StreamContext;
import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.transport.ReorderingWindow.Change;
+import java.util.ArrayList;
+import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger;
+import javax.annotation.Nullable;
import javax.annotation.concurrent.ThreadSafe;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
@@ -47,12 +53,13 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private final Clock clock;
private final TransportId transportId;
private final long rotationPeriodLength;
- private final ReentrantLock lock;
+ private final AtomicBoolean used = new AtomicBoolean(false);
+ private final ReentrantLock lock = new ReentrantLock();
// The following are locking: lock
- private final Map inContexts;
- private final Map outContexts;
- private final Map keys;
+ private final Collection keys = new ArrayList<>();
+ private final Map inContexts = new HashMap<>();
+ private final Map outContexts = new HashMap<>();
TransportKeyManagerImpl(DatabaseComponent db,
TransportCrypto transportCrypto, Executor dbExecutor,
@@ -65,20 +72,16 @@ class TransportKeyManagerImpl implements TransportKeyManager {
this.clock = clock;
this.transportId = transportId;
rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE;
- lock = new ReentrantLock();
- inContexts = new HashMap<>();
- outContexts = new HashMap<>();
- keys = new HashMap<>();
}
@Override
public void start(Transaction txn) throws DbException {
+ if (used.getAndSet(true)) throw new IllegalStateException();
long now = clock.currentTimeMillis();
lock.lock();
try {
// Load the transport keys from the DB
- Map loaded =
- db.getTransportKeys(txn, transportId);
+ Collection loaded = db.getTransportKeys(txn, transportId);
// Rotate the keys to the current rotation period
RotationResult rotationResult = rotateKeys(loaded, now);
// Initialise mutable state for all contacts
@@ -93,41 +96,51 @@ class TransportKeyManagerImpl implements TransportKeyManager {
scheduleKeyRotation(now);
}
- private RotationResult rotateKeys(Map keys,
- long now) {
+ private RotationResult rotateKeys(Collection keys, long now) {
RotationResult rotationResult = new RotationResult();
long rotationPeriod = now / rotationPeriodLength;
- for (Entry e : keys.entrySet()) {
- ContactId c = e.getKey();
- TransportKeys k = e.getValue();
+ for (KeySet ks : keys) {
+ TransportKeys k = ks.getTransportKeys();
TransportKeys k1 =
transportCrypto.rotateTransportKeys(k, rotationPeriod);
+ KeySet ks1 = new KeySet(ks.getKeySetId(), ks.getContactId(), k1);
if (k1.getRotationPeriod() > k.getRotationPeriod())
- rotationResult.rotated.put(c, k1);
- rotationResult.current.put(c, k1);
+ rotationResult.rotated.add(ks1);
+ rotationResult.current.add(ks1);
}
return rotationResult;
}
// Locking: lock
- private void addKeys(Map m) {
- for (Entry e : m.entrySet())
- addKeys(e.getKey(), new MutableTransportKeys(e.getValue()));
+ private void addKeys(Collection keys) {
+ for (KeySet ks : keys) {
+ addKeys(ks.getKeySetId(), ks.getContactId(),
+ new MutableTransportKeys(ks.getTransportKeys()));
+ }
}
// Locking: lock
- private void addKeys(ContactId c, MutableTransportKeys m) {
- encodeTags(c, m.getPreviousIncomingKeys());
- encodeTags(c, m.getCurrentIncomingKeys());
- encodeTags(c, m.getNextIncomingKeys());
- outContexts.put(c, m.getCurrentOutgoingKeys());
- keys.put(c, m);
+ private void addKeys(KeySetId keySetId, @Nullable ContactId contactId,
+ MutableTransportKeys m) {
+ MutableKeySet ks = new MutableKeySet(keySetId, contactId, m);
+ keys.add(ks);
+ if (contactId != null) {
+ encodeTags(keySetId, contactId, m.getPreviousIncomingKeys());
+ encodeTags(keySetId, contactId, m.getCurrentIncomingKeys());
+ encodeTags(keySetId, contactId, m.getNextIncomingKeys());
+ // Use the outgoing keys with the highest key set ID
+ MutableKeySet old = outContexts.get(contactId);
+ if (old == null || old.getKeySetId().getInt() < keySetId.getInt())
+ outContexts.put(contactId, ks);
+ }
}
// Locking: lock
- private void encodeTags(ContactId c, MutableIncomingKeys inKeys) {
+ private void encodeTags(KeySetId keySetId, ContactId contactId,
+ MutableIncomingKeys inKeys) {
for (long streamNumber : inKeys.getWindow().getUnseen()) {
- TagContext tagCtx = new TagContext(c, inKeys, streamNumber);
+ TagContext tagCtx =
+ new TagContext(keySetId, contactId, inKeys, streamNumber);
byte[] tag = new byte[TAG_LENGTH];
transportCrypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION,
streamNumber);
@@ -169,10 +182,10 @@ class TransportKeyManagerImpl implements TransportKeyManager {
// Rotate the keys to the current rotation period if necessary
rotationPeriod = clock.currentTimeMillis() / rotationPeriodLength;
k = transportCrypto.rotateTransportKeys(k, rotationPeriod);
- // Initialise mutable state for the contact
- addKeys(c, new MutableTransportKeys(k));
// Write the keys back to the DB
- db.addTransportKeys(txn, c, k);
+ KeySetId keySetId = db.addTransportKeys(txn, c, k);
+ // Initialise mutable state for the contact
+ addKeys(keySetId, c, new MutableTransportKeys(k));
} finally {
lock.unlock();
}
@@ -183,12 +196,18 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock();
try {
// Remove mutable state for the contact
- Iterator> it =
+ Iterator> inContextsIter =
inContexts.entrySet().iterator();
- while (it.hasNext())
- if (it.next().getValue().contactId.equals(c)) it.remove();
+ while (inContextsIter.hasNext()) {
+ ContactId c1 = inContextsIter.next().getValue().contactId;
+ if (c1.equals(c)) inContextsIter.remove();
+ }
outContexts.remove(c);
- keys.remove(c);
+ Iterator keysIter = keys.iterator();
+ while (keysIter.hasNext()) {
+ ContactId c1 = keysIter.next().getContactId();
+ if (c1 != null && c1.equals(c)) keysIter.remove();
+ }
} finally {
lock.unlock();
}
@@ -200,8 +219,10 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock();
try {
// Look up the outgoing keys for the contact
- MutableOutgoingKeys outKeys = outContexts.get(c);
- if (outKeys == null) return null;
+ MutableKeySet ks = outContexts.get(c);
+ if (ks == null) return null;
+ MutableOutgoingKeys outKeys =
+ ks.getTransportKeys().getCurrentOutgoingKeys();
if (outKeys.getStreamCounter() > MAX_32_BIT_UNSIGNED) return null;
// Create a stream context
StreamContext ctx = new StreamContext(c, transportId,
@@ -238,8 +259,9 @@ class TransportKeyManagerImpl implements TransportKeyManager {
byte[] addTag = new byte[TAG_LENGTH];
transportCrypto.encodeTag(addTag, inKeys.getTagKey(),
PROTOCOL_VERSION, streamNumber);
- inContexts.put(new Bytes(addTag), new TagContext(
- tagCtx.contactId, inKeys, streamNumber));
+ TagContext tagCtx1 = new TagContext(tagCtx.keySetId,
+ tagCtx.contactId, inKeys, streamNumber);
+ inContexts.put(new Bytes(addTag), tagCtx1);
}
// Remove tags for any stream numbers removed from the window
for (long streamNumber : change.getRemoved()) {
@@ -250,7 +272,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
inContexts.remove(new Bytes(removeTag));
}
// Write the window back to the DB
- db.setReorderingWindow(txn, tagCtx.contactId, transportId,
+ db.setReorderingWindow(txn, tagCtx.keySetId, transportId,
inKeys.getRotationPeriod(), window.getBase(),
window.getBitmap());
return ctx;
@@ -264,9 +286,11 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock();
try {
// Rotate the keys to the current rotation period
- Map snapshot = new HashMap<>();
- for (Entry e : keys.entrySet())
- snapshot.put(e.getKey(), e.getValue().snapshot());
+ Collection snapshot = new ArrayList<>(keys.size());
+ for (MutableKeySet ks : keys) {
+ snapshot.add(new KeySet(ks.getKeySetId(), ks.getContactId(),
+ ks.getTransportKeys().snapshot()));
+ }
RotationResult rotationResult = rotateKeys(snapshot, now);
// Rebuild the mutable state for all contacts
inContexts.clear();
@@ -285,12 +309,14 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private static class TagContext {
+ private final KeySetId keySetId;
private final ContactId contactId;
private final MutableIncomingKeys inKeys;
private final long streamNumber;
- private TagContext(ContactId contactId, MutableIncomingKeys inKeys,
- long streamNumber) {
+ private TagContext(KeySetId keySetId, ContactId contactId,
+ MutableIncomingKeys inKeys, long streamNumber) {
+ this.keySetId = keySetId;
this.contactId = contactId;
this.inKeys = inKeys;
this.streamNumber = streamNumber;
@@ -299,11 +325,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private static class RotationResult {
- private final Map current, rotated;
-
- private RotationResult() {
- current = new HashMap<>();
- rotated = new HashMap<>();
- }
+ private final Collection current = new ArrayList<>();
+ private final Collection rotated = new ArrayList<>();
}
}
diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java
index 0ed457461..eb31247ce 100644
--- a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java
+++ b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java
@@ -44,6 +44,8 @@ import org.briarproject.bramble.api.sync.event.MessageToRequestEvent;
import org.briarproject.bramble.api.sync.event.MessagesAckedEvent;
import org.briarproject.bramble.api.sync.event.MessagesSentEvent;
import org.briarproject.bramble.api.transport.IncomingKeys;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.test.BrambleMockTestCase;
@@ -55,12 +57,10 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
-import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList;
-import static java.util.Collections.singletonMap;
import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE;
import static org.briarproject.bramble.api.sync.Group.Visibility.SHARED;
import static org.briarproject.bramble.api.sync.Group.Visibility.VISIBLE;
@@ -100,6 +100,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
private final int maxLatency;
private final ContactId contactId;
private final Contact contact;
+ private final KeySetId keySetId;
public DatabaseComponentImplTest() {
clientId = new ClientId(getRandomString(123));
@@ -121,6 +122,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
contactId = new ContactId(234);
contact = new Contact(contactId, author, localAuthor.getId(),
true, true);
+ keySetId = new KeySetId(345);
}
private DatabaseComponent createDatabaseComponent(Database