BATCH_ID and MESSAGE_ID don't need to be structs.

This commit is contained in:
akwizgran
2011-12-07 00:38:14 +00:00
parent b7c3224618
commit 495baf8c70
11 changed files with 65 additions and 95 deletions

View File

@@ -3,17 +3,14 @@ package net.sf.briar.api.protocol;
/** Struct identifiers for encoding and decoding protocol objects. */ /** Struct identifiers for encoding and decoding protocol objects. */
public interface Types { public interface Types {
// FIXME: Batch ID, message ID don't need to be structs
static final int ACK = 0; static final int ACK = 0;
static final int AUTHOR = 1; static final int AUTHOR = 1;
static final int BATCH = 2; static final int BATCH = 2;
static final int BATCH_ID = 3; static final int GROUP = 3;
static final int GROUP = 4; static final int MESSAGE = 4;
static final int MESSAGE = 5; static final int OFFER = 5;
static final int MESSAGE_ID = 6; static final int REQUEST = 6;
static final int OFFER = 7; static final int SUBSCRIPTION_UPDATE = 7;
static final int REQUEST = 8; static final int TRANSPORT = 8;
static final int SUBSCRIPTION_UPDATE = 9; static final int TRANSPORT_UPDATE = 9;
static final int TRANSPORT = 10;
static final int TRANSPORT_UPDATE = 11;
} }

View File

@@ -8,5 +8,5 @@ public interface SerialComponent {
int getSerialisedStructIdLength(int id); int getSerialisedStructIdLength(int id);
int getSerialisedUniqueIdLength(int id); int getSerialisedUniqueIdLength();
} }

View File

@@ -1,8 +1,12 @@
package net.sf.briar.protocol; package net.sf.briar.protocol;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List;
import net.sf.briar.api.Bytes;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
@@ -18,35 +22,30 @@ import net.sf.briar.api.serial.Reader;
class AckReader implements ObjectReader<Ack> { class AckReader implements ObjectReader<Ack> {
private final PacketFactory packetFactory; private final PacketFactory packetFactory;
private final ObjectReader<BatchId> batchIdReader;
AckReader(PacketFactory packetFactory) { AckReader(PacketFactory packetFactory) {
this.packetFactory = packetFactory; this.packetFactory = packetFactory;
batchIdReader = new BatchIdReader();
} }
public Ack readObject(Reader r) throws IOException { public Ack readObject(Reader r) throws IOException {
// Initialise the consumer // Initialise the consumer
Consumer counting = Consumer counting =
new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH); new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
// Read and digest the data // Read the data
r.addConsumer(counting); r.addConsumer(counting);
r.readStructId(Types.ACK); r.readStructId(Types.ACK);
r.addObjectReader(Types.BATCH_ID, batchIdReader); r.setMaxBytesLength(UniqueId.LENGTH);
Collection<BatchId> batches = r.readList(BatchId.class); Collection<Bytes> raw = r.readList(Bytes.class);
r.removeObjectReader(Types.BATCH_ID); r.resetMaxBytesLength();
r.removeConsumer(counting); r.removeConsumer(counting);
// Build and return the ack // Convert the byte arrays to batch IDs
return packetFactory.createAck(batches); List<BatchId> batches = new ArrayList<BatchId>();
} for(Bytes b : raw) {
if(b.getBytes().length != UniqueId.LENGTH)
private static class BatchIdReader implements ObjectReader<BatchId> { throw new FormatException();
batches.add(new BatchId(b.getBytes()));
public BatchId readObject(Reader r) throws IOException {
r.readStructId(Types.BATCH_ID);
byte[] b = r.readBytes(UniqueId.LENGTH);
if(b.length != UniqueId.LENGTH) throw new FormatException();
return new BatchId(b);
} }
// Build and return the ack
return packetFactory.createAck(Collections.unmodifiableList(batches));
} }
} }

View File

@@ -1,20 +0,0 @@
package net.sf.briar.protocol;
import java.io.IOException;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
class MessageIdReader implements ObjectReader<MessageId> {
public MessageId readObject(Reader r) throws IOException {
r.readStructId(Types.MESSAGE_ID);
byte[] b = r.readBytes(UniqueId.LENGTH);
if(b.length != UniqueId.LENGTH) throw new FormatException();
return new MessageId(b);
}
}

View File

@@ -8,6 +8,7 @@ import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.CopyingConsumer; import net.sf.briar.api.serial.CopyingConsumer;
import net.sf.briar.api.serial.CountingConsumer; import net.sf.briar.api.serial.CountingConsumer;
import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.ObjectReader;
@@ -15,14 +16,11 @@ import net.sf.briar.api.serial.Reader;
class MessageReader implements ObjectReader<UnverifiedMessage> { class MessageReader implements ObjectReader<UnverifiedMessage> {
private final ObjectReader<MessageId> messageIdReader;
private final ObjectReader<Group> groupReader; private final ObjectReader<Group> groupReader;
private final ObjectReader<Author> authorReader; private final ObjectReader<Author> authorReader;
MessageReader(ObjectReader<MessageId> messageIdReader, MessageReader(ObjectReader<Group> groupReader,
ObjectReader<Group> groupReader,
ObjectReader<Author> authorReader) { ObjectReader<Author> authorReader) {
this.messageIdReader = messageIdReader;
this.groupReader = groupReader; this.groupReader = groupReader;
this.authorReader = authorReader; this.authorReader = authorReader;
} }
@@ -40,9 +38,9 @@ class MessageReader implements ObjectReader<UnverifiedMessage> {
if(r.hasNull()) { if(r.hasNull()) {
r.readNull(); r.readNull();
} else { } else {
r.addObjectReader(Types.MESSAGE_ID, messageIdReader); byte[] b = r.readBytes(UniqueId.LENGTH);
parent = r.readStruct(Types.MESSAGE_ID, MessageId.class); if(b.length != UniqueId.LENGTH) throw new FormatException();
r.removeObjectReader(Types.MESSAGE_ID); parent = new MessageId(b);
} }
// Read the group, if there is one // Read the group, if there is one
Group group = null; Group group = null;
@@ -69,7 +67,8 @@ class MessageReader implements ObjectReader<UnverifiedMessage> {
if(timestamp < 0L) throw new FormatException(); if(timestamp < 0L) throw new FormatException();
// Read the salt // Read the salt
byte[] salt = r.readBytes(ProtocolConstants.SALT_LENGTH); byte[] salt = r.readBytes(ProtocolConstants.SALT_LENGTH);
if(salt.length != ProtocolConstants.SALT_LENGTH) throw new FormatException(); if(salt.length != ProtocolConstants.SALT_LENGTH)
throw new FormatException();
// Read the message body // Read the message body
byte[] body = r.readBytes(ProtocolConstants.MAX_BODY_LENGTH); byte[] body = r.readBytes(ProtocolConstants.MAX_BODY_LENGTH);
// Record the offset of the body within the message // Record the offset of the body within the message

View File

@@ -1,13 +1,19 @@
package net.sf.briar.protocol; package net.sf.briar.protocol;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List;
import net.sf.briar.api.Bytes;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.CountingConsumer; import net.sf.briar.api.serial.CountingConsumer;
import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.ObjectReader;
@@ -15,12 +21,9 @@ import net.sf.briar.api.serial.Reader;
class OfferReader implements ObjectReader<Offer> { class OfferReader implements ObjectReader<Offer> {
private final ObjectReader<MessageId> messageIdReader;
private final PacketFactory packetFactory; private final PacketFactory packetFactory;
OfferReader(ObjectReader<MessageId> messageIdReader, OfferReader(PacketFactory packetFactory) {
PacketFactory packetFactory) {
this.messageIdReader = messageIdReader;
this.packetFactory = packetFactory; this.packetFactory = packetFactory;
} }
@@ -31,11 +34,19 @@ class OfferReader implements ObjectReader<Offer> {
// Read the data // Read the data
r.addConsumer(counting); r.addConsumer(counting);
r.readStructId(Types.OFFER); r.readStructId(Types.OFFER);
r.addObjectReader(Types.MESSAGE_ID, messageIdReader); r.setMaxBytesLength(UniqueId.LENGTH);
Collection<MessageId> messages = r.readList(MessageId.class); Collection<Bytes> raw = r.readList(Bytes.class);
r.removeObjectReader(Types.MESSAGE_ID); r.resetMaxBytesLength();
r.removeConsumer(counting); r.removeConsumer(counting);
// Convert the byte arrays to message IDs
List<MessageId> messages = new ArrayList<MessageId>();
for(Bytes b : raw) {
if(b.getBytes().length != UniqueId.LENGTH)
throw new FormatException();
messages.add(new MessageId(b.getBytes()));
}
// Build and return the offer // Build and return the offer
return packetFactory.createOffer(messages); return packetFactory.createOffer(Collections.unmodifiableList(
messages));
} }
} }

View File

@@ -7,7 +7,6 @@ import net.sf.briar.api.protocol.AuthorFactory;
import net.sf.briar.api.protocol.Group; import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupFactory; import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.MessageFactory; import net.sf.briar.api.protocol.MessageFactory;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
@@ -58,23 +57,16 @@ public class ProtocolModule extends AbstractModule {
return new GroupReader(crypto, groupFactory); return new GroupReader(crypto, groupFactory);
} }
@Provides
ObjectReader<MessageId> getMessageIdReader() {
return new MessageIdReader();
}
@Provides @Provides
ObjectReader<UnverifiedMessage> getMessageReader( ObjectReader<UnverifiedMessage> getMessageReader(
ObjectReader<MessageId> messageIdReader,
ObjectReader<Group> groupReader, ObjectReader<Group> groupReader,
ObjectReader<Author> authorReader) { ObjectReader<Author> authorReader) {
return new MessageReader(messageIdReader, groupReader, authorReader); return new MessageReader(groupReader, authorReader);
} }
@Provides @Provides
ObjectReader<Offer> getOfferReader(ObjectReader<MessageId> messageIdReader, ObjectReader<Offer> getOfferReader(PacketFactory packetFactory) {
PacketFactory packetFactory) { return new OfferReader(packetFactory);
return new OfferReader(messageIdReader, packetFactory);
} }
@Provides @Provides

View File

@@ -41,7 +41,7 @@ class ProtocolWriterImpl implements ProtocolWriter {
int overhead = serial.getSerialisedStructIdLength(Types.ACK) int overhead = serial.getSerialisedStructIdLength(Types.ACK)
+ serial.getSerialisedListStartLength() + serial.getSerialisedListStartLength()
+ serial.getSerialisedListEndLength(); + serial.getSerialisedListEndLength();
int idLength = serial.getSerialisedUniqueIdLength(Types.BATCH_ID); int idLength = serial.getSerialisedUniqueIdLength();
return (packet - overhead) / idLength; return (packet - overhead) / idLength;
} }
@@ -50,7 +50,7 @@ class ProtocolWriterImpl implements ProtocolWriter {
int overhead = serial.getSerialisedStructIdLength(Types.OFFER) int overhead = serial.getSerialisedStructIdLength(Types.OFFER)
+ serial.getSerialisedListStartLength() + serial.getSerialisedListStartLength()
+ serial.getSerialisedListEndLength(); + serial.getSerialisedListEndLength();
int idLength = serial.getSerialisedUniqueIdLength(Types.MESSAGE_ID); int idLength = serial.getSerialisedUniqueIdLength();
return (packet - overhead) / idLength; return (packet - overhead) / idLength;
} }
@@ -65,10 +65,7 @@ class ProtocolWriterImpl implements ProtocolWriter {
public void writeAck(Ack a) throws IOException { public void writeAck(Ack a) throws IOException {
w.writeStructId(Types.ACK); w.writeStructId(Types.ACK);
w.writeListStart(); w.writeListStart();
for(BatchId b : a.getBatchIds()) { for(BatchId b : a.getBatchIds()) w.writeBytes(b.getBytes());
w.writeStructId(Types.BATCH_ID);
w.writeBytes(b.getBytes());
}
w.writeListEnd(); w.writeListEnd();
} }
@@ -82,10 +79,7 @@ class ProtocolWriterImpl implements ProtocolWriter {
public void writeOffer(Offer o) throws IOException { public void writeOffer(Offer o) throws IOException {
w.writeStructId(Types.OFFER); w.writeStructId(Types.OFFER);
w.writeListStart(); w.writeListStart();
for(MessageId m : o.getMessageIds()) { for(MessageId m : o.getMessageIds()) w.writeBytes(m.getBytes());
w.writeStructId(Types.MESSAGE_ID);
w.writeBytes(m.getBytes());
}
w.writeListEnd(); w.writeListEnd();
} }

View File

@@ -20,10 +20,10 @@ class SerialComponentImpl implements SerialComponent {
return id < 32 ? 1 : 2; return id < 32 ? 1 : 2;
} }
public int getSerialisedUniqueIdLength(int id) { public int getSerialisedUniqueIdLength() {
// Struct ID, BYTES tag, length spec, bytes // BYTES tag, length spec, bytes
return getSerialisedStructIdLength(id) + 1 return 1 + getSerialisedLengthSpecLength(UniqueId.LENGTH)
+ getSerialisedLengthSpecLength(UniqueId.LENGTH) + UniqueId.LENGTH; + UniqueId.LENGTH;
} }
private int getSerialisedLengthSpecLength(int length) { private int getSerialisedLengthSpecLength(int length) {

View File

@@ -107,12 +107,10 @@ public class AckReaderTest extends TestCase {
Random random = new Random(); Random random = new Random();
while(out.size() + BatchId.LENGTH + 3 while(out.size() + BatchId.LENGTH + 3
< ProtocolConstants.MAX_PACKET_LENGTH) { < ProtocolConstants.MAX_PACKET_LENGTH) {
w.writeStructId(Types.BATCH_ID);
random.nextBytes(b); random.nextBytes(b);
w.writeBytes(b); w.writeBytes(b);
} }
if(tooBig) { if(tooBig) {
w.writeStructId(Types.BATCH_ID);
random.nextBytes(b); random.nextBytes(b);
w.writeBytes(b); w.writeBytes(b);
} }

View File

@@ -52,9 +52,9 @@ public class ProtocolWriterImplTest extends TestCase {
b.set(15); b.set(15);
Request r = packetFactory.createRequest(b, 16); Request r = packetFactory.createRequest(b, 16);
w.writeRequest(r); w.writeRequest(r);
// Short user tag 8, 0 as uint7, short bytes with length 2, 0xD959 // Short user tag 6, 0 as uint7, short bytes with length 2, 0xD959
byte[] output = out.toByteArray(); byte[] output = out.toByteArray();
assertEquals("C8" + "00" + "92" + "D959", assertEquals("C6" + "00" + "92" + "D959",
StringUtils.toHexString(output)); StringUtils.toHexString(output));
} }
@@ -75,9 +75,9 @@ public class ProtocolWriterImplTest extends TestCase {
b.set(12); b.set(12);
Request r = packetFactory.createRequest(b, 13); Request r = packetFactory.createRequest(b, 13);
w.writeRequest(r); w.writeRequest(r);
// Short user tag 8, 3 as uint7, short bytes with length 2, 0x59D8 // Short user tag 6, 3 as uint7, short bytes with length 2, 0x59D8
byte[] output = out.toByteArray(); byte[] output = out.toByteArray();
assertEquals("C8" + "03" + "92" + "59D8", assertEquals("C6" + "03" + "92" + "59D8",
StringUtils.toHexString(output)); StringUtils.toHexString(output));
} }
} }