Request reader and unit test.

This commit is contained in:
akwizgran
2011-07-27 11:06:54 +01:00
parent 0933092295
commit b161e5ed1d
20 changed files with 305 additions and 48 deletions

View File

@@ -12,5 +12,5 @@ public interface Ack {
static final int MAX_SIZE = (1024 * 1024) - 100;
/** Returns the IDs of the acknowledged batches. */
Collection<BatchId> getBatches();
Collection<BatchId> getBatchIds();
}

View File

@@ -12,5 +12,5 @@ public interface Offer {
static final int MAX_SIZE = (1024 * 1024) - 100;
/** Returns the message IDs contained in the offer. */
Collection<MessageId> getMessages();
Collection<MessageId> getMessageIds();
}

View File

@@ -0,0 +1,19 @@
package net.sf.briar.api.protocol;
import java.util.BitSet;
/** A packet requesting some or all of the messages from an offer. */
public interface Request {
/**
* The maximum size of a serialised request, exlcuding encryption and
* authentication.
*/
static final int MAX_SIZE = (1024 * 1024) - 100;
/**
* Returns a sequence of bits corresponding to the sequence of messages in
* the offer, where the i^th bit is set if the i^th message should be sent.
*/
BitSet getBitmap();
}

View File

@@ -597,7 +597,7 @@ class ReadWriteLockDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
try {
messageStatusLock.writeLock().lock();
try {
Collection<BatchId> acks = a.getBatches();
Collection<BatchId> acks = a.getBatchIds();
for(BatchId ack : acks) {
Txn txn = db.startTransaction();
try {
@@ -676,7 +676,7 @@ class ReadWriteLockDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
try {
subscriptionLock.readLock().lock();
try {
Collection<MessageId> offered = o.getMessages();
Collection<MessageId> offered = o.getMessageIds();
BitSet request = new BitSet(offered.size());
Txn txn = db.startTransaction();
try {

View File

@@ -440,7 +440,7 @@ class SynchronizedDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
if(!containsContact(c)) throw new NoSuchContactException();
synchronized(messageLock) {
synchronized(messageStatusLock) {
Collection<BatchId> acks = a.getBatches();
Collection<BatchId> acks = a.getBatchIds();
for(BatchId ack : acks) {
Txn txn = db.startTransaction();
try {
@@ -497,7 +497,7 @@ class SynchronizedDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
synchronized(messageLock) {
synchronized(messageStatusLock) {
synchronized(subscriptionLock) {
Collection<MessageId> offered = o.getMessages();
Collection<MessageId> offered = o.getMessageIds();
BitSet request = new BitSet(offered.size());
Txn txn = db.startTransaction();
try {

View File

@@ -7,5 +7,5 @@ import net.sf.briar.api.protocol.BatchId;
interface AckFactory {
Ack createAck(Collection<BatchId> batches);
Ack createAck(Collection<BatchId> acked);
}

View File

@@ -7,7 +7,7 @@ import net.sf.briar.api.protocol.BatchId;
class AckFactoryImpl implements AckFactory {
public Ack createAck(Collection<BatchId> batches) {
return new AckImpl(batches);
public Ack createAck(Collection<BatchId> acked) {
return new AckImpl(acked);
}
}

View File

@@ -7,13 +7,13 @@ import net.sf.briar.api.protocol.BatchId;
class AckImpl implements Ack {
private final Collection<BatchId> batches;
private final Collection<BatchId> acked;
AckImpl(Collection<BatchId> batches) {
this.batches = batches;
AckImpl(Collection<BatchId> acked) {
this.acked = acked;
}
public Collection<BatchId> getBatches() {
return batches;
public Collection<BatchId> getBatchIds() {
return acked;
}
}

View File

@@ -7,5 +7,5 @@ import net.sf.briar.api.protocol.Offer;
interface OfferFactory {
Offer createOffer(Collection<MessageId> messages);
Offer createOffer(Collection<MessageId> offered);
}

View File

@@ -7,7 +7,7 @@ import net.sf.briar.api.protocol.Offer;
class OfferFactoryImpl implements OfferFactory {
public Offer createOffer(Collection<MessageId> messages) {
return new OfferImpl(messages);
public Offer createOffer(Collection<MessageId> offered) {
return new OfferImpl(offered);
}
}

View File

@@ -7,13 +7,13 @@ import net.sf.briar.api.protocol.Offer;
class OfferImpl implements Offer {
private final Collection<MessageId> messages;
private final Collection<MessageId> offered;
OfferImpl(Collection<MessageId> messages) {
this.messages = messages;
OfferImpl(Collection<MessageId> offered) {
this.offered = offered;
}
public Collection<MessageId> getMessages() {
return messages;
public Collection<MessageId> getMessageIds() {
return offered;
}
}

View File

@@ -23,19 +23,21 @@ public class ProtocolModule extends AbstractModule {
bind(BatchFactory.class).to(BatchFactoryImpl.class);
bind(GroupFactory.class).to(GroupFactoryImpl.class);
bind(OfferFactory.class).to(OfferFactoryImpl.class);
bind(RequestFactory.class).to(RequestFactoryImpl.class);
bind(SubscriptionFactory.class).to(SubscriptionFactoryImpl.class);
bind(TransportFactory.class).to(TransportFactoryImpl.class);
bind(MessageEncoder.class).to(MessageEncoderImpl.class);
}
@Provides
ObjectReader<BatchId> getBatchIdReader() {
return new BatchIdReader();
ObjectReader<Author> getAuthorReader(CryptoComponent crypto,
AuthorFactory authorFactory) {
return new AuthorReader(crypto, authorFactory);
}
@Provides
ObjectReader<MessageId> getMessageIdReader() {
return new MessageIdReader();
ObjectReader<BatchId> getBatchIdReader() {
return new BatchIdReader();
}
@Provides
@@ -45,9 +47,8 @@ public class ProtocolModule extends AbstractModule {
}
@Provides
ObjectReader<Author> getAuthorReader(CryptoComponent crypto,
AuthorFactory authorFactory) {
return new AuthorReader(crypto, authorFactory);
ObjectReader<MessageId> getMessageIdReader() {
return new MessageIdReader();
}
@Provides

View File

@@ -0,0 +1,10 @@
package net.sf.briar.protocol;
import java.util.BitSet;
import net.sf.briar.api.protocol.Request;
interface RequestFactory {
Request createRequest(BitSet requested);
}

View File

@@ -0,0 +1,12 @@
package net.sf.briar.protocol;
import java.util.BitSet;
import net.sf.briar.api.protocol.Request;
class RequestFactoryImpl implements RequestFactory {
public Request createRequest(BitSet requested) {
return new RequestImpl(requested);
}
}

View File

@@ -0,0 +1,18 @@
package net.sf.briar.protocol;
import java.util.BitSet;
import net.sf.briar.api.protocol.Request;
class RequestImpl implements Request {
private final BitSet requested;
RequestImpl(BitSet requested) {
this.requested = requested;
}
public BitSet getBitmap() {
return requested;
}
}

View File

@@ -0,0 +1,41 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.util.BitSet;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
import com.google.inject.Inject;
class RequestReader implements ObjectReader<Request> {
private final RequestFactory requestFactory;
@Inject
RequestReader(RequestFactory requestFactory) {
this.requestFactory = requestFactory;
}
public Request readObject(Reader r) throws IOException {
// Initialise the consumer
Consumer counting = new CountingConsumer(Request.MAX_SIZE);
// Read the data
r.addConsumer(counting);
r.readUserDefinedTag(Tags.REQUEST);
byte[] bitmap = r.readBytes();
r.removeConsumer(counting);
// Convert the bitmap into a BitSet
BitSet b = new BitSet(bitmap.length * 8);
for(int i = 0; i < bitmap.length; i++) {
for(int j = 0; j < 8; j++) {
byte bit = (byte) (128 >> j);
if((bitmap[i] & bit) != 0) b.set(i * 8 + j);
}
}
return requestFactory.createRequest(b);
}
}

View File

@@ -752,7 +752,7 @@ public abstract class DatabaseComponentTest extends TestCase {
allowing(database).containsContact(txn, contactId);
will(returnValue(true));
// Get the acked batches
oneOf(ack).getBatches();
oneOf(ack).getBatchIds();
will(returnValue(Collections.singletonList(batchId)));
oneOf(database).removeAckedBatch(txn, contactId, batchId);
}});
@@ -940,7 +940,7 @@ public abstract class DatabaseComponentTest extends TestCase {
allowing(database).containsContact(txn, contactId);
will(returnValue(true));
// Get the offered messages
oneOf(offer).getMessages();
oneOf(offer).getMessageIds();
will(returnValue(offered));
oneOf(database).setStatusSeenIfVisible(txn, contactId, messageId);
will(returnValue(false)); // Not visible - request message # 0

View File

@@ -97,7 +97,7 @@ public class AckReaderTest extends TestCase {
}
private byte[] createAck(boolean tooBig) throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(Ack.MAX_SIZE);
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
w.writeUserDefinedTag(Tags.ACK);
w.writeListStart();

View File

@@ -6,6 +6,7 @@ import java.io.FileOutputStream;
import java.security.KeyPair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
@@ -24,6 +25,7 @@ import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageEncoder;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.Subscriptions;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.protocol.Transports;
@@ -32,6 +34,7 @@ import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.OfferWriter;
import net.sf.briar.api.protocol.writers.PacketWriterFactory;
import net.sf.briar.api.protocol.writers.RequestWriter;
import net.sf.briar.api.protocol.writers.SubscriptionWriter;
import net.sf.briar.api.protocol.writers.TransportWriter;
import net.sf.briar.api.serial.Reader;
@@ -61,6 +64,7 @@ public class FileReadWriteTest extends TestCase {
private final AckReader ackReader;
private final BatchReader batchReader;
private final OfferReader offerReader;
private final RequestReader requestReader;
private final SubscriptionReader subscriptionReader;
private final TransportReader transportReader;
private final Author author;
@@ -82,6 +86,7 @@ public class FileReadWriteTest extends TestCase {
ackReader = i.getInstance(AckReader.class);
batchReader = i.getInstance(BatchReader.class);
offerReader = i.getInstance(OfferReader.class);
requestReader = i.getInstance(RequestReader.class);
subscriptionReader = i.getInstance(SubscriptionReader.class);
transportReader = i.getInstance(TransportReader.class);
// Create two groups: one restricted, one unrestricted
@@ -135,6 +140,12 @@ public class FileReadWriteTest extends TestCase {
assertTrue(o.writeMessageId(message3.getId()));
o.finish();
RequestWriter r = packetWriterFactory.createRequestWriter(out);
BitSet requested = new BitSet(4);
requested.set(1);
requested.set(3);
r.writeBitmap(requested, 4);
SubscriptionWriter s =
packetWriterFactory.createSubscriptionWriter(out);
Collection<Group> subs = new ArrayList<Group>();
@@ -160,35 +171,47 @@ public class FileReadWriteTest extends TestCase {
reader.addObjectReader(Tags.ACK, ackReader);
reader.addObjectReader(Tags.BATCH, batchReader);
reader.addObjectReader(Tags.OFFER, offerReader);
reader.addObjectReader(Tags.REQUEST, requestReader);
reader.addObjectReader(Tags.SUBSCRIPTIONS, subscriptionReader);
reader.addObjectReader(Tags.TRANSPORTS, transportReader);
// Read the ack
assertTrue(reader.hasUserDefined(Tags.ACK));
Ack a = reader.readUserDefined(Tags.ACK, Ack.class);
assertEquals(Collections.singletonList(ack), a.getBatches());
assertEquals(Collections.singletonList(ack), a.getBatchIds());
// Read the batch
assertTrue(reader.hasUserDefined(Tags.BATCH));
Batch b = reader.readUserDefined(Tags.BATCH, Batch.class);
Collection<Message> messages = b.getMessages();
assertEquals(4, messages.size());
Iterator<Message> i = messages.iterator();
checkMessageEquality(message, i.next());
checkMessageEquality(message1, i.next());
checkMessageEquality(message2, i.next());
checkMessageEquality(message3, i.next());
Iterator<Message> it = messages.iterator();
checkMessageEquality(message, it.next());
checkMessageEquality(message1, it.next());
checkMessageEquality(message2, it.next());
checkMessageEquality(message3, it.next());
// Read the offer
assertTrue(reader.hasUserDefined(Tags.OFFER));
Offer o = reader.readUserDefined(Tags.OFFER, Offer.class);
Collection<MessageId> ids = o.getMessages();
assertEquals(4, ids.size());
Iterator<MessageId> i1 = ids.iterator();
assertEquals(message.getId(), i1.next());
assertEquals(message1.getId(), i1.next());
assertEquals(message2.getId(), i1.next());
assertEquals(message3.getId(), i1.next());
Collection<MessageId> offered = o.getMessageIds();
assertEquals(4, offered.size());
Iterator<MessageId> it1 = offered.iterator();
assertEquals(message.getId(), it1.next());
assertEquals(message1.getId(), it1.next());
assertEquals(message2.getId(), it1.next());
assertEquals(message3.getId(), it1.next());
// Read the request
assertTrue(reader.hasUserDefined(Tags.REQUEST));
Request r = reader.readUserDefined(Tags.REQUEST, Request.class);
BitSet requested = r.getBitmap();
assertFalse(requested.get(0));
assertTrue(requested.get(1));
assertFalse(requested.get(2));
assertTrue(requested.get(3));
// If there are any padding bits, they should all be zero
for(int i = 4; i < requested.size(); i++) assertFalse(requested.get(i));
// Read the subscriptions update
assertTrue(reader.hasUserDefined(Tags.SUBSCRIPTIONS));
@@ -196,9 +219,9 @@ public class FileReadWriteTest extends TestCase {
Subscriptions.class);
Collection<Group> subs = s.getSubscriptions();
assertEquals(2, subs.size());
Iterator<Group> i2 = subs.iterator();
checkGroupEquality(group, i2.next());
checkGroupEquality(group1, i2.next());
Iterator<Group> it2 = subs.iterator();
checkGroupEquality(group, it2.next());
checkGroupEquality(group1, it2.next());
assertTrue(s.getTimestamp() > start);
assertTrue(s.getTimestamp() <= System.currentTimeMillis());

View File

@@ -0,0 +1,133 @@
package net.sf.briar.protocol;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.BitSet;
import junit.framework.TestCase;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
import net.sf.briar.serial.SerialModule;
import org.jmock.Expectations;
import org.jmock.Mockery;
import org.junit.Test;
import com.google.inject.Guice;
import com.google.inject.Injector;
public class RequestReaderTest extends TestCase {
private final ReaderFactory readerFactory;
private final WriterFactory writerFactory;
private final Mockery context;
public RequestReaderTest() throws Exception {
super();
Injector i = Guice.createInjector(new SerialModule());
readerFactory = i.getInstance(ReaderFactory.class);
writerFactory = i.getInstance(WriterFactory.class);
context = new Mockery();
}
@Test
public void testFormatExceptionIfRequestIsTooLarge() throws Exception {
RequestFactory requestFactory = context.mock(RequestFactory.class);
RequestReader requestReader = new RequestReader(requestFactory);
byte[] b = createRequest(true);
ByteArrayInputStream in = new ByteArrayInputStream(b);
Reader reader = readerFactory.createReader(in);
reader.addObjectReader(Tags.REQUEST, requestReader);
try {
reader.readUserDefined(Tags.REQUEST, Request.class);
assertTrue(false);
} catch(FormatException expected) {}
context.assertIsSatisfied();
}
@Test
public void testNoFormatExceptionIfRequestIsMaximumSize() throws Exception {
final RequestFactory requestFactory =
context.mock(RequestFactory.class);
RequestReader requestReader = new RequestReader(requestFactory);
final Request request = context.mock(Request.class);
context.checking(new Expectations() {{
oneOf(requestFactory).createRequest(with(any(BitSet.class)));
will(returnValue(request));
}});
byte[] b = createRequest(false);
ByteArrayInputStream in = new ByteArrayInputStream(b);
Reader reader = readerFactory.createReader(in);
reader.addObjectReader(Tags.REQUEST, requestReader);
assertEquals(request, reader.readUserDefined(Tags.REQUEST,
Request.class));
context.assertIsSatisfied();
}
@Test
public void testBitmapDecoding() throws Exception {
// Test sizes from 0 to 1000 bits
for(int i = 0; i < 1000; i++) {
// Create a BitSet of size i with one in ten bits set (on average)
BitSet requested = new BitSet(i);
for(int j = 0; j < i; j++) if(Math.random() < 0.1) requested.set(j);
// Encode the BitSet as a bitmap
int bytes = i % 8 == 0 ? i / 8 : i / 8 + 1;
byte[] bitmap = new byte[bytes];
for(int j = 0; j < i; j++) {
if(requested.get(j)) {
int offset = j / 8;
byte bit = (byte) (128 >> j % 8);
bitmap[offset] |= bit;
}
}
// Create a serialised request containing the bitmap
byte[] b = createRequest(bitmap);
// Deserialise the request
ByteArrayInputStream in = new ByteArrayInputStream(b);
Reader reader = readerFactory.createReader(in);
RequestReader requestReader =
new RequestReader(new RequestFactoryImpl());
reader.addObjectReader(Tags.REQUEST, requestReader);
Request r = reader.readUserDefined(Tags.REQUEST, Request.class);
BitSet decoded = r.getBitmap();
// Check that the decoded BitSet matches the original - we can't
// use equals() because of padding, but the first i bits should
// match and the cardinalities should be equal, indicating that no
// padding bits are set
for(int j = 0; j < i; j++) {
assertEquals(requested.get(j), decoded.get(j));
}
assertEquals(requested.cardinality(), decoded.cardinality());
}
}
private byte[] createRequest(boolean tooBig) throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
w.writeUserDefinedTag(Tags.REQUEST);
// Allow one byte for the REQUEST tag, one byte for the BYTES tag, and
// five bytes for the length as an int32
if(tooBig) w.writeBytes(new byte[Request.MAX_SIZE - 6]);
else w.writeBytes(new byte[Request.MAX_SIZE - 7]);
assertEquals(tooBig, out.size() > Request.MAX_SIZE);
return out.toByteArray();
}
private byte[] createRequest(byte[] bitmap) throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
w.writeUserDefinedTag(Tags.REQUEST);
w.writeBytes(bitmap);
return out.toByteArray();
}
}