Accept connections asynchronously.

This commit is contained in:
akwizgran
2011-11-17 18:59:34 +00:00
parent 27a3f6e497
commit 2b45cf0dd1
10 changed files with 273 additions and 106 deletions

View File

@@ -10,9 +10,18 @@ import net.sf.briar.api.protocol.TransportId;
public interface ConnectionRecogniser { public interface ConnectionRecogniser {
/** /**
* Returns the connection's context if the connection should be accepted, * Asynchronously calls one of the callback's connectionAccepted(),
* or null if the connection should be rejected. * connectionRejected() or handleException() methods.
*/ */
ConnectionContext acceptConnection(TransportId t, byte[] encryptedIv) void acceptConnection(TransportId t, byte[] encryptedIv,
throws DbException; Callback c);
interface Callback {
void connectionAccepted(ConnectionContext ctx);
void connectionRejected();
void handleException(DbException e);
}
} }

View File

@@ -8,7 +8,6 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@@ -51,24 +50,22 @@ class PluginManagerImpl implements PluginManager {
"net.sf.briar.plugins.socket.SimpleSocketPluginFactory" "net.sf.briar.plugins.socket.SimpleSocketPluginFactory"
}; };
private static final int THREAD_POOL_SIZE = 5;
private final DatabaseComponent db; private final DatabaseComponent db;
private final Executor executor;
private final Poller poller; private final Poller poller;
private final ConnectionDispatcher dispatcher; private final ConnectionDispatcher dispatcher;
private final UiCallback uiCallback; private final UiCallback uiCallback;
private final Executor executor;
private final List<BatchPlugin> batchPlugins; private final List<BatchPlugin> batchPlugins;
private final List<StreamPlugin> streamPlugins; private final List<StreamPlugin> streamPlugins;
@Inject @Inject
PluginManagerImpl(DatabaseComponent db, Poller poller, PluginManagerImpl(DatabaseComponent db, Executor executor, Poller poller,
ConnectionDispatcher dispatcher, UiCallback uiCallback) { ConnectionDispatcher dispatcher, UiCallback uiCallback) {
this.db = db; this.db = db;
this.executor = executor;
this.poller = poller; this.poller = poller;
this.dispatcher = dispatcher; this.dispatcher = dispatcher;
this.uiCallback = uiCallback; this.uiCallback = uiCallback;
executor = new ScheduledThreadPoolExecutor(THREAD_POOL_SIZE);
batchPlugins = new ArrayList<BatchPlugin>(); batchPlugins = new ArrayList<BatchPlugin>();
streamPlugins = new ArrayList<StreamPlugin>(); streamPlugins = new ArrayList<StreamPlugin>();
} }

View File

@@ -15,6 +15,7 @@ import net.sf.briar.api.transport.BatchTransportWriter;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionDispatcher; import net.sf.briar.api.transport.ConnectionDispatcher;
import net.sf.briar.api.transport.ConnectionRecogniser; import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionRecogniser.Callback;
import net.sf.briar.api.transport.StreamConnectionFactory; import net.sf.briar.api.transport.StreamConnectionFactory;
import net.sf.briar.api.transport.StreamTransportConnection; import net.sf.briar.api.transport.StreamTransportConnection;
import net.sf.briar.api.transport.TransportConstants; import net.sf.briar.api.transport.TransportConstants;
@@ -39,9 +40,9 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher {
this.streamConnFactory = streamConnFactory; this.streamConnFactory = streamConnFactory;
} }
public void dispatchReader(TransportId t, BatchTransportReader r) { public void dispatchReader(TransportId t, final BatchTransportReader r) {
// Read the encrypted IV // Read the encrypted IV
byte[] encryptedIv; final byte[] encryptedIv;
try { try {
encryptedIv = readIv(r.getInputStream()); encryptedIv = readIv(r.getInputStream());
} catch(IOException e) { } catch(IOException e) {
@@ -49,20 +50,22 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher {
r.dispose(false); r.dispose(false);
return; return;
} }
// Get the connection context, or null if the IV wasn't expected // Get the connection context asynchronously
ConnectionContext ctx; recogniser.acceptConnection(t, encryptedIv, new Callback() {
try {
ctx = recogniser.acceptConnection(t, encryptedIv); public void connectionAccepted(ConnectionContext ctx) {
} catch(DbException e) { batchConnFactory.createIncomingConnection(ctx, r, encryptedIv);
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); }
r.dispose(false);
return; public void connectionRejected() {
} r.dispose(false);
if(ctx == null) { }
r.dispose(false);
return; public void handleException(DbException e) {
} if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
batchConnFactory.createIncomingConnection(ctx, r, encryptedIv); r.dispose(false);
}
});
} }
private byte[] readIv(InputStream in) throws IOException { private byte[] readIv(InputStream in) throws IOException {
@@ -82,9 +85,9 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher {
} }
public void dispatchIncomingConnection(TransportId t, public void dispatchIncomingConnection(TransportId t,
StreamTransportConnection s) { final StreamTransportConnection s) {
// Read the encrypted IV // Read the encrypted IV
byte[] encryptedIv; final byte[] encryptedIv;
try { try {
encryptedIv = readIv(s.getInputStream()); encryptedIv = readIv(s.getInputStream());
} catch(IOException e) { } catch(IOException e) {
@@ -92,20 +95,22 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher {
s.dispose(false); s.dispose(false);
return; return;
} }
// Get the connection context, or null if the IV wasn't expected // Get the connection context asynchronously
ConnectionContext ctx; recogniser.acceptConnection(t, encryptedIv, new Callback() {
try {
ctx = recogniser.acceptConnection(t, encryptedIv); public void connectionAccepted(ConnectionContext ctx) {
} catch(DbException e) { streamConnFactory.createIncomingConnection(ctx, s, encryptedIv);
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); }
s.dispose(false);
return; public void connectionRejected() {
} s.dispose(false);
if(ctx == null) { }
s.dispose(false);
return; public void handleException(DbException e) {
} if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
streamConnFactory.createIncomingConnection(ctx, s, encryptedIv); s.dispose(false);
}
});
} }
public void dispatchOutgoingConnection(ContactId c, TransportIndex i, public void dispatchOutgoingConnection(ContactId c, TransportIndex i,

View File

@@ -9,6 +9,7 @@ import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.concurrent.Executor;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@@ -31,7 +32,6 @@ import net.sf.briar.api.db.event.TransportAddedEvent;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser; import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.util.ByteUtils; import net.sf.briar.util.ByteUtils;
@@ -46,6 +46,7 @@ DatabaseListener {
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final DatabaseComponent db; private final DatabaseComponent db;
private final Executor executor;
private final Cipher ivCipher; private final Cipher ivCipher;
private final Map<Bytes, Context> expected; private final Map<Bytes, Context> expected;
private final Collection<TransportId> localTransportIds; private final Collection<TransportId> localTransportIds;
@@ -53,9 +54,11 @@ DatabaseListener {
private boolean initialised = false; private boolean initialised = false;
@Inject @Inject
ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db) { ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db,
Executor executor) {
this.crypto = crypto; this.crypto = crypto;
this.db = db; this.db = db;
this.executor = executor;
ivCipher = crypto.getIvCipher(); ivCipher = crypto.getIvCipher();
expected = new HashMap<Bytes, Context>(); expected = new HashMap<Bytes, Context>();
localTransportIds = new ArrayList<TransportId>(); localTransportIds = new ArrayList<TransportId>();
@@ -63,6 +66,12 @@ DatabaseListener {
} }
private synchronized void initialise() throws DbException { private synchronized void initialise() throws DbException {
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override
public void run() {
eraseSecrets();
}
});
for(Transport t : db.getLocalTransports()) { for(Transport t : db.getLocalTransports()) {
localTransportIds.add(t.getId()); localTransportIds.add(t.getId());
} }
@@ -73,12 +82,6 @@ DatabaseListener {
// The contact was removed - clean up in eventOccurred() // The contact was removed - clean up in eventOccurred()
} }
} }
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override
public void run() {
eraseSecrets();
}
});
initialised = true; initialised = true;
} }
@@ -125,36 +128,53 @@ DatabaseListener {
} }
} }
public synchronized ConnectionContext acceptConnection(TransportId t, public void acceptConnection(final TransportId t, final byte[] encryptedIv,
byte[] encryptedIv) throws DbException { final Callback callback) {
if(encryptedIv.length != IV_LENGTH) executor.execute(new Runnable() {
throw new IllegalArgumentException(); public void run() {
if(!initialised) initialise(); acceptConnectionSync(t, encryptedIv, callback);
Bytes b = new Bytes(encryptedIv); }
Context ctx = expected.get(b); });
// If the IV was not expected by this transport, reject the connection }
if(ctx == null || !ctx.transportId.equals(t)) return null;
expected.remove(b); private synchronized void acceptConnectionSync(TransportId t,
ContactId c = ctx.contactId; byte[] encryptedIv, Callback callback) {
TransportIndex i = ctx.transportIndex;
long connection = ctx.connection;
ConnectionWindow w = ctx.window;
// Get the secret and update the connection window
byte[] secret = w.setSeen(connection);
try { try {
db.setConnectionWindow(c, i, w); if(encryptedIv.length != IV_LENGTH)
} catch(NoSuchContactException e) { throw new IllegalArgumentException();
// The contact was removed - clean up when we get the event if(!initialised) initialise();
Bytes b = new Bytes(encryptedIv);
Context ctx = expected.get(b);
if(ctx == null || !ctx.transportId.equals(t)) {
callback.connectionRejected();
return;
}
// The IV was expected
expected.remove(b);
ContactId c = ctx.contactId;
TransportIndex i = ctx.transportIndex;
long connection = ctx.connection;
ConnectionWindow w = ctx.window;
// Get the secret and update the connection window
byte[] secret = w.setSeen(connection);
try {
db.setConnectionWindow(c, i, w);
} catch(NoSuchContactException e) {
// The contact was removed - clean up in eventOccurred()
}
// Update the set of expected IVs
Iterator<Context> it = expected.values().iterator();
while(it.hasNext()) {
Context ctx1 = it.next();
if(ctx1.contactId.equals(c) && ctx1.transportIndex.equals(i))
it.remove();
}
calculateIvs(c, t, i, w);
callback.connectionAccepted(new ConnectionContextImpl(c, i,
connection, secret));
} catch(DbException e) {
callback.handleException(e);
} }
// Update the set of expected IVs
Iterator<Context> it = expected.values().iterator();
while(it.hasNext()) {
Context ctx1 = it.next();
if(ctx1.contactId.equals(c) && ctx1.transportIndex.equals(i))
it.remove();
}
calculateIvs(c, t, i, w);
return new ConnectionContextImpl(c, i, connection, secret);
} }
public void eventOccurred(DatabaseEvent e) { public void eventOccurred(DatabaseEvent e) {

View File

@@ -14,6 +14,8 @@ import java.util.Iterator;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
@@ -62,8 +64,10 @@ import net.sf.briar.transport.stream.TransportStreamModule;
import org.bouncycastle.util.Arrays; import org.bouncycastle.util.Arrays;
import org.junit.Test; import org.junit.Test;
import com.google.inject.AbstractModule;
import com.google.inject.Guice; import com.google.inject.Guice;
import com.google.inject.Injector; import com.google.inject.Injector;
import com.google.inject.Module;
public class ProtocolIntegrationTest extends TestCase { public class ProtocolIntegrationTest extends TestCase {
@@ -90,7 +94,14 @@ public class ProtocolIntegrationTest extends TestCase {
public ProtocolIntegrationTest() throws Exception { public ProtocolIntegrationTest() throws Exception {
super(); super();
Injector i = Guice.createInjector(new CryptoModule(), Module testModule = new AbstractModule() {
@Override
public void configure() {
bind(Executor.class).toInstance(
new ScheduledThreadPoolExecutor(5));
}
};
Injector i = Guice.createInjector(testModule, new CryptoModule(),
new DatabaseModule(), new ProtocolModule(), new DatabaseModule(), new ProtocolModule(),
new ProtocolWritersModule(), new SerialModule(), new ProtocolWritersModule(), new SerialModule(),
new TestDatabaseModule(), new TransportBatchModule(), new TestDatabaseModule(), new TransportBatchModule(),

View File

@@ -13,6 +13,8 @@ import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
@@ -54,8 +56,10 @@ import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import com.google.inject.AbstractModule;
import com.google.inject.Guice; import com.google.inject.Guice;
import com.google.inject.Injector; import com.google.inject.Injector;
import com.google.inject.Module;
public class H2DatabaseTest extends TestCase { public class H2DatabaseTest extends TestCase {
@@ -93,7 +97,14 @@ public class H2DatabaseTest extends TestCase {
public H2DatabaseTest() throws Exception { public H2DatabaseTest() throws Exception {
super(); super();
Injector i = Guice.createInjector(new CryptoModule(), Module testModule = new AbstractModule() {
@Override
public void configure() {
bind(Executor.class).toInstance(
new ScheduledThreadPoolExecutor(5));
}
};
Injector i = Guice.createInjector(testModule, new CryptoModule(),
new DatabaseModule(), new ProtocolModule(), new DatabaseModule(), new ProtocolModule(),
new ProtocolWritersModule(), new SerialModule(), new ProtocolWritersModule(), new SerialModule(),
new TransportBatchModule(), new TransportModule(), new TransportBatchModule(), new TransportModule(),

View File

@@ -1,5 +1,6 @@
package net.sf.briar.plugins; package net.sf.briar.plugins;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import junit.framework.TestCase; import junit.framework.TestCase;
@@ -36,9 +37,10 @@ public class PluginManagerImplTest extends TestCase {
allowing(db).setLocalProperties(with(any(TransportId.class)), allowing(db).setLocalProperties(with(any(TransportId.class)),
with(any(TransportProperties.class))); with(any(TransportProperties.class)));
}}); }});
Executor executor = new ImmediateExecutor();
Poller poller = new PollerImpl(); Poller poller = new PollerImpl();
PluginManagerImpl p = new PluginManagerImpl(db, poller, dispatcher, PluginManagerImpl p = new PluginManagerImpl(db, executor, poller,
uiCallback); dispatcher, uiCallback);
// The Bluetooth plugin will not start without a Bluetooth device, so // The Bluetooth plugin will not start without a Bluetooth device, so
// we expect two plugins to be started // we expect two plugins to be started
assertEquals(2, p.startPlugins()); assertEquals(2, p.startPlugins());

View File

@@ -6,6 +6,7 @@ import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.concurrent.Executor;
import javax.crypto.Cipher; import javax.crypto.Cipher;
@@ -15,12 +16,16 @@ import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionRecogniser.Callback;
import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.plugins.ImmediateExecutor;
import org.jmock.Expectations; import org.jmock.Expectations;
import org.jmock.Mockery; import org.jmock.Mockery;
@@ -72,9 +77,22 @@ public class ConnectionRecogniserImplTest extends TestCase {
oneOf(db).getConnectionWindow(contactId, remoteIndex); oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(connectionWindow)); will(returnValue(connectionWindow));
}}); }});
final ConnectionRecogniserImpl c = Executor e = new ImmediateExecutor();
new ConnectionRecogniserImpl(crypto, db); ConnectionRecogniser c = new ConnectionRecogniserImpl(crypto, db, e);
assertNull(c.acceptConnection(transportId, new byte[IV_LENGTH])); c.acceptConnection(transportId, new byte[IV_LENGTH], new Callback() {
public void connectionAccepted(ConnectionContext ctx) {
fail();
}
public void connectionRejected() {
// Expected
}
public void handleException(DbException e) {
fail();
}
});
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -109,19 +127,57 @@ public class ConnectionRecogniserImplTest extends TestCase {
oneOf(db).setConnectionWindow(contactId, remoteIndex, oneOf(db).setConnectionWindow(contactId, remoteIndex,
connectionWindow); connectionWindow);
}}); }});
final ConnectionRecogniserImpl c = Executor e = new ImmediateExecutor();
new ConnectionRecogniserImpl(crypto, db); ConnectionRecogniser c = new ConnectionRecogniserImpl(crypto, db, e);
// The IV should not be expected by the wrong transport // The IV should not be expected by the wrong transport
TransportId wrong = new TransportId(TestUtils.getRandomId()); TransportId wrong = new TransportId(TestUtils.getRandomId());
assertNull(c.acceptConnection(wrong, encryptedIv)); c.acceptConnection(wrong, encryptedIv, new Callback() {
public void connectionAccepted(ConnectionContext ctx) {
fail();
}
public void connectionRejected() {
// Expected
}
public void handleException(DbException e) {
fail();
}
});
// The IV should be expected by the right transport // The IV should be expected by the right transport
ConnectionContext ctx = c.acceptConnection(transportId, encryptedIv); c.acceptConnection(transportId, encryptedIv, new Callback() {
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId()); public void connectionAccepted(ConnectionContext ctx) {
assertEquals(remoteIndex, ctx.getTransportIndex()); assertNotNull(ctx);
assertEquals(3L, ctx.getConnectionNumber()); assertEquals(contactId, ctx.getContactId());
assertEquals(remoteIndex, ctx.getTransportIndex());
assertEquals(3L, ctx.getConnectionNumber());
}
public void connectionRejected() {
fail();
}
public void handleException(DbException e) {
fail();
}
});
// The IV should no longer be expected // The IV should no longer be expected
assertNull(c.acceptConnection(transportId, encryptedIv)); c.acceptConnection(transportId, encryptedIv, new Callback() {
public void connectionAccepted(ConnectionContext ctx) {
fail();
}
public void connectionRejected() {
// Expected
}
public void handleException(DbException e) {
fail();
}
});
// The window should have advanced // The window should have advanced
Map<Long, byte[]> unseen = connectionWindow.getUnseen(); Map<Long, byte[]> unseen = connectionWindow.getUnseen();
assertEquals(19, unseen.size()); assertEquals(19, unseen.size());

View File

@@ -5,6 +5,8 @@ import static net.sf.briar.api.transport.TransportConstants.MIN_CONNECTION_LENGT
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.util.Random; import java.util.Random;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.TestDatabaseModule; import net.sf.briar.TestDatabaseModule;
@@ -24,8 +26,10 @@ import net.sf.briar.transport.stream.TransportStreamModule;
import org.junit.Test; import org.junit.Test;
import com.google.inject.AbstractModule;
import com.google.inject.Guice; import com.google.inject.Guice;
import com.google.inject.Injector; import com.google.inject.Injector;
import com.google.inject.Module;
public class ConnectionWriterTest extends TestCase { public class ConnectionWriterTest extends TestCase {
@@ -38,7 +42,14 @@ public class ConnectionWriterTest extends TestCase {
public ConnectionWriterTest() throws Exception { public ConnectionWriterTest() throws Exception {
super(); super();
Injector i = Guice.createInjector(new CryptoModule(), Module testModule = new AbstractModule() {
@Override
public void configure() {
bind(Executor.class).toInstance(
new ScheduledThreadPoolExecutor(5));
}
};
Injector i = Guice.createInjector(testModule, new CryptoModule(),
new DatabaseModule(), new ProtocolModule(), new DatabaseModule(), new ProtocolModule(),
new ProtocolWritersModule(), new SerialModule(), new ProtocolWritersModule(), new SerialModule(),
new TestDatabaseModule(), new TransportBatchModule(), new TestDatabaseModule(), new TransportBatchModule(),

View File

@@ -10,12 +10,16 @@ import java.io.OutputStream;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Random; import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.TestDatabaseModule; import net.sf.briar.TestDatabaseModule;
import net.sf.briar.TestUtils; import net.sf.briar.TestUtils;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.event.DatabaseEvent; import net.sf.briar.api.db.event.DatabaseEvent;
import net.sf.briar.api.db.event.DatabaseListener; import net.sf.briar.api.db.event.DatabaseListener;
import net.sf.briar.api.db.event.MessagesAddedEvent; import net.sf.briar.api.db.event.MessagesAddedEvent;
@@ -32,6 +36,7 @@ import net.sf.briar.api.transport.BatchTransportWriter;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionRecogniser; import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionRecogniser.Callback;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.db.DatabaseModule; import net.sf.briar.db.DatabaseModule;
@@ -45,8 +50,10 @@ import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import com.google.inject.AbstractModule;
import com.google.inject.Guice; import com.google.inject.Guice;
import com.google.inject.Injector; import com.google.inject.Injector;
import com.google.inject.Module;
public class BatchConnectionReadWriteTest extends TestCase { public class BatchConnectionReadWriteTest extends TestCase {
@@ -75,17 +82,31 @@ public class BatchConnectionReadWriteTest extends TestCase {
public void setUp() { public void setUp() {
testDir.mkdirs(); testDir.mkdirs();
// Create Alice's injector // Create Alice's injector
alice = Guice.createInjector(new CryptoModule(), new DatabaseModule(), Module aliceTestModule = new AbstractModule() {
new ProtocolModule(), new ProtocolWritersModule(), @Override
new SerialModule(), new TestDatabaseModule(aliceDir), public void configure() {
new TransportBatchModule(), new TransportModule(), bind(Executor.class).toInstance(
new TransportStreamModule()); new ScheduledThreadPoolExecutor(5));
}
};
alice = Guice.createInjector(aliceTestModule, new CryptoModule(),
new DatabaseModule(), new ProtocolModule(),
new ProtocolWritersModule(), new SerialModule(),
new TestDatabaseModule(aliceDir), new TransportBatchModule(),
new TransportModule(), new TransportStreamModule());
// Create Bob's injector // Create Bob's injector
bob = Guice.createInjector(new CryptoModule(), new DatabaseModule(), Module bobTestModule = new AbstractModule() {
new ProtocolModule(), new ProtocolWritersModule(), @Override
new SerialModule(), new TestDatabaseModule(bobDir), public void configure() {
new TransportBatchModule(), new TransportModule(), bind(Executor.class).toInstance(
new TransportStreamModule()); new ScheduledThreadPoolExecutor(5));
}
};
bob = Guice.createInjector(bobTestModule, new CryptoModule(),
new DatabaseModule(), new ProtocolModule(),
new ProtocolWritersModule(), new SerialModule(),
new TestDatabaseModule(bobDir), new TransportBatchModule(),
new TransportModule(), new TransportStreamModule());
} }
@Test @Test
@@ -159,7 +180,10 @@ public class BatchConnectionReadWriteTest extends TestCase {
byte[] encryptedIv = new byte[IV_LENGTH]; byte[] encryptedIv = new byte[IV_LENGTH];
int read = in.read(encryptedIv); int read = in.read(encryptedIv);
assertEquals(encryptedIv.length, read); assertEquals(encryptedIv.length, read);
ConnectionContext ctx = rec.acceptConnection(transportId, encryptedIv); TestCallback callback = new TestCallback();
rec.acceptConnection(transportId, encryptedIv, callback);
callback.latch.await();
ConnectionContext ctx = callback.ctx;
assertNotNull(ctx); assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId()); assertEquals(contactId, ctx.getContactId());
assertEquals(transportIndex, ctx.getTransportIndex()); assertEquals(transportIndex, ctx.getTransportIndex());
@@ -234,4 +258,25 @@ public class BatchConnectionReadWriteTest extends TestCase {
assertTrue(success); assertTrue(success);
} }
} }
private static class TestCallback implements Callback {
private final CountDownLatch latch = new CountDownLatch(1);
private ConnectionContext ctx = null;
public void connectionAccepted(ConnectionContext ctx) {
this.ctx = ctx;
latch.countDown();
}
public void connectionRejected() {
fail();
latch.countDown();
}
public void handleException(DbException e) {
fail();
latch.countDown();
}
}
} }