Decouple the database from IO.

This will enable asynchronous access to the database for IO threads.
This commit is contained in:
akwizgran
2011-12-07 00:23:35 +00:00
parent 45a51b4926
commit b7c3224618
73 changed files with 1120 additions and 1398 deletions

View File

@@ -17,17 +17,13 @@ import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.OfferWriter;
import net.sf.briar.api.protocol.writers.RequestWriter;
import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter;
import net.sf.briar.api.protocol.writers.TransportUpdateWriter;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindow;
@@ -72,43 +68,46 @@ public interface DatabaseComponent {
TransportIndex addTransport(TransportId t) throws DbException; TransportIndex addTransport(TransportId t) throws DbException;
/** /**
* Generates an acknowledgement for the given contact. * Generates an acknowledgement for the given contact. Returns null if
* @return True if any batch IDs were added to the acknowledgement. * there are no batches to acknowledge.
*/ */
boolean generateAck(ContactId c, AckWriter a) throws DbException, Ack generateAck(ContactId c, int maxBatches) throws DbException;
IOException;
/** /**
* Generates a batch of messages for the given contact. * Generates a batch of messages for the given contact. Returns null if
* @return True if any messages were added to tbe batch. * there are no sendable messages that fit in the given capacity.
*/ */
boolean generateBatch(ContactId c, BatchWriter b) throws DbException, RawBatch generateBatch(ContactId c, int capacity) throws DbException;
IOException;
/** /**
* Generates a batch of messages for the given contact from the given * Generates a batch of messages for the given contact from the given
* collection of requested messages. Any messages that were either added to * collection of requested messages. Any messages that were either added to
* the batch, or were considered but are no longer sendable to the contact, * the batch, or were considered but are no longer sendable to the contact,
* are removed from the collection of requested messages before returning. * are removed from the collection of requested messages before returning.
* @return True if any messages were added to the batch. * Returns null if there are no sendable messages that fit in the given
* capacity.
*/ */
boolean generateBatch(ContactId c, BatchWriter b, RawBatch generateBatch(ContactId c, int capacity,
Collection<MessageId> requested) throws DbException, IOException; Collection<MessageId> requested) throws DbException;
/** /**
* Generates an offer for the given contact and returns the offered * Generates an offer for the given contact. Returns null if there are no
* message IDs. * messages to offer.
*/ */
Collection<MessageId> generateOffer(ContactId c, OfferWriter o) Offer generateOffer(ContactId c, int maxMessages) throws DbException;
throws DbException, IOException;
/** Generates a subscription update for the given contact. */ /**
void generateSubscriptionUpdate(ContactId c, SubscriptionUpdateWriter s) * Generates a subscription update for the given contact. Returns null if
throws DbException, IOException; * an update is not due.
*/
SubscriptionUpdate generateSubscriptionUpdate(ContactId c)
throws DbException;
/** Generates a transport update for the given contact. */ /**
void generateTransportUpdate(ContactId c, TransportUpdateWriter t) * Generates a transport update for the given contact. Returns null if an
throws DbException, IOException; * update is not due.
*/
TransportUpdate generateTransportUpdate(ContactId c) throws DbException;
/** Returns the configuration for the given transport. */ /** Returns the configuration for the given transport. */
TransportConfig getConfig(TransportId t) throws DbException; TransportConfig getConfig(TransportId t) throws DbException;
@@ -185,8 +184,7 @@ public interface DatabaseComponent {
* to the contact are requested just as though they were not present in the * to the contact are requested just as though they were not present in the
* database. * database.
*/ */
void receiveOffer(ContactId c, Offer o, RequestWriter r) throws DbException, Request receiveOffer(ContactId c, Offer o) throws DbException;
IOException;
/** Processes a subscription update from the given contact. */ /** Processes a subscription update from the given contact. */
void receiveSubscriptionUpdate(ContactId c, SubscriptionUpdate s) void receiveSubscriptionUpdate(ContactId c, SubscriptionUpdate s)

View File

@@ -2,7 +2,7 @@ package net.sf.briar.api.protocol;
import java.util.Collection; import java.util.Collection;
/** A packet containing messages. */ /** An incoming packet containing messages. */
public interface Batch { public interface Batch {
/** Returns the batch's unique identifier. */ /** Returns the batch's unique identifier. */

View File

@@ -23,9 +23,6 @@ public interface Message {
/** Returns the timestamp created by the message's author. */ /** Returns the timestamp created by the message's author. */
long getTimestamp(); long getTimestamp();
/** Returns the length of the serialised message in bytes. */
int getLength();
/** Returns the serialised message. */ /** Returns the serialised message. */
byte[] getSerialised(); byte[] getSerialised();

View File

@@ -0,0 +1,22 @@
package net.sf.briar.api.protocol;
import java.util.BitSet;
import java.util.Collection;
import java.util.Map;
public interface PacketFactory {
Ack createAck(Collection<BatchId> acked);
RawBatch createBatch(Collection<byte[]> messages);
Offer createOffer(Collection<MessageId> offered);
Request createRequest(BitSet requested, int length);
SubscriptionUpdate createSubscriptionUpdate(Map<Group, Long> subs,
long timestamp);
TransportUpdate createTransportUpdate(Collection<Transport> transports,
long timestamp);
}

View File

@@ -0,0 +1,24 @@
package net.sf.briar.api.protocol;
import java.io.IOException;
public interface ProtocolWriter {
int getMaxBatchesForAck(long capacity);
int getMaxMessagesForOffer(long capacity);
int getMessageCapacityForBatch(long capacity);
void writeAck(Ack a) throws IOException;
void writeBatch(RawBatch b) throws IOException;
void writeOffer(Offer o) throws IOException;
void writeRequest(Request r) throws IOException;
void writeSubscriptionUpdate(SubscriptionUpdate s) throws IOException;
void writeTransportUpdate(TransportUpdate t) throws IOException;
}

View File

@@ -0,0 +1,8 @@
package net.sf.briar.api.protocol;
import java.io.OutputStream;
public interface ProtocolWriterFactory {
ProtocolWriter createProtocolWriter(OutputStream out);
}

View File

@@ -0,0 +1,13 @@
package net.sf.briar.api.protocol;
import java.util.Collection;
/** An outgoing packet containing messages. */
public interface RawBatch {
/** Returns the batch's unique identifier. */
BatchId getId();
/** Returns the serialised messages contained in the batch. */
Collection<byte[]> getMessages();
}

View File

@@ -10,4 +10,7 @@ public interface Request {
* the offer, where the i^th bit is set if the i^th message should be sent. * the offer, where the i^th bit is set if the i^th message should be sent.
*/ */
BitSet getBitmap(); BitSet getBitmap();
/** Returns the length of the bitmap in bits. */
int getLength();
} }

View File

@@ -3,6 +3,7 @@ package net.sf.briar.api.protocol;
/** Struct identifiers for encoding and decoding protocol objects. */ /** Struct identifiers for encoding and decoding protocol objects. */
public interface Types { public interface Types {
// FIXME: Batch ID, message ID don't need to be structs
static final int ACK = 0; static final int ACK = 0;
static final int AUTHOR = 1; static final int AUTHOR = 1;
static final int BATCH = 2; static final int BATCH = 2;

View File

@@ -1,24 +0,0 @@
package net.sf.briar.api.protocol.writers;
import java.io.IOException;
import net.sf.briar.api.protocol.BatchId;
/** An interface for creating an ack packet. */
public interface AckWriter {
/**
* Sets the maximum length of the serialised ack. If this method is not
* called, the default is ProtocolConstants.MAX_PACKET_LENGTH;
*/
void setMaxPacketLength(int length);
/**
* Attempts to add the given BatchId to the ack and returns true if it
* was added.
*/
boolean writeBatchId(BatchId b) throws IOException;
/** Finishes writing the ack. */
void finish() throws IOException;
}

View File

@@ -1,27 +0,0 @@
package net.sf.briar.api.protocol.writers;
import java.io.IOException;
import net.sf.briar.api.protocol.BatchId;
/** An interface for creating a batch packet. */
public interface BatchWriter {
/** Returns the capacity of the batch in bytes. */
int getCapacity();
/**
* Sets the maximum length of the serialised batch; the default is
* ProtocolConstants.MAX_PACKET_LENGTH;
*/
void setMaxPacketLength(int length);
/**
* Attempts to add the given raw message to the batch and returns true if
* it was added.
*/
boolean writeMessage(byte[] raw) throws IOException;
/** Finishes writing the batch and returns its unique identifier. */
BatchId finish() throws IOException;
}

View File

@@ -1,24 +0,0 @@
package net.sf.briar.api.protocol.writers;
import java.io.IOException;
import net.sf.briar.api.protocol.MessageId;
/** An interface for creating an offer packet. */
public interface OfferWriter {
/**
* Sets the maximum length of the serialised offer. If this method is not
* called, the default is ProtocolConstants.MAX_PACKET_LENGTH;
*/
void setMaxPacketLength(int length);
/**
* Attempts to add the given message ID to the offer and returns true if it
* was added.
*/
boolean writeMessageId(MessageId m) throws IOException;
/** Finishes writing the offer. */
void finish() throws IOException;
}

View File

@@ -1,18 +0,0 @@
package net.sf.briar.api.protocol.writers;
import java.io.OutputStream;
public interface ProtocolWriterFactory {
AckWriter createAckWriter(OutputStream out);
BatchWriter createBatchWriter(OutputStream out);
OfferWriter createOfferWriter(OutputStream out);
RequestWriter createRequestWriter(OutputStream out);
SubscriptionUpdateWriter createSubscriptionUpdateWriter(OutputStream out);
TransportUpdateWriter createTransportUpdateWriter(OutputStream out);
}

View File

@@ -1,11 +0,0 @@
package net.sf.briar.api.protocol.writers;
import java.io.IOException;
import java.util.BitSet;
/** An interface for creating a request packet. */
public interface RequestWriter {
/** Writes the contents of the request. */
void writeRequest(BitSet b, int length) throws IOException;
}

View File

@@ -1,14 +0,0 @@
package net.sf.briar.api.protocol.writers;
import java.io.IOException;
import java.util.Map;
import net.sf.briar.api.protocol.Group;
/** An interface for creating a subscription update. */
public interface SubscriptionUpdateWriter {
/** Writes the contents of the update. */
void writeSubscriptions(Map<Group, Long> subs, long timestamp)
throws IOException;
}

View File

@@ -1,14 +0,0 @@
package net.sf.briar.api.protocol.writers;
import java.io.IOException;
import java.util.Collection;
import net.sf.briar.api.protocol.Transport;
/** An interface for creating a transport update. */
public interface TransportUpdateWriter {
/** Writes the contents of the update. */
void writeTransports(Collection<Transport> transports, long timestamp)
throws IOException;
}

View File

@@ -6,7 +6,7 @@ public interface SerialComponent {
int getSerialisedListStartLength(); int getSerialisedListStartLength();
int getSerialisedUniqueIdLength(int id);
int getSerialisedStructIdLength(int id); int getSerialisedStructIdLength(int id);
int getSerialisedUniqueIdLength(int id);
} }

View File

@@ -178,7 +178,8 @@ interface Database<T> {
* <p> * <p>
* Locking: contact read, messageStatus read. * Locking: contact read, messageStatus read.
*/ */
Collection<BatchId> getBatchesToAck(T txn, ContactId c) throws DbException; Collection<BatchId> getBatchesToAck(T txn, ContactId c, int maxBatches)
throws DbException;
/** /**
* Returns the configuration for the given transport. * Returns the configuration for the given transport.
@@ -315,6 +316,16 @@ interface Database<T> {
*/ */
int getNumberOfSendableChildren(T txn, MessageId m) throws DbException; int getNumberOfSendableChildren(T txn, MessageId m) throws DbException;
/**
* Returns the IDs of some messages that are eligible to be sent to the
* given contact, up to the given number of messages.
* <p>
* Locking: contact read, message read, messageStatus read,
* subscription read.
*/
Collection<MessageId> getOfferableMessages(T txn, ContactId c,
int maxMessages) throws DbException;
/** /**
* Returns the IDs of the oldest messages in the database, with a total * Returns the IDs of the oldest messages in the database, with a total
* size less than or equal to the given size. * size less than or equal to the given size.
@@ -361,16 +372,6 @@ interface Database<T> {
*/ */
int getSendability(T txn, MessageId m) throws DbException; int getSendability(T txn, MessageId m) throws DbException;
/**
* Returns the IDs of some messages that are eligible to be sent to the
* given contact.
* <p>
* Locking: contact read, message read, messageStatus read,
* subscription read.
*/
Collection<MessageId> getSendableMessages(T txn, ContactId c)
throws DbException;
/** /**
* Returns the IDs of some messages that are eligible to be sent to the * Returns the IDs of some messages that are eligible to be sent to the
* given contact, with a total size less than or equal to the given size. * given contact, with a total size less than or equal to the given size.

View File

@@ -20,7 +20,6 @@ import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
import net.sf.briar.api.Bytes;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.Rating; import net.sf.briar.api.Rating;
import net.sf.briar.api.TransportConfig; import net.sf.briar.api.TransportConfig;
@@ -51,17 +50,14 @@ import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.OfferWriter;
import net.sf.briar.api.protocol.writers.RequestWriter;
import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter;
import net.sf.briar.api.protocol.writers.TransportUpdateWriter;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.util.ByteUtils; import net.sf.briar.util.ByteUtils;
@@ -105,6 +101,7 @@ DatabaseCleaner.Callback {
private final Database<T> db; private final Database<T> db;
private final DatabaseCleaner cleaner; private final DatabaseCleaner cleaner;
private final ShutdownManager shutdown; private final ShutdownManager shutdown;
private final PacketFactory packetFactory;
private final Collection<DatabaseListener> listeners = private final Collection<DatabaseListener> listeners =
new CopyOnWriteArrayList<DatabaseListener>(); new CopyOnWriteArrayList<DatabaseListener>();
@@ -119,10 +116,11 @@ DatabaseCleaner.Callback {
@Inject @Inject
DatabaseComponentImpl(Database<T> db, DatabaseCleaner cleaner, DatabaseComponentImpl(Database<T> db, DatabaseCleaner cleaner,
ShutdownManager shutdown) { ShutdownManager shutdown, PacketFactory packetFactory) {
this.db = db; this.db = db;
this.cleaner = cleaner; this.cleaner = cleaner;
this.shutdown = shutdown; this.shutdown = shutdown;
this.packetFactory = packetFactory;
} }
public void open(boolean resume) throws DbException, IOException { public void open(boolean resume) throws DbException, IOException {
@@ -265,7 +263,7 @@ DatabaseCleaner.Callback {
if(sendability > 0) updateAncestorSendability(txn, id, true); if(sendability > 0) updateAncestorSendability(txn, id, true);
// Count the bytes stored // Count the bytes stored
synchronized(spaceLock) { synchronized(spaceLock) {
bytesStoredSinceLastCheck += m.getLength(); bytesStoredSinceLastCheck += m.getSerialised().length;
} }
} }
return stored; return stored;
@@ -373,7 +371,7 @@ DatabaseCleaner.Callback {
else db.setStatus(txn, c, id, Status.NEW); else db.setStatus(txn, c, id, Status.NEW);
// Count the bytes stored // Count the bytes stored
synchronized(spaceLock) { synchronized(spaceLock) {
bytesStoredSinceLastCheck += m.getLength(); bytesStoredSinceLastCheck += m.getSerialised().length;
} }
return true; return true;
} }
@@ -415,17 +413,16 @@ DatabaseCleaner.Callback {
return i; return i;
} }
public boolean generateAck(ContactId c, AckWriter a) throws DbException, public Ack generateAck(ContactId c, int maxBatches) throws DbException {
IOException { Collection<BatchId> acked;
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException(); if(!containsContact(c)) throw new NoSuchContactException();
Collection<BatchId> acks, sent = new ArrayList<BatchId>();
messageStatusLock.readLock().lock(); messageStatusLock.readLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
acks = db.getBatchesToAck(txn, c); acked = db.getBatchesToAck(txn, c, maxBatches);
db.commitTransaction(txn); db.commitTransaction(txn);
} catch(DbException e) { } catch(DbException e) {
db.abortTransaction(txn); db.abortTransaction(txn);
@@ -434,20 +431,14 @@ DatabaseCleaner.Callback {
} finally { } finally {
messageStatusLock.readLock().unlock(); messageStatusLock.readLock().unlock();
} }
for(BatchId b : acks) { if(acked.isEmpty()) return null;
if(!a.writeBatchId(b)) break; // Record the contents of the ack
sent.add(b);
}
// Record the contents of the ack, unless it's empty
if(sent.isEmpty()) return false;
a.finish();
messageStatusLock.writeLock().lock(); messageStatusLock.writeLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
db.removeBatchesToAck(txn, c, sent); db.removeBatchesToAck(txn, c, acked);
db.commitTransaction(txn); db.commitTransaction(txn);
return true;
} catch(DbException e) { } catch(DbException e) {
db.abortTransaction(txn); db.abortTransaction(txn);
throw e; throw e;
@@ -458,12 +449,14 @@ DatabaseCleaner.Callback {
} finally { } finally {
contactLock.readLock().unlock(); contactLock.readLock().unlock();
} }
return packetFactory.createAck(acked);
} }
public boolean generateBatch(ContactId c, BatchWriter b) throws DbException, public RawBatch generateBatch(ContactId c, int capacity)
IOException { throws DbException {
Collection<MessageId> ids; Collection<MessageId> ids;
Collection<Bytes> messages = new ArrayList<Bytes>(); List<byte[]> messages = new ArrayList<byte[]>();
RawBatch b;
// Get some sendable messages from the database // Get some sendable messages from the database
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
@@ -476,10 +469,9 @@ DatabaseCleaner.Callback {
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
int capacity = b.getCapacity();
ids = db.getSendableMessages(txn, c, capacity); ids = db.getSendableMessages(txn, c, capacity);
for(MessageId m : ids) { for(MessageId m : ids) {
messages.add(new Bytes(db.getMessage(txn, m))); messages.add(db.getMessage(txn, m));
} }
db.commitTransaction(txn); db.commitTransaction(txn);
} catch(DbException e) { } catch(DbException e) {
@@ -492,40 +484,14 @@ DatabaseCleaner.Callback {
} finally { } finally {
messageStatusLock.readLock().unlock(); messageStatusLock.readLock().unlock();
} }
} finally { if(messages.isEmpty()) return null;
messageLock.readLock().unlock(); messages = Collections.unmodifiableList(messages);
} b = packetFactory.createBatch(messages);
} finally {
contactLock.readLock().unlock();
}
if(ids.isEmpty()) return false;
writeAndRecordBatch(c, b, ids, messages);
return true;
}
private void writeAndRecordBatch(ContactId c, BatchWriter b,
Collection<MessageId> ids, Collection<Bytes> messages)
throws DbException, IOException {
assert !ids.isEmpty();
assert !messages.isEmpty();
assert ids.size() == messages.size();
// Add the messages to the batch
for(Bytes raw : messages) {
boolean written = b.writeMessage(raw.getBytes());
assert written;
}
BatchId id = b.finish();
// Record the contents of the batch
contactLock.readLock().lock();
try {
if(!containsContact(c)) throw new NoSuchContactException();
messageLock.readLock().lock();
try {
messageStatusLock.writeLock().lock(); messageStatusLock.writeLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
db.addOutstandingBatch(txn, c, id, ids); db.addOutstandingBatch(txn, c, b.getId(), ids);
db.commitTransaction(txn); db.commitTransaction(txn);
} catch(DbException e) { } catch(DbException e) {
db.abortTransaction(txn); db.abortTransaction(txn);
@@ -540,12 +506,14 @@ DatabaseCleaner.Callback {
} finally { } finally {
contactLock.readLock().unlock(); contactLock.readLock().unlock();
} }
return b;
} }
public boolean generateBatch(ContactId c, BatchWriter b, public RawBatch generateBatch(ContactId c, int capacity,
Collection<MessageId> requested) throws DbException, IOException { Collection<MessageId> requested) throws DbException {
Collection<MessageId> ids = new ArrayList<MessageId>(); Collection<MessageId> ids = new ArrayList<MessageId>();
Collection<Bytes> messages = new ArrayList<Bytes>(); List<byte[]> messages = new ArrayList<byte[]>();
RawBatch b;
// Get some sendable messages from the database // Get some sendable messages from the database
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
@@ -558,15 +526,15 @@ DatabaseCleaner.Callback {
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
int capacity = b.getCapacity();
Iterator<MessageId> it = requested.iterator(); Iterator<MessageId> it = requested.iterator();
while(it.hasNext()) { while(it.hasNext()) {
MessageId m = it.next(); MessageId m = it.next();
byte[] raw = db.getMessageIfSendable(txn, c, m); byte[] raw = db.getMessageIfSendable(txn, c, m);
if(raw != null) { if(raw != null) {
if(raw.length > capacity) break; if(raw.length > capacity) break;
messages.add(raw);
ids.add(m); ids.add(m);
messages.add(new Bytes(raw)); capacity -= raw.length;
} }
it.remove(); it.remove();
} }
@@ -581,21 +549,34 @@ DatabaseCleaner.Callback {
} finally { } finally {
messageStatusLock.readLock().unlock(); messageStatusLock.readLock().unlock();
} }
if(messages.isEmpty()) return null;
messages = Collections.unmodifiableList(messages);
b = packetFactory.createBatch(messages);
messageStatusLock.writeLock().lock();
try {
T txn = db.startTransaction();
try {
db.addOutstandingBatch(txn, c, b.getId(), ids);
db.commitTransaction(txn);
} catch(DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
messageStatusLock.writeLock().unlock();
}
} finally { } finally {
messageLock.readLock().unlock(); messageLock.readLock().unlock();
} }
} finally { } finally {
contactLock.readLock().unlock(); contactLock.readLock().unlock();
} }
if(ids.isEmpty()) return false; return b;
writeAndRecordBatch(c, b, ids, messages);
return true;
} }
public Collection<MessageId> generateOffer(ContactId c, OfferWriter o) public Offer generateOffer(ContactId c, int maxMessages)
throws DbException, IOException { throws DbException {
Collection<MessageId> sendable; Collection<MessageId> offered;
List<MessageId> sent = new ArrayList<MessageId>();
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException(); if(!containsContact(c)) throw new NoSuchContactException();
@@ -605,7 +586,7 @@ DatabaseCleaner.Callback {
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
sendable = db.getSendableMessages(txn, c); offered = db.getOfferableMessages(txn, c, maxMessages);
db.commitTransaction(txn); db.commitTransaction(txn);
} catch(DbException e) { } catch(DbException e) {
db.abortTransaction(txn); db.abortTransaction(txn);
@@ -620,33 +601,41 @@ DatabaseCleaner.Callback {
} finally { } finally {
contactLock.readLock().unlock(); contactLock.readLock().unlock();
} }
for(MessageId m : sendable) { return packetFactory.createOffer(offered);
if(!o.writeMessageId(m)) break;
sent.add(m);
}
if(!sent.isEmpty()) o.finish();
return Collections.unmodifiableList(sent);
} }
public void generateSubscriptionUpdate(ContactId c, public SubscriptionUpdate generateSubscriptionUpdate(ContactId c)
SubscriptionUpdateWriter s) throws DbException, IOException { throws DbException {
Map<Group, Long> subs = null; boolean due;
long timestamp = 0L; Map<Group, Long> subs;
long timestamp;
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException(); if(!containsContact(c)) throw new NoSuchContactException();
subscriptionLock.writeLock().lock(); subscriptionLock.readLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
// Work out whether an update is due // Work out whether an update is due
long modified = db.getSubscriptionsModified(txn, c); long modified = db.getSubscriptionsModified(txn, c);
long sent = db.getSubscriptionsSent(txn, c); long sent = db.getSubscriptionsSent(txn, c);
if(modified >= sent || updateIsDue(sent)) { due = modified >= sent || updateIsDue(sent);
subs = db.getVisibleSubscriptions(txn, c); db.commitTransaction(txn);
timestamp = System.currentTimeMillis(); } catch(DbException e) {
db.setSubscriptionsSent(txn, c, timestamp); db.abortTransaction(txn);
} throw e;
}
} finally {
subscriptionLock.readLock().unlock();
}
if(!due) return null;
subscriptionLock.writeLock().lock();
try {
T txn = db.startTransaction();
try {
subs = db.getVisibleSubscriptions(txn, c);
timestamp = System.currentTimeMillis();
db.setSubscriptionsSent(txn, c, timestamp);
db.commitTransaction(txn); db.commitTransaction(txn);
} catch(DbException e) { } catch(DbException e) {
db.abortTransaction(txn); db.abortTransaction(txn);
@@ -658,7 +647,7 @@ DatabaseCleaner.Callback {
} finally { } finally {
contactLock.readLock().unlock(); contactLock.readLock().unlock();
} }
if(subs != null) s.writeSubscriptions(subs, timestamp); return packetFactory.createSubscriptionUpdate(subs, timestamp);
} }
private boolean updateIsDue(long sent) { private boolean updateIsDue(long sent) {
@@ -666,25 +655,38 @@ DatabaseCleaner.Callback {
return now - sent >= DatabaseConstants.MAX_UPDATE_INTERVAL; return now - sent >= DatabaseConstants.MAX_UPDATE_INTERVAL;
} }
public void generateTransportUpdate(ContactId c, TransportUpdateWriter t) public TransportUpdate generateTransportUpdate(ContactId c)
throws DbException, IOException { throws DbException {
Collection<Transport> transports = null; boolean due;
long timestamp = 0L; Collection<Transport> transports;
long timestamp;
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException(); if(!containsContact(c)) throw new NoSuchContactException();
transportLock.writeLock().lock(); transportLock.readLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
// Work out whether an update is due // Work out whether an update is due
long modified = db.getTransportsModified(txn); long modified = db.getTransportsModified(txn);
long sent = db.getTransportsSent(txn, c); long sent = db.getTransportsSent(txn, c);
if(modified >= sent || updateIsDue(sent)) { due = modified >= sent || updateIsDue(sent);
transports = db.getLocalTransports(txn); db.commitTransaction(txn);
timestamp = System.currentTimeMillis(); } catch(DbException e) {
db.setTransportsSent(txn, c, timestamp); db.abortTransaction(txn);
} throw e;
}
} finally {
transportLock.readLock().unlock();
}
if(!due) return null;
transportLock.writeLock().lock();
try {
T txn = db.startTransaction();
try {
transports = db.getLocalTransports(txn);
timestamp = System.currentTimeMillis();
db.setTransportsSent(txn, c, timestamp);
db.commitTransaction(txn); db.commitTransaction(txn);
} catch(DbException e) { } catch(DbException e) {
db.abortTransaction(txn); db.abortTransaction(txn);
@@ -696,7 +698,7 @@ DatabaseCleaner.Callback {
} finally { } finally {
contactLock.readLock().unlock(); contactLock.readLock().unlock();
} }
if(transports != null) t.writeTransports(transports, timestamp); return packetFactory.createTransportUpdate(transports, timestamp);
} }
public TransportConfig getConfig(TransportId t) throws DbException { public TransportConfig getConfig(TransportId t) throws DbException {
@@ -1119,8 +1121,7 @@ DatabaseCleaner.Callback {
return anyStored; return anyStored;
} }
public void receiveOffer(ContactId c, Offer o, RequestWriter r) public Request receiveOffer(ContactId c, Offer o) throws DbException {
throws DbException, IOException {
Collection<MessageId> offered; Collection<MessageId> offered;
BitSet request; BitSet request;
contactLock.readLock().lock(); contactLock.readLock().lock();
@@ -1161,7 +1162,7 @@ DatabaseCleaner.Callback {
} finally { } finally {
contactLock.readLock().unlock(); contactLock.readLock().unlock();
} }
r.writeRequest(request, offered.size()); return packetFactory.createRequest(request, offered.size());
} }
public void receiveSubscriptionUpdate(ContactId c, SubscriptionUpdate s) public void receiveSubscriptionUpdate(ContactId c, SubscriptionUpdate s)

View File

@@ -10,6 +10,7 @@ import net.sf.briar.api.db.DatabaseMaxSize;
import net.sf.briar.api.db.DatabasePassword; import net.sf.briar.api.db.DatabasePassword;
import net.sf.briar.api.lifecycle.ShutdownManager; import net.sf.briar.api.lifecycle.ShutdownManager;
import net.sf.briar.api.protocol.GroupFactory; import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.transport.ConnectionContextFactory; import net.sf.briar.api.transport.ConnectionContextFactory;
import net.sf.briar.api.transport.ConnectionWindowFactory; import net.sf.briar.api.transport.ConnectionWindowFactory;
@@ -36,7 +37,9 @@ public class DatabaseModule extends AbstractModule {
@Provides @Singleton @Provides @Singleton
DatabaseComponent getDatabaseComponent(Database<Connection> db, DatabaseComponent getDatabaseComponent(Database<Connection> db,
DatabaseCleaner cleaner, ShutdownManager shutdown) { DatabaseCleaner cleaner, ShutdownManager shutdown,
return new DatabaseComponentImpl<Connection>(db, cleaner, shutdown); PacketFactory packetFactory) {
return new DatabaseComponentImpl<Connection>(db, cleaner, shutdown,
packetFactory);
} }
} }

View File

@@ -612,10 +612,11 @@ abstract class JdbcDatabase implements Database<Connection> {
else ps.setBytes(4, m.getAuthor().getBytes()); else ps.setBytes(4, m.getAuthor().getBytes());
ps.setString(5, m.getSubject()); ps.setString(5, m.getSubject());
ps.setLong(6, m.getTimestamp()); ps.setLong(6, m.getTimestamp());
ps.setInt(7, m.getLength()); byte[] raw = m.getSerialised();
ps.setInt(7, raw.length);
ps.setInt(8, m.getBodyStart()); ps.setInt(8, m.getBodyStart());
ps.setInt(9, m.getBodyLength()); ps.setInt(9, m.getBodyLength());
ps.setBytes(10, m.getSerialised()); ps.setBytes(10, raw);
int affected = ps.executeUpdate(); int affected = ps.executeUpdate();
if(affected != 1) throw new DbStateException(); if(affected != 1) throw new DbStateException();
ps.close(); ps.close();
@@ -700,10 +701,11 @@ abstract class JdbcDatabase implements Database<Connection> {
else ps.setBytes(2, m.getParent().getBytes()); else ps.setBytes(2, m.getParent().getBytes());
ps.setString(3, m.getSubject()); ps.setString(3, m.getSubject());
ps.setLong(4, m.getTimestamp()); ps.setLong(4, m.getTimestamp());
ps.setInt(5, m.getLength()); byte[] raw = m.getSerialised();
ps.setInt(5, raw.length);
ps.setInt(6, m.getBodyStart()); ps.setInt(6, m.getBodyStart());
ps.setInt(7, m.getBodyLength()); ps.setInt(7, m.getBodyLength());
ps.setBytes(8, m.getSerialised()); ps.setBytes(8, raw);
ps.setInt(9, c.getInt()); ps.setInt(9, c.getInt());
int affected = ps.executeUpdate(); int affected = ps.executeUpdate();
if(affected != 1) throw new DbStateException(); if(affected != 1) throw new DbStateException();
@@ -889,15 +891,17 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
} }
public Collection<BatchId> getBatchesToAck(Connection txn, ContactId c) public Collection<BatchId> getBatchesToAck(Connection txn, ContactId c,
throws DbException { int maxBatches) throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null; ResultSet rs = null;
try { try {
String sql = "SELECT batchId FROM batchesToAck" String sql = "SELECT batchId FROM batchesToAck"
+ " WHERE contactId = ?"; + " WHERE contactId = ?"
+ " LIMIT ?";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); ps.setInt(1, c.getInt());
ps.setInt(2, maxBatches);
rs = ps.executeQuery(); rs = ps.executeQuery();
List<BatchId> ids = new ArrayList<BatchId>(); List<BatchId> ids = new ArrayList<BatchId>();
while(rs.next()) ids.add(new BatchId(rs.getBytes(1))); while(rs.next()) ids.add(new BatchId(rs.getBytes(1)));
@@ -1517,8 +1521,8 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
} }
public Collection<MessageId> getSendableMessages(Connection txn, public Collection<MessageId> getOfferableMessages(Connection txn,
ContactId c) throws DbException { ContactId c, int maxMessages) throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null; ResultSet rs = null;
try { try {
@@ -1526,15 +1530,19 @@ abstract class JdbcDatabase implements Database<Connection> {
String sql = "SELECT messages.messageId FROM messages" String sql = "SELECT messages.messageId FROM messages"
+ " JOIN statuses ON messages.messageId = statuses.messageId" + " JOIN statuses ON messages.messageId = statuses.messageId"
+ " WHERE messages.contactId = ? AND status = ?" + " WHERE messages.contactId = ? AND status = ?"
+ " ORDER BY timestamp"; + " ORDER BY timestamp"
+ " LIMIT ?";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); ps.setInt(1, c.getInt());
ps.setShort(2, (short) Status.NEW.ordinal()); ps.setShort(2, (short) Status.NEW.ordinal());
ps.setInt(3, maxMessages);
rs = ps.executeQuery(); rs = ps.executeQuery();
List<MessageId> ids = new ArrayList<MessageId>(); List<MessageId> ids = new ArrayList<MessageId>();
while(rs.next()) ids.add(new MessageId(rs.getBytes(2))); while(rs.next()) ids.add(new MessageId(rs.getBytes(2)));
rs.close(); rs.close();
ps.close(); ps.close();
if(ids.size() == maxMessages)
return Collections.unmodifiableList(ids);
// Do we have any sendable group messages? // Do we have any sendable group messages?
sql = "SELECT m.messageId FROM messages AS m" sql = "SELECT m.messageId FROM messages AS m"
+ " JOIN contactSubscriptions AS cs" + " JOIN contactSubscriptions AS cs"
@@ -1547,10 +1555,12 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " AND timestamp >= start" + " AND timestamp >= start"
+ " AND status = ?" + " AND status = ?"
+ " AND sendability > ZERO()" + " AND sendability > ZERO()"
+ " ORDER BY timestamp"; + " ORDER BY timestamp"
+ " LIMIT ?";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); ps.setInt(1, c.getInt());
ps.setShort(2, (short) Status.NEW.ordinal()); ps.setShort(2, (short) Status.NEW.ordinal());
ps.setInt(3, maxMessages - ids.size());
rs = ps.executeQuery(); rs = ps.executeQuery();
while(rs.next()) ids.add(new MessageId(rs.getBytes(2))); while(rs.next()) ids.add(new MessageId(rs.getBytes(2)));
rs.close(); rs.close();

View File

@@ -1,11 +0,0 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.BatchId;
interface AckFactory {
Ack createAck(Collection<BatchId> acked);
}

View File

@@ -1,13 +0,0 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.BatchId;
class AckFactoryImpl implements AckFactory {
public Ack createAck(Collection<BatchId> acked) {
return new AckImpl(acked);
}
}

View File

@@ -6,6 +6,7 @@ import java.util.Collection;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.protocol.UniqueId;
@@ -16,11 +17,11 @@ import net.sf.briar.api.serial.Reader;
class AckReader implements ObjectReader<Ack> { class AckReader implements ObjectReader<Ack> {
private final AckFactory ackFactory; private final PacketFactory packetFactory;
private final ObjectReader<BatchId> batchIdReader; private final ObjectReader<BatchId> batchIdReader;
AckReader(AckFactory ackFactory) { AckReader(PacketFactory packetFactory) {
this.ackFactory = ackFactory; this.packetFactory = packetFactory;
batchIdReader = new BatchIdReader(); batchIdReader = new BatchIdReader();
} }
@@ -36,7 +37,7 @@ class AckReader implements ObjectReader<Ack> {
r.removeObjectReader(Types.BATCH_ID); r.removeObjectReader(Types.BATCH_ID);
r.removeConsumer(counting); r.removeConsumer(counting);
// Build and return the ack // Build and return the ack
return ackFactory.createAck(batches); return packetFactory.createAck(batches);
} }
private static class BatchIdReader implements ObjectReader<BatchId> { private static class BatchIdReader implements ObjectReader<BatchId> {

View File

@@ -59,10 +59,6 @@ class MessageImpl implements Message {
return timestamp; return timestamp;
} }
public int getLength() {
return raw.length;
}
public byte[] getSerialised() { public byte[] getSerialised() {
return raw; return raw;
} }

View File

@@ -1,11 +0,0 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer;
interface OfferFactory {
Offer createOffer(Collection<MessageId> offered);
}

View File

@@ -1,13 +0,0 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer;
class OfferFactoryImpl implements OfferFactory {
public Offer createOffer(Collection<MessageId> offered) {
return new OfferImpl(offered);
}
}

View File

@@ -5,6 +5,7 @@ import java.util.Collection;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Consumer;
@@ -15,12 +16,12 @@ import net.sf.briar.api.serial.Reader;
class OfferReader implements ObjectReader<Offer> { class OfferReader implements ObjectReader<Offer> {
private final ObjectReader<MessageId> messageIdReader; private final ObjectReader<MessageId> messageIdReader;
private final OfferFactory offerFactory; private final PacketFactory packetFactory;
OfferReader(ObjectReader<MessageId> messageIdReader, OfferReader(ObjectReader<MessageId> messageIdReader,
OfferFactory offerFactory) { PacketFactory packetFactory) {
this.messageIdReader = messageIdReader; this.messageIdReader = messageIdReader;
this.offerFactory = offerFactory; this.packetFactory = packetFactory;
} }
public Offer readObject(Reader r) throws IOException { public Offer readObject(Reader r) throws IOException {
@@ -35,6 +36,6 @@ class OfferReader implements ObjectReader<Offer> {
r.removeObjectReader(Types.MESSAGE_ID); r.removeObjectReader(Types.MESSAGE_ID);
r.removeConsumer(counting); r.removeConsumer(counting);
// Build and return the offer // Build and return the offer
return offerFactory.createOffer(messages); return packetFactory.createOffer(messages);
} }
} }

View File

@@ -0,0 +1,59 @@
package net.sf.briar.protocol;
import java.util.BitSet;
import java.util.Collection;
import java.util.Map;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportUpdate;
import com.google.inject.Inject;
class PacketFactoryImpl implements PacketFactory {
private final CryptoComponent crypto;
@Inject
PacketFactoryImpl(CryptoComponent crypto) {
this.crypto = crypto;
}
public Ack createAck(Collection<BatchId> acked) {
return new AckImpl(acked);
}
public RawBatch createBatch(Collection<byte[]> messages) {
MessageDigest messageDigest = crypto.getMessageDigest();
for(byte[] raw : messages) messageDigest.update(raw);
return new RawBatchImpl(new BatchId(messageDigest.digest()), messages);
}
public Offer createOffer(Collection<MessageId> offered) {
return new OfferImpl(offered);
}
public Request createRequest(BitSet requested, int length) {
return new RequestImpl(requested, length);
}
public SubscriptionUpdate createSubscriptionUpdate(Map<Group, Long> subs,
long timestamp) {
return new SubscriptionUpdateImpl(subs, timestamp);
}
public TransportUpdate createTransportUpdate(
Collection<Transport> transports, long timestamp) {
return new TransportUpdateImpl(transports, timestamp);
}
}

View File

@@ -9,7 +9,9 @@ import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.MessageFactory; import net.sf.briar.api.protocol.MessageFactory;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
@@ -23,21 +25,17 @@ public class ProtocolModule extends AbstractModule {
@Override @Override
protected void configure() { protected void configure() {
bind(AckFactory.class).to(AckFactoryImpl.class);
bind(AuthorFactory.class).to(AuthorFactoryImpl.class); bind(AuthorFactory.class).to(AuthorFactoryImpl.class);
bind(GroupFactory.class).to(GroupFactoryImpl.class); bind(GroupFactory.class).to(GroupFactoryImpl.class);
bind(MessageFactory.class).to(MessageFactoryImpl.class); bind(MessageFactory.class).to(MessageFactoryImpl.class);
bind(OfferFactory.class).to(OfferFactoryImpl.class); bind(PacketFactory.class).to(PacketFactoryImpl.class);
bind(ProtocolReaderFactory.class).to(ProtocolReaderFactoryImpl.class); bind(ProtocolReaderFactory.class).to(ProtocolReaderFactoryImpl.class);
bind(RequestFactory.class).to(RequestFactoryImpl.class); bind(ProtocolWriterFactory.class).to(ProtocolWriterFactoryImpl.class);
bind(SubscriptionUpdateFactory.class).to(
SubscriptionUpdateFactoryImpl.class);
bind(TransportUpdateFactory.class).to(TransportUpdateFactoryImpl.class);
bind(UnverifiedBatchFactory.class).to(UnverifiedBatchFactoryImpl.class); bind(UnverifiedBatchFactory.class).to(UnverifiedBatchFactoryImpl.class);
} }
@Provides @Provides
ObjectReader<Ack> getAckReader(AckFactory ackFactory) { ObjectReader<Ack> getAckReader(PacketFactory ackFactory) {
return new AckReader(ackFactory); return new AckReader(ackFactory);
} }
@@ -75,25 +73,24 @@ public class ProtocolModule extends AbstractModule {
@Provides @Provides
ObjectReader<Offer> getOfferReader(ObjectReader<MessageId> messageIdReader, ObjectReader<Offer> getOfferReader(ObjectReader<MessageId> messageIdReader,
OfferFactory offerFactory) { PacketFactory packetFactory) {
return new OfferReader(messageIdReader, offerFactory); return new OfferReader(messageIdReader, packetFactory);
} }
@Provides @Provides
ObjectReader<Request> getRequestReader(RequestFactory requestFactory) { ObjectReader<Request> getRequestReader(PacketFactory packetFactory) {
return new RequestReader(requestFactory); return new RequestReader(packetFactory);
} }
@Provides @Provides
ObjectReader<SubscriptionUpdate> getSubscriptionReader( ObjectReader<SubscriptionUpdate> getSubscriptionReader(
ObjectReader<Group> groupReader, ObjectReader<Group> groupReader, PacketFactory packetFactory) {
SubscriptionUpdateFactory subscriptionFactory) { return new SubscriptionUpdateReader(groupReader, packetFactory);
return new SubscriptionUpdateReader(groupReader, subscriptionFactory);
} }
@Provides @Provides
ObjectReader<TransportUpdate> getTransportReader( ObjectReader<TransportUpdate> getTransportReader(
TransportUpdateFactory transportFactory) { PacketFactory packetFactory) {
return new TransportUpdateReader(transportFactory); return new TransportUpdateReader(packetFactory);
} }
} }

View File

@@ -0,0 +1,27 @@
package net.sf.briar.protocol;
import java.io.OutputStream;
import net.sf.briar.api.protocol.ProtocolWriter;
import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.serial.WriterFactory;
import com.google.inject.Inject;
class ProtocolWriterFactoryImpl implements ProtocolWriterFactory {
private final SerialComponent serial;
private final WriterFactory writerFactory;
@Inject
ProtocolWriterFactoryImpl(SerialComponent serial,
WriterFactory writerFactory) {
this.serial = serial;
this.writerFactory = writerFactory;
}
public ProtocolWriter createProtocolWriter(OutputStream out) {
return new ProtocolWriterImpl(serial, writerFactory, out);
}
}

View File

@@ -0,0 +1,143 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.io.OutputStream;
import java.util.BitSet;
import java.util.Map.Entry;
import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
import net.sf.briar.api.protocol.ProtocolWriter;
import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
// This class is not thread-safe
class ProtocolWriterImpl implements ProtocolWriter {
private final SerialComponent serial;
private final OutputStream out;
private final Writer w;
ProtocolWriterImpl(SerialComponent serial, WriterFactory writerFactory,
OutputStream out) {
this.serial = serial;
this.out = out;
w = writerFactory.createWriter(out);
}
public int getMaxBatchesForAck(long capacity) {
int packet = (int) Math.min(capacity, MAX_PACKET_LENGTH);
int overhead = serial.getSerialisedStructIdLength(Types.ACK)
+ serial.getSerialisedListStartLength()
+ serial.getSerialisedListEndLength();
int idLength = serial.getSerialisedUniqueIdLength(Types.BATCH_ID);
return (packet - overhead) / idLength;
}
public int getMaxMessagesForOffer(long capacity) {
int packet = (int) Math.min(capacity, MAX_PACKET_LENGTH);
int overhead = serial.getSerialisedStructIdLength(Types.OFFER)
+ serial.getSerialisedListStartLength()
+ serial.getSerialisedListEndLength();
int idLength = serial.getSerialisedUniqueIdLength(Types.MESSAGE_ID);
return (packet - overhead) / idLength;
}
public int getMessageCapacityForBatch(long capacity) {
int packet = (int) Math.min(capacity, MAX_PACKET_LENGTH);
int overhead = serial.getSerialisedStructIdLength(Types.BATCH)
+ serial.getSerialisedListStartLength()
+ serial.getSerialisedListEndLength();
return packet - overhead;
}
public void writeAck(Ack a) throws IOException {
w.writeStructId(Types.ACK);
w.writeListStart();
for(BatchId b : a.getBatchIds()) {
w.writeStructId(Types.BATCH_ID);
w.writeBytes(b.getBytes());
}
w.writeListEnd();
}
public void writeBatch(RawBatch b) throws IOException {
w.writeStructId(Types.BATCH);
w.writeListStart();
for(byte[] raw : b.getMessages()) out.write(raw);
w.writeListEnd();
}
public void writeOffer(Offer o) throws IOException {
w.writeStructId(Types.OFFER);
w.writeListStart();
for(MessageId m : o.getMessageIds()) {
w.writeStructId(Types.MESSAGE_ID);
w.writeBytes(m.getBytes());
}
w.writeListEnd();
}
public void writeRequest(Request r) throws IOException {
BitSet b = r.getBitmap();
int length = r.getLength();
// If the number of bits isn't a multiple of 8, round up to a byte
int bytes = length % 8 == 0 ? length / 8 : length / 8 + 1;
byte[] bitmap = new byte[bytes];
// I'm kind of surprised BitSet doesn't have a method for this
for(int i = 0; i < length; i++) {
if(b.get(i)) {
int offset = i / 8;
byte bit = (byte) (128 >> i % 8);
bitmap[offset] |= bit;
}
}
w.writeStructId(Types.REQUEST);
w.writeUint7((byte) (bytes * 8 - length));
w.writeBytes(bitmap);
}
public void writeSubscriptionUpdate(SubscriptionUpdate s)
throws IOException {
w.writeStructId(Types.SUBSCRIPTION_UPDATE);
w.writeMapStart();
for(Entry<Group, Long> e : s.getSubscriptions().entrySet()) {
writeGroup(w, e.getKey());
w.writeInt64(e.getValue());
}
w.writeMapEnd();
w.writeInt64(s.getTimestamp());
}
private void writeGroup(Writer w, Group g) throws IOException {
w.writeStructId(Types.GROUP);
w.writeString(g.getName());
byte[] publicKey = g.getPublicKey();
if(publicKey == null) w.writeNull();
else w.writeBytes(publicKey);
}
public void writeTransportUpdate(TransportUpdate t) throws IOException {
w.writeStructId(Types.TRANSPORT_UPDATE);
w.writeListStart();
for(Transport p : t.getTransports()) {
w.writeStructId(Types.TRANSPORT);
w.writeBytes(p.getId().getBytes());
w.writeInt32(p.getIndex().getInt());
w.writeMap(p);
}
w.writeListEnd();
w.writeInt64(t.getTimestamp());
}
}

View File

@@ -0,0 +1,25 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.RawBatch;
class RawBatchImpl implements RawBatch {
private final BatchId id;
private final Collection<byte[]> messages;
RawBatchImpl(BatchId id, Collection<byte[]> messages) {
this.id = id;
this.messages = messages;
}
public BatchId getId() {
return id;
}
public Collection<byte[]> getMessages() {
return messages;
}
}

View File

@@ -1,10 +0,0 @@
package net.sf.briar.protocol;
import java.util.BitSet;
import net.sf.briar.api.protocol.Request;
interface RequestFactory {
Request createRequest(BitSet requested);
}

View File

@@ -1,12 +0,0 @@
package net.sf.briar.protocol;
import java.util.BitSet;
import net.sf.briar.api.protocol.Request;
class RequestFactoryImpl implements RequestFactory {
public Request createRequest(BitSet requested) {
return new RequestImpl(requested);
}
}

View File

@@ -7,12 +7,18 @@ import net.sf.briar.api.protocol.Request;
class RequestImpl implements Request { class RequestImpl implements Request {
private final BitSet requested; private final BitSet requested;
private final int length;
RequestImpl(BitSet requested) { RequestImpl(BitSet requested, int length) {
this.requested = requested; this.requested = requested;
this.length = length;
} }
public BitSet getBitmap() { public BitSet getBitmap() {
return requested; return requested;
} }
public int getLength() {
return length;
}
} }

View File

@@ -3,6 +3,8 @@ package net.sf.briar.protocol;
import java.io.IOException; import java.io.IOException;
import java.util.BitSet; import java.util.BitSet;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
@@ -13,10 +15,10 @@ import net.sf.briar.api.serial.Reader;
class RequestReader implements ObjectReader<Request> { class RequestReader implements ObjectReader<Request> {
private final RequestFactory requestFactory; private final PacketFactory packetFactory;
RequestReader(RequestFactory requestFactory) { RequestReader(PacketFactory packetFactory) {
this.requestFactory = requestFactory; this.packetFactory = packetFactory;
} }
public Request readObject(Reader r) throws IOException { public Request readObject(Reader r) throws IOException {
@@ -26,16 +28,19 @@ class RequestReader implements ObjectReader<Request> {
// Read the data // Read the data
r.addConsumer(counting); r.addConsumer(counting);
r.readStructId(Types.REQUEST); r.readStructId(Types.REQUEST);
int padding = r.readUint7();
if(padding > 7) throw new FormatException();
byte[] bitmap = r.readBytes(ProtocolConstants.MAX_PACKET_LENGTH); byte[] bitmap = r.readBytes(ProtocolConstants.MAX_PACKET_LENGTH);
r.removeConsumer(counting); r.removeConsumer(counting);
// Convert the bitmap into a BitSet // Convert the bitmap into a BitSet
BitSet b = new BitSet(bitmap.length * 8); int length = bitmap.length * 8 - padding;
BitSet b = new BitSet(length);
for(int i = 0; i < bitmap.length; i++) { for(int i = 0; i < bitmap.length; i++) {
for(int j = 0; j < 8; j++) { for(int j = 0; j < 8 && i * 8 + j < length; j++) {
byte bit = (byte) (128 >> j); byte bit = (byte) (128 >> j);
if((bitmap[i] & bit) != 0) b.set(i * 8 + j); if((bitmap[i] & bit) != 0) b.set(i * 8 + j);
} }
} }
return requestFactory.createRequest(b); return packetFactory.createRequest(b, length);
} }
} }

View File

@@ -1,12 +0,0 @@
package net.sf.briar.protocol;
import java.util.Map;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.SubscriptionUpdate;
interface SubscriptionUpdateFactory {
SubscriptionUpdate createSubscriptions(Map<Group, Long> subs,
long timestamp);
}

View File

@@ -1,14 +0,0 @@
package net.sf.briar.protocol;
import java.util.Map;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.SubscriptionUpdate;
class SubscriptionUpdateFactoryImpl implements SubscriptionUpdateFactory {
public SubscriptionUpdate createSubscriptions(Map<Group, Long> subs,
long timestamp) {
return new SubscriptionUpdateImpl(subs, timestamp);
}
}

View File

@@ -5,6 +5,7 @@ import java.util.Map;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.Group; import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
@@ -16,12 +17,12 @@ import net.sf.briar.api.serial.Reader;
class SubscriptionUpdateReader implements ObjectReader<SubscriptionUpdate> { class SubscriptionUpdateReader implements ObjectReader<SubscriptionUpdate> {
private final ObjectReader<Group> groupReader; private final ObjectReader<Group> groupReader;
private final SubscriptionUpdateFactory subscriptionFactory; private final PacketFactory packetFactory;
SubscriptionUpdateReader(ObjectReader<Group> groupReader, SubscriptionUpdateReader(ObjectReader<Group> groupReader,
SubscriptionUpdateFactory subscriptionFactory) { PacketFactory packetFactory) {
this.groupReader = groupReader; this.groupReader = groupReader;
this.subscriptionFactory = subscriptionFactory; this.packetFactory = packetFactory;
} }
public SubscriptionUpdate readObject(Reader r) throws IOException { public SubscriptionUpdate readObject(Reader r) throws IOException {
@@ -38,6 +39,6 @@ class SubscriptionUpdateReader implements ObjectReader<SubscriptionUpdate> {
if(timestamp < 0L) throw new FormatException(); if(timestamp < 0L) throw new FormatException();
r.removeConsumer(counting); r.removeConsumer(counting);
// Build and return the subscription update // Build and return the subscription update
return subscriptionFactory.createSubscriptions(subs, timestamp); return packetFactory.createSubscriptionUpdate(subs, timestamp);
} }
} }

View File

@@ -1,12 +0,0 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportUpdate;
interface TransportUpdateFactory {
TransportUpdate createTransportUpdate(Collection<Transport> transports,
long timestamp);
}

View File

@@ -1,14 +0,0 @@
package net.sf.briar.protocol;
import java.util.Collection;
import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportUpdate;
class TransportUpdateFactoryImpl implements TransportUpdateFactory {
public TransportUpdate createTransportUpdate(
Collection<Transport> transports, long timestamp) {
return new TransportUpdateImpl(transports, timestamp);
}
}

View File

@@ -7,6 +7,7 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
@@ -21,11 +22,11 @@ import net.sf.briar.api.serial.Reader;
class TransportUpdateReader implements ObjectReader<TransportUpdate> { class TransportUpdateReader implements ObjectReader<TransportUpdate> {
private final TransportUpdateFactory transportUpdateFactory; private final PacketFactory packetFactory;
private final ObjectReader<Transport> transportReader; private final ObjectReader<Transport> transportReader;
TransportUpdateReader(TransportUpdateFactory transportFactory) { TransportUpdateReader(PacketFactory packetFactory) {
this.transportUpdateFactory = transportFactory; this.packetFactory = packetFactory;
transportReader = new TransportReader(); transportReader = new TransportReader();
} }
@@ -51,8 +52,7 @@ class TransportUpdateReader implements ObjectReader<TransportUpdate> {
if(!indices.add(t.getIndex())) throw new FormatException(); if(!indices.add(t.getIndex())) throw new FormatException();
} }
// Build and return the transport update // Build and return the transport update
return transportUpdateFactory.createTransportUpdate(transports, return packetFactory.createTransportUpdate(transports, timestamp);
timestamp);
} }
private static class TransportReader implements ObjectReader<Transport> { private static class TransportReader implements ObjectReader<Transport> {

View File

@@ -1,64 +0,0 @@
package net.sf.briar.protocol.writers;
import java.io.IOException;
import java.io.OutputStream;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
class AckWriterImpl implements AckWriter {
private final OutputStream out;
private final int headerLength, idLength, footerLength;
private final Writer w;
private boolean started = false;
private int capacity = ProtocolConstants.MAX_PACKET_LENGTH;
AckWriterImpl(OutputStream out, SerialComponent serial,
WriterFactory writerFactory) {
this.out = out;
headerLength = serial.getSerialisedStructIdLength(Types.ACK)
+ serial.getSerialisedListStartLength();
idLength = serial.getSerialisedUniqueIdLength(Types.BATCH_ID);
footerLength = serial.getSerialisedListEndLength();
w = writerFactory.createWriter(out);
}
public void setMaxPacketLength(int length) {
if(started) throw new IllegalStateException();
if(length < 0 || length > ProtocolConstants.MAX_PACKET_LENGTH)
throw new IllegalArgumentException();
capacity = length;
}
public boolean writeBatchId(BatchId b) throws IOException {
int overhead = started ? footerLength : headerLength + footerLength;
if(capacity < idLength + overhead) return false;
if(!started) start();
w.writeStructId(Types.BATCH_ID);
w.writeBytes(b.getBytes());
capacity -= idLength;
return true;
}
public void finish() throws IOException {
if(!started) start();
w.writeListEnd();
out.flush();
capacity = ProtocolConstants.MAX_PACKET_LENGTH;
started = false;
}
private void start() throws IOException {
w.writeStructId(Types.ACK);
w.writeListStart();
capacity -= headerLength;
started = true;
}
}

View File

@@ -1,78 +0,0 @@
package net.sf.briar.protocol.writers;
import java.io.IOException;
import java.io.OutputStream;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.serial.DigestingConsumer;
import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
class BatchWriterImpl implements BatchWriter {
private final OutputStream out;
private final int headerLength, footerLength;
private final Writer w;
private final MessageDigest messageDigest;
private final DigestingConsumer digestingConsumer;
private boolean started = false;
private int capacity = ProtocolConstants.MAX_PACKET_LENGTH;
private int remaining = capacity;
BatchWriterImpl(OutputStream out, SerialComponent serial,
WriterFactory writerFactory, MessageDigest messageDigest) {
this.out = out;
headerLength = serial.getSerialisedStructIdLength(Types.BATCH)
+ serial.getSerialisedListStartLength();
footerLength = serial.getSerialisedListEndLength();
w = writerFactory.createWriter(this.out);
this.messageDigest = messageDigest;
digestingConsumer = new DigestingConsumer(messageDigest);
}
public int getCapacity() {
return capacity - headerLength - footerLength;
}
public void setMaxPacketLength(int length) {
if(started) throw new IllegalStateException();
if(length < 0 || length > ProtocolConstants.MAX_PACKET_LENGTH)
throw new IllegalArgumentException();
remaining = capacity = length;
}
public boolean writeMessage(byte[] message) throws IOException {
int overhead = started ? footerLength : headerLength + footerLength;
if(remaining < message.length + overhead) return false;
if(!started) start();
// Bypass the writer and write the raw message directly
out.write(message);
remaining -= message.length;
return true;
}
public BatchId finish() throws IOException {
if(!started) start();
w.writeListEnd();
w.removeConsumer(digestingConsumer);
out.flush();
remaining = capacity = ProtocolConstants.MAX_PACKET_LENGTH;
started = false;
return new BatchId(messageDigest.digest());
}
private void start() throws IOException {
messageDigest.reset();
w.addConsumer(digestingConsumer);
w.writeStructId(Types.BATCH);
w.writeListStart();
remaining -= headerLength;
started = true;
}
}

View File

@@ -1,64 +0,0 @@
package net.sf.briar.protocol.writers;
import java.io.IOException;
import java.io.OutputStream;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.writers.OfferWriter;
import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
class OfferWriterImpl implements OfferWriter {
private final OutputStream out;
private final int headerLength, idLength, footerLength;
private final Writer w;
private boolean started = false;
private int capacity = ProtocolConstants.MAX_PACKET_LENGTH;
OfferWriterImpl(OutputStream out, SerialComponent serial,
WriterFactory writerFactory) {
this.out = out;
headerLength = serial.getSerialisedStructIdLength(Types.OFFER)
+ serial.getSerialisedListStartLength();
idLength = serial.getSerialisedUniqueIdLength(Types.MESSAGE_ID);
footerLength = serial.getSerialisedListEndLength();
w = writerFactory.createWriter(out);
}
public void setMaxPacketLength(int length) {
if(started) throw new IllegalStateException();
if(length < 0 || length > ProtocolConstants.MAX_PACKET_LENGTH)
throw new IllegalArgumentException();
capacity = length;
}
public boolean writeMessageId(MessageId m) throws IOException {
int overhead = started ? footerLength : headerLength + footerLength;
if(capacity < idLength + overhead) return false;
if(!started) start();
w.writeStructId(Types.MESSAGE_ID);
w.writeBytes(m.getBytes());
capacity -= idLength;
return true;
}
public void finish() throws IOException {
if(!started) start();
w.writeListEnd();
out.flush();
capacity = ProtocolConstants.MAX_PACKET_LENGTH;
started = false;
}
private void start() throws IOException {
w.writeStructId(Types.OFFER);
w.writeListStart();
capacity -= headerLength;
started = true;
}
}

View File

@@ -1,57 +0,0 @@
package net.sf.briar.protocol.writers;
import java.io.OutputStream;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.OfferWriter;
import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import net.sf.briar.api.protocol.writers.RequestWriter;
import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter;
import net.sf.briar.api.protocol.writers.TransportUpdateWriter;
import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.serial.WriterFactory;
import com.google.inject.Inject;
class ProtocolWriterFactoryImpl implements ProtocolWriterFactory {
private final MessageDigest messageDigest;
private final SerialComponent serial;
private final WriterFactory writerFactory;
@Inject
ProtocolWriterFactoryImpl(CryptoComponent crypto,
SerialComponent serial, WriterFactory writerFactory) {
messageDigest = crypto.getMessageDigest();
this.serial = serial;
this.writerFactory = writerFactory;
}
public AckWriter createAckWriter(OutputStream out) {
return new AckWriterImpl(out, serial, writerFactory);
}
public BatchWriter createBatchWriter(OutputStream out) {
return new BatchWriterImpl(out, serial, writerFactory, messageDigest);
}
public OfferWriter createOfferWriter(OutputStream out) {
return new OfferWriterImpl(out, serial, writerFactory);
}
public RequestWriter createRequestWriter(OutputStream out) {
return new RequestWriterImpl(out, writerFactory);
}
public SubscriptionUpdateWriter createSubscriptionUpdateWriter(
OutputStream out) {
return new SubscriptionUpdateWriterImpl(out, writerFactory);
}
public TransportUpdateWriter createTransportUpdateWriter(OutputStream out) {
return new TransportUpdateWriterImpl(out, writerFactory);
}
}

View File

@@ -1,13 +0,0 @@
package net.sf.briar.protocol.writers;
import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import com.google.inject.AbstractModule;
public class ProtocolWritersModule extends AbstractModule {
@Override
protected void configure() {
bind(ProtocolWriterFactory.class).to(ProtocolWriterFactoryImpl.class);
}
}

View File

@@ -1,39 +0,0 @@
package net.sf.briar.protocol.writers;
import java.io.IOException;
import java.io.OutputStream;
import java.util.BitSet;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.writers.RequestWriter;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
class RequestWriterImpl implements RequestWriter {
private final OutputStream out;
private final Writer w;
RequestWriterImpl(OutputStream out, WriterFactory writerFactory) {
this.out = out;
w = writerFactory.createWriter(out);
}
public void writeRequest(BitSet b, int length)
throws IOException {
w.writeStructId(Types.REQUEST);
// If the number of bits isn't a multiple of 8, round up to a byte
int bytes = length % 8 == 0 ? length / 8 : length / 8 + 1;
byte[] bitmap = new byte[bytes];
// I'm kind of surprised BitSet doesn't have a method for this
for(int i = 0; i < length; i++) {
if(b.get(i)) {
int offset = i / 8;
byte bit = (byte) (128 >> i % 8);
bitmap[offset] |= bit;
}
}
w.writeBytes(bitmap);
out.flush();
}
}

View File

@@ -1,45 +0,0 @@
package net.sf.briar.protocol.writers;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Map;
import java.util.Map.Entry;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
class SubscriptionUpdateWriterImpl implements SubscriptionUpdateWriter {
private final OutputStream out;
private final Writer w;
SubscriptionUpdateWriterImpl(OutputStream out,
WriterFactory writerFactory) {
this.out = out;
w = writerFactory.createWriter(out);
}
public void writeSubscriptions(Map<Group, Long> subs, long timestamp)
throws IOException {
w.writeStructId(Types.SUBSCRIPTION_UPDATE);
w.writeMapStart();
for(Entry<Group, Long> e : subs.entrySet()) {
writeGroup(w, e.getKey());
w.writeInt64(e.getValue());
}
w.writeMapEnd();
w.writeInt64(timestamp);
out.flush();
}
private void writeGroup(Writer w, Group g) throws IOException {
w.writeStructId(Types.GROUP);
w.writeString(g.getName());
byte[] publicKey = g.getPublicKey();
if(publicKey == null) w.writeNull();
else w.writeBytes(publicKey);
}
}

View File

@@ -1,37 +0,0 @@
package net.sf.briar.protocol.writers;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Collection;
import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.writers.TransportUpdateWriter;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
class TransportUpdateWriterImpl implements TransportUpdateWriter {
private final OutputStream out;
private final Writer w;
TransportUpdateWriterImpl(OutputStream out, WriterFactory writerFactory) {
this.out = out;
w = writerFactory.createWriter(out);
}
public void writeTransports(Collection<Transport> transports,
long timestamp) throws IOException {
w.writeStructId(Types.TRANSPORT_UPDATE);
w.writeListStart();
for(Transport p : transports) {
w.writeStructId(Types.TRANSPORT);
w.writeBytes(p.getId().getBytes());
w.writeInt32(p.getIndex().getInt());
w.writeMap(p);
}
w.writeListEnd();
w.writeInt64(timestamp);
out.flush();
}
}

View File

@@ -15,6 +15,11 @@ class SerialComponentImpl implements SerialComponent {
return 1; return 1;
} }
public int getSerialisedStructIdLength(int id) {
if(id < 0 || id > 255) throw new IllegalArgumentException();
return id < 32 ? 1 : 2;
}
public int getSerialisedUniqueIdLength(int id) { public int getSerialisedUniqueIdLength(int id) {
// Struct ID, BYTES tag, length spec, bytes // Struct ID, BYTES tag, length spec, bytes
return getSerialisedStructIdLength(id) + 1 return getSerialisedStructIdLength(id) + 1
@@ -22,14 +27,9 @@ class SerialComponentImpl implements SerialComponent {
} }
private int getSerialisedLengthSpecLength(int length) { private int getSerialisedLengthSpecLength(int length) {
assert length >= 0; if(length < 0) throw new IllegalArgumentException();
if(length < 128) return 1; // Uint7 if(length < 128) return 1; // Uint7
if(length < Short.MAX_VALUE) return 3; // Int16 if(length < Short.MAX_VALUE) return 3; // Int16
return 5; // Int32 return 5; // Int32
} }
public int getSerialisedStructIdLength(int id) {
assert id >= 0 && id <= 255;
return id < 32 ? 1 : 2;
}
} }

View File

@@ -5,8 +5,8 @@ import java.util.concurrent.Executor;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import net.sf.briar.api.transport.BatchConnectionFactory; import net.sf.briar.api.transport.BatchConnectionFactory;
import net.sf.briar.api.transport.BatchTransportReader; import net.sf.briar.api.transport.BatchTransportReader;
import net.sf.briar.api.transport.BatchTransportWriter; import net.sf.briar.api.transport.BatchTransportWriter;
@@ -19,22 +19,22 @@ import com.google.inject.Inject;
class BatchConnectionFactoryImpl implements BatchConnectionFactory { class BatchConnectionFactoryImpl implements BatchConnectionFactory {
private final Executor executor; private final Executor executor;
private final DatabaseComponent db;
private final ConnectionReaderFactory connReaderFactory; private final ConnectionReaderFactory connReaderFactory;
private final ConnectionWriterFactory connWriterFactory; private final ConnectionWriterFactory connWriterFactory;
private final DatabaseComponent db;
private final ProtocolReaderFactory protoReaderFactory; private final ProtocolReaderFactory protoReaderFactory;
private final ProtocolWriterFactory protoWriterFactory; private final ProtocolWriterFactory protoWriterFactory;
@Inject @Inject
BatchConnectionFactoryImpl(Executor executor, BatchConnectionFactoryImpl(Executor executor, DatabaseComponent db,
ConnectionReaderFactory connReaderFactory, ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory) { ProtocolWriterFactory protoWriterFactory) {
this.executor = executor; this.executor = executor;
this.db = db;
this.connReaderFactory = connReaderFactory; this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory; this.connWriterFactory = connWriterFactory;
this.db = db;
this.protoReaderFactory = protoReaderFactory; this.protoReaderFactory = protoReaderFactory;
this.protoWriterFactory = protoWriterFactory; this.protoWriterFactory = protoWriterFactory;
} }
@@ -42,7 +42,7 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory {
public void createIncomingConnection(ConnectionContext ctx, public void createIncomingConnection(ConnectionContext ctx,
BatchTransportReader r, byte[] tag) { BatchTransportReader r, byte[] tag) {
final IncomingBatchConnection conn = new IncomingBatchConnection( final IncomingBatchConnection conn = new IncomingBatchConnection(
executor, connReaderFactory, db, protoReaderFactory, ctx, r, executor, db, connReaderFactory, protoReaderFactory, ctx, r,
tag); tag);
Runnable read = new Runnable() { Runnable read = new Runnable() {
public void run() { public void run() {
@@ -54,8 +54,8 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory {
public void createOutgoingConnection(ContactId c, TransportIndex i, public void createOutgoingConnection(ContactId c, TransportIndex i,
BatchTransportWriter w) { BatchTransportWriter w) {
final OutgoingBatchConnection conn = new OutgoingBatchConnection( final OutgoingBatchConnection conn = new OutgoingBatchConnection(db,
connWriterFactory, db, protoWriterFactory, c, i, w); connWriterFactory, protoWriterFactory, c, i, w);
Runnable write = new Runnable() { Runnable write = new Runnable() {
public void run() { public void run() {
conn.write(); conn.write();

View File

@@ -39,8 +39,8 @@ class IncomingBatchConnection {
private final Semaphore semaphore; private final Semaphore semaphore;
IncomingBatchConnection(Executor executor, IncomingBatchConnection(Executor executor,
ConnectionReaderFactory connFactory, DatabaseComponent db,
DatabaseComponent db, ProtocolReaderFactory protoFactory, ConnectionReaderFactory connFactory, ProtocolReaderFactory protoFactory,
ConnectionContext ctx, BatchTransportReader reader, byte[] tag) { ConnectionContext ctx, BatchTransportReader reader, byte[] tag) {
this.executor = executor; this.executor = executor;
this.connFactory = connFactory; this.connFactory = connFactory;

View File

@@ -10,12 +10,13 @@ import java.util.logging.Logger;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.ProtocolWriter;
import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.writers.AckWriter; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter;
import net.sf.briar.api.protocol.writers.TransportUpdateWriter;
import net.sf.briar.api.transport.BatchTransportWriter; import net.sf.briar.api.transport.BatchTransportWriter;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
@@ -26,23 +27,23 @@ class OutgoingBatchConnection {
private static final Logger LOG = private static final Logger LOG =
Logger.getLogger(OutgoingBatchConnection.class.getName()); Logger.getLogger(OutgoingBatchConnection.class.getName());
private final ConnectionWriterFactory connFactory;
private final DatabaseComponent db; private final DatabaseComponent db;
private final ConnectionWriterFactory connFactory;
private final ProtocolWriterFactory protoFactory; private final ProtocolWriterFactory protoFactory;
private final ContactId contactId; private final ContactId contactId;
private final TransportIndex transportIndex; private final TransportIndex transportIndex;
private final BatchTransportWriter writer; private final BatchTransportWriter transport;
OutgoingBatchConnection(ConnectionWriterFactory connFactory, OutgoingBatchConnection(DatabaseComponent db,
DatabaseComponent db, ProtocolWriterFactory protoFactory, ConnectionWriterFactory connFactory,
ContactId contactId, TransportIndex transportIndex, ProtocolWriterFactory protoFactory, ContactId contactId,
BatchTransportWriter writer) { TransportIndex transportIndex, BatchTransportWriter transport) {
this.connFactory = connFactory;
this.db = db; this.db = db;
this.connFactory = connFactory;
this.protoFactory = protoFactory; this.protoFactory = protoFactory;
this.contactId = contactId; this.contactId = contactId;
this.transportIndex = transportIndex; this.transportIndex = transportIndex;
this.writer = writer; this.transport = transport;
} }
void write() { void write() {
@@ -50,45 +51,52 @@ class OutgoingBatchConnection {
ConnectionContext ctx = db.getConnectionContext(contactId, ConnectionContext ctx = db.getConnectionContext(contactId,
transportIndex); transportIndex);
ConnectionWriter conn = connFactory.createConnectionWriter( ConnectionWriter conn = connFactory.createConnectionWriter(
writer.getOutputStream(), writer.getCapacity(), transport.getOutputStream(), transport.getCapacity(),
ctx.getSecret()); ctx.getSecret());
OutputStream out = conn.getOutputStream(); OutputStream out = conn.getOutputStream();
ProtocolWriter proto = protoFactory.createProtocolWriter(out);
// There should be enough space for a packet // There should be enough space for a packet
long capacity = conn.getRemainingCapacity(); long capacity = conn.getRemainingCapacity();
if(capacity < MAX_PACKET_LENGTH) throw new IOException(); if(capacity < MAX_PACKET_LENGTH) throw new IOException();
// Write a transport update // Write a transport update
TransportUpdateWriter t = TransportUpdate t = db.generateTransportUpdate(contactId);
protoFactory.createTransportUpdateWriter(out); if(t != null) proto.writeTransportUpdate(t);
db.generateTransportUpdate(contactId, t);
// If there's space, write a subscription update // If there's space, write a subscription update
capacity = conn.getRemainingCapacity(); capacity = conn.getRemainingCapacity();
if(capacity >= MAX_PACKET_LENGTH) { if(capacity >= MAX_PACKET_LENGTH) {
SubscriptionUpdateWriter s = SubscriptionUpdate s = db.generateSubscriptionUpdate(contactId);
protoFactory.createSubscriptionUpdateWriter(out); if(s != null) proto.writeSubscriptionUpdate(s);
db.generateSubscriptionUpdate(contactId, s);
} }
// Write acks until you can't write acks no more // Write acks until you can't write acks no more
AckWriter a = protoFactory.createAckWriter(out); capacity = conn.getRemainingCapacity();
do { int maxBatches = proto.getMaxBatchesForAck(capacity);
Ack a = db.generateAck(contactId, maxBatches);
while(a != null) {
proto.writeAck(a);
capacity = conn.getRemainingCapacity(); capacity = conn.getRemainingCapacity();
int max = (int) Math.min(MAX_PACKET_LENGTH, capacity); maxBatches = proto.getMaxBatchesForAck(capacity);
a.setMaxPacketLength(max); a = db.generateAck(contactId, maxBatches);
} while(db.generateAck(contactId, a)); }
// Write batches until you can't write batches no more // Write batches until you can't write batches no more
BatchWriter b = protoFactory.createBatchWriter(out); capacity = conn.getRemainingCapacity();
do { capacity = proto.getMessageCapacityForBatch(capacity);
RawBatch b = db.generateBatch(contactId, (int) capacity);
while(b != null) {
proto.writeBatch(b);
capacity = conn.getRemainingCapacity(); capacity = conn.getRemainingCapacity();
int max = (int) Math.min(MAX_PACKET_LENGTH, capacity); capacity = proto.getMessageCapacityForBatch(capacity);
b.setMaxPacketLength(max); b = db.generateBatch(contactId, (int) capacity);
} while(db.generateBatch(contactId, b)); }
// Flush the output stream
out.flush();
} catch(DbException e) { } catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
writer.dispose(false); transport.dispose(false);
} catch(IOException e) { } catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
writer.dispose(false); transport.dispose(false);
} }
// Success // Success
writer.dispose(true); transport.dispose(true);
} }
} }

View File

@@ -6,7 +6,8 @@ import java.util.concurrent.Executor;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
@@ -19,14 +20,14 @@ class IncomingStreamConnection extends StreamConnection {
private final ConnectionContext ctx; private final ConnectionContext ctx;
private final byte[] tag; private final byte[] tag;
IncomingStreamConnection(Executor executor, IncomingStreamConnection(Executor executor, DatabaseComponent db,
ConnectionReaderFactory connReaderFactory, SerialComponent serial, ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ProtocolWriterFactory protoWriterFactory,
ConnectionContext ctx, StreamTransportConnection connection, ConnectionContext ctx, StreamTransportConnection connection,
byte[] tag) { byte[] tag) {
super(executor, connReaderFactory, connWriterFactory, db, super(executor, db, serial, connReaderFactory, connWriterFactory,
protoReaderFactory, protoWriterFactory, ctx.getContactId(), protoReaderFactory, protoWriterFactory, ctx.getContactId(),
connection); connection);
this.ctx = ctx; this.ctx = ctx;

View File

@@ -7,8 +7,9 @@ import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
@@ -22,14 +23,14 @@ class OutgoingStreamConnection extends StreamConnection {
private ConnectionContext ctx = null; // Locking: this private ConnectionContext ctx = null; // Locking: this
OutgoingStreamConnection(Executor executor, OutgoingStreamConnection(Executor executor, DatabaseComponent db,
ConnectionReaderFactory connReaderFactory, SerialComponent serial, ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId, ProtocolWriterFactory protoWriterFactory, ContactId contactId,
TransportIndex transportIndex, TransportIndex transportIndex,
StreamTransportConnection connection) { StreamTransportConnection connection) {
super(executor, connReaderFactory, connWriterFactory, db, super(executor, db, serial, connReaderFactory, connWriterFactory,
protoReaderFactory, protoWriterFactory, contactId, connection); protoReaderFactory, protoWriterFactory, contactId, connection);
this.transportIndex = transportIndex; this.transportIndex = transportIndex;
} }

View File

@@ -31,17 +31,14 @@ import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReader;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriter;
import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.protocol.writers.AckWriter; import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.OfferWriter;
import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import net.sf.briar.api.protocol.writers.RequestWriter;
import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter;
import net.sf.briar.api.protocol.writers.TransportUpdateWriter;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
@@ -58,9 +55,10 @@ abstract class StreamConnection implements DatabaseListener {
Logger.getLogger(StreamConnection.class.getName()); Logger.getLogger(StreamConnection.class.getName());
protected final Executor executor; protected final Executor executor;
protected final DatabaseComponent db;
protected final SerialComponent serial;
protected final ConnectionReaderFactory connReaderFactory; protected final ConnectionReaderFactory connReaderFactory;
protected final ConnectionWriterFactory connWriterFactory; protected final ConnectionWriterFactory connWriterFactory;
protected final DatabaseComponent db;
protected final ProtocolReaderFactory protoReaderFactory; protected final ProtocolReaderFactory protoReaderFactory;
protected final ProtocolWriterFactory protoWriterFactory; protected final ProtocolWriterFactory protoWriterFactory;
protected final ContactId contactId; protected final ContactId contactId;
@@ -73,16 +71,17 @@ abstract class StreamConnection implements DatabaseListener {
private LinkedList<MessageId> requested = null; // Locking: this private LinkedList<MessageId> requested = null; // Locking: this
private Offer incomingOffer = null; // Locking: this private Offer incomingOffer = null; // Locking: this
StreamConnection(Executor executor, StreamConnection(Executor executor, DatabaseComponent db,
ConnectionReaderFactory connReaderFactory, SerialComponent serial, ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId, ProtocolWriterFactory protoWriterFactory, ContactId contactId,
StreamTransportConnection connection) { StreamTransportConnection connection) {
this.executor = executor; this.executor = executor;
this.db = db;
this.serial = serial;
this.connReaderFactory = connReaderFactory; this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory; this.connWriterFactory = connWriterFactory;
this.db = db;
this.protoReaderFactory = protoReaderFactory; this.protoReaderFactory = protoReaderFactory;
this.protoWriterFactory = protoWriterFactory; this.protoWriterFactory = protoWriterFactory;
this.contactId = contactId; this.contactId = contactId;
@@ -267,20 +266,11 @@ abstract class StreamConnection implements DatabaseListener {
void write() { void write() {
try { try {
OutputStream out = createConnectionWriter().getOutputStream(); OutputStream out = createConnectionWriter().getOutputStream();
// Create the packet writers ProtocolWriter proto = protoWriterFactory.createProtocolWriter(out);
AckWriter ackWriter = protoWriterFactory.createAckWriter(out);
BatchWriter batchWriter = protoWriterFactory.createBatchWriter(out);
OfferWriter offerWriter = protoWriterFactory.createOfferWriter(out);
RequestWriter requestWriter =
protoWriterFactory.createRequestWriter(out);
SubscriptionUpdateWriter subscriptionUpdateWriter =
protoWriterFactory.createSubscriptionUpdateWriter(out);
TransportUpdateWriter transportUpdateWriter =
protoWriterFactory.createTransportUpdateWriter(out);
// Send the initial packets: transports, subs, any waiting acks // Send the initial packets: transports, subs, any waiting acks
sendTransportUpdate(transportUpdateWriter); sendTransportUpdate(proto);
sendSubscriptionUpdate(subscriptionUpdateWriter); sendSubscriptionUpdate(proto);
sendAcks(ackWriter); sendAcks(proto);
State state = State.SEND_OFFER; State state = State.SEND_OFFER;
// Main loop // Main loop
while(true) { while(true) {
@@ -289,7 +279,7 @@ abstract class StreamConnection implements DatabaseListener {
case SEND_OFFER: case SEND_OFFER:
// Try to send an offer // Try to send an offer
if(sendOffer(offerWriter)) state = State.AWAIT_REQUEST; if(sendOffer(proto)) state = State.AWAIT_REQUEST;
else state = State.IDLE; else state = State.IDLE;
break; break;
@@ -312,16 +302,16 @@ abstract class StreamConnection implements DatabaseListener {
return; return;
} }
if((flags & Flags.TRANSPORTS_UPDATED) != 0) { if((flags & Flags.TRANSPORTS_UPDATED) != 0) {
sendTransportUpdate(transportUpdateWriter); sendTransportUpdate(proto);
} }
if((flags & Flags.SUBSCRIPTIONS_UPDATED) != 0) { if((flags & Flags.SUBSCRIPTIONS_UPDATED) != 0) {
sendSubscriptionUpdate(subscriptionUpdateWriter); sendSubscriptionUpdate(proto);
} }
if((flags & Flags.BATCH_RECEIVED) != 0) { if((flags & Flags.BATCH_RECEIVED) != 0) {
sendAcks(ackWriter); sendAcks(proto);
} }
if((flags & Flags.OFFER_RECEIVED) != 0) { if((flags & Flags.OFFER_RECEIVED) != 0) {
sendRequest(requestWriter); sendRequest(proto);
} }
if((flags & Flags.REQUEST_RECEIVED) != 0) { if((flags & Flags.REQUEST_RECEIVED) != 0) {
// Should only be received in state AWAIT_REQUEST // Should only be received in state AWAIT_REQUEST
@@ -351,16 +341,16 @@ abstract class StreamConnection implements DatabaseListener {
return; return;
} }
if((flags & Flags.TRANSPORTS_UPDATED) != 0) { if((flags & Flags.TRANSPORTS_UPDATED) != 0) {
sendTransportUpdate(transportUpdateWriter); sendTransportUpdate(proto);
} }
if((flags & Flags.SUBSCRIPTIONS_UPDATED) != 0) { if((flags & Flags.SUBSCRIPTIONS_UPDATED) != 0) {
sendSubscriptionUpdate(subscriptionUpdateWriter); sendSubscriptionUpdate(proto);
} }
if((flags & Flags.BATCH_RECEIVED) != 0) { if((flags & Flags.BATCH_RECEIVED) != 0) {
sendAcks(ackWriter); sendAcks(proto);
} }
if((flags & Flags.OFFER_RECEIVED) != 0) { if((flags & Flags.OFFER_RECEIVED) != 0) {
sendRequest(requestWriter); sendRequest(proto);
} }
if((flags & Flags.REQUEST_RECEIVED) != 0) { if((flags & Flags.REQUEST_RECEIVED) != 0) {
state = State.SEND_BATCHES; state = State.SEND_BATCHES;
@@ -382,16 +372,16 @@ abstract class StreamConnection implements DatabaseListener {
return; return;
} }
if((flags & Flags.TRANSPORTS_UPDATED) != 0) { if((flags & Flags.TRANSPORTS_UPDATED) != 0) {
sendTransportUpdate(transportUpdateWriter); sendTransportUpdate(proto);
} }
if((flags & Flags.SUBSCRIPTIONS_UPDATED) != 0) { if((flags & Flags.SUBSCRIPTIONS_UPDATED) != 0) {
sendSubscriptionUpdate(subscriptionUpdateWriter); sendSubscriptionUpdate(proto);
} }
if((flags & Flags.BATCH_RECEIVED) != 0) { if((flags & Flags.BATCH_RECEIVED) != 0) {
sendAcks(ackWriter); sendAcks(proto);
} }
if((flags & Flags.OFFER_RECEIVED) != 0) { if((flags & Flags.OFFER_RECEIVED) != 0) {
sendRequest(requestWriter); sendRequest(proto);
} }
if((flags & Flags.REQUEST_RECEIVED) != 0) { if((flags & Flags.REQUEST_RECEIVED) != 0) {
// Should only be received in state AWAIT_REQUEST // Should only be received in state AWAIT_REQUEST
@@ -401,7 +391,7 @@ abstract class StreamConnection implements DatabaseListener {
// Ignored in this state // Ignored in this state
} }
// Try to send a batch // Try to send a batch
if(!sendBatch(batchWriter)) state = State.SEND_OFFER; if(!sendBatch(proto)) state = State.SEND_OFFER;
break; break;
} }
} }
@@ -416,11 +406,18 @@ abstract class StreamConnection implements DatabaseListener {
connection.dispose(true); connection.dispose(true);
} }
private void sendAcks(AckWriter a) throws DbException, IOException { private void sendAcks(ProtocolWriter proto)
while(db.generateAck(contactId, a)); throws DbException, IOException {
int maxBatches = proto.getMaxBatchesForAck(Long.MAX_VALUE);
Ack a = db.generateAck(contactId, maxBatches);
while(a != null) {
proto.writeAck(a);
a = db.generateAck(contactId, maxBatches);
}
} }
private boolean sendBatch(BatchWriter b) throws DbException, IOException { private boolean sendBatch(ProtocolWriter proto)
throws DbException, IOException {
Collection<MessageId> req; Collection<MessageId> req;
// Retrieve the requested message IDs // Retrieve the requested message IDs
synchronized(this) { synchronized(this) {
@@ -429,31 +426,40 @@ abstract class StreamConnection implements DatabaseListener {
req = requested; req = requested;
} }
// Try to generate a batch, updating the collection of message IDs // Try to generate a batch, updating the collection of message IDs
boolean anyAdded = db.generateBatch(contactId, b, req); int capacity = proto.getMessageCapacityForBatch(Long.MAX_VALUE);
// If no more batches can be generated, discard the remaining IDs RawBatch b = db.generateBatch(contactId, capacity, req);
if(!anyAdded) { if(b == null) {
// No more batches can be generated - discard the remaining IDs
synchronized(this) { synchronized(this) {
assert offered == null; assert offered == null;
assert requested == req; assert requested == req;
requested = null; requested = null;
} }
return false;
} else {
proto.writeBatch(b);
return true;
} }
return anyAdded;
} }
private boolean sendOffer(OfferWriter o) throws DbException, IOException { private boolean sendOffer(ProtocolWriter proto)
throws DbException, IOException {
// Generate an offer // Generate an offer
Collection<MessageId> off = db.generateOffer(contactId, o); int maxMessages = proto.getMaxMessagesForOffer(Long.MAX_VALUE);
Offer o = db.generateOffer(contactId, maxMessages);
if(o == null) return false;
proto.writeOffer(o);
// Store the offered message IDs // Store the offered message IDs
synchronized(this) { synchronized(this) {
assert offered == null; assert offered == null;
assert requested == null; assert requested == null;
offered = off; offered = o.getMessageIds();
} }
return !off.isEmpty(); return true;
} }
private void sendRequest(RequestWriter r) throws DbException, IOException { private void sendRequest(ProtocolWriter proto)
throws DbException, IOException {
Offer o; Offer o;
// Retrieve the incoming offer // Retrieve the incoming offer
synchronized(this) { synchronized(this) {
@@ -462,16 +468,19 @@ abstract class StreamConnection implements DatabaseListener {
incomingOffer = null; incomingOffer = null;
} }
// Process the offer and generate a request // Process the offer and generate a request
db.receiveOffer(contactId, o, r); Request r = db.receiveOffer(contactId, o);
proto.writeRequest(r);
} }
private void sendTransportUpdate(TransportUpdateWriter t) private void sendTransportUpdate(ProtocolWriter proto)
throws DbException, IOException { throws DbException, IOException {
db.generateTransportUpdate(contactId, t); TransportUpdate t = db.generateTransportUpdate(contactId);
if(t != null) proto.writeTransportUpdate(t);
} }
private void sendSubscriptionUpdate(SubscriptionUpdateWriter s) private void sendSubscriptionUpdate(ProtocolWriter proto)
throws DbException, IOException { throws DbException, IOException {
db.generateSubscriptionUpdate(contactId, s); SubscriptionUpdate s = db.generateSubscriptionUpdate(contactId);
if(s != null) proto.writeSubscriptionUpdate(s);
} }
} }

View File

@@ -5,8 +5,9 @@ import java.util.concurrent.Executor;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
@@ -18,31 +19,33 @@ import com.google.inject.Inject;
class StreamConnectionFactoryImpl implements StreamConnectionFactory { class StreamConnectionFactoryImpl implements StreamConnectionFactory {
private final Executor executor; private final Executor executor;
private final DatabaseComponent db;
private final SerialComponent serial;
private final ConnectionReaderFactory connReaderFactory; private final ConnectionReaderFactory connReaderFactory;
private final ConnectionWriterFactory connWriterFactory; private final ConnectionWriterFactory connWriterFactory;
private final DatabaseComponent db;
private final ProtocolReaderFactory protoReaderFactory; private final ProtocolReaderFactory protoReaderFactory;
private final ProtocolWriterFactory protoWriterFactory; private final ProtocolWriterFactory protoWriterFactory;
@Inject @Inject
StreamConnectionFactoryImpl(Executor executor, StreamConnectionFactoryImpl(Executor executor, DatabaseComponent db,
ConnectionReaderFactory connReaderFactory, SerialComponent serial, ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory) { ProtocolWriterFactory protoWriterFactory) {
this.executor = executor; this.executor = executor;
this.db = db;
this.serial = serial;
this.connReaderFactory = connReaderFactory; this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory; this.connWriterFactory = connWriterFactory;
this.db = db;
this.protoReaderFactory = protoReaderFactory; this.protoReaderFactory = protoReaderFactory;
this.protoWriterFactory = protoWriterFactory; this.protoWriterFactory = protoWriterFactory;
} }
public void createIncomingConnection(ConnectionContext ctx, public void createIncomingConnection(ConnectionContext ctx,
StreamTransportConnection s, byte[] tag) { StreamTransportConnection s, byte[] tag) {
final StreamConnection conn = new IncomingStreamConnection(executor, final StreamConnection conn = new IncomingStreamConnection(executor, db,
connReaderFactory, connWriterFactory, db, protoReaderFactory, serial, connReaderFactory, connWriterFactory,
protoWriterFactory, ctx, s, tag); protoReaderFactory, protoWriterFactory, ctx, s, tag);
Runnable write = new Runnable() { Runnable write = new Runnable() {
public void run() { public void run() {
conn.write(); conn.write();
@@ -59,9 +62,9 @@ class StreamConnectionFactoryImpl implements StreamConnectionFactory {
public void createOutgoingConnection(ContactId c, TransportIndex i, public void createOutgoingConnection(ContactId c, TransportIndex i,
StreamTransportConnection s) { StreamTransportConnection s) {
final StreamConnection conn = new OutgoingStreamConnection(executor, final StreamConnection conn = new OutgoingStreamConnection(executor, db,
connReaderFactory, connWriterFactory, db, protoReaderFactory, serial, connReaderFactory, connWriterFactory,
protoWriterFactory, c, i, s); protoReaderFactory, protoWriterFactory, c, i, s);
Runnable write = new Runnable() { Runnable write = new Runnable() {
public void run() { public void run() {
conn.write(); conn.write();

View File

@@ -37,11 +37,11 @@
<test name='net.sf.briar.plugins.socket.SimpleSocketPluginTest'/> <test name='net.sf.briar.plugins.socket.SimpleSocketPluginTest'/>
<test name='net.sf.briar.protocol.AckReaderTest'/> <test name='net.sf.briar.protocol.AckReaderTest'/>
<test name='net.sf.briar.protocol.BatchReaderTest'/> <test name='net.sf.briar.protocol.BatchReaderTest'/>
<test name='net.sf.briar.protocol.ConstantsTest'/>
<test name='net.sf.briar.protocol.ConsumersTest'/> <test name='net.sf.briar.protocol.ConsumersTest'/>
<test name='net.sf.briar.protocol.ProtocolReadWriteTest'/> <test name='net.sf.briar.protocol.ProtocolReadWriteTest'/>
<test name='net.sf.briar.protocol.ProtocolWriterImplTest'/>
<test name='net.sf.briar.protocol.RequestReaderTest'/> <test name='net.sf.briar.protocol.RequestReaderTest'/>
<test name='net.sf.briar.protocol.writers.ConstantsTest'/>
<test name='net.sf.briar.protocol.writers.RequestWriterImplTest'/>
<test name='net.sf.briar.serial.ReaderImplTest'/> <test name='net.sf.briar.serial.ReaderImplTest'/>
<test name='net.sf.briar.serial.WriterImplTest'/> <test name='net.sf.briar.serial.WriterImplTest'/>
<test name='net.sf.briar.setup.SetupWorkerTest'/> <test name='net.sf.briar.setup.SetupWorkerTest'/>

View File

@@ -8,6 +8,7 @@ import java.io.ByteArrayOutputStream;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.security.KeyPair; import java.security.KeyPair;
import java.util.Arrays;
import java.util.BitSet; import java.util.BitSet;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
@@ -16,7 +17,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.Executors;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
@@ -31,21 +32,18 @@ import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageFactory; import net.sf.briar.api.protocol.MessageFactory;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReader;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriter;
import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.OfferWriter;
import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import net.sf.briar.api.protocol.writers.RequestWriter;
import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter;
import net.sf.briar.api.protocol.writers.TransportUpdateWriter;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
@@ -54,7 +52,6 @@ import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.db.DatabaseModule; import net.sf.briar.db.DatabaseModule;
import net.sf.briar.lifecycle.LifecycleModule; import net.sf.briar.lifecycle.LifecycleModule;
import net.sf.briar.protocol.ProtocolModule; import net.sf.briar.protocol.ProtocolModule;
import net.sf.briar.protocol.writers.ProtocolWritersModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
import net.sf.briar.transport.TransportModule; import net.sf.briar.transport.TransportModule;
import net.sf.briar.transport.batch.TransportBatchModule; import net.sf.briar.transport.batch.TransportBatchModule;
@@ -76,6 +73,7 @@ public class ProtocolIntegrationTest extends TestCase {
private final ConnectionWriterFactory connectionWriterFactory; private final ConnectionWriterFactory connectionWriterFactory;
private final ProtocolReaderFactory protocolReaderFactory; private final ProtocolReaderFactory protocolReaderFactory;
private final ProtocolWriterFactory protocolWriterFactory; private final ProtocolWriterFactory protocolWriterFactory;
private final PacketFactory packetFactory;
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final byte[] secret; private final byte[] secret;
private final TransportIndex transportIndex = new TransportIndex(13); private final TransportIndex transportIndex = new TransportIndex(13);
@@ -93,19 +91,19 @@ public class ProtocolIntegrationTest extends TestCase {
@Override @Override
public void configure() { public void configure() {
bind(Executor.class).toInstance( bind(Executor.class).toInstance(
new ScheduledThreadPoolExecutor(5)); Executors.newCachedThreadPool());
} }
}; };
Injector i = Guice.createInjector(testModule, new CryptoModule(), Injector i = Guice.createInjector(testModule, new CryptoModule(),
new DatabaseModule(), new LifecycleModule(), new DatabaseModule(), new LifecycleModule(),
new ProtocolModule(), new ProtocolWritersModule(), new ProtocolModule(), new SerialModule(),
new SerialModule(), new TestDatabaseModule(), new TestDatabaseModule(), new TransportBatchModule(),
new TransportBatchModule(), new TransportModule(), new TransportModule(), new TransportStreamModule());
new TransportStreamModule());
connectionReaderFactory = i.getInstance(ConnectionReaderFactory.class); connectionReaderFactory = i.getInstance(ConnectionReaderFactory.class);
connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class); connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class);
protocolReaderFactory = i.getInstance(ProtocolReaderFactory.class); protocolReaderFactory = i.getInstance(ProtocolReaderFactory.class);
protocolWriterFactory = i.getInstance(ProtocolWriterFactory.class); protocolWriterFactory = i.getInstance(ProtocolWriterFactory.class);
packetFactory = i.getInstance(PacketFactory.class);
crypto = i.getInstance(CryptoComponent.class); crypto = i.getInstance(CryptoComponent.class);
// Create a shared secret // Create a shared secret
Random r = new Random(); Random r = new Random();
@@ -149,47 +147,51 @@ public class ProtocolIntegrationTest extends TestCase {
private byte[] write() throws Exception { private byte[] write() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, ConnectionWriter conn = connectionWriterFactory.createConnectionWriter(
Long.MAX_VALUE, secret.clone()); out, Long.MAX_VALUE, secret.clone());
OutputStream out1 = w.getOutputStream(); OutputStream out1 = conn.getOutputStream();
ProtocolWriter proto = protocolWriterFactory.createProtocolWriter(out1);
AckWriter a = protocolWriterFactory.createAckWriter(out1); Ack a = packetFactory.createAck(Collections.singletonList(ack));
assertTrue(a.writeBatchId(ack)); proto.writeAck(a);
a.finish();
BatchWriter b = protocolWriterFactory.createBatchWriter(out1); Collection<byte[]> batch = Arrays.asList(new byte[][] {
assertTrue(b.writeMessage(message.getSerialised())); message.getSerialised(),
assertTrue(b.writeMessage(message1.getSerialised())); message1.getSerialised(),
assertTrue(b.writeMessage(message2.getSerialised())); message2.getSerialised(),
assertTrue(b.writeMessage(message3.getSerialised())); message3.getSerialised()
b.finish(); });
RawBatch b = packetFactory.createBatch(batch);
proto.writeBatch(b);
OfferWriter o = protocolWriterFactory.createOfferWriter(out1); Collection<MessageId> offer = Arrays.asList(new MessageId[] {
assertTrue(o.writeMessageId(message.getId())); message.getId(),
assertTrue(o.writeMessageId(message1.getId())); message1.getId(),
assertTrue(o.writeMessageId(message2.getId())); message2.getId(),
assertTrue(o.writeMessageId(message3.getId())); message3.getId()
o.finish(); });
Offer o = packetFactory.createOffer(offer);
proto.writeOffer(o);
RequestWriter r = protocolWriterFactory.createRequestWriter(out1);
BitSet requested = new BitSet(4); BitSet requested = new BitSet(4);
requested.set(1); requested.set(1);
requested.set(3); requested.set(3);
r.writeRequest(requested, 4); Request r = packetFactory.createRequest(requested, 4);
proto.writeRequest(r);
SubscriptionUpdateWriter s =
protocolWriterFactory.createSubscriptionUpdateWriter(out1);
// Use a LinkedHashMap for predictable iteration order // Use a LinkedHashMap for predictable iteration order
Map<Group, Long> subs = new LinkedHashMap<Group, Long>(); Map<Group, Long> subs = new LinkedHashMap<Group, Long>();
subs.put(group, 0L); subs.put(group, 0L);
subs.put(group1, 0L); subs.put(group1, 0L);
s.writeSubscriptions(subs, timestamp); SubscriptionUpdate s = packetFactory.createSubscriptionUpdate(subs,
timestamp);
proto.writeSubscriptionUpdate(s);
TransportUpdateWriter t = TransportUpdate t = packetFactory.createTransportUpdate(transports,
protocolWriterFactory.createTransportUpdateWriter(out1); timestamp);
t.writeTransports(transports, timestamp); proto.writeTransportUpdate(t);
out1.close(); out1.flush();
return out.toByteArray(); return out.toByteArray();
} }

View File

@@ -8,6 +8,7 @@ import java.util.Collections;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.lifecycle.ShutdownManager; import net.sf.briar.api.lifecycle.ShutdownManager;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.db.DatabaseCleaner.Callback; import net.sf.briar.db.DatabaseCleaner.Callback;
import org.jmock.Expectations; import org.jmock.Expectations;
@@ -27,11 +28,13 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(database).getFreeSpace(); oneOf(database).getFreeSpace();
will(returnValue(MIN_FREE_SPACE)); will(returnValue(MIN_FREE_SPACE));
}}); }});
Callback db = createDatabaseComponentImpl(database, cleaner, shutdown); Callback db = createDatabaseComponentImpl(database, cleaner, shutdown,
packetFactory);
db.checkFreeSpaceAndClean(); db.checkFreeSpaceAndClean();
@@ -45,6 +48,7 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(database).getFreeSpace(); oneOf(database).getFreeSpace();
will(returnValue(MIN_FREE_SPACE - 1)); will(returnValue(MIN_FREE_SPACE - 1));
@@ -57,7 +61,8 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest {
oneOf(database).getFreeSpace(); oneOf(database).getFreeSpace();
will(returnValue(MIN_FREE_SPACE)); will(returnValue(MIN_FREE_SPACE));
}}); }});
Callback db = createDatabaseComponentImpl(database, cleaner, shutdown); Callback db = createDatabaseComponentImpl(database, cleaner, shutdown,
packetFactory);
db.checkFreeSpaceAndClean(); db.checkFreeSpaceAndClean();
@@ -72,6 +77,7 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(database).getFreeSpace(); oneOf(database).getFreeSpace();
will(returnValue(MIN_FREE_SPACE - 1)); will(returnValue(MIN_FREE_SPACE - 1));
@@ -86,7 +92,8 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest {
oneOf(database).getFreeSpace(); oneOf(database).getFreeSpace();
will(returnValue(MIN_FREE_SPACE)); will(returnValue(MIN_FREE_SPACE));
}}); }});
Callback db = createDatabaseComponentImpl(database, cleaner, shutdown); Callback db = createDatabaseComponentImpl(database, cleaner, shutdown,
packetFactory);
db.checkFreeSpaceAndClean(); db.checkFreeSpaceAndClean();
@@ -101,6 +108,7 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(database).getFreeSpace(); oneOf(database).getFreeSpace();
will(returnValue(MIN_FREE_SPACE - 1)); will(returnValue(MIN_FREE_SPACE - 1));
@@ -117,7 +125,8 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest {
oneOf(database).getFreeSpace(); oneOf(database).getFreeSpace();
will(returnValue(MIN_FREE_SPACE)); will(returnValue(MIN_FREE_SPACE));
}}); }});
Callback db = createDatabaseComponentImpl(database, cleaner, shutdown); Callback db = createDatabaseComponentImpl(database, cleaner, shutdown,
packetFactory);
db.checkFreeSpaceAndClean(); db.checkFreeSpaceAndClean();
@@ -127,13 +136,15 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest {
@Override @Override
protected <T> DatabaseComponent createDatabaseComponent( protected <T> DatabaseComponent createDatabaseComponent(
Database<T> database, DatabaseCleaner cleaner, Database<T> database, DatabaseCleaner cleaner,
ShutdownManager shutdown) { ShutdownManager shutdown, PacketFactory packetFactory) {
return createDatabaseComponentImpl(database, cleaner, shutdown); return createDatabaseComponentImpl(database, cleaner, shutdown,
packetFactory);
} }
private <T> DatabaseComponentImpl<T> createDatabaseComponentImpl( private <T> DatabaseComponentImpl<T> createDatabaseComponentImpl(
Database<T> database, DatabaseCleaner cleaner, Database<T> database, DatabaseCleaner cleaner,
ShutdownManager shutdown) { ShutdownManager shutdown, PacketFactory packetFactory) {
return new DatabaseComponentImpl<T>(database, cleaner, shutdown); return new DatabaseComponentImpl<T>(database, cleaner, shutdown,
packetFactory);
} }
} }

View File

@@ -1,6 +1,7 @@
package net.sf.briar.db; package net.sf.briar.db;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet; import java.util.BitSet;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
@@ -32,18 +33,14 @@ import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.OfferWriter;
import net.sf.briar.api.protocol.writers.RequestWriter;
import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter;
import net.sf.briar.api.protocol.writers.TransportUpdateWriter;
import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindow;
import org.jmock.Expectations; import org.jmock.Expectations;
@@ -105,7 +102,7 @@ public abstract class DatabaseComponentTest extends TestCase {
protected abstract <T> DatabaseComponent createDatabaseComponent( protected abstract <T> DatabaseComponent createDatabaseComponent(
Database<T> database, DatabaseCleaner cleaner, Database<T> database, DatabaseCleaner cleaner,
ShutdownManager shutdown); ShutdownManager shutdown, PacketFactory packetFactory);
@Test @Test
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@@ -115,6 +112,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final ConnectionWindow connectionWindow = final ConnectionWindow connectionWindow =
context.mock(ConnectionWindow.class); context.mock(ConnectionWindow.class);
final Group group = context.mock(Group.class); final Group group = context.mock(Group.class);
@@ -200,7 +198,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).close(); oneOf(database).close();
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.open(false); db.open(false);
db.addListener(listener); db.addListener(listener);
@@ -233,6 +231,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// setRating(authorId, Rating.GOOD) // setRating(authorId, Rating.GOOD)
allowing(database).startTransaction(); allowing(database).startTransaction();
@@ -251,7 +250,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).commitTransaction(txn); oneOf(database).commitTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.setRating(authorId, Rating.GOOD); db.setRating(authorId, Rating.GOOD);
@@ -265,6 +264,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// setRating(authorId, Rating.GOOD) // setRating(authorId, Rating.GOOD)
oneOf(database).startTransaction(); oneOf(database).startTransaction();
@@ -287,7 +287,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).commitTransaction(txn); oneOf(database).commitTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.setRating(authorId, Rating.GOOD); db.setRating(authorId, Rating.GOOD);
@@ -302,6 +302,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// setRating(authorId, Rating.GOOD) // setRating(authorId, Rating.GOOD)
oneOf(database).startTransaction(); oneOf(database).startTransaction();
@@ -327,7 +328,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).commitTransaction(txn); oneOf(database).commitTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.setRating(authorId, Rating.GOOD); db.setRating(authorId, Rating.GOOD);
@@ -342,6 +343,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// addLocalGroupMessage(message) // addLocalGroupMessage(message)
oneOf(database).startTransaction(); oneOf(database).startTransaction();
@@ -351,7 +353,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).commitTransaction(txn); oneOf(database).commitTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.addLocalGroupMessage(message); db.addLocalGroupMessage(message);
@@ -365,6 +367,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// addLocalGroupMessage(message) // addLocalGroupMessage(message)
oneOf(database).startTransaction(); oneOf(database).startTransaction();
@@ -376,7 +379,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).commitTransaction(txn); oneOf(database).commitTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.addLocalGroupMessage(message); db.addLocalGroupMessage(message);
@@ -390,6 +393,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// addLocalGroupMessage(message) // addLocalGroupMessage(message)
oneOf(database).startTransaction(); oneOf(database).startTransaction();
@@ -410,7 +414,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).commitTransaction(txn); oneOf(database).commitTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.addLocalGroupMessage(message); db.addLocalGroupMessage(message);
@@ -425,6 +429,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// addLocalGroupMessage(message) // addLocalGroupMessage(message)
oneOf(database).startTransaction(); oneOf(database).startTransaction();
@@ -448,7 +453,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).commitTransaction(txn); oneOf(database).commitTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.addLocalGroupMessage(message); db.addLocalGroupMessage(message);
@@ -462,6 +467,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -473,7 +479,7 @@ public abstract class DatabaseComponentTest extends TestCase {
will(returnValue(false)); will(returnValue(false));
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.addLocalPrivateMessage(privateMessage, contactId); db.addLocalPrivateMessage(privateMessage, contactId);
@@ -487,6 +493,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -499,7 +506,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).setStatus(txn, contactId, messageId, Status.NEW); oneOf(database).setStatus(txn, contactId, messageId, Status.NEW);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.addLocalPrivateMessage(privateMessage, contactId); db.addLocalPrivateMessage(privateMessage, contactId);
@@ -514,17 +521,10 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final AckWriter ackWriter = context.mock(AckWriter.class); final PacketFactory packetFactory = context.mock(PacketFactory.class);
final BatchWriter batchWriter = context.mock(BatchWriter.class);
final OfferWriter offerWriter = context.mock(OfferWriter.class);
final SubscriptionUpdateWriter subscriptionUpdateWriter =
context.mock(SubscriptionUpdateWriter.class);
final TransportUpdateWriter transportUpdateWriter =
context.mock(TransportUpdateWriter.class);
final Ack ack = context.mock(Ack.class); final Ack ack = context.mock(Ack.class);
final Batch batch = context.mock(Batch.class); final Batch batch = context.mock(Batch.class);
final Offer offer = context.mock(Offer.class); final Offer offer = context.mock(Offer.class);
final RequestWriter requestWriter = context.mock(RequestWriter.class);
final SubscriptionUpdate subscriptionUpdate = final SubscriptionUpdate subscriptionUpdate =
context.mock(SubscriptionUpdate.class); context.mock(SubscriptionUpdate.class);
final TransportUpdate transportUpdate = final TransportUpdate transportUpdate =
@@ -538,7 +538,7 @@ public abstract class DatabaseComponentTest extends TestCase {
exactly(19).of(database).commitTransaction(txn); exactly(19).of(database).commitTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
try { try {
db.addLocalPrivateMessage(privateMessage, contactId); db.addLocalPrivateMessage(privateMessage, contactId);
@@ -546,33 +546,33 @@ public abstract class DatabaseComponentTest extends TestCase {
} catch(NoSuchContactException expected) {} } catch(NoSuchContactException expected) {}
try { try {
db.generateAck(contactId, ackWriter); db.generateAck(contactId, 123);
fail(); fail();
} catch(NoSuchContactException expected) {} } catch(NoSuchContactException expected) {}
try { try {
db.generateBatch(contactId, batchWriter); db.generateBatch(contactId, 123);
fail(); fail();
} catch(NoSuchContactException expected) {} } catch(NoSuchContactException expected) {}
try { try {
db.generateBatch(contactId, batchWriter, db.generateBatch(contactId, 123,
Collections.<MessageId>emptyList()); Collections.<MessageId>emptyList());
fail(); fail();
} catch(NoSuchContactException expected) {} } catch(NoSuchContactException expected) {}
try { try {
db.generateOffer(contactId, offerWriter); db.generateOffer(contactId, 123);
fail(); fail();
} catch(NoSuchContactException expected) {} } catch(NoSuchContactException expected) {}
try { try {
db.generateSubscriptionUpdate(contactId, subscriptionUpdateWriter); db.generateSubscriptionUpdate(contactId);
fail(); fail();
} catch(NoSuchContactException expected) {} } catch(NoSuchContactException expected) {}
try { try {
db.generateTransportUpdate(contactId, transportUpdateWriter); db.generateTransportUpdate(contactId);
fail(); fail();
} catch(NoSuchContactException expected) {} } catch(NoSuchContactException expected) {}
@@ -607,7 +607,7 @@ public abstract class DatabaseComponentTest extends TestCase {
} catch(NoSuchContactException expected) {} } catch(NoSuchContactException expected) {}
try { try {
db.receiveOffer(contactId, offer, requestWriter); db.receiveOffer(contactId, offer);
fail(); fail();
} catch(NoSuchContactException expected) {} } catch(NoSuchContactException expected) {}
@@ -650,7 +650,8 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final AckWriter ackWriter = context.mock(AckWriter.class); final PacketFactory packetFactory = context.mock(PacketFactory.class);
final Ack ack = context.mock(Ack.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -658,22 +659,18 @@ public abstract class DatabaseComponentTest extends TestCase {
allowing(database).containsContact(txn, contactId); allowing(database).containsContact(txn, contactId);
will(returnValue(true)); will(returnValue(true));
// Get the batches to ack // Get the batches to ack
oneOf(database).getBatchesToAck(txn, contactId); oneOf(database).getBatchesToAck(txn, contactId, 123);
will(returnValue(batchesToAck)); will(returnValue(batchesToAck));
// Try to add both batches to the writer - only manage to add one // Create the packet
oneOf(ackWriter).writeBatchId(batchId); oneOf(packetFactory).createAck(batchesToAck);
will(returnValue(true)); will(returnValue(ack));
oneOf(ackWriter).writeBatchId(batchId1); // Record the batches that were acked
will(returnValue(false)); oneOf(database).removeBatchesToAck(txn, contactId, batchesToAck);
oneOf(ackWriter).finish();
// Record the batch that was acked
oneOf(database).removeBatchesToAck(txn, contactId,
Collections.singletonList(batchId));
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.generateAck(contactId, ackWriter); assertEquals(ack, db.generateAck(contactId, 123));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -682,47 +679,47 @@ public abstract class DatabaseComponentTest extends TestCase {
public void testGenerateBatch() throws Exception { public void testGenerateBatch() throws Exception {
final MessageId messageId1 = new MessageId(TestUtils.getRandomId()); final MessageId messageId1 = new MessageId(TestUtils.getRandomId());
final byte[] raw1 = new byte[size]; final byte[] raw1 = new byte[size];
final Collection<MessageId> sendable = new ArrayList<MessageId>(); final Collection<MessageId> sendable = Arrays.asList(new MessageId[] {
sendable.add(messageId); messageId,
sendable.add(messageId1); messageId1
});
final Collection<byte[]> messages = Arrays.asList(new byte[][] {
raw,
raw1
});
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final BatchWriter batchWriter = context.mock(BatchWriter.class); final PacketFactory packetFactory = context.mock(PacketFactory.class);
final RawBatch batch = context.mock(RawBatch.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
allowing(database).commitTransaction(txn); allowing(database).commitTransaction(txn);
allowing(database).containsContact(txn, contactId); allowing(database).containsContact(txn, contactId);
will(returnValue(true)); will(returnValue(true));
// Find out how much space we've got
oneOf(batchWriter).getCapacity();
will(returnValue(ProtocolConstants.MAX_PACKET_LENGTH));
// Get the sendable messages // Get the sendable messages
oneOf(database).getSendableMessages(txn, contactId, oneOf(database).getSendableMessages(txn, contactId, size * 2);
ProtocolConstants.MAX_PACKET_LENGTH);
will(returnValue(sendable)); will(returnValue(sendable));
oneOf(database).getMessage(txn, messageId); oneOf(database).getMessage(txn, messageId);
will(returnValue(raw)); will(returnValue(raw));
oneOf(database).getMessage(txn, messageId1); oneOf(database).getMessage(txn, messageId1);
will(returnValue(raw1)); will(returnValue(raw1));
// Add the sendable messages to the batch // Create the packet
oneOf(batchWriter).writeMessage(raw); oneOf(packetFactory).createBatch(messages);
will(returnValue(true)); will(returnValue(batch));
oneOf(batchWriter).writeMessage(raw1); // Record the outstanding batch
will(returnValue(true)); oneOf(batch).getId();
oneOf(batchWriter).finish();
will(returnValue(batchId)); will(returnValue(batchId));
// Record the message that was sent
oneOf(database).addOutstandingBatch(txn, contactId, batchId, oneOf(database).addOutstandingBatch(txn, contactId, batchId,
sendable); sendable);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.generateBatch(contactId, batchWriter); assertEquals(batch, db.generateBatch(contactId, size * 2));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -736,21 +733,22 @@ public abstract class DatabaseComponentTest extends TestCase {
requested.add(messageId); requested.add(messageId);
requested.add(messageId1); requested.add(messageId1);
requested.add(messageId2); requested.add(messageId2);
final Collection<byte[]> msgs = Arrays.asList(new byte[][] {
raw1
});
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final BatchWriter batchWriter = context.mock(BatchWriter.class); final PacketFactory packetFactory = context.mock(PacketFactory.class);
final RawBatch batch = context.mock(RawBatch.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
allowing(database).commitTransaction(txn); allowing(database).commitTransaction(txn);
allowing(database).containsContact(txn, contactId); allowing(database).containsContact(txn, contactId);
will(returnValue(true)); will(returnValue(true));
// Find out how much space we've got
oneOf(batchWriter).getCapacity();
will(returnValue(ProtocolConstants.MAX_PACKET_LENGTH));
// Try to get the requested messages // Try to get the requested messages
oneOf(database).getMessageIfSendable(txn, contactId, messageId); oneOf(database).getMessageIfSendable(txn, contactId, messageId);
will(returnValue(null)); // Message is not sendable will(returnValue(null)); // Message is not sendable
@@ -758,19 +756,19 @@ public abstract class DatabaseComponentTest extends TestCase {
will(returnValue(raw1)); // Message is sendable will(returnValue(raw1)); // Message is sendable
oneOf(database).getMessageIfSendable(txn, contactId, messageId2); oneOf(database).getMessageIfSendable(txn, contactId, messageId2);
will(returnValue(null)); // Message is not sendable will(returnValue(null)); // Message is not sendable
// Add the sendable message to the batch // Create the packet
oneOf(batchWriter).writeMessage(raw1); oneOf(packetFactory).createBatch(msgs);
will(returnValue(true)); will(returnValue(batch));
oneOf(batchWriter).finish(); // Record the outstanding batch
oneOf(batch).getId();
will(returnValue(batchId)); will(returnValue(batchId));
// Record the message that was sent
oneOf(database).addOutstandingBatch(txn, contactId, batchId, oneOf(database).addOutstandingBatch(txn, contactId, batchId,
Collections.singletonList(messageId1)); Collections.singletonList(messageId1));
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.generateBatch(contactId, batchWriter, requested); assertEquals(batch, db.generateBatch(contactId, size * 3, requested));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -778,15 +776,16 @@ public abstract class DatabaseComponentTest extends TestCase {
@Test @Test
public void testGenerateOffer() throws Exception { public void testGenerateOffer() throws Exception {
final MessageId messageId1 = new MessageId(TestUtils.getRandomId()); final MessageId messageId1 = new MessageId(TestUtils.getRandomId());
final Collection<MessageId> sendable = new ArrayList<MessageId>(); final Collection<MessageId> offerable = new ArrayList<MessageId>();
sendable.add(messageId); offerable.add(messageId);
sendable.add(messageId1); offerable.add(messageId1);
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final OfferWriter offerWriter = context.mock(OfferWriter.class); final PacketFactory packetFactory = context.mock(PacketFactory.class);
final Offer offer = context.mock(Offer.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -794,20 +793,16 @@ public abstract class DatabaseComponentTest extends TestCase {
allowing(database).containsContact(txn, contactId); allowing(database).containsContact(txn, contactId);
will(returnValue(true)); will(returnValue(true));
// Get the sendable message IDs // Get the sendable message IDs
oneOf(database).getSendableMessages(txn, contactId); oneOf(database).getOfferableMessages(txn, contactId, 123);
will(returnValue(sendable)); will(returnValue(offerable));
// Try to add both IDs to the writer - only manage to add one // Create the packet
oneOf(offerWriter).writeMessageId(messageId); oneOf(packetFactory).createOffer(offerable);
will(returnValue(true)); will(returnValue(offer));
oneOf(offerWriter).writeMessageId(messageId1);
will(returnValue(false));
oneOf(offerWriter).finish();
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
assertEquals(Collections.singletonList(messageId), assertEquals(offer, db.generateOffer(contactId, 123));
db.generateOffer(contactId, offerWriter));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -820,8 +815,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final SubscriptionUpdateWriter subscriptionUpdateWriter = final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.mock(SubscriptionUpdateWriter.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -835,26 +829,23 @@ public abstract class DatabaseComponentTest extends TestCase {
will(returnValue(now)); will(returnValue(now));
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.generateSubscriptionUpdate(contactId, subscriptionUpdateWriter); assertNull(db.generateSubscriptionUpdate(contactId));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@Test @Test
public void testGenerateSubscriptionUpdate() throws Exception { public void testGenerateSubscriptionUpdate() throws Exception {
final MessageId messageId1 = new MessageId(TestUtils.getRandomId());
final Collection<MessageId> sendable = new ArrayList<MessageId>();
sendable.add(messageId);
sendable.add(messageId1);
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final SubscriptionUpdateWriter subscriptionUpdateWriter = final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.mock(SubscriptionUpdateWriter.class); final SubscriptionUpdate subscriptionUpdate =
context.mock(SubscriptionUpdate.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -871,15 +862,17 @@ public abstract class DatabaseComponentTest extends TestCase {
will(returnValue(Collections.singletonMap(group, 0L))); will(returnValue(Collections.singletonMap(group, 0L)));
oneOf(database).setSubscriptionsSent(with(txn), with(contactId), oneOf(database).setSubscriptionsSent(with(txn), with(contactId),
with(any(long.class))); with(any(long.class)));
// Add the subscriptions to the writer // Create the packet
oneOf(subscriptionUpdateWriter).writeSubscriptions( oneOf(packetFactory).createSubscriptionUpdate(
with(Collections.singletonMap(group, 0L)), with(Collections.singletonMap(group, 0L)),
with(any(long.class))); with(any(long.class)));
will(returnValue(subscriptionUpdate));
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.generateSubscriptionUpdate(contactId, subscriptionUpdateWriter); assertEquals(subscriptionUpdate,
db.generateSubscriptionUpdate(contactId));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -892,8 +885,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final TransportUpdateWriter transportUpdateWriter = final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.mock(TransportUpdateWriter.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -907,26 +899,23 @@ public abstract class DatabaseComponentTest extends TestCase {
will(returnValue(now)); will(returnValue(now));
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.generateTransportUpdate(contactId, transportUpdateWriter); assertNull(db.generateTransportUpdate(contactId));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@Test @Test
public void testGenerateTransportUpdate() throws Exception { public void testGenerateTransportUpdate() throws Exception {
final MessageId messageId1 = new MessageId(TestUtils.getRandomId());
final Collection<MessageId> sendable = new ArrayList<MessageId>();
sendable.add(messageId);
sendable.add(messageId1);
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final TransportUpdateWriter transportUpdateWriter = final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.mock(TransportUpdateWriter.class); final TransportUpdate transportUpdate =
context.mock(TransportUpdate.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -943,14 +932,15 @@ public abstract class DatabaseComponentTest extends TestCase {
will(returnValue(transports)); will(returnValue(transports));
oneOf(database).setTransportsSent(with(txn), with(contactId), oneOf(database).setTransportsSent(with(txn), with(contactId),
with(any(long.class))); with(any(long.class)));
// Add the properties to the writer // Create the packet
oneOf(transportUpdateWriter).writeTransports(with(transports), oneOf(packetFactory).createTransportUpdate(with(transports),
with(any(long.class))); with(any(long.class)));
will(returnValue(transportUpdate));
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.generateTransportUpdate(contactId, transportUpdateWriter); assertEquals(transportUpdate, db.generateTransportUpdate(contactId));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -963,6 +953,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final Ack ack = context.mock(Ack.class); final Ack ack = context.mock(Ack.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
@@ -980,7 +971,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).removeLostBatch(txn, contactId, batchId1); oneOf(database).removeLostBatch(txn, contactId, batchId1);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.receiveAck(contactId, ack); db.receiveAck(contactId, ack);
@@ -994,6 +985,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final Batch batch = context.mock(Batch.class); final Batch batch = context.mock(Batch.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
@@ -1013,7 +1005,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).addBatchToAck(txn, contactId, batchId); oneOf(database).addBatchToAck(txn, contactId, batchId);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.receiveBatch(contactId, batch); db.receiveBatch(contactId, batch);
@@ -1027,6 +1019,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final Batch batch = context.mock(Batch.class); final Batch batch = context.mock(Batch.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
@@ -1045,7 +1038,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).addBatchToAck(txn, contactId, batchId); oneOf(database).addBatchToAck(txn, contactId, batchId);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.receiveBatch(contactId, batch); db.receiveBatch(contactId, batch);
@@ -1060,6 +1053,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final Batch batch = context.mock(Batch.class); final Batch batch = context.mock(Batch.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
@@ -1079,7 +1073,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).addBatchToAck(txn, contactId, batchId); oneOf(database).addBatchToAck(txn, contactId, batchId);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.receiveBatch(contactId, batch); db.receiveBatch(contactId, batch);
@@ -1094,6 +1088,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final Batch batch = context.mock(Batch.class); final Batch batch = context.mock(Batch.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
@@ -1117,7 +1112,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).addBatchToAck(txn, contactId, batchId); oneOf(database).addBatchToAck(txn, contactId, batchId);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.receiveBatch(contactId, batch); db.receiveBatch(contactId, batch);
@@ -1131,6 +1126,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final Batch batch = context.mock(Batch.class); final Batch batch = context.mock(Batch.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
@@ -1163,7 +1159,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).addBatchToAck(txn, contactId, batchId); oneOf(database).addBatchToAck(txn, contactId, batchId);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.receiveBatch(contactId, batch); db.receiveBatch(contactId, batch);
@@ -1177,6 +1173,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final Batch batch = context.mock(Batch.class); final Batch batch = context.mock(Batch.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
@@ -1211,7 +1208,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).addBatchToAck(txn, contactId, batchId); oneOf(database).addBatchToAck(txn, contactId, batchId);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.receiveBatch(contactId, batch); db.receiveBatch(contactId, batch);
@@ -1234,8 +1231,9 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final Offer offer = context.mock(Offer.class); final Offer offer = context.mock(Offer.class);
final RequestWriter requestWriter = context.mock(RequestWriter.class); final Request request = context.mock(Request.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -1251,12 +1249,14 @@ public abstract class DatabaseComponentTest extends TestCase {
will(returnValue(true)); // Visible - do not request message # 1 will(returnValue(true)); // Visible - do not request message # 1
oneOf(database).setStatusSeenIfVisible(txn, contactId, messageId2); oneOf(database).setStatusSeenIfVisible(txn, contactId, messageId2);
will(returnValue(false)); // Not visible - request message # 2 will(returnValue(false)); // Not visible - request message # 2
oneOf(requestWriter).writeRequest(expectedRequest, 3); // Create the packet
oneOf(packetFactory).createRequest(expectedRequest, 3);
will(returnValue(request));
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.receiveOffer(contactId, offer, requestWriter); assertEquals(request, db.receiveOffer(contactId, offer));
context.assertIsSatisfied(); context.assertIsSatisfied();
} }
@@ -1269,6 +1269,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final SubscriptionUpdate subscriptionUpdate = final SubscriptionUpdate subscriptionUpdate =
context.mock(SubscriptionUpdate.class); context.mock(SubscriptionUpdate.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
@@ -1286,7 +1287,7 @@ public abstract class DatabaseComponentTest extends TestCase {
Collections.singletonMap(group, 0L), timestamp); Collections.singletonMap(group, 0L), timestamp);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.receiveSubscriptionUpdate(contactId, subscriptionUpdate); db.receiveSubscriptionUpdate(contactId, subscriptionUpdate);
@@ -1301,6 +1302,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final TransportUpdate transportUpdate = final TransportUpdate transportUpdate =
context.mock(TransportUpdate.class); context.mock(TransportUpdate.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
@@ -1318,7 +1320,7 @@ public abstract class DatabaseComponentTest extends TestCase {
timestamp); timestamp);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.receiveTransportUpdate(contactId, transportUpdate); db.receiveTransportUpdate(contactId, transportUpdate);
@@ -1332,6 +1334,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final DatabaseListener listener = context.mock(DatabaseListener.class); final DatabaseListener listener = context.mock(DatabaseListener.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// addLocalGroupMessage(message) // addLocalGroupMessage(message)
@@ -1354,7 +1357,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(listener).eventOccurred(with(any(MessagesAddedEvent.class))); oneOf(listener).eventOccurred(with(any(MessagesAddedEvent.class)));
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.addListener(listener); db.addListener(listener);
db.addLocalGroupMessage(message); db.addLocalGroupMessage(message);
@@ -1369,6 +1372,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final DatabaseListener listener = context.mock(DatabaseListener.class); final DatabaseListener listener = context.mock(DatabaseListener.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
@@ -1384,7 +1388,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(listener).eventOccurred(with(any(MessagesAddedEvent.class))); oneOf(listener).eventOccurred(with(any(MessagesAddedEvent.class)));
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.addListener(listener); db.addListener(listener);
db.addLocalPrivateMessage(privateMessage, contactId); db.addLocalPrivateMessage(privateMessage, contactId);
@@ -1400,6 +1404,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final DatabaseListener listener = context.mock(DatabaseListener.class); final DatabaseListener listener = context.mock(DatabaseListener.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// addLocalGroupMessage(message) // addLocalGroupMessage(message)
@@ -1413,7 +1418,7 @@ public abstract class DatabaseComponentTest extends TestCase {
// The message was not added, so the listener should not be called // The message was not added, so the listener should not be called
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.addListener(listener); db.addListener(listener);
db.addLocalGroupMessage(message); db.addLocalGroupMessage(message);
@@ -1429,6 +1434,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final DatabaseListener listener = context.mock(DatabaseListener.class); final DatabaseListener listener = context.mock(DatabaseListener.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
@@ -1442,7 +1448,7 @@ public abstract class DatabaseComponentTest extends TestCase {
// The message was not added, so the listener should not be called // The message was not added, so the listener should not be called
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.addListener(listener); db.addListener(listener);
db.addLocalPrivateMessage(privateMessage, contactId); db.addLocalPrivateMessage(privateMessage, contactId);
@@ -1460,6 +1466,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final DatabaseListener listener = context.mock(DatabaseListener.class); final DatabaseListener listener = context.mock(DatabaseListener.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(database).startTransaction(); oneOf(database).startTransaction();
@@ -1474,7 +1481,7 @@ public abstract class DatabaseComponentTest extends TestCase {
TransportAddedEvent.class))); TransportAddedEvent.class)));
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.addListener(listener); db.addListener(listener);
db.setLocalProperties(transportId, properties); db.setLocalProperties(transportId, properties);
@@ -1492,6 +1499,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
final DatabaseListener listener = context.mock(DatabaseListener.class); final DatabaseListener listener = context.mock(DatabaseListener.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(database).startTransaction(); oneOf(database).startTransaction();
@@ -1501,7 +1509,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).commitTransaction(txn); oneOf(database).commitTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.addListener(listener); db.addListener(listener);
db.setLocalProperties(transportId, properties); db.setLocalProperties(transportId, properties);
@@ -1516,6 +1524,7 @@ public abstract class DatabaseComponentTest extends TestCase {
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -1526,7 +1535,7 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).setStatusSeenIfVisible(txn, contactId, messageId); oneOf(database).setStatusSeenIfVisible(txn, contactId, messageId);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown); shutdown, packetFactory);
db.setSeen(contactId, Collections.singletonList(messageId)); db.setSeen(contactId, Collections.singletonList(messageId));

View File

@@ -46,7 +46,6 @@ import net.sf.briar.api.transport.ConnectionWindowFactory;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.lifecycle.LifecycleModule; import net.sf.briar.lifecycle.LifecycleModule;
import net.sf.briar.protocol.ProtocolModule; import net.sf.briar.protocol.ProtocolModule;
import net.sf.briar.protocol.writers.ProtocolWritersModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
import net.sf.briar.transport.TransportModule; import net.sf.briar.transport.TransportModule;
import net.sf.briar.transport.batch.TransportBatchModule; import net.sf.briar.transport.batch.TransportBatchModule;
@@ -107,10 +106,9 @@ public class H2DatabaseTest extends TestCase {
}; };
Injector i = Guice.createInjector(testModule, new CryptoModule(), Injector i = Guice.createInjector(testModule, new CryptoModule(),
new DatabaseModule(), new LifecycleModule(), new DatabaseModule(), new LifecycleModule(),
new ProtocolModule(), new ProtocolWritersModule(), new ProtocolModule(), new SerialModule(),
new SerialModule(), new TransportBatchModule(), new TransportBatchModule(), new TransportModule(),
new TransportModule(), new TransportStreamModule(), new TransportStreamModule(), new TestDatabaseModule(testDir));
new TestDatabaseModule(testDir));
connectionContextFactory = connectionContextFactory =
i.getInstance(ConnectionContextFactory.class); i.getInstance(ConnectionContextFactory.class);
connectionWindowFactory = i.getInstance(ConnectionWindowFactory.class); connectionWindowFactory = i.getInstance(ConnectionWindowFactory.class);
@@ -588,7 +586,7 @@ public class H2DatabaseTest extends TestCase {
db.addBatchToAck(txn, contactId, batchId1); db.addBatchToAck(txn, contactId, batchId1);
// Both batch IDs should be returned // Both batch IDs should be returned
Collection<BatchId> acks = db.getBatchesToAck(txn, contactId); Collection<BatchId> acks = db.getBatchesToAck(txn, contactId, 1234);
assertEquals(2, acks.size()); assertEquals(2, acks.size());
assertTrue(acks.contains(batchId)); assertTrue(acks.contains(batchId));
assertTrue(acks.contains(batchId1)); assertTrue(acks.contains(batchId1));
@@ -597,7 +595,7 @@ public class H2DatabaseTest extends TestCase {
db.removeBatchesToAck(txn, contactId, acks); db.removeBatchesToAck(txn, contactId, acks);
// Both batch IDs should have been removed // Both batch IDs should have been removed
acks = db.getBatchesToAck(txn, contactId); acks = db.getBatchesToAck(txn, contactId, 1234);
assertEquals(0, acks.size()); assertEquals(0, acks.size());
db.commitTransaction(txn); db.commitTransaction(txn);
@@ -615,7 +613,7 @@ public class H2DatabaseTest extends TestCase {
db.addBatchToAck(txn, contactId, batchId); db.addBatchToAck(txn, contactId, batchId);
// The batch ID should only be returned once // The batch ID should only be returned once
Collection<BatchId> acks = db.getBatchesToAck(txn, contactId); Collection<BatchId> acks = db.getBatchesToAck(txn, contactId, 1234);
assertEquals(1, acks.size()); assertEquals(1, acks.size());
assertTrue(acks.contains(batchId)); assertTrue(acks.contains(batchId));
@@ -623,7 +621,7 @@ public class H2DatabaseTest extends TestCase {
db.removeBatchesToAck(txn, contactId, acks); db.removeBatchesToAck(txn, contactId, acks);
// The batch ID should have been removed // The batch ID should have been removed
acks = db.getBatchesToAck(txn, contactId); acks = db.getBatchesToAck(txn, contactId, 1234);
assertEquals(0, acks.size()); assertEquals(0, acks.size());
db.commitTransaction(txn); db.commitTransaction(txn);

View File

@@ -61,10 +61,6 @@ class TestMessage implements Message {
return timestamp; return timestamp;
} }
public int getLength() {
return raw.length;
}
public byte[] getSerialised() { public byte[] getSerialised() {
return raw; return raw;
} }

View File

@@ -10,6 +10,7 @@ import junit.framework.TestCase;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.protocol.UniqueId;
@@ -42,8 +43,8 @@ public class AckReaderTest extends TestCase {
@Test @Test
public void testFormatExceptionIfAckIsTooLarge() throws Exception { public void testFormatExceptionIfAckIsTooLarge() throws Exception {
AckFactory ackFactory = context.mock(AckFactory.class); PacketFactory packetFactory = context.mock(PacketFactory.class);
AckReader ackReader = new AckReader(ackFactory); AckReader ackReader = new AckReader(packetFactory);
byte[] b = createAck(true); byte[] b = createAck(true);
ByteArrayInputStream in = new ByteArrayInputStream(b); ByteArrayInputStream in = new ByteArrayInputStream(b);
@@ -60,11 +61,11 @@ public class AckReaderTest extends TestCase {
@Test @Test
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void testNoFormatExceptionIfAckIsMaximumSize() throws Exception { public void testNoFormatExceptionIfAckIsMaximumSize() throws Exception {
final AckFactory ackFactory = context.mock(AckFactory.class); final PacketFactory packetFactory = context.mock(PacketFactory.class);
AckReader ackReader = new AckReader(ackFactory); AckReader ackReader = new AckReader(packetFactory);
final Ack ack = context.mock(Ack.class); final Ack ack = context.mock(Ack.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(ackFactory).createAck(with(any(Collection.class))); oneOf(packetFactory).createAck(with(any(Collection.class)));
will(returnValue(ack)); will(returnValue(ack));
}}); }});
@@ -79,11 +80,11 @@ public class AckReaderTest extends TestCase {
@Test @Test
public void testEmptyAck() throws Exception { public void testEmptyAck() throws Exception {
final AckFactory ackFactory = context.mock(AckFactory.class); final PacketFactory packetFactory = context.mock(PacketFactory.class);
AckReader ackReader = new AckReader(ackFactory); AckReader ackReader = new AckReader(packetFactory);
final Ack ack = context.mock(Ack.class); final Ack ack = context.mock(Ack.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(ackFactory).createAck( oneOf(packetFactory).createAck(
with(Collections.<BatchId>emptyList())); with(Collections.<BatchId>emptyList()));
will(returnValue(ack)); will(returnValue(ack));
}}); }});

View File

@@ -1,4 +1,4 @@
package net.sf.briar.protocol.writers; package net.sf.briar.protocol;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_AUTHOR_NAME_LENGTH; import static net.sf.briar.api.protocol.ProtocolConstants.MAX_AUTHOR_NAME_LENGTH;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_BODY_LENGTH; import static net.sf.briar.api.protocol.ProtocolConstants.MAX_BODY_LENGTH;
@@ -15,12 +15,14 @@ import java.io.ByteArrayOutputStream;
import java.security.PrivateKey; import java.security.PrivateKey;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.TestUtils; import net.sf.briar.TestUtils;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Author; import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.AuthorFactory; import net.sf.briar.api.protocol.AuthorFactory;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
@@ -29,19 +31,18 @@ import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageFactory; import net.sf.briar.api.protocol.MessageFactory;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolWriter;
import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.OfferWriter;
import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter;
import net.sf.briar.api.protocol.writers.TransportUpdateWriter;
import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.serial.WriterFactory;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.protocol.ProtocolModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
import org.junit.Test; import org.junit.Test;
@@ -51,23 +52,23 @@ import com.google.inject.Injector;
public class ConstantsTest extends TestCase { public class ConstantsTest extends TestCase {
private final WriterFactory writerFactory;
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final SerialComponent serial;
private final GroupFactory groupFactory; private final GroupFactory groupFactory;
private final AuthorFactory authorFactory; private final AuthorFactory authorFactory;
private final MessageFactory messageFactory; private final MessageFactory messageFactory;
private final PacketFactory packetFactory;
private final ProtocolWriterFactory protocolWriterFactory;
public ConstantsTest() throws Exception { public ConstantsTest() throws Exception {
super(); super();
Injector i = Guice.createInjector(new CryptoModule(), Injector i = Guice.createInjector(new CryptoModule(),
new ProtocolModule(), new SerialModule()); new ProtocolModule(), new SerialModule());
writerFactory = i.getInstance(WriterFactory.class);
crypto = i.getInstance(CryptoComponent.class); crypto = i.getInstance(CryptoComponent.class);
serial = i.getInstance(SerialComponent.class);
groupFactory = i.getInstance(GroupFactory.class); groupFactory = i.getInstance(GroupFactory.class);
authorFactory = i.getInstance(AuthorFactory.class); authorFactory = i.getInstance(AuthorFactory.class);
messageFactory = i.getInstance(MessageFactory.class); messageFactory = i.getInstance(MessageFactory.class);
packetFactory = i.getInstance(PacketFactory.class);
protocolWriterFactory = i.getInstance(ProtocolWriterFactory.class);
} }
@Test @Test
@@ -83,25 +84,18 @@ public class ConstantsTest extends TestCase {
private void testBatchesFitIntoAck(int length) throws Exception { private void testBatchesFitIntoAck(int length) throws Exception {
// Create an ack with as many batch IDs as possible // Create an ack with as many batch IDs as possible
ByteArrayOutputStream out = new ByteArrayOutputStream(length); ByteArrayOutputStream out = new ByteArrayOutputStream(length);
AckWriter a = new AckWriterImpl(out, serial, writerFactory); ProtocolWriter writer = protocolWriterFactory.createProtocolWriter(out);
a.setMaxPacketLength(length); int maxBatches = writer.getMaxBatchesForAck(length);
while(a.writeBatchId(new BatchId(TestUtils.getRandomId()))); Collection<BatchId> acked = new ArrayList<BatchId>();
a.finish(); for(int i = 0; i < maxBatches; i++) {
acked.add(new BatchId(TestUtils.getRandomId()));
}
Ack a = packetFactory.createAck(acked);
writer.writeAck(a);
// Check the size of the serialised ack // Check the size of the serialised ack
assertTrue(out.size() <= length); assertTrue(out.size() <= length);
} }
@Test
public void testEmptyAck() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
AckWriter a = new AckWriterImpl(out, serial, writerFactory);
// There's not enough room for a batch ID
a.setMaxPacketLength(4);
assertFalse(a.writeBatchId(new BatchId(TestUtils.getRandomId())));
// Check that nothing was written
assertEquals(0, out.size());
}
@Test @Test
public void testMessageFitsIntoBatch() throws Exception { public void testMessageFitsIntoBatch() throws Exception {
// Create a maximum-length group // Create a maximum-length group
@@ -122,10 +116,10 @@ public class ConstantsTest extends TestCase {
// Add the message to a batch // Add the message to a batch
ByteArrayOutputStream out = ByteArrayOutputStream out =
new ByteArrayOutputStream(MAX_PACKET_LENGTH); new ByteArrayOutputStream(MAX_PACKET_LENGTH);
BatchWriter b = new BatchWriterImpl(out, serial, writerFactory, ProtocolWriter writer = protocolWriterFactory.createProtocolWriter(out);
crypto.getMessageDigest()); RawBatch b = packetFactory.createBatch(Collections.singletonList(
assertTrue(b.writeMessage(message.getSerialised())); message.getSerialised()));
b.finish(); writer.writeBatch(b);
// Check the size of the serialised batch // Check the size of the serialised batch
assertTrue(out.size() > UniqueId.LENGTH + MAX_GROUP_NAME_LENGTH assertTrue(out.size() > UniqueId.LENGTH + MAX_GROUP_NAME_LENGTH
+ MAX_PUBLIC_KEY_LENGTH + MAX_AUTHOR_NAME_LENGTH + MAX_PUBLIC_KEY_LENGTH + MAX_AUTHOR_NAME_LENGTH
@@ -133,18 +127,6 @@ public class ConstantsTest extends TestCase {
assertTrue(out.size() <= MAX_PACKET_LENGTH); assertTrue(out.size() <= MAX_PACKET_LENGTH);
} }
@Test
public void testEmptyBatch() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
BatchWriter b = new BatchWriterImpl(out, serial, writerFactory,
crypto.getMessageDigest());
// There's not enough room for a message
b.setMaxPacketLength(4);
assertFalse(b.writeMessage(new byte[4]));
// Check that nothing was written
assertEquals(0, out.size());
}
@Test @Test
public void testMessagesFitIntoLargeOffer() throws Exception { public void testMessagesFitIntoLargeOffer() throws Exception {
testMessagesFitIntoOffer(MAX_PACKET_LENGTH); testMessagesFitIntoOffer(MAX_PACKET_LENGTH);
@@ -158,25 +140,18 @@ public class ConstantsTest extends TestCase {
private void testMessagesFitIntoOffer(int length) throws Exception { private void testMessagesFitIntoOffer(int length) throws Exception {
// Create an offer with as many message IDs as possible // Create an offer with as many message IDs as possible
ByteArrayOutputStream out = new ByteArrayOutputStream(length); ByteArrayOutputStream out = new ByteArrayOutputStream(length);
OfferWriter o = new OfferWriterImpl(out, serial, writerFactory); ProtocolWriter writer = protocolWriterFactory.createProtocolWriter(out);
o.setMaxPacketLength(length); int maxMessages = writer.getMaxMessagesForOffer(length);
while(o.writeMessageId(new MessageId(TestUtils.getRandomId()))); Collection<MessageId> offered = new ArrayList<MessageId>();
o.finish(); for(int i = 0; i < maxMessages; i++) {
offered.add(new MessageId(TestUtils.getRandomId()));
}
Offer o = packetFactory.createOffer(offered);
writer.writeOffer(o);
// Check the size of the serialised offer // Check the size of the serialised offer
assertTrue(out.size() <= length); assertTrue(out.size() <= length);
} }
@Test
public void testEmptyOffer() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
OfferWriter o = new OfferWriterImpl(out, serial, writerFactory);
// There's not enough room for a message ID
o.setMaxPacketLength(4);
assertFalse(o.writeMessageId(new MessageId(TestUtils.getRandomId())));
// Check that nothing was written
assertEquals(0, out.size());
}
@Test @Test
public void testSubscriptionsFitIntoUpdate() throws Exception { public void testSubscriptionsFitIntoUpdate() throws Exception {
// Create the maximum number of maximum-length subscriptions // Create the maximum number of maximum-length subscriptions
@@ -190,9 +165,10 @@ public class ConstantsTest extends TestCase {
// Add the subscriptions to an update // Add the subscriptions to an update
ByteArrayOutputStream out = ByteArrayOutputStream out =
new ByteArrayOutputStream(MAX_PACKET_LENGTH); new ByteArrayOutputStream(MAX_PACKET_LENGTH);
SubscriptionUpdateWriter s = ProtocolWriter writer = protocolWriterFactory.createProtocolWriter(out);
new SubscriptionUpdateWriterImpl(out, writerFactory); SubscriptionUpdate s = packetFactory.createSubscriptionUpdate(subs,
s.writeSubscriptions(subs, Long.MAX_VALUE); Long.MAX_VALUE);
writer.writeSubscriptionUpdate(s);
// Check the size of the serialised update // Check the size of the serialised update
assertTrue(out.size() > MAX_GROUPS * assertTrue(out.size() > MAX_GROUPS *
(MAX_GROUP_NAME_LENGTH + MAX_PUBLIC_KEY_LENGTH + 8) + 8); (MAX_GROUP_NAME_LENGTH + MAX_PUBLIC_KEY_LENGTH + 8) + 8);
@@ -218,9 +194,10 @@ public class ConstantsTest extends TestCase {
// Add the transports to an update // Add the transports to an update
ByteArrayOutputStream out = ByteArrayOutputStream out =
new ByteArrayOutputStream(MAX_PACKET_LENGTH); new ByteArrayOutputStream(MAX_PACKET_LENGTH);
TransportUpdateWriter t = ProtocolWriter writer = protocolWriterFactory.createProtocolWriter(out);
new TransportUpdateWriterImpl(out, writerFactory); TransportUpdate t = packetFactory.createTransportUpdate(transports,
t.writeTransports(transports, Long.MAX_VALUE); Long.MAX_VALUE);
writer.writeTransportUpdate(t);
// Check the size of the serialised update // Check the size of the serialised update
assertTrue(out.size() > MAX_TRANSPORTS * (UniqueId.LENGTH + 4 assertTrue(out.size() > MAX_TRANSPORTS * (UniqueId.LENGTH + 4
+ (MAX_PROPERTIES_PER_TRANSPORT * MAX_PROPERTY_LENGTH * 2)) + (MAX_PROPERTIES_PER_TRANSPORT * MAX_PROPERTY_LENGTH * 2))

View File

@@ -17,23 +17,19 @@ import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageFactory; import net.sf.briar.api.protocol.MessageFactory;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReader;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriter;
import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.OfferWriter;
import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import net.sf.briar.api.protocol.writers.RequestWriter;
import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter;
import net.sf.briar.api.protocol.writers.TransportUpdateWriter;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.protocol.writers.ProtocolWritersModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
import org.junit.Test; import org.junit.Test;
@@ -45,6 +41,7 @@ public class ProtocolReadWriteTest extends TestCase {
private final ProtocolReaderFactory readerFactory; private final ProtocolReaderFactory readerFactory;
private final ProtocolWriterFactory writerFactory; private final ProtocolWriterFactory writerFactory;
private final PacketFactory packetFactory;
private final BatchId batchId; private final BatchId batchId;
private final Group group; private final Group group;
private final Message message; private final Message message;
@@ -58,10 +55,10 @@ public class ProtocolReadWriteTest extends TestCase {
public ProtocolReadWriteTest() throws Exception { public ProtocolReadWriteTest() throws Exception {
super(); super();
Injector i = Guice.createInjector(new CryptoModule(), Injector i = Guice.createInjector(new CryptoModule(),
new ProtocolModule(), new ProtocolWritersModule(), new ProtocolModule(), new SerialModule());
new SerialModule());
readerFactory = i.getInstance(ProtocolReaderFactory.class); readerFactory = i.getInstance(ProtocolReaderFactory.class);
writerFactory = i.getInstance(ProtocolWriterFactory.class); writerFactory = i.getInstance(ProtocolWriterFactory.class);
packetFactory = i.getInstance(PacketFactory.class);
batchId = new BatchId(TestUtils.getRandomId()); batchId = new BatchId(TestUtils.getRandomId());
GroupFactory groupFactory = i.getInstance(GroupFactory.class); GroupFactory groupFactory = i.getInstance(GroupFactory.class);
group = groupFactory.createGroup("Unrestricted group", null); group = groupFactory.createGroup("Unrestricted group", null);
@@ -83,53 +80,54 @@ public class ProtocolReadWriteTest extends TestCase {
public void testWriteAndRead() throws Exception { public void testWriteAndRead() throws Exception {
// Write // Write
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ProtocolWriter writer = writerFactory.createProtocolWriter(out);
AckWriter a = writerFactory.createAckWriter(out); Ack a = packetFactory.createAck(Collections.singletonList(batchId));
a.writeBatchId(batchId); writer.writeAck(a);
a.finish();
BatchWriter b = writerFactory.createBatchWriter(out); RawBatch b = packetFactory.createBatch(Collections.singletonList(
b.writeMessage(message.getSerialised()); message.getSerialised()));
b.finish(); writer.writeBatch(b);
OfferWriter o = writerFactory.createOfferWriter(out); Offer o = packetFactory.createOffer(Collections.singletonList(
o.writeMessageId(message.getId()); message.getId()));
o.finish(); writer.writeOffer(o);
RequestWriter r = writerFactory.createRequestWriter(out); Request r = packetFactory.createRequest(bitSet, 10);
r.writeRequest(bitSet, 10); writer.writeRequest(r);
SubscriptionUpdateWriter s = SubscriptionUpdate s = packetFactory.createSubscriptionUpdate(
writerFactory.createSubscriptionUpdateWriter(out); subscriptions, timestamp);
s.writeSubscriptions(subscriptions, timestamp); writer.writeSubscriptionUpdate(s);
TransportUpdateWriter t = TransportUpdate t = packetFactory.createTransportUpdate(transports,
writerFactory.createTransportUpdateWriter(out); timestamp);
t.writeTransports(transports, timestamp); writer.writeTransportUpdate(t);
// Read // Read
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
ProtocolReader reader = readerFactory.createProtocolReader(in); ProtocolReader reader = readerFactory.createProtocolReader(in);
Ack ack = reader.readAck(); a = reader.readAck();
assertEquals(Collections.singletonList(batchId), ack.getBatchIds()); assertEquals(Collections.singletonList(batchId), a.getBatchIds());
Batch batch = reader.readBatch().verify(); Batch b1 = reader.readBatch().verify();
assertEquals(Collections.singletonList(message), batch.getMessages()); assertEquals(Collections.singletonList(message), b1.getMessages());
Offer offer = reader.readOffer(); o = reader.readOffer();
assertEquals(Collections.singletonList(message.getId()), assertEquals(Collections.singletonList(message.getId()),
offer.getMessageIds()); o.getMessageIds());
Request request = reader.readRequest(); r = reader.readRequest();
assertEquals(bitSet, request.getBitmap()); assertEquals(bitSet, r.getBitmap());
assertEquals(10, r.getLength());
SubscriptionUpdate subscriptionUpdate = reader.readSubscriptionUpdate(); s = reader.readSubscriptionUpdate();
assertEquals(subscriptions, subscriptionUpdate.getSubscriptions()); assertEquals(subscriptions, s.getSubscriptions());
assertTrue(subscriptionUpdate.getTimestamp() == timestamp); assertEquals(timestamp, s.getTimestamp());
TransportUpdate transportUpdate = reader.readTransportUpdate(); t = reader.readTransportUpdate();
assertEquals(transports, transportUpdate.getTransports()); assertEquals(transports, t.getTransports());
assertTrue(transportUpdate.getTimestamp() == timestamp); assertEquals(timestamp, t.getTimestamp());
} }
} }

View File

@@ -0,0 +1,83 @@
package net.sf.briar.protocol;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.BitSet;
import junit.framework.TestCase;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolWriter;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.serial.SerialComponent;
import net.sf.briar.api.serial.WriterFactory;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.serial.SerialModule;
import net.sf.briar.util.StringUtils;
import org.junit.Test;
import com.google.inject.Guice;
import com.google.inject.Injector;
public class ProtocolWriterImplTest extends TestCase {
private final PacketFactory packetFactory;
private final SerialComponent serial;
private final WriterFactory writerFactory;
public ProtocolWriterImplTest() {
super();
Injector i = Guice.createInjector(new CryptoModule(),
new ProtocolModule(), new SerialModule());
packetFactory = i.getInstance(PacketFactory.class);
serial = i.getInstance(SerialComponent.class);
writerFactory = i.getInstance(WriterFactory.class);
}
@Test
public void testWriteBitmapNoPadding() throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
ProtocolWriter w = new ProtocolWriterImpl(serial, writerFactory, out);
BitSet b = new BitSet();
// 11011001 = 0xD9
b.set(0);
b.set(1);
b.set(3);
b.set(4);
b.set(7);
// 01011001 = 0x59
b.set(9);
b.set(11);
b.set(12);
b.set(15);
Request r = packetFactory.createRequest(b, 16);
w.writeRequest(r);
// Short user tag 8, 0 as uint7, short bytes with length 2, 0xD959
byte[] output = out.toByteArray();
assertEquals("C8" + "00" + "92" + "D959",
StringUtils.toHexString(output));
}
@Test
public void testWriteBitmapWithPadding() throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
ProtocolWriter w = new ProtocolWriterImpl(serial, writerFactory, out);
BitSet b = new BitSet();
// 01011001 = 0x59
b.set(1);
b.set(3);
b.set(4);
b.set(7);
// 11011xxx = 0xD8, after padding
b.set(8);
b.set(9);
b.set(11);
b.set(12);
Request r = packetFactory.createRequest(b, 13);
w.writeRequest(r);
// Short user tag 8, 3 as uint7, short bytes with length 2, 0x59D8
byte[] output = out.toByteArray();
assertEquals("C8" + "03" + "92" + "59D8",
StringUtils.toHexString(output));
}
}

View File

@@ -6,6 +6,7 @@ import java.util.BitSet;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
@@ -13,6 +14,7 @@ import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.serial.ReaderFactory;
import net.sf.briar.api.serial.Writer; import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory; import net.sf.briar.api.serial.WriterFactory;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
import org.jmock.Expectations; import org.jmock.Expectations;
@@ -26,20 +28,23 @@ public class RequestReaderTest extends TestCase {
private final ReaderFactory readerFactory; private final ReaderFactory readerFactory;
private final WriterFactory writerFactory; private final WriterFactory writerFactory;
private final PacketFactory packetFactory;
private final Mockery context; private final Mockery context;
public RequestReaderTest() throws Exception { public RequestReaderTest() throws Exception {
super(); super();
Injector i = Guice.createInjector(new SerialModule()); Injector i = Guice.createInjector(new CryptoModule(),
new ProtocolModule(), new SerialModule());
readerFactory = i.getInstance(ReaderFactory.class); readerFactory = i.getInstance(ReaderFactory.class);
writerFactory = i.getInstance(WriterFactory.class); writerFactory = i.getInstance(WriterFactory.class);
packetFactory = i.getInstance(PacketFactory.class);
context = new Mockery(); context = new Mockery();
} }
@Test @Test
public void testFormatExceptionIfRequestIsTooLarge() throws Exception { public void testFormatExceptionIfRequestIsTooLarge() throws Exception {
RequestFactory requestFactory = context.mock(RequestFactory.class); PacketFactory packetFactory = context.mock(PacketFactory.class);
RequestReader requestReader = new RequestReader(requestFactory); RequestReader requestReader = new RequestReader(packetFactory);
byte[] b = createRequest(true); byte[] b = createRequest(true);
ByteArrayInputStream in = new ByteArrayInputStream(b); ByteArrayInputStream in = new ByteArrayInputStream(b);
@@ -55,12 +60,12 @@ public class RequestReaderTest extends TestCase {
@Test @Test
public void testNoFormatExceptionIfRequestIsMaximumSize() throws Exception { public void testNoFormatExceptionIfRequestIsMaximumSize() throws Exception {
final RequestFactory requestFactory = final PacketFactory packetFactory = context.mock(PacketFactory.class);
context.mock(RequestFactory.class); RequestReader requestReader = new RequestReader(packetFactory);
RequestReader requestReader = new RequestReader(requestFactory);
final Request request = context.mock(Request.class); final Request request = context.mock(Request.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(requestFactory).createRequest(with(any(BitSet.class))); oneOf(packetFactory).createRequest(with(any(BitSet.class)),
with(any(int.class)));
will(returnValue(request)); will(returnValue(request));
}}); }});
@@ -96,8 +101,7 @@ public class RequestReaderTest extends TestCase {
// Deserialise the request // Deserialise the request
ByteArrayInputStream in = new ByteArrayInputStream(b); ByteArrayInputStream in = new ByteArrayInputStream(b);
Reader reader = readerFactory.createReader(in); Reader reader = readerFactory.createReader(in);
RequestReader requestReader = RequestReader requestReader = new RequestReader(packetFactory);
new RequestReader(new RequestFactoryImpl());
reader.addObjectReader(Types.REQUEST, requestReader); reader.addObjectReader(Types.REQUEST, requestReader);
Request r = reader.readStruct(Types.REQUEST, Request.class); Request r = reader.readStruct(Types.REQUEST, Request.class);
BitSet decoded = r.getBitmap(); BitSet decoded = r.getBitmap();
@@ -116,10 +120,13 @@ public class RequestReaderTest extends TestCase {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out); Writer w = writerFactory.createWriter(out);
w.writeStructId(Types.REQUEST); w.writeStructId(Types.REQUEST);
// Allow one byte for the REQUEST tag, one byte for the BYTES tag, // Allow one byte for the REQUEST tag, one byte for the padding length
// and five bytes for the length as an int32 // as a uint7, one byte for the BYTES tag, and five bytes for the
int size = ProtocolConstants.MAX_PACKET_LENGTH - 7; // length of the byte array as an int32
int size = ProtocolConstants.MAX_PACKET_LENGTH - 8;
if(tooBig) size++; if(tooBig) size++;
assertTrue(size > Short.MAX_VALUE);
w.writeUint7((byte) 0);
w.writeBytes(new byte[size]); w.writeBytes(new byte[size]);
assertEquals(tooBig, out.size() > ProtocolConstants.MAX_PACKET_LENGTH); assertEquals(tooBig, out.size() > ProtocolConstants.MAX_PACKET_LENGTH);
return out.toByteArray(); return out.toByteArray();
@@ -129,6 +136,7 @@ public class RequestReaderTest extends TestCase {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out); Writer w = writerFactory.createWriter(out);
w.writeStructId(Types.REQUEST); w.writeStructId(Types.REQUEST);
w.writeUint7((byte) 0);
w.writeBytes(bitmap); w.writeBytes(bitmap);
return out.toByteArray(); return out.toByteArray();
} }

View File

@@ -1,70 +0,0 @@
package net.sf.briar.protocol.writers;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.BitSet;
import junit.framework.TestCase;
import net.sf.briar.api.protocol.writers.RequestWriter;
import net.sf.briar.api.serial.WriterFactory;
import net.sf.briar.serial.SerialModule;
import net.sf.briar.util.StringUtils;
import org.junit.Test;
import com.google.inject.Guice;
import com.google.inject.Injector;
public class RequestWriterImplTest extends TestCase {
private final WriterFactory writerFactory;
public RequestWriterImplTest() {
super();
Injector i = Guice.createInjector(new SerialModule());
writerFactory = i.getInstance(WriterFactory.class);
}
@Test
public void testWriteBitmapNoPadding() throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
RequestWriter r = new RequestWriterImpl(out, writerFactory);
BitSet b = new BitSet();
// 11011001 = 0xD9
b.set(0);
b.set(1);
b.set(3);
b.set(4);
b.set(7);
// 01011001 = 0x59
b.set(9);
b.set(11);
b.set(12);
b.set(15);
r.writeRequest(b, 16);
// Short user tag 8, short bytes with length 2, 0xD959
byte[] output = out.toByteArray();
assertEquals("C8" + "92" + "D959", StringUtils.toHexString(output));
}
@Test
public void testWriteBitmapWithPadding() throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
RequestWriter r = new RequestWriterImpl(out, writerFactory);
BitSet b = new BitSet();
// 01011001 = 0x59
b.set(1);
b.set(3);
b.set(4);
b.set(7);
// 11011xxx = 0xD8, after padding
b.set(8);
b.set(9);
b.set(11);
b.set(12);
r.writeRequest(b, 13);
// Short user tag 8, short bytes with length 2, 0x59D8
byte[] output = out.toByteArray();
assertEquals("C8" + "92" + "59D8", StringUtils.toHexString(output));
}
}

View File

@@ -16,7 +16,6 @@ import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.db.DatabaseModule; import net.sf.briar.db.DatabaseModule;
import net.sf.briar.lifecycle.LifecycleModule; import net.sf.briar.lifecycle.LifecycleModule;
import net.sf.briar.protocol.ProtocolModule; import net.sf.briar.protocol.ProtocolModule;
import net.sf.briar.protocol.writers.ProtocolWritersModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
import net.sf.briar.transport.batch.TransportBatchModule; import net.sf.briar.transport.batch.TransportBatchModule;
import net.sf.briar.transport.stream.TransportStreamModule; import net.sf.briar.transport.stream.TransportStreamModule;
@@ -44,10 +43,9 @@ public class ConnectionWriterTest extends TestCase {
}; };
Injector i = Guice.createInjector(testModule, new CryptoModule(), Injector i = Guice.createInjector(testModule, new CryptoModule(),
new DatabaseModule(), new LifecycleModule(), new DatabaseModule(), new LifecycleModule(),
new ProtocolModule(), new ProtocolWritersModule(), new ProtocolModule(), new SerialModule(),
new SerialModule(), new TestDatabaseModule(), new TestDatabaseModule(), new TransportBatchModule(),
new TransportBatchModule(), new TransportModule(), new TransportModule(), new TransportStreamModule());
new TransportStreamModule());
connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class); connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class);
secret = new byte[32]; secret = new byte[32];
new Random().nextBytes(secret); new Random().nextBytes(secret);

View File

@@ -26,11 +26,11 @@ import net.sf.briar.api.db.event.MessagesAddedEvent;
import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageFactory; import net.sf.briar.api.protocol.MessageFactory;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import net.sf.briar.api.transport.BatchTransportReader; import net.sf.briar.api.transport.BatchTransportReader;
import net.sf.briar.api.transport.BatchTransportWriter; import net.sf.briar.api.transport.BatchTransportWriter;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
@@ -43,7 +43,6 @@ import net.sf.briar.db.DatabaseModule;
import net.sf.briar.lifecycle.LifecycleModule; import net.sf.briar.lifecycle.LifecycleModule;
import net.sf.briar.plugins.ImmediateExecutor; import net.sf.briar.plugins.ImmediateExecutor;
import net.sf.briar.protocol.ProtocolModule; import net.sf.briar.protocol.ProtocolModule;
import net.sf.briar.protocol.writers.ProtocolWritersModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
import net.sf.briar.transport.TransportModule; import net.sf.briar.transport.TransportModule;
import net.sf.briar.transport.stream.TransportStreamModule; import net.sf.briar.transport.stream.TransportStreamModule;
@@ -97,10 +96,9 @@ public class BatchConnectionReadWriteTest extends TestCase {
}; };
return Guice.createInjector(testModule, new CryptoModule(), return Guice.createInjector(testModule, new CryptoModule(),
new DatabaseModule(), new LifecycleModule(), new DatabaseModule(), new LifecycleModule(),
new ProtocolModule(), new ProtocolWritersModule(), new ProtocolModule(), new SerialModule(),
new SerialModule(), new TestDatabaseModule(dir), new TestDatabaseModule(dir), new TransportBatchModule(),
new TransportBatchModule(), new TransportModule(), new TransportModule(), new TransportStreamModule());
new TransportStreamModule());
} }
@Test @Test
@@ -132,10 +130,10 @@ public class BatchConnectionReadWriteTest extends TestCase {
alice.getInstance(ConnectionWriterFactory.class); alice.getInstance(ConnectionWriterFactory.class);
ProtocolWriterFactory protoFactory = ProtocolWriterFactory protoFactory =
alice.getInstance(ProtocolWriterFactory.class); alice.getInstance(ProtocolWriterFactory.class);
BatchTransportWriter writer = new TestBatchTransportWriter(out); BatchTransportWriter transport = new TestBatchTransportWriter(out);
OutgoingBatchConnection batchOut = new OutgoingBatchConnection( OutgoingBatchConnection batchOut = new OutgoingBatchConnection(db,
connFactory, db, protoFactory, contactId, transportIndex, connFactory, protoFactory, contactId, transportIndex,
writer); transport);
// Write whatever needs to be written // Write whatever needs to be written
batchOut.write(); batchOut.write();
// Close Alice's database // Close Alice's database
@@ -188,7 +186,7 @@ public class BatchConnectionReadWriteTest extends TestCase {
bob.getInstance(ProtocolReaderFactory.class); bob.getInstance(ProtocolReaderFactory.class);
BatchTransportReader reader = new TestBatchTransportReader(in); BatchTransportReader reader = new TestBatchTransportReader(in);
IncomingBatchConnection batchIn = new IncomingBatchConnection( IncomingBatchConnection batchIn = new IncomingBatchConnection(
new ImmediateExecutor(), connFactory, db, protoFactory, ctx, new ImmediateExecutor(), db, connFactory, protoFactory, ctx,
reader, tag); reader, tag);
// No messages should have been added yet // No messages should have been added yet
assertFalse(listener.messagesAdded); assertFalse(listener.messagesAdded);