Cleaned up serial and protocol packages in preparation for user-defined types.

This commit is contained in:
akwizgran
2011-07-18 14:33:41 +01:00
parent 308a7017be
commit 0bc8a31749
20 changed files with 378 additions and 495 deletions

View File

@@ -0,0 +1,12 @@
package net.sf.briar.api.serial;
import java.io.IOException;
public interface Consumer {
void write(byte b) throws IOException;
void write(byte[] b) throws IOException;
void write(byte[] b, int off, int len) throws IOException;
}

View File

@@ -7,11 +7,11 @@ import java.util.Map;
public interface Reader {
boolean eof() throws IOException;
void setReadLimit(long limit);
void resetReadLimit();
long getRawBytesRead();
void close() throws IOException;
void addConsumer(Consumer c);
void removeConsumer(Consumer c);
boolean hasBoolean() throws IOException;
boolean readBoolean() throws IOException;
@@ -33,8 +33,8 @@ public interface Reader {
boolean hasFloat64() throws IOException;
double readFloat64() throws IOException;
boolean hasUtf8() throws IOException;
String readUtf8() throws IOException;
boolean hasString() throws IOException;
String readString() throws IOException;
boolean hasRaw() throws IOException;
byte[] readRaw() throws IOException;

View File

@@ -2,11 +2,29 @@ package net.sf.briar.api.serial;
public interface Tag {
public static final byte FALSE = -1, TRUE = -2;
public static final byte INT8 = -3, INT16 = -4, INT32 = -5, INT64 = -6;
public static final byte FLOAT32 = -7, FLOAT64 = -8;
public static final byte UTF8 = -9, RAW = -10;
public static final byte LIST_DEF = -11, MAP_DEF = -12;
public static final byte LIST_INDEF = -13, MAP_INDEF = -14, END = -15;
public static final byte NULL = -16;
public static final byte FALSE = -1; // 1111 1111
public static final byte TRUE = -2; // 1111 1110
public static final byte INT8 = -3; // 1111 1101
public static final byte INT16 = -4; // 1111 1100
public static final byte INT32 = -5; // 1111 1011
public static final byte INT64 = -6; // 1111 1010
public static final byte FLOAT32 = -7; // 1111 1001
public static final byte FLOAT64 = -8; // 1111 1000
public static final byte STRING = -9; // 1111 0111
public static final byte RAW = -10; // 1111 0110
public static final byte LIST = -11; // 1111 0101
public static final byte MAP = -12; // 1111 0100
public static final byte LIST_START = -13; // 1111 0011
public static final byte MAP_START = -14; // 1111 0010
public static final byte END = -15; // 1111 0001
public static final byte NULL = -16; // 1111 0000
public static final int SHORT_MASK = 0xF0; // Match first four bits
public static final int SHORT_STRING = 0x80; // 1000 xxxx
public static final int SHORT_RAW = 0x90; // 1001 xxxx
public static final int SHORT_LIST = 0xA0; // 1010 xxxx
public static final int SHORT_MAP = 0xB0; // 1011 xxxx
public static final int USER_MASK = 0xE0; // Match first three bits
public static final int USER = 0xC0; // 110x xxxx
public static final byte USER_EXT = -32; // 1110 0000
}

View File

@@ -6,7 +6,7 @@ import java.util.Map;
public interface Writer {
long getRawBytesWritten();
long getBytesWritten();
void close() throws IOException;
void writeBoolean(boolean b) throws IOException;
@@ -21,7 +21,7 @@ public interface Writer {
void writeFloat32(float f) throws IOException;
void writeFloat64(double d) throws IOException;
void writeUtf8(String s) throws IOException;
void writeString(String s) throws IOException;
void writeRaw(byte[] b) throws IOException;
void writeRaw(Raw r) throws IOException;

View File

@@ -1,7 +1,6 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.io.InputStream;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.PublicKey;
@@ -24,16 +23,12 @@ import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.Raw;
import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory;
import com.google.inject.Inject;
class BundleReaderImpl implements BundleReader {
private static enum State { START, FIRST_BATCH, MORE_BATCHES, END };
private final SigningDigestingInputStream in;
private final Reader r;
private final Reader reader;
private final PublicKey publicKey;
private final Signature signature;
private final MessageDigest messageDigest;
@@ -42,13 +37,10 @@ class BundleReaderImpl implements BundleReader {
private final BatchFactory batchFactory;
private State state = State.START;
@Inject
BundleReaderImpl(InputStream in, ReaderFactory readerFactory,
PublicKey publicKey, Signature signature,
BundleReaderImpl(Reader reader, PublicKey publicKey, Signature signature,
MessageDigest messageDigest, MessageParser messageParser,
HeaderFactory headerFactory, BatchFactory batchFactory) {
this.in = new SigningDigestingInputStream(in, signature, messageDigest);
r = readerFactory.createReader(this.in);
this.reader = reader;
this.publicKey = publicKey;
this.signature = signature;
this.messageDigest = messageDigest;
@@ -61,29 +53,31 @@ class BundleReaderImpl implements BundleReader {
if(state != State.START) throw new IllegalStateException();
state = State.FIRST_BATCH;
// Initialise the input stream
CountingConsumer counting = new CountingConsumer(Header.MAX_SIZE);
SigningConsumer signing = new SigningConsumer(signature);
signature.initVerify(publicKey);
messageDigest.reset();
// Read the signed data
in.setSigning(true);
r.setReadLimit(Header.MAX_SIZE);
reader.addConsumer(counting);
reader.addConsumer(signing);
Set<BatchId> acks = new HashSet<BatchId>();
for(Raw raw : r.readList(Raw.class)) {
for(Raw raw : reader.readList(Raw.class)) {
byte[] b = raw.getBytes();
if(b.length != UniqueId.LENGTH) throw new FormatException();
acks.add(new BatchId(b));
}
Set<GroupId> subs = new HashSet<GroupId>();
for(Raw raw : r.readList(Raw.class)) {
for(Raw raw : reader.readList(Raw.class)) {
byte[] b = raw.getBytes();
if(b.length != UniqueId.LENGTH) throw new FormatException();
subs.add(new GroupId(b));
}
Map<String, String> transports =
r.readMap(String.class, String.class);
long timestamp = r.readInt64();
in.setSigning(false);
reader.readMap(String.class, String.class);
long timestamp = reader.readInt64();
reader.removeConsumer(signing);
// Read and verify the signature
byte[] sig = r.readRaw();
byte[] sig = reader.readRaw();
reader.removeConsumer(counting);
if(!signature.verify(sig)) throw new SignatureException();
// Build and return the header
return headerFactory.createHeader(acks, subs, transports, timestamp);
@@ -91,29 +85,33 @@ class BundleReaderImpl implements BundleReader {
public Batch getNextBatch() throws IOException, GeneralSecurityException {
if(state == State.FIRST_BATCH) {
r.readListStart();
reader.readListStart();
state = State.MORE_BATCHES;
}
if(state != State.MORE_BATCHES) throw new IllegalStateException();
if(r.hasListEnd()) {
r.readListEnd();
if(reader.hasListEnd()) {
reader.readListEnd();
// That should be all
if(!r.eof()) throw new FormatException();
if(!reader.eof()) throw new FormatException();
state = State.END;
return null;
}
// Initialise the input stream
signature.initVerify(publicKey);
CountingConsumer counting = new CountingConsumer(Batch.MAX_SIZE);
DigestingConsumer digesting = new DigestingConsumer(messageDigest);
messageDigest.reset();
SigningConsumer signing = new SigningConsumer(signature);
signature.initVerify(publicKey);
// Read the signed data
in.setDigesting(true);
in.setSigning(true);
r.setReadLimit(Batch.MAX_SIZE);
List<Raw> rawMessages = r.readList(Raw.class);
in.setSigning(false);
reader.addConsumer(counting);
reader.addConsumer(digesting);
reader.addConsumer(signing);
List<Raw> rawMessages = reader.readList(Raw.class);
reader.removeConsumer(signing);
// Read and verify the signature
byte[] sig = r.readRaw();
in.setDigesting(false);
byte[] sig = reader.readRaw();
reader.removeConsumer(digesting);
reader.removeConsumer(counting);
if(!signature.verify(sig)) throw new SignatureException();
// Parse the messages
List<Message> messages = new ArrayList<Message>(rawMessages.size());
@@ -127,6 +125,6 @@ class BundleReaderImpl implements BundleReader {
}
public void finish() throws IOException {
r.close();
reader.close();
}
}

View File

@@ -40,7 +40,7 @@ class BundleWriterImpl implements BundleWriter {
}
public long getRemainingCapacity() {
return capacity - w.getRawBytesWritten();
return capacity - w.getBytesWritten();
}
public void addHeader(Iterable<BatchId> acks, Iterable<GroupId> subs,

View File

@@ -0,0 +1,39 @@
package net.sf.briar.protocol;
import java.io.IOException;
import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.FormatException;
/**
* A consumer that counts the number of bytes consumed and throws a
* FormatException if the count exceeds a given limit.
*/
class CountingConsumer implements Consumer {
private final long limit;
private long count = 0L;
CountingConsumer(long limit) {
this.limit = limit;
}
long getCount() {
return count;
}
public void write(byte b) throws IOException {
count++;
if(count > limit) throw new FormatException();
}
public void write(byte[] b) throws IOException {
count += b.length;
if(count > limit) throw new FormatException();
}
public void write(byte[] b, int off, int len) throws IOException {
count += len;
if(count > limit) throw new FormatException();
}
}

View File

@@ -0,0 +1,28 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.security.MessageDigest;
import net.sf.briar.api.serial.Consumer;
/** A consumer that passes its input through a message digest. */
class DigestingConsumer implements Consumer {
private final MessageDigest messageDigest;
DigestingConsumer(MessageDigest messageDigest) {
this.messageDigest = messageDigest;
}
public void write(byte b) throws IOException {
messageDigest.update(b);
}
public void write(byte[] b) throws IOException {
messageDigest.update(b);
}
public void write(byte[] b, int off, int len) throws IOException {
messageDigest.update(b, off, len);
}
}

View File

@@ -38,7 +38,7 @@ class MessageEncoderImpl implements MessageEncoder {
w.writeRaw(parent);
w.writeRaw(group);
w.writeInt64(timestamp);
w.writeUtf8(nick);
w.writeString(nick);
w.writeRaw(encodedKey);
w.writeRaw(body);
byte[] signable = out.toByteArray();

View File

@@ -40,6 +40,8 @@ class MessageParserImpl implements MessageParser {
if(raw.length > Message.MAX_SIZE) throw new FormatException();
ByteArrayInputStream in = new ByteArrayInputStream(raw);
Reader r = readerFactory.createReader(in);
CountingConsumer counting = new CountingConsumer(Message.MAX_SIZE);
r.addConsumer(counting);
// Read the parent message ID
byte[] idBytes = r.readRaw();
if(idBytes.length != UniqueId.LENGTH) throw new FormatException();
@@ -51,7 +53,7 @@ class MessageParserImpl implements MessageParser {
// Read the timestamp
long timestamp = r.readInt64();
// Hash the author's nick and public key to get the author ID
String nick = r.readUtf8();
String nick = r.readString();
byte[] encodedKey = r.readRaw();
messageDigest.reset();
messageDigest.update(nick.getBytes("UTF-8"));
@@ -61,7 +63,8 @@ class MessageParserImpl implements MessageParser {
// Skip the message body
r.readRaw();
// Record the length of the signed data
int messageLength = (int) r.getRawBytesRead();
int messageLength = (int) counting.getCount();
r.removeConsumer(counting);
// Read the signature
byte[] sig = r.readRaw();
// That should be all

View File

@@ -0,0 +1,41 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.security.Signature;
import java.security.SignatureException;
import net.sf.briar.api.serial.Consumer;
/** A consumer that passes its input through a signature. */
class SigningConsumer implements Consumer {
private final Signature signature;
SigningConsumer(Signature signature) {
this.signature = signature;
}
public void write(byte b) throws IOException {
try {
signature.update(b);
} catch(SignatureException e) {
throw new IOException(e.getMessage());
}
}
public void write(byte[] b) throws IOException {
try {
signature.update(b);
} catch(SignatureException e) {
throw new IOException(e.getMessage());
}
}
public void write(byte[] b, int off, int len) throws IOException {
try {
signature.update(b, off, len);
} catch(SignatureException e) {
throw new IOException(e.getMessage());
}
}
}

View File

@@ -1,116 +0,0 @@
package net.sf.briar.protocol;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.MessageDigest;
import java.security.Signature;
import java.security.SignatureException;
/**
* An input stream that passes its input through a signature and a message
* digest. The signature and message digest lag behind the input by one byte
* until the end of the input is reached, to allow users of this class to
* maintain one byte of lookahead without affecting the signature or digest.
*/
class SigningDigestingInputStream extends FilterInputStream {
private final Signature signature;
private final MessageDigest messageDigest;
private byte nextByte = 0;
private boolean started = false, eof = false;
private boolean signing = false, digesting = false;
protected SigningDigestingInputStream(InputStream in, Signature signature,
MessageDigest messageDigest) {
super(in);
this.signature = signature;
this.messageDigest = messageDigest;
}
public void setSigning(boolean signing) {
this.signing = signing;
}
public void setDigesting(boolean digesting) {
this.digesting = digesting;
}
private void write(byte b) throws IOException {
if(signing) {
try {
signature.update(b);
} catch(SignatureException e) {
throw new IOException(e.getMessage());
}
}
if(digesting) messageDigest.update(b);
}
private void write(byte[] b, int off, int len) throws IOException {
if(signing) {
try {
signature.update(b, off, len);
} catch(SignatureException e) {
throw new IOException(e.getMessage());
}
}
if(digesting) messageDigest.update(b, off, len);
}
@Override
public void mark(int readLimit) {
throw new UnsupportedOperationException();
}
@Override
public boolean markSupported() {
return false;
}
@Override
public int read() throws IOException {
if(eof) return -1;
if(started) write(nextByte);
started = true;
int i = in.read();
if(i == -1) {
eof = true;
return -1;
}
nextByte = (byte) (i > 127 ? i - 256 : i);
return i;
}
@Override
public int read(byte[] b) throws IOException {
return read(b, 0, b.length);
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
if(eof) return -1;
if(started) write(nextByte);
started = true;
int read = in.read(b, off, len);
if(read == -1) {
eof = true;
return -1;
}
if(read > 0) {
write(b, off, read - 1);
nextByte = b[off + read - 1];
}
return read;
}
@Override
public void reset() {
throw new UnsupportedOperationException();
}
@Override
public long skip(long n) {
throw new UnsupportedOperationException();
}
}

View File

@@ -7,6 +7,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.RawByteArray;
import net.sf.briar.api.serial.Reader;
@@ -14,12 +15,10 @@ import net.sf.briar.api.serial.Tag;
class ReaderImpl implements Reader {
private static final int TOO_LARGE_TO_KEEP = 4096;
private final InputStream in;
private boolean started = false, eof = false, readLimited = false;
private Consumer[] consumers = new Consumer[] {};
private boolean started = false, eof = false;
private byte next;
private long rawBytesRead = 0L, readLimit = 0L;
private byte[] buf = null;
ReaderImpl(InputStream in) {
@@ -32,38 +31,45 @@ class ReaderImpl implements Reader {
}
private byte readNext(boolean eofAcceptable) throws IOException {
assert !eof;
if(started) for(Consumer c : consumers) c.write(next);
started = true;
int i = in.read();
if(i == -1) {
eof = true;
if(!eofAcceptable) throw new FormatException();
} else rawBytesRead++;
started = true;
}
if(i > 127) i -= 256;
next = (byte) i;
return next;
}
public void setReadLimit(long limit) {
assert limit >= 0L && limit < Long.MAX_VALUE;
readLimited = true;
readLimit = limit;
}
public void resetReadLimit() {
readLimited = false;
readLimit = 0L;
}
public long getRawBytesRead() {
if(eof) return rawBytesRead;
else if(started) return rawBytesRead - 1L; // Exclude lookahead byte
else return 0L;
}
public void close() throws IOException {
buf = null;
in.close();
}
public void addConsumer(Consumer c) {
Consumer[] newConsumers = new Consumer[consumers.length + 1];
System.arraycopy(consumers, 0, newConsumers, 0, consumers.length);
newConsumers[consumers.length] = c;
consumers = newConsumers;
}
public void removeConsumer(Consumer c) {
if(consumers.length == 0) throw new IllegalArgumentException();
Consumer[] newConsumers = new Consumer[consumers.length - 1];
boolean found = false;
for(int src = 0, dest = 0; src < consumers.length; src++, dest++) {
if(!found && consumers[src].equals(c)) {
found = true;
src++;
} else newConsumers[dest] = consumers[src];
}
if(found) consumers = newConsumers;
else throw new IllegalArgumentException();
}
public boolean hasBoolean() throws IOException {
if(!started) readNext(true);
if(eof) return false;
@@ -140,15 +146,20 @@ class ReaderImpl implements Reader {
assert length > 0;
if(buf == null || buf.length < length) buf = new byte[length];
buf[0] = next;
int offset = 1, read = 0;
int offset = 1;
while(offset < length) {
read = in.read(buf, offset, length - offset);
int read = in.read(buf, offset, length - offset);
if(read == -1) break;
offset += read;
rawBytesRead += read;
}
if(offset < length) throw new FormatException();
readNext(true);
// Feed the hungry mouths
for(Consumer c : consumers) c.write(buf, 0, length);
// Read the lookahead byte
int i = in.read();
if(i == -1) eof = true;
if(i > 127) i -= 256;
next = (byte) i;
}
public boolean hasInt64() throws IOException {
@@ -212,31 +223,40 @@ class ReaderImpl implements Reader {
return Double.longBitsToDouble(readInt64Bits());
}
public boolean hasUtf8() throws IOException {
public boolean hasString() throws IOException {
if(!started) readNext(true);
if(eof) return false;
return next == Tag.UTF8;
return next == Tag.STRING;
}
public String readUtf8() throws IOException {
if(!hasUtf8()) throw new FormatException();
public String readString() throws IOException {
if(!hasString()) throw new FormatException();
readNext(false);
long l = readIntAny();
if(l < 0 || l > Integer.MAX_VALUE) throw new FormatException();
int length = (int) l;
int length = readLength();
if(length == 0) return "";
checkLimit(length);
readIntoBuffer(length);
String s = new String(buf, 0, length, "UTF-8");
if(length >= TOO_LARGE_TO_KEEP) buf = null;
return s;
return new String(buf, 0, length, "UTF-8");
}
private boolean hasLength() throws IOException {
if(!started) readNext(true);
if(eof) return false;
return next >= 0 || next == Tag.INT8 || next == Tag.INT16
|| next == Tag.INT32;
}
private int readLength() throws IOException {
if(!hasLength()) throw new FormatException();
if(next >= 0) return readUint7();
if(next == Tag.INT8) return readInt8();
if(next == Tag.INT16) return readInt16();
if(next == Tag.INT32) return readInt32();
throw new IllegalStateException();
}
private void checkLimit(long bytes) throws FormatException {
if(readLimited) {
if(bytes > readLimit) throw new FormatException();
readLimit -= bytes;
}
// FIXME
}
public boolean hasRaw() throws IOException {
@@ -248,9 +268,7 @@ class ReaderImpl implements Reader {
public byte[] readRaw() throws IOException {
if(!hasRaw()) throw new FormatException();
readNext(false);
long l = readIntAny();
if(l < 0 || l > Integer.MAX_VALUE) throw new FormatException();
int length = (int) l;
int length = readLength();
if(length == 0) return new byte[] {};
checkLimit(length);
readIntoBuffer(length);
@@ -262,7 +280,7 @@ class ReaderImpl implements Reader {
public boolean hasList() throws IOException {
if(!started) readNext(true);
if(eof) return false;
return next == Tag.LIST_DEF || next == Tag.LIST_INDEF;
return next == Tag.LIST || next == Tag.LIST_START;
}
public List<Object> readList() throws IOException {
@@ -271,13 +289,11 @@ class ReaderImpl implements Reader {
public <E> List<E> readList(Class<E> e) throws IOException {
if(!hasList()) throw new FormatException();
boolean definite = next == Tag.LIST_DEF;
boolean definite = next == Tag.LIST;
readNext(false);
List<E> list = new ArrayList<E>();
if(definite) {
long l = readIntAny();
if(l < 0 || l > Integer.MAX_VALUE) throw new FormatException();
int length = (int) l;
int length = readLength();
for(int i = 0; i < length; i++) list.add(readObject(e));
} else {
while(!hasEnd()) list.add(readObject(e));
@@ -299,6 +315,7 @@ class ReaderImpl implements Reader {
}
private Object readObject() throws IOException {
// FIXME: Use a switch statement
if(!started) throw new IllegalStateException();
if(hasBoolean()) return Boolean.valueOf(readBoolean());
if(hasUint7()) return Byte.valueOf(readUint7());
@@ -308,7 +325,7 @@ class ReaderImpl implements Reader {
if(hasInt64()) return Long.valueOf(readInt64());
if(hasFloat32()) return Float.valueOf(readFloat32());
if(hasFloat64()) return Double.valueOf(readFloat64());
if(hasUtf8()) return readUtf8();
if(hasString()) return readString();
if(hasRaw()) return new RawByteArray(readRaw());
if(hasList()) return readList();
if(hasMap()) return readMap();
@@ -331,7 +348,7 @@ class ReaderImpl implements Reader {
public boolean hasListStart() throws IOException {
if(!started) readNext(true);
if(eof) return false;
return next == Tag.LIST_INDEF;
return next == Tag.LIST_START;
}
public void readListStart() throws IOException {
@@ -350,7 +367,7 @@ class ReaderImpl implements Reader {
public boolean hasMap() throws IOException {
if(!started) readNext(true);
if(eof) return false;
return next == Tag.MAP_DEF || next == Tag.MAP_INDEF;
return next == Tag.MAP || next == Tag.MAP_START;
}
public Map<Object, Object> readMap() throws IOException {
@@ -359,13 +376,11 @@ class ReaderImpl implements Reader {
public <K, V> Map<K, V> readMap(Class<K> k, Class<V> v) throws IOException {
if(!hasMap()) throw new FormatException();
boolean definite = next == Tag.MAP_DEF;
boolean definite = next == Tag.MAP;
readNext(false);
Map<K, V> m = new HashMap<K, V>();
if(definite) {
long l = readIntAny();
if(l < 0 || l > Integer.MAX_VALUE) throw new FormatException();
int length = (int) l;
int length = readLength();
for(int i = 0; i < length; i++) m.put(readObject(k), readObject(v));
} else {
while(!hasEnd()) m.put(readObject(k), readObject(v));
@@ -377,7 +392,7 @@ class ReaderImpl implements Reader {
public boolean hasMapStart() throws IOException {
if(!started) readNext(true);
if(eof) return false;
return next == Tag.MAP_INDEF;
return next == Tag.MAP_START;
}
public void readMapStart() throws IOException {

View File

@@ -13,14 +13,14 @@ import net.sf.briar.api.serial.Writer;
class WriterImpl implements Writer {
private final OutputStream out;
private long rawBytesWritten = 0L;
private long bytesWritten = 0L;
WriterImpl(OutputStream out) {
this.out = out;
}
public long getRawBytesWritten() {
return rawBytesWritten;
public long getBytesWritten() {
return bytesWritten;
}
public void close() throws IOException {
@@ -31,32 +31,32 @@ class WriterImpl implements Writer {
public void writeBoolean(boolean b) throws IOException {
if(b) out.write(Tag.TRUE);
else out.write(Tag.FALSE);
rawBytesWritten++;
bytesWritten++;
}
public void writeUint7(byte b) throws IOException {
if(b < 0) throw new IllegalArgumentException();
out.write(b);
rawBytesWritten++;
bytesWritten++;
}
public void writeInt8(byte b) throws IOException {
out.write(Tag.INT8);
out.write(b);
rawBytesWritten += 2;
bytesWritten += 2;
}
public void writeInt16(short s) throws IOException {
out.write(Tag.INT16);
out.write((byte) (s >> 8));
out.write((byte) ((s << 8) >> 8));
rawBytesWritten += 3;
bytesWritten += 3;
}
public void writeInt32(int i) throws IOException {
out.write(Tag.INT32);
writeInt32Bits(i);
rawBytesWritten += 5;
bytesWritten += 5;
}
private void writeInt32Bits(int i) throws IOException {
@@ -69,7 +69,7 @@ class WriterImpl implements Writer {
public void writeInt64(long l) throws IOException {
out.write(Tag.INT64);
writeInt64Bits(l);
rawBytesWritten += 9;
bytesWritten += 9;
}
private void writeInt64Bits(long l) throws IOException {
@@ -98,28 +98,28 @@ class WriterImpl implements Writer {
public void writeFloat32(float f) throws IOException {
out.write(Tag.FLOAT32);
writeInt32Bits(Float.floatToRawIntBits(f));
rawBytesWritten += 5;
bytesWritten += 5;
}
public void writeFloat64(double d) throws IOException {
out.write(Tag.FLOAT64);
writeInt64Bits(Double.doubleToRawLongBits(d));
rawBytesWritten += 9;
bytesWritten += 9;
}
public void writeUtf8(String s) throws IOException {
out.write(Tag.UTF8);
public void writeString(String s) throws IOException {
out.write(Tag.STRING);
byte[] b = s.getBytes("UTF-8");
writeIntAny(b.length);
out.write(b);
rawBytesWritten += b.length + 1;
bytesWritten += b.length + 1;
}
public void writeRaw(byte[] b) throws IOException {
out.write(Tag.RAW);
writeIntAny(b.length);
out.write(b);
rawBytesWritten += b.length + 1;
bytesWritten += b.length + 1;
}
public void writeRaw(Raw r) throws IOException {
@@ -127,8 +127,8 @@ class WriterImpl implements Writer {
}
public void writeList(List<?> l) throws IOException {
out.write(Tag.LIST_DEF);
rawBytesWritten++;
out.write(Tag.LIST);
bytesWritten++;
writeIntAny(l.size());
for(Object o : l) writeObject(o);
}
@@ -141,7 +141,7 @@ class WriterImpl implements Writer {
else if(o instanceof Long) writeIntAny((Long) o);
else if(o instanceof Float) writeFloat32((Float) o);
else if(o instanceof Double) writeFloat64((Double) o);
else if(o instanceof String) writeUtf8((String) o);
else if(o instanceof String) writeString((String) o);
else if(o instanceof Raw) writeRaw((Raw) o);
else if(o instanceof List) writeList((List<?>) o);
else if(o instanceof Map) writeMap((Map<?, ?>) o);
@@ -150,18 +150,18 @@ class WriterImpl implements Writer {
}
public void writeListStart() throws IOException {
out.write(Tag.LIST_INDEF);
rawBytesWritten++;
out.write(Tag.LIST_START);
bytesWritten++;
}
public void writeListEnd() throws IOException {
out.write(Tag.END);
rawBytesWritten++;
bytesWritten++;
}
public void writeMap(Map<?, ?> m) throws IOException {
out.write(Tag.MAP_DEF);
rawBytesWritten++;
out.write(Tag.MAP);
bytesWritten++;
writeIntAny(m.size());
for(Entry<?, ?> e : m.entrySet()) {
writeObject(e.getKey());
@@ -170,17 +170,17 @@ class WriterImpl implements Writer {
}
public void writeMapStart() throws IOException {
out.write(Tag.MAP_INDEF);
rawBytesWritten++;
out.write(Tag.MAP_START);
bytesWritten++;
}
public void writeMapEnd() throws IOException {
out.write(Tag.END);
rawBytesWritten++;
bytesWritten++;
}
public void writeNull() throws IOException {
out.write(Tag.NULL);
rawBytesWritten++;
bytesWritten++;
}
}

View File

@@ -21,7 +21,7 @@
<test name='net.sf.briar.i18n.I18nTest'/>
<test name='net.sf.briar.invitation.InvitationWorkerTest'/>
<test name='net.sf.briar.protocol.BundleReadWriteTest'/>
<test name='net.sf.briar.protocol.SigningStreamTest'/>
<test name='net.sf.briar.protocol.ConsumersTest'/>
<test name='net.sf.briar.serial.ReaderImplTest'/>
<test name='net.sf.briar.serial.WriterImplTest'/>
<test name='net.sf.briar.setup.SetupWorkerTest'/>

View File

@@ -36,6 +36,7 @@ import net.sf.briar.api.protocol.MessageParser;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.Raw;
import net.sf.briar.api.serial.RawByteArray;
import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory;
import net.sf.briar.api.serial.WriterFactory;
import net.sf.briar.serial.SerialModule;
@@ -125,7 +126,8 @@ public class BundleReadWriteTest extends TestCase {
MessageParser messageParser =
new MessageParserImpl(keyParser, sig, dig, rf);
FileInputStream in = new FileInputStream(bundle);
BundleReader r = new BundleReaderImpl(in, rf, keyPair.getPublic(), sig,
Reader reader = rf.createReader(in);
BundleReader r = new BundleReaderImpl(reader, keyPair.getPublic(), sig,
dig, messageParser, new HeaderFactoryImpl(),
new BatchFactoryImpl());
@@ -164,7 +166,8 @@ public class BundleReadWriteTest extends TestCase {
MessageParser messageParser =
new MessageParserImpl(keyParser, sig, dig, rf);
FileInputStream in = new FileInputStream(bundle);
BundleReader r = new BundleReaderImpl(in, rf, keyPair.getPublic(), sig,
Reader reader = rf.createReader(in);
BundleReader r = new BundleReaderImpl(reader, keyPair.getPublic(), sig,
dig, messageParser, new HeaderFactoryImpl(),
new BatchFactoryImpl());

View File

@@ -0,0 +1,73 @@
package net.sf.briar.protocol;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.MessageDigest;
import java.security.Signature;
import java.util.Arrays;
import java.util.Random;
import junit.framework.TestCase;
import net.sf.briar.api.serial.FormatException;
import org.junit.Test;
public class ConsumersTest extends TestCase {
private static final String SIGNATURE_ALGO = "SHA256withRSA";
private static final String KEY_PAIR_ALGO = "RSA";
private static final String DIGEST_ALGO = "SHA-256";
@Test
public void testSigningConsumer() throws Exception {
Signature s = Signature.getInstance(SIGNATURE_ALGO);
KeyPair k = KeyPairGenerator.getInstance(KEY_PAIR_ALGO).genKeyPair();
byte[] data = new byte[1234];
// Generate some random data and sign it
new Random().nextBytes(data);
s.initSign(k.getPrivate());
s.update(data);
byte[] sig = s.sign();
// Check that feeding a SigningConsumer generates the same signature
s.initSign(k.getPrivate());
SigningConsumer sc = new SigningConsumer(s);
sc.write(data[0]);
sc.write(data, 1, data.length - 2);
sc.write(data[data.length - 1]);
byte[] sig1 = s.sign();
assertTrue(Arrays.equals(sig, sig1));
}
@Test
public void testDigestingConsumer() throws Exception {
MessageDigest m = MessageDigest.getInstance(DIGEST_ALGO);
byte[] data = new byte[1234];
// Generate some random data and digest it
new Random().nextBytes(data);
m.reset();
m.update(data);
byte[] dig = m.digest();
// Check that feeding a DigestingConsumer generates the same digest
m.reset();
DigestingConsumer dc = new DigestingConsumer(m);
dc.write(data[0]);
dc.write(data, 1, data.length - 2);
dc.write(data[data.length - 1]);
byte[] dig1 = m.digest();
assertTrue(Arrays.equals(dig, dig1));
}
@Test
public void testCountingConsumer() throws Exception {
byte[] data = new byte[1234];
CountingConsumer cc = new CountingConsumer(data.length);
cc.write(data[0]);
cc.write(data, 1, data.length - 2);
cc.write(data[data.length - 1]);
assertEquals(data.length, cc.getCount());
try {
cc.write((byte) 0);
assertTrue(false);
} catch(FormatException expected) {}
}
}

View File

@@ -1,160 +0,0 @@
package net.sf.briar.protocol;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.MessageDigest;
import java.security.Signature;
import java.util.Arrays;
import java.util.Random;
import junit.framework.TestCase;
import org.junit.Test;
public class SigningStreamTest extends TestCase {
private static final String SIGNATURE_ALGO = "SHA256withRSA";
private static final String KEY_PAIR_ALGO = "RSA";
private static final String DIGEST_ALGO = "SHA-256";
private final KeyPair keyPair;
private final Signature sig;
private final MessageDigest dig;
private final Random random;
public SigningStreamTest() throws Exception {
super();
keyPair = KeyPairGenerator.getInstance(KEY_PAIR_ALGO).generateKeyPair();
sig = Signature.getInstance(SIGNATURE_ALGO);
dig = MessageDigest.getInstance(DIGEST_ALGO);
random = new Random();
}
@Test
public void testOutputStreamOutputMatchesInput() throws Exception {
byte[] input = new byte[1000];
random.nextBytes(input);
ByteArrayOutputStream out = new ByteArrayOutputStream();
SigningDigestingOutputStream signOut =
new SigningDigestingOutputStream(out, sig, dig);
sig.initSign(keyPair.getPrivate());
signOut.setSigning(true);
signOut.write(input, 0, 500);
signOut.setSigning(false);
signOut.write(input, 500, 250);
signOut.setSigning(true);
signOut.write(input, 750, 250);
byte[] output = out.toByteArray();
assertTrue(Arrays.equals(input, output));
}
@Test
public void testInputStreamOutputMatchesInput() throws Exception {
byte[] input = new byte[1000];
random.nextBytes(input);
ByteArrayInputStream in = new ByteArrayInputStream(input);
SigningDigestingInputStream signIn =
new SigningDigestingInputStream(in, sig, dig);
sig.initVerify(keyPair.getPublic());
byte[] output = new byte[1000];
signIn.setSigning(true);
assertEquals(500, signIn.read(output, 0, 500));
signIn.setSigning(false);
assertEquals(250, signIn.read(output, 500, 250));
signIn.setSigning(true);
assertEquals(250, signIn.read(output, 750, 250));
assertTrue(Arrays.equals(input, output));
}
@Test
public void testVerificationLagsByOneByte() throws Exception {
byte[] input = new byte[1000];
random.nextBytes(input);
ByteArrayOutputStream out = new ByteArrayOutputStream();
SigningDigestingOutputStream signOut =
new SigningDigestingOutputStream(out, sig, dig);
sig.initSign(keyPair.getPrivate());
// Sign bytes 0-499, skip bytes 500-749, sign bytes 750-999
signOut.setSigning(true);
signOut.write(input, 0, 500);
signOut.setSigning(false);
signOut.write(input, 500, 250);
signOut.setSigning(true);
signOut.write(input, 750, 250);
byte[] signature = sig.sign();
ByteArrayInputStream in = new ByteArrayInputStream(input);
SigningDigestingInputStream signIn =
new SigningDigestingInputStream(in, sig, dig);
sig.initVerify(keyPair.getPublic());
byte[] output = new byte[1000];
// Consume a lookahead byte
assertEquals(1, signIn.read(output, 0, 1));
// All the offsets are increased by 1 because of the lookahead byte
signIn.setSigning(true);
assertEquals(500, signIn.read(output, 1, 500));
signIn.setSigning(false);
assertEquals(250, signIn.read(output, 501, 250));
signIn.setSigning(true);
assertEquals(249, signIn.read(output, 751, 249));
// Have to reach EOF for the lookahead byte to be processed
assertEquals(-1, signIn.read());
assertTrue(Arrays.equals(input, output));
assertTrue(sig.verify(signature));
}
@Test
public void testDigestionLagsByOneByte() throws Exception {
byte[] input = new byte[1000];
random.nextBytes(input);
ByteArrayOutputStream out = new ByteArrayOutputStream();
SigningDigestingOutputStream signOut =
new SigningDigestingOutputStream(out, sig, dig);
dig.reset();
// Digest bytes 0-499, skip bytes 500-749, digest bytes 750-999
signOut.setDigesting(true);
signOut.write(input, 0, 500);
signOut.setDigesting(false);
signOut.write(input, 500, 250);
signOut.setDigesting(true);
signOut.write(input, 750, 250);
byte[] hash = dig.digest();
ByteArrayInputStream in = new ByteArrayInputStream(input);
SigningDigestingInputStream signIn =
new SigningDigestingInputStream(in, sig, dig);
dig.reset();
byte[] output = new byte[1000];
// Consume a lookahead byte
assertEquals(1, signIn.read(output, 0, 1));
// All the offsets are increased by 1 because of the lookahead byte
signIn.setDigesting(true);
assertEquals(500, signIn.read(output, 1, 500));
signIn.setDigesting(false);
assertEquals(250, signIn.read(output, 501, 250));
signIn.setDigesting(true);
assertEquals(249, signIn.read(output, 751, 249));
// Have to reach EOF for the lookahead byte to be processed
assertEquals(-1, signIn.read());
assertTrue(Arrays.equals(input, output));
assertTrue(Arrays.equals(hash, dig.digest()));
}
}

View File

@@ -3,13 +3,11 @@ package net.sf.briar.serial;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import junit.framework.TestCase;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.Raw;
import net.sf.briar.api.serial.RawByteArray;
import net.sf.briar.util.StringUtils;
@@ -121,38 +119,11 @@ public class ReaderImplTest extends TestCase {
}
@Test
public void testReadUtf8() throws IOException {
public void testReadString() throws IOException {
setContents("F703666F6F" + "F703666F6F" + "F700");
assertEquals("foo", r.readUtf8());
assertEquals("foo", r.readUtf8());
assertEquals("", r.readUtf8());
assertTrue(r.eof());
}
@Test
public void testReadUtf8LimitNotExceeded() throws IOException {
setContents("F703666F6F");
r.setReadLimit(3);
assertEquals("foo", r.readUtf8());
assertTrue(r.eof());
}
@Test
public void testReadUtf8LimitExceeded() throws IOException {
setContents("F703666F6F");
r.setReadLimit(2);
try {
r.readUtf8();
assertTrue(false);
} catch(FormatException expected) {}
}
@Test
public void testReadUtf8LimitReset() throws IOException {
setContents("F703666F6F");
r.setReadLimit(2);
r.resetReadLimit();
assertEquals("foo", r.readUtf8());
assertEquals("foo", r.readString());
assertEquals("foo", r.readString());
assertEquals("", r.readString());
assertTrue(r.eof());
}
@@ -165,33 +136,6 @@ public class ReaderImplTest extends TestCase {
assertTrue(r.eof());
}
@Test
public void testReadRawLimitNotExceeded() throws IOException {
setContents("F603010203");
r.setReadLimit(3);
assertTrue(Arrays.equals(new byte[] {1, 2, 3}, r.readRaw()));
assertTrue(r.eof());
}
@Test
public void testReadRawMaxLengthExceeded() throws IOException {
setContents("F603010203");
r.setReadLimit(2);
try {
r.readRaw();
assertTrue(false);
} catch(FormatException expected) {}
}
@Test
public void testReadRawLimitReset() throws IOException {
setContents("F603010203");
r.setReadLimit(2);
r.resetReadLimit();
assertTrue(Arrays.equals(new byte[] {1, 2, 3}, r.readRaw()));
assertTrue(r.eof());
}
@Test
public void testReadDefiniteList() throws IOException {
setContents("F5" + "03" + "01" + "F703666F6F" + "FC0080");
@@ -261,7 +205,7 @@ public class ReaderImplTest extends TestCase {
assertFalse(r.hasListEnd());
assertEquals((byte) 1, r.readIntAny());
assertFalse(r.hasListEnd());
assertEquals("foo", r.readUtf8());
assertEquals("foo", r.readString());
assertFalse(r.hasListEnd());
assertEquals((short) 128, r.readIntAny());
assertTrue(r.hasListEnd());
@@ -300,7 +244,7 @@ public class ReaderImplTest extends TestCase {
assertTrue(r.hasMapStart());
r.readMapStart();
assertFalse(r.hasMapEnd());
assertEquals("foo", r.readUtf8());
assertEquals("foo", r.readString());
assertFalse(r.hasMapEnd());
assertEquals((byte) 123, r.readIntAny());
assertFalse(r.hasMapEnd());
@@ -345,25 +289,10 @@ public class ReaderImplTest extends TestCase {
assertTrue(r.eof());
}
@Test
public void testGetRawBytesRead() throws IOException {
setContents("F4" + "00" + "F4" + "00");
assertEquals(0L, r.getRawBytesRead());
Map<Object, Object> m = r.readMap(Object.class, Object.class);
assertEquals(2L, r.getRawBytesRead());
assertEquals(Collections.emptyMap(), m);
m = r.readMap(Object.class, Object.class);
assertEquals(4L, r.getRawBytesRead());
assertEquals(Collections.emptyMap(), m);
assertTrue(r.eof());
assertEquals(4L, r.getRawBytesRead());
}
@Test
public void testReadEmptyInput() throws IOException {
setContents("");
assertTrue(r.eof());
assertEquals(0L, r.getRawBytesRead());
}
private void setContents(String hex) {

View File

@@ -129,7 +129,7 @@ public class WriterImplTest extends TestCase {
@Test
public void testWriteUtf8() throws IOException {
w.writeUtf8("foo");
w.writeString("foo");
// UTF-8 tag, length as uint7, UTF-8 bytes
checkContents("F7" + "03" + "666F6F");
}
@@ -170,7 +170,7 @@ public class WriterImplTest extends TestCase {
public void testWriteIndefiniteList() throws IOException {
w.writeListStart();
w.writeIntAny((byte) 1); // Written as uint7
w.writeUtf8("foo");
w.writeString("foo");
w.writeIntAny(128L); // Written as an int16
w.writeListEnd();
checkContents("F3" + "01" + "F703666F6F" + "FC0080" + "F1");
@@ -179,7 +179,7 @@ public class WriterImplTest extends TestCase {
@Test
public void testWriteIndefiniteMap() throws IOException {
w.writeMapStart();
w.writeUtf8("foo");
w.writeString("foo");
w.writeIntAny(123); // Written as a uint7
w.writeRaw(new byte[] {});
w.writeNull();
@@ -212,6 +212,6 @@ public class WriterImplTest extends TestCase {
byte[] expected = StringUtils.fromHexString(hex);
assertTrue(StringUtils.toHexString(out.toByteArray()),
Arrays.equals(expected, out.toByteArray()));
assertEquals(expected.length, w.getRawBytesWritten());
assertEquals(expected.length, w.getBytesWritten());
}
}