Changed to fixed-length frames (mostly untested).

This commit is contained in:
akwizgran
2012-08-28 09:15:04 +01:00
parent 312ad9d534
commit ff73905330
31 changed files with 448 additions and 892 deletions

View File

@@ -0,0 +1,70 @@
package net.sf.briar.crypto;
import java.security.InvalidKeyException;
import java.security.Key;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import net.sf.briar.api.crypto.AuthenticatedCipher;
import org.spongycastle.crypto.DataLengthException;
import org.spongycastle.crypto.InvalidCipherTextException;
import org.spongycastle.crypto.modes.AEADBlockCipher;
import org.spongycastle.crypto.params.AEADParameters;
import org.spongycastle.crypto.params.KeyParameter;
class AuthenticatedCipherImpl implements AuthenticatedCipher {
private final AEADBlockCipher cipher;
private final int macLength;
AuthenticatedCipherImpl(AEADBlockCipher cipher, int macLength) {
this.cipher = cipher;
this.macLength = macLength;
}
public int doFinal(byte[] input, int inputOff, int len, byte[] output,
int outputOff) throws IllegalBlockSizeException,
BadPaddingException {
int processed = 0;
if(len != 0) {
processed = cipher.processBytes(input, inputOff, len, output,
outputOff);
}
try {
return processed + cipher.doFinal(output, outputOff + processed);
} catch(DataLengthException e) {
throw new IllegalBlockSizeException(e.getMessage());
} catch(InvalidCipherTextException e) {
throw new BadPaddingException(e.getMessage());
}
}
public void init(int opmode, Key key, byte[] iv, byte[] aad)
throws InvalidKeyException {
KeyParameter k = new KeyParameter(key.getEncoded());
AEADParameters params = new AEADParameters(k, macLength * 8, iv, aad);
try {
switch(opmode) {
case Cipher.ENCRYPT_MODE:
case Cipher.WRAP_MODE:
cipher.init(true, params);
break;
case Cipher.DECRYPT_MODE:
case Cipher.UNWRAP_MODE:
cipher.init(false, params);
break;
default:
throw new IllegalArgumentException();
}
} catch(Exception e) {
throw new InvalidKeyException(e.getMessage());
}
}
public int getMacLength() {
return macLength;
}
}

View File

@@ -15,14 +15,17 @@ import javax.crypto.Cipher;
import javax.crypto.KeyAgreement;
import javax.crypto.spec.IvParameterSpec;
import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.crypto.IvEncoder;
import net.sf.briar.api.crypto.KeyParser;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.crypto.PseudoRandom;
import net.sf.briar.util.ByteUtils;
import org.spongycastle.crypto.engines.AESEngine;
import org.spongycastle.crypto.modes.AEADBlockCipher;
import org.spongycastle.crypto.modes.GCMBlockCipher;
import org.spongycastle.jce.provider.BouncyCastleProvider;
import com.google.inject.Inject;
@@ -42,8 +45,7 @@ class CryptoComponentImpl implements CryptoComponent {
private static final int SIGNATURE_KEY_PAIR_BITS = 384;
private static final String SIGNATURE_ALGO = "ECDSA";
private static final String TAG_CIPHER_ALGO = "AES/ECB/NoPadding";
private static final String FRAME_CIPHER_ALGO = "AES/GCM/NoPadding";
private static final String FRAME_PEEKING_CIPHER_ALGO = "AES/CTR/NoPadding";
private static final int GCM_MAC_LENGTH = 16; // 128 bits
// Labels for key derivation
private static final byte[] TAG = { 'T', 'A', 'G' };
@@ -275,27 +277,10 @@ class CryptoComponentImpl implements CryptoComponent {
}
}
public Cipher getFrameCipher() {
try {
return Cipher.getInstance(FRAME_CIPHER_ALGO, PROVIDER);
} catch(GeneralSecurityException e) {
throw new RuntimeException(e);
}
}
public Cipher getFramePeekingCipher() {
try {
return Cipher.getInstance(FRAME_PEEKING_CIPHER_ALGO, PROVIDER);
} catch(GeneralSecurityException e) {
throw new RuntimeException(e);
}
}
public IvEncoder getFrameIvEncoder() {
return new FrameIvEncoder();
}
public IvEncoder getFramePeekingIvEncoder() {
return new FramePeekingIvEncoder();
public AuthenticatedCipher getFrameCipher() {
// This code is specific to BouncyCastle because javax.crypto.Cipher
// doesn't support additional authenticated data until Java 7
AEADBlockCipher cipher = new GCMBlockCipher(new AESEngine());
return new AuthenticatedCipherImpl(cipher, GCM_MAC_LENGTH);
}
}

View File

@@ -1,26 +0,0 @@
package net.sf.briar.crypto;
import net.sf.briar.api.crypto.IvEncoder;
import net.sf.briar.util.ByteUtils;
class FrameIvEncoder implements IvEncoder {
// AES-GCM uses a 96-bit IV; the bytes 0x00, 0x00, 0x00, 0x02 are
// appended internally (see NIST SP 800-38D, section 7.1)
private static final int IV_LENGTH = 12;
public byte[] encodeIv(long frame) {
if(frame < 0 || frame > ByteUtils.MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
byte[] iv = new byte[IV_LENGTH];
updateIv(iv, frame);
return iv;
}
public void updateIv(byte[] iv, long frame) {
if(frame < 0 || frame > ByteUtils.MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
// Encode the frame number as a uint32
ByteUtils.writeUint32(frame, iv, 0);
}
}

View File

@@ -1,20 +0,0 @@
package net.sf.briar.crypto;
import net.sf.briar.util.ByteUtils;
class FramePeekingIvEncoder extends FrameIvEncoder {
// AES/CTR uses a 128-bit IV; to match the AES/GCM IV we have to append
// the bytes 0x00, 0x00, 0x00, 0x02 (see NIST SP 800-38D, section 7.1)
private static final int IV_LENGTH = 16;
@Override
public byte[] encodeIv(long frame) {
if(frame < 0 || frame > ByteUtils.MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
byte[] iv = new byte[IV_LENGTH];
iv[IV_LENGTH - 1] = 2;
updateIv(iv, frame);
return iv;
}
}

View File

@@ -1,12 +1,14 @@
package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import java.io.InputStream;
import javax.crypto.Cipher;
import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.crypto.IvEncoder;
import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.util.ByteUtils;
@@ -30,13 +32,9 @@ class ConnectionReaderFactoryImpl implements ConnectionReaderFactory {
ByteUtils.erase(secret);
// Create the reader
Cipher tagCipher = crypto.getTagCipher();
Cipher frameCipher = crypto.getFrameCipher();
Cipher framePeekingCipher = crypto.getFramePeekingCipher();
IvEncoder frameIvEncoder = crypto.getFrameIvEncoder();
IvEncoder framePeekingIvEncoder = crypto.getFramePeekingIvEncoder();
AuthenticatedCipher frameCipher = crypto.getFrameCipher();
FrameReader encryption = new IncomingEncryptionLayer(in, tagCipher,
frameCipher, framePeekingCipher, frameIvEncoder,
framePeekingIvEncoder, tagKey, frameKey, !initiator);
return new ConnectionReaderImpl(encryption);
frameCipher, tagKey, frameKey, !initiator, MAX_FRAME_LENGTH);
return new ConnectionReaderImpl(encryption, MAX_FRAME_LENGTH);
}
}

View File

@@ -1,12 +1,10 @@
package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import java.io.IOException;
import java.io.InputStream;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.transport.ConnectionReader;
class ConnectionReaderImpl extends InputStream implements ConnectionReader {
@@ -16,10 +14,9 @@ class ConnectionReaderImpl extends InputStream implements ConnectionReader {
private int offset = 0, length = 0;
ConnectionReaderImpl(FrameReader in) {
ConnectionReaderImpl(FrameReader in, int frameLength) {
this.in = in;
frame = new byte[MAX_FRAME_LENGTH];
offset = HEADER_LENGTH;
frame = new byte[frameLength];
}
public InputStream getInputStream() {
@@ -28,8 +25,10 @@ class ConnectionReaderImpl extends InputStream implements ConnectionReader {
@Override
public int read() throws IOException {
if(length == -1) return -1;
while(length == 0) if(!readFrame()) return -1;
while(length <= 0) {
if(length == -1) return -1;
readFrame();
}
int b = frame[offset] & 0xff;
offset++;
length--;
@@ -43,8 +42,10 @@ class ConnectionReaderImpl extends InputStream implements ConnectionReader {
@Override
public int read(byte[] b, int off, int len) throws IOException {
if(length == -1) return -1;
while(length == 0) if(!readFrame()) return -1;
while(length <= 0) {
if(length == -1) return -1;
readFrame();
}
len = Math.min(len, length);
System.arraycopy(frame, offset, b, off, len);
offset += len;
@@ -52,20 +53,9 @@ class ConnectionReaderImpl extends InputStream implements ConnectionReader {
return len;
}
private boolean readFrame() throws IOException {
private void readFrame() throws IOException {
assert length == 0;
if(HeaderEncoder.isLastFrame(frame)) {
length = -1;
return false;
}
if(!in.readFrame(frame)) throw new FormatException();
offset = HEADER_LENGTH;
length = HeaderEncoder.getPayloadLength(frame);
// The padding must be all zeroes
int padding = HeaderEncoder.getPaddingLength(frame);
for(int i = offset + length; i < offset + length + padding; i++) {
if(frame[i] != 0) throw new FormatException();
}
return true;
length = in.readFrame(frame);
}
}

View File

@@ -1,12 +1,14 @@
package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import java.io.OutputStream;
import javax.crypto.Cipher;
import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.crypto.IvEncoder;
import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.util.ByteUtils;
@@ -30,11 +32,10 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
ByteUtils.erase(secret);
// Create the writer
Cipher tagCipher = crypto.getTagCipher();
Cipher frameCipher = crypto.getFrameCipher();
IvEncoder frameIvEncoder = crypto.getFrameIvEncoder();
FrameWriter encryption = new OutgoingEncryptionLayer(
out, capacity, tagCipher, frameCipher, frameIvEncoder, tagKey,
frameKey);
return new ConnectionWriterImpl(encryption);
AuthenticatedCipher frameCipher = crypto.getFrameCipher();
FrameWriter encryption = new OutgoingEncryptionLayer(out, capacity,
tagCipher, frameCipher, tagKey, frameKey, initiator,
MAX_FRAME_LENGTH);
return new ConnectionWriterImpl(encryption, MAX_FRAME_LENGTH);
}
}

View File

@@ -2,7 +2,6 @@ package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.io.IOException;
@@ -12,7 +11,7 @@ import net.sf.briar.api.transport.ConnectionWriter;
/**
* A ConnectionWriter that buffers its input and writes a frame whenever there
* is a full-size frame to write or the flush() method is called.
* is a full frame to write or the flush() method is called.
* <p>
* This class is not thread-safe.
*/
@@ -20,15 +19,15 @@ class ConnectionWriterImpl extends OutputStream implements ConnectionWriter {
private final FrameWriter out;
private final byte[] frame;
private final int frameLength;
private int offset;
private long frameNumber;
private int length = 0;
private long frameNumber = 0L;
ConnectionWriterImpl(FrameWriter out) {
ConnectionWriterImpl(FrameWriter out, int frameLength) {
this.out = out;
frame = new byte[MAX_FRAME_LENGTH];
offset = HEADER_LENGTH;
frameNumber = 0L;
this.frameLength = frameLength;
frame = new byte[frameLength];
}
public OutputStream getOutputStream() {
@@ -37,31 +36,31 @@ class ConnectionWriterImpl extends OutputStream implements ConnectionWriter {
public long getRemainingCapacity() {
long capacity = out.getRemainingCapacity();
// If there's any data buffered, subtract it and its overhead
if(offset > HEADER_LENGTH) capacity -= offset + MAC_LENGTH;
// Subtract the overhead from the remaining capacity
long frames = (long) Math.ceil((double) capacity / MAX_FRAME_LENGTH);
int overheadPerFrame = HEADER_LENGTH + MAC_LENGTH;
return Math.max(0L, capacity - frames * overheadPerFrame);
int maxPayloadLength = frameLength - HEADER_LENGTH - MAC_LENGTH;
long frames = (long) Math.ceil((double) capacity / maxPayloadLength);
long overhead = (frames + 1) * (HEADER_LENGTH + MAC_LENGTH);
return capacity - overhead - length;
}
@Override
public void close() throws IOException {
if(offset > HEADER_LENGTH || frameNumber > 0L) writeFrame(true);
writeFrame(true);
out.flush();
super.close();
}
@Override
public void flush() throws IOException {
if(offset > HEADER_LENGTH) writeFrame(false);
if(length > 0) writeFrame(false);
out.flush();
}
@Override
public void write(int b) throws IOException {
frame[offset++] = (byte) b;
if(offset + MAC_LENGTH == MAX_FRAME_LENGTH) writeFrame(false);
frame[HEADER_LENGTH + length] = (byte) b;
length++;
if(HEADER_LENGTH + length + MAC_LENGTH == frameLength)
writeFrame(false);
}
@Override
@@ -71,26 +70,26 @@ class ConnectionWriterImpl extends OutputStream implements ConnectionWriter {
@Override
public void write(byte[] b, int off, int len) throws IOException {
int available = MAX_FRAME_LENGTH - offset - MAC_LENGTH;
int available = frameLength - HEADER_LENGTH - length - MAC_LENGTH;
while(available <= len) {
System.arraycopy(b, off, frame, offset, available);
offset += available;
System.arraycopy(b, off, frame, HEADER_LENGTH + length, available);
length += available;
writeFrame(false);
off += available;
len -= available;
available = MAX_FRAME_LENGTH - offset - MAC_LENGTH;
available = frameLength - HEADER_LENGTH - length - MAC_LENGTH;
}
System.arraycopy(b, off, frame, offset, len);
offset += len;
System.arraycopy(b, off, frame, HEADER_LENGTH + length, len);
length += len;
}
private void writeFrame(boolean lastFrame) throws IOException {
if(frameNumber > MAX_32_BIT_UNSIGNED) throw new IllegalStateException();
int payload = offset - HEADER_LENGTH;
assert payload >= 0;
HeaderEncoder.encodeHeader(frame, frameNumber, payload, 0, lastFrame);
out.writeFrame(frame);
offset = HEADER_LENGTH;
int capacity = (int) Math.min(frameLength, out.getRemainingCapacity());
int paddingLength = capacity - HEADER_LENGTH - length - MAC_LENGTH;
if(paddingLength < 0) throw new IllegalStateException();
out.writeFrame(frame, length, lastFrame ? 0 : paddingLength, lastFrame);
length = 0;
frameNumber++;
}
}

View File

@@ -0,0 +1,53 @@
package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.AAD_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import net.sf.briar.util.ByteUtils;
class FrameEncoder {
static void encodeIv(byte[] iv, long frameNumber) {
if(iv.length < IV_LENGTH) throw new IllegalArgumentException();
if(frameNumber < 0L || frameNumber > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
ByteUtils.writeUint32(frameNumber, iv, 0);
for(int i = 4; i < IV_LENGTH; i++) iv[i] = 0;
}
static void encodeAad(byte[] aad, long frameNumber, int plaintextLength) {
if(aad.length < AAD_LENGTH) throw new IllegalArgumentException();
if(frameNumber < 0L || frameNumber > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
if(plaintextLength < HEADER_LENGTH)
throw new IllegalArgumentException();
if(plaintextLength > MAX_FRAME_LENGTH - MAC_LENGTH)
throw new IllegalArgumentException();
ByteUtils.writeUint32(frameNumber, aad, 0);
ByteUtils.writeUint16(plaintextLength, aad, 4);
}
static void encodeHeader(byte[] header, boolean lastFrame,
int payloadLength) {
if(header.length < HEADER_LENGTH) throw new IllegalArgumentException();
if(payloadLength < 0)
throw new IllegalArgumentException();
if(payloadLength > MAX_FRAME_LENGTH - HEADER_LENGTH - MAC_LENGTH)
throw new IllegalArgumentException();
ByteUtils.writeUint16(payloadLength, header, 0);
if(lastFrame) header[0] |= 0x80;
}
static boolean isLastFrame(byte[] header) {
if(header.length < HEADER_LENGTH) throw new IllegalArgumentException();
return (header[0] & 0x80) == 0x80;
}
static int getPayloadLength(byte[] header) {
if(header.length < HEADER_LENGTH) throw new IllegalArgumentException();
return ByteUtils.readUint16(header, 0) & 0x7FFF;
}
}

View File

@@ -5,8 +5,8 @@ import java.io.IOException;
interface FrameReader {
/**
* Reads a frame into the given buffer. Returns false if no more frames can
* be read from the connection.
* Reads a frame into the given buffer and returns its payload length, or
* -1 if no more frames can be read from the connection.
*/
boolean readFrame(byte[] frame) throws IOException;
int readFrame(byte[] frame) throws IOException;
}

View File

@@ -5,7 +5,8 @@ import java.io.IOException;
interface FrameWriter {
/** Writes the given frame. */
void writeFrame(byte[] frame) throws IOException;
void writeFrame(byte[] frame, int payloadLength, int paddingLength,
boolean lastFrame) throws IOException;
/** Flushes the stack. */
void flush() throws IOException;

View File

@@ -1,54 +0,0 @@
package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import net.sf.briar.util.ByteUtils;
class HeaderEncoder {
static void encodeHeader(byte[] header, long frameNumber, int payload,
int padding, boolean lastFrame) {
if(header.length < HEADER_LENGTH)
throw new IllegalArgumentException();
if(frameNumber < 0 || frameNumber > ByteUtils.MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
if(payload < 0 || payload > ByteUtils.MAX_16_BIT_UNSIGNED)
throw new IllegalArgumentException();
if(padding < 0 || padding > ByteUtils.MAX_16_BIT_UNSIGNED)
throw new IllegalArgumentException();
ByteUtils.writeUint32(frameNumber, header, 0);
ByteUtils.writeUint16(payload, header, 4);
ByteUtils.writeUint16(padding, header, 6);
if(lastFrame) header[8] = 1;
}
static boolean checkHeader(byte[] header, int length) {
if(header.length < HEADER_LENGTH) throw new IllegalArgumentException();
int payload = getPayloadLength(header);
int padding = getPaddingLength(header);
if(HEADER_LENGTH + payload + padding + MAC_LENGTH != length)
return false;
if(header[8] != 0 && header[8] != 1) return false;
return true;
}
static long getFrameNumber(byte[] header) {
if(header.length < HEADER_LENGTH) throw new IllegalArgumentException();
return ByteUtils.readUint32(header, 0);
}
static int getPayloadLength(byte[] header) {
if(header.length < HEADER_LENGTH) throw new IllegalArgumentException();
return ByteUtils.readUint16(header, 4);
}
static int getPaddingLength(byte[] header) {
if(header.length < HEADER_LENGTH) throw new IllegalArgumentException();
return ByteUtils.readUint16(header, 6);
}
static boolean isLastFrame(byte[] header) {
if(header.length < HEADER_LENGTH) throw new IllegalArgumentException();
return header[8] == 1;
}
}

View File

@@ -1,8 +1,10 @@
package net.sf.briar.transport;
import static javax.crypto.Cipher.DECRYPT_MODE;
import static net.sf.briar.api.transport.TransportConstants.AAD_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import java.io.EOFException;
@@ -11,110 +13,98 @@ import java.io.InputStream;
import java.security.GeneralSecurityException;
import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.crypto.IvEncoder;
class IncomingEncryptionLayer implements FrameReader {
private final InputStream in;
private final Cipher tagCipher, frameCipher, framePeekingCipher;
private final IvEncoder frameIvEncoder, framePeekingIvEncoder;
private final Cipher tagCipher;
private final AuthenticatedCipher frameCipher;
private final ErasableKey tagKey, frameKey;
private final int blockSize;
private final byte[] frameIv, framePeekingIv, ciphertext;
private final byte[] iv, aad, ciphertext;
private final int maxFrameLength;
private boolean readTag;
private boolean readTag, lastFrame;
private long frameNumber;
IncomingEncryptionLayer(InputStream in, Cipher tagCipher,
Cipher frameCipher, Cipher framePeekingCipher,
IvEncoder frameIvEncoder, IvEncoder framePeekingIvEncoder,
ErasableKey tagKey, ErasableKey frameKey, boolean readTag) {
AuthenticatedCipher frameCipher, ErasableKey tagKey,
ErasableKey frameKey, boolean readTag, int maxFrameLength) {
this.in = in;
this.tagCipher = tagCipher;
this.frameCipher = frameCipher;
this.framePeekingCipher = framePeekingCipher;
this.frameIvEncoder = frameIvEncoder;
this.framePeekingIvEncoder = framePeekingIvEncoder;
this.tagKey = tagKey;
this.frameKey = frameKey;
this.readTag = readTag;
blockSize = frameCipher.getBlockSize();
if(blockSize < HEADER_LENGTH) throw new IllegalArgumentException();
frameIv = frameIvEncoder.encodeIv(0L);
framePeekingIv = framePeekingIvEncoder.encodeIv(0L);
ciphertext = new byte[MAX_FRAME_LENGTH];
this.maxFrameLength = maxFrameLength;
lastFrame = false;
iv = new byte[IV_LENGTH];
aad = new byte[AAD_LENGTH];
ciphertext = new byte[maxFrameLength];
frameNumber = 0L;
}
public boolean readFrame(byte[] frame) throws IOException {
try {
// Read the tag if it hasn't already been read
if(readTag) {
int offset = 0;
public int readFrame(byte[] frame) throws IOException {
if(lastFrame) return -1;
// Read the tag if required
if(readTag) {
int offset = 0;
try {
while(offset < TAG_LENGTH) {
int read = in.read(ciphertext, offset,
TAG_LENGTH - offset);
if(read == -1) {
if(offset == 0) return false;
throw new EOFException();
}
int read = in.read(ciphertext, offset, TAG_LENGTH - offset);
if(read == -1) throw new EOFException();
offset += read;
}
if(!TagEncoder.decodeTag(ciphertext, tagCipher, tagKey))
throw new FormatException();
}
// Read the first block of the frame
int offset = 0;
while(offset < blockSize) {
int read = in.read(ciphertext, offset, blockSize - offset);
if(read == -1) throw new EOFException();
offset += read;
} catch(IOException e) {
frameKey.erase();
tagKey.erase();
throw e;
}
if(!TagEncoder.decodeTag(ciphertext, tagCipher, tagKey))
throw new FormatException();
readTag = false;
// Decrypt the first block of the frame to peek at the header
framePeekingIvEncoder.updateIv(framePeekingIv, frameNumber);
IvParameterSpec ivSpec = new IvParameterSpec(framePeekingIv);
try {
framePeekingCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec);
int decrypted = framePeekingCipher.update(ciphertext, 0,
blockSize, frame);
if(decrypted != blockSize) throw new RuntimeException();
} catch(GeneralSecurityException badCipher) {
throw new RuntimeException(badCipher);
}
// Read the frame
int ciphertextLength = 0;
try {
while(ciphertextLength < maxFrameLength) {
int read = in.read(ciphertext, ciphertextLength,
maxFrameLength - ciphertextLength);
if(read == -1) break; // We'll check the length later
ciphertextLength += read;
}
// Parse the frame header
int payload = HeaderEncoder.getPayloadLength(frame);
int padding = HeaderEncoder.getPaddingLength(frame);
int length = HEADER_LENGTH + payload + padding + MAC_LENGTH;
if(length > MAX_FRAME_LENGTH) throw new FormatException();
// Read the remainder of the frame
while(offset < length) {
int read = in.read(ciphertext, offset, length - offset);
if(read == -1) throw new EOFException();
offset += read;
}
// Decrypt and authenticate the entire frame
frameIvEncoder.updateIv(frameIv, frameNumber);
ivSpec = new IvParameterSpec(frameIv);
try {
frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec);
int decrypted = frameCipher.doFinal(ciphertext, 0, length,
frame);
if(decrypted != length - MAC_LENGTH)
throw new RuntimeException();
} catch(GeneralSecurityException badCipher) {
throw new RuntimeException(badCipher);
}
frameNumber++;
return true;
} catch(IOException e) {
frameKey.erase();
tagKey.erase();
throw e;
}
int plaintextLength = ciphertextLength - MAC_LENGTH;
if(plaintextLength < HEADER_LENGTH) throw new EOFException();
// Decrypt and authenticate the frame
FrameEncoder.encodeIv(iv, frameNumber);
FrameEncoder.encodeAad(aad, frameNumber, plaintextLength);
try {
frameCipher.init(DECRYPT_MODE, frameKey, iv, aad);
int decrypted = frameCipher.doFinal(ciphertext, 0, ciphertextLength,
frame, 0);
if(decrypted != plaintextLength) throw new RuntimeException();
} catch(GeneralSecurityException badCipher) {
throw new RuntimeException(badCipher);
}
// Decode and validate the header
lastFrame = FrameEncoder.isLastFrame(frame);
if(!lastFrame && ciphertextLength < maxFrameLength)
throw new EOFException();
int payloadLength = FrameEncoder.getPayloadLength(frame);
if(payloadLength > plaintextLength - HEADER_LENGTH)
throw new FormatException();
// If there's any padding it must be all zeroes
for(int i = HEADER_LENGTH + payloadLength; i < plaintextLength; i++)
if(frame[i] != 0) throw new FormatException();
frameNumber++;
return payloadLength;
}
}

View File

@@ -1,8 +1,10 @@
package net.sf.briar.transport;
import static javax.crypto.Cipher.ENCRYPT_MODE;
import static net.sf.briar.api.transport.TransportConstants.AAD_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import java.io.IOException;
@@ -10,62 +12,85 @@ import java.io.OutputStream;
import java.security.GeneralSecurityException;
import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.crypto.IvEncoder;
class OutgoingEncryptionLayer implements FrameWriter {
private final OutputStream out;
private final Cipher tagCipher, frameCipher;
private final IvEncoder frameIvEncoder;
private final Cipher tagCipher;
private final AuthenticatedCipher frameCipher;
private final ErasableKey tagKey, frameKey;
private final byte[] frameIv, ciphertext;
private final byte[] iv, aad, ciphertext;
private final int maxFrameLength;
private boolean writeTag;
private long capacity, frameNumber;
OutgoingEncryptionLayer(OutputStream out, long capacity, Cipher tagCipher,
Cipher frameCipher, IvEncoder frameIvEncoder, ErasableKey tagKey,
ErasableKey frameKey) {
AuthenticatedCipher frameCipher, ErasableKey tagKey,
ErasableKey frameKey, boolean writeTag, int maxFrameLength) {
this.out = out;
this.capacity = capacity;
this.tagCipher = tagCipher;
this.frameCipher = frameCipher;
this.frameIvEncoder = frameIvEncoder;
this.tagKey = tagKey;
this.frameKey = frameKey;
frameIv = frameIvEncoder.encodeIv(0L);
ciphertext = new byte[TAG_LENGTH + MAX_FRAME_LENGTH];
this.writeTag = writeTag;
this.maxFrameLength = maxFrameLength;
iv = new byte[IV_LENGTH];
aad = new byte[AAD_LENGTH];
ciphertext = new byte[maxFrameLength];
frameNumber = 0L;
}
public void writeFrame(byte[] frame) throws IOException {
int payload = HeaderEncoder.getPayloadLength(frame);
int padding = HeaderEncoder.getPaddingLength(frame);
int offset = 0, length = HEADER_LENGTH + payload + padding;
if(frameNumber == 0) {
public void writeFrame(byte[] frame, int payloadLength, int paddingLength,
boolean lastFrame) throws IOException {
int plaintextLength = HEADER_LENGTH + payloadLength + paddingLength;
int ciphertextLength = plaintextLength + MAC_LENGTH;
if(ciphertextLength > maxFrameLength)
throw new IllegalArgumentException();
if(!lastFrame && ciphertextLength < maxFrameLength)
throw new IllegalArgumentException();
// Write the tag if required
if(writeTag) {
TagEncoder.encodeTag(ciphertext, tagCipher, tagKey);
offset = TAG_LENGTH;
try {
out.write(ciphertext, 0, TAG_LENGTH);
} catch(IOException e) {
frameKey.erase();
tagKey.erase();
throw e;
}
capacity -= TAG_LENGTH;
writeTag = false;
}
frameIvEncoder.updateIv(frameIv, frameNumber);
IvParameterSpec ivSpec = new IvParameterSpec(frameIv);
// Encode the header
FrameEncoder.encodeHeader(frame, lastFrame, payloadLength);
// If there's any padding it must all be zeroes
for(int i = HEADER_LENGTH + payloadLength; i < plaintextLength; i++)
frame[i] = 0;
// Encrypt and authenticate the frame
FrameEncoder.encodeIv(iv, frameNumber);
FrameEncoder.encodeAad(aad, frameNumber, plaintextLength);
try {
frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec);
int encrypted = frameCipher.doFinal(frame, 0, length, ciphertext,
offset);
if(encrypted != length + MAC_LENGTH) throw new RuntimeException();
frameCipher.init(ENCRYPT_MODE, frameKey, iv, aad);
int encrypted = frameCipher.doFinal(frame, 0, plaintextLength,
ciphertext, 0);
if(encrypted != ciphertextLength) throw new RuntimeException();
} catch(GeneralSecurityException badCipher) {
throw new RuntimeException(badCipher);
}
// Write the frame
try {
out.write(ciphertext, 0, offset + length + MAC_LENGTH);
out.write(ciphertext, 0, ciphertextLength);
} catch(IOException e) {
frameKey.erase();
tagKey.erase();
throw e;
}
capacity -= offset + length + MAC_LENGTH;
capacity -= ciphertextLength;
frameNumber++;
}
@@ -74,6 +99,6 @@ class OutgoingEncryptionLayer implements FrameWriter {
}
public long getRemainingCapacity() {
return capacity;
return writeTag ? capacity - TAG_LENGTH : capacity;
}
}

View File

@@ -1,5 +1,7 @@
package net.sf.briar.transport;
import static javax.crypto.Cipher.DECRYPT_MODE;
import static javax.crypto.Cipher.ENCRYPT_MODE;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import java.security.GeneralSecurityException;
@@ -15,7 +17,7 @@ class TagEncoder {
// Blank plaintext
for(int i = 0; i < TAG_LENGTH; i++) tag[i] = 0;
try {
tagCipher.init(Cipher.ENCRYPT_MODE, tagKey);
tagCipher.init(ENCRYPT_MODE, tagKey);
int encrypted = tagCipher.doFinal(tag, 0, TAG_LENGTH, tag);
if(encrypted != TAG_LENGTH) throw new IllegalArgumentException();
} catch(GeneralSecurityException e) {
@@ -27,7 +29,7 @@ class TagEncoder {
static boolean decodeTag(byte[] tag, Cipher tagCipher, ErasableKey tagKey) {
if(tag.length < TAG_LENGTH) throw new IllegalArgumentException();
try {
tagCipher.init(Cipher.DECRYPT_MODE, tagKey);
tagCipher.init(DECRYPT_MODE, tagKey);
int decrypted = tagCipher.doFinal(tag, 0, TAG_LENGTH, tag);
if(decrypted != TAG_LENGTH) throw new IllegalArgumentException();
//The plaintext should be blank