Refactored transport component and renamed WritersModule.

The goal of the refactoring was to clean up the dependencies of
IncomingBatchConnection and OutgoingBatchConnection.
This commit is contained in:
akwizgran
2011-09-27 19:21:44 +01:00
parent 6ed8d89e59
commit 4aff0c4f88
21 changed files with 105 additions and 161 deletions

View File

@@ -11,10 +11,6 @@ public interface ConnectionWriter {
*/ */
OutputStream getOutputStream(); OutputStream getOutputStream();
/** /** Returns the maximum number of bytes that can be written. */
* Returns the number of bytes that can be written to this writer without long getCapacity();
* outputting more than the given number of bytes, including encryption and
* authentication overhead.
*/
long getCapacity(long capacity);
} }

View File

@@ -4,6 +4,7 @@ import java.io.OutputStream;
public interface ConnectionWriterFactory { public interface ConnectionWriterFactory {
ConnectionWriter createConnectionWriter(OutputStream out, boolean initiator, ConnectionWriter createConnectionWriter(OutputStream out,
int transportId, long connection, byte[] secret); long capacity, boolean initiator, int transportId, long connection,
byte[] secret);
} }

View File

@@ -10,7 +10,7 @@ import java.io.InputStream;
public interface BatchTransportReader { public interface BatchTransportReader {
/** Returns an input stream for reading from the transport. */ /** Returns an input stream for reading from the transport. */
InputStream getInputStream() throws IOException; InputStream getInputStream();
/** /**
* Closes the reader and disposes of any associated state. This method must * Closes the reader and disposes of any associated state. This method must

View File

@@ -10,10 +10,10 @@ import java.io.OutputStream;
public interface BatchTransportWriter { public interface BatchTransportWriter {
/** Returns the maximum number of bytes that can be written. */ /** Returns the maximum number of bytes that can be written. */
long getCapacity() throws IOException; long getCapacity();
/** Returns an output stream for writing to the transport. */ /** Returns an output stream for writing to the transport. */
OutputStream getOutputStream() throws IOException; OutputStream getOutputStream();
/** /**
* Closes the writer and disposes of any associated state. This method must * Closes the writer and disposes of any associated state. This method must

View File

@@ -4,7 +4,7 @@ import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import com.google.inject.AbstractModule; import com.google.inject.AbstractModule;
public class WritersModule extends AbstractModule { public class ProtocolWritersModule extends AbstractModule {
@Override @Override
protected void configure() { protected void configure() {

View File

@@ -12,9 +12,6 @@ interface ConnectionEncrypter {
/** Encrypts and writes the MAC for the current frame. */ /** Encrypts and writes the MAC for the current frame. */
void writeMac(byte[] mac) throws IOException; void writeMac(byte[] mac) throws IOException;
/** /** Returns the maximum number of bytes that can be written. */
* Returns the number of bytes that can be encrypted without outputting long getCapacity();
* more than the given number of bytes, including encryption overhead.
*/
long getCapacity(long capacity);
} }

View File

@@ -22,10 +22,10 @@ implements ConnectionEncrypter {
private final SecretKey frameKey; private final SecretKey frameKey;
private final byte[] iv; private final byte[] iv;
private long frame = 0L; private long capacity, frame = 0L;
private boolean ivWritten = false, betweenFrames = false; private boolean ivWritten = false, betweenFrames = false;
ConnectionEncrypterImpl(OutputStream out, boolean initiator, ConnectionEncrypterImpl(OutputStream out, long capacity, boolean initiator,
int transportId, long connection, Cipher ivCipher, int transportId, long connection, Cipher ivCipher,
Cipher frameCipher, SecretKey ivKey, SecretKey frameKey) { Cipher frameCipher, SecretKey ivKey, SecretKey frameKey) {
super(out); super(out);
@@ -40,6 +40,7 @@ implements ConnectionEncrypter {
} }
if(ivCipher.getOutputSize(IV_LENGTH) != IV_LENGTH) if(ivCipher.getOutputSize(IV_LENGTH) != IV_LENGTH)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
this.capacity = capacity;
} }
public OutputStream getOutputStream() { public OutputStream getOutputStream() {
@@ -55,12 +56,12 @@ implements ConnectionEncrypter {
} catch(IllegalBlockSizeException badCipher) { } catch(IllegalBlockSizeException badCipher) {
throw new RuntimeException(badCipher); throw new RuntimeException(badCipher);
} }
capacity -= mac.length;
betweenFrames = true; betweenFrames = true;
} }
public long getCapacity(long capacity) { public long getCapacity() {
if(capacity < 0L) throw new IllegalArgumentException(); return capacity;
return ivWritten ? capacity : Math.max(0L, capacity - IV_LENGTH);
} }
@Override @Override
@@ -69,6 +70,7 @@ implements ConnectionEncrypter {
if(betweenFrames) initialiseCipher(); if(betweenFrames) initialiseCipher();
byte[] ciphertext = frameCipher.update(new byte[] {(byte) b}); byte[] ciphertext = frameCipher.update(new byte[] {(byte) b});
if(ciphertext != null) out.write(ciphertext); if(ciphertext != null) out.write(ciphertext);
capacity--;
} }
@Override @Override
@@ -82,6 +84,7 @@ implements ConnectionEncrypter {
if(betweenFrames) initialiseCipher(); if(betweenFrames) initialiseCipher();
byte[] ciphertext = frameCipher.update(b, off, len); byte[] ciphertext = frameCipher.update(b, off, len);
if(ciphertext != null) out.write(ciphertext); if(ciphertext != null) out.write(ciphertext);
capacity -= len;
} }
private void writeIv() throws IOException { private void writeIv() throws IOException {
@@ -94,6 +97,7 @@ implements ConnectionEncrypter {
} catch(IllegalBlockSizeException badCipher) { } catch(IllegalBlockSizeException badCipher) {
throw new RuntimeException(badCipher); throw new RuntimeException(badCipher);
} }
capacity -= iv.length;
ivWritten = true; ivWritten = true;
betweenFrames = true; betweenFrames = true;
} }

View File

@@ -23,7 +23,7 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
} }
public ConnectionWriter createConnectionWriter(OutputStream out, public ConnectionWriter createConnectionWriter(OutputStream out,
boolean initiator, int transportId, long connection, long capacity, boolean initiator, int transportId, long connection,
byte[] secret) { byte[] secret) {
SecretKey macKey = crypto.deriveOutgoingMacKey(secret); SecretKey macKey = crypto.deriveOutgoingMacKey(secret);
SecretKey ivKey = crypto.deriveOutgoingIvKey(secret); SecretKey ivKey = crypto.deriveOutgoingIvKey(secret);
@@ -37,8 +37,8 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
throw new IllegalArgumentException(badKey); throw new IllegalArgumentException(badKey);
} }
ConnectionEncrypter encrypter = new ConnectionEncrypterImpl(out, ConnectionEncrypter encrypter = new ConnectionEncrypterImpl(out,
initiator, transportId, connection, ivCipher, frameCipher, capacity, initiator, transportId, connection, ivCipher,
ivKey, frameKey); frameCipher, ivKey, frameKey);
return new ConnectionWriterImpl(encrypter, mac); return new ConnectionWriterImpl(encrypter, mac);
} }
} }

View File

@@ -41,10 +41,8 @@ implements ConnectionWriter {
return this; return this;
} }
public long getCapacity(long capacity) { public long getCapacity() {
if(capacity < 0L) throw new IllegalArgumentException(); long capacity = encrypter.getCapacity();
// Subtract the encryption overhead
capacity = encrypter.getCapacity(capacity);
// If there's any data buffered, subtract it and its auth overhead // If there's any data buffered, subtract it and its auth overhead
int overheadPerFrame = header.length + mac.getMacLength(); int overheadPerFrame = header.length + mac.getMacLength();
if(buf.size() > 0) capacity -= buf.size() + overheadPerFrame; if(buf.size() > 0) capacity -= buf.size() + overheadPerFrame;

View File

@@ -14,36 +14,23 @@ import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.batch.BatchTransportReader;
class IncomingBatchConnection { class IncomingBatchConnection {
private final BatchTransportReader trans; private final ConnectionReader conn;
private final ConnectionReaderFactory connFactory;
private final DatabaseComponent db; private final DatabaseComponent db;
private final ProtocolReaderFactory protoFactory; private final ProtocolReaderFactory protoFactory;
private final int transportId;
private final long connection;
private final ContactId contactId; private final ContactId contactId;
IncomingBatchConnection(BatchTransportReader trans, IncomingBatchConnection(ConnectionReader conn, DatabaseComponent db,
ConnectionReaderFactory connFactory, DatabaseComponent db, ProtocolReaderFactory protoFactory, ContactId contactId) {
ProtocolReaderFactory protoFactory, int transportId, this.conn = conn;
long connection, ContactId contactId) {
this.trans = trans;
this.connFactory = connFactory;
this.db = db; this.db = db;
this.protoFactory = protoFactory; this.protoFactory = protoFactory;
this.transportId = transportId;
this.connection = connection;
this.contactId = contactId; this.contactId = contactId;
} }
void read() throws DbException, IOException { void read() throws DbException, IOException {
byte[] secret = db.getSharedSecret(contactId);
ConnectionReader conn = connFactory.createConnectionReader(
trans.getInputStream(), false, transportId, connection, secret);
InputStream in = conn.getInputStream(); InputStream in = conn.getInputStream();
ProtocolReader proto = protoFactory.createProtocolReader(in); ProtocolReader proto = protoFactory.createProtocolReader(in);
// Read packets until EOF // Read packets until EOF

View File

@@ -14,45 +14,32 @@ import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import net.sf.briar.api.protocol.writers.SubscriptionWriter; import net.sf.briar.api.protocol.writers.SubscriptionWriter;
import net.sf.briar.api.protocol.writers.TransportWriter; import net.sf.briar.api.protocol.writers.TransportWriter;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.api.transport.batch.BatchTransportWriter;
class OutgoingBatchConnection { class OutgoingBatchConnection {
private final BatchTransportWriter trans; private final ConnectionWriter conn;
private final ConnectionWriterFactory connFactory;
private final DatabaseComponent db; private final DatabaseComponent db;
private final ProtocolWriterFactory protoFactory; private final ProtocolWriterFactory protoFactory;
private final int transportId;
private final long connection;
private final ContactId contactId; private final ContactId contactId;
OutgoingBatchConnection(BatchTransportWriter trans, OutgoingBatchConnection(ConnectionWriter conn, DatabaseComponent db,
ConnectionWriterFactory connFactory, DatabaseComponent db, ProtocolWriterFactory protoFactory, ContactId contactId) {
ProtocolWriterFactory protoFactory, int transportId, this.conn = conn;
long connection, ContactId contactId) {
this.trans = trans;
this.connFactory = connFactory;
this.db = db; this.db = db;
this.protoFactory = protoFactory; this.protoFactory = protoFactory;
this.transportId = transportId;
this.connection = connection;
this.contactId = contactId; this.contactId = contactId;
} }
void write() throws DbException, IOException { void write() throws DbException, IOException {
byte[] secret = db.getSharedSecret(contactId);
ConnectionWriter conn = connFactory.createConnectionWriter(
trans.getOutputStream(), true, transportId, connection, secret);
OutputStream out = conn.getOutputStream(); OutputStream out = conn.getOutputStream();
// There should be enough space for a packet // There should be enough space for a packet
long capacity = conn.getCapacity(trans.getCapacity()); long capacity = conn.getCapacity();
if(capacity < MAX_PACKET_LENGTH) throw new IOException(); if(capacity < MAX_PACKET_LENGTH) throw new IOException();
// Write a transport update // Write a transport update
TransportWriter t = protoFactory.createTransportWriter(out); TransportWriter t = protoFactory.createTransportWriter(out);
db.generateTransportUpdate(contactId, t); db.generateTransportUpdate(contactId, t);
// If there's space, write a subscription update // If there's space, write a subscription update
capacity = conn.getCapacity(trans.getCapacity()); capacity = conn.getCapacity();
if(capacity >= MAX_PACKET_LENGTH) { if(capacity >= MAX_PACKET_LENGTH) {
SubscriptionWriter s = protoFactory.createSubscriptionWriter(out); SubscriptionWriter s = protoFactory.createSubscriptionWriter(out);
db.generateSubscriptionUpdate(contactId, s); db.generateSubscriptionUpdate(contactId, s);
@@ -60,14 +47,14 @@ class OutgoingBatchConnection {
// Write acks until you can't write acks no more // Write acks until you can't write acks no more
AckWriter a = protoFactory.createAckWriter(out); AckWriter a = protoFactory.createAckWriter(out);
do { do {
capacity = conn.getCapacity(trans.getCapacity()); capacity = conn.getCapacity();
int max = (int) Math.min(MAX_PACKET_LENGTH, capacity); int max = (int) Math.min(MAX_PACKET_LENGTH, capacity);
a.setMaxPacketLength(max); a.setMaxPacketLength(max);
} while(db.generateAck(contactId, a)); } while(db.generateAck(contactId, a));
// Write batches until you can't write batches no more // Write batches until you can't write batches no more
BatchWriter b = protoFactory.createBatchWriter(out); BatchWriter b = protoFactory.createBatchWriter(out);
do { do {
capacity = conn.getCapacity(trans.getCapacity()); capacity = conn.getCapacity();
int max = (int) Math.min(MAX_PACKET_LENGTH, capacity); int max = (int) Math.min(MAX_PACKET_LENGTH, capacity);
b.setMaxPacketLength(max); b.setMaxPacketLength(max);
} while(db.generateBatch(contactId, b)); } while(db.generateBatch(contactId, b));

View File

@@ -46,7 +46,7 @@ import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.protocol.ProtocolModule; import net.sf.briar.protocol.ProtocolModule;
import net.sf.briar.protocol.writers.WritersModule; import net.sf.briar.protocol.writers.ProtocolWritersModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
import net.sf.briar.transport.TransportModule; import net.sf.briar.transport.TransportModule;
@@ -83,8 +83,8 @@ public class FileReadWriteTest extends TestCase {
public FileReadWriteTest() throws Exception { public FileReadWriteTest() throws Exception {
super(); super();
Injector i = Guice.createInjector(new CryptoModule(), Injector i = Guice.createInjector(new CryptoModule(),
new ProtocolModule(), new SerialModule(), new TransportModule(), new ProtocolModule(), new ProtocolWritersModule(),
new WritersModule()); new SerialModule(), new TransportModule());
connectionReaderFactory = i.getInstance(ConnectionReaderFactory.class); connectionReaderFactory = i.getInstance(ConnectionReaderFactory.class);
connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class); connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class);
protocolReaderFactory = i.getInstance(ProtocolReaderFactory.class); protocolReaderFactory = i.getInstance(ProtocolReaderFactory.class);
@@ -132,7 +132,7 @@ public class FileReadWriteTest extends TestCase {
OutputStream out = new FileOutputStream(file); OutputStream out = new FileOutputStream(file);
// Use Alice's secret for writing // Use Alice's secret for writing
ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out,
true, transportId, connection, aliceSecret); Long.MAX_VALUE, true, transportId, connection, aliceSecret);
out = w.getOutputStream(); out = w.getOutputStream();
AckWriter a = protocolWriterFactory.createAckWriter(out); AckWriter a = protocolWriterFactory.createAckWriter(out);

View File

@@ -29,7 +29,7 @@ import net.sf.briar.api.protocol.writers.RequestWriter;
import net.sf.briar.api.protocol.writers.SubscriptionWriter; import net.sf.briar.api.protocol.writers.SubscriptionWriter;
import net.sf.briar.api.protocol.writers.TransportWriter; import net.sf.briar.api.protocol.writers.TransportWriter;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.protocol.writers.WritersModule; import net.sf.briar.protocol.writers.ProtocolWritersModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
import org.junit.Test; import org.junit.Test;
@@ -53,7 +53,8 @@ public class ProtocolReadWriteTest extends TestCase {
public ProtocolReadWriteTest() throws Exception { public ProtocolReadWriteTest() throws Exception {
super(); super();
Injector i = Guice.createInjector(new CryptoModule(), Injector i = Guice.createInjector(new CryptoModule(),
new ProtocolModule(), new SerialModule(), new WritersModule()); new ProtocolModule(), new ProtocolWritersModule(),
new SerialModule());
readerFactory = i.getInstance(ProtocolReaderFactory.class); readerFactory = i.getInstance(ProtocolReaderFactory.class);
writerFactory = i.getInstance(ProtocolWriterFactory.class); writerFactory = i.getInstance(ProtocolWriterFactory.class);
batchId = new BatchId(TestUtils.getRandomId()); batchId = new BatchId(TestUtils.getRandomId());

View File

@@ -40,12 +40,14 @@ public class ConnectionEncrypterImplTest extends TestCase {
@Test @Test
public void testSingleByteFrame() throws Exception { public void testSingleByteFrame() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new ConnectionEncrypterImpl(out, true, ConnectionEncrypter e = new ConnectionEncrypterImpl(out, Long.MAX_VALUE,
transportId, connection, ivCipher, frameCipher, ivKey, true, transportId, connection, ivCipher, frameCipher, ivKey,
frameKey); frameKey);
e.getOutputStream().write((byte) 0); e.getOutputStream().write((byte) 0);
e.writeMac(new byte[MAC_LENGTH]); e.writeMac(new byte[MAC_LENGTH]);
assertEquals(IV_LENGTH + 1 + MAC_LENGTH, out.toByteArray().length); byte[] ciphertext = out.toByteArray();
assertEquals(IV_LENGTH + 1 + MAC_LENGTH, ciphertext.length);
assertEquals(Long.MAX_VALUE - ciphertext.length, e.getCapacity());
} }
@Test @Test
@@ -93,9 +95,9 @@ public class ConnectionEncrypterImplTest extends TestCase {
byte[] expected = out.toByteArray(); byte[] expected = out.toByteArray();
// Use a ConnectionEncrypter to encrypt the plaintext // Use a ConnectionEncrypter to encrypt the plaintext
out.reset(); out.reset();
ConnectionEncrypter e = new ConnectionEncrypterImpl(out, initiator, ConnectionEncrypter e = new ConnectionEncrypterImpl(out, Long.MAX_VALUE,
transportId, connection, ivCipher, frameCipher, ivKey, initiator, transportId, connection, ivCipher, frameCipher,
frameKey); ivKey, frameKey);
e.getOutputStream().write(plaintext); e.getOutputStream().write(plaintext);
e.writeMac(plaintextMac); e.writeMac(plaintextMac);
e.getOutputStream().write(plaintext1); e.getOutputStream().write(plaintext1);
@@ -103,5 +105,6 @@ public class ConnectionEncrypterImplTest extends TestCase {
byte[] actual = out.toByteArray(); byte[] actual = out.toByteArray();
// Check that the actual ciphertext matches the expected ciphertext // Check that the actual ciphertext matches the expected ciphertext
assertTrue(Arrays.equals(expected, actual)); assertTrue(Arrays.equals(expected, actual));
assertEquals(Long.MAX_VALUE - actual.length, e.getCapacity());
} }
} }

View File

@@ -100,32 +100,4 @@ public class ConnectionWriterImplTest extends TransportTest {
byte[] actual = out.toByteArray(); byte[] actual = out.toByteArray();
assertTrue(Arrays.equals(expected, actual)); assertTrue(Arrays.equals(expected, actual));
} }
@Test
public void testGetCapacity() throws Exception {
int overheadPerFrame = headerLength + macLength;
ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out);
ConnectionWriterImpl w = new ConnectionWriterImpl(e, mac);
// Full frame
long capacity = w.getCapacity(MAX_FRAME_LENGTH);
assertEquals(MAX_FRAME_LENGTH - overheadPerFrame, capacity);
// Partial frame
capacity = w.getCapacity(overheadPerFrame + 1);
assertEquals(1, capacity);
// Full frame and partial frame
capacity = w.getCapacity(MAX_FRAME_LENGTH + 1);
assertEquals(MAX_FRAME_LENGTH + 1 - 2 * overheadPerFrame, capacity);
// Buffer some output
w.getOutputStream().write(0);
// Full frame minus buffered frame
capacity = w.getCapacity(MAX_FRAME_LENGTH);
assertEquals(MAX_FRAME_LENGTH - 1 - 2 * overheadPerFrame, capacity);
// Flush the buffer
w.flush();
assertEquals(1 + overheadPerFrame, out.size());
// Back to square one
capacity = w.getCapacity(MAX_FRAME_LENGTH);
assertEquals(MAX_FRAME_LENGTH - overheadPerFrame, capacity);
}
} }

View File

@@ -1,12 +1,13 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MIN_CONNECTION_LENGTH;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.api.transport.TransportConstants;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import org.junit.Test; import org.junit.Test;
@@ -30,20 +31,20 @@ public class ConnectionWriterTest extends TestCase {
@Test @Test
public void testOverhead() throws Exception { public void testOverhead() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream( ByteArrayOutputStream out =
TransportConstants.MIN_CONNECTION_LENGTH); new ByteArrayOutputStream(MIN_CONNECTION_LENGTH);
ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out,
true, transportId, connection, secret); MIN_CONNECTION_LENGTH, true, transportId, connection, secret);
// Check that the connection writer thinks there's room for a packet // Check that the connection writer thinks there's room for a packet
long capacity = w.getCapacity(TransportConstants.MIN_CONNECTION_LENGTH); long capacity = w.getCapacity();
assertTrue(capacity >= ProtocolConstants.MAX_PACKET_LENGTH); assertTrue(capacity >= MAX_PACKET_LENGTH);
assertTrue(capacity <= TransportConstants.MIN_CONNECTION_LENGTH); assertTrue(capacity <= MIN_CONNECTION_LENGTH);
// Check that there really is room for a packet // Check that there really is room for a packet
byte[] payload = new byte[ProtocolConstants.MAX_PACKET_LENGTH]; byte[] payload = new byte[MAX_PACKET_LENGTH];
w.getOutputStream().write(payload); w.getOutputStream().write(payload);
w.getOutputStream().flush(); w.getOutputStream().flush();
long used = out.size(); long used = out.size();
assertTrue(used >= ProtocolConstants.MAX_PACKET_LENGTH); assertTrue(used >= MAX_PACKET_LENGTH);
assertTrue(used <= TransportConstants.MIN_CONNECTION_LENGTH); assertTrue(used <= MIN_CONNECTION_LENGTH);
} }
} }

View File

@@ -74,8 +74,8 @@ public class FrameReadWriteTest extends TestCase {
// Write the frames // Write the frames
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter encrypter = new ConnectionEncrypterImpl(out, ConnectionEncrypter encrypter = new ConnectionEncrypterImpl(out,
initiator, transportId, connection, ivCipher, frameCipher, Long.MAX_VALUE, initiator, transportId, connection, ivCipher,
ivKey, frameKey); frameCipher, ivKey, frameKey);
mac.init(macKey); mac.init(macKey);
ConnectionWriter writer = new ConnectionWriterImpl(encrypter, mac); ConnectionWriter writer = new ConnectionWriterImpl(encrypter, mac);
OutputStream out1 = writer.getOutputStream(); OutputStream out1 = writer.getOutputStream();

View File

@@ -1,26 +1,51 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import java.io.FilterOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
/** A ConnectionEncrypter that performs no encryption. */ /** A ConnectionEncrypter that performs no encryption. */
class NullConnectionEncrypter implements ConnectionEncrypter { class NullConnectionEncrypter extends FilterOutputStream
implements ConnectionEncrypter {
private final OutputStream out; private long capacity;
NullConnectionEncrypter(OutputStream out) { NullConnectionEncrypter(OutputStream out) {
this.out = out; this(out, Long.MAX_VALUE);
}
NullConnectionEncrypter(OutputStream out, long capacity) {
super(out);
this.capacity = capacity;
} }
public OutputStream getOutputStream() { public OutputStream getOutputStream() {
return out; return this;
} }
public void writeMac(byte[] mac) throws IOException { public void writeMac(byte[] mac) throws IOException {
out.write(mac); out.write(mac);
capacity -= mac.length;
} }
public long getCapacity(long capacity) { public long getCapacity() {
return capacity; return capacity;
} }
@Override
public void write(int b) throws IOException {
out.write(b);
capacity--;
}
@Override
public void write(byte[] b) throws IOException {
write(b, 0, b.length);
}
@Override
public void write(byte[] b, int off, int len) throws IOException {
out.write(b, off, len);
capacity -= len;
}
} }

View File

@@ -23,7 +23,7 @@ public class PaddedConnectionWriterTest extends TransportTest {
@Test @Test
public void testWriteByteDoesNotBlockUntilBufferIsFull() throws Exception { public void testWriteByteDoesNotBlockUntilBufferIsFull() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out); ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE);
ConnectionWriter w = new PaddedConnectionWriter(e, mac); ConnectionWriter w = new PaddedConnectionWriter(e, mac);
final OutputStream out1 = w.getOutputStream(); final OutputStream out1 = w.getOutputStream();
final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1);
@@ -52,7 +52,7 @@ public class PaddedConnectionWriterTest extends TransportTest {
@Test @Test
public void testWriteByteBlocksWhenBufferIsFull() throws Exception { public void testWriteByteBlocksWhenBufferIsFull() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out); ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE);
PaddedConnectionWriter w = new PaddedConnectionWriter(e, mac); PaddedConnectionWriter w = new PaddedConnectionWriter(e, mac);
final OutputStream out1 = w.getOutputStream(); final OutputStream out1 = w.getOutputStream();
final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1);
@@ -86,7 +86,7 @@ public class PaddedConnectionWriterTest extends TransportTest {
@Test @Test
public void testWriteArrayDoesNotBlockUntilBufferIsFull() throws Exception { public void testWriteArrayDoesNotBlockUntilBufferIsFull() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out); ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE);
ConnectionWriter w = new PaddedConnectionWriter(e, mac); ConnectionWriter w = new PaddedConnectionWriter(e, mac);
final OutputStream out1 = w.getOutputStream(); final OutputStream out1 = w.getOutputStream();
final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1);
@@ -115,7 +115,7 @@ public class PaddedConnectionWriterTest extends TransportTest {
@Test @Test
public void testWriteArrayBlocksWhenBufferIsFull() throws Exception { public void testWriteArrayBlocksWhenBufferIsFull() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out); ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE);
PaddedConnectionWriter w = new PaddedConnectionWriter(e, mac); PaddedConnectionWriter w = new PaddedConnectionWriter(e, mac);
final OutputStream out1 = w.getOutputStream(); final OutputStream out1 = w.getOutputStream();
final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1);
@@ -149,7 +149,7 @@ public class PaddedConnectionWriterTest extends TransportTest {
@Test @Test
public void testWriteFullFrameInsertsPadding() throws Exception { public void testWriteFullFrameInsertsPadding() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out); ConnectionEncrypter e = new NullConnectionEncrypter(out, Long.MAX_VALUE);
PaddedConnectionWriter w = new PaddedConnectionWriter(e, mac); PaddedConnectionWriter w = new PaddedConnectionWriter(e, mac);
w.getOutputStream().write(0); w.getOutputStream().write(0);
w.writeFullFrame(); w.writeFullFrame();
@@ -160,32 +160,4 @@ public class PaddedConnectionWriterTest extends TransportTest {
assertEquals(1, ByteUtils.readUint16(frame, 0)); // Payload length assertEquals(1, ByteUtils.readUint16(frame, 0)); // Payload length
assertEquals(maxPayloadLength - 1, ByteUtils.readUint16(frame, 2)); assertEquals(maxPayloadLength - 1, ByteUtils.readUint16(frame, 2));
} }
@Test
public void testGetCapacity() throws Exception {
int overheadPerFrame = headerLength + macLength;
ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionEncrypter e = new NullConnectionEncrypter(out);
PaddedConnectionWriter w = new PaddedConnectionWriter(e, mac);
// Full frame
long capacity = w.getCapacity(MAX_FRAME_LENGTH);
assertEquals(MAX_FRAME_LENGTH - overheadPerFrame, capacity);
// Partial frame
capacity = w.getCapacity(overheadPerFrame + 1);
assertEquals(1, capacity);
// Full frame and partial frame
capacity = w.getCapacity(MAX_FRAME_LENGTH + 1);
assertEquals(MAX_FRAME_LENGTH + 1 - 2 * overheadPerFrame, capacity);
// Buffer some output
w.getOutputStream().write(0);
// Full frame minus buffered frame
capacity = w.getCapacity(MAX_FRAME_LENGTH);
assertEquals(MAX_FRAME_LENGTH - 1 - 2 * overheadPerFrame, capacity);
// Flush the buffer
w.writeFullFrame();
assertEquals(MAX_FRAME_LENGTH, out.size());
// Back to square one
capacity = w.getCapacity(MAX_FRAME_LENGTH);
assertEquals(MAX_FRAME_LENGTH - overheadPerFrame, capacity);
}
} }

View File

@@ -13,7 +13,7 @@ implements BatchTransportReader {
super(in); super(in);
} }
public InputStream getInputStream() throws IOException { public InputStream getInputStream() {
return this; return this;
} }

View File

@@ -16,11 +16,11 @@ implements BatchTransportWriter {
this.capacity = capacity; this.capacity = capacity;
} }
public long getCapacity() throws IOException { public long getCapacity() {
return capacity; return capacity;
} }
public OutputStream getOutputStream() throws IOException { public OutputStream getOutputStream() {
return this; return this;
} }