Minor protocol refactoring.

This commit is contained in:
akwizgran
2011-11-18 11:27:34 +00:00
parent 30580f71ec
commit dacaa4566d
11 changed files with 96 additions and 108 deletions

View File

@@ -3,10 +3,12 @@ package net.sf.briar.protocol;
import java.io.IOException;
import java.util.Collection;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.CountingConsumer;
import net.sf.briar.api.serial.ObjectReader;
@@ -14,12 +16,12 @@ import net.sf.briar.api.serial.Reader;
class AckReader implements ObjectReader<Ack> {
private final ObjectReader<BatchId> batchIdReader;
private final AckFactory ackFactory;
private final ObjectReader<BatchId> batchIdReader;
AckReader(ObjectReader<BatchId> batchIdReader, AckFactory ackFactory) {
this.batchIdReader = batchIdReader;
AckReader(AckFactory ackFactory) {
this.ackFactory = ackFactory;
batchIdReader = new BatchIdReader();
}
public Ack readObject(Reader r) throws IOException {
@@ -36,4 +38,14 @@ class AckReader implements ObjectReader<Ack> {
// Build and return the ack
return ackFactory.createAck(batches);
}
private static class BatchIdReader implements ObjectReader<BatchId> {
public BatchId readObject(Reader r) throws IOException {
r.readUserDefinedId(Types.BATCH_ID);
byte[] b = r.readBytes(UniqueId.LENGTH);
if(b.length != UniqueId.LENGTH) throw new FormatException();
return new BatchId(b);
}
}
}

View File

@@ -1,20 +0,0 @@
package net.sf.briar.protocol;
import java.io.IOException;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
class BatchIdReader implements ObjectReader<BatchId> {
public BatchId readObject(Reader r) throws IOException {
r.readUserDefinedId(Types.BATCH_ID);
byte[] b = r.readBytes(UniqueId.LENGTH);
if(b.length != UniqueId.LENGTH) throw new FormatException();
return new BatchId(b);
}
}

View File

@@ -1,5 +1,11 @@
package net.sf.briar.protocol;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_BODY_LENGTH;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_SIGNATURE_LENGTH;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_SUBJECT_LENGTH;
import static net.sf.briar.api.protocol.ProtocolConstants.SALT_LENGTH;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.GeneralSecurityException;
@@ -14,9 +20,8 @@ import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageEncoder;
import net.sf.briar.api.protocol.MessageFactory;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.CountingConsumer;
@@ -27,7 +32,7 @@ import net.sf.briar.api.serial.WriterFactory;
import com.google.inject.Inject;
class MessageEncoderImpl implements MessageEncoder {
class MessageFactoryImpl implements MessageFactory {
private final Signature authorSignature, groupSignature;
private final SecureRandom random;
@@ -35,7 +40,7 @@ class MessageEncoderImpl implements MessageEncoder {
private final WriterFactory writerFactory;
@Inject
MessageEncoderImpl(CryptoComponent crypto, WriterFactory writerFactory) {
MessageFactoryImpl(CryptoComponent crypto, WriterFactory writerFactory) {
authorSignature = crypto.getSignature();
groupSignature = crypto.getSignature();
random = crypto.getSecureRandom();
@@ -43,50 +48,49 @@ class MessageEncoderImpl implements MessageEncoder {
this.writerFactory = writerFactory;
}
public Message encodeMessage(MessageId parent, String subject, byte[] body)
public Message createMessage(MessageId parent, String subject, byte[] body)
throws IOException, GeneralSecurityException {
return encodeMessage(parent, null, null, null, null, subject, body);
return createMessage(parent, null, null, null, null, subject, body);
}
public Message encodeMessage(MessageId parent, Group group, String subject,
public Message createMessage(MessageId parent, Group group, String subject,
byte[] body) throws IOException, GeneralSecurityException {
return encodeMessage(parent, group, null, null, null, subject, body);
return createMessage(parent, group, null, null, null, subject, body);
}
public Message encodeMessage(MessageId parent, Group group,
public Message createMessage(MessageId parent, Group group,
PrivateKey groupKey, String subject, byte[] body)
throws IOException, GeneralSecurityException {
return encodeMessage(parent, group, groupKey, null, null, subject,
return createMessage(parent, group, groupKey, null, null, subject,
body);
}
public Message encodeMessage(MessageId parent, Group group, Author author,
public Message createMessage(MessageId parent, Group group, Author author,
PrivateKey authorKey, String subject, byte[] body)
throws IOException, GeneralSecurityException {
return encodeMessage(parent, group, null, author, authorKey, subject,
return createMessage(parent, group, null, author, authorKey, subject,
body);
}
public Message encodeMessage(MessageId parent, Group group,
public Message createMessage(MessageId parent, Group group,
PrivateKey groupKey, Author author, PrivateKey authorKey,
String subject, byte[] body) throws IOException,
GeneralSecurityException {
if((author == null) != (authorKey == null))
throw new IllegalArgumentException();
if((group == null || group.getPublicKey() == null) !=
(groupKey == null))
if((group == null || group.getPublicKey() == null)
!= (groupKey == null))
throw new IllegalArgumentException();
if(subject.getBytes("UTF-8").length > ProtocolConstants.MAX_SUBJECT_LENGTH)
if(subject.getBytes("UTF-8").length > MAX_SUBJECT_LENGTH)
throw new IllegalArgumentException();
if(body.length > ProtocolConstants.MAX_BODY_LENGTH)
if(body.length > MAX_BODY_LENGTH)
throw new IllegalArgumentException();
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
// Initialise the consumers
CountingConsumer counting = new CountingConsumer(
ProtocolConstants.MAX_PACKET_LENGTH);
CountingConsumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
w.addConsumer(counting);
Consumer digestingConsumer = new DigestingConsumer(messageDigest);
w.addConsumer(digestingConsumer);
@@ -113,7 +117,7 @@ class MessageEncoderImpl implements MessageEncoder {
w.writeString(subject);
long timestamp = System.currentTimeMillis();
w.writeInt64(timestamp);
byte[] salt = new byte[ProtocolConstants.SALT_LENGTH];
byte[] salt = new byte[SALT_LENGTH];
random.nextBytes(salt);
w.writeBytes(salt);
w.writeBytes(body);
@@ -124,7 +128,7 @@ class MessageEncoderImpl implements MessageEncoder {
} else {
w.removeConsumer(authorConsumer);
byte[] sig = authorSignature.sign();
if(sig.length > ProtocolConstants.MAX_SIGNATURE_LENGTH)
if(sig.length > MAX_SIGNATURE_LENGTH)
throw new IllegalArgumentException();
w.writeBytes(sig);
}
@@ -134,7 +138,7 @@ class MessageEncoderImpl implements MessageEncoder {
} else {
w.removeConsumer(groupConsumer);
byte[] sig = groupSignature.sign();
if(sig.length > ProtocolConstants.MAX_SIGNATURE_LENGTH)
if(sig.length > MAX_SIGNATURE_LENGTH)
throw new IllegalArgumentException();
w.writeBytes(sig);
}

View File

@@ -5,11 +5,10 @@ import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Author;
import net.sf.briar.api.protocol.AuthorFactory;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageEncoder;
import net.sf.briar.api.protocol.MessageFactory;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.ProtocolReaderFactory;
@@ -29,7 +28,7 @@ public class ProtocolModule extends AbstractModule {
bind(AuthorFactory.class).to(AuthorFactoryImpl.class);
bind(BatchFactory.class).to(BatchFactoryImpl.class);
bind(GroupFactory.class).to(GroupFactoryImpl.class);
bind(MessageEncoder.class).to(MessageEncoderImpl.class);
bind(MessageFactory.class).to(MessageFactoryImpl.class);
bind(OfferFactory.class).to(OfferFactoryImpl.class);
bind(ProtocolReaderFactory.class).to(ProtocolReaderFactoryImpl.class);
bind(RequestFactory.class).to(RequestFactoryImpl.class);
@@ -38,9 +37,8 @@ public class ProtocolModule extends AbstractModule {
}
@Provides
ObjectReader<Ack> getAckReader(ObjectReader<BatchId> batchIdReader,
AckFactory ackFactory) {
return new AckReader(batchIdReader, ackFactory);
ObjectReader<Ack> getAckReader(AckFactory ackFactory) {
return new AckReader(ackFactory);
}
@Provides
@@ -55,11 +53,6 @@ public class ProtocolModule extends AbstractModule {
return new BatchReader(crypto, messageReader, batchFactory);
}
@Provides
ObjectReader<BatchId> getBatchIdReader() {
return new BatchIdReader();
}
@Provides
ObjectReader<Group> getGroupReader(CryptoComponent crypto,
GroupFactory groupFactory) {