Finer-grained locking in the connection recogniser.

This commit is contained in:
akwizgran
2011-11-17 20:06:19 +00:00
parent 2b45cf0dd1
commit 6fada9f243

View File

@@ -10,6 +10,7 @@ import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@@ -47,11 +48,9 @@ DatabaseListener {
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final DatabaseComponent db; private final DatabaseComponent db;
private final Executor executor; private final Executor executor;
private final Cipher ivCipher; private final Map<Bytes, Context> expected; // Locking: this
private final Map<Bytes, Context> expected; private final Collection<TransportId> localTransportIds; // Locking: this
private final Collection<TransportId> localTransportIds; private final AtomicBoolean initialised = new AtomicBoolean(false);
private boolean initialised = false;
@Inject @Inject
ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db, ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db,
@@ -59,60 +58,85 @@ DatabaseListener {
this.crypto = crypto; this.crypto = crypto;
this.db = db; this.db = db;
this.executor = executor; this.executor = executor;
ivCipher = crypto.getIvCipher();
expected = new HashMap<Bytes, Context>(); expected = new HashMap<Bytes, Context>();
localTransportIds = new ArrayList<TransportId>(); localTransportIds = new ArrayList<TransportId>();
db.addListener(this); db.addListener(this);
} }
private synchronized void initialise() throws DbException { private void initialise() throws DbException {
Runtime.getRuntime().addShutdownHook(new Thread() { Runtime.getRuntime().addShutdownHook(new Thread() {
@Override @Override
public void run() { public void run() {
eraseSecrets(); eraseSecrets();
} }
}); });
for(Transport t : db.getLocalTransports()) { Collection<TransportId> ids = new ArrayList<TransportId>();
localTransportIds.add(t.getId()); for(Transport t : db.getLocalTransports()) ids.add(t.getId());
synchronized(this) {
localTransportIds.addAll(ids);
} }
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
for(ContactId c : db.getContacts()) { for(ContactId c : db.getContacts()) {
try { try {
calculateIvs(c); ivs.putAll(calculateIvs(c));
} catch(NoSuchContactException e) { } catch(NoSuchContactException e) {
// The contact was removed - clean up in eventOccurred() // The contact was removed - clean up in eventOccurred()
} }
} }
initialised = true; synchronized(this) {
expected.putAll(ivs);
}
} }
private synchronized void calculateIvs(ContactId c) throws DbException { private synchronized void eraseSecrets() {
for(TransportId t : localTransportIds) { for(Context c : expected.values()) {
TransportIndex i = db.getRemoteIndex(c, t); synchronized(c.window) {
if(i != null) { for(byte[] b : c.window.getUnseen().values()) {
ConnectionWindow w = db.getConnectionWindow(c, i); ByteUtils.erase(b);
calculateIvs(c, t, i, w); }
} }
} }
} }
private synchronized void calculateIvs(ContactId c, TransportId t, private Map<Bytes, Context> calculateIvs(ContactId c) throws DbException {
TransportIndex i, ConnectionWindow w) throws DbException { Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) { Collection<TransportId> ids;
long connection = e.getKey(); synchronized(this) {
byte[] secret = e.getValue(); ids = new ArrayList<TransportId>(localTransportIds);
ErasableKey ivKey = crypto.deriveIvKey(secret, true);
Bytes iv = new Bytes(encryptIv(i, connection, ivKey));
ivKey.erase();
expected.put(iv, new Context(c, t, i, connection, w));
} }
for(TransportId t : ids) {
TransportIndex i = db.getRemoteIndex(c, t);
if(i != null) {
ConnectionWindow w = db.getConnectionWindow(c, i);
ivs.putAll(calculateIvs(c, t, i, w));
}
}
return ivs;
} }
private synchronized byte[] encryptIv(TransportIndex i, long connection, private Map<Bytes, Context> calculateIvs(ContactId c, TransportId t,
ErasableKey ivKey) { TransportIndex i, ConnectionWindow w) throws DbException {
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
synchronized(w) {
for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
long connection = e.getKey();
byte[] secret = e.getValue();
Bytes iv = new Bytes(encryptIv(i, connection, secret));
ivs.put(iv, new Context(c, t, i, connection, w));
}
}
return ivs;
}
private byte[] encryptIv(TransportIndex i, long connection, byte[] secret) {
byte[] iv = IvEncoder.encodeIv(true, i.getInt(), connection); byte[] iv = IvEncoder.encodeIv(true, i.getInt(), connection);
ErasableKey ivKey = crypto.deriveIvKey(secret, true);
try { try {
Cipher ivCipher = crypto.getIvCipher();
ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
return ivCipher.doFinal(iv); byte[] encryptedIv = ivCipher.doFinal(iv);
ivKey.erase();
return encryptedIv;
} catch(BadPaddingException badCipher) { } catch(BadPaddingException badCipher) {
throw new RuntimeException(badCipher); throw new RuntimeException(badCipher);
} catch(IllegalBlockSizeException badCipher) { } catch(IllegalBlockSizeException badCipher) {
@@ -122,12 +146,6 @@ DatabaseListener {
} }
} }
private synchronized void eraseSecrets() {
for(Context c : expected.values()) {
for(byte[] b : c.window.getUnseen().values()) ByteUtils.erase(b);
}
}
public void acceptConnection(final TransportId t, final byte[] encryptedIv, public void acceptConnection(final TransportId t, final byte[] encryptedIv,
final Callback callback) { final Callback callback) {
executor.execute(new Runnable() { executor.execute(new Runnable() {
@@ -137,39 +155,48 @@ DatabaseListener {
}); });
} }
private synchronized void acceptConnectionSync(TransportId t, private void acceptConnectionSync(TransportId t, byte[] encryptedIv,
byte[] encryptedIv, Callback callback) { Callback callback) {
try { try {
if(encryptedIv.length != IV_LENGTH) if(encryptedIv.length != IV_LENGTH)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
if(!initialised) initialise(); if(!initialised.getAndSet(true)) initialise();
Bytes b = new Bytes(encryptedIv); Context ctx;
Context ctx = expected.get(b); synchronized(this) {
if(ctx == null || !ctx.transportId.equals(t)) { Bytes b = new Bytes(encryptedIv);
callback.connectionRejected(); ctx = expected.get(b);
return; if(ctx == null || !ctx.transportId.equals(t)) {
callback.connectionRejected();
return;
}
expected.remove(b);
} }
// The IV was expected // The IV was expected
expected.remove(b);
ContactId c = ctx.contactId; ContactId c = ctx.contactId;
TransportIndex i = ctx.transportIndex; TransportIndex i = ctx.transportIndex;
long connection = ctx.connection; long connection = ctx.connection;
ConnectionWindow w = ctx.window; ConnectionWindow w = ctx.window;
// Get the secret and update the connection window byte[] secret;
byte[] secret = w.setSeen(connection); synchronized(w) {
try { // Get the secret and update the connection window
db.setConnectionWindow(c, i, w); secret = w.setSeen(connection);
} catch(NoSuchContactException e) { try {
// The contact was removed - clean up in eventOccurred() db.setConnectionWindow(c, i, w);
} catch(NoSuchContactException e) {
// The contact was removed - clean up in eventOccurred()
}
} }
// Update the set of expected IVs // Update the set of expected IVs
Iterator<Context> it = expected.values().iterator(); Map<Bytes, Context> ivs = calculateIvs(c, t, i, w);
while(it.hasNext()) { synchronized(this) {
Context ctx1 = it.next(); Iterator<Context> it = expected.values().iterator();
if(ctx1.contactId.equals(c) && ctx1.transportIndex.equals(i)) while(it.hasNext()) {
it.remove(); Context ctx1 = it.next();
if(ctx1.contactId.equals(c)
&& ctx1.transportIndex.equals(i)) it.remove();
}
expected.putAll(ivs);
} }
calculateIvs(c, t, i, w);
callback.connectionAccepted(new ConnectionContextImpl(c, i, callback.connectionAccepted(new ConnectionContextImpl(c, i,
connection, secret)); connection, secret));
} catch(DbException e) { } catch(DbException e) {
@@ -184,28 +211,30 @@ DatabaseListener {
} else if(e instanceof TransportAddedEvent) { } else if(e instanceof TransportAddedEvent) {
// Calculate the expected IVs for the new transport // Calculate the expected IVs for the new transport
TransportId t = ((TransportAddedEvent) e).getTransportId(); TransportId t = ((TransportAddedEvent) e).getTransportId();
synchronized(this) { try {
if(!initialised) return; if(!initialised.getAndSet(true)) initialise();
try { Map<Bytes, Context> ivs = calculateIvs(t);
synchronized(this) {
localTransportIds.add(t); localTransportIds.add(t);
calculateIvs(t); expected.putAll(ivs);
} catch(DbException e1) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e1.getMessage());
} }
} catch(DbException e1) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e1.getMessage());
} }
} else if(e instanceof RemoteTransportsUpdatedEvent) { } else if(e instanceof RemoteTransportsUpdatedEvent) {
// Remove and recalculate the expected IVs for the contact // Remove and recalculate the expected IVs for the contact
ContactId c = ((RemoteTransportsUpdatedEvent) e).getContactId(); ContactId c = ((RemoteTransportsUpdatedEvent) e).getContactId();
synchronized(this) { try {
if(!initialised) return; if(!initialised.getAndSet(true)) initialise();
removeIvs(c); Map<Bytes, Context> ivs = calculateIvs(c);
try { synchronized(this) {
calculateIvs(c); removeIvs(c);
} catch(DbException e1) { expected.putAll(ivs);
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e1.getMessage());
} }
} catch(DbException e1) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e1.getMessage());
} }
} }
} }
@@ -215,18 +244,20 @@ DatabaseListener {
while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove(); while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove();
} }
private synchronized void calculateIvs(TransportId t) throws DbException { private Map<Bytes, Context> calculateIvs(TransportId t) throws DbException {
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
for(ContactId c : db.getContacts()) { for(ContactId c : db.getContacts()) {
try { try {
TransportIndex i = db.getRemoteIndex(c, t); TransportIndex i = db.getRemoteIndex(c, t);
if(i != null) { if(i != null) {
ConnectionWindow w = db.getConnectionWindow(c, i); ConnectionWindow w = db.getConnectionWindow(c, i);
calculateIvs(c, t, i, w); ivs.putAll(calculateIvs(c, t, i, w));
} }
} catch(NoSuchContactException e) { } catch(NoSuchContactException e) {
// The contact was removed - clean up when we get the event // The contact was removed - clean up in eventOccurred()
} }
} }
return ivs;
} }
private static class Context { private static class Context {