Forward secrecy.

Each connection's keys are derived from a secret that is erased after
deriving the keys and the secret for the next connection.
This commit is contained in:
akwizgran
2011-11-16 15:35:16 +00:00
parent d02a68edfc
commit f6ae4734ce
45 changed files with 506 additions and 430 deletions

View File

@@ -15,7 +15,7 @@ public interface CryptoComponent {
ErasableKey deriveMacKey(byte[] secret, boolean initiator); ErasableKey deriveMacKey(byte[] secret, boolean initiator);
byte[] deriveNextSecret(byte[] secret, long connection); byte[] deriveNextSecret(byte[] secret, int index, long connection);
KeyPair generateKeyPair(); KeyPair generateKeyPair();

View File

@@ -57,8 +57,7 @@ public interface DatabaseComponent {
* Adds a new contact to the database with the given secrets and returns an * Adds a new contact to the database with the given secrets and returns an
* ID for the contact. * ID for the contact.
*/ */
ContactId addContact(byte[] incomingSecret, byte[] outgoingSecret) ContactId addContact(byte[] inSecret, byte[] outSecret) throws DbException;
throws DbException;
/** Adds a locally generated group message to the database. */ /** Adds a locally generated group message to the database. */
void addLocalGroupMessage(Message m) throws DbException; void addLocalGroupMessage(Message m) throws DbException;
@@ -160,9 +159,6 @@ public interface DatabaseComponent {
Map<ContactId, TransportProperties> getRemoteProperties(TransportId t) Map<ContactId, TransportProperties> getRemoteProperties(TransportId t)
throws DbException; throws DbException;
/** Returns the secret shared with the given contact. */
byte[] getSharedSecret(ContactId c, boolean incoming) throws DbException;
/** Returns the set of groups to which the user subscribes. */ /** Returns the set of groups to which the user subscribes. */
Collection<Group> getSubscriptions() throws DbException; Collection<Group> getSubscriptions() throws DbException;

View File

@@ -5,9 +5,9 @@ import net.sf.briar.api.protocol.TransportIndex;
public interface BatchConnectionFactory { public interface BatchConnectionFactory {
void createIncomingConnection(TransportIndex i, ContactId c, void createIncomingConnection(ConnectionContext ctx,
BatchTransportReader r, byte[] encryptedIv); BatchTransportReader r, byte[] encryptedIv);
void createOutgoingConnection(TransportIndex i, ContactId c, void createOutgoingConnection(ContactId c, TransportIndex i,
BatchTransportWriter w); BatchTransportWriter w);
} }

View File

@@ -10,4 +10,6 @@ public interface ConnectionContext {
TransportIndex getTransportIndex(); TransportIndex getTransportIndex();
long getConnectionNumber(); long getConnectionNumber();
byte[] getSecret();
} }

View File

@@ -6,5 +6,8 @@ import net.sf.briar.api.protocol.TransportIndex;
public interface ConnectionContextFactory { public interface ConnectionContextFactory {
ConnectionContext createConnectionContext(ContactId c, TransportIndex i, ConnectionContext createConnectionContext(ContactId c, TransportIndex i,
long connection); long connection, byte[] secret);
ConnectionContext createNextConnectionContext(ContactId c, TransportIndex i,
long connection, byte[] previousSecret);
} }

View File

@@ -8,10 +8,10 @@ public interface ConnectionDispatcher {
void dispatchReader(TransportId t, BatchTransportReader r); void dispatchReader(TransportId t, BatchTransportReader r);
void dispatchWriter(TransportIndex i, ContactId c, BatchTransportWriter w); void dispatchWriter(ContactId c, TransportIndex i, BatchTransportWriter w);
void dispatchIncomingConnection(TransportId t, StreamTransportConnection s); void dispatchIncomingConnection(TransportId t, StreamTransportConnection s);
void dispatchOutgoingConnection(TransportIndex i, ContactId c, void dispatchOutgoingConnection(ContactId c, TransportIndex i,
StreamTransportConnection s); StreamTransportConnection s);
} }

View File

@@ -2,22 +2,19 @@ package net.sf.briar.api.transport;
import java.io.InputStream; import java.io.InputStream;
import net.sf.briar.api.protocol.TransportIndex;
public interface ConnectionReaderFactory { public interface ConnectionReaderFactory {
/** /**
* Creates a connection reader for a batch-mode connection or the * Creates a connection reader for a batch-mode connection or the
* initiator's side of a stream-mode connection. The secret is erased before * initiator's side of a stream-mode connection.
* returning.
*/ */
ConnectionReader createConnectionReader(InputStream in, TransportIndex i, ConnectionReader createConnectionReader(InputStream in,
byte[] encryptedIv, byte[] secret); ConnectionContext ctx, byte[] encryptedIv);
/** /**
* Creates a connection reader for the responder's side of a stream-mode * Creates a connection reader for the responder's side of a stream-mode
* connection. The secret is erased before returning. * connection.
*/ */
ConnectionReader createConnectionReader(InputStream in, TransportIndex i, ConnectionReader createConnectionReader(InputStream in,
long connection, byte[] secret); ConnectionContext ctx);
} }

View File

@@ -1,6 +1,6 @@
package net.sf.briar.api.transport; package net.sf.briar.api.transport;
import java.util.Collection; import java.util.Map;
public interface ConnectionWindow { public interface ConnectionWindow {
@@ -8,5 +8,5 @@ public interface ConnectionWindow {
void setSeen(long connection); void setSeen(long connection);
Collection<Long> getUnseen(); Map<Long, byte[]> getUnseen();
} }

View File

@@ -1,10 +1,13 @@
package net.sf.briar.api.transport; package net.sf.briar.api.transport;
import java.util.Collection; import java.util.Map;
import net.sf.briar.api.protocol.TransportIndex;
public interface ConnectionWindowFactory { public interface ConnectionWindowFactory {
ConnectionWindow createConnectionWindow(); ConnectionWindow createConnectionWindow(TransportIndex i, byte[] secret);
ConnectionWindow createConnectionWindow(Collection<Long> unseen); ConnectionWindow createConnectionWindow(TransportIndex i,
Map<Long, byte[]> unseen);
} }

View File

@@ -2,22 +2,19 @@ package net.sf.briar.api.transport;
import java.io.OutputStream; import java.io.OutputStream;
import net.sf.briar.api.protocol.TransportIndex;
public interface ConnectionWriterFactory { public interface ConnectionWriterFactory {
/** /**
* Creates a connection writer for a batch-mode connection or the * Creates a connection writer for a batch-mode connection or the
* initiator's side of a stream-mode connection. The secret is erased before * initiator's side of a stream-mode connection.
* returning.
*/ */
ConnectionWriter createConnectionWriter(OutputStream out, long capacity, ConnectionWriter createConnectionWriter(OutputStream out, long capacity,
TransportIndex i, long connection, byte[] secret); ConnectionContext ctx);
/** /**
* Creates a connection writer for the responder's side of a stream-mode * Creates a connection writer for the responder's side of a stream-mode
* connection. The secret is erased before returning. * connection.
*/ */
ConnectionWriter createConnectionWriter(OutputStream out, long capacity, ConnectionWriter createConnectionWriter(OutputStream out, long capacity,
TransportIndex i, byte[] encryptedIv, byte[] secret); ConnectionContext ctx, byte[] encryptedIv);
} }

View File

@@ -5,9 +5,9 @@ import net.sf.briar.api.protocol.TransportIndex;
public interface StreamConnectionFactory { public interface StreamConnectionFactory {
void createIncomingConnection(TransportIndex i, ContactId c, void createIncomingConnection(ConnectionContext ctx,
StreamTransportConnection s, byte[] encryptedIv); StreamTransportConnection s, byte[] encryptedIv);
void createOutgoingConnection(TransportIndex i, ContactId c, void createOutgoingConnection(ContactId c, TransportIndex i,
StreamTransportConnection s); StreamTransportConnection s);
} }

View File

@@ -88,14 +88,14 @@ class CryptoComponentImpl implements CryptoComponent {
if(secret.length != SECRET_KEY_BYTES) if(secret.length != SECRET_KEY_BYTES)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
ErasableKey key = new ErasableKeyImpl(secret, SECRET_KEY_ALGO); ErasableKey key = new ErasableKeyImpl(secret, SECRET_KEY_ALGO);
// The context must leave four bytes free for the length // The context must leave two bytes free for the length
if(context.length + 4 > SECRET_KEY_BYTES) if(context.length + 2 > SECRET_KEY_BYTES)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
byte[] input = new byte[SECRET_KEY_BYTES]; byte[] input = new byte[SECRET_KEY_BYTES];
// The initial bytes of the input are the context // The input starts with the length of the context as a big-endian int16
System.arraycopy(context, 0, input, 0, context.length); ByteUtils.writeUint16(context.length, input, 0);
// The final bytes of the input are the length as a big-endian uint32 // The remaining bytes of the input are the context
ByteUtils.writeUint32(context.length, input, input.length - 4); System.arraycopy(context, 0, input, 2, context.length);
// Initialise the counter to zero // Initialise the counter to zero
byte[] zero = new byte[KEY_DERIVATION_IV_BYTES]; byte[] zero = new byte[KEY_DERIVATION_IV_BYTES];
IvParameterSpec iv = new IvParameterSpec(zero); IvParameterSpec iv = new IvParameterSpec(zero);
@@ -110,12 +110,15 @@ class CryptoComponentImpl implements CryptoComponent {
} }
} }
public byte[] deriveNextSecret(byte[] secret, long connection) { public byte[] deriveNextSecret(byte[] secret, int index, long connection) {
if(index < 0 || index > ByteUtils.MAX_16_BIT_UNSIGNED)
throw new IllegalArgumentException();
if(connection < 0 || connection > ByteUtils.MAX_32_BIT_UNSIGNED) if(connection < 0 || connection > ByteUtils.MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
byte[] context = new byte[NEXT.length + 4]; byte[] context = new byte[NEXT.length + 6];
System.arraycopy(NEXT, 0, context, 0, NEXT.length); System.arraycopy(NEXT, 0, context, 0, NEXT.length);
ByteUtils.writeUint32(connection, context, NEXT.length); ByteUtils.writeUint16(index, context, NEXT.length);
ByteUtils.writeUint32(connection, context, NEXT.length + 2);
return counterModeKdf(secret, context); return counterModeKdf(secret, context);
} }

View File

@@ -3,6 +3,7 @@ package net.sf.briar.crypto;
import java.util.Arrays; import java.util.Arrays;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.util.ByteUtils;
class ErasableKeyImpl implements ErasableKey { class ErasableKeyImpl implements ErasableKey {
@@ -34,7 +35,7 @@ class ErasableKeyImpl implements ErasableKey {
public void erase() { public void erase() {
if(erased) throw new IllegalStateException(); if(erased) throw new IllegalStateException();
for(int i = 0; i < key.length; i++) key[i] = 0; ByteUtils.erase(key);
erased = true; erased = true;
} }

View File

@@ -84,10 +84,14 @@ interface Database<T> {
* Adds a new contact to the database with the given secrets and returns an * Adds a new contact to the database with the given secrets and returns an
* ID for the contact. * ID for the contact.
* <p> * <p>
* Any secrets generated by the method are stored in the given collection
* and should be erased by the caller once the transaction has been
* committed or aborted.
* <p>
* Locking: contact write. * Locking: contact write.
*/ */
ContactId addContact(T txn, byte[] incomingSecret, byte[] outgoingSecret) ContactId addContact(T txn, byte[] inSecret, byte[] outSecret,
throws DbException; Collection<byte[]> erase) throws DbException;
/** /**
* Returns false if the given message is already in the database. Otherwise * Returns false if the given message is already in the database. Otherwise
@@ -187,10 +191,14 @@ interface Database<T> {
* Returns an outgoing connection context for the given contact and * Returns an outgoing connection context for the given contact and
* transport. * transport.
* <p> * <p>
* Any secrets generated by the method are stored in the given collection
* and should be erased by the caller once the transaction has been
* committed or aborted.
* <p>
* Locking: contact read, window write. * Locking: contact read, window write.
*/ */
ConnectionContext getConnectionContext(T txn, ContactId c, TransportIndex i) ConnectionContext getConnectionContext(T txn, ContactId c, TransportIndex i,
throws DbException; Collection<byte[]> erase) throws DbException;
/** /**
* Returns the connection reordering window for the given contact and * Returns the connection reordering window for the given contact and
@@ -373,14 +381,6 @@ interface Database<T> {
Collection<MessageId> getSendableMessages(T txn, ContactId c, int capacity) Collection<MessageId> getSendableMessages(T txn, ContactId c, int capacity)
throws DbException; throws DbException;
/**
* Returns the secret shared with the given contact.
* <p>
* Locking: contact read.
*/
byte[] getSharedSecret(T txn, ContactId c, boolean incoming)
throws DbException;
/** /**
* Returns true if the given message has been starred. * Returns true if the given message has been starred.
* <p> * <p>

View File

@@ -62,6 +62,7 @@ import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter;
import net.sf.briar.api.protocol.writers.TransportUpdateWriter; import net.sf.briar.api.protocol.writers.TransportUpdateWriter;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.util.ByteUtils;
import com.google.inject.Inject; import com.google.inject.Inject;
@@ -136,23 +137,24 @@ DatabaseCleaner.Callback {
} }
} }
public ContactId addContact(byte[] incomingSecret, byte[] outgoingSecret) public ContactId addContact(byte[] inSecret, byte[] outSecret)
throws DbException { throws DbException {
if(LOG.isLoggable(Level.FINE)) LOG.fine("Adding contact");
ContactId c; ContactId c;
Collection<byte[]> erase = new ArrayList<byte[]>();
contactLock.writeLock().lock(); contactLock.writeLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
c = db.addContact(txn, incomingSecret, outgoingSecret); c = db.addContact(txn, inSecret, outSecret, erase);
db.commitTransaction(txn); db.commitTransaction(txn);
if(LOG.isLoggable(Level.FINE)) LOG.fine("Added contact " + c);
} catch(DbException e) { } catch(DbException e) {
db.abortTransaction(txn); db.abortTransaction(txn);
throw e; throw e;
} }
} finally { } finally {
contactLock.writeLock().unlock(); contactLock.writeLock().unlock();
// Erase the secrets after committing or aborting the transaction
for(byte[] b : erase) ByteUtils.erase(b);
} }
// Call the listeners outside the lock // Call the listeners outside the lock
callListeners(new ContactAddedEvent(c)); callListeners(new ContactAddedEvent(c));
@@ -703,6 +705,7 @@ DatabaseCleaner.Callback {
public ConnectionContext getConnectionContext(ContactId c, TransportIndex i) public ConnectionContext getConnectionContext(ContactId c, TransportIndex i)
throws DbException { throws DbException {
Collection<byte[]> erase = new ArrayList<byte[]>();
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException(); if(!containsContact(c)) throw new NoSuchContactException();
@@ -710,7 +713,8 @@ DatabaseCleaner.Callback {
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
ConnectionContext ctx = db.getConnectionContext(txn, c, i); ConnectionContext ctx =
db.getConnectionContext(txn, c, i, erase);
db.commitTransaction(txn); db.commitTransaction(txn);
return ctx; return ctx;
} catch(DbException e) { } catch(DbException e) {
@@ -722,6 +726,8 @@ DatabaseCleaner.Callback {
} }
} finally { } finally {
contactLock.readLock().unlock(); contactLock.readLock().unlock();
// Erase the secrets after committing or aborting the transaction
for(byte[] b : erase) ByteUtils.erase(b);
} }
} }
@@ -907,25 +913,6 @@ DatabaseCleaner.Callback {
} }
} }
public byte[] getSharedSecret(ContactId c, boolean incoming)
throws DbException {
contactLock.readLock().lock();
try {
if(!containsContact(c)) throw new NoSuchContactException();
T txn = db.startTransaction();
try {
byte[] secret = db.getSharedSecret(txn, c, incoming);
db.commitTransaction(txn);
return secret;
} catch(DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
contactLock.readLock().unlock();
}
}
public Collection<Group> getSubscriptions() throws DbException { public Collection<Group> getSubscriptions() throws DbException {
subscriptionLock.readLock().lock(); subscriptionLock.readLock().lock();
try { try {

View File

@@ -29,6 +29,11 @@ class H2Database extends JdbcDatabase {
private static final Logger LOG = private static final Logger LOG =
Logger.getLogger(H2Database.class.getName()); Logger.getLogger(H2Database.class.getName());
private static final String HASH_TYPE = "BINARY(32)";
private static final String BINARY_TYPE = "BINARY";
private static final String COUNTER_TYPE = "INT NOT NULL AUTO_INCREMENT";
private static final String SECRET_TYPE = "BINARY(32)";
private final File home; private final File home;
private final Password password; private final Password password;
private final String url; private final String url;
@@ -42,7 +47,7 @@ class H2Database extends JdbcDatabase {
ConnectionWindowFactory connectionWindowFactory, ConnectionWindowFactory connectionWindowFactory,
GroupFactory groupFactory) { GroupFactory groupFactory) {
super(connectionContextFactory, connectionWindowFactory, groupFactory, super(connectionContextFactory, connectionWindowFactory, groupFactory,
"BINARY(32)", "BINARY", "INT NOT NULL AUTO_INCREMENT"); HASH_TYPE, BINARY_TYPE, COUNTER_TYPE, SECRET_TYPE);
home = new File(dir, "db"); home = new File(dir, "db");
this.password = password; this.password = password;
url = "jdbc:h2:split:" + home.getPath() url = "jdbc:h2:split:" + home.getPath()

View File

@@ -58,8 +58,6 @@ abstract class JdbcDatabase implements Database<Connection> {
private static final String CREATE_CONTACTS = private static final String CREATE_CONTACTS =
"CREATE TABLE contacts" "CREATE TABLE contacts"
+ " (contactId COUNTER," + " (contactId COUNTER,"
+ " incomingSecret BINARY NOT NULL,"
+ " outgoingSecret BINARY NOT NULL,"
+ " PRIMARY KEY (contactId))"; + " PRIMARY KEY (contactId))";
private static final String CREATE_MESSAGES = private static final String CREATE_MESSAGES =
@@ -221,7 +219,8 @@ abstract class JdbcDatabase implements Database<Connection> {
"CREATE TABLE connections" "CREATE TABLE connections"
+ " (contactId INT NOT NULL," + " (contactId INT NOT NULL,"
+ " index INT NOT NULL," + " index INT NOT NULL,"
+ " outgoing BIGINT NOT NULL," + " connection BIGINT NOT NULL,"
+ " secret SECRET NOT NULL,"
+ " PRIMARY KEY (contactId, index)," + " PRIMARY KEY (contactId, index),"
+ " FOREIGN KEY (contactId) REFERENCES contacts (contactId)" + " FOREIGN KEY (contactId) REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE)"; + " ON DELETE CASCADE)";
@@ -231,6 +230,7 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " (contactId INT NOT NULL," + " (contactId INT NOT NULL,"
+ " index INT NOT NULL," + " index INT NOT NULL,"
+ " unseen BIGINT NOT NULL," + " unseen BIGINT NOT NULL,"
+ " secret SECRET NOT NULL,"
+ " PRIMARY KEY (contactId, index, unseen)," + " PRIMARY KEY (contactId, index, unseen),"
+ " FOREIGN KEY (contactId) REFERENCES contacts (contactId)" + " FOREIGN KEY (contactId) REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE)"; + " ON DELETE CASCADE)";
@@ -271,7 +271,7 @@ abstract class JdbcDatabase implements Database<Connection> {
private final ConnectionWindowFactory connectionWindowFactory; private final ConnectionWindowFactory connectionWindowFactory;
private final GroupFactory groupFactory; private final GroupFactory groupFactory;
// Different database libraries use different names for certain types // Different database libraries use different names for certain types
private final String hashType, binaryType, counterType; private final String hashType, binaryType, counterType, secretType;
private final LinkedList<Connection> connections = private final LinkedList<Connection> connections =
new LinkedList<Connection>(); // Locking: self new LinkedList<Connection>(); // Locking: self
@@ -284,13 +284,14 @@ abstract class JdbcDatabase implements Database<Connection> {
JdbcDatabase(ConnectionContextFactory connectionContextFactory, JdbcDatabase(ConnectionContextFactory connectionContextFactory,
ConnectionWindowFactory connectionWindowFactory, ConnectionWindowFactory connectionWindowFactory,
GroupFactory groupFactory, String hashType, String binaryType, GroupFactory groupFactory, String hashType, String binaryType,
String counterType) { String counterType, String secretType) {
this.connectionContextFactory = connectionContextFactory; this.connectionContextFactory = connectionContextFactory;
this.connectionWindowFactory = connectionWindowFactory; this.connectionWindowFactory = connectionWindowFactory;
this.groupFactory = groupFactory; this.groupFactory = groupFactory;
this.hashType = hashType; this.hashType = hashType;
this.binaryType = binaryType; this.binaryType = binaryType;
this.counterType = counterType; this.counterType = counterType;
this.secretType = secretType;
} }
protected void open(boolean resume, File dir, String driverClass) protected void open(boolean resume, File dir, String driverClass)
@@ -371,6 +372,7 @@ abstract class JdbcDatabase implements Database<Connection> {
s = s.replaceAll("HASH", hashType); s = s.replaceAll("HASH", hashType);
s = s.replaceAll("BINARY", binaryType); s = s.replaceAll("BINARY", binaryType);
s = s.replaceAll("COUNTER", counterType); s = s.replaceAll("COUNTER", counterType);
s = s.replaceAll("SECRET", secretType);
return s; return s;
} }
@@ -515,17 +517,14 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
} }
public ContactId addContact(Connection txn, byte[] incomingSecret, public ContactId addContact(Connection txn, byte[] inSecret,
byte[] outgoingSecret) throws DbException { byte[] outSecret, Collection<byte[]> erase) throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null; ResultSet rs = null;
try { try {
// Create a new contact row // Create a new contact row
String sql = "INSERT INTO contacts (incomingSecret, outgoingSecret)" String sql = "INSERT INTO contacts DEFAULT VALUES";
+ " VALUES (?, ?)";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setBytes(1, incomingSecret);
ps.setBytes(2, outgoingSecret);
int affected = ps.executeUpdate(); int affected = ps.executeUpdate();
if(affected != 1) throw new DbStateException(); if(affected != 1) throw new DbStateException();
ps.close(); ps.close();
@@ -558,13 +557,20 @@ abstract class JdbcDatabase implements Database<Connection> {
affected = ps.executeUpdate(); affected = ps.executeUpdate();
if(affected != 1) throw new DbStateException(); if(affected != 1) throw new DbStateException();
ps.close(); ps.close();
// Initialise the connection numbers for all transports // Initialise the outgoing connection contexts for all transports
sql = "INSERT INTO connections (contactId, index, outgoing)" sql = "INSERT INTO connections"
+ " VALUES (?, ?, ZERO())"; + " (contactId, index, connection, secret)"
+ " VALUES (?, ?, ZERO(), ?)";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); ps.setInt(1, c.getInt());
for(int i = 0; i < ProtocolConstants.MAX_TRANSPORTS; i++) { for(int i = 0; i < ProtocolConstants.MAX_TRANSPORTS; i++) {
ps.setInt(2, i); ps.setInt(2, i);
ConnectionContext ctx =
connectionContextFactory.createNextConnectionContext(c,
new TransportIndex(i), 0L, outSecret);
byte[] secret = ctx.getSecret();
erase.add(secret);
ps.setBytes(3, secret);
ps.addBatch(); ps.addBatch();
} }
int[] batchAffected = ps.executeBatch(); int[] batchAffected = ps.executeBatch();
@@ -574,18 +580,23 @@ abstract class JdbcDatabase implements Database<Connection> {
if(batchAffected[i] != 1) throw new DbStateException(); if(batchAffected[i] != 1) throw new DbStateException();
} }
ps.close(); ps.close();
// Initialise the connection windows for all transports // Initialise the incoming connection windows for all transports
sql = "INSERT INTO connectionWindows (contactId, index, unseen)" sql = "INSERT INTO connectionWindows"
+ " VALUES (?, ?, ?)"; + " (contactId, index, unseen, secret)"
+ " VALUES (?, ?, ?, ?)";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); ps.setInt(1, c.getInt());
int batchSize = 0; int batchSize = 0;
for(int i = 0; i < ProtocolConstants.MAX_TRANSPORTS; i++) { for(int i = 0; i < ProtocolConstants.MAX_TRANSPORTS; i++) {
ps.setInt(2, i); ps.setInt(2, i);
ConnectionWindow w = ConnectionWindow w =
connectionWindowFactory.createConnectionWindow(); connectionWindowFactory.createConnectionWindow(
for(long l : w.getUnseen()) { new TransportIndex(i), inSecret);
ps.setLong(3, l); for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
ps.setLong(3, e.getKey());
byte[] secret = e.getValue();
erase.add(secret);
ps.setBytes(4, secret);
ps.addBatch(); ps.addBatch();
batchSize++; batchSize++;
} }
@@ -945,31 +956,43 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
public ConnectionContext getConnectionContext(Connection txn, ContactId c, public ConnectionContext getConnectionContext(Connection txn, ContactId c,
TransportIndex i) throws DbException { TransportIndex i, Collection<byte[]> erase) throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null; ResultSet rs = null;
try { try {
String sql = "UPDATE connections SET outgoing = outgoing + 1" // Retrieve the current context
+ " WHERE contactId = ? AND index = ?"; String sql = "SELECT connection, secret FROM connections"
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setInt(2, i.getInt());
int affected = ps.executeUpdate();
if(affected != 1) throw new DbStateException();
ps.close();
sql = "SELECT outgoing FROM connections"
+ " WHERE contactId = ? AND index = ?"; + " WHERE contactId = ? AND index = ?";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); ps.setInt(1, c.getInt());
ps.setInt(2, i.getInt()); ps.setInt(2, i.getInt());
rs = ps.executeQuery(); rs = ps.executeQuery();
if(!rs.next()) throw new DbStateException(); if(!rs.next()) throw new DbStateException();
long outgoing = rs.getLong(1); long connection = rs.getLong(1);
byte[] secret = rs.getBytes(2);
if(rs.next()) throw new DbStateException(); if(rs.next()) throw new DbStateException();
rs.close(); rs.close();
ps.close(); ps.close();
return connectionContextFactory.createConnectionContext(c, i, ConnectionContext ctx =
outgoing); connectionContextFactory.createConnectionContext(c, i,
connection, secret);
// Calculate and store the next context
ConnectionContext next =
connectionContextFactory.createNextConnectionContext(c, i,
connection + 1, secret);
byte[] nextSecret = next.getSecret();
erase.add(nextSecret);
sql = "UPDATE connections"
+ " SET connection = connection + 1, secret = ?"
+ " WHERE contactId = ? AND index = ?";
ps = txn.prepareStatement(sql);
ps.setBytes(1, nextSecret);
ps.setInt(2, c.getInt());
ps.setInt(3, i.getInt());
int affected = ps.executeUpdate();
if(affected != 1) throw new DbStateException();
ps.close();
return ctx;
} catch(SQLException e) { } catch(SQLException e) {
tryToClose(rs); tryToClose(rs);
tryToClose(ps); tryToClose(ps);
@@ -982,17 +1005,17 @@ abstract class JdbcDatabase implements Database<Connection> {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null; ResultSet rs = null;
try { try {
String sql = "SELECT unseen FROM connectionWindows" String sql = "SELECT unseen, secret FROM connectionWindows"
+ " WHERE contactId = ? AND index = ?"; + " WHERE contactId = ? AND index = ?";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); ps.setInt(1, c.getInt());
ps.setInt(2, i.getInt()); ps.setInt(2, i.getInt());
rs = ps.executeQuery(); rs = ps.executeQuery();
Collection<Long> unseen = new ArrayList<Long>(); Map<Long, byte[]> unseen = new HashMap<Long, byte[]>();
while(rs.next()) unseen.add(rs.getLong(1)); while(rs.next()) unseen.put(rs.getLong(1), rs.getBytes(2));
rs.close(); rs.close();
ps.close(); ps.close();
return connectionWindowFactory.createConnectionWindow(unseen); return connectionWindowFactory.createConnectionWindow(i, unseen);
} catch(SQLException e) { } catch(SQLException e) {
tryToClose(rs); tryToClose(rs);
tryToClose(ps); tryToClose(ps);
@@ -1652,29 +1675,6 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
} }
public byte[] getSharedSecret(Connection txn, ContactId c, boolean incoming)
throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String col = incoming ? "incomingSecret" : "outgoingSecret";
String sql = "SELECT " + col + " FROM contacts WHERE contactId = ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
rs = ps.executeQuery();
if(!rs.next()) throw new DbStateException();
byte[] secret = rs.getBytes(1);
if(rs.next()) throw new DbStateException();
rs.close();
ps.close();
return secret;
} catch(SQLException e) {
tryToClose(rs);
tryToClose(ps);
throw new DbException(e);
}
}
public boolean getStarred(Connection txn, MessageId m) throws DbException { public boolean getStarred(Connection txn, MessageId m) throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null; ResultSet rs = null;
@@ -2197,14 +2197,16 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.executeUpdate(); ps.executeUpdate();
ps.close(); ps.close();
// Store the new connection window // Store the new connection window
sql = "INSERT INTO connectionWindows (contactId, index, unseen)" sql = "INSERT INTO connectionWindows"
+ " VALUES(?, ?, ?)"; + " (contactId, index, unseen, secret)"
+ " VALUES(?, ?, ?, ?)";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); ps.setInt(1, c.getInt());
ps.setInt(2, i.getInt()); ps.setInt(2, i.getInt());
Collection<Long> unseen = w.getUnseen(); Map<Long, byte[]> unseen = w.getUnseen();
for(long l : unseen) { for(Entry<Long, byte[]> e : unseen.entrySet()) {
ps.setLong(3, l); ps.setLong(3, e.getKey());
ps.setBytes(4, e.getValue());
ps.addBatch(); ps.addBatch();
} }
int[] affectedBatch = ps.executeBatch(); int[] affectedBatch = ps.executeBatch();

View File

@@ -292,7 +292,7 @@ class PluginManagerImpl implements PluginManager {
public void writerCreated(ContactId c, BatchTransportWriter w) { public void writerCreated(ContactId c, BatchTransportWriter w) {
assert index != null; assert index != null;
dispatcher.dispatchWriter(index, c, w); dispatcher.dispatchWriter(c, index, w);
} }
} }
@@ -307,7 +307,7 @@ class PluginManagerImpl implements PluginManager {
public void outgoingConnectionCreated(ContactId c, public void outgoingConnectionCreated(ContactId c,
StreamTransportConnection s) { StreamTransportConnection s) {
assert index != null; assert index != null;
dispatcher.dispatchOutgoingConnection(index, c, s); dispatcher.dispatchOutgoingConnection(c, index, s);
} }
} }
} }

View File

@@ -1,14 +1,31 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionContextFactory; import net.sf.briar.api.transport.ConnectionContextFactory;
import com.google.inject.Inject;
class ConnectionContextFactoryImpl implements ConnectionContextFactory { class ConnectionContextFactoryImpl implements ConnectionContextFactory {
private final CryptoComponent crypto;
@Inject
ConnectionContextFactoryImpl(CryptoComponent crypto) {
this.crypto = crypto;
}
public ConnectionContext createConnectionContext(ContactId c, public ConnectionContext createConnectionContext(ContactId c,
TransportIndex i, long connection) { TransportIndex i, long connection, byte[] secret) {
return new ConnectionContextImpl(c, i, connection); return new ConnectionContextImpl(c, i, connection, secret);
}
public ConnectionContext createNextConnectionContext(ContactId c,
TransportIndex i, long connection, byte[] previousSecret) {
byte[] secret = crypto.deriveNextSecret(previousSecret, i.getInt(),
connection);
return new ConnectionContextImpl(c, i, connection, secret);
} }
} }

View File

@@ -9,12 +9,14 @@ class ConnectionContextImpl implements ConnectionContext {
private final ContactId contactId; private final ContactId contactId;
private final TransportIndex transportIndex; private final TransportIndex transportIndex;
private final long connectionNumber; private final long connectionNumber;
private final byte[] secret;
ConnectionContextImpl(ContactId contactId, TransportIndex transportIndex, ConnectionContextImpl(ContactId contactId, TransportIndex transportIndex,
long connectionNumber) { long connectionNumber, byte[] secret) {
this.contactId = contactId; this.contactId = contactId;
this.transportIndex = transportIndex; this.transportIndex = transportIndex;
this.connectionNumber = connectionNumber; this.connectionNumber = connectionNumber;
this.secret = secret;
} }
public ContactId getContactId() { public ContactId getContactId() {
@@ -28,4 +30,8 @@ class ConnectionContextImpl implements ConnectionContext {
public long getConnectionNumber() { public long getConnectionNumber() {
return connectionNumber; return connectionNumber;
} }
public byte[] getSecret() {
return secret;
}
} }

View File

@@ -62,8 +62,7 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher {
r.dispose(false); r.dispose(false);
return; return;
} }
batchConnFactory.createIncomingConnection(ctx.getTransportIndex(), batchConnFactory.createIncomingConnection(ctx, r, encryptedIv);
ctx.getContactId(), r, encryptedIv);
} }
private byte[] readIv(InputStream in) throws IOException { private byte[] readIv(InputStream in) throws IOException {
@@ -77,9 +76,9 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher {
return b; return b;
} }
public void dispatchWriter(TransportIndex i, ContactId c, public void dispatchWriter(ContactId c, TransportIndex i,
BatchTransportWriter w) { BatchTransportWriter w) {
batchConnFactory.createOutgoingConnection(i, c, w); batchConnFactory.createOutgoingConnection(c, i, w);
} }
public void dispatchIncomingConnection(TransportId t, public void dispatchIncomingConnection(TransportId t,
@@ -106,12 +105,11 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher {
s.dispose(false); s.dispose(false);
return; return;
} }
streamConnFactory.createIncomingConnection(ctx.getTransportIndex(), streamConnFactory.createIncomingConnection(ctx, s, encryptedIv);
ctx.getContactId(), s, encryptedIv);
} }
public void dispatchOutgoingConnection(TransportIndex i, ContactId c, public void dispatchOutgoingConnection(ContactId c, TransportIndex i,
StreamTransportConnection s) { StreamTransportConnection s) {
streamConnFactory.createOutgoingConnection(i, c, s); streamConnFactory.createOutgoingConnection(c, i, s);
} }
} }

View File

@@ -7,12 +7,13 @@ import javax.crypto.BadPaddingException;
import javax.crypto.Cipher; import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException; import javax.crypto.IllegalBlockSizeException;
import javax.crypto.Mac; import javax.crypto.Mac;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.util.ByteUtils;
import com.google.inject.Inject; import com.google.inject.Inject;
@@ -26,10 +27,10 @@ class ConnectionReaderFactoryImpl implements ConnectionReaderFactory {
} }
public ConnectionReader createConnectionReader(InputStream in, public ConnectionReader createConnectionReader(InputStream in,
TransportIndex i, byte[] encryptedIv, byte[] secret) { ConnectionContext ctx, byte[] encryptedIv) {
// Decrypt the IV // Decrypt the IV
Cipher ivCipher = crypto.getIvCipher(); Cipher ivCipher = crypto.getIvCipher();
ErasableKey ivKey = crypto.deriveIvKey(secret, true); ErasableKey ivKey = crypto.deriveIvKey(ctx.getSecret(), true);
byte[] iv; byte[] iv;
try { try {
ivCipher.init(Cipher.DECRYPT_MODE, ivKey); ivCipher.init(Cipher.DECRYPT_MODE, ivKey);
@@ -42,27 +43,25 @@ class ConnectionReaderFactoryImpl implements ConnectionReaderFactory {
throw new IllegalArgumentException(badKey); throw new IllegalArgumentException(badKey);
} }
// Validate the IV // Validate the IV
if(!IvEncoder.validateIv(iv, true, i)) if(!IvEncoder.validateIv(iv, true, ctx))
throw new IllegalArgumentException(); throw new IllegalArgumentException();
// Copy the connection number return createConnectionReader(in, true, ctx);
long connection = IvEncoder.getConnectionNumber(iv);
return createConnectionReader(in, true, i, connection, secret);
} }
public ConnectionReader createConnectionReader(InputStream in, public ConnectionReader createConnectionReader(InputStream in,
TransportIndex i, long connection, byte[] secret) { ConnectionContext ctx) {
return createConnectionReader(in, false, i, connection, secret); return createConnectionReader(in, false, ctx);
} }
private ConnectionReader createConnectionReader(InputStream in, private ConnectionReader createConnectionReader(InputStream in,
boolean initiator, TransportIndex i, long connection, boolean initiator, ConnectionContext ctx) {
byte[] secret) {
// Derive the keys and erase the secret // Derive the keys and erase the secret
byte[] secret = ctx.getSecret();
ErasableKey frameKey = crypto.deriveFrameKey(secret, initiator); ErasableKey frameKey = crypto.deriveFrameKey(secret, initiator);
ErasableKey macKey = crypto.deriveMacKey(secret, initiator); ErasableKey macKey = crypto.deriveMacKey(secret, initiator);
for(int j = 0; j < secret.length; j++) secret[j] = 0; ByteUtils.erase(secret);
// Create the decrypter // Create the decrypter
byte[] iv = IvEncoder.encodeIv(initiator, i, connection); byte[] iv = IvEncoder.encodeIv(initiator, ctx);
Cipher frameCipher = crypto.getFrameCipher(); Cipher frameCipher = crypto.getFrameCipher();
ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in, iv, ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in, iv,
frameCipher, frameKey); frameCipher, frameKey);

View File

@@ -8,25 +8,26 @@ import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
import javax.crypto.BadPaddingException; import javax.crypto.BadPaddingException;
import javax.crypto.Cipher; import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException; import javax.crypto.IllegalBlockSizeException;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.Bytes; import net.sf.briar.api.Bytes;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.NoSuchContactException; import net.sf.briar.api.db.NoSuchContactException;
import net.sf.briar.api.db.event.ContactRemovedEvent; import net.sf.briar.api.db.event.ContactRemovedEvent;
import net.sf.briar.api.db.event.DatabaseEvent; import net.sf.briar.api.db.event.DatabaseEvent;
import net.sf.briar.api.db.event.DatabaseListener; import net.sf.briar.api.db.event.DatabaseListener;
import net.sf.briar.api.db.event.TransportAddedEvent;
import net.sf.briar.api.db.event.RemoteTransportsUpdatedEvent; import net.sf.briar.api.db.event.RemoteTransportsUpdatedEvent;
import net.sf.briar.api.db.event.TransportAddedEvent;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
@@ -75,30 +76,29 @@ DatabaseListener {
} }
private synchronized void calculateIvs(ContactId c) throws DbException { private synchronized void calculateIvs(ContactId c) throws DbException {
byte[] secret = db.getSharedSecret(c, true);
ErasableKey ivKey = crypto.deriveIvKey(secret, true);
for(int i = 0; i < secret.length; i++) secret[i] = 0;
for(TransportId t : localTransportIds) { for(TransportId t : localTransportIds) {
TransportIndex i = db.getRemoteIndex(c, t); TransportIndex i = db.getRemoteIndex(c, t);
if(i != null) { if(i != null) {
ConnectionWindow w = db.getConnectionWindow(c, i); ConnectionWindow w = db.getConnectionWindow(c, i);
calculateIvs(c, i, ivKey, w); calculateIvs(c, i, w);
} }
} }
} }
private synchronized void calculateIvs(ContactId c, TransportIndex i, private synchronized void calculateIvs(ContactId c, TransportIndex i,
ErasableKey ivKey, ConnectionWindow w) ConnectionWindow w) throws DbException {
throws DbException { for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
for(Long unseen : w.getUnseen()) { long unseen = e.getKey();
byte[] secret = e.getValue();
ErasableKey ivKey = crypto.deriveIvKey(secret, true);
Bytes iv = new Bytes(encryptIv(i, unseen, ivKey)); Bytes iv = new Bytes(encryptIv(i, unseen, ivKey));
expected.put(iv, new ConnectionContextImpl(c, i, unseen)); expected.put(iv, new ConnectionContextImpl(c, i, unseen, secret));
} }
} }
private synchronized byte[] encryptIv(TransportIndex i, long connection, private synchronized byte[] encryptIv(TransportIndex i, long connection,
ErasableKey ivKey) { ErasableKey ivKey) {
byte[] iv = IvEncoder.encodeIv(true, i, connection); byte[] iv = IvEncoder.encodeIv(true, i.getInt(), connection);
try { try {
ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
return ivCipher.doFinal(iv); return ivCipher.doFinal(iv);
@@ -133,10 +133,7 @@ DatabaseListener {
TransportIndex i1 = ctx1.getTransportIndex(); TransportIndex i1 = ctx1.getTransportIndex();
if(c1.equals(c) && i1.equals(i)) it.remove(); if(c1.equals(c) && i1.equals(i)) it.remove();
} }
byte[] secret = db.getSharedSecret(c, true); calculateIvs(c, i, w);
ErasableKey ivKey = crypto.deriveIvKey(secret, true);
for(int j = 0; j < secret.length; j++) secret[j] = 0;
calculateIvs(c, i, ivKey, w);
} catch(NoSuchContactException e) { } catch(NoSuchContactException e) {
// The contact was removed - clean up when we get the event // The contact was removed - clean up when we get the event
} }
@@ -185,13 +182,10 @@ DatabaseListener {
private synchronized void calculateIvs(TransportId t) throws DbException { private synchronized void calculateIvs(TransportId t) throws DbException {
for(ContactId c : db.getContacts()) { for(ContactId c : db.getContacts()) {
try { try {
byte[] secret = db.getSharedSecret(c, true);
ErasableKey ivKey = crypto.deriveIvKey(secret, true);
for(int i = 0; i < secret.length; i++) secret[i] = 0;
TransportIndex i = db.getRemoteIndex(c, t); TransportIndex i = db.getRemoteIndex(c, t);
if(i != null) { if(i != null) {
ConnectionWindow w = db.getConnectionWindow(c, i); ConnectionWindow w = db.getConnectionWindow(c, i);
calculateIvs(c, i, ivKey, w); calculateIvs(c, i, w);
} }
} catch(NoSuchContactException e) { } catch(NoSuchContactException e) {
// The contact was removed - clean up when we get the event // The contact was removed - clean up when we get the event

View File

@@ -1,17 +1,30 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import java.util.Collection; import java.util.Map;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.api.transport.ConnectionWindowFactory; import net.sf.briar.api.transport.ConnectionWindowFactory;
import com.google.inject.Inject;
class ConnectionWindowFactoryImpl implements ConnectionWindowFactory { class ConnectionWindowFactoryImpl implements ConnectionWindowFactory {
public ConnectionWindow createConnectionWindow() { private final CryptoComponent crypto;
return new ConnectionWindowImpl();
@Inject
ConnectionWindowFactoryImpl(CryptoComponent crypto) {
this.crypto = crypto;
} }
public ConnectionWindow createConnectionWindow(Collection<Long> unseen) { public ConnectionWindow createConnectionWindow(TransportIndex i,
return new ConnectionWindowImpl(unseen); byte[] secret) {
return new ConnectionWindowImpl(crypto, i, secret);
}
public ConnectionWindow createConnectionWindow(TransportIndex i,
Map<Long, byte[]> unseen) {
return new ConnectionWindowImpl(crypto, i, unseen);
} }
} }

View File

@@ -2,28 +2,38 @@ package net.sf.briar.transport;
import static net.sf.briar.api.protocol.ProtocolConstants.CONNECTION_WINDOW_SIZE; import static net.sf.briar.api.protocol.ProtocolConstants.CONNECTION_WINDOW_SIZE;
import java.util.Collection; import java.util.HashMap;
import java.util.Set; import java.util.Map;
import java.util.TreeSet;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.util.ByteUtils; import net.sf.briar.util.ByteUtils;
class ConnectionWindowImpl implements ConnectionWindow { class ConnectionWindowImpl implements ConnectionWindow {
private final Set<Long> unseen; private final CryptoComponent crypto;
private final int index;
private final Map<Long, byte[]> unseen;
private long centre; private long centre;
ConnectionWindowImpl() { ConnectionWindowImpl(CryptoComponent crypto, TransportIndex i,
unseen = new TreeSet<Long>(); byte[] secret) {
for(long l = 0; l < CONNECTION_WINDOW_SIZE / 2; l++) unseen.add(l); this.crypto = crypto;
index = i.getInt();
unseen = new HashMap<Long, byte[]>();
for(long l = 0; l < CONNECTION_WINDOW_SIZE / 2; l++) {
secret = crypto.deriveNextSecret(secret, index, l);
unseen.put(l, secret);
}
centre = 0; centre = 0;
} }
ConnectionWindowImpl(Collection<Long> unseen) { ConnectionWindowImpl(CryptoComponent crypto, TransportIndex i,
Map<Long, byte[]> unseen) {
long min = Long.MAX_VALUE, max = Long.MIN_VALUE; long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
for(long l : unseen) { for(long l : unseen.keySet()) {
if(l < 0 || l > ByteUtils.MAX_32_BIT_UNSIGNED) if(l < 0 || l > ByteUtils.MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
if(l < min) min = l; if(l < min) min = l;
@@ -31,15 +41,17 @@ class ConnectionWindowImpl implements ConnectionWindow {
} }
if(max - min > CONNECTION_WINDOW_SIZE) if(max - min > CONNECTION_WINDOW_SIZE)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
this.unseen = new TreeSet<Long>(unseen);
centre = max - CONNECTION_WINDOW_SIZE / 2 + 1; centre = max - CONNECTION_WINDOW_SIZE / 2 + 1;
for(long l = centre; l <= max; l++) { for(long l = centre; l <= max; l++) {
if(!this.unseen.contains(l)) throw new IllegalArgumentException(); if(!unseen.containsKey(l)) throw new IllegalArgumentException();
} }
this.crypto = crypto;
index = i.getInt();
this.unseen = unseen;
} }
public boolean isSeen(long connection) { public boolean isSeen(long connection) {
return !unseen.contains(connection); return !unseen.containsKey(connection);
} }
public void setSeen(long connection) { public void setSeen(long connection) {
@@ -47,14 +59,26 @@ class ConnectionWindowImpl implements ConnectionWindow {
long top = getTop(centre); long top = getTop(centre);
if(connection < bottom || connection > top) if(connection < bottom || connection > top)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
if(!unseen.remove(connection)) throw new IllegalArgumentException(); if(!unseen.containsKey(connection))
throw new IllegalArgumentException();
if(connection >= centre) { if(connection >= centre) {
centre = connection + 1; centre = connection + 1;
long newBottom = getBottom(centre); long newBottom = getBottom(centre);
long newTop = getTop(centre); long newTop = getTop(centre);
for(long l = bottom; l < newBottom; l++) unseen.remove(l); for(long l = bottom; l < newBottom; l++) {
for(long l = top + 1; l <= newTop; l++) unseen.add(l); byte[] expired = unseen.remove(l);
if(expired != null) ByteUtils.erase(expired);
}
byte[] topSecret = unseen.get(top);
assert topSecret != null;
for(long l = top + 1; l <= newTop; l++) {
topSecret = crypto.deriveNextSecret(topSecret, index, l);
unseen.put(l, topSecret);
}
} }
byte[] seen = unseen.remove(connection);
assert seen != null;
ByteUtils.erase(seen);
} }
// Returns the lowest value contained in a window with the given centre // Returns the lowest value contained in a window with the given centre
@@ -68,7 +92,7 @@ class ConnectionWindowImpl implements ConnectionWindow {
centre + CONNECTION_WINDOW_SIZE / 2 - 1); centre + CONNECTION_WINDOW_SIZE / 2 - 1);
} }
public Collection<Long> getUnseen() { public Map<Long, byte[]> getUnseen() {
return unseen; return unseen;
} }
} }

View File

@@ -7,12 +7,13 @@ import javax.crypto.BadPaddingException;
import javax.crypto.Cipher; import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException; import javax.crypto.IllegalBlockSizeException;
import javax.crypto.Mac; import javax.crypto.Mac;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.util.ByteUtils;
import com.google.inject.Inject; import com.google.inject.Inject;
@@ -26,17 +27,15 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
} }
public ConnectionWriter createConnectionWriter(OutputStream out, public ConnectionWriter createConnectionWriter(OutputStream out,
long capacity, TransportIndex i, long connection, byte[] secret) { long capacity, ConnectionContext ctx) {
return createConnectionWriter(out, capacity, true, i, connection, return createConnectionWriter(out, capacity, true, ctx);
secret);
} }
public ConnectionWriter createConnectionWriter(OutputStream out, public ConnectionWriter createConnectionWriter(OutputStream out,
long capacity, TransportIndex i, byte[] encryptedIv, long capacity, ConnectionContext ctx, byte[] encryptedIv) {
byte[] secret) {
// Decrypt the IV // Decrypt the IV
Cipher ivCipher = crypto.getIvCipher(); Cipher ivCipher = crypto.getIvCipher();
ErasableKey ivKey = crypto.deriveIvKey(secret, true); ErasableKey ivKey = crypto.deriveIvKey(ctx.getSecret(), true);
byte[] iv; byte[] iv;
try { try {
ivCipher.init(Cipher.DECRYPT_MODE, ivKey); ivCipher.init(Cipher.DECRYPT_MODE, ivKey);
@@ -49,26 +48,23 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
throw new RuntimeException(badKey); throw new RuntimeException(badKey);
} }
// Validate the IV // Validate the IV
if(!IvEncoder.validateIv(iv, true, i)) if(!IvEncoder.validateIv(iv, true, ctx))
throw new IllegalArgumentException(); throw new IllegalArgumentException();
// Copy the connection number return createConnectionWriter(out, capacity, false, ctx);
long connection = IvEncoder.getConnectionNumber(iv);
return createConnectionWriter(out, capacity, false, i, connection,
secret);
} }
private ConnectionWriter createConnectionWriter(OutputStream out, private ConnectionWriter createConnectionWriter(OutputStream out,
long capacity, boolean initiator, TransportIndex i, long connection, long capacity, boolean initiator, ConnectionContext ctx) {
byte[] secret) {
// Derive the keys and erase the secret // Derive the keys and erase the secret
byte[] secret = ctx.getSecret();
ErasableKey ivKey = crypto.deriveIvKey(secret, initiator); ErasableKey ivKey = crypto.deriveIvKey(secret, initiator);
ErasableKey frameKey = crypto.deriveFrameKey(secret, initiator); ErasableKey frameKey = crypto.deriveFrameKey(secret, initiator);
ErasableKey macKey = crypto.deriveMacKey(secret, initiator); ErasableKey macKey = crypto.deriveMacKey(secret, initiator);
for(int j = 0; j < secret.length; j++) secret[j] = 0; ByteUtils.erase(secret);
// Create the encrypter // Create the encrypter
Cipher ivCipher = crypto.getIvCipher(); Cipher ivCipher = crypto.getIvCipher();
Cipher frameCipher = crypto.getFrameCipher(); Cipher frameCipher = crypto.getFrameCipher();
byte[] iv = IvEncoder.encodeIv(initiator, i, connection); byte[] iv = IvEncoder.encodeIv(initiator, ctx);
ConnectionEncrypter encrypter = new ConnectionEncrypterImpl(out, ConnectionEncrypter encrypter = new ConnectionEncrypterImpl(out,
capacity, iv, ivCipher, frameCipher, ivKey, frameKey); capacity, iv, ivCipher, frameCipher, ivKey, frameKey);
// Create the writer // Create the writer

View File

@@ -1,18 +1,22 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH; import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.util.ByteUtils; import net.sf.briar.util.ByteUtils;
class IvEncoder { class IvEncoder {
static byte[] encodeIv(boolean initiator, TransportIndex i, static byte[] encodeIv(boolean initiator, ConnectionContext ctx) {
long connection) { return encodeIv(initiator, ctx.getTransportIndex().getInt(),
ctx.getConnectionNumber());
}
static byte[] encodeIv(boolean initiator, int index, long connection) {
byte[] iv = new byte[IV_LENGTH]; byte[] iv = new byte[IV_LENGTH];
// Bit 31 is the initiator flag // Bit 31 is the initiator flag
if(initiator) iv[3] = 1; if(initiator) iv[3] = 1;
// Encode the transport identifier as an unsigned 16-bit integer // Encode the transport index as an unsigned 16-bit integer
ByteUtils.writeUint16(i.getInt(), iv, 4); ByteUtils.writeUint16(index, iv, 4);
// Encode the connection number as an unsigned 32-bit integer // Encode the connection number as an unsigned 32-bit integer
ByteUtils.writeUint32(connection, iv, 6); ByteUtils.writeUint32(connection, iv, 6);
return iv; return iv;
@@ -24,7 +28,14 @@ class IvEncoder {
ByteUtils.writeUint32(frame, iv, 10); ByteUtils.writeUint32(frame, iv, 10);
} }
static boolean validateIv(byte[] iv, boolean initiator, TransportIndex i) { static boolean validateIv(byte[] iv, boolean initiator,
ConnectionContext ctx) {
return validateIv(iv, initiator, ctx.getTransportIndex().getInt(),
ctx.getConnectionNumber());
}
static boolean validateIv(byte[] iv, boolean initiator, int index,
long connection) {
if(iv.length != IV_LENGTH) return false; if(iv.length != IV_LENGTH) return false;
// Check that the reserved bits are all zero // Check that the reserved bits are all zero
for(int j = 0; j < 2; j++) if(iv[j] != 0) return false; for(int j = 0; j < 2; j++) if(iv[j] != 0) return false;
@@ -32,8 +43,10 @@ class IvEncoder {
for(int j = 10; j < iv.length; j++) if(iv[j] != 0) return false; for(int j = 10; j < iv.length; j++) if(iv[j] != 0) return false;
// Check that the initiator flag matches // Check that the initiator flag matches
if(initiator != getInitiatorFlag(iv)) return false; if(initiator != getInitiatorFlag(iv)) return false;
// Check that the transport ID matches // Check that the transport index matches
if(i.getInt() != getTransportId(iv)) return false; if(index != getTransportIndex(iv)) return false;
// Check that the connection number matches
if(connection != getConnectionNumber(iv)) return false;
// The IV is valid // The IV is valid
return true; return true;
} }
@@ -43,7 +56,7 @@ class IvEncoder {
return (iv[3] & 1) == 1; return (iv[3] & 1) == 1;
} }
static int getTransportId(byte[] iv) { static int getTransportIndex(byte[] iv) {
if(iv.length != IV_LENGTH) throw new IllegalArgumentException(); if(iv.length != IV_LENGTH) throw new IllegalArgumentException();
return ByteUtils.readUint16(iv, 4); return ByteUtils.readUint16(iv, 4);
} }

View File

@@ -8,6 +8,7 @@ import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import net.sf.briar.api.transport.BatchConnectionFactory; import net.sf.briar.api.transport.BatchConnectionFactory;
import net.sf.briar.api.transport.BatchTransportReader; import net.sf.briar.api.transport.BatchTransportReader;
import net.sf.briar.api.transport.BatchTransportWriter; import net.sf.briar.api.transport.BatchTransportWriter;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
@@ -33,11 +34,10 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory {
this.protoWriterFactory = protoWriterFactory; this.protoWriterFactory = protoWriterFactory;
} }
public void createIncomingConnection(TransportIndex i, ContactId c, public void createIncomingConnection(ConnectionContext ctx,
BatchTransportReader r, byte[] encryptedIv) { BatchTransportReader r, byte[] encryptedIv) {
final IncomingBatchConnection conn = new IncomingBatchConnection( final IncomingBatchConnection conn = new IncomingBatchConnection(
connReaderFactory, db, protoReaderFactory, i, c, r, connReaderFactory, db, protoReaderFactory, ctx, r, encryptedIv);
encryptedIv);
Runnable read = new Runnable() { Runnable read = new Runnable() {
public void run() { public void run() {
conn.read(); conn.read();
@@ -46,10 +46,10 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory {
new Thread(read).start(); new Thread(read).start();
} }
public void createOutgoingConnection(TransportIndex i, ContactId c, public void createOutgoingConnection(ContactId c, TransportIndex i,
BatchTransportWriter w) { BatchTransportWriter w) {
final OutgoingBatchConnection conn = new OutgoingBatchConnection( final OutgoingBatchConnection conn = new OutgoingBatchConnection(
connWriterFactory, db, protoWriterFactory, i, c, w); connWriterFactory, db, protoWriterFactory, c, i, w);
Runnable write = new Runnable() { Runnable write = new Runnable() {
public void run() { public void run() {
conn.write(); conn.write();

View File

@@ -13,9 +13,9 @@ import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReader;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.transport.BatchTransportReader; import net.sf.briar.api.transport.BatchTransportReader;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
@@ -27,46 +27,43 @@ class IncomingBatchConnection {
private final ConnectionReaderFactory connFactory; private final ConnectionReaderFactory connFactory;
private final DatabaseComponent db; private final DatabaseComponent db;
private final ProtocolReaderFactory protoFactory; private final ProtocolReaderFactory protoFactory;
private final TransportIndex transportIndex; private final ConnectionContext ctx;
private final ContactId contactId;
private final BatchTransportReader reader; private final BatchTransportReader reader;
private final byte[] encryptedIv; private final byte[] encryptedIv;
IncomingBatchConnection(ConnectionReaderFactory connFactory, IncomingBatchConnection(ConnectionReaderFactory connFactory,
DatabaseComponent db, ProtocolReaderFactory protoFactory, DatabaseComponent db, ProtocolReaderFactory protoFactory,
TransportIndex transportIndex, ContactId contactId, ConnectionContext ctx, BatchTransportReader reader,
BatchTransportReader reader, byte[] encryptedIv) { byte[] encryptedIv) {
this.connFactory = connFactory; this.connFactory = connFactory;
this.db = db; this.db = db;
this.protoFactory = protoFactory; this.protoFactory = protoFactory;
this.transportIndex = transportIndex; this.ctx = ctx;
this.contactId = contactId;
this.reader = reader; this.reader = reader;
this.encryptedIv = encryptedIv; this.encryptedIv = encryptedIv;
} }
void read() { void read() {
try { try {
byte[] secret = db.getSharedSecret(contactId, true);
ConnectionReader conn = connFactory.createConnectionReader( ConnectionReader conn = connFactory.createConnectionReader(
reader.getInputStream(), transportIndex, encryptedIv, reader.getInputStream(), ctx, encryptedIv);
secret);
ProtocolReader proto = protoFactory.createProtocolReader( ProtocolReader proto = protoFactory.createProtocolReader(
conn.getInputStream()); conn.getInputStream());
ContactId c = ctx.getContactId();
// Read packets until EOF // Read packets until EOF
while(!proto.eof()) { while(!proto.eof()) {
if(proto.hasAck()) { if(proto.hasAck()) {
Ack a = proto.readAck(); Ack a = proto.readAck();
db.receiveAck(contactId, a); db.receiveAck(c, a);
} else if(proto.hasBatch()) { } else if(proto.hasBatch()) {
Batch b = proto.readBatch(); Batch b = proto.readBatch();
db.receiveBatch(contactId, b); db.receiveBatch(c, b);
} else if(proto.hasSubscriptionUpdate()) { } else if(proto.hasSubscriptionUpdate()) {
SubscriptionUpdate s = proto.readSubscriptionUpdate(); SubscriptionUpdate s = proto.readSubscriptionUpdate();
db.receiveSubscriptionUpdate(contactId, s); db.receiveSubscriptionUpdate(c, s);
} else if(proto.hasTransportUpdate()) { } else if(proto.hasTransportUpdate()) {
TransportUpdate t = proto.readTransportUpdate(); TransportUpdate t = proto.readTransportUpdate();
db.receiveTransportUpdate(contactId, t); db.receiveTransportUpdate(c, t);
} else { } else {
throw new FormatException(); throw new FormatException();
} }

View File

@@ -29,30 +29,28 @@ class OutgoingBatchConnection {
private final ConnectionWriterFactory connFactory; private final ConnectionWriterFactory connFactory;
private final DatabaseComponent db; private final DatabaseComponent db;
private final ProtocolWriterFactory protoFactory; private final ProtocolWriterFactory protoFactory;
private final TransportIndex transportIndex;
private final ContactId contactId; private final ContactId contactId;
private final TransportIndex transportIndex;
private final BatchTransportWriter writer; private final BatchTransportWriter writer;
OutgoingBatchConnection(ConnectionWriterFactory connFactory, OutgoingBatchConnection(ConnectionWriterFactory connFactory,
DatabaseComponent db, ProtocolWriterFactory protoFactory, DatabaseComponent db, ProtocolWriterFactory protoFactory,
TransportIndex transportIndex, ContactId contactId, ContactId contactId, TransportIndex transportIndex,
BatchTransportWriter writer) { BatchTransportWriter writer) {
this.connFactory = connFactory; this.connFactory = connFactory;
this.db = db; this.db = db;
this.protoFactory = protoFactory; this.protoFactory = protoFactory;
this.transportIndex = transportIndex;
this.contactId = contactId; this.contactId = contactId;
this.transportIndex = transportIndex;
this.writer = writer; this.writer = writer;
} }
void write() { void write() {
try { try {
byte[] secret = db.getSharedSecret(contactId, false); ConnectionContext ctx = db.getConnectionContext(contactId,
ConnectionContext ctx = transportIndex);
db.getConnectionContext(contactId, transportIndex);
ConnectionWriter conn = connFactory.createConnectionWriter( ConnectionWriter conn = connFactory.createConnectionWriter(
writer.getOutputStream(), writer.getCapacity(), writer.getOutputStream(), writer.getCapacity(), ctx);
transportIndex, ctx.getConnectionNumber(), secret);
OutputStream out = conn.getOutputStream(); OutputStream out = conn.getOutputStream();
// There should be enough space for a packet // There should be enough space for a packet
long capacity = conn.getRemainingCapacity(); long capacity = conn.getRemainingCapacity();

View File

@@ -2,12 +2,11 @@ package net.sf.briar.transport.stream;
import java.io.IOException; import java.io.IOException;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
@@ -16,35 +15,32 @@ import net.sf.briar.api.transport.StreamTransportConnection;
public class IncomingStreamConnection extends StreamConnection { public class IncomingStreamConnection extends StreamConnection {
private final ConnectionContext ctx;
private final byte[] encryptedIv; private final byte[] encryptedIv;
IncomingStreamConnection(ConnectionReaderFactory connReaderFactory, IncomingStreamConnection(ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ProtocolWriterFactory protoWriterFactory,
TransportIndex transportIndex, ContactId contactId, ConnectionContext ctx, StreamTransportConnection connection,
StreamTransportConnection connection,
byte[] encryptedIv) { byte[] encryptedIv) {
super(connReaderFactory, connWriterFactory, db, protoReaderFactory, super(connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, transportIndex, contactId, connection); protoWriterFactory, ctx.getContactId(), connection);
this.ctx = ctx;
this.encryptedIv = encryptedIv; this.encryptedIv = encryptedIv;
} }
@Override @Override
protected ConnectionReader createConnectionReader() throws DbException, protected ConnectionReader createConnectionReader() throws DbException,
IOException { IOException {
byte[] secret = db.getSharedSecret(contactId, true);
return connReaderFactory.createConnectionReader( return connReaderFactory.createConnectionReader(
connection.getInputStream(), transportIndex, encryptedIv, connection.getInputStream(), ctx, encryptedIv);
secret);
} }
@Override @Override
protected ConnectionWriter createConnectionWriter() throws DbException, protected ConnectionWriter createConnectionWriter() throws DbException,
IOException { IOException {
byte[] secret = db.getSharedSecret(contactId, false);
return connWriterFactory.createConnectionWriter( return connWriterFactory.createConnectionWriter(
connection.getOutputStream(), Long.MAX_VALUE, transportIndex, connection.getOutputStream(), Long.MAX_VALUE, ctx, encryptedIv);
encryptedIv, secret);
} }
} }

View File

@@ -17,43 +17,40 @@ import net.sf.briar.api.transport.StreamTransportConnection;
public class OutgoingStreamConnection extends StreamConnection { public class OutgoingStreamConnection extends StreamConnection {
private final TransportIndex transportIndex;
private ConnectionContext ctx = null; // Locking: this private ConnectionContext ctx = null; // Locking: this
OutgoingStreamConnection(ConnectionReaderFactory connReaderFactory, OutgoingStreamConnection(ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ProtocolWriterFactory protoWriterFactory, ContactId contactId,
TransportIndex transportIndex, ContactId contactId, TransportIndex transportIndex,
StreamTransportConnection connection) { StreamTransportConnection connection) {
super(connReaderFactory, connWriterFactory, db, protoReaderFactory, super(connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, transportIndex, contactId, connection); protoWriterFactory, contactId, connection);
this.transportIndex = transportIndex;
} }
@Override @Override
protected ConnectionReader createConnectionReader() throws DbException, protected ConnectionReader createConnectionReader() throws DbException,
IOException { IOException {
synchronized(this) { synchronized(this) {
if(ctx == null) { if(ctx == null)
ctx = db.getConnectionContext(contactId, transportIndex); ctx = db.getConnectionContext(contactId, transportIndex);
}
} }
byte[] secret = db.getSharedSecret(contactId, true);
return connReaderFactory.createConnectionReader( return connReaderFactory.createConnectionReader(
connection.getInputStream(), transportIndex, connection.getInputStream(), ctx);
ctx.getConnectionNumber(), secret);
} }
@Override @Override
protected ConnectionWriter createConnectionWriter() throws DbException, protected ConnectionWriter createConnectionWriter() throws DbException,
IOException { IOException {
synchronized(this) { synchronized(this) {
if(ctx == null) { if(ctx == null)
ctx = db.getConnectionContext(contactId, transportIndex); ctx = db.getConnectionContext(contactId, transportIndex);
}
} }
byte[] secret = db.getSharedSecret(contactId, false);
return connWriterFactory.createConnectionWriter( return connWriterFactory.createConnectionWriter(
connection.getOutputStream(), Long.MAX_VALUE, transportIndex, connection.getOutputStream(), Long.MAX_VALUE, ctx);
ctx.getConnectionNumber(), secret);
} }
} }

View File

@@ -29,7 +29,6 @@ import net.sf.briar.api.protocol.ProtocolReader;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.writers.AckWriter; import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.protocol.writers.BatchWriter; import net.sf.briar.api.protocol.writers.BatchWriter;
@@ -56,7 +55,6 @@ abstract class StreamConnection implements DatabaseListener {
protected final DatabaseComponent db; protected final DatabaseComponent db;
protected final ProtocolReaderFactory protoReaderFactory; protected final ProtocolReaderFactory protoReaderFactory;
protected final ProtocolWriterFactory protoWriterFactory; protected final ProtocolWriterFactory protoWriterFactory;
protected final TransportIndex transportIndex;
protected final ContactId contactId; protected final ContactId contactId;
protected final StreamTransportConnection connection; protected final StreamTransportConnection connection;
@@ -69,15 +67,13 @@ abstract class StreamConnection implements DatabaseListener {
StreamConnection(ConnectionReaderFactory connReaderFactory, StreamConnection(ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ProtocolWriterFactory protoWriterFactory, ContactId contactId,
TransportIndex transportIndex, ContactId contactId,
StreamTransportConnection connection) { StreamTransportConnection connection) {
this.connReaderFactory = connReaderFactory; this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory; this.connWriterFactory = connWriterFactory;
this.db = db; this.db = db;
this.protoReaderFactory = protoReaderFactory; this.protoReaderFactory = protoReaderFactory;
this.protoWriterFactory = protoWriterFactory; this.protoWriterFactory = protoWriterFactory;
this.transportIndex = transportIndex;
this.contactId = contactId; this.contactId = contactId;
this.connection = connection; this.connection = connection;
} }

View File

@@ -5,6 +5,7 @@ import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.api.transport.StreamConnectionFactory; import net.sf.briar.api.transport.StreamConnectionFactory;
@@ -32,11 +33,11 @@ public class StreamConnectionFactoryImpl implements StreamConnectionFactory {
this.protoWriterFactory = protoWriterFactory; this.protoWriterFactory = protoWriterFactory;
} }
public void createIncomingConnection(TransportIndex i, ContactId c, public void createIncomingConnection(ConnectionContext ctx,
StreamTransportConnection s, byte[] encryptedIv) { StreamTransportConnection s, byte[] encryptedIv) {
final StreamConnection conn = new IncomingStreamConnection( final StreamConnection conn = new IncomingStreamConnection(
connReaderFactory, connWriterFactory, db, protoReaderFactory, connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, i, c, s, encryptedIv); protoWriterFactory, ctx, s, encryptedIv);
Runnable write = new Runnable() { Runnable write = new Runnable() {
public void run() { public void run() {
conn.write(); conn.write();
@@ -51,11 +52,11 @@ public class StreamConnectionFactoryImpl implements StreamConnectionFactory {
new Thread(read).start(); new Thread(read).start();
} }
public void createOutgoingConnection(TransportIndex i, ContactId c, public void createOutgoingConnection(ContactId c, TransportIndex i,
StreamTransportConnection s) { StreamTransportConnection s) {
final StreamConnection conn = new OutgoingStreamConnection( final StreamConnection conn = new OutgoingStreamConnection(
connReaderFactory, connWriterFactory, db, protoReaderFactory, connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, i, c, s); protoWriterFactory, c, i, s);
Runnable write = new Runnable() { Runnable write = new Runnable() {
public void run() { public void run() {
conn.write(); conn.write();

View File

@@ -16,6 +16,7 @@ import java.util.Map;
import java.util.Random; import java.util.Random;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Author; import net.sf.briar.api.protocol.Author;
@@ -36,7 +37,6 @@ import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.protocol.writers.AckWriter; import net.sf.briar.api.protocol.writers.AckWriter;
import net.sf.briar.api.protocol.writers.BatchWriter; import net.sf.briar.api.protocol.writers.BatchWriter;
import net.sf.briar.api.protocol.writers.OfferWriter; import net.sf.briar.api.protocol.writers.OfferWriter;
@@ -44,6 +44,8 @@ import net.sf.briar.api.protocol.writers.ProtocolWriterFactory;
import net.sf.briar.api.protocol.writers.RequestWriter; import net.sf.briar.api.protocol.writers.RequestWriter;
import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter; import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter;
import net.sf.briar.api.protocol.writers.TransportUpdateWriter; import net.sf.briar.api.protocol.writers.TransportUpdateWriter;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionContextFactory;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
@@ -68,12 +70,14 @@ public class ProtocolIntegrationTest extends TestCase {
private final BatchId ack = new BatchId(TestUtils.getRandomId()); private final BatchId ack = new BatchId(TestUtils.getRandomId());
private final long timestamp = System.currentTimeMillis(); private final long timestamp = System.currentTimeMillis();
private final ConnectionContextFactory connectionContextFactory;
private final ConnectionReaderFactory connectionReaderFactory; private final ConnectionReaderFactory connectionReaderFactory;
private final ConnectionWriterFactory connectionWriterFactory; private final ConnectionWriterFactory connectionWriterFactory;
private final ProtocolReaderFactory protocolReaderFactory; private final ProtocolReaderFactory protocolReaderFactory;
private final ProtocolWriterFactory protocolWriterFactory; private final ProtocolWriterFactory protocolWriterFactory;
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final byte[] aliceToBobSecret; private final byte[] secret;
private final ContactId contactId = new ContactId(13);
private final TransportIndex transportIndex = new TransportIndex(13); private final TransportIndex transportIndex = new TransportIndex(13);
private final long connection = 12345L; private final long connection = 12345L;
private final Author author; private final Author author;
@@ -91,16 +95,17 @@ public class ProtocolIntegrationTest extends TestCase {
new ProtocolWritersModule(), new SerialModule(), new ProtocolWritersModule(), new SerialModule(),
new TestDatabaseModule(), new TransportBatchModule(), new TestDatabaseModule(), new TransportBatchModule(),
new TransportModule(), new TransportStreamModule()); new TransportModule(), new TransportStreamModule());
connectionContextFactory =
i.getInstance(ConnectionContextFactory.class);
connectionReaderFactory = i.getInstance(ConnectionReaderFactory.class); connectionReaderFactory = i.getInstance(ConnectionReaderFactory.class);
connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class); connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class);
protocolReaderFactory = i.getInstance(ProtocolReaderFactory.class); protocolReaderFactory = i.getInstance(ProtocolReaderFactory.class);
protocolWriterFactory = i.getInstance(ProtocolWriterFactory.class); protocolWriterFactory = i.getInstance(ProtocolWriterFactory.class);
crypto = i.getInstance(CryptoComponent.class); crypto = i.getInstance(CryptoComponent.class);
assertEquals(crypto.getMessageDigest().getDigestLength(), // Create a shared secret
UniqueId.LENGTH);
Random r = new Random(); Random r = new Random();
aliceToBobSecret = new byte[32]; secret = new byte[32];
r.nextBytes(aliceToBobSecret); r.nextBytes(secret);
// Create two groups: one restricted, one unrestricted // Create two groups: one restricted, one unrestricted
GroupFactory groupFactory = i.getInstance(GroupFactory.class); GroupFactory groupFactory = i.getInstance(GroupFactory.class);
group = groupFactory.createGroup("Unrestricted group", null); group = groupFactory.createGroup("Unrestricted group", null);
@@ -139,9 +144,11 @@ public class ProtocolIntegrationTest extends TestCase {
private byte[] write() throws Exception { private byte[] write() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
byte[] copyOfSecret = Arrays.clone(aliceToBobSecret); ConnectionContext ctx =
connectionContextFactory.createConnectionContext(contactId,
transportIndex, connection, Arrays.clone(secret));
ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out,
Long.MAX_VALUE, transportIndex, connection, copyOfSecret); Long.MAX_VALUE, ctx);
OutputStream out1 = w.getOutputStream(); OutputStream out1 = w.getOutputStream();
AckWriter a = protocolWriterFactory.createAckWriter(out1); AckWriter a = protocolWriterFactory.createAckWriter(out1);
@@ -184,19 +191,15 @@ public class ProtocolIntegrationTest extends TestCase {
return out.toByteArray(); return out.toByteArray();
} }
private void read(byte[] connection) throws Exception { private void read(byte[] connectionData) throws Exception {
InputStream in = new ByteArrayInputStream(connection); InputStream in = new ByteArrayInputStream(connectionData);
byte[] encryptedIv = new byte[16]; byte[] encryptedIv = new byte[16];
int offset = 0; assertEquals(16, in.read(encryptedIv, 0, 16));
while(offset < 16) { ConnectionContext ctx =
int read = in.read(encryptedIv, offset, 16 - offset); connectionContextFactory.createConnectionContext(contactId,
if(read == -1) break; transportIndex, connection, Arrays.clone(secret));
offset += read;
}
assertEquals(16, offset);
byte[] copyOfSecret = Arrays.clone(aliceToBobSecret);
ConnectionReader r = connectionReaderFactory.createConnectionReader(in, ConnectionReader r = connectionReaderFactory.createConnectionReader(in,
transportIndex, encryptedIv, copyOfSecret); ctx, encryptedIv);
in = r.getInputStream(); in = r.getInputStream();
ProtocolReader protocolReader = ProtocolReader protocolReader =
protocolReaderFactory.createProtocolReader(in); protocolReaderFactory.createProtocolReader(in);

View File

@@ -47,7 +47,6 @@ import net.sf.briar.api.transport.ConnectionWindow;
import org.jmock.Expectations; import org.jmock.Expectations;
import org.jmock.Mockery; import org.jmock.Mockery;
import static org.junit.Assert.assertArrayEquals;
import org.junit.Test; import org.junit.Test;
public abstract class DatabaseComponentTest extends TestCase { public abstract class DatabaseComponentTest extends TestCase {
@@ -107,9 +106,9 @@ public abstract class DatabaseComponentTest extends TestCase {
Database<T> database, DatabaseCleaner cleaner); Database<T> database, DatabaseCleaner cleaner);
@Test @Test
@SuppressWarnings("unchecked")
public void testSimpleCalls() throws Exception { public void testSimpleCalls() throws Exception {
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ConnectionWindow connectionWindow = final ConnectionWindow connectionWindow =
@@ -138,7 +137,8 @@ public abstract class DatabaseComponentTest extends TestCase {
oneOf(database).setRating(txn, authorId, Rating.GOOD); oneOf(database).setRating(txn, authorId, Rating.GOOD);
will(returnValue(Rating.GOOD)); will(returnValue(Rating.GOOD));
// addContact() // addContact()
oneOf(database).addContact(txn, inSecret, outSecret); oneOf(database).addContact(with(txn), with(inSecret),
with(outSecret), with(any(Collection.class)));
will(returnValue(contactId)); will(returnValue(contactId));
oneOf(listener).eventOccurred(with(any(ContactAddedEvent.class))); oneOf(listener).eventOccurred(with(any(ContactAddedEvent.class)));
// getContacts() // getContacts()
@@ -149,16 +149,6 @@ public abstract class DatabaseComponentTest extends TestCase {
will(returnValue(true)); will(returnValue(true));
oneOf(database).getConnectionWindow(txn, contactId, remoteIndex); oneOf(database).getConnectionWindow(txn, contactId, remoteIndex);
will(returnValue(connectionWindow)); will(returnValue(connectionWindow));
// getSharedSecret(contactId, true)
oneOf(database).containsContact(txn, contactId);
will(returnValue(true));
oneOf(database).getSharedSecret(txn, contactId, true);
will(returnValue(inSecret));
// getSharedSecret(contactId, false)
oneOf(database).containsContact(txn, contactId);
will(returnValue(true));
oneOf(database).getSharedSecret(txn, contactId, false);
will(returnValue(outSecret));
// getTransportProperties(transportId) // getTransportProperties(transportId)
oneOf(database).getRemoteProperties(txn, transportId); oneOf(database).getRemoteProperties(txn, transportId);
will(returnValue(remoteProperties)); will(returnValue(remoteProperties));
@@ -213,8 +203,6 @@ public abstract class DatabaseComponentTest extends TestCase {
assertEquals(Collections.singletonList(contactId), db.getContacts()); assertEquals(Collections.singletonList(contactId), db.getContacts());
assertEquals(connectionWindow, assertEquals(connectionWindow,
db.getConnectionWindow(contactId, remoteIndex)); db.getConnectionWindow(contactId, remoteIndex));
assertArrayEquals(inSecret, db.getSharedSecret(contactId, true));
assertArrayEquals(outSecret, db.getSharedSecret(contactId, false));
assertEquals(remoteProperties, db.getRemoteProperties(transportId)); assertEquals(remoteProperties, db.getRemoteProperties(transportId));
db.subscribe(group); // First time - listeners called db.subscribe(group); // First time - listeners called
db.subscribe(group); // Second time - not called db.subscribe(group); // Second time - not called
@@ -516,11 +504,11 @@ public abstract class DatabaseComponentTest extends TestCase {
context.mock(TransportUpdate.class); context.mock(TransportUpdate.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// Check whether the contact is still in the DB (which it's not) // Check whether the contact is still in the DB (which it's not)
exactly(20).of(database).startTransaction(); exactly(19).of(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
exactly(20).of(database).containsContact(txn, contactId); exactly(19).of(database).containsContact(txn, contactId);
will(returnValue(false)); will(returnValue(false));
exactly(20).of(database).commitTransaction(txn); exactly(19).of(database).commitTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner); DatabaseComponent db = createDatabaseComponent(database, cleaner);
@@ -575,11 +563,6 @@ public abstract class DatabaseComponentTest extends TestCase {
fail(); fail();
} catch(NoSuchContactException expected) {} } catch(NoSuchContactException expected) {}
try {
db.getSharedSecret(contactId, true);
fail();
} catch(NoSuchContactException expected) {}
try { try {
db.hasSendableMessages(contactId); db.hasSendableMessages(contactId);
fail(); fail();

View File

@@ -4,6 +4,7 @@ import static org.junit.Assert.assertArrayEquals;
import java.io.File; import java.io.File;
import java.sql.Connection; import java.sql.Connection;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
@@ -88,6 +89,7 @@ public class H2DatabaseTest extends TestCase {
private final Collection<Transport> remoteTransports; private final Collection<Transport> remoteTransports;
private final Map<Group, Long> subscriptions; private final Map<Group, Long> subscriptions;
private final byte[] inSecret, outSecret; private final byte[] inSecret, outSecret;
private final Collection<byte[]> erase;
public H2DatabaseTest() throws Exception { public H2DatabaseTest() throws Exception {
super(); super();
@@ -131,6 +133,7 @@ public class H2DatabaseTest extends TestCase {
r.nextBytes(inSecret); r.nextBytes(inSecret);
outSecret = new byte[32]; outSecret = new byte[32];
r.nextBytes(outSecret); r.nextBytes(outSecret);
erase = new ArrayList<byte[]>();
} }
@Before @Before
@@ -144,8 +147,7 @@ public class H2DatabaseTest extends TestCase {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
assertFalse(db.containsContact(txn, contactId)); assertFalse(db.containsContact(txn, contactId));
assertEquals(contactId, assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addContact(txn, inSecret, outSecret));
assertTrue(db.containsContact(txn, contactId)); assertTrue(db.containsContact(txn, contactId));
assertFalse(db.containsSubscription(txn, groupId)); assertFalse(db.containsSubscription(txn, groupId));
db.addSubscription(txn, group); db.addSubscription(txn, group);
@@ -201,20 +203,23 @@ public class H2DatabaseTest extends TestCase {
// Create three contacts // Create three contacts
assertFalse(db.containsContact(txn, contactId)); assertFalse(db.containsContact(txn, contactId));
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
assertTrue(db.containsContact(txn, contactId)); assertTrue(db.containsContact(txn, contactId));
assertFalse(db.containsContact(txn, contactId1)); assertFalse(db.containsContact(txn, contactId1));
assertEquals(contactId1, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId1,
db.addContact(txn, inSecret, outSecret, erase));
assertTrue(db.containsContact(txn, contactId1)); assertTrue(db.containsContact(txn, contactId1));
assertFalse(db.containsContact(txn, contactId2)); assertFalse(db.containsContact(txn, contactId2));
assertEquals(contactId2, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId2,
db.addContact(txn, inSecret, outSecret, erase));
assertTrue(db.containsContact(txn, contactId2)); assertTrue(db.containsContact(txn, contactId2));
// Delete the contact with the highest ID // Delete the contact with the highest ID
db.removeContact(txn, contactId2); db.removeContact(txn, contactId2);
assertFalse(db.containsContact(txn, contactId2)); assertFalse(db.containsContact(txn, contactId2));
// Add another contact - a new ID should be created // Add another contact - a new ID should be created
assertFalse(db.containsContact(txn, contactId3)); assertFalse(db.containsContact(txn, contactId3));
assertEquals(contactId3, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId3,
db.addContact(txn, inSecret, outSecret, erase));
assertTrue(db.containsContact(txn, contactId3)); assertTrue(db.containsContact(txn, contactId3));
db.commitTransaction(txn); db.commitTransaction(txn);
@@ -261,7 +266,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and store a private message // Add a contact and store a private message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addPrivateMessage(txn, privateMessage, contactId); db.addPrivateMessage(txn, privateMessage, contactId);
// Removing the contact should remove the message // Removing the contact should remove the message
@@ -280,7 +285,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and store a private message // Add a contact and store a private message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addPrivateMessage(txn, privateMessage, contactId); db.addPrivateMessage(txn, privateMessage, contactId);
// The message has no status yet, so it should not be sendable // The message has no status yet, so it should not be sendable
@@ -319,7 +324,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and store a private message // Add a contact and store a private message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addPrivateMessage(txn, privateMessage, contactId); db.addPrivateMessage(txn, privateMessage, contactId);
db.setStatus(txn, contactId, privateMessageId, Status.NEW); db.setStatus(txn, contactId, privateMessageId, Status.NEW);
@@ -347,7 +352,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setVisibility(txn, groupId, Collections.singletonList(contactId));
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
@@ -385,7 +390,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setVisibility(txn, groupId, Collections.singletonList(contactId));
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
@@ -427,7 +432,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setVisibility(txn, groupId, Collections.singletonList(contactId));
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -466,7 +471,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setVisibility(txn, groupId, Collections.singletonList(contactId));
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -501,7 +506,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setVisibility(txn, groupId, Collections.singletonList(contactId));
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
@@ -532,7 +537,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -565,7 +570,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and some batches to ack // Add a contact and some batches to ack
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addBatchToAck(txn, contactId, batchId); db.addBatchToAck(txn, contactId, batchId);
db.addBatchToAck(txn, contactId, batchId1); db.addBatchToAck(txn, contactId, batchId1);
@@ -592,7 +597,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and receive the same batch twice // Add a contact and receive the same batch twice
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addBatchToAck(txn, contactId, batchId); db.addBatchToAck(txn, contactId, batchId);
db.addBatchToAck(txn, contactId, batchId); db.addBatchToAck(txn, contactId, batchId);
@@ -618,7 +623,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -643,8 +648,8 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add two contacts, subscribe to a group and store a message // Add two contacts, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
ContactId contactId1 = db.addContact(txn, inSecret, outSecret); ContactId contactId1 = db.addContact(txn, inSecret, outSecret, erase);
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -666,7 +671,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setVisibility(txn, groupId, Collections.singletonList(contactId));
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
@@ -705,7 +710,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setVisibility(txn, groupId, Collections.singletonList(contactId));
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
@@ -750,7 +755,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact // Add a contact
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
// Add some outstanding batches, a few ms apart // Add some outstanding batches, a few ms apart
for(int i = 0; i < ids.length; i++) { for(int i = 0; i < ids.length; i++) {
@@ -790,7 +795,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact // Add a contact
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
// Add some outstanding batches, a few ms apart // Add some outstanding batches, a few ms apart
for(int i = 0; i < ids.length; i++) { for(int i = 0; i < ids.length; i++) {
@@ -1010,7 +1015,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact with a transport // Add a contact with a transport
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.setTransports(txn, contactId, remoteTransports, 1); db.setTransports(txn, contactId, remoteTransports, 1);
assertEquals(remoteProperties, assertEquals(remoteProperties,
db.getRemoteProperties(txn, transportId)); db.getRemoteProperties(txn, transportId));
@@ -1103,7 +1108,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact with a transport // Add a contact with a transport
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.setTransports(txn, contactId, remoteTransports, 1); db.setTransports(txn, contactId, remoteTransports, 1);
assertEquals(remoteProperties, assertEquals(remoteProperties,
db.getRemoteProperties(txn, transportId)); db.getRemoteProperties(txn, transportId));
@@ -1147,7 +1152,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact with some subscriptions // Add a contact with some subscriptions
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
assertEquals(Collections.singletonList(group), assertEquals(Collections.singletonList(group),
db.getSubscriptions(txn, contactId)); db.getSubscriptions(txn, contactId));
@@ -1172,7 +1177,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact with some subscriptions // Add a contact with some subscriptions
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.setSubscriptions(txn, contactId, subscriptions, 2); db.setSubscriptions(txn, contactId, subscriptions, 2);
assertEquals(Collections.singletonList(group), assertEquals(Collections.singletonList(group),
db.getSubscriptions(txn, contactId)); db.getSubscriptions(txn, contactId));
@@ -1196,7 +1201,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and subscribe to a group // Add a contact and subscribe to a group
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
@@ -1214,7 +1219,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -1237,7 +1242,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -1260,7 +1265,7 @@ public class H2DatabaseTest extends TestCase {
// Add a contact, subscribe to a group and store a message - // Add a contact, subscribe to a group and store a message -
// the message is older than the contact's subscription // the message is older than the contact's subscription
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setVisibility(txn, groupId, Collections.singletonList(contactId));
Map<Group, Long> subs = Collections.singletonMap(group, timestamp + 1); Map<Group, Long> subs = Collections.singletonMap(group, timestamp + 1);
@@ -1284,7 +1289,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setVisibility(txn, groupId, Collections.singletonList(contactId));
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
@@ -1309,7 +1314,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and subscribe to a group // Add a contact and subscribe to a group
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setVisibility(txn, groupId, Collections.singletonList(contactId));
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
@@ -1328,7 +1333,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact with a subscription // Add a contact with a subscription
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
// There's no local subscription for the group // There's no local subscription for the group
@@ -1345,7 +1350,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
db.setStatus(txn, contactId, messageId, Status.NEW); db.setStatus(txn, contactId, messageId, Status.NEW);
@@ -1364,7 +1369,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
@@ -1384,7 +1389,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setVisibility(txn, groupId, Collections.singletonList(contactId));
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
@@ -1406,7 +1411,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.setVisibility(txn, groupId, Collections.singletonList(contactId)); db.setVisibility(txn, groupId, Collections.singletonList(contactId));
db.setSubscriptions(txn, contactId, subscriptions, 1); db.setSubscriptions(txn, contactId, subscriptions, 1);
@@ -1427,7 +1432,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and subscribe to a group // Add a contact and subscribe to a group
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
// The group should not be visible to the contact // The group should not be visible to the contact
assertEquals(Collections.emptyList(), db.getVisibility(txn, groupId)); assertEquals(Collections.emptyList(), db.getVisibility(txn, groupId));
@@ -1450,7 +1455,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact // Add a contact
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
// Get the connection window for a new index // Get the connection window for a new index
ConnectionWindow w = db.getConnectionWindow(txn, contactId, ConnectionWindow w = db.getConnectionWindow(txn, contactId,
remoteIndex); remoteIndex);
@@ -1469,18 +1474,18 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact // Add a contact
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
// Get the connection window for a new index // Get the connection window for a new index
ConnectionWindow w = db.getConnectionWindow(txn, contactId, ConnectionWindow w = db.getConnectionWindow(txn, contactId,
remoteIndex); remoteIndex);
// The connection window should exist and be in the initial state // The connection window should exist and be in the initial state
assertNotNull(w); assertNotNull(w);
Collection<Long> unseen = w.getUnseen(); Map<Long, byte[]> unseen = w.getUnseen();
long top = ProtocolConstants.CONNECTION_WINDOW_SIZE / 2 - 1; long top = ProtocolConstants.CONNECTION_WINDOW_SIZE / 2 - 1;
assertEquals(top + 1, unseen.size()); assertEquals(top + 1, unseen.size());
for(long l = 0; l <= top; l++) { for(long l = 0; l <= top; l++) {
assertFalse(w.isSeen(l)); assertFalse(w.isSeen(l));
assertTrue(unseen.contains(l)); assertTrue(unseen.containsKey(l));
} }
// Update the connection window and store it // Update the connection window and store it
w.setSeen(5); w.setSeen(5);
@@ -1573,7 +1578,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and subscribe to a group // Add a contact and subscribe to a group
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
// A message with a private parent should return null // A message with a private parent should return null
@@ -1622,7 +1627,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact // Add a contact
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
// The subscription and transport timestamps should be initialised to 0 // The subscription and transport timestamps should be initialised to 0
assertEquals(0L, db.getSubscriptionsModified(txn, contactId)); assertEquals(0L, db.getSubscriptionsModified(txn, contactId));
@@ -1653,7 +1658,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and subscribe to a group // Add a contact and subscribe to a group
assertEquals(contactId, db.addContact(txn, inSecret, outSecret)); assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
db.addSubscription(txn, group); db.addSubscription(txn, group);
// Store a couple of messages // Store a couple of messages
@@ -1897,6 +1902,7 @@ public class H2DatabaseTest extends TestCase {
@After @After
public void tearDown() { public void tearDown() {
erase.clear();
TestUtils.deleteTestDirectory(testDir); TestUtils.deleteTestDirectory(testDir);
} }

View File

@@ -52,7 +52,8 @@ public class ConnectionDecrypterImplTest extends TestCase {
private void testDecryption(boolean initiator) throws Exception { private void testDecryption(boolean initiator) throws Exception {
// Calculate the plaintext and ciphertext for the IV // Calculate the plaintext and ciphertext for the IV
byte[] iv = IvEncoder.encodeIv(initiator, transportIndex, connection); byte[] iv = IvEncoder.encodeIv(initiator, transportIndex.getInt(),
connection);
ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
byte[] encryptedIv = ivCipher.doFinal(iv); byte[] encryptedIv = ivCipher.doFinal(iv);
assertEquals(IV_LENGTH, encryptedIv.length); assertEquals(IV_LENGTH, encryptedIv.length);
@@ -85,8 +86,8 @@ public class ConnectionDecrypterImplTest extends TestCase {
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
// Use a ConnectionDecrypter to decrypt the ciphertext // Use a ConnectionDecrypter to decrypt the ciphertext
ConnectionDecrypter d = new ConnectionDecrypterImpl(in, ConnectionDecrypter d = new ConnectionDecrypterImpl(in,
IvEncoder.encodeIv(initiator, transportIndex, connection), IvEncoder.encodeIv(initiator, transportIndex.getInt(),
frameCipher, frameKey); connection), frameCipher, frameKey);
// First frame // First frame
byte[] decrypted = new byte[ciphertext.length]; byte[] decrypted = new byte[ciphertext.length];
TestUtils.readFully(d.getInputStream(), decrypted); TestUtils.readFully(d.getInputStream(), decrypted);

View File

@@ -50,7 +50,8 @@ public class ConnectionEncrypterImplTest extends TestCase {
private void testEncryption(boolean initiator) throws Exception { private void testEncryption(boolean initiator) throws Exception {
// Calculate the expected ciphertext for the IV // Calculate the expected ciphertext for the IV
byte[] iv = IvEncoder.encodeIv(initiator, transportIndex, connection); byte[] iv = IvEncoder.encodeIv(initiator, transportIndex.getInt(),
connection);
ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
byte[] encryptedIv = ivCipher.doFinal(iv); byte[] encryptedIv = ivCipher.doFinal(iv);
assertEquals(IV_LENGTH, encryptedIv.length); assertEquals(IV_LENGTH, encryptedIv.length);
@@ -82,7 +83,7 @@ public class ConnectionEncrypterImplTest extends TestCase {
byte[] expected = out.toByteArray(); byte[] expected = out.toByteArray();
// Use a ConnectionEncrypter to encrypt the plaintext // Use a ConnectionEncrypter to encrypt the plaintext
out.reset(); out.reset();
iv = IvEncoder.encodeIv(initiator, transportIndex, connection); iv = IvEncoder.encodeIv(initiator, transportIndex.getInt(), connection);
ConnectionEncrypter e = new ConnectionEncrypterImpl(out, Long.MAX_VALUE, ConnectionEncrypter e = new ConnectionEncrypterImpl(out, Long.MAX_VALUE,
iv, ivCipher, frameCipher, ivKey, frameKey); iv, ivCipher, frameCipher, ivKey, frameKey);
e.getOutputStream().write(plaintext); e.getOutputStream().write(plaintext);

View File

@@ -4,6 +4,7 @@ import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Map;
import java.util.Random; import java.util.Random;
import javax.crypto.Cipher; import javax.crypto.Cipher;
@@ -51,7 +52,8 @@ public class ConnectionRecogniserImplTest extends TestCase {
Transport transport = new Transport(transportId, localIndex, Transport transport = new Transport(transportId, localIndex,
Collections.singletonMap("foo", "bar")); Collections.singletonMap("foo", "bar"));
transports = Collections.singletonList(transport); transports = Collections.singletonList(transport);
connectionWindow = new ConnectionWindowImpl(); connectionWindow = new ConnectionWindowImpl(crypto, remoteIndex,
inSecret);
} }
@Test @Test
@@ -65,8 +67,6 @@ public class ConnectionRecogniserImplTest extends TestCase {
will(returnValue(transports)); will(returnValue(transports));
oneOf(db).getContacts(); oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId))); will(returnValue(Collections.singletonList(contactId)));
oneOf(db).getSharedSecret(contactId, true);
will(returnValue(inSecret));
oneOf(db).getRemoteIndex(contactId, transportId); oneOf(db).getRemoteIndex(contactId, transportId);
will(returnValue(remoteIndex)); will(returnValue(remoteIndex));
oneOf(db).getConnectionWindow(contactId, remoteIndex); oneOf(db).getConnectionWindow(contactId, remoteIndex);
@@ -80,11 +80,16 @@ public class ConnectionRecogniserImplTest extends TestCase {
@Test @Test
public void testExpectedIv() throws Exception { public void testExpectedIv() throws Exception {
// Calculate the shared secret for connection number 3
byte[] secret = inSecret;
for(int i = 0; i < 4; i++) {
secret = crypto.deriveNextSecret(secret, remoteIndex.getInt(), i);
}
// Calculate the expected IV for connection number 3 // Calculate the expected IV for connection number 3
ErasableKey ivKey = crypto.deriveIvKey(inSecret, true); ErasableKey ivKey = crypto.deriveIvKey(secret, true);
Cipher ivCipher = crypto.getIvCipher(); Cipher ivCipher = crypto.getIvCipher();
ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
byte[] iv = IvEncoder.encodeIv(true, remoteIndex, 3L); byte[] iv = IvEncoder.encodeIv(true, remoteIndex.getInt(), 3);
byte[] encryptedIv = ivCipher.doFinal(iv); byte[] encryptedIv = ivCipher.doFinal(iv);
Mockery context = new Mockery(); Mockery context = new Mockery();
@@ -96,8 +101,6 @@ public class ConnectionRecogniserImplTest extends TestCase {
will(returnValue(transports)); will(returnValue(transports));
oneOf(db).getContacts(); oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId))); will(returnValue(Collections.singletonList(contactId)));
oneOf(db).getSharedSecret(contactId, true);
will(returnValue(inSecret));
oneOf(db).getRemoteIndex(contactId, transportId); oneOf(db).getRemoteIndex(contactId, transportId);
will(returnValue(remoteIndex)); will(returnValue(remoteIndex));
oneOf(db).getConnectionWindow(contactId, remoteIndex); oneOf(db).getConnectionWindow(contactId, remoteIndex);
@@ -107,8 +110,6 @@ public class ConnectionRecogniserImplTest extends TestCase {
will(returnValue(connectionWindow)); will(returnValue(connectionWindow));
oneOf(db).setConnectionWindow(contactId, remoteIndex, oneOf(db).setConnectionWindow(contactId, remoteIndex,
connectionWindow); connectionWindow);
oneOf(db).getSharedSecret(contactId, true);
will(returnValue(inSecret));
}}); }});
final ConnectionRecogniserImpl c = final ConnectionRecogniserImpl c =
new ConnectionRecogniserImpl(crypto, db); new ConnectionRecogniserImpl(crypto, db);
@@ -121,11 +122,11 @@ public class ConnectionRecogniserImplTest extends TestCase {
// Second time - the IV should no longer be expected // Second time - the IV should no longer be expected
assertNull(c.acceptConnection(encryptedIv)); assertNull(c.acceptConnection(encryptedIv));
// The window should have advanced // The window should have advanced
Collection<Long> unseen = connectionWindow.getUnseen(); Map<Long, byte[]> unseen = connectionWindow.getUnseen();
assertEquals(19, unseen.size()); assertEquals(19, unseen.size());
for(int i = 0; i < 19; i++) { for(int i = 0; i < 19; i++) {
if(i == 3) continue; if(i == 3) continue;
assertTrue(unseen.contains(Long.valueOf(i))); assertTrue(unseen.containsKey(Long.valueOf(i)));
} }
context.assertIsSatisfied(); context.assertIsSatisfied();
} }

View File

@@ -1,18 +1,39 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import java.util.ArrayList; import java.util.HashMap;
import java.util.Collection; import java.util.Map;
import java.util.Random;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.util.ByteUtils; import net.sf.briar.util.ByteUtils;
import org.junit.Test; import org.junit.Test;
import com.google.inject.Guice;
import com.google.inject.Injector;
public class ConnectionWindowImplTest extends TestCase { public class ConnectionWindowImplTest extends TestCase {
private final CryptoComponent crypto;
private final byte[] secret;
private final TransportIndex transportIndex = new TransportIndex(13);
public ConnectionWindowImplTest(String name) {
super(name);
Injector i = Guice.createInjector(new CryptoModule());
crypto = i.getInstance(CryptoComponent.class);
secret = new byte[32];
new Random().nextBytes(secret);
}
@Test @Test
public void testWindowSliding() { public void testWindowSliding() {
ConnectionWindowImpl w = new ConnectionWindowImpl(); ConnectionWindow w = new ConnectionWindowImpl(crypto,
transportIndex, secret);
for(int i = 0; i < 100; i++) { for(int i = 0; i < 100; i++) {
assertFalse(w.isSeen(i)); assertFalse(w.isSeen(i));
w.setSeen(i); w.setSeen(i);
@@ -22,7 +43,8 @@ public class ConnectionWindowImplTest extends TestCase {
@Test @Test
public void testWindowJumping() { public void testWindowJumping() {
ConnectionWindowImpl w = new ConnectionWindowImpl(); ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex,
secret);
for(int i = 0; i < 100; i += 13) { for(int i = 0; i < 100; i += 13) {
assertFalse(w.isSeen(i)); assertFalse(w.isSeen(i));
w.setSeen(i); w.setSeen(i);
@@ -32,7 +54,8 @@ public class ConnectionWindowImplTest extends TestCase {
@Test @Test
public void testWindowUpperLimit() { public void testWindowUpperLimit() {
ConnectionWindowImpl w = new ConnectionWindowImpl(); ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex,
secret);
// Centre is 0, highest value in window is 15 // Centre is 0, highest value in window is 15
w.setSeen(15); w.setSeen(15);
// Centre is 16, highest value in window is 31 // Centre is 16, highest value in window is 31
@@ -43,11 +66,11 @@ public class ConnectionWindowImplTest extends TestCase {
fail(); fail();
} catch(IllegalArgumentException expected) {} } catch(IllegalArgumentException expected) {}
// Values greater than 2^32 - 1 should never be allowed // Values greater than 2^32 - 1 should never be allowed
Collection<Long> unseen = new ArrayList<Long>(); Map<Long, byte[]> unseen = new HashMap<Long, byte[]>();
for(int i = 0; i < 32; i++) { for(int i = 0; i < 32; i++) {
unseen.add(ByteUtils.MAX_32_BIT_UNSIGNED - i); unseen.put(ByteUtils.MAX_32_BIT_UNSIGNED - i, secret);
} }
w = new ConnectionWindowImpl(unseen); w = new ConnectionWindowImpl(crypto, transportIndex, unseen);
w.setSeen(ByteUtils.MAX_32_BIT_UNSIGNED); w.setSeen(ByteUtils.MAX_32_BIT_UNSIGNED);
try { try {
w.setSeen(ByteUtils.MAX_32_BIT_UNSIGNED + 1); w.setSeen(ByteUtils.MAX_32_BIT_UNSIGNED + 1);
@@ -57,7 +80,8 @@ public class ConnectionWindowImplTest extends TestCase {
@Test @Test
public void testWindowLowerLimit() { public void testWindowLowerLimit() {
ConnectionWindowImpl w = new ConnectionWindowImpl(); ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex,
secret);
// Centre is 0, negative values should never be allowed // Centre is 0, negative values should never be allowed
try { try {
w.setSeen(-1); w.setSeen(-1);
@@ -87,7 +111,8 @@ public class ConnectionWindowImplTest extends TestCase {
@Test @Test
public void testCannotSetSeenTwice() { public void testCannotSetSeenTwice() {
ConnectionWindowImpl w = new ConnectionWindowImpl(); ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex,
secret);
w.setSeen(15); w.setSeen(15);
try { try {
w.setSeen(15); w.setSeen(15);
@@ -97,12 +122,13 @@ public class ConnectionWindowImplTest extends TestCase {
@Test @Test
public void testGetUnseenConnectionNumbers() { public void testGetUnseenConnectionNumbers() {
ConnectionWindowImpl w = new ConnectionWindowImpl(); ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex,
secret);
// Centre is 0; window should cover 0 to 15, inclusive, with none seen // Centre is 0; window should cover 0 to 15, inclusive, with none seen
Collection<Long> unseen = w.getUnseen(); Map<Long, byte[]> unseen = w.getUnseen();
assertEquals(16, unseen.size()); assertEquals(16, unseen.size());
for(int i = 0; i < 16; i++) { for(int i = 0; i < 16; i++) {
assertTrue(unseen.contains(Long.valueOf(i))); assertTrue(unseen.containsKey(Long.valueOf(i)));
assertFalse(w.isSeen(i)); assertFalse(w.isSeen(i));
} }
w.setSeen(3); w.setSeen(3);
@@ -112,10 +138,10 @@ public class ConnectionWindowImplTest extends TestCase {
assertEquals(19, unseen.size()); assertEquals(19, unseen.size());
for(int i = 0; i < 21; i++) { for(int i = 0; i < 21; i++) {
if(i == 3 || i == 4) { if(i == 3 || i == 4) {
assertFalse(unseen.contains(Long.valueOf(i))); assertFalse(unseen.containsKey(Long.valueOf(i)));
assertTrue(w.isSeen(i)); assertTrue(w.isSeen(i));
} else { } else {
assertTrue(unseen.contains(Long.valueOf(i))); assertTrue(unseen.containsKey(Long.valueOf(i)));
assertFalse(w.isSeen(i)); assertFalse(w.isSeen(i));
} }
} }
@@ -125,10 +151,10 @@ public class ConnectionWindowImplTest extends TestCase {
assertEquals(30, unseen.size()); assertEquals(30, unseen.size());
for(int i = 4; i < 36; i++) { for(int i = 4; i < 36; i++) {
if(i == 4 || i == 19) { if(i == 4 || i == 19) {
assertFalse(unseen.contains(Long.valueOf(i))); assertFalse(unseen.containsKey(Long.valueOf(i)));
assertTrue(w.isSeen(i)); assertTrue(w.isSeen(i));
} else { } else {
assertTrue(unseen.contains(Long.valueOf(i))); assertTrue(unseen.containsKey(Long.valueOf(i)));
assertFalse(w.isSeen(i)); assertFalse(w.isSeen(i));
} }
} }

View File

@@ -8,7 +8,10 @@ import java.util.Random;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.TestDatabaseModule; import net.sf.briar.TestDatabaseModule;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionContextFactory;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
@@ -26,8 +29,10 @@ import com.google.inject.Injector;
public class ConnectionWriterTest extends TestCase { public class ConnectionWriterTest extends TestCase {
private final ConnectionContextFactory connectionContextFactory;
private final ConnectionWriterFactory connectionWriterFactory; private final ConnectionWriterFactory connectionWriterFactory;
private final byte[] outSecret; private final byte[] secret;
private final ContactId contactId = new ContactId(13);
private final TransportIndex transportIndex = new TransportIndex(13); private final TransportIndex transportIndex = new TransportIndex(13);
private final long connection = 12345L; private final long connection = 12345L;
@@ -38,17 +43,22 @@ public class ConnectionWriterTest extends TestCase {
new ProtocolWritersModule(), new SerialModule(), new ProtocolWritersModule(), new SerialModule(),
new TestDatabaseModule(), new TransportBatchModule(), new TestDatabaseModule(), new TransportBatchModule(),
new TransportModule(), new TransportStreamModule()); new TransportModule(), new TransportStreamModule());
connectionContextFactory =
i.getInstance(ConnectionContextFactory.class);
connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class); connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class);
outSecret = new byte[32]; secret = new byte[32];
new Random().nextBytes(outSecret); new Random().nextBytes(secret);
} }
@Test @Test
public void testOverhead() throws Exception { public void testOverhead() throws Exception {
ByteArrayOutputStream out = ByteArrayOutputStream out =
new ByteArrayOutputStream(MIN_CONNECTION_LENGTH); new ByteArrayOutputStream(MIN_CONNECTION_LENGTH);
ConnectionContext ctx =
connectionContextFactory.createConnectionContext(contactId,
transportIndex, connection, secret);
ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out,
MIN_CONNECTION_LENGTH, transportIndex, connection, outSecret); MIN_CONNECTION_LENGTH, ctx);
// Check that the connection writer thinks there's room for a packet // Check that the connection writer thinks there's room for a packet
long capacity = w.getRemainingCapacity(); long capacity = w.getRemainingCapacity();
assertTrue(capacity >= MAX_PACKET_LENGTH); assertTrue(capacity >= MAX_PACKET_LENGTH);

View File

@@ -64,7 +64,8 @@ public class FrameReadWriteTest extends TestCase {
private void testWriteAndRead(boolean initiator) throws Exception { private void testWriteAndRead(boolean initiator) throws Exception {
// Create and encrypt the IV // Create and encrypt the IV
byte[] iv = IvEncoder.encodeIv(initiator, transportIndex, connection); byte[] iv = IvEncoder.encodeIv(initiator, transportIndex.getInt(),
connection);
ivCipher.init(Cipher.ENCRYPT_MODE, ivKey); ivCipher.init(Cipher.ENCRYPT_MODE, ivKey);
byte[] encryptedIv = ivCipher.doFinal(iv); byte[] encryptedIv = ivCipher.doFinal(iv);
assertEquals(IV_LENGTH, encryptedIv.length); assertEquals(IV_LENGTH, encryptedIv.length);
@@ -92,7 +93,7 @@ public class FrameReadWriteTest extends TestCase {
// Decrypt the IV // Decrypt the IV
ivCipher.init(Cipher.DECRYPT_MODE, ivKey); ivCipher.init(Cipher.DECRYPT_MODE, ivKey);
byte[] recoveredIv = ivCipher.doFinal(recoveredEncryptedIv); byte[] recoveredIv = ivCipher.doFinal(recoveredEncryptedIv);
iv = IvEncoder.encodeIv(initiator, transportIndex, connection); iv = IvEncoder.encodeIv(initiator, transportIndex.getInt(), connection);
assertArrayEquals(iv, recoveredIv); assertArrayEquals(iv, recoveredIv);
// Read the frames back // Read the frames back
ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in, iv, ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in, iv,

View File

@@ -119,7 +119,7 @@ public class BatchConnectionReadWriteTest extends TestCase {
alice.getInstance(ProtocolWriterFactory.class); alice.getInstance(ProtocolWriterFactory.class);
BatchTransportWriter writer = new TestBatchTransportWriter(out); BatchTransportWriter writer = new TestBatchTransportWriter(out);
OutgoingBatchConnection batchOut = new OutgoingBatchConnection( OutgoingBatchConnection batchOut = new OutgoingBatchConnection(
connFactory, db, protoFactory, transportIndex, contactId, connFactory, db, protoFactory, contactId, transportIndex,
writer); writer);
// Write whatever needs to be written // Write whatever needs to be written
batchOut.write(); batchOut.write();
@@ -170,8 +170,7 @@ public class BatchConnectionReadWriteTest extends TestCase {
bob.getInstance(ProtocolReaderFactory.class); bob.getInstance(ProtocolReaderFactory.class);
BatchTransportReader reader = new TestBatchTransportReader(in); BatchTransportReader reader = new TestBatchTransportReader(in);
IncomingBatchConnection batchIn = new IncomingBatchConnection( IncomingBatchConnection batchIn = new IncomingBatchConnection(
connFactory, db, protoFactory, transportIndex, contactId, connFactory, db, protoFactory, ctx, reader, encryptedIv);
reader, encryptedIv);
// No messages should have been added yet // No messages should have been added yet
assertFalse(listener.messagesAdded); assertFalse(listener.messagesAdded);
// Read whatever needs to be read // Read whatever needs to be read

View File

@@ -40,4 +40,8 @@ public class ByteUtils {
return ((b[offset] & 0xFFL) << 24) | ((b[offset + 1] & 0xFFL) << 16) return ((b[offset] & 0xFFL) << 24) | ((b[offset + 1] & 0xFFL) << 16)
| ((b[offset + 2] & 0xFFL) << 8) | (b[offset + 3] & 0xFFL); | ((b[offset + 2] & 0xFFL) << 8) | (b[offset + 3] & 0xFFL);
} }
public static void erase(byte[] b) {
for(int i = 0; i < b.length; i++) b[i] = 0;
}
} }