Factored out header reading and batch reading into separate classes

for easier testing.
This commit is contained in:
akwizgran
2011-07-19 19:13:27 +01:00
parent fb528a85ad
commit 6b61cfa1bc
9 changed files with 217 additions and 135 deletions

View File

@@ -0,0 +1,62 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import java.util.List;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
public class BatchReader implements ObjectReader<Batch> {
private final PublicKey publicKey;
private final Signature signature;
private final MessageDigest messageDigest;
private final ObjectReader<Message> messageReader;
private final BatchFactory batchFactory;
BatchReader(PublicKey publicKey, Signature signature,
MessageDigest messageDigest, ObjectReader<Message> messageReader,
BatchFactory batchFactory) {
this.publicKey = publicKey;
this.signature = signature;
this.messageDigest = messageDigest;
this.messageReader = messageReader;
this.batchFactory = batchFactory;
}
public Batch readObject(Reader reader) throws IOException,
GeneralSecurityException {
// Initialise the input stream
CountingConsumer counting = new CountingConsumer(Batch.MAX_SIZE);
DigestingConsumer digesting = new DigestingConsumer(messageDigest);
messageDigest.reset();
SigningConsumer signing = new SigningConsumer(signature);
signature.initVerify(publicKey);
// Read the signed data
reader.addConsumer(counting);
reader.addConsumer(digesting);
reader.addConsumer(signing);
reader.addObjectReader(Tags.MESSAGE, messageReader);
List<Message> messages = reader.readList(Message.class);
reader.removeObjectReader(Tags.MESSAGE);
reader.removeConsumer(signing);
// Read and verify the signature
reader.readUserDefinedTag(Tags.SIGNATURE);
byte[] sig = reader.readRaw();
reader.removeConsumer(digesting);
reader.removeConsumer(counting);
if(!signature.verify(sig)) throw new SignatureException();
// Build and return the batch
BatchId id = new BatchId(messageDigest.digest());
return batchFactory.createBatch(id, messages);
}
}

View File

@@ -2,20 +2,10 @@ package net.sf.briar.protocol;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.BundleReader;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Header;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.ObjectReader;
@@ -26,102 +16,44 @@ class BundleReaderImpl implements BundleReader {
private static enum State { START, FIRST_BATCH, MORE_BATCHES, END };
private final Reader reader;
private final PublicKey publicKey;
private final Signature signature;
private final MessageDigest messageDigest;
private final ObjectReader<Message> messageReader;
private final HeaderFactory headerFactory;
private final BatchFactory batchFactory;
private final ObjectReader<Header> headerReader;
private final ObjectReader<Batch> batchReader;
private State state = State.START;
BundleReaderImpl(Reader reader, PublicKey publicKey, Signature signature,
MessageDigest messageDigest, ObjectReader<Message> messageReader,
HeaderFactory headerFactory, BatchFactory batchFactory) {
BundleReaderImpl(Reader reader, ObjectReader<Header> headerReader,
ObjectReader<Batch> batchReader) {
this.reader = reader;
this.publicKey = publicKey;
this.signature = signature;
this.messageDigest = messageDigest;
this.messageReader = messageReader;
this.headerFactory = headerFactory;
this.batchFactory = batchFactory;
this.headerReader = headerReader;
this.batchReader = batchReader;
}
public Header getHeader() throws IOException, GeneralSecurityException {
if(state != State.START) throw new IllegalStateException();
state = State.FIRST_BATCH;
// Initialise the input stream
CountingConsumer counting = new CountingConsumer(Header.MAX_SIZE);
SigningConsumer signing = new SigningConsumer(signature);
signature.initVerify(publicKey);
// Read the initial tag
reader.addObjectReader(Tags.HEADER, headerReader);
reader.readUserDefinedTag(Tags.HEADER);
// Read the signed data
reader.addConsumer(counting);
reader.addConsumer(signing);
reader.addObjectReader(Tags.BATCH_ID, new BatchIdReader());
reader.addObjectReader(Tags.GROUP_ID, new GroupIdReader());
// Acks
Collection<BatchId> acks = reader.readList(BatchId.class);
// Subs
Collection<GroupId> subs = reader.readList(GroupId.class);
// Transports
reader.readUserDefinedTag(Tags.TRANSPORTS);
Map<String, String> transports =
reader.readMap(String.class, String.class);
// Timestamp
reader.readUserDefinedTag(Tags.TIMESTAMP);
long timestamp = reader.readInt64();
if(timestamp < 0L) throw new FormatException();
reader.removeObjectReader(Tags.GROUP_ID);
reader.removeObjectReader(Tags.BATCH_ID);
reader.removeConsumer(signing);
// Read and verify the signature
reader.readUserDefinedTag(Tags.SIGNATURE);
byte[] sig = reader.readRaw();
reader.removeConsumer(counting);
if(!signature.verify(sig)) throw new SignatureException();
// Build and return the header
return headerFactory.createHeader(acks, subs, transports, timestamp);
Header h = reader.readUserDefinedObject(Tags.HEADER);
reader.removeObjectReader(Tags.HEADER);
state = State.FIRST_BATCH;
return h;
}
public Batch getNextBatch() throws IOException, GeneralSecurityException {
if(state == State.FIRST_BATCH) {
reader.readListStart();
reader.addObjectReader(Tags.BATCH, batchReader);
state = State.MORE_BATCHES;
}
if(state != State.MORE_BATCHES) throw new IllegalStateException();
if(reader.hasListEnd()) {
reader.removeObjectReader(Tags.BATCH);
reader.readListEnd();
// That should be all
if(!reader.eof()) throw new FormatException();
state = State.END;
return null;
}
// Initialise the input stream
CountingConsumer counting = new CountingConsumer(Batch.MAX_SIZE);
DigestingConsumer digesting = new DigestingConsumer(messageDigest);
messageDigest.reset();
SigningConsumer signing = new SigningConsumer(signature);
signature.initVerify(publicKey);
// Read the initial tag
reader.readUserDefinedTag(Tags.BATCH);
// Read the signed data
reader.addConsumer(counting);
reader.addConsumer(digesting);
reader.addConsumer(signing);
reader.addObjectReader(Tags.MESSAGE, messageReader);
List<Message> messages = reader.readList(Message.class);
reader.removeObjectReader(Tags.MESSAGE);
reader.removeConsumer(signing);
// Read and verify the signature
reader.readUserDefinedTag(Tags.SIGNATURE);
byte[] sig = reader.readRaw();
reader.removeConsumer(digesting);
reader.removeConsumer(counting);
if(!signature.verify(sig)) throw new SignatureException();
// Build and return the batch
BatchId id = new BatchId(messageDigest.digest());
return batchFactory.createBatch(id, messages);
return reader.readUserDefinedObject(Tags.BATCH);
}
public void finish() throws IOException {

View File

@@ -0,0 +1,66 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import java.util.Collection;
import java.util.Map;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Header;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
class HeaderReader implements ObjectReader<Header> {
private final PublicKey publicKey;
private final Signature signature;
private final HeaderFactory headerFactory;
HeaderReader(PublicKey publicKey, Signature signature,
HeaderFactory headerFactory) {
this.publicKey = publicKey;
this.signature = signature;
this.headerFactory = headerFactory;
}
public Header readObject(Reader reader) throws IOException,
GeneralSecurityException {
// Initialise the input stream
CountingConsumer counting = new CountingConsumer(Header.MAX_SIZE);
SigningConsumer signing = new SigningConsumer(signature);
signature.initVerify(publicKey);
// Read the signed data
reader.addConsumer(counting);
reader.addConsumer(signing);
// Acks
reader.addObjectReader(Tags.BATCH_ID, new BatchIdReader());
Collection<BatchId> acks = reader.readList(BatchId.class);
reader.removeObjectReader(Tags.BATCH_ID);
// Subs
reader.addObjectReader(Tags.GROUP_ID, new GroupIdReader());
Collection<GroupId> subs = reader.readList(GroupId.class);
reader.removeObjectReader(Tags.GROUP_ID);
// Transports
reader.readUserDefinedTag(Tags.TRANSPORTS);
Map<String, String> transports =
reader.readMap(String.class, String.class);
// Timestamp
reader.readUserDefinedTag(Tags.TIMESTAMP);
long timestamp = reader.readInt64();
if(timestamp < 0L) throw new FormatException();
reader.removeConsumer(signing);
// Read and verify the signature
reader.readUserDefinedTag(Tags.SIGNATURE);
byte[] sig = reader.readRaw();
reader.removeConsumer(counting);
if(!signature.verify(sig)) throw new SignatureException();
// Build and return the header
return headerFactory.createHeader(acks, subs, transports, timestamp);
}
}

View File

@@ -2,6 +2,7 @@ package net.sf.briar.serial;
import java.io.IOException;
import java.io.InputStream;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@@ -315,11 +316,13 @@ class ReaderImpl implements Reader {
|| (next & Tag.SHORT_MASK) == Tag.SHORT_LIST;
}
public List<Object> readList() throws IOException {
public List<Object> readList() throws IOException,
GeneralSecurityException {
return readList(Object.class);
}
public <E> List<E> readList(Class<E> e) throws IOException {
public <E> List<E> readList(Class<E> e) throws IOException,
GeneralSecurityException {
if(!hasList()) throw new FormatException();
if(next == Tag.LIST) {
readNext(false);
@@ -337,7 +340,8 @@ class ReaderImpl implements Reader {
}
}
private <E> List<E> readList(Class<E> e, int length) throws IOException {
private <E> List<E> readList(Class<E> e, int length) throws IOException,
GeneralSecurityException {
assert length >= 0;
List<E> list = new ArrayList<E>();
for(int i = 0; i < length; i++) list.add(readObject(e));
@@ -356,7 +360,7 @@ class ReaderImpl implements Reader {
readNext(true);
}
private Object readObject() throws IOException {
private Object readObject() throws IOException, GeneralSecurityException {
if(!started) throw new IllegalStateException();
if(hasUserDefinedTag()) {
ObjectReader<?> o = objectReaders.get(readUserDefinedTag());
@@ -383,7 +387,8 @@ class ReaderImpl implements Reader {
}
@SuppressWarnings("unchecked")
private <T> T readObject(Class<T> t) throws IOException {
private <T> T readObject(Class<T> t) throws IOException,
GeneralSecurityException {
try {
return (T) readObject();
} catch(ClassCastException e) {
@@ -417,11 +422,13 @@ class ReaderImpl implements Reader {
|| (next & Tag.SHORT_MASK) == Tag.SHORT_MAP;
}
public Map<Object, Object> readMap() throws IOException {
public Map<Object, Object> readMap() throws IOException,
GeneralSecurityException {
return readMap(Object.class, Object.class);
}
public <K, V> Map<K, V> readMap(Class<K> k, Class<V> v) throws IOException {
public <K, V> Map<K, V> readMap(Class<K> k, Class<V> v) throws IOException,
GeneralSecurityException {
if(!hasMap()) throw new FormatException();
if(next == Tag.MAP) {
readNext(false);
@@ -440,7 +447,7 @@ class ReaderImpl implements Reader {
}
private <K, V> Map<K, V> readMap(Class<K> k, Class<V> v, int size)
throws IOException {
throws IOException, GeneralSecurityException {
assert size >= 0;
Map<K, V> m = new HashMap<K, V>();
for(int i = 0; i < size; i++) m.put(readObject(k), readObject(v));
@@ -500,7 +507,8 @@ class ReaderImpl implements Reader {
if(readUserDefinedTag() != tag) throw new FormatException();
}
public <T> T readUserDefinedObject(int tag) throws IOException {
public <T> T readUserDefinedObject(int tag) throws IOException,
GeneralSecurityException {
ObjectReader<?> o = objectReaders.get(tag);
if(o == null) throw new FormatException();
try {