Add a flag to indicate whether outgoing keys are active.

This commit is contained in:
akwizgran
2018-03-28 11:41:50 +01:00
parent 57e6f2ea9c
commit 6787d29f11
16 changed files with 201 additions and 69 deletions

View File

@@ -36,7 +36,8 @@ class TransportCryptoImpl implements TransportCrypto {
@Override
public TransportKeys deriveTransportKeys(TransportId t,
SecretKey master, long rotationPeriod, boolean alice) {
SecretKey master, long rotationPeriod, boolean alice,
boolean active) {
// Keys for the previous period are derived from the master secret
SecretKey inTagPrev = deriveTagKey(master, t, !alice);
SecretKey inHeaderPrev = deriveHeaderKey(master, t, !alice);
@@ -57,7 +58,7 @@ class TransportCryptoImpl implements TransportCrypto {
IncomingKeys inNext = new IncomingKeys(inTagNext, inHeaderNext,
rotationPeriod + 1);
OutgoingKeys outCurr = new OutgoingKeys(outTagCurr, outHeaderCurr,
rotationPeriod);
rotationPeriod, active);
// Collect and return the keys
return new TransportKeys(t, inPrev, inCurr, inNext, outCurr);
}
@@ -71,6 +72,7 @@ class TransportCryptoImpl implements TransportCrypto {
IncomingKeys inNext = k.getNextIncomingKeys();
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
long startPeriod = outCurr.getRotationPeriod();
boolean active = outCurr.isActive();
// Rotate the keys
for (long p = startPeriod + 1; p <= rotationPeriod; p++) {
inPrev = inCurr;
@@ -80,7 +82,7 @@ class TransportCryptoImpl implements TransportCrypto {
inNext = new IncomingKeys(inNextTag, inNextHeader, p + 1);
SecretKey outCurrTag = rotateKey(outCurr.getTagKey(), p);
SecretKey outCurrHeader = rotateKey(outCurr.getHeaderKey(), p);
outCurr = new OutgoingKeys(outCurrTag, outCurrHeader, p);
outCurr = new OutgoingKeys(outCurrTag, outCurrHeader, p, active);
}
// Collect and return the keys
return new TransportKeys(k.getTransportId(), inPrev, inCurr, inNext,

View File

@@ -640,6 +640,12 @@ interface Database<T> {
void setReorderingWindow(T txn, KeySetId k, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException;
/**
* Marks the given transport keys as usable for outgoing streams.
*/
void setTransportKeysActive(T txn, TransportId t, KeySetId k)
throws DbException;
/**
* Updates the transmission count and expiry time of the given message
* with respect to the given contact, using the latency of the transport

View File

@@ -890,6 +890,16 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
db.setReorderingWindow(txn, k, t, rotationPeriod, base, bitmap);
}
@Override
public void setTransportKeysActive(Transaction transaction, TransportId t,
KeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
db.setTransportKeysActive(txn, t, k);
}
@Override
public void updateTransportKeys(Transaction transaction,
Collection<KeySet> keys) throws DbException {

View File

@@ -234,6 +234,7 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL,"
+ " stream BIGINT NOT NULL,"
+ " active BOOLEAN NOT NULL,"
+ " PRIMARY KEY (transportId, keySetId),"
+ " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)"
@@ -880,8 +881,8 @@ abstract class JdbcDatabase implements Database<Connection> {
try {
// Store the outgoing keys
String sql = "INSERT INTO outgoingKeys (contactId, transportId,"
+ " rotationPeriod, tagKey, headerKey, stream)"
+ " VALUES (?, ?, ?, ?, ?, ?)";
+ " rotationPeriod, tagKey, headerKey, stream, active)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
if (c == null) ps.setNull(1, INTEGER);
else ps.setInt(1, c.getInt());
@@ -891,6 +892,7 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.setBytes(4, outCurr.getTagKey().getBytes());
ps.setBytes(5, outCurr.getHeaderKey().getBytes());
ps.setLong(6, outCurr.getStreamCounter());
ps.setBoolean(7, outCurr.isActive());
int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException();
ps.close();
@@ -2157,7 +2159,7 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.close();
// Retrieve the outgoing keys in the same order
sql = "SELECT keySetId, contactId, rotationPeriod,"
+ " tagKey, headerKey, stream"
+ " tagKey, headerKey, stream, active"
+ " FROM outgoingKeys"
+ " WHERE transportId = ?"
+ " ORDER BY keySetId";
@@ -2175,8 +2177,9 @@ abstract class JdbcDatabase implements Database<Connection> {
SecretKey tagKey = new SecretKey(rs.getBytes(4));
SecretKey headerKey = new SecretKey(rs.getBytes(5));
long streamCounter = rs.getLong(6);
boolean active = rs.getBoolean(7);
OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey,
rotationPeriod, streamCounter);
rotationPeriod, streamCounter, active);
IncomingKeys inPrev = inKeys.get(i * 3);
IncomingKeys inCurr = inKeys.get(i * 3 + 1);
IncomingKeys inNext = inKeys.get(i * 3 + 2);
@@ -2889,6 +2892,23 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
@Override
public void setTransportKeysActive(Connection txn, TransportId t,
KeySetId k) throws DbException {
PreparedStatement ps = null;
try {
String sql = "UPDATE outgoingKeys SET active = true"
+ " WHERE transportId = ? AND keySetId = ?";
ps = txn.prepareStatement(sql);
int affected = ps.executeUpdate();
if (affected < 0 || affected > 1) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
}
@Override
public void updateExpiryTime(Connection txn, ContactId c, MessageId m,
int maxLatency) throws DbException {

View File

@@ -132,6 +132,20 @@ class KeyManagerImpl implements KeyManager, Service, EventListener {
}
}
@Override
public void activateKeys(Transaction txn, Map<TransportId, KeySetId> keys)
throws DbException {
for (Entry<TransportId, KeySetId> e : keys.entrySet()) {
TransportId t = e.getKey();
TransportKeyManager m = managers.get(t);
if (m == null) {
if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t);
} else {
m.activateKeys(txn, e.getValue());
}
}
}
@Override
public void removeKeys(Transaction txn, Map<TransportId, KeySetId> keys)
throws DbException {

View File

@@ -13,17 +13,19 @@ class MutableOutgoingKeys {
private final SecretKey tagKey, headerKey;
private final long rotationPeriod;
private long streamCounter;
private boolean active;
MutableOutgoingKeys(OutgoingKeys out) {
tagKey = out.getTagKey();
headerKey = out.getHeaderKey();
rotationPeriod = out.getRotationPeriod();
streamCounter = out.getStreamCounter();
active = out.isActive();
}
OutgoingKeys snapshot() {
return new OutgoingKeys(tagKey, headerKey, rotationPeriod,
streamCounter);
streamCounter, active);
}
SecretKey getTagKey() {
@@ -45,4 +47,12 @@ class MutableOutgoingKeys {
void incrementStreamCounter() {
streamCounter++;
}
boolean isActive() {
return active;
}
void activate() {
active = true;
}
}

View File

@@ -23,6 +23,8 @@ interface TransportKeyManager {
void bindKeys(Transaction txn, ContactId c, KeySetId k) throws DbException;
void activateKeys(Transaction txn, KeySetId k) throws DbException;
void removeKeys(Transaction txn, KeySetId k) throws DbException;
void removeContact(ContactId c);

View File

@@ -127,10 +127,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
encodeTags(keySetId, contactId, m.getPreviousIncomingKeys());
encodeTags(keySetId, contactId, m.getCurrentIncomingKeys());
encodeTags(keySetId, contactId, m.getNextIncomingKeys());
// Use the outgoing keys with the highest key set ID
MutableKeySet old = outContexts.get(contactId);
if (old == null || old.getKeySetId().getInt() < keySetId.getInt())
outContexts.put(contactId, ks);
considerReplacingOutgoingKeys(ks);
}
}
@@ -147,6 +144,17 @@ class TransportKeyManagerImpl implements TransportKeyManager {
}
}
// Locking: lock
private void considerReplacingOutgoingKeys(MutableKeySet ks) {
// Use the active outgoing keys with the highest key set ID
if (ks.getTransportKeys().getCurrentOutgoingKeys().isActive()) {
MutableKeySet old = outContexts.get(ks.getContactId());
if (old == null ||
old.getKeySetId().getInt() < ks.getKeySetId().getInt())
outContexts.put(ks.getContactId(), ks);
}
}
private void scheduleKeyRotation(long now) {
long delay = rotationPeriodLength - now % rotationPeriodLength;
scheduler.schedule((Runnable) this::rotateKeys, delay, MILLISECONDS);
@@ -171,17 +179,17 @@ class TransportKeyManagerImpl implements TransportKeyManager {
@Override
public void addContact(Transaction txn, ContactId c, SecretKey master,
long timestamp, boolean alice) throws DbException {
deriveAndAddKeys(txn, c, master, timestamp, alice);
deriveAndAddKeys(txn, c, master, timestamp, alice, true);
}
@Override
public KeySetId addUnboundKeys(Transaction txn, SecretKey master,
long timestamp, boolean alice) throws DbException {
return deriveAndAddKeys(txn, null, master, timestamp, alice);
return deriveAndAddKeys(txn, null, master, timestamp, alice, false);
}
private KeySetId deriveAndAddKeys(Transaction txn, @Nullable ContactId c,
SecretKey master, long timestamp, boolean alice)
SecretKey master, long timestamp, boolean alice, boolean active)
throws DbException {
lock.lock();
try {
@@ -189,7 +197,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
long rotationPeriod = timestamp / rotationPeriodLength;
// Derive the transport keys
TransportKeys k = transportCrypto.deriveTransportKeys(transportId,
master, rotationPeriod, alice);
master, rotationPeriod, alice, active);
// Rotate the keys to the current rotation period if necessary
rotationPeriod = clock.currentTimeMillis() / rotationPeriodLength;
k = transportCrypto.rotateTransportKeys(k, rotationPeriod);
@@ -210,6 +218,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
try {
MutableKeySet ks = keys.get(k);
if (ks == null) throw new IllegalArgumentException();
// Check that the keys haven't already been bound
if (ks.getContactId() != null) throw new IllegalArgumentException();
MutableTransportKeys m = ks.getTransportKeys();
addKeys(k, c, m);
@@ -219,12 +228,30 @@ class TransportKeyManagerImpl implements TransportKeyManager {
}
}
@Override
public void activateKeys(Transaction txn, KeySetId k) throws DbException {
lock.lock();
try {
MutableKeySet ks = keys.get(k);
if (ks == null) throw new IllegalArgumentException();
// Check that the keys have been bound
if (ks.getContactId() == null) throw new IllegalArgumentException();
MutableTransportKeys m = ks.getTransportKeys();
m.getCurrentOutgoingKeys().activate();
considerReplacingOutgoingKeys(ks);
db.setTransportKeysActive(txn, m.getTransportId(), k);
} finally {
lock.unlock();
}
}
@Override
public void removeKeys(Transaction txn, KeySetId k) throws DbException {
lock.lock();
try {
MutableKeySet ks = keys.remove(k);
if (ks == null) throw new IllegalArgumentException();
// Check that the keys haven't been bound
if (ks.getContactId() != null) throw new IllegalArgumentException();
TransportId t = ks.getTransportKeys().getTransportId();
db.removeTransportKeys(txn, t, k);