Frame-at-a-time decryption.

This commit is contained in:
akwizgran
2012-01-12 18:41:43 +00:00
parent f55f98f506
commit f6cad10868
13 changed files with 184 additions and 274 deletions

View File

@@ -1,14 +1,13 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
/** Decrypts unauthenticated data received over a connection. */ /** Decrypts unauthenticated data received over a connection. */
interface ConnectionDecrypter { interface ConnectionDecrypter {
/** Returns an input stream from which decrypted data can be read. */ /**
InputStream getInputStream(); * Reads and decrypts a frame into the given buffer and returns the length
* of the decrypted frame, or -1 if no more frames can be read.
/** Reads and decrypts the remainder of the current frame. */ */
void readFinal(byte[] b) throws IOException; int readFrame(byte[] b) throws IOException;
} }

View File

@@ -1,9 +1,10 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED; import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.io.EOFException; import java.io.EOFException;
import java.io.FilterInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
@@ -11,136 +12,35 @@ import java.security.GeneralSecurityException;
import javax.crypto.Cipher; import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec; import javax.crypto.spec.IvParameterSpec;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
class ConnectionDecrypterImpl extends FilterInputStream class ConnectionDecrypterImpl implements ConnectionDecrypter {
implements ConnectionDecrypter {
private final InputStream in;
private final Cipher frameCipher; private final Cipher frameCipher;
private final ErasableKey frameKey; private final ErasableKey frameKey;
private final byte[] iv, buf; private final int macLength, blockSize;
private final byte[] iv;
private int bufOff = 0, bufLen = 0;
private long frame = 0L; private long frame = 0L;
private boolean betweenFrames = true;
ConnectionDecrypterImpl(InputStream in, Cipher frameCipher, ConnectionDecrypterImpl(InputStream in, Cipher frameCipher,
ErasableKey frameKey) { ErasableKey frameKey, int macLength) {
super(in); this.in = in;
this.frameCipher = frameCipher; this.frameCipher = frameCipher;
this.frameKey = frameKey; this.frameKey = frameKey;
iv = IvEncoder.encodeIv(0, frameCipher.getBlockSize()); this.macLength = macLength;
buf = new byte[frameCipher.getBlockSize()]; blockSize = frameCipher.getBlockSize();
if(blockSize < FRAME_HEADER_LENGTH)
throw new IllegalArgumentException();
iv = IvEncoder.encodeIv(0, blockSize);
} }
public InputStream getInputStream() { public int readFrame(byte[] b) throws IOException {
return this; if(b.length < MAX_FRAME_LENGTH) throw new IllegalArgumentException();
}
public void readFinal(byte[] b) throws IOException {
try {
if(betweenFrames) throw new IllegalStateException();
// If we have any plaintext in the buffer, copy it into the frame
System.arraycopy(buf, bufOff, b, 0, bufLen);
// Read the remainder of the frame
int offset = bufLen;
while(offset < b.length) {
int read = in.read(b, offset, b.length - offset);
if(read == -1) break;
offset += read;
}
if(offset < b.length) throw new EOFException(); // Unexpected EOF
// Decrypt the remainder of the frame
try {
int length = b.length - bufLen;
int i = frameCipher.doFinal(b, bufLen, length, b, bufLen);
if(i < length) throw new RuntimeException();
} catch(GeneralSecurityException badCipher) {
throw new RuntimeException(badCipher);
}
bufOff = bufLen = 0;
betweenFrames = true;
} catch(IOException e) {
frameKey.erase();
throw e;
}
}
@Override
public int read() throws IOException {
try {
if(betweenFrames) initialiseCipher();
if(bufLen == 0) {
if(!readBlock()) {
frameKey.erase();
return -1;
}
bufOff = 0;
bufLen = buf.length;
}
int i = buf[bufOff];
bufOff++;
bufLen--;
return i < 0 ? i + 256 : i;
} catch(IOException e) {
frameKey.erase();
throw e;
}
}
@Override
public int read(byte[] b) throws IOException {
return read(b, 0, b.length);
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
try {
if(betweenFrames) initialiseCipher();
if(bufLen == 0) {
if(!readBlock()) {
frameKey.erase();
return -1;
}
bufOff = 0;
bufLen = buf.length;
}
int length = Math.min(len, bufLen);
System.arraycopy(buf, bufOff, b, off, length);
bufOff += length;
bufLen -= length;
return length;
} catch(IOException e) {
frameKey.erase();
throw e;
}
}
// Although we're using CTR mode, which doesn't require full blocks of
// ciphertext, the cipher still tries to operate a block at a time
private boolean readBlock() throws IOException {
// Try to read a block of ciphertext
int offset = 0;
while(offset < buf.length) {
int read = in.read(buf, offset, buf.length - offset);
if(read == -1) break;
offset += read;
}
if(offset == 0) return false;
if(offset < buf.length) throw new EOFException(); // Unexpected EOF
// Decrypt the block
try {
int i = frameCipher.update(buf, 0, offset, buf);
if(i < offset) throw new RuntimeException();
} catch(GeneralSecurityException badCipher) {
throw new RuntimeException(badCipher);
}
return true;
}
private void initialiseCipher() {
assert betweenFrames;
if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException(); if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException();
// Initialise the cipher
IvEncoder.updateIv(iv, frame); IvEncoder.updateIv(iv, frame);
IvParameterSpec ivSpec = new IvParameterSpec(iv); IvParameterSpec ivSpec = new IvParameterSpec(iv);
try { try {
@@ -148,7 +48,52 @@ implements ConnectionDecrypter {
} catch(GeneralSecurityException badIvOrKey) { } catch(GeneralSecurityException badIvOrKey) {
throw new RuntimeException(badIvOrKey); throw new RuntimeException(badIvOrKey);
} }
frame++; try {
betweenFrames = false; // Read the first block
int offset = 0;
while(offset < blockSize) {
int read = in.read(b, offset, blockSize - offset);
if(read == -1) {
if(offset == 0) return -1;
if(offset < blockSize) throw new EOFException();
break;
}
offset += read;
}
// Decrypt the first block
try {
int decrypted = frameCipher.update(b, 0, blockSize, b);
assert decrypted == blockSize;
} catch(GeneralSecurityException badCipher) {
throw new RuntimeException(badCipher);
}
// Validate and parse the header
int max = MAX_FRAME_LENGTH - FRAME_HEADER_LENGTH - macLength;
if(!HeaderEncoder.validateHeader(b, frame, max))
throw new FormatException();
int payload = HeaderEncoder.getPayloadLength(b);
int padding = HeaderEncoder.getPaddingLength(b);
int length = FRAME_HEADER_LENGTH + payload + padding + macLength;
if(length > MAX_FRAME_LENGTH) throw new FormatException();
// Read the remainder of the frame
while(offset < length) {
int read = in.read(b, offset, length - offset);
if(read == -1) throw new EOFException();
offset += read;
}
// Decrypt the remainder of the frame
try {
int decrypted = frameCipher.doFinal(b, blockSize,
length - blockSize, b, blockSize);
assert decrypted == length - blockSize;
} catch(GeneralSecurityException badCipher) {
throw new RuntimeException(badCipher);
}
frame++;
return length;
} catch(IOException e) {
frameKey.erase();
throw e;
}
} }
} }

View File

@@ -6,7 +6,7 @@ import java.io.IOException;
interface ConnectionEncrypter { interface ConnectionEncrypter {
/** Encrypts and writes the given frame. */ /** Encrypts and writes the given frame. */
void writeFrame(byte[] b, int off, int len) throws IOException; void writeFrame(byte[] b, int len) throws IOException;
/** Flushes the output stream. */ /** Flushes the output stream. */
void flush() throws IOException; void flush() throws IOException;

View File

@@ -35,7 +35,7 @@ class ConnectionEncrypterImpl implements ConnectionEncrypter {
if(tag.length != TAG_LENGTH) throw new IllegalArgumentException(); if(tag.length != TAG_LENGTH) throw new IllegalArgumentException();
} }
public void writeFrame(byte[] b, int off, int len) throws IOException { public void writeFrame(byte[] b, int len) throws IOException {
try { try {
if(!tagWritten) { if(!tagWritten) {
out.write(tag); out.write(tag);
@@ -47,12 +47,12 @@ class ConnectionEncrypterImpl implements ConnectionEncrypter {
IvParameterSpec ivSpec = new IvParameterSpec(iv); IvParameterSpec ivSpec = new IvParameterSpec(iv);
try { try {
frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec); frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec);
int encrypted = frameCipher.doFinal(b, off, len, b, off); int encrypted = frameCipher.doFinal(b, 0, len, b, 0);
assert encrypted == len; assert encrypted == len;
} catch(GeneralSecurityException badCipher) { } catch(GeneralSecurityException badCipher) {
throw new RuntimeException(badCipher); throw new RuntimeException(badCipher);
} }
out.write(b, off, len); out.write(b, 0, len);
capacity -= len; capacity -= len;
frame++; frame++;
} catch(IOException e) { } catch(IOException e) {

View File

@@ -46,10 +46,10 @@ class ConnectionReaderFactoryImpl implements ConnectionReaderFactory {
ByteUtils.erase(secret); ByteUtils.erase(secret);
// Create the decrypter // Create the decrypter
Cipher frameCipher = crypto.getFrameCipher(); Cipher frameCipher = crypto.getFrameCipher();
ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in,
frameCipher, frameKey);
// Create the reader
Mac mac = crypto.getMac(); Mac mac = crypto.getMac();
ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in,
frameCipher, frameKey, mac.getMacLength());
// Create the reader
return new ConnectionReaderImpl(decrypter, mac, macKey); return new ConnectionReaderImpl(decrypter, mac, macKey);
} }
} }

View File

@@ -4,12 +4,9 @@ import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH; import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED; import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.io.EOFException;
import java.io.FilterInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.security.InvalidKeyException; import java.security.InvalidKeyException;
import java.util.Arrays;
import javax.crypto.Mac; import javax.crypto.Mac;
@@ -17,21 +14,18 @@ import net.sf.briar.api.FormatException;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
class ConnectionReaderImpl extends FilterInputStream class ConnectionReaderImpl extends InputStream implements ConnectionReader {
implements ConnectionReader {
private final ConnectionDecrypter decrypter; private final ConnectionDecrypter decrypter;
private final Mac mac; private final Mac mac;
private final int maxPayloadLength; private final int macLength;
private final byte[] header, payload, footer; private final byte[] buf;
private long frame = 0L; private long frame = 0L;
private int payloadOff = 0, payloadLen = 0; private int bufOffset = 0, bufLength = 0;
private boolean betweenFrames = true;
ConnectionReaderImpl(ConnectionDecrypter decrypter, Mac mac, ConnectionReaderImpl(ConnectionDecrypter decrypter, Mac mac,
ErasableKey macKey) { ErasableKey macKey) {
super(decrypter.getInputStream());
this.decrypter = decrypter; this.decrypter = decrypter;
this.mac = mac; this.mac = mac;
// Initialise the MAC // Initialise the MAC
@@ -41,11 +35,8 @@ implements ConnectionReader {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} }
macKey.erase(); macKey.erase();
maxPayloadLength = macLength = mac.getMacLength();
MAX_FRAME_LENGTH - FRAME_HEADER_LENGTH - mac.getMacLength(); buf = new byte[MAX_FRAME_LENGTH];
header = new byte[FRAME_HEADER_LENGTH];
payload = new byte[maxPayloadLength];
footer = new byte[mac.getMacLength()];
} }
public InputStream getInputStream() { public InputStream getInputStream() {
@@ -54,12 +45,11 @@ implements ConnectionReader {
@Override @Override
public int read() throws IOException { public int read() throws IOException {
if(betweenFrames && !readNonEmptyFrame()) return -1; while(bufLength == 0) if(!readFrame()) return -1;
int i = payload[payloadOff]; int b = buf[bufOffset] & 0xff;
payloadOff++; bufOffset++;
payloadLen--; bufLength--;
if(payloadLen == 0) betweenFrames = true; return b;
return i;
} }
@Override @Override
@@ -69,69 +59,44 @@ implements ConnectionReader {
@Override @Override
public int read(byte[] b, int off, int len) throws IOException { public int read(byte[] b, int off, int len) throws IOException {
if(betweenFrames && !readNonEmptyFrame()) return -1; while(bufLength == 0) if(!readFrame()) return -1;
len = Math.min(len, payloadLen); len = Math.min(len, bufLength);
System.arraycopy(payload, payloadOff, b, off, len); System.arraycopy(buf, bufOffset, b, off, len);
payloadOff += len; bufOffset += len;
payloadLen -= len; bufLength -= len;
if(payloadLen == 0) betweenFrames = true;
return len; return len;
} }
private boolean readNonEmptyFrame() throws IOException { private boolean readFrame() throws IOException {
int payload = 0; assert bufLength == 0;
do {
payload = readFrame();
} while(payload == 0);
return payload > 0;
}
private int readFrame() throws IOException {
assert betweenFrames;
// Don't allow more than 2^32 frames to be read // Don't allow more than 2^32 frames to be read
if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException(); if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException();
// Read the header // Read a frame
int offset = 0; int length = decrypter.readFrame(buf);
while(offset < header.length) { if(length == -1) return false;
int read = in.read(header, offset, header.length - offset);
if(read == -1) break;
offset += read;
}
if(offset == 0) return -1; // EOF between frames
if(offset < header.length) throw new EOFException(); // Unexpected EOF
// Check that the frame number is correct and the length is legal // Check that the frame number is correct and the length is legal
if(!HeaderEncoder.validateHeader(header, frame, maxPayloadLength)) int max = MAX_FRAME_LENGTH - FRAME_HEADER_LENGTH - macLength;
if(!HeaderEncoder.validateHeader(buf, frame, max))
throw new FormatException();
int payload = HeaderEncoder.getPayloadLength(buf);
int padding = HeaderEncoder.getPaddingLength(buf);
if(length != FRAME_HEADER_LENGTH + payload + padding + macLength)
throw new FormatException(); throw new FormatException();
payloadLen = HeaderEncoder.getPayloadLength(header);
int paddingLen = HeaderEncoder.getPaddingLength(header);
mac.update(header);
// Read the payload
offset = 0;
while(offset < payloadLen) {
int read = in.read(payload, offset, payloadLen - offset);
if(read == -1) throw new EOFException(); // Unexpected EOF
mac.update(payload, offset, read);
offset += read;
}
payloadOff = 0;
// Read the padding
while(offset < payloadLen + paddingLen) {
int read = in.read(payload, offset,
payloadLen + paddingLen - offset);
if(read == -1) throw new EOFException(); // Unexpected EOF
mac.update(payload, offset, read);
offset += read;
}
// Check that the padding is all zeroes // Check that the padding is all zeroes
for(int i = payloadLen; i < payloadLen + paddingLen; i++) { int paddingStart = FRAME_HEADER_LENGTH + payload;
if(payload[i] != 0) throw new FormatException(); for(int i = paddingStart; i < paddingStart + padding; i++) {
if(buf[i] != 0) throw new FormatException();
} }
// Read the MAC // Check the MAC
int macStart = FRAME_HEADER_LENGTH + payload + padding;
mac.update(buf, 0, macStart);
byte[] expectedMac = mac.doFinal(); byte[] expectedMac = mac.doFinal();
decrypter.readFinal(footer); for(int i = 0; i < macLength; i++) {
if(!Arrays.equals(expectedMac, footer)) throw new FormatException(); if(expectedMac[i] != buf[macStart + i]) throw new FormatException();
}
bufOffset = FRAME_HEADER_LENGTH;
bufLength = payload;
frame++; frame++;
if(payloadLen > 0) betweenFrames = false; return true;
return payloadLen;
} }
} }

View File

@@ -101,7 +101,7 @@ class ConnectionWriterImpl extends OutputStream implements ConnectionWriter {
} catch(ShortBufferException badMac) { } catch(ShortBufferException badMac) {
throw new RuntimeException(badMac); throw new RuntimeException(badMac);
} }
encrypter.writeFrame(buf, 0, bufLength + mac.getMacLength()); encrypter.writeFrame(buf, bufLength + mac.getMacLength());
bufLength = FRAME_HEADER_LENGTH; bufLength = FRAME_HEADER_LENGTH;
frame++; frame++;
} }

View File

@@ -1,6 +1,7 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import static org.junit.Assert.assertArrayEquals; import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
@@ -8,7 +9,6 @@ import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec; import javax.crypto.spec.IvParameterSpec;
import net.sf.briar.BriarTestCase; import net.sf.briar.BriarTestCase;
import net.sf.briar.TestUtils;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
@@ -45,58 +45,40 @@ public class ConnectionDecrypterImplTest extends BriarTestCase {
} }
private void testDecryption(boolean initiator) throws Exception { private void testDecryption(boolean initiator) throws Exception {
// Calculate the expected plaintext for the first frame // Calculate the ciphertext for the first frame
byte[] iv = new byte[frameCipher.getBlockSize()]; byte[] plaintext = new byte[FRAME_HEADER_LENGTH + 123 + MAC_LENGTH];
byte[] ciphertext = new byte[123]; HeaderEncoder.encodeHeader(plaintext, 0L, 123, 0);
byte[] ciphertextMac = new byte[MAC_LENGTH]; byte[] iv = IvEncoder.encodeIv(0L, frameCipher.getBlockSize());
IvParameterSpec ivSpec = new IvParameterSpec(iv); IvParameterSpec ivSpec = new IvParameterSpec(iv);
frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec); frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec);
byte[] plaintext = new byte[ciphertext.length + ciphertextMac.length]; byte[] ciphertext = new byte[plaintext.length];
int offset = frameCipher.update(ciphertext, 0, ciphertext.length, frameCipher.doFinal(plaintext, 0, plaintext.length, ciphertext);
plaintext); // Calculate the ciphertext for the second frame
frameCipher.doFinal(ciphertextMac, 0, ciphertextMac.length, plaintext, byte[] plaintext1 = new byte[FRAME_HEADER_LENGTH + 1234 + MAC_LENGTH];
offset); HeaderEncoder.encodeHeader(plaintext1, 1L, 1234, 0);
// Calculate the expected plaintext for the second frame
byte[] ciphertext1 = new byte[1234];
IvEncoder.updateIv(iv, 1L); IvEncoder.updateIv(iv, 1L);
ivSpec = new IvParameterSpec(iv); ivSpec = new IvParameterSpec(iv);
frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec); frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec);
byte[] plaintext1 = new byte[ciphertext1.length + ciphertextMac.length]; byte[] ciphertext1 = new byte[plaintext1.length];
offset = frameCipher.update(ciphertext1, 0, ciphertext1.length, frameCipher.doFinal(plaintext1, 0, plaintext1.length, ciphertext1);
plaintext1);
frameCipher.doFinal(ciphertextMac, 0, ciphertextMac.length, plaintext1,
offset);
// Concatenate the ciphertexts // Concatenate the ciphertexts
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
out.write(ciphertext); out.write(ciphertext);
out.write(ciphertextMac);
out.write(ciphertext1); out.write(ciphertext1);
out.write(ciphertextMac);
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
// Use a ConnectionDecrypter to decrypt the ciphertext // Use a ConnectionDecrypter to decrypt the ciphertext
ConnectionDecrypter d = new ConnectionDecrypterImpl(in, frameCipher, ConnectionDecrypter d = new ConnectionDecrypterImpl(in, frameCipher,
frameKey); frameKey, MAC_LENGTH);
// First frame // First frame
byte[] decrypted = new byte[ciphertext.length]; byte[] decrypted = new byte[MAX_FRAME_LENGTH];
TestUtils.readFully(d.getInputStream(), decrypted); assertEquals(plaintext.length, d.readFrame(decrypted));
byte[] decryptedMac = new byte[MAC_LENGTH]; for(int i = 0; i < plaintext.length; i++) {
d.readFinal(decryptedMac); assertEquals(plaintext[i], decrypted[i]);
}
// Second frame // Second frame
byte[] decrypted1 = new byte[ciphertext1.length]; assertEquals(plaintext1.length, d.readFrame(decrypted));
TestUtils.readFully(d.getInputStream(), decrypted1); for(int i = 0; i < plaintext1.length; i++) {
byte[] decryptedMac1 = new byte[MAC_LENGTH]; assertEquals(plaintext1[i], decrypted[i]);
d.readFinal(decryptedMac1); }
// Check that the actual plaintext matches the expected plaintext
out.reset();
out.write(plaintext);
out.write(plaintext1);
byte[] expected = out.toByteArray();
out.reset();
out.write(decrypted);
out.write(decryptedMac);
out.write(decrypted1);
out.write(decryptedMac1);
byte[] actual = out.toByteArray();
assertArrayEquals(expected, actual);
} }
} }

View File

@@ -69,8 +69,8 @@ public class ConnectionEncrypterImplTest extends BriarTestCase {
out.reset(); out.reset();
ConnectionEncrypter e = new ConnectionEncrypterImpl(out, Long.MAX_VALUE, ConnectionEncrypter e = new ConnectionEncrypterImpl(out, Long.MAX_VALUE,
tagCipher, frameCipher, tagKey, frameKey); tagCipher, frameCipher, tagKey, frameKey);
e.writeFrame(plaintext, 0, plaintext.length); e.writeFrame(plaintext, plaintext.length);
e.writeFrame(plaintext1, 0, plaintext1.length); e.writeFrame(plaintext1, plaintext1.length);
byte[] actual = out.toByteArray(); byte[] actual = out.toByteArray();
// Check that the actual ciphertext matches the expected ciphertext // Check that the actual ciphertext matches the expected ciphertext
assertArrayEquals(expected, actual); assertArrayEquals(expected, actual);

View File

@@ -31,7 +31,7 @@ public class ConnectionReaderImplTest extends TransportTest {
mac.doFinal(frame, FRAME_HEADER_LENGTH + payloadLength); mac.doFinal(frame, FRAME_HEADER_LENGTH + payloadLength);
// Read the frame // Read the frame
ByteArrayInputStream in = new ByteArrayInputStream(frame); ByteArrayInputStream in = new ByteArrayInputStream(frame);
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength);
ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
// There should be no bytes available before EOF // There should be no bytes available before EOF
assertEquals(-1, r.getInputStream().read()); assertEquals(-1, r.getInputStream().read());
@@ -49,7 +49,7 @@ public class ConnectionReaderImplTest extends TransportTest {
mac.doFinal(frame, FRAME_HEADER_LENGTH + payloadLength); mac.doFinal(frame, FRAME_HEADER_LENGTH + payloadLength);
// Read the frame // Read the frame
ByteArrayInputStream in = new ByteArrayInputStream(frame); ByteArrayInputStream in = new ByteArrayInputStream(frame);
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength);
ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
// There should be one byte available before EOF // There should be one byte available before EOF
assertEquals(0, r.getInputStream().read()); assertEquals(0, r.getInputStream().read());
@@ -75,7 +75,7 @@ public class ConnectionReaderImplTest extends TransportTest {
out.write(frame1); out.write(frame1);
// Read the first frame // Read the first frame
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength);
ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
byte[] read = new byte[maxPayloadLength]; byte[] read = new byte[maxPayloadLength];
TestUtils.readFully(r.getInputStream(), read); TestUtils.readFully(r.getInputStream(), read);
@@ -109,7 +109,7 @@ public class ConnectionReaderImplTest extends TransportTest {
out.write(frame1); out.write(frame1);
// Read the first frame // Read the first frame
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength);
ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
byte[] read = new byte[maxPayloadLength - paddingLength]; byte[] read = new byte[maxPayloadLength - paddingLength];
TestUtils.readFully(r.getInputStream(), read); TestUtils.readFully(r.getInputStream(), read);
@@ -135,7 +135,7 @@ public class ConnectionReaderImplTest extends TransportTest {
mac.doFinal(frame, FRAME_HEADER_LENGTH + payloadLength + paddingLength); mac.doFinal(frame, FRAME_HEADER_LENGTH + payloadLength + paddingLength);
// Read the frame // Read the frame
ByteArrayInputStream in = new ByteArrayInputStream(frame); ByteArrayInputStream in = new ByteArrayInputStream(frame);
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength);
ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
// The non-zero padding should be rejected // The non-zero padding should be rejected
try { try {
@@ -167,7 +167,7 @@ public class ConnectionReaderImplTest extends TransportTest {
out.write(frame1); out.write(frame1);
// Read the frames // Read the frames
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength);
ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
byte[] read = new byte[payloadLength]; byte[] read = new byte[payloadLength];
TestUtils.readFully(r.getInputStream(), read); TestUtils.readFully(r.getInputStream(), read);
@@ -191,7 +191,7 @@ public class ConnectionReaderImplTest extends TransportTest {
frame[12] ^= 1; frame[12] ^= 1;
// Try to read the frame - not a single byte should be read // Try to read the frame - not a single byte should be read
ByteArrayInputStream in = new ByteArrayInputStream(frame); ByteArrayInputStream in = new ByteArrayInputStream(frame);
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength);
ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
try { try {
r.getInputStream().read(); r.getInputStream().read();
@@ -213,7 +213,7 @@ public class ConnectionReaderImplTest extends TransportTest {
frame[17] ^= 1; frame[17] ^= 1;
// Try to read the frame - not a single byte should be read // Try to read the frame - not a single byte should be read
ByteArrayInputStream in = new ByteArrayInputStream(frame); ByteArrayInputStream in = new ByteArrayInputStream(frame);
ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength);
ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey);
try { try {
r.getInputStream().read(); r.getInputStream().read();

View File

@@ -90,7 +90,7 @@ public class FrameReadWriteTest extends BriarTestCase {
assertTrue(TagEncoder.validateTag(tag, 0, tagCipher, tagKey)); assertTrue(TagEncoder.validateTag(tag, 0, tagCipher, tagKey));
// Read the frames back // Read the frames back
ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in, ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in,
frameCipher, frameKey); frameCipher, frameKey, mac.getMacLength());
ConnectionReader reader = new ConnectionReaderImpl(decrypter, mac, ConnectionReader reader = new ConnectionReaderImpl(decrypter, mac,
macKey); macKey);
InputStream in1 = reader.getInputStream(); InputStream in1 = reader.getInputStream();

View File

@@ -1,29 +1,48 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import java.io.EOFException; import java.io.EOFException;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import net.sf.briar.api.FormatException;
/** A ConnectionDecrypter that performs no decryption. */ /** A ConnectionDecrypter that performs no decryption. */
class NullConnectionDecrypter implements ConnectionDecrypter { class NullConnectionDecrypter implements ConnectionDecrypter {
private final InputStream in; private final InputStream in;
private final int macLength;
NullConnectionDecrypter(InputStream in) { NullConnectionDecrypter(InputStream in, int macLength) {
this.in = in; this.in = in;
this.macLength = macLength;
} }
public InputStream getInputStream() { public int readFrame(byte[] b) throws IOException {
return in; if(b.length < MAX_FRAME_LENGTH) throw new IllegalArgumentException();
} // Read the header to determine the frame length
int offset = 0, length = FRAME_HEADER_LENGTH;
public void readFinal(byte[] mac) throws IOException { while(offset < length) {
int offset = 0; int read = in.read(b, offset, length - offset);
while(offset < mac.length) { if(read == -1) {
int read = in.read(mac, offset, mac.length - offset); if(offset == 0) return -1;
if(read == -1) break; throw new EOFException();
}
offset += read; offset += read;
} }
if(offset < mac.length) throw new EOFException(); // Parse the header
int payload = HeaderEncoder.getPayloadLength(b);
int padding = HeaderEncoder.getPaddingLength(b);
length = FRAME_HEADER_LENGTH + payload + padding + macLength;
if(length > MAX_FRAME_LENGTH) throw new FormatException();
// Read the remainder of the frame
while(offset < length) {
int read = in.read(b, offset, length - offset);
if(read == -1) throw new EOFException();
offset += read;
}
return length;
} }
} }

View File

@@ -20,8 +20,8 @@ class NullConnectionEncrypter implements ConnectionEncrypter {
this.capacity = capacity; this.capacity = capacity;
} }
public void writeFrame(byte[] b, int off, int len) throws IOException { public void writeFrame(byte[] b, int len) throws IOException {
out.write(b, off, len); out.write(b, 0, len);
capacity -= len; capacity -= len;
} }