Don't keep connection windows in memory.

This commit is contained in:
akwizgran
2011-11-24 13:56:58 +00:00
parent 98148085b6
commit 8068fa0d38
2 changed files with 78 additions and 80 deletions

View File

@@ -6,8 +6,11 @@ import java.security.InvalidKeyException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@@ -28,13 +31,13 @@ import net.sf.briar.api.db.event.DatabaseEvent;
import net.sf.briar.api.db.event.DatabaseListener; import net.sf.briar.api.db.event.DatabaseListener;
import net.sf.briar.api.db.event.RemoteTransportsUpdatedEvent; import net.sf.briar.api.db.event.RemoteTransportsUpdatedEvent;
import net.sf.briar.api.db.event.TransportAddedEvent; import net.sf.briar.api.db.event.TransportAddedEvent;
import net.sf.briar.api.lifecycle.ShutdownManager;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser; import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.util.ByteUtils;
import com.google.inject.Inject; import com.google.inject.Inject;
@@ -47,7 +50,6 @@ DatabaseListener {
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final DatabaseComponent db; private final DatabaseComponent db;
private final Executor executor; private final Executor executor;
private final ShutdownManager shutdown;
private final Cipher ivCipher; // Locking: this private final Cipher ivCipher; // Locking: this
private final Map<Bytes, Context> expected; // Locking: this private final Map<Bytes, Context> expected; // Locking: this
@@ -55,65 +57,50 @@ DatabaseListener {
@Inject @Inject
ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db, ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db,
Executor executor, ShutdownManager shutdown) { Executor executor) {
this.crypto = crypto; this.crypto = crypto;
this.db = db; this.db = db;
this.executor = executor; this.executor = executor;
this.shutdown = shutdown;
ivCipher = crypto.getIvCipher(); ivCipher = crypto.getIvCipher();
expected = new HashMap<Bytes, Context>(); expected = new HashMap<Bytes, Context>();
db.addListener(this);
} }
// Locking: this // Locking: this
private void initialise() throws DbException { private void initialise() throws DbException {
assert !initialised; assert !initialised;
shutdown.addShutdownHook(new Runnable() { db.addListener(this);
public void run() { Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
eraseSecrets();
}
});
Collection<TransportId> transports = new ArrayList<TransportId>(); Collection<TransportId> transports = new ArrayList<TransportId>();
for(Transport t : db.getLocalTransports()) transports.add(t.getId()); for(Transport t : db.getLocalTransports()) transports.add(t.getId());
for(ContactId c : db.getContacts()) { for(ContactId c : db.getContacts()) {
Collection<Context> contexts = new ArrayList<Context>();
try { try {
for(TransportId t : transports) { for(TransportId t : transports) {
TransportIndex i = db.getRemoteIndex(c, t); TransportIndex i = db.getRemoteIndex(c, t);
if(i == null) continue; if(i == null) continue;
ConnectionWindow w = db.getConnectionWindow(c, i); ConnectionWindow w = db.getConnectionWindow(c, i);
for(long unseen : w.getUnseen().keySet()) { for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
contexts.add(new Context(c, t, i, unseen, w)); Context ctx = new Context(c, t, i, e.getKey());
ivs.put(calculateIv(ctx, e.getValue()), ctx);
} }
w.erase();
} }
} catch(NoSuchContactException e) { } catch(NoSuchContactException e) {
// The contact was removed - don't add the IVs // The contact was removed - clean up in removeContact()
for(Context ctx : contexts) ctx.window.erase();
continue; continue;
} }
for(Context ctx : contexts) expected.put(calculateIv(ctx), ctx);
} }
expected.putAll(ivs);
initialised = true; initialised = true;
} }
private synchronized void eraseSecrets() {
for(Context c : expected.values()) c.window.erase();
}
// Locking: this // Locking: this
private Bytes calculateIv(Context ctx) { private Bytes calculateIv(Context ctx, byte[] secret) {
byte[] secret = ctx.window.getUnseen().get(ctx.connection); byte[] iv = IvEncoder.encodeIv(true, ctx.transportIndex.getInt(),
byte[] iv = encryptIv(ctx.transportIndex, ctx.connection, secret); ctx.connection);
return new Bytes(iv);
}
// Locking: this
private byte[] encryptIv(TransportIndex i, long connection, byte[] secret) {
byte[] iv = IvEncoder.encodeIv(true, i.getInt(), connection);
ErasableKey ivKey = crypto.deriveIvKey(secret, true); ErasableKey ivKey = crypto.deriveIvKey(secret, true);
try { try {
ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
return ivCipher.doFinal(iv); return new Bytes(ivCipher.doFinal(iv));
} catch(BadPaddingException badCipher) { } catch(BadPaddingException badCipher) {
throw new RuntimeException(badCipher); throw new RuntimeException(badCipher);
} catch(IllegalBlockSizeException badCipher) { } catch(IllegalBlockSizeException badCipher) {
@@ -154,15 +141,17 @@ DatabaseListener {
ContactId c = ctx.contactId; ContactId c = ctx.contactId;
TransportIndex i = ctx.transportIndex; TransportIndex i = ctx.transportIndex;
long connection = ctx.connection; long connection = ctx.connection;
ConnectionWindow w = ctx.window; ConnectionWindow w = null;
byte[] secret = null;
// Get the secret and update the connection window // Get the secret and update the connection window
byte[] secret = w.setSeen(connection);
try { try {
w = db.getConnectionWindow(c, i);
secret = w.setSeen(connection);
db.setConnectionWindow(c, i, w); db.setConnectionWindow(c, i, w);
} catch(NoSuchContactException e) { } catch(NoSuchContactException e) {
// The contact was removed - reject the connection // The contact was removed - reject the connection
removeContact(c); if(w != null) w.erase();
w.erase(); if(secret != null) ByteUtils.erase(secret);
return null; return null;
} }
// Update the connection window's expected IVs // Update the connection window's expected IVs
@@ -172,26 +161,15 @@ DatabaseListener {
if(ctx1.contactId.equals(c) && ctx1.transportIndex.equals(i)) if(ctx1.contactId.equals(c) && ctx1.transportIndex.equals(i))
it.remove(); it.remove();
} }
for(long unseen : w.getUnseen().keySet()) { for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
Context ctx1 = new Context(c, t, i, unseen, w); Context ctx1 = new Context(c, t, i, e.getKey());
expected.put(calculateIv(ctx1), ctx1); expected.put(calculateIv(ctx1, e.getValue()), ctx1);
} }
w.erase();
return new ConnectionContextImpl(c, i, connection, secret); return new ConnectionContextImpl(c, i, connection, secret);
} }
} }
private synchronized void removeContact(ContactId c) {
if(!initialised) return;
Iterator<Context> it = expected.values().iterator();
while(it.hasNext()) {
Context ctx = it.next();
if(ctx.contactId.equals(c)) {
ctx.window.erase();
it.remove();
}
}
}
public void eventOccurred(DatabaseEvent e) { public void eventOccurred(DatabaseEvent e) {
if(e instanceof ContactRemovedEvent) { if(e instanceof ContactRemovedEvent) {
// Remove the expected IVs for the ex-contact // Remove the expected IVs for the ex-contact
@@ -210,7 +188,7 @@ DatabaseListener {
} }
}); });
} else if(e instanceof RemoteTransportsUpdatedEvent) { } else if(e instanceof RemoteTransportsUpdatedEvent) {
// Recalculate the expected IVs for the contact // Update the expected IVs for the contact
final ContactId c = final ContactId c =
((RemoteTransportsUpdatedEvent) e).getContactId(); ((RemoteTransportsUpdatedEvent) e).getContactId();
executor.execute(new Runnable() { executor.execute(new Runnable() {
@@ -221,52 +199,79 @@ DatabaseListener {
} }
} }
private synchronized void removeContact(ContactId c) {
if(!initialised) return;
Iterator<Context> it = expected.values().iterator();
while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove();
}
private synchronized void addTransport(TransportId t) { private synchronized void addTransport(TransportId t) {
if(!initialised) return; if(!initialised) return;
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
try { try {
for(ContactId c : db.getContacts()) { for(ContactId c : db.getContacts()) {
Collection<Context> contexts = new ArrayList<Context>();
try { try {
TransportIndex i = db.getRemoteIndex(c, t); TransportIndex i = db.getRemoteIndex(c, t);
if(i == null) continue; if(i == null) continue;
ConnectionWindow w = db.getConnectionWindow(c, i); ConnectionWindow w = db.getConnectionWindow(c, i);
for(long unseen : w.getUnseen().keySet()) { for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
contexts.add(new Context(c, t, i, unseen, w)); Context ctx = new Context(c, t, i, e.getKey());
ivs.put(calculateIv(ctx, e.getValue()), ctx);
} }
w.erase();
} catch(NoSuchContactException e) { } catch(NoSuchContactException e) {
// The contact was removed - don't add the IVs // The contact was removed - clean up in removeContact()
for(Context ctx : contexts) ctx.window.erase();
continue; continue;
} }
for(Context ctx : contexts) expected.put(calculateIv(ctx), ctx);
} }
} catch(DbException e) { } catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
return;
} }
expected.putAll(ivs);
} }
private synchronized void updateContact(ContactId c) { private synchronized void updateContact(ContactId c) {
if(!initialised) return; if(!initialised) return;
removeContact(c); // Don't recalculate IVs for transports that are already known
Set<TransportIndex> known = new HashSet<TransportIndex>();
for(Context ctx : expected.values()) {
if(ctx.contactId.equals(c)) known.add(ctx.transportIndex);
}
Set<TransportIndex> current = new HashSet<TransportIndex>();
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
try { try {
Collection<Context> contexts = new ArrayList<Context>(); for(Transport transport : db.getLocalTransports()) {
try { TransportId t = transport.getId();
for(Transport transport : db.getLocalTransports()) { TransportIndex i = db.getRemoteIndex(c, t);
TransportId t = transport.getId(); if(i == null) continue;
TransportIndex i = db.getRemoteIndex(c, t); current.add(i);
// If the transport is not already known, calculate the IVs
if(!known.contains(i)) {
ConnectionWindow w = db.getConnectionWindow(c, i); ConnectionWindow w = db.getConnectionWindow(c, i);
for(long unseen : w.getUnseen().keySet()) { for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
contexts.add(new Context(c, t, i, unseen, w)); Context ctx = new Context(c, t, i, e.getKey());
ivs.put(calculateIv(ctx, e.getValue()), ctx);
} }
w.erase();
} }
} catch(NoSuchContactException e) {
// The contact was removed - don't add the IVs
return;
} }
for(Context ctx : contexts) expected.put(calculateIv(ctx), ctx); } catch(NoSuchContactException e) {
// The contact was removed - clean up in removeContact()
return;
} catch(DbException e) { } catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
return;
} }
// Remove any IVs that are no longer current
Iterator<Context> it = expected.values().iterator();
while(it.hasNext()) {
Context ctx = it.next();
if(ctx.contactId.equals(c) && !current.contains(ctx.transportIndex))
it.remove();
}
// Add any IVs that were not previously known
expected.putAll(ivs);
} }
private static class Context { private static class Context {
@@ -275,17 +280,13 @@ DatabaseListener {
private final TransportId transportId; private final TransportId transportId;
private final TransportIndex transportIndex; private final TransportIndex transportIndex;
private final long connection; private final long connection;
// Locking: ConnectionRecogniser.this
private final ConnectionWindow window;
private Context(ContactId contactId, TransportId transportId, private Context(ContactId contactId, TransportId transportId,
TransportIndex transportIndex, long connection, TransportIndex transportIndex, long connection) {
ConnectionWindow window) {
this.contactId = contactId; this.contactId = contactId;
this.transportId = transportId; this.transportId = transportId;
this.transportIndex = transportIndex; this.transportIndex = transportIndex;
this.connection = connection; this.connection = connection;
this.window = window;
} }
} }
} }

View File

@@ -17,7 +17,6 @@ import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.lifecycle.ShutdownManager;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
@@ -66,11 +65,9 @@ public class ConnectionRecogniserImplTest extends TestCase {
public void testUnexpectedIv() throws Exception { public void testUnexpectedIv() throws Exception {
Mockery context = new Mockery(); Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class); final DatabaseComponent db = context.mock(DatabaseComponent.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class))); oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
// Initialise // Initialise
oneOf(shutdown).addShutdownHook(with(any(Runnable.class)));
oneOf(db).getLocalTransports(); oneOf(db).getLocalTransports();
will(returnValue(transports)); will(returnValue(transports));
oneOf(db).getContacts(); oneOf(db).getContacts();
@@ -82,7 +79,7 @@ public class ConnectionRecogniserImplTest extends TestCase {
}}); }});
Executor executor = new ImmediateExecutor(); Executor executor = new ImmediateExecutor();
ConnectionRecogniser c = new ConnectionRecogniserImpl(crypto, db, ConnectionRecogniser c = new ConnectionRecogniserImpl(crypto, db,
executor, shutdown); executor);
c.acceptConnection(transportId, new byte[IV_LENGTH], new Callback() { c.acceptConnection(transportId, new byte[IV_LENGTH], new Callback() {
public void connectionAccepted(ConnectionContext ctx) { public void connectionAccepted(ConnectionContext ctx) {
@@ -116,11 +113,9 @@ public class ConnectionRecogniserImplTest extends TestCase {
Mockery context = new Mockery(); Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class); final DatabaseComponent db = context.mock(DatabaseComponent.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class))); oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
// Initialise // Initialise
oneOf(shutdown).addShutdownHook(with(any(Runnable.class)));
oneOf(db).getLocalTransports(); oneOf(db).getLocalTransports();
will(returnValue(transports)); will(returnValue(transports));
oneOf(db).getContacts(); oneOf(db).getContacts();
@@ -130,12 +125,14 @@ public class ConnectionRecogniserImplTest extends TestCase {
oneOf(db).getConnectionWindow(contactId, remoteIndex); oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(connectionWindow)); will(returnValue(connectionWindow));
// Update the window // Update the window
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(connectionWindow));
oneOf(db).setConnectionWindow(contactId, remoteIndex, oneOf(db).setConnectionWindow(contactId, remoteIndex,
connectionWindow); connectionWindow);
}}); }});
Executor executor = new ImmediateExecutor(); Executor executor = new ImmediateExecutor();
ConnectionRecogniser c = new ConnectionRecogniserImpl(crypto, db, ConnectionRecogniser c = new ConnectionRecogniserImpl(crypto, db,
executor, shutdown); executor);
// The IV should not be expected by the wrong transport // The IV should not be expected by the wrong transport
TransportId wrong = new TransportId(TestUtils.getRandomId()); TransportId wrong = new TransportId(TestUtils.getRandomId());
c.acceptConnection(wrong, encryptedIv, new Callback() { c.acceptConnection(wrong, encryptedIv, new Callback() {