Moved message verification and DB writes off the IO thread.

This commit is contained in:
akwizgran
2011-12-05 22:52:00 +00:00
parent ab722f9371
commit e24a3218ca
24 changed files with 468 additions and 169 deletions

View File

@@ -10,7 +10,7 @@ public interface ProtocolReader {
Ack readAck() throws IOException;
boolean hasBatch() throws IOException;
Batch readBatch() throws IOException;
UnverifiedBatch readBatch() throws IOException;
boolean hasOffer() throws IOException;
Offer readOffer() throws IOException;

View File

@@ -0,0 +1,8 @@
package net.sf.briar.api.protocol;
import java.security.GeneralSecurityException;
public interface UnverifiedBatch {
Batch verify() throws GeneralSecurityException;
}

View File

@@ -1,12 +0,0 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Message;
interface BatchFactory {
Batch createBatch(BatchId id, Collection<Message> messages);
}

View File

@@ -1,14 +0,0 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Message;
class BatchFactoryImpl implements BatchFactory {
public Batch createBatch(BatchId id, Collection<Message> messages) {
return new BatchImpl(id, messages);
}
}

View File

@@ -5,31 +5,31 @@ import java.util.List;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.CountingConsumer;
import net.sf.briar.api.serial.DigestingConsumer;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
class BatchReader implements ObjectReader<Batch> {
class BatchReader implements ObjectReader<UnverifiedBatch> {
private final MessageDigest messageDigest;
private final ObjectReader<Message> messageReader;
private final BatchFactory batchFactory;
private final ObjectReader<UnverifiedMessage> messageReader;
private final UnverifiedBatchFactory batchFactory;
BatchReader(CryptoComponent crypto, ObjectReader<Message> messageReader,
BatchFactory batchFactory) {
BatchReader(CryptoComponent crypto,
ObjectReader<UnverifiedMessage> messageReader,
UnverifiedBatchFactory batchFactory) {
messageDigest = crypto.getMessageDigest();
this.messageReader = messageReader;
this.batchFactory = batchFactory;
}
public Batch readObject(Reader r) throws IOException {
public UnverifiedBatch readObject(Reader r) throws IOException {
// Initialise the consumers
Consumer counting =
new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
@@ -40,12 +40,12 @@ class BatchReader implements ObjectReader<Batch> {
r.addConsumer(digesting);
r.readStructId(Types.BATCH);
r.addObjectReader(Types.MESSAGE, messageReader);
List<Message> messages = r.readList(Message.class);
List<UnverifiedMessage> messages = r.readList(UnverifiedMessage.class);
r.removeObjectReader(Types.MESSAGE);
r.removeConsumer(digesting);
r.removeConsumer(counting);
// Build and return the batch
BatchId id = new BatchId(messageDigest.digest());
return batchFactory.createBatch(id, messages);
return batchFactory.createUnverifiedBatch(id, messages);
}
}

View File

@@ -1,19 +1,10 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.PublicKey;
import java.security.Signature;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.KeyParser;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types;
@@ -22,28 +13,21 @@ import net.sf.briar.api.serial.CountingConsumer;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
class MessageReader implements ObjectReader<Message> {
class MessageReader implements ObjectReader<UnverifiedMessage> {
private final ObjectReader<MessageId> messageIdReader;
private final ObjectReader<Group> groupReader;
private final ObjectReader<Author> authorReader;
private final KeyParser keyParser;
private final Signature signature;
private final MessageDigest messageDigest;
MessageReader(CryptoComponent crypto,
ObjectReader<MessageId> messageIdReader,
MessageReader(ObjectReader<MessageId> messageIdReader,
ObjectReader<Group> groupReader,
ObjectReader<Author> authorReader) {
this.messageIdReader = messageIdReader;
this.groupReader = groupReader;
this.authorReader = authorReader;
keyParser = crypto.getKeyParser();
signature = crypto.getSignature();
messageDigest = crypto.getMessageDigest();
}
public Message readObject(Reader r) throws IOException {
public UnverifiedMessage readObject(Reader r) throws IOException {
CopyingConsumer copying = new CopyingConsumer();
CountingConsumer counting =
new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
@@ -106,35 +90,8 @@ class MessageReader implements ObjectReader<Message> {
r.removeConsumer(counting);
r.removeConsumer(copying);
byte[] raw = copying.getCopy();
// Verify the author's signature, if there is one
if(author != null) {
try {
PublicKey k = keyParser.parsePublicKey(author.getPublicKey());
signature.initVerify(k);
signature.update(raw, 0, signedByAuthor);
if(!signature.verify(authorSig)) throw new FormatException();
} catch(GeneralSecurityException e) {
throw new FormatException();
}
}
// Verify the group's signature, if there is one
if(group != null && group.getPublicKey() != null) {
try {
PublicKey k = keyParser.parsePublicKey(group.getPublicKey());
signature.initVerify(k);
signature.update(raw, 0, signedByGroup);
if(!signature.verify(groupSig)) throw new FormatException();
} catch(GeneralSecurityException e) {
throw new FormatException();
}
}
// Hash the message, including the signatures, to get the message ID
messageDigest.reset();
messageDigest.update(raw);
MessageId id = new MessageId(messageDigest.digest());
GroupId groupId = group == null ? null : group.getId();
AuthorId authorId = author == null ? null : author.getId();
return new MessageImpl(id, parent, groupId, authorId, subject,
timestamp, raw, bodyStart, body.length);
return new UnverifiedMessageImpl(parent, group, author, subject,
timestamp, raw, authorSig, groupSig, bodyStart, body.length,
signedByAuthor, signedByGroup);
}
}

View File

@@ -4,10 +4,8 @@ import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.AuthorFactory;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageFactory;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer;
@@ -15,6 +13,7 @@ import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.serial.ObjectReader;
import com.google.inject.AbstractModule;
@@ -26,14 +25,15 @@ public class ProtocolModule extends AbstractModule {
protected void configure() {
bind(AckFactory.class).to(AckFactoryImpl.class);
bind(AuthorFactory.class).to(AuthorFactoryImpl.class);
bind(BatchFactory.class).to(BatchFactoryImpl.class);
bind(GroupFactory.class).to(GroupFactoryImpl.class);
bind(MessageFactory.class).to(MessageFactoryImpl.class);
bind(OfferFactory.class).to(OfferFactoryImpl.class);
bind(ProtocolReaderFactory.class).to(ProtocolReaderFactoryImpl.class);
bind(RequestFactory.class).to(RequestFactoryImpl.class);
bind(SubscriptionUpdateFactory.class).to(SubscriptionUpdateFactoryImpl.class);
bind(SubscriptionUpdateFactory.class).to(
SubscriptionUpdateFactoryImpl.class);
bind(TransportUpdateFactory.class).to(TransportUpdateFactoryImpl.class);
bind(UnverifiedBatchFactory.class).to(UnverifiedBatchFactoryImpl.class);
}
@Provides
@@ -48,8 +48,9 @@ public class ProtocolModule extends AbstractModule {
}
@Provides
ObjectReader<Batch> getBatchReader(CryptoComponent crypto,
ObjectReader<Message> messageReader, BatchFactory batchFactory) {
ObjectReader<UnverifiedBatch> getBatchReader(CryptoComponent crypto,
ObjectReader<UnverifiedMessage> messageReader,
UnverifiedBatchFactory batchFactory) {
return new BatchReader(crypto, messageReader, batchFactory);
}
@@ -65,12 +66,11 @@ public class ProtocolModule extends AbstractModule {
}
@Provides
ObjectReader<Message> getMessageReader(CryptoComponent crypto,
ObjectReader<UnverifiedMessage> getMessageReader(
ObjectReader<MessageId> messageIdReader,
ObjectReader<Group> groupReader,
ObjectReader<Author> authorReader) {
return new MessageReader(crypto, messageIdReader, groupReader,
authorReader);
return new MessageReader(messageIdReader, groupReader, authorReader);
}
@Provides

View File

@@ -3,13 +3,13 @@ package net.sf.briar.protocol;
import java.io.InputStream;
import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.ProtocolReader;
import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.ReaderFactory;
@@ -20,7 +20,7 @@ class ProtocolReaderFactoryImpl implements ProtocolReaderFactory {
private final ReaderFactory readerFactory;
private final Provider<ObjectReader<Ack>> ackProvider;
private final Provider<ObjectReader<Batch>> batchProvider;
private final Provider<ObjectReader<UnverifiedBatch>> batchProvider;
private final Provider<ObjectReader<Offer>> offerProvider;
private final Provider<ObjectReader<Request>> requestProvider;
private final Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider;
@@ -29,7 +29,7 @@ class ProtocolReaderFactoryImpl implements ProtocolReaderFactory {
@Inject
ProtocolReaderFactoryImpl(ReaderFactory readerFactory,
Provider<ObjectReader<Ack>> ackProvider,
Provider<ObjectReader<Batch>> batchProvider,
Provider<ObjectReader<UnverifiedBatch>> batchProvider,
Provider<ObjectReader<Offer>> offerProvider,
Provider<ObjectReader<Request>> requestProvider,
Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider,

View File

@@ -4,13 +4,13 @@ import java.io.IOException;
import java.io.InputStream;
import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.ProtocolReader;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory;
@@ -20,7 +20,8 @@ class ProtocolReaderImpl implements ProtocolReader {
private final Reader reader;
ProtocolReaderImpl(InputStream in, ReaderFactory readerFactory,
ObjectReader<Ack> ackReader, ObjectReader<Batch> batchReader,
ObjectReader<Ack> ackReader,
ObjectReader<UnverifiedBatch> batchReader,
ObjectReader<Offer> offerReader,
ObjectReader<Request> requestReader,
ObjectReader<SubscriptionUpdate> subscriptionReader,
@@ -50,8 +51,8 @@ class ProtocolReaderImpl implements ProtocolReader {
return reader.hasStruct(Types.BATCH);
}
public Batch readBatch() throws IOException {
return reader.readStruct(Types.BATCH, Batch.class);
public UnverifiedBatch readBatch() throws IOException {
return reader.readStruct(Types.BATCH, UnverifiedBatch.class);
}
public boolean hasOffer() throws IOException {

View File

@@ -0,0 +1,12 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.UnverifiedBatch;
interface UnverifiedBatchFactory {
UnverifiedBatch createUnverifiedBatch(BatchId id,
Collection<UnverifiedMessage> messages);
}

View File

@@ -0,0 +1,24 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.UnverifiedBatch;
import com.google.inject.Inject;
class UnverifiedBatchFactoryImpl implements UnverifiedBatchFactory {
private final CryptoComponent crypto;
@Inject
UnverifiedBatchFactoryImpl(CryptoComponent crypto) {
this.crypto = crypto;
}
public UnverifiedBatch createUnverifiedBatch(BatchId id,
Collection<UnverifiedMessage> messages) {
return new UnverifiedBatchImpl(crypto, id, messages);
}
}

View File

@@ -0,0 +1,83 @@
package net.sf.briar.protocol;
import java.security.GeneralSecurityException;
import java.security.PublicKey;
import java.security.Signature;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.KeyParser;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.UnverifiedBatch;
class UnverifiedBatchImpl implements UnverifiedBatch {
private final CryptoComponent crypto;
private final BatchId id;
private final Collection<UnverifiedMessage> messages;
// Initialise lazily - the batch may be empty or contain unsigned messages
private MessageDigest messageDigest = null;
private KeyParser keyParser = null;
private Signature signature = null;
UnverifiedBatchImpl(CryptoComponent crypto, BatchId id,
Collection<UnverifiedMessage> messages) {
this.crypto = crypto;
this.id = id;
this.messages = messages;
}
public Batch verify() throws GeneralSecurityException {
List<Message> verified = new ArrayList<Message>();
for(UnverifiedMessage m : messages) verified.add(verify(m));
return new BatchImpl(id, Collections.unmodifiableList(verified));
}
private Message verify(UnverifiedMessage m)
throws GeneralSecurityException {
// Hash the message, including the signatures, to get the message ID
byte[] raw = m.getRaw();
if(messageDigest == null) messageDigest = crypto.getMessageDigest();
messageDigest.update(raw);
MessageId id = new MessageId(messageDigest.digest());
// Verify the author's signature, if there is one
Author author = m.getAuthor();
if(author != null) {
if(keyParser == null) keyParser = crypto.getKeyParser();
PublicKey k = keyParser.parsePublicKey(author.getPublicKey());
if(signature == null) signature = crypto.getSignature();
signature.initVerify(k);
signature.update(raw, 0, m.getLengthSignedByAuthor());
if(!signature.verify(m.getAuthorSignature()))
throw new GeneralSecurityException();
}
// Verify the group's signature, if there is one
Group group = m.getGroup();
if(group != null && group.getPublicKey() != null) {
if(keyParser == null) keyParser = crypto.getKeyParser();
PublicKey k = keyParser.parsePublicKey(group.getPublicKey());
if(signature == null) signature = crypto.getSignature();
signature.initVerify(k);
signature.update(raw, 0, m.getLengthSignedByGroup());
if(!signature.verify(m.getGroupSignature()))
throw new GeneralSecurityException();
}
GroupId groupId = group == null ? null : group.getId();
AuthorId authorId = author == null ? null : author.getId();
return new MessageImpl(id, m.getParent(), groupId, authorId,
m.getSubject(), m.getTimestamp(), raw, m.getBodyStart(),
m.getBodyLength());
}
}

View File

@@ -0,0 +1,32 @@
package net.sf.briar.protocol;
import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.MessageId;
interface UnverifiedMessage {
MessageId getParent();
Group getGroup();
Author getAuthor();
String getSubject();
long getTimestamp();
byte[] getRaw();
byte[] getAuthorSignature();
byte[] getGroupSignature();
int getBodyStart();
int getBodyLength();
int getLengthSignedByAuthor();
int getLengthSignedByGroup();
}

View File

@@ -0,0 +1,82 @@
package net.sf.briar.protocol;
import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.MessageId;
class UnverifiedMessageImpl implements UnverifiedMessage {
private final MessageId parent;
private final Group group;
private final Author author;
private final String subject;
private final long timestamp;
private final byte[] raw, authorSig, groupSig;
private final int bodyStart, bodyLength, signedByAuthor, signedByGroup;
UnverifiedMessageImpl(MessageId parent, Group group, Author author,
String subject, long timestamp, byte[] raw, byte[] authorSig,
byte[] groupSig, int bodyStart, int bodyLength, int signedByAuthor,
int signedByGroup) {
this.parent = parent;
this.group = group;
this.author = author;
this.subject = subject;
this.timestamp = timestamp;
this.raw = raw;
this.authorSig = authorSig;
this.groupSig = groupSig;
this.bodyStart = bodyStart;
this.bodyLength = bodyLength;
this.signedByAuthor = signedByAuthor;
this.signedByGroup = signedByGroup;
}
public MessageId getParent() {
return parent;
}
public Group getGroup() {
return group;
}
public Author getAuthor() {
return author;
}
public String getSubject() {
return subject;
}
public long getTimestamp() {
return timestamp;
}
public byte[] getRaw() {
return raw;
}
public byte[] getAuthorSignature() {
return authorSig;
}
public byte[] getGroupSignature() {
return groupSig;
}
public int getBodyStart() {
return bodyStart;
}
public int getBodyLength() {
return bodyLength;
}
public int getLengthSignedByAuthor() {
return signedByAuthor;
}
public int getLengthSignedByGroup() {
return signedByGroup;
}
}

View File

@@ -1,5 +1,7 @@
package net.sf.briar.transport.batch;
import java.util.concurrent.Executor;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.protocol.ProtocolReaderFactory;
@@ -16,6 +18,7 @@ import com.google.inject.Inject;
class BatchConnectionFactoryImpl implements BatchConnectionFactory {
private final Executor executor;
private final ConnectionReaderFactory connReaderFactory;
private final ConnectionWriterFactory connWriterFactory;
private final DatabaseComponent db;
@@ -23,10 +26,12 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory {
private final ProtocolWriterFactory protoWriterFactory;
@Inject
BatchConnectionFactoryImpl(ConnectionReaderFactory connReaderFactory,
BatchConnectionFactoryImpl(Executor executor,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory) {
this.executor = executor;
this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory;
this.db = db;
@@ -37,7 +42,8 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory {
public void createIncomingConnection(ConnectionContext ctx,
BatchTransportReader r, byte[] tag) {
final IncomingBatchConnection conn = new IncomingBatchConnection(
connReaderFactory, db, protoReaderFactory, ctx, r, tag);
executor, connReaderFactory, db, protoReaderFactory, ctx, r,
tag);
Runnable read = new Runnable() {
public void run() {
conn.read();

View File

@@ -1,6 +1,8 @@
package net.sf.briar.transport.batch;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.concurrent.Executor;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -9,11 +11,11 @@ import net.sf.briar.api.FormatException;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.ProtocolReader;
import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.transport.BatchTransportReader;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader;
@@ -24,6 +26,7 @@ class IncomingBatchConnection {
private static final Logger LOG =
Logger.getLogger(IncomingBatchConnection.class.getName());
private final Executor executor;
private final ConnectionReaderFactory connFactory;
private final DatabaseComponent db;
private final ProtocolReaderFactory protoFactory;
@@ -31,9 +34,11 @@ class IncomingBatchConnection {
private final BatchTransportReader reader;
private final byte[] tag;
IncomingBatchConnection(ConnectionReaderFactory connFactory,
IncomingBatchConnection(Executor executor,
ConnectionReaderFactory connFactory,
DatabaseComponent db, ProtocolReaderFactory protoFactory,
ConnectionContext ctx, BatchTransportReader reader, byte[] tag) {
this.executor = executor;
this.connFactory = connFactory;
this.db = db;
this.protoFactory = protoFactory;
@@ -48,28 +53,68 @@ class IncomingBatchConnection {
reader.getInputStream(), ctx.getSecret(), tag);
ProtocolReader proto = protoFactory.createProtocolReader(
conn.getInputStream());
ContactId c = ctx.getContactId();
final ContactId c = ctx.getContactId();
// Read packets until EOF
while(!proto.eof()) {
if(proto.hasAck()) {
Ack a = proto.readAck();
db.receiveAck(c, a);
final Ack a = proto.readAck();
// Store the ack on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveAck(c, a);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else if(proto.hasBatch()) {
Batch b = proto.readBatch();
db.receiveBatch(c, b);
final UnverifiedBatch b = proto.readBatch();
// Verify and store the batch on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveBatch(c, b.verify());
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
} catch(GeneralSecurityException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else if(proto.hasSubscriptionUpdate()) {
SubscriptionUpdate s = proto.readSubscriptionUpdate();
db.receiveSubscriptionUpdate(c, s);
final SubscriptionUpdate s = proto.readSubscriptionUpdate();
// Store the update on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveSubscriptionUpdate(c, s);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else if(proto.hasTransportUpdate()) {
TransportUpdate t = proto.readTransportUpdate();
db.receiveTransportUpdate(c, t);
final TransportUpdate t = proto.readTransportUpdate();
// Store the update on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveTransportUpdate(c, t);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else {
throw new FormatException();
}
}
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
reader.dispose(false);
} catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
reader.dispose(false);

View File

@@ -1,6 +1,7 @@
package net.sf.briar.transport.stream;
import java.io.IOException;
import java.util.concurrent.Executor;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
@@ -18,14 +19,16 @@ class IncomingStreamConnection extends StreamConnection {
private final ConnectionContext ctx;
private final byte[] tag;
IncomingStreamConnection(ConnectionReaderFactory connReaderFactory,
IncomingStreamConnection(Executor executor,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory,
ConnectionContext ctx, StreamTransportConnection connection,
byte[] tag) {
super(connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, ctx.getContactId(), connection);
super(executor, connReaderFactory, connWriterFactory, db,
protoReaderFactory, protoWriterFactory, ctx.getContactId(),
connection);
this.ctx = ctx;
this.tag = tag;
}

View File

@@ -1,6 +1,7 @@
package net.sf.briar.transport.stream;
import java.io.IOException;
import java.util.concurrent.Executor;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent;
@@ -21,14 +22,15 @@ class OutgoingStreamConnection extends StreamConnection {
private ConnectionContext ctx = null; // Locking: this
OutgoingStreamConnection(ConnectionReaderFactory connReaderFactory,
OutgoingStreamConnection(Executor executor,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId,
TransportIndex transportIndex,
StreamTransportConnection connection) {
super(connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, contactId, connection);
super(executor, connReaderFactory, connWriterFactory, db,
protoReaderFactory, protoWriterFactory, contactId, connection);
this.transportIndex = transportIndex;
}

View File

@@ -3,12 +3,14 @@ package net.sf.briar.transport.stream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -24,7 +26,6 @@ import net.sf.briar.api.db.event.LocalTransportsUpdatedEvent;
import net.sf.briar.api.db.event.MessagesAddedEvent;
import net.sf.briar.api.db.event.SubscriptionsUpdatedEvent;
import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.ProtocolReader;
@@ -32,6 +33,7 @@ import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.OfferWriter;
@@ -52,6 +54,7 @@ abstract class StreamConnection implements DatabaseListener {
private static final Logger LOG =
Logger.getLogger(StreamConnection.class.getName());
protected final Executor executor;
protected final ConnectionReaderFactory connReaderFactory;
protected final ConnectionWriterFactory connWriterFactory;
protected final DatabaseComponent db;
@@ -65,11 +68,13 @@ abstract class StreamConnection implements DatabaseListener {
private LinkedList<MessageId> requested = null; // Locking: this
private Offer incomingOffer = null; // Locking: this
StreamConnection(ConnectionReaderFactory connReaderFactory,
StreamConnection(Executor executor,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId,
StreamTransportConnection connection) {
this.executor = executor;
this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory;
this.db = db;
@@ -119,11 +124,34 @@ abstract class StreamConnection implements DatabaseListener {
ProtocolReader proto = protoReaderFactory.createProtocolReader(in);
while(!proto.eof()) {
if(proto.hasAck()) {
Ack a = proto.readAck();
db.receiveAck(contactId, a);
final Ack a = proto.readAck();
// Store the ack on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveAck(contactId, a);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else if(proto.hasBatch()) {
Batch b = proto.readBatch();
db.receiveBatch(contactId, b);
final UnverifiedBatch b = proto.readBatch();
// Verify and store the batch on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveBatch(contactId, b.verify());
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
} catch(GeneralSecurityException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else if(proto.hasOffer()) {
Offer o = proto.readOffer();
// Store the incoming offer and notify the writer
@@ -151,8 +179,19 @@ abstract class StreamConnection implements DatabaseListener {
if(b.get(i++)) req.add(m);
else seen.add(m);
}
// Mark the unrequested messages as seen
db.setSeen(contactId, Collections.unmodifiableList(seen));
// Mark the unrequested messages as seen on another thread
final List<MessageId> l =
Collections.unmodifiableList(seen);
executor.execute(new Runnable() {
public void run() {
try {
db.setSeen(contactId, l);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
// Store the requested message IDs and notify the writer
synchronized(this) {
if(requested != null)
@@ -162,11 +201,31 @@ abstract class StreamConnection implements DatabaseListener {
notifyAll();
}
} else if(proto.hasSubscriptionUpdate()) {
SubscriptionUpdate s = proto.readSubscriptionUpdate();
db.receiveSubscriptionUpdate(contactId, s);
final SubscriptionUpdate s = proto.readSubscriptionUpdate();
// Store the update on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveSubscriptionUpdate(contactId, s);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else if(proto.hasTransportUpdate()) {
TransportUpdate t = proto.readTransportUpdate();
db.receiveTransportUpdate(contactId, t);
final TransportUpdate t = proto.readTransportUpdate();
// Store the update on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveTransportUpdate(contactId, t);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else {
throw new FormatException();
}

View File

@@ -1,5 +1,7 @@
package net.sf.briar.transport.stream;
import java.util.concurrent.Executor;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.protocol.ProtocolReaderFactory;
@@ -15,6 +17,7 @@ import com.google.inject.Inject;
class StreamConnectionFactoryImpl implements StreamConnectionFactory {
private final Executor executor;
private final ConnectionReaderFactory connReaderFactory;
private final ConnectionWriterFactory connWriterFactory;
private final DatabaseComponent db;
@@ -22,10 +25,12 @@ class StreamConnectionFactoryImpl implements StreamConnectionFactory {
private final ProtocolWriterFactory protoWriterFactory;
@Inject
StreamConnectionFactoryImpl(ConnectionReaderFactory connReaderFactory,
StreamConnectionFactoryImpl(Executor executor,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory) {
this.executor = executor;
this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory;
this.db = db;
@@ -35,7 +40,7 @@ class StreamConnectionFactoryImpl implements StreamConnectionFactory {
public void createIncomingConnection(ConnectionContext ctx,
StreamTransportConnection s, byte[] tag) {
final StreamConnection conn = new IncomingStreamConnection(
final StreamConnection conn = new IncomingStreamConnection(executor,
connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, ctx, s, tag);
Runnable write = new Runnable() {
@@ -54,7 +59,7 @@ class StreamConnectionFactoryImpl implements StreamConnectionFactory {
public void createOutgoingConnection(ContactId c, TransportIndex i,
StreamTransportConnection s) {
final StreamConnection conn = new OutgoingStreamConnection(
final StreamConnection conn = new OutgoingStreamConnection(executor,
connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, c, i, s);
Runnable write = new Runnable() {

View File

@@ -208,9 +208,9 @@ public class ProtocolIntegrationTest extends TestCase {
Ack a = protocolReader.readAck();
assertEquals(Collections.singletonList(ack), a.getBatchIds());
// Read the batch
// Read and verify the batch
assertTrue(protocolReader.hasBatch());
Batch b = protocolReader.readBatch();
Batch b = protocolReader.readBatch().verify();
Collection<Message> messages = b.getMessages();
assertEquals(4, messages.size());
Iterator<Message> it = messages.iterator();

View File

@@ -9,11 +9,10 @@ import junit.framework.TestCase;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory;
@@ -35,7 +34,8 @@ public class BatchReaderTest extends TestCase {
private final WriterFactory writerFactory;
private final CryptoComponent crypto;
private final Mockery context;
private final Message message;
private final UnverifiedMessage message;
private final ObjectReader<UnverifiedMessage> messageReader;
public BatchReaderTest() throws Exception {
super();
@@ -45,13 +45,14 @@ public class BatchReaderTest extends TestCase {
writerFactory = i.getInstance(WriterFactory.class);
crypto = i.getInstance(CryptoComponent.class);
context = new Mockery();
message = context.mock(Message.class);
message = context.mock(UnverifiedMessage.class);
messageReader = new TestMessageReader();
}
@Test
public void testFormatExceptionIfBatchIsTooLarge() throws Exception {
ObjectReader<Message> messageReader = new TestMessageReader();
BatchFactory batchFactory = context.mock(BatchFactory.class);
UnverifiedBatchFactory batchFactory =
context.mock(UnverifiedBatchFactory.class);
BatchReader batchReader = new BatchReader(crypto, messageReader,
batchFactory);
@@ -61,7 +62,7 @@ public class BatchReaderTest extends TestCase {
reader.addObjectReader(Types.BATCH, batchReader);
try {
reader.readStruct(Types.BATCH, Batch.class);
reader.readStruct(Types.BATCH, UnverifiedBatch.class);
fail();
} catch(FormatException expected) {}
context.assertIsSatisfied();
@@ -69,13 +70,13 @@ public class BatchReaderTest extends TestCase {
@Test
public void testNoFormatExceptionIfBatchIsMaximumSize() throws Exception {
ObjectReader<Message> messageReader = new TestMessageReader();
final BatchFactory batchFactory = context.mock(BatchFactory.class);
final UnverifiedBatchFactory batchFactory =
context.mock(UnverifiedBatchFactory.class);
BatchReader batchReader = new BatchReader(crypto, messageReader,
batchFactory);
final Batch batch = context.mock(Batch.class);
final UnverifiedBatch batch = context.mock(UnverifiedBatch.class);
context.checking(new Expectations() {{
oneOf(batchFactory).createBatch(with(any(BatchId.class)),
oneOf(batchFactory).createUnverifiedBatch(with(any(BatchId.class)),
with(Collections.singletonList(message)));
will(returnValue(batch));
}});
@@ -85,7 +86,8 @@ public class BatchReaderTest extends TestCase {
Reader reader = readerFactory.createReader(in);
reader.addObjectReader(Types.BATCH, batchReader);
assertEquals(batch, reader.readStruct(Types.BATCH, Batch.class));
assertEquals(batch, reader.readStruct(Types.BATCH,
UnverifiedBatch.class));
context.assertIsSatisfied();
}
@@ -98,14 +100,14 @@ public class BatchReaderTest extends TestCase {
messageDigest.update(b);
final BatchId id = new BatchId(messageDigest.digest());
ObjectReader<Message> messageReader = new TestMessageReader();
final BatchFactory batchFactory = context.mock(BatchFactory.class);
final UnverifiedBatchFactory batchFactory =
context.mock(UnverifiedBatchFactory.class);
BatchReader batchReader = new BatchReader(crypto, messageReader,
batchFactory);
final Batch batch = context.mock(Batch.class);
final UnverifiedBatch batch = context.mock(UnverifiedBatch.class);
context.checking(new Expectations() {{
// Check that the batch ID matches the expected ID
oneOf(batchFactory).createBatch(with(id),
oneOf(batchFactory).createUnverifiedBatch(with(id),
with(Collections.singletonList(message)));
will(returnValue(batch));
}});
@@ -114,20 +116,21 @@ public class BatchReaderTest extends TestCase {
Reader reader = readerFactory.createReader(in);
reader.addObjectReader(Types.BATCH, batchReader);
assertEquals(batch, reader.readStruct(Types.BATCH, Batch.class));
assertEquals(batch, reader.readStruct(Types.BATCH,
UnverifiedBatch.class));
context.assertIsSatisfied();
}
@Test
public void testEmptyBatch() throws Exception {
ObjectReader<Message> messageReader = new TestMessageReader();
final BatchFactory batchFactory = context.mock(BatchFactory.class);
final UnverifiedBatchFactory batchFactory =
context.mock(UnverifiedBatchFactory.class);
BatchReader batchReader = new BatchReader(crypto, messageReader,
batchFactory);
final Batch batch = context.mock(Batch.class);
final UnverifiedBatch batch = context.mock(UnverifiedBatch.class);
context.checking(new Expectations() {{
oneOf(batchFactory).createBatch(with(any(BatchId.class)),
with(Collections.<Message>emptyList()));
oneOf(batchFactory).createUnverifiedBatch(with(any(BatchId.class)),
with(Collections.<UnverifiedMessage>emptyList()));
will(returnValue(batch));
}});
@@ -136,7 +139,8 @@ public class BatchReaderTest extends TestCase {
Reader reader = readerFactory.createReader(in);
reader.addObjectReader(Types.BATCH, batchReader);
assertEquals(batch, reader.readStruct(Types.BATCH, Batch.class));
assertEquals(batch, reader.readStruct(Types.BATCH,
UnverifiedBatch.class));
context.assertIsSatisfied();
}
@@ -163,9 +167,9 @@ public class BatchReaderTest extends TestCase {
return out.toByteArray();
}
private class TestMessageReader implements ObjectReader<Message> {
private class TestMessageReader implements ObjectReader<UnverifiedMessage> {
public Message readObject(Reader r) throws IOException {
public UnverifiedMessage readObject(Reader r) throws IOException {
r.readStructId(Types.MESSAGE);
r.readBytes();
return message;

View File

@@ -114,7 +114,7 @@ public class ProtocolReadWriteTest extends TestCase {
Ack ack = reader.readAck();
assertEquals(Collections.singletonList(batchId), ack.getBatchIds());
Batch batch = reader.readBatch();
Batch batch = reader.readBatch().verify();
assertEquals(Collections.singletonList(message), batch.getMessages());
Offer offer = reader.readOffer();

View File

@@ -41,6 +41,7 @@ import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.db.DatabaseModule;
import net.sf.briar.lifecycle.LifecycleModule;
import net.sf.briar.plugins.ImmediateExecutor;
import net.sf.briar.protocol.ProtocolModule;
import net.sf.briar.protocol.writers.ProtocolWritersModule;
import net.sf.briar.serial.SerialModule;
@@ -187,7 +188,8 @@ public class BatchConnectionReadWriteTest extends TestCase {
bob.getInstance(ProtocolReaderFactory.class);
BatchTransportReader reader = new TestBatchTransportReader(in);
IncomingBatchConnection batchIn = new IncomingBatchConnection(
connFactory, db, protoFactory, ctx, reader, tag);
new ImmediateExecutor(), connFactory, db, protoFactory, ctx,
reader, tag);
// No messages should have been added yet
assertFalse(listener.messagesAdded);
// Read whatever needs to be read