First stage of key rotation refactoring. Some tests are failing.

This commit is contained in:
akwizgran
2012-09-23 17:40:54 +01:00
parent eb360475aa
commit e779210ced
78 changed files with 1601 additions and 3411 deletions

View File

@@ -0,0 +1,36 @@
package net.sf.briar.api;
import net.sf.briar.api.protocol.TransportId;
public class ContactTransport extends TemporarySecret {
private final long epoch, clockDiff, latency;
private final boolean alice;
public ContactTransport(ContactId contactId, TransportId transportId,
long epoch, long clockDiff, long latency, boolean alice,
long period, byte[] secret, long outgoing, long centre,
byte[] bitmap) {
super(contactId, transportId, period, secret, outgoing, centre, bitmap);
this.epoch = epoch;
this.clockDiff = clockDiff;
this.latency = latency;
this.alice = alice;
}
public long getEpoch() {
return epoch;
}
public long getClockDifference() {
return clockDiff;
}
public long getLatency() {
return latency;
}
public boolean getAlice() {
return alice;
}
}

View File

@@ -0,0 +1,51 @@
package net.sf.briar.api;
import net.sf.briar.api.protocol.TransportId;
public class TemporarySecret {
protected final ContactId contactId;
protected final TransportId transportId;
protected final long period, outgoing, centre;
protected final byte[] secret, bitmap;
public TemporarySecret(ContactId contactId, TransportId transportId,
long period, byte[] secret, long outgoing, long centre,
byte[] bitmap) {
this.contactId = contactId;
this.transportId = transportId;
this.period = period;
this.secret = secret;
this.outgoing = outgoing;
this.centre = centre;
this.bitmap = bitmap;
}
public ContactId getContactId() {
return contactId;
}
public TransportId getTransportId() {
return transportId;
}
public long getPeriod() {
return period;
}
public byte[] getSecret() {
return secret;
}
public long getOutgoingConnectionCounter() {
return outgoing;
}
public long getWindowCentre() {
return centre;
}
public byte[] getWindowBitmap() {
return bitmap;
}
}

View File

@@ -9,16 +9,49 @@ import javax.crypto.Cipher;
public interface CryptoComponent { public interface CryptoComponent {
ErasableKey deriveTagKey(byte[] secret, boolean initiator); /**
* Derives a tag key from the given temporary secret.
* @param alice Indicates whether the key is for connections initiated by
* Alice or Bob.
*/
ErasableKey deriveTagKey(byte[] secret, boolean alice);
ErasableKey deriveFrameKey(byte[] secret, boolean initiator); /**
* Derives a frame key from the given temporary secret and connection
* number.
* @param alice Indicates whether the key is for a connection initiated by
* Alice or Bob.
* @param initiator Indicates whether the key is for the initiator's or the
* responder's side of the connection.
*/
ErasableKey deriveFrameKey(byte[] secret, long connection, boolean alice,
boolean initiator);
byte[][] deriveInitialSecrets(byte[] ourPublicKey, byte[] theirPublicKey, /**
PrivateKey ourPrivateKey, int invitationCode, boolean initiator); * Derives an initial shared secret from two public keys and one of the
* corresponding private keys.
* @param alice Indicates whether the private key belongs to Alice or Bob.
*/
byte[] deriveInitialSecret(byte[] ourPublicKey, byte[] theirPublicKey,
PrivateKey ourPrivateKey, boolean alice);
int deriveConfirmationCode(byte[] secret); /**
* Generates a random invitation code.
*/
int generateInvitationCode();
byte[] deriveNextSecret(byte[] secret, int index, long connection); /**
* Derives two confirmation codes from the given initial shared secret. The
* first code is for Alice to give to Bob; the second is for Bob to give to
* Alice.
*/
int[] deriveConfirmationCodes(byte[] secret);
/**
* Derives a temporary secret for the given period from the previous
* period's temporary secret.
*/
byte[] deriveNextSecret(byte[] secret, long period);
KeyPair generateAgreementKeyPair(); KeyPair generateAgreementKeyPair();

View File

@@ -0,0 +1,15 @@
package net.sf.briar.api.crypto;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext;
public interface KeyManager {
/**
* Returns a connection context for connecting to the given contact over
* the given transport, or null if the contact does not support the
* transport.
*/
ConnectionContext getConnectionContext(ContactId c, TransportId t);
}

View File

@@ -5,7 +5,9 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.ContactTransport;
import net.sf.briar.api.Rating; import net.sf.briar.api.Rating;
import net.sf.briar.api.TemporarySecret;
import net.sf.briar.api.TransportConfig; import net.sf.briar.api.TransportConfig;
import net.sf.briar.api.TransportProperties; import net.sf.briar.api.TransportProperties;
import net.sf.briar.api.db.event.DatabaseListener; import net.sf.briar.api.db.event.DatabaseListener;
@@ -22,10 +24,7 @@ import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWindow;
/** /**
* Encapsulates the database implementation and exposes high-level operations * Encapsulates the database implementation and exposes high-level operations
@@ -50,10 +49,9 @@ public interface DatabaseComponent {
void removeListener(DatabaseListener d); void removeListener(DatabaseListener d);
/** /**
* Adds a new contact to the database with the given secrets and returns an * Adds a new contact to the database and returns an ID for the contact.
* ID for the contact.
*/ */
ContactId addContact(byte[] inSecret, byte[] outSecret) throws DbException; ContactId addContact() throws DbException;
/** Adds a locally generated group message to the database. */ /** Adds a locally generated group message to the database. */
void addLocalGroupMessage(Message m) throws DbException; void addLocalGroupMessage(Message m) throws DbException;
@@ -62,10 +60,10 @@ public interface DatabaseComponent {
void addLocalPrivateMessage(Message m, ContactId c) throws DbException; void addLocalPrivateMessage(Message m, ContactId c) throws DbException;
/** /**
* Allocates and returns a local index for the given transport. Returns * Stores the given temporary secrets and deletes any secrets that have
* null if all indices have been allocated. * been made obsolete.
*/ */
TransportIndex addTransport(TransportId t) throws DbException; void addSecrets(Collection<TemporarySecret> secrets) throws DbException;
/** /**
* Generates an acknowledgement for the given contact. Returns null if * Generates an acknowledgement for the given contact. Returns null if
@@ -101,7 +99,7 @@ public interface DatabaseComponent {
* an update is not due. * an update is not due.
*/ */
SubscriptionUpdate generateSubscriptionUpdate(ContactId c) SubscriptionUpdate generateSubscriptionUpdate(ContactId c)
throws DbException; throws DbException;
/** /**
* Generates a transport update for the given contact. Returns null if an * Generates a transport update for the given contact. Returns null if an
@@ -112,28 +110,11 @@ public interface DatabaseComponent {
/** Returns the configuration for the given transport. */ /** Returns the configuration for the given transport. */
TransportConfig getConfig(TransportId t) throws DbException; TransportConfig getConfig(TransportId t) throws DbException;
/**
* Returns an outgoing connection context for the given contact and
* transport.
*/
ConnectionContext getConnectionContext(ContactId c, TransportIndex i)
throws DbException;
/**
* Returns the connection reordering window for the given contact and
* transport.
*/
ConnectionWindow getConnectionWindow(ContactId c, TransportIndex i)
throws DbException;
/** Returns the IDs of all contacts. */ /** Returns the IDs of all contacts. */
Collection<ContactId> getContacts() throws DbException; Collection<ContactId> getContacts() throws DbException;
/** /** Returns all contact transports. */
* Returns the local index for the given transport, or null if no index Collection<ContactTransport> getContactTransports() throws DbException;
* has been allocated.
*/
TransportIndex getLocalIndex(TransportId t) throws DbException;
/** Returns the local transport properties for the given transport. */ /** Returns the local transport properties for the given transport. */
TransportProperties getLocalProperties(TransportId t) throws DbException; TransportProperties getLocalProperties(TransportId t) throws DbException;
@@ -147,16 +128,9 @@ public interface DatabaseComponent {
/** Returns the user's rating for the given author. */ /** Returns the user's rating for the given author. */
Rating getRating(AuthorId a) throws DbException; Rating getRating(AuthorId a) throws DbException;
/**
* Returns the given contact's index for the given transport, or null if
* the contact does not support the transport.
*/
TransportIndex getRemoteIndex(ContactId c, TransportId t)
throws DbException;
/** Returns all remote transport properties for the given transport. */ /** Returns all remote transport properties for the given transport. */
Map<ContactId, TransportProperties> getRemoteProperties(TransportId t) Map<ContactId, TransportProperties> getRemoteProperties(TransportId t)
throws DbException; throws DbException;
/** Returns the set of groups to which the user subscribes. */ /** Returns the set of groups to which the user subscribes. */
Collection<Group> getSubscriptions() throws DbException; Collection<Group> getSubscriptions() throws DbException;
@@ -170,6 +144,13 @@ public interface DatabaseComponent {
/** Returns true if any messages are sendable to the given contact. */ /** Returns true if any messages are sendable to the given contact. */
boolean hasSendableMessages(ContactId c) throws DbException; boolean hasSendableMessages(ContactId c) throws DbException;
/**
* Increments the outgoing connection counter for the given contact
* transport in the given rotation period.
*/
void incrementConnectionCounter(ContactId c, TransportId t, long period)
throws DbException;
/** Processes an acknowledgement from the given contact. */ /** Processes an acknowledgement from the given contact. */
void receiveAck(ContactId c, Ack a) throws DbException; void receiveAck(ContactId c, Ack a) throws DbException;
@@ -188,11 +169,11 @@ public interface DatabaseComponent {
/** Processes a subscription update from the given contact. */ /** Processes a subscription update from the given contact. */
void receiveSubscriptionUpdate(ContactId c, SubscriptionUpdate s) void receiveSubscriptionUpdate(ContactId c, SubscriptionUpdate s)
throws DbException; throws DbException;
/** Processes a transport update from the given contact. */ /** Processes a transport update from the given contact. */
void receiveTransportUpdate(ContactId c, TransportUpdate t) void receiveTransportUpdate(ContactId c, TransportUpdate t)
throws DbException; throws DbException;
/** Removes a contact (and all associated state) from the database. */ /** Removes a contact (and all associated state) from the database. */
void removeContact(ContactId c) throws DbException; void removeContact(ContactId c) throws DbException;
@@ -204,18 +185,18 @@ public interface DatabaseComponent {
void setConfig(TransportId t, TransportConfig c) throws DbException; void setConfig(TransportId t, TransportConfig c) throws DbException;
/** /**
* Sets the connection reordering window for the given contact and * Sets the connection reordering window for the given contact transport
* transport. * in the given rotation period.
*/ */
void setConnectionWindow(ContactId c, TransportIndex i, void setConnectionWindow(ContactId c, TransportId t, long period,
ConnectionWindow w) throws DbException; long centre, byte[] bitmap) throws DbException;
/** /**
* Sets the local transport properties for the given transport, replacing * Sets the local transport properties for the given transport, replacing
* any existing properties for that transport. * any existing properties for that transport.
*/ */
void setLocalProperties(TransportId t, TransportProperties p) void setLocalProperties(TransportId t, TransportProperties p)
throws DbException; throws DbException;
/** Records the user's rating for the given author. */ /** Records the user's rating for the given author. */
void setRating(AuthorId a, Rating r) throws DbException; void setRating(AuthorId a, Rating r) throws DbException;
@@ -228,7 +209,7 @@ public interface DatabaseComponent {
* to any other contacts. * to any other contacts.
*/ */
void setVisibility(GroupId g, Collection<ContactId> visible) void setVisibility(GroupId g, Collection<ContactId> visible)
throws DbException; throws DbException;
/** Subscribes to the given group. */ /** Subscribes to the given group. */
void subscribe(Group g) throws DbException; void subscribe(Group g) throws DbException;

View File

@@ -7,8 +7,4 @@ package net.sf.briar.api.db;
public class NoSuchContactException extends DbException { public class NoSuchContactException extends DbException {
private static final long serialVersionUID = -7048538231308207386L; private static final long serialVersionUID = -7048538231308207386L;
public NoSuchContactException() {
super();
}
} }

View File

@@ -0,0 +1,10 @@
package net.sf.briar.api.db;
/**
* Thrown when a database operation is attempted for a contact transport that
* is not in the database.
*/
public class NoSuchContactTransportException extends DbException {
private static final long serialVersionUID = -6274982612759573100L;
}

View File

@@ -1,17 +0,0 @@
package net.sf.briar.api.db.event;
import net.sf.briar.api.protocol.TransportId;
/** An event that is broadcast when a transport is added. */
public class TransportAddedEvent extends DatabaseEvent {
private final TransportId transportId;
public TransportAddedEvent(TransportId transportId) {
this.transportId = transportId;
}
public TransportId getTransportId() {
return transportId;
}
}

View File

@@ -1,6 +0,0 @@
package net.sf.briar.api.plugins;
public interface IncomingInvitationCallback extends InvitationCallback {
int enterInvitationCode();
}

View File

@@ -1,14 +0,0 @@
package net.sf.briar.api.plugins;
public interface InvitationCallback {
boolean isCancelled();
int enterConfirmationCode(int code);
void showProgress(String... message);
void showFailure(String... message);
void showSuccess();
}

View File

@@ -1,12 +0,0 @@
package net.sf.briar.api.plugins;
import net.sf.briar.api.plugins.duplex.DuplexPlugin;
public interface InvitationStarter {
void startIncomingInvitation(DuplexPlugin plugin,
IncomingInvitationCallback callback);
void startOutgoingInvitation(DuplexPlugin plugin,
OutgoingInvitationCallback callback);
}

View File

@@ -1,6 +0,0 @@
package net.sf.briar.api.plugins;
public interface OutgoingInvitationCallback extends InvitationCallback {
void showInvitationCode(int code);
}

View File

@@ -8,29 +8,21 @@ public class Transport extends TreeMap<String, String> {
private static final long serialVersionUID = 4900420175715429560L; private static final long serialVersionUID = 4900420175715429560L;
private final TransportId id; private final TransportId id;
private final TransportIndex index;
public Transport(TransportId id, TransportIndex index, public Transport(TransportId id, Map<String, String> p) {
Map<String, String> p) {
super(p); super(p);
this.id = id; this.id = id;
this.index = index;
} }
public Transport(TransportId id, TransportIndex index) { public Transport(TransportId id) {
super(); super();
this.id = id; this.id = id;
this.index = index;
} }
public TransportId getId() { public TransportId getId() {
return id; return id;
} }
public TransportIndex getIndex() {
return index;
}
@Override @Override
public int hashCode() { public int hashCode() {
return id.hashCode(); return id.hashCode();
@@ -40,7 +32,7 @@ public class Transport extends TreeMap<String, String> {
public boolean equals(Object o) { public boolean equals(Object o) {
if(o instanceof Transport) { if(o instanceof Transport) {
Transport t = (Transport) o; Transport t = (Transport) o;
return id.equals(t.id) && index.equals(t.index) && super.equals(o); return id.equals(t.id) && super.equals(o);
} }
return false; return false;
} }

View File

@@ -1,33 +0,0 @@
package net.sf.briar.api.protocol;
/**
* Type-safe wrapper for an integer that uniquely identifies a transport plugin
* within the scope of a single node.
*/
public class TransportIndex {
private final int index;
public TransportIndex(int index) {
if(index < 0 || index >= ProtocolConstants.MAX_TRANSPORTS)
throw new IllegalArgumentException();
this.index = index;
}
public int getInt() {
return index;
}
@Override
public boolean equals(Object o) {
if(o instanceof TransportIndex)
return index == ((TransportIndex) o).index;
return false;
}
@Override
public int hashCode() {
return index;
}
}

View File

@@ -3,14 +3,12 @@ package net.sf.briar.api.protocol.duplex;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.plugins.duplex.DuplexTransportConnection; import net.sf.briar.api.plugins.duplex.DuplexTransportConnection;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
public interface DuplexConnectionFactory { public interface DuplexConnectionFactory {
void createIncomingConnection(ConnectionContext ctx, TransportId t, void createIncomingConnection(ConnectionContext ctx, DuplexTransportConnection d);
DuplexTransportConnection d);
void createOutgoingConnection(ContactId c, TransportId t, TransportIndex i, void createOutgoingConnection(ContactId c, TransportId t,
DuplexTransportConnection d); DuplexTransportConnection d);
} }

View File

@@ -4,14 +4,12 @@ import net.sf.briar.api.ContactId;
import net.sf.briar.api.plugins.simplex.SimplexTransportReader; import net.sf.briar.api.plugins.simplex.SimplexTransportReader;
import net.sf.briar.api.plugins.simplex.SimplexTransportWriter; import net.sf.briar.api.plugins.simplex.SimplexTransportWriter;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
public interface SimplexConnectionFactory { public interface SimplexConnectionFactory {
void createIncomingConnection(ConnectionContext ctx, TransportId t, void createIncomingConnection(ConnectionContext ctx, SimplexTransportReader r);
SimplexTransportReader r);
void createOutgoingConnection(ContactId c, TransportId t, TransportIndex i, void createOutgoingConnection(ContactId c, TransportId t,
SimplexTransportWriter w); SimplexTransportWriter w);
} }

View File

@@ -1,15 +1,47 @@
package net.sf.briar.api.transport; package net.sf.briar.api.transport;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportId;
public interface ConnectionContext { public class ConnectionContext {
ContactId getContactId(); private final ContactId contactId;
private final TransportId transportId;
private final byte[] tag, secret;
private final long connection;
private final boolean alice;
TransportIndex getTransportIndex(); public ConnectionContext(ContactId contactId, TransportId transportId,
byte[] tag, byte[] secret, long connection, boolean alice) {
this.contactId = contactId;
this.transportId = transportId;
this.tag = tag;
this.secret = secret;
this.connection = connection;
this.alice = alice;
}
long getConnectionNumber(); public ContactId getContactId() {
return contactId;
}
byte[] getSecret(); public TransportId getTransportId() {
return transportId;
}
public byte[] getTag() {
return tag;
}
public byte[] getSecret() {
return secret;
}
public long getConnectionNumber() {
return connection;
}
public boolean getAlice() {
return alice;
}
} }

View File

@@ -1,13 +0,0 @@
package net.sf.briar.api.transport;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.protocol.TransportIndex;
public interface ConnectionContextFactory {
ConnectionContext createConnectionContext(ContactId c, TransportIndex i,
long connection, byte[] secret);
ConnectionContext createNextConnectionContext(ContactId c, TransportIndex i,
long connection, byte[] previousSecret);
}

View File

@@ -5,17 +5,16 @@ import net.sf.briar.api.plugins.duplex.DuplexTransportConnection;
import net.sf.briar.api.plugins.simplex.SimplexTransportReader; import net.sf.briar.api.plugins.simplex.SimplexTransportReader;
import net.sf.briar.api.plugins.simplex.SimplexTransportWriter; import net.sf.briar.api.plugins.simplex.SimplexTransportWriter;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
public interface ConnectionDispatcher { public interface ConnectionDispatcher {
void dispatchReader(TransportId t, SimplexTransportReader r); void dispatchReader(TransportId t, SimplexTransportReader r);
void dispatchWriter(ContactId c, TransportId t, TransportIndex i, void dispatchWriter(ContactId c, TransportId t,
SimplexTransportWriter w); SimplexTransportWriter w);
void dispatchIncomingConnection(TransportId t, DuplexTransportConnection d); void dispatchIncomingConnection(TransportId t, DuplexTransportConnection d);
void dispatchOutgoingConnection(ContactId c, TransportId t, void dispatchOutgoingConnection(ContactId c, TransportId t,
TransportIndex i, DuplexTransportConnection d); DuplexTransportConnection d);
} }

View File

@@ -8,6 +8,6 @@ public interface ConnectionReaderFactory {
* Creates a connection reader for a simplex connection or one side of a * Creates a connection reader for a simplex connection or one side of a
* duplex connection. The secret is erased before this method returns. * duplex connection. The secret is erased before this method returns.
*/ */
ConnectionReader createConnectionReader(InputStream in, byte[] secret, ConnectionReader createConnectionReader(InputStream in,
boolean initiator); ConnectionContext ctx, boolean initiator);
} }

View File

@@ -1,14 +1,12 @@
package net.sf.briar.api.transport; package net.sf.briar.api.transport;
import java.util.Map; import java.util.Set;
public interface ConnectionWindow { public interface ConnectionWindow {
boolean isSeen(long connection); boolean isSeen(long connection);
byte[] setSeen(long connection); void setSeen(long connection);
Map<Long, byte[]> getUnseen(); Set<Long> getUnseen();
void erase();
} }

View File

@@ -1,13 +0,0 @@
package net.sf.briar.api.transport;
import java.util.Map;
import net.sf.briar.api.protocol.TransportIndex;
public interface ConnectionWindowFactory {
ConnectionWindow createConnectionWindow(TransportIndex i, byte[] secret);
ConnectionWindow createConnectionWindow(TransportIndex i,
Map<Long, byte[]> unseen);
}

View File

@@ -9,5 +9,5 @@ public interface ConnectionWriterFactory {
* duplex connection. The secret is erased before this method returns. * duplex connection. The secret is erased before this method returns.
*/ */
ConnectionWriter createConnectionWriter(OutputStream out, long capacity, ConnectionWriter createConnectionWriter(OutputStream out, long capacity,
byte[] secret, boolean initiator); ConnectionContext ctx, boolean initiator);
} }

View File

@@ -1,6 +1,7 @@
package net.sf.briar.crypto; package net.sf.briar.crypto;
import static net.sf.briar.api.plugins.InvitationConstants.CODE_BITS; import static net.sf.briar.api.plugins.InvitationConstants.CODE_BITS;
import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.security.KeyPair; import java.security.KeyPair;
@@ -48,18 +49,23 @@ class CryptoComponentImpl implements CryptoComponent {
private static final int GCM_MAC_LENGTH = 16; // 128 bits private static final int GCM_MAC_LENGTH = 16; // 128 bits
// Labels for key derivation // Labels for key derivation
private static final byte[] TAG = { 'T', 'A', 'G' }; private static final byte[] A_TAG = { 'A', '_', 'T', 'A', 'G', '\0' };
private static final byte[] FRAME = { 'F', 'R', 'A', 'M', 'E' }; private static final byte[] B_TAG = { 'B', '_', 'T', 'A', 'G', '\0' };
private static final byte[] A_FRAME_A =
{ 'A', '_', 'F', 'R', 'A', 'M', 'E', '_', 'A', '\0' };
private static final byte[] A_FRAME_B =
{ 'A', '_', 'F', 'R', 'A', 'M', 'E', '_', 'B', '\0' };
private static final byte[] B_FRAME_A =
{ 'B', '_', 'F', 'R', 'A', 'M', 'E', '_', 'A', '\0' };
private static final byte[] B_FRAME_B =
{ 'B', '_', 'F', 'R', 'A', 'M', 'E', '_', 'B', '\0' };
// Labels for secret derivation // Labels for secret derivation
private static final byte[] FIRST = { 'F', 'I', 'R', 'S', 'T' }; private static final byte[] FIRST = { 'F', 'I', 'R', 'S', 'T', '\0' };
private static final byte[] NEXT = { 'N', 'E', 'X', 'T' }; private static final byte[] ROTATE = { 'R', 'O', 'T', 'A', 'T', 'E', '\0' };
// Label for confirmation code derivation // Label for confirmation code derivation
private static final byte[] CODE = { 'C', 'O', 'D', 'E' }; private static final byte[] CODE = { 'C', 'O', 'D', 'E', '\0' };
// Context strings for key and confirmation code derivation
private static final byte[] INITIATOR = { 'I' };
private static final byte[] RESPONDER = { 'R' };
// Blank plaintext for key derivation // Blank plaintext for key derivation
private static final byte[] KEY_DERIVATION_INPUT = private static final byte[] KEY_DERIVATION_BLANK_PLAINTEXT =
new byte[SECRET_KEY_BYTES]; new byte[SECRET_KEY_BYTES];
private final KeyParser agreementKeyParser, signatureKeyParser; private final KeyParser agreementKeyParser, signatureKeyParser;
@@ -87,43 +93,46 @@ class CryptoComponentImpl implements CryptoComponent {
secureRandom = new SecureRandom(); secureRandom = new SecureRandom();
} }
public ErasableKey deriveTagKey(byte[] secret, boolean initiator) { public ErasableKey deriveTagKey(byte[] secret, boolean alice) {
if(initiator) return deriveKey(secret, TAG, INITIATOR); if(alice) return deriveKey(secret, A_TAG, 0L);
else return deriveKey(secret, TAG, RESPONDER); else return deriveKey(secret, B_TAG, 0L);
} }
public ErasableKey deriveFrameKey(byte[] secret, boolean initiator) { public ErasableKey deriveFrameKey(byte[] secret, long connection,
if(initiator) return deriveKey(secret, FRAME, INITIATOR); boolean alice, boolean initiator) {
else return deriveKey(secret, FRAME, RESPONDER); if(alice) {
if(initiator) return deriveKey(secret, A_FRAME_A, connection);
else return deriveKey(secret, A_FRAME_B, connection);
} else {
if(initiator) return deriveKey(secret, B_FRAME_A, connection);
else return deriveKey(secret, B_FRAME_B, connection);
}
} }
private ErasableKey deriveKey(byte[] secret, byte[] label, byte[] context) { private ErasableKey deriveKey(byte[] secret, byte[] label, long context) {
byte[] key = counterModeKdf(secret, label, context); byte[] key = counterModeKdf(secret, label, context);
return new ErasableKeyImpl(key, SECRET_KEY_ALGO); return new ErasableKeyImpl(key, SECRET_KEY_ALGO);
} }
// Key derivation function based on a block cipher in CTR mode - see // Key derivation function based on a block cipher in CTR mode - see
// NIST SP 800-108, section 5.1 // NIST SP 800-108, section 5.1
private byte[] counterModeKdf(byte[] secret, byte[] label, byte[] context) { private byte[] counterModeKdf(byte[] secret, byte[] label, long context) {
// The secret must be usable as a key // The secret must be usable as a key
if(secret.length != SECRET_KEY_BYTES) if(secret.length != SECRET_KEY_BYTES)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
// The label and context must leave a byte free for the counter // The label and context must leave a byte free for the counter
if(label.length + context.length + 2 >= KEY_DERIVATION_IV_BYTES) if(label.length + 4 >= KEY_DERIVATION_IV_BYTES)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
// The IV contains the length-prefixed label and context
byte[] ivBytes = new byte[KEY_DERIVATION_IV_BYTES]; byte[] ivBytes = new byte[KEY_DERIVATION_IV_BYTES];
ByteUtils.writeUint8(label.length, ivBytes, 0); System.arraycopy(label, 0, ivBytes, 0, label.length);
System.arraycopy(label, 0, ivBytes, 1, label.length); ByteUtils.writeUint32(context, ivBytes, label.length);
ByteUtils.writeUint8(context.length, ivBytes, label.length + 1);
System.arraycopy(context, 0, ivBytes, label.length + 2, context.length);
// Use the secret and the IV to encrypt a blank plaintext // Use the secret and the IV to encrypt a blank plaintext
IvParameterSpec iv = new IvParameterSpec(ivBytes); IvParameterSpec iv = new IvParameterSpec(ivBytes);
ErasableKey key = new ErasableKeyImpl(secret, SECRET_KEY_ALGO); ErasableKey key = new ErasableKeyImpl(secret, SECRET_KEY_ALGO);
try { try {
Cipher cipher = Cipher.getInstance(KEY_DERIVATION_ALGO, PROVIDER); Cipher cipher = Cipher.getInstance(KEY_DERIVATION_ALGO, PROVIDER);
cipher.init(Cipher.ENCRYPT_MODE, key, iv); cipher.init(Cipher.ENCRYPT_MODE, key, iv);
byte[] output = cipher.doFinal(KEY_DERIVATION_INPUT); byte[] output = cipher.doFinal(KEY_DERIVATION_BLANK_PLAINTEXT);
assert output.length == SECRET_KEY_BYTES; assert output.length == SECRET_KEY_BYTES;
return output; return output;
} catch(GeneralSecurityException e) { } catch(GeneralSecurityException e) {
@@ -131,27 +140,22 @@ class CryptoComponentImpl implements CryptoComponent {
} }
} }
public byte[][] deriveInitialSecrets(byte[] ourPublicKey, public byte[] deriveInitialSecret(byte[] ourPublicKey,
byte[] theirPublicKey, PrivateKey ourPrivateKey, int invitationCode, byte[] theirPublicKey, PrivateKey ourPrivateKey, boolean alice) {
boolean initiator) {
try { try {
PublicKey theirPublic = agreementKeyParser.parsePublicKey( PublicKey theirPublic = agreementKeyParser.parsePublicKey(
theirPublicKey); theirPublicKey);
MessageDigest messageDigest = getMessageDigest(); MessageDigest messageDigest = getMessageDigest();
byte[] ourHash = messageDigest.digest(ourPublicKey); byte[] ourHash = messageDigest.digest(ourPublicKey);
byte[] theirHash = messageDigest.digest(theirPublicKey); byte[] theirHash = messageDigest.digest(theirPublicKey);
// The initiator and responder info are hashes of the public keys byte[] aliceInfo, bobInfo;
byte[] initiatorInfo, responderInfo; if(alice) {
if(initiator) { aliceInfo = ourHash;
initiatorInfo = ourHash; bobInfo = theirHash;
responderInfo = theirHash;
} else { } else {
initiatorInfo = theirHash; aliceInfo = theirHash;
responderInfo = ourHash; bobInfo = ourHash;
} }
// The public info is the invitation code as a uint32
byte[] publicInfo = new byte[4];
ByteUtils.writeUint32(invitationCode, publicInfo, 0);
// The raw secret comes from the key agreement algorithm // The raw secret comes from the key agreement algorithm
KeyAgreement keyAgreement = KeyAgreement.getInstance(AGREEMENT_ALGO, KeyAgreement keyAgreement = KeyAgreement.getInstance(AGREEMENT_ALGO,
PROVIDER); PROVIDER);
@@ -160,17 +164,12 @@ class CryptoComponentImpl implements CryptoComponent {
byte[] rawSecret = keyAgreement.generateSecret(); byte[] rawSecret = keyAgreement.generateSecret();
// Derive the cooked secret from the raw secret using the // Derive the cooked secret from the raw secret using the
// concatenation KDF // concatenation KDF
byte[] cookedSecret = concatenationKdf(rawSecret, FIRST, byte[] cookedSecret = concatenationKdf(rawSecret, FIRST, aliceInfo,
initiatorInfo, responderInfo, publicInfo); bobInfo);
ByteUtils.erase(rawSecret); ByteUtils.erase(rawSecret);
// Derive the incoming and outgoing secrets from the cooked secret return cookedSecret;
// using the CTR mode KDF
byte[][] secrets = new byte[2][];
secrets[0] = counterModeKdf(cookedSecret, FIRST, INITIATOR);
secrets[1] = counterModeKdf(cookedSecret, FIRST, RESPONDER);
ByteUtils.erase(cookedSecret);
return secrets;
} catch(GeneralSecurityException e) { } catch(GeneralSecurityException e) {
// FIXME: Throw instead of returning null?
return null; return null;
} }
} }
@@ -178,7 +177,7 @@ class CryptoComponentImpl implements CryptoComponent {
// Key derivation function based on a hash function - see NIST SP 800-56A, // Key derivation function based on a hash function - see NIST SP 800-56A,
// section 5.8 // section 5.8
private byte[] concatenationKdf(byte[] rawSecret, byte[] label, private byte[] concatenationKdf(byte[] rawSecret, byte[] label,
byte[] initiatorInfo, byte[] responderInfo, byte[] publicInfo) { byte[] initiatorInfo, byte[] responderInfo) {
// The output of the hash function must be long enough to use as a key // The output of the hash function must be long enough to use as a key
MessageDigest messageDigest = getMessageDigest(); MessageDigest messageDigest = getMessageDigest();
if(messageDigest.getDigestLength() < SECRET_KEY_BYTES) if(messageDigest.getDigestLength() < SECRET_KEY_BYTES)
@@ -197,9 +196,6 @@ class CryptoComponentImpl implements CryptoComponent {
ByteUtils.writeUint8(responderInfo.length, length, 0); ByteUtils.writeUint8(responderInfo.length, length, 0);
messageDigest.update(length); messageDigest.update(length);
messageDigest.update(responderInfo); messageDigest.update(responderInfo);
ByteUtils.writeUint8(publicInfo.length, length, 0);
messageDigest.update(length);
messageDigest.update(publicInfo);
byte[] hash = messageDigest.digest(); byte[] hash = messageDigest.digest();
// The secret is the first SECRET_KEY_BYTES bytes of the hash // The secret is the first SECRET_KEY_BYTES bytes of the hash
byte[] output = new byte[SECRET_KEY_BYTES]; byte[] output = new byte[SECRET_KEY_BYTES];
@@ -208,22 +204,28 @@ class CryptoComponentImpl implements CryptoComponent {
return output; return output;
} }
public byte[] deriveNextSecret(byte[] secret, int index, long connection) { public byte[] deriveNextSecret(byte[] secret, long period) {
if(index < 0 || index > ByteUtils.MAX_16_BIT_UNSIGNED) if(period < 0 || period > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
if(connection < 0 || connection > ByteUtils.MAX_32_BIT_UNSIGNED) return counterModeKdf(secret, ROTATE, period);
throw new IllegalArgumentException();
byte[] context = new byte[6];
ByteUtils.writeUint16(index, context, 0);
ByteUtils.writeUint32(connection, context, 2);
return counterModeKdf(secret, NEXT, context);
} }
public int deriveConfirmationCode(byte[] secret) { public int generateInvitationCode() {
byte[] output = counterModeKdf(secret, CODE, CODE); int codeBytes = (int) Math.ceil(CODE_BITS / 8.0);
int code = ByteUtils.readUint(output, CODE_BITS); byte[] random = new byte[codeBytes];
ByteUtils.erase(output); secureRandom.nextBytes(random);
return code; return ByteUtils.readUint(random, CODE_BITS);
}
public int[] deriveConfirmationCodes(byte[] secret) {
byte[] alice = counterModeKdf(secret, CODE, 0);
byte[] bob = counterModeKdf(secret, CODE, 1);
int[] codes = new int[2];
codes[0] = ByteUtils.readUint(alice, CODE_BITS);
codes[1] = ByteUtils.readUint(bob, CODE_BITS);
ByteUtils.erase(alice);
ByteUtils.erase(bob);
return codes;
} }
public KeyPair generateAgreementKeyPair() { public KeyPair generateAgreementKeyPair() {

View File

@@ -1,4 +1,4 @@
package net.sf.briar.db; package net.sf.briar.crypto;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;

View File

@@ -1,4 +1,4 @@
package net.sf.briar.db; package net.sf.briar.crypto;
import java.util.Timer; import java.util.Timer;
import java.util.TimerTask; import java.util.TimerTask;

View File

@@ -5,7 +5,9 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.ContactTransport;
import net.sf.briar.api.Rating; import net.sf.briar.api.Rating;
import net.sf.briar.api.TemporarySecret;
import net.sf.briar.api.TransportConfig; import net.sf.briar.api.TransportConfig;
import net.sf.briar.api.TransportProperties; import net.sf.briar.api.TransportProperties;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
@@ -19,9 +21,6 @@ import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWindow;
/** /**
* A low-level interface to the database (DatabaseComponent provides a * A low-level interface to the database (DatabaseComponent provides a
@@ -81,18 +80,19 @@ interface Database<T> {
void addBatchToAck(T txn, ContactId c, BatchId b) throws DbException; void addBatchToAck(T txn, ContactId c, BatchId b) throws DbException;
/** /**
* Adds a new contact to the database with the given secrets and returns an * Adds a new contact to the database and returns an ID for the contact.
* ID for the contact.
* <p> * <p>
* Any secrets generated by the method are stored in the given collection * Locking: contact write, subscription write, transport write.
* and should be erased by the caller once the transaction has been
* committed or aborted.
* <p>
* Locking: contact write, subscription write, transport write,
* window write.
*/ */
ContactId addContact(T txn, byte[] inSecret, byte[] outSecret, ContactId addContact(T txn) throws DbException;
Collection<byte[]> erase) throws DbException;
/**
* Adds a contact transport to the database.
* <p>
* Locking: contact read, window write.
*/
void addContactTransport(T txn, ContactTransport ct)
throws DbException;
/** /**
* Returns false if the given message is already in the database. Otherwise * Returns false if the given message is already in the database. Otherwise
@@ -118,6 +118,15 @@ interface Database<T> {
*/ */
boolean addPrivateMessage(T txn, Message m, ContactId c) throws DbException; boolean addPrivateMessage(T txn, Message m, ContactId c) throws DbException;
/**
* Stores the given temporary secrets and deletes any secrets that have
* been made obsolete.
* <p>
* Locking: contact read, window write.
*/
void addSecrets(T txn, Collection<TemporarySecret> secrets)
throws DbException;
/** /**
* Subscribes to the given group. * Subscribes to the given group.
* <p> * <p>
@@ -132,15 +141,7 @@ interface Database<T> {
* Locking: contact read, subscription write. * Locking: contact read, subscription write.
*/ */
void addSubscription(T txn, ContactId c, Group g, long start) void addSubscription(T txn, ContactId c, Group g, long start)
throws DbException; throws DbException;
/**
* Allocates and returns a local index for the given transport. Returns
* null if all indices have been allocated.
* <p>
* Locking: transport write.
*/
TransportIndex addTransport(T txn, TransportId t) throws DbException;
/** /**
* Makes the given group visible to the given contact. * Makes the given group visible to the given contact.
@@ -156,6 +157,14 @@ interface Database<T> {
*/ */
boolean containsContact(T txn, ContactId c) throws DbException; boolean containsContact(T txn, ContactId c) throws DbException;
/**
* Returns true if the database contains the given contact transport.
* <p>
* Locking: contact read, window read.
*/
boolean containsContactTransport(T txn, ContactId c, TransportId t)
throws DbException;
/** /**
* Returns true if the database contains the given message. * Returns true if the database contains the given message.
* <p> * <p>
@@ -177,7 +186,7 @@ interface Database<T> {
* Locking: subscription read. * Locking: subscription read.
*/ */
boolean containsSubscription(T txn, GroupId g, long time) boolean containsSubscription(T txn, GroupId g, long time)
throws DbException; throws DbException;
/** /**
* Returns true if the user is subscribed to the given group, the group is * Returns true if the user is subscribed to the given group, the group is
@@ -196,7 +205,7 @@ interface Database<T> {
* Locking: contact read, messageStatus read. * Locking: contact read, messageStatus read.
*/ */
Collection<BatchId> getBatchesToAck(T txn, ContactId c, int maxBatches) Collection<BatchId> getBatchesToAck(T txn, ContactId c, int maxBatches)
throws DbException; throws DbException;
/** /**
* Returns the configuration for the given transport. * Returns the configuration for the given transport.
@@ -205,28 +214,6 @@ interface Database<T> {
*/ */
TransportConfig getConfig(T txn, TransportId t) throws DbException; TransportConfig getConfig(T txn, TransportId t) throws DbException;
/**
* Returns an outgoing connection context for the given contact and
* transport.
* <p>
* Any secrets generated by the method are stored in the given collection
* and should be erased by the caller once the transaction has been
* committed or aborted.
* <p>
* Locking: contact read, window write.
*/
ConnectionContext getConnectionContext(T txn, ContactId c, TransportIndex i,
Collection<byte[]> erase) throws DbException;
/**
* Returns the connection reordering window for the given contact and
* transport.
* <p>
* Locking: contact read, window read.
*/
ConnectionWindow getConnectionWindow(T txn, ContactId c, TransportIndex i)
throws DbException;
/** /**
* Returns the IDs of all contacts. * Returns the IDs of all contacts.
* <p> * <p>
@@ -234,6 +221,13 @@ interface Database<T> {
*/ */
Collection<ContactId> getContacts(T txn) throws DbException; Collection<ContactId> getContacts(T txn) throws DbException;
/**
* Returns all contact transports.
* <p>
* Locking: contact read, window read.
*/
Collection<ContactTransport> getContactTransports(T txn) throws DbException;
/** /**
* Returns the approximate expiry time of the database. * Returns the approximate expiry time of the database.
* <p> * <p>
@@ -259,21 +253,13 @@ interface Database<T> {
*/ */
MessageId getGroupMessageParent(T txn, MessageId m) throws DbException; MessageId getGroupMessageParent(T txn, MessageId m) throws DbException;
/**
* Returns the local index for the given transport, or null if no index
* has been allocated.
* <p>
* Locking: transport read.
*/
TransportIndex getLocalIndex(T txn, TransportId t) throws DbException;
/** /**
* Returns the local transport properties for the given transport. * Returns the local transport properties for the given transport.
* <p> * <p>
* Locking: transport read. * Locking: transport read.
*/ */
TransportProperties getLocalProperties(T txn, TransportId t) TransportProperties getLocalProperties(T txn, TransportId t)
throws DbException; throws DbException;
/** /**
* Returns all local transports. * Returns all local transports.
@@ -310,7 +296,7 @@ interface Database<T> {
* Locking: message read, messageFlag read. * Locking: message read, messageFlag read.
*/ */
Collection<MessageHeader> getMessageHeaders(T txn, GroupId g) Collection<MessageHeader> getMessageHeaders(T txn, GroupId g)
throws DbException; throws DbException;
/** /**
* Returns the message identified by the given ID, in raw format, or null * Returns the message identified by the given ID, in raw format, or null
@@ -321,7 +307,7 @@ interface Database<T> {
* subscription read. * subscription read.
*/ */
byte[] getMessageIfSendable(T txn, ContactId c, MessageId m) byte[] getMessageIfSendable(T txn, ContactId c, MessageId m)
throws DbException; throws DbException;
/** /**
* Returns the IDs of all messages signed by the given author. * Returns the IDs of all messages signed by the given author.
@@ -329,7 +315,7 @@ interface Database<T> {
* Locking: message read. * Locking: message read.
*/ */
Collection<MessageId> getMessagesByAuthor(T txn, AuthorId a) Collection<MessageId> getMessagesByAuthor(T txn, AuthorId a)
throws DbException; throws DbException;
/** /**
* Returns the number of children of the message identified by the given * Returns the number of children of the message identified by the given
@@ -372,15 +358,6 @@ interface Database<T> {
*/ */
boolean getRead(T txn, MessageId m) throws DbException; boolean getRead(T txn, MessageId m) throws DbException;
/**
* Returns the given contact's index for the given transport, or null if
* the contact does not support the transport.
* <p>
* Locking: contact read, window read.
*/
TransportIndex getRemoteIndex(T txn, ContactId c, TransportId t)
throws DbException;
/** /**
* Returns all remote properties for the given transport. * Returns all remote properties for the given transport.
* <p> * <p>
@@ -404,7 +381,7 @@ interface Database<T> {
* subscription read. * subscription read.
*/ */
Collection<MessageId> getSendableMessages(T txn, ContactId c, int capacity) Collection<MessageId> getSendableMessages(T txn, ContactId c, int capacity)
throws DbException; throws DbException;
/** /**
* Returns true if the given message has been starred. * Returns true if the given message has been starred.
@@ -464,7 +441,7 @@ interface Database<T> {
* Locking: contact read, subscription read. * Locking: contact read, subscription read.
*/ */
Map<GroupId, GroupId> getVisibleHoles(T txn, ContactId c, long timestamp) Map<GroupId, GroupId> getVisibleHoles(T txn, ContactId c, long timestamp)
throws DbException; throws DbException;
/** /**
* Returns any subscriptions that are visible to the given contact, * Returns any subscriptions that are visible to the given contact,
@@ -474,7 +451,7 @@ interface Database<T> {
* Locking: contact read, subscription read. * Locking: contact read, subscription read.
*/ */
Map<Group, Long> getVisibleSubscriptions(T txn, ContactId c, long timestamp) Map<Group, Long> getVisibleSubscriptions(T txn, ContactId c, long timestamp)
throws DbException; throws DbException;
/** /**
* Returns true if any messages are sendable to the given contact. * Returns true if any messages are sendable to the given contact.
@@ -483,6 +460,15 @@ interface Database<T> {
*/ */
boolean hasSendableMessages(T txn, ContactId c) throws DbException; boolean hasSendableMessages(T txn, ContactId c) throws DbException;
/**
* Increments the outgoing connection counter for the given contact
* transport in the given rotation period.
* <p>
* Locking: contact read, window write.
*/
void incrementConnectionCounter(T txn, ContactId c, TransportId t,
long period) throws DbException;
/** /**
* Removes an outstanding batch that has been acknowledged. Any messages in * Removes an outstanding batch that has been acknowledged. Any messages in
* the batch that are still considered outstanding (Status.SENT) with * the batch that are still considered outstanding (Status.SENT) with
@@ -499,7 +485,7 @@ interface Database<T> {
* Locking: contact read, messageStatus write. * Locking: contact read, messageStatus write.
*/ */
void removeBatchesToAck(T txn, ContactId c, Collection<BatchId> sent) void removeBatchesToAck(T txn, ContactId c, Collection<BatchId> sent)
throws DbException; throws DbException;
/** /**
* Removes a contact (and all associated state) from the database. * Removes a contact (and all associated state) from the database.
@@ -543,7 +529,7 @@ interface Database<T> {
* with IDs greater than the first are removed. * with IDs greater than the first are removed.
*/ */
void removeSubscriptions(T txn, ContactId c, GroupId start, GroupId end) void removeSubscriptions(T txn, ContactId c, GroupId start, GroupId end)
throws DbException; throws DbException;
/** /**
* Makes the given group invisible to the given contact. * Makes the given group invisible to the given contact.
@@ -559,16 +545,16 @@ interface Database<T> {
* Locking: transport write. * Locking: transport write.
*/ */
void setConfig(T txn, TransportId t, TransportConfig config) void setConfig(T txn, TransportId t, TransportConfig config)
throws DbException; throws DbException;
/** /**
* Sets the connection reordering window for the given contact and * Sets the connection reordering window for the given contact transport in
* transport. * the given rotation period.
* <p> * <p>
* Locking: contact read, window write. * Locking: contact read, window write.
*/ */
void setConnectionWindow(T txn, ContactId c, TransportIndex i, void setConnectionWindow(T txn, ContactId c, TransportId t, long period,
ConnectionWindow w) throws DbException; long centre, byte[] bitmap) throws DbException;
/** /**
* Sets the given contact's database expiry time. * Sets the given contact's database expiry time.
@@ -584,7 +570,7 @@ interface Database<T> {
* Locking: transport write. * Locking: transport write.
*/ */
void setLocalProperties(T txn, TransportId t, TransportProperties p) void setLocalProperties(T txn, TransportId t, TransportProperties p)
throws DbException; throws DbException;
/** /**
* Sets the user's rating for the given author. * Sets the user's rating for the given author.
@@ -622,7 +608,7 @@ interface Database<T> {
* Locking: contact read, message read, messageStatus write. * Locking: contact read, message read, messageStatus write.
*/ */
void setStatus(T txn, ContactId c, MessageId m, Status s) void setStatus(T txn, ContactId c, MessageId m, Status s)
throws DbException; throws DbException;
/** /**
* If the database contains the given message and it belongs to a group * If the database contains the given message and it belongs to a group
@@ -634,7 +620,7 @@ interface Database<T> {
* subscription read. * subscription read.
*/ */
boolean setStatusSeenIfVisible(T txn, ContactId c, MessageId m) boolean setStatusSeenIfVisible(T txn, ContactId c, MessageId m)
throws DbException; throws DbException;
/** /**
* Records the time of the latest subscription update acknowledged by the * Records the time of the latest subscription update acknowledged by the
@@ -643,7 +629,7 @@ interface Database<T> {
* Locking: contact read, subscription write. * Locking: contact read, subscription write.
*/ */
void setSubscriptionsAcked(T txn, ContactId c, long timestamp) void setSubscriptionsAcked(T txn, ContactId c, long timestamp)
throws DbException; throws DbException;
/** /**
* Records the time of the latest subscription update received from the * Records the time of the latest subscription update received from the
@@ -652,7 +638,7 @@ interface Database<T> {
* Locking: contact read, subscription write. * Locking: contact read, subscription write.
*/ */
void setSubscriptionsReceived(T txn, ContactId c, long timestamp) void setSubscriptionsReceived(T txn, ContactId c, long timestamp)
throws DbException; throws DbException;
/** /**
* Sets the transports for the given contact, replacing any existing * Sets the transports for the given contact, replacing any existing
@@ -677,5 +663,5 @@ interface Database<T> {
* Locking: contact read, transport write. * Locking: contact read, transport write.
*/ */
void setTransportsSent(T txn, ContactId c, long timestamp) void setTransportsSent(T txn, ContactId c, long timestamp)
throws DbException; throws DbException;
} }

View File

@@ -22,7 +22,9 @@ import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.ContactTransport;
import net.sf.briar.api.Rating; import net.sf.briar.api.Rating;
import net.sf.briar.api.TemporarySecret;
import net.sf.briar.api.TransportConfig; import net.sf.briar.api.TransportConfig;
import net.sf.briar.api.TransportProperties; import net.sf.briar.api.TransportProperties;
import net.sf.briar.api.clock.Clock; import net.sf.briar.api.clock.Clock;
@@ -30,6 +32,7 @@ import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.MessageHeader; import net.sf.briar.api.db.MessageHeader;
import net.sf.briar.api.db.NoSuchContactException; import net.sf.briar.api.db.NoSuchContactException;
import net.sf.briar.api.db.NoSuchContactTransportException;
import net.sf.briar.api.db.Status; import net.sf.briar.api.db.Status;
import net.sf.briar.api.db.event.BatchReceivedEvent; import net.sf.briar.api.db.event.BatchReceivedEvent;
import net.sf.briar.api.db.event.ContactAddedEvent; import net.sf.briar.api.db.event.ContactAddedEvent;
@@ -41,7 +44,6 @@ import net.sf.briar.api.db.event.MessagesAddedEvent;
import net.sf.briar.api.db.event.RatingChangedEvent; import net.sf.briar.api.db.event.RatingChangedEvent;
import net.sf.briar.api.db.event.RemoteTransportsUpdatedEvent; import net.sf.briar.api.db.event.RemoteTransportsUpdatedEvent;
import net.sf.briar.api.db.event.SubscriptionsUpdatedEvent; import net.sf.briar.api.db.event.SubscriptionsUpdatedEvent;
import net.sf.briar.api.db.event.TransportAddedEvent;
import net.sf.briar.api.lifecycle.ShutdownManager; import net.sf.briar.api.lifecycle.ShutdownManager;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.AuthorId; import net.sf.briar.api.protocol.AuthorId;
@@ -58,10 +60,7 @@ import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.util.ByteUtils; import net.sf.briar.util.ByteUtils;
import com.google.inject.Inject; import com.google.inject.Inject;
@@ -76,7 +75,7 @@ class DatabaseComponentImpl<T> implements DatabaseComponent,
DatabaseCleaner.Callback { DatabaseCleaner.Callback {
private static final Logger LOG = private static final Logger LOG =
Logger.getLogger(DatabaseComponentImpl.class.getName()); Logger.getLogger(DatabaseComponentImpl.class.getName());
/* /*
* Locks must always be acquired in alphabetical order. See the Database * Locks must always be acquired in alphabetical order. See the Database
@@ -84,21 +83,21 @@ DatabaseCleaner.Callback {
*/ */
private final ReentrantReadWriteLock contactLock = private final ReentrantReadWriteLock contactLock =
new ReentrantReadWriteLock(true); new ReentrantReadWriteLock(true);
private final ReentrantReadWriteLock messageLock = private final ReentrantReadWriteLock messageLock =
new ReentrantReadWriteLock(true); new ReentrantReadWriteLock(true);
private final ReentrantReadWriteLock messageFlagLock = private final ReentrantReadWriteLock messageFlagLock =
new ReentrantReadWriteLock(true); new ReentrantReadWriteLock(true);
private final ReentrantReadWriteLock messageStatusLock = private final ReentrantReadWriteLock messageStatusLock =
new ReentrantReadWriteLock(true); new ReentrantReadWriteLock(true);
private final ReentrantReadWriteLock ratingLock = private final ReentrantReadWriteLock ratingLock =
new ReentrantReadWriteLock(true); new ReentrantReadWriteLock(true);
private final ReentrantReadWriteLock subscriptionLock = private final ReentrantReadWriteLock subscriptionLock =
new ReentrantReadWriteLock(true); new ReentrantReadWriteLock(true);
private final ReentrantReadWriteLock transportLock = private final ReentrantReadWriteLock transportLock =
new ReentrantReadWriteLock(true); new ReentrantReadWriteLock(true);
private final ReentrantReadWriteLock windowLock = private final ReentrantReadWriteLock windowLock =
new ReentrantReadWriteLock(true); new ReentrantReadWriteLock(true);
private final Database<T> db; private final Database<T> db;
private final DatabaseCleaner cleaner; private final DatabaseCleaner cleaner;
@@ -107,7 +106,7 @@ DatabaseCleaner.Callback {
private final Clock clock; private final Clock clock;
private final Collection<DatabaseListener> listeners = private final Collection<DatabaseListener> listeners =
new CopyOnWriteArrayList<DatabaseListener>(); new CopyOnWriteArrayList<DatabaseListener>();
private final Object spaceLock = new Object(); private final Object spaceLock = new Object();
private long bytesStoredSinceLastCheck = 0L; // Locking: spaceLock private long bytesStoredSinceLastCheck = 0L; // Locking: spaceLock
@@ -172,8 +171,7 @@ DatabaseCleaner.Callback {
listeners.remove(d); listeners.remove(d);
} }
public ContactId addContact(byte[] inSecret, byte[] outSecret) public ContactId addContact() throws DbException {
throws DbException {
ContactId c; ContactId c;
Collection<byte[]> erase = new ArrayList<byte[]>(); Collection<byte[]> erase = new ArrayList<byte[]>();
contactLock.writeLock().lock(); contactLock.writeLock().lock();
@@ -186,7 +184,7 @@ DatabaseCleaner.Callback {
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
c = db.addContact(txn, inSecret, outSecret, erase); c = db.addContact(txn);
db.commitTransaction(txn); db.commitTransaction(txn);
} catch(DbException e) { } catch(DbException e) {
db.abortTransaction(txn); db.abortTransaction(txn);
@@ -266,7 +264,7 @@ DatabaseCleaner.Callback {
* @param sender may be null for a locally generated message. * @param sender may be null for a locally generated message.
*/ */
private boolean storeGroupMessage(T txn, Message m, ContactId sender) private boolean storeGroupMessage(T txn, Message m, ContactId sender)
throws DbException { throws DbException {
if(m.getGroup() == null) throw new IllegalArgumentException(); if(m.getGroup() == null) throw new IllegalArgumentException();
boolean stored = db.addGroupMessage(txn, m); boolean stored = db.addGroupMessage(txn, m);
// Mark the message as seen by the sender // Mark the message as seen by the sender
@@ -315,7 +313,7 @@ DatabaseCleaner.Callback {
* greater than 0, or false if it has changed from greater than 0 to 0. * greater than 0, or false if it has changed from greater than 0 to 0.
*/ */
private int updateAncestorSendability(T txn, MessageId m, boolean increment) private int updateAncestorSendability(T txn, MessageId m, boolean increment)
throws DbException { throws DbException {
int affected = 0; int affected = 0;
boolean changed = true; boolean changed = true;
while(changed) { while(changed) {
@@ -343,17 +341,18 @@ DatabaseCleaner.Callback {
} }
public void addLocalPrivateMessage(Message m, ContactId c) public void addLocalPrivateMessage(Message m, ContactId c)
throws DbException { throws DbException {
boolean added = false; boolean added = false;
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
messageLock.writeLock().lock(); messageLock.writeLock().lock();
try { try {
messageStatusLock.writeLock().lock(); messageStatusLock.writeLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
added = storePrivateMessage(txn, m, c, false); added = storePrivateMessage(txn, m, c, false);
db.commitTransaction(txn); db.commitTransaction(txn);
} catch(DbException e) { } catch(DbException e) {
@@ -373,6 +372,36 @@ DatabaseCleaner.Callback {
if(added) callListeners(new MessagesAddedEvent()); if(added) callListeners(new MessagesAddedEvent());
} }
public void addSecrets(Collection<TemporarySecret> secrets)
throws DbException {
contactLock.readLock().lock();
try {
windowLock.writeLock().lock();
try {
T txn = db.startTransaction();
try {
Collection<TemporarySecret> relevant =
new ArrayList<TemporarySecret>();
for(TemporarySecret s : secrets) {
ContactId c = s.getContactId();
TransportId t = s.getTransportId();
if(db.containsContactTransport(txn, c, t))
relevant.add(s);
}
if(!secrets.isEmpty()) db.addSecrets(txn, relevant);
db.commitTransaction(txn);
} catch(DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
windowLock.writeLock().unlock();
}
} finally {
contactLock.writeLock().unlock();
}
}
/** /**
* If the given message is already in the database, returns false. * If the given message is already in the database, returns false.
* Otherwise stores the message and marks it as new or seen with respect to * Otherwise stores the message and marks it as new or seen with respect to
@@ -396,52 +425,16 @@ DatabaseCleaner.Callback {
return true; return true;
} }
/**
* Returns true if the database contains the given contact.
* <p>
* Locking: contact read.
*/
private boolean containsContact(ContactId c) throws DbException {
T txn = db.startTransaction();
try {
boolean contains = db.containsContact(txn, c);
db.commitTransaction(txn);
return contains;
} catch(DbException e) {
db.abortTransaction(txn);
throw e;
}
}
public TransportIndex addTransport(TransportId t) throws DbException {
TransportIndex i;
transportLock.writeLock().lock();
try {
T txn = db.startTransaction();
try {
i = db.addTransport(txn, t);
db.commitTransaction(txn);
} catch(DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
transportLock.writeLock().unlock();
}
// Call the listeners outside the lock
if(i != null) callListeners(new TransportAddedEvent(t));
return i;
}
public Ack generateAck(ContactId c, int maxBatches) throws DbException { public Ack generateAck(ContactId c, int maxBatches) throws DbException {
Collection<BatchId> acked; Collection<BatchId> acked;
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
messageStatusLock.readLock().lock(); messageStatusLock.readLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
acked = db.getBatchesToAck(txn, c, maxBatches); acked = db.getBatchesToAck(txn, c, maxBatches);
db.commitTransaction(txn); db.commitTransaction(txn);
} catch(DbException e) { } catch(DbException e) {
@@ -473,14 +466,13 @@ DatabaseCleaner.Callback {
} }
public RawBatch generateBatch(ContactId c, int capacity) public RawBatch generateBatch(ContactId c, int capacity)
throws DbException { throws DbException {
Collection<MessageId> ids; Collection<MessageId> ids;
List<byte[]> messages = new ArrayList<byte[]>(); List<byte[]> messages = new ArrayList<byte[]>();
RawBatch b; RawBatch b;
// Get some sendable messages from the database // Get some sendable messages from the database
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
messageLock.readLock().lock(); messageLock.readLock().lock();
try { try {
messageStatusLock.readLock().lock(); messageStatusLock.readLock().lock();
@@ -489,6 +481,8 @@ DatabaseCleaner.Callback {
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
ids = db.getSendableMessages(txn, c, capacity); ids = db.getSendableMessages(txn, c, capacity);
for(MessageId m : ids) { for(MessageId m : ids) {
messages.add(db.getMessage(txn, m)); messages.add(db.getMessage(txn, m));
@@ -537,7 +531,6 @@ DatabaseCleaner.Callback {
// Get some sendable messages from the database // Get some sendable messages from the database
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
messageLock.readLock().lock(); messageLock.readLock().lock();
try { try {
messageStatusLock.readLock().lock(); messageStatusLock.readLock().lock();
@@ -546,6 +539,8 @@ DatabaseCleaner.Callback {
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
Iterator<MessageId> it = requested.iterator(); Iterator<MessageId> it = requested.iterator();
while(it.hasNext()) { while(it.hasNext()) {
MessageId m = it.next(); MessageId m = it.next();
@@ -595,17 +590,18 @@ DatabaseCleaner.Callback {
} }
public Offer generateOffer(ContactId c, int maxMessages) public Offer generateOffer(ContactId c, int maxMessages)
throws DbException { throws DbException {
Collection<MessageId> offered; Collection<MessageId> offered;
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
messageLock.readLock().lock(); messageLock.readLock().lock();
try { try {
messageStatusLock.readLock().lock(); messageStatusLock.readLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
offered = db.getOfferableMessages(txn, c, maxMessages); offered = db.getOfferableMessages(txn, c, maxMessages);
db.commitTransaction(txn); db.commitTransaction(txn);
} catch(DbException e) { } catch(DbException e) {
@@ -625,17 +621,18 @@ DatabaseCleaner.Callback {
} }
public SubscriptionUpdate generateSubscriptionUpdate(ContactId c) public SubscriptionUpdate generateSubscriptionUpdate(ContactId c)
throws DbException { throws DbException {
Map<GroupId, GroupId> holes; Map<GroupId, GroupId> holes;
Map<Group, Long> subs; Map<Group, Long> subs;
long expiry, timestamp; long expiry, timestamp;
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
subscriptionLock.readLock().lock(); subscriptionLock.readLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
timestamp = clock.currentTimeMillis() - 1; timestamp = clock.currentTimeMillis() - 1;
holes = db.getVisibleHoles(txn, c, timestamp); holes = db.getVisibleHoles(txn, c, timestamp);
subs = db.getVisibleSubscriptions(txn, c, timestamp); subs = db.getVisibleSubscriptions(txn, c, timestamp);
@@ -661,17 +658,18 @@ DatabaseCleaner.Callback {
} }
public TransportUpdate generateTransportUpdate(ContactId c) public TransportUpdate generateTransportUpdate(ContactId c)
throws DbException { throws DbException {
boolean due; boolean due;
Collection<Transport> transports; Collection<Transport> transports;
long timestamp; long timestamp;
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
transportLock.readLock().lock(); transportLock.readLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
// Work out whether an update is due // Work out whether an update is due
long modified = db.getTransportsModified(txn); long modified = db.getTransportsModified(txn);
long sent = db.getTransportsSent(txn, c); long sent = db.getTransportsSent(txn, c);
@@ -723,58 +721,6 @@ DatabaseCleaner.Callback {
} }
} }
public ConnectionContext getConnectionContext(ContactId c, TransportIndex i)
throws DbException {
Collection<byte[]> erase = new ArrayList<byte[]>();
contactLock.readLock().lock();
try {
if(!containsContact(c)) throw new NoSuchContactException();
windowLock.writeLock().lock();
try {
T txn = db.startTransaction();
try {
ConnectionContext ctx =
db.getConnectionContext(txn, c, i, erase);
db.commitTransaction(txn);
return ctx;
} catch(DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
windowLock.writeLock().unlock();
}
} finally {
contactLock.readLock().unlock();
// Erase the secrets after committing or aborting the transaction
for(byte[] b : erase) ByteUtils.erase(b);
}
}
public ConnectionWindow getConnectionWindow(ContactId c, TransportIndex i)
throws DbException {
contactLock.readLock().lock();
try {
if(!containsContact(c)) throw new NoSuchContactException();
windowLock.readLock().lock();
try {
T txn = db.startTransaction();
try {
ConnectionWindow w = db.getConnectionWindow(txn, c, i);
db.commitTransaction(txn);
return w;
} catch(DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
windowLock.readLock().unlock();
}
} finally {
contactLock.readLock().unlock();
}
}
public Collection<ContactId> getContacts() throws DbException { public Collection<ContactId> getContacts() throws DbException {
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
@@ -792,32 +738,39 @@ DatabaseCleaner.Callback {
} }
} }
public TransportIndex getLocalIndex(TransportId t) throws DbException { public Collection<ContactTransport> getContactTransports()
transportLock.readLock().lock(); throws DbException {
contactLock.readLock().lock();
try { try {
T txn = db.startTransaction(); windowLock.readLock().lock();
try { try {
TransportIndex i = db.getLocalIndex(txn, t); T txn = db.startTransaction();
db.commitTransaction(txn); try {
return i; Collection<ContactTransport> contactTransports =
} catch(DbException e) { db.getContactTransports(txn);
db.abortTransaction(txn); db.commitTransaction(txn);
throw e; return contactTransports;
} catch(DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
windowLock.readLock().unlock();
} }
} finally { } finally {
transportLock.readLock().unlock(); contactLock.readLock().unlock();
} }
} }
public TransportProperties getLocalProperties(TransportId t) public TransportProperties getLocalProperties(TransportId t)
throws DbException { throws DbException {
transportLock.readLock().lock(); transportLock.readLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
TransportProperties p = db.getLocalProperties(txn, t); TransportProperties properties = db.getLocalProperties(txn, t);
db.commitTransaction(txn); db.commitTransaction(txn);
return p; return properties;
} catch(DbException e) { } catch(DbException e) {
db.abortTransaction(txn); db.abortTransaction(txn);
throw e; throw e;
@@ -845,7 +798,7 @@ DatabaseCleaner.Callback {
} }
public Collection<MessageHeader> getMessageHeaders(GroupId g) public Collection<MessageHeader> getMessageHeaders(GroupId g)
throws DbException { throws DbException {
messageLock.readLock().lock(); messageLock.readLock().lock();
try { try {
messageFlagLock.readLock().lock(); messageFlagLock.readLock().lock();
@@ -853,7 +806,7 @@ DatabaseCleaner.Callback {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
Collection<MessageHeader> headers = Collection<MessageHeader> headers =
db.getMessageHeaders(txn, g); db.getMessageHeaders(txn, g);
db.commitTransaction(txn); db.commitTransaction(txn);
return headers; return headers;
} catch(DbException e) { } catch(DbException e) {
@@ -885,30 +838,6 @@ DatabaseCleaner.Callback {
} }
} }
public TransportIndex getRemoteIndex(ContactId c, TransportId t)
throws DbException {
contactLock.readLock().lock();
try {
if(!containsContact(c)) throw new NoSuchContactException();
transportLock.readLock().lock();
try {
T txn = db.startTransaction();
try {
TransportIndex i = db.getRemoteIndex(txn, c, t);
db.commitTransaction(txn);
return i;
} catch(DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
transportLock.readLock().unlock();
}
} finally {
contactLock.readLock().unlock();
}
}
public Map<ContactId, TransportProperties> getRemoteProperties( public Map<ContactId, TransportProperties> getRemoteProperties(
TransportId t) throws DbException { TransportId t) throws DbException {
contactLock.readLock().lock(); contactLock.readLock().lock();
@@ -918,7 +847,7 @@ DatabaseCleaner.Callback {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
Map<ContactId, TransportProperties> properties = Map<ContactId, TransportProperties> properties =
db.getRemoteProperties(txn, t); db.getRemoteProperties(txn, t);
db.commitTransaction(txn); db.commitTransaction(txn);
return properties; return properties;
} catch(DbException e) { } catch(DbException e) {
@@ -960,7 +889,7 @@ DatabaseCleaner.Callback {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
Map<GroupId, Integer> counts = Map<GroupId, Integer> counts =
db.getUnreadMessageCounts(txn); db.getUnreadMessageCounts(txn);
db.commitTransaction(txn); db.commitTransaction(txn);
return counts; return counts;
} catch(DbException e) { } catch(DbException e) {
@@ -1003,7 +932,6 @@ DatabaseCleaner.Callback {
public boolean hasSendableMessages(ContactId c) throws DbException { public boolean hasSendableMessages(ContactId c) throws DbException {
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
messageLock.readLock().lock(); messageLock.readLock().lock();
try { try {
messageStatusLock.readLock().lock(); messageStatusLock.readLock().lock();
@@ -1012,6 +940,8 @@ DatabaseCleaner.Callback {
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
boolean has = db.hasSendableMessages(txn, c); boolean has = db.hasSendableMessages(txn, c);
db.commitTransaction(txn); db.commitTransaction(txn);
return has; return has;
@@ -1033,17 +963,41 @@ DatabaseCleaner.Callback {
} }
} }
public void incrementConnectionCounter(ContactId c, TransportId t,
long period) throws DbException {
contactLock.readLock().lock();
try {
windowLock.writeLock().lock();
try {
T txn = db.startTransaction();
try {
if(!db.containsContactTransport(txn, c, t))
throw new NoSuchContactTransportException();
db.incrementConnectionCounter(txn, c, t, period);
db.commitTransaction(txn);
} catch(DbException e) {
db.abortTransaction(txn);
}
} finally {
windowLock.writeLock().unlock();
}
} finally {
contactLock.readLock().unlock();
}
}
public void receiveAck(ContactId c, Ack a) throws DbException { public void receiveAck(ContactId c, Ack a) throws DbException {
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
messageLock.readLock().lock(); messageLock.readLock().lock();
try { try {
messageStatusLock.writeLock().lock(); messageStatusLock.writeLock().lock();
try { try {
Collection<BatchId> acks = a.getBatchIds();
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
Collection<BatchId> acks = a.getBatchIds();
// Mark all messages in acked batches as seen // Mark all messages in acked batches as seen
for(BatchId b : acks) db.removeAckedBatch(txn, c, b); for(BatchId b : acks) db.removeAckedBatch(txn, c, b);
// Find any lost batches that need to be retransmitted // Find any lost batches that need to be retransmitted
@@ -1069,7 +1023,6 @@ DatabaseCleaner.Callback {
boolean anyAdded = false; boolean anyAdded = false;
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
messageLock.writeLock().lock(); messageLock.writeLock().lock();
try { try {
messageStatusLock.writeLock().lock(); messageStatusLock.writeLock().lock();
@@ -1078,6 +1031,8 @@ DatabaseCleaner.Callback {
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
anyAdded = storeMessages(txn, c, b.getMessages()); anyAdded = storeMessages(txn, c, b.getMessages());
db.addBatchToAck(txn, c, b.getId()); db.addBatchToAck(txn, c, b.getId());
db.commitTransaction(txn); db.commitTransaction(txn);
@@ -1131,7 +1086,6 @@ DatabaseCleaner.Callback {
BitSet request; BitSet request;
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
messageLock.readLock().lock(); messageLock.readLock().lock();
try { try {
messageStatusLock.writeLock().lock(); messageStatusLock.writeLock().lock();
@@ -1140,6 +1094,8 @@ DatabaseCleaner.Callback {
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
offered = o.getMessageIds(); offered = o.getMessageIds();
request = new BitSet(offered.size()); request = new BitSet(offered.size());
Iterator<MessageId> it = offered.iterator(); Iterator<MessageId> it = offered.iterator();
@@ -1171,15 +1127,16 @@ DatabaseCleaner.Callback {
} }
public void receiveSubscriptionUpdate(ContactId c, SubscriptionUpdate s) public void receiveSubscriptionUpdate(ContactId c, SubscriptionUpdate s)
throws DbException { throws DbException {
// Update the contact's subscriptions // Update the contact's subscriptions
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
subscriptionLock.writeLock().lock(); subscriptionLock.writeLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
Map<GroupId, GroupId> holes = s.getHoles(); Map<GroupId, GroupId> holes = s.getHoles();
for(Entry<GroupId, GroupId> e : holes.entrySet()) { for(Entry<GroupId, GroupId> e : holes.entrySet()) {
GroupId start = e.getKey(), end = e.getValue(); GroupId start = e.getKey(), end = e.getValue();
@@ -1208,16 +1165,17 @@ DatabaseCleaner.Callback {
} }
public void receiveTransportUpdate(ContactId c, TransportUpdate t) public void receiveTransportUpdate(ContactId c, TransportUpdate t)
throws DbException { throws DbException {
Collection<Transport> transports; Collection<Transport> transports;
// Update the contact's transport properties // Update the contact's transport properties
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
transportLock.writeLock().lock(); transportLock.writeLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
transports = t.getTransports(); transports = t.getTransports();
db.setTransports(txn, c, transports, t.getTimestamp()); db.setTransports(txn, c, transports, t.getTimestamp());
db.commitTransaction(txn); db.commitTransaction(txn);
@@ -1238,7 +1196,6 @@ DatabaseCleaner.Callback {
public void removeContact(ContactId c) throws DbException { public void removeContact(ContactId c) throws DbException {
contactLock.writeLock().lock(); contactLock.writeLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
messageLock.writeLock().lock(); messageLock.writeLock().lock();
try { try {
messageFlagLock.writeLock().lock(); messageFlagLock.writeLock().lock();
@@ -1253,6 +1210,8 @@ DatabaseCleaner.Callback {
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
db.removeContact(txn, c); db.removeContact(txn, c);
db.commitTransaction(txn); db.commitTransaction(txn);
} catch(DbException e) { } catch(DbException e) {
@@ -1285,7 +1244,7 @@ DatabaseCleaner.Callback {
} }
public void setConfig(TransportId t, TransportConfig c) public void setConfig(TransportId t, TransportConfig c)
throws DbException { throws DbException {
transportLock.writeLock().lock(); transportLock.writeLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
@@ -1301,16 +1260,17 @@ DatabaseCleaner.Callback {
} }
} }
public void setConnectionWindow(ContactId c, TransportIndex i, public void setConnectionWindow(ContactId c, TransportId t, long period,
ConnectionWindow w) throws DbException { long centre, byte[] bitmap) throws DbException {
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
windowLock.writeLock().lock(); windowLock.writeLock().lock();
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
db.setConnectionWindow(txn, c, i, w); if(!db.containsContactTransport(txn, c, t))
throw new NoSuchContactTransportException();
db.setConnectionWindow(txn, c, t, period, centre, bitmap);
db.commitTransaction(txn); db.commitTransaction(txn);
} catch(DbException e) { } catch(DbException e) {
db.abortTransaction(txn); db.abortTransaction(txn);
@@ -1324,7 +1284,7 @@ DatabaseCleaner.Callback {
} }
public void setLocalProperties(TransportId t, TransportProperties p) public void setLocalProperties(TransportId t, TransportProperties p)
throws DbException { throws DbException {
boolean changed = false; boolean changed = false;
transportLock.writeLock().lock(); transportLock.writeLock().lock();
try { try {
@@ -1378,10 +1338,9 @@ DatabaseCleaner.Callback {
} }
public void setSeen(ContactId c, Collection<MessageId> seen) public void setSeen(ContactId c, Collection<MessageId> seen)
throws DbException { throws DbException {
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
if(!containsContact(c)) throw new NoSuchContactException();
messageLock.readLock().lock(); messageLock.readLock().lock();
try { try {
messageStatusLock.writeLock().lock(); messageStatusLock.writeLock().lock();
@@ -1390,6 +1349,8 @@ DatabaseCleaner.Callback {
try { try {
T txn = db.startTransaction(); T txn = db.startTransaction();
try { try {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
for(MessageId m : seen) { for(MessageId m : seen) {
db.setStatusSeenIfVisible(txn, c, m); db.setStatusSeenIfVisible(txn, c, m);
} }
@@ -1421,7 +1382,7 @@ DatabaseCleaner.Callback {
* from not good to good, or false if it has changed from good to not good. * from not good to good, or false if it has changed from good to not good.
*/ */
private void updateAuthorSendability(T txn, AuthorId a, boolean increment) private void updateAuthorSendability(T txn, AuthorId a, boolean increment)
throws DbException { throws DbException {
for(MessageId id : db.getMessagesByAuthor(txn, a)) { for(MessageId id : db.getMessagesByAuthor(txn, a)) {
int sendability = db.getSendability(txn, id); int sendability = db.getSendability(txn, id);
if(increment) { if(increment) {
@@ -1438,7 +1399,7 @@ DatabaseCleaner.Callback {
} }
public void setVisibility(GroupId g, Collection<ContactId> visible) public void setVisibility(GroupId g, Collection<ContactId> visible)
throws DbException { throws DbException {
List<ContactId> affected = new ArrayList<ContactId>(); List<ContactId> affected = new ArrayList<ContactId>();
contactLock.readLock().lock(); contactLock.readLock().lock();
try { try {
@@ -1619,4 +1580,8 @@ DatabaseCleaner.Callback {
} }
return false; return false;
} }
public void rotateKeys() throws DbException {
}
} }

View File

@@ -14,8 +14,6 @@ import net.sf.briar.api.db.DatabasePassword;
import net.sf.briar.api.lifecycle.ShutdownManager; import net.sf.briar.api.lifecycle.ShutdownManager;
import net.sf.briar.api.protocol.GroupFactory; import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.transport.ConnectionContextFactory;
import net.sf.briar.api.transport.ConnectionWindowFactory;
import net.sf.briar.util.BoundedExecutor; import net.sf.briar.util.BoundedExecutor;
import com.google.inject.AbstractModule; import com.google.inject.AbstractModule;
@@ -50,11 +48,8 @@ public class DatabaseModule extends AbstractModule {
@Provides @Provides
Database<Connection> getDatabase(@DatabaseDirectory File dir, Database<Connection> getDatabase(@DatabaseDirectory File dir,
@DatabasePassword Password password, @DatabaseMaxSize long maxSize, @DatabasePassword Password password, @DatabaseMaxSize long maxSize,
ConnectionContextFactory connectionContextFactory,
ConnectionWindowFactory connectionWindowFactory,
GroupFactory groupFactory, Clock clock) { GroupFactory groupFactory, Clock clock) {
return new H2Database(dir, password, maxSize, connectionContextFactory, return new H2Database(dir, password, maxSize, groupFactory, clock);
connectionWindowFactory, groupFactory, clock);
} }
@Provides @Singleton @Provides @Singleton

View File

@@ -15,8 +15,6 @@ import net.sf.briar.api.db.DatabaseMaxSize;
import net.sf.briar.api.db.DatabasePassword; import net.sf.briar.api.db.DatabasePassword;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.GroupFactory; import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.transport.ConnectionContextFactory;
import net.sf.briar.api.transport.ConnectionWindowFactory;
import org.apache.commons.io.FileSystemUtils; import org.apache.commons.io.FileSystemUtils;
@@ -39,11 +37,9 @@ class H2Database extends JdbcDatabase {
H2Database(@DatabaseDirectory File dir, H2Database(@DatabaseDirectory File dir,
@DatabasePassword Password password, @DatabasePassword Password password,
@DatabaseMaxSize long maxSize, @DatabaseMaxSize long maxSize,
ConnectionContextFactory connectionContextFactory,
ConnectionWindowFactory connectionWindowFactory,
GroupFactory groupFactory, Clock clock) { GroupFactory groupFactory, Clock clock) {
super(connectionContextFactory, connectionWindowFactory, groupFactory, super(groupFactory, clock, HASH_TYPE, BINARY_TYPE, COUNTER_TYPE,
clock, HASH_TYPE, BINARY_TYPE, COUNTER_TYPE, SECRET_TYPE); SECRET_TYPE);
home = new File(dir, "db"); home = new File(dir, "db");
this.password = password; this.password = password;
url = "jdbc:h2:split:" + home.getPath() url = "jdbc:h2:split:" + home.getPath()

File diff suppressed because it is too large Load Diff

View File

@@ -1,220 +0,0 @@
package net.sf.briar.plugins;
import static net.sf.briar.api.plugins.InvitationConstants.HASH_LENGTH;
import static net.sf.briar.api.plugins.InvitationConstants.INVITATION_TIMEOUT;
import static net.sf.briar.api.plugins.InvitationConstants.MAX_CODE;
import static net.sf.briar.api.plugins.InvitationConstants.MAX_PUBLIC_KEY_LENGTH;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.KeyPair;
import java.util.Arrays;
import java.util.concurrent.Executor;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.MessageDigest;
import net.sf.briar.api.crypto.PseudoRandom;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.plugins.IncomingInvitationCallback;
import net.sf.briar.api.plugins.InvitationCallback;
import net.sf.briar.api.plugins.InvitationStarter;
import net.sf.briar.api.plugins.OutgoingInvitationCallback;
import net.sf.briar.api.plugins.PluginExecutor;
import net.sf.briar.api.plugins.duplex.DuplexPlugin;
import net.sf.briar.api.plugins.duplex.DuplexTransportConnection;
import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
import net.sf.briar.util.ByteUtils;
import com.google.inject.Inject;
class InvitationStarterImpl implements InvitationStarter {
private static final String TIMED_OUT = "INVITATION_TIMED_OUT";
private static final String IO_EXCEPTION = "INVITATION_IO_EXCEPTION";
private static final String INVALID_KEY = "INVITATION_INVALID_KEY";
private static final String WRONG_CODE = "INVITATION_WRONG_CODE";
private static final String DB_EXCEPTION = "INVITATION_DB_EXCEPTION";
private final Executor pluginExecutor;
private final CryptoComponent crypto;
private final DatabaseComponent db;
private final ReaderFactory readerFactory;
private final WriterFactory writerFactory;
@Inject
InvitationStarterImpl(@PluginExecutor Executor pluginExecutor,
CryptoComponent crypto, DatabaseComponent db,
ReaderFactory readerFactory, WriterFactory writerFactory) {
this.pluginExecutor = pluginExecutor;
this.crypto = crypto;
this.db = db;
this.readerFactory = readerFactory;
this.writerFactory = writerFactory;
}
public void startIncomingInvitation(DuplexPlugin plugin,
IncomingInvitationCallback callback) {
pluginExecutor.execute(new IncomingInvitationWorker(plugin, callback));
}
public void startOutgoingInvitation(DuplexPlugin plugin,
OutgoingInvitationCallback callback) {
pluginExecutor.execute(new OutgoingInvitationWorker(plugin, callback));
}
private abstract class InvitationWorker implements Runnable {
private final DuplexPlugin plugin;
private final InvitationCallback callback;
private final boolean initiator;
protected InvitationWorker(DuplexPlugin plugin,
InvitationCallback callback, boolean initiator) {
this.plugin = plugin;
this.callback = callback;
this.initiator = initiator;
}
protected abstract int getInvitationCode();
public void run() {
long end = System.currentTimeMillis() + INVITATION_TIMEOUT;
// Use the invitation code to seed the PRNG
int code = getInvitationCode();
if(code == -1) return; // Cancelled
PseudoRandom r = crypto.getPseudoRandom(code);
long timeout = end - System.currentTimeMillis();
if(timeout <= 0) {
callback.showFailure(TIMED_OUT);
return;
}
// Create a connection
DuplexTransportConnection conn;
if(initiator) conn = plugin.sendInvitation(r, timeout);
else conn = plugin.acceptInvitation(r, timeout);
if(callback.isCancelled()) {
if(conn != null) conn.dispose(false, false);
return;
}
if(conn == null) {
callback.showFailure(TIMED_OUT);
return;
}
// Use an ephemeral key pair for key agreement
KeyPair ourKeyPair = crypto.generateAgreementKeyPair();
MessageDigest messageDigest = crypto.getMessageDigest();
byte[] ourKey = ourKeyPair.getPublic().getEncoded();
byte[] ourHash = messageDigest.digest(ourKey);
byte[] theirKey, theirHash;
try {
OutputStream out = conn.getOutputStream();
Writer writer = writerFactory.createWriter(out);
InputStream in = conn.getInputStream();
Reader reader = readerFactory.createReader(in);
if(initiator) {
// Send the public key hash
writer.writeBytes(ourHash);
out.flush();
// Receive the public key hash
theirHash = reader.readBytes(HASH_LENGTH);
// Send the public key
writer.writeBytes(ourKey);
out.flush();
// Receive the public key
theirKey = reader.readBytes(MAX_PUBLIC_KEY_LENGTH);
} else {
// Receive the public key hash
theirHash = reader.readBytes(HASH_LENGTH);
// Send the public key hash
writer.writeBytes(ourHash);
out.flush();
// Receive the public key
theirKey = reader.readBytes(MAX_PUBLIC_KEY_LENGTH);
// Send the public key
writer.writeBytes(ourKey);
out.flush();
}
} catch(IOException e) {
conn.dispose(true, false);
callback.showFailure(IO_EXCEPTION);
return;
}
conn.dispose(false, false);
if(callback.isCancelled()) return;
// Check that the received hash matches the received key
if(!Arrays.equals(theirHash, messageDigest.digest(theirKey))) {
callback.showFailure(INVALID_KEY);
return;
}
// Derive the initial shared secrets and the confirmation codes
byte[][] secrets = crypto.deriveInitialSecrets(ourKey, theirKey,
ourKeyPair.getPrivate(), code, initiator);
if(secrets == null) {
callback.showFailure(INVALID_KEY);
return;
}
int initCode = crypto.deriveConfirmationCode(secrets[0]);
int respCode = crypto.deriveConfirmationCode(secrets[1]);
int ourCode = initiator ? initCode : respCode;
int theirCode = initiator ? respCode : initCode;
// Compare the confirmation codes
if(callback.enterConfirmationCode(ourCode) != theirCode) {
callback.showFailure(WRONG_CODE);
ByteUtils.erase(secrets[0]);
ByteUtils.erase(secrets[1]);
return;
}
// Add the contact to the database
byte[] inSecret = initiator ? secrets[1] : secrets[0];
byte[] outSecret = initiator ? secrets[0] : secrets[1];
try {
db.addContact(inSecret, outSecret);
} catch(DbException e) {
callback.showFailure(DB_EXCEPTION);
ByteUtils.erase(secrets[0]);
ByteUtils.erase(secrets[1]);
return;
}
callback.showSuccess();
}
}
private class IncomingInvitationWorker extends InvitationWorker {
private final IncomingInvitationCallback callback;
IncomingInvitationWorker(DuplexPlugin plugin,
IncomingInvitationCallback callback) {
super(plugin, callback, false);
this.callback = callback;
}
@Override
protected int getInvitationCode() {
return callback.enterInvitationCode();
}
}
private class OutgoingInvitationWorker extends InvitationWorker {
private final OutgoingInvitationCallback callback;
OutgoingInvitationWorker(DuplexPlugin plugin,
OutgoingInvitationCallback callback) {
super(plugin, callback, true);
this.callback = callback;
}
@Override
protected int getInvitationCode() {
int code = crypto.getSecureRandom().nextInt(MAX_CODE + 1);
callback.showInvitationCode(code);
return code;
}
}
}

View File

@@ -32,7 +32,6 @@ import net.sf.briar.api.plugins.simplex.SimplexTransportReader;
import net.sf.briar.api.plugins.simplex.SimplexTransportWriter; import net.sf.briar.api.plugins.simplex.SimplexTransportWriter;
import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionDispatcher; import net.sf.briar.api.transport.ConnectionDispatcher;
import net.sf.briar.api.ui.UiCallback; import net.sf.briar.api.ui.UiCallback;
@@ -102,14 +101,7 @@ class PluginManagerImpl implements PluginManager {
LOG.warning("Duplicate transport ID: " + id); LOG.warning("Duplicate transport ID: " + id);
continue; continue;
} }
TransportIndex index = db.getLocalIndex(id); callback.init(id);
if(index == null) index = db.addTransport(id);
if(index == null) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning("Could not allocate index for ID: " + id);
continue;
}
callback.init(id, index);
plugin.start(); plugin.start();
simplexPlugins.add(plugin); simplexPlugins.add(plugin);
} catch(ClassCastException e) { } catch(ClassCastException e) {
@@ -142,14 +134,7 @@ class PluginManagerImpl implements PluginManager {
LOG.warning("Duplicate transport ID: " + id); LOG.warning("Duplicate transport ID: " + id);
continue; continue;
} }
TransportIndex index = db.getLocalIndex(id); callback.init(id);
if(index == null) index = db.addTransport(id);
if(index == null) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning("Could not allocate index for ID: " + id);
continue;
}
callback.init(id, index);
plugin.start(); plugin.start();
duplexPlugins.add(plugin); duplexPlugins.add(plugin);
} catch(ClassCastException e) { } catch(ClassCastException e) {
@@ -222,12 +207,10 @@ class PluginManagerImpl implements PluginManager {
private abstract class PluginCallbackImpl implements PluginCallback { private abstract class PluginCallbackImpl implements PluginCallback {
protected volatile TransportId id = null; protected volatile TransportId id = null;
protected volatile TransportIndex index = null;
protected void init(TransportId id, TransportIndex index) { protected void init(TransportId id) {
assert this.id == null && this.index == null; assert this.id == null;
this.id = id; this.id = id;
this.index = index;
} }
public TransportConfig getConfig() { public TransportConfig getConfig() {
@@ -320,8 +303,7 @@ class PluginManagerImpl implements PluginManager {
} }
public void writerCreated(ContactId c, SimplexTransportWriter w) { public void writerCreated(ContactId c, SimplexTransportWriter w) {
assert index != null; dispatcher.dispatchWriter(c, id, w);
dispatcher.dispatchWriter(c, id, index, w);
} }
} }
@@ -335,8 +317,7 @@ class PluginManagerImpl implements PluginManager {
public void outgoingConnectionCreated(ContactId c, public void outgoingConnectionCreated(ContactId c,
DuplexTransportConnection d) { DuplexTransportConnection d) {
assert index != null; dispatcher.dispatchOutgoingConnection(c, id, d);
dispatcher.dispatchOutgoingConnection(c, id, index, d);
} }
} }
} }

View File

@@ -30,7 +30,6 @@ import javax.mail.internet.MimeMessage;
import javax.mail.internet.MimeMultipart; import javax.mail.internet.MimeMultipart;
import javax.mail.search.FlagTerm; import javax.mail.search.FlagTerm;
import javax.mail.util.ByteArrayDataSource; import javax.mail.util.ByteArrayDataSource;
import javax.microedition.io.StreamConnection;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.TransportConfig; import net.sf.briar.api.TransportConfig;

View File

@@ -1,31 +0,0 @@
package net.sf.briar.plugins.email;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import javax.activation.DataSource;
public class PipeDataSource implements DataSource{
public String getContentType() {
return "application/octet-stream";
}
public PipedInputStream getInputStream() throws IOException {
return null;
}
public String getName() {
return "foo";
}
public PipedOutputStream getOutputStream() throws UnsupportedOperationException {
return null;
}
}

View File

@@ -148,7 +148,6 @@ class ProtocolWriterImpl implements ProtocolWriter {
for(Transport p : t.getTransports()) { for(Transport p : t.getTransports()) {
w.writeStructId(Types.TRANSPORT); w.writeStructId(Types.TRANSPORT);
w.writeBytes(p.getId().getBytes()); w.writeBytes(p.getId().getBytes());
w.writeInt32(p.getIndex().getInt());
w.writeMap(p); w.writeMap(p);
} }
w.writeListEnd(); w.writeListEnd();

View File

@@ -15,14 +15,13 @@ import net.sf.briar.api.FormatException;
import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.PacketFactory;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.CountingConsumer; import net.sf.briar.api.serial.CountingConsumer;
import net.sf.briar.api.serial.StructReader;
import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.StructReader;
class TransportUpdateReader implements StructReader<TransportUpdate> { class TransportUpdateReader implements StructReader<TransportUpdate> {
@@ -46,12 +45,10 @@ class TransportUpdateReader implements StructReader<TransportUpdate> {
if(transports.size() > MAX_TRANSPORTS) throw new FormatException(); if(transports.size() > MAX_TRANSPORTS) throw new FormatException();
long timestamp = r.readInt64(); long timestamp = r.readInt64();
r.removeConsumer(counting); r.removeConsumer(counting);
// Check for duplicate IDs or indices // Check for duplicate IDs
Set<TransportId> ids = new HashSet<TransportId>(); Set<TransportId> ids = new HashSet<TransportId>();
Set<TransportIndex> indices = new HashSet<TransportIndex>();
for(Transport t : transports) { for(Transport t : transports) {
if(!ids.add(t.getId())) throw new FormatException(); if(!ids.add(t.getId())) throw new FormatException();
if(!indices.add(t.getIndex())) throw new FormatException();
} }
// Build and return the transport update // Build and return the transport update
return packetFactory.createTransportUpdate(transports, timestamp); return packetFactory.createTransportUpdate(transports, timestamp);
@@ -65,17 +62,13 @@ class TransportUpdateReader implements StructReader<TransportUpdate> {
byte[] b = r.readBytes(UniqueId.LENGTH); byte[] b = r.readBytes(UniqueId.LENGTH);
if(b.length != UniqueId.LENGTH) throw new FormatException(); if(b.length != UniqueId.LENGTH) throw new FormatException();
TransportId id = new TransportId(b); TransportId id = new TransportId(b);
// Read the index
int i = r.readInt32();
if(i < 0 || i >= MAX_TRANSPORTS) throw new FormatException();
TransportIndex index = new TransportIndex(i);
// Read the properties // Read the properties
r.setMaxStringLength(MAX_PROPERTY_LENGTH); r.setMaxStringLength(MAX_PROPERTY_LENGTH);
Map<String, String> m = r.readMap(String.class, String.class); Map<String, String> m = r.readMap(String.class, String.class);
r.resetMaxStringLength(); r.resetMaxStringLength();
if(m.size() > MAX_PROPERTIES_PER_TRANSPORT) if(m.size() > MAX_PROPERTIES_PER_TRANSPORT)
throw new FormatException(); throw new FormatException();
return new Transport(id, index, m); return new Transport(id, m);
} }
} }
} }

View File

@@ -45,6 +45,7 @@ import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.protocol.UnverifiedBatch;
import net.sf.briar.api.protocol.VerificationExecutor; import net.sf.briar.api.protocol.VerificationExecutor;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionRegistry; import net.sf.briar.api.transport.ConnectionRegistry;
@@ -54,7 +55,7 @@ import net.sf.briar.api.transport.ConnectionWriterFactory;
abstract class DuplexConnection implements DatabaseListener { abstract class DuplexConnection implements DatabaseListener {
private static final Logger LOG = private static final Logger LOG =
Logger.getLogger(DuplexConnection.class.getName()); Logger.getLogger(DuplexConnection.class.getName());
private static final Runnable CLOSE = new Runnable() { private static final Runnable CLOSE = new Runnable() {
public void run() {} public void run() {}
@@ -66,9 +67,10 @@ abstract class DuplexConnection implements DatabaseListener {
protected final ConnectionWriterFactory connWriterFactory; protected final ConnectionWriterFactory connWriterFactory;
protected final ProtocolReaderFactory protoReaderFactory; protected final ProtocolReaderFactory protoReaderFactory;
protected final ProtocolWriterFactory protoWriterFactory; protected final ProtocolWriterFactory protoWriterFactory;
protected final ConnectionContext ctx;
protected final DuplexTransportConnection transport;
protected final ContactId contactId; protected final ContactId contactId;
protected final TransportId transportId; protected final TransportId transportId;
protected final DuplexTransportConnection transport;
private final Executor dbExecutor, verificationExecutor; private final Executor dbExecutor, verificationExecutor;
private final AtomicBoolean canSendOffer, disposed; private final AtomicBoolean canSendOffer, disposed;
@@ -84,8 +86,8 @@ abstract class DuplexConnection implements DatabaseListener {
ConnectionReaderFactory connReaderFactory, ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId, ProtocolWriterFactory protoWriterFactory, ConnectionContext ctx,
TransportId transportId, DuplexTransportConnection transport) { DuplexTransportConnection transport) {
this.dbExecutor = dbExecutor; this.dbExecutor = dbExecutor;
this.verificationExecutor = verificationExecutor; this.verificationExecutor = verificationExecutor;
this.db = db; this.db = db;
@@ -94,19 +96,20 @@ abstract class DuplexConnection implements DatabaseListener {
this.connWriterFactory = connWriterFactory; this.connWriterFactory = connWriterFactory;
this.protoReaderFactory = protoReaderFactory; this.protoReaderFactory = protoReaderFactory;
this.protoWriterFactory = protoWriterFactory; this.protoWriterFactory = protoWriterFactory;
this.contactId = contactId; this.ctx = ctx;
this.transportId = transportId;
this.transport = transport; this.transport = transport;
contactId = ctx.getContactId();
transportId = ctx.getTransportId();
canSendOffer = new AtomicBoolean(false); canSendOffer = new AtomicBoolean(false);
disposed = new AtomicBoolean(false); disposed = new AtomicBoolean(false);
writerTasks = new LinkedBlockingQueue<Runnable>(); writerTasks = new LinkedBlockingQueue<Runnable>();
} }
protected abstract ConnectionReader createConnectionReader() protected abstract ConnectionReader createConnectionReader()
throws DbException, IOException; throws IOException;
protected abstract ConnectionWriter createConnectionWriter() protected abstract ConnectionWriter createConnectionWriter()
throws DbException, IOException; throws IOException;
public void eventOccurred(DatabaseEvent e) { public void eventOccurred(DatabaseEvent e) {
if(e instanceof BatchReceivedEvent) { if(e instanceof BatchReceivedEvent) {
@@ -121,7 +124,7 @@ abstract class DuplexConnection implements DatabaseListener {
dbExecutor.execute(new GenerateOffer()); dbExecutor.execute(new GenerateOffer());
} else if(e instanceof SubscriptionsUpdatedEvent) { } else if(e instanceof SubscriptionsUpdatedEvent) {
Collection<ContactId> affected = Collection<ContactId> affected =
((SubscriptionsUpdatedEvent) e).getAffectedContacts(); ((SubscriptionsUpdatedEvent) e).getAffectedContacts();
if(affected.contains(contactId)) { if(affected.contains(contactId)) {
dbExecutor.execute(new GenerateSubscriptionUpdate()); dbExecutor.execute(new GenerateSubscriptionUpdate());
} }
@@ -176,9 +179,6 @@ abstract class DuplexConnection implements DatabaseListener {
} }
// The writer will dispose of the transport // The writer will dispose of the transport
writerTasks.add(CLOSE); writerTasks.add(CLOSE);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
if(!disposed.getAndSet(true)) transport.dispose(true, true);
} catch(IOException e) { } catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
if(!disposed.getAndSet(true)) transport.dispose(true, true); if(!disposed.getAndSet(true)) transport.dispose(true, true);
@@ -217,9 +217,6 @@ abstract class DuplexConnection implements DatabaseListener {
writer.flush(); writer.flush();
writer.close(); writer.close();
if(!disposed.getAndSet(true)) transport.dispose(false, true); if(!disposed.getAndSet(true)) transport.dispose(false, true);
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
if(!disposed.getAndSet(true)) transport.dispose(true, true);
} catch(InterruptedException e) { } catch(InterruptedException e) {
if(LOG.isLoggable(Level.INFO)) if(LOG.isLoggable(Level.INFO))
LOG.info("Interrupted while waiting for task"); LOG.info("Interrupted while waiting for task");

View File

@@ -1,15 +1,17 @@
package net.sf.briar.protocol.duplex; package net.sf.briar.protocol.duplex;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.logging.Level;
import java.util.logging.Logger;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.KeyManager;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DatabaseExecutor; import net.sf.briar.api.db.DatabaseExecutor;
import net.sf.briar.api.plugins.duplex.DuplexTransportConnection; import net.sf.briar.api.plugins.duplex.DuplexTransportConnection;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.VerificationExecutor; import net.sf.briar.api.protocol.VerificationExecutor;
import net.sf.briar.api.protocol.duplex.DuplexConnectionFactory; import net.sf.briar.api.protocol.duplex.DuplexConnectionFactory;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
@@ -21,8 +23,12 @@ import com.google.inject.Inject;
class DuplexConnectionFactoryImpl implements DuplexConnectionFactory { class DuplexConnectionFactoryImpl implements DuplexConnectionFactory {
private static final Logger LOG =
Logger.getLogger(DuplexConnectionFactoryImpl.class.getName());
private final Executor dbExecutor, verificationExecutor; private final Executor dbExecutor, verificationExecutor;
private final DatabaseComponent db; private final DatabaseComponent db;
private final KeyManager keyManager;
private final ConnectionRegistry connRegistry; private final ConnectionRegistry connRegistry;
private final ConnectionReaderFactory connReaderFactory; private final ConnectionReaderFactory connReaderFactory;
private final ConnectionWriterFactory connWriterFactory; private final ConnectionWriterFactory connWriterFactory;
@@ -32,14 +38,15 @@ class DuplexConnectionFactoryImpl implements DuplexConnectionFactory {
@Inject @Inject
DuplexConnectionFactoryImpl(@DatabaseExecutor Executor dbExecutor, DuplexConnectionFactoryImpl(@DatabaseExecutor Executor dbExecutor,
@VerificationExecutor Executor verificationExecutor, @VerificationExecutor Executor verificationExecutor,
DatabaseComponent db, ConnectionRegistry connRegistry, DatabaseComponent db, KeyManager keyManager,
ConnectionRegistry connRegistry,
ConnectionReaderFactory connReaderFactory, ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory, ProtocolWriterFactory protoWriterFactory) {
ProtocolWriterFactory protoWriterFactory) {
this.dbExecutor = dbExecutor; this.dbExecutor = dbExecutor;
this.verificationExecutor = verificationExecutor; this.verificationExecutor = verificationExecutor;
this.db = db; this.db = db;
this.keyManager = keyManager;
this.connRegistry = connRegistry; this.connRegistry = connRegistry;
this.connReaderFactory = connReaderFactory; this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory; this.connWriterFactory = connWriterFactory;
@@ -47,12 +54,12 @@ class DuplexConnectionFactoryImpl implements DuplexConnectionFactory {
this.protoWriterFactory = protoWriterFactory; this.protoWriterFactory = protoWriterFactory;
} }
public void createIncomingConnection(ConnectionContext ctx, TransportId t, public void createIncomingConnection(ConnectionContext ctx,
DuplexTransportConnection d) { DuplexTransportConnection transport) {
final DuplexConnection conn = new IncomingDuplexConnection(dbExecutor, final DuplexConnection conn = new IncomingDuplexConnection(dbExecutor,
verificationExecutor, db, connRegistry, connReaderFactory, verificationExecutor, db, connRegistry, connReaderFactory,
connWriterFactory, protoReaderFactory, protoWriterFactory, connWriterFactory, protoReaderFactory, protoWriterFactory, ctx,
ctx, t, d); transport);
Runnable write = new Runnable() { Runnable write = new Runnable() {
public void run() { public void run() {
conn.write(); conn.write();
@@ -68,11 +75,17 @@ class DuplexConnectionFactoryImpl implements DuplexConnectionFactory {
} }
public void createOutgoingConnection(ContactId c, TransportId t, public void createOutgoingConnection(ContactId c, TransportId t,
TransportIndex i, DuplexTransportConnection d) { DuplexTransportConnection transport) {
ConnectionContext ctx = keyManager.getConnectionContext(c, t);
if(ctx == null) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning("Could not create outgoing connection context");
return;
}
final DuplexConnection conn = new OutgoingDuplexConnection(dbExecutor, final DuplexConnection conn = new OutgoingDuplexConnection(dbExecutor,
verificationExecutor, db, connRegistry, connReaderFactory, verificationExecutor, db, connRegistry, connReaderFactory,
connWriterFactory, protoReaderFactory, protoWriterFactory, connWriterFactory, protoReaderFactory, protoWriterFactory, ctx,
c, t, i, d); transport);
Runnable write = new Runnable() { Runnable write = new Runnable() {
public void run() { public void run() {
conn.write(); conn.write();
@@ -86,5 +99,4 @@ class DuplexConnectionFactoryImpl implements DuplexConnectionFactory {
}; };
new Thread(read).start(); new Thread(read).start();
} }
} }

View File

@@ -8,7 +8,6 @@ import net.sf.briar.api.db.DatabaseExecutor;
import net.sf.briar.api.plugins.duplex.DuplexTransportConnection; import net.sf.briar.api.plugins.duplex.DuplexTransportConnection;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.VerificationExecutor; import net.sf.briar.api.protocol.VerificationExecutor;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
@@ -28,24 +27,22 @@ class IncomingDuplexConnection extends DuplexConnection {
ConnectionWriterFactory connWriterFactory, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ProtocolWriterFactory protoWriterFactory,
ConnectionContext ctx, TransportId transportId, ConnectionContext ctx, DuplexTransportConnection transport) {
DuplexTransportConnection transport) {
super(dbExecutor, verificationExecutor, db, connRegistry, super(dbExecutor, verificationExecutor, db, connRegistry,
connReaderFactory, connWriterFactory, protoReaderFactory, connReaderFactory, connWriterFactory, protoReaderFactory,
protoWriterFactory, ctx.getContactId(), transportId, transport); protoWriterFactory, ctx, transport);
this.ctx = ctx; this.ctx = ctx;
} }
@Override @Override
protected ConnectionReader createConnectionReader() throws IOException { protected ConnectionReader createConnectionReader() throws IOException {
return connReaderFactory.createConnectionReader( return connReaderFactory.createConnectionReader(
transport.getInputStream(), ctx.getSecret(), true); transport.getInputStream(), ctx, true);
} }
@Override @Override
protected ConnectionWriter createConnectionWriter() throws IOException { protected ConnectionWriter createConnectionWriter() throws IOException {
return connWriterFactory.createConnectionWriter( return connWriterFactory.createConnectionWriter(
transport.getOutputStream(), Long.MAX_VALUE, ctx.getSecret(), transport.getOutputStream(), Long.MAX_VALUE, ctx, false);
false);
} }
} }

View File

@@ -3,15 +3,11 @@ package net.sf.briar.protocol.duplex;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DatabaseExecutor; import net.sf.briar.api.db.DatabaseExecutor;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.plugins.duplex.DuplexTransportConnection; import net.sf.briar.api.plugins.duplex.DuplexTransportConnection;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.VerificationExecutor; import net.sf.briar.api.protocol.VerificationExecutor;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
@@ -22,45 +18,28 @@ import net.sf.briar.api.transport.ConnectionWriterFactory;
class OutgoingDuplexConnection extends DuplexConnection { class OutgoingDuplexConnection extends DuplexConnection {
private final TransportIndex transportIndex;
private ConnectionContext ctx = null; // Locking: this
OutgoingDuplexConnection(@DatabaseExecutor Executor dbExecutor, OutgoingDuplexConnection(@DatabaseExecutor Executor dbExecutor,
@VerificationExecutor Executor verificationExecutor, @VerificationExecutor Executor verificationExecutor,
DatabaseComponent db, ConnectionRegistry connRegistry, DatabaseComponent db, ConnectionRegistry connRegistry,
ConnectionReaderFactory connReaderFactory, ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId, ProtocolWriterFactory protoWriterFactory, ConnectionContext ctx,
TransportId transportId, TransportIndex transportIndex,
DuplexTransportConnection transport) { DuplexTransportConnection transport) {
super(dbExecutor, verificationExecutor, db, connRegistry, super(dbExecutor, verificationExecutor, db, connRegistry,
connReaderFactory, connWriterFactory, protoReaderFactory, connReaderFactory, connWriterFactory, protoReaderFactory,
protoWriterFactory, contactId, transportId, transport); protoWriterFactory, ctx, transport);
this.transportIndex = transportIndex;
} }
@Override @Override
protected ConnectionReader createConnectionReader() throws DbException, protected ConnectionReader createConnectionReader() throws IOException {
IOException {
synchronized(this) {
if(ctx == null)
ctx = db.getConnectionContext(contactId, transportIndex);
}
return connReaderFactory.createConnectionReader( return connReaderFactory.createConnectionReader(
transport.getInputStream(), ctx.getSecret(), false); transport.getInputStream(), ctx, false);
} }
@Override @Override
protected ConnectionWriter createConnectionWriter() throws DbException, protected ConnectionWriter createConnectionWriter() throws IOException {
IOException {
synchronized(this) {
if(ctx == null)
ctx = db.getConnectionContext(contactId, transportIndex);
}
return connWriterFactory.createConnectionWriter( return connWriterFactory.createConnectionWriter(
transport.getOutputStream(), Long.MAX_VALUE, ctx.getSecret(), transport.getOutputStream(), Long.MAX_VALUE, ctx, true);
true);
} }
} }

View File

@@ -30,7 +30,7 @@ import net.sf.briar.api.transport.ConnectionRegistry;
class IncomingSimplexConnection { class IncomingSimplexConnection {
private static final Logger LOG = private static final Logger LOG =
Logger.getLogger(IncomingSimplexConnection.class.getName()); Logger.getLogger(IncomingSimplexConnection.class.getName());
private final Executor dbExecutor, verificationExecutor; private final Executor dbExecutor, verificationExecutor;
private final DatabaseComponent db; private final DatabaseComponent db;
@@ -38,16 +38,16 @@ class IncomingSimplexConnection {
private final ConnectionReaderFactory connFactory; private final ConnectionReaderFactory connFactory;
private final ProtocolReaderFactory protoFactory; private final ProtocolReaderFactory protoFactory;
private final ConnectionContext ctx; private final ConnectionContext ctx;
private final TransportId transportId;
private final SimplexTransportReader transport; private final SimplexTransportReader transport;
private final ContactId contactId; private final ContactId contactId;
private final TransportId transportId;
IncomingSimplexConnection(@DatabaseExecutor Executor dbExecutor, IncomingSimplexConnection(@DatabaseExecutor Executor dbExecutor,
@VerificationExecutor Executor verificationExecutor, @VerificationExecutor Executor verificationExecutor,
DatabaseComponent db, ConnectionRegistry connRegistry, DatabaseComponent db, ConnectionRegistry connRegistry,
ConnectionReaderFactory connFactory, ConnectionReaderFactory connFactory,
ProtocolReaderFactory protoFactory, ConnectionContext ctx, ProtocolReaderFactory protoFactory, ConnectionContext ctx,
TransportId transportId, SimplexTransportReader transport) { SimplexTransportReader transport) {
this.dbExecutor = dbExecutor; this.dbExecutor = dbExecutor;
this.verificationExecutor = verificationExecutor; this.verificationExecutor = verificationExecutor;
this.db = db; this.db = db;
@@ -55,16 +55,16 @@ class IncomingSimplexConnection {
this.connFactory = connFactory; this.connFactory = connFactory;
this.protoFactory = protoFactory; this.protoFactory = protoFactory;
this.ctx = ctx; this.ctx = ctx;
this.transportId = transportId;
this.transport = transport; this.transport = transport;
contactId = ctx.getContactId(); contactId = ctx.getContactId();
transportId = ctx.getTransportId();
} }
void read() { void read() {
connRegistry.registerConnection(contactId, transportId); connRegistry.registerConnection(contactId, transportId);
try { try {
ConnectionReader conn = connFactory.createConnectionReader( ConnectionReader conn = connFactory.createConnectionReader(
transport.getInputStream(), ctx.getSecret(), true); transport.getInputStream(), ctx, true);
InputStream in = conn.getInputStream(); InputStream in = conn.getInputStream();
ProtocolReader reader = protoFactory.createProtocolReader(in); ProtocolReader reader = protoFactory.createProtocolReader(in);
// Read packets until EOF // Read packets until EOF

View File

@@ -18,7 +18,6 @@ import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.RawBatch; import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRegistry; import net.sf.briar.api.transport.ConnectionRegistry;
@@ -34,35 +33,32 @@ class OutgoingSimplexConnection {
private final ConnectionRegistry connRegistry; private final ConnectionRegistry connRegistry;
private final ConnectionWriterFactory connFactory; private final ConnectionWriterFactory connFactory;
private final ProtocolWriterFactory protoFactory; private final ProtocolWriterFactory protoFactory;
private final ConnectionContext ctx;
private final SimplexTransportWriter transport;
private final ContactId contactId; private final ContactId contactId;
private final TransportId transportId; private final TransportId transportId;
private final TransportIndex transportIndex;
private final SimplexTransportWriter transport;
OutgoingSimplexConnection(DatabaseComponent db, OutgoingSimplexConnection(DatabaseComponent db,
ConnectionRegistry connRegistry, ConnectionRegistry connRegistry,
ConnectionWriterFactory connFactory, ConnectionWriterFactory connFactory,
ProtocolWriterFactory protoFactory, ContactId contactId, ProtocolWriterFactory protoFactory, ConnectionContext ctx,
TransportId transportId, TransportIndex transportIndex,
SimplexTransportWriter transport) { SimplexTransportWriter transport) {
this.db = db; this.db = db;
this.connRegistry = connRegistry; this.connRegistry = connRegistry;
this.connFactory = connFactory; this.connFactory = connFactory;
this.protoFactory = protoFactory; this.protoFactory = protoFactory;
this.contactId = contactId; this.ctx = ctx;
this.transportId = transportId;
this.transportIndex = transportIndex;
this.transport = transport; this.transport = transport;
contactId = ctx.getContactId();
transportId = ctx.getTransportId();
} }
void write() { void write() {
connRegistry.registerConnection(contactId, transportId); connRegistry.registerConnection(contactId, transportId);
try { try {
ConnectionContext ctx = db.getConnectionContext(contactId,
transportIndex);
ConnectionWriter conn = connFactory.createConnectionWriter( ConnectionWriter conn = connFactory.createConnectionWriter(
transport.getOutputStream(), transport.getCapacity(), transport.getOutputStream(), transport.getCapacity(),
ctx.getSecret(), true); ctx, true);
OutputStream out = conn.getOutputStream(); OutputStream out = conn.getOutputStream();
ProtocolWriter writer = protoFactory.createProtocolWriter(out, ProtocolWriter writer = protoFactory.createProtocolWriter(out,
transport.shouldFlush()); transport.shouldFlush());

View File

@@ -1,8 +1,11 @@
package net.sf.briar.protocol.simplex; package net.sf.briar.protocol.simplex;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.logging.Level;
import java.util.logging.Logger;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.KeyManager;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DatabaseExecutor; import net.sf.briar.api.db.DatabaseExecutor;
import net.sf.briar.api.plugins.simplex.SimplexTransportReader; import net.sf.briar.api.plugins.simplex.SimplexTransportReader;
@@ -10,7 +13,6 @@ import net.sf.briar.api.plugins.simplex.SimplexTransportWriter;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.VerificationExecutor; import net.sf.briar.api.protocol.VerificationExecutor;
import net.sf.briar.api.protocol.simplex.SimplexConnectionFactory; import net.sf.briar.api.protocol.simplex.SimplexConnectionFactory;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
@@ -22,8 +24,12 @@ import com.google.inject.Inject;
class SimplexConnectionFactoryImpl implements SimplexConnectionFactory { class SimplexConnectionFactoryImpl implements SimplexConnectionFactory {
private static final Logger LOG =
Logger.getLogger(SimplexConnectionFactoryImpl.class.getName());
private final Executor dbExecutor, verificationExecutor; private final Executor dbExecutor, verificationExecutor;
private final DatabaseComponent db; private final DatabaseComponent db;
private final KeyManager keyManager;
private final ConnectionRegistry connRegistry; private final ConnectionRegistry connRegistry;
private final ConnectionReaderFactory connReaderFactory; private final ConnectionReaderFactory connReaderFactory;
private final ConnectionWriterFactory connWriterFactory; private final ConnectionWriterFactory connWriterFactory;
@@ -33,7 +39,8 @@ class SimplexConnectionFactoryImpl implements SimplexConnectionFactory {
@Inject @Inject
SimplexConnectionFactoryImpl(@DatabaseExecutor Executor dbExecutor, SimplexConnectionFactoryImpl(@DatabaseExecutor Executor dbExecutor,
@VerificationExecutor Executor verificationExecutor, @VerificationExecutor Executor verificationExecutor,
DatabaseComponent db, ConnectionRegistry connRegistry, DatabaseComponent db, KeyManager keyManager,
ConnectionRegistry connRegistry,
ConnectionReaderFactory connReaderFactory, ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, ConnectionWriterFactory connWriterFactory,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
@@ -41,6 +48,7 @@ class SimplexConnectionFactoryImpl implements SimplexConnectionFactory {
this.dbExecutor = dbExecutor; this.dbExecutor = dbExecutor;
this.verificationExecutor = verificationExecutor; this.verificationExecutor = verificationExecutor;
this.db = db; this.db = db;
this.keyManager = keyManager;
this.connRegistry = connRegistry; this.connRegistry = connRegistry;
this.connReaderFactory = connReaderFactory; this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory; this.connWriterFactory = connWriterFactory;
@@ -48,11 +56,10 @@ class SimplexConnectionFactoryImpl implements SimplexConnectionFactory {
this.protoWriterFactory = protoWriterFactory; this.protoWriterFactory = protoWriterFactory;
} }
public void createIncomingConnection(ConnectionContext ctx, TransportId t, public void createIncomingConnection(ConnectionContext ctx, SimplexTransportReader r) {
SimplexTransportReader r) {
final IncomingSimplexConnection conn = new IncomingSimplexConnection( final IncomingSimplexConnection conn = new IncomingSimplexConnection(
dbExecutor, verificationExecutor, db, connRegistry, dbExecutor, verificationExecutor, db, connRegistry,
connReaderFactory, protoReaderFactory, ctx, t, r); connReaderFactory, protoReaderFactory, ctx, r);
Runnable read = new Runnable() { Runnable read = new Runnable() {
public void run() { public void run() {
conn.read(); conn.read();
@@ -62,10 +69,15 @@ class SimplexConnectionFactoryImpl implements SimplexConnectionFactory {
} }
public void createOutgoingConnection(ContactId c, TransportId t, public void createOutgoingConnection(ContactId c, TransportId t,
TransportIndex i, SimplexTransportWriter w) { SimplexTransportWriter w) {
ConnectionContext ctx = keyManager.getConnectionContext(c, t);
if(ctx == null) {
if(LOG.isLoggable(Level.WARNING))
LOG.warning("Could not create outgoing connection context");
return;
}
final OutgoingSimplexConnection conn = new OutgoingSimplexConnection(db, final OutgoingSimplexConnection conn = new OutgoingSimplexConnection(db,
connRegistry, connWriterFactory, protoWriterFactory, connRegistry, connWriterFactory, protoWriterFactory, ctx, w);
c, t, i, w);
Runnable write = new Runnable() { Runnable write = new Runnable() {
public void run() { public void run() {
conn.write(); conn.write();

View File

@@ -1,31 +0,0 @@
package net.sf.briar.transport;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionContextFactory;
import com.google.inject.Inject;
class ConnectionContextFactoryImpl implements ConnectionContextFactory {
private final CryptoComponent crypto;
@Inject
ConnectionContextFactoryImpl(CryptoComponent crypto) {
this.crypto = crypto;
}
public ConnectionContext createConnectionContext(ContactId c,
TransportIndex i, long connection, byte[] secret) {
return new ConnectionContextImpl(c, i, connection, secret);
}
public ConnectionContext createNextConnectionContext(ContactId c,
TransportIndex i, long connection, byte[] previousSecret) {
byte[] secret = crypto.deriveNextSecret(previousSecret, i.getInt(),
connection);
return new ConnectionContextImpl(c, i, connection, secret);
}
}

View File

@@ -1,37 +0,0 @@
package net.sf.briar.transport;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext;
class ConnectionContextImpl implements ConnectionContext {
private final ContactId contactId;
private final TransportIndex transportIndex;
private final long connectionNumber;
private final byte[] secret;
ConnectionContextImpl(ContactId contactId, TransportIndex transportIndex,
long connectionNumber, byte[] secret) {
this.contactId = contactId;
this.transportIndex = transportIndex;
this.connectionNumber = connectionNumber;
this.secret = secret;
}
public ContactId getContactId() {
return contactId;
}
public TransportIndex getTransportIndex() {
return transportIndex;
}
public long getConnectionNumber() {
return connectionNumber;
}
public byte[] getSecret() {
return secret;
}
}

View File

@@ -13,7 +13,6 @@ import net.sf.briar.api.plugins.duplex.DuplexTransportConnection;
import net.sf.briar.api.plugins.simplex.SimplexTransportReader; import net.sf.briar.api.plugins.simplex.SimplexTransportReader;
import net.sf.briar.api.plugins.simplex.SimplexTransportWriter; import net.sf.briar.api.plugins.simplex.SimplexTransportWriter;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.duplex.DuplexConnectionFactory; import net.sf.briar.api.protocol.duplex.DuplexConnectionFactory;
import net.sf.briar.api.protocol.simplex.SimplexConnectionFactory; import net.sf.briar.api.protocol.simplex.SimplexConnectionFactory;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
@@ -31,27 +30,27 @@ class ConnectionDispatcherImpl implements ConnectionDispatcher {
private final Executor connExecutor; private final Executor connExecutor;
private final ConnectionRecogniser recogniser; private final ConnectionRecogniser recogniser;
private final SimplexConnectionFactory batchConnFactory; private final SimplexConnectionFactory simplexConnFactory;
private final DuplexConnectionFactory streamConnFactory; private final DuplexConnectionFactory duplexConnFactory;
@Inject @Inject
ConnectionDispatcherImpl(@IncomingConnectionExecutor Executor connExecutor, ConnectionDispatcherImpl(@IncomingConnectionExecutor Executor connExecutor,
ConnectionRecogniser recogniser, ConnectionRecogniser recogniser,
SimplexConnectionFactory batchConnFactory, SimplexConnectionFactory simplexConnFactory,
DuplexConnectionFactory streamConnFactory) { DuplexConnectionFactory duplexConnFactory) {
this.connExecutor = connExecutor; this.connExecutor = connExecutor;
this.recogniser = recogniser; this.recogniser = recogniser;
this.batchConnFactory = batchConnFactory; this.simplexConnFactory = simplexConnFactory;
this.streamConnFactory = streamConnFactory; this.duplexConnFactory = duplexConnFactory;
} }
public void dispatchReader(TransportId t, SimplexTransportReader r) { public void dispatchReader(TransportId t, SimplexTransportReader r) {
connExecutor.execute(new DispatchSimplexConnection(t, r)); connExecutor.execute(new DispatchSimplexConnection(t, r));
} }
public void dispatchWriter(ContactId c, TransportId t, TransportIndex i, public void dispatchWriter(ContactId c, TransportId t,
SimplexTransportWriter w) { SimplexTransportWriter w) {
batchConnFactory.createOutgoingConnection(c, t, i, w); simplexConnFactory.createOutgoingConnection(c, t, w);
} }
public void dispatchIncomingConnection(TransportId t, public void dispatchIncomingConnection(TransportId t,
@@ -60,8 +59,8 @@ class ConnectionDispatcherImpl implements ConnectionDispatcher {
} }
public void dispatchOutgoingConnection(ContactId c, TransportId t, public void dispatchOutgoingConnection(ContactId c, TransportId t,
TransportIndex i, DuplexTransportConnection d) { DuplexTransportConnection d) {
streamConnFactory.createOutgoingConnection(c, t, i, d); duplexConnFactory.createOutgoingConnection(c, t, d);
} }
private byte[] readTag(InputStream in) throws IOException { private byte[] readTag(InputStream in) throws IOException {
@@ -91,9 +90,12 @@ class ConnectionDispatcherImpl implements ConnectionDispatcher {
byte[] tag = readTag(transport.getInputStream()); byte[] tag = readTag(transport.getInputStream());
ConnectionContext ctx = recogniser.acceptConnection(transportId, ConnectionContext ctx = recogniser.acceptConnection(transportId,
tag); tag);
if(ctx == null) transport.dispose(false, false); if(ctx == null) {
else batchConnFactory.createIncomingConnection(ctx, transportId, transport.dispose(false, false);
transport); } else {
simplexConnFactory.createIncomingConnection(ctx,
transport);
}
} catch(DbException e) { } catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
try { try {
@@ -128,9 +130,11 @@ class ConnectionDispatcherImpl implements ConnectionDispatcher {
byte[] tag = readTag(transport.getInputStream()); byte[] tag = readTag(transport.getInputStream());
ConnectionContext ctx = recogniser.acceptConnection(transportId, ConnectionContext ctx = recogniser.acceptConnection(transportId,
tag); tag);
if(ctx == null) transport.dispose(false, false); if(ctx == null) {
else streamConnFactory.createIncomingConnection(ctx, transport.dispose(false, false);
transportId, transport); } else {
duplexConnFactory.createIncomingConnection(ctx, transport);
}
} catch(DbException e) { } catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString()); if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
transport.dispose(true, false); transport.dispose(true, false);

View File

@@ -4,14 +4,11 @@ import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import java.io.InputStream; import java.io.InputStream;
import javax.crypto.Cipher;
import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.util.ByteUtils;
import com.google.inject.Inject; import com.google.inject.Inject;
@@ -25,27 +22,14 @@ class ConnectionReaderFactoryImpl implements ConnectionReaderFactory {
} }
public ConnectionReader createConnectionReader(InputStream in, public ConnectionReader createConnectionReader(InputStream in,
byte[] secret, boolean initiator) { ConnectionContext ctx, boolean initiator) {
if(initiator) { byte[] secret = ctx.getSecret();
// Derive the frame key and erase the secret long connection = ctx.getConnectionNumber();
ErasableKey frameKey = crypto.deriveFrameKey(secret, initiator); boolean alice = ctx.getAlice();
ByteUtils.erase(secret); ErasableKey frameKey = crypto.deriveFrameKey(secret, connection, alice,
// Create a reader for the responder's side of the connection initiator);
AuthenticatedCipher frameCipher = crypto.getFrameCipher(); FrameReader encryption = new IncomingEncryptionLayer(in,
FrameReader encryption = new IncomingEncryptionLayer(in, crypto.getFrameCipher(), frameKey, MAX_FRAME_LENGTH);
frameCipher, frameKey, MAX_FRAME_LENGTH); return new ConnectionReaderImpl(encryption, MAX_FRAME_LENGTH);
return new ConnectionReaderImpl(encryption, MAX_FRAME_LENGTH);
} else {
// Derive the tag and frame keys and erase the secret
ErasableKey tagKey = crypto.deriveTagKey(secret, initiator);
ErasableKey frameKey = crypto.deriveFrameKey(secret, initiator);
ByteUtils.erase(secret);
// Create a reader for the initiator's side of the connection
Cipher tagCipher = crypto.getTagCipher();
AuthenticatedCipher frameCipher = crypto.getFrameCipher();
FrameReader encryption = new IncomingEncryptionLayer(in, tagCipher,
frameCipher, tagKey, frameKey, MAX_FRAME_LENGTH);
return new ConnectionReaderImpl(encryption, MAX_FRAME_LENGTH);
}
} }
} }

View File

@@ -1,263 +0,0 @@
package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.crypto.Cipher;
import net.sf.briar.api.Bytes;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.NoSuchContactException;
import net.sf.briar.api.db.event.ContactRemovedEvent;
import net.sf.briar.api.db.event.DatabaseEvent;
import net.sf.briar.api.db.event.DatabaseListener;
import net.sf.briar.api.db.event.RemoteTransportsUpdatedEvent;
import net.sf.briar.api.db.event.TransportAddedEvent;
import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.api.transport.IncomingConnectionExecutor;
import net.sf.briar.util.ByteUtils;
import com.google.inject.Inject;
class ConnectionRecogniserImpl implements ConnectionRecogniser,
DatabaseListener {
private static final Logger LOG =
Logger.getLogger(ConnectionRecogniserImpl.class.getName());
private final Executor connExecutor;
private final DatabaseComponent db;
private final CryptoComponent crypto;
private final Cipher tagCipher; // Locking: this
private final Set<TransportId> localTransportIds; // Locking: this
private final Map<Bytes, Context> expected; // Locking: this
private boolean initialised = false; // Locking: this
@Inject
ConnectionRecogniserImpl(@IncomingConnectionExecutor Executor connExecutor,
DatabaseComponent db, CryptoComponent crypto) {
this.connExecutor = connExecutor;
this.db = db;
this.crypto = crypto;
tagCipher = crypto.getTagCipher();
localTransportIds = new HashSet<TransportId>();
expected = new HashMap<Bytes, Context>();
}
// Package access for testing
synchronized boolean isInitialised() {
return initialised;
}
// Locking: this
private void initialise() throws DbException {
assert !initialised;
db.addListener(this);
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
Collection<TransportId> transports = new ArrayList<TransportId>();
for(Transport t : db.getLocalTransports()) transports.add(t.getId());
for(ContactId c : db.getContacts()) {
try {
for(TransportId t : transports) {
TransportIndex i = db.getRemoteIndex(c, t);
if(i == null) continue; // Contact doesn't support transport
ConnectionWindow w = db.getConnectionWindow(c, i);
for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
Context ctx = new Context(c, t, i, e.getKey());
ivs.put(calculateTag(ctx, e.getValue()), ctx);
}
w.erase();
}
} catch(NoSuchContactException e) {
// The contact was removed - clean up in removeContact()
continue;
}
}
localTransportIds.addAll(transports);
expected.putAll(ivs);
initialised = true;
}
// Locking: this
private Bytes calculateTag(Context ctx, byte[] secret) {
ErasableKey tagKey = crypto.deriveTagKey(secret, true);
byte[] tag = new byte[TAG_LENGTH];
TagEncoder.encodeTag(tag, tagCipher, tagKey);
tagKey.erase();
return new Bytes(tag);
}
public ConnectionContext acceptConnection(TransportId t, byte[] tag)
throws DbException {
if(tag.length != TAG_LENGTH)
throw new IllegalArgumentException();
synchronized(this) {
if(!initialised) initialise();
Bytes b = new Bytes(tag);
Context ctx = expected.get(b);
if(ctx == null || !ctx.transportId.equals(t)) return null;
// The IV was expected
expected.remove(b);
ContactId c = ctx.contactId;
TransportIndex i = ctx.transportIndex;
long connection = ctx.connection;
ConnectionWindow w = null;
byte[] secret = null;
// Get the secret and update the connection window
try {
w = db.getConnectionWindow(c, i);
secret = w.setSeen(connection);
db.setConnectionWindow(c, i, w);
} catch(NoSuchContactException e) {
// The contact was removed - reject the connection
if(w != null) w.erase();
if(secret != null) ByteUtils.erase(secret);
return null;
}
// Update the connection window's expected IVs
Iterator<Context> it = expected.values().iterator();
while(it.hasNext()) {
Context ctx1 = it.next();
if(ctx1.contactId.equals(c) && ctx1.transportIndex.equals(i))
it.remove();
}
for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
Context ctx1 = new Context(c, t, i, e.getKey());
expected.put(calculateTag(ctx1, e.getValue()), ctx1);
}
w.erase();
return new ConnectionContextImpl(c, i, connection, secret);
}
}
public void eventOccurred(DatabaseEvent e) {
if(e instanceof ContactRemovedEvent) {
// Remove the expected IVs for the ex-contact
final ContactId c = ((ContactRemovedEvent) e).getContactId();
connExecutor.execute(new Runnable() {
public void run() {
removeContact(c);
}
});
} else if(e instanceof TransportAddedEvent) {
// Add the expected IVs for the new transport
final TransportId t = ((TransportAddedEvent) e).getTransportId();
connExecutor.execute(new Runnable() {
public void run() {
addTransport(t);
}
});
} else if(e instanceof RemoteTransportsUpdatedEvent) {
// Update the expected IVs for the contact
RemoteTransportsUpdatedEvent r = (RemoteTransportsUpdatedEvent) e;
final ContactId c = r.getContactId();
final Collection<Transport> transports = r.getTransports();
connExecutor.execute(new Runnable() {
public void run() {
updateContact(c, transports);
}
});
}
}
private synchronized void removeContact(ContactId c) {
if(!initialised) return;
Iterator<Context> it = expected.values().iterator();
while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove();
}
private synchronized void addTransport(TransportId t) {
if(!initialised) return;
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
try {
for(ContactId c : db.getContacts()) {
try {
TransportIndex i = db.getRemoteIndex(c, t);
if(i == null) continue; // Contact doesn't support transport
ConnectionWindow w = db.getConnectionWindow(c, i);
for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
Context ctx = new Context(c, t, i, e.getKey());
ivs.put(calculateTag(ctx, e.getValue()), ctx);
}
w.erase();
} catch(NoSuchContactException e) {
// The contact was removed - clean up in removeContact()
continue;
}
}
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
return;
}
localTransportIds.add(t);
expected.putAll(ivs);
}
private synchronized void updateContact(ContactId c,
Collection<Transport> transports) {
if(!initialised) return;
// The ID <-> index mappings may have changed, so recalculate everything
Map<Bytes, Context> ivs = new HashMap<Bytes, Context>();
try {
for(Transport transport: transports) {
TransportId t = transport.getId();
if(!localTransportIds.contains(t)) continue;
TransportIndex i = transport.getIndex();
ConnectionWindow w = db.getConnectionWindow(c, i);
for(Entry<Long, byte[]> e : w.getUnseen().entrySet()) {
Context ctx = new Context(c, t, i, e.getKey());
ivs.put(calculateTag(ctx, e.getValue()), ctx);
}
w.erase();
}
} catch(NoSuchContactException e) {
// The contact was removed - clean up in removeContact()
return;
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
return;
}
// Remove the old IVs
Iterator<Context> it = expected.values().iterator();
while(it.hasNext()) if(it.next().contactId.equals(c)) it.remove();
// Store the new IVs
expected.putAll(ivs);
}
private static class Context {
private final ContactId contactId;
private final TransportId transportId;
private final TransportIndex transportIndex;
private final long connection;
private Context(ContactId contactId, TransportId transportId,
TransportIndex transportIndex, long connection) {
this.contactId = contactId;
this.transportId = transportId;
this.transportIndex = transportIndex;
this.connection = connection;
}
}
}

View File

@@ -1,30 +0,0 @@
package net.sf.briar.transport;
import java.util.Map;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.api.transport.ConnectionWindowFactory;
import com.google.inject.Inject;
class ConnectionWindowFactoryImpl implements ConnectionWindowFactory {
private final CryptoComponent crypto;
@Inject
ConnectionWindowFactoryImpl(CryptoComponent crypto) {
this.crypto = crypto;
}
public ConnectionWindow createConnectionWindow(TransportIndex i,
byte[] secret) {
return new ConnectionWindowImpl(crypto, i, secret);
}
public ConnectionWindow createConnectionWindow(TransportIndex i,
Map<Long, byte[]> unseen) {
return new ConnectionWindowImpl(crypto, i, unseen);
}
}

View File

@@ -1,41 +1,30 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.CONNECTION_WINDOW_SIZE; import static net.sf.briar.api.transport.TransportConstants.CONNECTION_WINDOW_SIZE;
import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.util.HashMap; import java.util.HashSet;
import java.util.Map; import java.util.Set;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.util.ByteUtils;
// This class is not thread-safe // This class is not thread-safe
class ConnectionWindowImpl implements ConnectionWindow { class ConnectionWindowImpl implements ConnectionWindow {
private final CryptoComponent crypto; private final Set<Long> unseen;
private final int index;
private final Map<Long, byte[]> unseen;
private long centre; private long centre;
ConnectionWindowImpl(CryptoComponent crypto, TransportIndex i, ConnectionWindowImpl() {
byte[] secret) { unseen = new HashSet<Long>();
this.crypto = crypto; for(long l = 0; l < CONNECTION_WINDOW_SIZE / 2; l++) unseen.add(l);
index = i.getInt();
unseen = new HashMap<Long, byte[]>();
for(long l = 0; l < CONNECTION_WINDOW_SIZE / 2; l++) {
secret = crypto.deriveNextSecret(secret, index, l);
unseen.put(l, secret);
}
centre = 0; centre = 0;
} }
ConnectionWindowImpl(CryptoComponent crypto, TransportIndex i, ConnectionWindowImpl(Set<Long> unseen) {
Map<Long, byte[]> unseen) {
long min = Long.MAX_VALUE, max = Long.MIN_VALUE; long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
for(long l : unseen.keySet()) { for(long l : unseen) {
if(l < 0 || l > ByteUtils.MAX_32_BIT_UNSIGNED) if(l < 0 || l > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
if(l < min) min = l; if(l < min) min = l;
if(l > max) max = l; if(l > max) max = l;
@@ -44,42 +33,29 @@ class ConnectionWindowImpl implements ConnectionWindow {
throw new IllegalArgumentException(); throw new IllegalArgumentException();
centre = max - CONNECTION_WINDOW_SIZE / 2 + 1; centre = max - CONNECTION_WINDOW_SIZE / 2 + 1;
for(long l = centre; l <= max; l++) { for(long l = centre; l <= max; l++) {
if(!unseen.containsKey(l)) throw new IllegalArgumentException(); if(!unseen.contains(l)) throw new IllegalArgumentException();
} }
this.crypto = crypto;
index = i.getInt();
this.unseen = unseen; this.unseen = unseen;
} }
public boolean isSeen(long connection) { public boolean isSeen(long connection) {
return !unseen.containsKey(connection); return !unseen.contains(connection);
} }
public byte[] setSeen(long connection) { public void setSeen(long connection) {
long bottom = getBottom(centre); long bottom = getBottom(centre);
long top = getTop(centre); long top = getTop(centre);
if(connection < bottom || connection > top) if(connection < bottom || connection > top)
throw new IllegalArgumentException(); throw new IllegalArgumentException();
if(!unseen.containsKey(connection)) if(!unseen.remove(connection))
throw new IllegalArgumentException(); throw new IllegalArgumentException();
if(connection >= centre) { if(connection >= centre) {
centre = connection + 1; centre = connection + 1;
long newBottom = getBottom(centre); long newBottom = getBottom(centre);
long newTop = getTop(centre); long newTop = getTop(centre);
for(long l = bottom; l < newBottom; l++) { for(long l = bottom; l < newBottom; l++) unseen.remove(l);
byte[] expired = unseen.remove(l); for(long l = top + 1; l <= newTop; l++) unseen.add(l);
if(expired != null) ByteUtils.erase(expired);
}
byte[] topSecret = unseen.get(top);
assert topSecret != null;
for(long l = top + 1; l <= newTop; l++) {
topSecret = crypto.deriveNextSecret(topSecret, index, l);
unseen.put(l, topSecret);
}
} }
byte[] seen = unseen.remove(connection);
assert seen != null;
return seen;
} }
// Returns the lowest value contained in a window with the given centre // Returns the lowest value contained in a window with the given centre
@@ -89,15 +65,11 @@ class ConnectionWindowImpl implements ConnectionWindow {
// Returns the highest value contained in a window with the given centre // Returns the highest value contained in a window with the given centre
private static long getTop(long centre) { private static long getTop(long centre) {
return Math.min(ByteUtils.MAX_32_BIT_UNSIGNED, return Math.min(MAX_32_BIT_UNSIGNED,
centre + CONNECTION_WINDOW_SIZE / 2 - 1); centre + CONNECTION_WINDOW_SIZE / 2 - 1);
} }
public Map<Long, byte[]> getUnseen() { public Set<Long> getUnseen() {
return unseen; return unseen;
} }
public void erase() {
for(byte[] secret : unseen.values()) ByteUtils.erase(secret);
}
} }

View File

@@ -4,14 +4,11 @@ import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import java.io.OutputStream; import java.io.OutputStream;
import javax.crypto.Cipher;
import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.util.ByteUtils;
import com.google.inject.Inject; import com.google.inject.Inject;
@@ -25,27 +22,21 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
} }
public ConnectionWriter createConnectionWriter(OutputStream out, public ConnectionWriter createConnectionWriter(OutputStream out,
long capacity, byte[] secret, boolean initiator) { long capacity, ConnectionContext ctx, boolean initiator) {
byte[] secret = ctx.getSecret();
long connection = ctx.getConnectionNumber();
boolean alice = ctx.getAlice();
ErasableKey frameKey = crypto.deriveFrameKey(secret, connection, alice,
initiator);
FrameWriter encryption;
if(initiator) { if(initiator) {
// Derive the tag and frame keys and erase the secret encryption = new OutgoingEncryptionLayer(out, capacity,
ErasableKey tagKey = crypto.deriveTagKey(secret, initiator); crypto.getFrameCipher(), frameKey, MAX_FRAME_LENGTH,
ErasableKey frameKey = crypto.deriveFrameKey(secret, initiator); ctx.getTag());
ByteUtils.erase(secret);
// Create a writer for the initiator's side of the connection
Cipher tagCipher = crypto.getTagCipher();
AuthenticatedCipher frameCipher = crypto.getFrameCipher();
FrameWriter encryption = new OutgoingEncryptionLayer(out, capacity,
tagCipher, frameCipher, tagKey, frameKey, MAX_FRAME_LENGTH);
return new ConnectionWriterImpl(encryption, MAX_FRAME_LENGTH);
} else { } else {
// Derive the frame key and erase the secret encryption = new OutgoingEncryptionLayer(out, capacity,
ErasableKey frameKey = crypto.deriveFrameKey(secret, initiator); crypto.getFrameCipher(), frameKey, MAX_FRAME_LENGTH);
ByteUtils.erase(secret);
// Create a writer for the responder's side of the connection
AuthenticatedCipher frameCipher = crypto.getFrameCipher();
FrameWriter encryption = new OutgoingEncryptionLayer(out, capacity,
frameCipher, frameKey, MAX_FRAME_LENGTH);
return new ConnectionWriterImpl(encryption, MAX_FRAME_LENGTH);
} }
return new ConnectionWriterImpl(encryption, MAX_FRAME_LENGTH);
} }
} }

View File

@@ -5,15 +5,12 @@ import static net.sf.briar.api.transport.TransportConstants.AAD_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH; import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH; import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH; import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import java.io.EOFException; import java.io.EOFException;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import javax.crypto.Cipher;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.crypto.AuthenticatedCipher; import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
@@ -21,70 +18,29 @@ import net.sf.briar.api.crypto.ErasableKey;
class IncomingEncryptionLayer implements FrameReader { class IncomingEncryptionLayer implements FrameReader {
private final InputStream in; private final InputStream in;
private final Cipher tagCipher;
private final AuthenticatedCipher frameCipher; private final AuthenticatedCipher frameCipher;
private final ErasableKey tagKey, frameKey; private final ErasableKey frameKey;
private final byte[] iv, aad, ciphertext; private final byte[] iv, aad, ciphertext;
private final int frameLength; private final int frameLength;
private long frameNumber; private long frameNumber;
private boolean readTag, finalFrame; private boolean finalFrame;
/** Constructor for the initiator's side of a connection. */
IncomingEncryptionLayer(InputStream in, Cipher tagCipher,
AuthenticatedCipher frameCipher, ErasableKey tagKey,
ErasableKey frameKey, int frameLength) {
this.in = in;
this.tagCipher = tagCipher;
this.frameCipher = frameCipher;
this.tagKey = tagKey;
this.frameKey = frameKey;
this.frameLength = frameLength;
iv = new byte[IV_LENGTH];
aad = new byte[AAD_LENGTH];
ciphertext = new byte[frameLength];
frameNumber = 0L;
readTag = true;
finalFrame = false;
}
/** Constructor for the responder's side of a connection. */
IncomingEncryptionLayer(InputStream in, AuthenticatedCipher frameCipher, IncomingEncryptionLayer(InputStream in, AuthenticatedCipher frameCipher,
ErasableKey frameKey, int frameLength) { ErasableKey frameKey, int frameLength) {
this.in = in; this.in = in;
this.frameCipher = frameCipher; this.frameCipher = frameCipher;
this.frameKey = frameKey; this.frameKey = frameKey;
this.frameLength = frameLength; this.frameLength = frameLength;
tagCipher = null;
tagKey = null;
iv = new byte[IV_LENGTH]; iv = new byte[IV_LENGTH];
aad = new byte[AAD_LENGTH]; aad = new byte[AAD_LENGTH];
ciphertext = new byte[frameLength]; ciphertext = new byte[frameLength];
frameNumber = 0L; frameNumber = 0L;
readTag = false;
finalFrame = false; finalFrame = false;
} }
public int readFrame(byte[] frame) throws IOException { public int readFrame(byte[] frame) throws IOException {
if(finalFrame) return -1; if(finalFrame) return -1;
// Read the tag if required
if(readTag) {
int offset = 0;
try {
while(offset < TAG_LENGTH) {
int read = in.read(ciphertext, offset, TAG_LENGTH - offset);
if(read == -1) throw new EOFException();
offset += read;
}
} catch(IOException e) {
frameKey.erase();
tagKey.erase();
throw e;
}
if(!TagEncoder.decodeTag(ciphertext, tagCipher, tagKey))
throw new FormatException();
readTag = false;
}
// Read the frame // Read the frame
int ciphertextLength = 0; int ciphertextLength = 0;
try { try {
@@ -96,7 +52,6 @@ class IncomingEncryptionLayer implements FrameReader {
} }
} catch(IOException e) { } catch(IOException e) {
frameKey.erase(); frameKey.erase();
tagKey.erase();
throw e; throw e;
} }
int plaintextLength = ciphertextLength - MAC_LENGTH; int plaintextLength = ciphertextLength - MAC_LENGTH;

View File

@@ -5,41 +5,36 @@ import static net.sf.briar.api.transport.TransportConstants.AAD_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH; import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH; import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH; import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED; import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import javax.crypto.Cipher;
import net.sf.briar.api.crypto.AuthenticatedCipher; import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
class OutgoingEncryptionLayer implements FrameWriter { class OutgoingEncryptionLayer implements FrameWriter {
private final OutputStream out; private final OutputStream out;
private final Cipher tagCipher;
private final AuthenticatedCipher frameCipher; private final AuthenticatedCipher frameCipher;
private final ErasableKey tagKey, frameKey; private final ErasableKey frameKey;
private final byte[] iv, aad, ciphertext; private final byte[] tag, iv, aad, ciphertext;
private final int frameLength, maxPayloadLength; private final int frameLength, maxPayloadLength;
private long capacity, frameNumber; private long capacity, frameNumber;
private boolean writeTag; private boolean writeTag;
/** Constructor for the initiator's side of a connection. */ /** Constructor for the initiator's side of a connection. */
OutgoingEncryptionLayer(OutputStream out, long capacity, Cipher tagCipher, OutgoingEncryptionLayer(OutputStream out, long capacity,
AuthenticatedCipher frameCipher, ErasableKey tagKey, AuthenticatedCipher frameCipher, ErasableKey frameKey,
ErasableKey frameKey, int frameLength) { int frameLength, byte[] tag) {
this.out = out; this.out = out;
this.capacity = capacity; this.capacity = capacity;
this.tagCipher = tagCipher;
this.frameCipher = frameCipher; this.frameCipher = frameCipher;
this.tagKey = tagKey;
this.frameKey = frameKey; this.frameKey = frameKey;
this.frameLength = frameLength; this.frameLength = frameLength;
this.tag = tag;
maxPayloadLength = frameLength - HEADER_LENGTH - MAC_LENGTH; maxPayloadLength = frameLength - HEADER_LENGTH - MAC_LENGTH;
iv = new byte[IV_LENGTH]; iv = new byte[IV_LENGTH];
aad = new byte[AAD_LENGTH]; aad = new byte[AAD_LENGTH];
@@ -57,9 +52,8 @@ class OutgoingEncryptionLayer implements FrameWriter {
this.frameCipher = frameCipher; this.frameCipher = frameCipher;
this.frameKey = frameKey; this.frameKey = frameKey;
this.frameLength = frameLength; this.frameLength = frameLength;
tag = null;
maxPayloadLength = frameLength - HEADER_LENGTH - MAC_LENGTH; maxPayloadLength = frameLength - HEADER_LENGTH - MAC_LENGTH;
tagCipher = null;
tagKey = null;
iv = new byte[IV_LENGTH]; iv = new byte[IV_LENGTH];
aad = new byte[AAD_LENGTH]; aad = new byte[AAD_LENGTH];
ciphertext = new byte[frameLength]; ciphertext = new byte[frameLength];
@@ -75,15 +69,13 @@ class OutgoingEncryptionLayer implements FrameWriter {
if(writeTag && finalFrame && payloadLength == 0) return; if(writeTag && finalFrame && payloadLength == 0) return;
// Write the tag if required // Write the tag if required
if(writeTag) { if(writeTag) {
TagEncoder.encodeTag(ciphertext, tagCipher, tagKey);
try { try {
out.write(ciphertext, 0, TAG_LENGTH); out.write(tag, 0, tag.length);
} catch(IOException e) { } catch(IOException e) {
frameKey.erase(); frameKey.erase();
tagKey.erase();
throw e; throw e;
} }
capacity -= TAG_LENGTH; capacity -= tag.length;
writeTag = false; writeTag = false;
} }
// Encode the header // Encode the header
@@ -117,7 +109,6 @@ class OutgoingEncryptionLayer implements FrameWriter {
out.write(ciphertext, 0, ciphertextLength); out.write(ciphertext, 0, ciphertextLength);
} catch(IOException e) { } catch(IOException e) {
frameKey.erase(); frameKey.erase();
tagKey.erase();
throw e; throw e;
} }
capacity -= ciphertextLength; capacity -= ciphertextLength;
@@ -132,7 +123,7 @@ class OutgoingEncryptionLayer implements FrameWriter {
// How many frame numbers can we use? // How many frame numbers can we use?
long frameNumbers = MAX_32_BIT_UNSIGNED - frameNumber + 1; long frameNumbers = MAX_32_BIT_UNSIGNED - frameNumber + 1;
// How many full frames do we have space for? // How many full frames do we have space for?
long bytes = writeTag ? capacity - TAG_LENGTH : capacity; long bytes = writeTag ? capacity - tag.length : capacity;
long fullFrames = bytes / frameLength; long fullFrames = bytes / frameLength;
// Are we limited by frame numbers or space? // Are we limited by frame numbers or space?
if(frameNumbers > fullFrames) { if(frameNumbers > fullFrames) {

View File

@@ -3,12 +3,9 @@ package net.sf.briar.transport;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import net.sf.briar.api.transport.ConnectionContextFactory;
import net.sf.briar.api.transport.ConnectionDispatcher; import net.sf.briar.api.transport.ConnectionDispatcher;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionRegistry; import net.sf.briar.api.transport.ConnectionRegistry;
import net.sf.briar.api.transport.ConnectionWindowFactory;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.api.transport.IncomingConnectionExecutor; import net.sf.briar.api.transport.IncomingConnectionExecutor;
@@ -18,15 +15,10 @@ public class TransportModule extends AbstractModule {
@Override @Override
protected void configure() { protected void configure() {
bind(ConnectionContextFactory.class).to(
ConnectionContextFactoryImpl.class);
bind(ConnectionDispatcher.class).to(ConnectionDispatcherImpl.class); bind(ConnectionDispatcher.class).to(ConnectionDispatcherImpl.class);
bind(ConnectionReaderFactory.class).to( bind(ConnectionReaderFactory.class).to(
ConnectionReaderFactoryImpl.class); ConnectionReaderFactoryImpl.class);
bind(ConnectionRecogniser.class).to(ConnectionRecogniserImpl.class);
bind(ConnectionRegistry.class).toInstance(new ConnectionRegistryImpl()); bind(ConnectionRegistry.class).toInstance(new ConnectionRegistryImpl());
bind(ConnectionWindowFactory.class).to(
ConnectionWindowFactoryImpl.class);
bind(ConnectionWriterFactory.class).to( bind(ConnectionWriterFactory.class).to(
ConnectionWriterFactoryImpl.class); ConnectionWriterFactoryImpl.class);
// The executor is unbounded, so tasks can be dependent or long-lived // The executor is unbounded, so tasks can be dependent or long-lived

View File

@@ -19,6 +19,7 @@
<test name='net.sf.briar.crypto.CounterModeTest'/> <test name='net.sf.briar.crypto.CounterModeTest'/>
<test name='net.sf.briar.crypto.ErasableKeyTest'/> <test name='net.sf.briar.crypto.ErasableKeyTest'/>
<test name='net.sf.briar.crypto.KeyDerivationTest'/> <test name='net.sf.briar.crypto.KeyDerivationTest'/>
<test name='net.sf.briar.crypto.KeyRotatorImplTest'/>
<test name='net.sf.briar.db.BasicH2Test'/> <test name='net.sf.briar.db.BasicH2Test'/>
<test name='net.sf.briar.db.DatabaseCleanerImplTest'/> <test name='net.sf.briar.db.DatabaseCleanerImplTest'/>
<test name='net.sf.briar.db.DatabaseComponentImplTest'/> <test name='net.sf.briar.db.DatabaseComponentImplTest'/>
@@ -46,7 +47,6 @@
<test name='net.sf.briar.serial.ReaderImplTest'/> <test name='net.sf.briar.serial.ReaderImplTest'/>
<test name='net.sf.briar.serial.WriterImplTest'/> <test name='net.sf.briar.serial.WriterImplTest'/>
<test name='net.sf.briar.transport.ConnectionReaderImplTest'/> <test name='net.sf.briar.transport.ConnectionReaderImplTest'/>
<test name='net.sf.briar.transport.ConnectionRecogniserImplTest'/>
<test name='net.sf.briar.transport.ConnectionRegistryImplTest'/> <test name='net.sf.briar.transport.ConnectionRegistryImplTest'/>
<test name='net.sf.briar.transport.ConnectionWindowImplTest'/> <test name='net.sf.briar.transport.ConnectionWindowImplTest'/>
<test name='net.sf.briar.transport.ConnectionWriterImplTest'/> <test name='net.sf.briar.transport.ConnectionWriterImplTest'/>

View File

@@ -17,6 +17,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.Random; import java.util.Random;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.Author; import net.sf.briar.api.protocol.Author;
@@ -40,8 +41,8 @@ import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReader;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
@@ -72,8 +73,9 @@ public class ProtocolIntegrationTest extends BriarTestCase {
private final ProtocolWriterFactory protocolWriterFactory; private final ProtocolWriterFactory protocolWriterFactory;
private final PacketFactory packetFactory; private final PacketFactory packetFactory;
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final ContactId contactId;
private final TransportId transportId;
private final byte[] secret; private final byte[] secret;
private final TransportIndex transportIndex = new TransportIndex(13);
private final Author author; private final Author author;
private final Group group, group1; private final Group group, group1;
private final Message message, message1, message2, message3; private final Message message, message1, message2, message3;
@@ -95,6 +97,8 @@ public class ProtocolIntegrationTest extends BriarTestCase {
protocolWriterFactory = i.getInstance(ProtocolWriterFactory.class); protocolWriterFactory = i.getInstance(ProtocolWriterFactory.class);
packetFactory = i.getInstance(PacketFactory.class); packetFactory = i.getInstance(PacketFactory.class);
crypto = i.getInstance(CryptoComponent.class); crypto = i.getInstance(CryptoComponent.class);
contactId = new ContactId(234);
transportId = new TransportId(TestUtils.getRandomId());
// Create a shared secret // Create a shared secret
Random r = new Random(); Random r = new Random();
secret = new byte[32]; secret = new byte[32];
@@ -125,7 +129,7 @@ public class ProtocolIntegrationTest extends BriarTestCase {
subject, messageBody.getBytes("UTF-8")); subject, messageBody.getBytes("UTF-8"));
// Create some transports // Create some transports
TransportId transportId = new TransportId(TestUtils.getRandomId()); TransportId transportId = new TransportId(TestUtils.getRandomId());
Transport transport = new Transport(transportId, transportIndex, Transport transport = new Transport(transportId,
Collections.singletonMap("bar", "baz")); Collections.singletonMap("bar", "baz"));
transports = Collections.singletonList(transport); transports = Collections.singletonList(transport);
} }
@@ -137,8 +141,11 @@ public class ProtocolIntegrationTest extends BriarTestCase {
private byte[] write() throws Exception { private byte[] write() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
byte[] tag = new byte[TAG_LENGTH];
ConnectionContext ctx = new ConnectionContext(contactId, transportId,
tag, secret.clone(), 0L, true);
ConnectionWriter conn = connectionWriterFactory.createConnectionWriter( ConnectionWriter conn = connectionWriterFactory.createConnectionWriter(
out, Long.MAX_VALUE, secret.clone(), true); out, Long.MAX_VALUE, ctx, true);
OutputStream out1 = conn.getOutputStream(); OutputStream out1 = conn.getOutputStream();
ProtocolWriter writer = protocolWriterFactory.createProtocolWriter(out1, ProtocolWriter writer = protocolWriterFactory.createProtocolWriter(out1,
false); false);
@@ -190,8 +197,11 @@ public class ProtocolIntegrationTest extends BriarTestCase {
InputStream in = new ByteArrayInputStream(connectionData); InputStream in = new ByteArrayInputStream(connectionData);
byte[] tag = new byte[TAG_LENGTH]; byte[] tag = new byte[TAG_LENGTH];
assertEquals(TAG_LENGTH, in.read(tag, 0, TAG_LENGTH)); assertEquals(TAG_LENGTH, in.read(tag, 0, TAG_LENGTH));
assertArrayEquals(new byte[TAG_LENGTH], tag);
ConnectionContext ctx = new ConnectionContext(contactId, transportId,
tag, secret.clone(), 0L, true);
ConnectionReader conn = connectionReaderFactory.createConnectionReader( ConnectionReader conn = connectionReaderFactory.createConnectionReader(
in, secret.clone(), true); in, ctx, true);
InputStream in1 = conn.getInputStream(); InputStream in1 = conn.getInputStream();
ProtocolReader reader = protocolReaderFactory.createProtocolReader(in1); ProtocolReader reader = protocolReaderFactory.createProtocolReader(in1);

View File

@@ -8,7 +8,6 @@ import java.util.Random;
import net.sf.briar.BriarTestCase; import net.sf.briar.BriarTestCase;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.protocol.ProtocolConstants;
import org.junit.Test; import org.junit.Test;
@@ -27,8 +26,10 @@ public class KeyDerivationTest extends BriarTestCase {
@Test @Test
public void testKeysAreDistinct() { public void testKeysAreDistinct() {
List<ErasableKey> keys = new ArrayList<ErasableKey>(); List<ErasableKey> keys = new ArrayList<ErasableKey>();
keys.add(crypto.deriveFrameKey(secret, true)); keys.add(crypto.deriveFrameKey(secret, 0, false, false));
keys.add(crypto.deriveFrameKey(secret, false)); keys.add(crypto.deriveFrameKey(secret, 0, false, true));
keys.add(crypto.deriveFrameKey(secret, 0, true, false));
keys.add(crypto.deriveFrameKey(secret, 0, true, true));
keys.add(crypto.deriveTagKey(secret, true)); keys.add(crypto.deriveTagKey(secret, true));
keys.add(crypto.deriveTagKey(secret, false)); keys.add(crypto.deriveTagKey(secret, false));
for(int i = 0; i < 4; i++) { for(int i = 0; i < 4; i++) {
@@ -47,7 +48,7 @@ public class KeyDerivationTest extends BriarTestCase {
for(int i = 0; i < 20; i++) { for(int i = 0; i < 20; i++) {
byte[] b = new byte[32]; byte[] b = new byte[32];
r.nextBytes(b); r.nextBytes(b);
secrets.add(crypto.deriveNextSecret(b, 0, 0)); secrets.add(crypto.deriveNextSecret(b, 0));
} }
for(int i = 0; i < 20; i++) { for(int i = 0; i < 20; i++) {
byte[] secretI = secrets.get(i); byte[] secretI = secrets.get(i);
@@ -58,26 +59,11 @@ public class KeyDerivationTest extends BriarTestCase {
} }
} }
@Test
public void testTransportIndexAffectsDerivation() {
List<byte[]> secrets = new ArrayList<byte[]>();
for(int i = 0; i < ProtocolConstants.MAX_TRANSPORTS; i++) {
secrets.add(crypto.deriveNextSecret(secret, i, 0));
}
for(int i = 0; i < ProtocolConstants.MAX_TRANSPORTS; i++) {
byte[] secretI = secrets.get(i);
for(int j = 0; j < ProtocolConstants.MAX_TRANSPORTS; j++) {
byte[] secretJ = secrets.get(j);
assertEquals(i == j, Arrays.equals(secretI, secretJ));
}
}
}
@Test @Test
public void testConnectionNumberAffectsDerivation() { public void testConnectionNumberAffectsDerivation() {
List<byte[]> secrets = new ArrayList<byte[]>(); List<byte[]> secrets = new ArrayList<byte[]>();
for(int i = 0; i < 20; i++) { for(int i = 0; i < 20; i++) {
secrets.add(crypto.deriveNextSecret(secret, 0, i)); secrets.add(crypto.deriveNextSecret(secret, i));
} }
for(int i = 0; i < 20; i++) { for(int i = 0; i < 20; i++) {
byte[] secretI = secrets.get(i); byte[] secretI = secrets.get(i);

View File

@@ -1,11 +1,12 @@
package net.sf.briar.db; package net.sf.briar.crypto;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import net.sf.briar.BriarTestCase; import net.sf.briar.BriarTestCase;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.db.KeyRotator.Callback; import net.sf.briar.crypto.KeyRotatorImpl;
import net.sf.briar.crypto.KeyRotator.Callback;
import org.junit.Test; import org.junit.Test;

View File

@@ -23,7 +23,6 @@ import net.sf.briar.api.db.event.DatabaseListener;
import net.sf.briar.api.db.event.MessagesAddedEvent; import net.sf.briar.api.db.event.MessagesAddedEvent;
import net.sf.briar.api.db.event.RatingChangedEvent; import net.sf.briar.api.db.event.RatingChangedEvent;
import net.sf.briar.api.db.event.SubscriptionsUpdatedEvent; import net.sf.briar.api.db.event.SubscriptionsUpdatedEvent;
import net.sf.briar.api.db.event.TransportAddedEvent;
import net.sf.briar.api.lifecycle.ShutdownManager; import net.sf.briar.api.lifecycle.ShutdownManager;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.AuthorId; import net.sf.briar.api.protocol.AuthorId;
@@ -40,9 +39,7 @@ import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.transport.ConnectionWindow;
import org.jmock.Expectations; import org.jmock.Expectations;
import org.jmock.Mockery; import org.jmock.Mockery;
@@ -63,7 +60,6 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
private final Message message, privateMessage; private final Message message, privateMessage;
private final Group group; private final Group group;
private final TransportId transportId; private final TransportId transportId;
private final TransportIndex localIndex, remoteIndex;
private final Collection<Transport> transports; private final Collection<Transport> transports;
private final Map<ContactId, TransportProperties> remoteProperties; private final Map<ContactId, TransportProperties> remoteProperties;
private final byte[] inSecret, outSecret; private final byte[] inSecret, outSecret;
@@ -72,7 +68,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
super(); super();
authorId = new AuthorId(TestUtils.getRandomId()); authorId = new AuthorId(TestUtils.getRandomId());
batchId = new BatchId(TestUtils.getRandomId()); batchId = new BatchId(TestUtils.getRandomId());
contactId = new ContactId(123); contactId = new ContactId(234);
groupId = new GroupId(TestUtils.getRandomId()); groupId = new GroupId(TestUtils.getRandomId());
messageId = new MessageId(TestUtils.getRandomId()); messageId = new MessageId(TestUtils.getRandomId());
parentId = new MessageId(TestUtils.getRandomId()); parentId = new MessageId(TestUtils.getRandomId());
@@ -86,13 +82,10 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
timestamp, raw); timestamp, raw);
group = new TestGroup(groupId, "The really exciting group", null); group = new TestGroup(groupId, "The really exciting group", null);
transportId = new TransportId(TestUtils.getRandomId()); transportId = new TransportId(TestUtils.getRandomId());
localIndex = new TransportIndex(0);
remoteIndex = new TransportIndex(13);
TransportProperties properties = new TransportProperties( TransportProperties properties = new TransportProperties(
Collections.singletonMap("foo", "bar")); Collections.singletonMap("foo", "bar"));
remoteProperties = Collections.singletonMap(contactId, properties); remoteProperties = Collections.singletonMap(contactId, properties);
Transport transport = new Transport(transportId, localIndex, Transport transport = new Transport(transportId, properties);
properties);
transports = Collections.singletonList(transport); transports = Collections.singletonList(transport);
Random r = new Random(); Random r = new Random();
inSecret = new byte[32]; inSecret = new byte[32];
@@ -108,14 +101,13 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
@Test @Test
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void testSimpleCalls() throws Exception { public void testSimpleCalls() throws Exception {
// FIXME: Test new methods
final int shutdownHandle = 12345; final int shutdownHandle = 12345;
Mockery context = new Mockery(); Mockery context = new Mockery();
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class); final PacketFactory packetFactory = context.mock(PacketFactory.class);
final ConnectionWindow connectionWindow =
context.mock(ConnectionWindow.class);
final Group group = context.mock(Group.class); final Group group = context.mock(Group.class);
final DatabaseListener listener = context.mock(DatabaseListener.class); final DatabaseListener listener = context.mock(DatabaseListener.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
@@ -142,18 +134,12 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
oneOf(database).setRating(txn, authorId, Rating.GOOD); oneOf(database).setRating(txn, authorId, Rating.GOOD);
will(returnValue(Rating.GOOD)); will(returnValue(Rating.GOOD));
// addContact() // addContact()
oneOf(database).addContact(with(txn), with(inSecret), oneOf(database).addContact(txn);
with(outSecret), with(any(Collection.class)));
will(returnValue(contactId)); will(returnValue(contactId));
oneOf(listener).eventOccurred(with(any(ContactAddedEvent.class))); oneOf(listener).eventOccurred(with(any(ContactAddedEvent.class)));
// getContacts() // getContacts()
oneOf(database).getContacts(txn); oneOf(database).getContacts(txn);
will(returnValue(Collections.singletonList(contactId))); will(returnValue(Collections.singletonList(contactId)));
// getConnectionWindow(contactId, 13)
oneOf(database).containsContact(txn, contactId);
will(returnValue(true));
oneOf(database).getConnectionWindow(txn, contactId, remoteIndex);
will(returnValue(connectionWindow));
// getTransportProperties(transportId) // getTransportProperties(transportId)
oneOf(database).getRemoteProperties(txn, transportId); oneOf(database).getRemoteProperties(txn, transportId);
will(returnValue(remoteProperties)); will(returnValue(remoteProperties));
@@ -183,11 +169,6 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
// unsubscribe(groupId) again // unsubscribe(groupId) again
oneOf(database).containsSubscription(txn, groupId); oneOf(database).containsSubscription(txn, groupId);
will(returnValue(false)); will(returnValue(false));
// setConnectionWindow(contactId, 13, connectionWindow)
oneOf(database).containsContact(txn, contactId);
will(returnValue(true));
oneOf(database).setConnectionWindow(txn, contactId, remoteIndex,
connectionWindow);
// removeContact(contactId) // removeContact(contactId)
oneOf(database).containsContact(txn, contactId); oneOf(database).containsContact(txn, contactId);
will(returnValue(true)); will(returnValue(true));
@@ -206,10 +187,8 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
assertEquals(Rating.UNRATED, db.getRating(authorId)); assertEquals(Rating.UNRATED, db.getRating(authorId));
db.setRating(authorId, Rating.GOOD); // First time - listeners called db.setRating(authorId, Rating.GOOD); // First time - listeners called
db.setRating(authorId, Rating.GOOD); // Second time - not called db.setRating(authorId, Rating.GOOD); // Second time - not called
assertEquals(contactId, db.addContact(inSecret, outSecret)); assertEquals(contactId, db.addContact());
assertEquals(Collections.singletonList(contactId), db.getContacts()); assertEquals(Collections.singletonList(contactId), db.getContacts());
assertEquals(connectionWindow,
db.getConnectionWindow(contactId, remoteIndex));
assertEquals(remoteProperties, db.getRemoteProperties(transportId)); assertEquals(remoteProperties, db.getRemoteProperties(transportId));
db.subscribe(group); // First time - listeners called db.subscribe(group); // First time - listeners called
db.subscribe(group); // Second time - not called db.subscribe(group); // Second time - not called
@@ -217,7 +196,6 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
assertEquals(Collections.singletonList(groupId), db.getSubscriptions()); assertEquals(Collections.singletonList(groupId), db.getSubscriptions());
db.unsubscribe(groupId); // First time - listeners called db.unsubscribe(groupId); // First time - listeners called
db.unsubscribe(groupId); // Second time - not called db.unsubscribe(groupId); // Second time - not called
db.setConnectionWindow(contactId, remoteIndex, connectionWindow);
db.removeContact(contactId); db.removeContact(contactId);
db.removeListener(listener); db.removeListener(listener);
db.close(); db.close();
@@ -297,7 +275,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
@Test @Test
public void testAffectedParentContinuesBackwardInclusion() public void testAffectedParentContinuesBackwardInclusion()
throws Exception { throws Exception {
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
@@ -338,7 +316,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
@Test @Test
public void testGroupMessagesAreNotStoredUnlessSubscribed() public void testGroupMessagesAreNotStoredUnlessSubscribed()
throws Exception { throws Exception {
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
@@ -424,7 +402,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
@Test @Test
public void testAddingSendableMessageTriggersBackwardInclusion() public void testAddingSendableMessageTriggersBackwardInclusion()
throws Exception { throws Exception {
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
@@ -516,7 +494,8 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
@Test @Test
public void testVariousMethodsThrowExceptionIfContactIsMissing() public void testVariousMethodsThrowExceptionIfContactIsMissing()
throws Exception { throws Exception {
// FIXME: Test new methods
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
@@ -527,16 +506,16 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
final Batch batch = context.mock(Batch.class); final Batch batch = context.mock(Batch.class);
final Offer offer = context.mock(Offer.class); final Offer offer = context.mock(Offer.class);
final SubscriptionUpdate subscriptionUpdate = final SubscriptionUpdate subscriptionUpdate =
context.mock(SubscriptionUpdate.class); context.mock(SubscriptionUpdate.class);
final TransportUpdate transportUpdate = final TransportUpdate transportUpdate =
context.mock(TransportUpdate.class); context.mock(TransportUpdate.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
// Check whether the contact is still in the DB (which it's not) // Check whether the contact is still in the DB (which it's not)
exactly(19).of(database).startTransaction(); exactly(15).of(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
exactly(19).of(database).containsContact(txn, contactId); exactly(15).of(database).containsContact(txn, contactId);
will(returnValue(false)); will(returnValue(false));
exactly(19).of(database).commitTransaction(txn); exactly(15).of(database).abortTransaction(txn);
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown, packetFactory); shutdown, packetFactory);
@@ -577,21 +556,6 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
fail(); fail();
} catch(NoSuchContactException expected) {} } catch(NoSuchContactException expected) {}
try {
db.getConnectionContext(contactId, remoteIndex);
fail();
} catch(NoSuchContactException expected) {}
try {
db.getConnectionWindow(contactId, remoteIndex);
fail();
} catch(NoSuchContactException expected) {}
try {
db.getRemoteIndex(contactId, transportId);
fail();
} catch(NoSuchContactException expected) {}
try { try {
db.hasSendableMessages(contactId); db.hasSendableMessages(contactId);
fail(); fail();
@@ -627,11 +591,6 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
fail(); fail();
} catch(NoSuchContactException expected) {} } catch(NoSuchContactException expected) {}
try {
db.setConnectionWindow(contactId, remoteIndex, null);
fail();
} catch(NoSuchContactException expected) {}
try { try {
db.setSeen(contactId, Collections.singletonList(messageId)); db.setSeen(contactId, Collections.singletonList(messageId));
fail(); fail();
@@ -811,7 +770,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class); final PacketFactory packetFactory = context.mock(PacketFactory.class);
final SubscriptionUpdate subscriptionUpdate = final SubscriptionUpdate subscriptionUpdate =
context.mock(SubscriptionUpdate.class); context.mock(SubscriptionUpdate.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -883,7 +842,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class); final PacketFactory packetFactory = context.mock(PacketFactory.class);
final TransportUpdate transportUpdate = final TransportUpdate transportUpdate =
context.mock(TransportUpdate.class); context.mock(TransportUpdate.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -1015,7 +974,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
@Test @Test
public void testReceiveBatchDoesNotStoreGroupMessageUnlessSubscribed() public void testReceiveBatchDoesNotStoreGroupMessageUnlessSubscribed()
throws Exception { throws Exception {
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
@@ -1050,7 +1009,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
@Test @Test
public void testReceiveBatchDoesNotCalculateSendabilityForDuplicates() public void testReceiveBatchDoesNotCalculateSendabilityForDuplicates()
throws Exception { throws Exception {
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
@@ -1241,7 +1200,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class); final PacketFactory packetFactory = context.mock(PacketFactory.class);
final SubscriptionUpdate subscriptionUpdate = final SubscriptionUpdate subscriptionUpdate =
context.mock(SubscriptionUpdate.class); context.mock(SubscriptionUpdate.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -1281,7 +1240,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class); final PacketFactory packetFactory = context.mock(PacketFactory.class);
final TransportUpdate transportUpdate = final TransportUpdate transportUpdate =
context.mock(TransportUpdate.class); context.mock(TransportUpdate.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -1375,7 +1334,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
@Test @Test
public void testAddingDuplicateGroupMessageDoesNotCallListeners() public void testAddingDuplicateGroupMessageDoesNotCallListeners()
throws Exception { throws Exception {
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
@@ -1405,7 +1364,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
@Test @Test
public void testAddingDuplicatePrivateMessageDoesNotCallListeners() public void testAddingDuplicatePrivateMessageDoesNotCallListeners()
throws Exception { throws Exception {
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
@@ -1435,16 +1394,15 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
@Test @Test
public void testTransportPropertiesChangedCallsListeners() public void testTransportPropertiesChangedCallsListeners()
throws Exception { throws Exception {
final TransportProperties properties = final TransportProperties properties =
new TransportProperties(Collections.singletonMap("bar", "baz")); new TransportProperties(Collections.singletonMap("bar", "baz"));
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final ShutdownManager shutdown = context.mock(ShutdownManager.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final PacketFactory packetFactory = context.mock(PacketFactory.class); final PacketFactory packetFactory = context.mock(PacketFactory.class);
final DatabaseListener listener = context.mock(DatabaseListener.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(database).startTransaction(); oneOf(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -1454,13 +1412,10 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
oneOf(database).setTransportsModified(with(txn), oneOf(database).setTransportsModified(with(txn),
with(any(long.class))); with(any(long.class)));
oneOf(database).commitTransaction(txn); oneOf(database).commitTransaction(txn);
oneOf(listener).eventOccurred(with(any(
TransportAddedEvent.class)));
}}); }});
DatabaseComponent db = createDatabaseComponent(database, cleaner, DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown, packetFactory); shutdown, packetFactory);
db.addListener(listener);
db.setLocalProperties(transportId, properties); db.setLocalProperties(transportId, properties);
context.assertIsSatisfied(); context.assertIsSatisfied();
@@ -1468,9 +1423,9 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
@Test @Test
public void testTransportPropertiesUnchangedDoesNotCallListeners() public void testTransportPropertiesUnchangedDoesNotCallListeners()
throws Exception { throws Exception {
final TransportProperties properties = final TransportProperties properties =
new TransportProperties(Collections.singletonMap("bar", "baz")); new TransportProperties(Collections.singletonMap("bar", "baz"));
Mockery context = new Mockery(); Mockery context = new Mockery();
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
@@ -1521,7 +1476,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
@Test @Test
public void testVisibilityChangedCallsListeners() throws Exception { public void testVisibilityChangedCallsListeners() throws Exception {
final ContactId contactId1 = new ContactId(234); final ContactId contactId1 = new ContactId(123);
final Collection<ContactId> both = final Collection<ContactId> both =
Arrays.asList(new ContactId[] {contactId, contactId1}); Arrays.asList(new ContactId[] {contactId, contactId1});
Mockery context = new Mockery(); Mockery context = new Mockery();

View File

@@ -3,6 +3,7 @@ package net.sf.briar.db;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.sql.Connection; import java.sql.Connection;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@@ -18,7 +19,6 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import net.sf.briar.BriarTestCase; import net.sf.briar.BriarTestCase;
import net.sf.briar.TestDatabaseModule;
import net.sf.briar.TestUtils; import net.sf.briar.TestUtils;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.Rating; import net.sf.briar.api.Rating;
@@ -38,28 +38,12 @@ import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContextFactory;
import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.api.transport.ConnectionWindowFactory;
import net.sf.briar.api.transport.TransportConstants;
import net.sf.briar.clock.ClockModule;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.lifecycle.LifecycleModule;
import net.sf.briar.protocol.ProtocolModule;
import net.sf.briar.protocol.duplex.DuplexProtocolModule;
import net.sf.briar.protocol.simplex.SimplexProtocolModule;
import net.sf.briar.serial.SerialModule;
import net.sf.briar.transport.TransportModule;
import org.apache.commons.io.FileSystemUtils; import org.apache.commons.io.FileSystemUtils;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import com.google.inject.Guice;
import com.google.inject.Injector;
public class H2DatabaseTest extends BriarTestCase { public class H2DatabaseTest extends BriarTestCase {
private static final int ONE_MEGABYTE = 1024 * 1024; private static final int ONE_MEGABYTE = 1024 * 1024;
@@ -70,10 +54,8 @@ public class H2DatabaseTest extends BriarTestCase {
private final String passwordString = "foo bar"; private final String passwordString = "foo bar";
private final Password password = new TestPassword(); private final Password password = new TestPassword();
private final Random random = new Random(); private final Random random = new Random();
private final ConnectionContextFactory connectionContextFactory;
private final ConnectionWindowFactory connectionWindowFactory;
private final GroupFactory groupFactory; private final GroupFactory groupFactory;
private final Group group;
private final AuthorId authorId; private final AuthorId authorId;
private final BatchId batchId; private final BatchId batchId;
private final ContactId contactId; private final ContactId contactId;
@@ -84,33 +66,21 @@ public class H2DatabaseTest extends BriarTestCase {
private final int size; private final int size;
private final byte[] raw; private final byte[] raw;
private final Message message, privateMessage; private final Message message, privateMessage;
private final Group group;
private final TransportId transportId; private final TransportId transportId;
private final TransportIndex localIndex, remoteIndex;
private final TransportProperties properties; private final TransportProperties properties;
private final Map<ContactId, TransportProperties> remoteProperties; private final Map<ContactId, TransportProperties> remoteProperties;
private final Collection<Transport> remoteTransports; private final Collection<Transport> remoteTransports;
private final byte[] inSecret, outSecret;
private final Collection<byte[]> erase;
public H2DatabaseTest() throws Exception { public H2DatabaseTest() throws Exception {
super(); super();
// FIXME: Use mocks for the factories rather than building the whole app groupFactory = new TestGroupFactory();
Injector i = Guice.createInjector(new ClockModule(), new CryptoModule(),
new DatabaseModule(), new LifecycleModule(),
new ProtocolModule(), new SerialModule(),
new SimplexProtocolModule(), new TransportModule(),
new DuplexProtocolModule(), new TestDatabaseModule(testDir));
connectionContextFactory =
i.getInstance(ConnectionContextFactory.class);
connectionWindowFactory = i.getInstance(ConnectionWindowFactory.class);
groupFactory = i.getInstance(GroupFactory.class);
authorId = new AuthorId(TestUtils.getRandomId()); authorId = new AuthorId(TestUtils.getRandomId());
batchId = new BatchId(TestUtils.getRandomId()); batchId = new BatchId(TestUtils.getRandomId());
contactId = new ContactId(1); contactId = new ContactId(1);
groupId = new GroupId(TestUtils.getRandomId()); groupId = new GroupId(TestUtils.getRandomId());
messageId = new MessageId(TestUtils.getRandomId()); messageId = new MessageId(TestUtils.getRandomId());
privateMessageId = new MessageId(TestUtils.getRandomId()); privateMessageId = new MessageId(TestUtils.getRandomId());
group = new TestGroup(groupId, "Foo", null);
subject = "Foo"; subject = "Foo";
timestamp = System.currentTimeMillis(); timestamp = System.currentTimeMillis();
size = 1234; size = 1234;
@@ -120,22 +90,12 @@ public class H2DatabaseTest extends BriarTestCase {
timestamp, raw); timestamp, raw);
privateMessage = new TestMessage(privateMessageId, null, null, null, privateMessage = new TestMessage(privateMessageId, null, null, null,
subject, timestamp, raw); subject, timestamp, raw);
group = groupFactory.createGroup(groupId, "Group name", null);
transportId = new TransportId(TestUtils.getRandomId()); transportId = new TransportId(TestUtils.getRandomId());
localIndex = new TransportIndex(1);
remoteIndex = new TransportIndex(13);
properties = new TransportProperties( properties = new TransportProperties(
Collections.singletonMap("foo", "bar")); Collections.singletonMap("foo", "bar"));
remoteProperties = Collections.singletonMap(contactId, properties); remoteProperties = Collections.singletonMap(contactId, properties);
Transport remoteTransport = new Transport(transportId, remoteIndex, Transport remoteTransport = new Transport(transportId, properties);
properties);
remoteTransports = Collections.singletonList(remoteTransport); remoteTransports = Collections.singletonList(remoteTransport);
Random r = new Random();
inSecret = new byte[32];
r.nextBytes(inSecret);
outSecret = new byte[32];
r.nextBytes(outSecret);
erase = new ArrayList<byte[]>();
} }
@Before @Before
@@ -149,7 +109,7 @@ public class H2DatabaseTest extends BriarTestCase {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
assertFalse(db.containsContact(txn, contactId)); assertFalse(db.containsContact(txn, contactId));
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
assertTrue(db.containsContact(txn, contactId)); assertTrue(db.containsContact(txn, contactId));
assertFalse(db.containsSubscription(txn, groupId)); assertFalse(db.containsSubscription(txn, groupId));
db.addSubscription(txn, group); db.addSubscription(txn, group);
@@ -205,23 +165,20 @@ public class H2DatabaseTest extends BriarTestCase {
// Create three contacts // Create three contacts
assertFalse(db.containsContact(txn, contactId)); assertFalse(db.containsContact(txn, contactId));
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
assertTrue(db.containsContact(txn, contactId)); assertTrue(db.containsContact(txn, contactId));
assertFalse(db.containsContact(txn, contactId1)); assertFalse(db.containsContact(txn, contactId1));
assertEquals(contactId1, assertEquals(contactId1, db.addContact(txn));
db.addContact(txn, inSecret, outSecret, erase));
assertTrue(db.containsContact(txn, contactId1)); assertTrue(db.containsContact(txn, contactId1));
assertFalse(db.containsContact(txn, contactId2)); assertFalse(db.containsContact(txn, contactId2));
assertEquals(contactId2, assertEquals(contactId2, db.addContact(txn));
db.addContact(txn, inSecret, outSecret, erase));
assertTrue(db.containsContact(txn, contactId2)); assertTrue(db.containsContact(txn, contactId2));
// Delete the contact with the highest ID // Delete the contact with the highest ID
db.removeContact(txn, contactId2); db.removeContact(txn, contactId2);
assertFalse(db.containsContact(txn, contactId2)); assertFalse(db.containsContact(txn, contactId2));
// Add another contact - a new ID should be created // Add another contact - a new ID should be created
assertFalse(db.containsContact(txn, contactId3)); assertFalse(db.containsContact(txn, contactId3));
assertEquals(contactId3, assertEquals(contactId3, db.addContact(txn));
db.addContact(txn, inSecret, outSecret, erase));
assertTrue(db.containsContact(txn, contactId3)); assertTrue(db.containsContact(txn, contactId3));
db.commitTransaction(txn); db.commitTransaction(txn);
@@ -268,7 +225,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and store a private message // Add a contact and store a private message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addPrivateMessage(txn, privateMessage, contactId); db.addPrivateMessage(txn, privateMessage, contactId);
// Removing the contact should remove the message // Removing the contact should remove the message
@@ -282,18 +239,18 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testSendablePrivateMessagesMustHaveStatusNew() public void testSendablePrivateMessagesMustHaveStatusNew()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and store a private message // Add a contact and store a private message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addPrivateMessage(txn, privateMessage, contactId); db.addPrivateMessage(txn, privateMessage, contactId);
// The message has no status yet, so it should not be sendable // The message has no status yet, so it should not be sendable
assertFalse(db.hasSendableMessages(txn, contactId)); assertFalse(db.hasSendableMessages(txn, contactId));
Iterator<MessageId> it = Iterator<MessageId> it =
db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator(); db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator();
assertFalse(it.hasNext()); assertFalse(it.hasNext());
// Changing the status to NEW should make the message sendable // Changing the status to NEW should make the message sendable
@@ -321,19 +278,19 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testSendablePrivateMessagesMustFitCapacity() public void testSendablePrivateMessagesMustFitCapacity()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and store a private message // Add a contact and store a private message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addPrivateMessage(txn, privateMessage, contactId); db.addPrivateMessage(txn, privateMessage, contactId);
db.setStatus(txn, contactId, privateMessageId, Status.NEW); db.setStatus(txn, contactId, privateMessageId, Status.NEW);
// The message is sendable, but too large to send // The message is sendable, but too large to send
assertTrue(db.hasSendableMessages(txn, contactId)); assertTrue(db.hasSendableMessages(txn, contactId));
Iterator<MessageId> it = Iterator<MessageId> it =
db.getSendableMessages(txn, contactId, size - 1).iterator(); db.getSendableMessages(txn, contactId, size - 1).iterator();
assertFalse(it.hasNext()); assertFalse(it.hasNext());
// The message is just the right size to send // The message is just the right size to send
@@ -349,12 +306,12 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testSendableGroupMessagesMustHavePositiveSendability() public void testSendableGroupMessagesMustHavePositiveSendability()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addVisibility(txn, contactId, groupId); db.addVisibility(txn, contactId, groupId);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
@@ -364,7 +321,7 @@ public class H2DatabaseTest extends BriarTestCase {
// The message should not be sendable // The message should not be sendable
assertFalse(db.hasSendableMessages(txn, contactId)); assertFalse(db.hasSendableMessages(txn, contactId));
Iterator<MessageId> it = Iterator<MessageId> it =
db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator(); db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator();
assertFalse(it.hasNext()); assertFalse(it.hasNext());
// Changing the sendability to > 0 should make the message sendable // Changing the sendability to > 0 should make the message sendable
@@ -387,12 +344,12 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testSendableGroupMessagesMustHaveStatusNew() public void testSendableGroupMessagesMustHaveStatusNew()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addVisibility(txn, contactId, groupId); db.addVisibility(txn, contactId, groupId);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
@@ -402,7 +359,7 @@ public class H2DatabaseTest extends BriarTestCase {
// The message has no status yet, so it should not be sendable // The message has no status yet, so it should not be sendable
assertFalse(db.hasSendableMessages(txn, contactId)); assertFalse(db.hasSendableMessages(txn, contactId));
Iterator<MessageId> it = Iterator<MessageId> it =
db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator(); db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator();
assertFalse(it.hasNext()); assertFalse(it.hasNext());
// Changing the status to Status.NEW should make the message sendable // Changing the status to Status.NEW should make the message sendable
@@ -434,7 +391,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addVisibility(txn, contactId, groupId); db.addVisibility(txn, contactId, groupId);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -444,7 +401,7 @@ public class H2DatabaseTest extends BriarTestCase {
// The contact is not subscribed, so the message should not be sendable // The contact is not subscribed, so the message should not be sendable
assertFalse(db.hasSendableMessages(txn, contactId)); assertFalse(db.hasSendableMessages(txn, contactId));
Iterator<MessageId> it = Iterator<MessageId> it =
db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator(); db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator();
assertFalse(it.hasNext()); assertFalse(it.hasNext());
// The contact subscribing should make the message sendable // The contact subscribing should make the message sendable
@@ -467,12 +424,12 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testSendableGroupMessagesMustBeNewerThanSubscriptions() public void testSendableGroupMessagesMustBeNewerThanSubscriptions()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addVisibility(txn, contactId, groupId); db.addVisibility(txn, contactId, groupId);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -484,7 +441,7 @@ public class H2DatabaseTest extends BriarTestCase {
db.addSubscription(txn, contactId, group, timestamp + 1); db.addSubscription(txn, contactId, group, timestamp + 1);
assertFalse(db.hasSendableMessages(txn, contactId)); assertFalse(db.hasSendableMessages(txn, contactId));
Iterator<MessageId> it = Iterator<MessageId> it =
db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator(); db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator();
assertFalse(it.hasNext()); assertFalse(it.hasNext());
// Changing the contact's subscription should make the message sendable // Changing the contact's subscription should make the message sendable
@@ -506,7 +463,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addVisibility(txn, contactId, groupId); db.addVisibility(txn, contactId, groupId);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
@@ -517,7 +474,7 @@ public class H2DatabaseTest extends BriarTestCase {
// The message is sendable, but too large to send // The message is sendable, but too large to send
assertTrue(db.hasSendableMessages(txn, contactId)); assertTrue(db.hasSendableMessages(txn, contactId));
Iterator<MessageId> it = Iterator<MessageId> it =
db.getSendableMessages(txn, contactId, size - 1).iterator(); db.getSendableMessages(txn, contactId, size - 1).iterator();
assertFalse(it.hasNext()); assertFalse(it.hasNext());
// The message is just the right size to send // The message is just the right size to send
@@ -537,7 +494,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -548,7 +505,7 @@ public class H2DatabaseTest extends BriarTestCase {
// should not be sendable // should not be sendable
assertFalse(db.hasSendableMessages(txn, contactId)); assertFalse(db.hasSendableMessages(txn, contactId));
Iterator<MessageId> it = Iterator<MessageId> it =
db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator(); db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator();
assertFalse(it.hasNext()); assertFalse(it.hasNext());
// Making the subscription visible should make the message sendable // Making the subscription visible should make the message sendable
@@ -570,7 +527,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and some batches to ack // Add a contact and some batches to ack
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addBatchToAck(txn, contactId, batchId); db.addBatchToAck(txn, contactId, batchId);
db.addBatchToAck(txn, contactId, batchId1); db.addBatchToAck(txn, contactId, batchId1);
@@ -597,7 +554,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and receive the same batch twice // Add a contact and receive the same batch twice
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addBatchToAck(txn, contactId, batchId); db.addBatchToAck(txn, contactId, batchId);
db.addBatchToAck(txn, contactId, batchId); db.addBatchToAck(txn, contactId, batchId);
@@ -623,7 +580,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -648,8 +605,8 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add two contacts, subscribe to a group and store a message // Add two contacts, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
ContactId contactId1 = db.addContact(txn, inSecret, outSecret, erase); ContactId contactId1 = db.addContact(txn);
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -671,7 +628,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addVisibility(txn, contactId, groupId); db.addVisibility(txn, contactId, groupId);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
@@ -681,7 +638,7 @@ public class H2DatabaseTest extends BriarTestCase {
// Retrieve the message from the database and mark it as sent // Retrieve the message from the database and mark it as sent
Iterator<MessageId> it = Iterator<MessageId> it =
db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator(); db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator();
assertTrue(it.hasNext()); assertTrue(it.hasNext());
assertEquals(messageId, it.next()); assertEquals(messageId, it.next());
assertFalse(it.hasNext()); assertFalse(it.hasNext());
@@ -710,7 +667,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addVisibility(txn, contactId, groupId); db.addVisibility(txn, contactId, groupId);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
@@ -720,7 +677,7 @@ public class H2DatabaseTest extends BriarTestCase {
// Get the message and mark it as sent // Get the message and mark it as sent
Iterator<MessageId> it = Iterator<MessageId> it =
db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator(); db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator();
assertTrue(it.hasNext()); assertTrue(it.hasNext());
assertEquals(messageId, it.next()); assertEquals(messageId, it.next());
assertFalse(it.hasNext()); assertFalse(it.hasNext());
@@ -755,7 +712,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact // Add a contact
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
// Add some outstanding batches, a few ms apart // Add some outstanding batches, a few ms apart
for(int i = 0; i < ids.length; i++) { for(int i = 0; i < ids.length; i++) {
@@ -795,7 +752,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact // Add a contact
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
// Add some outstanding batches, a few ms apart // Add some outstanding batches, a few ms apart
for(int i = 0; i < ids.length; i++) { for(int i = 0; i < ids.length; i++) {
@@ -832,7 +789,7 @@ public class H2DatabaseTest extends BriarTestCase {
// Check that each message is retrievable via its author // Check that each message is retrievable via its author
Iterator<MessageId> it = Iterator<MessageId> it =
db.getMessagesByAuthor(txn, authorId).iterator(); db.getMessagesByAuthor(txn, authorId).iterator();
assertTrue(it.hasNext()); assertTrue(it.hasNext());
assertEquals(messageId, it.next()); assertEquals(messageId, it.next());
assertFalse(it.hasNext()); assertFalse(it.hasNext());
@@ -1021,35 +978,29 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact with a transport // Add a contact with a transport
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.setTransports(txn, contactId, remoteTransports, 1); db.setTransports(txn, contactId, remoteTransports, 1);
assertEquals(remoteProperties, assertEquals(remoteProperties,
db.getRemoteProperties(txn, transportId)); db.getRemoteProperties(txn, transportId));
// Replace the transport properties // Replace the transport properties
TransportProperties properties1 = TransportProperties properties1 =
new TransportProperties(Collections.singletonMap("baz", "bam")); new TransportProperties(Collections.singletonMap("baz", "bam"));
Transport remoteTransport1 = Transport remoteTransport1 = new Transport(transportId, properties1);
new Transport(transportId, remoteIndex, properties1);
Collection<Transport> remoteTransports1 = Collection<Transport> remoteTransports1 =
Collections.singletonList(remoteTransport1); Collections.singletonList(remoteTransport1);
Map<ContactId, TransportProperties> remoteProperties1 = Map<ContactId, TransportProperties> remoteProperties1 =
Collections.singletonMap(contactId, properties1); Collections.singletonMap(contactId, properties1);
db.setTransports(txn, contactId, remoteTransports1, 2); db.setTransports(txn, contactId, remoteTransports1, 2);
assertEquals(remoteProperties1, assertEquals(remoteProperties1,
db.getRemoteProperties(txn, transportId)); db.getRemoteProperties(txn, transportId));
// Remove the transport properties but leave the transport // Remove the transport properties
properties1 = new TransportProperties(); properties1 = new TransportProperties();
remoteTransport1 = new Transport(transportId, remoteIndex, properties1); remoteTransport1 = new Transport(transportId, properties1);
remoteTransports1 = Collections.singletonList(remoteTransport1); remoteTransports1 = Collections.singletonList(remoteTransport1);
remoteProperties1 = Collections.singletonMap(contactId, properties1); remoteProperties1 = Collections.singletonMap(contactId, properties1);
db.setTransports(txn, contactId, remoteTransports1, 3); db.setTransports(txn, contactId, remoteTransports1, 3);
assertEquals(remoteProperties1,
db.getRemoteProperties(txn, transportId));
// Remove the transport
db.setTransports(txn, contactId, Collections.<Transport>emptyList(), 4);
assertEquals(Collections.emptyMap(), assertEquals(Collections.emptyMap(),
db.getRemoteProperties(txn, transportId)); db.getRemoteProperties(txn, transportId));
@@ -1062,9 +1013,6 @@ public class H2DatabaseTest extends BriarTestCase {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Allocate a transport index
assertEquals(localIndex, db.addTransport(txn, transportId));
// Set the transport properties // Set the transport properties
db.setLocalProperties(txn, transportId, properties); db.setLocalProperties(txn, transportId, properties);
assertEquals(Collections.singletonList(properties), assertEquals(Collections.singletonList(properties),
@@ -1072,8 +1020,7 @@ public class H2DatabaseTest extends BriarTestCase {
// Remove the transport properties but leave the transport // Remove the transport properties but leave the transport
db.setLocalProperties(txn, transportId, new TransportProperties()); db.setLocalProperties(txn, transportId, new TransportProperties());
assertEquals(Collections.singletonList(Collections.emptyMap()), assertEquals(Collections.emptyList(), db.getLocalTransports(txn));
db.getLocalTransports(txn));
db.commitTransaction(txn); db.commitTransaction(txn);
db.close(); db.close();
@@ -1082,16 +1029,13 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testUpdateTransportConfig() throws Exception { public void testUpdateTransportConfig() throws Exception {
TransportConfig config = TransportConfig config =
new TransportConfig(Collections.singletonMap("foo", "bar")); new TransportConfig(Collections.singletonMap("foo", "bar"));
TransportConfig config1 = TransportConfig config1 =
new TransportConfig(Collections.singletonMap("baz", "bam")); new TransportConfig(Collections.singletonMap("baz", "bam"));
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Allocate a transport index
assertEquals(localIndex, db.addTransport(txn, transportId));
// Set the transport config // Set the transport config
db.setConfig(txn, transportId, config); db.setConfig(txn, transportId, config);
assertEquals(config, db.getConfig(txn, transportId)); assertEquals(config, db.getConfig(txn, transportId));
@@ -1114,31 +1058,29 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact with a transport // Add a contact with a transport
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.setTransports(txn, contactId, remoteTransports, 1); db.setTransports(txn, contactId, remoteTransports, 1);
assertEquals(remoteProperties, assertEquals(remoteProperties,
db.getRemoteProperties(txn, transportId)); db.getRemoteProperties(txn, transportId));
// Replace the transport properties using a timestamp of 2 // Replace the transport properties using a timestamp of 2
TransportProperties properties1 = TransportProperties properties1 =
new TransportProperties(Collections.singletonMap("baz", "bam")); new TransportProperties(Collections.singletonMap("baz", "bam"));
Transport remoteTransport1 = Transport remoteTransport1 = new Transport(transportId, properties1);
new Transport(transportId, remoteIndex, properties1);
Collection<Transport> remoteTransports1 = Collection<Transport> remoteTransports1 =
Collections.singletonList(remoteTransport1); Collections.singletonList(remoteTransport1);
Map<ContactId, TransportProperties> remoteProperties1 = Map<ContactId, TransportProperties> remoteProperties1 =
Collections.singletonMap(contactId, properties1); Collections.singletonMap(contactId, properties1);
db.setTransports(txn, contactId, remoteTransports1, 2); db.setTransports(txn, contactId, remoteTransports1, 2);
assertEquals(remoteProperties1, assertEquals(remoteProperties1,
db.getRemoteProperties(txn, transportId)); db.getRemoteProperties(txn, transportId));
// Try to replace the transport properties using a timestamp of 1 // Try to replace the transport properties using a timestamp of 1
TransportProperties properties2 = TransportProperties properties2 =
new TransportProperties(Collections.singletonMap("quux", "etc")); new TransportProperties(Collections.singletonMap("quux", "etc"));
Transport remoteTransport2 = Transport remoteTransport2 = new Transport(transportId, properties2);
new Transport(transportId, remoteIndex, properties2);
Collection<Transport> remoteTransports2 = Collection<Transport> remoteTransports2 =
Collections.singletonList(remoteTransport2); Collections.singletonList(remoteTransport2);
db.setTransports(txn, contactId, remoteTransports2, 1); db.setTransports(txn, contactId, remoteTransports2, 1);
// The old properties should still be there // The old properties should still be there
@@ -1151,12 +1093,12 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testGetMessageIfSendableReturnsNullIfNotInDatabase() public void testGetMessageIfSendableReturnsNullIfNotInDatabase()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and subscribe to a group // Add a contact and subscribe to a group
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
@@ -1169,12 +1111,12 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testGetMessageIfSendableReturnsNullIfSeen() public void testGetMessageIfSendableReturnsNullIfSeen()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -1192,12 +1134,12 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testGetMessageIfSendableReturnsNullIfNotSendable() public void testGetMessageIfSendableReturnsNullIfNotSendable()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
@@ -1220,7 +1162,7 @@ public class H2DatabaseTest extends BriarTestCase {
// Add a contact, subscribe to a group and store a message - // Add a contact, subscribe to a group and store a message -
// the message is older than the contact's subscription // the message is older than the contact's subscription
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addVisibility(txn, contactId, groupId); db.addVisibility(txn, contactId, groupId);
db.addSubscription(txn, contactId, group, timestamp + 1); db.addSubscription(txn, contactId, group, timestamp + 1);
@@ -1243,7 +1185,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addVisibility(txn, contactId, groupId); db.addVisibility(txn, contactId, groupId);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
@@ -1263,12 +1205,12 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testSetStatusSeenIfVisibleRequiresMessageInDatabase() public void testSetStatusSeenIfVisibleRequiresMessageInDatabase()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and subscribe to a group // Add a contact and subscribe to a group
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addVisibility(txn, contactId, groupId); db.addVisibility(txn, contactId, groupId);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
@@ -1282,12 +1224,12 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testSetStatusSeenIfVisibleRequiresLocalSubscription() public void testSetStatusSeenIfVisibleRequiresLocalSubscription()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact with a subscription // Add a contact with a subscription
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
// There's no local subscription for the group // There's no local subscription for the group
@@ -1299,12 +1241,12 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testSetStatusSeenIfVisibleRequiresContactSubscription() public void testSetStatusSeenIfVisibleRequiresContactSubscription()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
db.setStatus(txn, contactId, messageId, Status.NEW); db.setStatus(txn, contactId, messageId, Status.NEW);
@@ -1318,12 +1260,12 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testSetStatusSeenIfVisibleRequiresVisibility() public void testSetStatusSeenIfVisibleRequiresVisibility()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addGroupMessage(txn, message); db.addGroupMessage(txn, message);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
@@ -1338,12 +1280,12 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testSetStatusSeenIfVisibleReturnsTrueIfAlreadySeen() public void testSetStatusSeenIfVisibleReturnsTrueIfAlreadySeen()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addVisibility(txn, contactId, groupId); db.addVisibility(txn, contactId, groupId);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
@@ -1360,12 +1302,12 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testSetStatusSeenIfVisibleReturnsTrueIfNew() public void testSetStatusSeenIfVisibleReturnsTrueIfNew()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
db.addVisibility(txn, contactId, groupId); db.addVisibility(txn, contactId, groupId);
db.addSubscription(txn, contactId, group, 0L); db.addSubscription(txn, contactId, group, 0L);
@@ -1386,7 +1328,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and subscribe to a group // Add a contact and subscribe to a group
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
// The group should not be visible to the contact // The group should not be visible to the contact
assertEquals(Collections.emptyList(), db.getVisibility(txn, groupId)); assertEquals(Collections.emptyList(), db.getVisibility(txn, groupId));
@@ -1402,58 +1344,6 @@ public class H2DatabaseTest extends BriarTestCase {
db.close(); db.close();
} }
@Test
public void testGettingUnknownConnectionWindowReturnsDefault()
throws Exception {
Database<Connection> db = open(false);
Connection txn = db.startTransaction();
// Add a contact
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
// Get the connection window for a new index
ConnectionWindow w = db.getConnectionWindow(txn, contactId,
remoteIndex);
// The connection window should exist and be in the initial state
assertNotNull(w);
long top = TransportConstants.CONNECTION_WINDOW_SIZE / 2 - 1;
for(long l = 0; l <= top; l++) assertFalse(w.isSeen(l));
db.commitTransaction(txn);
db.close();
}
@Test
public void testConnectionWindow() throws Exception {
Database<Connection> db = open(false);
Connection txn = db.startTransaction();
// Add a contact
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase));
// Get the connection window for a new index
ConnectionWindow w = db.getConnectionWindow(txn, contactId,
remoteIndex);
// The connection window should exist and be in the initial state
assertNotNull(w);
Map<Long, byte[]> unseen = w.getUnseen();
long top = TransportConstants.CONNECTION_WINDOW_SIZE / 2 - 1;
assertEquals(top + 1, unseen.size());
for(long l = 0; l <= top; l++) {
assertFalse(w.isSeen(l));
assertTrue(unseen.containsKey(l));
}
// Update the connection window and store it
w.setSeen(5);
db.setConnectionWindow(txn, contactId, remoteIndex, w);
// Check that the connection window was stored
w = db.getConnectionWindow(txn, contactId, remoteIndex);
assertNotNull(w);
top += 5;
for(long l = 0; l <= top; l++) assertEquals(l == 5, w.isSeen(l));
db.commitTransaction(txn);
db.close();
}
@Test @Test
public void testGetGroupMessageParentWithNoParent() throws Exception { public void testGetGroupMessageParentWithNoParent() throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
@@ -1498,7 +1388,7 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testGetGroupMessageParentWithParentInAnotherGroup() public void testGetGroupMessageParentWithParentInAnotherGroup()
throws Exception { throws Exception {
GroupId groupId1 = new GroupId(TestUtils.getRandomId()); GroupId groupId1 = new GroupId(TestUtils.getRandomId());
Group group1 = groupFactory.createGroup(groupId1, "Group name", null); Group group1 = groupFactory.createGroup(groupId1, "Group name", null);
Database<Connection> db = open(false); Database<Connection> db = open(false);
@@ -1527,12 +1417,12 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testGetGroupMessageParentWithPrivateParent() public void testGetGroupMessageParentWithPrivateParent()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and subscribe to a group // Add a contact and subscribe to a group
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
// A message with a private parent should return null // A message with a private parent should return null
@@ -1551,7 +1441,7 @@ public class H2DatabaseTest extends BriarTestCase {
@Test @Test
public void testGetGroupMessageParentWithParentInSameGroup() public void testGetGroupMessageParentWithParentInSameGroup()
throws Exception { throws Exception {
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
@@ -1581,7 +1471,7 @@ public class H2DatabaseTest extends BriarTestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Add a contact and subscribe to a group // Add a contact and subscribe to a group
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
db.addSubscription(txn, group); db.addSubscription(txn, group);
// Store a couple of messages // Store a couple of messages
@@ -1813,7 +1703,7 @@ public class H2DatabaseTest extends BriarTestCase {
// Subscribe to the groups and add a contact // Subscribe to the groups and add a contact
for(Group g : groups) db.addSubscription(txn, g); for(Group g : groups) db.addSubscription(txn, g);
assertEquals(contactId, db.addContact(txn, inSecret, outSecret, erase)); assertEquals(contactId, db.addContact(txn));
// Make the groups visible to the contact // Make the groups visible to the contact
Collections.shuffle(groups); Collections.shuffle(groups);
@@ -1849,7 +1739,6 @@ public class H2DatabaseTest extends BriarTestCase {
private Database<Connection> open(boolean resume) throws Exception { private Database<Connection> open(boolean resume) throws Exception {
Database<Connection> db = new H2Database(testDir, password, MAX_SIZE, Database<Connection> db = new H2Database(testDir, password, MAX_SIZE,
connectionContextFactory, connectionWindowFactory,
groupFactory, new SystemClock()); groupFactory, new SystemClock());
db.open(resume); db.open(resume);
return db; return db;
@@ -1857,7 +1746,6 @@ public class H2DatabaseTest extends BriarTestCase {
@After @After
public void tearDown() { public void tearDown() {
erase.clear();
TestUtils.deleteTestDirectory(testDir); TestUtils.deleteTestDirectory(testDir);
} }
@@ -1867,4 +1755,17 @@ public class H2DatabaseTest extends BriarTestCase {
return passwordString.toCharArray(); return passwordString.toCharArray();
} }
} }
private class TestGroupFactory implements GroupFactory {
public Group createGroup(String name, byte[] publicKey)
throws IOException {
GroupId id = new GroupId(TestUtils.getRandomId());
return new TestGroup(id, name, publicKey);
}
public Group createGroup(GroupId id, String name, byte[] publicKey) {
return new TestGroup(id, name, publicKey);
}
}
} }

View File

@@ -16,7 +16,7 @@ abstract class DuplexTest {
protected static final String RESPONSE = "Potatoes!"; protected static final String RESPONSE = "Potatoes!";
protected static final long INVITATION_TIMEOUT = 30 * 1000; protected static final long INVITATION_TIMEOUT = 30 * 1000;
protected final ContactId contactId = new ContactId(0); protected final ContactId contactId = new ContactId(234);
protected DuplexPlugin plugin = null; protected DuplexPlugin plugin = null;

View File

@@ -16,6 +16,8 @@ import com.google.inject.Injector;
public class InvitationStarterImplTest extends BriarTestCase { public class InvitationStarterImplTest extends BriarTestCase {
// FIXME: This is actually a test of CryptoComponent
private final CryptoComponent crypto; private final CryptoComponent crypto;
public InvitationStarterImplTest() { public InvitationStarterImplTest() {
@@ -32,13 +34,8 @@ public class InvitationStarterImplTest extends BriarTestCase {
KeyPair b = crypto.generateAgreementKeyPair(); KeyPair b = crypto.generateAgreementKeyPair();
byte[] bPub = b.getPublic().getEncoded(); byte[] bPub = b.getPublic().getEncoded();
PrivateKey bPriv = b.getPrivate(); PrivateKey bPriv = b.getPrivate();
byte[][] aSecrets = crypto.deriveInitialSecrets(aPub, bPub, aPriv, 123, byte[] aSecret = crypto.deriveInitialSecret(aPub, bPub, aPriv, true);
true); byte[] bSecret = crypto.deriveInitialSecret(bPub, aPub, bPriv, false);
byte[][] bSecrets = crypto.deriveInitialSecrets(bPub, aPub, bPriv, 123, assertArrayEquals(aSecret, bSecret);
false);
assertEquals(2, aSecrets.length);
assertEquals(2, bSecrets.length);
assertArrayEquals(aSecrets[0], bSecrets[0]);
assertArrayEquals(aSecrets[1], bSecrets[1]);
} }
} }

View File

@@ -3,14 +3,12 @@ package net.sf.briar.plugins;
import java.util.Collection; import java.util.Collection;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import net.sf.briar.BriarTestCase; import net.sf.briar.BriarTestCase;
import net.sf.briar.api.TransportConfig; import net.sf.briar.api.TransportConfig;
import net.sf.briar.api.TransportProperties; import net.sf.briar.api.TransportProperties;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionDispatcher; import net.sf.briar.api.transport.ConnectionDispatcher;
import net.sf.briar.api.ui.UiCallback; import net.sf.briar.api.ui.UiCallback;
@@ -29,13 +27,8 @@ public class PluginManagerImplTest extends BriarTestCase {
final ConnectionDispatcher dispatcher = final ConnectionDispatcher dispatcher =
context.mock(ConnectionDispatcher.class); context.mock(ConnectionDispatcher.class);
final UiCallback uiCallback = context.mock(UiCallback.class); final UiCallback uiCallback = context.mock(UiCallback.class);
final AtomicInteger index = new AtomicInteger(0);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(poller).start(with(any(Collection.class))); oneOf(poller).start(with(any(Collection.class)));
allowing(db).getLocalIndex(with(any(TransportId.class)));
will(returnValue(null));
allowing(db).addTransport(with(any(TransportId.class)));
will(returnValue(new TransportIndex(index.getAndIncrement())));
allowing(db).getConfig(with(any(TransportId.class))); allowing(db).getConfig(with(any(TransportId.class)));
will(returnValue(new TransportConfig())); will(returnValue(new TransportConfig()));
allowing(db).getLocalProperties(with(any(TransportId.class))); allowing(db).getLocalProperties(with(any(TransportId.class)));

View File

@@ -1,17 +1,14 @@
package net.sf.briar.plugins.email; package net.sf.briar.plugins.email;
import static org.junit.Assert.*; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream;
import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import net.sf.briar.BriarTestCase;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.TransportConfig; import net.sf.briar.api.TransportConfig;
import net.sf.briar.api.TransportProperties; import net.sf.briar.api.TransportProperties;
@@ -19,7 +16,6 @@ import net.sf.briar.api.plugins.simplex.SimplexPluginCallback;
import net.sf.briar.api.plugins.simplex.SimplexTransportReader; import net.sf.briar.api.plugins.simplex.SimplexTransportReader;
import net.sf.briar.api.plugins.simplex.SimplexTransportWriter; import net.sf.briar.api.plugins.simplex.SimplexTransportWriter;
import org.jmock.Mockery;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@@ -47,7 +43,7 @@ public class GmailPluginTest {
props1 = new TransportProperties(); props1 = new TransportProperties();
props1.put("email", System.getenv("CONTACT1_EMAIL")); props1.put("email", System.getenv("CONTACT1_EMAIL"));
test1 = new ContactId(12); test1 = new ContactId(234);
map.put(test1, props1); map.put(test1, props1);
assertEquals(1, map.size()); assertEquals(1, map.size());
@@ -120,7 +116,7 @@ public class GmailPluginTest {
GmailPlugin pluginTest = new GmailPlugin( GmailPlugin pluginTest = new GmailPlugin(
Executors.newSingleThreadExecutor(), callback); Executors.newSingleThreadExecutor(), callback);
assertEquals(true, pluginTest.connectSMTP(test1)); assertEquals(true, pluginTest.connectSMTP(test1));
assertEquals(false, pluginTest.connectSMTP(new ContactId(7))); assertEquals(false, pluginTest.connectSMTP(new ContactId(123)));
pluginTest.stop(); pluginTest.stop();
} }

View File

@@ -28,7 +28,7 @@ import org.junit.Test;
public class RemovableDrivePluginTest extends BriarTestCase { public class RemovableDrivePluginTest extends BriarTestCase {
private final File testDir = TestUtils.getTestDirectory(); private final File testDir = TestUtils.getTestDirectory();
private final ContactId contactId = new ContactId(0); private final ContactId contactId = new ContactId(234);
@Before @Before
public void setUp() { public void setUp() {

View File

@@ -23,7 +23,7 @@ import org.junit.Test;
public class SimpleSocketPluginTest extends BriarTestCase { public class SimpleSocketPluginTest extends BriarTestCase {
private final ContactId contactId = new ContactId(0); private final ContactId contactId = new ContactId(234);
@Test @Test
public void testIncomingConnection() throws Exception { public void testIncomingConnection() throws Exception {

View File

@@ -22,7 +22,7 @@ import org.junit.Test;
public class TorPluginTest extends BriarTestCase { public class TorPluginTest extends BriarTestCase {
private final ContactId contactId = new ContactId(1); private final ContactId contactId = new ContactId(234);
@Test @Test
public void testHiddenService() throws Exception { public void testHiddenService() throws Exception {

View File

@@ -35,7 +35,6 @@ import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.RawBatch; import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
@@ -158,8 +157,7 @@ public class ConstantsTest extends BriarTestCase {
Collection<Transport> transports = new ArrayList<Transport>(); Collection<Transport> transports = new ArrayList<Transport>();
for(int i = 0; i < MAX_TRANSPORTS; i++) { for(int i = 0; i < MAX_TRANSPORTS; i++) {
TransportId id = new TransportId(TestUtils.getRandomId()); TransportId id = new TransportId(TestUtils.getRandomId());
TransportIndex index = new TransportIndex(i); Transport t = new Transport(id);
Transport t = new Transport(id, index);
for(int j = 0; j < MAX_PROPERTIES_PER_TRANSPORT; j++) { for(int j = 0; j < MAX_PROPERTIES_PER_TRANSPORT; j++) {
String key = createRandomString(MAX_PROPERTY_LENGTH); String key = createRandomString(MAX_PROPERTY_LENGTH);
String value = createRandomString(MAX_PROPERTY_LENGTH); String value = createRandomString(MAX_PROPERTY_LENGTH);

View File

@@ -28,7 +28,6 @@ import net.sf.briar.api.protocol.Request;
import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
@@ -71,8 +70,7 @@ public class ProtocolReadWriteTest extends BriarTestCase {
bitSet.set(7); bitSet.set(7);
subscriptions = Collections.singletonMap(group, 123L); subscriptions = Collections.singletonMap(group, 123L);
TransportId transportId = new TransportId(TestUtils.getRandomId()); TransportId transportId = new TransportId(TestUtils.getRandomId());
TransportIndex transportIndex = new TransportIndex(13); Transport transport = new Transport(transportId,
Transport transport = new Transport(transportId, transportIndex,
Collections.singletonMap("bar", "baz")); Collections.singletonMap("bar", "baz"));
transports = Collections.singletonList(transport); transports = Collections.singletonList(transport);
} }

View File

@@ -1,7 +1,9 @@
package net.sf.briar.protocol.simplex; package net.sf.briar.protocol.simplex;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH; import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH; import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MIN_CONNECTION_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH; import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
@@ -12,20 +14,19 @@ import java.util.concurrent.Executors;
import net.sf.briar.BriarTestCase; import net.sf.briar.BriarTestCase;
import net.sf.briar.TestUtils; import net.sf.briar.TestUtils;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.KeyManager;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DatabaseExecutor; import net.sf.briar.api.db.DatabaseExecutor;
import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.RawBatch; import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionRegistry; import net.sf.briar.api.transport.ConnectionRegistry;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.api.transport.TransportConstants;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.protocol.ProtocolModule; import net.sf.briar.protocol.ProtocolModule;
import net.sf.briar.protocol.duplex.DuplexProtocolModule; import net.sf.briar.protocol.duplex.DuplexProtocolModule;
@@ -43,24 +44,31 @@ import com.google.inject.Module;
public class OutgoingSimplexConnectionTest extends BriarTestCase { public class OutgoingSimplexConnectionTest extends BriarTestCase {
// FIXME: This is an integration test, not a unit test
private final Mockery context; private final Mockery context;
private final DatabaseComponent db; private final DatabaseComponent db;
private final KeyManager keyManager;
private final ConnectionRecogniser connRecogniser;
private final ConnectionRegistry connRegistry; private final ConnectionRegistry connRegistry;
private final ConnectionWriterFactory connFactory; private final ConnectionWriterFactory connFactory;
private final ProtocolWriterFactory protoFactory; private final ProtocolWriterFactory protoFactory;
private final ContactId contactId; private final ContactId contactId;
private final TransportId transportId; private final TransportId transportId;
private final TransportIndex transportIndex;
private final byte[] secret; private final byte[] secret;
public OutgoingSimplexConnectionTest() { public OutgoingSimplexConnectionTest() {
super(); super();
context = new Mockery(); context = new Mockery();
db = context.mock(DatabaseComponent.class); db = context.mock(DatabaseComponent.class);
keyManager = context.mock(KeyManager.class);
connRecogniser = context.mock(ConnectionRecogniser.class);
Module testModule = new AbstractModule() { Module testModule = new AbstractModule() {
@Override @Override
public void configure() { public void configure() {
bind(DatabaseComponent.class).toInstance(db); bind(DatabaseComponent.class).toInstance(db);
bind(KeyManager.class).toInstance(keyManager);
bind(ConnectionRecogniser.class).toInstance(connRecogniser);
bind(Executor.class).annotatedWith( bind(Executor.class).annotatedWith(
DatabaseExecutor.class).toInstance( DatabaseExecutor.class).toInstance(
Executors.newCachedThreadPool()); Executors.newCachedThreadPool());
@@ -73,9 +81,8 @@ public class OutgoingSimplexConnectionTest extends BriarTestCase {
connRegistry = i.getInstance(ConnectionRegistry.class); connRegistry = i.getInstance(ConnectionRegistry.class);
connFactory = i.getInstance(ConnectionWriterFactory.class); connFactory = i.getInstance(ConnectionWriterFactory.class);
protoFactory = i.getInstance(ProtocolWriterFactory.class); protoFactory = i.getInstance(ProtocolWriterFactory.class);
contactId = new ContactId(1); contactId = new ContactId(234);
transportId = new TransportId(TestUtils.getRandomId()); transportId = new TransportId(TestUtils.getRandomId());
transportIndex = new TransportIndex(13);
secret = new byte[32]; secret = new byte[32];
} }
@@ -83,40 +90,31 @@ public class OutgoingSimplexConnectionTest extends BriarTestCase {
public void testConnectionTooShort() throws Exception { public void testConnectionTooShort() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
TestSimplexTransportWriter transport = new TestSimplexTransportWriter( TestSimplexTransportWriter transport = new TestSimplexTransportWriter(
out, ProtocolConstants.MAX_PACKET_LENGTH, true); out, MAX_PACKET_LENGTH, true);
byte[] tag = new byte[TAG_LENGTH];
ConnectionContext ctx = new ConnectionContext(contactId, transportId,
tag, secret, 0L, true);
OutgoingSimplexConnection connection = new OutgoingSimplexConnection(db, OutgoingSimplexConnection connection = new OutgoingSimplexConnection(db,
connRegistry, connFactory, protoFactory, contactId, transportId, connRegistry, connFactory, protoFactory, ctx, transport);
transportIndex, transport);
final ConnectionContext ctx = context.mock(ConnectionContext.class);
context.checking(new Expectations() {{
oneOf(db).getConnectionContext(contactId, transportIndex);
will(returnValue(ctx));
oneOf(ctx).getSecret();
will(returnValue(secret));
}});
connection.write(); connection.write();
// Nothing should have been written // Nothing should have been written
assertEquals(0, out.size()); assertEquals(0, out.size());
// The transport should have been disposed with exception == true // The transport should have been disposed with exception == true
assertTrue(transport.getDisposed()); assertTrue(transport.getDisposed());
assertTrue(transport.getException()); assertTrue(transport.getException());
context.assertIsSatisfied();
} }
@Test @Test
public void testNothingToSend() throws Exception { public void testNothingToSend() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
TestSimplexTransportWriter transport = new TestSimplexTransportWriter( TestSimplexTransportWriter transport = new TestSimplexTransportWriter(
out, TransportConstants.MIN_CONNECTION_LENGTH, true); out, MIN_CONNECTION_LENGTH, true);
byte[] tag = new byte[TAG_LENGTH];
ConnectionContext ctx = new ConnectionContext(contactId, transportId,
tag, secret, 0L, true);
OutgoingSimplexConnection connection = new OutgoingSimplexConnection(db, OutgoingSimplexConnection connection = new OutgoingSimplexConnection(db,
connRegistry, connFactory, protoFactory, contactId, transportId, connRegistry, connFactory, protoFactory, ctx, transport);
transportIndex, transport);
final ConnectionContext ctx = context.mock(ConnectionContext.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(db).getConnectionContext(contactId, transportIndex);
will(returnValue(ctx));
oneOf(ctx).getSecret();
will(returnValue(secret));
// No transports to send // No transports to send
oneOf(db).generateTransportUpdate(contactId); oneOf(db).generateTransportUpdate(contactId);
will(returnValue(null)); will(returnValue(null));
@@ -143,20 +141,17 @@ public class OutgoingSimplexConnectionTest extends BriarTestCase {
public void testSomethingToSend() throws Exception { public void testSomethingToSend() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
TestSimplexTransportWriter transport = new TestSimplexTransportWriter( TestSimplexTransportWriter transport = new TestSimplexTransportWriter(
out, TransportConstants.MIN_CONNECTION_LENGTH, true); out, MIN_CONNECTION_LENGTH, true);
byte[] tag = new byte[TAG_LENGTH];
ConnectionContext ctx = new ConnectionContext(contactId, transportId,
tag, secret, 0L, true);
OutgoingSimplexConnection connection = new OutgoingSimplexConnection(db, OutgoingSimplexConnection connection = new OutgoingSimplexConnection(db,
connRegistry, connFactory, protoFactory, contactId, transportId, connRegistry, connFactory, protoFactory, ctx, transport);
transportIndex, transport);
final ConnectionContext ctx = context.mock(ConnectionContext.class);
final Ack ack = context.mock(Ack.class); final Ack ack = context.mock(Ack.class);
final BatchId batchId = new BatchId(TestUtils.getRandomId()); final BatchId batchId = new BatchId(TestUtils.getRandomId());
final RawBatch batch = context.mock(RawBatch.class); final RawBatch batch = context.mock(RawBatch.class);
final byte[] message = new byte[1234]; final byte[] message = new byte[1234];
context.checking(new Expectations() {{ context.checking(new Expectations() {{
oneOf(db).getConnectionContext(contactId, transportIndex);
will(returnValue(ctx));
oneOf(ctx).getSecret();
will(returnValue(secret));
// No transports to send // No transports to send
oneOf(db).generateTransportUpdate(contactId); oneOf(db).generateTransportUpdate(contactId);
will(returnValue(null)); will(returnValue(null));

View File

@@ -23,7 +23,6 @@ import net.sf.briar.api.protocol.ProtocolReaderFactory;
import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.ProtocolWriterFactory;
import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionReaderFactory;
@@ -49,11 +48,12 @@ import com.google.inject.Injector;
public class SimplexConnectionReadWriteTest extends BriarTestCase { public class SimplexConnectionReadWriteTest extends BriarTestCase {
// FIXME: This is an integration test, not a unit test
private final File testDir = TestUtils.getTestDirectory(); private final File testDir = TestUtils.getTestDirectory();
private final File aliceDir = new File(testDir, "alice"); private final File aliceDir = new File(testDir, "alice");
private final File bobDir = new File(testDir, "bob"); private final File bobDir = new File(testDir, "bob");
private final TransportId transportId; private final TransportId transportId;
private final TransportIndex transportIndex;
private final byte[] aliceToBobSecret, bobToAliceSecret; private final byte[] aliceToBobSecret, bobToAliceSecret;
private Injector alice, bob; private Injector alice, bob;
@@ -61,7 +61,6 @@ public class SimplexConnectionReadWriteTest extends BriarTestCase {
public SimplexConnectionReadWriteTest() throws Exception { public SimplexConnectionReadWriteTest() throws Exception {
super(); super();
transportId = new TransportId(TestUtils.getRandomId()); transportId = new TransportId(TestUtils.getRandomId());
transportIndex = new TransportIndex(1);
// Create matching secrets for Alice and Bob // Create matching secrets for Alice and Bob
Random r = new Random(); Random r = new Random();
aliceToBobSecret = new byte[32]; aliceToBobSecret = new byte[32];
@@ -102,7 +101,7 @@ public class SimplexConnectionReadWriteTest extends BriarTestCase {
DatabaseComponent db = alice.getInstance(DatabaseComponent.class); DatabaseComponent db = alice.getInstance(DatabaseComponent.class);
db.open(false); db.open(false);
// Add Bob as a contact and send him a message // Add Bob as a contact and send him a message
ContactId contactId = db.addContact(bobToAliceSecret, aliceToBobSecret); ContactId contactId = db.addContact();
String subject = "Hello"; String subject = "Hello";
byte[] body = "Hi Bob!".getBytes("UTF-8"); byte[] body = "Hi Bob!".getBytes("UTF-8");
MessageFactory messageFactory = alice.getInstance(MessageFactory.class); MessageFactory messageFactory = alice.getInstance(MessageFactory.class);
@@ -111,16 +110,19 @@ public class SimplexConnectionReadWriteTest extends BriarTestCase {
// Create an outgoing batch connection // Create an outgoing batch connection
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
ConnectionRegistry connRegistry = ConnectionRegistry connRegistry =
alice.getInstance(ConnectionRegistry.class); alice.getInstance(ConnectionRegistry.class);
ConnectionWriterFactory connFactory = ConnectionWriterFactory connFactory =
alice.getInstance(ConnectionWriterFactory.class); alice.getInstance(ConnectionWriterFactory.class);
ProtocolWriterFactory protoFactory = ProtocolWriterFactory protoFactory =
alice.getInstance(ProtocolWriterFactory.class); alice.getInstance(ProtocolWriterFactory.class);
TestSimplexTransportWriter transport = new TestSimplexTransportWriter( TestSimplexTransportWriter transport = new TestSimplexTransportWriter(
out, Long.MAX_VALUE, false); out, Long.MAX_VALUE, false);
// FIXME: Encode the tag
byte[] tag = new byte[TAG_LENGTH];
ConnectionContext ctx = new ConnectionContext(contactId, transportId,
tag, aliceToBobSecret, 0L, true);
OutgoingSimplexConnection simplex = new OutgoingSimplexConnection(db, OutgoingSimplexConnection simplex = new OutgoingSimplexConnection(db,
connRegistry, connFactory, protoFactory, contactId, transportId, connRegistry, connFactory, protoFactory, ctx, transport);
transportIndex, transport);
// Write whatever needs to be written // Write whatever needs to be written
simplex.write(); simplex.write();
assertTrue(transport.getDisposed()); assertTrue(transport.getDisposed());
@@ -139,14 +141,12 @@ public class SimplexConnectionReadWriteTest extends BriarTestCase {
MessageListener listener = new MessageListener(); MessageListener listener = new MessageListener();
db.addListener(listener); db.addListener(listener);
// Add Alice as a contact // Add Alice as a contact
ContactId contactId = db.addContact(aliceToBobSecret, bobToAliceSecret); ContactId contactId = db.addContact();
// Add the transport
assertEquals(transportIndex, db.addTransport(transportId));
// Fake a transport update from Alice // Fake a transport update from Alice
TransportUpdate transportUpdate = new TransportUpdate() { TransportUpdate transportUpdate = new TransportUpdate() {
public Collection<Transport> getTransports() { public Collection<Transport> getTransports() {
Transport t = new Transport(transportId, transportIndex); Transport t = new Transport(transportId);
return Collections.singletonList(t); return Collections.singletonList(t);
} }
@@ -164,19 +164,17 @@ public class SimplexConnectionReadWriteTest extends BriarTestCase {
ConnectionContext ctx = rec.acceptConnection(transportId, tag); ConnectionContext ctx = rec.acceptConnection(transportId, tag);
assertNotNull(ctx); assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId()); assertEquals(contactId, ctx.getContactId());
assertEquals(transportIndex, ctx.getTransportIndex());
// Create an incoming batch connection // Create an incoming batch connection
ConnectionRegistry connRegistry = ConnectionRegistry connRegistry =
bob.getInstance(ConnectionRegistry.class); bob.getInstance(ConnectionRegistry.class);
ConnectionReaderFactory connFactory = ConnectionReaderFactory connFactory =
bob.getInstance(ConnectionReaderFactory.class); bob.getInstance(ConnectionReaderFactory.class);
ProtocolReaderFactory protoFactory = ProtocolReaderFactory protoFactory =
bob.getInstance(ProtocolReaderFactory.class); bob.getInstance(ProtocolReaderFactory.class);
TestSimplexTransportReader transport = new TestSimplexTransportReader(in); TestSimplexTransportReader transport = new TestSimplexTransportReader(in);
IncomingSimplexConnection simplex = new IncomingSimplexConnection( IncomingSimplexConnection simplex = new IncomingSimplexConnection(
new ImmediateExecutor(), new ImmediateExecutor(), db, new ImmediateExecutor(), new ImmediateExecutor(), db,
connRegistry, connFactory, protoFactory, ctx, transportId, connRegistry, connFactory, protoFactory, ctx, transport);
transport);
// No messages should have been added yet // No messages should have been added yet
assertFalse(listener.messagesAdded); assertFalse(listener.messagesAdded);
// Read whatever needs to be read // Read whatever needs to be read

View File

@@ -1,624 +0,0 @@
package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.Executor;
import javax.crypto.Cipher;
import net.sf.briar.BriarTestCase;
import net.sf.briar.TestUtils;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.event.ContactRemovedEvent;
import net.sf.briar.api.db.event.RemoteTransportsUpdatedEvent;
import net.sf.briar.api.db.event.TransportAddedEvent;
import net.sf.briar.api.protocol.Transport;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.plugins.ImmediateExecutor;
import org.jmock.Expectations;
import org.jmock.Mockery;
import org.junit.Test;
import com.google.inject.Guice;
import com.google.inject.Injector;
public class ConnectionRecogniserImplTest extends BriarTestCase {
private final CryptoComponent crypto;
private final ContactId contactId;
private final byte[] inSecret;
private final TransportId transportId;
private final TransportIndex localIndex, remoteIndex;
private final Collection<Transport> localTransports, remoteTransports;
public ConnectionRecogniserImplTest() {
super();
Injector i = Guice.createInjector(new CryptoModule());
crypto = i.getInstance(CryptoComponent.class);
contactId = new ContactId(1);
inSecret = new byte[32];
new Random().nextBytes(inSecret);
transportId = new TransportId(TestUtils.getRandomId());
localIndex = new TransportIndex(13);
remoteIndex = new TransportIndex(7);
Map<String, String> properties = Collections.singletonMap("foo", "bar");
Transport localTransport = new Transport(transportId, localIndex,
properties);
localTransports = Collections.singletonList(localTransport);
Transport remoteTransport = new Transport(transportId, remoteIndex,
properties);
remoteTransports = Collections.singletonList(remoteTransport);
}
@Test
public void testUnexpectedIv() throws Exception {
final ConnectionWindow window = createConnectionWindow(remoteIndex);
Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Initialise
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
oneOf(db).getLocalTransports();
will(returnValue(localTransports));
oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId)));
oneOf(db).getRemoteIndex(contactId, transportId);
will(returnValue(remoteIndex));
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
}});
Executor executor = new ImmediateExecutor();
ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(executor, db,
crypto);
assertNull(c.acceptConnection(transportId, new byte[TAG_LENGTH]));
context.assertIsSatisfied();
}
@Test
public void testExpectedIv() throws Exception {
final ConnectionWindow window = createConnectionWindow(remoteIndex);
Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Initialise
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
oneOf(db).getLocalTransports();
will(returnValue(localTransports));
oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId)));
oneOf(db).getRemoteIndex(contactId, transportId);
will(returnValue(remoteIndex));
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
// Update the window
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
oneOf(db).setConnectionWindow(contactId, remoteIndex, window);
}});
Executor executor = new ImmediateExecutor();
ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(executor, db,
crypto);
byte[] tag = calculateTag();
// The tag should not be expected by the wrong transport
TransportId wrong = new TransportId(TestUtils.getRandomId());
assertNull(c.acceptConnection(wrong, tag));
// The tag should be expected by the right transport
ConnectionContext ctx = c.acceptConnection(transportId, tag);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(remoteIndex, ctx.getTransportIndex());
assertEquals(3, ctx.getConnectionNumber());
// The tag should no longer be expected
assertNull(c.acceptConnection(transportId, tag));
// The window should have advanced
Map<Long, byte[]> unseen = window.getUnseen();
assertEquals(19, unseen.size());
for(int i = 0; i < 19; i++) {
assertEquals(i != 3, unseen.containsKey(Long.valueOf(i)));
}
context.assertIsSatisfied();
}
@Test
public void testContactRemovedAfterInit() throws Exception {
final ConnectionWindow window = createConnectionWindow(remoteIndex);
Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Initialise before removing contact
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
oneOf(db).getLocalTransports();
will(returnValue(localTransports));
oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId)));
oneOf(db).getRemoteIndex(contactId, transportId);
will(returnValue(remoteIndex));
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
}});
Executor executor = new ImmediateExecutor();
ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(executor, db,
crypto);
byte[] tag = calculateTag();
// Ensure the recogniser is initialised
assertFalse(c.isInitialised());
assertNull(c.acceptConnection(transportId, new byte[TAG_LENGTH]));
assertTrue(c.isInitialised());
// Remove the contact
c.eventOccurred(new ContactRemovedEvent(contactId));
// The tag should not be expected
assertNull(c.acceptConnection(transportId, tag));
context.assertIsSatisfied();
}
@Test
public void testContactRemovedBeforeInit() throws Exception {
Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Initialise after removing contact
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
oneOf(db).getLocalTransports();
will(returnValue(localTransports));
oneOf(db).getContacts();
will(returnValue(Collections.emptyList()));
}});
Executor executor = new ImmediateExecutor();
ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(executor, db,
crypto);
byte[] tag = calculateTag();
// Remove the contact
c.eventOccurred(new ContactRemovedEvent(contactId));
// The tag should not be expected
assertFalse(c.isInitialised());
assertNull(c.acceptConnection(transportId, tag));
assertTrue(c.isInitialised());
context.assertIsSatisfied();
}
@Test
public void testLocalTransportAddedAfterInit() throws Exception {
final ConnectionWindow window = createConnectionWindow(remoteIndex);
Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Initialise before adding transport
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
oneOf(db).getLocalTransports();
will(returnValue(Collections.emptyList()));
oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId)));
// Add the transport
oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId)));
oneOf(db).getRemoteIndex(contactId, transportId);
will(returnValue(remoteIndex));
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
// Update the window
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
oneOf(db).setConnectionWindow(contactId, remoteIndex, window);
}});
Executor executor = new ImmediateExecutor();
ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(executor, db,
crypto);
byte[] tag = calculateTag();
// The tag should not be expected
assertFalse(c.isInitialised());
assertNull(c.acceptConnection(transportId, tag));
assertTrue(c.isInitialised());
// Add the transport
c.eventOccurred(new TransportAddedEvent(transportId));
// The tag should be expected
ConnectionContext ctx = c.acceptConnection(transportId, tag);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(remoteIndex, ctx.getTransportIndex());
assertEquals(3, ctx.getConnectionNumber());
// The tag should no longer be expected
assertNull(c.acceptConnection(transportId, tag));
// The window should have advanced
Map<Long, byte[]> unseen = window.getUnseen();
assertEquals(19, unseen.size());
for(int i = 0; i < 19; i++) {
assertEquals(i != 3, unseen.containsKey(Long.valueOf(i)));
}
context.assertIsSatisfied();
}
@Test
public void testLocalTransportAddedBeforeInit() throws Exception {
final ConnectionWindow window = createConnectionWindow(remoteIndex);
Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Initialise after adding transport
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
oneOf(db).getLocalTransports();
will(returnValue(localTransports));
oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId)));
oneOf(db).getRemoteIndex(contactId, transportId);
will(returnValue(remoteIndex));
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
// Update the window
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
oneOf(db).setConnectionWindow(contactId, remoteIndex, window);
}});
Executor executor = new ImmediateExecutor();
ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(executor, db,
crypto);
byte[] tag = calculateTag();
// Add the transport
c.eventOccurred(new TransportAddedEvent(transportId));
// The tag should be expected
assertFalse(c.isInitialised());
ConnectionContext ctx = c.acceptConnection(transportId, tag);
assertTrue(c.isInitialised());
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(remoteIndex, ctx.getTransportIndex());
assertEquals(3, ctx.getConnectionNumber());
// The tag should no longer be expected
assertNull(c.acceptConnection(transportId, tag));
// The window should have advanced
Map<Long, byte[]> unseen = window.getUnseen();
assertEquals(19, unseen.size());
for(int i = 0; i < 19; i++) {
assertEquals(i != 3, unseen.containsKey(Long.valueOf(i)));
}
context.assertIsSatisfied();
}
@Test
public void testRemoteTransportAddedAfterInit() throws Exception {
final ConnectionWindow window = createConnectionWindow(remoteIndex);
Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Initialise before updating the contact
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
oneOf(db).getLocalTransports();
will(returnValue(localTransports));
oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId)));
oneOf(db).getRemoteIndex(contactId, transportId);
will(returnValue(null));
// Update the contact
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
// Update the window
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
oneOf(db).setConnectionWindow(contactId, remoteIndex, window);
}});
Executor executor = new ImmediateExecutor();
ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(executor, db,
crypto);
byte[] tag = calculateTag();
// The tag should not be expected
assertFalse(c.isInitialised());
assertNull(c.acceptConnection(transportId, tag));
assertTrue(c.isInitialised());
// Update the contact
c.eventOccurred(new RemoteTransportsUpdatedEvent(contactId,
remoteTransports));
// The tag should be expected
ConnectionContext ctx = c.acceptConnection(transportId, tag);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(remoteIndex, ctx.getTransportIndex());
assertEquals(3, ctx.getConnectionNumber());
// The tag should no longer be expected
assertNull(c.acceptConnection(transportId, tag));
// The window should have advanced
Map<Long, byte[]> unseen = window.getUnseen();
assertEquals(19, unseen.size());
for(int i = 0; i < 19; i++) {
assertEquals(i != 3, unseen.containsKey(Long.valueOf(i)));
}
context.assertIsSatisfied();
}
@Test
public void testRemoteTransportAddedBeforeInit() throws Exception {
final ConnectionWindow window = createConnectionWindow(remoteIndex);
Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Initialise after updating the contact
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
oneOf(db).getLocalTransports();
will(returnValue(localTransports));
oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId)));
oneOf(db).getRemoteIndex(contactId, transportId);
will(returnValue(remoteIndex));
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
// Update the window
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
oneOf(db).setConnectionWindow(contactId, remoteIndex, window);
}});
Executor executor = new ImmediateExecutor();
ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(executor, db,
crypto);
byte[] tag = calculateTag();
// Update the contact
c.eventOccurred(new RemoteTransportsUpdatedEvent(contactId,
remoteTransports));
// The tag should be expected
assertFalse(c.isInitialised());
ConnectionContext ctx = c.acceptConnection(transportId, tag);
assertTrue(c.isInitialised());
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(remoteIndex, ctx.getTransportIndex());
assertEquals(3, ctx.getConnectionNumber());
// The tag should no longer be expected
assertNull(c.acceptConnection(transportId, tag));
// The window should have advanced
Map<Long, byte[]> unseen = window.getUnseen();
assertEquals(19, unseen.size());
for(int i = 0; i < 19; i++) {
assertEquals(i != 3, unseen.containsKey(Long.valueOf(i)));
}
context.assertIsSatisfied();
}
@Test
public void testRemoteTransportRemovedAfterInit() throws Exception {
final ConnectionWindow window = createConnectionWindow(remoteIndex);
Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Initialise before updating the contact
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
oneOf(db).getLocalTransports();
will(returnValue(localTransports));
oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId)));
oneOf(db).getRemoteIndex(contactId, transportId);
will(returnValue(remoteIndex));
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
}});
Executor executor = new ImmediateExecutor();
ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(executor, db,
crypto);
byte[] tag = calculateTag();
// Ensure the recogniser is initialised
assertFalse(c.isInitialised());
assertNull(c.acceptConnection(transportId, new byte[TAG_LENGTH]));
assertTrue(c.isInitialised());
// Update the contact
c.eventOccurred(new RemoteTransportsUpdatedEvent(contactId,
Collections.<Transport>emptyList()));
// The tag should not be expected
assertNull(c.acceptConnection(transportId, tag));
context.assertIsSatisfied();
}
@Test
public void testRemoteTransportRemovedBeforeInit() throws Exception {
Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Initialise after updating the contact
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
oneOf(db).getLocalTransports();
will(returnValue(localTransports));
oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId)));
oneOf(db).getRemoteIndex(contactId, transportId);
will(returnValue(null));
}});
Executor executor = new ImmediateExecutor();
ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(executor, db,
crypto);
byte[] tag = calculateTag();
// Update the contact
c.eventOccurred(new RemoteTransportsUpdatedEvent(contactId,
Collections.<Transport>emptyList()));
// The tag should not be expected
assertFalse(c.isInitialised());
assertNull(c.acceptConnection(transportId, tag));
assertTrue(c.isInitialised());
context.assertIsSatisfied();
}
@Test
public void testRemoteTransportIndexChangedAfterInit() throws Exception {
// The contact changes the transport ID <-> index relationships
final TransportId transportId1 =
new TransportId(TestUtils.getRandomId());
final TransportIndex remoteIndex1 = new TransportIndex(11);
Map<String, String> properties = Collections.singletonMap("foo", "bar");
Transport remoteTransport = new Transport(transportId, remoteIndex1,
properties);
Transport remoteTransport1 = new Transport(transportId1, remoteIndex,
properties);
Collection<Transport> remoteTransports1 = Arrays.asList(
new Transport[] {remoteTransport, remoteTransport1});
// Use two local transports for this test
TransportIndex localIndex1 = new TransportIndex(17);
Transport localTransport = new Transport(transportId, localIndex,
properties);
Transport localTransport1 = new Transport(transportId1, localIndex1,
properties);
final Collection<Transport> localTransports1 = Arrays.asList(
new Transport[] {localTransport, localTransport1});
final ConnectionWindow window = createConnectionWindow(remoteIndex);
final ConnectionWindow window1 = createConnectionWindow(remoteIndex1);
Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Initialise before updating the contact
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
oneOf(db).getLocalTransports();
will(returnValue(localTransports1));
oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId)));
// First, transportId <-> remoteIndex, transportId1 <-> remoteIndex
oneOf(db).getRemoteIndex(contactId, transportId);
will(returnValue(remoteIndex));
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
oneOf(db).getRemoteIndex(contactId, transportId1);
will(returnValue(remoteIndex1));
oneOf(db).getConnectionWindow(contactId, remoteIndex1);
will(returnValue(window1));
// Later, transportId <-> remoteIndex1, transportId1 <-> remoteIndex
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
oneOf(db).getConnectionWindow(contactId, remoteIndex1);
will(returnValue(window1));
// Update the window
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
oneOf(db).setConnectionWindow(contactId, remoteIndex, window);
}});
Executor executor = new ImmediateExecutor();
ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(executor, db,
crypto);
byte[] tag = calculateTag();
// Ensure the recogniser is initialised
assertFalse(c.isInitialised());
assertNull(c.acceptConnection(transportId, new byte[TAG_LENGTH]));
assertTrue(c.isInitialised());
// Update the contact
c.eventOccurred(new RemoteTransportsUpdatedEvent(contactId,
remoteTransports1));
// The tag should not be expected by the old transport
assertNull(c.acceptConnection(transportId, tag));
// The tag should be expected by the new transport
ConnectionContext ctx = c.acceptConnection(transportId1, tag);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(remoteIndex, ctx.getTransportIndex());
assertEquals(3, ctx.getConnectionNumber());
// The tag should no longer be expected
assertNull(c.acceptConnection(transportId1, tag));
// The window should have advanced
Map<Long, byte[]> unseen = window.getUnseen();
assertEquals(19, unseen.size());
for(int i = 0; i < 19; i++) {
assertEquals(i != 3, unseen.containsKey(Long.valueOf(i)));
}
context.assertIsSatisfied();
}
@Test
public void testRemoteTransportIndexChangedBeforeInit() throws Exception {
// The contact changes the transport ID <-> index relationships
final TransportId transportId1 =
new TransportId(TestUtils.getRandomId());
final TransportIndex remoteIndex1 = new TransportIndex(11);
Map<String, String> properties = Collections.singletonMap("foo", "bar");
Transport remoteTransport = new Transport(transportId, remoteIndex1,
properties);
Transport remoteTransport1 = new Transport(transportId1, remoteIndex,
properties);
Collection<Transport> remoteTransports1 = Arrays.asList(
new Transport[] {remoteTransport, remoteTransport1});
// Use two local transports for this test
TransportIndex localIndex1 = new TransportIndex(17);
Transport localTransport = new Transport(transportId, localIndex,
properties);
Transport localTransport1 = new Transport(transportId1, localIndex1,
properties);
final Collection<Transport> localTransports1 = Arrays.asList(
new Transport[] {localTransport, localTransport1});
final ConnectionWindow window = createConnectionWindow(remoteIndex);
final ConnectionWindow window1 = createConnectionWindow(remoteIndex1);
Mockery context = new Mockery();
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Initialise after updating the contact
oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class)));
oneOf(db).getLocalTransports();
will(returnValue(localTransports1));
oneOf(db).getContacts();
will(returnValue(Collections.singletonList(contactId)));
// First, transportId <-> remoteIndex1, transportId1 <-> remoteIndex
oneOf(db).getRemoteIndex(contactId, transportId);
will(returnValue(remoteIndex1));
oneOf(db).getConnectionWindow(contactId, remoteIndex1);
will(returnValue(window1));
oneOf(db).getRemoteIndex(contactId, transportId1);
will(returnValue(remoteIndex));
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
// Update the window
oneOf(db).getConnectionWindow(contactId, remoteIndex);
will(returnValue(window));
oneOf(db).setConnectionWindow(contactId, remoteIndex, window);
}});
Executor executor = new ImmediateExecutor();
ConnectionRecogniserImpl c = new ConnectionRecogniserImpl(executor, db,
crypto);
byte[] tag = calculateTag();
// Update the contact
c.eventOccurred(new RemoteTransportsUpdatedEvent(contactId,
remoteTransports1));
// The tag should not be expected by the old transport
assertFalse(c.isInitialised());
assertNull(c.acceptConnection(transportId, tag));
assertTrue(c.isInitialised());
// The tag should be expected by the new transport
ConnectionContext ctx = c.acceptConnection(transportId1, tag);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(remoteIndex, ctx.getTransportIndex());
assertEquals(3, ctx.getConnectionNumber());
// The tag should no longer be expected
assertNull(c.acceptConnection(transportId1, tag));
// The window should have advanced
Map<Long, byte[]> unseen = window.getUnseen();
assertEquals(19, unseen.size());
for(int i = 0; i < 19; i++) {
assertEquals(i != 3, unseen.containsKey(Long.valueOf(i)));
}
context.assertIsSatisfied();
}
private ConnectionWindow createConnectionWindow(TransportIndex index) {
return new ConnectionWindowImpl(crypto, index, inSecret) {
@Override
public void erase() {}
};
}
private byte[] calculateTag() throws Exception {
// Calculate the shared secret for connection number 3
byte[] secret = inSecret;
for(int i = 0; i < 4; i++) {
secret = crypto.deriveNextSecret(secret, remoteIndex.getInt(), i);
}
// Calculate the expected tag for connection number 3
ErasableKey tagKey = crypto.deriveTagKey(secret, true);
Cipher tagCipher = crypto.getTagCipher();
byte[] tag = new byte[TAG_LENGTH];
TagEncoder.encodeTag(tag, tagCipher, tagKey);
return tag;
}
}

View File

@@ -1,39 +1,19 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import java.util.HashMap; import java.util.HashSet;
import java.util.Map; import java.util.Set;
import java.util.Random;
import net.sf.briar.BriarTestCase; import net.sf.briar.BriarTestCase;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.protocol.TransportIndex;
import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindow;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.util.ByteUtils; import net.sf.briar.util.ByteUtils;
import org.junit.Test; import org.junit.Test;
import com.google.inject.Guice;
import com.google.inject.Injector;
public class ConnectionWindowImplTest extends BriarTestCase { public class ConnectionWindowImplTest extends BriarTestCase {
private final CryptoComponent crypto;
private final byte[] secret;
private final TransportIndex transportIndex = new TransportIndex(13);
public ConnectionWindowImplTest() {
super();
Injector i = Guice.createInjector(new CryptoModule());
crypto = i.getInstance(CryptoComponent.class);
secret = new byte[32];
new Random().nextBytes(secret);
}
@Test @Test
public void testWindowSliding() { public void testWindowSliding() {
ConnectionWindow w = new ConnectionWindowImpl(crypto, ConnectionWindow w = new ConnectionWindowImpl();
transportIndex, secret);
for(int i = 0; i < 100; i++) { for(int i = 0; i < 100; i++) {
assertFalse(w.isSeen(i)); assertFalse(w.isSeen(i));
w.setSeen(i); w.setSeen(i);
@@ -43,8 +23,7 @@ public class ConnectionWindowImplTest extends BriarTestCase {
@Test @Test
public void testWindowJumping() { public void testWindowJumping() {
ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex, ConnectionWindow w = new ConnectionWindowImpl();
secret);
for(int i = 0; i < 100; i += 13) { for(int i = 0; i < 100; i += 13) {
assertFalse(w.isSeen(i)); assertFalse(w.isSeen(i));
w.setSeen(i); w.setSeen(i);
@@ -54,8 +33,7 @@ public class ConnectionWindowImplTest extends BriarTestCase {
@Test @Test
public void testWindowUpperLimit() { public void testWindowUpperLimit() {
ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex, ConnectionWindow w = new ConnectionWindowImpl();
secret);
// Centre is 0, highest value in window is 15 // Centre is 0, highest value in window is 15
w.setSeen(15); w.setSeen(15);
// Centre is 16, highest value in window is 31 // Centre is 16, highest value in window is 31
@@ -66,11 +44,11 @@ public class ConnectionWindowImplTest extends BriarTestCase {
fail(); fail();
} catch(IllegalArgumentException expected) {} } catch(IllegalArgumentException expected) {}
// Values greater than 2^32 - 1 should never be allowed // Values greater than 2^32 - 1 should never be allowed
Map<Long, byte[]> unseen = new HashMap<Long, byte[]>(); Set<Long> unseen = new HashSet<Long>();
for(int i = 0; i < 32; i++) { for(int i = 0; i < 32; i++) {
unseen.put(ByteUtils.MAX_32_BIT_UNSIGNED - i, secret); unseen.add(ByteUtils.MAX_32_BIT_UNSIGNED - i);
} }
w = new ConnectionWindowImpl(crypto, transportIndex, unseen); w = new ConnectionWindowImpl(unseen);
w.setSeen(ByteUtils.MAX_32_BIT_UNSIGNED); w.setSeen(ByteUtils.MAX_32_BIT_UNSIGNED);
try { try {
w.setSeen(ByteUtils.MAX_32_BIT_UNSIGNED + 1); w.setSeen(ByteUtils.MAX_32_BIT_UNSIGNED + 1);
@@ -80,8 +58,7 @@ public class ConnectionWindowImplTest extends BriarTestCase {
@Test @Test
public void testWindowLowerLimit() { public void testWindowLowerLimit() {
ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex, ConnectionWindow w = new ConnectionWindowImpl();
secret);
// Centre is 0, negative values should never be allowed // Centre is 0, negative values should never be allowed
try { try {
w.setSeen(-1); w.setSeen(-1);
@@ -111,8 +88,7 @@ public class ConnectionWindowImplTest extends BriarTestCase {
@Test @Test
public void testCannotSetSeenTwice() { public void testCannotSetSeenTwice() {
ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex, ConnectionWindow w = new ConnectionWindowImpl();
secret);
w.setSeen(15); w.setSeen(15);
try { try {
w.setSeen(15); w.setSeen(15);
@@ -122,13 +98,12 @@ public class ConnectionWindowImplTest extends BriarTestCase {
@Test @Test
public void testGetUnseenConnectionNumbers() { public void testGetUnseenConnectionNumbers() {
ConnectionWindow w = new ConnectionWindowImpl(crypto, transportIndex, ConnectionWindow w = new ConnectionWindowImpl();
secret);
// Centre is 0; window should cover 0 to 15, inclusive, with none seen // Centre is 0; window should cover 0 to 15, inclusive, with none seen
Map<Long, byte[]> unseen = w.getUnseen(); Set<Long> unseen = w.getUnseen();
assertEquals(16, unseen.size()); assertEquals(16, unseen.size());
for(int i = 0; i < 16; i++) { for(int i = 0; i < 16; i++) {
assertTrue(unseen.containsKey(Long.valueOf(i))); assertTrue(unseen.contains(Long.valueOf(i)));
assertFalse(w.isSeen(i)); assertFalse(w.isSeen(i));
} }
w.setSeen(3); w.setSeen(3);
@@ -138,10 +113,10 @@ public class ConnectionWindowImplTest extends BriarTestCase {
assertEquals(19, unseen.size()); assertEquals(19, unseen.size());
for(int i = 0; i < 21; i++) { for(int i = 0; i < 21; i++) {
if(i == 3 || i == 4) { if(i == 3 || i == 4) {
assertFalse(unseen.containsKey(Long.valueOf(i))); assertFalse(unseen.contains(Long.valueOf(i)));
assertTrue(w.isSeen(i)); assertTrue(w.isSeen(i));
} else { } else {
assertTrue(unseen.containsKey(Long.valueOf(i))); assertTrue(unseen.contains(Long.valueOf(i)));
assertFalse(w.isSeen(i)); assertFalse(w.isSeen(i));
} }
} }
@@ -151,10 +126,10 @@ public class ConnectionWindowImplTest extends BriarTestCase {
assertEquals(30, unseen.size()); assertEquals(30, unseen.size());
for(int i = 4; i < 36; i++) { for(int i = 4; i < 36; i++) {
if(i == 4 || i == 19) { if(i == 4 || i == 19) {
assertFalse(unseen.containsKey(Long.valueOf(i))); assertFalse(unseen.contains(Long.valueOf(i)));
assertTrue(w.isSeen(i)); assertTrue(w.isSeen(i));
} else { } else {
assertTrue(unseen.containsKey(Long.valueOf(i))); assertTrue(unseen.contains(Long.valueOf(i)));
assertFalse(w.isSeen(i)); assertFalse(w.isSeen(i));
} }
} }

View File

@@ -2,12 +2,17 @@ package net.sf.briar.transport;
import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH; import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MIN_CONNECTION_LENGTH; import static net.sf.briar.api.transport.TransportConstants.MIN_CONNECTION_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.util.Random; import java.util.Random;
import net.sf.briar.BriarTestCase; import net.sf.briar.BriarTestCase;
import net.sf.briar.TestDatabaseModule; import net.sf.briar.TestDatabaseModule;
import net.sf.briar.TestUtils;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWriter; import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.clock.ClockModule; import net.sf.briar.clock.ClockModule;
@@ -27,6 +32,8 @@ import com.google.inject.Injector;
public class ConnectionWriterTest extends BriarTestCase { public class ConnectionWriterTest extends BriarTestCase {
private final ConnectionWriterFactory connectionWriterFactory; private final ConnectionWriterFactory connectionWriterFactory;
private final ContactId contactId;
private final TransportId transportId;
private final byte[] secret; private final byte[] secret;
public ConnectionWriterTest() throws Exception { public ConnectionWriterTest() throws Exception {
@@ -37,6 +44,8 @@ public class ConnectionWriterTest extends BriarTestCase {
new TestDatabaseModule(), new SimplexProtocolModule(), new TestDatabaseModule(), new SimplexProtocolModule(),
new TransportModule(), new DuplexProtocolModule()); new TransportModule(), new DuplexProtocolModule());
connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class); connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class);
contactId = new ContactId(234);
transportId = new TransportId(TestUtils.getRandomId());
secret = new byte[32]; secret = new byte[32];
new Random().nextBytes(secret); new Random().nextBytes(secret);
} }
@@ -44,9 +53,12 @@ public class ConnectionWriterTest extends BriarTestCase {
@Test @Test
public void testOverheadWithTag() throws Exception { public void testOverheadWithTag() throws Exception {
ByteArrayOutputStream out = ByteArrayOutputStream out =
new ByteArrayOutputStream(MIN_CONNECTION_LENGTH); new ByteArrayOutputStream(MIN_CONNECTION_LENGTH);
byte[] tag = new byte[TAG_LENGTH];
ConnectionContext ctx = new ConnectionContext(contactId, transportId,
tag, secret, 0L, true);
ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out,
MIN_CONNECTION_LENGTH, secret, true); MIN_CONNECTION_LENGTH, ctx, true);
// Check that the connection writer thinks there's room for a packet // Check that the connection writer thinks there's room for a packet
long capacity = w.getRemainingCapacity(); long capacity = w.getRemainingCapacity();
assertTrue(capacity > MAX_PACKET_LENGTH); assertTrue(capacity > MAX_PACKET_LENGTH);
@@ -63,9 +75,11 @@ public class ConnectionWriterTest extends BriarTestCase {
@Test @Test
public void testOverheadWithoutTag() throws Exception { public void testOverheadWithoutTag() throws Exception {
ByteArrayOutputStream out = ByteArrayOutputStream out =
new ByteArrayOutputStream(MIN_CONNECTION_LENGTH); new ByteArrayOutputStream(MIN_CONNECTION_LENGTH);
ConnectionContext ctx = new ConnectionContext(contactId, transportId,
null, secret, 0L, true);
ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out,
MIN_CONNECTION_LENGTH, secret, false); MIN_CONNECTION_LENGTH, ctx, false);
// Check that the connection writer thinks there's room for a packet // Check that the connection writer thinks there's room for a packet
long capacity = w.getRemainingCapacity(); long capacity = w.getRemainingCapacity();
assertTrue(capacity > MAX_PACKET_LENGTH); assertTrue(capacity > MAX_PACKET_LENGTH);

View File

@@ -1,6 +1,5 @@
package net.sf.briar.transport; package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
@@ -9,8 +8,6 @@ import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.Random; import java.util.Random;
import javax.crypto.Cipher;
import net.sf.briar.BriarTestCase; import net.sf.briar.BriarTestCase;
import net.sf.briar.api.crypto.AuthenticatedCipher; import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
@@ -29,24 +26,21 @@ public class FrameReadWriteTest extends BriarTestCase {
private final int FRAME_LENGTH = 2048; private final int FRAME_LENGTH = 2048;
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final Cipher tagCipher;
private final AuthenticatedCipher frameCipher; private final AuthenticatedCipher frameCipher;
private final Random random; private final Random random;
private final byte[] outSecret; private final byte[] outSecret;
private final ErasableKey tagKey, frameKey; private final ErasableKey frameKey;
public FrameReadWriteTest() { public FrameReadWriteTest() {
super(); super();
Injector i = Guice.createInjector(new CryptoModule()); Injector i = Guice.createInjector(new CryptoModule());
crypto = i.getInstance(CryptoComponent.class); crypto = i.getInstance(CryptoComponent.class);
tagCipher = crypto.getTagCipher();
frameCipher = crypto.getFrameCipher(); frameCipher = crypto.getFrameCipher();
random = new Random(); random = new Random();
// Since we're sending frames to ourselves, we only need outgoing keys // Since we're sending frames to ourselves, we only need outgoing keys
outSecret = new byte[32]; outSecret = new byte[32];
random.nextBytes(outSecret); random.nextBytes(outSecret);
tagKey = crypto.deriveTagKey(outSecret, true); frameKey = crypto.deriveFrameKey(outSecret, 0L, true, true);
frameKey = crypto.deriveFrameKey(outSecret, true);
} }
@Test @Test
@@ -60,22 +54,17 @@ public class FrameReadWriteTest extends BriarTestCase {
} }
private void testWriteAndRead(boolean initiator) throws Exception { private void testWriteAndRead(boolean initiator) throws Exception {
// Encode the tag
byte[] tag = new byte[TAG_LENGTH];
TagEncoder.encodeTag(tag, tagCipher, tagKey);
// Generate two random frames // Generate two random frames
byte[] frame = new byte[1234]; byte[] frame = new byte[1234];
random.nextBytes(frame); random.nextBytes(frame);
byte[] frame1 = new byte[321]; byte[] frame1 = new byte[321];
random.nextBytes(frame1); random.nextBytes(frame1);
// Copy the keys - the copies will be erased // Copy the frame key - the copy will be erased
ErasableKey tagCopy = tagKey.copy();
ErasableKey frameCopy = frameKey.copy(); ErasableKey frameCopy = frameKey.copy();
// Write the frames // Write the frames
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
FrameWriter encryptionOut = new OutgoingEncryptionLayer(out, FrameWriter encryptionOut = new OutgoingEncryptionLayer(out,
Long.MAX_VALUE, tagCipher, frameCipher, tagCopy, frameCopy, Long.MAX_VALUE, frameCipher, frameCopy, FRAME_LENGTH);
FRAME_LENGTH);
ConnectionWriter writer = new ConnectionWriterImpl(encryptionOut, ConnectionWriter writer = new ConnectionWriterImpl(encryptionOut,
FRAME_LENGTH); FRAME_LENGTH);
OutputStream out1 = writer.getOutputStream(); OutputStream out1 = writer.getOutputStream();
@@ -84,11 +73,11 @@ public class FrameReadWriteTest extends BriarTestCase {
out1.write(frame1); out1.write(frame1);
out1.flush(); out1.flush();
byte[] output = out.toByteArray(); byte[] output = out.toByteArray();
assertEquals(TAG_LENGTH + FRAME_LENGTH * 2, output.length); assertEquals(FRAME_LENGTH * 2, output.length);
// Read the tag and the frames back // Read the tag and the frames back
ByteArrayInputStream in = new ByteArrayInputStream(output); ByteArrayInputStream in = new ByteArrayInputStream(output);
FrameReader encryptionIn = new IncomingEncryptionLayer(in, tagCipher, FrameReader encryptionIn = new IncomingEncryptionLayer(in, frameCipher,
frameCipher, tagKey, frameKey, FRAME_LENGTH); frameKey, FRAME_LENGTH);
ConnectionReader reader = new ConnectionReaderImpl(encryptionIn, ConnectionReader reader = new ConnectionReaderImpl(encryptionIn,
FRAME_LENGTH); FRAME_LENGTH);
InputStream in1 = reader.getInputStream(); InputStream in1 = reader.getInputStream();

View File

@@ -5,12 +5,8 @@ import static net.sf.briar.api.transport.TransportConstants.AAD_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH; import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH; import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH; import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.EOFException;
import javax.crypto.Cipher;
import net.sf.briar.BriarTestCase; import net.sf.briar.BriarTestCase;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
@@ -19,7 +15,6 @@ import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.crypto.CryptoModule; import net.sf.briar.crypto.CryptoModule;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import com.google.inject.Guice; import com.google.inject.Guice;
@@ -32,99 +27,46 @@ public class IncomingEncryptionLayerTest extends BriarTestCase {
FRAME_LENGTH - HEADER_LENGTH - MAC_LENGTH; FRAME_LENGTH - HEADER_LENGTH - MAC_LENGTH;
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final Cipher tagCipher;
private final AuthenticatedCipher frameCipher; private final AuthenticatedCipher frameCipher;
private final ErasableKey frameKey;
private ErasableKey tagKey = null, frameKey = null;
public IncomingEncryptionLayerTest() { public IncomingEncryptionLayerTest() {
super(); super();
Injector i = Guice.createInjector(new CryptoModule()); Injector i = Guice.createInjector(new CryptoModule());
crypto = i.getInstance(CryptoComponent.class); crypto = i.getInstance(CryptoComponent.class);
tagCipher = crypto.getTagCipher();
frameCipher = crypto.getFrameCipher(); frameCipher = crypto.getFrameCipher();
}
@Before
public void setUp() {
tagKey = crypto.generateTestKey();
frameKey = crypto.generateTestKey(); frameKey = crypto.generateTestKey();
} }
@Test @Test
public void testReadValidTagAndFrames() throws Exception { public void testReadValidFrames() throws Exception {
// Generate a valid tag
byte[] tag = generateTag(tagKey);
// Generate two valid frames // Generate two valid frames
byte[] frame = generateFrame(0L, FRAME_LENGTH, 123, false, false); byte[] frame = generateFrame(0L, FRAME_LENGTH, 123, false, false);
byte[] frame1 = generateFrame(1L, FRAME_LENGTH, 123, false, false); byte[] frame1 = generateFrame(1L, FRAME_LENGTH, 123, false, false);
// Concatenate the tag and the frames // Concatenate the frames
byte[] valid = new byte[TAG_LENGTH + FRAME_LENGTH * 2]; byte[] valid = new byte[FRAME_LENGTH * 2];
System.arraycopy(tag, 0, valid, 0, TAG_LENGTH); System.arraycopy(frame, 0, valid, 0, FRAME_LENGTH);
System.arraycopy(frame, 0, valid, TAG_LENGTH, FRAME_LENGTH); System.arraycopy(frame1, 0, valid, FRAME_LENGTH, FRAME_LENGTH);
System.arraycopy(frame1, 0, valid, TAG_LENGTH + FRAME_LENGTH, // Read the frames
FRAME_LENGTH);
// Read the frames, which should first read the tag
ByteArrayInputStream in = new ByteArrayInputStream(valid); ByteArrayInputStream in = new ByteArrayInputStream(valid);
IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, tagCipher, IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, frameCipher,
frameCipher, tagKey, frameKey, FRAME_LENGTH); frameKey, FRAME_LENGTH);
byte[] buf = new byte[FRAME_LENGTH - MAC_LENGTH]; byte[] buf = new byte[FRAME_LENGTH - MAC_LENGTH];
assertEquals(123, i.readFrame(buf)); assertEquals(123, i.readFrame(buf));
assertEquals(123, i.readFrame(buf)); assertEquals(123, i.readFrame(buf));
} }
@Test
public void testTruncatedTagThrowsException() throws Exception {
// Generate a valid tag
byte[] tag = generateTag(tagKey);
// Chop off the last byte
byte[] truncated = new byte[TAG_LENGTH - 1];
System.arraycopy(tag, 0, truncated, 0, TAG_LENGTH - 1);
// Try to read the frame, which should first try to read the tag
ByteArrayInputStream in = new ByteArrayInputStream(truncated);
IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, tagCipher,
frameCipher, tagKey, crypto.generateTestKey(), FRAME_LENGTH);
try {
i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]);
fail();
} catch(EOFException expected) {}
}
@Test @Test
public void testTruncatedFrameThrowsException() throws Exception { public void testTruncatedFrameThrowsException() throws Exception {
// Generate a valid tag
byte[] tag = generateTag(tagKey);
// Generate a valid frame // Generate a valid frame
byte[] frame = generateFrame(0L, FRAME_LENGTH, 123, false, false); byte[] frame = generateFrame(0L, FRAME_LENGTH, 123, false, false);
// Chop off the last byte // Chop off the last byte
byte[] truncated = new byte[TAG_LENGTH + FRAME_LENGTH - 1]; byte[] truncated = new byte[FRAME_LENGTH - 1];
System.arraycopy(tag, 0, truncated, 0, TAG_LENGTH); System.arraycopy(frame, 0, truncated, 0, FRAME_LENGTH - 1);
System.arraycopy(frame, 0, truncated, TAG_LENGTH, FRAME_LENGTH - 1);
// Try to read the frame, which should fail due to truncation // Try to read the frame, which should fail due to truncation
ByteArrayInputStream in = new ByteArrayInputStream(truncated); ByteArrayInputStream in = new ByteArrayInputStream(truncated);
IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, tagCipher, IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, frameCipher,
frameCipher, tagKey, frameKey, FRAME_LENGTH); frameKey, FRAME_LENGTH);
try {
i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]);
fail();
} catch(FormatException expected) {}
}
@Test
public void testModifiedTagThrowsException() throws Exception {
// Generate a valid tag
byte[] tag = generateTag(tagKey);
// Generate a valid frame
byte[] frame = generateFrame(0L, FRAME_LENGTH, 123, false, false);
// Modify a randomly chosen byte of the tag
byte[] modified = new byte[TAG_LENGTH + FRAME_LENGTH];
System.arraycopy(tag, 0, modified, 0, TAG_LENGTH);
System.arraycopy(frame, 0, modified, TAG_LENGTH, FRAME_LENGTH);
modified[(int) (Math.random() * TAG_LENGTH)] ^= 1;
// Try to read the frame, which should fail due to modification
ByteArrayInputStream in = new ByteArrayInputStream(modified);
IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, tagCipher,
frameCipher, tagKey, frameKey, FRAME_LENGTH);
try { try {
i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]); i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]);
fail(); fail();
@@ -133,19 +75,14 @@ public class IncomingEncryptionLayerTest extends BriarTestCase {
@Test @Test
public void testModifiedFrameThrowsException() throws Exception { public void testModifiedFrameThrowsException() throws Exception {
// Generate a valid tag
byte[] tag = generateTag(tagKey);
// Generate a valid frame // Generate a valid frame
byte[] frame = generateFrame(0L, FRAME_LENGTH, 123, false, false); byte[] frame = generateFrame(0L, FRAME_LENGTH, 123, false, false);
// Modify a randomly chosen byte of the frame // Modify a randomly chosen byte of the frame
byte[] modified = new byte[TAG_LENGTH + FRAME_LENGTH]; frame[(int) (Math.random() * FRAME_LENGTH)] ^= 1;
System.arraycopy(tag, 0, modified, 0, TAG_LENGTH);
System.arraycopy(frame, 0, modified, TAG_LENGTH, FRAME_LENGTH);
modified[TAG_LENGTH + (int) (Math.random() * FRAME_LENGTH)] ^= 1;
// Try to read the frame, which should fail due to modification // Try to read the frame, which should fail due to modification
ByteArrayInputStream in = new ByteArrayInputStream(modified); ByteArrayInputStream in = new ByteArrayInputStream(frame);
IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, tagCipher, IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, frameCipher,
frameCipher, tagKey, frameKey, FRAME_LENGTH); frameKey, FRAME_LENGTH);
try { try {
i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]); i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]);
fail(); fail();
@@ -154,18 +91,12 @@ public class IncomingEncryptionLayerTest extends BriarTestCase {
@Test @Test
public void testShortNonFinalFrameThrowsException() throws Exception { public void testShortNonFinalFrameThrowsException() throws Exception {
// Generate a valid tag
byte[] tag = generateTag(tagKey);
// Generate a short non-final frame // Generate a short non-final frame
byte[] frame = generateFrame(0L, FRAME_LENGTH - 1, 123, false, false); byte[] frame = generateFrame(0L, FRAME_LENGTH - 1, 123, false, false);
// Concatenate the tag and the frame
byte[] tooShort = new byte[TAG_LENGTH + FRAME_LENGTH - 1];
System.arraycopy(tag, 0, tooShort, 0, TAG_LENGTH);
System.arraycopy(frame, 0, tooShort, TAG_LENGTH, FRAME_LENGTH - 1);
// Try to read the frame, which should fail due to invalid length // Try to read the frame, which should fail due to invalid length
ByteArrayInputStream in = new ByteArrayInputStream(tooShort); ByteArrayInputStream in = new ByteArrayInputStream(frame);
IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, tagCipher, IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, frameCipher,
frameCipher, tagKey, frameKey, FRAME_LENGTH); frameKey, FRAME_LENGTH);
try { try {
i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]); i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]);
fail(); fail();
@@ -174,37 +105,25 @@ public class IncomingEncryptionLayerTest extends BriarTestCase {
@Test @Test
public void testShortFinalFrameDoesNotThrowException() throws Exception { public void testShortFinalFrameDoesNotThrowException() throws Exception {
// Generate a valid tag
byte[] tag = generateTag(tagKey);
// Generate a short final frame // Generate a short final frame
byte[] frame = generateFrame(0L, FRAME_LENGTH - 1, 123, true, false); byte[] frame = generateFrame(0L, FRAME_LENGTH - 1, 123, true, false);
// Concatenate the tag and the frame // Read the frame
byte[] valid = new byte[TAG_LENGTH + FRAME_LENGTH - 1]; ByteArrayInputStream in = new ByteArrayInputStream(frame);
System.arraycopy(tag, 0, valid, 0, TAG_LENGTH); IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, frameCipher,
System.arraycopy(frame, 0, valid, TAG_LENGTH, FRAME_LENGTH - 1); frameKey, FRAME_LENGTH);
// Read the frame, which should first read the tag
ByteArrayInputStream in = new ByteArrayInputStream(valid);
IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, tagCipher,
frameCipher, tagKey, frameKey, FRAME_LENGTH);
int length = i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]); int length = i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]);
assertEquals(123, length); assertEquals(123, length);
} }
@Test @Test
public void testInvalidPayloadLengthThrowsException() throws Exception { public void testInvalidPayloadLengthThrowsException() throws Exception {
// Generate a valid tag
byte[] tag = generateTag(tagKey);
// Generate a frame with an invalid payload length // Generate a frame with an invalid payload length
byte[] frame = generateFrame(0L, FRAME_LENGTH, MAX_PAYLOAD_LENGTH + 1, byte[] frame = generateFrame(0L, FRAME_LENGTH, MAX_PAYLOAD_LENGTH + 1,
false, false); false, false);
// Concatenate the tag and the frame
byte[] tooLong = new byte[TAG_LENGTH + FRAME_LENGTH];
System.arraycopy(tag, 0, tooLong, 0, TAG_LENGTH);
System.arraycopy(frame, 0, tooLong, TAG_LENGTH, FRAME_LENGTH);
// Try to read the frame, which should fail due to invalid length // Try to read the frame, which should fail due to invalid length
ByteArrayInputStream in = new ByteArrayInputStream(tooLong); ByteArrayInputStream in = new ByteArrayInputStream(frame);
IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, tagCipher, IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, frameCipher,
frameCipher, tagKey, frameKey, FRAME_LENGTH); frameKey, FRAME_LENGTH);
try { try {
i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]); i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]);
fail(); fail();
@@ -213,18 +132,12 @@ public class IncomingEncryptionLayerTest extends BriarTestCase {
@Test @Test
public void testNonZeroPaddingThrowsException() throws Exception { public void testNonZeroPaddingThrowsException() throws Exception {
// Generate a valid tag // Generate a frame with bad padding
byte[] tag = generateTag(tagKey);
// Generate a frame with pad padding
byte[] frame = generateFrame(0L, FRAME_LENGTH, 123, false, true); byte[] frame = generateFrame(0L, FRAME_LENGTH, 123, false, true);
// Concatenate the tag and the frame
byte[] badPadding = new byte[TAG_LENGTH + FRAME_LENGTH];
System.arraycopy(tag, 0, badPadding, 0, TAG_LENGTH);
System.arraycopy(frame, 0, badPadding, TAG_LENGTH, FRAME_LENGTH);
// Try to read the frame, which should fail due to bad padding // Try to read the frame, which should fail due to bad padding
ByteArrayInputStream in = new ByteArrayInputStream(badPadding); ByteArrayInputStream in = new ByteArrayInputStream(frame);
IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, tagCipher, IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, frameCipher,
frameCipher, tagKey, frameKey, FRAME_LENGTH); frameKey, FRAME_LENGTH);
try { try {
i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]); i.readFrame(new byte[FRAME_LENGTH - MAC_LENGTH]);
fail(); fail();
@@ -233,34 +146,24 @@ public class IncomingEncryptionLayerTest extends BriarTestCase {
@Test @Test
public void testCannotReadBeyondFinalFrame() throws Exception { public void testCannotReadBeyondFinalFrame() throws Exception {
// Generate a valid tag
byte[] tag = generateTag(tagKey);
// Generate a valid final frame and another valid final frame after it // Generate a valid final frame and another valid final frame after it
byte[] frame = generateFrame(0L, FRAME_LENGTH, MAX_PAYLOAD_LENGTH, true, byte[] frame = generateFrame(0L, FRAME_LENGTH, MAX_PAYLOAD_LENGTH, true,
false); false);
byte[] frame1 = generateFrame(1L, FRAME_LENGTH, 123, true, false); byte[] frame1 = generateFrame(1L, FRAME_LENGTH, 123, true, false);
// Concatenate the tag and the frames // Concatenate the frames
byte[] extraFrame = new byte[TAG_LENGTH + FRAME_LENGTH * 2]; byte[] extraFrame = new byte[FRAME_LENGTH * 2];
System.arraycopy(tag, 0, extraFrame, 0, TAG_LENGTH); System.arraycopy(frame, 0, extraFrame, 0, FRAME_LENGTH);
System.arraycopy(frame, 0, extraFrame, TAG_LENGTH, FRAME_LENGTH); System.arraycopy(frame1, 0, extraFrame, FRAME_LENGTH, FRAME_LENGTH);
System.arraycopy(frame1, 0, extraFrame, TAG_LENGTH + FRAME_LENGTH,
FRAME_LENGTH);
// Read the final frame, which should first read the tag // Read the final frame, which should first read the tag
ByteArrayInputStream in = new ByteArrayInputStream(extraFrame); ByteArrayInputStream in = new ByteArrayInputStream(extraFrame);
IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, tagCipher, IncomingEncryptionLayer i = new IncomingEncryptionLayer(in, frameCipher,
frameCipher, tagKey, frameKey, FRAME_LENGTH); frameKey, FRAME_LENGTH);
byte[] buf = new byte[FRAME_LENGTH - MAC_LENGTH]; byte[] buf = new byte[FRAME_LENGTH - MAC_LENGTH];
assertEquals(MAX_PAYLOAD_LENGTH, i.readFrame(buf)); assertEquals(MAX_PAYLOAD_LENGTH, i.readFrame(buf));
// The frame after the final frame should not be read // The frame after the final frame should not be read
assertEquals(-1, i.readFrame(buf)); assertEquals(-1, i.readFrame(buf));
} }
private byte[] generateTag(ErasableKey tagKey) {
byte[] tag = new byte[TAG_LENGTH];
TagEncoder.encodeTag(tag, tagCipher, tagKey);
return tag;
}
private byte[] generateFrame(long frameNumber, int frameLength, private byte[] generateFrame(long frameNumber, int frameLength,
int payloadLength, boolean finalFrame, boolean badPadding) int payloadLength, boolean finalFrame, boolean badPadding)
throws Exception { throws Exception {

View File

@@ -9,8 +9,6 @@ import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import javax.crypto.Cipher;
import net.sf.briar.BriarTestCase; import net.sf.briar.BriarTestCase;
import net.sf.briar.api.crypto.AuthenticatedCipher; import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.CryptoComponent;
@@ -29,28 +27,24 @@ public class OutgoingEncryptionLayerTest extends BriarTestCase {
FRAME_LENGTH - HEADER_LENGTH - MAC_LENGTH; FRAME_LENGTH - HEADER_LENGTH - MAC_LENGTH;
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final Cipher tagCipher;
private final AuthenticatedCipher frameCipher; private final AuthenticatedCipher frameCipher;
private final byte[] tag;
public OutgoingEncryptionLayerTest() { public OutgoingEncryptionLayerTest() {
super(); super();
Injector i = Guice.createInjector(new CryptoModule()); Injector i = Guice.createInjector(new CryptoModule());
crypto = i.getInstance(CryptoComponent.class); crypto = i.getInstance(CryptoComponent.class);
tagCipher = crypto.getTagCipher();
frameCipher = crypto.getFrameCipher(); frameCipher = crypto.getFrameCipher();
tag = new byte[TAG_LENGTH];
} }
@Test @Test
public void testEncryption() throws Exception { public void testEncryption() throws Exception {
int payloadLength = 123; int payloadLength = 123;
byte[] tag = new byte[TAG_LENGTH];
byte[] iv = new byte[IV_LENGTH], aad = new byte[AAD_LENGTH]; byte[] iv = new byte[IV_LENGTH], aad = new byte[AAD_LENGTH];
byte[] plaintext = new byte[FRAME_LENGTH - MAC_LENGTH]; byte[] plaintext = new byte[FRAME_LENGTH - MAC_LENGTH];
byte[] ciphertext = new byte[FRAME_LENGTH]; byte[] ciphertext = new byte[FRAME_LENGTH];
ErasableKey tagKey = crypto.generateTestKey();
ErasableKey frameKey = crypto.generateTestKey(); ErasableKey frameKey = crypto.generateTestKey();
// Calculate the expected tag
TagEncoder.encodeTag(tag, tagCipher, tagKey);
// Calculate the expected ciphertext // Calculate the expected ciphertext
FrameEncoder.encodeIv(iv, 0); FrameEncoder.encodeIv(iv, 0);
FrameEncoder.encodeAad(aad, 0, plaintext.length); FrameEncoder.encodeAad(aad, 0, plaintext.length);
@@ -60,14 +54,11 @@ public class OutgoingEncryptionLayerTest extends BriarTestCase {
// Check that the actual tag and ciphertext match what's expected // Check that the actual tag and ciphertext match what's expected
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
OutgoingEncryptionLayer o = new OutgoingEncryptionLayer(out, OutgoingEncryptionLayer o = new OutgoingEncryptionLayer(out,
10 * FRAME_LENGTH, tagCipher, frameCipher, tagKey, frameKey, 10 * FRAME_LENGTH, frameCipher, frameKey, FRAME_LENGTH, tag);
FRAME_LENGTH);
o.writeFrame(new byte[FRAME_LENGTH - MAC_LENGTH], payloadLength, false); o.writeFrame(new byte[FRAME_LENGTH - MAC_LENGTH], payloadLength, false);
byte[] actual = out.toByteArray(); byte[] actual = out.toByteArray();
assertEquals(TAG_LENGTH + FRAME_LENGTH, actual.length); assertEquals(TAG_LENGTH + FRAME_LENGTH, actual.length);
for(int i = 0; i < TAG_LENGTH; i++) { for(int i = 0; i < TAG_LENGTH; i++) assertEquals(tag[i], actual[i]);
assertEquals(tag[i], actual[i]);
}
for(int i = 0; i < FRAME_LENGTH; i++) { for(int i = 0; i < FRAME_LENGTH; i++) {
assertEquals("" + i, ciphertext[i], actual[TAG_LENGTH + i]); assertEquals("" + i, ciphertext[i], actual[TAG_LENGTH + i]);
} }
@@ -78,9 +69,8 @@ public class OutgoingEncryptionLayerTest extends BriarTestCase {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
// Initiator's constructor // Initiator's constructor
OutgoingEncryptionLayer o = new OutgoingEncryptionLayer(out, OutgoingEncryptionLayer o = new OutgoingEncryptionLayer(out,
10 * FRAME_LENGTH, tagCipher, frameCipher, 10 * FRAME_LENGTH, frameCipher, crypto.generateTestKey(),
crypto.generateTestKey(), crypto.generateTestKey(), FRAME_LENGTH, tag);
FRAME_LENGTH);
// Write an empty final frame without having written any other frames // Write an empty final frame without having written any other frames
o.writeFrame(new byte[FRAME_LENGTH - MAC_LENGTH], 0, true); o.writeFrame(new byte[FRAME_LENGTH - MAC_LENGTH], 0, true);
// Nothing should be written to the output stream // Nothing should be written to the output stream
@@ -106,9 +96,8 @@ public class OutgoingEncryptionLayerTest extends BriarTestCase {
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
// Initiator's constructor // Initiator's constructor
OutgoingEncryptionLayer o = new OutgoingEncryptionLayer(out, OutgoingEncryptionLayer o = new OutgoingEncryptionLayer(out,
10 * FRAME_LENGTH, tagCipher, frameCipher, 10 * FRAME_LENGTH, frameCipher, crypto.generateTestKey(),
crypto.generateTestKey(), crypto.generateTestKey(), FRAME_LENGTH, tag);
FRAME_LENGTH);
// There should be space for nine full frames and one partial frame // There should be space for nine full frames and one partial frame
byte[] frame = new byte[FRAME_LENGTH - MAC_LENGTH]; byte[] frame = new byte[FRAME_LENGTH - MAC_LENGTH];
assertEquals(10 * MAX_PAYLOAD_LENGTH - TAG_LENGTH, assertEquals(10 * MAX_PAYLOAD_LENGTH - TAG_LENGTH,