More refactoring to connect ConnectionRecogniser to ConnectionReader.

Added TestDatabaseModule so tests can specify their own DB
configuration. The modules are currently too tightly coupled - see
whether any dependencies can be removed.
This commit is contained in:
akwizgran
2011-09-28 14:21:38 +01:00
parent 4aff0c4f88
commit a1b664b639
30 changed files with 260 additions and 191 deletions

View File

@@ -7,7 +7,7 @@ import net.sf.briar.api.protocol.BatchId;
/** An interface for creating a batch packet. */ /** An interface for creating a batch packet. */
public interface BatchWriter { public interface BatchWriter {
/** Returns the capacity of the batch. */ /** Returns the capacity of the batch in bytes. */
int getCapacity(); int getCapacity();
/** /**

View File

@@ -4,6 +4,6 @@ import java.io.InputStream;
public interface ConnectionReaderFactory { public interface ConnectionReaderFactory {
ConnectionReader createConnectionReader(InputStream in, boolean initiator, ConnectionReader createConnectionReader(InputStream in, byte[] encryptedIv,
int transportId, long connection, byte[] secret); byte[] secret);
} }

View File

@@ -0,0 +1,6 @@
package net.sf.briar.api.transport;
public interface ConnectionRecogniserFactory {
ConnectionRecogniser createConnectionRecogniser(int transportId);
}

View File

@@ -12,5 +12,5 @@ public interface ConnectionWriter {
OutputStream getOutputStream(); OutputStream getOutputStream();
/** Returns the maximum number of bytes that can be written. */ /** Returns the maximum number of bytes that can be written. */
long getCapacity(); long getRemainingCapacity();
} }

View File

@@ -9,7 +9,7 @@ import java.io.OutputStream;
*/ */
public interface BatchTransportWriter { public interface BatchTransportWriter {
/** Returns the maximum number of bytes that can be written. */ /** Returns the capacity of the transport in bytes. */
long getCapacity(); long getCapacity();
/** Returns an output stream for writing to the transport. */ /** Returns an output stream for writing to the transport. */

View File

@@ -22,19 +22,30 @@ implements ConnectionDecrypter {
private final Cipher frameCipher; private final Cipher frameCipher;
private final SecretKey frameKey; private final SecretKey frameKey;
private final byte[] buf, iv; private final byte[] iv, buf;
private int bufOff = 0, bufLen = 0; private int bufOff = 0, bufLen = 0;
private long frame = 0L; private long frame = 0L;
private boolean betweenFrames = true; private boolean betweenFrames = true;
ConnectionDecrypterImpl(InputStream in, boolean initiator, int transportId, ConnectionDecrypterImpl(InputStream in, byte[] encryptedIv, Cipher ivCipher,
long connection, Cipher frameCipher, SecretKey frameKey) { Cipher frameCipher, SecretKey ivKey, SecretKey frameKey) {
super(in); super(in);
this.frameCipher = frameCipher; this.frameCipher = frameCipher;
this.frameKey = frameKey; this.frameKey = frameKey;
// Decrypt the IV
try {
ivCipher.init(Cipher.DECRYPT_MODE, ivKey);
iv = ivCipher.doFinal(encryptedIv);
} catch(BadPaddingException badCipher) {
throw new IllegalArgumentException(badCipher);
} catch(IllegalBlockSizeException badCipher) {
throw new IllegalArgumentException(badCipher);
} catch(InvalidKeyException badKey) {
throw new IllegalArgumentException(badKey);
}
if(iv.length != IV_LENGTH) throw new IllegalArgumentException();
buf = new byte[IV_LENGTH]; buf = new byte[IV_LENGTH];
iv = IvEncoder.encodeIv(initiator, transportId, connection);
} }
public InputStream getInputStream() { public InputStream getInputStream() {

View File

@@ -13,5 +13,5 @@ interface ConnectionEncrypter {
void writeMac(byte[] mac) throws IOException; void writeMac(byte[] mac) throws IOException;
/** Returns the maximum number of bytes that can be written. */ /** Returns the maximum number of bytes that can be written. */
long getCapacity(); long getRemainingCapacity();
} }

View File

@@ -18,29 +18,34 @@ import javax.crypto.spec.IvParameterSpec;
class ConnectionEncrypterImpl extends FilterOutputStream class ConnectionEncrypterImpl extends FilterOutputStream
implements ConnectionEncrypter { implements ConnectionEncrypter {
private final Cipher ivCipher, frameCipher; private final Cipher frameCipher;
private final SecretKey frameKey; private final SecretKey frameKey;
private final byte[] iv; private final byte[] iv, encryptedIv;
private long capacity, frame = 0L; private long capacity, frame = 0L;
private boolean ivWritten = false, betweenFrames = false; private boolean ivWritten = false, betweenFrames = false;
ConnectionEncrypterImpl(OutputStream out, long capacity, boolean initiator, ConnectionEncrypterImpl(OutputStream out, long capacity, byte[] iv,
int transportId, long connection, Cipher ivCipher, Cipher ivCipher, Cipher frameCipher, SecretKey ivKey,
Cipher frameCipher, SecretKey ivKey, SecretKey frameKey) { SecretKey frameKey) {
super(out); super(out);
this.ivCipher = ivCipher; this.capacity = capacity;
this.iv = iv;
this.frameCipher = frameCipher; this.frameCipher = frameCipher;
this.frameKey = frameKey; this.frameKey = frameKey;
iv = IvEncoder.encodeIv(initiator, transportId, connection); // Encrypt the IV
try { try {
ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
encryptedIv = ivCipher.doFinal(iv);
} catch(BadPaddingException badCipher) {
throw new IllegalArgumentException(badCipher);
} catch(IllegalBlockSizeException badCipher) {
throw new IllegalArgumentException(badCipher);
} catch(InvalidKeyException badKey) { } catch(InvalidKeyException badKey) {
throw new IllegalArgumentException(badKey); throw new IllegalArgumentException(badKey);
} }
if(ivCipher.getOutputSize(IV_LENGTH) != IV_LENGTH) if(encryptedIv.length != IV_LENGTH)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
this.capacity = capacity;
} }
public OutputStream getOutputStream() { public OutputStream getOutputStream() {
@@ -60,7 +65,7 @@ implements ConnectionEncrypter {
betweenFrames = true; betweenFrames = true;
} }
public long getCapacity() { public long getRemainingCapacity() {
return capacity; return capacity;
} }
@@ -90,14 +95,8 @@ implements ConnectionEncrypter {
private void writeIv() throws IOException { private void writeIv() throws IOException {
assert !ivWritten; assert !ivWritten;
assert !betweenFrames; assert !betweenFrames;
try { out.write(encryptedIv);
out.write(ivCipher.doFinal(iv)); capacity -= encryptedIv.length;
} catch(BadPaddingException badCipher) {
throw new RuntimeException(badCipher);
} catch(IllegalBlockSizeException badCipher) {
throw new RuntimeException(badCipher);
}
capacity -= iv.length;
ivWritten = true; ivWritten = true;
betweenFrames = true; betweenFrames = true;
} }

View File

@@ -1,7 +1,6 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import java.io.InputStream; import java.io.InputStream;
import java.security.InvalidKeyException;
import javax.crypto.Cipher; import javax.crypto.Cipher;
import javax.crypto.Mac; import javax.crypto.Mac;
@@ -23,19 +22,17 @@ class ConnectionReaderFactoryImpl implements ConnectionReaderFactory {
} }
public ConnectionReader createConnectionReader(InputStream in, public ConnectionReader createConnectionReader(InputStream in,
boolean initiator, int transportId, long connection, byte[] encryptedIv, byte[] secret) {
byte[] secret) { // Create the decrypter
SecretKey macKey = crypto.deriveIncomingMacKey(secret); Cipher ivCipher = crypto.getIvCipher();
SecretKey frameKey = crypto.deriveIncomingFrameKey(secret);
Cipher frameCipher = crypto.getFrameCipher(); Cipher frameCipher = crypto.getFrameCipher();
Mac mac = crypto.getMac(); SecretKey ivKey = crypto.deriveIncomingIvKey(secret);
try { SecretKey frameKey = crypto.deriveIncomingFrameKey(secret);
mac.init(macKey);
} catch(InvalidKeyException e) {
throw new IllegalArgumentException(e);
}
ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in, ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in,
initiator, transportId, connection, frameCipher, frameKey); encryptedIv, ivCipher, frameCipher, ivKey, frameKey);
return new ConnectionReaderImpl(decrypter, mac); // Create the reader
Mac mac = crypto.getMac();
SecretKey macKey = crypto.deriveIncomingMacKey(secret);
return new ConnectionReaderImpl(decrypter, mac, macKey);
} }
} }

View File

@@ -7,9 +7,11 @@ import java.io.EOFException;
import java.io.FilterInputStream; import java.io.FilterInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.security.InvalidKeyException;
import java.util.Arrays; import java.util.Arrays;
import javax.crypto.Mac; import javax.crypto.Mac;
import javax.crypto.SecretKey;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
@@ -27,10 +29,17 @@ implements ConnectionReader {
private int payloadOff = 0, payloadLen = 0; private int payloadOff = 0, payloadLen = 0;
private boolean betweenFrames = true; private boolean betweenFrames = true;
ConnectionReaderImpl(ConnectionDecrypter decrypter, Mac mac) { ConnectionReaderImpl(ConnectionDecrypter decrypter, Mac mac,
SecretKey macKey) {
super(decrypter.getInputStream()); super(decrypter.getInputStream());
this.decrypter = decrypter; this.decrypter = decrypter;
this.mac = mac; this.mac = mac;
// Initialise the MAC
try {
mac.init(macKey);
} catch(InvalidKeyException e) {
throw new IllegalArgumentException(e);
}
maxPayloadLength = MAX_FRAME_LENGTH - 4 - mac.getMacLength(); maxPayloadLength = MAX_FRAME_LENGTH - 4 - mac.getMacLength();
header = new byte[4]; header = new byte[4];
payload = new byte[maxPayloadLength]; payload = new byte[maxPayloadLength];

View File

@@ -0,0 +1,25 @@
package net.sf.briar.transport;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionRecogniserFactory;
import com.google.inject.Inject;
class ConnectionRecogniserFactoryImpl implements ConnectionRecogniserFactory {
private final CryptoComponent crypto;
private final DatabaseComponent db;
@Inject
ConnectionRecogniserFactoryImpl(CryptoComponent crypto,
DatabaseComponent db) {
this.crypto = crypto;
this.db = db;
}
public ConnectionRecogniser createConnectionRecogniser(int transportId) {
return new ConnectionRecogniserImpl(transportId, crypto, db);
}
}

View File

@@ -1,7 +1,6 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import java.io.OutputStream; import java.io.OutputStream;
import java.security.InvalidKeyException;
import javax.crypto.Cipher; import javax.crypto.Cipher;
import javax.crypto.Mac; import javax.crypto.Mac;
@@ -25,20 +24,17 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
public ConnectionWriter createConnectionWriter(OutputStream out, public ConnectionWriter createConnectionWriter(OutputStream out,
long capacity, boolean initiator, int transportId, long connection, long capacity, boolean initiator, int transportId, long connection,
byte[] secret) { byte[] secret) {
SecretKey macKey = crypto.deriveOutgoingMacKey(secret); // Create the encrypter
SecretKey ivKey = crypto.deriveOutgoingIvKey(secret);
SecretKey frameKey = crypto.deriveOutgoingFrameKey(secret);
Cipher ivCipher = crypto.getIvCipher(); Cipher ivCipher = crypto.getIvCipher();
Cipher frameCipher = crypto.getFrameCipher(); Cipher frameCipher = crypto.getFrameCipher();
Mac mac = crypto.getMac(); SecretKey ivKey = crypto.deriveOutgoingIvKey(secret);
try { SecretKey frameKey = crypto.deriveOutgoingFrameKey(secret);
mac.init(macKey); byte[] iv = IvEncoder.encodeIv(initiator, transportId, connection);
} catch(InvalidKeyException badKey) {
throw new IllegalArgumentException(badKey);
}
ConnectionEncrypter encrypter = new ConnectionEncrypterImpl(out, ConnectionEncrypter encrypter = new ConnectionEncrypterImpl(out,
capacity, initiator, transportId, connection, ivCipher, capacity, iv, ivCipher, frameCipher, ivKey, frameKey);
frameCipher, ivKey, frameKey); // Create the writer
return new ConnectionWriterImpl(encrypter, mac); Mac mac = crypto.getMac();
SecretKey macKey = crypto.deriveOutgoingMacKey(secret);
return new ConnectionWriterImpl(encrypter, mac, macKey);
} }
} }

View File

@@ -7,8 +7,10 @@ import java.io.ByteArrayOutputStream;
import java.io.FilterOutputStream; import java.io.FilterOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.security.InvalidKeyException;
import javax.crypto.Mac; import javax.crypto.Mac;
import javax.crypto.SecretKey;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.util.ByteUtils; import net.sf.briar.util.ByteUtils;
@@ -28,10 +30,17 @@ implements ConnectionWriter {
protected long frame = 0L; protected long frame = 0L;
ConnectionWriterImpl(ConnectionEncrypter encrypter, Mac mac) { ConnectionWriterImpl(ConnectionEncrypter encrypter, Mac mac,
SecretKey macKey) {
super(encrypter.getOutputStream()); super(encrypter.getOutputStream());
this.encrypter = encrypter; this.encrypter = encrypter;
this.mac = mac; this.mac = mac;
// Initialise the MAC
try {
mac.init(macKey);
} catch(InvalidKeyException badKey) {
throw new IllegalArgumentException(badKey);
}
maxPayloadLength = MAX_FRAME_LENGTH - 4 - mac.getMacLength(); maxPayloadLength = MAX_FRAME_LENGTH - 4 - mac.getMacLength();
buf = new ByteArrayOutputStream(maxPayloadLength); buf = new ByteArrayOutputStream(maxPayloadLength);
header = new byte[4]; header = new byte[4];
@@ -41,8 +50,8 @@ implements ConnectionWriter {
return this; return this;
} }
public long getCapacity() { public long getRemainingCapacity() {
long capacity = encrypter.getCapacity(); long capacity = encrypter.getRemainingCapacity();
// If there's any data buffered, subtract it and its auth overhead // If there's any data buffered, subtract it and its auth overhead
int overheadPerFrame = header.length + mac.getMacLength(); int overheadPerFrame = header.length + mac.getMacLength();
if(buf.size() > 0) capacity -= buf.size() + overheadPerFrame; if(buf.size() > 0) capacity -= buf.size() + overheadPerFrame;

View File

@@ -5,6 +5,7 @@ import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.io.IOException; import java.io.IOException;
import javax.crypto.Mac; import javax.crypto.Mac;
import javax.crypto.SecretKey;
import net.sf.briar.util.ByteUtils; import net.sf.briar.util.ByteUtils;
@@ -21,8 +22,9 @@ class PaddedConnectionWriter extends ConnectionWriterImpl {
private boolean closed = false; private boolean closed = false;
private IOException exception = null; private IOException exception = null;
PaddedConnectionWriter(ConnectionEncrypter encrypter, Mac mac) { PaddedConnectionWriter(ConnectionEncrypter encrypter, Mac mac,
super(encrypter, mac); SecretKey macKey) {
super(encrypter, mac, macKey);
padding = new byte[maxPayloadLength]; padding = new byte[maxPayloadLength];
} }

View File

@@ -1,6 +1,7 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionRecogniserFactory;
import net.sf.briar.api.transport.ConnectionWindowFactory; import net.sf.briar.api.transport.ConnectionWindowFactory;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
@@ -12,6 +13,8 @@ public class TransportModule extends AbstractModule {
protected void configure() { protected void configure() {
bind(ConnectionReaderFactory.class).to( bind(ConnectionReaderFactory.class).to(
ConnectionReaderFactoryImpl.class); ConnectionReaderFactoryImpl.class);
bind(ConnectionRecogniserFactory.class).to(
ConnectionRecogniserFactoryImpl.class);
bind(ConnectionWindowFactory.class).to( bind(ConnectionWindowFactory.class).to(
ConnectionWindowFactoryImpl.class); ConnectionWindowFactoryImpl.class);
bind(ConnectionWriterFactory.class).to( bind(ConnectionWriterFactory.class).to(

View File

@@ -33,13 +33,13 @@ class OutgoingBatchConnection {
void write() throws DbException, IOException { void write() throws DbException, IOException {
OutputStream out = conn.getOutputStream(); OutputStream out = conn.getOutputStream();
// There should be enough space for a packet // There should be enough space for a packet
long capacity = conn.getCapacity(); long capacity = conn.getRemainingCapacity();
if(capacity < MAX_PACKET_LENGTH) throw new IOException(); if(capacity < MAX_PACKET_LENGTH) throw new IOException();
// Write a transport update // Write a transport update
TransportWriter t = protoFactory.createTransportWriter(out); TransportWriter t = protoFactory.createTransportWriter(out);
db.generateTransportUpdate(contactId, t); db.generateTransportUpdate(contactId, t);
// If there's space, write a subscription update // If there's space, write a subscription update
capacity = conn.getCapacity(); capacity = conn.getRemainingCapacity();
if(capacity >= MAX_PACKET_LENGTH) { if(capacity >= MAX_PACKET_LENGTH) {
SubscriptionWriter s = protoFactory.createSubscriptionWriter(out); SubscriptionWriter s = protoFactory.createSubscriptionWriter(out);
db.generateSubscriptionUpdate(contactId, s); db.generateSubscriptionUpdate(contactId, s);
@@ -47,14 +47,14 @@ class OutgoingBatchConnection {
// Write acks until you can't write acks no more // Write acks until you can't write acks no more
AckWriter a = protoFactory.createAckWriter(out); AckWriter a = protoFactory.createAckWriter(out);
do { do {
capacity = conn.getCapacity(); capacity = conn.getRemainingCapacity();
int max = (int) Math.min(MAX_PACKET_LENGTH, capacity); int max = (int) Math.min(MAX_PACKET_LENGTH, capacity);
a.setMaxPacketLength(max); a.setMaxPacketLength(max);
} while(db.generateAck(contactId, a)); } while(db.generateAck(contactId, a));
// Write batches until you can't write batches no more // Write batches until you can't write batches no more
BatchWriter b = protoFactory.createBatchWriter(out); BatchWriter b = protoFactory.createBatchWriter(out);
do { do {
capacity = conn.getCapacity(); capacity = conn.getRemainingCapacity();
int max = (int) Math.min(MAX_PACKET_LENGTH, capacity); int max = (int) Math.min(MAX_PACKET_LENGTH, capacity);
b.setMaxPacketLength(max); b.setMaxPacketLength(max);
} while(db.generateBatch(contactId, b)); } while(db.generateBatch(contactId, b));

View File

@@ -45,6 +45,7 @@ import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.db.DatabaseModule;
import net.sf.briar.protocol.ProtocolModule; import net.sf.briar.protocol.ProtocolModule;
import net.sf.briar.protocol.writers.ProtocolWritersModule; import net.sf.briar.protocol.writers.ProtocolWritersModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
@@ -83,8 +84,9 @@ public class FileReadWriteTest extends TestCase {
public FileReadWriteTest() throws Exception { public FileReadWriteTest() throws Exception {
super(); super();
Injector i = Guice.createInjector(new CryptoModule(), Injector i = Guice.createInjector(new CryptoModule(),
new ProtocolModule(), new ProtocolWritersModule(), new DatabaseModule(), new ProtocolModule(),
new SerialModule(), new TransportModule()); new ProtocolWritersModule(), new SerialModule(),
new TestDatabaseModule(testDir), new TransportModule());
connectionReaderFactory = i.getInstance(ConnectionReaderFactory.class); connectionReaderFactory = i.getInstance(ConnectionReaderFactory.class);
connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class); connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class);
protocolReaderFactory = i.getInstance(ProtocolReaderFactory.class); protocolReaderFactory = i.getInstance(ProtocolReaderFactory.class);
@@ -191,7 +193,7 @@ public class FileReadWriteTest extends TestCase {
assertEquals(16, offset); assertEquals(16, offset);
// Use Bob's secret for reading // Use Bob's secret for reading
ConnectionReader r = connectionReaderFactory.createConnectionReader(in, ConnectionReader r = connectionReaderFactory.createConnectionReader(in,
true, transportId, connection, bobSecret); iv, bobSecret);
in = r.getInputStream(); in = r.getInputStream();
ProtocolReader protocolReader = ProtocolReader protocolReader =
protocolReaderFactory.createProtocolReader(in); protocolReaderFactory.createProtocolReader(in);

View File

@@ -0,0 +1,34 @@
package net.sf.briar;
import java.io.File;
import net.sf.briar.api.crypto.Password;
import net.sf.briar.api.db.DatabaseDirectory;
import net.sf.briar.api.db.DatabaseMaxSize;
import net.sf.briar.api.db.DatabasePassword;
import com.google.inject.AbstractModule;
public class TestDatabaseModule extends AbstractModule {
private final File dir;
private final Password password;
public TestDatabaseModule(File dir) {
this.dir = dir;
this.password = new Password() {
public char[] getPassword() {
return "foo bar".toCharArray();
}
};
}
@Override
protected void configure() {
bind(File.class).annotatedWith(DatabaseDirectory.class).toInstance(dir);
bind(Password.class).annotatedWith(
DatabasePassword.class).toInstance(password);
bind(long.class).annotatedWith(
DatabaseMaxSize.class).toInstance(Long.MAX_VALUE);
}
}

View File

@@ -13,6 +13,7 @@ import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.TestDatabaseModule;
import net.sf.briar.TestUtils; import net.sf.briar.TestUtils;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.Rating; import net.sf.briar.api.Rating;
@@ -71,8 +72,8 @@ public class H2DatabaseTest extends TestCase {
public H2DatabaseTest() throws Exception { public H2DatabaseTest() throws Exception {
super(); super();
Injector i = Guice.createInjector(new CryptoModule(), Injector i = Guice.createInjector(new CryptoModule(),
new ProtocolModule(), new SerialModule(), new DatabaseModule(), new ProtocolModule(), new SerialModule(),
new TransportModule()); new TransportModule(), new TestDatabaseModule(testDir));
connectionWindowFactory = i.getInstance(ConnectionWindowFactory.class); connectionWindowFactory = i.getInstance(ConnectionWindowFactory.class);
groupFactory = i.getInstance(GroupFactory.class); groupFactory = i.getInstance(GroupFactory.class);
authorId = new AuthorId(TestUtils.getRandomId()); authorId = new AuthorId(TestUtils.getRandomId());

View File

@@ -1,5 +1,7 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.util.Arrays; import java.util.Arrays;
@@ -22,8 +24,8 @@ public class ConnectionDecrypterImplTest extends TestCase {
private static final int MAC_LENGTH = 32; private static final int MAC_LENGTH = 32;
private final Cipher frameCipher; private final Cipher ivCipher, frameCipher;
private final SecretKey frameKey; private final SecretKey ivKey, frameKey;
private final int transportId = 1234; private final int transportId = 1234;
private final long connection = 12345L; private final long connection = 12345L;
@@ -31,23 +33,12 @@ public class ConnectionDecrypterImplTest extends TestCase {
super(); super();
Injector i = Guice.createInjector(new CryptoModule()); Injector i = Guice.createInjector(new CryptoModule());
CryptoComponent crypto = i.getInstance(CryptoComponent.class); CryptoComponent crypto = i.getInstance(CryptoComponent.class);
ivCipher = crypto.getIvCipher();
frameCipher = crypto.getFrameCipher(); frameCipher = crypto.getFrameCipher();
ivKey = crypto.generateSecretKey();
frameKey = crypto.generateSecretKey(); frameKey = crypto.generateSecretKey();
} }
@Test
public void testSingleByteFrame() throws Exception {
// Create a fake ciphertext frame: one byte plus a MAC
byte[] ciphertext = new byte[1 + MAC_LENGTH];
ByteArrayInputStream in = new ByteArrayInputStream(ciphertext);
// Check that one byte plus a MAC can be read
ConnectionDecrypter d = new ConnectionDecrypterImpl(in, true,
transportId, connection, frameCipher, frameKey);
assertFalse(d.getInputStream().read() == -1);
d.readMac(new byte[MAC_LENGTH]);
assertTrue(d.getInputStream().read() == -1);
}
@Test @Test
public void testInitiatorDecryption() throws Exception { public void testInitiatorDecryption() throws Exception {
testDecryption(true); testDecryption(true);
@@ -59,34 +50,48 @@ public class ConnectionDecrypterImplTest extends TestCase {
} }
private void testDecryption(boolean initiator) throws Exception { private void testDecryption(boolean initiator) throws Exception {
// Calculate the plaintext and ciphertext for the IV
byte[] iv = IvEncoder.encodeIv(initiator, transportId, connection);
ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
byte[] encryptedIv = ivCipher.doFinal(iv);
assertEquals(IV_LENGTH, encryptedIv.length);
// Calculate the expected plaintext for the first frame // Calculate the expected plaintext for the first frame
byte[] ciphertext = new byte[123]; byte[] ciphertext = new byte[123];
byte[] iv = IvEncoder.encodeIv(initiator, transportId, connection); byte[] ciphertextMac = new byte[MAC_LENGTH];
IvParameterSpec ivSpec = new IvParameterSpec(iv); IvParameterSpec ivSpec = new IvParameterSpec(iv);
frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec); frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec);
byte[] plaintext = frameCipher.doFinal(ciphertext); byte[] plaintext = new byte[ciphertext.length + ciphertextMac.length];
int offset = frameCipher.update(ciphertext, 0, ciphertext.length,
plaintext);
frameCipher.doFinal(ciphertextMac, 0, ciphertextMac.length, plaintext,
offset);
// Calculate the expected plaintext for the second frame // Calculate the expected plaintext for the second frame
byte[] ciphertext1 = new byte[1234]; byte[] ciphertext1 = new byte[1234];
IvEncoder.updateIv(iv, 1L); IvEncoder.updateIv(iv, 1L);
ivSpec = new IvParameterSpec(iv); ivSpec = new IvParameterSpec(iv);
frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec); frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec);
byte[] plaintext1 = frameCipher.doFinal(ciphertext1); byte[] plaintext1 = new byte[ciphertext1.length + ciphertextMac.length];
assertEquals(ciphertext1.length, plaintext1.length); offset = frameCipher.update(ciphertext1, 0, ciphertext1.length,
plaintext1);
frameCipher.doFinal(ciphertextMac, 0, ciphertextMac.length, plaintext1,
offset);
// Concatenate the ciphertexts // Concatenate the ciphertexts
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
out.write(ciphertext); out.write(ciphertext);
out.write(ciphertextMac);
out.write(ciphertext1); out.write(ciphertext1);
out.write(ciphertextMac);
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
// Use a ConnectionDecrypter to decrypt the ciphertext // Use a ConnectionDecrypter to decrypt the ciphertext
ConnectionDecrypter d = new ConnectionDecrypterImpl(in, initiator, ConnectionDecrypter d = new ConnectionDecrypterImpl(in, encryptedIv,
transportId, connection, frameCipher, frameKey); ivCipher, frameCipher, ivKey, frameKey);
// First frame // First frame
byte[] decrypted = new byte[plaintext.length - MAC_LENGTH]; byte[] decrypted = new byte[ciphertext.length];
TestUtils.readFully(d.getInputStream(), decrypted); TestUtils.readFully(d.getInputStream(), decrypted);
byte[] decryptedMac = new byte[MAC_LENGTH]; byte[] decryptedMac = new byte[MAC_LENGTH];
d.readMac(decryptedMac); d.readMac(decryptedMac);
// Second frame // Second frame
byte[] decrypted1 = new byte[plaintext1.length - MAC_LENGTH]; byte[] decrypted1 = new byte[ciphertext1.length];
TestUtils.readFully(d.getInputStream(), decrypted1); TestUtils.readFully(d.getInputStream(), decrypted1);
byte[] decryptedMac1 = new byte[MAC_LENGTH]; byte[] decryptedMac1 = new byte[MAC_LENGTH];
d.readMac(decryptedMac1); d.readMac(decryptedMac1);

View File

@@ -37,19 +37,6 @@ public class ConnectionEncrypterImplTest extends TestCase {
frameKey = crypto.generateSecretKey(); frameKey = crypto.generateSecretKey();
} }
@Test
public void testSingleByteFrame() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new ConnectionEncrypterImpl(out, Long.MAX_VALUE,
true, transportId, connection, ivCipher, frameCipher, ivKey,
frameKey);
e.getOutputStream().write((byte) 0);
e.writeMac(new byte[MAC_LENGTH]);
byte[] ciphertext = out.toByteArray();
assertEquals(IV_LENGTH + 1 + MAC_LENGTH, ciphertext.length);
assertEquals(Long.MAX_VALUE - ciphertext.length, e.getCapacity());
}
@Test @Test
public void testInitiatorEncryption() throws Exception { public void testInitiatorEncryption() throws Exception {
testEncryption(true); testEncryption(true);
@@ -63,7 +50,6 @@ public class ConnectionEncrypterImplTest extends TestCase {
private void testEncryption(boolean initiator) throws Exception { private void testEncryption(boolean initiator) throws Exception {
// Calculate the expected ciphertext for the IV // Calculate the expected ciphertext for the IV
byte[] iv = IvEncoder.encodeIv(initiator, transportId, connection); byte[] iv = IvEncoder.encodeIv(initiator, transportId, connection);
assertEquals(IV_LENGTH, iv.length);
ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
byte[] encryptedIv = ivCipher.doFinal(iv); byte[] encryptedIv = ivCipher.doFinal(iv);
assertEquals(IV_LENGTH, encryptedIv.length); assertEquals(IV_LENGTH, encryptedIv.length);
@@ -95,9 +81,9 @@ public class ConnectionEncrypterImplTest extends TestCase {
byte[] expected = out.toByteArray(); byte[] expected = out.toByteArray();
// Use a ConnectionEncrypter to encrypt the plaintext // Use a ConnectionEncrypter to encrypt the plaintext
out.reset(); out.reset();
iv = IvEncoder.encodeIv(initiator, transportId, connection);
ConnectionEncrypter e = new ConnectionEncrypterImpl(out, Long.MAX_VALUE, ConnectionEncrypter e = new ConnectionEncrypterImpl(out, Long.MAX_VALUE,
initiator, transportId, connection, ivCipher, frameCipher, iv, ivCipher, frameCipher, ivKey, frameKey);
ivKey, frameKey);
e.getOutputStream().write(plaintext); e.getOutputStream().write(plaintext);
e.writeMac(plaintextMac); e.writeMac(plaintextMac);
e.getOutputStream().write(plaintext1); e.getOutputStream().write(plaintext1);
@@ -105,6 +91,6 @@ public class ConnectionEncrypterImplTest extends TestCase {
byte[] actual = out.toByteArray(); byte[] actual = out.toByteArray();
// Check that the actual ciphertext matches the expected ciphertext // Check that the actual ciphertext matches the expected ciphertext
assertTrue(Arrays.equals(expected, actual)); assertTrue(Arrays.equals(expected, actual));
assertEquals(Long.MAX_VALUE - actual.length, e.getCapacity()); assertEquals(Long.MAX_VALUE - actual.length, e.getRemainingCapacity());
} }
} }

View File

@@ -24,12 +24,13 @@ public class ConnectionReaderImplTest extends TransportTest {
byte[] frame = new byte[headerLength + payloadLength + macLength]; byte[] frame = new byte[headerLength + payloadLength + macLength];
writeHeader(frame, payloadLength, 0); writeHeader(frame, payloadLength, 0);
// Calculate the MAC // Calculate the MAC
mac.init(macKey);
mac.update(frame, 0, headerLength + payloadLength); mac.update(frame, 0, headerLength + payloadLength);
mac.doFinal(frame, headerLength + payloadLength); mac.doFinal(frame, headerLength + payloadLength);
// Read the frame // Read the frame
ByteArrayInputStream in = new ByteArrayInputStream(frame); ByteArrayInputStream in = new ByteArrayInputStream(frame);
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in);
ConnectionReader r = new ConnectionReaderImpl(d, mac); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
try { try {
r.getInputStream().read(); r.getInputStream().read();
fail(); fail();
@@ -42,12 +43,13 @@ public class ConnectionReaderImplTest extends TransportTest {
byte[] frame = new byte[headerLength + payloadLength + macLength]; byte[] frame = new byte[headerLength + payloadLength + macLength];
writeHeader(frame, payloadLength, 0); writeHeader(frame, payloadLength, 0);
// Calculate the MAC // Calculate the MAC
mac.init(macKey);
mac.update(frame, 0, headerLength + payloadLength); mac.update(frame, 0, headerLength + payloadLength);
mac.doFinal(frame, headerLength + payloadLength); mac.doFinal(frame, headerLength + payloadLength);
// Read the frame // Read the frame
ByteArrayInputStream in = new ByteArrayInputStream(frame); ByteArrayInputStream in = new ByteArrayInputStream(frame);
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in);
ConnectionReader r = new ConnectionReaderImpl(d, mac); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
// There should be one byte available before EOF // There should be one byte available before EOF
assertEquals(0, r.getInputStream().read()); assertEquals(0, r.getInputStream().read());
assertEquals(-1, r.getInputStream().read()); assertEquals(-1, r.getInputStream().read());
@@ -58,6 +60,7 @@ public class ConnectionReaderImplTest extends TransportTest {
// First frame: max payload length // First frame: max payload length
byte[] frame = new byte[MAX_FRAME_LENGTH]; byte[] frame = new byte[MAX_FRAME_LENGTH];
writeHeader(frame, maxPayloadLength, 0); writeHeader(frame, maxPayloadLength, 0);
mac.init(macKey);
mac.update(frame, 0, headerLength + maxPayloadLength); mac.update(frame, 0, headerLength + maxPayloadLength);
mac.doFinal(frame, headerLength + maxPayloadLength); mac.doFinal(frame, headerLength + maxPayloadLength);
// Second frame: max payload length plus one // Second frame: max payload length plus one
@@ -72,7 +75,7 @@ public class ConnectionReaderImplTest extends TransportTest {
// Read the first frame // Read the first frame
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in);
ConnectionReader r = new ConnectionReaderImpl(d, mac); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
byte[] read = new byte[maxPayloadLength]; byte[] read = new byte[maxPayloadLength];
TestUtils.readFully(r.getInputStream(), read); TestUtils.readFully(r.getInputStream(), read);
// Try to read the second frame // Try to read the second frame
@@ -89,6 +92,7 @@ public class ConnectionReaderImplTest extends TransportTest {
// First frame: max payload length, including padding // First frame: max payload length, including padding
byte[] frame = new byte[MAX_FRAME_LENGTH]; byte[] frame = new byte[MAX_FRAME_LENGTH];
writeHeader(frame, maxPayloadLength - paddingLength, paddingLength); writeHeader(frame, maxPayloadLength - paddingLength, paddingLength);
mac.init(macKey);
mac.update(frame, 0, headerLength + maxPayloadLength); mac.update(frame, 0, headerLength + maxPayloadLength);
mac.doFinal(frame, headerLength + maxPayloadLength); mac.doFinal(frame, headerLength + maxPayloadLength);
// Second frame: max payload length plus one, including padding // Second frame: max payload length plus one, including padding
@@ -104,7 +108,7 @@ public class ConnectionReaderImplTest extends TransportTest {
// Read the first frame // Read the first frame
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in);
ConnectionReader r = new ConnectionReaderImpl(d, mac); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
byte[] read = new byte[maxPayloadLength - paddingLength]; byte[] read = new byte[maxPayloadLength - paddingLength];
TestUtils.readFully(r.getInputStream(), read); TestUtils.readFully(r.getInputStream(), read);
// Try to read the second frame // Try to read the second frame
@@ -120,6 +124,7 @@ public class ConnectionReaderImplTest extends TransportTest {
// First frame: 123-byte payload // First frame: 123-byte payload
byte[] frame = new byte[headerLength + 123 + mac.getMacLength()]; byte[] frame = new byte[headerLength + 123 + mac.getMacLength()];
writeHeader(frame, 123, 0); writeHeader(frame, 123, 0);
mac.init(macKey);
mac.update(frame, 0, headerLength + 123); mac.update(frame, 0, headerLength + 123);
mac.doFinal(frame, headerLength + 123); mac.doFinal(frame, headerLength + 123);
// Second frame: 1234-byte payload // Second frame: 1234-byte payload
@@ -134,7 +139,7 @@ public class ConnectionReaderImplTest extends TransportTest {
// Read the frames // Read the frames
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in);
ConnectionReader r = new ConnectionReaderImpl(d, mac); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
byte[] read = new byte[123]; byte[] read = new byte[123];
TestUtils.readFully(r.getInputStream(), read); TestUtils.readFully(r.getInputStream(), read);
assertTrue(Arrays.equals(new byte[123], read)); assertTrue(Arrays.equals(new byte[123], read));
@@ -149,6 +154,7 @@ public class ConnectionReaderImplTest extends TransportTest {
byte[] frame = new byte[headerLength + payloadLength + macLength]; byte[] frame = new byte[headerLength + payloadLength + macLength];
writeHeader(frame, payloadLength, 0); writeHeader(frame, payloadLength, 0);
// Calculate the MAC // Calculate the MAC
mac.init(macKey);
mac.update(frame, 0, headerLength + payloadLength); mac.update(frame, 0, headerLength + payloadLength);
mac.doFinal(frame, headerLength + payloadLength); mac.doFinal(frame, headerLength + payloadLength);
// Modify the payload // Modify the payload
@@ -156,7 +162,7 @@ public class ConnectionReaderImplTest extends TransportTest {
// Try to read the frame - not a single byte should be read // Try to read the frame - not a single byte should be read
ByteArrayInputStream in = new ByteArrayInputStream(frame); ByteArrayInputStream in = new ByteArrayInputStream(frame);
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in);
ConnectionReader r = new ConnectionReaderImpl(d, mac); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
try { try {
r.getInputStream().read(); r.getInputStream().read();
fail(); fail();
@@ -169,6 +175,7 @@ public class ConnectionReaderImplTest extends TransportTest {
byte[] frame = new byte[headerLength + payloadLength + macLength]; byte[] frame = new byte[headerLength + payloadLength + macLength];
writeHeader(frame, payloadLength, 0); writeHeader(frame, payloadLength, 0);
// Calculate the MAC // Calculate the MAC
mac.init(macKey);
mac.update(frame, 0, headerLength + payloadLength); mac.update(frame, 0, headerLength + payloadLength);
mac.doFinal(frame, headerLength + payloadLength); mac.doFinal(frame, headerLength + payloadLength);
// Modify the MAC // Modify the MAC
@@ -176,7 +183,7 @@ public class ConnectionReaderImplTest extends TransportTest {
// Try to read the frame - not a single byte should be read // Try to read the frame - not a single byte should be read
ByteArrayInputStream in = new ByteArrayInputStream(frame); ByteArrayInputStream in = new ByteArrayInputStream(frame);
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in);
ConnectionReader r = new ConnectionReaderImpl(d, mac); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
try { try {
r.getInputStream().read(); r.getInputStream().read();
fail(); fail();

View File

@@ -20,7 +20,7 @@ public class ConnectionWriterImplTest extends TransportTest {
public void testFlushWithoutWriteProducesNothing() throws Exception { public void testFlushWithoutWriteProducesNothing() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out); ConnectionEncrypter e = new NullConnectionEncrypter(out);
ConnectionWriter w = new ConnectionWriterImpl(e, mac); ConnectionWriter w = new ConnectionWriterImpl(e, mac, macKey);
w.getOutputStream().flush(); w.getOutputStream().flush();
w.getOutputStream().flush(); w.getOutputStream().flush();
w.getOutputStream().flush(); w.getOutputStream().flush();
@@ -33,12 +33,13 @@ public class ConnectionWriterImplTest extends TransportTest {
byte[] frame = new byte[headerLength + payloadLength + macLength]; byte[] frame = new byte[headerLength + payloadLength + macLength];
writeHeader(frame, payloadLength, 0); writeHeader(frame, payloadLength, 0);
// Calculate the MAC // Calculate the MAC
mac.init(macKey);
mac.update(frame, 0, headerLength + payloadLength); mac.update(frame, 0, headerLength + payloadLength);
mac.doFinal(frame, headerLength + payloadLength); mac.doFinal(frame, headerLength + payloadLength);
// Check that the ConnectionWriter gets the same results // Check that the ConnectionWriter gets the same results
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out); ConnectionEncrypter e = new NullConnectionEncrypter(out);
ConnectionWriter w = new ConnectionWriterImpl(e, mac); ConnectionWriter w = new ConnectionWriterImpl(e, mac, macKey);
w.getOutputStream().write(0); w.getOutputStream().write(0);
w.getOutputStream().flush(); w.getOutputStream().flush();
assertTrue(Arrays.equals(frame, out.toByteArray())); assertTrue(Arrays.equals(frame, out.toByteArray()));
@@ -48,7 +49,7 @@ public class ConnectionWriterImplTest extends TransportTest {
public void testWriteByteToMaxLengthWritesFrame() throws Exception { public void testWriteByteToMaxLengthWritesFrame() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out); ConnectionEncrypter e = new NullConnectionEncrypter(out);
ConnectionWriter w = new ConnectionWriterImpl(e, mac); ConnectionWriter w = new ConnectionWriterImpl(e, mac, macKey);
OutputStream out1 = w.getOutputStream(); OutputStream out1 = w.getOutputStream();
// The first maxPayloadLength - 1 bytes should be buffered // The first maxPayloadLength - 1 bytes should be buffered
for(int i = 0; i < maxPayloadLength - 1; i++) out1.write(0); for(int i = 0; i < maxPayloadLength - 1; i++) out1.write(0);
@@ -62,7 +63,7 @@ public class ConnectionWriterImplTest extends TransportTest {
public void testWriteArrayToMaxLengthWritesFrame() throws Exception { public void testWriteArrayToMaxLengthWritesFrame() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out); ConnectionEncrypter e = new NullConnectionEncrypter(out);
ConnectionWriter w = new ConnectionWriterImpl(e, mac); ConnectionWriter w = new ConnectionWriterImpl(e, mac, macKey);
OutputStream out1 = w.getOutputStream(); OutputStream out1 = w.getOutputStream();
// The first maxPayloadLength - 1 bytes should be buffered // The first maxPayloadLength - 1 bytes should be buffered
out1.write(new byte[maxPayloadLength - 1]); out1.write(new byte[maxPayloadLength - 1]);
@@ -77,6 +78,7 @@ public class ConnectionWriterImplTest extends TransportTest {
// First frame: 123-byte payload // First frame: 123-byte payload
byte[] frame = new byte[headerLength + 123 + macLength]; byte[] frame = new byte[headerLength + 123 + macLength];
writeHeader(frame, 123, 0); writeHeader(frame, 123, 0);
mac.init(macKey);
mac.update(frame, 0, headerLength + 123); mac.update(frame, 0, headerLength + 123);
mac.doFinal(frame, headerLength + 123); mac.doFinal(frame, headerLength + 123);
// Second frame: 1234-byte payload // Second frame: 1234-byte payload
@@ -92,7 +94,7 @@ public class ConnectionWriterImplTest extends TransportTest {
// Check that the ConnectionWriter gets the same results // Check that the ConnectionWriter gets the same results
out.reset(); out.reset();
ConnectionEncrypter e = new NullConnectionEncrypter(out); ConnectionEncrypter e = new NullConnectionEncrypter(out);
ConnectionWriter w = new ConnectionWriterImpl(e, mac); ConnectionWriter w = new ConnectionWriterImpl(e, mac, macKey);
w.getOutputStream().write(new byte[123]); w.getOutputStream().write(new byte[123]);
w.getOutputStream().flush(); w.getOutputStream().flush();
w.getOutputStream().write(new byte[1234]); w.getOutputStream().write(new byte[1234]);

View File

@@ -4,11 +4,16 @@ import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MIN_CONNECTION_LENGTH; import static net.sf.briar.api.transport.TransportConstants.MIN_CONNECTION_LENGTH;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.File;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.TestDatabaseModule;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.db.DatabaseModule;
import net.sf.briar.protocol.ProtocolModule;
import net.sf.briar.serial.SerialModule;
import org.junit.Test; import org.junit.Test;
@@ -25,7 +30,8 @@ public class ConnectionWriterTest extends TestCase {
public ConnectionWriterTest() throws Exception { public ConnectionWriterTest() throws Exception {
super(); super();
Injector i = Guice.createInjector(new CryptoModule(), Injector i = Guice.createInjector(new CryptoModule(),
new TransportModule()); new DatabaseModule(), new ProtocolModule(), new SerialModule(),
new TestDatabaseModule(new File(".")), new TransportModule());
connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class); connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class);
} }
@@ -36,7 +42,7 @@ public class ConnectionWriterTest extends TestCase {
ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out,
MIN_CONNECTION_LENGTH, true, transportId, connection, secret); MIN_CONNECTION_LENGTH, true, transportId, connection, secret);
// Check that the connection writer thinks there's room for a packet // Check that the connection writer thinks there's room for a packet
long capacity = w.getCapacity(); long capacity = w.getRemainingCapacity();
assertTrue(capacity >= MAX_PACKET_LENGTH); assertTrue(capacity >= MAX_PACKET_LENGTH);
assertTrue(capacity <= MIN_CONNECTION_LENGTH); assertTrue(capacity <= MIN_CONNECTION_LENGTH);
// Check that there really is room for a packet // Check that there really is room for a packet

View File

@@ -60,9 +60,8 @@ public class FrameReadWriteTest extends TestCase {
} }
private void testWriteAndRead(boolean initiator) throws Exception { private void testWriteAndRead(boolean initiator) throws Exception {
// Calculate the expected ciphertext for the IV // Create and encrypt the IV
byte[] iv = IvEncoder.encodeIv(initiator, transportId, connection); byte[] iv = IvEncoder.encodeIv(initiator, transportId, connection);
assertEquals(IV_LENGTH, iv.length);
ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
byte[] encryptedIv = ivCipher.doFinal(iv); byte[] encryptedIv = ivCipher.doFinal(iv);
assertEquals(IV_LENGTH, encryptedIv.length); assertEquals(IV_LENGTH, encryptedIv.length);
@@ -74,23 +73,24 @@ public class FrameReadWriteTest extends TestCase {
// Write the frames // Write the frames
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter encrypter = new ConnectionEncrypterImpl(out, ConnectionEncrypter encrypter = new ConnectionEncrypterImpl(out,
Long.MAX_VALUE, initiator, transportId, connection, ivCipher, Long.MAX_VALUE, iv, ivCipher, frameCipher, ivKey, frameKey);
frameCipher, ivKey, frameKey); ConnectionWriter writer = new ConnectionWriterImpl(encrypter, mac,
mac.init(macKey); macKey);
ConnectionWriter writer = new ConnectionWriterImpl(encrypter, mac);
OutputStream out1 = writer.getOutputStream(); OutputStream out1 = writer.getOutputStream();
out1.write(frame); out1.write(frame);
out1.flush(); out1.flush();
out1.write(frame1); out1.write(frame1);
out1.flush(); out1.flush();
// Read the frames back // Read the IV back
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
byte[] recoveredIv = new byte[IV_LENGTH]; byte[] recoveredIv = new byte[IV_LENGTH];
assertEquals(IV_LENGTH, in.read(recoveredIv)); assertEquals(IV_LENGTH, in.read(recoveredIv));
assertTrue(Arrays.equals(encryptedIv, recoveredIv)); assertTrue(Arrays.equals(encryptedIv, recoveredIv));
// Read the frames back
ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in, ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in,
initiator, transportId, connection, frameCipher, frameKey); recoveredIv, ivCipher, frameCipher, ivKey, frameKey);
ConnectionReader reader = new ConnectionReaderImpl(decrypter, mac); ConnectionReader reader = new ConnectionReaderImpl(decrypter, mac,
macKey);
InputStream in1 = reader.getInputStream(); InputStream in1 = reader.getInputStream();
byte[] recovered = new byte[frame.length]; byte[] recovered = new byte[frame.length];
int offset = 0; int offset = 0;

View File

@@ -28,7 +28,7 @@ implements ConnectionEncrypter {
capacity -= mac.length; capacity -= mac.length;
} }
public long getCapacity() { public long getRemainingCapacity() {
return capacity; return capacity;
} }

View File

@@ -24,7 +24,7 @@ public class PaddedConnectionWriterTest extends TransportTest {
public void testWriteByteDoesNotBlockUntilBufferIsFull() throws Exception { public void testWriteByteDoesNotBlockUntilBufferIsFull() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE); ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE);
ConnectionWriter w = new PaddedConnectionWriter(e, mac); ConnectionWriter w = new PaddedConnectionWriter(e, mac, macKey);
final OutputStream out1 = w.getOutputStream(); final OutputStream out1 = w.getOutputStream();
final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean finished = new AtomicBoolean(false); final AtomicBoolean finished = new AtomicBoolean(false);
@@ -53,7 +53,7 @@ public class PaddedConnectionWriterTest extends TransportTest {
public void testWriteByteBlocksWhenBufferIsFull() throws Exception { public void testWriteByteBlocksWhenBufferIsFull() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE); ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE);
PaddedConnectionWriter w = new PaddedConnectionWriter(e, mac); PaddedConnectionWriter w = new PaddedConnectionWriter(e, mac, macKey);
final OutputStream out1 = w.getOutputStream(); final OutputStream out1 = w.getOutputStream();
final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean finished = new AtomicBoolean(false); final AtomicBoolean finished = new AtomicBoolean(false);
@@ -87,7 +87,7 @@ public class PaddedConnectionWriterTest extends TransportTest {
public void testWriteArrayDoesNotBlockUntilBufferIsFull() throws Exception { public void testWriteArrayDoesNotBlockUntilBufferIsFull() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE); ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE);
ConnectionWriter w = new PaddedConnectionWriter(e, mac); ConnectionWriter w = new PaddedConnectionWriter(e, mac, macKey);
final OutputStream out1 = w.getOutputStream(); final OutputStream out1 = w.getOutputStream();
final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean finished = new AtomicBoolean(false); final AtomicBoolean finished = new AtomicBoolean(false);
@@ -116,7 +116,7 @@ public class PaddedConnectionWriterTest extends TransportTest {
public void testWriteArrayBlocksWhenBufferIsFull() throws Exception { public void testWriteArrayBlocksWhenBufferIsFull() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE); ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE);
PaddedConnectionWriter w = new PaddedConnectionWriter(e, mac); PaddedConnectionWriter w = new PaddedConnectionWriter(e, mac, macKey);
final OutputStream out1 = w.getOutputStream(); final OutputStream out1 = w.getOutputStream();
final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean finished = new AtomicBoolean(false); final AtomicBoolean finished = new AtomicBoolean(false);
@@ -150,7 +150,7 @@ public class PaddedConnectionWriterTest extends TransportTest {
public void testWriteFullFrameInsertsPadding() throws Exception { public void testWriteFullFrameInsertsPadding() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE); ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE);
PaddedConnectionWriter w = new PaddedConnectionWriter(e, mac); PaddedConnectionWriter w = new PaddedConnectionWriter(e, mac, macKey);
w.getOutputStream().write(0); w.getOutputStream().write(0);
w.writeFullFrame(); w.writeFullFrame();
// A full frame should have been written // A full frame should have been written

View File

@@ -3,18 +3,20 @@ package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH; import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import javax.crypto.Mac; import javax.crypto.Mac;
import javax.crypto.SecretKey;
import junit.framework.TestCase;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.util.ByteUtils;
import com.google.inject.Guice; import com.google.inject.Guice;
import com.google.inject.Injector; import com.google.inject.Injector;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.util.ByteUtils;
import junit.framework.TestCase;
public abstract class TransportTest extends TestCase { public abstract class TransportTest extends TestCase {
protected final Mac mac; protected final Mac mac;
protected final SecretKey macKey;
protected final int headerLength = 4, macLength, maxPayloadLength; protected final int headerLength = 4, macLength, maxPayloadLength;
public TransportTest() throws Exception { public TransportTest() throws Exception {
@@ -22,7 +24,7 @@ public abstract class TransportTest extends TestCase {
Injector i = Guice.createInjector(new CryptoModule()); Injector i = Guice.createInjector(new CryptoModule());
CryptoComponent crypto = i.getInstance(CryptoComponent.class); CryptoComponent crypto = i.getInstance(CryptoComponent.class);
mac = crypto.getMac(); mac = crypto.getMac();
mac.init(crypto.generateSecretKey()); macKey = crypto.generateSecretKey();
macLength = mac.getMacLength(); macLength = mac.getMacLength();
maxPayloadLength = MAX_FRAME_LENGTH - headerLength - macLength; maxPayloadLength = MAX_FRAME_LENGTH - headerLength - macLength;
} }

View File

@@ -1,38 +1,24 @@
package net.sf.briar.transport.batch; package net.sf.briar.transport.batch;
import java.io.FilterInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import net.sf.briar.api.transport.batch.BatchTransportReader; import net.sf.briar.api.transport.batch.BatchTransportReader;
class TestBatchTransportReader extends FilterInputStream class TestBatchTransportReader implements BatchTransportReader {
implements BatchTransportReader {
private final InputStream in;
TestBatchTransportReader(InputStream in) { TestBatchTransportReader(InputStream in) {
super(in); this.in = in;
} }
public InputStream getInputStream() { public InputStream getInputStream() {
return this; return in;
} }
public void dispose() throws IOException { public void dispose() throws IOException {
// Nothing to do // The input stream may have been left open
} in.close();
@Override
public int read() throws IOException {
return in.read();
}
@Override
public int read(byte[] b) throws IOException {
return read(b, 0, b.length);
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
return in.read(b, off, len);
} }
} }

View File

@@ -1,18 +1,17 @@
package net.sf.briar.transport.batch; package net.sf.briar.transport.batch;
import java.io.FilterOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import net.sf.briar.api.transport.batch.BatchTransportWriter; import net.sf.briar.api.transport.batch.BatchTransportWriter;
class TestBatchTransportWriter extends FilterOutputStream class TestBatchTransportWriter implements BatchTransportWriter {
implements BatchTransportWriter {
private int capacity; private final OutputStream out;
private final int capacity;
TestBatchTransportWriter(OutputStream out, int capacity) { TestBatchTransportWriter(OutputStream out, int capacity) {
super(out); this.out = out;
this.capacity = capacity; this.capacity = capacity;
} }
@@ -21,29 +20,11 @@ implements BatchTransportWriter {
} }
public OutputStream getOutputStream() { public OutputStream getOutputStream() {
return this; return out;
} }
public void dispose() throws IOException { public void dispose() throws IOException {
// Nothing to do // The output stream may have been left open
} out.close();
@Override
public void write(int b) throws IOException {
if(capacity < 1) throw new IllegalArgumentException();
out.write(b);
capacity--;
}
@Override
public void write(byte[] b) throws IOException {
write(b, 0, b.length);
}
@Override
public void write(byte[] b, int off, int len) throws IOException {
if(len > capacity) throw new IllegalArgumentException();
out.write(b, off, len);
capacity -= len;
} }
} }