Add handshake keys to TransportKeyManagerImpl.

This commit is contained in:
akwizgran
2019-05-07 10:33:26 +01:00
parent 3f51ad6c07
commit 5adc9d8dbd
8 changed files with 180 additions and 111 deletions

View File

@@ -1,34 +1,52 @@
package org.briarproject.bramble.api.transport; package org.briarproject.bramble.api.transport;
import org.briarproject.bramble.api.contact.ContactId; import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.contact.PendingContactId;
import org.briarproject.bramble.api.crypto.SecretKey; import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault; import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.TransportId;
import javax.annotation.Nullable;
import javax.annotation.concurrent.Immutable; import javax.annotation.concurrent.Immutable;
@Immutable @Immutable
@NotNullByDefault @NotNullByDefault
public class StreamContext { public class StreamContext {
@Nullable
private final ContactId contactId; private final ContactId contactId;
@Nullable
private final PendingContactId pendingContactId;
private final TransportId transportId; private final TransportId transportId;
private final SecretKey tagKey, headerKey; private final SecretKey tagKey, headerKey;
private final long streamNumber; private final long streamNumber;
private final boolean handshakeMode;
public StreamContext(ContactId contactId, TransportId transportId, public StreamContext(@Nullable ContactId contactId,
SecretKey tagKey, SecretKey headerKey, long streamNumber) { @Nullable PendingContactId pendingContactId,
TransportId transportId, SecretKey tagKey, SecretKey headerKey,
long streamNumber, boolean handshakeMode) {
if ((contactId == null) == (pendingContactId == null))
throw new IllegalArgumentException();
this.contactId = contactId; this.contactId = contactId;
this.pendingContactId = pendingContactId;
this.transportId = transportId; this.transportId = transportId;
this.tagKey = tagKey; this.tagKey = tagKey;
this.headerKey = headerKey; this.headerKey = headerKey;
this.streamNumber = streamNumber; this.streamNumber = streamNumber;
this.handshakeMode = handshakeMode;
} }
@Nullable
public ContactId getContactId() { public ContactId getContactId() {
return contactId; return contactId;
} }
@Nullable
public PendingContactId getPendingContactId() {
return pendingContactId;
}
public TransportId getTransportId() { public TransportId getTransportId() {
return transportId; return transportId;
} }
@@ -44,4 +62,8 @@ public class StreamContext {
public long getStreamNumber() { public long getStreamNumber() {
return streamNumber; return streamNumber;
} }
public boolean isHandshakeMode() {
return handshakeMode;
}
} }

View File

@@ -24,6 +24,8 @@ public class TransportKeySet {
public TransportKeySet(KeySetId keySetId, @Nullable ContactId contactId, public TransportKeySet(KeySetId keySetId, @Nullable ContactId contactId,
@Nullable PendingContactId pendingContactId, TransportKeys keys) { @Nullable PendingContactId pendingContactId, TransportKeys keys) {
if ((contactId == null) == (pendingContactId == null))
throw new IllegalArgumentException();
this.keySetId = keySetId; this.keySetId = keySetId;
this.contactId = contactId; this.contactId = contactId;
this.pendingContactId = pendingContactId; this.pendingContactId = pendingContactId;

View File

@@ -96,6 +96,7 @@ class ConnectionManagerImpl implements ConnectionManager {
TransportConnectionReader r) throws IOException { TransportConnectionReader r) throws IOException {
InputStream streamReader = streamReaderFactory.createStreamReader( InputStream streamReader = streamReaderFactory.createStreamReader(
r.getInputStream(), ctx); r.getInputStream(), ctx);
// TODO: Pending contacts, handshake mode
return syncSessionFactory.createIncomingSession(ctx.getContactId(), return syncSessionFactory.createIncomingSession(ctx.getContactId(),
streamReader); streamReader);
} }
@@ -104,6 +105,7 @@ class ConnectionManagerImpl implements ConnectionManager {
TransportConnectionWriter w) throws IOException { TransportConnectionWriter w) throws IOException {
StreamWriter streamWriter = streamWriterFactory.createStreamWriter( StreamWriter streamWriter = streamWriterFactory.createStreamWriter(
w.getOutputStream(), ctx); w.getOutputStream(), ctx);
// TODO: Pending contacts, handshake mode
return syncSessionFactory.createSimplexOutgoingSession( return syncSessionFactory.createSimplexOutgoingSession(
ctx.getContactId(), w.getMaxLatency(), streamWriter); ctx.getContactId(), w.getMaxLatency(), streamWriter);
} }
@@ -112,6 +114,7 @@ class ConnectionManagerImpl implements ConnectionManager {
TransportConnectionWriter w) throws IOException { TransportConnectionWriter w) throws IOException {
StreamWriter streamWriter = streamWriterFactory.createStreamWriter( StreamWriter streamWriter = streamWriterFactory.createStreamWriter(
w.getOutputStream(), ctx); w.getOutputStream(), ctx);
// TODO: Pending contacts, handshake mode
return syncSessionFactory.createDuplexOutgoingSession( return syncSessionFactory.createDuplexOutgoingSession(
ctx.getContactId(), w.getMaxLatency(), w.getMaxIdleTime(), ctx.getContactId(), w.getMaxLatency(), w.getMaxIdleTime(),
streamWriter); streamWriter);
@@ -145,6 +148,7 @@ class ConnectionManagerImpl implements ConnectionManager {
disposeReader(false, false); disposeReader(false, false);
return; return;
} }
// TODO: Pending contacts
ContactId contactId = ctx.getContactId(); ContactId contactId = ctx.getContactId();
connectionRegistry.registerConnection(contactId, transportId, true); connectionRegistry.registerConnection(contactId, transportId, true);
try { try {
@@ -388,7 +392,7 @@ class ConnectionManagerImpl implements ConnectionManager {
return; return;
} }
// Check that the stream comes from the expected contact // Check that the stream comes from the expected contact
if (!ctx.getContactId().equals(contactId)) { if (!contactId.equals(ctx.getContactId())) {
LOG.warning("Wrong contact ID for returning stream"); LOG.warning("Wrong contact ID for returning stream");
disposeReader(true, true); disposeReader(true, true);
return; return;

View File

@@ -1,30 +0,0 @@
package org.briarproject.bramble.transport;
import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.transport.KeySetId;
class MutableKeySet {
private final KeySetId keySetId;
private final ContactId contactId;
private final MutableTransportKeys transportKeys;
MutableKeySet(KeySetId keySetId, ContactId contactId,
MutableTransportKeys transportKeys) {
this.keySetId = keySetId;
this.contactId = contactId;
this.transportKeys = transportKeys;
}
KeySetId getKeySetId() {
return keySetId;
}
ContactId getContactId() {
return contactId;
}
MutableTransportKeys getTransportKeys() {
return transportKeys;
}
}

View File

@@ -0,0 +1,50 @@
package org.briarproject.bramble.transport;
import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.contact.PendingContactId;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.transport.KeySetId;
import javax.annotation.Nullable;
import javax.annotation.concurrent.NotThreadSafe;
@NotThreadSafe
@NotNullByDefault
class MutableTransportKeySet {
private final KeySetId keySetId;
@Nullable
private final ContactId contactId;
@Nullable
private final PendingContactId pendingContactId;
private final MutableTransportKeys keys;
MutableTransportKeySet(KeySetId keySetId, @Nullable ContactId contactId,
@Nullable PendingContactId pendingContactId,
MutableTransportKeys keys) {
if ((contactId == null) == (pendingContactId == null))
throw new IllegalArgumentException();
this.keySetId = keySetId;
this.contactId = contactId;
this.pendingContactId = pendingContactId;
this.keys = keys;
}
KeySetId getKeySetId() {
return keySetId;
}
@Nullable
ContactId getContactId() {
return contactId;
}
@Nullable
PendingContactId getPendingContactId() {
return pendingContactId;
}
MutableTransportKeys getKeys() {
return keys;
}
}

View File

@@ -2,6 +2,7 @@ package org.briarproject.bramble.transport;
import org.briarproject.bramble.api.Bytes; import org.briarproject.bramble.api.Bytes;
import org.briarproject.bramble.api.contact.ContactId; import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.contact.PendingContactId;
import org.briarproject.bramble.api.crypto.SecretKey; import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.crypto.TransportCrypto; import org.briarproject.bramble.api.crypto.TransportCrypto;
import org.briarproject.bramble.api.db.DatabaseComponent; import org.briarproject.bramble.api.db.DatabaseComponent;
@@ -28,6 +29,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger; import java.util.logging.Logger;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe; import javax.annotation.concurrent.ThreadSafe;
@@ -58,11 +60,12 @@ class TransportKeyManagerImpl implements TransportKeyManager {
private final ReentrantLock lock = new ReentrantLock(); private final ReentrantLock lock = new ReentrantLock();
@GuardedBy("lock") @GuardedBy("lock")
private final Map<KeySetId, MutableKeySet> keys = new HashMap<>(); private final Map<KeySetId, MutableTransportKeySet> keys = new HashMap<>();
@GuardedBy("lock") @GuardedBy("lock")
private final Map<Bytes, TagContext> inContexts = new HashMap<>(); private final Map<Bytes, TagContext> inContexts = new HashMap<>();
@GuardedBy("lock") @GuardedBy("lock")
private final Map<ContactId, MutableKeySet> outContexts = new HashMap<>(); private final Map<ContactId, MutableTransportKeySet> outContexts =
new HashMap<>();
TransportKeyManagerImpl(DatabaseComponent db, TransportKeyManagerImpl(DatabaseComponent db,
TransportCrypto transportCrypto, Executor dbExecutor, TransportCrypto transportCrypto, Executor dbExecutor,
@@ -86,23 +89,23 @@ class TransportKeyManagerImpl implements TransportKeyManager {
// Load the transport keys from the DB // Load the transport keys from the DB
Collection<TransportKeySet> loaded = Collection<TransportKeySet> loaded =
db.getTransportKeys(txn, transportId); db.getTransportKeys(txn, transportId);
// Rotate the keys to the current time period // Update the keys to the current time period
RotationResult rotationResult = rotateKeys(loaded, now); UpdateResult updateResult = updateKeys(loaded, now);
// Initialise mutable state for all contacts // Initialise mutable state for all contacts
addKeys(rotationResult.current); addKeys(updateResult.current);
// Write any rotated keys back to the DB // Write any updated keys back to the DB
if (!rotationResult.rotated.isEmpty()) if (!updateResult.updated.isEmpty())
db.updateTransportKeys(txn, rotationResult.rotated); db.updateTransportKeys(txn, updateResult.updated);
} finally { } finally {
lock.unlock(); lock.unlock();
} }
// Schedule the next key rotation // Schedule the next key update
scheduleKeyRotation(now); scheduleKeyUpdate(now);
} }
private RotationResult rotateKeys(Collection<TransportKeySet> keys, private UpdateResult updateKeys(Collection<TransportKeySet> keys,
long now) { long now) {
RotationResult rotationResult = new RotationResult(); UpdateResult updateResult = new UpdateResult();
long timePeriod = now / timePeriodLength; long timePeriod = now / timePeriodLength;
for (TransportKeySet ks : keys) { for (TransportKeySet ks : keys) {
TransportKeys k = ks.getKeys(); TransportKeys k = ks.getKeys();
@@ -111,38 +114,45 @@ class TransportKeyManagerImpl implements TransportKeyManager {
TransportKeySet ks1 = new TransportKeySet(ks.getKeySetId(), TransportKeySet ks1 = new TransportKeySet(ks.getKeySetId(),
ks.getContactId(), null, k1); ks.getContactId(), null, k1);
if (k1.getTimePeriod() > k.getTimePeriod()) if (k1.getTimePeriod() > k.getTimePeriod())
rotationResult.rotated.add(ks1); updateResult.updated.add(ks1);
rotationResult.current.add(ks1); updateResult.current.add(ks1);
} }
return rotationResult; return updateResult;
} }
@GuardedBy("lock") @GuardedBy("lock")
private void addKeys(Collection<TransportKeySet> keys) { private void addKeys(Collection<TransportKeySet> keys) {
for (TransportKeySet ks : keys) { for (TransportKeySet ks : keys) {
// TODO: Keys may be for a pending contact
addKeys(ks.getKeySetId(), ks.getContactId(), addKeys(ks.getKeySetId(), ks.getContactId(),
ks.getPendingContactId(),
new MutableTransportKeys(ks.getKeys())); new MutableTransportKeys(ks.getKeys()));
} }
} }
@GuardedBy("lock") @GuardedBy("lock")
private void addKeys(KeySetId keySetId, ContactId contactId, private void addKeys(KeySetId keySetId, @Nullable ContactId contactId,
MutableTransportKeys m) { @Nullable PendingContactId pendingContactId,
MutableKeySet ks = new MutableKeySet(keySetId, contactId, m); MutableTransportKeys keys) {
keys.put(keySetId, ks); MutableTransportKeySet ks = new MutableTransportKeySet(keySetId,
encodeTags(keySetId, contactId, m.getPreviousIncomingKeys()); contactId, pendingContactId, keys);
encodeTags(keySetId, contactId, m.getCurrentIncomingKeys()); this.keys.put(keySetId, ks);
encodeTags(keySetId, contactId, m.getNextIncomingKeys()); boolean handshakeMode = keys.isHandshakeMode();
encodeTags(keySetId, contactId, pendingContactId,
keys.getPreviousIncomingKeys(), handshakeMode);
encodeTags(keySetId, contactId, pendingContactId,
keys.getCurrentIncomingKeys(), handshakeMode);
encodeTags(keySetId, contactId, pendingContactId,
keys.getNextIncomingKeys(), handshakeMode);
considerReplacingOutgoingKeys(ks); considerReplacingOutgoingKeys(ks);
} }
@GuardedBy("lock") @GuardedBy("lock")
private void encodeTags(KeySetId keySetId, ContactId contactId, private void encodeTags(KeySetId keySetId, @Nullable ContactId contactId,
MutableIncomingKeys inKeys) { @Nullable PendingContactId pendingContactId,
MutableIncomingKeys inKeys, boolean handshakeMode) {
for (long streamNumber : inKeys.getWindow().getUnseen()) { for (long streamNumber : inKeys.getWindow().getUnseen()) {
TagContext tagCtx = TagContext tagCtx = new TagContext(keySetId, contactId,
new TagContext(keySetId, contactId, inKeys, streamNumber); pendingContactId, inKeys, streamNumber, handshakeMode);
byte[] tag = new byte[TAG_LENGTH]; byte[] tag = new byte[TAG_LENGTH];
transportCrypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION, transportCrypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION,
streamNumber); streamNumber);
@@ -151,26 +161,29 @@ class TransportKeyManagerImpl implements TransportKeyManager {
} }
@GuardedBy("lock") @GuardedBy("lock")
private void considerReplacingOutgoingKeys(MutableKeySet ks) { private void considerReplacingOutgoingKeys(MutableTransportKeySet ks) {
// Use the active outgoing keys with the highest key set ID // Use the active outgoing keys with the highest key set ID
if (ks.getTransportKeys().getCurrentOutgoingKeys().isActive()) { ContactId c = ks.getContactId();
MutableKeySet old = outContexts.get(ks.getContactId()); if (c != null && ks.getKeys().getCurrentOutgoingKeys().isActive()) {
MutableTransportKeySet old = outContexts.get(c);
if (old == null || if (old == null ||
(old.getKeys().isHandshakeMode() &&
!ks.getKeys().isHandshakeMode()) ||
old.getKeySetId().getInt() < ks.getKeySetId().getInt()) { old.getKeySetId().getInt() < ks.getKeySetId().getInt()) {
outContexts.put(ks.getContactId(), ks); outContexts.put(c, ks);
} }
} }
} }
private void scheduleKeyRotation(long now) { private void scheduleKeyUpdate(long now) {
long delay = timePeriodLength - now % timePeriodLength; long delay = timePeriodLength - now % timePeriodLength;
scheduler.schedule((Runnable) this::rotateKeys, delay, MILLISECONDS); scheduler.schedule((Runnable) this::updateKeys, delay, MILLISECONDS);
} }
private void rotateKeys() { private void updateKeys() {
dbExecutor.execute(() -> { dbExecutor.execute(() -> {
try { try {
db.transaction(false, this::rotateKeys); db.transaction(false, this::updateKeys);
} catch (DbException e) { } catch (DbException e) {
logException(LOG, WARNING, e); logException(LOG, WARNING, e);
} }
@@ -187,13 +200,13 @@ class TransportKeyManagerImpl implements TransportKeyManager {
// Derive the transport keys // Derive the transport keys
TransportKeys k = transportCrypto.deriveRotationKeys(transportId, TransportKeys k = transportCrypto.deriveRotationKeys(transportId,
rootKey, timePeriod, alice, active); rootKey, timePeriod, alice, active);
// Rotate the keys to the current time period if necessary // Update the keys to the current time period if necessary
timePeriod = clock.currentTimeMillis() / timePeriodLength; timePeriod = clock.currentTimeMillis() / timePeriodLength;
k = transportCrypto.updateTransportKeys(k, timePeriod); k = transportCrypto.updateTransportKeys(k, timePeriod);
// Write the keys back to the DB // Write the keys back to the DB
KeySetId keySetId = db.addTransportKeys(txn, c, k); KeySetId keySetId = db.addTransportKeys(txn, c, k);
// Initialise mutable state for the contact // Initialise mutable state for the contact
addKeys(keySetId, c, new MutableTransportKeys(k)); addKeys(keySetId, c, null, new MutableTransportKeys(k));
return keySetId; return keySetId;
} finally { } finally {
lock.unlock(); lock.unlock();
@@ -204,9 +217,9 @@ class TransportKeyManagerImpl implements TransportKeyManager {
public void activateKeys(Transaction txn, KeySetId k) throws DbException { public void activateKeys(Transaction txn, KeySetId k) throws DbException {
lock.lock(); lock.lock();
try { try {
MutableKeySet ks = keys.get(k); MutableTransportKeySet ks = keys.get(k);
if (ks == null) throw new IllegalArgumentException(); if (ks == null) throw new IllegalArgumentException();
MutableTransportKeys m = ks.getTransportKeys(); MutableTransportKeys m = ks.getKeys();
m.getCurrentOutgoingKeys().activate(); m.getCurrentOutgoingKeys().activate();
considerReplacingOutgoingKeys(ks); considerReplacingOutgoingKeys(ks);
db.setTransportKeysActive(txn, m.getTransportId(), k); db.setTransportKeysActive(txn, m.getTransportId(), k);
@@ -221,13 +234,11 @@ class TransportKeyManagerImpl implements TransportKeyManager {
try { try {
// Remove mutable state for the contact // Remove mutable state for the contact
Iterator<TagContext> it = inContexts.values().iterator(); Iterator<TagContext> it = inContexts.values().iterator();
while (it.hasNext()) if (it.next().contactId.equals(c)) it.remove(); while (it.hasNext()) if (c.equals(it.next().contactId)) it.remove();
outContexts.remove(c); outContexts.remove(c);
Iterator<MutableKeySet> it1 = keys.values().iterator(); Iterator<MutableTransportKeySet> it1 = keys.values().iterator();
while (it1.hasNext()) { while (it1.hasNext())
ContactId c1 = it1.next().getContactId(); if (c.equals(it1.next().getContactId())) it1.remove();
if (c1 != null && c1.equals(c)) it1.remove();
}
} finally { } finally {
lock.unlock(); lock.unlock();
} }
@@ -237,10 +248,10 @@ class TransportKeyManagerImpl implements TransportKeyManager {
public boolean canSendOutgoingStreams(ContactId c) { public boolean canSendOutgoingStreams(ContactId c) {
lock.lock(); lock.lock();
try { try {
MutableKeySet ks = outContexts.get(c); MutableTransportKeySet ks = outContexts.get(c);
if (ks == null) return false; if (ks == null) return false;
MutableOutgoingKeys outKeys = MutableOutgoingKeys outKeys =
ks.getTransportKeys().getCurrentOutgoingKeys(); ks.getKeys().getCurrentOutgoingKeys();
if (!outKeys.isActive()) throw new AssertionError(); if (!outKeys.isActive()) throw new AssertionError();
return outKeys.getStreamCounter() <= MAX_32_BIT_UNSIGNED; return outKeys.getStreamCounter() <= MAX_32_BIT_UNSIGNED;
} finally { } finally {
@@ -254,16 +265,16 @@ class TransportKeyManagerImpl implements TransportKeyManager {
lock.lock(); lock.lock();
try { try {
// Look up the outgoing keys for the contact // Look up the outgoing keys for the contact
MutableKeySet ks = outContexts.get(c); MutableTransportKeySet ks = outContexts.get(c);
if (ks == null) return null; if (ks == null) return null;
MutableOutgoingKeys outKeys = MutableTransportKeys keys = ks.getKeys();
ks.getTransportKeys().getCurrentOutgoingKeys(); MutableOutgoingKeys outKeys = keys.getCurrentOutgoingKeys();
if (!outKeys.isActive()) throw new AssertionError(); if (!outKeys.isActive()) throw new AssertionError();
if (outKeys.getStreamCounter() > MAX_32_BIT_UNSIGNED) return null; if (outKeys.getStreamCounter() > MAX_32_BIT_UNSIGNED) return null;
// Create a stream context // Create a stream context
StreamContext ctx = new StreamContext(c, transportId, StreamContext ctx = new StreamContext(c, null, transportId,
outKeys.getTagKey(), outKeys.getHeaderKey(), outKeys.getTagKey(), outKeys.getHeaderKey(),
outKeys.getStreamCounter()); outKeys.getStreamCounter(), keys.isHandshakeMode());
// Increment the stream counter and write it back to the DB // Increment the stream counter and write it back to the DB
outKeys.incrementStreamCounter(); outKeys.incrementStreamCounter();
db.incrementStreamCounter(txn, transportId, ks.getKeySetId()); db.incrementStreamCounter(txn, transportId, ks.getKeySetId());
@@ -283,9 +294,10 @@ class TransportKeyManagerImpl implements TransportKeyManager {
if (tagCtx == null) return null; if (tagCtx == null) return null;
MutableIncomingKeys inKeys = tagCtx.inKeys; MutableIncomingKeys inKeys = tagCtx.inKeys;
// Create a stream context // Create a stream context
StreamContext ctx = new StreamContext(tagCtx.contactId, transportId, StreamContext ctx = new StreamContext(tagCtx.contactId,
tagCtx.pendingContactId, transportId,
inKeys.getTagKey(), inKeys.getHeaderKey(), inKeys.getTagKey(), inKeys.getHeaderKey(),
tagCtx.streamNumber); tagCtx.streamNumber, tagCtx.handshakeMode);
// Update the reordering window // Update the reordering window
ReorderingWindow window = inKeys.getWindow(); ReorderingWindow window = inKeys.getWindow();
Change change = window.setSeen(tagCtx.streamNumber); Change change = window.setSeen(tagCtx.streamNumber);
@@ -295,7 +307,8 @@ class TransportKeyManagerImpl implements TransportKeyManager {
transportCrypto.encodeTag(addTag, inKeys.getTagKey(), transportCrypto.encodeTag(addTag, inKeys.getTagKey(),
PROTOCOL_VERSION, streamNumber); PROTOCOL_VERSION, streamNumber);
TagContext tagCtx1 = new TagContext(tagCtx.keySetId, TagContext tagCtx1 = new TagContext(tagCtx.keySetId,
tagCtx.contactId, inKeys, streamNumber); tagCtx.contactId, tagCtx.pendingContactId, inKeys,
streamNumber, tagCtx.handshakeMode);
inContexts.put(new Bytes(addTag), tagCtx1); inContexts.put(new Bytes(addTag), tagCtx1);
} }
// Remove tags for any stream numbers removed from the window // Remove tags for any stream numbers removed from the window
@@ -311,9 +324,9 @@ class TransportKeyManagerImpl implements TransportKeyManager {
inKeys.getTimePeriod(), window.getBase(), inKeys.getTimePeriod(), window.getBase(),
window.getBitmap()); window.getBitmap());
// If the outgoing keys are inactive, activate them // If the outgoing keys are inactive, activate them
MutableKeySet ks = keys.get(tagCtx.keySetId); MutableTransportKeySet ks = keys.get(tagCtx.keySetId);
MutableOutgoingKeys outKeys = MutableOutgoingKeys outKeys =
ks.getTransportKeys().getCurrentOutgoingKeys(); ks.getKeys().getCurrentOutgoingKeys();
if (!outKeys.isActive()) { if (!outKeys.isActive()) {
LOG.info("Activating outgoing keys"); LOG.info("Activating outgoing keys");
outKeys.activate(); outKeys.activate();
@@ -326,52 +339,60 @@ class TransportKeyManagerImpl implements TransportKeyManager {
} }
} }
private void rotateKeys(Transaction txn) throws DbException { private void updateKeys(Transaction txn) throws DbException {
long now = clock.currentTimeMillis(); long now = clock.currentTimeMillis();
lock.lock(); lock.lock();
try { try {
// Rotate the keys to the current time period // Update the keys to the current time period
Collection<TransportKeySet> snapshot = new ArrayList<>(keys.size()); Collection<TransportKeySet> snapshot = new ArrayList<>(keys.size());
for (MutableKeySet ks : keys.values()) { for (MutableTransportKeySet ks : keys.values()) {
snapshot.add(new TransportKeySet(ks.getKeySetId(), snapshot.add(new TransportKeySet(ks.getKeySetId(),
ks.getContactId(), null, ks.getContactId(), ks.getPendingContactId(),
ks.getTransportKeys().snapshot())); ks.getKeys().snapshot()));
} }
RotationResult rotationResult = rotateKeys(snapshot, now); UpdateResult updateResult = updateKeys(snapshot, now);
// Rebuild the mutable state for all contacts // Rebuild the mutable state for all contacts
inContexts.clear(); inContexts.clear();
outContexts.clear(); outContexts.clear();
keys.clear(); keys.clear();
addKeys(rotationResult.current); addKeys(updateResult.current);
// Write any rotated keys back to the DB // Write any updated keys back to the DB
if (!rotationResult.rotated.isEmpty()) if (!updateResult.updated.isEmpty())
db.updateTransportKeys(txn, rotationResult.rotated); db.updateTransportKeys(txn, updateResult.updated);
} finally { } finally {
lock.unlock(); lock.unlock();
} }
// Schedule the next key rotation // Schedule the next key update
scheduleKeyRotation(now); scheduleKeyUpdate(now);
} }
private static class TagContext { private static class TagContext {
private final KeySetId keySetId; private final KeySetId keySetId;
@Nullable
private final ContactId contactId; private final ContactId contactId;
@Nullable
private final PendingContactId pendingContactId;
private final MutableIncomingKeys inKeys; private final MutableIncomingKeys inKeys;
private final long streamNumber; private final long streamNumber;
private final boolean handshakeMode;
private TagContext(KeySetId keySetId, ContactId contactId, private TagContext(KeySetId keySetId, @Nullable ContactId contactId,
MutableIncomingKeys inKeys, long streamNumber) { @Nullable PendingContactId pendingContactId,
MutableIncomingKeys inKeys, long streamNumber,
boolean handshakeMode) {
this.keySetId = keySetId; this.keySetId = keySetId;
this.contactId = contactId; this.contactId = contactId;
this.pendingContactId = pendingContactId;
this.inKeys = inKeys; this.inKeys = inKeys;
this.streamNumber = streamNumber; this.streamNumber = streamNumber;
this.handshakeMode = handshakeMode;
} }
} }
private static class RotationResult { private static class UpdateResult {
private final Collection<TransportKeySet> current = new ArrayList<>(); private final Collection<TransportKeySet> current = new ArrayList<>();
private final Collection<TransportKeySet> rotated = new ArrayList<>(); private final Collection<TransportKeySet> updated = new ArrayList<>();
} }
} }

View File

@@ -101,8 +101,8 @@ public class SyncIntegrationTest extends BrambleTestCase {
private byte[] write() throws Exception { private byte[] write() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
StreamContext ctx = new StreamContext(contactId, transportId, tagKey, StreamContext ctx = new StreamContext(contactId, null, transportId,
headerKey, streamNumber); tagKey, headerKey, streamNumber, false);
StreamWriter streamWriter = streamWriterFactory.createStreamWriter(out, StreamWriter streamWriter = streamWriterFactory.createStreamWriter(out,
ctx); ctx);
SyncRecordWriter recordWriter = recordWriterFactory.createRecordWriter( SyncRecordWriter recordWriter = recordWriterFactory.createRecordWriter(
@@ -131,8 +131,8 @@ public class SyncIntegrationTest extends BrambleTestCase {
assertArrayEquals(expectedTag, tag); assertArrayEquals(expectedTag, tag);
// Create the readers // Create the readers
StreamContext ctx = new StreamContext(contactId, transportId, tagKey, StreamContext ctx = new StreamContext(contactId, null, transportId,
headerKey, streamNumber); tagKey, headerKey, streamNumber, false);
InputStream streamReader = streamReaderFactory.createStreamReader(in, InputStream streamReader = streamReaderFactory.createStreamReader(in,
ctx); ctx);
SyncRecordReader recordReader = recordReaderFactory.createRecordReader( SyncRecordReader recordReader = recordReaderFactory.createRecordReader(

View File

@@ -47,7 +47,7 @@ public class KeyManagerImplTest extends BrambleMockTestCase {
private final TransportId transportId = getTransportId(); private final TransportId transportId = getTransportId();
private final TransportId unknownTransportId = getTransportId(); private final TransportId unknownTransportId = getTransportId();
private final StreamContext streamContext = new StreamContext(contactId, private final StreamContext streamContext = new StreamContext(contactId,
transportId, getSecretKey(), getSecretKey(), 1); null, transportId, getSecretKey(), getSecretKey(), 1, false);
private final byte[] tag = getRandomBytes(TAG_LENGTH); private final byte[] tag = getRandomBytes(TAG_LENGTH);
private final Random random = new Random(); private final Random random = new Random();