Added a flag to indicate the last frame of the connection.

This commit is contained in:
akwizgran
2012-02-07 10:46:01 +00:00
parent e8660c13ca
commit 4ca5be7c06
13 changed files with 117 additions and 45 deletions

View File

@@ -4,8 +4,6 @@ import java.io.IOException;
public interface ProtocolWriter {
void flush() throws IOException;
int getMaxBatchesForAck(long capacity);
int getMaxMessagesForOffer(long capacity);
@@ -23,4 +21,8 @@ public interface ProtocolWriter {
void writeSubscriptionUpdate(SubscriptionUpdate s) throws IOException;
void writeTransportUpdate(TransportUpdate t) throws IOException;
void flush() throws IOException;
void close() throws IOException;
}

View File

@@ -9,7 +9,7 @@ public interface TransportConstants {
static final int MAX_FRAME_LENGTH = 65536; // 2^16, 64 KiB
/** The length of the frame header in bytes. */
static final int FRAME_HEADER_LENGTH = 8;
static final int FRAME_HEADER_LENGTH = 9;
/** The length of the MAC in bytes. */
static final int MAC_LENGTH = 32;

View File

@@ -147,4 +147,8 @@ class ProtocolWriterImpl implements ProtocolWriter {
public void flush() throws IOException {
out.flush();
}
public void close() throws IOException {
out.close();
}
}

View File

@@ -215,6 +215,7 @@ abstract class DuplexConnection implements DatabaseListener {
task.run();
}
writer.flush();
writer.close();
if(!disposed.getAndSet(true)) transport.dispose(false, true);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());

View File

@@ -99,6 +99,7 @@ class OutgoingSimplexConnection {
b = db.generateBatch(contactId, (int) capacity);
}
writer.flush();
writer.close();
transport.dispose(false);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());

View File

@@ -5,6 +5,7 @@ import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH;
import java.io.IOException;
import java.io.InputStream;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.transport.ConnectionReader;
class ConnectionReaderImpl extends InputStream implements ConnectionReader {
@@ -52,15 +53,14 @@ class ConnectionReaderImpl extends InputStream implements ConnectionReader {
private boolean readFrame() throws IOException {
assert length == 0;
while(true) {
frame.reset();
if(!in.readFrame(frame)) {
length = -1;
return false;
}
offset = FRAME_HEADER_LENGTH;
length = HeaderEncoder.getPayloadLength(frame.getBuffer());
return true;
if(HeaderEncoder.isLastFrame(frame.getBuffer())) {
length = -1;
return false;
}
frame.reset();
if(!in.readFrame(frame)) throw new FormatException();
offset = FRAME_HEADER_LENGTH;
length = HeaderEncoder.getPayloadLength(frame.getBuffer());
return true;
}
}

View File

@@ -45,16 +45,23 @@ class ConnectionWriterImpl extends OutputStream implements ConnectionWriter {
return Math.max(0L, capacity - frames * overheadPerFrame);
}
@Override
public void close() throws IOException {
if(offset > FRAME_HEADER_LENGTH || frameNumber > 0L) writeFrame(true);
out.flush();
super.close();
}
@Override
public void flush() throws IOException {
if(offset > FRAME_HEADER_LENGTH) writeFrame();
if(offset > FRAME_HEADER_LENGTH) writeFrame(false);
out.flush();
}
@Override
public void write(int b) throws IOException {
frame.getBuffer()[offset++] = (byte) b;
if(offset + MAC_LENGTH == MAX_FRAME_LENGTH) writeFrame();
if(offset + MAC_LENGTH == MAX_FRAME_LENGTH) writeFrame(false);
}
@Override
@@ -69,7 +76,7 @@ class ConnectionWriterImpl extends OutputStream implements ConnectionWriter {
while(available <= len) {
System.arraycopy(b, off, buf, offset, available);
offset += available;
writeFrame();
writeFrame(false);
off += available;
len -= available;
available = MAX_FRAME_LENGTH - offset - MAC_LENGTH;
@@ -78,11 +85,12 @@ class ConnectionWriterImpl extends OutputStream implements ConnectionWriter {
offset += len;
}
private void writeFrame() throws IOException {
private void writeFrame(boolean lastFrame) throws IOException {
if(frameNumber > MAX_32_BIT_UNSIGNED) throw new IllegalStateException();
int payload = offset - FRAME_HEADER_LENGTH;
assert payload > 0;
HeaderEncoder.encodeHeader(frame.getBuffer(), frameNumber, payload, 0);
assert payload >= 0;
HeaderEncoder.encodeHeader(frame.getBuffer(), frameNumber, payload, 0,
lastFrame);
frame.setLength(offset + MAC_LENGTH);
out.writeFrame(frame);
frame.reset();

View File

@@ -1,12 +1,13 @@
package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import net.sf.briar.util.ByteUtils;
class HeaderEncoder {
static void encodeHeader(byte[] header, long frameNumber, int payload,
int padding) {
int padding, boolean lastFrame) {
if(header.length < FRAME_HEADER_LENGTH)
throw new IllegalArgumentException();
if(frameNumber < 0 || frameNumber > ByteUtils.MAX_32_BIT_UNSIGNED)
@@ -18,6 +19,18 @@ class HeaderEncoder {
ByteUtils.writeUint32(frameNumber, header, 0);
ByteUtils.writeUint16(payload, header, 4);
ByteUtils.writeUint16(padding, header, 6);
if(lastFrame) header[8] = 1;
}
static boolean checkHeader(byte[] header, int length) {
if(header.length < FRAME_HEADER_LENGTH)
throw new IllegalArgumentException();
int payload = getPayloadLength(header);
int padding = getPaddingLength(header);
if(FRAME_HEADER_LENGTH + payload + padding + MAC_LENGTH != length)
return false;
if(header[8] != 0 && header[8] != 1) return false;
return true;
}
static long getFrameNumber(byte[] header) {
@@ -37,4 +50,10 @@ class HeaderEncoder {
throw new IllegalArgumentException();
return ByteUtils.readUint16(header, 6);
}
static boolean isLastFrame(byte[] header) {
if(header.length < FRAME_HEADER_LENGTH)
throw new IllegalArgumentException();
return header[8] == 1;
}
}

View File

@@ -39,13 +39,12 @@ class IncomingAuthenticationLayerImpl implements FrameReader {
if(length < FRAME_HEADER_LENGTH + MAC_LENGTH)
throw new FormatException();
if(length > MAX_FRAME_LENGTH) throw new FormatException();
// Check that the payload and padding lengths are correct
// Check that the header fields are legal and match the length
byte[] buf = f.getBuffer();
if(!HeaderEncoder.checkHeader(buf, length)) throw new FormatException();
// Check that the padding is all zeroes
int payload = HeaderEncoder.getPayloadLength(buf);
int padding = HeaderEncoder.getPaddingLength(buf);
if(length != FRAME_HEADER_LENGTH + payload + padding + MAC_LENGTH)
throw new FormatException();
// Check that the padding is all zeroes
int paddingStart = FRAME_HEADER_LENGTH + payload;
for(int i = paddingStart; i < paddingStart + padding; i++) {
if(buf[i] != 0) throw new FormatException();

View File

@@ -116,11 +116,11 @@ public class SimplexConnectionReadWriteTest extends BriarTestCase {
alice.getInstance(ProtocolWriterFactory.class);
TestSimplexTransportWriter transport = new TestSimplexTransportWriter(out,
Long.MAX_VALUE, false);
OutgoingSimplexConnection batchOut = new OutgoingSimplexConnection(db,
OutgoingSimplexConnection simplex = new OutgoingSimplexConnection(db,
connRegistry, connFactory, protoFactory, contactId, transportId,
transportIndex, transport);
// Write whatever needs to be written
batchOut.write();
simplex.write();
assertTrue(transport.getDisposed());
assertFalse(transport.getException());
// Close Alice's database
@@ -171,14 +171,14 @@ public class SimplexConnectionReadWriteTest extends BriarTestCase {
ProtocolReaderFactory protoFactory =
bob.getInstance(ProtocolReaderFactory.class);
TestSimplexTransportReader transport = new TestSimplexTransportReader(in);
IncomingSimplexConnection batchIn = new IncomingSimplexConnection(
IncomingSimplexConnection simplex = new IncomingSimplexConnection(
new ImmediateExecutor(), new ImmediateExecutor(), db,
connRegistry, connFactory, protoFactory, ctx, transportId,
transport);
// No messages should have been added yet
assertFalse(listener.messagesAdded);
// Read whatever needs to be read
batchIn.read();
simplex.read();
assertTrue(transport.getDisposed());
assertFalse(transport.getException());
assertTrue(transport.getRecognised());

View File

@@ -26,7 +26,7 @@ public class ConnectionReaderImplTest extends TransportTest {
int payloadLength = 0;
byte[] frame = new byte[FRAME_HEADER_LENGTH + payloadLength
+ MAC_LENGTH];
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0);
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0, true);
// Calculate the MAC
mac.init(macKey);
mac.update(frame, 0, FRAME_HEADER_LENGTH + payloadLength);
@@ -43,7 +43,7 @@ public class ConnectionReaderImplTest extends TransportTest {
int payloadLength = 1;
byte[] frame = new byte[FRAME_HEADER_LENGTH + payloadLength
+ MAC_LENGTH];
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0);
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0, true);
// Calculate the MAC
mac.init(macKey);
mac.update(frame, 0, FRAME_HEADER_LENGTH + payloadLength);
@@ -60,13 +60,13 @@ public class ConnectionReaderImplTest extends TransportTest {
public void testMaxLength() throws Exception {
// First frame: max payload length
byte[] frame = new byte[MAX_FRAME_LENGTH];
HeaderEncoder.encodeHeader(frame, 0, MAX_PAYLOAD_LENGTH, 0);
HeaderEncoder.encodeHeader(frame, 0, MAX_PAYLOAD_LENGTH, 0, false);
mac.init(macKey);
mac.update(frame, 0, FRAME_HEADER_LENGTH + MAX_PAYLOAD_LENGTH);
mac.doFinal(frame, FRAME_HEADER_LENGTH + MAX_PAYLOAD_LENGTH);
// Second frame: max payload length plus one
byte[] frame1 = new byte[MAX_FRAME_LENGTH + 1];
HeaderEncoder.encodeHeader(frame1, 1, MAX_PAYLOAD_LENGTH + 1, 0);
HeaderEncoder.encodeHeader(frame1, 1, MAX_PAYLOAD_LENGTH + 1, 0, false);
mac.update(frame1, 0, FRAME_HEADER_LENGTH + MAX_PAYLOAD_LENGTH + 1);
mac.doFinal(frame1, FRAME_HEADER_LENGTH + MAX_PAYLOAD_LENGTH + 1);
// Concatenate the frames
@@ -92,14 +92,14 @@ public class ConnectionReaderImplTest extends TransportTest {
// First frame: max payload length, including padding
byte[] frame = new byte[MAX_FRAME_LENGTH];
HeaderEncoder.encodeHeader(frame, 0, MAX_PAYLOAD_LENGTH - paddingLength,
paddingLength);
paddingLength, false);
mac.init(macKey);
mac.update(frame, 0, FRAME_HEADER_LENGTH + MAX_PAYLOAD_LENGTH);
mac.doFinal(frame, FRAME_HEADER_LENGTH + MAX_PAYLOAD_LENGTH);
// Second frame: max payload length plus one, including padding
byte[] frame1 = new byte[MAX_FRAME_LENGTH + 1];
HeaderEncoder.encodeHeader(frame1, 1,
MAX_PAYLOAD_LENGTH + 1 - paddingLength, paddingLength);
MAX_PAYLOAD_LENGTH + 1 - paddingLength, paddingLength, false);
mac.update(frame1, 0, FRAME_HEADER_LENGTH + MAX_PAYLOAD_LENGTH + 1);
mac.doFinal(frame1, FRAME_HEADER_LENGTH + MAX_PAYLOAD_LENGTH + 1);
// Concatenate the frames
@@ -124,7 +124,8 @@ public class ConnectionReaderImplTest extends TransportTest {
int payloadLength = 10, paddingLength = 10;
byte[] frame = new byte[FRAME_HEADER_LENGTH + payloadLength
+ paddingLength + MAC_LENGTH];
HeaderEncoder.encodeHeader(frame, 0, payloadLength, paddingLength);
HeaderEncoder.encodeHeader(frame, 0, payloadLength, paddingLength,
false);
// Set a byte of the padding to a non-zero value
frame[FRAME_HEADER_LENGTH + payloadLength] = 1;
mac.init(macKey);
@@ -147,7 +148,7 @@ public class ConnectionReaderImplTest extends TransportTest {
int payloadLength = 123;
byte[] frame = new byte[FRAME_HEADER_LENGTH + payloadLength
+ MAC_LENGTH];
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0);
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0, false);
mac.init(macKey);
mac.update(frame, 0, FRAME_HEADER_LENGTH + payloadLength);
mac.doFinal(frame, FRAME_HEADER_LENGTH + payloadLength);
@@ -155,7 +156,7 @@ public class ConnectionReaderImplTest extends TransportTest {
int payloadLength1 = 1234;
byte[] frame1 = new byte[FRAME_HEADER_LENGTH + payloadLength1
+ MAC_LENGTH];
HeaderEncoder.encodeHeader(frame1, 1, payloadLength1, 0);
HeaderEncoder.encodeHeader(frame1, 1, payloadLength1, 0, true);
mac.update(frame1, 0, FRAME_HEADER_LENGTH + payloadLength1);
mac.doFinal(frame1, FRAME_HEADER_LENGTH + payloadLength1);
// Concatenate the frames
@@ -171,6 +172,43 @@ public class ConnectionReaderImplTest extends TransportTest {
byte[] read1 = new byte[payloadLength1];
TestUtils.readFully(r.getInputStream(), read1);
assertArrayEquals(new byte[payloadLength1], read1);
assertEquals(-1, r.getInputStream().read());
}
@Test
public void testLastFrameNotMarkedAsSuch() throws Exception {
// First frame: 123-byte payload
int payloadLength = 123;
byte[] frame = new byte[FRAME_HEADER_LENGTH + payloadLength
+ MAC_LENGTH];
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0, false);
mac.init(macKey);
mac.update(frame, 0, FRAME_HEADER_LENGTH + payloadLength);
mac.doFinal(frame, FRAME_HEADER_LENGTH + payloadLength);
// Second frame: 1234-byte payload
int payloadLength1 = 1234;
byte[] frame1 = new byte[FRAME_HEADER_LENGTH + payloadLength1
+ MAC_LENGTH];
HeaderEncoder.encodeHeader(frame1, 1, payloadLength1, 0, false);
mac.update(frame1, 0, FRAME_HEADER_LENGTH + payloadLength1);
mac.doFinal(frame1, FRAME_HEADER_LENGTH + payloadLength1);
// Concatenate the frames
ByteArrayOutputStream out = new ByteArrayOutputStream();
out.write(frame);
out.write(frame1);
// Read the frames
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
ConnectionReader r = createConnectionReader(in);
byte[] read = new byte[payloadLength];
TestUtils.readFully(r.getInputStream(), read);
assertArrayEquals(new byte[payloadLength], read);
byte[] read1 = new byte[payloadLength1];
TestUtils.readFully(r.getInputStream(), read1);
assertArrayEquals(new byte[payloadLength1], read1);
try {
r.getInputStream().read();
fail();
} catch(FormatException expected) {}
}
@Test
@@ -178,7 +216,7 @@ public class ConnectionReaderImplTest extends TransportTest {
int payloadLength = 8;
byte[] frame = new byte[FRAME_HEADER_LENGTH + payloadLength
+ MAC_LENGTH];
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0);
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0, false);
// Calculate the MAC
mac.init(macKey);
mac.update(frame, 0, FRAME_HEADER_LENGTH + payloadLength);
@@ -199,7 +237,7 @@ public class ConnectionReaderImplTest extends TransportTest {
int payloadLength = 8;
byte[] frame = new byte[FRAME_HEADER_LENGTH + payloadLength
+ MAC_LENGTH];
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0);
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0, false);
// Calculate the MAC
mac.init(macKey);
mac.update(frame, 0, FRAME_HEADER_LENGTH + payloadLength);

View File

@@ -33,7 +33,7 @@ public class ConnectionWriterImplTest extends TransportTest {
int payloadLength = 1;
byte[] frame = new byte[FRAME_HEADER_LENGTH + payloadLength
+ MAC_LENGTH];
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0);
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0, false);
// Calculate the MAC
mac.init(macKey);
mac.update(frame, 0, FRAME_HEADER_LENGTH + payloadLength);
@@ -78,7 +78,7 @@ public class ConnectionWriterImplTest extends TransportTest {
int payloadLength = 123;
byte[] frame = new byte[FRAME_HEADER_LENGTH + payloadLength
+ MAC_LENGTH];
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0);
HeaderEncoder.encodeHeader(frame, 0, payloadLength, 0, false);
mac.init(macKey);
mac.update(frame, 0, FRAME_HEADER_LENGTH + payloadLength);
mac.doFinal(frame, FRAME_HEADER_LENGTH + payloadLength);
@@ -86,7 +86,7 @@ public class ConnectionWriterImplTest extends TransportTest {
int payloadLength1 = 1234;
byte[] frame1 = new byte[FRAME_HEADER_LENGTH + payloadLength1
+ MAC_LENGTH];
HeaderEncoder.encodeHeader(frame1, 1, payloadLength1, 0);
HeaderEncoder.encodeHeader(frame1, 1, payloadLength1, 0, false);
mac.update(frame1, 0, FRAME_HEADER_LENGTH + 1234);
mac.doFinal(frame1, FRAME_HEADER_LENGTH + 1234);
// Concatenate the frames

View File

@@ -42,14 +42,14 @@ public class IncomingEncryptionLayerImplTest extends BriarTestCase {
TagEncoder.encodeTag(tag, tagCipher, tagKey);
// Calculate the ciphertext for the first frame
byte[] plaintext = new byte[FRAME_HEADER_LENGTH + 123 + MAC_LENGTH];
HeaderEncoder.encodeHeader(plaintext, 0L, 123, 0);
HeaderEncoder.encodeHeader(plaintext, 0L, 123, 0, false);
byte[] iv = IvEncoder.encodeIv(0L, frameCipher.getBlockSize());
IvParameterSpec ivSpec = new IvParameterSpec(iv);
frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec);
byte[] ciphertext = frameCipher.doFinal(plaintext, 0, plaintext.length);
// Calculate the ciphertext for the second frame
byte[] plaintext1 = new byte[FRAME_HEADER_LENGTH + 1234 + MAC_LENGTH];
HeaderEncoder.encodeHeader(plaintext1, 1L, 1234, 0);
HeaderEncoder.encodeHeader(plaintext1, 1L, 1234, 0, false);
IvEncoder.updateIv(iv, 1L);
ivSpec = new IvParameterSpec(iv);
frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec);
@@ -87,14 +87,14 @@ public class IncomingEncryptionLayerImplTest extends BriarTestCase {
public void testDecryptionWithoutTag() throws Exception {
// Calculate the ciphertext for the first frame
byte[] plaintext = new byte[FRAME_HEADER_LENGTH + 123 + MAC_LENGTH];
HeaderEncoder.encodeHeader(plaintext, 0L, 123, 0);
HeaderEncoder.encodeHeader(plaintext, 0L, 123, 0, false);
byte[] iv = IvEncoder.encodeIv(0L, frameCipher.getBlockSize());
IvParameterSpec ivSpec = new IvParameterSpec(iv);
frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec);
byte[] ciphertext = frameCipher.doFinal(plaintext, 0, plaintext.length);
// Calculate the ciphertext for the second frame
byte[] plaintext1 = new byte[FRAME_HEADER_LENGTH + 1234 + MAC_LENGTH];
HeaderEncoder.encodeHeader(plaintext1, 1L, 1234, 0);
HeaderEncoder.encodeHeader(plaintext1, 1L, 1234, 0, false);
IvEncoder.updateIv(iv, 1L);
ivSpec = new IvParameterSpec(iv);
frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec);