Rewrote StreamConnection to decouple the database from IO.

Runnables encapsulating database or IO tasks are passed to the
relevant threads. The IO thread's task queue is unbounded to avoid
deadlock, but its growth is indirectly limited by the progress of
database tasks.
This commit is contained in:
akwizgran
2011-12-07 20:52:04 +00:00
parent 5099979b9d
commit 2020f60ebf
7 changed files with 425 additions and 385 deletions

View File

@@ -1,6 +1,7 @@
package net.sf.briar.transport.batch; package net.sf.briar.transport.batch;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.logging.Level; import java.util.logging.Level;
@@ -32,62 +33,60 @@ class IncomingBatchConnection {
private final DatabaseComponent db; private final DatabaseComponent db;
private final ProtocolReaderFactory protoFactory; private final ProtocolReaderFactory protoFactory;
private final ConnectionContext ctx; private final ConnectionContext ctx;
private final BatchTransportReader reader; private final BatchTransportReader transport;
private final byte[] tag; private final byte[] tag;
private final ContactId contactId;
IncomingBatchConnection(@DatabaseExecutor Executor dbExecutor, IncomingBatchConnection(@DatabaseExecutor Executor dbExecutor,
DatabaseComponent db, ConnectionReaderFactory connFactory, DatabaseComponent db, ConnectionReaderFactory connFactory,
ProtocolReaderFactory protoFactory, ConnectionContext ctx, ProtocolReaderFactory protoFactory, ConnectionContext ctx,
BatchTransportReader reader, byte[] tag) { BatchTransportReader transport, byte[] tag) {
this.dbExecutor = dbExecutor; this.dbExecutor = dbExecutor;
this.connFactory = connFactory; this.connFactory = connFactory;
this.db = db; this.db = db;
this.protoFactory = protoFactory; this.protoFactory = protoFactory;
this.ctx = ctx; this.ctx = ctx;
this.reader = reader; this.transport = transport;
this.tag = tag; this.tag = tag;
contactId = ctx.getContactId();
} }
void read() { void read() {
try { try {
ConnectionReader conn = connFactory.createConnectionReader( ConnectionReader conn = connFactory.createConnectionReader(
reader.getInputStream(), ctx.getSecret(), tag); transport.getInputStream(), ctx.getSecret(), tag);
ProtocolReader proto = protoFactory.createProtocolReader( InputStream in = conn.getInputStream();
conn.getInputStream()); ProtocolReader reader = protoFactory.createProtocolReader(in);
final ContactId c = ctx.getContactId();
// Read packets until EOF // Read packets until EOF
while(!proto.eof()) { while(!reader.eof()) {
if(proto.hasAck()) { if(reader.hasAck()) {
Ack a = proto.readAck(); Ack a = reader.readAck();
dbExecutor.execute(new ReceiveAck(c, a)); dbExecutor.execute(new ReceiveAck(a));
} else if(proto.hasBatch()) { } else if(reader.hasBatch()) {
UnverifiedBatch b = proto.readBatch(); UnverifiedBatch b = reader.readBatch();
dbExecutor.execute(new ReceiveBatch(c, b)); dbExecutor.execute(new ReceiveBatch(b));
} else if(proto.hasSubscriptionUpdate()) { } else if(reader.hasSubscriptionUpdate()) {
SubscriptionUpdate s = proto.readSubscriptionUpdate(); SubscriptionUpdate s = reader.readSubscriptionUpdate();
dbExecutor.execute(new ReceiveSubscriptionUpdate(c, s)); dbExecutor.execute(new ReceiveSubscriptionUpdate(s));
} else if(proto.hasTransportUpdate()) { } else if(reader.hasTransportUpdate()) {
TransportUpdate t = proto.readTransportUpdate(); TransportUpdate t = reader.readTransportUpdate();
dbExecutor.execute(new ReceiveTransportUpdate(c, t)); dbExecutor.execute(new ReceiveTransportUpdate(t));
} else { } else {
throw new FormatException(); throw new FormatException();
} }
} }
transport.dispose(true);
} catch(IOException e) { } catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
reader.dispose(false); transport.dispose(false);
} }
// Success
reader.dispose(true);
} }
private class ReceiveAck implements Runnable { private class ReceiveAck implements Runnable {
private final ContactId contactId;
private final Ack ack; private final Ack ack;
private ReceiveAck(ContactId contactId, Ack ack) { private ReceiveAck(Ack ack) {
this.contactId = contactId;
this.ack = ack; this.ack = ack;
} }
@@ -102,11 +101,9 @@ class IncomingBatchConnection {
private class ReceiveBatch implements Runnable { private class ReceiveBatch implements Runnable {
private final ContactId contactId;
private final UnverifiedBatch batch; private final UnverifiedBatch batch;
private ReceiveBatch(ContactId contactId, UnverifiedBatch batch) { private ReceiveBatch(UnverifiedBatch batch) {
this.contactId = contactId;
this.batch = batch; this.batch = batch;
} }
@@ -124,12 +121,9 @@ class IncomingBatchConnection {
private class ReceiveSubscriptionUpdate implements Runnable { private class ReceiveSubscriptionUpdate implements Runnable {
private final ContactId contactId;
private final SubscriptionUpdate update; private final SubscriptionUpdate update;
private ReceiveSubscriptionUpdate(ContactId contactId, private ReceiveSubscriptionUpdate(SubscriptionUpdate update) {
SubscriptionUpdate update) {
this.contactId = contactId;
this.update = update; this.update = update;
} }
@@ -144,12 +138,9 @@ class IncomingBatchConnection {
private class ReceiveTransportUpdate implements Runnable { private class ReceiveTransportUpdate implements Runnable {
private final ContactId contactId;
private final TransportUpdate update; private final TransportUpdate update;
private ReceiveTransportUpdate(ContactId contactId, private ReceiveTransportUpdate(TransportUpdate update) {
TransportUpdate update) {
this.contactId = contactId;
this.update = update; this.update = update;
} }

View File

@@ -54,37 +54,37 @@ class OutgoingBatchConnection {
transport.getOutputStream(), transport.getCapacity(), transport.getOutputStream(), transport.getCapacity(),
ctx.getSecret()); ctx.getSecret());
OutputStream out = conn.getOutputStream(); OutputStream out = conn.getOutputStream();
ProtocolWriter proto = protoFactory.createProtocolWriter(out); ProtocolWriter writer = protoFactory.createProtocolWriter(out);
// There should be enough space for a packet // There should be enough space for a packet
long capacity = conn.getRemainingCapacity(); long capacity = conn.getRemainingCapacity();
if(capacity < MAX_PACKET_LENGTH) throw new IOException(); if(capacity < MAX_PACKET_LENGTH) throw new IOException();
// Write a transport update // Write a transport update
TransportUpdate t = db.generateTransportUpdate(contactId); TransportUpdate t = db.generateTransportUpdate(contactId);
if(t != null) proto.writeTransportUpdate(t); if(t != null) writer.writeTransportUpdate(t);
// If there's space, write a subscription update // If there's space, write a subscription update
capacity = conn.getRemainingCapacity(); capacity = conn.getRemainingCapacity();
if(capacity >= MAX_PACKET_LENGTH) { if(capacity >= MAX_PACKET_LENGTH) {
SubscriptionUpdate s = db.generateSubscriptionUpdate(contactId); SubscriptionUpdate s = db.generateSubscriptionUpdate(contactId);
if(s != null) proto.writeSubscriptionUpdate(s); if(s != null) writer.writeSubscriptionUpdate(s);
} }
// Write acks until you can't write acks no more // Write acks until you can't write acks no more
capacity = conn.getRemainingCapacity(); capacity = conn.getRemainingCapacity();
int maxBatches = proto.getMaxBatchesForAck(capacity); int maxBatches = writer.getMaxBatchesForAck(capacity);
Ack a = db.generateAck(contactId, maxBatches); Ack a = db.generateAck(contactId, maxBatches);
while(a != null) { while(a != null) {
proto.writeAck(a); writer.writeAck(a);
capacity = conn.getRemainingCapacity(); capacity = conn.getRemainingCapacity();
maxBatches = proto.getMaxBatchesForAck(capacity); maxBatches = writer.getMaxBatchesForAck(capacity);
a = db.generateAck(contactId, maxBatches); a = db.generateAck(contactId, maxBatches);
} }
// Write batches until you can't write batches no more // Write batches until you can't write batches no more
capacity = conn.getRemainingCapacity(); capacity = conn.getRemainingCapacity();
capacity = proto.getMessageCapacityForBatch(capacity); capacity = writer.getMessageCapacityForBatch(capacity);
RawBatch b = db.generateBatch(contactId, (int) capacity); RawBatch b = db.generateBatch(contactId, (int) capacity);
while(b != null) { while(b != null) {
proto.writeBatch(b); writer.writeBatch(b);
capacity = conn.getRemainingCapacity(); capacity = conn.getRemainingCapacity();
capacity = proto.getMessageCapacityForBatch(capacity); capacity = writer.getMessageCapacityForBatch(capacity);
b = db.generateBatch(contactId, (int) capacity); b = db.generateBatch(contactId, (int) capacity);
} }
// Flush the output stream // Flush the output stream

View File

@@ -1,14 +0,0 @@
package net.sf.briar.transport.stream;
interface Flags {
// Flags raised by the database listener
static final int BATCH_RECEIVED = 1;
static final int CONTACT_REMOVED = 2;
static final int MESSAGES_ADDED = 4;
static final int SUBSCRIPTIONS_UPDATED = 8;
static final int TRANSPORTS_UPDATED = 16;
// Flags raised by the reading side of the connection
static final int OFFER_RECEIVED = 32;
static final int REQUEST_RECEIVED = 64;
}

View File

@@ -4,10 +4,9 @@ import java.io.IOException;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DatabaseExecutor;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
@@ -20,14 +19,14 @@ class IncomingStreamConnection extends StreamConnection {
private final ConnectionContext ctx; private final ConnectionContext ctx;
private final byte[] tag; private final byte[] tag;
IncomingStreamConnection(Executor executor, DatabaseComponent db, IncomingStreamConnection(@DatabaseExecutor Executor dbExecutor,
SerialComponent serial, ConnectionReaderFactory connReaderFactory, DatabaseComponent db, ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ProtocolWriterFactory protoWriterFactory,
ConnectionContext ctx, StreamTransportConnection connection, ConnectionContext ctx, StreamTransportConnection connection,
byte[] tag) { byte[] tag) {
super(executor, db, serial, connReaderFactory, connWriterFactory, super(dbExecutor, db, connReaderFactory, connWriterFactory,
protoReaderFactory, protoWriterFactory, ctx.getContactId(), protoReaderFactory, protoWriterFactory, ctx.getContactId(),
connection); connection);
this.ctx = ctx; this.ctx = ctx;
@@ -35,15 +34,13 @@ class IncomingStreamConnection extends StreamConnection {
} }
@Override @Override
protected ConnectionReader createConnectionReader() throws DbException, protected ConnectionReader createConnectionReader() throws IOException {
IOException {
return connReaderFactory.createConnectionReader( return connReaderFactory.createConnectionReader(
connection.getInputStream(), ctx.getSecret(), tag); connection.getInputStream(), ctx.getSecret(), tag);
} }
@Override @Override
protected ConnectionWriter createConnectionWriter() throws DbException, protected ConnectionWriter createConnectionWriter() throws IOException {
IOException {
return connWriterFactory.createConnectionWriter( return connWriterFactory.createConnectionWriter(
connection.getOutputStream(), Long.MAX_VALUE, ctx.getSecret(), connection.getOutputStream(), Long.MAX_VALUE, ctx.getSecret(),
tag); tag);

View File

@@ -5,11 +5,11 @@ import java.util.concurrent.Executor;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DatabaseExecutor;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
@@ -23,14 +23,14 @@ class OutgoingStreamConnection extends StreamConnection {
private ConnectionContext ctx = null; // Locking: this private ConnectionContext ctx = null; // Locking: this
OutgoingStreamConnection(Executor executor, DatabaseComponent db, OutgoingStreamConnection(@DatabaseExecutor Executor dbExecutor,
SerialComponent serial, ConnectionReaderFactory connReaderFactory, DatabaseComponent db, ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId, ProtocolWriterFactory protoWriterFactory, ContactId contactId,
TransportIndex transportIndex, TransportIndex transportIndex,
StreamTransportConnection connection) { StreamTransportConnection connection) {
super(executor, db, serial, connReaderFactory, connWriterFactory, super(dbExecutor, db, connReaderFactory, connWriterFactory,
protoReaderFactory, protoWriterFactory, contactId, connection); protoReaderFactory, protoWriterFactory, contactId, connection);
this.transportIndex = transportIndex; this.transportIndex = transportIndex;
} }

View File

@@ -11,6 +11,7 @@ import java.util.Collections;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@@ -38,7 +39,6 @@ import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
@@ -47,14 +47,16 @@ import net.sf.briar.api.transport.StreamTransportConnection;
abstract class StreamConnection implements DatabaseListener { abstract class StreamConnection implements DatabaseListener {
private static enum State { SEND_OFFER, IDLE, AWAIT_REQUEST, SEND_BATCHES };
private static final Logger LOG = private static final Logger LOG =
Logger.getLogger(StreamConnection.class.getName()); Logger.getLogger(StreamConnection.class.getName());
// A canary indicating that the connection should be closed
private static final Runnable CLOSE_CONNECTION = new Runnable() {
public void run() {}
};
protected final Executor dbExecutor; protected final Executor dbExecutor;
protected final DatabaseComponent db; protected final DatabaseComponent db;
protected final SerialComponent serial;
protected final ConnectionReaderFactory connReaderFactory; protected final ConnectionReaderFactory connReaderFactory;
protected final ConnectionWriterFactory connWriterFactory; protected final ConnectionWriterFactory connWriterFactory;
protected final ProtocolReaderFactory protoReaderFactory; protected final ProtocolReaderFactory protoReaderFactory;
@@ -62,124 +64,105 @@ abstract class StreamConnection implements DatabaseListener {
protected final ContactId contactId; protected final ContactId contactId;
protected final StreamTransportConnection connection; protected final StreamTransportConnection connection;
private int writerFlags = 0; // Locking: this private final AtomicBoolean canSendOffer = new AtomicBoolean(false);
private final LinkedList<Runnable> writerTasks; // Locking: this
private Collection<MessageId> offered = null; // Locking: this private Collection<MessageId> offered = null; // Locking: this
private LinkedList<MessageId> requested = null; // Locking: this
private Offer incomingOffer = null; // Locking: this private volatile ProtocolWriter writer = null;
StreamConnection(@DatabaseExecutor Executor dbExecutor, StreamConnection(@DatabaseExecutor Executor dbExecutor,
DatabaseComponent db, SerialComponent serial, DatabaseComponent db, ConnectionReaderFactory connReaderFactory,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId, ProtocolWriterFactory protoWriterFactory, ContactId contactId,
StreamTransportConnection connection) { StreamTransportConnection connection) {
this.dbExecutor = dbExecutor; this.dbExecutor = dbExecutor;
this.db = db; this.db = db;
this.serial = serial;
this.connReaderFactory = connReaderFactory; this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory; this.connWriterFactory = connWriterFactory;
this.protoReaderFactory = protoReaderFactory; this.protoReaderFactory = protoReaderFactory;
this.protoWriterFactory = protoWriterFactory; this.protoWriterFactory = protoWriterFactory;
this.contactId = contactId; this.contactId = contactId;
this.connection = connection; this.connection = connection;
writerTasks = new LinkedList<Runnable>();
} }
protected abstract ConnectionReader createConnectionReader() protected abstract ConnectionReader createConnectionReader()
throws DbException, IOException; throws DbException, IOException;
protected abstract ConnectionWriter createConnectionWriter() protected abstract ConnectionWriter createConnectionWriter()
throws DbException, IOException ; throws DbException, IOException;
public void eventOccurred(DatabaseEvent e) { public void eventOccurred(DatabaseEvent e) {
synchronized(this) { if(e instanceof BatchReceivedEvent) {
if(e instanceof BatchReceivedEvent) { dbExecutor.execute(new GenerateAcks());
writerFlags |= Flags.BATCH_RECEIVED; } else if(e instanceof ContactRemovedEvent) {
notifyAll(); ContactId c = ((ContactRemovedEvent) e).getContactId();
} else if(e instanceof ContactRemovedEvent) { if(contactId.equals(c)) {
ContactId c = ((ContactRemovedEvent) e).getContactId(); synchronized(this) {
if(contactId.equals(c)) { writerTasks.add(CLOSE_CONNECTION);
writerFlags |= Flags.CONTACT_REMOVED;
notifyAll(); notifyAll();
} }
} else if(e instanceof MessagesAddedEvent) {
writerFlags |= Flags.MESSAGES_ADDED;
notifyAll();
} else if(e instanceof SubscriptionsUpdatedEvent) {
Collection<ContactId> affected =
((SubscriptionsUpdatedEvent) e).getAffectedContacts();
if(affected.contains(contactId)) {
writerFlags |= Flags.SUBSCRIPTIONS_UPDATED;
notifyAll();
}
} else if(e instanceof LocalTransportsUpdatedEvent) {
writerFlags |= Flags.TRANSPORTS_UPDATED;
notifyAll();
} }
} else if(e instanceof MessagesAddedEvent) {
if(canSendOffer.getAndSet(false))
dbExecutor.execute(new GenerateOffer());
} else if(e instanceof SubscriptionsUpdatedEvent) {
Collection<ContactId> affected =
((SubscriptionsUpdatedEvent) e).getAffectedContacts();
if(affected.contains(contactId)) {
dbExecutor.execute(new GenerateSubscriptionUpdate());
}
} else if(e instanceof LocalTransportsUpdatedEvent) {
dbExecutor.execute(new GenerateTransportUpdate());
} }
} }
void read() { void read() {
try { try {
InputStream in = createConnectionReader().getInputStream(); InputStream in = createConnectionReader().getInputStream();
ProtocolReader proto = protoReaderFactory.createProtocolReader(in); ProtocolReader reader = protoReaderFactory.createProtocolReader(in);
while(!proto.eof()) { while(!reader.eof()) {
if(proto.hasAck()) { if(reader.hasAck()) {
Ack a = proto.readAck(); Ack a = reader.readAck();
dbExecutor.execute(new ReceiveAck(contactId, a)); dbExecutor.execute(new ReceiveAck(a));
} else if(proto.hasBatch()) { } else if(reader.hasBatch()) {
UnverifiedBatch b = proto.readBatch(); UnverifiedBatch b = reader.readBatch();
dbExecutor.execute(new ReceiveBatch(contactId, b)); dbExecutor.execute(new ReceiveBatch(b));
} else if(proto.hasOffer()) { } else if(reader.hasOffer()) {
Offer o = proto.readOffer(); Offer o = reader.readOffer();
// Store the incoming offer and notify the writer dbExecutor.execute(new ReceiveOffer(o));
synchronized(this) { } else if(reader.hasRequest()) {
writerFlags |= Flags.OFFER_RECEIVED; Request r = reader.readRequest();
incomingOffer = o;
notifyAll();
}
} else if(proto.hasRequest()) {
Request r = proto.readRequest();
// Retrieve the offered message IDs // Retrieve the offered message IDs
Collection<MessageId> off; Collection<MessageId> offered = getOfferedMessageIds();
synchronized(this) {
if(offered == null)
throw new IOException("Unexpected request packet");
off = offered;
offered = null;
}
// Work out which messages were requested // Work out which messages were requested
BitSet b = r.getBitmap(); BitSet b = r.getBitmap();
LinkedList<MessageId> req = new LinkedList<MessageId>(); List<MessageId> requested = new LinkedList<MessageId>();
List<MessageId> seen = new ArrayList<MessageId>(); List<MessageId> seen = new ArrayList<MessageId>();
int i = 0; int i = 0;
for(MessageId m : off) { for(MessageId m : offered) {
if(b.get(i++)) req.add(m); if(b.get(i++)) requested.add(m);
else seen.add(m); else seen.add(m);
} }
requested = Collections.synchronizedList(requested);
seen = Collections.unmodifiableList(seen); seen = Collections.unmodifiableList(seen);
// Mark the unrequested messages as seen // Mark the unrequested messages as seen
dbExecutor.execute(new SetSeen(contactId, seen)); dbExecutor.execute(new SetSeen(seen));
// Store the requested message IDs and notify the writer // Start sending the requested messages
synchronized(this) { dbExecutor.execute(new GenerateBatch(requested));
if(requested != null) } else if(reader.hasSubscriptionUpdate()) {
throw new IOException("Unexpected request packet"); SubscriptionUpdate s = reader.readSubscriptionUpdate();
requested = req; dbExecutor.execute(new ReceiveSubscriptionUpdate(s));
writerFlags |= Flags.REQUEST_RECEIVED; } else if(reader.hasTransportUpdate()) {
notifyAll(); TransportUpdate t = reader.readTransportUpdate();
} dbExecutor.execute(new ReceiveTransportUpdate(t));
} else if(proto.hasSubscriptionUpdate()) {
SubscriptionUpdate s = proto.readSubscriptionUpdate();
dbExecutor.execute(new ReceiveSubscriptionUpdate(
contactId, s));
} else if(proto.hasTransportUpdate()) {
TransportUpdate t = proto.readTransportUpdate();
dbExecutor.execute(new ReceiveTransportUpdate(
contactId, t));
} else { } else {
throw new FormatException(); throw new FormatException();
} }
} }
connection.dispose(true);
} catch(DbException e) { } catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
connection.dispose(false); connection.dispose(false);
@@ -187,142 +170,50 @@ abstract class StreamConnection implements DatabaseListener {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
connection.dispose(false); connection.dispose(false);
} }
// Success }
connection.dispose(true);
private synchronized Collection<MessageId> getOfferedMessageIds()
throws FormatException {
if(offered == null) throw new FormatException(); // Unexpected request
Collection<MessageId> ids = offered;
offered = null;
return ids;
}
// Locking: this
private void setOfferedMessageIds(Collection<MessageId> ids) {
assert offered == null;
offered = ids;
} }
void write() { void write() {
try { try {
OutputStream out = createConnectionWriter().getOutputStream(); OutputStream out = createConnectionWriter().getOutputStream();
ProtocolWriter proto = protoWriterFactory.createProtocolWriter(out); writer = protoWriterFactory.createProtocolWriter(out);
// Send the initial packets: transports, subs, any waiting acks // Start receiving database events
sendTransportUpdate(proto); db.addListener(this);
sendSubscriptionUpdate(proto); // Send the initial packets: transports, subs, acks, offer
sendAcks(proto); dbExecutor.execute(new GenerateTransportUpdate());
State state = State.SEND_OFFER; dbExecutor.execute(new GenerateSubscriptionUpdate());
dbExecutor.execute(new GenerateAcks());
dbExecutor.execute(new GenerateOffer());
// Main loop // Main loop
while(true) { while(true) {
int flags = 0; Runnable task = null;
switch(state) { synchronized(this) {
while(writerTasks.isEmpty()) {
case SEND_OFFER: try {
// Try to send an offer wait();
if(sendOffer(proto)) state = State.AWAIT_REQUEST; } catch(InterruptedException e) {
else state = State.IDLE; Thread.currentThread().interrupt();
break;
case IDLE:
// Wait for one or more flags to be raised
synchronized(this) {
while(writerFlags == 0) {
try {
wait();
} catch(InterruptedException e) {
Thread.currentThread().interrupt();
}
} }
flags = writerFlags;
writerFlags = 0;
} }
// Handle the flags in approximate order of urgency task = writerTasks.poll();
if((flags & Flags.CONTACT_REMOVED) != 0) {
connection.dispose(true);
return;
}
if((flags & Flags.TRANSPORTS_UPDATED) != 0) {
sendTransportUpdate(proto);
}
if((flags & Flags.SUBSCRIPTIONS_UPDATED) != 0) {
sendSubscriptionUpdate(proto);
}
if((flags & Flags.BATCH_RECEIVED) != 0) {
sendAcks(proto);
}
if((flags & Flags.OFFER_RECEIVED) != 0) {
sendRequest(proto);
}
if((flags & Flags.REQUEST_RECEIVED) != 0) {
// Should only be received in state AWAIT_REQUEST
throw new IOException("Unexpected request packet");
}
if((flags & Flags.MESSAGES_ADDED) != 0) {
state = State.SEND_OFFER;
}
break;
case AWAIT_REQUEST:
// Wait for one or more flags to be raised
synchronized(this) {
while(writerFlags == 0) {
try {
wait();
} catch(InterruptedException e) {
Thread.currentThread().interrupt();
}
}
flags = writerFlags;
writerFlags = 0;
}
// Handle the flags in approximate order of urgency
if((flags & Flags.CONTACT_REMOVED) != 0) {
connection.dispose(true);
return;
}
if((flags & Flags.TRANSPORTS_UPDATED) != 0) {
sendTransportUpdate(proto);
}
if((flags & Flags.SUBSCRIPTIONS_UPDATED) != 0) {
sendSubscriptionUpdate(proto);
}
if((flags & Flags.BATCH_RECEIVED) != 0) {
sendAcks(proto);
}
if((flags & Flags.OFFER_RECEIVED) != 0) {
sendRequest(proto);
}
if((flags & Flags.REQUEST_RECEIVED) != 0) {
state = State.SEND_BATCHES;
}
if((flags & Flags.MESSAGES_ADDED) != 0) {
// Ignored in this state
}
break;
case SEND_BATCHES:
// Check whether any flags have been raised
synchronized(this) {
flags = writerFlags;
writerFlags = 0;
}
// Handle the flags in approximate order of urgency
if((flags & Flags.CONTACT_REMOVED) != 0) {
connection.dispose(true);
return;
}
if((flags & Flags.TRANSPORTS_UPDATED) != 0) {
sendTransportUpdate(proto);
}
if((flags & Flags.SUBSCRIPTIONS_UPDATED) != 0) {
sendSubscriptionUpdate(proto);
}
if((flags & Flags.BATCH_RECEIVED) != 0) {
sendAcks(proto);
}
if((flags & Flags.OFFER_RECEIVED) != 0) {
sendRequest(proto);
}
if((flags & Flags.REQUEST_RECEIVED) != 0) {
// Should only be received in state AWAIT_REQUEST
throw new IOException("Unexpected request packet");
}
if((flags & Flags.MESSAGES_ADDED) != 0) {
// Ignored in this state
}
// Try to send a batch
if(!sendBatch(proto)) state = State.SEND_OFFER;
break;
} }
if(task == CLOSE_CONNECTION) break;
task.run();
} }
connection.dispose(true);
} catch(DbException e) { } catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
connection.dispose(false); connection.dispose(false);
@@ -330,95 +221,14 @@ abstract class StreamConnection implements DatabaseListener {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
connection.dispose(false); connection.dispose(false);
} }
// Success
connection.dispose(true);
}
private void sendAcks(ProtocolWriter proto)
throws DbException, IOException {
int maxBatches = proto.getMaxBatchesForAck(Long.MAX_VALUE);
Ack a = db.generateAck(contactId, maxBatches);
while(a != null) {
proto.writeAck(a);
a = db.generateAck(contactId, maxBatches);
}
}
private boolean sendBatch(ProtocolWriter proto)
throws DbException, IOException {
Collection<MessageId> req;
// Retrieve the requested message IDs
synchronized(this) {
assert offered == null;
assert requested != null;
req = requested;
}
// Try to generate a batch, updating the collection of message IDs
int capacity = proto.getMessageCapacityForBatch(Long.MAX_VALUE);
RawBatch b = db.generateBatch(contactId, capacity, req);
if(b == null) {
// No more batches can be generated - discard the remaining IDs
synchronized(this) {
assert offered == null;
assert requested == req;
requested = null;
}
return false;
} else {
proto.writeBatch(b);
return true;
}
}
private boolean sendOffer(ProtocolWriter proto)
throws DbException, IOException {
// Generate an offer
int maxMessages = proto.getMaxMessagesForOffer(Long.MAX_VALUE);
Offer o = db.generateOffer(contactId, maxMessages);
if(o == null) return false;
proto.writeOffer(o);
// Store the offered message IDs
synchronized(this) {
assert offered == null;
assert requested == null;
offered = o.getMessageIds();
}
return true;
}
private void sendRequest(ProtocolWriter proto)
throws DbException, IOException {
Offer o;
// Retrieve the incoming offer
synchronized(this) {
assert incomingOffer != null;
o = incomingOffer;
incomingOffer = null;
}
// Process the offer and generate a request
Request r = db.receiveOffer(contactId, o);
proto.writeRequest(r);
}
private void sendTransportUpdate(ProtocolWriter proto)
throws DbException, IOException {
TransportUpdate t = db.generateTransportUpdate(contactId);
if(t != null) proto.writeTransportUpdate(t);
}
private void sendSubscriptionUpdate(ProtocolWriter proto)
throws DbException, IOException {
SubscriptionUpdate s = db.generateSubscriptionUpdate(contactId);
if(s != null) proto.writeSubscriptionUpdate(s);
} }
// This task runs on a database thread
private class ReceiveAck implements Runnable { private class ReceiveAck implements Runnable {
private final ContactId contactId;
private final Ack ack; private final Ack ack;
private ReceiveAck(ContactId contactId, Ack ack) { private ReceiveAck(Ack ack) {
this.contactId = contactId;
this.ack = ack; this.ack = ack;
} }
@@ -431,13 +241,12 @@ abstract class StreamConnection implements DatabaseListener {
} }
} }
// This task runs on a database thread
private class ReceiveBatch implements Runnable { private class ReceiveBatch implements Runnable {
private final ContactId contactId;
private final UnverifiedBatch batch; private final UnverifiedBatch batch;
private ReceiveBatch(ContactId contactId, UnverifiedBatch batch) { private ReceiveBatch(UnverifiedBatch batch) {
this.contactId = contactId;
this.batch = batch; this.batch = batch;
} }
@@ -453,13 +262,54 @@ abstract class StreamConnection implements DatabaseListener {
} }
} }
// This task runs on a database thread
private class ReceiveOffer implements Runnable {
private final Offer offer;
private ReceiveOffer(Offer offer) {
this.offer = offer;
}
public void run() {
try {
Request r = db.receiveOffer(contactId, offer);
synchronized(StreamConnection.this) {
writerTasks.add(new WriteRequest(r));
StreamConnection.this.notifyAll();
}
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
}
}
}
// This task runs on the writer thread
private class WriteRequest implements Runnable {
private final Request request;
private WriteRequest(Request request) {
this.request = request;
}
public void run() {
assert writer != null;
try {
writer.writeRequest(request);
} catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
connection.dispose(false);
}
}
}
// This task runs on a database thread
private class SetSeen implements Runnable { private class SetSeen implements Runnable {
private final ContactId contactId;
private final Collection<MessageId> seen; private final Collection<MessageId> seen;
private SetSeen(ContactId contactId, Collection<MessageId> seen) { private SetSeen(Collection<MessageId> seen) {
this.contactId = contactId;
this.seen = seen; this.seen = seen;
} }
@@ -472,14 +322,12 @@ abstract class StreamConnection implements DatabaseListener {
} }
} }
// This task runs on a database thread
private class ReceiveSubscriptionUpdate implements Runnable { private class ReceiveSubscriptionUpdate implements Runnable {
private final ContactId contactId;
private final SubscriptionUpdate update; private final SubscriptionUpdate update;
private ReceiveSubscriptionUpdate(ContactId contactId, private ReceiveSubscriptionUpdate(SubscriptionUpdate update) {
SubscriptionUpdate update) {
this.contactId = contactId;
this.update = update; this.update = update;
} }
@@ -492,14 +340,12 @@ abstract class StreamConnection implements DatabaseListener {
} }
} }
// This task runs on a database thread
private class ReceiveTransportUpdate implements Runnable { private class ReceiveTransportUpdate implements Runnable {
private final ContactId contactId;
private final TransportUpdate update; private final TransportUpdate update;
private ReceiveTransportUpdate(ContactId contactId, private ReceiveTransportUpdate(TransportUpdate update) {
TransportUpdate update) {
this.contactId = contactId;
this.update = update; this.update = update;
} }
@@ -511,4 +357,226 @@ abstract class StreamConnection implements DatabaseListener {
} }
} }
} }
// This task runs on a database thread
private class GenerateAcks implements Runnable {
public void run() {
assert writer != null;
int maxBatches = writer.getMaxBatchesForAck(Long.MAX_VALUE);
try {
Ack a = db.generateAck(contactId, maxBatches);
while(a != null) {
synchronized(StreamConnection.this) {
writerTasks.add(new WriteAck(a));
StreamConnection.this.notifyAll();
}
a = db.generateAck(contactId, maxBatches);
}
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
}
}
}
// This task runs on the writer thread
private class WriteAck implements Runnable {
private final Ack ack;
private WriteAck(Ack ack) {
this.ack = ack;
}
public void run() {
assert writer != null;
try {
writer.writeAck(ack);
} catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
connection.dispose(false);
}
}
}
// This task runs on a database thred
private class GenerateBatch implements Runnable {
private final Collection<MessageId> requested;
private GenerateBatch(Collection<MessageId> requested) {
this.requested = requested;
}
public void run() {
assert writer != null;
int capacity = writer.getMessageCapacityForBatch(Long.MAX_VALUE);
try {
RawBatch b = db.generateBatch(contactId, capacity, requested);
if(b == null) {
// No batch to write - send another offer
new GenerateOffer().run();
} else {
// Write the batch
synchronized(StreamConnection.this) {
writerTasks.add(new WriteBatch(b, requested));
StreamConnection.this.notifyAll();
}
}
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
}
}
}
// This task runs on the writer thread
private class WriteBatch implements Runnable {
private final RawBatch batch;
private final Collection<MessageId> requested;
private WriteBatch(RawBatch batch, Collection<MessageId> requested) {
this.batch = batch;
this.requested = requested;
}
public void run() {
assert writer != null;
try {
writer.writeBatch(batch);
if(requested.isEmpty()) {
// No more batches to send - send another offer
dbExecutor.execute(new GenerateOffer());
} else {
// Send another batch
dbExecutor.execute(new GenerateBatch(requested));
}
} catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
connection.dispose(false);
}
}
}
// This task runs on a database thread
private class GenerateOffer implements Runnable {
public void run() {
assert writer != null;
int maxMessages = writer.getMaxMessagesForOffer(Long.MAX_VALUE);
try {
Offer o = db.generateOffer(contactId, maxMessages);
if(o == null) {
// No messages to offer - wait for some to be added
canSendOffer.set(true);
} else {
synchronized(StreamConnection.this) {
// Store the offered message IDs
setOfferedMessageIds(o.getMessageIds());
// Write the offer on the writer thread
writerTasks.add(new WriteOffer(o));
StreamConnection.this.notifyAll();
}
}
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
}
}
}
// This task runs on the writer thread
private class WriteOffer implements Runnable {
private final Offer offer;
private WriteOffer(Offer offer) {
this.offer = offer;
}
public void run() {
assert writer != null;
try {
writer.writeOffer(offer);
} catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
connection.dispose(false);
}
}
}
// This task runs on a database thread
private class GenerateSubscriptionUpdate implements Runnable {
public void run() {
try {
SubscriptionUpdate s = db.generateSubscriptionUpdate(contactId);
if(s != null) {
synchronized(StreamConnection.this) {
writerTasks.add(new WriteSubscriptionUpdate(s));
StreamConnection.this.notifyAll();
}
}
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
}
}
}
// This task runs on the writer thread
private class WriteSubscriptionUpdate implements Runnable {
private final SubscriptionUpdate update;
private WriteSubscriptionUpdate(SubscriptionUpdate update) {
this.update = update;
}
public void run() {
assert writer != null;
try {
writer.writeSubscriptionUpdate(update);
} catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
connection.dispose(false);
}
}
}
// This task runs on a database thread
private class GenerateTransportUpdate implements Runnable {
public void run() {
try {
TransportUpdate t = db.generateTransportUpdate(contactId);
if(t != null) {
synchronized(StreamConnection.this) {
writerTasks.add(new WriteTransportUpdate(t));
StreamConnection.this.notifyAll();
}
}
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
}
}
}
// This task runs on the writer thread
private class WriteTransportUpdate implements Runnable {
private final TransportUpdate update;
private WriteTransportUpdate(TransportUpdate update) {
this.update = update;
}
public void run() {
assert writer != null;
try {
writer.writeTransportUpdate(update);
} catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
connection.dispose(false);
}
}
}
} }

View File

@@ -4,10 +4,10 @@ import java.util.concurrent.Executor;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DatabaseExecutor;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
@@ -18,23 +18,21 @@ import com.google.inject.Inject;
class StreamConnectionFactoryImpl implements StreamConnectionFactory { class StreamConnectionFactoryImpl implements StreamConnectionFactory {
private final Executor executor; private final Executor dbExecutor;
private final DatabaseComponent db; private final DatabaseComponent db;
private final SerialComponent serial;
private final ConnectionReaderFactory connReaderFactory; private final ConnectionReaderFactory connReaderFactory;
private final ConnectionWriterFactory connWriterFactory; private final ConnectionWriterFactory connWriterFactory;
private final ProtocolReaderFactory protoReaderFactory; private final ProtocolReaderFactory protoReaderFactory;
private final ProtocolWriterFactory protoWriterFactory; private final ProtocolWriterFactory protoWriterFactory;
@Inject @Inject
StreamConnectionFactoryImpl(Executor executor, DatabaseComponent db, StreamConnectionFactoryImpl(@DatabaseExecutor Executor dbExecutor,
SerialComponent serial, ConnectionReaderFactory connReaderFactory, DatabaseComponent db, ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory) { ProtocolWriterFactory protoWriterFactory) {
this.executor = executor; this.dbExecutor = dbExecutor;
this.db = db; this.db = db;
this.serial = serial;
this.connReaderFactory = connReaderFactory; this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory; this.connWriterFactory = connWriterFactory;
this.protoReaderFactory = protoReaderFactory; this.protoReaderFactory = protoReaderFactory;
@@ -43,9 +41,9 @@ class StreamConnectionFactoryImpl implements StreamConnectionFactory {
public void createIncomingConnection(ConnectionContext ctx, public void createIncomingConnection(ConnectionContext ctx,
StreamTransportConnection s, byte[] tag) { StreamTransportConnection s, byte[] tag) {
final StreamConnection conn = new IncomingStreamConnection(executor, db, final StreamConnection conn = new IncomingStreamConnection(dbExecutor,
serial, connReaderFactory, connWriterFactory, db, connReaderFactory, connWriterFactory, protoReaderFactory,
protoReaderFactory, protoWriterFactory, ctx, s, tag); protoWriterFactory, ctx, s, tag);
Runnable write = new Runnable() { Runnable write = new Runnable() {
public void run() { public void run() {
conn.write(); conn.write();
@@ -62,9 +60,9 @@ class StreamConnectionFactoryImpl implements StreamConnectionFactory {
public void createOutgoingConnection(ContactId c, TransportIndex i, public void createOutgoingConnection(ContactId c, TransportIndex i,
StreamTransportConnection s) { StreamTransportConnection s) {
final StreamConnection conn = new OutgoingStreamConnection(executor, db, final StreamConnection conn = new OutgoingStreamConnection(dbExecutor,
serial, connReaderFactory, connWriterFactory, db, connReaderFactory, connWriterFactory, protoReaderFactory,
protoReaderFactory, protoWriterFactory, c, i, s); protoWriterFactory, c, i, s);
Runnable write = new Runnable() { Runnable write = new Runnable() {
public void run() { public void run() {
conn.write(); conn.write();