Fixed the locking in ConnectionRecogniserImpl. Again.

Database calls are made outside the lock, with the exception of
{get,set}ConnectionWindow(), which seems to be unavoidable if we want
to ensure replay protection within and across sessions.
This commit is contained in:
akwizgran
2011-11-18 14:16:51 +00:00
parent dacaa4566d
commit a349a3f1ea

View File

@@ -8,7 +8,6 @@ import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level; import java.util.logging.Level;
@@ -33,6 +32,7 @@ import net.sf.briar.api.db.event.TransportAddedEvent;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser; import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.util.ByteUtils; import net.sf.briar.util.ByteUtils;
@@ -48,8 +48,8 @@ DatabaseListener {
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final DatabaseComponent db; private final DatabaseComponent db;
private final Executor executor; private final Executor executor;
private final Cipher ivCipher; // Locking: this
private final Map<Bytes, Context> expected; // Locking: this private final Map<Bytes, Context> expected; // Locking: this
private final Collection<TransportId> localTransportIds; // Locking: this
private final AtomicBoolean initialised = new AtomicBoolean(false); private final AtomicBoolean initialised = new AtomicBoolean(false);
@Inject @Inject
@@ -58,91 +58,90 @@ DatabaseListener {
this.crypto = crypto; this.crypto = crypto;
this.db = db; this.db = db;
this.executor = executor; this.executor = executor;
ivCipher = crypto.getIvCipher();
expected = new HashMap<Bytes, Context>(); expected = new HashMap<Bytes, Context>();
localTransportIds = new ArrayList<TransportId>();
db.addListener(this); db.addListener(this);
}
private void initialise() throws DbException {
Runtime.getRuntime().addShutdownHook(new Thread() { Runtime.getRuntime().addShutdownHook(new Thread() {
@Override @Override
public void run() { public void run() {
eraseSecrets(); eraseSecrets();
} }
}); });
Collection<TransportId> ids = new ArrayList<TransportId>();
for(Transport t : db.getLocalTransports()) ids.add(t.getId());
synchronized(this) {
localTransportIds.addAll(ids);
}
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
for(ContactId c : db.getContacts()) {
try {
ivs.putAll(calculateIvs(c));
} catch(NoSuchContactException e) {
// The contact was removed - clean up in eventOccurred()
}
}
synchronized(this) {
expected.putAll(ivs);
}
} }
private synchronized void eraseSecrets() { private synchronized void eraseSecrets() {
for(Context c : expected.values()) { for(Context c : expected.values()) {
synchronized(c.window) { for(byte[] b : c.window.getUnseen().values()) ByteUtils.erase(b);
for(byte[] b : c.window.getUnseen().values()) {
ByteUtils.erase(b);
}
}
} }
} }
private Map<Bytes, Context> calculateIvs(ContactId c) throws DbException { private void initialise() throws DbException {
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>(); // Fill in the contexts as far as possible outside the lock
Collection<TransportId> ids; Collection<Context> partial = new ArrayList<Context>();
synchronized(this) { Collection<Transport> transports = db.getLocalTransports();
ids = new ArrayList<TransportId>(localTransportIds); for(ContactId c : db.getContacts()) {
for(Transport transport : transports) {
getPartialContexts(c, transport.getId(), partial);
}
} }
for(TransportId t : ids) { synchronized(this) {
// Complete the contexts and calculate the expected IVs
calculateIvs(completeContexts(partial));
}
}
private void getPartialContexts(ContactId c, TransportId t,
Collection<Context> partial) throws DbException {
try {
TransportIndex i = db.getRemoteIndex(c, t); TransportIndex i = db.getRemoteIndex(c, t);
if(i != null) { if(i != null) {
ConnectionWindow w = db.getConnectionWindow(c, i); // Acquire the lock to avoid getting stale data
ivs.putAll(calculateIvs(c, t, i, w)); synchronized(this) {
ConnectionWindow w = db.getConnectionWindow(c, i);
partial.add(new Context(c, t, i, -1, w));
}
} }
} catch(NoSuchContactException e) {
// The contact was removed - we'll handle the event later
} }
return ivs;
} }
private Map<Bytes, Context> calculateIvs(ContactId c, TransportId t, // Locking: this
TransportIndex i, ConnectionWindow w) throws DbException { private Collection<Context> completeContexts(Collection<Context> partial) {
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>(); Collection<Context> contexts = new ArrayList<Context>();
synchronized(w) { for(Context ctx : partial) {
for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) { for(long unseen : ctx.window.getUnseen().keySet()) {
long connection = e.getKey(); contexts.add(new Context(ctx.contactId, ctx.transportId,
byte[] secret = e.getValue(); ctx.transportIndex, unseen, ctx.window));
Bytes iv = new Bytes(encryptIv(i, connection, secret));
ivs.put(iv, new Context(c, t, i, connection, w));
} }
} }
return ivs; return contexts;
} }
// Locking: this
private void calculateIvs(Collection<Context> contexts) {
for(Context ctx : contexts) {
byte[] secret = ctx.window.getUnseen().get(ctx.connection);
byte[] iv = encryptIv(ctx.transportIndex, ctx.connection, secret);
expected.put(new Bytes(iv), ctx);
}
}
// Locking: this
private byte[] encryptIv(TransportIndex i, long connection, byte[] secret) { private byte[] encryptIv(TransportIndex i, long connection, byte[] secret) {
byte[] iv = IvEncoder.encodeIv(true, i.getInt(), connection); byte[] iv = IvEncoder.encodeIv(true, i.getInt(), connection);
ErasableKey ivKey = crypto.deriveIvKey(secret, true); ErasableKey ivKey = crypto.deriveIvKey(secret, true);
try { try {
Cipher ivCipher = crypto.getIvCipher();
ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
byte[] encryptedIv = ivCipher.doFinal(iv); return ivCipher.doFinal(iv);
ivKey.erase();
return encryptedIv;
} catch(BadPaddingException badCipher) { } catch(BadPaddingException badCipher) {
throw new RuntimeException(badCipher); throw new RuntimeException(badCipher);
} catch(IllegalBlockSizeException badCipher) { } catch(IllegalBlockSizeException badCipher) {
throw new RuntimeException(badCipher); throw new RuntimeException(badCipher);
} catch(InvalidKeyException badKey) { } catch(InvalidKeyException badKey) {
throw new RuntimeException(badKey); throw new RuntimeException(badKey);
} finally {
ivKey.erase();
} }
} }
@@ -150,92 +149,82 @@ DatabaseListener {
final Callback callback) { final Callback callback) {
executor.execute(new Runnable() { executor.execute(new Runnable() {
public void run() { public void run() {
acceptConnectionSync(t, encryptedIv, callback); try {
ConnectionContext ctx = acceptConnection(t, encryptedIv);
if(ctx == null) callback.connectionRejected();
else callback.connectionAccepted(ctx);
} catch(DbException e) {
callback.handleException(e);
}
} }
}); });
} }
private void acceptConnectionSync(TransportId t, byte[] encryptedIv, private ConnectionContext acceptConnection(TransportId t,
Callback callback) { byte[] encryptedIv) throws DbException {
try { if(encryptedIv.length != IV_LENGTH)
if(encryptedIv.length != IV_LENGTH) throw new IllegalArgumentException();
throw new IllegalArgumentException(); if(!initialised.getAndSet(true)) initialise();
if(!initialised.getAndSet(true)) initialise(); synchronized(this) {
Context ctx; Bytes b = new Bytes(encryptedIv);
synchronized(this) { Context ctx = expected.get(b);
Bytes b = new Bytes(encryptedIv); if(ctx == null || !ctx.transportId.equals(t)) return null;
ctx = expected.get(b);
if(ctx == null || !ctx.transportId.equals(t)) {
callback.connectionRejected();
return;
}
expected.remove(b);
}
// The IV was expected // The IV was expected
expected.remove(b);
ContactId c = ctx.contactId; ContactId c = ctx.contactId;
TransportIndex i = ctx.transportIndex; TransportIndex i = ctx.transportIndex;
long connection = ctx.connection; long connection = ctx.connection;
ConnectionWindow w = ctx.window; ConnectionWindow w = ctx.window;
byte[] secret; byte[] secret;
synchronized(w) { // Get the secret and update the connection window
// Get the secret and update the connection window try {
secret = w.setSeen(connection); db.setConnectionWindow(c, i, w);
try { } catch(NoSuchContactException e) {
db.setConnectionWindow(c, i, w); // The contact was removed - we'll handle the event later
} catch(NoSuchContactException e) {
// The contact was removed - clean up in eventOccurred()
}
} }
// Update the set of expected IVs secret = w.setSeen(connection);
Map<Bytes, Context> ivs = calculateIvs(c, t, i, w); // Update the connection window's expected IVs
synchronized(this) { Iterator<Context> it = expected.values().iterator();
Iterator<Context> it = expected.values().iterator(); while(it.hasNext()) {
while(it.hasNext()) { Context ctx1 = it.next();
Context ctx1 = it.next(); if(ctx1.contactId.equals(c)
if(ctx1.contactId.equals(c) && ctx1.transportIndex.equals(i)) it.remove();
&& ctx1.transportIndex.equals(i)) it.remove();
}
expected.putAll(ivs);
} }
callback.connectionAccepted(new ConnectionContextImpl(c, i, Collection<Context> contexts = new ArrayList<Context>();
connection, secret)); for(long unseen : w.getUnseen().keySet()) {
} catch(DbException e) { contexts.add(new Context(c, t, i, unseen, w));
callback.handleException(e); }
calculateIvs(contexts);
return new ConnectionContextImpl(c, i, connection, secret);
} }
} }
public void eventOccurred(DatabaseEvent e) { public void eventOccurred(DatabaseEvent e) {
if(e instanceof ContactRemovedEvent) { if(e instanceof ContactRemovedEvent) {
// Remove the expected IVs for the ex-contact // Remove the expected IVs for the ex-contact
removeIvs(((ContactRemovedEvent) e).getContactId()); final ContactId c = ((ContactRemovedEvent) e).getContactId();
} else if(e instanceof TransportAddedEvent) { executor.execute(new Runnable() {
// Calculate the expected IVs for the new transport public void run() {
TransportId t = ((TransportAddedEvent) e).getTransportId();
try {
if(!initialised.getAndSet(true)) initialise();
Map<Bytes, Context> ivs = calculateIvs(t);
synchronized(this) {
localTransportIds.add(t);
expected.putAll(ivs);
}
} catch(DbException e1) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e1.getMessage());
}
} else if(e instanceof RemoteTransportsUpdatedEvent) {
// Remove and recalculate the expected IVs for the contact
ContactId c = ((RemoteTransportsUpdatedEvent) e).getContactId();
try {
if(!initialised.getAndSet(true)) initialise();
Map<Bytes, Context> ivs = calculateIvs(c);
synchronized(this) {
removeIvs(c); removeIvs(c);
expected.putAll(ivs);
} }
} catch(DbException e1) { });
if(LOG.isLoggable(Level.WARNING)) } else if(e instanceof TransportAddedEvent) {
LOG.warning(e1.getMessage()); // Add the expected IVs for the new transport
} final TransportId t = ((TransportAddedEvent) e).getTransportId();
executor.execute(new Runnable() {
public void run() {
addTransport(t);
}
});
} else if(e instanceof RemoteTransportsUpdatedEvent) {
// Recalculate the expected IVs for the contact
final ContactId c =
((RemoteTransportsUpdatedEvent) e).getContactId();
executor.execute(new Runnable() {
public void run() {
updateContact(c);
}
});
} }
} }
@@ -244,20 +233,42 @@ DatabaseListener {
while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove(); while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove();
} }
private Map<Bytes, Context> calculateIvs(TransportId t) throws DbException { private void addTransport(TransportId t) {
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>(); // Fill in the contexts as far as possible outside the lock
for(ContactId c : db.getContacts()) { Collection<Context> partial = new ArrayList<Context>();
try { try {
TransportIndex i = db.getRemoteIndex(c, t); for(ContactId c : db.getContacts()) {
if(i != null) { getPartialContexts(c, t, partial);
ConnectionWindow w = db.getConnectionWindow(c, i);
ivs.putAll(calculateIvs(c, t, i, w));
}
} catch(NoSuchContactException e) {
// The contact was removed - clean up in eventOccurred()
} }
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
return;
}
synchronized(this) {
// Complete the contexts and calculate the expected IVs
calculateIvs(completeContexts(partial));
}
}
private void updateContact(ContactId c) {
// Fill in the contexts as far as possible outside the lock
Collection<Context> partial = new ArrayList<Context>();
try {
Collection<Transport> transports = db.getLocalTransports();
for(Transport transport : transports) {
getPartialContexts(c, transport.getId(), partial);
}
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
return;
}
synchronized(this) {
// Clear the contact's existing IVs
Iterator<Context> it = expected.values().iterator();
while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove();
// Complete the contexts and calculate the expected IVs
calculateIvs(completeContexts(partial));
} }
return ivs;
} }
private static class Context { private static class Context {
@@ -266,6 +277,7 @@ DatabaseListener {
private final TransportId transportId; private final TransportId transportId;
private final TransportIndex transportIndex; private final TransportIndex transportIndex;
private final long connection; private final long connection;
// Locking: ConnectionRecogniser.this
private final ConnectionWindow window; private final ConnectionWindow window;
private Context(ContactId contactId, TransportId transportId, private Context(ContactId contactId, TransportId transportId,