Validate the decrypted IV before creating a reader/writer.

This commit is contained in:
akwizgran
2011-10-18 15:58:10 +01:00
parent 2f457162a5
commit d7a417f36d
18 changed files with 121 additions and 80 deletions

View File

@@ -5,8 +5,8 @@ import net.sf.briar.api.TransportId;
public interface BatchConnectionFactory {
void createIncomingConnection(ContactId c, BatchTransportReader r,
byte[] encryptedIv);
void createIncomingConnection(TransportId t, ContactId c,
BatchTransportReader r, byte[] encryptedIv);
void createOutgoingConnection(TransportId t, ContactId c,
BatchTransportWriter w);

View File

@@ -6,9 +6,17 @@ import net.sf.briar.api.TransportId;
public interface ConnectionReaderFactory {
ConnectionReader createConnectionReader(InputStream in, byte[] encryptedIv,
byte[] secret);
/**
* Creates a connection reader for a batch-mode connection or the
* initiator's side of a stream-mode connection.
*/
ConnectionReader createConnectionReader(InputStream in, TransportId t,
byte[] encryptedIv, byte[] secret);
ConnectionReader createConnectionReader(InputStream in, boolean initiator,
TransportId t, long connection, byte[] secret);
/**
* Creates a connection reader for the responder's side of a stream-mode
* connection.
*/
ConnectionReader createConnectionReader(InputStream in, TransportId t,
long connection, byte[] secret);
}

View File

@@ -6,9 +6,17 @@ import net.sf.briar.api.TransportId;
public interface ConnectionWriterFactory {
/**
* Creates a connection writer for a batch-mode connection or the
* initiator's side of a stream-mode connection.
*/
ConnectionWriter createConnectionWriter(OutputStream out, long capacity,
boolean initiator, TransportId t, long connection, byte[] secret);
TransportId t, long connection, byte[] secret);
/**
* Creates a connection writer for the responder's side of a stream-mode
* connection.
*/
ConnectionWriter createConnectionWriter(OutputStream out, long capacity,
TransportId t, byte[] encryptedIv, byte[] secret);
}

View File

@@ -28,26 +28,6 @@ implements ConnectionDecrypter {
private long frame = 0L;
private boolean betweenFrames = true;
ConnectionDecrypterImpl(InputStream in, byte[] encryptedIv, Cipher ivCipher,
Cipher frameCipher, SecretKey ivKey, SecretKey frameKey) {
super(in);
this.frameCipher = frameCipher;
this.frameKey = frameKey;
// Decrypt the IV
try {
ivCipher.init(Cipher.DECRYPT_MODE, ivKey);
iv = ivCipher.doFinal(encryptedIv);
} catch(BadPaddingException badCipher) {
throw new IllegalArgumentException(badCipher);
} catch(IllegalBlockSizeException badCipher) {
throw new IllegalArgumentException(badCipher);
} catch(InvalidKeyException badKey) {
throw new IllegalArgumentException(badKey);
}
if(iv.length != IV_LENGTH) throw new IllegalArgumentException();
buf = new byte[IV_LENGTH];
}
ConnectionDecrypterImpl(InputStream in, byte[] iv, Cipher frameCipher,
SecretKey frameKey) {
super(in);

View File

@@ -63,7 +63,7 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher {
r.dispose(false);
return;
}
batchConnFactory.createIncomingConnection(c, r, encryptedIv);
batchConnFactory.createIncomingConnection(t, c, r, encryptedIv);
}
private byte[] readIv(InputStream in) throws IOException {

View File

@@ -1,8 +1,11 @@
package net.sf.briar.transport;
import java.io.InputStream;
import java.security.InvalidKeyException;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.Mac;
import javax.crypto.SecretKey;
@@ -23,21 +26,35 @@ class ConnectionReaderFactoryImpl implements ConnectionReaderFactory {
}
public ConnectionReader createConnectionReader(InputStream in,
byte[] encryptedIv, byte[] secret) {
// Create the decrypter
TransportId t, byte[] encryptedIv, byte[] secret) {
// Decrypt the IV
Cipher ivCipher = crypto.getIvCipher();
Cipher frameCipher = crypto.getFrameCipher();
SecretKey ivKey = crypto.deriveIncomingIvKey(secret);
SecretKey frameKey = crypto.deriveIncomingFrameKey(secret);
ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in,
encryptedIv, ivCipher, frameCipher, ivKey, frameKey);
// Create the reader
Mac mac = crypto.getMac();
SecretKey macKey = crypto.deriveIncomingMacKey(secret);
return new ConnectionReaderImpl(decrypter, mac, macKey);
byte[] iv;
try {
ivCipher.init(Cipher.DECRYPT_MODE, ivKey);
iv = ivCipher.doFinal(encryptedIv);
} catch(BadPaddingException badCipher) {
throw new IllegalArgumentException(badCipher);
} catch(IllegalBlockSizeException badCipher) {
throw new IllegalArgumentException(badCipher);
} catch(InvalidKeyException badKey) {
throw new IllegalArgumentException(badKey);
}
// Validate the IV
if(!IvEncoder.validateIv(iv, true, t))
throw new IllegalArgumentException();
// Copy the connection number
long connection = IvEncoder.getConnectionNumber(iv);
return createConnectionReader(in, true, t, connection, secret);
}
public ConnectionReader createConnectionReader(InputStream in,
TransportId t, long connection, byte[] secret) {
return createConnectionReader(in, false, t, connection, secret);
}
private ConnectionReader createConnectionReader(InputStream in,
boolean initiator, TransportId t, long connection, byte[] secret) {
byte[] iv = IvEncoder.encodeIv(initiator, t, connection);
// Create the decrypter

View File

@@ -26,20 +26,9 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
}
public ConnectionWriter createConnectionWriter(OutputStream out,
long capacity, boolean initiator, TransportId t, long connection,
byte[] secret) {
// Create the encrypter
Cipher ivCipher = crypto.getIvCipher();
Cipher frameCipher = crypto.getFrameCipher();
SecretKey ivKey = crypto.deriveOutgoingIvKey(secret);
SecretKey frameKey = crypto.deriveOutgoingFrameKey(secret);
byte[] iv = IvEncoder.encodeIv(initiator, t, connection);
ConnectionEncrypter encrypter = new ConnectionEncrypterImpl(out,
capacity, iv, ivCipher, frameCipher, ivKey, frameKey);
// Create the writer
Mac mac = crypto.getMac();
SecretKey macKey = crypto.deriveOutgoingMacKey(secret);
return new ConnectionWriterImpl(encrypter, mac, macKey);
long capacity, TransportId t, long connection, byte[] secret) {
return createConnectionWriter(out, capacity, true, t, connection,
secret);
}
public ConnectionWriter createConnectionWriter(OutputStream out,
@@ -58,15 +47,29 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
} catch(InvalidKeyException badKey) {
throw new RuntimeException(badKey);
}
// Check that the initiator flag is raised
if(!IvEncoder.getInitiatorFlag(iv))
throw new IllegalArgumentException();
// Check that the transport ID matches the expected ID
if(!t.equals(new TransportId(IvEncoder.getTransportId(iv))))
// Validate the IV
if(!IvEncoder.validateIv(iv, true, t))
throw new IllegalArgumentException();
// Copy the connection number
long connection = IvEncoder.getConnectionNumber(iv);
return createConnectionWriter(out, capacity, false, t, connection,
secret);
}
private ConnectionWriter createConnectionWriter(OutputStream out,
long capacity, boolean initiator, TransportId t, long connection,
byte[] secret) {
// Create the encrypter
Cipher ivCipher = crypto.getIvCipher();
Cipher frameCipher = crypto.getFrameCipher();
SecretKey ivKey = crypto.deriveOutgoingIvKey(secret);
SecretKey frameKey = crypto.deriveOutgoingFrameKey(secret);
byte[] iv = IvEncoder.encodeIv(initiator, t, connection);
ConnectionEncrypter encrypter = new ConnectionEncrypterImpl(out,
capacity, iv, ivCipher, frameCipher, ivKey, frameKey);
// Create the writer
Mac mac = crypto.getMac();
SecretKey macKey = crypto.deriveOutgoingMacKey(secret);
return new ConnectionWriterImpl(encrypter, mac, macKey);
}
}

View File

@@ -24,6 +24,20 @@ class IvEncoder {
ByteUtils.writeUint32(frame, iv, 10);
}
static boolean validateIv(byte[] iv, boolean initiator, TransportId t) {
if(iv.length != IV_LENGTH) return false;
// Check that the reserved bits are all zero
for(int i = 0; i < 2; i++) if(iv[i] != 0) return false;
if(iv[3] != 0 && iv[3] != 1) return false;
for(int i = 10; i < iv.length; i++) if(iv[i] != 0) return false;
// Check that the initiator flag matches
if(initiator != getInitiatorFlag(iv)) return false;
// Check that the transport ID matches
if(t.getInt() != getTransportId(iv)) return false;
// The IV is valid
return true;
}
static boolean getInitiatorFlag(byte[] iv) {
if(iv.length != IV_LENGTH) throw new IllegalArgumentException();
return (iv[3] & 1) == 1;

View File

@@ -33,10 +33,11 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory {
this.protoWriterFactory = protoWriterFactory;
}
public void createIncomingConnection(ContactId c,
public void createIncomingConnection(TransportId t, ContactId c,
BatchTransportReader r, byte[] encryptedIv) {
final IncomingBatchConnection conn = new IncomingBatchConnection(
connReaderFactory, db, protoReaderFactory, c, r, encryptedIv);
connReaderFactory, db, protoReaderFactory, t, c, r,
encryptedIv);
Runnable read = new Runnable() {
public void run() {
conn.read();

View File

@@ -6,6 +6,7 @@ import java.util.logging.Logger;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.TransportId;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.Ack;
@@ -26,17 +27,19 @@ class IncomingBatchConnection {
private final ConnectionReaderFactory connFactory;
private final DatabaseComponent db;
private final ProtocolReaderFactory protoFactory;
private final TransportId transportId;
private final ContactId contactId;
private final BatchTransportReader reader;
private final byte[] encryptedIv;
IncomingBatchConnection(ConnectionReaderFactory connFactory,
DatabaseComponent db, ProtocolReaderFactory protoFactory,
ContactId contactId, BatchTransportReader reader,
byte[] encryptedIv) {
TransportId transportId, ContactId contactId,
BatchTransportReader reader, byte[] encryptedIv) {
this.connFactory = connFactory;
this.db = db;
this.protoFactory = protoFactory;
this.transportId = transportId;
this.contactId = contactId;
this.reader = reader;
this.encryptedIv = encryptedIv;
@@ -46,7 +49,7 @@ class IncomingBatchConnection {
try {
byte[] secret = db.getSharedSecret(contactId);
ConnectionReader conn = connFactory.createConnectionReader(
reader.getInputStream(), encryptedIv, secret);
reader.getInputStream(), transportId, encryptedIv, secret);
ProtocolReader proto = protoFactory.createProtocolReader(
conn.getInputStream());
// Read packets until EOF

View File

@@ -49,8 +49,8 @@ class OutgoingBatchConnection {
byte[] secret = db.getSharedSecret(contactId);
long connection = db.getConnectionNumber(contactId, transportId);
ConnectionWriter conn = connFactory.createConnectionWriter(
writer.getOutputStream(), writer.getCapacity(), true,
transportId, connection, secret);
writer.getOutputStream(), writer.getCapacity(), transportId,
connection, secret);
OutputStream out = conn.getOutputStream();
// There should be enough space for a packet
long capacity = conn.getRemainingCapacity();

View File

@@ -34,7 +34,7 @@ public class IncomingStreamConnection extends StreamConnection {
IOException {
byte[] secret = db.getSharedSecret(contactId);
return connReaderFactory.createConnectionReader(
connection.getInputStream(), encryptedIv, secret);
connection.getInputStream(), transportId, encryptedIv, secret);
}
@Override

View File

@@ -36,7 +36,7 @@ public class OutgoingStreamConnection extends StreamConnection {
}
byte[] secret = db.getSharedSecret(contactId);
return connReaderFactory.createConnectionReader(
connection.getInputStream(), false, transportId, connectionNum,
connection.getInputStream(), transportId, connectionNum,
secret);
}
@@ -49,7 +49,7 @@ public class OutgoingStreamConnection extends StreamConnection {
}
byte[] secret = db.getSharedSecret(contactId);
return connWriterFactory.createConnectionWriter(
connection.getOutputStream(), Long.MAX_VALUE, true, transportId,
connection.getOutputStream(), Long.MAX_VALUE, transportId,
connectionNum, secret);
}
}

View File

@@ -131,7 +131,7 @@ public class ProtocolIntegrationTest extends TestCase {
ByteArrayOutputStream out = new ByteArrayOutputStream();
// Use Alice's secret for writing
ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out,
Long.MAX_VALUE, true, transportId, connection, aliceSecret);
Long.MAX_VALUE, transportId, connection, aliceSecret);
OutputStream out1 = w.getOutputStream();
AckWriter a = protocolWriterFactory.createAckWriter(out1);
@@ -175,17 +175,17 @@ public class ProtocolIntegrationTest extends TestCase {
private void read(byte[] connection) throws Exception {
InputStream in = new ByteArrayInputStream(connection);
byte[] iv = new byte[16];
byte[] encryptedIv = new byte[16];
int offset = 0;
while(offset < 16) {
int read = in.read(iv, offset, iv.length - offset);
int read = in.read(encryptedIv, offset, 16 - offset);
if(read == -1) break;
offset += read;
}
assertEquals(16, offset);
// Use Bob's secret for reading
ConnectionReader r = connectionReaderFactory.createConnectionReader(in,
iv, bobSecret);
transportId, encryptedIv, bobSecret);
in = r.getInputStream();
ProtocolReader protocolReader =
protocolReaderFactory.createProtocolReader(in);

View File

@@ -84,8 +84,9 @@ public class ConnectionDecrypterImplTest extends TestCase {
out.write(ciphertextMac);
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
// Use a ConnectionDecrypter to decrypt the ciphertext
ConnectionDecrypter d = new ConnectionDecrypterImpl(in, encryptedIv,
ivCipher, frameCipher, ivKey, frameKey);
ConnectionDecrypter d = new ConnectionDecrypterImpl(in,
IvEncoder.encodeIv(initiator, transportId, connection),
frameCipher, frameKey);
// First frame
byte[] decrypted = new byte[ciphertext.length];
TestUtils.readFully(d.getInputStream(), decrypted);

View File

@@ -40,7 +40,7 @@ public class ConnectionWriterTest extends TestCase {
ByteArrayOutputStream out =
new ByteArrayOutputStream(MIN_CONNECTION_LENGTH);
ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out,
MIN_CONNECTION_LENGTH, true, transportId, connection, secret);
MIN_CONNECTION_LENGTH, transportId, connection, secret);
// Check that the connection writer thinks there's room for a packet
long capacity = w.getRemainingCapacity();
assertTrue(capacity >= MAX_PACKET_LENGTH);

View File

@@ -84,12 +84,17 @@ public class FrameReadWriteTest extends TestCase {
out1.flush();
// Read the IV back
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
byte[] recoveredIv = new byte[IV_LENGTH];
assertEquals(IV_LENGTH, in.read(recoveredIv));
assertArrayEquals(encryptedIv, recoveredIv);
byte[] recoveredEncryptedIv = new byte[IV_LENGTH];
assertEquals(IV_LENGTH, in.read(recoveredEncryptedIv));
assertArrayEquals(encryptedIv, recoveredEncryptedIv);
// Decrypt the IV
ivCipher.init(Cipher.DECRYPT_MODE, ivKey);
byte[] recoveredIv = ivCipher.doFinal(recoveredEncryptedIv);
iv = IvEncoder.encodeIv(initiator, transportId, connection);
assertArrayEquals(iv, recoveredIv);
// Read the frames back
ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in,
recoveredIv, ivCipher, frameCipher, ivKey, frameKey);
ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in, iv,
frameCipher, frameKey);
ConnectionReader reader = new ConnectionReaderImpl(decrypter, mac,
macKey);
InputStream in1 = reader.getInputStream();

View File

@@ -145,7 +145,8 @@ public class BatchConnectionReadWriteTest extends TestCase {
bob.getInstance(ProtocolReaderFactory.class);
BatchTransportReader reader = new TestBatchTransportReader(in);
IncomingBatchConnection batchIn = new IncomingBatchConnection(
connFactory, db, protoFactory, contactId, reader, encryptedIv);
connFactory, db, protoFactory, transportId, contactId, reader,
encryptedIv);
// No messages should have been added yet
assertFalse(listener.messagesAdded);
// Read whatever needs to be read