Added Consumer support to Writer, to avoid redundant copying.

This commit is contained in:
akwizgran
2011-09-28 18:47:24 +01:00
parent 9c2e3917bf
commit a8b96f11fd
12 changed files with 134 additions and 97 deletions

View File

@@ -6,7 +6,5 @@ public interface Consumer {
void write(byte b) throws IOException; void write(byte b) throws IOException;
void write(byte[] b) throws IOException;
void write(byte[] b, int off, int len) throws IOException; void write(byte[] b, int off, int len) throws IOException;
} }

View File

@@ -6,6 +6,9 @@ import java.util.Map;
public interface Writer { public interface Writer {
void addConsumer(Consumer c);
void removeConsumer(Consumer c);
void writeBoolean(boolean b) throws IOException; void writeBoolean(boolean b) throws IOException;
void writeUint7(byte b) throws IOException; void writeUint7(byte b) throws IOException;

View File

@@ -18,10 +18,6 @@ class CopyingConsumer implements Consumer {
out.write(b); out.write(b);
} }
public void write(byte[] b) throws IOException {
out.write(b);
}
public void write(byte[] b, int off, int len) throws IOException { public void write(byte[] b, int off, int len) throws IOException {
out.write(b, off, len); out.write(b, off, len);
} }

View File

@@ -27,11 +27,6 @@ class CountingConsumer implements Consumer {
if(count > limit) throw new FormatException(); if(count > limit) throw new FormatException();
} }
public void write(byte[] b) throws IOException {
count += b.length;
if(count > limit) throw new FormatException();
}
public void write(byte[] b, int off, int len) throws IOException { public void write(byte[] b, int off, int len) throws IOException {
count += len; count += len;
if(count > limit) throw new FormatException(); if(count > limit) throw new FormatException();

View File

@@ -17,10 +17,6 @@ class DigestingConsumer implements Consumer {
messageDigest.update(b); messageDigest.update(b);
} }
public void write(byte[] b) {
messageDigest.update(b);
}
public void write(byte[] b, int off, int len) { public void write(byte[] b, int off, int len) {
messageDigest.update(b, off, len); messageDigest.update(b, off, len);
} }

View File

@@ -17,6 +17,7 @@ import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageEncoder; import net.sf.briar.api.protocol.MessageEncoder;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.Writer; import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory; import net.sf.briar.api.serial.WriterFactory;
@@ -24,14 +25,15 @@ import com.google.inject.Inject;
class MessageEncoderImpl implements MessageEncoder { class MessageEncoderImpl implements MessageEncoder {
private final Signature signature; private final Signature authorSignature, groupSignature;
private final SecureRandom random; private final SecureRandom random;
private final MessageDigest messageDigest; private final MessageDigest messageDigest;
private final WriterFactory writerFactory; private final WriterFactory writerFactory;
@Inject @Inject
MessageEncoderImpl(CryptoComponent crypto, WriterFactory writerFactory) { MessageEncoderImpl(CryptoComponent crypto, WriterFactory writerFactory) {
signature = crypto.getSignature(); authorSignature = crypto.getSignature();
groupSignature = crypto.getSignature();
random = crypto.getSecureRandom(); random = crypto.getSecureRandom();
messageDigest = crypto.getMessageDigest(); messageDigest = crypto.getMessageDigest();
this.writerFactory = writerFactory; this.writerFactory = writerFactory;
@@ -71,9 +73,23 @@ class MessageEncoderImpl implements MessageEncoder {
if(body.length > Message.MAX_BODY_LENGTH) if(body.length > Message.MAX_BODY_LENGTH)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
long timestamp = System.currentTimeMillis();
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out); Writer w = writerFactory.createWriter(out);
// Initialise the consumers
Consumer digestingConsumer = new DigestingConsumer(messageDigest);
w.addConsumer(digestingConsumer);
Consumer authorConsumer = null;
if(authorKey != null) {
authorSignature.initSign(authorKey);
authorConsumer = new SigningConsumer(authorSignature);
w.addConsumer(authorConsumer);
}
Consumer groupConsumer = null;
if(groupKey != null) {
groupSignature.initSign(groupKey);
groupConsumer = new SigningConsumer(groupSignature);
w.addConsumer(groupConsumer);
}
// Write the message // Write the message
w.writeUserDefinedId(Types.MESSAGE); w.writeUserDefinedId(Types.MESSAGE);
if(parent == null) w.writeNull(); if(parent == null) w.writeNull();
@@ -82,6 +98,7 @@ class MessageEncoderImpl implements MessageEncoder {
else group.writeTo(w); else group.writeTo(w);
if(author == null) w.writeNull(); if(author == null) w.writeNull();
else author.writeTo(w); else author.writeTo(w);
long timestamp = System.currentTimeMillis();
w.writeInt64(timestamp); w.writeInt64(timestamp);
byte[] salt = new byte[Message.SALT_LENGTH]; byte[] salt = new byte[Message.SALT_LENGTH];
random.nextBytes(salt); random.nextBytes(salt);
@@ -91,9 +108,8 @@ class MessageEncoderImpl implements MessageEncoder {
if(authorKey == null) { if(authorKey == null) {
w.writeNull(); w.writeNull();
} else { } else {
signature.initSign(authorKey); w.removeConsumer(authorConsumer);
signature.update(out.toByteArray()); byte[] sig = authorSignature.sign();
byte[] sig = signature.sign();
if(sig.length > Message.MAX_SIGNATURE_LENGTH) if(sig.length > Message.MAX_SIGNATURE_LENGTH)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
w.writeBytes(sig); w.writeBytes(sig);
@@ -102,17 +118,15 @@ class MessageEncoderImpl implements MessageEncoder {
if(groupKey == null) { if(groupKey == null) {
w.writeNull(); w.writeNull();
} else { } else {
signature.initSign(groupKey); w.removeConsumer(groupConsumer);
signature.update(out.toByteArray()); byte[] sig = groupSignature.sign();
byte[] sig = signature.sign();
if(sig.length > Message.MAX_SIGNATURE_LENGTH) if(sig.length > Message.MAX_SIGNATURE_LENGTH)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
w.writeBytes(sig); w.writeBytes(sig);
} }
// Hash the message, including the signatures, to get the message ID // Hash the message, including the signatures, to get the message ID
w.removeConsumer(digestingConsumer);
byte[] raw = out.toByteArray(); byte[] raw = out.toByteArray();
messageDigest.reset();
messageDigest.update(raw);
MessageId id = new MessageId(messageDigest.digest()); MessageId id = new MessageId(messageDigest.digest());
GroupId groupId = group == null ? null : group.getId(); GroupId groupId = group == null ? null : group.getId();
AuthorId authorId = author == null ? null : author.getId(); AuthorId authorId = author == null ? null : author.getId();

View File

@@ -0,0 +1,33 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.security.Signature;
import java.security.SignatureException;
import net.sf.briar.api.serial.Consumer;
/** A consumer that passes its input through a signature. */
class SigningConsumer implements Consumer {
private final Signature signature;
SigningConsumer(Signature signature) {
this.signature = signature;
}
public void write(byte b) throws IOException {
try {
signature.update(b);
} catch(SignatureException e) {
throw new IOException(e.getMessage());
}
}
public void write(byte[] b, int off, int len) throws IOException {
try {
signature.update(b, off, len);
} catch(SignatureException e) {
throw new IOException(e.getMessage());
}
}
}

View File

@@ -19,8 +19,8 @@ class ReaderImpl implements Reader {
private static final byte[] EMPTY_BUFFER = new byte[] {}; private static final byte[] EMPTY_BUFFER = new byte[] {};
private final InputStream in; private final InputStream in;
private final List<Consumer> consumers = new ArrayList<Consumer>(0);
private Consumer[] consumers = new Consumer[] {};
private ObjectReader<?>[] objectReaders = new ObjectReader<?>[] {}; private ObjectReader<?>[] objectReaders = new ObjectReader<?>[] {};
private boolean hasLookahead = false, eof = false; private boolean hasLookahead = false, eof = false;
private byte next, nextNext; private byte next, nextNext;
@@ -89,24 +89,11 @@ class ReaderImpl implements Reader {
} }
public void addConsumer(Consumer c) { public void addConsumer(Consumer c) {
Consumer[] newConsumers = new Consumer[consumers.length + 1]; consumers.add(c);
System.arraycopy(consumers, 0, newConsumers, 0, consumers.length);
newConsumers[consumers.length] = c;
consumers = newConsumers;
} }
public void removeConsumer(Consumer c) { public void removeConsumer(Consumer c) {
if(consumers.length == 0) throw new IllegalArgumentException(); if(!consumers.remove(c)) throw new IllegalArgumentException();
Consumer[] newConsumers = new Consumer[consumers.length - 1];
boolean found = false;
for(int src = 0, dest = 0; src < consumers.length; src++, dest++) {
if(!found && consumers[src].equals(c)) {
found = true;
src++;
} else newConsumers[dest] = consumers[src];
}
if(found) consumers = newConsumers;
else throw new IllegalArgumentException();
} }
public void addObjectReader(int id, ObjectReader<?> o) { public void addObjectReader(int id, ObjectReader<?> o) {

View File

@@ -2,70 +2,81 @@ package net.sf.briar.serial;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import net.sf.briar.api.Bytes; import net.sf.briar.api.Bytes;
import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.Writable; import net.sf.briar.api.serial.Writable;
import net.sf.briar.api.serial.Writer; import net.sf.briar.api.serial.Writer;
class WriterImpl implements Writer { class WriterImpl implements Writer {
private final OutputStream out; private final OutputStream out;
private final List<Consumer> consumers = new ArrayList<Consumer>(0);
WriterImpl(OutputStream out) { WriterImpl(OutputStream out) {
this.out = out; this.out = out;
} }
public void addConsumer(Consumer c) {
consumers.add(c);
}
public void removeConsumer(Consumer c) {
if(!consumers.remove(c)) throw new IllegalArgumentException();
}
public void writeBoolean(boolean b) throws IOException { public void writeBoolean(boolean b) throws IOException {
if(b) out.write(Tag.TRUE); if(b) write(Tag.TRUE);
else out.write(Tag.FALSE); else write(Tag.FALSE);
} }
public void writeUint7(byte b) throws IOException { public void writeUint7(byte b) throws IOException {
if(b < 0) throw new IllegalArgumentException(); if(b < 0) throw new IllegalArgumentException();
out.write(b); write(b);
} }
public void writeInt8(byte b) throws IOException { public void writeInt8(byte b) throws IOException {
out.write(Tag.INT8); write(Tag.INT8);
out.write(b); write(b);
} }
public void writeInt16(short s) throws IOException { public void writeInt16(short s) throws IOException {
out.write(Tag.INT16); write(Tag.INT16);
out.write((byte) (s >> 8)); write((byte) (s >> 8));
out.write((byte) ((s << 8) >> 8)); write((byte) ((s << 8) >> 8));
} }
public void writeInt32(int i) throws IOException { public void writeInt32(int i) throws IOException {
out.write(Tag.INT32); write(Tag.INT32);
writeInt32Bits(i); writeInt32Bits(i);
} }
private void writeInt32Bits(int i) throws IOException { private void writeInt32Bits(int i) throws IOException {
out.write((byte) (i >> 24)); write((byte) (i >> 24));
out.write((byte) ((i << 8) >> 24)); write((byte) ((i << 8) >> 24));
out.write((byte) ((i << 16) >> 24)); write((byte) ((i << 16) >> 24));
out.write((byte) ((i << 24) >> 24)); write((byte) ((i << 24) >> 24));
} }
public void writeInt64(long l) throws IOException { public void writeInt64(long l) throws IOException {
out.write(Tag.INT64); write(Tag.INT64);
writeInt64Bits(l); writeInt64Bits(l);
} }
private void writeInt64Bits(long l) throws IOException { private void writeInt64Bits(long l) throws IOException {
out.write((byte) (l >> 56)); write((byte) (l >> 56));
out.write((byte) ((l << 8) >> 56)); write((byte) ((l << 8) >> 56));
out.write((byte) ((l << 16) >> 56)); write((byte) ((l << 16) >> 56));
out.write((byte) ((l << 24) >> 56)); write((byte) ((l << 24) >> 56));
out.write((byte) ((l << 32) >> 56)); write((byte) ((l << 32) >> 56));
out.write((byte) ((l << 40) >> 56)); write((byte) ((l << 40) >> 56));
out.write((byte) ((l << 48) >> 56)); write((byte) ((l << 48) >> 56));
out.write((byte) ((l << 56) >> 56)); write((byte) ((l << 56) >> 56));
} }
public void writeIntAny(long l) throws IOException { public void writeIntAny(long l) throws IOException {
@@ -81,23 +92,23 @@ class WriterImpl implements Writer {
} }
public void writeFloat32(float f) throws IOException { public void writeFloat32(float f) throws IOException {
out.write(Tag.FLOAT32); write(Tag.FLOAT32);
writeInt32Bits(Float.floatToRawIntBits(f)); writeInt32Bits(Float.floatToRawIntBits(f));
} }
public void writeFloat64(double d) throws IOException { public void writeFloat64(double d) throws IOException {
out.write(Tag.FLOAT64); write(Tag.FLOAT64);
writeInt64Bits(Double.doubleToRawLongBits(d)); writeInt64Bits(Double.doubleToRawLongBits(d));
} }
public void writeString(String s) throws IOException { public void writeString(String s) throws IOException {
byte[] b = s.getBytes("UTF-8"); byte[] b = s.getBytes("UTF-8");
if(b.length < 16) out.write((byte) (Tag.SHORT_STRING | b.length)); if(b.length < 16) write((byte) (Tag.SHORT_STRING | b.length));
else { else {
out.write(Tag.STRING); write(Tag.STRING);
writeLength(b.length); writeLength(b.length);
} }
out.write(b); write(b);
} }
private void writeLength(int i) throws IOException { private void writeLength(int i) throws IOException {
@@ -109,19 +120,19 @@ class WriterImpl implements Writer {
} }
public void writeBytes(byte[] b) throws IOException { public void writeBytes(byte[] b) throws IOException {
if(b.length < 16) out.write((byte) (Tag.SHORT_BYTES | b.length)); if(b.length < 16) write((byte) (Tag.SHORT_BYTES | b.length));
else { else {
out.write(Tag.BYTES); write(Tag.BYTES);
writeLength(b.length); writeLength(b.length);
} }
out.write(b); write(b);
} }
public void writeList(Collection<?> c) throws IOException { public void writeList(Collection<?> c) throws IOException {
int length = c.size(); int length = c.size();
if(length < 16) out.write((byte) (Tag.SHORT_LIST | length)); if(length < 16) write((byte) (Tag.SHORT_LIST | length));
else { else {
out.write(Tag.LIST); write(Tag.LIST);
writeLength(length); writeLength(length);
} }
for(Object o : c) writeObject(o); for(Object o : c) writeObject(o);
@@ -145,18 +156,18 @@ class WriterImpl implements Writer {
} }
public void writeListStart() throws IOException { public void writeListStart() throws IOException {
out.write(Tag.LIST_START); write(Tag.LIST_START);
} }
public void writeListEnd() throws IOException { public void writeListEnd() throws IOException {
out.write(Tag.END); write(Tag.END);
} }
public void writeMap(Map<?, ?> m) throws IOException { public void writeMap(Map<?, ?> m) throws IOException {
int length = m.size(); int length = m.size();
if(length < 16) out.write((byte) (Tag.SHORT_MAP | length)); if(length < 16) write((byte) (Tag.SHORT_MAP | length));
else { else {
out.write(Tag.MAP); write(Tag.MAP);
writeLength(length); writeLength(length);
} }
for(Entry<?, ?> e : m.entrySet()) { for(Entry<?, ?> e : m.entrySet()) {
@@ -166,24 +177,34 @@ class WriterImpl implements Writer {
} }
public void writeMapStart() throws IOException { public void writeMapStart() throws IOException {
out.write(Tag.MAP_START); write(Tag.MAP_START);
} }
public void writeMapEnd() throws IOException { public void writeMapEnd() throws IOException {
out.write(Tag.END); write(Tag.END);
} }
public void writeNull() throws IOException { public void writeNull() throws IOException {
out.write(Tag.NULL); write(Tag.NULL);
} }
public void writeUserDefinedId(int id) throws IOException { public void writeUserDefinedId(int id) throws IOException {
if(id < 0 || id > 255) throw new IllegalArgumentException(); if(id < 0 || id > 255) throw new IllegalArgumentException();
if(id < 32) { if(id < 32) {
out.write((byte) (Tag.SHORT_USER | id)); write((byte) (Tag.SHORT_USER | id));
} else { } else {
out.write(Tag.USER); write(Tag.USER);
out.write((byte) id); write((byte) id);
} }
} }
private void write(byte b) throws IOException {
out.write(b);
for(Consumer c : consumers) c.write(b);
}
private void write(byte[] b) throws IOException {
out.write(b);
for(Consumer c : consumers) c.write(b, 0, b.length);
}
} }

View File

@@ -17,10 +17,6 @@ class MacConsumer implements Consumer {
mac.update(b); mac.update(b);
} }
public void write(byte[] b) {
mac.update(b);
}
public void write(byte[] b, int off, int len) { public void write(byte[] b, int off, int len) {
mac.update(b, off, len); mac.update(b, off, len);
} }

View File

@@ -130,7 +130,12 @@ public class FileReadWriteTest extends TestCase {
} }
@Test @Test
public void testWriteFile() throws Exception { public void testWriteAndRead() throws Exception {
write();
read();
}
private void write() throws Exception {
OutputStream out = new FileOutputStream(file); OutputStream out = new FileOutputStream(file);
// Use Alice's secret for writing // Use Alice's secret for writing
ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out,
@@ -177,10 +182,7 @@ public class FileReadWriteTest extends TestCase {
assertTrue(file.length() > message.getSize()); assertTrue(file.length() > message.getSize());
} }
@Test private void read() throws Exception {
public void testWriteAndReadFile() throws Exception {
testWriteFile();
InputStream in = new FileInputStream(file); InputStream in = new FileInputStream(file);
byte[] iv = new byte[16]; byte[] iv = new byte[16];

View File

@@ -416,10 +416,6 @@ public class ReaderImplTest extends TestCase {
out.write(b); out.write(b);
} }
public void write(byte[] b) throws IOException {
out.write(b);
}
public void write(byte[] b, int off, int len) throws IOException { public void write(byte[] b, int off, int len) throws IOException {
out.write(b, off, len); out.write(b, off, len);
} }