Don't start transactions while holding locks. #272

This commit is contained in:
akwizgran
2016-03-29 15:21:46 +01:00
parent 685a864b43
commit e58ca00979
6 changed files with 200 additions and 180 deletions

View File

@@ -26,10 +26,6 @@ import java.util.Map;
/** /**
* Encapsulates the database implementation and exposes high-level operations * Encapsulates the database implementation and exposes high-level operations
* to other components. * to other components.
* <p/>
* This interface's methods are blocking, but they do not call out into other
* components except to broadcast {@link org.briarproject.api.event.Event
* Events}, so they can safely be called while holding locks.
*/ */
public interface DatabaseComponent { public interface DatabaseComponent {
@@ -45,6 +41,9 @@ public interface DatabaseComponent {
/** /**
* Starts a new transaction and returns an object representing it. * Starts a new transaction and returns an object representing it.
* <p/>
* This method acquires locks, so it must not be called while holding a
* lock.
* @param readOnly true if the transaction will only be used for reading. * @param readOnly true if the transaction will only be used for reading.
*/ */
Transaction startTransaction(boolean readOnly) throws DbException; Transaction startTransaction(boolean readOnly) throws DbException;

View File

@@ -26,12 +26,14 @@ public interface KeyManager {
* contact over the given transport, or null if an error occurs or the * contact over the given transport, or null if an error occurs or the
* contact does not support the transport. * contact does not support the transport.
*/ */
StreamContext getStreamContext(ContactId c, TransportId t); StreamContext getStreamContext(ContactId c, TransportId t)
throws DbException;
/** /**
* Looks up the given tag and returns a {@link StreamContext} for reading * Looks up the given tag and returns a {@link StreamContext} for reading
* from the corresponding stream, or null if an error occurs or the tag was * from the corresponding stream, or null if an error occurs or the tag was
* unexpected. * unexpected.
*/ */
StreamContext getStreamContext(TransportId t, byte[] tag); StreamContext getStreamContext(TransportId t, byte[] tag)
throws DbException;
} }

View File

@@ -2,6 +2,7 @@ package org.briarproject.plugins;
import org.briarproject.api.TransportId; import org.briarproject.api.TransportId;
import org.briarproject.api.contact.ContactId; import org.briarproject.api.contact.ContactId;
import org.briarproject.api.db.DbException;
import org.briarproject.api.lifecycle.IoExecutor; import org.briarproject.api.lifecycle.IoExecutor;
import org.briarproject.api.plugins.ConnectionManager; import org.briarproject.api.plugins.ConnectionManager;
import org.briarproject.api.plugins.ConnectionRegistry; import org.briarproject.api.plugins.ConnectionRegistry;
@@ -132,10 +133,14 @@ class ConnectionManagerImpl implements ConnectionManager {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
disposeReader(true, false); disposeReader(true, false);
return; return;
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
disposeReader(true, false);
return;
} }
if (ctx == null) { if (ctx == null) {
LOG.info("Unrecognised tag"); LOG.info("Unrecognised tag");
disposeReader(true, false); disposeReader(false, false);
return; return;
} }
ContactId contactId = ctx.getContactId(); ContactId contactId = ctx.getContactId();
@@ -176,11 +181,17 @@ class ConnectionManagerImpl implements ConnectionManager {
public void run() { public void run() {
// Allocate a stream context // Allocate a stream context
StreamContext ctx = keyManager.getStreamContext(contactId, StreamContext ctx;
transportId); try {
ctx = keyManager.getStreamContext(contactId, transportId);
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
disposeWriter(true);
return;
}
if (ctx == null) { if (ctx == null) {
LOG.warning("Could not allocate stream context"); LOG.warning("Could not allocate stream context");
disposeWriter(true); disposeWriter(false);
return; return;
} }
connectionRegistry.registerConnection(contactId, transportId); connectionRegistry.registerConnection(contactId, transportId);
@@ -232,10 +243,14 @@ class ConnectionManagerImpl implements ConnectionManager {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
disposeReader(true, false); disposeReader(true, false);
return; return;
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
disposeReader(true, false);
return;
} }
if (ctx == null) { if (ctx == null) {
LOG.info("Unrecognised tag"); LOG.info("Unrecognised tag");
disposeReader(true, false); disposeReader(false, false);
return; return;
} }
contactId = ctx.getContactId(); contactId = ctx.getContactId();
@@ -261,11 +276,17 @@ class ConnectionManagerImpl implements ConnectionManager {
private void runOutgoingSession() { private void runOutgoingSession() {
// Allocate a stream context // Allocate a stream context
StreamContext ctx = keyManager.getStreamContext(contactId, StreamContext ctx;
transportId); try {
ctx = keyManager.getStreamContext(contactId, transportId);
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
disposeWriter(true);
return;
}
if (ctx == null) { if (ctx == null) {
LOG.warning("Could not allocate stream context"); LOG.warning("Could not allocate stream context");
disposeWriter(true); disposeWriter(false);
return; return;
} }
try { try {
@@ -320,11 +341,17 @@ class ConnectionManagerImpl implements ConnectionManager {
public void run() { public void run() {
// Allocate a stream context // Allocate a stream context
StreamContext ctx = keyManager.getStreamContext(contactId, StreamContext ctx;
transportId); try {
ctx = keyManager.getStreamContext(contactId, transportId);
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
disposeWriter(true);
return;
}
if (ctx == null) { if (ctx == null) {
LOG.warning("Could not allocate stream context"); LOG.warning("Could not allocate stream context");
disposeWriter(true); disposeWriter(false);
return; return;
} }
connectionRegistry.registerConnection(contactId, transportId); connectionRegistry.registerConnection(contactId, transportId);
@@ -357,6 +384,10 @@ class ConnectionManagerImpl implements ConnectionManager {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
disposeReader(true, true); disposeReader(true, true);
return; return;
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
disposeReader(true, true);
return;
} }
// Unrecognised tags are suspicious in this case // Unrecognised tags are suspicious in this case
if (ctx == null) { if (ctx == null) {

View File

@@ -22,7 +22,6 @@ import org.briarproject.api.system.Timer;
import org.briarproject.api.transport.KeyManager; import org.briarproject.api.transport.KeyManager;
import org.briarproject.api.transport.StreamContext; import org.briarproject.api.transport.StreamContext;
import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
@@ -32,6 +31,7 @@ import java.util.logging.Logger;
import javax.inject.Inject; import javax.inject.Inject;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING; import static java.util.logging.Level.WARNING;
class KeyManagerImpl implements KeyManager, Service, EventListener { class KeyManagerImpl implements KeyManager, Service, EventListener {
@@ -65,31 +65,29 @@ class KeyManagerImpl implements KeyManager, Service, EventListener {
@Override @Override
public boolean start() { public boolean start() {
Map<TransportId, Integer> latencies = Map<TransportId, Integer> transports =
new HashMap<TransportId, Integer>(); new HashMap<TransportId, Integer>();
for (SimplexPluginFactory f : pluginConfig.getSimplexFactories()) for (SimplexPluginFactory f : pluginConfig.getSimplexFactories())
latencies.put(f.getId(), f.getMaxLatency()); transports.put(f.getId(), f.getMaxLatency());
for (DuplexPluginFactory f : pluginConfig.getDuplexFactories()) for (DuplexPluginFactory f : pluginConfig.getDuplexFactories())
latencies.put(f.getId(), f.getMaxLatency()); transports.put(f.getId(), f.getMaxLatency());
try { try {
Collection<Contact> contacts;
Transaction txn = db.startTransaction(false); Transaction txn = db.startTransaction(false);
try { try {
contacts = db.getContacts(txn); for (Contact c : db.getContacts(txn))
for (Entry<TransportId, Integer> e : latencies.entrySet()) if (c.isActive()) activeContacts.put(c.getId(), true);
for (Entry<TransportId, Integer> e : transports.entrySet())
db.addTransport(txn, e.getKey(), e.getValue()); db.addTransport(txn, e.getKey(), e.getValue());
for (Entry<TransportId, Integer> e : transports.entrySet()) {
TransportKeyManager m = new TransportKeyManager(db, crypto,
timer, clock, e.getKey(), e.getValue());
managers.put(e.getKey(), m);
m.start(txn);
}
txn.setComplete(); txn.setComplete();
} finally { } finally {
db.endTransaction(txn); db.endTransaction(txn);
} }
for (Contact c : contacts)
if (c.isActive()) activeContacts.put(c.getId(), true);
for (Entry<TransportId, Integer> e : latencies.entrySet()) {
TransportKeyManager m = new TransportKeyManager(db, crypto,
timer, clock, e.getKey(), e.getValue());
managers.put(e.getKey(), m);
m.start();
}
} catch (DbException e) { } catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return false; return false;
@@ -108,32 +106,43 @@ class KeyManagerImpl implements KeyManager, Service, EventListener {
m.addContact(txn, c, master, timestamp, alice); m.addContact(txn, c, master, timestamp, alice);
} }
public StreamContext getStreamContext(ContactId c, TransportId t) { public StreamContext getStreamContext(ContactId c, TransportId t)
throws DbException {
// Don't allow outgoing streams to inactive contacts // Don't allow outgoing streams to inactive contacts
if (!activeContacts.containsKey(c)) return null; if (!activeContacts.containsKey(c)) return null;
TransportKeyManager m = managers.get(t); TransportKeyManager m = managers.get(t);
return m == null ? null : m.getStreamContext(c); if (m == null) {
if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t);
return null;
}
StreamContext ctx = null;
Transaction txn = db.startTransaction(false);
try {
ctx = m.getStreamContext(txn, c);
txn.setComplete();
} finally {
db.endTransaction(txn);
}
return ctx;
} }
public StreamContext getStreamContext(TransportId t, byte[] tag) { public StreamContext getStreamContext(TransportId t, byte[] tag)
throws DbException {
TransportKeyManager m = managers.get(t); TransportKeyManager m = managers.get(t);
if (m == null) return null; if (m == null) {
StreamContext ctx = m.getStreamContext(tag); if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t);
if (ctx == null) return null; return null;
// Activate the contact if not already active }
if (!activeContacts.containsKey(ctx.getContactId())) { StreamContext ctx = null;
try { Transaction txn = db.startTransaction(false);
Transaction txn = db.startTransaction(false); try {
try { ctx = m.getStreamContext(txn, tag);
db.setContactActive(txn, ctx.getContactId(), true); // Activate the contact if not already active
txn.setComplete(); if (ctx != null && !activeContacts.containsKey(ctx.getContactId()))
} finally { db.setContactActive(txn, ctx.getContactId(), true);
db.endTransaction(txn); txn.setComplete();
} } finally {
} catch (DbException e) { db.endTransaction(txn);
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return null;
}
} }
return ctx; return ctx;
} }

View File

@@ -60,46 +60,20 @@ class TransportKeyManager {
keys = new HashMap<ContactId, MutableTransportKeys>(); keys = new HashMap<ContactId, MutableTransportKeys>();
} }
void start() { void start(Transaction txn) throws DbException {
long now = clock.currentTimeMillis(); long now = clock.currentTimeMillis();
lock.lock(); lock.lock();
try { try {
// Load the transport keys from the DB // Load the transport keys from the DB
Map<ContactId, TransportKeys> loaded; Map<ContactId, TransportKeys> loaded =
try { db.getTransportKeys(txn, transportId);
Transaction txn = db.startTransaction(true);
try {
loaded = db.getTransportKeys(txn, transportId);
txn.setComplete();
} finally {
db.endTransaction(txn);
}
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return;
}
// Rotate the keys to the current rotation period // Rotate the keys to the current rotation period
Map<ContactId, TransportKeys> rotated = RotationResult rotationResult = rotateKeys(loaded, now);
new HashMap<ContactId, TransportKeys>();
Map<ContactId, TransportKeys> current =
new HashMap<ContactId, TransportKeys>();
long rotationPeriod = now / rotationPeriodLength;
for (Entry<ContactId, TransportKeys> e : loaded.entrySet()) {
ContactId c = e.getKey();
TransportKeys k = e.getValue();
TransportKeys k1 = crypto.rotateTransportKeys(k,
rotationPeriod);
if (k1.getRotationPeriod() > k.getRotationPeriod())
rotated.put(c, k1);
current.put(c, k1);
}
// Initialise mutable state for all contacts // Initialise mutable state for all contacts
for (Entry<ContactId, TransportKeys> e : current.entrySet()) addKeys(rotationResult.current);
addKeys(e.getKey(), new MutableTransportKeys(e.getValue()));
// Write any rotated keys back to the DB // Write any rotated keys back to the DB
updateTransportKeys(rotated); if (!rotationResult.rotated.isEmpty())
} catch (DbException e) { db.updateTransportKeys(txn, rotationResult.rotated);
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
} finally { } finally {
lock.unlock(); lock.unlock();
} }
@@ -107,6 +81,27 @@ class TransportKeyManager {
scheduleKeyRotation(now); scheduleKeyRotation(now);
} }
private RotationResult rotateKeys(Map<ContactId, TransportKeys> keys,
long now) {
RotationResult rotationResult = new RotationResult();
long rotationPeriod = now / rotationPeriodLength;
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
ContactId c = e.getKey();
TransportKeys k = e.getValue();
TransportKeys k1 = crypto.rotateTransportKeys(k, rotationPeriod);
if (k1.getRotationPeriod() > k.getRotationPeriod())
rotationResult.rotated.put(c, k1);
rotationResult.current.put(c, k1);
}
return rotationResult;
}
// Locking: lock
private void addKeys(Map<ContactId, TransportKeys> m) {
for (Entry<ContactId, TransportKeys> e : m.entrySet())
addKeys(e.getKey(), new MutableTransportKeys(e.getValue()));
}
// Locking: lock // Locking: lock
private void addKeys(ContactId c, MutableTransportKeys m) { private void addKeys(ContactId c, MutableTransportKeys m) {
encodeTags(c, m.getPreviousIncomingKeys()); encodeTags(c, m.getPreviousIncomingKeys());
@@ -126,23 +121,21 @@ class TransportKeyManager {
} }
} }
private void updateTransportKeys(Map<ContactId, TransportKeys> rotated)
throws DbException {
if (!rotated.isEmpty()) {
Transaction txn = db.startTransaction(false);
try {
db.updateTransportKeys(txn, rotated);
txn.setComplete();
} finally {
db.endTransaction(txn);
}
}
}
private void scheduleKeyRotation(long now) { private void scheduleKeyRotation(long now) {
TimerTask task = new TimerTask() { TimerTask task = new TimerTask() {
public void run() { public void run() {
rotateKeys(); try {
Transaction txn = db.startTransaction(false);
try {
rotateKeys(txn);
txn.setComplete();
} finally {
db.endTransaction(txn);
}
} catch (DbException e) {
if (LOG.isLoggable(WARNING))
LOG.log(WARNING, e.toString(), e);
}
} }
}; };
long delay = rotationPeriodLength - now % rotationPeriodLength; long delay = rotationPeriodLength - now % rotationPeriodLength;
@@ -185,7 +178,8 @@ class TransportKeyManager {
} }
} }
StreamContext getStreamContext(ContactId c) { StreamContext getStreamContext(Transaction txn, ContactId c)
throws DbException {
lock.lock(); lock.lock();
try { try {
// Look up the outgoing keys for the contact // Look up the outgoing keys for the contact
@@ -198,24 +192,16 @@ class TransportKeyManager {
outKeys.getStreamCounter()); outKeys.getStreamCounter());
// Increment the stream counter and write it back to the DB // Increment the stream counter and write it back to the DB
outKeys.incrementStreamCounter(); outKeys.incrementStreamCounter();
Transaction txn = db.startTransaction(false); db.incrementStreamCounter(txn, c, transportId,
try { outKeys.getRotationPeriod());
db.incrementStreamCounter(txn, c, transportId,
outKeys.getRotationPeriod());
txn.setComplete();
} finally {
db.endTransaction(txn);
}
return ctx; return ctx;
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return null;
} finally { } finally {
lock.unlock(); lock.unlock();
} }
} }
StreamContext getStreamContext(byte[] tag) { StreamContext getStreamContext(Transaction txn, byte[] tag)
throws DbException {
lock.lock(); lock.lock();
try { try {
// Look up the incoming keys for the tag // Look up the incoming keys for the tag
@@ -244,53 +230,33 @@ class TransportKeyManager {
inContexts.remove(new Bytes(removeTag)); inContexts.remove(new Bytes(removeTag));
} }
// Write the window back to the DB // Write the window back to the DB
Transaction txn = db.startTransaction(false); db.setReorderingWindow(txn, tagCtx.contactId, transportId,
try { inKeys.getRotationPeriod(), window.getBase(),
db.setReorderingWindow(txn, tagCtx.contactId, transportId, window.getBitmap());
inKeys.getRotationPeriod(), window.getBase(),
window.getBitmap());
txn.setComplete();
} finally {
db.endTransaction(txn);
}
return ctx; return ctx;
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return null;
} finally { } finally {
lock.unlock(); lock.unlock();
} }
} }
private void rotateKeys() { private void rotateKeys(Transaction txn) throws DbException {
long now = clock.currentTimeMillis(); long now = clock.currentTimeMillis();
lock.lock(); lock.lock();
try { try {
// Rotate the keys to the current rotation period // Rotate the keys to the current rotation period
Map<ContactId, TransportKeys> rotated = Map<ContactId, TransportKeys> snapshot =
new HashMap<ContactId, TransportKeys>(); new HashMap<ContactId, TransportKeys>();
Map<ContactId, TransportKeys> current = for (Entry<ContactId, MutableTransportKeys> e : keys.entrySet())
new HashMap<ContactId, TransportKeys>(); snapshot.put(e.getKey(), e.getValue().snapshot());
long rotationPeriod = now / rotationPeriodLength; RotationResult rotationResult = rotateKeys(snapshot, now);
for (Entry<ContactId, MutableTransportKeys> e : keys.entrySet()) {
ContactId c = e.getKey();
TransportKeys k = e.getValue().snapshot();
TransportKeys k1 = crypto.rotateTransportKeys(k,
rotationPeriod);
if (k1.getRotationPeriod() > k.getRotationPeriod())
rotated.put(c, k1);
current.put(c, k1);
}
// Rebuild the mutable state for all contacts // Rebuild the mutable state for all contacts
inContexts.clear(); inContexts.clear();
outContexts.clear(); outContexts.clear();
keys.clear(); keys.clear();
for (Entry<ContactId, TransportKeys> e : current.entrySet()) addKeys(rotationResult.current);
addKeys(e.getKey(), new MutableTransportKeys(e.getValue()));
// Write any rotated keys back to the DB // Write any rotated keys back to the DB
updateTransportKeys(rotated); if (!rotationResult.rotated.isEmpty())
} catch (DbException e) { db.updateTransportKeys(txn, rotationResult.rotated);
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
} finally { } finally {
lock.unlock(); lock.unlock();
} }
@@ -311,4 +277,14 @@ class TransportKeyManager {
this.streamNumber = streamNumber; this.streamNumber = streamNumber;
} }
} }
private static class RotationResult {
private final Map<ContactId, TransportKeys> current, rotated;
private RotationResult() {
current = new HashMap<ContactId, TransportKeys>();
rotated = new HashMap<ContactId, TransportKeys>();
}
}
} }

View File

@@ -37,6 +37,7 @@ import static org.briarproject.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
public class TransportKeyManagerTest extends BriarTestCase { public class TransportKeyManagerTest extends BriarTestCase {
@@ -57,7 +58,7 @@ public class TransportKeyManagerTest extends BriarTestCase {
final CryptoComponent crypto = context.mock(CryptoComponent.class); final CryptoComponent crypto = context.mock(CryptoComponent.class);
final Timer timer = context.mock(Timer.class); final Timer timer = context.mock(Timer.class);
final Clock clock = context.mock(Clock.class); final Clock clock = context.mock(Clock.class);
final Transaction txn = new Transaction(null, true);
final Map<ContactId, TransportKeys> loaded = final Map<ContactId, TransportKeys> loaded =
new LinkedHashMap<ContactId, TransportKeys>(); new LinkedHashMap<ContactId, TransportKeys>();
final TransportKeys shouldRotate = createTransportKeys(900, 0); final TransportKeys shouldRotate = createTransportKeys(900, 0);
@@ -65,17 +66,15 @@ public class TransportKeyManagerTest extends BriarTestCase {
loaded.put(contactId, shouldRotate); loaded.put(contactId, shouldRotate);
loaded.put(contactId1, shouldNotRotate); loaded.put(contactId1, shouldNotRotate);
final TransportKeys rotated = createTransportKeys(1000, 0); final TransportKeys rotated = createTransportKeys(1000, 0);
final Transaction txn1 = new Transaction(null, false); final Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// Get the current time (1 ms after start of rotation period 1000) // Get the current time (1 ms after start of rotation period 1000)
oneOf(clock).currentTimeMillis(); oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1000 + 1)); will(returnValue(rotationPeriodLength * 1000 + 1));
// Load the transport keys // Load the transport keys
oneOf(db).startTransaction(true);
will(returnValue(txn));
oneOf(db).getTransportKeys(txn, transportId); oneOf(db).getTransportKeys(txn, transportId);
will(returnValue(loaded)); will(returnValue(loaded));
oneOf(db).endTransaction(txn);
// Rotate the transport keys // Rotate the transport keys
oneOf(crypto).rotateTransportKeys(shouldRotate, 1000); oneOf(crypto).rotateTransportKeys(shouldRotate, 1000);
will(returnValue(rotated)); will(returnValue(rotated));
@@ -88,11 +87,8 @@ public class TransportKeyManagerTest extends BriarTestCase {
will(new EncodeTagAction()); will(new EncodeTagAction());
} }
// Save the keys that were rotated // Save the keys that were rotated
oneOf(db).startTransaction(false); oneOf(db).updateTransportKeys(txn,
will(returnValue(txn1));
oneOf(db).updateTransportKeys(txn1,
Collections.singletonMap(contactId, rotated)); Collections.singletonMap(contactId, rotated));
oneOf(db).endTransaction(txn1);
// Schedule key rotation at the start of the next rotation period // Schedule key rotation at the start of the next rotation period
oneOf(timer).schedule(with(any(TimerTask.class)), oneOf(timer).schedule(with(any(TimerTask.class)),
with(rotationPeriodLength - 1)); with(rotationPeriodLength - 1));
@@ -100,7 +96,7 @@ public class TransportKeyManagerTest extends BriarTestCase {
TransportKeyManager transportKeyManager = new TransportKeyManager(db, TransportKeyManager transportKeyManager = new TransportKeyManager(db,
crypto, timer, clock, transportId, maxLatency); crypto, timer, clock, transportId, maxLatency);
transportKeyManager.start(); transportKeyManager.start(txn);
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -112,10 +108,12 @@ public class TransportKeyManagerTest extends BriarTestCase {
final CryptoComponent crypto = context.mock(CryptoComponent.class); final CryptoComponent crypto = context.mock(CryptoComponent.class);
final Timer timer = context.mock(Timer.class); final Timer timer = context.mock(Timer.class);
final Clock clock = context.mock(Clock.class); final Clock clock = context.mock(Clock.class);
final boolean alice = true; final boolean alice = true;
final TransportKeys transportKeys = createTransportKeys(999, 0); final TransportKeys transportKeys = createTransportKeys(999, 0);
final TransportKeys rotated = createTransportKeys(1000, 0); final TransportKeys rotated = createTransportKeys(1000, 0);
final Transaction txn = new Transaction(null, false); final Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(crypto).deriveTransportKeys(transportId, masterKey, 999, oneOf(crypto).deriveTransportKeys(transportId, masterKey, 999,
alice); alice);
@@ -155,9 +153,11 @@ public class TransportKeyManagerTest extends BriarTestCase {
final Timer timer = context.mock(Timer.class); final Timer timer = context.mock(Timer.class);
final Clock clock = context.mock(Clock.class); final Clock clock = context.mock(Clock.class);
final Transaction txn = new Transaction(null, false);
TransportKeyManager transportKeyManager = new TransportKeyManager(db, TransportKeyManager transportKeyManager = new TransportKeyManager(db,
crypto, timer, clock, transportId, maxLatency); crypto, timer, clock, transportId, maxLatency);
assertNull(transportKeyManager.getStreamContext(contactId)); assertNull(transportKeyManager.getStreamContext(txn, contactId));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -170,11 +170,13 @@ public class TransportKeyManagerTest extends BriarTestCase {
final CryptoComponent crypto = context.mock(CryptoComponent.class); final CryptoComponent crypto = context.mock(CryptoComponent.class);
final Timer timer = context.mock(Timer.class); final Timer timer = context.mock(Timer.class);
final Clock clock = context.mock(Clock.class); final Clock clock = context.mock(Clock.class);
final boolean alice = true; final boolean alice = true;
// The stream counter has been exhausted // The stream counter has been exhausted
final TransportKeys transportKeys = createTransportKeys(1000, final TransportKeys transportKeys = createTransportKeys(1000,
MAX_32_BIT_UNSIGNED + 1); MAX_32_BIT_UNSIGNED + 1);
final Transaction txn = new Transaction(null, false); final Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000, oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000,
alice); alice);
@@ -201,7 +203,7 @@ public class TransportKeyManagerTest extends BriarTestCase {
long timestamp = rotationPeriodLength * 1000; long timestamp = rotationPeriodLength * 1000;
transportKeyManager.addContact(txn, contactId, masterKey, timestamp, transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
alice); alice);
assertNull(transportKeyManager.getStreamContext(contactId)); assertNull(transportKeyManager.getStreamContext(txn, contactId));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -213,12 +215,13 @@ public class TransportKeyManagerTest extends BriarTestCase {
final CryptoComponent crypto = context.mock(CryptoComponent.class); final CryptoComponent crypto = context.mock(CryptoComponent.class);
final Timer timer = context.mock(Timer.class); final Timer timer = context.mock(Timer.class);
final Clock clock = context.mock(Clock.class); final Clock clock = context.mock(Clock.class);
final boolean alice = true; final boolean alice = true;
// The stream counter can be used one more time before being exhausted // The stream counter can be used one more time before being exhausted
final TransportKeys transportKeys = createTransportKeys(1000, final TransportKeys transportKeys = createTransportKeys(1000,
MAX_32_BIT_UNSIGNED); MAX_32_BIT_UNSIGNED);
final Transaction txn = new Transaction(null, false); final Transaction txn = new Transaction(null, false);
final Transaction txn1 = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000, oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000,
alice); alice);
@@ -238,11 +241,7 @@ public class TransportKeyManagerTest extends BriarTestCase {
// Save the keys // Save the keys
oneOf(db).addTransportKeys(txn, contactId, transportKeys); oneOf(db).addTransportKeys(txn, contactId, transportKeys);
// Increment the stream counter // Increment the stream counter
oneOf(db).startTransaction(false); oneOf(db).incrementStreamCounter(txn, contactId, transportId, 1000);
will(returnValue(txn1));
oneOf(db).incrementStreamCounter(txn1, contactId, transportId,
1000);
oneOf(db).endTransaction(txn1);
}}); }});
TransportKeyManager transportKeyManager = new TransportKeyManager(db, TransportKeyManager transportKeyManager = new TransportKeyManager(db,
@@ -252,7 +251,8 @@ public class TransportKeyManagerTest extends BriarTestCase {
transportKeyManager.addContact(txn, contactId, masterKey, timestamp, transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
alice); alice);
// The first request should return a stream context // The first request should return a stream context
StreamContext ctx = transportKeyManager.getStreamContext(contactId); StreamContext ctx = transportKeyManager.getStreamContext(txn,
contactId);
assertNotNull(ctx); assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId()); assertEquals(contactId, ctx.getContactId());
assertEquals(transportId, ctx.getTransportId()); assertEquals(transportId, ctx.getTransportId());
@@ -260,7 +260,7 @@ public class TransportKeyManagerTest extends BriarTestCase {
assertEquals(headerKey, ctx.getHeaderKey()); assertEquals(headerKey, ctx.getHeaderKey());
assertEquals(MAX_32_BIT_UNSIGNED, ctx.getStreamNumber()); assertEquals(MAX_32_BIT_UNSIGNED, ctx.getStreamNumber());
// The second request should return null, the counter is exhausted // The second request should return null, the counter is exhausted
assertNull(transportKeyManager.getStreamContext(contactId)); assertNull(transportKeyManager.getStreamContext(txn, contactId));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -273,9 +273,11 @@ public class TransportKeyManagerTest extends BriarTestCase {
final CryptoComponent crypto = context.mock(CryptoComponent.class); final CryptoComponent crypto = context.mock(CryptoComponent.class);
final Timer timer = context.mock(Timer.class); final Timer timer = context.mock(Timer.class);
final Clock clock = context.mock(Clock.class); final Clock clock = context.mock(Clock.class);
final boolean alice = true; final boolean alice = true;
final TransportKeys transportKeys = createTransportKeys(1000, 0); final TransportKeys transportKeys = createTransportKeys(1000, 0);
final Transaction txn = new Transaction(null, false); final Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000, oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000,
alice); alice);
@@ -302,7 +304,8 @@ public class TransportKeyManagerTest extends BriarTestCase {
long timestamp = rotationPeriodLength * 1000; long timestamp = rotationPeriodLength * 1000;
transportKeyManager.addContact(txn, contactId, masterKey, timestamp, transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
alice); alice);
assertNull(transportKeyManager.getStreamContext(new byte[TAG_LENGTH])); assertNull(transportKeyManager.getStreamContext(txn,
new byte[TAG_LENGTH]));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -314,12 +317,13 @@ public class TransportKeyManagerTest extends BriarTestCase {
final CryptoComponent crypto = context.mock(CryptoComponent.class); final CryptoComponent crypto = context.mock(CryptoComponent.class);
final Timer timer = context.mock(Timer.class); final Timer timer = context.mock(Timer.class);
final Clock clock = context.mock(Clock.class); final Clock clock = context.mock(Clock.class);
final boolean alice = true; final boolean alice = true;
final TransportKeys transportKeys = createTransportKeys(1000, 0); final TransportKeys transportKeys = createTransportKeys(1000, 0);
final Transaction txn = new Transaction(null, false);
final Transaction txn1 = new Transaction(null, false);
// Keep a copy of the tags // Keep a copy of the tags
final List<byte[]> tags = new ArrayList<byte[]>(); final List<byte[]> tags = new ArrayList<byte[]>();
final Transaction txn = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000, oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000,
alice); alice);
@@ -343,11 +347,8 @@ public class TransportKeyManagerTest extends BriarTestCase {
with(tagKey), with((long) REORDERING_WINDOW_SIZE)); with(tagKey), with((long) REORDERING_WINDOW_SIZE));
will(new EncodeTagAction(tags)); will(new EncodeTagAction(tags));
// Save the reordering window (previous rotation period, base 1) // Save the reordering window (previous rotation period, base 1)
oneOf(db).startTransaction(false); oneOf(db).setReorderingWindow(txn, contactId, transportId, 999,
will(returnValue(txn1));
oneOf(db).setReorderingWindow(txn1, contactId, transportId, 999,
1, new byte[REORDERING_WINDOW_SIZE / 8]); 1, new byte[REORDERING_WINDOW_SIZE / 8]);
oneOf(db).endTransaction(txn1);
}}); }});
TransportKeyManager transportKeyManager = new TransportKeyManager(db, TransportKeyManager transportKeyManager = new TransportKeyManager(db,
@@ -360,7 +361,7 @@ public class TransportKeyManagerTest extends BriarTestCase {
assertEquals(REORDERING_WINDOW_SIZE * 3, tags.size()); assertEquals(REORDERING_WINDOW_SIZE * 3, tags.size());
byte[] tag = tags.get(0); byte[] tag = tags.get(0);
// The first request should return a stream context // The first request should return a stream context
StreamContext ctx = transportKeyManager.getStreamContext(tag); StreamContext ctx = transportKeyManager.getStreamContext(txn, tag);
assertNotNull(ctx); assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId()); assertEquals(contactId, ctx.getContactId());
assertEquals(transportId, ctx.getTransportId()); assertEquals(transportId, ctx.getTransportId());
@@ -370,7 +371,7 @@ public class TransportKeyManagerTest extends BriarTestCase {
// Another tag should have been encoded // Another tag should have been encoded
assertEquals(REORDERING_WINDOW_SIZE * 3 + 1, tags.size()); assertEquals(REORDERING_WINDOW_SIZE * 3 + 1, tags.size());
// The second request should return null, the tag has already been used // The second request should return null, the tag has already been used
assertNull(transportKeyManager.getStreamContext(tag)); assertNull(transportKeyManager.getStreamContext(txn, tag));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -382,22 +383,21 @@ public class TransportKeyManagerTest extends BriarTestCase {
final CryptoComponent crypto = context.mock(CryptoComponent.class); final CryptoComponent crypto = context.mock(CryptoComponent.class);
final Timer timer = context.mock(Timer.class); final Timer timer = context.mock(Timer.class);
final Clock clock = context.mock(Clock.class); final Clock clock = context.mock(Clock.class);
final Transaction txn = new Transaction(null, true);
final TransportKeys transportKeys = createTransportKeys(1000, 0); final TransportKeys transportKeys = createTransportKeys(1000, 0);
final Map<ContactId, TransportKeys> loaded = final Map<ContactId, TransportKeys> loaded =
Collections.singletonMap(contactId, transportKeys); Collections.singletonMap(contactId, transportKeys);
final TransportKeys rotated = createTransportKeys(1001, 0); final TransportKeys rotated = createTransportKeys(1001, 0);
final Transaction txn = new Transaction(null, false);
final Transaction txn1 = new Transaction(null, false); final Transaction txn1 = new Transaction(null, false);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// Get the current time (the start of rotation period 1000) // Get the current time (the start of rotation period 1000)
oneOf(clock).currentTimeMillis(); oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1000)); will(returnValue(rotationPeriodLength * 1000));
// Load the transport keys // Load the transport keys
oneOf(db).startTransaction(true);
will(returnValue(txn));
oneOf(db).getTransportKeys(txn, transportId); oneOf(db).getTransportKeys(txn, transportId);
will(returnValue(loaded)); will(returnValue(loaded));
oneOf(db).endTransaction(txn);
// Rotate the transport keys (the keys are unaffected) // Rotate the transport keys (the keys are unaffected)
oneOf(crypto).rotateTransportKeys(transportKeys, 1000); oneOf(crypto).rotateTransportKeys(transportKeys, 1000);
will(returnValue(transportKeys)); will(returnValue(transportKeys));
@@ -411,6 +411,9 @@ public class TransportKeyManagerTest extends BriarTestCase {
oneOf(timer).schedule(with(any(TimerTask.class)), oneOf(timer).schedule(with(any(TimerTask.class)),
with(rotationPeriodLength)); with(rotationPeriodLength));
will(new RunTimerTaskAction()); will(new RunTimerTaskAction());
// Start a transaction for key rotation
oneOf(db).startTransaction(false);
will(returnValue(txn1));
// Get the current time (the start of rotation period 1001) // Get the current time (the start of rotation period 1001)
oneOf(clock).currentTimeMillis(); oneOf(clock).currentTimeMillis();
will(returnValue(rotationPeriodLength * 1001)); will(returnValue(rotationPeriodLength * 1001));
@@ -425,19 +428,19 @@ public class TransportKeyManagerTest extends BriarTestCase {
will(new EncodeTagAction()); will(new EncodeTagAction());
} }
// Save the keys that were rotated // Save the keys that were rotated
oneOf(db).startTransaction(false);
will(returnValue(txn1));
oneOf(db).updateTransportKeys(txn1, oneOf(db).updateTransportKeys(txn1,
Collections.singletonMap(contactId, rotated)); Collections.singletonMap(contactId, rotated));
oneOf(db).endTransaction(txn1);
// Schedule key rotation at the start of the next rotation period // Schedule key rotation at the start of the next rotation period
oneOf(timer).schedule(with(any(TimerTask.class)), oneOf(timer).schedule(with(any(TimerTask.class)),
with(rotationPeriodLength)); with(rotationPeriodLength));
// Commit the key rotation transaction
oneOf(db).endTransaction(txn1);
}}); }});
TransportKeyManager transportKeyManager = new TransportKeyManager(db, TransportKeyManager transportKeyManager = new TransportKeyManager(db,
crypto, timer, clock, transportId, maxLatency); crypto, timer, clock, transportId, maxLatency);
transportKeyManager.start(); transportKeyManager.start(txn);
assertTrue(txn1.isComplete());
context.assertIsSatisfied(); context.assertIsSatisfied();
} }