Read the tag on a connection recogniser thread, don't block the

plugin.
This commit is contained in:
akwizgran
2011-12-08 16:33:48 +00:00
parent 6e080bb35d
commit e3242ebb06
6 changed files with 83 additions and 131 deletions

View File

@@ -1,6 +0,0 @@
package net.sf.briar.api;
public interface ExceptionHandler<E extends Exception> {
void handleException(E exception);
}

View File

@@ -1,6 +1,5 @@
package net.sf.briar.api.transport; package net.sf.briar.api.transport;
import net.sf.briar.api.ExceptionHandler;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
@@ -11,15 +10,9 @@ import net.sf.briar.api.protocol.TransportId;
public interface ConnectionRecogniser { public interface ConnectionRecogniser {
/** /**
* Asynchronously calls one of the callback's connectionAccepted(), * Returns the context for the given connection if the connection was
* connectionRejected() or handleException() methods. * expected, or null if the connection was not expected.
*/ */
void acceptConnection(TransportId t, byte[] tag, Callback c); ConnectionContext acceptConnection(TransportId t, byte[] tag)
throws DbException;
interface Callback extends ExceptionHandler<DbException> {
void connectionAccepted(ConnectionContext ctx);
void connectionRejected();
}
} }

View File

@@ -1,7 +1,9 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import java.io.EOFException;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.concurrent.Executor;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@@ -16,7 +18,7 @@ import net.sf.briar.api.transport.BatchTransportWriter;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionDispatcher; import net.sf.briar.api.transport.ConnectionDispatcher;
import net.sf.briar.api.transport.ConnectionRecogniser; import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionRecogniser.Callback; import net.sf.briar.api.transport.ConnectionRecogniserExecutor;
import net.sf.briar.api.transport.StreamTransportConnection; import net.sf.briar.api.transport.StreamTransportConnection;
import net.sf.briar.api.transport.TransportConstants; import net.sf.briar.api.transport.TransportConstants;
@@ -27,56 +29,24 @@ class ConnectionDispatcherImpl implements ConnectionDispatcher {
private static final Logger LOG = private static final Logger LOG =
Logger.getLogger(ConnectionDispatcherImpl.class.getName()); Logger.getLogger(ConnectionDispatcherImpl.class.getName());
private final Executor executor;
private final ConnectionRecogniser recogniser; private final ConnectionRecogniser recogniser;
private final BatchConnectionFactory batchConnFactory; private final BatchConnectionFactory batchConnFactory;
private final StreamConnectionFactory streamConnFactory; private final StreamConnectionFactory streamConnFactory;
@Inject @Inject
ConnectionDispatcherImpl(ConnectionRecogniser recogniser, ConnectionDispatcherImpl(@ConnectionRecogniserExecutor Executor executor,
ConnectionRecogniser recogniser,
BatchConnectionFactory batchConnFactory, BatchConnectionFactory batchConnFactory,
StreamConnectionFactory streamConnFactory) { StreamConnectionFactory streamConnFactory) {
this.executor = executor;
this.recogniser = recogniser; this.recogniser = recogniser;
this.batchConnFactory = batchConnFactory; this.batchConnFactory = batchConnFactory;
this.streamConnFactory = streamConnFactory; this.streamConnFactory = streamConnFactory;
} }
public void dispatchReader(TransportId t, final BatchTransportReader r) { public void dispatchReader(TransportId t, BatchTransportReader r) {
// Read the tag executor.execute(new DispatchBatchConnection(t, r));
final byte[] tag;
try {
tag = readTag(r.getInputStream());
} catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
r.dispose(false);
return;
}
// Get the connection context asynchronously
recogniser.acceptConnection(t, tag, new Callback() {
public void connectionAccepted(ConnectionContext ctx) {
batchConnFactory.createIncomingConnection(ctx, r, tag);
}
public void connectionRejected() {
r.dispose(true);
}
public void handleException(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
r.dispose(false);
}
});
}
private byte[] readTag(InputStream in) throws IOException {
byte[] b = new byte[TransportConstants.TAG_LENGTH];
int offset = 0;
while(offset < b.length) {
int read = in.read(b, offset, b.length - offset);
if(read == -1) throw new IOException();
offset += read;
}
return b;
} }
public void dispatchWriter(ContactId c, TransportIndex i, public void dispatchWriter(ContactId c, TransportIndex i,
@@ -85,36 +55,76 @@ class ConnectionDispatcherImpl implements ConnectionDispatcher {
} }
public void dispatchIncomingConnection(TransportId t, public void dispatchIncomingConnection(TransportId t,
final StreamTransportConnection s) { StreamTransportConnection s) {
// Read the tag executor.execute(new DispatchStreamConnection(t, s));
final byte[] tag;
try {
tag = readTag(s.getInputStream());
} catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
s.dispose(false);
return;
}
// Get the connection context asynchronously
recogniser.acceptConnection(t, tag, new Callback() {
public void connectionAccepted(ConnectionContext ctx) {
streamConnFactory.createIncomingConnection(ctx, s, tag);
}
public void connectionRejected() {
s.dispose(true);
}
public void handleException(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
s.dispose(false);
}
});
} }
public void dispatchOutgoingConnection(ContactId c, TransportIndex i, public void dispatchOutgoingConnection(ContactId c, TransportIndex i,
StreamTransportConnection s) { StreamTransportConnection s) {
streamConnFactory.createOutgoingConnection(c, i, s); streamConnFactory.createOutgoingConnection(c, i, s);
} }
private byte[] readTag(InputStream in) throws IOException {
byte[] b = new byte[TransportConstants.TAG_LENGTH];
int offset = 0;
while(offset < b.length) {
int read = in.read(b, offset, b.length - offset);
if(read == -1) throw new EOFException();
offset += read;
}
return b;
}
private class DispatchBatchConnection implements Runnable {
private final TransportId t;
private final BatchTransportReader r;
private DispatchBatchConnection(TransportId t, BatchTransportReader r) {
this.t = t;
this.r = r;
}
public void run() {
try {
byte[] tag = readTag(r.getInputStream());
ConnectionContext ctx = recogniser.acceptConnection(t, tag);
if(ctx == null) r.dispose(true);
else batchConnFactory.createIncomingConnection(ctx, r, tag);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
r.dispose(false);
} catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
r.dispose(false);
}
}
}
private class DispatchStreamConnection implements Runnable {
private final TransportId t;
private final StreamTransportConnection s;
private DispatchStreamConnection(TransportId t,
StreamTransportConnection s) {
this.t = t;
this.s = s;
}
public void run() {
try {
byte[] tag = readTag(s.getInputStream());
ConnectionContext ctx = recogniser.acceptConnection(t, tag);
if(ctx == null) s.dispose(true);
else streamConnFactory.createIncomingConnection(ctx, s, tag);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
s.dispose(false);
} catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
s.dispose(false);
}
}
}
} }

View File

@@ -107,23 +107,7 @@ DatabaseListener {
return new Bytes(tag); return new Bytes(tag);
} }
public void acceptConnection(final TransportId t, final byte[] tag, public ConnectionContext acceptConnection(TransportId t, byte[] tag)
final Callback callback) {
executor.execute(new Runnable() {
public void run() {
try {
ConnectionContext ctx = acceptConnection(t, tag);
if(ctx == null) callback.connectionRejected();
else callback.connectionAccepted(ctx);
} catch(DbException e) {
callback.handleException(e);
}
}
});
}
// Package access for testing
ConnectionContext acceptConnection(TransportId t, byte[] tag)
throws DbException { throws DbException {
if(tag.length != TAG_LENGTH) if(tag.length != TAG_LENGTH)
throw new IllegalArgumentException(); throw new IllegalArgumentException();

View File

@@ -12,7 +12,6 @@ import net.sf.briar.api.transport.ConnectionWindowFactory;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import com.google.inject.AbstractModule; import com.google.inject.AbstractModule;
import com.google.inject.Singleton;
public class TransportModule extends AbstractModule { public class TransportModule extends AbstractModule {
@@ -23,8 +22,7 @@ public class TransportModule extends AbstractModule {
bind(ConnectionDispatcher.class).to(ConnectionDispatcherImpl.class); bind(ConnectionDispatcher.class).to(ConnectionDispatcherImpl.class);
bind(ConnectionReaderFactory.class).to( bind(ConnectionReaderFactory.class).to(
ConnectionReaderFactoryImpl.class); ConnectionReaderFactoryImpl.class);
bind(ConnectionRecogniser.class).to(ConnectionRecogniserImpl.class).in( bind(ConnectionRecogniser.class).to(ConnectionRecogniserImpl.class);
Singleton.class);
bind(ConnectionWindowFactory.class).to( bind(ConnectionWindowFactory.class).to(
ConnectionWindowFactoryImpl.class); ConnectionWindowFactoryImpl.class);
bind(ConnectionWriterFactory.class).to( bind(ConnectionWriterFactory.class).to(

View File

@@ -8,14 +8,12 @@ import java.io.File;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Random; import java.util.Random;
import java.util.concurrent.CountDownLatch;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.TestDatabaseModule; import net.sf.briar.TestDatabaseModule;
import net.sf.briar.TestUtils; import net.sf.briar.TestUtils;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.event.DatabaseEvent; import net.sf.briar.api.db.event.DatabaseEvent;
import net.sf.briar.api.db.event.DatabaseListener; import net.sf.briar.api.db.event.DatabaseListener;
import net.sf.briar.api.db.event.MessagesAddedEvent; import net.sf.briar.api.db.event.MessagesAddedEvent;
@@ -30,7 +28,6 @@ import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionRecogniser; import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionRecogniser.Callback;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.db.DatabaseModule; import net.sf.briar.db.DatabaseModule;
@@ -158,10 +155,7 @@ public class BatchConnectionReadWriteTest extends TestCase {
byte[] tag = new byte[TAG_LENGTH]; byte[] tag = new byte[TAG_LENGTH];
int read = in.read(tag); int read = in.read(tag);
assertEquals(tag.length, read); assertEquals(tag.length, read);
TestCallback callback = new TestCallback(); ConnectionContext ctx = rec.acceptConnection(transportId, tag);
rec.acceptConnection(transportId, tag, callback);
callback.latch.await();
ConnectionContext ctx = callback.ctx;
assertNotNull(ctx); assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId()); assertEquals(contactId, ctx.getContactId());
assertEquals(transportIndex, ctx.getTransportIndex()); assertEquals(transportIndex, ctx.getTransportIndex());
@@ -198,25 +192,4 @@ public class BatchConnectionReadWriteTest extends TestCase {
if(e instanceof MessagesAddedEvent) messagesAdded = true; if(e instanceof MessagesAddedEvent) messagesAdded = true;
} }
} }
private static class TestCallback implements Callback {
private final CountDownLatch latch = new CountDownLatch(1);
private ConnectionContext ctx = null;
public void connectionAccepted(ConnectionContext ctx) {
this.ctx = ctx;
latch.countDown();
}
public void connectionRejected() {
fail();
latch.countDown();
}
public void handleException(DbException e) {
fail();
latch.countDown();
}
}
} }