Moved batch ID calculation off the IO thread.

This commit is contained in:
akwizgran
2011-12-08 12:51:34 +00:00
parent d91f96b5e2
commit ae87100c8f
14 changed files with 306 additions and 109 deletions

View File

@@ -1,5 +1,7 @@
package net.sf.briar.protocol; package net.sf.briar.protocol;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@@ -10,7 +12,6 @@ import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Consumer;
@@ -28,8 +29,7 @@ class AckReader implements ObjectReader<Ack> {
public Ack readObject(Reader r) throws IOException { public Ack readObject(Reader r) throws IOException {
// Initialise the consumer // Initialise the consumer
Consumer counting = Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
// Read the data // Read the data
r.addConsumer(counting); r.addConsumer(counting);
r.readStructId(Types.ACK); r.readStructId(Types.ACK);

View File

@@ -1,52 +1,41 @@
package net.sf.briar.protocol; package net.sf.briar.protocol;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.CountingConsumer; import net.sf.briar.api.serial.CountingConsumer;
import net.sf.briar.api.serial.DigestingConsumer;
import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.Reader;
class BatchReader implements ObjectReader<UnverifiedBatch> { class BatchReader implements ObjectReader<UnverifiedBatch> {
private final MessageDigest messageDigest;
private final ObjectReader<UnverifiedMessage> messageReader; private final ObjectReader<UnverifiedMessage> messageReader;
private final UnverifiedBatchFactory batchFactory; private final UnverifiedBatchFactory batchFactory;
BatchReader(CryptoComponent crypto, BatchReader(ObjectReader<UnverifiedMessage> messageReader,
ObjectReader<UnverifiedMessage> messageReader,
UnverifiedBatchFactory batchFactory) { UnverifiedBatchFactory batchFactory) {
messageDigest = crypto.getMessageDigest();
this.messageReader = messageReader; this.messageReader = messageReader;
this.batchFactory = batchFactory; this.batchFactory = batchFactory;
} }
public UnverifiedBatch readObject(Reader r) throws IOException { public UnverifiedBatch readObject(Reader r) throws IOException {
// Initialise the consumers // Initialise the consumer
Consumer counting = Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH); // Read the data
DigestingConsumer digesting = new DigestingConsumer(messageDigest);
// Read and digest the data
r.addConsumer(counting); r.addConsumer(counting);
r.addConsumer(digesting);
r.readStructId(Types.BATCH); r.readStructId(Types.BATCH);
r.addObjectReader(Types.MESSAGE, messageReader); r.addObjectReader(Types.MESSAGE, messageReader);
List<UnverifiedMessage> messages = r.readList(UnverifiedMessage.class); List<UnverifiedMessage> messages = r.readList(UnverifiedMessage.class);
r.removeObjectReader(Types.MESSAGE); r.removeObjectReader(Types.MESSAGE);
r.removeConsumer(digesting);
r.removeConsumer(counting); r.removeConsumer(counting);
if(messages.isEmpty()) throw new FormatException(); if(messages.isEmpty()) throw new FormatException();
// Build and return the batch // Build and return the batch
BatchId id = new BatchId(messageDigest.digest()); return batchFactory.createUnverifiedBatch( messages);
return batchFactory.createUnverifiedBatch(id, messages);
} }
} }

View File

@@ -1,12 +1,17 @@
package net.sf.briar.protocol; package net.sf.briar.protocol;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_BODY_LENGTH;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_SIGNATURE_LENGTH;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_SUBJECT_LENGTH;
import static net.sf.briar.api.protocol.ProtocolConstants.SALT_LENGTH;
import java.io.IOException; import java.io.IOException;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.Author; import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.Group; import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.CopyingConsumer; import net.sf.briar.api.serial.CopyingConsumer;
@@ -27,8 +32,7 @@ class MessageReader implements ObjectReader<UnverifiedMessage> {
public UnverifiedMessage readObject(Reader r) throws IOException { public UnverifiedMessage readObject(Reader r) throws IOException {
CopyingConsumer copying = new CopyingConsumer(); CopyingConsumer copying = new CopyingConsumer();
CountingConsumer counting = CountingConsumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
r.addConsumer(copying); r.addConsumer(copying);
r.addConsumer(counting); r.addConsumer(counting);
// Read the initial tag // Read the initial tag
@@ -61,16 +65,15 @@ class MessageReader implements ObjectReader<UnverifiedMessage> {
r.removeObjectReader(Types.AUTHOR); r.removeObjectReader(Types.AUTHOR);
} }
// Read the subject // Read the subject
String subject = r.readString(ProtocolConstants.MAX_SUBJECT_LENGTH); String subject = r.readString(MAX_SUBJECT_LENGTH);
// Read the timestamp // Read the timestamp
long timestamp = r.readInt64(); long timestamp = r.readInt64();
if(timestamp < 0L) throw new FormatException(); if(timestamp < 0L) throw new FormatException();
// Read the salt // Read the salt
byte[] salt = r.readBytes(ProtocolConstants.SALT_LENGTH); byte[] salt = r.readBytes(SALT_LENGTH);
if(salt.length != ProtocolConstants.SALT_LENGTH) if(salt.length != SALT_LENGTH) throw new FormatException();
throw new FormatException();
// Read the message body // Read the message body
byte[] body = r.readBytes(ProtocolConstants.MAX_BODY_LENGTH); byte[] body = r.readBytes(MAX_BODY_LENGTH);
// Record the offset of the body within the message // Record the offset of the body within the message
int bodyStart = (int) counting.getCount() - body.length; int bodyStart = (int) counting.getCount() - body.length;
// Record the length of the data covered by the author's signature // Record the length of the data covered by the author's signature
@@ -78,13 +81,13 @@ class MessageReader implements ObjectReader<UnverifiedMessage> {
// Read the author's signature, if there is one // Read the author's signature, if there is one
byte[] authorSig = null; byte[] authorSig = null;
if(author == null) r.readNull(); if(author == null) r.readNull();
else authorSig = r.readBytes(ProtocolConstants.MAX_SIGNATURE_LENGTH); else authorSig = r.readBytes(MAX_SIGNATURE_LENGTH);
// Record the length of the data covered by the group's signature // Record the length of the data covered by the group's signature
int signedByGroup = (int) counting.getCount(); int signedByGroup = (int) counting.getCount();
// Read the group's signature, if there is one // Read the group's signature, if there is one
byte[] groupSig = null; byte[] groupSig = null;
if(group == null || group.getPublicKey() == null) r.readNull(); if(group == null || group.getPublicKey() == null) r.readNull();
else groupSig = r.readBytes(ProtocolConstants.MAX_SIGNATURE_LENGTH); else groupSig = r.readBytes(MAX_SIGNATURE_LENGTH);
// That's all, folks // That's all, folks
r.removeConsumer(counting); r.removeConsumer(counting);
r.removeConsumer(copying); r.removeConsumer(copying);

View File

@@ -1,5 +1,7 @@
package net.sf.briar.protocol; package net.sf.briar.protocol;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@@ -10,7 +12,6 @@ import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Consumer;
@@ -28,8 +29,7 @@ class OfferReader implements ObjectReader<Offer> {
public Offer readObject(Reader r) throws IOException { public Offer readObject(Reader r) throws IOException {
// Initialise the consumer // Initialise the consumer
Consumer counting = Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
// Read the data // Read the data
r.addConsumer(counting); r.addConsumer(counting);
r.readStructId(Types.OFFER); r.readStructId(Types.OFFER);

View File

@@ -51,10 +51,10 @@ public class ProtocolModule extends AbstractModule {
} }
@Provides @Provides
ObjectReader<UnverifiedBatch> getBatchReader(CryptoComponent crypto, ObjectReader<UnverifiedBatch> getBatchReader(
ObjectReader<UnverifiedMessage> messageReader, ObjectReader<UnverifiedMessage> messageReader,
UnverifiedBatchFactory batchFactory) { UnverifiedBatchFactory batchFactory) {
return new BatchReader(crypto, messageReader, batchFactory); return new BatchReader(messageReader, batchFactory);
} }
@Provides @Provides

View File

@@ -1,11 +1,12 @@
package net.sf.briar.protocol; package net.sf.briar.protocol;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
import java.io.IOException; import java.io.IOException;
import java.util.BitSet; import java.util.BitSet;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Consumer;
@@ -23,14 +24,13 @@ class RequestReader implements ObjectReader<Request> {
public Request readObject(Reader r) throws IOException { public Request readObject(Reader r) throws IOException {
// Initialise the consumer // Initialise the consumer
Consumer counting = Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
// Read the data // Read the data
r.addConsumer(counting); r.addConsumer(counting);
r.readStructId(Types.REQUEST); r.readStructId(Types.REQUEST);
int padding = r.readUint7(); int padding = r.readUint7();
if(padding > 7) throw new FormatException(); if(padding > 7) throw new FormatException();
byte[] bitmap = r.readBytes(ProtocolConstants.MAX_PACKET_LENGTH); byte[] bitmap = r.readBytes(MAX_PACKET_LENGTH);
r.removeConsumer(counting); r.removeConsumer(counting);
// Convert the bitmap into a BitSet // Convert the bitmap into a BitSet
int length = bitmap.length * 8 - padding; int length = bitmap.length * 8 - padding;

View File

@@ -1,12 +1,13 @@
package net.sf.briar.protocol; package net.sf.briar.protocol;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
import java.io.IOException; import java.io.IOException;
import java.util.Map; import java.util.Map;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.Group; import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Consumer;
@@ -27,8 +28,7 @@ class SubscriptionUpdateReader implements ObjectReader<SubscriptionUpdate> {
public SubscriptionUpdate readObject(Reader r) throws IOException { public SubscriptionUpdate readObject(Reader r) throws IOException {
// Initialise the consumer // Initialise the consumer
Consumer counting = Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
// Read the data // Read the data
r.addConsumer(counting); r.addConsumer(counting);
r.readStructId(Types.SUBSCRIPTION_UPDATE); r.readStructId(Types.SUBSCRIPTION_UPDATE);

View File

@@ -1,5 +1,10 @@
package net.sf.briar.protocol; package net.sf.briar.protocol;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PROPERTIES_PER_TRANSPORT;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PROPERTY_LENGTH;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_TRANSPORTS;
import java.io.IOException; import java.io.IOException;
import java.util.Collection; import java.util.Collection;
import java.util.HashSet; import java.util.HashSet;
@@ -8,7 +13,6 @@ import java.util.Set;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
@@ -32,16 +36,14 @@ class TransportUpdateReader implements ObjectReader<TransportUpdate> {
public TransportUpdate readObject(Reader r) throws IOException { public TransportUpdate readObject(Reader r) throws IOException {
// Initialise the consumer // Initialise the consumer
Consumer counting = Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
// Read the data // Read the data
r.addConsumer(counting); r.addConsumer(counting);
r.readStructId(Types.TRANSPORT_UPDATE); r.readStructId(Types.TRANSPORT_UPDATE);
r.addObjectReader(Types.TRANSPORT, transportReader); r.addObjectReader(Types.TRANSPORT, transportReader);
Collection<Transport> transports = r.readList(Transport.class); Collection<Transport> transports = r.readList(Transport.class);
r.removeObjectReader(Types.TRANSPORT); r.removeObjectReader(Types.TRANSPORT);
if(transports.size() > ProtocolConstants.MAX_TRANSPORTS) if(transports.size() > MAX_TRANSPORTS) throw new FormatException();
throw new FormatException();
long timestamp = r.readInt64(); long timestamp = r.readInt64();
r.removeConsumer(counting); r.removeConsumer(counting);
// Check for duplicate IDs or indices // Check for duplicate IDs or indices
@@ -65,14 +67,13 @@ class TransportUpdateReader implements ObjectReader<TransportUpdate> {
TransportId id = new TransportId(b); TransportId id = new TransportId(b);
// Read the index // Read the index
int i = r.readInt32(); int i = r.readInt32();
if(i < 0 || i >= ProtocolConstants.MAX_TRANSPORTS) if(i < 0 || i >= MAX_TRANSPORTS) throw new FormatException();
throw new FormatException();
TransportIndex index = new TransportIndex(i); TransportIndex index = new TransportIndex(i);
// Read the properties // Read the properties
r.setMaxStringLength(ProtocolConstants.MAX_PROPERTY_LENGTH); r.setMaxStringLength(MAX_PROPERTY_LENGTH);
Map<String, String> m = r.readMap(String.class, String.class); Map<String, String> m = r.readMap(String.class, String.class);
r.resetMaxStringLength(); r.resetMaxStringLength();
if(m.size() > ProtocolConstants.MAX_PROPERTIES_PER_TRANSPORT) if(m.size() > MAX_PROPERTIES_PER_TRANSPORT)
throw new FormatException(); throw new FormatException();
return new Transport(id, index, m); return new Transport(id, index, m);
} }

View File

@@ -2,11 +2,10 @@ package net.sf.briar.protocol;
import java.util.Collection; import java.util.Collection;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.protocol.UnverifiedBatch;
interface UnverifiedBatchFactory { interface UnverifiedBatchFactory {
UnverifiedBatch createUnverifiedBatch(BatchId id, UnverifiedBatch createUnverifiedBatch(
Collection<UnverifiedMessage> messages); Collection<UnverifiedMessage> messages);
} }

View File

@@ -3,7 +3,6 @@ package net.sf.briar.protocol;
import java.util.Collection; import java.util.Collection;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.protocol.UnverifiedBatch;
import com.google.inject.Inject; import com.google.inject.Inject;
@@ -17,8 +16,8 @@ class UnverifiedBatchFactoryImpl implements UnverifiedBatchFactory {
this.crypto = crypto; this.crypto = crypto;
} }
public UnverifiedBatch createUnverifiedBatch(BatchId id, public UnverifiedBatch createUnverifiedBatch(
Collection<UnverifiedMessage> messages) { Collection<UnverifiedMessage> messages) {
return new UnverifiedBatchImpl(crypto, id, messages); return new UnverifiedBatchImpl(crypto, messages);
} }
} }

View File

@@ -24,32 +24,34 @@ import net.sf.briar.api.protocol.UnverifiedBatch;
class UnverifiedBatchImpl implements UnverifiedBatch { class UnverifiedBatchImpl implements UnverifiedBatch {
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final BatchId id;
private final Collection<UnverifiedMessage> messages; private final Collection<UnverifiedMessage> messages;
private final MessageDigest batchDigest, messageDigest;
// Initialise lazily - the batch may be empty or contain unsigned messages // Initialise lazily - the batch may contain unsigned messages
private MessageDigest messageDigest = null;
private KeyParser keyParser = null; private KeyParser keyParser = null;
private Signature signature = null; private Signature signature = null;
UnverifiedBatchImpl(CryptoComponent crypto, BatchId id, UnverifiedBatchImpl(CryptoComponent crypto,
Collection<UnverifiedMessage> messages) { Collection<UnverifiedMessage> messages) {
this.crypto = crypto; this.crypto = crypto;
this.id = id;
this.messages = messages; this.messages = messages;
batchDigest = crypto.getMessageDigest();
messageDigest = crypto.getMessageDigest();
} }
public Batch verify() throws GeneralSecurityException { public Batch verify() throws GeneralSecurityException {
List<Message> verified = new ArrayList<Message>(); List<Message> verified = new ArrayList<Message>();
for(UnverifiedMessage m : messages) verified.add(verify(m)); for(UnverifiedMessage m : messages) verified.add(verify(m));
BatchId id = new BatchId(batchDigest.digest());
return new BatchImpl(id, Collections.unmodifiableList(verified)); return new BatchImpl(id, Collections.unmodifiableList(verified));
} }
private Message verify(UnverifiedMessage m) private Message verify(UnverifiedMessage m)
throws GeneralSecurityException { throws GeneralSecurityException {
// Hash the message, including the signatures, to get the message ID // The batch ID is the hash of the concatenated messages
byte[] raw = m.getRaw(); byte[] raw = m.getRaw();
if(messageDigest == null) messageDigest = crypto.getMessageDigest(); batchDigest.update(raw);
// Hash the message, including the signatures, to get the message ID
messageDigest.update(raw); messageDigest.update(raw);
MessageId id = new MessageId(messageDigest.digest()); MessageId id = new MessageId(messageDigest.digest());
// Verify the author's signature, if there is one // Verify the author's signature, if there is one

View File

@@ -43,6 +43,7 @@
<test name='net.sf.briar.protocol.ProtocolReadWriteTest'/> <test name='net.sf.briar.protocol.ProtocolReadWriteTest'/>
<test name='net.sf.briar.protocol.ProtocolWriterImplTest'/> <test name='net.sf.briar.protocol.ProtocolWriterImplTest'/>
<test name='net.sf.briar.protocol.RequestReaderTest'/> <test name='net.sf.briar.protocol.RequestReaderTest'/>
<test name='net.sf.briar.protocol.UnverifiedBatchImplTest'/>
<test name='net.sf.briar.serial.ReaderImplTest'/> <test name='net.sf.briar.serial.ReaderImplTest'/>
<test name='net.sf.briar.serial.WriterImplTest'/> <test name='net.sf.briar.serial.WriterImplTest'/>
<test name='net.sf.briar.setup.SetupWorkerTest'/> <test name='net.sf.briar.setup.SetupWorkerTest'/>

View File

@@ -7,9 +7,6 @@ import java.util.Collections;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.protocol.UnverifiedBatch;
@@ -18,7 +15,6 @@ import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.serial.ReaderFactory;
import net.sf.briar.api.serial.Writer; import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory; import net.sf.briar.api.serial.WriterFactory;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
import org.jmock.Expectations; import org.jmock.Expectations;
@@ -32,18 +28,15 @@ public class BatchReaderTest extends TestCase {
private final ReaderFactory readerFactory; private final ReaderFactory readerFactory;
private final WriterFactory writerFactory; private final WriterFactory writerFactory;
private final CryptoComponent crypto;
private final Mockery context; private final Mockery context;
private final UnverifiedMessage message; private final UnverifiedMessage message;
private final ObjectReader<UnverifiedMessage> messageReader; private final ObjectReader<UnverifiedMessage> messageReader;
public BatchReaderTest() throws Exception { public BatchReaderTest() throws Exception {
super(); super();
Injector i = Guice.createInjector(new SerialModule(), Injector i = Guice.createInjector(new SerialModule());
new CryptoModule());
readerFactory = i.getInstance(ReaderFactory.class); readerFactory = i.getInstance(ReaderFactory.class);
writerFactory = i.getInstance(WriterFactory.class); writerFactory = i.getInstance(WriterFactory.class);
crypto = i.getInstance(CryptoComponent.class);
context = new Mockery(); context = new Mockery();
message = context.mock(UnverifiedMessage.class); message = context.mock(UnverifiedMessage.class);
messageReader = new TestMessageReader(); messageReader = new TestMessageReader();
@@ -53,8 +46,7 @@ public class BatchReaderTest extends TestCase {
public void testFormatExceptionIfBatchIsTooLarge() throws Exception { public void testFormatExceptionIfBatchIsTooLarge() throws Exception {
UnverifiedBatchFactory batchFactory = UnverifiedBatchFactory batchFactory =
context.mock(UnverifiedBatchFactory.class); context.mock(UnverifiedBatchFactory.class);
BatchReader batchReader = new BatchReader(crypto, messageReader, BatchReader batchReader = new BatchReader(messageReader, batchFactory);
batchFactory);
byte[] b = createBatch(ProtocolConstants.MAX_PACKET_LENGTH + 1); byte[] b = createBatch(ProtocolConstants.MAX_PACKET_LENGTH + 1);
ByteArrayInputStream in = new ByteArrayInputStream(b); ByteArrayInputStream in = new ByteArrayInputStream(b);
@@ -72,12 +64,11 @@ public class BatchReaderTest extends TestCase {
public void testNoFormatExceptionIfBatchIsMaximumSize() throws Exception { public void testNoFormatExceptionIfBatchIsMaximumSize() throws Exception {
final UnverifiedBatchFactory batchFactory = final UnverifiedBatchFactory batchFactory =
context.mock(UnverifiedBatchFactory.class); context.mock(UnverifiedBatchFactory.class);
BatchReader batchReader = new BatchReader(crypto, messageReader, BatchReader batchReader = new BatchReader(messageReader, batchFactory);
batchFactory);
final UnverifiedBatch batch = context.mock(UnverifiedBatch.class); final UnverifiedBatch batch = context.mock(UnverifiedBatch.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(batchFactory).createUnverifiedBatch(with(any(BatchId.class)), oneOf(batchFactory).createUnverifiedBatch(
with(Collections.singletonList(message))); Collections.singletonList(message));
will(returnValue(batch)); will(returnValue(batch));
}}); }});
@@ -91,41 +82,11 @@ public class BatchReaderTest extends TestCase {
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@Test
public void testBatchId() throws Exception {
byte[] b = createBatch(ProtocolConstants.MAX_PACKET_LENGTH);
// Calculate the expected batch ID
MessageDigest messageDigest = crypto.getMessageDigest();
messageDigest.update(b);
final BatchId id = new BatchId(messageDigest.digest());
final UnverifiedBatchFactory batchFactory =
context.mock(UnverifiedBatchFactory.class);
BatchReader batchReader = new BatchReader(crypto, messageReader,
batchFactory);
final UnverifiedBatch batch = context.mock(UnverifiedBatch.class);
context.checking(new Expectations() {{
// Check that the batch ID matches the expected ID
oneOf(batchFactory).createUnverifiedBatch(with(id),
with(Collections.singletonList(message)));
will(returnValue(batch));
}});
ByteArrayInputStream in = new ByteArrayInputStream(b);
Reader reader = readerFactory.createReader(in);
reader.addObjectReader(Types.BATCH, batchReader);
assertEquals(batch, reader.readStruct(Types.BATCH,
UnverifiedBatch.class));
context.assertIsSatisfied();
}
@Test @Test
public void testEmptyBatch() throws Exception { public void testEmptyBatch() throws Exception {
final UnverifiedBatchFactory batchFactory = final UnverifiedBatchFactory batchFactory =
context.mock(UnverifiedBatchFactory.class); context.mock(UnverifiedBatchFactory.class);
BatchReader batchReader = new BatchReader(crypto, messageReader, BatchReader batchReader = new BatchReader(messageReader, batchFactory);
batchFactory);
byte[] b = createEmptyBatch(); byte[] b = createEmptyBatch();
ByteArrayInputStream in = new ByteArrayInputStream(b); ByteArrayInputStream in = new ByteArrayInputStream(b);

View File

@@ -0,0 +1,242 @@
package net.sf.briar.protocol;
import java.security.GeneralSecurityException;
import java.security.KeyPair;
import java.security.Signature;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.Random;
import junit.framework.TestCase;
import net.sf.briar.TestUtils;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.crypto.CryptoModule;
import org.jmock.Expectations;
import org.jmock.Mockery;
import org.junit.Test;
import com.google.inject.Guice;
import com.google.inject.Injector;
public class UnverifiedBatchImplTest extends TestCase {
private final CryptoComponent crypto;
private final byte[] raw, raw1;
private final String subject;
private final long timestamp;
public UnverifiedBatchImplTest() {
super();
Injector i = Guice.createInjector(new CryptoModule());
crypto = i.getInstance(CryptoComponent.class);
Random r = new Random();
raw = new byte[123];
r.nextBytes(raw);
raw1 = new byte[1234];
r.nextBytes(raw1);
subject = "Unit tests are exciting";
timestamp = System.currentTimeMillis();
}
@Test
public void testIds() throws Exception {
// Calculate the expected batch and message IDs
MessageDigest messageDigest = crypto.getMessageDigest();
messageDigest.update(raw);
messageDigest.update(raw1);
BatchId batchId = new BatchId(messageDigest.digest());
messageDigest.update(raw);
MessageId messageId = new MessageId(messageDigest.digest());
messageDigest.update(raw1);
MessageId messageId1 = new MessageId(messageDigest.digest());
// Verify the batch
Mockery context = new Mockery();
final UnverifiedMessage message =
context.mock(UnverifiedMessage.class, "message");
final UnverifiedMessage message1 =
context.mock(UnverifiedMessage.class, "message1");
context.checking(new Expectations() {{
// First message
oneOf(message).getRaw();
will(returnValue(raw));
oneOf(message).getAuthor();
will(returnValue(null));
oneOf(message).getGroup();
will(returnValue(null));
oneOf(message).getParent();
will(returnValue(null));
oneOf(message).getSubject();
will(returnValue(subject));
oneOf(message).getTimestamp();
will(returnValue(timestamp));
oneOf(message).getBodyStart();
will(returnValue(10));
oneOf(message).getBodyLength();
will(returnValue(100));
// Second message
oneOf(message1).getRaw();
will(returnValue(raw1));
oneOf(message1).getAuthor();
will(returnValue(null));
oneOf(message1).getGroup();
will(returnValue(null));
oneOf(message1).getParent();
will(returnValue(null));
oneOf(message1).getSubject();
will(returnValue(subject));
oneOf(message1).getTimestamp();
will(returnValue(timestamp));
oneOf(message1).getBodyStart();
will(returnValue(10));
oneOf(message1).getBodyLength();
will(returnValue(1000));
}});
Collection<UnverifiedMessage> messages =
Arrays.asList(new UnverifiedMessage[] {message, message1});
UnverifiedBatch batch = new UnverifiedBatchImpl(crypto, messages);
Batch verifiedBatch = batch.verify();
// Check that the batch and message IDs match
assertEquals(batchId, verifiedBatch.getId());
Collection<Message> verifiedMessages = verifiedBatch.getMessages();
assertEquals(2, verifiedMessages.size());
Iterator<Message> it = verifiedMessages.iterator();
Message verifiedMessage = it.next();
assertEquals(messageId, verifiedMessage.getId());
Message verifiedMessage1 = it.next();
assertEquals(messageId1, verifiedMessage1.getId());
context.assertIsSatisfied();
}
@Test
public void testSignatures() throws Exception {
final int signedByAuthor = 100, signedByGroup = 110;
final KeyPair authorKeyPair = crypto.generateKeyPair();
final KeyPair groupKeyPair = crypto.generateKeyPair();
Signature signature = crypto.getSignature();
// Calculate the expected author and group signatures
signature.initSign(authorKeyPair.getPrivate());
signature.update(raw, 0, signedByAuthor);
final byte[] authorSignature = signature.sign();
signature.initSign(groupKeyPair.getPrivate());
signature.update(raw, 0, signedByGroup);
final byte[] groupSignature = signature.sign();
// Verify the batch
Mockery context = new Mockery();
final UnverifiedMessage message =
context.mock(UnverifiedMessage.class, "message");
final Author author = context.mock(Author.class);
final Group group = context.mock(Group.class);
final UnverifiedMessage message1 =
context.mock(UnverifiedMessage.class, "message1");
context.checking(new Expectations() {{
// First message
oneOf(message).getRaw();
will(returnValue(raw));
oneOf(message).getAuthor();
will(returnValue(author));
oneOf(author).getPublicKey();
will(returnValue(authorKeyPair.getPublic().getEncoded()));
oneOf(message).getLengthSignedByAuthor();
will(returnValue(signedByAuthor));
oneOf(message).getAuthorSignature();
will(returnValue(authorSignature));
oneOf(message).getGroup();
will(returnValue(group));
exactly(2).of(group).getPublicKey();
will(returnValue(groupKeyPair.getPublic().getEncoded()));
oneOf(message).getLengthSignedByGroup();
will(returnValue(signedByGroup));
oneOf(message).getGroupSignature();
will(returnValue(groupSignature));
oneOf(author).getId();
will(returnValue(new AuthorId(TestUtils.getRandomId())));
oneOf(group).getId();
will(returnValue(new GroupId(TestUtils.getRandomId())));
oneOf(message).getParent();
will(returnValue(null));
oneOf(message).getSubject();
will(returnValue(subject));
oneOf(message).getTimestamp();
will(returnValue(timestamp));
oneOf(message).getBodyStart();
will(returnValue(10));
oneOf(message).getBodyLength();
will(returnValue(100));
// Second message
oneOf(message1).getRaw();
will(returnValue(raw1));
oneOf(message1).getAuthor();
will(returnValue(null));
oneOf(message1).getGroup();
will(returnValue(null));
oneOf(message1).getParent();
will(returnValue(null));
oneOf(message1).getSubject();
will(returnValue(subject));
oneOf(message1).getTimestamp();
will(returnValue(timestamp));
oneOf(message1).getBodyStart();
will(returnValue(10));
oneOf(message1).getBodyLength();
will(returnValue(1000));
}});
Collection<UnverifiedMessage> messages =
Arrays.asList(new UnverifiedMessage[] {message, message1});
UnverifiedBatch batch = new UnverifiedBatchImpl(crypto, messages);
batch.verify();
context.assertIsSatisfied();
}
@Test
public void testExceptionThrownIfMessageIsModified() throws Exception {
final int signedByAuthor = 100;
final KeyPair authorKeyPair = crypto.generateKeyPair();
Signature signature = crypto.getSignature();
// Calculate the expected author signature
signature.initSign(authorKeyPair.getPrivate());
signature.update(raw, 0, signedByAuthor);
final byte[] authorSignature = signature.sign();
// Modify the message
raw[signedByAuthor / 2] ^= 0xff;
// Verify the batch
Mockery context = new Mockery();
final UnverifiedMessage message =
context.mock(UnverifiedMessage.class, "message");
final Author author = context.mock(Author.class);
final UnverifiedMessage message1 =
context.mock(UnverifiedMessage.class, "message1");
context.checking(new Expectations() {{
// First message - verification will fail at the author's signature
oneOf(message).getRaw();
will(returnValue(raw));
oneOf(message).getAuthor();
will(returnValue(author));
oneOf(author).getPublicKey();
will(returnValue(authorKeyPair.getPublic().getEncoded()));
oneOf(message).getLengthSignedByAuthor();
will(returnValue(signedByAuthor));
oneOf(message).getAuthorSignature();
will(returnValue(authorSignature));
}});
Collection<UnverifiedMessage> messages =
Arrays.asList(new UnverifiedMessage[] {message, message1});
UnverifiedBatch batch = new UnverifiedBatchImpl(crypto, messages);
try {
batch.verify();
fail();
} catch(GeneralSecurityException expected) {}
context.assertIsSatisfied();
}
}