Merge branch 'merge-handshake-and-transport-keys' into 'master'

Add support for handshake keys to KeyManager

See merge request briar/briar!1088
This commit is contained in:
Torsten Grote
2019-05-15 16:27:33 +00:00
38 changed files with 1043 additions and 1492 deletions

View File

@@ -4,7 +4,6 @@ import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.crypto.TransportCrypto;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.transport.HandshakeKeys;
import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.TransportKeys;
@@ -42,7 +41,7 @@ class TransportCryptoImpl implements TransportCrypto {
}
@Override
public TransportKeys deriveTransportKeys(TransportId t,
public TransportKeys deriveRotationKeys(TransportId t,
SecretKey rootKey, long timePeriod, boolean weAreAlice,
boolean active) {
// Keys for the previous period are derived from the root key
@@ -70,31 +69,6 @@ class TransportCryptoImpl implements TransportCrypto {
return new TransportKeys(t, inPrev, inCurr, inNext, outCurr);
}
@Override
public TransportKeys rotateTransportKeys(TransportKeys k, long timePeriod) {
if (k.getTimePeriod() >= timePeriod) return k;
IncomingKeys inPrev = k.getPreviousIncomingKeys();
IncomingKeys inCurr = k.getCurrentIncomingKeys();
IncomingKeys inNext = k.getNextIncomingKeys();
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
long startPeriod = outCurr.getTimePeriod();
boolean active = outCurr.isActive();
// Rotate the keys
for (long p = startPeriod + 1; p <= timePeriod; p++) {
inPrev = inCurr;
inCurr = inNext;
SecretKey inNextTag = rotateKey(inNext.getTagKey(), p + 1);
SecretKey inNextHeader = rotateKey(inNext.getHeaderKey(), p + 1);
inNext = new IncomingKeys(inNextTag, inNextHeader, p + 1);
SecretKey outCurrTag = rotateKey(outCurr.getTagKey(), p);
SecretKey outCurrHeader = rotateKey(outCurr.getHeaderKey(), p);
outCurr = new OutgoingKeys(outCurrTag, outCurrHeader, p, active);
}
// Collect and return the keys
return new TransportKeys(k.getTransportId(), inPrev, inCurr, inNext,
outCurr);
}
private SecretKey rotateKey(SecretKey k, long timePeriod) {
byte[] period = new byte[INT_64_BYTES];
writeUint64(timePeriod, period, 0);
@@ -117,7 +91,7 @@ class TransportCryptoImpl implements TransportCrypto {
}
@Override
public HandshakeKeys deriveHandshakeKeys(TransportId t, SecretKey rootKey,
public TransportKeys deriveHandshakeKeys(TransportId t, SecretKey rootKey,
long timePeriod, boolean weAreAlice) {
if (timePeriod < 1) throw new IllegalArgumentException();
IncomingKeys inPrev = deriveIncomingHandshakeKeys(t, rootKey,
@@ -128,7 +102,7 @@ class TransportCryptoImpl implements TransportCrypto {
weAreAlice, timePeriod + 1);
OutgoingKeys outCurr = deriveOutgoingHandshakeKeys(t, rootKey,
weAreAlice, timePeriod);
return new HandshakeKeys(t, inPrev, inCurr, inNext, outCurr, rootKey,
return new TransportKeys(t, inPrev, inCurr, inNext, outCurr, rootKey,
weAreAlice);
}
@@ -171,7 +145,13 @@ class TransportCryptoImpl implements TransportCrypto {
}
@Override
public HandshakeKeys updateHandshakeKeys(HandshakeKeys k, long timePeriod) {
public TransportKeys updateTransportKeys(TransportKeys k, long timePeriod) {
if (k.isHandshakeMode()) return updateHandshakeKeys(k, timePeriod);
else return updateRotationKeys(k, timePeriod);
}
private TransportKeys updateHandshakeKeys(TransportKeys k,
long timePeriod) {
long elapsed = timePeriod - k.getTimePeriod();
TransportId t = k.getTransportId();
SecretKey rootKey = k.getRootKey();
@@ -188,7 +168,7 @@ class TransportCryptoImpl implements TransportCrypto {
weAreAlice, timePeriod + 1);
OutgoingKeys outCurr = deriveOutgoingHandshakeKeys(t, rootKey,
weAreAlice, timePeriod);
return new HandshakeKeys(t, inPrev, inCurr, inNext, outCurr,
return new TransportKeys(t, inPrev, inCurr, inNext, outCurr,
rootKey, weAreAlice);
} else if (elapsed == 2) {
// The keys are two periods old - shift by two periods, keeping
@@ -200,7 +180,7 @@ class TransportCryptoImpl implements TransportCrypto {
weAreAlice, timePeriod + 1);
OutgoingKeys outCurr = deriveOutgoingHandshakeKeys(t, rootKey,
weAreAlice, timePeriod);
return new HandshakeKeys(t, inPrev, inCurr, inNext, outCurr,
return new TransportKeys(t, inPrev, inCurr, inNext, outCurr,
rootKey, weAreAlice);
} else {
// The keys are more than two periods old - derive fresh keys
@@ -208,6 +188,30 @@ class TransportCryptoImpl implements TransportCrypto {
}
}
private TransportKeys updateRotationKeys(TransportKeys k, long timePeriod) {
if (k.getTimePeriod() >= timePeriod) return k;
IncomingKeys inPrev = k.getPreviousIncomingKeys();
IncomingKeys inCurr = k.getCurrentIncomingKeys();
IncomingKeys inNext = k.getNextIncomingKeys();
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
long startPeriod = outCurr.getTimePeriod();
boolean active = outCurr.isActive();
// Rotate the keys
for (long p = startPeriod + 1; p <= timePeriod; p++) {
inPrev = inCurr;
inCurr = inNext;
SecretKey inNextTag = rotateKey(inNext.getTagKey(), p + 1);
SecretKey inNextHeader = rotateKey(inNext.getHeaderKey(), p + 1);
inNext = new IncomingKeys(inNextTag, inNextHeader, p + 1);
SecretKey outCurrTag = rotateKey(outCurr.getTagKey(), p);
SecretKey outCurrHeader = rotateKey(outCurr.getHeaderKey(), p);
outCurr = new OutgoingKeys(outCurrTag, outCurrHeader, p, active);
}
// Collect and return the keys
return new TransportKeys(k.getTransportId(), inPrev, inCurr, inNext,
outCurr);
}
@Override
public void encodeTag(byte[] tag, SecretKey tagKey, int protocolVersion,
long streamNumber) {

View File

@@ -29,11 +29,8 @@ import org.briarproject.bramble.api.sync.Message;
import org.briarproject.bramble.api.sync.MessageId;
import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.validation.MessageState;
import org.briarproject.bramble.api.transport.HandshakeKeySet;
import org.briarproject.bramble.api.transport.HandshakeKeySetId;
import org.briarproject.bramble.api.transport.HandshakeKeys;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.TransportKeySet;
import org.briarproject.bramble.api.transport.TransportKeySetId;
import org.briarproject.bramble.api.transport.TransportKeys;
import java.util.Collection;
@@ -107,20 +104,6 @@ interface Database<T> {
void addGroupVisibility(T txn, ContactId c, GroupId g, boolean shared)
throws DbException;
/**
* Stores the given handshake keys for the given contact and returns a
* key set ID.
*/
HandshakeKeySetId addHandshakeKeys(T txn, ContactId c, HandshakeKeys k)
throws DbException;
/**
* Stores the given handshake keys for the given pending contact and
* returns a key set ID.
*/
HandshakeKeySetId addHandshakeKeys(T txn, PendingContactId p,
HandshakeKeys k) throws DbException;
/**
* Stores an identity.
*/
@@ -162,7 +145,14 @@ interface Database<T> {
* Stores the given transport keys for the given contact and returns a
* key set ID.
*/
TransportKeySetId addTransportKeys(T txn, ContactId c, TransportKeys k)
KeySetId addTransportKeys(T txn, ContactId c, TransportKeys k)
throws DbException;
/**
* Stores the given transport keys for the given pending contact and
* returns a key set ID.
*/
KeySetId addTransportKeys(T txn, PendingContactId p, TransportKeys k)
throws DbException;
/**
@@ -275,7 +265,7 @@ interface Database<T> {
* <p/>
* Read-only.
*/
Collection<ContactId> getContacts(T txn, AuthorId a) throws DbException;
Collection<ContactId> getContacts(T txn, AuthorId local) throws DbException;
/**
* Returns the group with the given ID.
@@ -317,14 +307,6 @@ interface Database<T> {
Map<ContactId, Boolean> getGroupVisibility(T txn, GroupId g)
throws DbException;
/**
* Returns all handshake keys for the given transport.
* <p/>
* Read-only.
*/
Collection<HandshakeKeySet> getHandshakeKeys(T txn, TransportId t)
throws DbException;
/**
* Returns the identity for local pseudonym with the given ID.
* <p/>
@@ -547,16 +529,10 @@ interface Database<T> {
Collection<TransportKeySet> getTransportKeys(T txn, TransportId t)
throws DbException;
/**
* Increments the outgoing stream counter for the given handshake keys.
*/
void incrementStreamCounter(T txn, TransportId t, HandshakeKeySetId k)
throws DbException;
/**
* Increments the outgoing stream counter for the given transport keys.
*/
void incrementStreamCounter(T txn, TransportId t, TransportKeySetId k)
void incrementStreamCounter(T txn, TransportId t, KeySetId k)
throws DbException;
/**
@@ -625,12 +601,6 @@ interface Database<T> {
void removeGroupVisibility(T txn, ContactId c, GroupId g)
throws DbException;
/**
* Removes the given handshake keys from the database.
*/
void removeHandshakeKeys(T txn, TransportId t, HandshakeKeySetId k)
throws DbException;
/**
* Removes an identity (and all associated state) from the database.
*/
@@ -661,8 +631,7 @@ interface Database<T> {
/**
* Removes the given transport keys from the database.
*/
void removeTransportKeys(T txn, TransportId t, TransportKeySetId k)
throws DbException;
void removeTransportKeys(T txn, TransportId t, KeySetId k) throws DbException;
/**
* Resets the transmission count and expiry time of the given message with
@@ -712,23 +681,16 @@ interface Database<T> {
PendingContactState state) throws DbException;
/**
* Sets the reordering window for the given transport key set in the given
* Sets the reordering window for the given transport keys in the given
* time period.
*/
void setReorderingWindow(T txn, TransportKeySetId k, TransportId t,
long timePeriod, long base, byte[] bitmap) throws DbException;
/**
* Sets the reordering window for the given handshake key set in the given
* time period.
*/
void setReorderingWindow(T txn, HandshakeKeySetId k, TransportId t,
void setReorderingWindow(T txn, KeySetId k, TransportId t,
long timePeriod, long base, byte[] bitmap) throws DbException;
/**
* Marks the given transport keys as usable for outgoing streams.
*/
void setTransportKeysActive(T txn, TransportId t, TransportKeySetId k)
void setTransportKeysActive(T txn, TransportId t, KeySetId k)
throws DbException;
/**
@@ -740,12 +702,7 @@ interface Database<T> {
throws DbException;
/**
* Updates the given handshake keys.
*/
void updateHandshakeKeys(T txn, HandshakeKeySet ks) throws DbException;
/**
* Updates the given transport keys following key rotation.
* Stores the given transport keys, deleting any keys they have replaced.
*/
void updateTransportKeys(T txn, TransportKeySet ks) throws DbException;
}

View File

@@ -64,11 +64,8 @@ import org.briarproject.bramble.api.sync.event.MessageToRequestEvent;
import org.briarproject.bramble.api.sync.event.MessagesAckedEvent;
import org.briarproject.bramble.api.sync.event.MessagesSentEvent;
import org.briarproject.bramble.api.sync.validation.MessageState;
import org.briarproject.bramble.api.transport.HandshakeKeySet;
import org.briarproject.bramble.api.transport.HandshakeKeySetId;
import org.briarproject.bramble.api.transport.HandshakeKeys;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.TransportKeySet;
import org.briarproject.bramble.api.transport.TransportKeySetId;
import org.briarproject.bramble.api.transport.TransportKeys;
import java.util.ArrayList;
@@ -260,30 +257,6 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
}
}
@Override
public HandshakeKeySetId addHandshakeKeys(Transaction transaction,
ContactId c, HandshakeKeys k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsContact(txn, c))
throw new NoSuchContactException();
if (!db.containsTransport(txn, k.getTransportId()))
throw new NoSuchTransportException();
return db.addHandshakeKeys(txn, c, k);
}
@Override
public HandshakeKeySetId addHandshakeKeys(Transaction transaction,
PendingContactId p, HandshakeKeys k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsPendingContact(txn, p))
throw new NoSuchPendingContactException();
if (!db.containsTransport(txn, k.getTransportId()))
throw new NoSuchTransportException();
return db.addHandshakeKeys(txn, p, k);
}
@Override
public void addIdentity(Transaction transaction, Identity i)
throws DbException {
@@ -332,8 +305,8 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
}
@Override
public TransportKeySetId addTransportKeys(Transaction transaction,
ContactId c, TransportKeys k) throws DbException {
public KeySetId addTransportKeys(Transaction transaction, ContactId c,
TransportKeys k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsContact(txn, c))
@@ -343,6 +316,18 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
return db.addTransportKeys(txn, c, k);
}
@Override
public KeySetId addTransportKeys(Transaction transaction,
PendingContactId p, TransportKeys k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsPendingContact(txn, p))
throw new NoSuchPendingContactException();
if (!db.containsTransport(txn, k.getTransportId()))
throw new NoSuchTransportException();
return db.addTransportKeys(txn, p, k);
}
@Override
public boolean containsContact(Transaction transaction, AuthorId remote,
AuthorId local) throws DbException {
@@ -505,11 +490,11 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
@Override
public Collection<ContactId> getContacts(Transaction transaction,
AuthorId a) throws DbException {
AuthorId local) throws DbException {
T txn = unbox(transaction);
if (!db.containsIdentity(txn, a))
if (!db.containsIdentity(txn, local))
throw new NoSuchIdentityException();
return db.getContacts(txn, a);
return db.getContacts(txn, local);
}
@Override
@@ -546,15 +531,6 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
return db.getGroupVisibility(txn, c, g);
}
@Override
public Collection<HandshakeKeySet> getHandshakeKeys(Transaction transaction,
TransportId t) throws DbException {
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
return db.getHandshakeKeys(txn, t);
}
@Override
public Identity getIdentity(Transaction transaction, AuthorId a)
throws DbException {
@@ -737,17 +713,7 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
@Override
public void incrementStreamCounter(Transaction transaction, TransportId t,
HandshakeKeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
db.incrementStreamCounter(txn, t, k);
}
@Override
public void incrementStreamCounter(Transaction transaction, TransportId t,
TransportKeySetId k) throws DbException {
KeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
@@ -896,16 +862,6 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
transaction.attach(new GroupVisibilityUpdatedEvent(affected));
}
@Override
public void removeHandshakeKeys(Transaction transaction,
TransportId t, HandshakeKeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
db.removeHandshakeKeys(txn, t, k);
}
@Override
public void removeIdentity(Transaction transaction, AuthorId a)
throws DbException {
@@ -949,8 +905,8 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
}
@Override
public void removeTransportKeys(Transaction transaction,
TransportId t, TransportKeySetId k) throws DbException {
public void removeTransportKeys(Transaction transaction, TransportId t,
KeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
@@ -1048,20 +1004,9 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
}
@Override
public void setReorderingWindow(Transaction transaction,
TransportKeySetId k, TransportId t, long timePeriod, long base,
byte[] bitmap) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
db.setReorderingWindow(txn, k, t, timePeriod, base, bitmap);
}
@Override
public void setReorderingWindow(Transaction transaction,
HandshakeKeySetId k, TransportId t, long timePeriod, long base,
byte[] bitmap) throws DbException {
public void setReorderingWindow(Transaction transaction, KeySetId k,
TransportId t, long timePeriod, long base, byte[] bitmap)
throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
@@ -1071,7 +1016,7 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
@Override
public void setTransportKeysActive(Transaction transaction, TransportId t,
TransportKeySetId k) throws DbException {
KeySetId k) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
if (!db.containsTransport(txn, t))
@@ -1079,18 +1024,6 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
db.setTransportKeysActive(txn, t, k);
}
@Override
public void updateHandshakeKeys(Transaction transaction,
Collection<HandshakeKeySet> keys) throws DbException {
if (transaction.isReadOnly()) throw new IllegalArgumentException();
T txn = unbox(transaction);
for (HandshakeKeySet ks : keys) {
TransportId t = ks.getKeys().getTransportId();
if (db.containsTransport(txn, t))
db.updateHandshakeKeys(txn, ks);
}
}
@Override
public void updateTransportKeys(Transaction transaction,
Collection<TransportKeySet> keys) throws DbException {

View File

@@ -36,13 +36,10 @@ import org.briarproject.bramble.api.sync.MessageId;
import org.briarproject.bramble.api.sync.MessageStatus;
import org.briarproject.bramble.api.sync.validation.MessageState;
import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.transport.HandshakeKeySet;
import org.briarproject.bramble.api.transport.HandshakeKeySetId;
import org.briarproject.bramble.api.transport.HandshakeKeys;
import org.briarproject.bramble.api.transport.IncomingKeys;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.OutgoingKeys;
import org.briarproject.bramble.api.transport.TransportKeySet;
import org.briarproject.bramble.api.transport.TransportKeySetId;
import org.briarproject.bramble.api.transport.TransportKeys;
import java.sql.Connection;
@@ -68,11 +65,13 @@ import java.util.logging.Logger;
import javax.annotation.Nullable;
import static java.sql.Types.BINARY;
import static java.sql.Types.BOOLEAN;
import static java.sql.Types.INTEGER;
import static java.sql.Types.VARCHAR;
import static java.util.Arrays.asList;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import static java.util.logging.Logger.getLogger;
import static org.briarproject.bramble.api.db.Metadata.REMOVE;
import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE;
import static org.briarproject.bramble.api.sync.Group.Visibility.SHARED;
@@ -99,7 +98,7 @@ import static org.briarproject.bramble.util.LogUtils.now;
abstract class JdbcDatabase implements Database<Connection> {
// Package access for testing
static final int CODE_SCHEMA_VERSION = 43;
static final int CODE_SCHEMA_VERSION = 44;
// Time period offsets for incoming transport keys
private static final int OFFSET_PREV = -1;
@@ -260,16 +259,28 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " maxLatency INT NOT NULL,"
+ " PRIMARY KEY (transportId))";
private static final String CREATE_PENDING_CONTACTS =
"CREATE TABLE pendingContacts"
+ " (pendingContactId _HASH NOT NULL,"
+ " publicKey _BINARY NOT NULL,"
+ " alias _STRING NOT NULL,"
+ " state INT NOT NULL,"
+ " timestamp BIGINT NOT NULL,"
+ " PRIMARY KEY (pendingContactId))";
private static final String CREATE_OUTGOING_KEYS =
"CREATE TABLE outgoingKeys"
+ " (transportId _STRING NOT NULL,"
+ " keySetId _COUNTER,"
+ " timePeriod BIGINT NOT NULL,"
+ " contactId INT NOT NULL,"
+ " contactId INT," // Null if contact is pending
+ " pendingContactId _HASH," // Null if not pending
+ " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL,"
+ " stream BIGINT NOT NULL,"
+ " active BOOLEAN NOT NULL,"
+ " rootKey _SECRET," // Null for rotation keys
+ " alice BOOLEAN," // Null for rotation keys
+ " PRIMARY KEY (transportId, keySetId),"
+ " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)"
@@ -277,6 +288,9 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " UNIQUE (keySetId),"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE,"
+ " FOREIGN KEY (pendingContactId)"
+ " REFERENCES pendingContacts (pendingContactId)"
+ " ON DELETE CASCADE)";
private static final String CREATE_INCOMING_KEYS =
@@ -297,57 +311,6 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " REFERENCES outgoingKeys (keySetId)"
+ " ON DELETE CASCADE)";
private static final String CREATE_PENDING_CONTACTS =
"CREATE TABLE pendingContacts"
+ " (pendingContactId _HASH NOT NULL,"
+ " publicKey _BINARY NOT NULL,"
+ " alias _STRING NOT NULL,"
+ " state INT NOT NULL,"
+ " timestamp BIGINT NOT NULL,"
+ " PRIMARY KEY (pendingContactId))";
private static final String CREATE_OUTGOING_HANDSHAKE_KEYS =
"CREATE TABLE outgoingHandshakeKeys"
+ " (transportId _STRING NOT NULL,"
+ " keySetId _COUNTER,"
+ " timePeriod BIGINT NOT NULL,"
+ " contactId INT," // Null if contact is pending
+ " pendingContactId _HASH," // Null if not pending
+ " rootKey _SECRET NOT NULL,"
+ " alice BOOLEAN NOT NULL,"
+ " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL,"
+ " stream BIGINT NOT NULL,"
+ " PRIMARY KEY (transportId, keySetId),"
+ " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)"
+ " ON DELETE CASCADE,"
+ " UNIQUE (keySetId),"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE,"
+ " FOREIGN KEY (pendingContactId)"
+ " REFERENCES pendingContacts (pendingContactId)"
+ " ON DELETE CASCADE)";
private static final String CREATE_INCOMING_HANDSHAKE_KEYS =
"CREATE TABLE incomingHandshakeKeys"
+ " (transportId _STRING NOT NULL,"
+ " keySetId INT NOT NULL,"
+ " timePeriod BIGINT NOT NULL,"
+ " tagKey _SECRET NOT NULL,"
+ " headerKey _SECRET NOT NULL,"
+ " base BIGINT NOT NULL,"
+ " bitmap _BINARY NOT NULL,"
+ " periodOffset INT NOT NULL,"
+ " PRIMARY KEY (transportId, keySetId, periodOffset),"
+ " FOREIGN KEY (transportId)"
+ " REFERENCES transports (transportId)"
+ " ON DELETE CASCADE,"
+ " FOREIGN KEY (keySetId)"
+ " REFERENCES outgoingHandshakeKeys (keySetId)"
+ " ON DELETE CASCADE)";
private static final String INDEX_CONTACTS_BY_AUTHOR_ID =
"CREATE INDEX IF NOT EXISTS contactsByAuthorId"
+ " ON contacts (authorId)";
@@ -373,7 +336,7 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " ON statuses (contactId, timestamp)";
private static final Logger LOG =
Logger.getLogger(JdbcDatabase.class.getName());
getLogger(JdbcDatabase.class.getName());
// Different database libraries use different names for certain types
private final MessageFactory messageFactory;
@@ -493,7 +456,8 @@ abstract class JdbcDatabase implements Database<Connection> {
new Migration39_40(),
new Migration40_41(dbTypes),
new Migration41_42(dbTypes),
new Migration42_43(dbTypes)
new Migration42_43(dbTypes),
new Migration43_44(dbTypes)
);
}
@@ -541,13 +505,9 @@ abstract class JdbcDatabase implements Database<Connection> {
s.executeUpdate(dbTypes.replaceTypes(CREATE_OFFERS));
s.executeUpdate(dbTypes.replaceTypes(CREATE_STATUSES));
s.executeUpdate(dbTypes.replaceTypes(CREATE_TRANSPORTS));
s.executeUpdate(dbTypes.replaceTypes(CREATE_PENDING_CONTACTS));
s.executeUpdate(dbTypes.replaceTypes(CREATE_OUTGOING_KEYS));
s.executeUpdate(dbTypes.replaceTypes(CREATE_INCOMING_KEYS));
s.executeUpdate(dbTypes.replaceTypes(CREATE_PENDING_CONTACTS));
s.executeUpdate(dbTypes.replaceTypes(
CREATE_OUTGOING_HANDSHAKE_KEYS));
s.executeUpdate(dbTypes.replaceTypes(
CREATE_INCOMING_HANDSHAKE_KEYS));
s.close();
} catch (SQLException e) {
tryToClose(s, LOG, WARNING);
@@ -783,103 +743,6 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
@Override
public HandshakeKeySetId addHandshakeKeys(Connection txn, ContactId c,
HandshakeKeys k) throws DbException {
return addHandshakeKeys(txn, c, null, k);
}
@Override
public HandshakeKeySetId addHandshakeKeys(Connection txn,
PendingContactId p, HandshakeKeys k) throws DbException {
return addHandshakeKeys(txn, null, p, k);
}
private HandshakeKeySetId addHandshakeKeys(Connection txn,
@Nullable ContactId c, @Nullable PendingContactId p,
HandshakeKeys k) throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
// Store the outgoing keys
String sql = "INSERT INTO outgoingHandshakeKeys (contactId,"
+ " pendingContactId, transportId, rootKey, alice,"
+ " timePeriod, tagKey, headerKey, stream)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
if (c == null) ps.setNull(1, INTEGER);
else ps.setInt(1, c.getInt());
if (p == null) ps.setNull(2, BINARY);
else ps.setBytes(2, p.getBytes());
ps.setString(3, k.getTransportId().getString());
ps.setBytes(4, k.getRootKey().getBytes());
ps.setBoolean(5, k.isAlice());
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
ps.setLong(6, outCurr.getTimePeriod());
ps.setBytes(7, outCurr.getTagKey().getBytes());
ps.setBytes(8, outCurr.getHeaderKey().getBytes());
ps.setLong(9, outCurr.getStreamCounter());
int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException();
ps.close();
// Get the new (highest) key set ID
sql = "SELECT keySetId FROM outgoingHandshakeKeys"
+ " ORDER BY keySetId DESC LIMIT 1";
ps = txn.prepareStatement(sql);
rs = ps.executeQuery();
if (!rs.next()) throw new DbStateException();
HandshakeKeySetId keySetId = new HandshakeKeySetId(rs.getInt(1));
if (rs.next()) throw new DbStateException();
rs.close();
ps.close();
// Store the incoming keys
sql = "INSERT INTO incomingHandshakeKeys (keySetId, transportId,"
+ " timePeriod, tagKey, headerKey, base, bitmap,"
+ " periodOffset)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setInt(1, keySetId.getInt());
ps.setString(2, k.getTransportId().getString());
// Previous time period
IncomingKeys inPrev = k.getPreviousIncomingKeys();
ps.setLong(3, inPrev.getTimePeriod());
ps.setBytes(4, inPrev.getTagKey().getBytes());
ps.setBytes(5, inPrev.getHeaderKey().getBytes());
ps.setLong(6, inPrev.getWindowBase());
ps.setBytes(7, inPrev.getWindowBitmap());
ps.setInt(8, OFFSET_PREV);
ps.addBatch();
// Current time period
IncomingKeys inCurr = k.getCurrentIncomingKeys();
ps.setLong(3, inCurr.getTimePeriod());
ps.setBytes(4, inCurr.getTagKey().getBytes());
ps.setBytes(5, inCurr.getHeaderKey().getBytes());
ps.setLong(6, inCurr.getWindowBase());
ps.setBytes(7, inCurr.getWindowBitmap());
ps.setInt(8, OFFSET_CURR);
ps.addBatch();
// Next time period
IncomingKeys inNext = k.getNextIncomingKeys();
ps.setLong(3, inNext.getTimePeriod());
ps.setBytes(4, inNext.getTagKey().getBytes());
ps.setBytes(5, inNext.getHeaderKey().getBytes());
ps.setLong(6, inNext.getWindowBase());
ps.setBytes(7, inNext.getWindowBitmap());
ps.setInt(8, OFFSET_NEXT);
ps.addBatch();
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != 3) throw new DbStateException();
for (int rows : batchAffected)
if (rows != 1) throw new DbStateException();
ps.close();
return keySetId;
} catch (SQLException e) {
tryToClose(rs, LOG, WARNING);
tryToClose(ps, LOG, WARNING);
throw new DbException(e);
}
}
@Override
public void addIdentity(Connection txn, Identity i) throws DbException {
PreparedStatement ps = null;
@@ -1107,24 +970,47 @@ abstract class JdbcDatabase implements Database<Connection> {
}
@Override
public TransportKeySetId addTransportKeys(Connection txn, ContactId c,
public KeySetId addTransportKeys(Connection txn, ContactId c,
TransportKeys k) throws DbException {
return addTransportKeys(txn, c, null, k);
}
@Override
public KeySetId addTransportKeys(Connection txn,
PendingContactId p, TransportKeys k) throws DbException {
return addTransportKeys(txn, null, p, k);
}
private KeySetId addTransportKeys(Connection txn,
@Nullable ContactId c, @Nullable PendingContactId p,
TransportKeys k) throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
// Store the outgoing keys
String sql = "INSERT INTO outgoingKeys (contactId, transportId,"
+ " timePeriod, tagKey, headerKey, stream, active)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?)";
String sql = "INSERT INTO outgoingKeys (transportId, timePeriod,"
+ " contactId, pendingContactId, tagKey, headerKey,"
+ " stream, active, rootKey, alice)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setString(2, k.getTransportId().getString());
ps.setString(1, k.getTransportId().getString());
ps.setLong(2, k.getTimePeriod());
if (c == null) ps.setNull(3, INTEGER);
else ps.setInt(3, c.getInt());
if (p == null) ps.setNull(4, BINARY);
else ps.setBytes(4, p.getBytes());
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
ps.setLong(3, outCurr.getTimePeriod());
ps.setBytes(4, outCurr.getTagKey().getBytes());
ps.setBytes(5, outCurr.getHeaderKey().getBytes());
ps.setLong(6, outCurr.getStreamCounter());
ps.setBoolean(7, outCurr.isActive());
ps.setBytes(5, outCurr.getTagKey().getBytes());
ps.setBytes(6, outCurr.getHeaderKey().getBytes());
ps.setLong(7, outCurr.getStreamCounter());
ps.setBoolean(8, outCurr.isActive());
if (k.isHandshakeMode()) {
ps.setBytes(9, k.getRootKey().getBytes());
ps.setBoolean(10, k.isAlice());
} else {
ps.setNull(9, BINARY);
ps.setNull(10, BOOLEAN);
}
int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException();
ps.close();
@@ -1134,18 +1020,18 @@ abstract class JdbcDatabase implements Database<Connection> {
ps = txn.prepareStatement(sql);
rs = ps.executeQuery();
if (!rs.next()) throw new DbStateException();
TransportKeySetId keySetId = new TransportKeySetId(rs.getInt(1));
KeySetId keySetId = new KeySetId(rs.getInt(1));
if (rs.next()) throw new DbStateException();
rs.close();
ps.close();
// Store the incoming keys
sql = "INSERT INTO incomingKeys (keySetId, transportId,"
sql = "INSERT INTO incomingKeys (transportId, keySetId,"
+ " timePeriod, tagKey, headerKey, base, bitmap,"
+ " periodOffset)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setInt(1, keySetId.getInt());
ps.setString(2, k.getTransportId().getString());
ps.setString(1, k.getTransportId().getString());
ps.setInt(2, keySetId.getInt());
// Previous time period
IncomingKeys inPrev = k.getPreviousIncomingKeys();
ps.setLong(3, inPrev.getTimePeriod());
@@ -1673,86 +1559,6 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
@Override
public Collection<HandshakeKeySet> getHandshakeKeys(Connection txn,
TransportId t) throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
// Retrieve the incoming keys
String sql = "SELECT timePeriod, tagKey, headerKey, base, bitmap"
+ " FROM incomingHandshakeKeys"
+ " WHERE transportId = ?"
+ " ORDER BY keySetId, periodOffset";
ps = txn.prepareStatement(sql);
ps.setString(1, t.getString());
rs = ps.executeQuery();
List<IncomingKeys> inKeys = new ArrayList<>();
while (rs.next()) {
long timePeriod = rs.getLong(1);
SecretKey tagKey = new SecretKey(rs.getBytes(2));
SecretKey headerKey = new SecretKey(rs.getBytes(3));
long windowBase = rs.getLong(4);
byte[] windowBitmap = rs.getBytes(5);
inKeys.add(new IncomingKeys(tagKey, headerKey, timePeriod,
windowBase, windowBitmap));
}
rs.close();
ps.close();
// Retrieve the outgoing keys in the same order
sql = "SELECT keySetId, contactId, pendingContactId, timePeriod,"
+ " tagKey, headerKey, rootKey, alice, stream"
+ " FROM outgoingHandshakeKeys"
+ " WHERE transportId = ?"
+ " ORDER BY keySetId";
ps = txn.prepareStatement(sql);
ps.setString(1, t.getString());
rs = ps.executeQuery();
Collection<HandshakeKeySet> keys = new ArrayList<>();
for (int i = 0; rs.next(); i++) {
// There should be three times as many incoming keys
if (inKeys.size() < (i + 1) * 3) throw new DbStateException();
HandshakeKeySetId keySetId =
new HandshakeKeySetId(rs.getInt(1));
ContactId contactId = null;
int cId = rs.getInt(2);
if (!rs.wasNull()) contactId = new ContactId(cId);
PendingContactId pendingContactId = null;
byte[] pId = rs.getBytes(3);
if (!rs.wasNull()) pendingContactId = new PendingContactId(pId);
long timePeriod = rs.getLong(4);
SecretKey tagKey = new SecretKey(rs.getBytes(5));
SecretKey headerKey = new SecretKey(rs.getBytes(6));
SecretKey rootKey = new SecretKey(rs.getBytes(7));
boolean alice = rs.getBoolean(8);
long streamCounter = rs.getLong(9);
OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey,
timePeriod, streamCounter, true);
IncomingKeys inPrev = inKeys.get(i * 3);
IncomingKeys inCurr = inKeys.get(i * 3 + 1);
IncomingKeys inNext = inKeys.get(i * 3 + 2);
HandshakeKeys handshakeKeys = new HandshakeKeys(t, inPrev,
inCurr, inNext, outCurr, rootKey, alice);
if (contactId == null) {
if (pendingContactId == null) throw new DbStateException();
keys.add(new HandshakeKeySet(keySetId, pendingContactId,
handshakeKeys));
} else {
if (pendingContactId != null) throw new DbStateException();
keys.add(new HandshakeKeySet(keySetId, contactId,
handshakeKeys));
}
}
rs.close();
ps.close();
return keys;
} catch (SQLException e) {
tryToClose(rs, LOG, WARNING);
tryToClose(ps, LOG, WARNING);
throw new DbException(e);
}
}
@Override
public Identity getIdentity(Connection txn, AuthorId a) throws DbException {
PreparedStatement ps = null;
@@ -2522,8 +2328,8 @@ abstract class JdbcDatabase implements Database<Connection> {
rs.close();
ps.close();
// Retrieve the outgoing keys in the same order
sql = "SELECT keySetId, contactId, timePeriod,"
+ " tagKey, headerKey, stream, active"
sql = "SELECT keySetId, timePeriod, contactId, pendingContactId,"
+ " tagKey, headerKey, stream, active, rootKey, alice"
+ " FROM outgoingKeys"
+ " WHERE transportId = ?"
+ " ORDER BY keySetId";
@@ -2534,23 +2340,34 @@ abstract class JdbcDatabase implements Database<Connection> {
for (int i = 0; rs.next(); i++) {
// There should be three times as many incoming keys
if (inKeys.size() < (i + 1) * 3) throw new DbStateException();
TransportKeySetId keySetId =
new TransportKeySetId(rs.getInt(1));
ContactId contactId = new ContactId(rs.getInt(2));
long timePeriod = rs.getLong(3);
SecretKey tagKey = new SecretKey(rs.getBytes(4));
SecretKey headerKey = new SecretKey(rs.getBytes(5));
long streamCounter = rs.getLong(6);
boolean active = rs.getBoolean(7);
KeySetId keySetId = new KeySetId(rs.getInt(1));
long timePeriod = rs.getLong(2);
int cId = rs.getInt(3);
ContactId contactId = rs.wasNull() ? null : new ContactId(cId);
byte[] pId = rs.getBytes(4);
PendingContactId pendingContactId = pId == null ?
null : new PendingContactId(pId);
SecretKey tagKey = new SecretKey(rs.getBytes(5));
SecretKey headerKey = new SecretKey(rs.getBytes(6));
long streamCounter = rs.getLong(7);
boolean active = rs.getBoolean(8);
byte[] rootKey = rs.getBytes(9);
boolean alice = rs.getBoolean(10);
OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey,
timePeriod, streamCounter, active);
IncomingKeys inPrev = inKeys.get(i * 3);
IncomingKeys inCurr = inKeys.get(i * 3 + 1);
IncomingKeys inNext = inKeys.get(i * 3 + 2);
TransportKeys transportKeys = new TransportKeys(t, inPrev,
inCurr, inNext, outCurr);
TransportKeys transportKeys;
if (rootKey == null) {
transportKeys = new TransportKeys(t, inPrev, inCurr,
inNext, outCurr);
} else {
transportKeys = new TransportKeys(t, inPrev, inCurr,
inNext, outCurr, new SecretKey(rootKey), alice);
}
keys.add(new TransportKeySet(keySetId, contactId,
transportKeys));
pendingContactId, transportKeys));
}
rs.close();
ps.close();
@@ -2564,26 +2381,7 @@ abstract class JdbcDatabase implements Database<Connection> {
@Override
public void incrementStreamCounter(Connection txn, TransportId t,
HandshakeKeySetId k) throws DbException {
PreparedStatement ps = null;
try {
String sql = "UPDATE outgoingHandshakeKeys SET stream = stream + 1"
+ " WHERE transportId = ? AND keySetId = ?";
ps = txn.prepareStatement(sql);
ps.setString(1, t.getString());
ps.setInt(2, k.getInt());
int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps, LOG, WARNING);
throw new DbException(e);
}
}
@Override
public void incrementStreamCounter(Connection txn, TransportId t,
TransportKeySetId k) throws DbException {
KeySetId k) throws DbException {
PreparedStatement ps = null;
try {
String sql = "UPDATE outgoingKeys SET stream = stream + 1"
@@ -2961,27 +2759,6 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
@Override
public void removeHandshakeKeys(Connection txn, TransportId t,
HandshakeKeySetId k) throws DbException {
PreparedStatement ps = null;
try {
// Delete any existing outgoing keys - this will also remove any
// incoming keys with the same key set ID
String sql = "DELETE FROM outgoingHandshakeKeys"
+ " WHERE transportId = ? AND keySetId = ?";
ps = txn.prepareStatement(sql);
ps.setString(1, t.getString());
ps.setInt(2, k.getInt());
int affected = ps.executeUpdate();
if (affected < 0) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps, LOG, WARNING);
throw new DbException(e);
}
}
@Override
public void removeIdentity(Connection txn, AuthorId a) throws DbException {
PreparedStatement ps = null;
@@ -3094,8 +2871,8 @@ abstract class JdbcDatabase implements Database<Connection> {
}
@Override
public void removeTransportKeys(Connection txn, TransportId t,
TransportKeySetId k) throws DbException {
public void removeTransportKeys(Connection txn, TransportId t, KeySetId k)
throws DbException {
PreparedStatement ps = null;
try {
// Delete any existing outgoing keys - this will also remove any
@@ -3311,7 +3088,7 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.setInt(1, state.getValue());
ps.setBytes(2, p.getBytes());
int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException();
if (affected < 0 || affected > 1) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps, LOG, WARNING);
@@ -3320,7 +3097,7 @@ abstract class JdbcDatabase implements Database<Connection> {
}
@Override
public void setReorderingWindow(Connection txn, TransportKeySetId k,
public void setReorderingWindow(Connection txn, KeySetId k,
TransportId t, long timePeriod, long base, byte[] bitmap)
throws DbException {
PreparedStatement ps = null;
@@ -3343,33 +3120,9 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
@Override
public void setReorderingWindow(Connection txn, HandshakeKeySetId k,
TransportId t, long timePeriod, long base, byte[] bitmap)
throws DbException {
PreparedStatement ps = null;
try {
String sql = "UPDATE incomingHandshakeKeys SET base = ?, bitmap = ?"
+ " WHERE transportId = ? AND keySetId = ?"
+ " AND timePeriod = ?";
ps = txn.prepareStatement(sql);
ps.setLong(1, base);
ps.setBytes(2, bitmap);
ps.setString(3, t.getString());
ps.setInt(4, k.getInt());
ps.setLong(5, timePeriod);
int affected = ps.executeUpdate();
if (affected < 0 || affected > 1) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps, LOG, WARNING);
throw new DbException(e);
}
}
@Override
public void setTransportKeysActive(Connection txn, TransportId t,
TransportKeySetId k) throws DbException {
KeySetId k) throws DbException {
PreparedStatement ps = null;
try {
String sql = "UPDATE outgoingKeys SET active = true"
@@ -3489,71 +3242,4 @@ abstract class JdbcDatabase implements Database<Connection> {
throw new DbException(e);
}
}
@Override
public void updateHandshakeKeys(Connection txn, HandshakeKeySet ks)
throws DbException {
PreparedStatement ps = null;
try {
// Update the outgoing keys
String sql = "UPDATE outgoingHandshakeKeys SET timePeriod = ?,"
+ " tagKey = ?, headerKey = ?, stream = ?"
+ " WHERE transportId = ? AND keySetId = ?";
ps = txn.prepareStatement(sql);
HandshakeKeys k = ks.getKeys();
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
ps.setLong(1, outCurr.getTimePeriod());
ps.setBytes(2, outCurr.getTagKey().getBytes());
ps.setBytes(3, outCurr.getHeaderKey().getBytes());
ps.setLong(4, outCurr.getStreamCounter());
ps.setString(5, k.getTransportId().getString());
ps.setInt(6, ks.getKeySetId().getInt());
int affected = ps.executeUpdate();
if (affected < 0 || affected > 1) throw new DbStateException();
ps.close();
// Update the incoming keys
sql = "UPDATE incomingHandshakeKeys SET timePeriod = ?,"
+ " tagKey = ?, headerKey = ?, base = ?, bitmap = ?"
+ " WHERE transportId = ? AND keySetId = ?"
+ " AND periodOffset = ?";
ps = txn.prepareStatement(sql);
ps.setString(6, k.getTransportId().getString());
ps.setInt(7, ks.getKeySetId().getInt());
// Previous time period
IncomingKeys inPrev = k.getPreviousIncomingKeys();
ps.setLong(1, inPrev.getTimePeriod());
ps.setBytes(2, inPrev.getTagKey().getBytes());
ps.setBytes(3, inPrev.getHeaderKey().getBytes());
ps.setLong(4, inPrev.getWindowBase());
ps.setBytes(5, inPrev.getWindowBitmap());
ps.setInt(8, OFFSET_PREV);
ps.addBatch();
// Current time period
IncomingKeys inCurr = k.getCurrentIncomingKeys();
ps.setLong(1, inCurr.getTimePeriod());
ps.setBytes(2, inCurr.getTagKey().getBytes());
ps.setBytes(3, inCurr.getHeaderKey().getBytes());
ps.setLong(4, inCurr.getWindowBase());
ps.setBytes(5, inCurr.getWindowBitmap());
ps.setInt(8, OFFSET_CURR);
ps.addBatch();
// Next time period
IncomingKeys inNext = k.getNextIncomingKeys();
ps.setLong(1, inNext.getTimePeriod());
ps.setBytes(2, inNext.getTagKey().getBytes());
ps.setBytes(3, inNext.getHeaderKey().getBytes());
ps.setLong(4, inNext.getWindowBase());
ps.setBytes(5, inNext.getWindowBitmap());
ps.setInt(8, OFFSET_NEXT);
ps.addBatch();
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != 3) throw new DbStateException();
for (int rows : batchAffected)
if (rows < 0 || rows > 1) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps, LOG, WARNING);
throw new DbException(e);
}
}
}

View File

@@ -0,0 +1,58 @@
package org.briarproject.bramble.db;
import org.briarproject.bramble.api.db.DbException;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.logging.Logger;
import static java.util.logging.Level.WARNING;
import static java.util.logging.Logger.getLogger;
import static org.briarproject.bramble.db.JdbcUtils.tryToClose;
class Migration43_44 implements Migration<Connection> {
private static final Logger LOG = getLogger(Migration43_44.class.getName());
private final DatabaseTypes dbTypes;
Migration43_44(DatabaseTypes dbTypes) {
this.dbTypes = dbTypes;
}
@Override
public int getStartVersion() {
return 43;
}
@Override
public int getEndVersion() {
return 44;
}
@Override
public void migrate(Connection txn) throws DbException {
Statement s = null;
try {
s = txn.createStatement();
s.execute("DROP TABLE outgoingHandshakeKeys");
s.execute("DROP TABLE incomingHandshakeKeys");
s.execute("ALTER TABLE outgoingKeys"
+ " ALTER COLUMN contactId DROP NOT NULL");
s.execute(dbTypes.replaceTypes("ALTER TABLE outgoingKeys"
+ " ADD COLUMN pendingContactId _HASH"));
s.execute("ALTER TABLE outgoingKeys"
+ " ADD FOREIGN KEY (pendingContactId)"
+ " REFERENCES pendingContacts (pendingContactId)"
+ " ON DELETE CASCADE");
s.execute(dbTypes.replaceTypes("ALTER TABLE outgoingKeys"
+ " ADD COLUMN rootKey _SECRET"));
s.execute("ALTER TABLE outgoingKeys"
+ " ADD COLUMN alice BOOLEAN");
} catch (SQLException e) {
tryToClose(s, LOG, WARNING);
throw new DbException(e);
}
}
}

View File

@@ -96,6 +96,7 @@ class ConnectionManagerImpl implements ConnectionManager {
TransportConnectionReader r) throws IOException {
InputStream streamReader = streamReaderFactory.createStreamReader(
r.getInputStream(), ctx);
// TODO: Pending contacts, handshake mode
return syncSessionFactory.createIncomingSession(ctx.getContactId(),
streamReader);
}
@@ -104,6 +105,7 @@ class ConnectionManagerImpl implements ConnectionManager {
TransportConnectionWriter w) throws IOException {
StreamWriter streamWriter = streamWriterFactory.createStreamWriter(
w.getOutputStream(), ctx);
// TODO: Pending contacts, handshake mode
return syncSessionFactory.createSimplexOutgoingSession(
ctx.getContactId(), w.getMaxLatency(), streamWriter);
}
@@ -112,6 +114,7 @@ class ConnectionManagerImpl implements ConnectionManager {
TransportConnectionWriter w) throws IOException {
StreamWriter streamWriter = streamWriterFactory.createStreamWriter(
w.getOutputStream(), ctx);
// TODO: Pending contacts, handshake mode
return syncSessionFactory.createDuplexOutgoingSession(
ctx.getContactId(), w.getMaxLatency(), w.getMaxIdleTime(),
streamWriter);
@@ -145,6 +148,7 @@ class ConnectionManagerImpl implements ConnectionManager {
disposeReader(false, false);
return;
}
// TODO: Pending contacts
ContactId contactId = ctx.getContactId();
connectionRegistry.registerConnection(contactId, transportId, true);
try {
@@ -388,7 +392,7 @@ class ConnectionManagerImpl implements ConnectionManager {
return;
}
// Check that the stream comes from the expected contact
if (!ctx.getContactId().equals(contactId)) {
if (!contactId.equals(ctx.getContactId())) {
LOG.warning("Wrong contact ID for returning stream");
disposeReader(true, true);
return;

View File

@@ -17,8 +17,8 @@ import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexPluginFactory;
import org.briarproject.bramble.api.plugin.simplex.SimplexPluginFactory;
import org.briarproject.bramble.api.transport.KeyManager;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.StreamContext;
import org.briarproject.bramble.api.transport.TransportKeySetId;
import java.util.HashMap;
import java.util.Map;
@@ -88,10 +88,10 @@ class KeyManagerImpl implements KeyManager, Service, EventListener {
}
@Override
public Map<TransportId, TransportKeySetId> addContact(Transaction txn,
public Map<TransportId, KeySetId> addContact(Transaction txn,
ContactId c, SecretKey rootKey, long timestamp, boolean alice,
boolean active) throws DbException {
Map<TransportId, TransportKeySetId> ids = new HashMap<>();
Map<TransportId, KeySetId> ids = new HashMap<>();
for (Entry<TransportId, TransportKeyManager> e : managers.entrySet()) {
TransportId t = e.getKey();
TransportKeyManager m = e.getValue();
@@ -101,9 +101,9 @@ class KeyManagerImpl implements KeyManager, Service, EventListener {
}
@Override
public void activateKeys(Transaction txn, Map<TransportId,
TransportKeySetId> keys) throws DbException {
for (Entry<TransportId, TransportKeySetId> e : keys.entrySet()) {
public void activateKeys(Transaction txn, Map<TransportId, KeySetId> keys)
throws DbException {
for (Entry<TransportId, KeySetId> e : keys.entrySet()) {
TransportId t = e.getKey();
TransportKeyManager m = managers.get(t);
if (m == null) {

View File

@@ -1,30 +0,0 @@
package org.briarproject.bramble.transport;
import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.transport.TransportKeySetId;
class MutableKeySet {
private final TransportKeySetId keySetId;
private final ContactId contactId;
private final MutableTransportKeys transportKeys;
MutableKeySet(TransportKeySetId keySetId, ContactId contactId,
MutableTransportKeys transportKeys) {
this.keySetId = keySetId;
this.contactId = contactId;
this.transportKeys = transportKeys;
}
TransportKeySetId getKeySetId() {
return keySetId;
}
ContactId getContactId() {
return contactId;
}
MutableTransportKeys getTransportKeys() {
return transportKeys;
}
}

View File

@@ -0,0 +1,50 @@
package org.briarproject.bramble.transport;
import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.contact.PendingContactId;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.transport.KeySetId;
import javax.annotation.Nullable;
import javax.annotation.concurrent.NotThreadSafe;
@NotThreadSafe
@NotNullByDefault
class MutableTransportKeySet {
private final KeySetId keySetId;
@Nullable
private final ContactId contactId;
@Nullable
private final PendingContactId pendingContactId;
private final MutableTransportKeys keys;
MutableTransportKeySet(KeySetId keySetId, @Nullable ContactId contactId,
@Nullable PendingContactId pendingContactId,
MutableTransportKeys keys) {
if ((contactId == null) == (pendingContactId == null))
throw new IllegalArgumentException();
this.keySetId = keySetId;
this.contactId = contactId;
this.pendingContactId = pendingContactId;
this.keys = keys;
}
KeySetId getKeySetId() {
return keySetId;
}
@Nullable
ContactId getContactId() {
return contactId;
}
@Nullable
PendingContactId getPendingContactId() {
return pendingContactId;
}
MutableTransportKeys getKeys() {
return keys;
}
}

View File

@@ -1,9 +1,11 @@
package org.briarproject.bramble.transport;
import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.transport.TransportKeys;
import javax.annotation.Nullable;
import javax.annotation.concurrent.NotThreadSafe;
@NotThreadSafe
@@ -13,6 +15,9 @@ class MutableTransportKeys {
private final TransportId transportId;
private final MutableIncomingKeys inPrev, inCurr, inNext;
private final MutableOutgoingKeys outCurr;
@Nullable
private final SecretKey rootKey;
private final boolean alice;
MutableTransportKeys(TransportKeys k) {
transportId = k.getTransportId();
@@ -20,11 +25,24 @@ class MutableTransportKeys {
inCurr = new MutableIncomingKeys(k.getCurrentIncomingKeys());
inNext = new MutableIncomingKeys(k.getNextIncomingKeys());
outCurr = new MutableOutgoingKeys(k.getCurrentOutgoingKeys());
if (k.isHandshakeMode()) {
rootKey = k.getRootKey();
alice = k.isAlice();
} else {
rootKey = null;
alice = false;
}
}
TransportKeys snapshot() {
return new TransportKeys(transportId, inPrev.snapshot(),
inCurr.snapshot(), inNext.snapshot(), outCurr.snapshot());
if (rootKey == null) {
return new TransportKeys(transportId, inPrev.snapshot(),
inCurr.snapshot(), inNext.snapshot(), outCurr.snapshot());
} else {
return new TransportKeys(transportId, inPrev.snapshot(),
inCurr.snapshot(), inNext.snapshot(), outCurr.snapshot(),
rootKey, alice);
}
}
TransportId getTransportId() {
@@ -46,4 +64,18 @@ class MutableTransportKeys {
MutableOutgoingKeys getCurrentOutgoingKeys() {
return outCurr;
}
boolean isHandshakeMode() {
return rootKey != null;
}
SecretKey getRootKey() {
if (rootKey == null) throw new UnsupportedOperationException();
return rootKey;
}
boolean isAlice() {
if (rootKey == null) throw new UnsupportedOperationException();
return alice;
}
}

View File

@@ -5,8 +5,8 @@ import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.db.DbException;
import org.briarproject.bramble.api.db.Transaction;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.StreamContext;
import org.briarproject.bramble.api.transport.TransportKeySetId;
import javax.annotation.Nullable;
@@ -15,11 +15,10 @@ interface TransportKeyManager {
void start(Transaction txn) throws DbException;
TransportKeySetId addContact(Transaction txn, ContactId c,
SecretKey rootKey, long timestamp, boolean alice, boolean active)
throws DbException;
KeySetId addContact(Transaction txn, ContactId c, SecretKey rootKey,
long timestamp, boolean alice, boolean active) throws DbException;
void activateKeys(Transaction txn, TransportKeySetId k) throws DbException;
void activateKeys(Transaction txn, KeySetId k) throws DbException;
void removeContact(ContactId c);

View File

@@ -2,6 +2,7 @@ package org.briarproject.bramble.transport;
import org.briarproject.bramble.api.Bytes;
import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.contact.PendingContactId;
import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.crypto.TransportCrypto;
import org.briarproject.bramble.api.db.DatabaseComponent;
@@ -11,9 +12,9 @@ import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.system.Scheduler;
import org.briarproject.bramble.api.transport.KeySetId;
import org.briarproject.bramble.api.transport.StreamContext;
import org.briarproject.bramble.api.transport.TransportKeySet;
import org.briarproject.bramble.api.transport.TransportKeySetId;
import org.briarproject.bramble.api.transport.TransportKeys;
import org.briarproject.bramble.transport.ReorderingWindow.Change;
@@ -28,10 +29,13 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.logging.Level.WARNING;
import static java.util.logging.Logger.getLogger;
import static org.briarproject.bramble.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE;
import static org.briarproject.bramble.api.transport.TransportConstants.PROTOCOL_VERSION;
import static org.briarproject.bramble.api.transport.TransportConstants.TAG_LENGTH;
@@ -43,7 +47,7 @@ import static org.briarproject.bramble.util.LogUtils.logException;
class TransportKeyManagerImpl implements TransportKeyManager {
private static final Logger LOG =
Logger.getLogger(TransportKeyManagerImpl.class.getName());
getLogger(TransportKeyManagerImpl.class.getName());
private final DatabaseComponent db;
private final TransportCrypto transportCrypto;
@@ -55,10 +59,13 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private final AtomicBoolean used = new AtomicBoolean(false);
private final ReentrantLock lock = new ReentrantLock();
// The following are locking: lock
private final Map<TransportKeySetId, MutableKeySet> keys = new HashMap<>();
@GuardedBy("lock")
private final Map<KeySetId, MutableTransportKeySet> keys = new HashMap<>();
@GuardedBy("lock")
private final Map<Bytes, TagContext> inContexts = new HashMap<>();
private final Map<ContactId, MutableKeySet> outContexts = new HashMap<>();
@GuardedBy("lock")
private final Map<ContactId, MutableTransportKeySet> outContexts =
new HashMap<>();
TransportKeyManagerImpl(DatabaseComponent db,
TransportCrypto transportCrypto, Executor dbExecutor,
@@ -82,62 +89,70 @@ class TransportKeyManagerImpl implements TransportKeyManager {
// Load the transport keys from the DB
Collection<TransportKeySet> loaded =
db.getTransportKeys(txn, transportId);
// Rotate the keys to the current time period
RotationResult rotationResult = rotateKeys(loaded, now);
// Update the keys to the current time period
UpdateResult updateResult = updateKeys(loaded, now);
// Initialise mutable state for all contacts
addKeys(rotationResult.current);
// Write any rotated keys back to the DB
if (!rotationResult.rotated.isEmpty())
db.updateTransportKeys(txn, rotationResult.rotated);
addKeys(updateResult.current);
// Write any updated keys back to the DB
if (!updateResult.updated.isEmpty())
db.updateTransportKeys(txn, updateResult.updated);
} finally {
lock.unlock();
}
// Schedule the next key rotation
scheduleKeyRotation(now);
// Schedule the next key update
scheduleKeyUpdate(now);
}
private RotationResult rotateKeys(Collection<TransportKeySet> keys,
private UpdateResult updateKeys(Collection<TransportKeySet> keys,
long now) {
RotationResult rotationResult = new RotationResult();
UpdateResult updateResult = new UpdateResult();
long timePeriod = now / timePeriodLength;
for (TransportKeySet ks : keys) {
TransportKeys k = ks.getKeys();
TransportKeys k1 = transportCrypto.rotateTransportKeys(k,
TransportKeys k1 = transportCrypto.updateTransportKeys(k,
timePeriod);
TransportKeySet ks1 = new TransportKeySet(ks.getKeySetId(),
ks.getContactId(), k1);
ks.getContactId(), null, k1);
if (k1.getTimePeriod() > k.getTimePeriod())
rotationResult.rotated.add(ks1);
rotationResult.current.add(ks1);
updateResult.updated.add(ks1);
updateResult.current.add(ks1);
}
return rotationResult;
return updateResult;
}
// Locking: lock
@GuardedBy("lock")
private void addKeys(Collection<TransportKeySet> keys) {
for (TransportKeySet ks : keys) {
addKeys(ks.getKeySetId(), ks.getContactId(),
ks.getPendingContactId(),
new MutableTransportKeys(ks.getKeys()));
}
}
// Locking: lock
private void addKeys(TransportKeySetId keySetId, ContactId contactId,
MutableTransportKeys m) {
MutableKeySet ks = new MutableKeySet(keySetId, contactId, m);
keys.put(keySetId, ks);
encodeTags(keySetId, contactId, m.getPreviousIncomingKeys());
encodeTags(keySetId, contactId, m.getCurrentIncomingKeys());
encodeTags(keySetId, contactId, m.getNextIncomingKeys());
@GuardedBy("lock")
private void addKeys(KeySetId keySetId, @Nullable ContactId contactId,
@Nullable PendingContactId pendingContactId,
MutableTransportKeys keys) {
MutableTransportKeySet ks = new MutableTransportKeySet(keySetId,
contactId, pendingContactId, keys);
this.keys.put(keySetId, ks);
boolean handshakeMode = keys.isHandshakeMode();
encodeTags(keySetId, contactId, pendingContactId,
keys.getPreviousIncomingKeys(), handshakeMode);
encodeTags(keySetId, contactId, pendingContactId,
keys.getCurrentIncomingKeys(), handshakeMode);
encodeTags(keySetId, contactId, pendingContactId,
keys.getNextIncomingKeys(), handshakeMode);
considerReplacingOutgoingKeys(ks);
}
// Locking: lock
private void encodeTags(TransportKeySetId keySetId, ContactId contactId,
MutableIncomingKeys inKeys) {
@GuardedBy("lock")
private void encodeTags(KeySetId keySetId, @Nullable ContactId contactId,
@Nullable PendingContactId pendingContactId,
MutableIncomingKeys inKeys, boolean handshakeMode) {
for (long streamNumber : inKeys.getWindow().getUnseen()) {
TagContext tagCtx =
new TagContext(keySetId, contactId, inKeys, streamNumber);
TagContext tagCtx = new TagContext(keySetId, contactId,
pendingContactId, inKeys, streamNumber, handshakeMode);
byte[] tag = new byte[TAG_LENGTH];
transportCrypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION,
streamNumber);
@@ -145,27 +160,30 @@ class TransportKeyManagerImpl implements TransportKeyManager {
}
}
// Locking: lock
private void considerReplacingOutgoingKeys(MutableKeySet ks) {
@GuardedBy("lock")
private void considerReplacingOutgoingKeys(MutableTransportKeySet ks) {
// Use the active outgoing keys with the highest key set ID
if (ks.getTransportKeys().getCurrentOutgoingKeys().isActive()) {
MutableKeySet old = outContexts.get(ks.getContactId());
ContactId c = ks.getContactId();
if (c != null && ks.getKeys().getCurrentOutgoingKeys().isActive()) {
MutableTransportKeySet old = outContexts.get(c);
if (old == null ||
(old.getKeys().isHandshakeMode() &&
!ks.getKeys().isHandshakeMode()) ||
old.getKeySetId().getInt() < ks.getKeySetId().getInt()) {
outContexts.put(ks.getContactId(), ks);
outContexts.put(c, ks);
}
}
}
private void scheduleKeyRotation(long now) {
private void scheduleKeyUpdate(long now) {
long delay = timePeriodLength - now % timePeriodLength;
scheduler.schedule((Runnable) this::rotateKeys, delay, MILLISECONDS);
scheduler.schedule((Runnable) this::updateKeys, delay, MILLISECONDS);
}
private void rotateKeys() {
private void updateKeys() {
dbExecutor.execute(() -> {
try {
db.transaction(false, this::rotateKeys);
db.transaction(false, this::updateKeys);
} catch (DbException e) {
logException(LOG, WARNING, e);
}
@@ -173,23 +191,22 @@ class TransportKeyManagerImpl implements TransportKeyManager {
}
@Override
public TransportKeySetId addContact(Transaction txn, ContactId c,
SecretKey rootKey, long timestamp, boolean alice, boolean active)
throws DbException {
public KeySetId addContact(Transaction txn, ContactId c, SecretKey rootKey,
long timestamp, boolean alice, boolean active) throws DbException {
lock.lock();
try {
// Work out what time period the timestamp belongs to
long timePeriod = timestamp / timePeriodLength;
// Derive the transport keys
TransportKeys k = transportCrypto.deriveTransportKeys(transportId,
TransportKeys k = transportCrypto.deriveRotationKeys(transportId,
rootKey, timePeriod, alice, active);
// Rotate the keys to the current time period if necessary
// Update the keys to the current time period if necessary
timePeriod = clock.currentTimeMillis() / timePeriodLength;
k = transportCrypto.rotateTransportKeys(k, timePeriod);
k = transportCrypto.updateTransportKeys(k, timePeriod);
// Write the keys back to the DB
TransportKeySetId keySetId = db.addTransportKeys(txn, c, k);
KeySetId keySetId = db.addTransportKeys(txn, c, k);
// Initialise mutable state for the contact
addKeys(keySetId, c, new MutableTransportKeys(k));
addKeys(keySetId, c, null, new MutableTransportKeys(k));
return keySetId;
} finally {
lock.unlock();
@@ -197,13 +214,12 @@ class TransportKeyManagerImpl implements TransportKeyManager {
}
@Override
public void activateKeys(Transaction txn, TransportKeySetId k)
throws DbException {
public void activateKeys(Transaction txn, KeySetId k) throws DbException {
lock.lock();
try {
MutableKeySet ks = keys.get(k);
MutableTransportKeySet ks = keys.get(k);
if (ks == null) throw new IllegalArgumentException();
MutableTransportKeys m = ks.getTransportKeys();
MutableTransportKeys m = ks.getKeys();
m.getCurrentOutgoingKeys().activate();
considerReplacingOutgoingKeys(ks);
db.setTransportKeysActive(txn, m.getTransportId(), k);
@@ -218,13 +234,11 @@ class TransportKeyManagerImpl implements TransportKeyManager {
try {
// Remove mutable state for the contact
Iterator<TagContext> it = inContexts.values().iterator();
while (it.hasNext()) if (it.next().contactId.equals(c)) it.remove();
while (it.hasNext()) if (c.equals(it.next().contactId)) it.remove();
outContexts.remove(c);
Iterator<MutableKeySet> it1 = keys.values().iterator();
while (it1.hasNext()) {
ContactId c1 = it1.next().getContactId();
if (c1 != null && c1.equals(c)) it1.remove();
}
Iterator<MutableTransportKeySet> it1 = keys.values().iterator();
while (it1.hasNext())
if (c.equals(it1.next().getContactId())) it1.remove();
} finally {
lock.unlock();
}
@@ -234,10 +248,10 @@ class TransportKeyManagerImpl implements TransportKeyManager {
public boolean canSendOutgoingStreams(ContactId c) {
lock.lock();
try {
MutableKeySet ks = outContexts.get(c);
MutableTransportKeySet ks = outContexts.get(c);
if (ks == null) return false;
MutableOutgoingKeys outKeys =
ks.getTransportKeys().getCurrentOutgoingKeys();
ks.getKeys().getCurrentOutgoingKeys();
if (!outKeys.isActive()) throw new AssertionError();
return outKeys.getStreamCounter() <= MAX_32_BIT_UNSIGNED;
} finally {
@@ -251,16 +265,16 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock();
try {
// Look up the outgoing keys for the contact
MutableKeySet ks = outContexts.get(c);
MutableTransportKeySet ks = outContexts.get(c);
if (ks == null) return null;
MutableOutgoingKeys outKeys =
ks.getTransportKeys().getCurrentOutgoingKeys();
MutableTransportKeys keys = ks.getKeys();
MutableOutgoingKeys outKeys = keys.getCurrentOutgoingKeys();
if (!outKeys.isActive()) throw new AssertionError();
if (outKeys.getStreamCounter() > MAX_32_BIT_UNSIGNED) return null;
// Create a stream context
StreamContext ctx = new StreamContext(c, transportId,
StreamContext ctx = new StreamContext(c, null, transportId,
outKeys.getTagKey(), outKeys.getHeaderKey(),
outKeys.getStreamCounter());
outKeys.getStreamCounter(), keys.isHandshakeMode());
// Increment the stream counter and write it back to the DB
outKeys.incrementStreamCounter();
db.incrementStreamCounter(txn, transportId, ks.getKeySetId());
@@ -280,9 +294,10 @@ class TransportKeyManagerImpl implements TransportKeyManager {
if (tagCtx == null) return null;
MutableIncomingKeys inKeys = tagCtx.inKeys;
// Create a stream context
StreamContext ctx = new StreamContext(tagCtx.contactId, transportId,
StreamContext ctx = new StreamContext(tagCtx.contactId,
tagCtx.pendingContactId, transportId,
inKeys.getTagKey(), inKeys.getHeaderKey(),
tagCtx.streamNumber);
tagCtx.streamNumber, tagCtx.handshakeMode);
// Update the reordering window
ReorderingWindow window = inKeys.getWindow();
Change change = window.setSeen(tagCtx.streamNumber);
@@ -292,7 +307,8 @@ class TransportKeyManagerImpl implements TransportKeyManager {
transportCrypto.encodeTag(addTag, inKeys.getTagKey(),
PROTOCOL_VERSION, streamNumber);
TagContext tagCtx1 = new TagContext(tagCtx.keySetId,
tagCtx.contactId, inKeys, streamNumber);
tagCtx.contactId, tagCtx.pendingContactId, inKeys,
streamNumber, tagCtx.handshakeMode);
inContexts.put(new Bytes(addTag), tagCtx1);
}
// Remove tags for any stream numbers removed from the window
@@ -308,9 +324,9 @@ class TransportKeyManagerImpl implements TransportKeyManager {
inKeys.getTimePeriod(), window.getBase(),
window.getBitmap());
// If the outgoing keys are inactive, activate them
MutableKeySet ks = keys.get(tagCtx.keySetId);
MutableTransportKeySet ks = keys.get(tagCtx.keySetId);
MutableOutgoingKeys outKeys =
ks.getTransportKeys().getCurrentOutgoingKeys();
ks.getKeys().getCurrentOutgoingKeys();
if (!outKeys.isActive()) {
LOG.info("Activating outgoing keys");
outKeys.activate();
@@ -323,51 +339,60 @@ class TransportKeyManagerImpl implements TransportKeyManager {
}
}
private void rotateKeys(Transaction txn) throws DbException {
private void updateKeys(Transaction txn) throws DbException {
long now = clock.currentTimeMillis();
lock.lock();
try {
// Rotate the keys to the current time period
// Update the keys to the current time period
Collection<TransportKeySet> snapshot = new ArrayList<>(keys.size());
for (MutableKeySet ks : keys.values()) {
for (MutableTransportKeySet ks : keys.values()) {
snapshot.add(new TransportKeySet(ks.getKeySetId(),
ks.getContactId(), ks.getTransportKeys().snapshot()));
ks.getContactId(), ks.getPendingContactId(),
ks.getKeys().snapshot()));
}
RotationResult rotationResult = rotateKeys(snapshot, now);
UpdateResult updateResult = updateKeys(snapshot, now);
// Rebuild the mutable state for all contacts
inContexts.clear();
outContexts.clear();
keys.clear();
addKeys(rotationResult.current);
// Write any rotated keys back to the DB
if (!rotationResult.rotated.isEmpty())
db.updateTransportKeys(txn, rotationResult.rotated);
addKeys(updateResult.current);
// Write any updated keys back to the DB
if (!updateResult.updated.isEmpty())
db.updateTransportKeys(txn, updateResult.updated);
} finally {
lock.unlock();
}
// Schedule the next key rotation
scheduleKeyRotation(now);
// Schedule the next key update
scheduleKeyUpdate(now);
}
private static class TagContext {
private final TransportKeySetId keySetId;
private final KeySetId keySetId;
@Nullable
private final ContactId contactId;
@Nullable
private final PendingContactId pendingContactId;
private final MutableIncomingKeys inKeys;
private final long streamNumber;
private final boolean handshakeMode;
private TagContext(TransportKeySetId keySetId, ContactId contactId,
MutableIncomingKeys inKeys, long streamNumber) {
private TagContext(KeySetId keySetId, @Nullable ContactId contactId,
@Nullable PendingContactId pendingContactId,
MutableIncomingKeys inKeys, long streamNumber,
boolean handshakeMode) {
this.keySetId = keySetId;
this.contactId = contactId;
this.pendingContactId = pendingContactId;
this.inKeys = inKeys;
this.streamNumber = streamNumber;
this.handshakeMode = handshakeMode;
}
}
private static class RotationResult {
private static class UpdateResult {
private final Collection<TransportKeySet> current = new ArrayList<>();
private final Collection<TransportKeySet> rotated = new ArrayList<>();
private final Collection<TransportKeySet> updated = new ArrayList<>();
}
}