Implemented KeyManager (untested).

A test is failing due to key derivation errors - must be fixed!
This commit is contained in:
akwizgran
2012-10-24 18:16:17 +01:00
parent cc6e9d53ad
commit 27e50b8495
25 changed files with 540 additions and 306 deletions

View File

@@ -1,24 +0,0 @@
package net.sf.briar.crypto;
import net.sf.briar.api.db.DbException;
interface KeyRotator {
/**
* Starts a new thread to rotate keys periodically. The rotator will pause
* for the given number of milliseconds between rotations.
*/
void startRotating(Callback callback, long msBetweenRotations);
/** Tells the rotator thread to exit. */
void stopRotating();
interface Callback {
/**
* Rotates keys, replacing and destroying any keys that have passed the
* ends of their respective retention periods.
*/
void rotateKeys() throws DbException;
}
}

View File

@@ -1,41 +0,0 @@
package net.sf.briar.crypto;
import java.util.Timer;
import java.util.TimerTask;
import java.util.logging.Level;
import java.util.logging.Logger;
import net.sf.briar.api.db.DbException;
class KeyRotatorImpl extends TimerTask implements KeyRotator {
private static final Logger LOG =
Logger.getLogger(KeyRotatorImpl.class.getName());
private volatile Callback callback = null;
private volatile Timer timer = null;
public void startRotating(Callback callback, long msBetweenRotations) {
this.callback = callback;
timer = new Timer();
timer.scheduleAtFixedRate(this, 0L, msBetweenRotations);
}
public void stopRotating() {
if(timer == null) throw new IllegalStateException();
timer.cancel();
}
public void run() {
if(callback == null) throw new IllegalStateException();
try {
callback.rotateKeys();
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
throw new Error(e); // Kill the application
} catch(RuntimeException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
throw new Error(e); // Kill the application
}
}
}

View File

@@ -468,11 +468,11 @@ interface Database<T> {
/**
* Increments the outgoing connection counter for the given contact
* transport in the given rotation period.
* transport in the given rotation period and returns the old value;
* <p>
* Locking: contact read, window write.
*/
void incrementConnectionCounter(T txn, ContactId c, TransportId t,
long incrementConnectionCounter(T txn, ContactId c, TransportId t,
long period) throws DbException;
/**

View File

@@ -758,30 +758,6 @@ DatabaseCleaner.Callback {
}
}
public Collection<ContactTransport> getContactTransports()
throws DbException {
contactLock.readLock().lock();
try {
windowLock.readLock().lock();
try {
T txn = db.startTransaction();
try {
Collection<ContactTransport> contactTransports =
db.getContactTransports(txn);
db.commitTransaction(txn);
return contactTransports;
} catch(DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
windowLock.readLock().unlock();
}
} finally {
contactLock.readLock().unlock();
}
}
public TransportProperties getLocalProperties(TransportId t)
throws DbException {
transportLock.readLock().lock();
@@ -1005,7 +981,7 @@ DatabaseCleaner.Callback {
}
}
public void incrementConnectionCounter(ContactId c, TransportId t,
public long incrementConnectionCounter(ContactId c, TransportId t,
long period) throws DbException {
contactLock.readLock().lock();
try {
@@ -1015,8 +991,9 @@ DatabaseCleaner.Callback {
try {
if(!db.containsContactTransport(txn, c, t))
throw new NoSuchContactTransportException();
db.incrementConnectionCounter(txn, c, t, period);
long l = db.incrementConnectionCounter(txn, c, t, period);
db.commitTransaction(txn);
return l;
} catch(DbException e) {
db.abortTransaction(txn);
throw e;

View File

@@ -1557,22 +1557,30 @@ abstract class JdbcDatabase implements Database<Connection> {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT contactId, transportId, period, secret,"
+ " outgoing, centre, bitmap"
+ " FROM secrets";
String sql = "SELECT ct.contactId, ct.transportId, epoch,"
+ " clockDiff, latency, alice, period, secret, outgoing,"
+ " centre, bitmap"
+ " FROM contactTransports AS ct"
+ " JOIN secrets AS s"
+ " ON ct.contactId = s.contactId"
+ " AND ct.transportId = s.transportId";
ps = txn.prepareStatement(sql);
rs = ps.executeQuery();
List<TemporarySecret> secrets = new ArrayList<TemporarySecret>();
while(rs.next()) {
ContactId c = new ContactId(rs.getInt(1));
TransportId t = new TransportId(rs.getBytes(2));
long period = rs.getLong(3);
byte[] secret = rs.getBytes(4);
long outgoing = rs.getLong(5);
long centre = rs.getLong(6);
byte[] bitmap = rs.getBytes(7);
secrets.add(new TemporarySecret(c, t, period, secret, outgoing,
centre, bitmap));
long epoch = rs.getLong(3);
long clockDiff = rs.getLong(4);
long latency = rs.getLong(5);
boolean alice = rs.getBoolean(6);
long period = rs.getLong(7);
byte[] secret = rs.getBytes(8);
long outgoing = rs.getLong(9);
long centre = rs.getLong(10);
byte[] bitmap = rs.getBytes(11);
secrets.add(new TemporarySecret(c, t, epoch, clockDiff, latency,
alice, period, secret, outgoing, centre, bitmap));
}
rs.close();
ps.close();
@@ -2021,11 +2029,26 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
public void incrementConnectionCounter(Connection txn, ContactId c,
public long incrementConnectionCounter(Connection txn, ContactId c,
TransportId t, long period) throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "UPDATE secrets SET outgoing = outgoing + 1"
// Get the current connection counter
String sql = "SELECT outgoing FROM secrets"
+ " WHERE contactId = ? AND transportId = ? AND period + ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setBytes(2, t.getBytes());
ps.setLong(3, period);
rs = ps.executeQuery();
if(!rs.next()) throw new DbStateException();
long connection = rs.getLong(1);
if(rs.next()) throw new DbStateException();
rs.close();
ps.close();
// Increment the connection counter
sql = "UPDATE secrets SET outgoing = outgoing + 1"
+ " WHERE contactId = ? AND transportId = ? AND period = ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
@@ -2034,8 +2057,10 @@ abstract class JdbcDatabase implements Database<Connection> {
int affected = ps.executeUpdate();
if(affected > 1) throw new DbStateException();
ps.close();
return connection;
} catch(SQLException e) {
tryToClose(ps);
tryToClose(rs);
throw new DbException(e);
}
}

View File

@@ -7,6 +7,7 @@ import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.TemporarySecret;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser;
@@ -37,9 +38,8 @@ class ConnectionRecogniserImpl implements ConnectionRecogniser {
return r.acceptConnection(tag);
}
public void addWindow(ContactId c, TransportId t, long period,
boolean alice, byte[] secret, long centre, byte[] bitmap)
throws DbException {
public void addSecret(TemporarySecret s) throws DbException {
TransportId t = s.getTransportId();
TransportConnectionRecogniser r;
synchronized(this) {
r = recognisers.get(t);
@@ -48,20 +48,24 @@ class ConnectionRecogniserImpl implements ConnectionRecogniser {
recognisers.put(t, r);
}
}
r.addWindow(c, period, alice, secret, centre, bitmap);
r.addSecret(s);
}
public void removeWindow(ContactId c, TransportId t, long period) {
public void removeSecret(ContactId c, TransportId t, long period) {
TransportConnectionRecogniser r;
synchronized(this) {
r = recognisers.get(t);
}
if(r != null) r.removeWindow(c, period);
if(r != null) r.removeSecret(c, period);
}
public synchronized void removeWindows(ContactId c) {
for(TransportConnectionRecogniser r : recognisers.values()) {
r.removeWindows(c);
}
public synchronized void removeSecrets(ContactId c) {
for(TransportConnectionRecogniser r : recognisers.values())
r.removeSecrets(c);
}
public synchronized void removeSecrets() {
for(TransportConnectionRecogniser r : recognisers.values())
r.removeSecrets();
}
}

View File

@@ -13,6 +13,7 @@ import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory;
import com.google.inject.Inject;
class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {

View File

@@ -6,6 +6,7 @@ import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import java.io.IOException;
import java.io.OutputStream;
import net.sf.briar.api.transport.ConnectionWriter;
/**

View File

@@ -0,0 +1,305 @@
package net.sf.briar.transport;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Timer;
import java.util.TimerTask;
import java.util.logging.Level;
import java.util.logging.Logger;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.KeyManager;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.TemporarySecret;
import net.sf.briar.api.db.event.ContactRemovedEvent;
import net.sf.briar.api.db.event.DatabaseEvent;
import net.sf.briar.api.db.event.DatabaseListener;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.util.ByteUtils;
import com.google.inject.Inject;
class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener {
private static final int MS_BETWEEN_CHECKS = 60 * 1000;
private static final Logger LOG =
Logger.getLogger(KeyManagerImpl.class.getName());
private final CryptoComponent crypto;
private final DatabaseComponent db;
private final ConnectionRecogniser recogniser;
private final Timer timer;
// Locking: this
private final Map<ContactTransportKey, TemporarySecret> outgoing;
// Locking: this
private final Map<ContactTransportKey, TemporarySecret> incomingOld;
// Locking: this
private final Map<ContactTransportKey, TemporarySecret> incomingNew;
// Locking: this
private boolean running = false;
@Inject
public KeyManagerImpl(CryptoComponent crypto, DatabaseComponent db,
ConnectionRecogniser recogniser) {
this.crypto = crypto;
this.db = db;
this.recogniser = recogniser;
timer = new Timer();
outgoing = new HashMap<ContactTransportKey, TemporarySecret>();
incomingOld = new HashMap<ContactTransportKey, TemporarySecret>();
incomingNew = new HashMap<ContactTransportKey, TemporarySecret>();
}
public synchronized boolean start() {
if(running) return false;
Collection<TemporarySecret> secrets;
try {
secrets = db.getSecrets();
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
return false;
}
// Work out what phase of its lifecycle each secret is in
long now = System.currentTimeMillis();
Collection<TemporarySecret> dead = assignSecretsToMaps(now, secrets);
// Replace any dead secrets
Collection<TemporarySecret> created = replaceDeadSecrets(now, dead);
try {
// Store any secrets that have been created
if(!created.isEmpty()) db.addSecrets(created);
// Pass the current incoming secrets to the connection recogniser
// FIXME: This uses a separate database transaction for each secret
for(TemporarySecret s : incomingOld.values())
recogniser.addSecret(s);
for(TemporarySecret s : incomingNew.values())
recogniser.addSecret(s);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
return false;
}
// Schedule periodic key rotation
timer.scheduleAtFixedRate(this, MS_BETWEEN_CHECKS, MS_BETWEEN_CHECKS);
running = true;
return true;
}
// Assigns secrets to the appropriate maps and returns any dead secrets
// FIXME: Check there are no duplicate keys when updating maps
private Collection<TemporarySecret> assignSecretsToMaps(long now,
Collection<TemporarySecret> secrets) {
Collection<TemporarySecret> dead = new ArrayList<TemporarySecret>();
for(TemporarySecret s : secrets) {
ContactId c = s.getContactId();
TransportId t = s.getTransportId();
ContactTransportKey k = new ContactTransportKey(c, t);
long rotationPeriod = getRotationPeriod(s);
long creationTime = getCreationTime(s);
long activationTime = creationTime + s.getClockDifference();
long successorCreationTime = creationTime + rotationPeriod;
long deactivationTime = activationTime + rotationPeriod;
long destructionTime = successorCreationTime + rotationPeriod;
if(now >= destructionTime) {
dead.add(s);
} else if(now >= deactivationTime) {
incomingOld.put(k, s);
} else if(now >= successorCreationTime) {
incomingOld.put(k, s);
outgoing.put(k, s);
} else if(now >= activationTime) {
incomingNew.put(k, s);
outgoing.put(k, s);
} else if(now >= creationTime) {
incomingNew.put(k, s);
} else {
// FIXME: What should we do if the clock moves backwards?
throw new IllegalStateException();
}
}
return dead;
}
// Replaces and erases the given secrets and returns any secrets created
private Collection<TemporarySecret> replaceDeadSecrets(long now,
Collection<TemporarySecret> dead) {
Collection<TemporarySecret> created = new ArrayList<TemporarySecret>();
for(TemporarySecret s : dead) {
ContactId c = s.getContactId();
TransportId t = s.getTransportId();
ContactTransportKey k = new ContactTransportKey(c, t);
if(incomingNew.containsKey(k)) throw new IllegalStateException();
byte[] secret = s.getSecret();
long period = s.getPeriod();
if(incomingOld.containsKey(k)) {
// The dead secret's successor is still alive
byte[] secret1 = crypto.deriveNextSecret(secret, period + 1);
TemporarySecret s1 = new TemporarySecret(s, period + 1,
secret1);
created.add(s1);
incomingNew.put(k, s1);
long creationTime = getCreationTime(s1);
long activationTime = creationTime + s1.getClockDifference();
if(now >= activationTime) outgoing.put(k, s1);
} else {
// The dead secret has no living successor
long rotationPeriod = getRotationPeriod(s);
long elapsed = now - s.getEpoch();
long currentPeriod = elapsed / rotationPeriod;
if(currentPeriod <= period) throw new IllegalStateException();
// Derive the two current incoming secrets
byte[] secret1, secret2;
secret1 = secret;
for(long l = period; l < currentPeriod; l++) {
byte[] temp = crypto.deriveNextSecret(secret1, l);
ByteUtils.erase(secret1);
secret1 = temp;
}
secret2 = crypto.deriveNextSecret(secret1, currentPeriod);
// One of the incoming secrets is the current outgoing secret
TemporarySecret s1, s2;
s1 = new TemporarySecret(s, currentPeriod - 1, secret1);
created.add(s1);
incomingOld.put(k, s1);
s2 = new TemporarySecret(s, currentPeriod, secret2);
created.add(s2);
incomingNew.put(k, s2);
if(elapsed % rotationPeriod < s.getClockDifference()) {
// The outgoing secret is the newer incoming secret
outgoing.put(k, s2);
} else {
// The outgoing secret is the older incoming secret
outgoing.put(k, s1);
}
}
// Erase the dead secret
ByteUtils.erase(secret);
}
return created;
}
private long getRotationPeriod(TemporarySecret s) {
return 2 * s.getClockDifference() + s.getLatency();
}
private long getCreationTime(TemporarySecret s) {
long rotationPeriod = getRotationPeriod(s);
return s.getEpoch() + rotationPeriod * s.getPeriod();
}
public synchronized void stop() {
if(!running) return;
recogniser.removeSecrets();
removeAndEraseSecrets(outgoing);
removeAndEraseSecrets(incomingOld);
removeAndEraseSecrets(incomingNew);
running = false;
}
// Locking: this
private void removeAndEraseSecrets(Map<?, TemporarySecret> m) {
for(TemporarySecret s : m.values()) ByteUtils.erase(s.getSecret());
m.clear();
}
public synchronized ConnectionContext getConnectionContext(ContactId c,
TransportId t) {
TemporarySecret s = outgoing.get(new ContactTransportKey(c, t));
if(s == null) return null;
long connection;
try {
connection = db.incrementConnectionCounter(c, t, s.getPeriod());
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
return null;
}
byte[] secret = s.getSecret().clone();
return new ConnectionContext(c, t, secret, connection, s.getAlice());
}
@Override
public synchronized void run() {
// Rebuild the maps because we may be running a whole period late
Collection<TemporarySecret> secrets = new ArrayList<TemporarySecret>();
secrets.addAll(incomingOld.values());
secrets.addAll(incomingNew.values());
outgoing.clear();
incomingOld.clear();
incomingNew.clear();
// Work out what phase of its lifecycle each secret is in
long now = System.currentTimeMillis();
Collection<TemporarySecret> dead = assignSecretsToMaps(now, secrets);
// Replace any dead secrets
Collection<TemporarySecret> created = replaceDeadSecrets(now, dead);
try {
// Store any secrets that have been created
if(!created.isEmpty()) db.addSecrets(created);
// Pass the current incoming secrets to the connection recogniser
// FIXME: This uses a separate database transaction for each secret
for(TemporarySecret s : incomingOld.values())
recogniser.addSecret(s);
for(TemporarySecret s : incomingNew.values())
recogniser.addSecret(s);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
}
}
public void eventOccurred(DatabaseEvent e) {
if(e instanceof ContactRemovedEvent) {
ContactId c = ((ContactRemovedEvent) e).getContactId();
recogniser.removeSecrets(c);
synchronized(this) {
removeAndEraseSecrets(c, outgoing);
removeAndEraseSecrets(c, incomingOld);
removeAndEraseSecrets(c, incomingNew);
}
}
}
// Locking: this
private void removeAndEraseSecrets(ContactId c, Map<?, TemporarySecret> m) {
Iterator<TemporarySecret> it = m.values().iterator();
while(it.hasNext()) {
TemporarySecret s = it.next();
if(s.getContactId().equals(c)) {
ByteUtils.erase(s.getSecret());
it.remove();
}
}
}
private static class ContactTransportKey {
private final ContactId contactId;
private final TransportId transportId;
private ContactTransportKey(ContactId contactId,
TransportId transportId) {
this.contactId = contactId;
this.transportId = transportId;
}
@Override
public int hashCode() {
return contactId.hashCode() + transportId.hashCode();
}
@Override
public boolean equals(Object o) {
if(o instanceof ContactTransportKey) {
ContactTransportKey k = (ContactTransportKey) o;
return contactId.equals(k.contactId) &&
transportId.equals(k.transportId);
}
return false;
}
}
}

View File

@@ -11,6 +11,7 @@ import java.io.IOException;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.ErasableKey;

View File

@@ -15,6 +15,7 @@ import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.TemporarySecret;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.util.ByteUtils;
@@ -74,8 +75,13 @@ class TransportConnectionRecogniser {
return ctx;
}
synchronized void addWindow(ContactId contactId, long period, boolean alice,
byte[] secret, long centre, byte[] bitmap) throws DbException {
synchronized void addSecret(TemporarySecret s) throws DbException {
ContactId contactId = s.getContactId();
long period = s.getPeriod();
byte[] secret = s.getSecret();
boolean alice = s.getAlice();
long centre = s.getWindowCentre();
byte[] bitmap = s.getWindowBitmap();
// Create the connection window and the expected tags
Cipher cipher = crypto.getTagCipher();
ErasableKey key = crypto.deriveTagKey(secret, alice);
@@ -96,10 +102,15 @@ class TransportConnectionRecogniser {
removalMap.put(new RemovalKey(contactId, period), rctx);
}
synchronized void removeWindow(ContactId contactId, long period) {
synchronized void removeSecret(ContactId contactId, long period) {
RemovalKey rk = new RemovalKey(contactId, period);
RemovalContext rctx = removalMap.remove(rk);
if(rctx == null) throw new IllegalArgumentException();
removeSecret(rctx);
}
// Locking: this
private void removeSecret(RemovalContext rctx) {
// Remove the expected tags
Cipher cipher = crypto.getTagCipher();
ErasableKey key = crypto.deriveTagKey(rctx.secret, rctx.alice);
@@ -114,12 +125,18 @@ class TransportConnectionRecogniser {
ByteUtils.erase(rctx.secret);
}
synchronized void removeWindows(ContactId c) {
synchronized void removeSecrets(ContactId c) {
Collection<RemovalKey> keysToRemove = new ArrayList<RemovalKey>();
for(RemovalKey k : removalMap.keySet()) {
if(k.contactId.equals(c)) keysToRemove.add(k);
}
for(RemovalKey k : keysToRemove) removeWindow(k.contactId, k.period);
for(RemovalKey k : keysToRemove) removeSecret(k.contactId, k.period);
}
synchronized void removeSecrets() {
for(RemovalContext rctx : removalMap.values()) removeSecret(rctx);
assert tagMap.isEmpty();
removalMap.clear();
}
private static class WindowContext {
@@ -148,7 +165,7 @@ class TransportConnectionRecogniser {
@Override
public int hashCode() {
return contactId.hashCode()+ (int) period;
return contactId.hashCode() + (int) period;
}
@Override

View File

@@ -3,13 +3,16 @@ package net.sf.briar.transport;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import net.sf.briar.api.crypto.KeyManager;
import net.sf.briar.api.transport.ConnectionDispatcher;
import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionRegistry;
import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.api.transport.IncomingConnectionExecutor;
import com.google.inject.AbstractModule;
import com.google.inject.Singleton;
public class TransportModule extends AbstractModule {
@@ -18,6 +21,8 @@ public class TransportModule extends AbstractModule {
bind(ConnectionDispatcher.class).to(ConnectionDispatcherImpl.class);
bind(ConnectionReaderFactory.class).to(
ConnectionReaderFactoryImpl.class);
bind(ConnectionRecogniser.class).to(ConnectionRecogniserImpl.class).in(
Singleton.class);
bind(ConnectionRegistry.class).toInstance(new ConnectionRegistryImpl());
bind(ConnectionWriterFactory.class).to(
ConnectionWriterFactoryImpl.class);
@@ -25,5 +30,6 @@ public class TransportModule extends AbstractModule {
bind(Executor.class).annotatedWith(
IncomingConnectionExecutor.class).toInstance(
Executors.newCachedThreadPool());
bind(KeyManager.class).to(KeyManagerImpl.class).in(Singleton.class);
}
}