Nested user-defined objects (and collections of them) can now be read

by registering ObjectReaders with the Reader.
This commit is contained in:
akwizgran
2011-07-19 17:17:45 +01:00
parent a9e7cbd05c
commit fb528a85ad
22 changed files with 414 additions and 177 deletions

View File

@@ -2,6 +2,7 @@ package net.sf.briar.api.protocol;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.Collection;
import java.util.Map;
import net.sf.briar.api.serial.Raw;
@@ -16,12 +17,12 @@ public interface BundleWriter {
long getRemainingCapacity() throws IOException;
/** Adds a header to the bundle. */
void addHeader(Iterable<BatchId> acks, Iterable<GroupId> subs,
void addHeader(Collection<BatchId> acks, Collection<GroupId> subs,
Map<String, String> transports) throws IOException,
GeneralSecurityException;
/** Adds a batch of messages to the bundle and returns its identifier. */
BatchId addBatch(Iterable<Raw> messages) throws IOException,
BatchId addBatch(Collection<Raw> messages) throws IOException,
GeneralSecurityException;
/** Finishes writing the bundle. */

View File

@@ -2,15 +2,17 @@ package net.sf.briar.api.protocol;
public interface Tags {
static final int HEADER = 0;
static final int BATCH_ID = 1;
static final int GROUP_ID = 2;
static final int TIMESTAMP = 3;
static final int SIGNATURE = 4;
static final int BATCH = 5;
static final int AUTHOR_ID = 1;
static final int BATCH = 2;
static final int BATCH_ID = 3;
static final int GROUP_ID = 4;
static final int HEADER = 5;
static final int MESSAGE = 6;
static final int MESSAGE_ID = 7;
static final int AUTHOR = 8;
static final int MESSAGE_BODY = 9;
static final int AUTHOR_ID = 10;
static final int MESSAGE_BODY = 7;
static final int MESSAGE_ID = 8;
static final int NICKNAME = 9;
static final int PUBLIC_KEY = 10;
static final int SIGNATURE = 12;
static final int TIMESTAMP = 13;
static final int TRANSPORTS = 14;
}

View File

@@ -0,0 +1,8 @@
package net.sf.briar.api.serial;
import java.io.IOException;
public interface ObjectReader<T> {
T readObject(Reader r) throws IOException;
}

View File

@@ -12,6 +12,9 @@ public interface Reader {
void addConsumer(Consumer c);
void removeConsumer(Consumer c);
void addObjectReader(int tag, ObjectReader<?> o);
void removeObjectReader(int tag);
boolean hasBoolean() throws IOException;
boolean readBoolean() throws IOException;
@@ -60,4 +63,5 @@ public interface Reader {
boolean hasUserDefinedTag() throws IOException;
int readUserDefinedTag() throws IOException;
void readUserDefinedTag(int tag) throws IOException;
<T> T readUserDefinedObject(int tag) throws IOException;
}

View File

@@ -1,7 +1,7 @@
package net.sf.briar.api.serial;
import java.io.IOException;
import java.util.List;
import java.util.Collection;
import java.util.Map;
public interface Writer {
@@ -25,7 +25,7 @@ public interface Writer {
void writeRaw(byte[] b) throws IOException;
void writeRaw(Raw r) throws IOException;
void writeList(List<?> l) throws IOException;
void writeList(Collection<?> c) throws IOException;
void writeListStart() throws IOException;
void writeListEnd() throws IOException;

View File

@@ -0,0 +1,18 @@
package net.sf.briar.protocol;
import java.io.IOException;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
public class BatchIdReader implements ObjectReader<BatchId> {
public BatchId readObject(Reader r) throws IOException {
byte[] b = r.readRaw();
if(b.length != UniqueId.LENGTH) throw new FormatException();
return new BatchId(b);
}
}

View File

@@ -6,11 +6,9 @@ import java.security.MessageDigest;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
@@ -19,8 +17,8 @@ import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Header;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
class BundleReaderImpl implements BundleReader {
@@ -31,19 +29,19 @@ class BundleReaderImpl implements BundleReader {
private final PublicKey publicKey;
private final Signature signature;
private final MessageDigest messageDigest;
private final MessageReader messageReader;
private final ObjectReader<Message> messageReader;
private final HeaderFactory headerFactory;
private final BatchFactory batchFactory;
private State state = State.START;
BundleReaderImpl(Reader reader, PublicKey publicKey, Signature signature,
MessageDigest messageDigest, MessageReader messageParser,
MessageDigest messageDigest, ObjectReader<Message> messageReader,
HeaderFactory headerFactory, BatchFactory batchFactory) {
this.reader = reader;
this.publicKey = publicKey;
this.signature = signature;
this.messageDigest = messageDigest;
this.messageReader = messageParser;
this.messageReader = messageReader;
this.headerFactory = headerFactory;
this.batchFactory = batchFactory;
}
@@ -55,37 +53,27 @@ class BundleReaderImpl implements BundleReader {
CountingConsumer counting = new CountingConsumer(Header.MAX_SIZE);
SigningConsumer signing = new SigningConsumer(signature);
signature.initVerify(publicKey);
// Read the initial tag
reader.readUserDefinedTag(Tags.HEADER);
// Read the signed data
reader.addConsumer(counting);
reader.addConsumer(signing);
reader.readUserDefinedTag(Tags.HEADER);
reader.addObjectReader(Tags.BATCH_ID, new BatchIdReader());
reader.addObjectReader(Tags.GROUP_ID, new GroupIdReader());
// Acks
Set<BatchId> acks = new HashSet<BatchId>();
reader.readListStart();
while(!reader.hasListEnd()) {
reader.readUserDefinedTag(Tags.BATCH_ID);
byte[] b = reader.readRaw();
if(b.length != UniqueId.LENGTH) throw new FormatException();
acks.add(new BatchId(b));
}
reader.readListEnd();
Collection<BatchId> acks = reader.readList(BatchId.class);
// Subs
Set<GroupId> subs = new HashSet<GroupId>();
reader.readListStart();
while(!reader.hasListEnd()) {
reader.readUserDefinedTag(Tags.GROUP_ID);
byte[] b = reader.readRaw();
if(b.length != UniqueId.LENGTH) throw new FormatException();
subs.add(new GroupId(b));
}
reader.readListEnd();
Collection<GroupId> subs = reader.readList(GroupId.class);
// Transports
reader.readUserDefinedTag(Tags.TRANSPORTS);
Map<String, String> transports =
reader.readMap(String.class, String.class);
// Timestamp
reader.readUserDefinedTag(Tags.TIMESTAMP);
long timestamp = reader.readInt64();
if(timestamp < 0L) throw new FormatException();
reader.removeObjectReader(Tags.GROUP_ID);
reader.removeObjectReader(Tags.BATCH_ID);
reader.removeConsumer(signing);
// Read and verify the signature
reader.readUserDefinedTag(Tags.SIGNATURE);
@@ -115,17 +103,15 @@ class BundleReaderImpl implements BundleReader {
messageDigest.reset();
SigningConsumer signing = new SigningConsumer(signature);
signature.initVerify(publicKey);
// Read the initial tag
reader.readUserDefinedTag(Tags.BATCH);
// Read the signed data
reader.addConsumer(counting);
reader.addConsumer(digesting);
reader.addConsumer(signing);
reader.readUserDefinedTag(Tags.BATCH);
List<Message> messages = new ArrayList<Message>();
reader.readListStart();
while(!reader.hasListEnd()) {
messages.add(messageReader.readMessage(reader));
}
reader.readListEnd();
reader.addObjectReader(Tags.MESSAGE, messageReader);
List<Message> messages = reader.readList(Message.class);
reader.removeObjectReader(Tags.MESSAGE);
reader.removeConsumer(signing);
// Read and verify the signature
reader.readUserDefinedTag(Tags.SIGNATURE);

View File

@@ -6,6 +6,7 @@ import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.PrivateKey;
import java.security.Signature;
import java.util.Collection;
import java.util.Map;
import net.sf.briar.api.protocol.BatchId;
@@ -44,24 +45,22 @@ class BundleWriterImpl implements BundleWriter {
return capacity - writer.getBytesWritten();
}
public void addHeader(Iterable<BatchId> acks, Iterable<GroupId> subs,
public void addHeader(Collection<BatchId> acks, Collection<GroupId> subs,
Map<String, String> transports) throws IOException,
GeneralSecurityException {
if(state != State.START) throw new IllegalStateException();
// Initialise the output stream
signature.initSign(privateKey);
// Write the initial tag
writer.writeUserDefinedTag(Tags.HEADER);
// Write the data to be signed
out.setSigning(true);
writer.writeUserDefinedTag(Tags.HEADER);
// Acks
writer.writeListStart();
for(BatchId ack : acks) ack.writeTo(writer);
writer.writeListEnd();
writer.writeList(acks);
// Subs
writer.writeListStart();
for(GroupId sub : subs) sub.writeTo(writer);
writer.writeListEnd();
writer.writeList(subs);
// Transports
writer.writeUserDefinedTag(Tags.TRANSPORTS);
writer.writeMap(transports);
// Timestamp
writer.writeUserDefinedTag(Tags.TIMESTAMP);
@@ -75,7 +74,7 @@ class BundleWriterImpl implements BundleWriter {
state = State.FIRST_BATCH;
}
public BatchId addBatch(Iterable<Raw> messages) throws IOException,
public BatchId addBatch(Collection<Raw> messages) throws IOException,
GeneralSecurityException {
if(state == State.FIRST_BATCH) {
writer.writeListStart();
@@ -85,13 +84,17 @@ class BundleWriterImpl implements BundleWriter {
// Initialise the output stream
signature.initSign(privateKey);
messageDigest.reset();
// Write the initial tag
writer.writeUserDefinedTag(Tags.BATCH);
// Write the data to be signed
out.setDigesting(true);
out.setSigning(true);
writer.writeUserDefinedTag(Tags.BATCH);
writer.writeListStart();
// Bypass the writer and write the raw messages directly
for(Raw message : messages) out.write(message.getBytes());
for(Raw message : messages) {
writer.writeUserDefinedTag(Tags.MESSAGE);
out.write(message.getBytes());
}
writer.writeListEnd();
out.setSigning(false);
// Create and write the signature

View File

@@ -0,0 +1,18 @@
package net.sf.briar.protocol;
import java.io.IOException;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
public class GroupIdReader implements ObjectReader<GroupId> {
public GroupId readObject(Reader r) throws IOException {
byte[] b = r.readRaw();
if(b.length != UniqueId.LENGTH) throw new FormatException();
return new GroupId(b);
}
}

View File

@@ -1,7 +1,7 @@
package net.sf.briar.protocol;
import java.util.Collection;
import java.util.Map;
import java.util.Set;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.GroupId;
@@ -9,6 +9,6 @@ import net.sf.briar.api.protocol.Header;
interface HeaderFactory {
Header createHeader(Set<BatchId> acks, Set<GroupId> subs,
Header createHeader(Collection<BatchId> acks, Collection<GroupId> subs,
Map<String, String> transports, long timestamp);
}

View File

@@ -1,5 +1,7 @@
package net.sf.briar.protocol;
import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
@@ -9,8 +11,11 @@ import net.sf.briar.api.protocol.Header;
class HeaderFactoryImpl implements HeaderFactory {
public Header createHeader(Set<BatchId> acks, Set<GroupId> subs,
Map<String, String> transports, long timestamp) {
return new HeaderImpl(acks, subs, transports, timestamp);
public Header createHeader(Collection<BatchId> acks,
Collection<GroupId> subs, Map<String, String> transports,
long timestamp) {
Set<BatchId> ackSet = new HashSet<BatchId>(acks);
Set<GroupId> subSet = new HashSet<GroupId>(subs);
return new HeaderImpl(ackSet, subSet, transports, timestamp);
}
}

View File

@@ -33,23 +33,26 @@ class MessageEncoderImpl implements MessageEncoder {
KeyPair keyPair, byte[] body) throws IOException,
GeneralSecurityException {
long timestamp = System.currentTimeMillis();
byte[] encodedKey = keyPair.getPublic().getEncoded();
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
w.writeUserDefinedTag(Tags.MESSAGE);
// Write the message
parent.writeTo(w);
group.writeTo(w);
w.writeUserDefinedTag(Tags.TIMESTAMP);
w.writeInt64(timestamp);
w.writeUserDefinedTag(Tags.AUTHOR);
w.writeUserDefinedTag(Tags.NICKNAME);
w.writeString(nick);
w.writeRaw(encodedKey);
w.writeUserDefinedTag(Tags.PUBLIC_KEY);
w.writeRaw(keyPair.getPublic().getEncoded());
w.writeUserDefinedTag(Tags.MESSAGE_BODY);
w.writeRaw(body);
// Sign the message
byte[] signable = out.toByteArray();
signature.initSign(keyPair.getPrivate());
signature.update(signable);
byte[] sig = signature.sign();
signable = null;
// Write the signature
w.writeUserDefinedTag(Tags.SIGNATURE);
w.writeRaw(sig);
byte[] raw = out.toByteArray();
@@ -61,13 +64,14 @@ class MessageEncoderImpl implements MessageEncoder {
// The author ID is the hash of the author's nick and public key
out.reset();
w = writerFactory.createWriter(out);
w.writeUserDefinedTag(Tags.AUTHOR);
w.writeUserDefinedTag(Tags.NICKNAME);
w.writeString(nick);
w.writeRaw(encodedKey);
w.writeUserDefinedTag(Tags.PUBLIC_KEY);
w.writeRaw(keyPair.getPublic().getEncoded());
w.close();
messageDigest.reset();
messageDigest.update(out.toByteArray());
AuthorId author = new AuthorId(messageDigest.digest());
return new MessageImpl(id, parent, group, author, timestamp, raw);
AuthorId authorId = new AuthorId(messageDigest.digest());
return new MessageImpl(id, parent, group, authorId, timestamp, raw);
}
}

View File

@@ -2,12 +2,94 @@ package net.sf.briar.protocol;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import java.security.spec.InvalidKeySpecException;
import net.sf.briar.api.crypto.KeyParser;
import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
interface MessageReader {
class MessageReader implements ObjectReader<Message> {
Message readMessage(Reader r) throws IOException,
GeneralSecurityException;
private final KeyParser keyParser;
private final Signature signature;
private final MessageDigest messageDigest;
MessageReader(KeyParser keyParser, Signature signature,
MessageDigest messageDigest) {
this.keyParser = keyParser;
this.signature = signature;
this.messageDigest = messageDigest;
}
public Message readObject(Reader reader) throws IOException {
CopyingConsumer copying = new CopyingConsumer();
CountingConsumer counting = new CountingConsumer(Message.MAX_SIZE);
reader.addConsumer(copying);
reader.addConsumer(counting);
// Read the parent's message ID
reader.readUserDefinedTag(Tags.MESSAGE_ID);
byte[] b = reader.readRaw();
if(b.length != UniqueId.LENGTH) throw new FormatException();
MessageId parent = new MessageId(b);
// Read the group ID
reader.readUserDefinedTag(Tags.GROUP_ID);
b = reader.readRaw();
if(b.length != UniqueId.LENGTH) throw new FormatException();
GroupId group = new GroupId(b);
// Read the timestamp
reader.readUserDefinedTag(Tags.TIMESTAMP);
long timestamp = reader.readInt64();
if(timestamp < 0L) throw new FormatException();
// Hash the author's nick and public key to get the author ID
DigestingConsumer digesting = new DigestingConsumer(messageDigest);
messageDigest.reset();
reader.addConsumer(digesting);
reader.readUserDefinedTag(Tags.NICKNAME);
reader.readString();
reader.readUserDefinedTag(Tags.PUBLIC_KEY);
byte[] encodedKey = reader.readRaw();
reader.removeConsumer(digesting);
AuthorId author = new AuthorId(messageDigest.digest());
// Skip the message body
reader.readUserDefinedTag(Tags.MESSAGE_BODY);
reader.readRaw();
// Record the length of the signed data
int messageLength = (int) counting.getCount();
// Read the signature
reader.readUserDefinedTag(Tags.SIGNATURE);
byte[] sig = reader.readRaw();
reader.removeConsumer(counting);
reader.removeConsumer(copying);
// Verify the signature
PublicKey publicKey;
try {
publicKey = keyParser.parsePublicKey(encodedKey);
} catch(InvalidKeySpecException e) {
throw new FormatException();
}
byte[] raw = copying.getCopy();
try {
signature.initVerify(publicKey);
signature.update(raw, 0, messageLength);
if(!signature.verify(sig)) throw new SignatureException();
} catch(GeneralSecurityException e) {
throw new FormatException();
}
// Hash the message, including the signature, to get the message ID
messageDigest.reset();
messageDigest.update(raw);
MessageId id = new MessageId(messageDigest.digest());
return new MessageImpl(id, parent, group, author, timestamp, raw);
}
}

View File

@@ -1,92 +0,0 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import java.security.spec.InvalidKeySpecException;
import net.sf.briar.api.crypto.KeyParser;
import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.Reader;
class MessageReaderImpl implements MessageReader {
private final KeyParser keyParser;
private final Signature signature;
private final MessageDigest messageDigest;
MessageReaderImpl(KeyParser keyParser, Signature signature,
MessageDigest messageDigest) {
this.keyParser = keyParser;
this.signature = signature;
this.messageDigest = messageDigest;
}
public Message readMessage(Reader reader) throws IOException,
GeneralSecurityException {
CopyingConsumer copying = new CopyingConsumer();
CountingConsumer counting = new CountingConsumer(Message.MAX_SIZE);
DigestingConsumer digesting = new DigestingConsumer(messageDigest);
messageDigest.reset();
reader.addConsumer(copying);
reader.addConsumer(counting);
// Read the initial tag
reader.readUserDefinedTag(Tags.MESSAGE);
// Read the parent's message ID
reader.readUserDefinedTag(Tags.MESSAGE_ID);
byte[] b = reader.readRaw();
if(b.length != UniqueId.LENGTH) throw new FormatException();
MessageId parent = new MessageId(b);
// Read the group ID
reader.readUserDefinedTag(Tags.GROUP_ID);
b = reader.readRaw();
if(b.length != UniqueId.LENGTH) throw new FormatException();
GroupId group = new GroupId(b);
// Read the timestamp
reader.readUserDefinedTag(Tags.TIMESTAMP);
long timestamp = reader.readInt64();
if(timestamp < 0L) throw new FormatException();
// Hash the author's nick and public key to get the author ID
reader.addConsumer(digesting);
reader.readUserDefinedTag(Tags.AUTHOR);
reader.readString();
byte[] encodedKey = reader.readRaw();
reader.removeConsumer(digesting);
AuthorId author = new AuthorId(messageDigest.digest());
// Skip the message body
reader.readUserDefinedTag(Tags.MESSAGE_BODY);
reader.readRaw();
// Record the length of the signed data
int messageLength = (int) counting.getCount();
// Read the signature
reader.readUserDefinedTag(Tags.SIGNATURE);
byte[] sig = reader.readRaw();
reader.removeConsumer(counting);
reader.removeConsumer(copying);
// Verify the signature
PublicKey publicKey;
try {
publicKey = keyParser.parsePublicKey(encodedKey);
} catch(InvalidKeySpecException e) {
throw new FormatException();
}
byte[] raw = copying.getCopy();
signature.initVerify(publicKey);
signature.update(raw, 0, messageLength);
if(!signature.verify(sig)) throw new SignatureException();
// Hash the message, including the signature, to get the message ID
messageDigest.reset();
messageDigest.update(raw);
MessageId id = new MessageId(messageDigest.digest());
return new MessageImpl(id, parent, group, author, timestamp, raw);
}
}

View File

@@ -2,7 +2,6 @@ package net.sf.briar.protocol;
import net.sf.briar.api.protocol.BundleReader;
import net.sf.briar.api.protocol.BundleWriter;
import net.sf.briar.api.protocol.MessageEncoder;
import com.google.inject.AbstractModule;
@@ -14,7 +13,5 @@ public class ProtocolModule extends AbstractModule {
bind(BundleReader.class).to(BundleReaderImpl.class);
bind(BundleWriter.class).to(BundleWriterImpl.class);
bind(HeaderFactory.class).to(HeaderFactoryImpl.class);
bind(MessageEncoder.class).to(MessageEncoderImpl.class);
bind(MessageReader.class).to(MessageReaderImpl.class);
}
}

View File

@@ -9,6 +9,7 @@ import java.util.Map;
import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.RawByteArray;
import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.Tag;
@@ -18,6 +19,9 @@ class ReaderImpl implements Reader {
private static final byte[] EMPTY_BUFFER = new byte[] {};
private final InputStream in;
private final Map<Integer, ObjectReader<?>> objectReaders =
new HashMap<Integer, ObjectReader<?>>();
private Consumer[] consumers = new Consumer[] {};
private boolean started = false, eof = false;
private byte next;
@@ -72,6 +76,14 @@ class ReaderImpl implements Reader {
else throw new IllegalArgumentException();
}
public void addObjectReader(int tag, ObjectReader<?> o) {
objectReaders.put(tag, o);
}
public void removeObjectReader(int tag) {
objectReaders.remove(tag);
}
public boolean hasBoolean() throws IOException {
if(!started) readNext(true);
if(eof) return false;
@@ -346,6 +358,11 @@ class ReaderImpl implements Reader {
private Object readObject() throws IOException {
if(!started) throw new IllegalStateException();
if(hasUserDefinedTag()) {
ObjectReader<?> o = objectReaders.get(readUserDefinedTag());
if(o == null) throw new FormatException();
return o.readObject(this);
}
if(hasBoolean()) return Boolean.valueOf(readBoolean());
if(hasUint7()) return Byte.valueOf(readUint7());
if(hasInt8()) return Byte.valueOf(readInt8());
@@ -482,4 +499,16 @@ class ReaderImpl implements Reader {
public void readUserDefinedTag(int tag) throws IOException {
if(readUserDefinedTag() != tag) throw new FormatException();
}
public <T> T readUserDefinedObject(int tag) throws IOException {
ObjectReader<?> o = objectReaders.get(tag);
if(o == null) throw new FormatException();
try {
@SuppressWarnings("unchecked")
ObjectReader<T> cast = (ObjectReader<T>) o;
return cast.readObject(this);
} catch(ClassCastException e) {
throw new FormatException();
}
}
}

View File

@@ -2,12 +2,14 @@ package net.sf.briar.serial;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import net.sf.briar.api.serial.Raw;
import net.sf.briar.api.serial.Tag;
import net.sf.briar.api.serial.Writable;
import net.sf.briar.api.serial.Writer;
class WriterImpl implements Writer {
@@ -146,19 +148,20 @@ class WriterImpl implements Writer {
writeRaw(r.getBytes());
}
public void writeList(List<?> l) throws IOException {
int length = l.size();
public void writeList(Collection<?> c) throws IOException {
int length = c.size();
if(length < 16) out.write(intToByte(Tag.SHORT_LIST | length));
else {
out.write(Tag.LIST);
writeLength(length);
}
for(Object o : l) writeObject(o);
for(Object o : c) writeObject(o);
bytesWritten++;
}
private void writeObject(Object o) throws IOException {
if(o instanceof Boolean) writeBoolean((Boolean) o);
if(o instanceof Writable) ((Writable) o).writeTo(this);
else if(o instanceof Boolean) writeBoolean((Boolean) o);
else if(o instanceof Byte) writeIntAny((Byte) o);
else if(o instanceof Short) writeIntAny((Short) o);
else if(o instanceof Integer) writeIntAny((Integer) o);

View File

@@ -22,6 +22,7 @@
<test name='net.sf.briar.invitation.InvitationWorkerTest'/>
<test name='net.sf.briar.protocol.BundleReadWriteTest'/>
<test name='net.sf.briar.protocol.ConsumersTest'/>
<test name='net.sf.briar.protocol.SigningDigestingOutputStreamTest'/>
<test name='net.sf.briar.serial.ReaderImplTest'/>
<test name='net.sf.briar.serial.WriterImplTest'/>
<test name='net.sf.briar.setup.SetupWorkerTest'/>

View File

@@ -125,7 +125,7 @@ public class BundleReadWriteTest extends TestCase {
testWriteBundle();
MessageReader messageReader =
new MessageReaderImpl(keyParser, sig1, dig1);
new MessageReader(keyParser, sig1, dig1);
FileInputStream in = new FileInputStream(bundle);
Reader reader = rf.createReader(in);
BundleReader r = new BundleReaderImpl(reader, keyPair.getPublic(), sig,
@@ -158,14 +158,14 @@ public class BundleReadWriteTest extends TestCase {
testWriteBundle();
RandomAccessFile f = new RandomAccessFile(bundle, "rw");
f.seek(bundle.length() - 150);
f.seek(bundle.length() - 100);
byte b = f.readByte();
f.seek(bundle.length() - 150);
f.seek(bundle.length() - 100);
f.writeByte(b + 1);
f.close();
MessageReader messageReader =
new MessageReaderImpl(keyParser, sig1, dig1);
new MessageReader(keyParser, sig1, dig1);
FileInputStream in = new FileInputStream(bundle);
Reader reader = rf.createReader(in);
BundleReader r = new BundleReaderImpl(reader, keyPair.getPublic(), sig,

View File

@@ -0,0 +1,81 @@
package net.sf.briar.protocol;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.MessageDigest;
import java.security.Signature;
import java.util.Arrays;
import java.util.Random;
import junit.framework.TestCase;
import org.junit.Before;
import org.junit.Test;
public class SigningDigestingOutputStreamTest extends TestCase {
private static final String SIGNATURE_ALGO = "SHA256withRSA";
private static final String KEY_PAIR_ALGO = "RSA";
private static final String DIGEST_ALGO = "SHA-256";
private KeyPair keyPair = null;
private Signature sig = null;
private MessageDigest dig = null;
@Before
public void setUp() throws Exception {
keyPair = KeyPairGenerator.getInstance(KEY_PAIR_ALGO).generateKeyPair();
sig = Signature.getInstance(SIGNATURE_ALGO);
dig = MessageDigest.getInstance(DIGEST_ALGO);
}
@Test
public void testStopAndStart() throws Exception {
byte[] input = new byte[1024];
new Random().nextBytes(input);
ByteArrayOutputStream out = new ByteArrayOutputStream(input.length);
SigningDigestingOutputStream s =
new SigningDigestingOutputStream(out, sig, dig);
sig.initSign(keyPair.getPrivate());
dig.reset();
// Sign the first 256 bytes, digest all but the last 256 bytes
s.setDigesting(true);
s.setSigning(true);
s.write(input, 0, 256);
s.setSigning(false);
s.write(input, 256, 512);
s.setDigesting(false);
s.write(input, 768, 256);
s.close();
// Get the signature and the digest
byte[] signature = sig.sign();
byte[] digest = dig.digest();
// Check that the output matches the input
assertTrue(Arrays.equals(input, out.toByteArray()));
// Check that the signature matches a signature over the first 256 bytes
sig.initSign(keyPair.getPrivate());
sig.update(input, 0, 256);
byte[] directSignature = sig.sign();
assertTrue(Arrays.equals(directSignature, signature));
// Check that the digest matches a digest over all but the last 256
// bytes
dig.reset();
dig.update(input, 0, 768);
byte[] directDigest = dig.digest();
assertTrue(Arrays.equals(directDigest, digest));
}
@Test
public void testSignatureExceptionThrowsIOException() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
SigningDigestingOutputStream s =
new SigningDigestingOutputStream(out, sig, dig);
s.setSigning(true); // Signature hasn't been initialised yet
try {
s.write((byte) 0);
assertTrue(false);
} catch(IOException expected) {};
}
}

View File

@@ -8,8 +8,10 @@ import java.util.Map;
import java.util.Map.Entry;
import junit.framework.TestCase;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Raw;
import net.sf.briar.api.serial.RawByteArray;
import net.sf.briar.api.serial.Reader;
import net.sf.briar.util.StringUtils;
import org.junit.Test;
@@ -326,6 +328,56 @@ public class ReaderImplTest extends TestCase {
assertTrue(r.eof());
}
@Test
public void testReadUserDefinedObject() throws IOException {
setContents("C0" + "83666F6F");
// Add an object reader for a user-defined type
r.addObjectReader(0, new ObjectReader<Foo>() {
public Foo readObject(Reader r) throws IOException {
return new Foo(r.readString());
}
});
assertEquals(0, r.readUserDefinedTag());
assertEquals("foo", r.<Foo>readUserDefinedObject(0).s);
}
@Test
public void testReadListUsingObjectReader() throws IOException {
setContents("A" + "1" + "C0" + "83666F6F");
// Add an object reader for a user-defined type
r.addObjectReader(0, new ObjectReader<Foo>() {
public Foo readObject(Reader r) throws IOException {
return new Foo(r.readString());
}
});
// Check that the object reader is used for lists
List<Foo> l = r.readList(Foo.class);
assertEquals(1, l.size());
assertEquals("foo", l.get(0).s);
}
@Test
public void testReadMapUsingObjectReader() throws IOException {
setContents("B" + "1" + "C0" + "83666F6F" + "C1" + "83626172");
// Add object readers for two user-defined types
r.addObjectReader(0, new ObjectReader<Foo>() {
public Foo readObject(Reader r) throws IOException {
return new Foo(r.readString());
}
});
r.addObjectReader(1, new ObjectReader<Bar>() {
public Bar readObject(Reader r) throws IOException {
return new Bar(r.readString());
}
});
// Check that the object readers are used for maps
Map<Foo, Bar> m = r.readMap(Foo.class, Bar.class);
assertEquals(1, m.size());
Entry<Foo, Bar> e = m.entrySet().iterator().next();
assertEquals("foo", e.getKey().s);
assertEquals("bar", e.getValue().s);
}
@Test
public void testReadEmptyInput() throws IOException {
setContents("");
@@ -336,4 +388,22 @@ public class ReaderImplTest extends TestCase {
in = new ByteArrayInputStream(StringUtils.fromHexString(hex));
r = new ReaderImpl(in);
}
private static class Foo {
private final String s;
private Foo(String s) {
this.s = s;
}
}
private static class Bar {
private final String s;
private Bar(String s) {
this.s = s;
}
}
}

View File

@@ -4,12 +4,15 @@ import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import junit.framework.TestCase;
import net.sf.briar.api.serial.RawByteArray;
import net.sf.briar.api.serial.Writable;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.util.StringUtils;
import org.junit.Before;
@@ -300,6 +303,20 @@ public class WriterImplTest extends TestCase {
checkContents("EF" + "20" + "EF" + "FB7FFFFFFF");
}
@Test
public void testWriteCollectionOfWritables() throws IOException {
Writable writable = new Writable() {
public void writeTo(Writer w) throws IOException {
w.writeUserDefinedTag(0);
w.writeString("foo");
}
};
w.writeList(Collections.singleton(writable));
// SHORT_LIST tag, length 1, SHORT_USER tag (3 bits), 0 (5 bits),
// "foo" as short string
checkContents("A" + "1" + "C0" + "83666F6F");
}
private void checkContents(String hex) throws IOException {
out.flush();
out.close();