Readers, writers and factories for subscription and transport updates.

This commit is contained in:
akwizgran
2011-07-23 21:46:47 +01:00
parent 30271c14ce
commit 941460e3bc
34 changed files with 423 additions and 53 deletions

View File

@@ -2,8 +2,10 @@ package net.sf.briar.api.protocol;
import java.security.PublicKey;
import net.sf.briar.api.serial.Writable;
/** A group to which users may subscribe. */
public interface Group {
public interface Group extends Writable {
/** Returns the group's unique identifier. */
GroupId getId();

View File

@@ -2,5 +2,6 @@ package net.sf.briar.api.protocol;
public interface GroupFactory {
Group createGroup(GroupId id, String name, boolean restricted, byte[] b);
Group createGroup(GroupId id, String name, boolean restricted,
byte[] saltOrKey);
}

View File

@@ -5,6 +5,12 @@ import java.util.Collection;
/** A packet updating the sender's subscriptions. */
public interface Subscriptions {
/**
* The maximum size of a serialized subscriptions update, excluding
* encryption and authentication.
*/
static final int MAX_SIZE = (1024 * 1024) - 100;
/** Returns the subscriptions contained in the update. */
Collection<Group> getSubscriptions();

View File

@@ -11,9 +11,10 @@ public interface Tags {
static final int AUTHOR_ID = 1;
static final int BATCH = 2;
static final int BATCH_ID = 3;
static final int GROUP_ID = 4;
static final int MESSAGE = 5;
static final int MESSAGE_ID = 6;
static final int SUBSCRIPTIONS = 7;
static final int TRANSPORTS = 8;
static final int GROUP = 4;
static final int GROUP_ID = 5;
static final int MESSAGE = 6;
static final int MESSAGE_ID = 7;
static final int SUBSCRIPTIONS = 8;
static final int TRANSPORTS = 9;
}

View File

@@ -5,6 +5,12 @@ import java.util.Map;
/** A packet updating the sender's transports. */
public interface Transports {
/**
* The maximum size of a serialised transports update, excluding
* encryption and authentication.
*/
static final int MAX_SIZE = (1024 * 1024) - 100;
/** Returns the transports contained in the update. */
Map<String, String> getTransports();

View File

@@ -11,7 +11,7 @@ public interface AckWriter {
* Attempts to add the given BatchId to the ack and returns true if it
* was added.
*/
boolean addBatchId(BatchId b) throws IOException;
boolean writeBatchId(BatchId b) throws IOException;
/** Finishes writing the ack. */
void finish() throws IOException;

View File

@@ -14,7 +14,7 @@ public interface BatchWriter {
* Attempts to add the given raw message to the batch and returns true if
* it was added.
*/
boolean addMessage(byte[] raw) throws IOException;
boolean writeMessage(byte[] raw) throws IOException;
/** Finishes writing the batch and returns its unique identifier. */
BatchId finish() throws IOException;

View File

@@ -1,12 +1,13 @@
package net.sf.briar.api.protocol.writers;
import java.io.IOException;
import java.util.Collection;
import net.sf.briar.api.protocol.Group;
/** An interface for creating a subscription update. */
public interface SubscriptionWriter {
/** Sets the contents of the update. */
void setSubscriptions(Iterable<Group> subs) throws IOException;
/** Writes the contents of the update. */
void writeSubscriptions(Collection<Group> subs) throws IOException;
}

View File

@@ -6,6 +6,6 @@ import java.util.Map;
/** An interface for creating a transports update. */
public interface TransportWriter {
/** Sets the contents of the update. */
void setTransports(Map<String, String> transports) throws IOException;
/** Writes the contents of the update. */
void writeTransports(Map<String, String> transports) throws IOException;
}

View File

@@ -259,7 +259,7 @@ class ReadWriteLockDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
try {
Collection<BatchId> acks = db.getBatchesToAck(txn, c);
Collection<BatchId> sent = new ArrayList<BatchId>();
for(BatchId b : acks) if(a.addBatchId(b)) sent.add(b);
for(BatchId b : acks) if(a.writeBatchId(b)) sent.add(b);
a.finish();
db.removeBatchesToAck(txn, c, sent);
if(LOG.isLoggable(Level.FINE))
@@ -300,7 +300,7 @@ class ReadWriteLockDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
while(it.hasNext()) {
MessageId m = it.next();
byte[] message = db.getMessage(txn, m);
if(!b.addMessage(message)) break;
if(!b.writeMessage(message)) break;
bytesSent += message.length;
sent.add(m);
}
@@ -349,7 +349,7 @@ class ReadWriteLockDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
Txn txn = db.startTransaction();
try {
Collection<Group> subs = db.getSubscriptions(txn);
s.setSubscriptions(subs);
s.writeSubscriptions(subs);
if(LOG.isLoggable(Level.FINE))
LOG.fine("Added " + subs.size() + " subscriptions");
db.commitTransaction(txn);
@@ -378,7 +378,7 @@ class ReadWriteLockDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
Txn txn = db.startTransaction();
try {
Map<String, String> transports = db.getTransports(txn);
t.setTransports(transports);
t.writeTransports(transports);
if(LOG.isLoggable(Level.FINE))
LOG.fine("Added " + transports.size() + " transports");
db.commitTransaction(txn);

View File

@@ -192,7 +192,7 @@ class SynchronizedDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
try {
Collection<BatchId> acks = db.getBatchesToAck(txn, c);
Collection<BatchId> sent = new ArrayList<BatchId>();
for(BatchId b : acks) if(a.addBatchId(b)) sent.add(b);
for(BatchId b : acks) if(a.writeBatchId(b)) sent.add(b);
a.finish();
db.removeBatchesToAck(txn, c, sent);
if(LOG.isLoggable(Level.FINE))
@@ -225,7 +225,7 @@ class SynchronizedDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
while(it.hasNext()) {
MessageId m = it.next();
byte[] message = db.getMessage(txn, m);
if(!b.addMessage(message)) break;
if(!b.writeMessage(message)) break;
bytesSent += message.length;
sent.add(m);
}
@@ -254,7 +254,7 @@ class SynchronizedDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
Txn txn = db.startTransaction();
try {
Collection<Group> subs = db.getSubscriptions(txn);
s.setSubscriptions(subs);
s.writeSubscriptions(subs);
if(LOG.isLoggable(Level.FINE))
LOG.fine("Added " + subs.size() + " subscriptions");
db.commitTransaction(txn);
@@ -277,7 +277,7 @@ class SynchronizedDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
Txn txn = db.startTransaction();
try {
Map<String, String> transports = db.getTransports(txn);
t.setTransports(transports);
t.writeTransports(transports);
if(LOG.isLoggable(Level.FINE))
LOG.fine("Added " + transports.size() + " transports");
db.commitTransaction(txn);

View File

@@ -11,9 +11,11 @@ import net.sf.briar.api.serial.Reader;
class AckReader implements ObjectReader<Ack> {
private final ObjectReader<BatchId> batchIdReader;
private final AckFactory ackFactory;
AckReader(AckFactory ackFactory) {
AckReader(ObjectReader<BatchId> batchIdReader, AckFactory ackFactory) {
this.batchIdReader = batchIdReader;
this.ackFactory = ackFactory;
}
@@ -23,7 +25,7 @@ class AckReader implements ObjectReader<Ack> {
// Read and digest the data
r.addConsumer(counting);
r.readUserDefinedTag(Tags.ACK);
r.addObjectReader(Tags.BATCH_ID, new BatchIdReader());
r.addObjectReader(Tags.BATCH_ID, batchIdReader);
Collection<BatchId> batches = r.readList(BatchId.class);
r.removeObjectReader(Tags.BATCH_ID);
r.removeConsumer(counting);

View File

@@ -20,14 +20,14 @@ class GroupFactoryImpl implements GroupFactory {
}
public Group createGroup(GroupId id, String name, boolean restricted,
byte[] b) {
byte[] saltOrKey) {
if(restricted) {
try {
PublicKey key = keyParser.parsePublicKey(b);
return new GroupImpl(id, name, null, key);
} catch (InvalidKeySpecException e) {
PublicKey key = keyParser.parsePublicKey(saltOrKey);
return new GroupImpl(id, name, key);
} catch(InvalidKeySpecException e) {
throw new IllegalArgumentException(e);
}
} else return new GroupImpl(id, name, b, null);
} else return new GroupImpl(id, name, saltOrKey);
}
}

View File

@@ -1,9 +1,12 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.security.PublicKey;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.serial.Writer;
class GroupImpl implements Group {
@@ -12,12 +15,18 @@ class GroupImpl implements Group {
private final byte[] salt;
private final PublicKey publicKey;
GroupImpl(GroupId id, String name, byte[] salt, PublicKey publicKey) {
assert salt == null || publicKey == null;
GroupImpl(GroupId id, String name, byte[] salt) {
this.id = id;
this.name = name;
this.salt = salt;
publicKey = null;
}
GroupImpl(GroupId id, String name, PublicKey publicKey) {
this.id = id;
this.name = name;
this.publicKey = publicKey;
salt = null;
}
public GroupId getId() {
@@ -40,6 +49,14 @@ class GroupImpl implements Group {
return publicKey;
}
public void writeTo(Writer w) throws IOException {
w.writeUserDefinedTag(Tags.GROUP);
w.writeString(name);
w.writeBoolean(isRestricted());
if(salt == null) w.writeRaw(publicKey.getEncoded());
else w.writeRaw(salt);
}
@Override
public boolean equals(Object o) {
return o instanceof Group && id.equals(((Group) o).getId());

View File

@@ -0,0 +1,38 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.security.MessageDigest;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
class GroupReader implements ObjectReader<Group> {
private final MessageDigest messageDigest;
private final GroupFactory groupFactory;
GroupReader(MessageDigest messageDigest, GroupFactory groupFactory) {
this.messageDigest = messageDigest;
this.groupFactory = groupFactory;
}
public Group readObject(Reader r) throws IOException {
// Initialise the consumer
DigestingConsumer digesting = new DigestingConsumer(messageDigest);
messageDigest.reset();
// Read and digest the data
r.addConsumer(digesting);
r.readUserDefinedTag(Tags.GROUP);
String name = r.readString();
boolean restricted = r.readBoolean();
byte[] saltOrKey = r.readRaw();
r.removeConsumer(digesting);
// Build and return the group
GroupId id = new GroupId(messageDigest.digest());
return groupFactory.createGroup(id, name, restricted, saltOrKey);
}
}

View File

@@ -16,12 +16,15 @@ import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
import com.google.inject.Inject;
class MessageEncoderImpl implements MessageEncoder {
private final Signature signature;
private final MessageDigest messageDigest;
private final WriterFactory writerFactory;
@Inject
MessageEncoderImpl(Signature signature, MessageDigest messageDigest,
WriterFactory writerFactory) {
this.signature = signature;

View File

@@ -1,6 +1,7 @@
package net.sf.briar.protocol;
import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.MessageEncoder;
import com.google.inject.AbstractModule;
@@ -11,5 +12,6 @@ public class ProtocolModule extends AbstractModule {
bind(AckFactory.class).to(AckFactoryImpl.class);
bind(BatchFactory.class).to(BatchFactoryImpl.class);
bind(GroupFactory.class).to(GroupFactoryImpl.class);
bind(MessageEncoder.class).to(MessageEncoderImpl.class);
}
}

View File

@@ -0,0 +1,11 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.Subscriptions;
interface SubscriptionFactory {
Subscriptions createSubscriptions(Collection<Group> subs, long timestamp);
}

View File

@@ -0,0 +1,14 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.Subscriptions;
class SubscriptionFactoryImpl implements SubscriptionFactory {
public Subscriptions createSubscriptions(Collection<Group> subs,
long timestamp) {
return new SubscriptionsImpl(subs, timestamp);
}
}

View File

@@ -0,0 +1,38 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.util.Collection;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.Subscriptions;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
class SubscriptionReader implements ObjectReader<Subscriptions> {
private final ObjectReader<Group> groupReader;
private final SubscriptionFactory subscriptionFactory;
SubscriptionReader(ObjectReader<Group> groupReader,
SubscriptionFactory subscriptionFactory) {
this.groupReader = groupReader;
this.subscriptionFactory = subscriptionFactory;
}
public Subscriptions readObject(Reader r) throws IOException {
// Initialise the consumer
CountingConsumer counting =
new CountingConsumer(Subscriptions.MAX_SIZE);
// Read the data
r.addConsumer(counting);
r.readUserDefinedTag(Tags.SUBSCRIPTIONS);
r.addObjectReader(Tags.GROUP, groupReader);
Collection<Group> subs = r.readList(Group.class);
r.removeObjectReader(Tags.GROUP);
long timestamp = r.readInt64();
r.removeConsumer(counting);
// Build and return the subscriptions update
return subscriptionFactory.createSubscriptions(subs, timestamp);
}
}

View File

@@ -0,0 +1,25 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.Subscriptions;
class SubscriptionsImpl implements Subscriptions {
private final Collection<Group> subs;
private final long timestamp;
SubscriptionsImpl(Collection<Group> subs, long timestamp) {
this.subs = subs;
this.timestamp = timestamp;
}
public Collection<Group> getSubscriptions() {
return subs;
}
public long getTimestamp() {
return timestamp;
}
}

View File

@@ -0,0 +1,10 @@
package net.sf.briar.protocol;
import java.util.Map;
import net.sf.briar.api.protocol.Transports;
interface TransportFactory {
Transports createTransports(Map<String, String> transports, long timestamp);
}

View File

@@ -0,0 +1,13 @@
package net.sf.briar.protocol;
import java.util.Map;
import net.sf.briar.api.protocol.Transports;
class TransportFactoryImpl implements TransportFactory {
public Transports createTransports(Map<String, String> transports,
long timestamp) {
return new TransportsImpl(transports, timestamp);
}
}

View File

@@ -0,0 +1,31 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.util.Map;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.protocol.Transports;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
class TransportReader implements ObjectReader<Transports> {
private final TransportFactory transportFactory;
TransportReader(TransportFactory transportFactory) {
this.transportFactory = transportFactory;
}
public Transports readObject(Reader r) throws IOException {
// Initialise the consumer
CountingConsumer counting = new CountingConsumer(Transports.MAX_SIZE);
// Read the data
r.addConsumer(counting);
r.readUserDefinedTag(Tags.TRANSPORTS);
Map<String, String> transports = r.readMap(String.class, String.class);
long timestamp = r.readInt64();
r.removeConsumer(counting);
// Build and return the transports update
return transportFactory.createTransports(transports, timestamp);
}
}

View File

@@ -0,0 +1,24 @@
package net.sf.briar.protocol;
import java.util.Map;
import net.sf.briar.api.protocol.Transports;
class TransportsImpl implements Transports {
private final Map<String, String> transports;
private final long timestamp;
TransportsImpl(Map<String, String> transports, long timestamp) {
this.transports = transports;
this.timestamp = timestamp;
}
public Map<String, String> getTransports() {
return transports;
}
public long getTimestamp() {
return timestamp;
}
}

View File

@@ -19,10 +19,10 @@ class AckWriterImpl implements AckWriter {
AckWriterImpl(OutputStream out, WriterFactory writerFactory) {
this.out = out;
this.w = writerFactory.createWriter(out);
w = writerFactory.createWriter(out);
}
public boolean addBatchId(BatchId b) throws IOException {
public boolean writeBatchId(BatchId b) throws IOException {
if(finished) throw new IllegalStateException();
if(!started) {
w.writeUserDefinedTag(Tags.ACK);

View File

@@ -31,7 +31,7 @@ class BatchWriterImpl implements BatchWriter {
return Batch.MAX_SIZE - 3;
}
public boolean addMessage(byte[] message) throws IOException {
public boolean writeMessage(byte[] message) throws IOException {
if(finished) throw new IllegalStateException();
if(!started) {
messageDigest.reset();

View File

@@ -33,12 +33,10 @@ class PacketWriterFactoryImpl implements PacketWriterFactory {
}
public SubscriptionWriter createSubscriptionWriter(OutputStream out) {
// TODO Auto-generated method stub
return null;
return new SubscriptionWriterImpl(out, writerFactory);
}
public TransportWriter createTransportWriter(OutputStream out) {
// TODO Auto-generated method stub
return null;
return new TransportWriterImpl(out, writerFactory);
}
}

View File

@@ -0,0 +1,33 @@
package net.sf.briar.protocol.writers;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Collection;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.protocol.writers.SubscriptionWriter;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
class SubscriptionWriterImpl implements SubscriptionWriter {
private final OutputStream out;
private final Writer w;
private boolean used = false;
SubscriptionWriterImpl(OutputStream out, WriterFactory writerFactory) {
this.out = out;
w = writerFactory.createWriter(out);
}
public void writeSubscriptions(Collection<Group> subs) throws IOException {
if(used) throw new IllegalStateException();
w.writeUserDefinedTag(Tags.SUBSCRIPTIONS);
w.writeList(subs);
w.writeInt64(System.currentTimeMillis());
out.flush();
used = true;
}
}

View File

@@ -0,0 +1,33 @@
package net.sf.briar.protocol.writers;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Map;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.protocol.writers.TransportWriter;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
class TransportWriterImpl implements TransportWriter {
private final OutputStream out;
private final Writer w;
private boolean used = false;
TransportWriterImpl(OutputStream out, WriterFactory writerFactory) {
this.out = out;
w = writerFactory.createWriter(out);
}
public void writeTransports(Map<String, String> transports)
throws IOException {
if(used) throw new IllegalStateException();
w.writeUserDefinedTag(Tags.TRANSPORTS);
w.writeMap(transports);
w.writeInt64(System.currentTimeMillis());
out.flush();
used = true;
}
}

View File

@@ -464,9 +464,9 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).getBatchesToAck(txn, contactId);
will(returnValue(twoAcks));
// Try to add both batches to the writer - only manage to add one
oneOf(ackWriter).addBatchId(batchId);
oneOf(ackWriter).writeBatchId(batchId);
will(returnValue(true));
oneOf(ackWriter).addBatchId(batchId1);
oneOf(ackWriter).writeBatchId(batchId1);
will(returnValue(false));
oneOf(ackWriter).finish();
// Record the batch that was acked

View File

@@ -28,6 +28,7 @@ import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.protocol.ProtocolModule;
import net.sf.briar.serial.SerialModule;
import org.apache.commons.io.FileSystemUtils;
import org.junit.After;
@@ -62,8 +63,8 @@ public class H2DatabaseTest extends TestCase {
public H2DatabaseTest() throws Exception {
super();
Injector i = Guice.createInjector(new ProtocolModule(),
new CryptoModule());
Injector i = Guice.createInjector(new CryptoModule(),
new ProtocolModule(), new SerialModule());
groupFactory = i.getInstance(GroupFactory.class);
authorId = new AuthorId(TestUtils.getRandomId());
batchId = new BatchId(TestUtils.getRandomId());

View File

@@ -42,7 +42,7 @@ public class AckReaderTest extends TestCase {
@Test
public void testFormatExceptionIfAckIsTooLarge() throws Exception {
AckFactory ackFactory = context.mock(AckFactory.class);
AckReader ackReader = new AckReader(ackFactory);
AckReader ackReader = new AckReader(new BatchIdReader(), ackFactory);
byte[] b = createAck(true);
ByteArrayInputStream in = new ByteArrayInputStream(b);
@@ -60,7 +60,7 @@ public class AckReaderTest extends TestCase {
@SuppressWarnings("unchecked")
public void testNoFormatExceptionIfAckIsMaximumSize() throws Exception {
final AckFactory ackFactory = context.mock(AckFactory.class);
AckReader ackReader = new AckReader(ackFactory);
AckReader ackReader = new AckReader(new BatchIdReader(), ackFactory);
final Ack ack = context.mock(Ack.class);
context.checking(new Expectations() {{
oneOf(ackFactory).createAck(with(any(Collection.class)));
@@ -79,7 +79,7 @@ public class AckReaderTest extends TestCase {
@Test
public void testEmptyAck() throws Exception {
final AckFactory ackFactory = context.mock(AckFactory.class);
AckReader ackReader = new AckReader(ackFactory);
AckReader ackReader = new AckReader(new BatchIdReader(), ackFactory);
final Ack ack = context.mock(Ack.class);
context.checking(new Expectations() {{
oneOf(ackFactory).createAck(

View File

@@ -1,5 +1,6 @@
package net.sf.briar.protocol;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
@@ -16,22 +17,31 @@ import net.sf.briar.api.crypto.KeyParser;
import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageEncoder;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Subscriptions;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.protocol.Transports;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.PacketWriterFactory;
import net.sf.briar.api.protocol.writers.SubscriptionWriter;
import net.sf.briar.api.protocol.writers.TransportWriter;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.protocol.writers.WritersModule;
import net.sf.briar.serial.SerialModule;
import org.apache.commons.io.output.ByteArrayOutputStream;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -48,6 +58,7 @@ public class FileReadWriteTest extends TestCase {
private final GroupId sub = new GroupId(TestUtils.getRandomId());
private final String nick = "Foo Bar";
private final String messageBody = "This is the message body! Wooooooo!";
private final long start = System.currentTimeMillis();
private final ReaderFactory readerFactory;
private final WriterFactory writerFactory;
@@ -56,11 +67,13 @@ public class FileReadWriteTest extends TestCase {
private final MessageDigest messageDigest, batchDigest;
private final KeyParser keyParser;
private final Message message;
private final Group group;
public FileReadWriteTest() throws Exception {
super();
Injector i = Guice.createInjector(new SerialModule(),
new CryptoModule(), new WritersModule());
Injector i = Guice.createInjector(new CryptoModule(),
new ProtocolModule(), new SerialModule(),
new WritersModule());
readerFactory = i.getInstance(ReaderFactory.class);
writerFactory = i.getInstance(WriterFactory.class);
packetWriterFactory = i.getInstance(PacketWriterFactory.class);
@@ -71,11 +84,23 @@ public class FileReadWriteTest extends TestCase {
assertEquals(messageDigest.getDigestLength(), UniqueId.LENGTH);
assertEquals(batchDigest.getDigestLength(), UniqueId.LENGTH);
// Create and encode a test message
MessageEncoder messageEncoder = new MessageEncoderImpl(signature,
messageDigest, writerFactory);
MessageEncoder messageEncoder = i.getInstance(MessageEncoder.class);
KeyPair keyPair = i.getInstance(KeyPair.class);
message = messageEncoder.encodeMessage(MessageId.NONE, sub, nick,
keyPair, messageBody.getBytes("UTF-8"));
// Create a test group, then write and read it to calculate its ID
GroupFactory groupFactory = i.getInstance(GroupFactory.class);
Group noId = groupFactory.createGroup(
new GroupId(new byte[UniqueId.LENGTH]), "Group name", false,
TestUtils.getRandomId());
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
noId.writeTo(w);
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
Reader r = readerFactory.createReader(in);
ObjectReader<Group> groupReader = new GroupReader(batchDigest,
groupFactory);
group = groupReader.readObject(r);
}
@Before
@@ -88,13 +113,20 @@ public class FileReadWriteTest extends TestCase {
FileOutputStream out = new FileOutputStream(file);
AckWriter a = packetWriterFactory.createAckWriter(out);
a.addBatchId(ack);
assertTrue(a.writeBatchId(ack));
a.finish();
BatchWriter b = packetWriterFactory.createBatchWriter(out);
b.addMessage(message.getBytes());
assertTrue(b.writeMessage(message.getBytes()));
b.finish();
SubscriptionWriter s =
packetWriterFactory.createSubscriptionWriter(out);
s.writeSubscriptions(Collections.singleton(group));
TransportWriter t = packetWriterFactory.createTransportWriter(out);
t.writeTransports(Collections.singletonMap("foo", "bar"));
out.close();
assertTrue(file.exists());
assertTrue(file.length() > message.getSize());
@@ -107,18 +139,30 @@ public class FileReadWriteTest extends TestCase {
MessageReader messageReader =
new MessageReader(keyParser, signature, messageDigest);
AckReader ackReader = new AckReader(new AckFactoryImpl());
BatchReader batchReader = new BatchReader(batchDigest, messageReader,
new BatchFactoryImpl());
ObjectReader<Ack> ackReader = new AckReader(new BatchIdReader(),
new AckFactoryImpl());
ObjectReader<Batch> batchReader = new BatchReader(batchDigest,
messageReader, new BatchFactoryImpl());
ObjectReader<Group> groupReader = new GroupReader(batchDigest,
new GroupFactoryImpl(keyParser));
ObjectReader<Subscriptions> subscriptionReader =
new SubscriptionReader(groupReader, new SubscriptionFactoryImpl());
ObjectReader<Transports> transportReader =
new TransportReader(new TransportFactoryImpl());
FileInputStream in = new FileInputStream(file);
Reader reader = readerFactory.createReader(in);
reader.addObjectReader(Tags.ACK, ackReader);
reader.addObjectReader(Tags.BATCH, batchReader);
reader.addObjectReader(Tags.SUBSCRIPTIONS, subscriptionReader);
reader.addObjectReader(Tags.TRANSPORTS, transportReader);
// Read the ack
assertTrue(reader.hasUserDefined(Tags.ACK));
Ack a = reader.readUserDefined(Tags.ACK, Ack.class);
assertEquals(Collections.singletonList(ack), a.getBatches());
// Read the batch
assertTrue(reader.hasUserDefined(Tags.BATCH));
Batch b = reader.readUserDefined(Tags.BATCH, Batch.class);
Iterator<Message> i = b.getMessages().iterator();
@@ -132,6 +176,22 @@ public class FileReadWriteTest extends TestCase {
assertTrue(Arrays.equals(message.getBytes(), m.getBytes()));
assertFalse(i.hasNext());
// Read the subscriptions update
assertTrue(reader.hasUserDefined(Tags.SUBSCRIPTIONS));
Subscriptions s = reader.readUserDefined(Tags.SUBSCRIPTIONS,
Subscriptions.class);
assertEquals(Collections.singletonList(group), s.getSubscriptions());
assertTrue(s.getTimestamp() > start);
assertTrue(s.getTimestamp() <= System.currentTimeMillis());
// Read the transports update
assertTrue(reader.hasUserDefined(Tags.TRANSPORTS));
Transports t = reader.readUserDefined(Tags.TRANSPORTS,
Transports.class);
assertEquals(Collections.singletonMap("foo", "bar"), t.getTransports());
assertTrue(t.getTimestamp() > start);
assertTrue(t.getTimestamp() <= System.currentTimeMillis());
assertTrue(reader.eof());
}