Allow empty frames.

This commit is contained in:
akwizgran
2011-12-02 16:10:44 +00:00
parent c8338f9866
commit 4ab5dfcac0
3 changed files with 21 additions and 17 deletions

View File

@@ -54,7 +54,7 @@ implements ConnectionReader {
@Override @Override
public int read() throws IOException { public int read() throws IOException {
if(betweenFrames && !readFrame()) return -1; if(betweenFrames && !readNonEmptyFrame()) return -1;
int i = payload[payloadOff]; int i = payload[payloadOff];
payloadOff++; payloadOff++;
payloadLen--; payloadLen--;
@@ -69,7 +69,7 @@ implements ConnectionReader {
@Override @Override
public int read(byte[] b, int off, int len) throws IOException { public int read(byte[] b, int off, int len) throws IOException {
if(betweenFrames && !readFrame()) return -1; if(betweenFrames && !readNonEmptyFrame()) return -1;
len = Math.min(len, payloadLen); len = Math.min(len, payloadLen);
System.arraycopy(payload, payloadOff, b, off, len); System.arraycopy(payload, payloadOff, b, off, len);
payloadOff += len; payloadOff += len;
@@ -78,7 +78,15 @@ implements ConnectionReader {
return len; return len;
} }
private boolean readFrame() throws IOException { private boolean readNonEmptyFrame() throws IOException {
int payload = 0;
do {
payload = readFrame();
} while(payload == 0);
return payload > 0;
}
private int readFrame() throws IOException {
assert betweenFrames; assert betweenFrames;
// Don't allow more than 2^32 frames to be read // Don't allow more than 2^32 frames to be read
if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException(); if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException();
@@ -89,7 +97,7 @@ implements ConnectionReader {
if(read == -1) break; if(read == -1) break;
offset += read; offset += read;
} }
if(offset == 0) return false; // EOF between frames if(offset == 0) return -1; // EOF between frames
if(offset < header.length) throw new EOFException(); // Unexpected EOF if(offset < header.length) throw new EOFException(); // Unexpected EOF
// Check that the frame number is correct and the length is legal // Check that the frame number is correct and the length is legal
if(!HeaderEncoder.validateHeader(header, frame, maxPayloadLength)) if(!HeaderEncoder.validateHeader(header, frame, maxPayloadLength))
@@ -122,8 +130,8 @@ implements ConnectionReader {
byte[] expectedMac = mac.doFinal(); byte[] expectedMac = mac.doFinal();
decrypter.readMac(footer); decrypter.readMac(footer);
if(!Arrays.equals(expectedMac, footer)) throw new FormatException(); if(!Arrays.equals(expectedMac, footer)) throw new FormatException();
betweenFrames = false;
frame++; frame++;
return true; if(payloadLen > 0) betweenFrames = false;
return payloadLen;
} }
} }

View File

@@ -1,13 +1,13 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import net.sf.briar.api.transport.TransportConstants; import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH;
import net.sf.briar.util.ByteUtils; import net.sf.briar.util.ByteUtils;
class HeaderEncoder { class HeaderEncoder {
static void encodeHeader(byte[] header, long frame, int payload, static void encodeHeader(byte[] header, long frame, int payload,
int padding) { int padding) {
if(header.length < TransportConstants.FRAME_HEADER_LENGTH) if(header.length < FRAME_HEADER_LENGTH)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
if(frame < 0 || frame > ByteUtils.MAX_32_BIT_UNSIGNED) if(frame < 0 || frame > ByteUtils.MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
@@ -21,24 +21,22 @@ class HeaderEncoder {
} }
static boolean validateHeader(byte[] header, long frame, int max) { static boolean validateHeader(byte[] header, long frame, int max) {
if(header.length < TransportConstants.FRAME_HEADER_LENGTH) if(header.length < FRAME_HEADER_LENGTH) return false;
return false;
if(ByteUtils.readUint32(header, 0) != frame) return false; if(ByteUtils.readUint32(header, 0) != frame) return false;
int payload = ByteUtils.readUint16(header, 4); int payload = ByteUtils.readUint16(header, 4);
int padding = ByteUtils.readUint16(header, 6); int padding = ByteUtils.readUint16(header, 6);
if(payload + padding == 0) return false;
if(payload + padding > max) return false; if(payload + padding > max) return false;
return true; return true;
} }
static int getPayloadLength(byte[] header) { static int getPayloadLength(byte[] header) {
if(header.length < TransportConstants.FRAME_HEADER_LENGTH) if(header.length < FRAME_HEADER_LENGTH)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
return ByteUtils.readUint16(header, 4); return ByteUtils.readUint16(header, 4);
} }
static int getPaddingLength(byte[] header) { static int getPaddingLength(byte[] header) {
if(header.length < TransportConstants.FRAME_HEADER_LENGTH) if(header.length < FRAME_HEADER_LENGTH)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
return ByteUtils.readUint16(header, 6); return ByteUtils.readUint16(header, 6);
} }

View File

@@ -33,10 +33,8 @@ public class ConnectionReaderImplTest extends TransportTest {
ByteArrayInputStream in = new ByteArrayInputStream(frame); ByteArrayInputStream in = new ByteArrayInputStream(frame);
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in);
ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
try { // There should be no bytes available before EOF
r.getInputStream().read(); assertEquals(-1, r.getInputStream().read());
fail();
} catch(FormatException expected) {}
} }
@Test @Test