Decoupled ProtocolReader (which belongs in the protocol component)

from PacketReader (which belongs in the transport component).
This commit is contained in:
akwizgran
2011-08-13 14:18:16 +02:00
parent 5b6fecfb43
commit 9d25a819d1
17 changed files with 591 additions and 221 deletions

View File

@@ -3,6 +3,7 @@ package net.sf.briar;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.KeyPair;
import java.util.Arrays;
@@ -26,6 +27,8 @@ import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageEncoder;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Offer;
import net.sf.briar.api.protocol.ProtocolReader;
import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportUpdate;
@@ -64,6 +67,7 @@ public class FileReadWriteTest extends TestCase {
private final PacketReaderFactory packetReaderFactory;
private final PacketWriterFactory packetWriterFactory;
private final ProtocolReaderFactory protocolReaderFactory;
private final ProtocolWriterFactory protocolWriterFactory;
private final CryptoComponent crypto;
private final byte[] secret = new byte[45];
@@ -83,6 +87,7 @@ public class FileReadWriteTest extends TestCase {
new WritersModule());
packetReaderFactory = i.getInstance(PacketReaderFactory.class);
packetWriterFactory = i.getInstance(PacketWriterFactory.class);
protocolReaderFactory = i.getInstance(ProtocolReaderFactory.class);
protocolWriterFactory = i.getInstance(ProtocolWriterFactory.class);
crypto = i.getInstance(CryptoComponent.class);
assertEquals(crypto.getMessageDigest().getDigestLength(),
@@ -121,14 +126,14 @@ public class FileReadWriteTest extends TestCase {
@Test
public void testWriteFile() throws Exception {
OutputStream out = new FileOutputStream(file);
PacketWriter p = packetWriterFactory.createPacketWriter(out,
PacketWriter packetWriter = packetWriterFactory.createPacketWriter(out,
transportId, connection, secret);
out = p.getOutputStream();
out = packetWriter.getOutputStream();
AckWriter a = protocolWriterFactory.createAckWriter(out);
assertTrue(a.writeBatchId(ack));
a.finish();
p.nextPacket();
packetWriter.finishPacket();
BatchWriter b = protocolWriterFactory.createBatchWriter(out);
assertTrue(b.writeMessage(message.getBytes()));
@@ -136,7 +141,7 @@ public class FileReadWriteTest extends TestCase {
assertTrue(b.writeMessage(message2.getBytes()));
assertTrue(b.writeMessage(message3.getBytes()));
b.finish();
p.nextPacket();
packetWriter.finishPacket();
OfferWriter o = protocolWriterFactory.createOfferWriter(out);
assertTrue(o.writeMessageId(message.getId()));
@@ -144,14 +149,14 @@ public class FileReadWriteTest extends TestCase {
assertTrue(o.writeMessageId(message2.getId()));
assertTrue(o.writeMessageId(message3.getId()));
o.finish();
p.nextPacket();
packetWriter.finishPacket();
RequestWriter r = protocolWriterFactory.createRequestWriter(out);
BitSet requested = new BitSet(4);
requested.set(1);
requested.set(3);
r.writeBitmap(requested, 4);
p.nextPacket();
packetWriter.finishPacket();
SubscriptionWriter s =
protocolWriterFactory.createSubscriptionWriter(out);
@@ -160,11 +165,11 @@ public class FileReadWriteTest extends TestCase {
subs.put(group, 0L);
subs.put(group1, 0L);
s.writeSubscriptions(subs);
p.nextPacket();
packetWriter.finishPacket();
TransportWriter t = protocolWriterFactory.createTransportWriter(out);
t.writeTransports(transports);
p.nextPacket();
packetWriter.finishPacket();
out.flush();
out.close();
@@ -177,7 +182,7 @@ public class FileReadWriteTest extends TestCase {
testWriteFile();
FileInputStream in = new FileInputStream(file);
InputStream in = new FileInputStream(file);
byte[] firstTag = new byte[16];
int offset = 0;
while(offset < 16) {
@@ -186,17 +191,22 @@ public class FileReadWriteTest extends TestCase {
offset += read;
}
assertEquals(16, offset);
PacketReader p = packetReaderFactory.createPacketReader(firstTag, in,
transportId, connection, secret);
PacketReader packetReader = packetReaderFactory.createPacketReader(
firstTag, in, transportId, connection, secret);
in = packetReader.getInputStream();
ProtocolReader protocolReader =
protocolReaderFactory.createProtocolReader(in);
// Read the ack
assertTrue(p.hasAck());
Ack a = p.readAck();
assertTrue(protocolReader.hasAck());
Ack a = protocolReader.readAck();
packetReader.finishPacket();
assertEquals(Collections.singletonList(ack), a.getBatchIds());
// Read the batch
assertTrue(p.hasBatch());
Batch b = p.readBatch();
assertTrue(protocolReader.hasBatch());
Batch b = protocolReader.readBatch();
packetReader.finishPacket();
Collection<Message> messages = b.getMessages();
assertEquals(4, messages.size());
Iterator<Message> it = messages.iterator();
@@ -206,8 +216,9 @@ public class FileReadWriteTest extends TestCase {
checkMessageEquality(message3, it.next());
// Read the offer
assertTrue(p.hasOffer());
Offer o = p.readOffer();
assertTrue(protocolReader.hasOffer());
Offer o = protocolReader.readOffer();
packetReader.finishPacket();
Collection<MessageId> offered = o.getMessageIds();
assertEquals(4, offered.size());
Iterator<MessageId> it1 = offered.iterator();
@@ -217,8 +228,9 @@ public class FileReadWriteTest extends TestCase {
assertEquals(message3.getId(), it1.next());
// Read the request
assertTrue(p.hasRequest());
Request r = p.readRequest();
assertTrue(protocolReader.hasRequest());
Request r = protocolReader.readRequest();
packetReader.finishPacket();
BitSet requested = r.getBitmap();
assertFalse(requested.get(0));
assertTrue(requested.get(1));
@@ -228,8 +240,9 @@ public class FileReadWriteTest extends TestCase {
assertEquals(2, requested.cardinality());
// Read the subscription update
assertTrue(p.hasSubscriptionUpdate());
SubscriptionUpdate s = p.readSubscriptionUpdate();
assertTrue(protocolReader.hasSubscriptionUpdate());
SubscriptionUpdate s = protocolReader.readSubscriptionUpdate();
packetReader.finishPacket();
Map<Group, Long> subs = s.getSubscriptions();
assertEquals(2, subs.size());
assertEquals(Long.valueOf(0L), subs.get(group));
@@ -238,11 +251,14 @@ public class FileReadWriteTest extends TestCase {
assertTrue(s.getTimestamp() <= System.currentTimeMillis());
// Read the transport update
assertTrue(p.hasTransportUpdate());
TransportUpdate t = p.readTransportUpdate();
assertTrue(protocolReader.hasTransportUpdate());
TransportUpdate t = protocolReader.readTransportUpdate();
packetReader.finishPacket();
assertEquals(transports, t.getTransports());
assertTrue(t.getTimestamp() > start);
assertTrue(t.getTimestamp() <= System.currentTimeMillis());
in.close();
}
@After

View File

@@ -0,0 +1,98 @@
package net.sf.briar.transport;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Arrays;
import java.util.Random;
import javax.crypto.Cipher;
import javax.crypto.Mac;
import javax.crypto.SecretKey;
import junit.framework.TestCase;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.transport.PacketReader;
import net.sf.briar.api.transport.PacketWriter;
import net.sf.briar.crypto.CryptoModule;
import org.junit.Test;
import com.google.inject.Guice;
import com.google.inject.Injector;
public class PacketReadWriteTest extends TestCase {
private final CryptoComponent crypto;
private final Cipher tagCipher, packetCipher;
private final SecretKey macKey, tagKey, packetKey;
private final Mac mac;
private final Random random;
private final byte[] secret = new byte[100];
private final int transportId = 999;
private final long connection = 1234L;
public PacketReadWriteTest() {
super();
Injector i = Guice.createInjector(new CryptoModule());
crypto = i.getInstance(CryptoComponent.class);
tagCipher = crypto.getTagCipher();
packetCipher = crypto.getPacketCipher();
macKey = crypto.deriveMacKey(secret);
tagKey = crypto.deriveTagKey(secret);
packetKey = crypto.derivePacketKey(secret);
mac = crypto.getMac();
random = new Random();
}
@Test
public void testWriteAndRead() throws Exception {
// Generate two random packets
byte[] packet = new byte[12345];
random.nextBytes(packet);
byte[] packet1 = new byte[321];
random.nextBytes(packet1);
// Write the packets
ByteArrayOutputStream out = new ByteArrayOutputStream();
PacketEncrypter encrypter = new PacketEncrypterImpl(out, tagCipher,
packetCipher, tagKey, packetKey);
mac.init(macKey);
PacketWriter writer = new PacketWriterImpl(encrypter, mac, transportId,
connection);
OutputStream out1 = writer.getOutputStream();
out1.write(packet);
writer.finishPacket();
out1.write(packet1);
writer.finishPacket();
// Read the packets back
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
byte[] firstTag = new byte[Constants.TAG_BYTES];
assertEquals(Constants.TAG_BYTES, in.read(firstTag));
PacketDecrypter decrypter = new PacketDecrypterImpl(firstTag, in,
tagCipher, packetCipher, tagKey, packetKey);
PacketReader reader = new PacketReaderImpl(decrypter, mac, transportId,
connection);
InputStream in1 = reader.getInputStream();
byte[] recovered = new byte[packet.length];
int offset = 0;
while(offset < recovered.length) {
int read = in1.read(recovered, offset, recovered.length - offset);
if(read == -1) break;
offset += read;
}
assertEquals(recovered.length, offset);
reader.finishPacket();
assertTrue(Arrays.equals(packet, recovered));
byte[] recovered1 = new byte[packet1.length];
offset = 0;
while(offset < recovered1.length) {
int read = in1.read(recovered1, offset, recovered1.length - offset);
if(read == -1) break;
offset += read;
}
assertEquals(recovered1.length, offset);
reader.finishPacket();
assertTrue(Arrays.equals(packet1, recovered1));
}
}

View File

@@ -0,0 +1,187 @@
package net.sf.briar.transport;
import java.io.ByteArrayInputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.security.GeneralSecurityException;
import java.util.Arrays;
import javax.crypto.Mac;
import junit.framework.TestCase;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.transport.PacketReader;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.util.StringUtils;
import org.junit.Test;
import com.google.inject.Guice;
import com.google.inject.Injector;
public class PacketReaderImplTest extends TestCase {
private final Mac mac;
public PacketReaderImplTest() throws Exception {
super();
Injector i = Guice.createInjector(new CryptoModule());
CryptoComponent crypto = i.getInstance(CryptoComponent.class);
mac = crypto.getMac();
mac.init(crypto.generateSecretKey());
}
@Test
public void testFirstReadTriggersTag() throws Exception {
// TAG_BYTES for the tag, 1 byte for the packet
byte[] b = new byte[Constants.TAG_BYTES + 1];
ByteArrayInputStream in = new ByteArrayInputStream(b);
PacketDecrypter d = new NullPacketDecrypter(in);
PacketReader p = new PacketReaderImpl(d, mac, 0, 0L);
// There should be one byte available before EOF
assertEquals(0, p.getInputStream().read());
assertEquals(-1, p.getInputStream().read());
}
@Test
public void testFinishPacketAfterReadTriggersMac() throws Exception {
// TAG_BYTES for the tag, 1 byte for the packet
byte[] b = new byte[Constants.TAG_BYTES + 1];
// Calculate the MAC and append it to the packet
mac.update(b);
byte[] macBytes = mac.doFinal();
byte[] b1 = Arrays.copyOf(b, b.length + macBytes.length);
System.arraycopy(macBytes, 0, b1, b.length, macBytes.length);
// Check that the PacketReader reads and verifies the MAC
ByteArrayInputStream in = new ByteArrayInputStream(b1);
PacketDecrypter d = new NullPacketDecrypter(in);
PacketReader p = new PacketReaderImpl(d, mac, 0, 0L);
assertEquals(0, p.getInputStream().read());
p.finishPacket();
// Reading the MAC should take us to EOF
assertEquals(-1, p.getInputStream().read());
}
@Test
public void testModifyingPacketInvalidatesMac() throws Exception {
// TAG_BYTES for the tag, 1 byte for the packet
byte[] b = new byte[Constants.TAG_BYTES + 1];
// Calculate the MAC and append it to the packet
mac.update(b);
byte[] macBytes = mac.doFinal();
byte[] b1 = Arrays.copyOf(b, b.length + macBytes.length);
System.arraycopy(macBytes, 0, b1, b.length, macBytes.length);
// Modify the packet
b1[Constants.TAG_BYTES] = (byte) 1;
// Check that the PacketReader reads and fails to verify the MAC
ByteArrayInputStream in = new ByteArrayInputStream(b1);
PacketDecrypter d = new NullPacketDecrypter(in);
PacketReader p = new PacketReaderImpl(d, mac, 0, 0L);
assertEquals(1, p.getInputStream().read());
try {
p.finishPacket();
fail();
} catch(GeneralSecurityException expected) {}
}
@Test
public void testExtraCallsToFinishPacketDoNothing() throws Exception {
// TAG_BYTES for the tag, 1 byte for the packet
byte[] b = new byte[Constants.TAG_BYTES + 1];
// Calculate the MAC and append it to the packet
mac.update(b);
byte[] macBytes = mac.doFinal();
byte[] b1 = Arrays.copyOf(b, b.length + macBytes.length);
System.arraycopy(macBytes, 0, b1, b.length, macBytes.length);
// Check that the PacketReader reads and verifies the MAC
ByteArrayInputStream in = new ByteArrayInputStream(b1);
PacketDecrypter d = new NullPacketDecrypter(in);
PacketReader p = new PacketReaderImpl(d, mac, 0, 0L);
// Initial calls to finishPacket() should have no effect
p.finishPacket();
p.finishPacket();
p.finishPacket();
assertEquals(0, p.getInputStream().read());
p.finishPacket();
// Extra calls to finishPacket() should have no effect
p.finishPacket();
p.finishPacket();
p.finishPacket();
// Reading the MAC should take us to EOF
assertEquals(-1, p.getInputStream().read());
}
@Test
public void testPacketNumberIsIncremented() throws Exception {
byte[] tag = StringUtils.fromHexString(
"0000" // 16 bits reserved
+ "F00D" // 16 bits for the transport ID
+ "DEADBEEF" // 32 bits for the connection number
+ "00000000" // 32 bits for the packet number
+ "00000000" // 32 bits for the block number
);
assertEquals(Constants.TAG_BYTES, tag.length);
byte[] tag1 = StringUtils.fromHexString(
"0000" // 16 bits reserved
+ "F00D" // 16 bits for the transport ID
+ "DEADBEEF" // 32 bits for the connection number
+ "00000001" // 32 bits for the packet number
+ "00000000" // 32 bits for the block number
);
assertEquals(Constants.TAG_BYTES, tag1.length);
// Calculate the MAC on the first packet and append it to the packet
mac.update(tag);
mac.update((byte) 0);
byte[] macBytes = mac.doFinal();
byte[] b = Arrays.copyOf(tag, tag.length + 1 + macBytes.length);
System.arraycopy(macBytes, 0, b, tag.length + 1, macBytes.length);
// Calculate the MAC on the second packet and append it to the packet
mac.update(tag1);
mac.update((byte) 0);
byte[] macBytes1 = mac.doFinal();
byte[] b1 = Arrays.copyOf(tag1, tag1.length + 1 + macBytes1.length);
System.arraycopy(macBytes1, 0, b1, tag.length + 1, macBytes1.length);
// Check that the PacketReader accepts the correct tags and MACs
byte[] b2 = Arrays.copyOf(b, b.length + b1.length);
System.arraycopy(b1, 0, b2, b.length, b1.length);
ByteArrayInputStream in = new ByteArrayInputStream(b2);
PacketDecrypter d = new NullPacketDecrypter(in);
PacketReader p = new PacketReaderImpl(d, mac, 0xF00D, 0xDEADBEEFL);
// Packet one
assertEquals(0, p.getInputStream().read());
p.finishPacket();
// Packet two
assertEquals(0, p.getInputStream().read());
p.finishPacket();
// We should be at EOF
assertEquals(-1, p.getInputStream().read());
}
/** A PacketDecrypter that performs no decryption. */
private static class NullPacketDecrypter implements PacketDecrypter {
private final InputStream in;
private NullPacketDecrypter(InputStream in) {
this.in = in;
}
public InputStream getInputStream() {
return in;
}
public byte[] readTag() throws IOException {
byte[] tag = new byte[Constants.TAG_BYTES];
int offset = 0;
while(offset < tag.length) {
int read = in.read(tag, offset, tag.length - offset);
if(read == -1) break;
offset += read;
}
if(offset == 0) return null; // EOF between packets is acceptable
if(offset < tag.length) throw new EOFException();
return tag;
}
}
}

View File

@@ -36,52 +36,56 @@ public class PacketWriterImplTest extends TestCase {
PacketEncrypter e = new NullPacketEncrypter(out);
PacketWriter p = new PacketWriterImpl(e, mac, 0, 0L);
p.getOutputStream().write(0);
// There should be TAG_BYTES bytes for the tag, 1 byte for the write
// There should be TAG_BYTES bytes for the tag, 1 byte for the packet
assertTrue(Arrays.equals(new byte[Constants.TAG_BYTES + 1],
out.toByteArray()));
}
@Test
public void testNextPacketAfterWriteTriggersMac() throws Exception {
public void testFinishPacketAfterWriteTriggersMac() throws Exception {
// Calculate what the MAC should be
mac.update(new byte[17]);
mac.update(new byte[Constants.TAG_BYTES + 1]);
byte[] expectedMac = mac.doFinal();
// Check that the PacketWriter calculates and writes the correct MAC
ByteArrayOutputStream out = new ByteArrayOutputStream();
PacketEncrypter e = new NullPacketEncrypter(out);
PacketWriter p = new PacketWriterImpl(e, mac, 0, 0L);
p.getOutputStream().write(0);
p.nextPacket();
p.finishPacket();
byte[] written = out.toByteArray();
assertEquals(17 + expectedMac.length, written.length);
assertEquals(Constants.TAG_BYTES + 1 + expectedMac.length,
written.length);
byte[] actualMac = new byte[expectedMac.length];
System.arraycopy(written, 17, actualMac, 0, actualMac.length);
System.arraycopy(written, Constants.TAG_BYTES + 1, actualMac, 0,
actualMac.length);
assertTrue(Arrays.equals(expectedMac, actualMac));
}
@Test
public void testExtraCallsToNextPacketDoNothing() throws Exception {
public void testExtraCallsToFinishPacketDoNothing() throws Exception {
// Calculate what the MAC should be
mac.update(new byte[17]);
mac.update(new byte[Constants.TAG_BYTES + 1]);
byte[] expectedMac = mac.doFinal();
// Check that the PacketWriter calculates and writes the correct MAC
ByteArrayOutputStream out = new ByteArrayOutputStream();
PacketEncrypter e = new NullPacketEncrypter(out);
PacketWriter p = new PacketWriterImpl(e, mac, 0, 0L);
// Initial calls to nextPacket() should have no effect
p.nextPacket();
p.nextPacket();
p.nextPacket();
// Initial calls to finishPacket() should have no effect
p.finishPacket();
p.finishPacket();
p.finishPacket();
p.getOutputStream().write(0);
p.nextPacket();
// Extra calls to nextPacket() should have no effect
p.nextPacket();
p.nextPacket();
p.nextPacket();
p.finishPacket();
// Extra calls to finishPacket() should have no effect
p.finishPacket();
p.finishPacket();
p.finishPacket();
byte[] written = out.toByteArray();
assertEquals(17 + expectedMac.length, written.length);
assertEquals(Constants.TAG_BYTES + 1 + expectedMac.length,
written.length);
byte[] actualMac = new byte[expectedMac.length];
System.arraycopy(written, 17, actualMac, 0, actualMac.length);
System.arraycopy(written, Constants.TAG_BYTES + 1, actualMac, 0,
actualMac.length);
assertTrue(Arrays.equals(expectedMac, actualMac));
}
@@ -117,10 +121,10 @@ public class PacketWriterImplTest extends TestCase {
PacketWriter p = new PacketWriterImpl(e, mac, 0xF00D, 0xDEADBEEFL);
// Packet one
p.getOutputStream().write(0);
p.nextPacket();
p.finishPacket();
// Packet two
p.getOutputStream().write(0);
p.nextPacket();
p.finishPacket();
byte[] written = out.toByteArray();
assertEquals(Constants.TAG_BYTES + 1 + expectedMac.length
+ Constants.TAG_BYTES + 1 + expectedMac1.length,
@@ -146,6 +150,7 @@ public class PacketWriterImplTest extends TestCase {
assertTrue(Arrays.equals(expectedMac1, actualMac1));
}
/** A PacketEncrypter that performs no encryption. */
private static class NullPacketEncrypter implements PacketEncrypter {
private final OutputStream out;