Added Consumer support to Writer, to avoid redundant copying.

This commit is contained in:
akwizgran
2011-09-28 18:47:24 +01:00
parent 9c2e3917bf
commit a8b96f11fd
12 changed files with 134 additions and 97 deletions

View File

@@ -6,7 +6,5 @@ public interface Consumer {
void write(byte b) throws IOException;
void write(byte[] b) throws IOException;
void write(byte[] b, int off, int len) throws IOException;
}

View File

@@ -6,6 +6,9 @@ import java.util.Map;
public interface Writer {
void addConsumer(Consumer c);
void removeConsumer(Consumer c);
void writeBoolean(boolean b) throws IOException;
void writeUint7(byte b) throws IOException;

View File

@@ -18,10 +18,6 @@ class CopyingConsumer implements Consumer {
out.write(b);
}
public void write(byte[] b) throws IOException {
out.write(b);
}
public void write(byte[] b, int off, int len) throws IOException {
out.write(b, off, len);
}

View File

@@ -27,11 +27,6 @@ class CountingConsumer implements Consumer {
if(count > limit) throw new FormatException();
}
public void write(byte[] b) throws IOException {
count += b.length;
if(count > limit) throw new FormatException();
}
public void write(byte[] b, int off, int len) throws IOException {
count += len;
if(count > limit) throw new FormatException();

View File

@@ -17,10 +17,6 @@ class DigestingConsumer implements Consumer {
messageDigest.update(b);
}
public void write(byte[] b) {
messageDigest.update(b);
}
public void write(byte[] b, int off, int len) {
messageDigest.update(b, off, len);
}

View File

@@ -17,6 +17,7 @@ import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageEncoder;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
@@ -24,14 +25,15 @@ import com.google.inject.Inject;
class MessageEncoderImpl implements MessageEncoder {
private final Signature signature;
private final Signature authorSignature, groupSignature;
private final SecureRandom random;
private final MessageDigest messageDigest;
private final WriterFactory writerFactory;
@Inject
MessageEncoderImpl(CryptoComponent crypto, WriterFactory writerFactory) {
signature = crypto.getSignature();
authorSignature = crypto.getSignature();
groupSignature = crypto.getSignature();
random = crypto.getSecureRandom();
messageDigest = crypto.getMessageDigest();
this.writerFactory = writerFactory;
@@ -71,9 +73,23 @@ class MessageEncoderImpl implements MessageEncoder {
if(body.length > Message.MAX_BODY_LENGTH)
throw new IllegalArgumentException();
long timestamp = System.currentTimeMillis();
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
// Initialise the consumers
Consumer digestingConsumer = new DigestingConsumer(messageDigest);
w.addConsumer(digestingConsumer);
Consumer authorConsumer = null;
if(authorKey != null) {
authorSignature.initSign(authorKey);
authorConsumer = new SigningConsumer(authorSignature);
w.addConsumer(authorConsumer);
}
Consumer groupConsumer = null;
if(groupKey != null) {
groupSignature.initSign(groupKey);
groupConsumer = new SigningConsumer(groupSignature);
w.addConsumer(groupConsumer);
}
// Write the message
w.writeUserDefinedId(Types.MESSAGE);
if(parent == null) w.writeNull();
@@ -82,6 +98,7 @@ class MessageEncoderImpl implements MessageEncoder {
else group.writeTo(w);
if(author == null) w.writeNull();
else author.writeTo(w);
long timestamp = System.currentTimeMillis();
w.writeInt64(timestamp);
byte[] salt = new byte[Message.SALT_LENGTH];
random.nextBytes(salt);
@@ -91,9 +108,8 @@ class MessageEncoderImpl implements MessageEncoder {
if(authorKey == null) {
w.writeNull();
} else {
signature.initSign(authorKey);
signature.update(out.toByteArray());
byte[] sig = signature.sign();
w.removeConsumer(authorConsumer);
byte[] sig = authorSignature.sign();
if(sig.length > Message.MAX_SIGNATURE_LENGTH)
throw new IllegalArgumentException();
w.writeBytes(sig);
@@ -102,17 +118,15 @@ class MessageEncoderImpl implements MessageEncoder {
if(groupKey == null) {
w.writeNull();
} else {
signature.initSign(groupKey);
signature.update(out.toByteArray());
byte[] sig = signature.sign();
w.removeConsumer(groupConsumer);
byte[] sig = groupSignature.sign();
if(sig.length > Message.MAX_SIGNATURE_LENGTH)
throw new IllegalArgumentException();
w.writeBytes(sig);
}
// Hash the message, including the signatures, to get the message ID
w.removeConsumer(digestingConsumer);
byte[] raw = out.toByteArray();
messageDigest.reset();
messageDigest.update(raw);
MessageId id = new MessageId(messageDigest.digest());
GroupId groupId = group == null ? null : group.getId();
AuthorId authorId = author == null ? null : author.getId();

View File

@@ -0,0 +1,33 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.security.Signature;
import java.security.SignatureException;
import net.sf.briar.api.serial.Consumer;
/** A consumer that passes its input through a signature. */
class SigningConsumer implements Consumer {
private final Signature signature;
SigningConsumer(Signature signature) {
this.signature = signature;
}
public void write(byte b) throws IOException {
try {
signature.update(b);
} catch(SignatureException e) {
throw new IOException(e.getMessage());
}
}
public void write(byte[] b, int off, int len) throws IOException {
try {
signature.update(b, off, len);
} catch(SignatureException e) {
throw new IOException(e.getMessage());
}
}
}

View File

@@ -19,8 +19,8 @@ class ReaderImpl implements Reader {
private static final byte[] EMPTY_BUFFER = new byte[] {};
private final InputStream in;
private final List<Consumer> consumers = new ArrayList<Consumer>(0);
private Consumer[] consumers = new Consumer[] {};
private ObjectReader<?>[] objectReaders = new ObjectReader<?>[] {};
private boolean hasLookahead = false, eof = false;
private byte next, nextNext;
@@ -89,24 +89,11 @@ class ReaderImpl implements Reader {
}
public void addConsumer(Consumer c) {
Consumer[] newConsumers = new Consumer[consumers.length + 1];
System.arraycopy(consumers, 0, newConsumers, 0, consumers.length);
newConsumers[consumers.length] = c;
consumers = newConsumers;
consumers.add(c);
}
public void removeConsumer(Consumer c) {
if(consumers.length == 0) throw new IllegalArgumentException();
Consumer[] newConsumers = new Consumer[consumers.length - 1];
boolean found = false;
for(int src = 0, dest = 0; src < consumers.length; src++, dest++) {
if(!found && consumers[src].equals(c)) {
found = true;
src++;
} else newConsumers[dest] = consumers[src];
}
if(found) consumers = newConsumers;
else throw new IllegalArgumentException();
if(!consumers.remove(c)) throw new IllegalArgumentException();
}
public void addObjectReader(int id, ObjectReader<?> o) {

View File

@@ -2,70 +2,81 @@ package net.sf.briar.serial;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import net.sf.briar.api.Bytes;
import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.Writable;
import net.sf.briar.api.serial.Writer;
class WriterImpl implements Writer {
private final OutputStream out;
private final List<Consumer> consumers = new ArrayList<Consumer>(0);
WriterImpl(OutputStream out) {
this.out = out;
}
public void addConsumer(Consumer c) {
consumers.add(c);
}
public void removeConsumer(Consumer c) {
if(!consumers.remove(c)) throw new IllegalArgumentException();
}
public void writeBoolean(boolean b) throws IOException {
if(b) out.write(Tag.TRUE);
else out.write(Tag.FALSE);
if(b) write(Tag.TRUE);
else write(Tag.FALSE);
}
public void writeUint7(byte b) throws IOException {
if(b < 0) throw new IllegalArgumentException();
out.write(b);
write(b);
}
public void writeInt8(byte b) throws IOException {
out.write(Tag.INT8);
out.write(b);
write(Tag.INT8);
write(b);
}
public void writeInt16(short s) throws IOException {
out.write(Tag.INT16);
out.write((byte) (s >> 8));
out.write((byte) ((s << 8) >> 8));
write(Tag.INT16);
write((byte) (s >> 8));
write((byte) ((s << 8) >> 8));
}
public void writeInt32(int i) throws IOException {
out.write(Tag.INT32);
write(Tag.INT32);
writeInt32Bits(i);
}
private void writeInt32Bits(int i) throws IOException {
out.write((byte) (i >> 24));
out.write((byte) ((i << 8) >> 24));
out.write((byte) ((i << 16) >> 24));
out.write((byte) ((i << 24) >> 24));
write((byte) (i >> 24));
write((byte) ((i << 8) >> 24));
write((byte) ((i << 16) >> 24));
write((byte) ((i << 24) >> 24));
}
public void writeInt64(long l) throws IOException {
out.write(Tag.INT64);
write(Tag.INT64);
writeInt64Bits(l);
}
private void writeInt64Bits(long l) throws IOException {
out.write((byte) (l >> 56));
out.write((byte) ((l << 8) >> 56));
out.write((byte) ((l << 16) >> 56));
out.write((byte) ((l << 24) >> 56));
out.write((byte) ((l << 32) >> 56));
out.write((byte) ((l << 40) >> 56));
out.write((byte) ((l << 48) >> 56));
out.write((byte) ((l << 56) >> 56));
write((byte) (l >> 56));
write((byte) ((l << 8) >> 56));
write((byte) ((l << 16) >> 56));
write((byte) ((l << 24) >> 56));
write((byte) ((l << 32) >> 56));
write((byte) ((l << 40) >> 56));
write((byte) ((l << 48) >> 56));
write((byte) ((l << 56) >> 56));
}
public void writeIntAny(long l) throws IOException {
@@ -81,23 +92,23 @@ class WriterImpl implements Writer {
}
public void writeFloat32(float f) throws IOException {
out.write(Tag.FLOAT32);
write(Tag.FLOAT32);
writeInt32Bits(Float.floatToRawIntBits(f));
}
public void writeFloat64(double d) throws IOException {
out.write(Tag.FLOAT64);
write(Tag.FLOAT64);
writeInt64Bits(Double.doubleToRawLongBits(d));
}
public void writeString(String s) throws IOException {
byte[] b = s.getBytes("UTF-8");
if(b.length < 16) out.write((byte) (Tag.SHORT_STRING | b.length));
if(b.length < 16) write((byte) (Tag.SHORT_STRING | b.length));
else {
out.write(Tag.STRING);
write(Tag.STRING);
writeLength(b.length);
}
out.write(b);
write(b);
}
private void writeLength(int i) throws IOException {
@@ -109,19 +120,19 @@ class WriterImpl implements Writer {
}
public void writeBytes(byte[] b) throws IOException {
if(b.length < 16) out.write((byte) (Tag.SHORT_BYTES | b.length));
if(b.length < 16) write((byte) (Tag.SHORT_BYTES | b.length));
else {
out.write(Tag.BYTES);
write(Tag.BYTES);
writeLength(b.length);
}
out.write(b);
write(b);
}
public void writeList(Collection<?> c) throws IOException {
int length = c.size();
if(length < 16) out.write((byte) (Tag.SHORT_LIST | length));
if(length < 16) write((byte) (Tag.SHORT_LIST | length));
else {
out.write(Tag.LIST);
write(Tag.LIST);
writeLength(length);
}
for(Object o : c) writeObject(o);
@@ -145,18 +156,18 @@ class WriterImpl implements Writer {
}
public void writeListStart() throws IOException {
out.write(Tag.LIST_START);
write(Tag.LIST_START);
}
public void writeListEnd() throws IOException {
out.write(Tag.END);
write(Tag.END);
}
public void writeMap(Map<?, ?> m) throws IOException {
int length = m.size();
if(length < 16) out.write((byte) (Tag.SHORT_MAP | length));
if(length < 16) write((byte) (Tag.SHORT_MAP | length));
else {
out.write(Tag.MAP);
write(Tag.MAP);
writeLength(length);
}
for(Entry<?, ?> e : m.entrySet()) {
@@ -166,24 +177,34 @@ class WriterImpl implements Writer {
}
public void writeMapStart() throws IOException {
out.write(Tag.MAP_START);
write(Tag.MAP_START);
}
public void writeMapEnd() throws IOException {
out.write(Tag.END);
write(Tag.END);
}
public void writeNull() throws IOException {
out.write(Tag.NULL);
write(Tag.NULL);
}
public void writeUserDefinedId(int id) throws IOException {
if(id < 0 || id > 255) throw new IllegalArgumentException();
if(id < 32) {
out.write((byte) (Tag.SHORT_USER | id));
write((byte) (Tag.SHORT_USER | id));
} else {
out.write(Tag.USER);
out.write((byte) id);
write(Tag.USER);
write((byte) id);
}
}
private void write(byte b) throws IOException {
out.write(b);
for(Consumer c : consumers) c.write(b);
}
private void write(byte[] b) throws IOException {
out.write(b);
for(Consumer c : consumers) c.write(b, 0, b.length);
}
}

View File

@@ -17,10 +17,6 @@ class MacConsumer implements Consumer {
mac.update(b);
}
public void write(byte[] b) {
mac.update(b);
}
public void write(byte[] b, int off, int len) {
mac.update(b, off, len);
}

View File

@@ -130,7 +130,12 @@ public class FileReadWriteTest extends TestCase {
}
@Test
public void testWriteFile() throws Exception {
public void testWriteAndRead() throws Exception {
write();
read();
}
private void write() throws Exception {
OutputStream out = new FileOutputStream(file);
// Use Alice's secret for writing
ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out,
@@ -177,10 +182,7 @@ public class FileReadWriteTest extends TestCase {
assertTrue(file.length() > message.getSize());
}
@Test
public void testWriteAndReadFile() throws Exception {
testWriteFile();
private void read() throws Exception {
InputStream in = new FileInputStream(file);
byte[] iv = new byte[16];

View File

@@ -416,10 +416,6 @@ public class ReaderImplTest extends TestCase {
out.write(b);
}
public void write(byte[] b) throws IOException {
out.write(b);
}
public void write(byte[] b, int off, int len) throws IOException {
out.write(b, off, len);
}