Implemented KeyManager (untested).

A test is failing due to key derivation errors - must be fixed!
This commit is contained in:
akwizgran
2012-10-24 18:16:17 +01:00
parent cc6e9d53ad
commit 27e50b8495
25 changed files with 540 additions and 306 deletions

View File

@@ -6,10 +6,20 @@ import net.sf.briar.api.transport.ConnectionContext;
public interface KeyManager { public interface KeyManager {
/**
* Starts the key manager and returns true if the manager started
* successfully. This method must be called after the database has been
* opened.
*/
boolean start();
/** Stops the key manager. */
void stop();
/** /**
* Returns a connection context for connecting to the given contact over * Returns a connection context for connecting to the given contact over
* the given transport, or null if the contact does not support the * the given transport, or null if an error occurs or the contact does not
* transport. * support the transport.
*/ */
ConnectionContext getConnectionContext(ContactId c, TransportId t); ConnectionContext getConnectionContext(ContactId c, TransportId t);
} }

View File

@@ -114,9 +114,6 @@ public interface DatabaseComponent {
/** Returns the IDs of all contacts. */ /** Returns the IDs of all contacts. */
Collection<ContactId> getContacts() throws DbException; Collection<ContactId> getContacts() throws DbException;
/** Returns all contact transports. */
Collection<ContactTransport> getContactTransports() throws DbException;
/** Returns the local transport properties for the given transport. */ /** Returns the local transport properties for the given transport. */
TransportProperties getLocalProperties(TransportId t) throws DbException; TransportProperties getLocalProperties(TransportId t) throws DbException;
@@ -150,9 +147,9 @@ public interface DatabaseComponent {
/** /**
* Increments the outgoing connection counter for the given contact * Increments the outgoing connection counter for the given contact
* transport in the given rotation period. * transport in the given rotation period and returns the old value.
*/ */
void incrementConnectionCounter(ContactId c, TransportId t, long period) long incrementConnectionCounter(ContactId c, TransportId t, long period)
throws DbException; throws DbException;
/** Processes an acknowledgement from the given contact. */ /** Processes an acknowledgement from the given contact. */

View File

@@ -1,20 +1,19 @@
package net.sf.briar.api.db; package net.sf.briar.api.db;
import static net.sf.briar.api.transport.TransportConstants.CONNECTION_WINDOW_SIZE;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
public class TemporarySecret { public class TemporarySecret extends ContactTransport {
private final ContactId contactId;
private final TransportId transportId;
private final long period, outgoing, centre; private final long period, outgoing, centre;
private final byte[] secret, bitmap; private final byte[] secret, bitmap;
public TemporarySecret(ContactId contactId, TransportId transportId, public TemporarySecret(ContactId contactId, TransportId transportId,
long epoch, long clockDiff, long latency, boolean alice,
long period, byte[] secret, long outgoing, long centre, long period, byte[] secret, long outgoing, long centre,
byte[] bitmap) { byte[] bitmap) {
this.contactId = contactId; super(contactId, transportId, epoch, clockDiff, latency, alice);
this.transportId = transportId;
this.period = period; this.period = period;
this.secret = secret; this.secret = secret;
this.outgoing = outgoing; this.outgoing = outgoing;
@@ -22,12 +21,14 @@ public class TemporarySecret {
this.bitmap = bitmap; this.bitmap = bitmap;
} }
public ContactId getContactId() { public TemporarySecret(TemporarySecret old, long period, byte[] secret) {
return contactId; super(old.getContactId(), old.getTransportId(), old.getEpoch(),
} old.getClockDifference(), old.getLatency(), old.getAlice());
this.period = period;
public TransportId getTransportId() { this.secret = secret;
return transportId; outgoing = 0L;
centre = 0L;
bitmap = new byte[CONNECTION_WINDOW_SIZE / 8];
} }
public long getPeriod() { public long getPeriod() {

View File

@@ -2,6 +2,7 @@ package net.sf.briar.api.transport;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.TemporarySecret;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
/** /**
@@ -17,10 +18,11 @@ public interface ConnectionRecogniser {
ConnectionContext acceptConnection(TransportId t, byte[] tag) ConnectionContext acceptConnection(TransportId t, byte[] tag)
throws DbException; throws DbException;
void addWindow(ContactId c, TransportId t, long period, boolean alice, void addSecret(TemporarySecret s) throws DbException;
byte[] secret, long centre, byte[] bitmap) throws DbException;
void removeWindow(ContactId c, TransportId t, long period); void removeSecret(ContactId c, TransportId t, long period);
void removeWindows(ContactId c); void removeSecrets(ContactId c);
void removeSecrets();
} }

View File

@@ -1,24 +0,0 @@
package net.sf.briar.crypto;
import net.sf.briar.api.db.DbException;
interface KeyRotator {
/**
* Starts a new thread to rotate keys periodically. The rotator will pause
* for the given number of milliseconds between rotations.
*/
void startRotating(Callback callback, long msBetweenRotations);
/** Tells the rotator thread to exit. */
void stopRotating();
interface Callback {
/**
* Rotates keys, replacing and destroying any keys that have passed the
* ends of their respective retention periods.
*/
void rotateKeys() throws DbException;
}
}

View File

@@ -1,41 +0,0 @@
package net.sf.briar.crypto;
import java.util.Timer;
import java.util.TimerTask;
import java.util.logging.Level;
import java.util.logging.Logger;
import net.sf.briar.api.db.DbException;
class KeyRotatorImpl extends TimerTask implements KeyRotator {
private static final Logger LOG =
Logger.getLogger(KeyRotatorImpl.class.getName());
private volatile Callback callback = null;
private volatile Timer timer = null;
public void startRotating(Callback callback, long msBetweenRotations) {
this.callback = callback;
timer = new Timer();
timer.scheduleAtFixedRate(this, 0L, msBetweenRotations);
}
public void stopRotating() {
if(timer == null) throw new IllegalStateException();
timer.cancel();
}
public void run() {
if(callback == null) throw new IllegalStateException();
try {
callback.rotateKeys();
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
throw new Error(e); // Kill the application
} catch(RuntimeException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
throw new Error(e); // Kill the application
}
}
}

View File

@@ -468,11 +468,11 @@ interface Database<T> {
/** /**
* Increments the outgoing connection counter for the given contact * Increments the outgoing connection counter for the given contact
* transport in the given rotation period. * transport in the given rotation period and returns the old value;
* <p> * <p>
* Locking: contact read, window write. * Locking: contact read, window write.
*/ */
void incrementConnectionCounter(T txn, ContactId c, TransportId t, long incrementConnectionCounter(T txn, ContactId c, TransportId t,
long period) throws DbException; long period) throws DbException;
/** /**

View File

@@ -758,30 +758,6 @@ DatabaseCleaner.Callback {
} }
} }
public Collection<ContactTransport> getContactTransports()
throws DbException {
contactLock.readLock().lock();
try {
windowLock.readLock().lock();
try {
T txn = db.startTransaction();
try {
Collection<ContactTransport> contactTransports =
db.getContactTransports(txn);
db.commitTransaction(txn);
return contactTransports;
} catch(DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
windowLock.readLock().unlock();
}
} finally {
contactLock.readLock().unlock();
}
}
public TransportProperties getLocalProperties(TransportId t) public TransportProperties getLocalProperties(TransportId t)
throws DbException { throws DbException {
transportLock.readLock().lock(); transportLock.readLock().lock();
@@ -1005,7 +981,7 @@ DatabaseCleaner.Callback {
} }
} }
public void incrementConnectionCounter(ContactId c, TransportId t, public long incrementConnectionCounter(ContactId c, TransportId t,
long period) throws DbException { long period) throws DbException {
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
@@ -1015,8 +991,9 @@ DatabaseCleaner.Callback {
try { try {
if(!db.containsContactTransport(txn, c, t)) if(!db.containsContactTransport(txn, c, t))
throw new NoSuchContactTransportException(); throw new NoSuchContactTransportException();
db.incrementConnectionCounter(txn, c, t, period); long l = db.incrementConnectionCounter(txn, c, t, period);
db.commitTransaction(txn); db.commitTransaction(txn);
return l;
} catch(DbException e) { } catch(DbException e) {
db.abortTransaction(txn); db.abortTransaction(txn);
throw e; throw e;

View File

@@ -1557,22 +1557,30 @@ abstract class JdbcDatabase implements Database<Connection> {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null; ResultSet rs = null;
try { try {
String sql = "SELECT contactId, transportId, period, secret," String sql = "SELECT ct.contactId, ct.transportId, epoch,"
+ " outgoing, centre, bitmap" + " clockDiff, latency, alice, period, secret, outgoing,"
+ " FROM secrets"; + " centre, bitmap"
+ " FROM contactTransports AS ct"
+ " JOIN secrets AS s"
+ " ON ct.contactId = s.contactId"
+ " AND ct.transportId = s.transportId";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
rs = ps.executeQuery(); rs = ps.executeQuery();
List<TemporarySecret> secrets = new ArrayList<TemporarySecret>(); List<TemporarySecret> secrets = new ArrayList<TemporarySecret>();
while(rs.next()) { while(rs.next()) {
ContactId c = new ContactId(rs.getInt(1)); ContactId c = new ContactId(rs.getInt(1));
TransportId t = new TransportId(rs.getBytes(2)); TransportId t = new TransportId(rs.getBytes(2));
long period = rs.getLong(3); long epoch = rs.getLong(3);
byte[] secret = rs.getBytes(4); long clockDiff = rs.getLong(4);
long outgoing = rs.getLong(5); long latency = rs.getLong(5);
long centre = rs.getLong(6); boolean alice = rs.getBoolean(6);
byte[] bitmap = rs.getBytes(7); long period = rs.getLong(7);
secrets.add(new TemporarySecret(c, t, period, secret, outgoing, byte[] secret = rs.getBytes(8);
centre, bitmap)); long outgoing = rs.getLong(9);
long centre = rs.getLong(10);
byte[] bitmap = rs.getBytes(11);
secrets.add(new TemporarySecret(c, t, epoch, clockDiff, latency,
alice, period, secret, outgoing, centre, bitmap));
} }
rs.close(); rs.close();
ps.close(); ps.close();
@@ -2021,11 +2029,26 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
} }
public void incrementConnectionCounter(Connection txn, ContactId c, public long incrementConnectionCounter(Connection txn, ContactId c,
TransportId t, long period) throws DbException { TransportId t, long period) throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null;
try { try {
String sql = "UPDATE secrets SET outgoing = outgoing + 1" // Get the current connection counter
String sql = "SELECT outgoing FROM secrets"
+ " WHERE contactId = ? AND transportId = ? AND period + ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setBytes(2, t.getBytes());
ps.setLong(3, period);
rs = ps.executeQuery();
if(!rs.next()) throw new DbStateException();
long connection = rs.getLong(1);
if(rs.next()) throw new DbStateException();
rs.close();
ps.close();
// Increment the connection counter
sql = "UPDATE secrets SET outgoing = outgoing + 1"
+ " WHERE contactId = ? AND transportId = ? AND period = ?"; + " WHERE contactId = ? AND transportId = ? AND period = ?";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); ps.setInt(1, c.getInt());
@@ -2034,8 +2057,10 @@ abstract class JdbcDatabase implements Database<Connection> {
int affected = ps.executeUpdate(); int affected = ps.executeUpdate();
if(affected > 1) throw new DbStateException(); if(affected > 1) throw new DbStateException();
ps.close(); ps.close();
return connection;
} catch(SQLException e) { } catch(SQLException e) {
tryToClose(ps); tryToClose(ps);
tryToClose(rs);
throw new DbException(e); throw new DbException(e);
} }
} }

View File

@@ -7,6 +7,7 @@ import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.TemporarySecret;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser; import net.sf.briar.api.transport.ConnectionRecogniser;
@@ -37,9 +38,8 @@ class ConnectionRecogniserImpl implements ConnectionRecogniser {
return r.acceptConnection(tag); return r.acceptConnection(tag);
} }
public void addWindow(ContactId c, TransportId t, long period, public void addSecret(TemporarySecret s) throws DbException {
boolean alice, byte[] secret, long centre, byte[] bitmap) TransportId t = s.getTransportId();
throws DbException {
TransportConnectionRecogniser r; TransportConnectionRecogniser r;
synchronized(this) { synchronized(this) {
r = recognisers.get(t); r = recognisers.get(t);
@@ -48,20 +48,24 @@ class ConnectionRecogniserImpl implements ConnectionRecogniser {
recognisers.put(t, r); recognisers.put(t, r);
} }
} }
r.addWindow(c, period, alice, secret, centre, bitmap); r.addSecret(s);
} }
public void removeWindow(ContactId c, TransportId t, long period) { public void removeSecret(ContactId c, TransportId t, long period) {
TransportConnectionRecogniser r; TransportConnectionRecogniser r;
synchronized(this) { synchronized(this) {
r = recognisers.get(t); r = recognisers.get(t);
} }
if(r != null) r.removeWindow(c, period); if(r != null) r.removeSecret(c, period);
} }
public synchronized void removeWindows(ContactId c) { public synchronized void removeSecrets(ContactId c) {
for(TransportConnectionRecogniser r : recognisers.values()) { for(TransportConnectionRecogniser r : recognisers.values())
r.removeWindows(c); r.removeSecrets(c);
} }
public synchronized void removeSecrets() {
for(TransportConnectionRecogniser r : recognisers.values())
r.removeSecrets();
} }
} }

View File

@@ -13,6 +13,7 @@ import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import com.google.inject.Inject; import com.google.inject.Inject;
class ConnectionWriterFactoryImpl implements ConnectionWriterFactory { class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {

View File

@@ -6,6 +6,7 @@ import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
/** /**

View File

@@ -0,0 +1,305 @@
package net.sf.briar.transport;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Timer;
import java.util.TimerTask;
import java.util.logging.Level;
import java.util.logging.Logger;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.KeyManager;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.TemporarySecret;
import net.sf.briar.api.db.event.ContactRemovedEvent;
import net.sf.briar.api.db.event.DatabaseEvent;
import net.sf.briar.api.db.event.DatabaseListener;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.util.ByteUtils;
import com.google.inject.Inject;
class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener {
private static final int MS_BETWEEN_CHECKS = 60 * 1000;
private static final Logger LOG =
Logger.getLogger(KeyManagerImpl.class.getName());
private final CryptoComponent crypto;
private final DatabaseComponent db;
private final ConnectionRecogniser recogniser;
private final Timer timer;
// Locking: this
private final Map<ContactTransportKey, TemporarySecret> outgoing;
// Locking: this
private final Map<ContactTransportKey, TemporarySecret> incomingOld;
// Locking: this
private final Map<ContactTransportKey, TemporarySecret> incomingNew;
// Locking: this
private boolean running = false;
@Inject
public KeyManagerImpl(CryptoComponent crypto, DatabaseComponent db,
ConnectionRecogniser recogniser) {
this.crypto = crypto;
this.db = db;
this.recogniser = recogniser;
timer = new Timer();
outgoing = new HashMap<ContactTransportKey, TemporarySecret>();
incomingOld = new HashMap<ContactTransportKey, TemporarySecret>();
incomingNew = new HashMap<ContactTransportKey, TemporarySecret>();
}
public synchronized boolean start() {
if(running) return false;
Collection<TemporarySecret> secrets;
try {
secrets = db.getSecrets();
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
return false;
}
// Work out what phase of its lifecycle each secret is in
long now = System.currentTimeMillis();
Collection<TemporarySecret> dead = assignSecretsToMaps(now, secrets);
// Replace any dead secrets
Collection<TemporarySecret> created = replaceDeadSecrets(now, dead);
try {
// Store any secrets that have been created
if(!created.isEmpty()) db.addSecrets(created);
// Pass the current incoming secrets to the connection recogniser
// FIXME: This uses a separate database transaction for each secret
for(TemporarySecret s : incomingOld.values())
recogniser.addSecret(s);
for(TemporarySecret s : incomingNew.values())
recogniser.addSecret(s);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
return false;
}
// Schedule periodic key rotation
timer.scheduleAtFixedRate(this, MS_BETWEEN_CHECKS, MS_BETWEEN_CHECKS);
running = true;
return true;
}
// Assigns secrets to the appropriate maps and returns any dead secrets
// FIXME: Check there are no duplicate keys when updating maps
private Collection<TemporarySecret> assignSecretsToMaps(long now,
Collection<TemporarySecret> secrets) {
Collection<TemporarySecret> dead = new ArrayList<TemporarySecret>();
for(TemporarySecret s : secrets) {
ContactId c = s.getContactId();
TransportId t = s.getTransportId();
ContactTransportKey k = new ContactTransportKey(c, t);
long rotationPeriod = getRotationPeriod(s);
long creationTime = getCreationTime(s);
long activationTime = creationTime + s.getClockDifference();
long successorCreationTime = creationTime + rotationPeriod;
long deactivationTime = activationTime + rotationPeriod;
long destructionTime = successorCreationTime + rotationPeriod;
if(now >= destructionTime) {
dead.add(s);
} else if(now >= deactivationTime) {
incomingOld.put(k, s);
} else if(now >= successorCreationTime) {
incomingOld.put(k, s);
outgoing.put(k, s);
} else if(now >= activationTime) {
incomingNew.put(k, s);
outgoing.put(k, s);
} else if(now >= creationTime) {
incomingNew.put(k, s);
} else {
// FIXME: What should we do if the clock moves backwards?
throw new IllegalStateException();
}
}
return dead;
}
// Replaces and erases the given secrets and returns any secrets created
private Collection<TemporarySecret> replaceDeadSecrets(long now,
Collection<TemporarySecret> dead) {
Collection<TemporarySecret> created = new ArrayList<TemporarySecret>();
for(TemporarySecret s : dead) {
ContactId c = s.getContactId();
TransportId t = s.getTransportId();
ContactTransportKey k = new ContactTransportKey(c, t);
if(incomingNew.containsKey(k)) throw new IllegalStateException();
byte[] secret = s.getSecret();
long period = s.getPeriod();
if(incomingOld.containsKey(k)) {
// The dead secret's successor is still alive
byte[] secret1 = crypto.deriveNextSecret(secret, period + 1);
TemporarySecret s1 = new TemporarySecret(s, period + 1,
secret1);
created.add(s1);
incomingNew.put(k, s1);
long creationTime = getCreationTime(s1);
long activationTime = creationTime + s1.getClockDifference();
if(now >= activationTime) outgoing.put(k, s1);
} else {
// The dead secret has no living successor
long rotationPeriod = getRotationPeriod(s);
long elapsed = now - s.getEpoch();
long currentPeriod = elapsed / rotationPeriod;
if(currentPeriod <= period) throw new IllegalStateException();
// Derive the two current incoming secrets
byte[] secret1, secret2;
secret1 = secret;
for(long l = period; l < currentPeriod; l++) {
byte[] temp = crypto.deriveNextSecret(secret1, l);
ByteUtils.erase(secret1);
secret1 = temp;
}
secret2 = crypto.deriveNextSecret(secret1, currentPeriod);
// One of the incoming secrets is the current outgoing secret
TemporarySecret s1, s2;
s1 = new TemporarySecret(s, currentPeriod - 1, secret1);
created.add(s1);
incomingOld.put(k, s1);
s2 = new TemporarySecret(s, currentPeriod, secret2);
created.add(s2);
incomingNew.put(k, s2);
if(elapsed % rotationPeriod < s.getClockDifference()) {
// The outgoing secret is the newer incoming secret
outgoing.put(k, s2);
} else {
// The outgoing secret is the older incoming secret
outgoing.put(k, s1);
}
}
// Erase the dead secret
ByteUtils.erase(secret);
}
return created;
}
private long getRotationPeriod(TemporarySecret s) {
return 2 * s.getClockDifference() + s.getLatency();
}
private long getCreationTime(TemporarySecret s) {
long rotationPeriod = getRotationPeriod(s);
return s.getEpoch() + rotationPeriod * s.getPeriod();
}
public synchronized void stop() {
if(!running) return;
recogniser.removeSecrets();
removeAndEraseSecrets(outgoing);
removeAndEraseSecrets(incomingOld);
removeAndEraseSecrets(incomingNew);
running = false;
}
// Locking: this
private void removeAndEraseSecrets(Map<?, TemporarySecret> m) {
for(TemporarySecret s : m.values()) ByteUtils.erase(s.getSecret());
m.clear();
}
public synchronized ConnectionContext getConnectionContext(ContactId c,
TransportId t) {
TemporarySecret s = outgoing.get(new ContactTransportKey(c, t));
if(s == null) return null;
long connection;
try {
connection = db.incrementConnectionCounter(c, t, s.getPeriod());
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
return null;
}
byte[] secret = s.getSecret().clone();
return new ConnectionContext(c, t, secret, connection, s.getAlice());
}
@Override
public synchronized void run() {
// Rebuild the maps because we may be running a whole period late
Collection<TemporarySecret> secrets = new ArrayList<TemporarySecret>();
secrets.addAll(incomingOld.values());
secrets.addAll(incomingNew.values());
outgoing.clear();
incomingOld.clear();
incomingNew.clear();
// Work out what phase of its lifecycle each secret is in
long now = System.currentTimeMillis();
Collection<TemporarySecret> dead = assignSecretsToMaps(now, secrets);
// Replace any dead secrets
Collection<TemporarySecret> created = replaceDeadSecrets(now, dead);
try {
// Store any secrets that have been created
if(!created.isEmpty()) db.addSecrets(created);
// Pass the current incoming secrets to the connection recogniser
// FIXME: This uses a separate database transaction for each secret
for(TemporarySecret s : incomingOld.values())
recogniser.addSecret(s);
for(TemporarySecret s : incomingNew.values())
recogniser.addSecret(s);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
}
}
public void eventOccurred(DatabaseEvent e) {
if(e instanceof ContactRemovedEvent) {
ContactId c = ((ContactRemovedEvent) e).getContactId();
recogniser.removeSecrets(c);
synchronized(this) {
removeAndEraseSecrets(c, outgoing);
removeAndEraseSecrets(c, incomingOld);
removeAndEraseSecrets(c, incomingNew);
}
}
}
// Locking: this
private void removeAndEraseSecrets(ContactId c, Map<?, TemporarySecret> m) {
Iterator<TemporarySecret> it = m.values().iterator();
while(it.hasNext()) {
TemporarySecret s = it.next();
if(s.getContactId().equals(c)) {
ByteUtils.erase(s.getSecret());
it.remove();
}
}
}
private static class ContactTransportKey {
private final ContactId contactId;
private final TransportId transportId;
private ContactTransportKey(ContactId contactId,
TransportId transportId) {
this.contactId = contactId;
this.transportId = transportId;
}
@Override
public int hashCode() {
return contactId.hashCode() + transportId.hashCode();
}
@Override
public boolean equals(Object o) {
if(o instanceof ContactTransportKey) {
ContactTransportKey k = (ContactTransportKey) o;
return contactId.equals(k.contactId) &&
transportId.equals(k.transportId);
}
return false;
}
}
}

View File

@@ -11,6 +11,7 @@ import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import net.sf.briar.api.crypto.AuthenticatedCipher; import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;

View File

@@ -15,6 +15,7 @@ import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.TemporarySecret;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.util.ByteUtils; import net.sf.briar.util.ByteUtils;
@@ -74,8 +75,13 @@ class TransportConnectionRecogniser {
return ctx; return ctx;
} }
synchronized void addWindow(ContactId contactId, long period, boolean alice, synchronized void addSecret(TemporarySecret s) throws DbException {
byte[] secret, long centre, byte[] bitmap) throws DbException { ContactId contactId = s.getContactId();
long period = s.getPeriod();
byte[] secret = s.getSecret();
boolean alice = s.getAlice();
long centre = s.getWindowCentre();
byte[] bitmap = s.getWindowBitmap();
// Create the connection window and the expected tags // Create the connection window and the expected tags
Cipher cipher = crypto.getTagCipher(); Cipher cipher = crypto.getTagCipher();
ErasableKey key = crypto.deriveTagKey(secret, alice); ErasableKey key = crypto.deriveTagKey(secret, alice);
@@ -96,10 +102,15 @@ class TransportConnectionRecogniser {
removalMap.put(new RemovalKey(contactId, period), rctx); removalMap.put(new RemovalKey(contactId, period), rctx);
} }
synchronized void removeWindow(ContactId contactId, long period) { synchronized void removeSecret(ContactId contactId, long period) {
RemovalKey rk = new RemovalKey(contactId, period); RemovalKey rk = new RemovalKey(contactId, period);
RemovalContext rctx = removalMap.remove(rk); RemovalContext rctx = removalMap.remove(rk);
if(rctx == null) throw new IllegalArgumentException(); if(rctx == null) throw new IllegalArgumentException();
removeSecret(rctx);
}
// Locking: this
private void removeSecret(RemovalContext rctx) {
// Remove the expected tags // Remove the expected tags
Cipher cipher = crypto.getTagCipher(); Cipher cipher = crypto.getTagCipher();
ErasableKey key = crypto.deriveTagKey(rctx.secret, rctx.alice); ErasableKey key = crypto.deriveTagKey(rctx.secret, rctx.alice);
@@ -114,12 +125,18 @@ class TransportConnectionRecogniser {
ByteUtils.erase(rctx.secret); ByteUtils.erase(rctx.secret);
} }
synchronized void removeWindows(ContactId c) { synchronized void removeSecrets(ContactId c) {
Collection<RemovalKey> keysToRemove = new ArrayList<RemovalKey>(); Collection<RemovalKey> keysToRemove = new ArrayList<RemovalKey>();
for(RemovalKey k : removalMap.keySet()) { for(RemovalKey k : removalMap.keySet()) {
if(k.contactId.equals(c)) keysToRemove.add(k); if(k.contactId.equals(c)) keysToRemove.add(k);
} }
for(RemovalKey k : keysToRemove) removeWindow(k.contactId, k.period); for(RemovalKey k : keysToRemove) removeSecret(k.contactId, k.period);
}
synchronized void removeSecrets() {
for(RemovalContext rctx : removalMap.values()) removeSecret(rctx);
assert tagMap.isEmpty();
removalMap.clear();
} }
private static class WindowContext { private static class WindowContext {
@@ -148,7 +165,7 @@ class TransportConnectionRecogniser {
@Override @Override
public int hashCode() { public int hashCode() {
return contactId.hashCode()+ (int) period; return contactId.hashCode() + (int) period;
} }
@Override @Override

View File

@@ -3,13 +3,16 @@ package net.sf.briar.transport;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import net.sf.briar.api.crypto.KeyManager;
import net.sf.briar.api.transport.ConnectionDispatcher; import net.sf.briar.api.transport.ConnectionDispatcher;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionRegistry; import net.sf.briar.api.transport.ConnectionRegistry;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.api.transport.IncomingConnectionExecutor; import net.sf.briar.api.transport.IncomingConnectionExecutor;
import com.google.inject.AbstractModule; import com.google.inject.AbstractModule;
import com.google.inject.Singleton;
public class TransportModule extends AbstractModule { public class TransportModule extends AbstractModule {
@@ -18,6 +21,8 @@ public class TransportModule extends AbstractModule {
bind(ConnectionDispatcher.class).to(ConnectionDispatcherImpl.class); bind(ConnectionDispatcher.class).to(ConnectionDispatcherImpl.class);
bind(ConnectionReaderFactory.class).to( bind(ConnectionReaderFactory.class).to(
ConnectionReaderFactoryImpl.class); ConnectionReaderFactoryImpl.class);
bind(ConnectionRecogniser.class).to(ConnectionRecogniserImpl.class).in(
Singleton.class);
bind(ConnectionRegistry.class).toInstance(new ConnectionRegistryImpl()); bind(ConnectionRegistry.class).toInstance(new ConnectionRegistryImpl());
bind(ConnectionWriterFactory.class).to( bind(ConnectionWriterFactory.class).to(
ConnectionWriterFactoryImpl.class); ConnectionWriterFactoryImpl.class);
@@ -25,5 +30,6 @@ public class TransportModule extends AbstractModule {
bind(Executor.class).annotatedWith( bind(Executor.class).annotatedWith(
IncomingConnectionExecutor.class).toInstance( IncomingConnectionExecutor.class).toInstance(
Executors.newCachedThreadPool()); Executors.newCachedThreadPool());
bind(KeyManager.class).to(KeyManagerImpl.class).in(Singleton.class);
} }
} }

View File

@@ -20,7 +20,6 @@
<test name='net.sf.briar.crypto.ErasableKeyTest'/> <test name='net.sf.briar.crypto.ErasableKeyTest'/>
<test name='net.sf.briar.crypto.KeyAgreementTest'/> <test name='net.sf.briar.crypto.KeyAgreementTest'/>
<test name='net.sf.briar.crypto.KeyDerivationTest'/> <test name='net.sf.briar.crypto.KeyDerivationTest'/>
<test name='net.sf.briar.crypto.KeyRotatorImplTest'/>
<test name='net.sf.briar.db.BasicH2Test'/> <test name='net.sf.briar.db.BasicH2Test'/>
<test name='net.sf.briar.db.DatabaseCleanerImplTest'/> <test name='net.sf.briar.db.DatabaseCleanerImplTest'/>
<test name='net.sf.briar.db.DatabaseComponentImplTest'/> <test name='net.sf.briar.db.DatabaseComponentImplTest'/>

View File

@@ -189,7 +189,7 @@ public class ProtocolIntegrationTest extends BriarTestCase {
InputStream in = new ByteArrayInputStream(connectionData); InputStream in = new ByteArrayInputStream(connectionData);
byte[] tag = new byte[TAG_LENGTH]; byte[] tag = new byte[TAG_LENGTH];
assertEquals(TAG_LENGTH, in.read(tag, 0, TAG_LENGTH)); assertEquals(TAG_LENGTH, in.read(tag, 0, TAG_LENGTH));
assertArrayEquals(new byte[TAG_LENGTH], tag); // FIXME: Check that the expected tag was received
ConnectionContext ctx = new ConnectionContext(contactId, transportId, ConnectionContext ctx = new ConnectionContext(contactId, transportId,
secret.clone(), 0L, true); secret.clone(), 0L, true);
ConnectionReader conn = connectionReaderFactory.createConnectionReader( ConnectionReader conn = connectionReaderFactory.createConnectionReader(

View File

@@ -63,7 +63,7 @@ public class KeyDerivationTest extends BriarTestCase {
public void testConnectionNumberAffectsDerivation() { public void testConnectionNumberAffectsDerivation() {
List<byte[]> secrets = new ArrayList<byte[]>(); List<byte[]> secrets = new ArrayList<byte[]>();
for(int i = 0; i < 20; i++) { for(int i = 0; i < 20; i++) {
secrets.add(crypto.deriveNextSecret(secret, i)); secrets.add(crypto.deriveNextSecret(secret.clone(), i));
} }
for(int i = 0; i < 20; i++) { for(int i = 0; i < 20; i++) {
byte[] secretI = secrets.get(i); byte[] secretI = secrets.get(i);

View File

@@ -1,54 +0,0 @@
package net.sf.briar.crypto;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import net.sf.briar.BriarTestCase;
import net.sf.briar.api.db.DbException;
import net.sf.briar.crypto.KeyRotatorImpl;
import net.sf.briar.crypto.KeyRotator.Callback;
import org.junit.Test;
public class KeyRotatorImplTest extends BriarTestCase {
@Test
public void testCleanerRunsPeriodically() throws Exception {
final CountDownLatch latch = new CountDownLatch(5);
Callback callback = new Callback() {
public void rotateKeys() throws DbException {
latch.countDown();
}
};
KeyRotatorImpl cleaner = new KeyRotatorImpl();
// Start the rotator
cleaner.startRotating(callback, 10L);
// The keys should be rotated five times (allow 5 secs for system load)
assertTrue(latch.await(5, TimeUnit.SECONDS));
// Stop the rotator
cleaner.stopRotating();
}
@Test
public void testStoppingCleanerWakesItUp() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
Callback callback = new Callback() {
public void rotateKeys() throws DbException {
latch.countDown();
}
};
KeyRotatorImpl cleaner = new KeyRotatorImpl();
long start = System.currentTimeMillis();
// Start the rotator
cleaner.startRotating(callback, 10L * 1000L);
// The keys should be rotated once at startup
assertTrue(latch.await(5, TimeUnit.SECONDS));
// Stop the rotator (it should be waiting between rotations)
cleaner.stopRotating();
long end = System.currentTimeMillis();
// Check that much less than 10 seconds expired
assertTrue(end - start < 10L * 1000L);
}
}

View File

@@ -88,8 +88,8 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
transports = Collections.singletonList(transport); transports = Collections.singletonList(transport);
contactTransport = new ContactTransport(contactId, transportId, 123L, contactTransport = new ContactTransport(contactId, transportId, 123L,
234L, 345L, true); 234L, 345L, true);
temporarySecret = new TemporarySecret(contactId, transportId, 0L, temporarySecret = new TemporarySecret(contactId, transportId, 1L, 2L,
new byte[32], 0L, 0L, new byte[4]); 3L, false, 4L, new byte[32], 5L, 6L, new byte[4]);
} }
protected abstract <T> DatabaseComponent createDatabaseComponent( protected abstract <T> DatabaseComponent createDatabaseComponent(

View File

@@ -1728,18 +1728,18 @@ public class H2DatabaseTest extends BriarTestCase {
byte[] secret1 = new byte[32], bitmap1 = new byte[4]; byte[] secret1 = new byte[32], bitmap1 = new byte[4];
random.nextBytes(secret1); random.nextBytes(secret1);
random.nextBytes(bitmap1); random.nextBytes(bitmap1);
TemporarySecret ts1 = new TemporarySecret(contactId, transportId, 0L, TemporarySecret s1 = new TemporarySecret(contactId, transportId, 123L,
secret1, 123L, 234L, bitmap1); 234L, 345L, false, 0L, secret1, 456L, 567L, bitmap1);
byte[] secret2 = new byte[32], bitmap2 = new byte[4]; byte[] secret2 = new byte[32], bitmap2 = new byte[4];
random.nextBytes(secret2); random.nextBytes(secret2);
random.nextBytes(bitmap2); random.nextBytes(bitmap2);
TemporarySecret ts2 = new TemporarySecret(contactId, transportId, 1L, TemporarySecret s2 = new TemporarySecret(contactId, transportId, 1234L,
secret2, 1234L, 2345L, bitmap2); 2345L, 3456L, false, 1L, secret2, 4567L, 5678L, bitmap2);
byte[] secret3 = new byte[32], bitmap3 = new byte[4]; byte[] secret3 = new byte[32], bitmap3 = new byte[4];
random.nextBytes(secret3); random.nextBytes(secret3);
random.nextBytes(bitmap3); random.nextBytes(bitmap3);
TemporarySecret ts3 = new TemporarySecret(contactId, transportId, 2L, TemporarySecret s3 = new TemporarySecret(contactId, transportId, 12345L,
secret3, 12345L, 23456L, bitmap3); 23456L, 34567L, false, 0L, secret3, 45678L, 56789L, bitmap3);
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
@@ -1749,26 +1749,36 @@ public class H2DatabaseTest extends BriarTestCase {
// Add a contact and the first two secrets // Add a contact and the first two secrets
assertEquals(contactId, db.addContact(txn)); assertEquals(contactId, db.addContact(txn));
db.addSecrets(txn, Arrays.asList(ts1, ts2)); db.addSecrets(txn, Arrays.asList(s1, s2));
// Retrieve the first two secrets // Retrieve the first two secrets
Collection<TemporarySecret> secrets = db.getSecrets(txn); Collection<TemporarySecret> secrets = db.getSecrets(txn);
assertEquals(2, secrets.size()); assertEquals(2, secrets.size());
boolean foundFirst = false, foundSecond = false; boolean foundFirst = false, foundSecond = false;
for(TemporarySecret ts : secrets) { for(TemporarySecret s : secrets) {
assertEquals(contactId, ts.getContactId()); assertEquals(contactId, s.getContactId());
assertEquals(transportId, ts.getTransportId()); assertEquals(transportId, s.getTransportId());
if(ts.getPeriod() == 0L) { if(s.getPeriod() == 0L) {
assertArrayEquals(secret1, ts.getSecret()); assertEquals(s1.getEpoch(), s.getEpoch());
assertEquals(123L, ts.getOutgoingConnectionCounter()); assertEquals(s1.getClockDifference(), s.getClockDifference());
assertEquals(234L, ts.getWindowCentre()); assertEquals(s1.getLatency(), s.getLatency());
assertArrayEquals(bitmap1, ts.getWindowBitmap()); assertEquals(s1.getAlice(), s.getAlice());
assertArrayEquals(s1.getSecret(), s.getSecret());
assertEquals(s1.getOutgoingConnectionCounter(),
s.getOutgoingConnectionCounter());
assertEquals(s1.getWindowCentre(), s.getWindowCentre());
assertArrayEquals(s1.getWindowBitmap(), s.getWindowBitmap());
foundFirst = true; foundFirst = true;
} else if(ts.getPeriod() == 1L) { } else if(s.getPeriod() == 1L) {
assertArrayEquals(secret2, ts.getSecret()); assertEquals(s2.getEpoch(), s.getEpoch());
assertEquals(1234L, ts.getOutgoingConnectionCounter()); assertEquals(s2.getClockDifference(), s.getClockDifference());
assertEquals(2345L, ts.getWindowCentre()); assertEquals(s2.getLatency(), s.getLatency());
assertArrayEquals(bitmap2, ts.getWindowBitmap()); assertEquals(s2.getAlice(), s.getAlice());
assertArrayEquals(s2.getSecret(), s.getSecret());
assertEquals(s2.getOutgoingConnectionCounter(),
s.getOutgoingConnectionCounter());
assertEquals(s2.getWindowCentre(), s.getWindowCentre());
assertArrayEquals(s2.getWindowBitmap(), s.getWindowBitmap());
foundSecond = true; foundSecond = true;
} else { } else {
fail(); fail();
@@ -1778,25 +1788,35 @@ public class H2DatabaseTest extends BriarTestCase {
assertTrue(foundSecond); assertTrue(foundSecond);
// Adding the third secret (period 2) should delete the first (period 0) // Adding the third secret (period 2) should delete the first (period 0)
db.addSecrets(txn, Arrays.asList(ts3)); db.addSecrets(txn, Arrays.asList(s3));
secrets = db.getSecrets(txn); secrets = db.getSecrets(txn);
assertEquals(2, secrets.size()); assertEquals(2, secrets.size());
foundSecond = false; foundSecond = false;
boolean foundThird = false; boolean foundThird = false;
for(TemporarySecret ts : secrets) { for(TemporarySecret s : secrets) {
assertEquals(contactId, ts.getContactId()); assertEquals(contactId, s.getContactId());
assertEquals(transportId, ts.getTransportId()); assertEquals(transportId, s.getTransportId());
if(ts.getPeriod() == 1L) { if(s.getPeriod() == 1L) {
assertArrayEquals(secret2, ts.getSecret()); assertEquals(s2.getEpoch(), s.getEpoch());
assertEquals(1234L, ts.getOutgoingConnectionCounter()); assertEquals(s2.getClockDifference(), s.getClockDifference());
assertEquals(2345L, ts.getWindowCentre()); assertEquals(s2.getLatency(), s.getLatency());
assertArrayEquals(bitmap2, ts.getWindowBitmap()); assertEquals(s2.getAlice(), s.getAlice());
assertArrayEquals(s2.getSecret(), s.getSecret());
assertEquals(s2.getOutgoingConnectionCounter(),
s.getOutgoingConnectionCounter());
assertEquals(s2.getWindowCentre(), s.getWindowCentre());
assertArrayEquals(s2.getWindowBitmap(), s.getWindowBitmap());
foundSecond = true; foundSecond = true;
} else if(ts.getPeriod() == 2L) { } else if(s.getPeriod() == 2L) {
assertArrayEquals(secret3, ts.getSecret()); assertEquals(s3.getEpoch(), s.getEpoch());
assertEquals(12345L, ts.getOutgoingConnectionCounter()); assertEquals(s3.getClockDifference(), s.getClockDifference());
assertEquals(23456L, ts.getWindowCentre()); assertEquals(s3.getLatency(), s.getLatency());
assertArrayEquals(bitmap3, ts.getWindowBitmap()); assertEquals(s3.getAlice(), s.getAlice());
assertArrayEquals(s3.getSecret(), s.getSecret());
assertEquals(s3.getOutgoingConnectionCounter(),
s.getOutgoingConnectionCounter());
assertEquals(s3.getWindowCentre(), s.getWindowCentre());
assertArrayEquals(s3.getWindowBitmap(), s.getWindowBitmap());
foundThird = true; foundThird = true;
} else { } else {
fail(); fail();
@@ -1819,55 +1839,43 @@ public class H2DatabaseTest extends BriarTestCase {
Random random = new Random(); Random random = new Random();
byte[] secret = new byte[32], bitmap = new byte[4]; byte[] secret = new byte[32], bitmap = new byte[4];
random.nextBytes(secret); random.nextBytes(secret);
TemporarySecret ts = new TemporarySecret(contactId, transportId, 0L, TemporarySecret s = new TemporarySecret(contactId, transportId, 0L,
secret, 0L, 0L, bitmap); 0L, 0L, false, 0L, secret, 0L, 0L, bitmap);
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and the temporary secret // Add a contact and the temporary secret
assertEquals(contactId, db.addContact(txn)); assertEquals(contactId, db.addContact(txn));
db.addSecrets(txn, Arrays.asList(ts)); db.addSecrets(txn, Arrays.asList(s));
// Retrieve the secret // Retrieve the secret
Collection<TemporarySecret> secrets = db.getSecrets(txn); Collection<TemporarySecret> secrets = db.getSecrets(txn);
assertEquals(1, secrets.size()); assertEquals(1, secrets.size());
ts = secrets.iterator().next(); s = secrets.iterator().next();
assertEquals(contactId, ts.getContactId()); assertEquals(contactId, s.getContactId());
assertEquals(transportId, ts.getTransportId()); assertEquals(transportId, s.getTransportId());
assertEquals(0L, ts.getPeriod()); assertEquals(0L, s.getPeriod());
assertArrayEquals(secret, ts.getSecret()); assertArrayEquals(secret, s.getSecret());
assertEquals(0L, ts.getOutgoingConnectionCounter()); assertEquals(0L, s.getOutgoingConnectionCounter());
assertEquals(0L, ts.getWindowCentre()); assertEquals(0L, s.getWindowCentre());
assertArrayEquals(bitmap, ts.getWindowBitmap()); assertArrayEquals(bitmap, s.getWindowBitmap());
// Increment the connection counter twice and retrieve the secret again // Increment the connection counter twice and retrieve the secret again
db.incrementConnectionCounter(txn, contactId, transportId, 0L); assertEquals(0L, db.incrementConnectionCounter(txn, contactId,
db.incrementConnectionCounter(txn, contactId, transportId, 0L); transportId, 0L));
assertEquals(1L, db.incrementConnectionCounter(txn, contactId,
transportId, 0L));
secrets = db.getSecrets(txn); secrets = db.getSecrets(txn);
assertEquals(1, secrets.size()); assertEquals(1, secrets.size());
ts = secrets.iterator().next(); s = secrets.iterator().next();
assertEquals(contactId, ts.getContactId()); assertEquals(contactId, s.getContactId());
assertEquals(transportId, ts.getTransportId()); assertEquals(transportId, s.getTransportId());
assertEquals(0L, ts.getPeriod()); assertEquals(0L, s.getPeriod());
assertArrayEquals(secret, ts.getSecret()); assertArrayEquals(secret, s.getSecret());
assertEquals(2L, ts.getOutgoingConnectionCounter()); assertEquals(2L, s.getOutgoingConnectionCounter());
assertEquals(0L, ts.getWindowCentre()); assertEquals(0L, s.getWindowCentre());
assertArrayEquals(bitmap, ts.getWindowBitmap()); assertArrayEquals(bitmap, s.getWindowBitmap());
// Incrementing a nonexistent counter should not throw an exception
db.incrementConnectionCounter(txn, contactId, transportId, 1L);
// The nonexistent counter should not have been created
secrets = db.getSecrets(txn);
assertEquals(1, secrets.size());
ts = secrets.iterator().next();
assertEquals(contactId, ts.getContactId());
assertEquals(transportId, ts.getTransportId());
assertEquals(0L, ts.getPeriod());
assertArrayEquals(secret, ts.getSecret());
assertEquals(2L, ts.getOutgoingConnectionCounter());
assertEquals(0L, ts.getWindowCentre());
assertArrayEquals(bitmap, ts.getWindowBitmap());
db.commitTransaction(txn); db.commitTransaction(txn);
db.close(); db.close();
@@ -1879,27 +1887,27 @@ public class H2DatabaseTest extends BriarTestCase {
Random random = new Random(); Random random = new Random();
byte[] secret = new byte[32], bitmap = new byte[4]; byte[] secret = new byte[32], bitmap = new byte[4];
random.nextBytes(secret); random.nextBytes(secret);
TemporarySecret ts = new TemporarySecret(contactId, transportId, 0L, TemporarySecret s = new TemporarySecret(contactId, transportId, 0L,
secret, 0L, 0L, bitmap); 0L, 0L, false, 0L, secret, 0L, 0L, bitmap);
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and the temporary secret // Add a contact and the temporary secret
assertEquals(contactId, db.addContact(txn)); assertEquals(contactId, db.addContact(txn));
db.addSecrets(txn, Arrays.asList(ts)); db.addSecrets(txn, Arrays.asList(s));
// Retrieve the secret // Retrieve the secret
Collection<TemporarySecret> secrets = db.getSecrets(txn); Collection<TemporarySecret> secrets = db.getSecrets(txn);
assertEquals(1, secrets.size()); assertEquals(1, secrets.size());
ts = secrets.iterator().next(); s = secrets.iterator().next();
assertEquals(contactId, ts.getContactId()); assertEquals(contactId, s.getContactId());
assertEquals(transportId, ts.getTransportId()); assertEquals(transportId, s.getTransportId());
assertEquals(0L, ts.getPeriod()); assertEquals(0L, s.getPeriod());
assertArrayEquals(secret, ts.getSecret()); assertArrayEquals(secret, s.getSecret());
assertEquals(0L, ts.getOutgoingConnectionCounter()); assertEquals(0L, s.getOutgoingConnectionCounter());
assertEquals(0L, ts.getWindowCentre()); assertEquals(0L, s.getWindowCentre());
assertArrayEquals(bitmap, ts.getWindowBitmap()); assertArrayEquals(bitmap, s.getWindowBitmap());
// Update the connection window and retrieve the secret again // Update the connection window and retrieve the secret again
db.setConnectionWindow(txn, contactId, transportId, 0L, 1L, bitmap); db.setConnectionWindow(txn, contactId, transportId, 0L, 1L, bitmap);
@@ -1907,28 +1915,28 @@ public class H2DatabaseTest extends BriarTestCase {
db.setConnectionWindow(txn, contactId, transportId, 0L, 1L, bitmap); db.setConnectionWindow(txn, contactId, transportId, 0L, 1L, bitmap);
secrets = db.getSecrets(txn); secrets = db.getSecrets(txn);
assertEquals(1, secrets.size()); assertEquals(1, secrets.size());
ts = secrets.iterator().next(); s = secrets.iterator().next();
assertEquals(contactId, ts.getContactId()); assertEquals(contactId, s.getContactId());
assertEquals(transportId, ts.getTransportId()); assertEquals(transportId, s.getTransportId());
assertEquals(0L, ts.getPeriod()); assertEquals(0L, s.getPeriod());
assertArrayEquals(secret, ts.getSecret()); assertArrayEquals(secret, s.getSecret());
assertEquals(0L, ts.getOutgoingConnectionCounter()); assertEquals(0L, s.getOutgoingConnectionCounter());
assertEquals(1L, ts.getWindowCentre()); assertEquals(1L, s.getWindowCentre());
assertArrayEquals(bitmap, ts.getWindowBitmap()); assertArrayEquals(bitmap, s.getWindowBitmap());
// Updating a nonexistent window should not throw an exception // Updating a nonexistent window should not throw an exception
db.setConnectionWindow(txn, contactId, transportId, 1L, 1L, bitmap); db.setConnectionWindow(txn, contactId, transportId, 1L, 1L, bitmap);
// The nonexistent window should not have been created // The nonexistent window should not have been created
secrets = db.getSecrets(txn); secrets = db.getSecrets(txn);
assertEquals(1, secrets.size()); assertEquals(1, secrets.size());
ts = secrets.iterator().next(); s = secrets.iterator().next();
assertEquals(contactId, ts.getContactId()); assertEquals(contactId, s.getContactId());
assertEquals(transportId, ts.getTransportId()); assertEquals(transportId, s.getTransportId());
assertEquals(0L, ts.getPeriod()); assertEquals(0L, s.getPeriod());
assertArrayEquals(secret, ts.getSecret()); assertArrayEquals(secret, s.getSecret());
assertEquals(0L, ts.getOutgoingConnectionCounter()); assertEquals(0L, s.getOutgoingConnectionCounter());
assertEquals(1L, ts.getWindowCentre()); assertEquals(1L, s.getWindowCentre());
assertArrayEquals(bitmap, ts.getWindowBitmap()); assertArrayEquals(bitmap, s.getWindowBitmap());
db.commitTransaction(txn); db.commitTransaction(txn);
db.close(); db.close();

View File

@@ -14,7 +14,6 @@ import java.util.concurrent.Executors;
import net.sf.briar.BriarTestCase; import net.sf.briar.BriarTestCase;
import net.sf.briar.TestUtils; import net.sf.briar.TestUtils;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.KeyManager;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DatabaseExecutor; import net.sf.briar.api.db.DatabaseExecutor;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
@@ -24,7 +23,6 @@ import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionRegistry; import net.sf.briar.api.transport.ConnectionRegistry;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
@@ -48,8 +46,6 @@ public class OutgoingSimplexConnectionTest extends BriarTestCase {
private final Mockery context; private final Mockery context;
private final DatabaseComponent db; private final DatabaseComponent db;
private final KeyManager keyManager;
private final ConnectionRecogniser connRecogniser;
private final ConnectionRegistry connRegistry; private final ConnectionRegistry connRegistry;
private final ConnectionWriterFactory connFactory; private final ConnectionWriterFactory connFactory;
private final ProtocolWriterFactory protoFactory; private final ProtocolWriterFactory protoFactory;
@@ -61,14 +57,10 @@ public class OutgoingSimplexConnectionTest extends BriarTestCase {
super(); super();
context = new Mockery(); context = new Mockery();
db = context.mock(DatabaseComponent.class); db = context.mock(DatabaseComponent.class);
keyManager = context.mock(KeyManager.class);
connRecogniser = context.mock(ConnectionRecogniser.class);
Module testModule = new AbstractModule() { Module testModule = new AbstractModule() {
@Override @Override
public void configure() { public void configure() {
bind(DatabaseComponent.class).toInstance(db); bind(DatabaseComponent.class).toInstance(db);
bind(KeyManager.class).toInstance(keyManager);
bind(ConnectionRecogniser.class).toInstance(connRecogniser);
bind(Executor.class).annotatedWith( bind(Executor.class).annotatedWith(
DatabaseExecutor.class).toInstance( DatabaseExecutor.class).toInstance(
Executors.newCachedThreadPool()); Executors.newCachedThreadPool());

View File

@@ -8,6 +8,7 @@ import org.jmock.Expectations;
import org.jmock.Mockery; import org.jmock.Mockery;
import org.junit.Test; import org.junit.Test;
public class ConnectionWriterImplTest extends BriarTestCase { public class ConnectionWriterImplTest extends BriarTestCase {
private static final int FRAME_LENGTH = 1024; private static final int FRAME_LENGTH = 1024;

View File

@@ -22,9 +22,15 @@ import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.transport.ConnectionReaderImpl;
import net.sf.briar.transport.ConnectionWriterFactoryImpl;
import net.sf.briar.transport.ConnectionWriterImpl;
import net.sf.briar.transport.IncomingEncryptionLayer;
import net.sf.briar.transport.OutgoingEncryptionLayer;
import org.junit.Test; import org.junit.Test;
import com.google.inject.AbstractModule; import com.google.inject.AbstractModule;
import com.google.inject.Guice; import com.google.inject.Guice;
import com.google.inject.Injector; import com.google.inject.Injector;