Merge branch 'multiple-transport-keys' into 'master'

Support multiple sets of transport keys per contact

See merge request akwizgran/briar!745
This commit is contained in:
akwizgran
2018-04-17 14:02:45 +00:00
23 changed files with 1278 additions and 476 deletions

View File

@@ -50,7 +50,7 @@ class ContactManagerImpl implements ContactManager {
@Override
public ContactId addContact(Transaction txn, Author remote, AuthorId local,
SecretKey master,long timestamp, boolean alice, boolean verified,
SecretKey master, long timestamp, boolean alice, boolean verified,
boolean active) throws DbException {
ContactId c = db.addContact(txn, remote, local, verified, active);
keyManager.addContact(txn, c, master, timestamp, alice);
@@ -60,6 +60,16 @@ class ContactManagerImpl implements ContactManager {
return c;
}
@Override
public ContactId addContact(Transaction txn, Author remote, AuthorId local,
boolean verified, boolean active) throws DbException {
ContactId c = db.addContact(txn, remote, local, verified, active);
Contact contact = db.getContact(txn, c);
for (AddContactHook hook : addHooks)
hook.addingContact(txn, contact);
return c;
}
@Override
public ContactId addContact(Author remote, AuthorId local, SecretKey master,
long timestamp, boolean alice, boolean verified, boolean active)

View File

@@ -36,7 +36,8 @@ class TransportCryptoImpl implements TransportCrypto {
@Override
public TransportKeys deriveTransportKeys(TransportId t,
SecretKey master, long rotationPeriod, boolean alice) {
SecretKey master, long rotationPeriod, boolean alice,
boolean active) {
// Keys for the previous period are derived from the master secret
SecretKey inTagPrev = deriveTagKey(master, t, !alice);
SecretKey inHeaderPrev = deriveHeaderKey(master, t, !alice);
@@ -57,7 +58,7 @@ class TransportCryptoImpl implements TransportCrypto {
IncomingKeys inNext = new IncomingKeys(inTagNext, inHeaderNext,
rotationPeriod + 1);
OutgoingKeys outCurr = new OutgoingKeys(outTagCurr, outHeaderCurr,
rotationPeriod);
rotationPeriod, active);
// Collect and return the keys
return new TransportKeys(t, inPrev, inCurr, inNext, outCurr);
}
@@ -71,6 +72,7 @@ class TransportCryptoImpl implements TransportCrypto {
IncomingKeys inNext = k.getNextIncomingKeys();
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
long startPeriod = outCurr.getRotationPeriod();
boolean active = outCurr.isActive();
// Rotate the keys
for (long p = startPeriod + 1; p <= rotationPeriod; p++) {
inPrev = inCurr;
@@ -80,7 +82,7 @@ class TransportCryptoImpl implements TransportCrypto {
inNext = new IncomingKeys(inNextTag, inNextHeader, p + 1);
SecretKey outCurrTag = rotateKey(outCurr.getTagKey(), p);
SecretKey outCurrHeader = rotateKey(outCurr.getHeaderKey(), p);
outCurr = new OutgoingKeys(outCurrTag, outCurrHeader, p);
outCurr = new OutgoingKeys(outCurrTag, outCurrHeader, p, active);
}
// Collect and return the keys
return new TransportKeys(k.getTransportId(), inPrev, inCurr, inNext,

View File

@@ -21,6 +21,8 @@ import org.briarproject.bramble.api.sync.Message;
import org.briarproject.bramble.api.sync.MessageId;
import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.ValidationManager.State;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.TransportKeys;
import java.util.Collection;
@@ -123,9 +125,16 @@ interface Database<T> {
throws DbException;
/**
* Stores transport keys for a newly added contact.
* Stores the given transport keys, optionally binding them to the given
* contact, and returns a key set ID.
*/
void addTransportKeys(T txn, ContactId c, TransportKeys k)
KeySetId addTransportKeys(T txn, @Nullable ContactId c, TransportKeys k)
throws DbException;
/**
* Binds the given keys for the given transport to the given contact.
*/
void bindTransportKeys(T txn, ContactId c, TransportId t, KeySetId k)
throws DbException;
/**
@@ -486,15 +495,14 @@ interface Database<T> {
* <p/>
* Read-only.
*/
Map<ContactId, TransportKeys> getTransportKeys(T txn, TransportId t)
Collection<KeySet> getTransportKeys(T txn, TransportId t)
throws DbException;
/**
* Increments the outgoing stream counter for the given contact and
* transport in the given rotation period.
* Increments the outgoing stream counter for the given transport keys.
*/
void incrementStreamCounter(T txn, ContactId c, TransportId t,
long rotationPeriod) throws DbException;
void incrementStreamCounter(T txn, TransportId t, KeySetId k)
throws DbException;
/**
* Marks the given messages as not needing to be acknowledged to the
@@ -584,6 +592,12 @@ interface Database<T> {
*/
void removeTransport(T txn, TransportId t) throws DbException;
/**
* Removes the given transport keys from the database.
*/
void removeTransportKeys(T txn, TransportId t, KeySetId k)
throws DbException;
/**
* Resets the transmission count and expiry time of the given message with
* respect to the given contact.
@@ -619,12 +633,18 @@ interface Database<T> {
void setMessageState(T txn, MessageId m, State state) throws DbException;
/**
* Sets the reordering window for the given contact and transport in the
* Sets the reordering window for the given key set and transport in the
* given rotation period.
*/
void setReorderingWindow(T txn, ContactId c, TransportId t,
void setReorderingWindow(T txn, KeySetId k, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException;
/**
* Marks the given transport keys as usable for outgoing streams.
*/
void setTransportKeysActive(T txn, TransportId t, KeySetId k)
throws DbException;
/**
* Updates the transmission count and expiry time of the given message
* with respect to the given contact, using the latency of the transport
@@ -636,6 +656,5 @@ interface Database<T> {
/**
* Stores the given transport keys, deleting any keys they have replaced.
*/
void updateTransportKeys(T txn, Map<ContactId, TransportKeys> keys)
throws DbException;
void updateTransportKeys(T txn, Collection<KeySet> keys) throws DbException;
}

View File

@@ -51,15 +51,15 @@ import org.briarproject.bramble.api.sync.event.MessageToAckEvent;
import org.briarproject.bramble.api.sync.event.MessageToRequestEvent;
import org.briarproject.bramble.api.sync.event.MessagesAckedEvent;
import org.briarproject.bramble.api.sync.event.MessagesSentEvent;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.TransportKeys;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.logging.Logger;
@@ -234,15 +234,27 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
}
@Override
public void addTransportKeys(Transaction transaction, ContactId c,
TransportKeys k) throws DbException {
public KeySetId addTransportKeys(Transaction transaction,
@Nullable ContactId c, TransportKeys k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (c != null && !db.containsContact(txn, c))
throw new NoSuchContactException();
if (!db.containsTransport(txn, k.getTransportId()))
throw new NoSuchTransportException();
return db.addTransportKeys(txn, c, k);
}
@Override
public void bindTransportKeys(Transaction transaction, ContactId c,
TransportId t, KeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsContact(txn, c))
throw new NoSuchContactException();
if (!db.containsTransport(txn, k.getTransportId()))
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
db.addTransportKeys(txn, c, k);
db.bindTransportKeys(txn, c, t, k);
}
@Override
@@ -586,8 +598,8 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
}
@Override
public Map<ContactId, TransportKeys> getTransportKeys(
Transaction transaction, TransportId t) throws DbException {
public Collection<KeySet> getTransportKeys(Transaction transaction,
TransportId t) throws DbException {
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
@@ -595,15 +607,13 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
}
@Override
public void incrementStreamCounter(Transaction transaction, ContactId c,
TransportId t, long rotationPeriod) throws DbException {
public void incrementStreamCounter(Transaction transaction, TransportId t,
KeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsContact(txn, c))
throw new NoSuchContactException();
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
db.incrementStreamCounter(txn, c, t, rotationPeriod);
db.incrementStreamCounter(txn, t, k);
}
@Override
@@ -779,6 +789,16 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
db.removeTransport(txn, t);
}
@Override
public void removeTransportKeys(Transaction transaction,
TransportId t, KeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
db.removeTransportKeys(txn, t, k);
}
@Override
public void setContactVerified(Transaction transaction, ContactId c)
throws DbException {
@@ -858,31 +878,35 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
}
@Override
public void setReorderingWindow(Transaction transaction, ContactId c,
public void setReorderingWindow(Transaction transaction, KeySetId k,
TransportId t, long rotationPeriod, long base, byte[] bitmap)
throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsContact(txn, c))
throw new NoSuchContactException();
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
db.setReorderingWindow(txn, c, t, rotationPeriod, base, bitmap);
db.setReorderingWindow(txn, k, t, rotationPeriod, base, bitmap);
}
@Override
public void setTransportKeysActive(Transaction transaction, TransportId t,
KeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
db.setTransportKeysActive(txn, t, k);
}
@Override
public void updateTransportKeys(Transaction transaction,
Map<ContactId, TransportKeys> keys) throws DbException {
Collection<KeySet> keys) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
Map<ContactId, TransportKeys> filtered = new HashMap<>();
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
ContactId c = e.getKey();
TransportKeys k = e.getValue();
if (db.containsContact(txn, c)
&& db.containsTransport(txn, k.getTransportId())) {
filtered.put(c, k);
}
Collection<KeySet> filtered = new ArrayList<>();
for (KeySet ks : keys) {
TransportId t = ks.getTransportKeys().getTransportId();
if (db.containsTransport(txn, t)) filtered.add(ks);
}
db.updateTransportKeys(txn, filtered);
}

View File

@@ -25,6 +25,8 @@ import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.ValidationManager.State;
import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.TransportKeys;
@@ -223,37 +225,44 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " maxLatency INT NOT NULL,"
+ " PRIMARY KEY (transportId))";
private static final String CREATE_OUTGOING_KEYS =
"CREATE TABLE outgoingKeys"
+ " (transportId _STRING NOT NULL,"
+ " keySetId _COUNTER,"
+ " rotationPeriod BIGINT NOT NULL,"
+ " contactId INT," // Null if keys are not bound
+ " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL,"
+ " stream BIGINT NOT NULL,"
+ " active BOOLEAN NOT NULL,"
+ " PRIMARY KEY (transportId, keySetId),"
+ " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)"
+ " ON DELETE CASCADE,"
+ " UNIQUE (keySetId),"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE)";
private static final String CREATE_INCOMING_KEYS =
"CREATE TABLE incomingKeys"
+ " (contactId INT NOT NULL,"
+ " transportId _STRING NOT NULL,"
+ " (transportId _STRING NOT NULL,"
+ " keySetId INT NOT NULL,"
+ " rotationPeriod BIGINT NOT NULL,"
+ " contactId INT," // Null if keys are not bound
+ " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL,"
+ " base BIGINT NOT NULL,"
+ " bitmap _BINARY NOT NULL,"
+ " PRIMARY KEY (contactId, transportId, rotationPeriod),"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE,"
+ " PRIMARY KEY (transportId, keySetId, rotationPeriod),"
+ " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)"
+ " ON DELETE CASCADE)";
private static final String CREATE_OUTGOING_KEYS =
"CREATE TABLE outgoingKeys"
+ " (contactId INT NOT NULL,"
+ " transportId _STRING NOT NULL,"
+ " rotationPeriod BIGINT NOT NULL,"
+ " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL,"
+ " stream BIGINT NOT NULL,"
+ " PRIMARY KEY (contactId, transportId),"
+ " ON DELETE CASCADE,"
+ " FOREIGN KEY (keySetId)"
+ " REFERENCES outgoingKeys (keySetId)"
+ " ON DELETE CASCADE,"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE,"
+ " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)"
+ " ON DELETE CASCADE)";
private static final String INDEX_CONTACTS_BY_AUTHOR_ID =
@@ -415,8 +424,8 @@ abstract class JdbcDatabase implements Database<Connection> {
s.executeUpdate(insertTypeNames(CREATE_OFFERS));
s.executeUpdate(insertTypeNames(CREATE_STATUSES));
s.executeUpdate(insertTypeNames(CREATE_TRANSPORTS));
s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS));
s.executeUpdate(insertTypeNames(CREATE_OUTGOING_KEYS));
s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS));
s.close();
} catch (SQLException e) {
tryToClose(s);
@@ -865,61 +874,105 @@ abstract class JdbcDatabase implements Database<Connection> {
}
@Override
public void addTransportKeys(Connection txn, ContactId c, TransportKeys k)
throws DbException {
public KeySetId addTransportKeys(Connection txn, @Nullable ContactId c,
TransportKeys k) throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
// Store the incoming keys
String sql = "INSERT INTO incomingKeys (contactId, transportId,"
+ " rotationPeriod, tagKey, headerKey, base, bitmap)"
// Store the outgoing keys
String sql = "INSERT INTO outgoingKeys (contactId, transportId,"
+ " rotationPeriod, tagKey, headerKey, stream, active)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setString(2, k.getTransportId().getString());
// Previous rotation period
IncomingKeys inPrev = k.getPreviousIncomingKeys();
ps.setLong(3, inPrev.getRotationPeriod());
ps.setBytes(4, inPrev.getTagKey().getBytes());
ps.setBytes(5, inPrev.getHeaderKey().getBytes());
ps.setLong(6, inPrev.getWindowBase());
ps.setBytes(7, inPrev.getWindowBitmap());
ps.addBatch();
// Current rotation period
IncomingKeys inCurr = k.getCurrentIncomingKeys();
ps.setLong(3, inCurr.getRotationPeriod());
ps.setBytes(4, inCurr.getTagKey().getBytes());
ps.setBytes(5, inCurr.getHeaderKey().getBytes());
ps.setLong(6, inCurr.getWindowBase());
ps.setBytes(7, inCurr.getWindowBitmap());
ps.addBatch();
// Next rotation period
IncomingKeys inNext = k.getNextIncomingKeys();
ps.setLong(3, inNext.getRotationPeriod());
ps.setBytes(4, inNext.getTagKey().getBytes());
ps.setBytes(5, inNext.getHeaderKey().getBytes());
ps.setLong(6, inNext.getWindowBase());
ps.setBytes(7, inNext.getWindowBitmap());
ps.addBatch();
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != 3) throw new DbStateException();
for (int rows : batchAffected)
if (rows != 1) throw new DbStateException();
ps.close();
// Store the outgoing keys
sql = "INSERT INTO outgoingKeys (contactId, transportId,"
+ " rotationPeriod, tagKey, headerKey, stream)"
+ " VALUES (?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
if (c == null) ps.setNull(1, INTEGER);
else ps.setInt(1, c.getInt());
ps.setString(2, k.getTransportId().getString());
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
ps.setLong(3, outCurr.getRotationPeriod());
ps.setBytes(4, outCurr.getTagKey().getBytes());
ps.setBytes(5, outCurr.getHeaderKey().getBytes());
ps.setLong(6, outCurr.getStreamCounter());
ps.setBoolean(7, outCurr.isActive());
int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException();
ps.close();
// Get the new (highest) key set ID
sql = "SELECT keySetId FROM outgoingKeys"
+ " ORDER BY keySetId DESC LIMIT 1";
ps = txn.prepareStatement(sql);
rs = ps.executeQuery();
if (!rs.next()) throw new DbStateException();
KeySetId keySetId = new KeySetId(rs.getInt(1));
if (rs.next()) throw new DbStateException();
rs.close();
ps.close();
// Store the incoming keys
sql = "INSERT INTO incomingKeys (keySetId, contactId, transportId,"
+ " rotationPeriod, tagKey, headerKey, base, bitmap)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setInt(1, keySetId.getInt());
if (c == null) ps.setNull(2, INTEGER);
else ps.setInt(2, c.getInt());
ps.setString(3, k.getTransportId().getString());
// Previous rotation period
IncomingKeys inPrev = k.getPreviousIncomingKeys();
ps.setLong(4, inPrev.getRotationPeriod());
ps.setBytes(5, inPrev.getTagKey().getBytes());
ps.setBytes(6, inPrev.getHeaderKey().getBytes());
ps.setLong(7, inPrev.getWindowBase());
ps.setBytes(8, inPrev.getWindowBitmap());
ps.addBatch();
// Current rotation period
IncomingKeys inCurr = k.getCurrentIncomingKeys();
ps.setLong(4, inCurr.getRotationPeriod());
ps.setBytes(5, inCurr.getTagKey().getBytes());
ps.setBytes(6, inCurr.getHeaderKey().getBytes());
ps.setLong(7, inCurr.getWindowBase());
ps.setBytes(8, inCurr.getWindowBitmap());
ps.addBatch();
// Next rotation period
IncomingKeys inNext = k.getNextIncomingKeys();
ps.setLong(4, inNext.getRotationPeriod());
ps.setBytes(5, inNext.getTagKey().getBytes());
ps.setBytes(6, inNext.getHeaderKey().getBytes());
ps.setLong(7, inNext.getWindowBase());
ps.setBytes(8, inNext.getWindowBitmap());
ps.addBatch();
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != 3) throw new DbStateException();
for (int rows : batchAffected)
if (rows != 1) throw new DbStateException();
ps.close();
return keySetId;
} catch (SQLException e) {
tryToClose(rs);
tryToClose(ps);
throw new DbException(e);
}
}
@Override
public void bindTransportKeys(Connection txn, ContactId c, TransportId t,
KeySetId k) throws DbException {
PreparedStatement ps = null;
try {
String sql = "UPDATE outgoingKeys SET contactId = ?"
+ " WHERE keySetId = ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setInt(2, k.getInt());
int affected = ps.executeUpdate();
if (affected < 0) throw new DbStateException();
ps.close();
sql = "UPDATE incomingKeys SET contactId = ?"
+ " WHERE keySetId = ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setInt(2, k.getInt());
affected = ps.executeUpdate();
if (affected < 0) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps);
throw new DbException(e);
@@ -2078,8 +2131,8 @@ abstract class JdbcDatabase implements Database<Connection> {
}
@Override
public Map<ContactId, TransportKeys> getTransportKeys(Connection txn,
TransportId t) throws DbException {
public Collection<KeySet> getTransportKeys(Connection txn, TransportId t)
throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
@@ -2088,7 +2141,7 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " base, bitmap"
+ " FROM incomingKeys"
+ " WHERE transportId = ?"
+ " ORDER BY contactId, rotationPeriod";
+ " ORDER BY keySetId, rotationPeriod";
ps = txn.prepareStatement(sql);
ps.setString(1, t.getString());
rs = ps.executeQuery();
@@ -2105,29 +2158,34 @@ abstract class JdbcDatabase implements Database<Connection> {
rs.close();
ps.close();
// Retrieve the outgoing keys in the same order
sql = "SELECT contactId, rotationPeriod, tagKey, headerKey, stream"
sql = "SELECT keySetId, contactId, rotationPeriod,"
+ " tagKey, headerKey, stream, active"
+ " FROM outgoingKeys"
+ " WHERE transportId = ?"
+ " ORDER BY contactId, rotationPeriod";
+ " ORDER BY keySetId";
ps = txn.prepareStatement(sql);
ps.setString(1, t.getString());
rs = ps.executeQuery();
Map<ContactId, TransportKeys> keys = new HashMap<>();
Collection<KeySet> keys = new ArrayList<>();
for (int i = 0; rs.next(); i++) {
// There should be three times as many incoming keys
if (inKeys.size() < (i + 1) * 3) throw new DbStateException();
ContactId contactId = new ContactId(rs.getInt(1));
long rotationPeriod = rs.getLong(2);
SecretKey tagKey = new SecretKey(rs.getBytes(3));
SecretKey headerKey = new SecretKey(rs.getBytes(4));
long streamCounter = rs.getLong(5);
KeySetId keySetId = new KeySetId(rs.getInt(1));
ContactId contactId = new ContactId(rs.getInt(2));
if (rs.wasNull()) contactId = null;
long rotationPeriod = rs.getLong(3);
SecretKey tagKey = new SecretKey(rs.getBytes(4));
SecretKey headerKey = new SecretKey(rs.getBytes(5));
long streamCounter = rs.getLong(6);
boolean active = rs.getBoolean(7);
OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey,
rotationPeriod, streamCounter);
rotationPeriod, streamCounter, active);
IncomingKeys inPrev = inKeys.get(i * 3);
IncomingKeys inCurr = inKeys.get(i * 3 + 1);
IncomingKeys inNext = inKeys.get(i * 3 + 2);
keys.put(contactId, new TransportKeys(t, inPrev, inCurr,
inNext, outCurr));
TransportKeys transportKeys = new TransportKeys(t, inPrev,
inCurr, inNext, outCurr);
keys.add(new KeySet(keySetId, contactId, transportKeys));
}
rs.close();
ps.close();
@@ -2140,17 +2198,15 @@ abstract class JdbcDatabase implements Database<Connection> {
}
@Override
public void incrementStreamCounter(Connection txn, ContactId c,
TransportId t, long rotationPeriod) throws DbException {
public void incrementStreamCounter(Connection txn, TransportId t,
KeySetId k) throws DbException {
PreparedStatement ps = null;
try {
String sql = "UPDATE outgoingKeys SET stream = stream + 1"
+ " WHERE contactId = ? AND transportId = ?"
+ " AND rotationPeriod = ?";
+ " WHERE transportId = ? AND keySetId = ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setString(2, t.getString());
ps.setLong(3, rotationPeriod);
ps.setString(1, t.getString());
ps.setInt(2, k.getInt());
int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException();
ps.close();
@@ -2626,6 +2682,27 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
@Override
public void removeTransportKeys(Connection txn, TransportId t, KeySetId k)
throws DbException {
PreparedStatement ps = null;
try {
// Delete any existing outgoing keys - this will also remove any
// incoming keys with the same key set ID
String sql = "DELETE FROM outgoingKeys"
+ " WHERE transportId = ? AND keySetId = ?";
ps = txn.prepareStatement(sql);
ps.setString(1, t.getString());
ps.setInt(2, k.getInt());
int affected = ps.executeUpdate();
if (affected < 0) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
}
@Override
public void resetExpiryTime(Connection txn, ContactId c, MessageId m)
throws DbException {
@@ -2791,18 +2868,18 @@ abstract class JdbcDatabase implements Database<Connection> {
}
@Override
public void setReorderingWindow(Connection txn, ContactId c, TransportId t,
public void setReorderingWindow(Connection txn, KeySetId k, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException {
PreparedStatement ps = null;
try {
String sql = "UPDATE incomingKeys SET base = ?, bitmap = ?"
+ " WHERE contactId = ? AND transportId = ?"
+ " WHERE transportId = ? AND keySetId = ?"
+ " AND rotationPeriod = ?";
ps = txn.prepareStatement(sql);
ps.setLong(1, base);
ps.setBytes(2, bitmap);
ps.setInt(3, c.getInt());
ps.setString(4, t.getString());
ps.setString(3, t.getString());
ps.setInt(4, k.getInt());
ps.setLong(5, rotationPeriod);
int affected = ps.executeUpdate();
if (affected < 0 || affected > 1) throw new DbStateException();
@@ -2813,6 +2890,23 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
@Override
public void setTransportKeysActive(Connection txn, TransportId t,
KeySetId k) throws DbException {
PreparedStatement ps = null;
try {
String sql = "UPDATE outgoingKeys SET active = true"
+ " WHERE transportId = ? AND keySetId = ?";
ps = txn.prepareStatement(sql);
int affected = ps.executeUpdate();
if (affected < 0 || affected > 1) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
}
@Override
public void updateExpiryTime(Connection txn, ContactId c, MessageId m,
int maxLatency) throws DbException {
@@ -2848,45 +2942,12 @@ abstract class JdbcDatabase implements Database<Connection> {
}
@Override
public void updateTransportKeys(Connection txn,
Map<ContactId, TransportKeys> keys) throws DbException {
PreparedStatement ps = null;
try {
// Delete any existing incoming keys
String sql = "DELETE FROM incomingKeys"
+ " WHERE contactId = ?"
+ " AND transportId = ?";
ps = txn.prepareStatement(sql);
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
ps.setInt(1, e.getKey().getInt());
ps.setString(2, e.getValue().getTransportId().getString());
ps.addBatch();
}
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != keys.size())
throw new DbStateException();
ps.close();
// Delete any existing outgoing keys
sql = "DELETE FROM outgoingKeys"
+ " WHERE contactId = ?"
+ " AND transportId = ?";
ps = txn.prepareStatement(sql);
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
ps.setInt(1, e.getKey().getInt());
ps.setString(2, e.getValue().getTransportId().getString());
ps.addBatch();
}
batchAffected = ps.executeBatch();
if (batchAffected.length != keys.size())
throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
// Store the new keys
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
addTransportKeys(txn, e.getKey(), e.getValue());
public void updateTransportKeys(Connection txn, Collection<KeySet> keys)
throws DbException {
for (KeySet ks : keys) {
TransportKeys k = ks.getTransportKeys();
removeTransportKeys(txn, k.getTransportId(), ks.getKeySetId());
addTransportKeys(txn, ks.getContactId(), k);
}
}
}

View File

@@ -19,6 +19,7 @@ import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexPluginFactory;
import org.briarproject.bramble.api.plugin.simplex.SimplexPluginFactory;
import org.briarproject.bramble.api.transport.KeyManager;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.StreamContext;
import java.util.HashMap;
@@ -104,6 +105,67 @@ class KeyManagerImpl implements KeyManager, Service, EventListener {
m.addContact(txn, c, master, timestamp, alice);
}
@Override
public Map<TransportId, KeySetId> addUnboundKeys(Transaction txn,
SecretKey master, long timestamp, boolean alice)
throws DbException {
Map<TransportId, KeySetId> ids = new HashMap<>();
for (Entry<TransportId, TransportKeyManager> e : managers.entrySet()) {
TransportId t = e.getKey();
TransportKeyManager m = e.getValue();
ids.put(t, m.addUnboundKeys(txn, master, timestamp, alice));
}
return ids;
}
@Override
public void bindKeys(Transaction txn, ContactId c,
Map<TransportId, KeySetId> keys) throws DbException {
for (Entry<TransportId, KeySetId> e : keys.entrySet()) {
TransportId t = e.getKey();
TransportKeyManager m = managers.get(t);
if (m == null) {
if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t);
} else {
m.bindKeys(txn, c, e.getValue());
}
}
}
@Override
public void activateKeys(Transaction txn, Map<TransportId, KeySetId> keys)
throws DbException {
for (Entry<TransportId, KeySetId> e : keys.entrySet()) {
TransportId t = e.getKey();
TransportKeyManager m = managers.get(t);
if (m == null) {
if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t);
} else {
m.activateKeys(txn, e.getValue());
}
}
}
@Override
public void removeKeys(Transaction txn, Map<TransportId, KeySetId> keys)
throws DbException {
for (Entry<TransportId, KeySetId> e : keys.entrySet()) {
TransportId t = e.getKey();
TransportKeyManager m = managers.get(t);
if (m == null) {
if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t);
} else {
m.removeKeys(txn, e.getValue());
}
}
}
@Override
public boolean canSendOutgoingStreams(ContactId c, TransportId t) {
TransportKeyManager m = managers.get(t);
return m == null ? false : m.canSendOutgoingStreams(c);
}
@Override
public StreamContext getStreamContext(ContactId c, TransportId t)
throws DbException {
@@ -114,7 +176,7 @@ class KeyManagerImpl implements KeyManager, Service, EventListener {
if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t);
return null;
}
StreamContext ctx = null;
StreamContext ctx;
Transaction txn = db.startTransaction(false);
try {
ctx = m.getStreamContext(txn, c);
@@ -133,7 +195,7 @@ class KeyManagerImpl implements KeyManager, Service, EventListener {
if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t);
return null;
}
StreamContext ctx = null;
StreamContext ctx;
Transaction txn = db.startTransaction(false);
try {
ctx = m.getStreamContext(txn, tag);

View File

@@ -0,0 +1,34 @@
package org.briarproject.bramble.transport;
import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.transport.KeySetId;
import javax.annotation.Nullable;
public class MutableKeySet {
private final KeySetId keySetId;
@Nullable
private final ContactId contactId;
private final MutableTransportKeys transportKeys;
public MutableKeySet(KeySetId keySetId, @Nullable ContactId contactId,
MutableTransportKeys transportKeys) {
this.keySetId = keySetId;
this.contactId = contactId;
this.transportKeys = transportKeys;
}
public KeySetId getKeySetId() {
return keySetId;
}
@Nullable
public ContactId getContactId() {
return contactId;
}
public MutableTransportKeys getTransportKeys() {
return transportKeys;
}
}

View File

@@ -13,17 +13,19 @@ class MutableOutgoingKeys {
private final SecretKey tagKey, headerKey;
private final long rotationPeriod;
private long streamCounter;
private boolean active;
MutableOutgoingKeys(OutgoingKeys out) {
tagKey = out.getTagKey();
headerKey = out.getHeaderKey();
rotationPeriod = out.getRotationPeriod();
streamCounter = out.getStreamCounter();
active = out.isActive();
}
OutgoingKeys snapshot() {
return new OutgoingKeys(tagKey, headerKey, rotationPeriod,
streamCounter);
streamCounter, active);
}
SecretKey getTagKey() {
@@ -45,4 +47,12 @@ class MutableOutgoingKeys {
void incrementStreamCounter() {
streamCounter++;
}
boolean isActive() {
return active;
}
void activate() {
active = true;
}
}

View File

@@ -5,6 +5,7 @@ import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.db.DbException;
import org.briarproject.bramble.api.db.Transaction;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.StreamContext;
import javax.annotation.Nullable;
@@ -17,8 +18,19 @@ interface TransportKeyManager {
void addContact(Transaction txn, ContactId c, SecretKey master,
long timestamp, boolean alice) throws DbException;
KeySetId addUnboundKeys(Transaction txn, SecretKey master, long timestamp,
boolean alice) throws DbException;
void bindKeys(Transaction txn, ContactId c, KeySetId k) throws DbException;
void activateKeys(Transaction txn, KeySetId k) throws DbException;
void removeKeys(Transaction txn, KeySetId k) throws DbException;
void removeContact(ContactId c);
boolean canSendOutgoingStreams(ContactId c);
@Nullable
StreamContext getStreamContext(Transaction txn, ContactId c)
throws DbException;

View File

@@ -11,19 +11,24 @@ import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.system.Scheduler;
import org.briarproject.bramble.api.transport.KeySet;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.StreamContext;
import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.transport.ReorderingWindow.Change;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger;
import javax.annotation.Nullable;
import javax.annotation.concurrent.ThreadSafe;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
@@ -47,12 +52,13 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private final Clock clock;
private final TransportId transportId;
private final long rotationPeriodLength;
private final ReentrantLock lock;
private final AtomicBoolean used = new AtomicBoolean(false);
private final ReentrantLock lock = new ReentrantLock();
// The following are locking: lock
private final Map<Bytes, TagContext> inContexts;
private final Map<ContactId, MutableOutgoingKeys> outContexts;
private final Map<ContactId, MutableTransportKeys> keys;
private final Map<KeySetId, MutableKeySet> keys = new HashMap<>();
private final Map<Bytes, TagContext> inContexts = new HashMap<>();
private final Map<ContactId, MutableKeySet> outContexts = new HashMap<>();
TransportKeyManagerImpl(DatabaseComponent db,
TransportCrypto transportCrypto, Executor dbExecutor,
@@ -65,20 +71,16 @@ class TransportKeyManagerImpl implements TransportKeyManager {
this.clock = clock;
this.transportId = transportId;
rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE;
lock = new ReentrantLock();
inContexts = new HashMap<>();
outContexts = new HashMap<>();
keys = new HashMap<>();
}
@Override
public void start(Transaction txn) throws DbException {
if (used.getAndSet(true)) throw new IllegalStateException();
long now = clock.currentTimeMillis();
lock.lock();
try {
// Load the transport keys from the DB
Map<ContactId, TransportKeys> loaded =
db.getTransportKeys(txn, transportId);
Collection<KeySet> loaded = db.getTransportKeys(txn, transportId);
// Rotate the keys to the current rotation period
RotationResult rotationResult = rotateKeys(loaded, now);
// Initialise mutable state for all contacts
@@ -93,41 +95,48 @@ class TransportKeyManagerImpl implements TransportKeyManager {
scheduleKeyRotation(now);
}
private RotationResult rotateKeys(Map<ContactId, TransportKeys> keys,
long now) {
private RotationResult rotateKeys(Collection<KeySet> keys, long now) {
RotationResult rotationResult = new RotationResult();
long rotationPeriod = now / rotationPeriodLength;
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
ContactId c = e.getKey();
TransportKeys k = e.getValue();
for (KeySet ks : keys) {
TransportKeys k = ks.getTransportKeys();
TransportKeys k1 =
transportCrypto.rotateTransportKeys(k, rotationPeriod);
KeySet ks1 = new KeySet(ks.getKeySetId(), ks.getContactId(), k1);
if (k1.getRotationPeriod() > k.getRotationPeriod())
rotationResult.rotated.put(c, k1);
rotationResult.current.put(c, k1);
rotationResult.rotated.add(ks1);
rotationResult.current.add(ks1);
}
return rotationResult;
}
// Locking: lock
private void addKeys(Map<ContactId, TransportKeys> m) {
for (Entry<ContactId, TransportKeys> e : m.entrySet())
addKeys(e.getKey(), new MutableTransportKeys(e.getValue()));
private void addKeys(Collection<KeySet> keys) {
for (KeySet ks : keys) {
addKeys(ks.getKeySetId(), ks.getContactId(),
new MutableTransportKeys(ks.getTransportKeys()));
}
}
// Locking: lock
private void addKeys(ContactId c, MutableTransportKeys m) {
encodeTags(c, m.getPreviousIncomingKeys());
encodeTags(c, m.getCurrentIncomingKeys());
encodeTags(c, m.getNextIncomingKeys());
outContexts.put(c, m.getCurrentOutgoingKeys());
keys.put(c, m);
private void addKeys(KeySetId keySetId, @Nullable ContactId contactId,
MutableTransportKeys m) {
MutableKeySet ks = new MutableKeySet(keySetId, contactId, m);
keys.put(keySetId, ks);
if (contactId != null) {
encodeTags(keySetId, contactId, m.getPreviousIncomingKeys());
encodeTags(keySetId, contactId, m.getCurrentIncomingKeys());
encodeTags(keySetId, contactId, m.getNextIncomingKeys());
considerReplacingOutgoingKeys(ks);
}
}
// Locking: lock
private void encodeTags(ContactId c, MutableIncomingKeys inKeys) {
private void encodeTags(KeySetId keySetId, ContactId contactId,
MutableIncomingKeys inKeys) {
for (long streamNumber : inKeys.getWindow().getUnseen()) {
TagContext tagCtx = new TagContext(c, inKeys, streamNumber);
TagContext tagCtx =
new TagContext(keySetId, contactId, inKeys, streamNumber);
byte[] tag = new byte[TAG_LENGTH];
transportCrypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION,
streamNumber);
@@ -135,6 +144,17 @@ class TransportKeyManagerImpl implements TransportKeyManager {
}
}
// Locking: lock
private void considerReplacingOutgoingKeys(MutableKeySet ks) {
// Use the active outgoing keys with the highest key set ID
if (ks.getTransportKeys().getCurrentOutgoingKeys().isActive()) {
MutableKeySet old = outContexts.get(ks.getContactId());
if (old == null ||
old.getKeySetId().getInt() < ks.getKeySetId().getInt())
outContexts.put(ks.getContactId(), ks);
}
}
private void scheduleKeyRotation(long now) {
long delay = rotationPeriodLength - now % rotationPeriodLength;
scheduler.schedule((Runnable) this::rotateKeys, delay, MILLISECONDS);
@@ -159,20 +179,82 @@ class TransportKeyManagerImpl implements TransportKeyManager {
@Override
public void addContact(Transaction txn, ContactId c, SecretKey master,
long timestamp, boolean alice) throws DbException {
deriveAndAddKeys(txn, c, master, timestamp, alice, true);
}
@Override
public KeySetId addUnboundKeys(Transaction txn, SecretKey master,
long timestamp, boolean alice) throws DbException {
return deriveAndAddKeys(txn, null, master, timestamp, alice, false);
}
private KeySetId deriveAndAddKeys(Transaction txn, @Nullable ContactId c,
SecretKey master, long timestamp, boolean alice, boolean active)
throws DbException {
lock.lock();
try {
// Work out what rotation period the timestamp belongs to
long rotationPeriod = timestamp / rotationPeriodLength;
// Derive the transport keys
TransportKeys k = transportCrypto.deriveTransportKeys(transportId,
master, rotationPeriod, alice);
master, rotationPeriod, alice, active);
// Rotate the keys to the current rotation period if necessary
rotationPeriod = clock.currentTimeMillis() / rotationPeriodLength;
k = transportCrypto.rotateTransportKeys(k, rotationPeriod);
// Initialise mutable state for the contact
addKeys(c, new MutableTransportKeys(k));
// Write the keys back to the DB
db.addTransportKeys(txn, c, k);
KeySetId keySetId = db.addTransportKeys(txn, c, k);
// Initialise mutable state for the contact
addKeys(keySetId, c, new MutableTransportKeys(k));
return keySetId;
} finally {
lock.unlock();
}
}
@Override
public void bindKeys(Transaction txn, ContactId c, KeySetId k)
throws DbException {
lock.lock();
try {
MutableKeySet ks = keys.get(k);
if (ks == null) throw new IllegalArgumentException();
// Check that the keys haven't already been bound
if (ks.getContactId() != null) throw new IllegalArgumentException();
MutableTransportKeys m = ks.getTransportKeys();
addKeys(k, c, m);
db.bindTransportKeys(txn, c, m.getTransportId(), k);
} finally {
lock.unlock();
}
}
@Override
public void activateKeys(Transaction txn, KeySetId k) throws DbException {
lock.lock();
try {
MutableKeySet ks = keys.get(k);
if (ks == null) throw new IllegalArgumentException();
// Check that the keys have been bound
if (ks.getContactId() == null) throw new IllegalArgumentException();
MutableTransportKeys m = ks.getTransportKeys();
m.getCurrentOutgoingKeys().activate();
considerReplacingOutgoingKeys(ks);
db.setTransportKeysActive(txn, m.getTransportId(), k);
} finally {
lock.unlock();
}
}
@Override
public void removeKeys(Transaction txn, KeySetId k) throws DbException {
lock.lock();
try {
MutableKeySet ks = keys.remove(k);
if (ks == null) throw new IllegalArgumentException();
// Check that the keys haven't been bound
if (ks.getContactId() != null) throw new IllegalArgumentException();
TransportId t = ks.getTransportKeys().getTransportId();
db.removeTransportKeys(txn, t, k);
} finally {
lock.unlock();
}
@@ -183,12 +265,29 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock();
try {
// Remove mutable state for the contact
Iterator<Entry<Bytes, TagContext>> it =
inContexts.entrySet().iterator();
while (it.hasNext())
if (it.next().getValue().contactId.equals(c)) it.remove();
Iterator<TagContext> it = inContexts.values().iterator();
while (it.hasNext()) if (it.next().contactId.equals(c)) it.remove();
outContexts.remove(c);
keys.remove(c);
Iterator<MutableKeySet> it1 = keys.values().iterator();
while (it1.hasNext()) {
ContactId c1 = it1.next().getContactId();
if (c1 != null && c1.equals(c)) it1.remove();
}
} finally {
lock.unlock();
}
}
@Override
public boolean canSendOutgoingStreams(ContactId c) {
lock.lock();
try {
MutableKeySet ks = outContexts.get(c);
if (ks == null) return false;
MutableOutgoingKeys outKeys =
ks.getTransportKeys().getCurrentOutgoingKeys();
if (!outKeys.isActive()) throw new AssertionError();
return outKeys.getStreamCounter() <= MAX_32_BIT_UNSIGNED;
} finally {
lock.unlock();
}
@@ -200,8 +299,11 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock();
try {
// Look up the outgoing keys for the contact
MutableOutgoingKeys outKeys = outContexts.get(c);
if (outKeys == null) return null;
MutableKeySet ks = outContexts.get(c);
if (ks == null) return null;
MutableOutgoingKeys outKeys =
ks.getTransportKeys().getCurrentOutgoingKeys();
if (!outKeys.isActive()) throw new AssertionError();
if (outKeys.getStreamCounter() > MAX_32_BIT_UNSIGNED) return null;
// Create a stream context
StreamContext ctx = new StreamContext(c, transportId,
@@ -209,8 +311,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
outKeys.getStreamCounter());
// Increment the stream counter and write it back to the DB
outKeys.incrementStreamCounter();
db.incrementStreamCounter(txn, c, transportId,
outKeys.getRotationPeriod());
db.incrementStreamCounter(txn, transportId, ks.getKeySetId());
return ctx;
} finally {
lock.unlock();
@@ -238,8 +339,9 @@ class TransportKeyManagerImpl implements TransportKeyManager {
byte[] addTag = new byte[TAG_LENGTH];
transportCrypto.encodeTag(addTag, inKeys.getTagKey(),
PROTOCOL_VERSION, streamNumber);
inContexts.put(new Bytes(addTag), new TagContext(
tagCtx.contactId, inKeys, streamNumber));
TagContext tagCtx1 = new TagContext(tagCtx.keySetId,
tagCtx.contactId, inKeys, streamNumber);
inContexts.put(new Bytes(addTag), tagCtx1);
}
// Remove tags for any stream numbers removed from the window
for (long streamNumber : change.getRemoved()) {
@@ -250,9 +352,19 @@ class TransportKeyManagerImpl implements TransportKeyManager {
inContexts.remove(new Bytes(removeTag));
}
// Write the window back to the DB
db.setReorderingWindow(txn, tagCtx.contactId, transportId,
db.setReorderingWindow(txn, tagCtx.keySetId, transportId,
inKeys.getRotationPeriod(), window.getBase(),
window.getBitmap());
// If the outgoing keys are inactive, activate them
MutableKeySet ks = keys.get(tagCtx.keySetId);
MutableOutgoingKeys outKeys =
ks.getTransportKeys().getCurrentOutgoingKeys();
if (!outKeys.isActive()) {
LOG.info("Activating outgoing keys");
outKeys.activate();
considerReplacingOutgoingKeys(ks);
db.setTransportKeysActive(txn, transportId, tagCtx.keySetId);
}
return ctx;
} finally {
lock.unlock();
@@ -264,9 +376,11 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock();
try {
// Rotate the keys to the current rotation period
Map<ContactId, TransportKeys> snapshot = new HashMap<>();
for (Entry<ContactId, MutableTransportKeys> e : keys.entrySet())
snapshot.put(e.getKey(), e.getValue().snapshot());
Collection<KeySet> snapshot = new ArrayList<>(keys.size());
for (MutableKeySet ks : keys.values()) {
snapshot.add(new KeySet(ks.getKeySetId(), ks.getContactId(),
ks.getTransportKeys().snapshot()));
}
RotationResult rotationResult = rotateKeys(snapshot, now);
// Rebuild the mutable state for all contacts
inContexts.clear();
@@ -285,12 +399,14 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private static class TagContext {
private final KeySetId keySetId;
private final ContactId contactId;
private final MutableIncomingKeys inKeys;
private final long streamNumber;
private TagContext(ContactId contactId, MutableIncomingKeys inKeys,
long streamNumber) {
private TagContext(KeySetId keySetId, ContactId contactId,
MutableIncomingKeys inKeys, long streamNumber) {
this.keySetId = keySetId;
this.contactId = contactId;
this.inKeys = inKeys;
this.streamNumber = streamNumber;
@@ -299,11 +415,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private static class RotationResult {
private final Map<ContactId, TransportKeys> current, rotated;
private RotationResult() {
current = new HashMap<>();
rotated = new HashMap<>();
}
private final Collection<KeySet> current = new ArrayList<>();
private final Collection<KeySet> rotated = new ArrayList<>();
}
}