Moved message verification and DB writes off the IO thread.

This commit is contained in:
akwizgran
2011-12-05 22:52:00 +00:00
parent ab722f9371
commit e24a3218ca
24 changed files with 468 additions and 169 deletions

View File

@@ -10,7 +10,7 @@ public interface ProtocolReader {
Ack readAck() throws IOException; Ack readAck() throws IOException;
boolean hasBatch() throws IOException; boolean hasBatch() throws IOException;
Batch readBatch() throws IOException; UnverifiedBatch readBatch() throws IOException;
boolean hasOffer() throws IOException; boolean hasOffer() throws IOException;
Offer readOffer() throws IOException; Offer readOffer() throws IOException;

View File

@@ -0,0 +1,8 @@
package net.sf.briar.api.protocol;
import java.security.GeneralSecurityException;
public interface UnverifiedBatch {
Batch verify() throws GeneralSecurityException;
}

View File

@@ -1,12 +0,0 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Message;
interface BatchFactory {
Batch createBatch(BatchId id, Collection<Message> messages);
}

View File

@@ -1,14 +0,0 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Message;
class BatchFactoryImpl implements BatchFactory {
public Batch createBatch(BatchId id, Collection<Message> messages) {
return new BatchImpl(id, messages);
}
}

View File

@@ -5,31 +5,31 @@ import java.util.List;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.MessageDigest; import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.CountingConsumer; import net.sf.briar.api.serial.CountingConsumer;
import net.sf.briar.api.serial.DigestingConsumer; import net.sf.briar.api.serial.DigestingConsumer;
import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.Reader;
class BatchReader implements ObjectReader<Batch> { class BatchReader implements ObjectReader<UnverifiedBatch> {
private final MessageDigest messageDigest; private final MessageDigest messageDigest;
private final ObjectReader<Message> messageReader; private final ObjectReader<UnverifiedMessage> messageReader;
private final BatchFactory batchFactory; private final UnverifiedBatchFactory batchFactory;
BatchReader(CryptoComponent crypto, ObjectReader<Message> messageReader, BatchReader(CryptoComponent crypto,
BatchFactory batchFactory) { ObjectReader<UnverifiedMessage> messageReader,
UnverifiedBatchFactory batchFactory) {
messageDigest = crypto.getMessageDigest(); messageDigest = crypto.getMessageDigest();
this.messageReader = messageReader; this.messageReader = messageReader;
this.batchFactory = batchFactory; this.batchFactory = batchFactory;
} }
public Batch readObject(Reader r) throws IOException { public UnverifiedBatch readObject(Reader r) throws IOException {
// Initialise the consumers // Initialise the consumers
Consumer counting = Consumer counting =
new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH); new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
@@ -40,12 +40,12 @@ class BatchReader implements ObjectReader<Batch> {
r.addConsumer(digesting); r.addConsumer(digesting);
r.readStructId(Types.BATCH); r.readStructId(Types.BATCH);
r.addObjectReader(Types.MESSAGE, messageReader); r.addObjectReader(Types.MESSAGE, messageReader);
List<Message> messages = r.readList(Message.class); List<UnverifiedMessage> messages = r.readList(UnverifiedMessage.class);
r.removeObjectReader(Types.MESSAGE); r.removeObjectReader(Types.MESSAGE);
r.removeConsumer(digesting); r.removeConsumer(digesting);
r.removeConsumer(counting); r.removeConsumer(counting);
// Build and return the batch // Build and return the batch
BatchId id = new BatchId(messageDigest.digest()); BatchId id = new BatchId(messageDigest.digest());
return batchFactory.createBatch(id, messages); return batchFactory.createUnverifiedBatch(id, messages);
} }
} }

View File

@@ -1,19 +1,10 @@
package net.sf.briar.protocol; package net.sf.briar.protocol;
import java.io.IOException; import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.PublicKey;
import java.security.Signature;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.KeyParser;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.Author; import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.Group; import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
@@ -22,28 +13,21 @@ import net.sf.briar.api.serial.CountingConsumer;
import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.Reader;
class MessageReader implements ObjectReader<Message> { class MessageReader implements ObjectReader<UnverifiedMessage> {
private final ObjectReader<MessageId> messageIdReader; private final ObjectReader<MessageId> messageIdReader;
private final ObjectReader<Group> groupReader; private final ObjectReader<Group> groupReader;
private final ObjectReader<Author> authorReader; private final ObjectReader<Author> authorReader;
private final KeyParser keyParser;
private final Signature signature;
private final MessageDigest messageDigest;
MessageReader(CryptoComponent crypto, MessageReader(ObjectReader<MessageId> messageIdReader,
ObjectReader<MessageId> messageIdReader,
ObjectReader<Group> groupReader, ObjectReader<Group> groupReader,
ObjectReader<Author> authorReader) { ObjectReader<Author> authorReader) {
this.messageIdReader = messageIdReader; this.messageIdReader = messageIdReader;
this.groupReader = groupReader; this.groupReader = groupReader;
this.authorReader = authorReader; this.authorReader = authorReader;
keyParser = crypto.getKeyParser();
signature = crypto.getSignature();
messageDigest = crypto.getMessageDigest();
} }
public Message readObject(Reader r) throws IOException { public UnverifiedMessage readObject(Reader r) throws IOException {
CopyingConsumer copying = new CopyingConsumer(); CopyingConsumer copying = new CopyingConsumer();
CountingConsumer counting = CountingConsumer counting =
new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH); new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
@@ -106,35 +90,8 @@ class MessageReader implements ObjectReader<Message> {
r.removeConsumer(counting); r.removeConsumer(counting);
r.removeConsumer(copying); r.removeConsumer(copying);
byte[] raw = copying.getCopy(); byte[] raw = copying.getCopy();
// Verify the author's signature, if there is one return new UnverifiedMessageImpl(parent, group, author, subject,
if(author != null) { timestamp, raw, authorSig, groupSig, bodyStart, body.length,
try { signedByAuthor, signedByGroup);
PublicKey k = keyParser.parsePublicKey(author.getPublicKey());
signature.initVerify(k);
signature.update(raw, 0, signedByAuthor);
if(!signature.verify(authorSig)) throw new FormatException();
} catch(GeneralSecurityException e) {
throw new FormatException();
}
}
// Verify the group's signature, if there is one
if(group != null && group.getPublicKey() != null) {
try {
PublicKey k = keyParser.parsePublicKey(group.getPublicKey());
signature.initVerify(k);
signature.update(raw, 0, signedByGroup);
if(!signature.verify(groupSig)) throw new FormatException();
} catch(GeneralSecurityException e) {
throw new FormatException();
}
}
// Hash the message, including the signatures, to get the message ID
messageDigest.reset();
messageDigest.update(raw);
MessageId id = new MessageId(messageDigest.digest());
GroupId groupId = group == null ? null : group.getId();
AuthorId authorId = author == null ? null : author.getId();
return new MessageImpl(id, parent, groupId, authorId, subject,
timestamp, raw, bodyStart, body.length);
} }
} }

View File

@@ -4,10 +4,8 @@ import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Author; import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.AuthorFactory; import net.sf.briar.api.protocol.AuthorFactory;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.Group; import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupFactory; import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageFactory; import net.sf.briar.api.protocol.MessageFactory;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
@@ -15,6 +13,7 @@ import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.ObjectReader;
import com.google.inject.AbstractModule; import com.google.inject.AbstractModule;
@@ -26,14 +25,15 @@ public class ProtocolModule extends AbstractModule {
protected void configure() { protected void configure() {
bind(AckFactory.class).to(AckFactoryImpl.class); bind(AckFactory.class).to(AckFactoryImpl.class);
bind(AuthorFactory.class).to(AuthorFactoryImpl.class); bind(AuthorFactory.class).to(AuthorFactoryImpl.class);
bind(BatchFactory.class).to(BatchFactoryImpl.class);
bind(GroupFactory.class).to(GroupFactoryImpl.class); bind(GroupFactory.class).to(GroupFactoryImpl.class);
bind(MessageFactory.class).to(MessageFactoryImpl.class); bind(MessageFactory.class).to(MessageFactoryImpl.class);
bind(OfferFactory.class).to(OfferFactoryImpl.class); bind(OfferFactory.class).to(OfferFactoryImpl.class);
bind(ProtocolReaderFactory.class).to(ProtocolReaderFactoryImpl.class); bind(ProtocolReaderFactory.class).to(ProtocolReaderFactoryImpl.class);
bind(RequestFactory.class).to(RequestFactoryImpl.class); bind(RequestFactory.class).to(RequestFactoryImpl.class);
bind(SubscriptionUpdateFactory.class).to(SubscriptionUpdateFactoryImpl.class); bind(SubscriptionUpdateFactory.class).to(
SubscriptionUpdateFactoryImpl.class);
bind(TransportUpdateFactory.class).to(TransportUpdateFactoryImpl.class); bind(TransportUpdateFactory.class).to(TransportUpdateFactoryImpl.class);
bind(UnverifiedBatchFactory.class).to(UnverifiedBatchFactoryImpl.class);
} }
@Provides @Provides
@@ -48,8 +48,9 @@ public class ProtocolModule extends AbstractModule {
} }
@Provides @Provides
ObjectReader<Batch> getBatchReader(CryptoComponent crypto, ObjectReader<UnverifiedBatch> getBatchReader(CryptoComponent crypto,
ObjectReader<Message> messageReader, BatchFactory batchFactory) { ObjectReader<UnverifiedMessage> messageReader,
UnverifiedBatchFactory batchFactory) {
return new BatchReader(crypto, messageReader, batchFactory); return new BatchReader(crypto, messageReader, batchFactory);
} }
@@ -65,12 +66,11 @@ public class ProtocolModule extends AbstractModule {
} }
@Provides @Provides
ObjectReader<Message> getMessageReader(CryptoComponent crypto, ObjectReader<UnverifiedMessage> getMessageReader(
ObjectReader<MessageId> messageIdReader, ObjectReader<MessageId> messageIdReader,
ObjectReader<Group> groupReader, ObjectReader<Group> groupReader,
ObjectReader<Author> authorReader) { ObjectReader<Author> authorReader) {
return new MessageReader(crypto, messageIdReader, groupReader, return new MessageReader(messageIdReader, groupReader, authorReader);
authorReader);
} }
@Provides @Provides

View File

@@ -3,13 +3,13 @@ package net.sf.briar.protocol;
import java.io.InputStream; import java.io.InputStream;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReader;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.serial.ReaderFactory;
@@ -20,7 +20,7 @@ class ProtocolReaderFactoryImpl implements ProtocolReaderFactory {
private final ReaderFactory readerFactory; private final ReaderFactory readerFactory;
private final Provider<ObjectReader<Ack>> ackProvider; private final Provider<ObjectReader<Ack>> ackProvider;
private final Provider<ObjectReader<Batch>> batchProvider; private final Provider<ObjectReader<UnverifiedBatch>> batchProvider;
private final Provider<ObjectReader<Offer>> offerProvider; private final Provider<ObjectReader<Offer>> offerProvider;
private final Provider<ObjectReader<Request>> requestProvider; private final Provider<ObjectReader<Request>> requestProvider;
private final Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider; private final Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider;
@@ -29,7 +29,7 @@ class ProtocolReaderFactoryImpl implements ProtocolReaderFactory {
@Inject @Inject
ProtocolReaderFactoryImpl(ReaderFactory readerFactory, ProtocolReaderFactoryImpl(ReaderFactory readerFactory,
Provider<ObjectReader<Ack>> ackProvider, Provider<ObjectReader<Ack>> ackProvider,
Provider<ObjectReader<Batch>> batchProvider, Provider<ObjectReader<UnverifiedBatch>> batchProvider,
Provider<ObjectReader<Offer>> offerProvider, Provider<ObjectReader<Offer>> offerProvider,
Provider<ObjectReader<Request>> requestProvider, Provider<ObjectReader<Request>> requestProvider,
Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider, Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider,

View File

@@ -4,13 +4,13 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReader;
import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.serial.ReaderFactory;
@@ -20,7 +20,8 @@ class ProtocolReaderImpl implements ProtocolReader {
private final Reader reader; private final Reader reader;
ProtocolReaderImpl(InputStream in, ReaderFactory readerFactory, ProtocolReaderImpl(InputStream in, ReaderFactory readerFactory,
ObjectReader<Ack> ackReader, ObjectReader<Batch> batchReader, ObjectReader<Ack> ackReader,
ObjectReader<UnverifiedBatch> batchReader,
ObjectReader<Offer> offerReader, ObjectReader<Offer> offerReader,
ObjectReader<Request> requestReader, ObjectReader<Request> requestReader,
ObjectReader<SubscriptionUpdate> subscriptionReader, ObjectReader<SubscriptionUpdate> subscriptionReader,
@@ -50,8 +51,8 @@ class ProtocolReaderImpl implements ProtocolReader {
return reader.hasStruct(Types.BATCH); return reader.hasStruct(Types.BATCH);
} }
public Batch readBatch() throws IOException { public UnverifiedBatch readBatch() throws IOException {
return reader.readStruct(Types.BATCH, Batch.class); return reader.readStruct(Types.BATCH, UnverifiedBatch.class);
} }
public boolean hasOffer() throws IOException { public boolean hasOffer() throws IOException {

View File

@@ -0,0 +1,12 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.UnverifiedBatch;
interface UnverifiedBatchFactory {
UnverifiedBatch createUnverifiedBatch(BatchId id,
Collection<UnverifiedMessage> messages);
}

View File

@@ -0,0 +1,24 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.UnverifiedBatch;
import com.google.inject.Inject;
class UnverifiedBatchFactoryImpl implements UnverifiedBatchFactory {
private final CryptoComponent crypto;
@Inject
UnverifiedBatchFactoryImpl(CryptoComponent crypto) {
this.crypto = crypto;
}
public UnverifiedBatch createUnverifiedBatch(BatchId id,
Collection<UnverifiedMessage> messages) {
return new UnverifiedBatchImpl(crypto, id, messages);
}
}

View File

@@ -0,0 +1,83 @@
package net.sf.briar.protocol;
import java.security.GeneralSecurityException;
import java.security.PublicKey;
import java.security.Signature;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.KeyParser;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.UnverifiedBatch;
class UnverifiedBatchImpl implements UnverifiedBatch {
private final CryptoComponent crypto;
private final BatchId id;
private final Collection<UnverifiedMessage> messages;
// Initialise lazily - the batch may be empty or contain unsigned messages
private MessageDigest messageDigest = null;
private KeyParser keyParser = null;
private Signature signature = null;
UnverifiedBatchImpl(CryptoComponent crypto, BatchId id,
Collection<UnverifiedMessage> messages) {
this.crypto = crypto;
this.id = id;
this.messages = messages;
}
public Batch verify() throws GeneralSecurityException {
List<Message> verified = new ArrayList<Message>();
for(UnverifiedMessage m : messages) verified.add(verify(m));
return new BatchImpl(id, Collections.unmodifiableList(verified));
}
private Message verify(UnverifiedMessage m)
throws GeneralSecurityException {
// Hash the message, including the signatures, to get the message ID
byte[] raw = m.getRaw();
if(messageDigest == null) messageDigest = crypto.getMessageDigest();
messageDigest.update(raw);
MessageId id = new MessageId(messageDigest.digest());
// Verify the author's signature, if there is one
Author author = m.getAuthor();
if(author != null) {
if(keyParser == null) keyParser = crypto.getKeyParser();
PublicKey k = keyParser.parsePublicKey(author.getPublicKey());
if(signature == null) signature = crypto.getSignature();
signature.initVerify(k);
signature.update(raw, 0, m.getLengthSignedByAuthor());
if(!signature.verify(m.getAuthorSignature()))
throw new GeneralSecurityException();
}
// Verify the group's signature, if there is one
Group group = m.getGroup();
if(group != null && group.getPublicKey() != null) {
if(keyParser == null) keyParser = crypto.getKeyParser();
PublicKey k = keyParser.parsePublicKey(group.getPublicKey());
if(signature == null) signature = crypto.getSignature();
signature.initVerify(k);
signature.update(raw, 0, m.getLengthSignedByGroup());
if(!signature.verify(m.getGroupSignature()))
throw new GeneralSecurityException();
}
GroupId groupId = group == null ? null : group.getId();
AuthorId authorId = author == null ? null : author.getId();
return new MessageImpl(id, m.getParent(), groupId, authorId,
m.getSubject(), m.getTimestamp(), raw, m.getBodyStart(),
m.getBodyLength());
}
}

View File

@@ -0,0 +1,32 @@
package net.sf.briar.protocol;
import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.MessageId;
interface UnverifiedMessage {
MessageId getParent();
Group getGroup();
Author getAuthor();
String getSubject();
long getTimestamp();
byte[] getRaw();
byte[] getAuthorSignature();
byte[] getGroupSignature();
int getBodyStart();
int getBodyLength();
int getLengthSignedByAuthor();
int getLengthSignedByGroup();
}

View File

@@ -0,0 +1,82 @@
package net.sf.briar.protocol;
import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.MessageId;
class UnverifiedMessageImpl implements UnverifiedMessage {
private final MessageId parent;
private final Group group;
private final Author author;
private final String subject;
private final long timestamp;
private final byte[] raw, authorSig, groupSig;
private final int bodyStart, bodyLength, signedByAuthor, signedByGroup;
UnverifiedMessageImpl(MessageId parent, Group group, Author author,
String subject, long timestamp, byte[] raw, byte[] authorSig,
byte[] groupSig, int bodyStart, int bodyLength, int signedByAuthor,
int signedByGroup) {
this.parent = parent;
this.group = group;
this.author = author;
this.subject = subject;
this.timestamp = timestamp;
this.raw = raw;
this.authorSig = authorSig;
this.groupSig = groupSig;
this.bodyStart = bodyStart;
this.bodyLength = bodyLength;
this.signedByAuthor = signedByAuthor;
this.signedByGroup = signedByGroup;
}
public MessageId getParent() {
return parent;
}
public Group getGroup() {
return group;
}
public Author getAuthor() {
return author;
}
public String getSubject() {
return subject;
}
public long getTimestamp() {
return timestamp;
}
public byte[] getRaw() {
return raw;
}
public byte[] getAuthorSignature() {
return authorSig;
}
public byte[] getGroupSignature() {
return groupSig;
}
public int getBodyStart() {
return bodyStart;
}
public int getBodyLength() {
return bodyLength;
}
public int getLengthSignedByAuthor() {
return signedByAuthor;
}
public int getLengthSignedByGroup() {
return signedByGroup;
}
}

View File

@@ -1,5 +1,7 @@
package net.sf.briar.transport.batch; package net.sf.briar.transport.batch;
import java.util.concurrent.Executor;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
@@ -16,6 +18,7 @@ import com.google.inject.Inject;
class BatchConnectionFactoryImpl implements BatchConnectionFactory { class BatchConnectionFactoryImpl implements BatchConnectionFactory {
private final Executor executor;
private final ConnectionReaderFactory connReaderFactory; private final ConnectionReaderFactory connReaderFactory;
private final ConnectionWriterFactory connWriterFactory; private final ConnectionWriterFactory connWriterFactory;
private final DatabaseComponent db; private final DatabaseComponent db;
@@ -23,10 +26,12 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory {
private final ProtocolWriterFactory protoWriterFactory; private final ProtocolWriterFactory protoWriterFactory;
@Inject @Inject
BatchConnectionFactoryImpl(ConnectionReaderFactory connReaderFactory, BatchConnectionFactoryImpl(Executor executor,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory) { ProtocolWriterFactory protoWriterFactory) {
this.executor = executor;
this.connReaderFactory = connReaderFactory; this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory; this.connWriterFactory = connWriterFactory;
this.db = db; this.db = db;
@@ -37,7 +42,8 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory {
public void createIncomingConnection(ConnectionContext ctx, public void createIncomingConnection(ConnectionContext ctx,
BatchTransportReader r, byte[] tag) { BatchTransportReader r, byte[] tag) {
final IncomingBatchConnection conn = new IncomingBatchConnection( final IncomingBatchConnection conn = new IncomingBatchConnection(
connReaderFactory, db, protoReaderFactory, ctx, r, tag); executor, connReaderFactory, db, protoReaderFactory, ctx, r,
tag);
Runnable read = new Runnable() { Runnable read = new Runnable() {
public void run() { public void run() {
conn.read(); conn.read();

View File

@@ -1,6 +1,8 @@
package net.sf.briar.transport.batch; package net.sf.briar.transport.batch;
import java.io.IOException; import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.concurrent.Executor;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@@ -9,11 +11,11 @@ import net.sf.briar.api.FormatException;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReader;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.transport.BatchTransportReader; import net.sf.briar.api.transport.BatchTransportReader;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
@@ -24,6 +26,7 @@ class IncomingBatchConnection {
private static final Logger LOG = private static final Logger LOG =
Logger.getLogger(IncomingBatchConnection.class.getName()); Logger.getLogger(IncomingBatchConnection.class.getName());
private final Executor executor;
private final ConnectionReaderFactory connFactory; private final ConnectionReaderFactory connFactory;
private final DatabaseComponent db; private final DatabaseComponent db;
private final ProtocolReaderFactory protoFactory; private final ProtocolReaderFactory protoFactory;
@@ -31,9 +34,11 @@ class IncomingBatchConnection {
private final BatchTransportReader reader; private final BatchTransportReader reader;
private final byte[] tag; private final byte[] tag;
IncomingBatchConnection(ConnectionReaderFactory connFactory, IncomingBatchConnection(Executor executor,
ConnectionReaderFactory connFactory,
DatabaseComponent db, ProtocolReaderFactory protoFactory, DatabaseComponent db, ProtocolReaderFactory protoFactory,
ConnectionContext ctx, BatchTransportReader reader, byte[] tag) { ConnectionContext ctx, BatchTransportReader reader, byte[] tag) {
this.executor = executor;
this.connFactory = connFactory; this.connFactory = connFactory;
this.db = db; this.db = db;
this.protoFactory = protoFactory; this.protoFactory = protoFactory;
@@ -48,28 +53,68 @@ class IncomingBatchConnection {
reader.getInputStream(), ctx.getSecret(), tag); reader.getInputStream(), ctx.getSecret(), tag);
ProtocolReader proto = protoFactory.createProtocolReader( ProtocolReader proto = protoFactory.createProtocolReader(
conn.getInputStream()); conn.getInputStream());
ContactId c = ctx.getContactId(); final ContactId c = ctx.getContactId();
// Read packets until EOF // Read packets until EOF
while(!proto.eof()) { while(!proto.eof()) {
if(proto.hasAck()) { if(proto.hasAck()) {
Ack a = proto.readAck(); final Ack a = proto.readAck();
db.receiveAck(c, a); // Store the ack on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveAck(c, a);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else if(proto.hasBatch()) { } else if(proto.hasBatch()) {
Batch b = proto.readBatch(); final UnverifiedBatch b = proto.readBatch();
db.receiveBatch(c, b); // Verify and store the batch on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveBatch(c, b.verify());
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
} catch(GeneralSecurityException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else if(proto.hasSubscriptionUpdate()) { } else if(proto.hasSubscriptionUpdate()) {
SubscriptionUpdate s = proto.readSubscriptionUpdate(); final SubscriptionUpdate s = proto.readSubscriptionUpdate();
db.receiveSubscriptionUpdate(c, s); // Store the update on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveSubscriptionUpdate(c, s);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else if(proto.hasTransportUpdate()) { } else if(proto.hasTransportUpdate()) {
TransportUpdate t = proto.readTransportUpdate(); final TransportUpdate t = proto.readTransportUpdate();
db.receiveTransportUpdate(c, t); // Store the update on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveTransportUpdate(c, t);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else { } else {
throw new FormatException(); throw new FormatException();
} }
} }
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
reader.dispose(false);
} catch(IOException e) { } catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
reader.dispose(false); reader.dispose(false);

View File

@@ -1,6 +1,7 @@
package net.sf.briar.transport.stream; package net.sf.briar.transport.stream;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.Executor;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
@@ -18,14 +19,16 @@ class IncomingStreamConnection extends StreamConnection {
private final ConnectionContext ctx; private final ConnectionContext ctx;
private final byte[] tag; private final byte[] tag;
IncomingStreamConnection(ConnectionReaderFactory connReaderFactory, IncomingStreamConnection(Executor executor,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ProtocolWriterFactory protoWriterFactory,
ConnectionContext ctx, StreamTransportConnection connection, ConnectionContext ctx, StreamTransportConnection connection,
byte[] tag) { byte[] tag) {
super(connReaderFactory, connWriterFactory, db, protoReaderFactory, super(executor, connReaderFactory, connWriterFactory, db,
protoWriterFactory, ctx.getContactId(), connection); protoReaderFactory, protoWriterFactory, ctx.getContactId(),
connection);
this.ctx = ctx; this.ctx = ctx;
this.tag = tag; this.tag = tag;
} }

View File

@@ -1,6 +1,7 @@
package net.sf.briar.transport.stream; package net.sf.briar.transport.stream;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.Executor;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
@@ -21,14 +22,15 @@ class OutgoingStreamConnection extends StreamConnection {
private ConnectionContext ctx = null; // Locking: this private ConnectionContext ctx = null; // Locking: this
OutgoingStreamConnection(ConnectionReaderFactory connReaderFactory, OutgoingStreamConnection(Executor executor,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId, ProtocolWriterFactory protoWriterFactory, ContactId contactId,
TransportIndex transportIndex, TransportIndex transportIndex,
StreamTransportConnection connection) { StreamTransportConnection connection) {
super(connReaderFactory, connWriterFactory, db, protoReaderFactory, super(executor, connReaderFactory, connWriterFactory, db,
protoWriterFactory, contactId, connection); protoReaderFactory, protoWriterFactory, contactId, connection);
this.transportIndex = transportIndex; this.transportIndex = transportIndex;
} }

View File

@@ -3,12 +3,14 @@ package net.sf.briar.transport.stream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.security.GeneralSecurityException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.BitSet; import java.util.BitSet;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.concurrent.Executor;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@@ -24,7 +26,6 @@ import net.sf.briar.api.db.event.LocalTransportsUpdatedEvent;
import net.sf.briar.api.db.event.MessagesAddedEvent; import net.sf.briar.api.db.event.MessagesAddedEvent;
import net.sf.briar.api.db.event.SubscriptionsUpdatedEvent; import net.sf.briar.api.db.event.SubscriptionsUpdatedEvent;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReader;
@@ -32,6 +33,7 @@ import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.protocol.writers.AckWriter; import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.protocol.writers.BatchWriter; import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.OfferWriter; import net.sf.briar.api.protocol.writers.OfferWriter;
@@ -52,6 +54,7 @@ abstract class StreamConnection implements DatabaseListener {
private static final Logger LOG = private static final Logger LOG =
Logger.getLogger(StreamConnection.class.getName()); Logger.getLogger(StreamConnection.class.getName());
protected final Executor executor;
protected final ConnectionReaderFactory connReaderFactory; protected final ConnectionReaderFactory connReaderFactory;
protected final ConnectionWriterFactory connWriterFactory; protected final ConnectionWriterFactory connWriterFactory;
protected final DatabaseComponent db; protected final DatabaseComponent db;
@@ -65,11 +68,13 @@ abstract class StreamConnection implements DatabaseListener {
private LinkedList<MessageId> requested = null; // Locking: this private LinkedList<MessageId> requested = null; // Locking: this
private Offer incomingOffer = null; // Locking: this private Offer incomingOffer = null; // Locking: this
StreamConnection(ConnectionReaderFactory connReaderFactory, StreamConnection(Executor executor,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId, ProtocolWriterFactory protoWriterFactory, ContactId contactId,
StreamTransportConnection connection) { StreamTransportConnection connection) {
this.executor = executor;
this.connReaderFactory = connReaderFactory; this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory; this.connWriterFactory = connWriterFactory;
this.db = db; this.db = db;
@@ -119,11 +124,34 @@ abstract class StreamConnection implements DatabaseListener {
ProtocolReader proto = protoReaderFactory.createProtocolReader(in); ProtocolReader proto = protoReaderFactory.createProtocolReader(in);
while(!proto.eof()) { while(!proto.eof()) {
if(proto.hasAck()) { if(proto.hasAck()) {
Ack a = proto.readAck(); final Ack a = proto.readAck();
db.receiveAck(contactId, a); // Store the ack on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveAck(contactId, a);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else if(proto.hasBatch()) { } else if(proto.hasBatch()) {
Batch b = proto.readBatch(); final UnverifiedBatch b = proto.readBatch();
db.receiveBatch(contactId, b); // Verify and store the batch on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveBatch(contactId, b.verify());
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
} catch(GeneralSecurityException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else if(proto.hasOffer()) { } else if(proto.hasOffer()) {
Offer o = proto.readOffer(); Offer o = proto.readOffer();
// Store the incoming offer and notify the writer // Store the incoming offer and notify the writer
@@ -151,8 +179,19 @@ abstract class StreamConnection implements DatabaseListener {
if(b.get(i++)) req.add(m); if(b.get(i++)) req.add(m);
else seen.add(m); else seen.add(m);
} }
// Mark the unrequested messages as seen // Mark the unrequested messages as seen on another thread
db.setSeen(contactId, Collections.unmodifiableList(seen)); final List<MessageId> l =
Collections.unmodifiableList(seen);
executor.execute(new Runnable() {
public void run() {
try {
db.setSeen(contactId, l);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
// Store the requested message IDs and notify the writer // Store the requested message IDs and notify the writer
synchronized(this) { synchronized(this) {
if(requested != null) if(requested != null)
@@ -162,11 +201,31 @@ abstract class StreamConnection implements DatabaseListener {
notifyAll(); notifyAll();
} }
} else if(proto.hasSubscriptionUpdate()) { } else if(proto.hasSubscriptionUpdate()) {
SubscriptionUpdate s = proto.readSubscriptionUpdate(); final SubscriptionUpdate s = proto.readSubscriptionUpdate();
db.receiveSubscriptionUpdate(contactId, s); // Store the update on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveSubscriptionUpdate(contactId, s);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else if(proto.hasTransportUpdate()) { } else if(proto.hasTransportUpdate()) {
TransportUpdate t = proto.readTransportUpdate(); final TransportUpdate t = proto.readTransportUpdate();
db.receiveTransportUpdate(contactId, t); // Store the update on another thread
executor.execute(new Runnable() {
public void run() {
try {
db.receiveTransportUpdate(contactId, t);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning(e.getMessage());
}
}
});
} else { } else {
throw new FormatException(); throw new FormatException();
} }

View File

@@ -1,5 +1,7 @@
package net.sf.briar.transport.stream; package net.sf.briar.transport.stream;
import java.util.concurrent.Executor;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
@@ -15,6 +17,7 @@ import com.google.inject.Inject;
class StreamConnectionFactoryImpl implements StreamConnectionFactory { class StreamConnectionFactoryImpl implements StreamConnectionFactory {
private final Executor executor;
private final ConnectionReaderFactory connReaderFactory; private final ConnectionReaderFactory connReaderFactory;
private final ConnectionWriterFactory connWriterFactory; private final ConnectionWriterFactory connWriterFactory;
private final DatabaseComponent db; private final DatabaseComponent db;
@@ -22,10 +25,12 @@ class StreamConnectionFactoryImpl implements StreamConnectionFactory {
private final ProtocolWriterFactory protoWriterFactory; private final ProtocolWriterFactory protoWriterFactory;
@Inject @Inject
StreamConnectionFactoryImpl(ConnectionReaderFactory connReaderFactory, StreamConnectionFactoryImpl(Executor executor,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory) { ProtocolWriterFactory protoWriterFactory) {
this.executor = executor;
this.connReaderFactory = connReaderFactory; this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory; this.connWriterFactory = connWriterFactory;
this.db = db; this.db = db;
@@ -35,7 +40,7 @@ class StreamConnectionFactoryImpl implements StreamConnectionFactory {
public void createIncomingConnection(ConnectionContext ctx, public void createIncomingConnection(ConnectionContext ctx,
StreamTransportConnection s, byte[] tag) { StreamTransportConnection s, byte[] tag) {
final StreamConnection conn = new IncomingStreamConnection( final StreamConnection conn = new IncomingStreamConnection(executor,
connReaderFactory, connWriterFactory, db, protoReaderFactory, connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, ctx, s, tag); protoWriterFactory, ctx, s, tag);
Runnable write = new Runnable() { Runnable write = new Runnable() {
@@ -54,7 +59,7 @@ class StreamConnectionFactoryImpl implements StreamConnectionFactory {
public void createOutgoingConnection(ContactId c, TransportIndex i, public void createOutgoingConnection(ContactId c, TransportIndex i,
StreamTransportConnection s) { StreamTransportConnection s) {
final StreamConnection conn = new OutgoingStreamConnection( final StreamConnection conn = new OutgoingStreamConnection(executor,
connReaderFactory, connWriterFactory, db, protoReaderFactory, connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, c, i, s); protoWriterFactory, c, i, s);
Runnable write = new Runnable() { Runnable write = new Runnable() {

View File

@@ -208,9 +208,9 @@ public class ProtocolIntegrationTest extends TestCase {
Ack a = protocolReader.readAck(); Ack a = protocolReader.readAck();
assertEquals(Collections.singletonList(ack), a.getBatchIds()); assertEquals(Collections.singletonList(ack), a.getBatchIds());
// Read the batch // Read and verify the batch
assertTrue(protocolReader.hasBatch()); assertTrue(protocolReader.hasBatch());
Batch b = protocolReader.readBatch(); Batch b = protocolReader.readBatch().verify();
Collection<Message> messages = b.getMessages(); Collection<Message> messages = b.getMessages();
assertEquals(4, messages.size()); assertEquals(4, messages.size());
Iterator<Message> it = messages.iterator(); Iterator<Message> it = messages.iterator();

View File

@@ -9,11 +9,10 @@ import junit.framework.TestCase;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.MessageDigest; import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.serial.ReaderFactory;
@@ -35,7 +34,8 @@ public class BatchReaderTest extends TestCase {
private final WriterFactory writerFactory; private final WriterFactory writerFactory;
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final Mockery context; private final Mockery context;
private final Message message; private final UnverifiedMessage message;
private final ObjectReader<UnverifiedMessage> messageReader;
public BatchReaderTest() throws Exception { public BatchReaderTest() throws Exception {
super(); super();
@@ -45,13 +45,14 @@ public class BatchReaderTest extends TestCase {
writerFactory = i.getInstance(WriterFactory.class); writerFactory = i.getInstance(WriterFactory.class);
crypto = i.getInstance(CryptoComponent.class); crypto = i.getInstance(CryptoComponent.class);
context = new Mockery(); context = new Mockery();
message = context.mock(Message.class); message = context.mock(UnverifiedMessage.class);
messageReader = new TestMessageReader();
} }
@Test @Test
public void testFormatExceptionIfBatchIsTooLarge() throws Exception { public void testFormatExceptionIfBatchIsTooLarge() throws Exception {
ObjectReader<Message> messageReader = new TestMessageReader(); UnverifiedBatchFactory batchFactory =
BatchFactory batchFactory = context.mock(BatchFactory.class); context.mock(UnverifiedBatchFactory.class);
BatchReader batchReader = new BatchReader(crypto, messageReader, BatchReader batchReader = new BatchReader(crypto, messageReader,
batchFactory); batchFactory);
@@ -61,7 +62,7 @@ public class BatchReaderTest extends TestCase {
reader.addObjectReader(Types.BATCH, batchReader); reader.addObjectReader(Types.BATCH, batchReader);
try { try {
reader.readStruct(Types.BATCH, Batch.class); reader.readStruct(Types.BATCH, UnverifiedBatch.class);
fail(); fail();
} catch(FormatException expected) {} } catch(FormatException expected) {}
context.assertIsSatisfied(); context.assertIsSatisfied();
@@ -69,13 +70,13 @@ public class BatchReaderTest extends TestCase {
@Test @Test
public void testNoFormatExceptionIfBatchIsMaximumSize() throws Exception { public void testNoFormatExceptionIfBatchIsMaximumSize() throws Exception {
ObjectReader<Message> messageReader = new TestMessageReader(); final UnverifiedBatchFactory batchFactory =
final BatchFactory batchFactory = context.mock(BatchFactory.class); context.mock(UnverifiedBatchFactory.class);
BatchReader batchReader = new BatchReader(crypto, messageReader, BatchReader batchReader = new BatchReader(crypto, messageReader,
batchFactory); batchFactory);
final Batch batch = context.mock(Batch.class); final UnverifiedBatch batch = context.mock(UnverifiedBatch.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(batchFactory).createBatch(with(any(BatchId.class)), oneOf(batchFactory).createUnverifiedBatch(with(any(BatchId.class)),
with(Collections.singletonList(message))); with(Collections.singletonList(message)));
will(returnValue(batch)); will(returnValue(batch));
}}); }});
@@ -85,7 +86,8 @@ public class BatchReaderTest extends TestCase {
Reader reader = readerFactory.createReader(in); Reader reader = readerFactory.createReader(in);
reader.addObjectReader(Types.BATCH, batchReader); reader.addObjectReader(Types.BATCH, batchReader);
assertEquals(batch, reader.readStruct(Types.BATCH, Batch.class)); assertEquals(batch, reader.readStruct(Types.BATCH,
UnverifiedBatch.class));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -98,14 +100,14 @@ public class BatchReaderTest extends TestCase {
messageDigest.update(b); messageDigest.update(b);
final BatchId id = new BatchId(messageDigest.digest()); final BatchId id = new BatchId(messageDigest.digest());
ObjectReader<Message> messageReader = new TestMessageReader(); final UnverifiedBatchFactory batchFactory =
final BatchFactory batchFactory = context.mock(BatchFactory.class); context.mock(UnverifiedBatchFactory.class);
BatchReader batchReader = new BatchReader(crypto, messageReader, BatchReader batchReader = new BatchReader(crypto, messageReader,
batchFactory); batchFactory);
final Batch batch = context.mock(Batch.class); final UnverifiedBatch batch = context.mock(UnverifiedBatch.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// Check that the batch ID matches the expected ID // Check that the batch ID matches the expected ID
oneOf(batchFactory).createBatch(with(id), oneOf(batchFactory).createUnverifiedBatch(with(id),
with(Collections.singletonList(message))); with(Collections.singletonList(message)));
will(returnValue(batch)); will(returnValue(batch));
}}); }});
@@ -114,20 +116,21 @@ public class BatchReaderTest extends TestCase {
Reader reader = readerFactory.createReader(in); Reader reader = readerFactory.createReader(in);
reader.addObjectReader(Types.BATCH, batchReader); reader.addObjectReader(Types.BATCH, batchReader);
assertEquals(batch, reader.readStruct(Types.BATCH, Batch.class)); assertEquals(batch, reader.readStruct(Types.BATCH,
UnverifiedBatch.class));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@Test @Test
public void testEmptyBatch() throws Exception { public void testEmptyBatch() throws Exception {
ObjectReader<Message> messageReader = new TestMessageReader(); final UnverifiedBatchFactory batchFactory =
final BatchFactory batchFactory = context.mock(BatchFactory.class); context.mock(UnverifiedBatchFactory.class);
BatchReader batchReader = new BatchReader(crypto, messageReader, BatchReader batchReader = new BatchReader(crypto, messageReader,
batchFactory); batchFactory);
final Batch batch = context.mock(Batch.class); final UnverifiedBatch batch = context.mock(UnverifiedBatch.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(batchFactory).createBatch(with(any(BatchId.class)), oneOf(batchFactory).createUnverifiedBatch(with(any(BatchId.class)),
with(Collections.<Message>emptyList())); with(Collections.<UnverifiedMessage>emptyList()));
will(returnValue(batch)); will(returnValue(batch));
}}); }});
@@ -136,7 +139,8 @@ public class BatchReaderTest extends TestCase {
Reader reader = readerFactory.createReader(in); Reader reader = readerFactory.createReader(in);
reader.addObjectReader(Types.BATCH, batchReader); reader.addObjectReader(Types.BATCH, batchReader);
assertEquals(batch, reader.readStruct(Types.BATCH, Batch.class)); assertEquals(batch, reader.readStruct(Types.BATCH,
UnverifiedBatch.class));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -163,9 +167,9 @@ public class BatchReaderTest extends TestCase {
return out.toByteArray(); return out.toByteArray();
} }
private class TestMessageReader implements ObjectReader<Message> { private class TestMessageReader implements ObjectReader<UnverifiedMessage> {
public Message readObject(Reader r) throws IOException { public UnverifiedMessage readObject(Reader r) throws IOException {
r.readStructId(Types.MESSAGE); r.readStructId(Types.MESSAGE);
r.readBytes(); r.readBytes();
return message; return message;

View File

@@ -114,7 +114,7 @@ public class ProtocolReadWriteTest extends TestCase {
Ack ack = reader.readAck(); Ack ack = reader.readAck();
assertEquals(Collections.singletonList(batchId), ack.getBatchIds()); assertEquals(Collections.singletonList(batchId), ack.getBatchIds());
Batch batch = reader.readBatch(); Batch batch = reader.readBatch().verify();
assertEquals(Collections.singletonList(message), batch.getMessages()); assertEquals(Collections.singletonList(message), batch.getMessages());
Offer offer = reader.readOffer(); Offer offer = reader.readOffer();

View File

@@ -41,6 +41,7 @@ import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.db.DatabaseModule; import net.sf.briar.db.DatabaseModule;
import net.sf.briar.lifecycle.LifecycleModule; import net.sf.briar.lifecycle.LifecycleModule;
import net.sf.briar.plugins.ImmediateExecutor;
import net.sf.briar.protocol.ProtocolModule; import net.sf.briar.protocol.ProtocolModule;
import net.sf.briar.protocol.writers.ProtocolWritersModule; import net.sf.briar.protocol.writers.ProtocolWritersModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
@@ -187,7 +188,8 @@ public class BatchConnectionReadWriteTest extends TestCase {
bob.getInstance(ProtocolReaderFactory.class); bob.getInstance(ProtocolReaderFactory.class);
BatchTransportReader reader = new TestBatchTransportReader(in); BatchTransportReader reader = new TestBatchTransportReader(in);
IncomingBatchConnection batchIn = new IncomingBatchConnection( IncomingBatchConnection batchIn = new IncomingBatchConnection(
connFactory, db, protoFactory, ctx, reader, tag); new ImmediateExecutor(), connFactory, db, protoFactory, ctx,
reader, tag);
// No messages should have been added yet // No messages should have been added yet
assertFalse(listener.messagesAdded); assertFalse(listener.messagesAdded);
// Read whatever needs to be read // Read whatever needs to be read