Merge branch '55-key-manager-refactoring' into 'master'

Refactor KeyManager and TagRecogniser, implement new key rotation logic.

This patch implements the new key rotation logic for BTP version 2, the new transport protocol (#111).

KeyManager and TagRecogniser have been merged (#55). They no longer need to hold locks while calling each other's methods (#3, #4). TransportKeyManager holds a lock while calling CryptoComponent methods, but those methods don't block or acquire any locks.

The maximum clock difference has been increased from one hour to 24 hours because some people adjust the time rather than the timezone when travelling (#18). This will cause keys to be rotated less frequently.

For the same reason, the key manager no longer throws an Error when the clock moves backwards - keys that belong to future rotation periods are kept until they become current, then rotation resumes (#19).

The new KeyManagerImpl and TransportKeyManager need unit tests. I'm putting this up for review while I write the tests.

TransportKeyManager needs further work before #55 is complete - when a stream context is created, TransportKeyManager should wait for the database write to complete before returning the context, to avoid the risk of key reuse if the write fails. I'll make a separate patch for that as this one's big enough already.

Enjoy!

See merge request !13
This commit is contained in:
akwizgran
2015-12-15 13:18:30 +00:00
59 changed files with 2118 additions and 3840 deletions

View File

@@ -32,7 +32,6 @@ import org.briarproject.api.event.Event;
import org.briarproject.api.event.EventBus;
import org.briarproject.api.event.EventListener;
import org.briarproject.api.event.SettingsUpdatedEvent;
import org.briarproject.api.lifecycle.Service;
import org.briarproject.api.messaging.GroupId;
import org.briarproject.util.StringUtils;
@@ -45,7 +44,7 @@ import android.support.v4.app.NotificationCompat;
import android.support.v4.app.TaskStackBuilder;
class AndroidNotificationManagerImpl implements AndroidNotificationManager,
Service, EventListener {
EventListener {
private static final int PRIVATE_MESSAGE_NOTIFICATION_ID = 3;
private static final int GROUP_POST_NOTIFICATION_ID = 4;

View File

@@ -1,13 +1,14 @@
package org.briarproject.api.android;
import org.briarproject.api.ContactId;
import org.briarproject.api.lifecycle.Service;
import org.briarproject.api.messaging.GroupId;
/**
* Manages notifications for private messages and group posts. All methods must
* be called from the Android UI thread.
*/
public interface AndroidNotificationManager {
public interface AndroidNotificationManager extends Service {
public void showPrivateMessageNotification(ContactId c);

View File

@@ -1,5 +1,8 @@
package org.briarproject.api.crypto;
import org.briarproject.api.TransportId;
import org.briarproject.api.transport.TransportKeys;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
@@ -26,56 +29,50 @@ public interface CryptoComponent {
/** Generates a random invitation code. */
int generateInvitationCode();
/**
* Derives two confirmation codes from the given master secret. The first
* code is for Alice to give to Bob; the second is for Bob to give to
* Alice.
*/
int[] deriveConfirmationCodes(byte[] secret);
/**
* Derives two nonces from the given master secret. The first nonce is for
* Alice to sign; the second is for Bob to sign.
*/
byte[][] deriveInvitationNonces(byte[] secret);
/**
* Derives a shared master secret from two public keys and one of the
* corresponding private keys.
* @param alice indicates whether the private key belongs to Alice or Bob.
* @param alice whether the private key belongs to Alice or Bob.
*/
byte[] deriveMasterSecret(byte[] theirPublicKey, KeyPair ourKeyPair,
SecretKey deriveMasterSecret(byte[] theirPublicKey, KeyPair ourKeyPair,
boolean alice) throws GeneralSecurityException;
/** Derives a group salt from the given master secret. */
byte[] deriveGroupSalt(byte[] secret);
/**
* Derives a confirmation code from the given master secret.
* @param alice whether the code is for use by Alice or Bob.
*/
int deriveConfirmationCode(SecretKey master, boolean alice);
/**
* Derives an initial secret for the given transport from the given master
* Derives a header key for an invitation stream from the given master
* secret.
* @param alice whether the key is for use by Alice or Bob.
*/
byte[] deriveInitialSecret(byte[] secret, int transportIndex);
SecretKey deriveInvitationKey(SecretKey master, boolean alice);
/**
* Derives a temporary secret for the given period from the given secret,
* which is either the initial shared secret or the previous period's
* temporary secret.
* Derives a nonce from the given master secret for one of the parties to
* sign.
* @param alice whether the nonce is for use by Alice or Bob.
*/
byte[] deriveNextSecret(byte[] secret, long period);
byte[] deriveSignatureNonce(SecretKey master, boolean alice);
/** Derives a group salt from the given master secret. */
byte[] deriveGroupSalt(SecretKey master);
/**
* Derives a tag key from the given temporary secret.
* @param alice indicates whether the key is for streams initiated by
* Alice or Bob.
* Derives initial transport keys for the given transport in the given
* rotation period from the given master secret.
* @param alice whether the keys are for use by Alice or Bob.
*/
SecretKey deriveTagKey(byte[] secret, boolean alice);
TransportKeys deriveTransportKeys(TransportId t, SecretKey master,
long rotationPeriod, boolean alice);
/**
* Derives a frame key from the given temporary secret and stream number.
* @param alice indicates whether the key is for a stream initiated by
* Alice or Bob.
* Rotates the given transport keys to the given rotation period. If the
* keys are for a future rotation period they are not rotated.
*/
SecretKey deriveFrameKey(byte[] secret, long streamNumber, boolean alice);
TransportKeys rotateTransportKeys(TransportKeys k, long rotationPeriod);
/** Encodes the pseudo-random tag that is used to recognise a stream. */
void encodeTag(byte[] tag, SecretKey tagKey, long streamNumber);

View File

@@ -1,21 +0,0 @@
package org.briarproject.api.crypto;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.lifecycle.Service;
import org.briarproject.api.transport.Endpoint;
import org.briarproject.api.transport.StreamContext;
public interface KeyManager extends Service {
/**
* Returns a {@link org.briarproject.api.transport.StreamContext
* StreamContext} for sending data to the given contact over the given
* transport, or null if an error occurs or the contact does not support
* the transport.
*/
StreamContext getStreamContext(ContactId c, TransportId t);
/** Called whenever an endpoint has been added. */
void endpointAdded(Endpoint ep, int maxLatency, byte[] initialSecret);
}

View File

@@ -13,5 +13,5 @@ public interface StreamDecrypterFactory {
* Creates a {@link StreamDecrypter} for decrypting an invitation stream.
*/
StreamDecrypter createInvitationStreamDecrypter(InputStream in,
byte[] secret, boolean alice);
SecretKey headerKey);
}

View File

@@ -13,5 +13,5 @@ public interface StreamEncrypterFactory {
* Creates a {@link StreamEncrypter} for encrypting an invitation stream.
*/
StreamEncrypter createInvitationStreamEncrypter(OutputStream out,
byte[] secret, boolean alice);
SecretKey headerKey);
}

View File

@@ -1,9 +1,5 @@
package org.briarproject.api.db;
import java.io.IOException;
import java.util.Collection;
import java.util.Map;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorId;
import org.briarproject.api.Contact;
@@ -26,8 +22,11 @@ import org.briarproject.api.messaging.SubscriptionAck;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.TransportAck;
import org.briarproject.api.messaging.TransportUpdate;
import org.briarproject.api.transport.Endpoint;
import org.briarproject.api.transport.TemporarySecret;
import org.briarproject.api.transport.TransportKeys;
import java.io.IOException;
import java.util.Collection;
import java.util.Map;
/**
* Encapsulates the database implementation and exposes high-level operations
@@ -47,9 +46,6 @@ public interface DatabaseComponent {
*/
ContactId addContact(Author remote, AuthorId local) throws DbException;
/** Stores an endpoint. */
void addEndpoint(Endpoint ep) throws DbException;
/**
* Subscribes to a group, or returns false if the user already has the
* maximum number of public subscriptions.
@@ -62,18 +58,17 @@ public interface DatabaseComponent {
/** Stores a local message. */
void addLocalMessage(Message m) throws DbException;
/**
* Stores the given temporary secrets and deletes any secrets that have
* been made obsolete.
*/
void addSecrets(Collection<TemporarySecret> secrets) throws DbException;
/**
* Stores a transport and returns true if the transport was not previously
* in the database.
*/
boolean addTransport(TransportId t, int maxLatency) throws DbException;
/**
* Stores the given transport keys for a newly added contact.
*/
void addTransportKeys(ContactId c, TransportKeys k) throws DbException;
/**
* Returns an acknowledgement for the given contact, or null if there are
* no messages to acknowledge.
@@ -214,16 +209,17 @@ public interface DatabaseComponent {
Map<ContactId, TransportProperties> getRemoteProperties(TransportId t)
throws DbException;
/** Returns all temporary secrets. */
Collection<TemporarySecret> getSecrets() throws DbException;
/** Returns all settings. */
Settings getSettings() throws DbException;
/** Returns all contacts who subscribe to the given group. */
Collection<Contact> getSubscribers(GroupId g) throws DbException;
/** Returns the maximum latencies of all supported transports. */
/** Returns all transport keys for the given transport. */
Map<ContactId, TransportKeys> getTransportKeys(TransportId t)
throws DbException;
/** Returns the maximum latencies in milliseconds of all transports. */
Map<TransportId, Integer> getTransportLatencies() throws DbException;
/** Returns the number of unread messages in each subscribed group. */
@@ -233,11 +229,10 @@ public interface DatabaseComponent {
Collection<ContactId> getVisibility(GroupId g) throws DbException;
/**
* Increments the outgoing stream counter for the given endpoint in the
* given rotation period and returns the old value, or -1 if the counter
* does not exist.
* Increments the outgoing stream counter for the given contact and
* transport in the given rotation period .
*/
long incrementStreamCounter(ContactId c, TransportId t, long period)
void incrementStreamCounter(ContactId c, TransportId t, long rotationPeriod)
throws DbException;
/**
@@ -310,18 +305,11 @@ public interface DatabaseComponent {
*/
void removeTransport(TransportId t) throws DbException;
/**
* Sets the reordering window for the given endpoint in the given rotation
* period.
*/
void setReorderingWindow(ContactId c, TransportId t, long period,
long centre, byte[] bitmap) throws DbException;
/**
* Makes a group visible to the given contact, adds it to the contact's
* subscriptions, and sets it as the inbox group for the contact.
*/
public void setInboxGroup(ContactId c, Group g) throws DbException;
void setInboxGroup(ContactId c, Group g) throws DbException;
/**
* Marks a message as read or unread.
@@ -335,6 +323,13 @@ public interface DatabaseComponent {
void setRemoteProperties(ContactId c,
Map<TransportId, TransportProperties> p) throws DbException;
/**
* Sets the reordering window for the given contact and transport in the
* given rotation period.
*/
void setReorderingWindow(ContactId c, TransportId t, long rotationPeriod,
long base, byte[] bitmap) throws DbException;
/**
* Makes a group visible to the given set of contacts and invisible to any
* other current or future contacts.
@@ -347,4 +342,10 @@ public interface DatabaseComponent {
* to future contacts.
*/
void setVisibleToAll(GroupId g, boolean all) throws DbException;
/**
* Stores the given transport keys, deleting any keys they have replaced.
*/
void updateTransportKeys(Map<ContactId, TransportKeys> keys)
throws DbException;
}

View File

@@ -28,7 +28,7 @@ public interface Message {
/** Returns the message's content type. */
String getContentType();
/** Returns the message's timestamp. */
/** Returns the message's timestamp in milliseconds since the Unix epoch. */
long getTimestamp();
/** Returns the serialised message. */

View File

@@ -1,36 +0,0 @@
package org.briarproject.api.transport;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
public class Endpoint {
protected final ContactId contactId;
protected final TransportId transportId;
private final long epoch;
private final boolean alice;
public Endpoint(ContactId contactId, TransportId transportId, long epoch,
boolean alice) {
this.contactId = contactId;
this.transportId = transportId;
this.epoch = epoch;
this.alice = alice;
}
public ContactId getContactId() {
return contactId;
}
public TransportId getTransportId() {
return transportId;
}
public long getEpoch() {
return epoch;
}
public boolean getAlice() {
return alice;
}
}

View File

@@ -0,0 +1,51 @@
package org.briarproject.api.transport;
import org.briarproject.api.crypto.SecretKey;
import static org.briarproject.api.transport.TransportConstants.REORDERING_WINDOW_SIZE;
/**
* Contains transport keys for receiving streams from a given contact over a
* given transport in a given rotation period.
*/
public class IncomingKeys {
private final SecretKey tagKey, headerKey;
private final long rotationPeriod, windowBase;
private final byte[] windowBitmap;
public IncomingKeys(SecretKey tagKey, SecretKey headerKey,
long rotationPeriod) {
this(tagKey, headerKey, rotationPeriod, 0,
new byte[REORDERING_WINDOW_SIZE / 8]);
}
public IncomingKeys(SecretKey tagKey, SecretKey headerKey,
long rotationPeriod, long windowBase, byte[] windowBitmap) {
this.tagKey = tagKey;
this.headerKey = headerKey;
this.rotationPeriod = rotationPeriod;
this.windowBase = windowBase;
this.windowBitmap = windowBitmap;
}
public SecretKey getTagKey() {
return tagKey;
}
public SecretKey getHeaderKey() {
return headerKey;
}
public long getRotationPeriod() {
return rotationPeriod;
}
public long getWindowBase() {
return windowBase;
}
public byte[] getWindowBitmap() {
return windowBitmap;
}
}

View File

@@ -0,0 +1,36 @@
package org.briarproject.api.transport;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.db.DbException;
import org.briarproject.api.lifecycle.Service;
import java.util.Collection;
/**
* Responsible for managing transport keys and recognising the pseudo-random
* tags of incoming streams.
*/
public interface KeyManager extends Service {
/**
* Informs the key manager that a new contact has been added.
* {@link StreamContext StreamContexts} for the contact can be created
* after this method has returned.
*/
void contactAdded(ContactId c, Collection<TransportKeys> keys);
/**
* Returns a {@link StreamContext} for sending a stream to the given
* contact over the given transport, or null if an error occurs or the
* contact does not support the transport.
*/
StreamContext getStreamContext(ContactId c, TransportId t);
/**
* Looks up the given tag and returns a {@link StreamContext} for reading
* from the corresponding stream if the tag was expected, or null if the
* tag was unexpected.
*/
StreamContext recogniseTag(TransportId t, byte[] tag) throws DbException;
}

View File

@@ -0,0 +1,42 @@
package org.briarproject.api.transport;
import org.briarproject.api.crypto.SecretKey;
/**
* Contains transport keys for sending streams to a given contact over a given
* transport in a given rotation period.
*/
public class OutgoingKeys {
private final SecretKey tagKey, headerKey;
private final long rotationPeriod, streamCounter;
public OutgoingKeys(SecretKey tagKey, SecretKey headerKey,
long rotationPeriod) {
this(tagKey, headerKey, rotationPeriod, 0);
}
public OutgoingKeys(SecretKey tagKey, SecretKey headerKey,
long rotationPeriod, long streamCounter) {
this.tagKey = tagKey;
this.headerKey = headerKey;
this.rotationPeriod = rotationPeriod;
this.streamCounter = streamCounter;
}
public SecretKey getTagKey() {
return tagKey;
}
public SecretKey getHeaderKey() {
return headerKey;
}
public long getRotationPeriod() {
return rotationPeriod;
}
public long getStreamCounter() {
return streamCounter;
}
}

View File

@@ -2,22 +2,22 @@ package org.briarproject.api.transport;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.SecretKey;
public class StreamContext {
private final ContactId contactId;
private final TransportId transportId;
private final byte[] secret;
private final SecretKey tagKey, headerKey;
private final long streamNumber;
private final boolean alice;
public StreamContext(ContactId contactId, TransportId transportId,
byte[] secret, long streamNumber, boolean alice) {
SecretKey tagKey, SecretKey headerKey, long streamNumber) {
this.contactId = contactId;
this.transportId = transportId;
this.secret = secret;
this.tagKey = tagKey;
this.headerKey = headerKey;
this.streamNumber = streamNumber;
this.alice = alice;
}
public ContactId getContactId() {
@@ -28,15 +28,15 @@ public class StreamContext {
return transportId;
}
public byte[] getSecret() {
return secret;
public SecretKey getTagKey() {
return tagKey;
}
public SecretKey getHeaderKey() {
return headerKey;
}
public long getStreamNumber() {
return streamNumber;
}
public boolean getAlice() {
return alice;
}
}

View File

@@ -2,6 +2,8 @@ package org.briarproject.api.transport;
import java.io.InputStream;
import org.briarproject.api.crypto.SecretKey;
public interface StreamReaderFactory {
/**
@@ -15,5 +17,5 @@ public interface StreamReaderFactory {
* invitation stream.
*/
InputStream createInvitationStreamReader(InputStream in,
byte[] secret, boolean alice);
SecretKey headerKey);
}

View File

@@ -2,6 +2,8 @@ package org.briarproject.api.transport;
import java.io.OutputStream;
import org.briarproject.api.crypto.SecretKey;
public interface StreamWriterFactory {
/**
@@ -15,5 +17,5 @@ public interface StreamWriterFactory {
* invitation stream.
*/
OutputStream createInvitationStreamWriter(OutputStream out,
byte[] secret, boolean alice);
SecretKey headerKey);
}

View File

@@ -1,26 +0,0 @@
package org.briarproject.api.transport;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.db.DbException;
/** Keeps track of expected tags and uses them to recognise incoming streams. */
public interface TagRecogniser {
/**
* Looks up the given tag and returns a {@link StreamContext} for reading
* from the stream if the tag was expected, or null if the tag was
* unexpected.
*/
StreamContext recogniseTag(TransportId t, byte[] tag) throws DbException;
void addSecret(TemporarySecret s);
void removeSecret(ContactId c, TransportId t, long period);
void removeSecrets(ContactId c);
void removeSecrets(TransportId t);
void removeSecrets();
}

View File

@@ -1,73 +0,0 @@
package org.briarproject.api.transport;
import static org.briarproject.api.transport.TransportConstants.REORDERING_WINDOW_SIZE;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
public class TemporarySecret extends Endpoint {
private final long period, outgoing, centre;
private final byte[] secret, bitmap;
/** Creates a temporary secret with the given reordering window. */
public TemporarySecret(ContactId contactId, TransportId transportId,
long epoch, boolean alice, long period, byte[] secret,
long outgoing, long centre, byte[] bitmap) {
super(contactId, transportId, epoch, alice);
this.period = period;
this.secret = secret;
this.outgoing = outgoing;
this.centre = centre;
this.bitmap = bitmap;
}
/** Creates a temporary secret with a new reordering window. */
public TemporarySecret(ContactId contactId, TransportId transportId,
long epoch, boolean alice, long period, byte[] secret) {
this(contactId, transportId, epoch, alice, period, secret, 0, 0,
new byte[REORDERING_WINDOW_SIZE / 8]);
}
/** Creates a temporary secret derived from the given endpoint. */
public TemporarySecret(Endpoint ep, long period, byte[] secret) {
this(ep.getContactId(), ep.getTransportId(), ep.getEpoch(),
ep.getAlice(), period, secret);
}
public long getPeriod() {
return period;
}
public byte[] getSecret() {
return secret;
}
public long getOutgoingStreamCounter() {
return outgoing;
}
public long getWindowCentre() {
return centre;
}
public byte[] getWindowBitmap() {
return bitmap;
}
@Override
public int hashCode() {
int periodHashCode = (int) (period ^ (period >>> 32));
return contactId.hashCode() ^ transportId.hashCode() ^ periodHashCode;
}
@Override
public boolean equals(Object o) {
if (o instanceof TemporarySecret) {
TemporarySecret s = (TemporarySecret) o;
return contactId.equals(s.contactId) &&
transportId.equals(s.transportId) && period == s.period;
}
return false;
}
}

View File

@@ -28,8 +28,8 @@ public interface TransportConstants {
*/
int MIN_STREAM_LENGTH = 64 * 1024; // 64 KiB
/** The maximum difference between two communicating devices' clocks. */
int MAX_CLOCK_DIFFERENCE = 60 * 60 * 1000; // 1 hour
/** The maximum difference in milliseconds between two peers' clocks. */
int MAX_CLOCK_DIFFERENCE = 24 * 60 * 60 * 1000; // 24 hours
/** The size of the reordering window. */
int REORDERING_WINDOW_SIZE = 32;

View File

@@ -0,0 +1,44 @@
package org.briarproject.api.transport;
import org.briarproject.api.TransportId;
/** Keys for communicating with a given contact over a given transport. */
public class TransportKeys {
private final TransportId transportId;
private final IncomingKeys inPrev, inCurr, inNext;
private final OutgoingKeys outCurr;
public TransportKeys(TransportId transportId, IncomingKeys inPrev,
IncomingKeys inCurr, IncomingKeys inNext, OutgoingKeys outCurr) {
this.transportId = transportId;
this.inPrev = inPrev;
this.inCurr = inCurr;
this.inNext = inNext;
this.outCurr = outCurr;
}
public TransportId getTransportId() {
return transportId;
}
public IncomingKeys getPreviousIncomingKeys() {
return inPrev;
}
public IncomingKeys getCurrentIncomingKeys() {
return inCurr;
}
public IncomingKeys getNextIncomingKeys() {
return inNext;
}
public OutgoingKeys getCurrentOutgoingKeys() {
return outCurr;
}
public long getRotationPeriod() {
return outCurr.getRotationPeriod();
}
}

View File

@@ -1,20 +1,6 @@
package org.briarproject.crypto;
import static java.util.logging.Level.INFO;
import static org.briarproject.api.invitation.InvitationConstants.CODE_BITS;
import static org.briarproject.api.transport.TransportConstants.TAG_LENGTH;
import static org.briarproject.crypto.EllipticCurveConstants.PARAMETERS;
import static org.briarproject.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyPair;
import org.briarproject.api.crypto.KeyParser;
@@ -25,6 +11,9 @@ import org.briarproject.api.crypto.PublicKey;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.crypto.Signature;
import org.briarproject.api.system.SeedProvider;
import org.briarproject.api.transport.IncomingKeys;
import org.briarproject.api.transport.OutgoingKeys;
import org.briarproject.api.transport.TransportKeys;
import org.briarproject.util.ByteUtils;
import org.briarproject.util.StringUtils;
import org.spongycastle.crypto.AsymmetricCipherKeyPair;
@@ -43,6 +32,21 @@ import org.spongycastle.crypto.params.ECPrivateKeyParameters;
import org.spongycastle.crypto.params.ECPublicKeyParameters;
import org.spongycastle.crypto.params.KeyParameter;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;
import javax.inject.Inject;
import static java.util.logging.Level.INFO;
import static org.briarproject.api.invitation.InvitationConstants.CODE_BITS;
import static org.briarproject.api.transport.TransportConstants.TAG_LENGTH;
import static org.briarproject.crypto.EllipticCurveConstants.PARAMETERS;
import static org.briarproject.util.ByteUtils.MAX_32_BIT_UNSIGNED;
class CryptoComponentImpl implements CryptoComponent {
private static final Logger LOG =
@@ -55,22 +59,33 @@ class CryptoComponentImpl implements CryptoComponent {
private static final int PBKDF_TARGET_MILLIS = 500;
private static final int PBKDF_SAMPLES = 30;
// Labels for secret derivation
private static final byte[] MASTER = { 'M', 'A', 'S', 'T', 'E', 'R', '\0' };
private static final byte[] SALT = { 'S', 'A', 'L', 'T', '\0' };
private static final byte[] FIRST = { 'F', 'I', 'R', 'S', 'T', '\0' };
private static final byte[] ROTATE = { 'R', 'O', 'T', 'A', 'T', 'E', '\0' };
// Label for confirmation code derivation
private static final byte[] CODE = { 'C', 'O', 'D', 'E', '\0' };
// Label for invitation nonce derivation
private static final byte[] NONCE = { 'N', 'O', 'N', 'C', 'E', '\0' };
// Labels for key derivation
private static final byte[] A_TAG = { 'A', '_', 'T', 'A', 'G', '\0' };
private static final byte[] B_TAG = { 'B', '_', 'T', 'A', 'G', '\0' };
private static final byte[] A_FRAME =
{ 'A', '_', 'F', 'R', 'A', 'M', 'E', '\0' };
private static final byte[] B_FRAME =
{ 'B', '_', 'F', 'R', 'A', 'M', 'E', '\0' };
// KDF label for master key derivation
private static final byte[] MASTER = { 'M', 'A', 'S', 'T', 'E', 'R' };
// KDF labels for confirmation code derivation
private static final byte[] A_CONFIRM =
{ 'A', '_', 'C', 'O', 'N', 'F', 'I', 'R', 'M' };
private static final byte[] B_CONFIRM =
{ 'B', '_', 'C', 'O', 'N', 'F', 'I', 'R', 'M' };
// KDF labels for invitation stream header key derivation
private static final byte[] A_INVITE =
{ 'A', '_', 'I', 'N', 'V', 'I', 'T', 'E' };
private static final byte[] B_INVITE =
{ 'B', '_', 'I', 'N', 'V', 'I', 'T', 'E' };
// KDF labels for signature nonce derivation
private static final byte[] A_NONCE = { 'A', '_', 'N', 'O', 'N', 'C', 'E' };
private static final byte[] B_NONCE = { 'B', '_', 'N', 'O', 'N', 'C', 'E' };
// KDF label for group salt derivation
private static final byte[] SALT = { 'S', 'A', 'L', 'T' };
// KDF labels for tag key derivation
private static final byte[] A_TAG = { 'A', '_', 'T', 'A', 'G' };
private static final byte[] B_TAG = { 'B', '_', 'T', 'A', 'G' };
// KDF labels for header key derivation
private static final byte[] A_HEADER =
{ 'A', '_', 'H', 'E', 'A', 'D', 'E', 'R' };
private static final byte[] B_HEADER =
{ 'B', '_', 'H', 'E', 'A', 'D', 'E', 'R' };
// KDF label for key rotation
private static final byte[] ROTATE = { 'R', 'O', 'T', 'A', 'T', 'E' };
private final SecureRandom secureRandom;
private final ECKeyPairGenerator agreementKeyPairGenerator;
@@ -167,26 +182,7 @@ class CryptoComponentImpl implements CryptoComponent {
return ByteUtils.readUint(random, CODE_BITS);
}
public int[] deriveConfirmationCodes(byte[] secret) {
if (secret.length != SecretKey.LENGTH)
throw new IllegalArgumentException();
byte[] alice = counterModeKdf(secret, CODE, 0);
byte[] bob = counterModeKdf(secret, CODE, 1);
int[] codes = new int[2];
codes[0] = ByteUtils.readUint(alice, CODE_BITS);
codes[1] = ByteUtils.readUint(bob, CODE_BITS);
return codes;
}
public byte[][] deriveInvitationNonces(byte[] secret) {
if (secret.length != SecretKey.LENGTH)
throw new IllegalArgumentException();
byte[] alice = counterModeKdf(secret, NONCE, 0);
byte[] bob = counterModeKdf(secret, NONCE, 1);
return new byte[][] { alice, bob };
}
public byte[] deriveMasterSecret(byte[] theirPublicKey,
public SecretKey deriveMasterSecret(byte[] theirPublicKey,
KeyPair ourKeyPair, boolean alice) throws GeneralSecurityException {
MessageDigest messageDigest = getMessageDigest();
byte[] ourPublicKey = ourKeyPair.getPublic().getEncoded();
@@ -204,9 +200,8 @@ class CryptoComponentImpl implements CryptoComponent {
PublicKey theirPub = agreementKeyParser.parsePublicKey(theirPublicKey);
// The raw secret comes from the key agreement algorithm
byte[] raw = deriveSharedSecret(ourPriv, theirPub);
// Derive the cooked secret from the raw secret using the
// concatenation KDF
return concatenationKdf(raw, MASTER, aliceInfo, bobInfo);
// Derive the master secret from the raw secret using the hash KDF
return new SecretKey(hashKdf(raw, MASTER, aliceInfo, bobInfo));
}
// Package access for testing
@@ -228,46 +223,90 @@ class CryptoComponentImpl implements CryptoComponent {
return secret;
}
public byte[] deriveGroupSalt(byte[] secret) {
if (secret.length != SecretKey.LENGTH)
throw new IllegalArgumentException();
return counterModeKdf(secret, SALT, 0);
public int deriveConfirmationCode(SecretKey master, boolean alice) {
byte[] b = macKdf(master, alice ? A_CONFIRM : B_CONFIRM);
return ByteUtils.readUint(b, CODE_BITS);
}
public byte[] deriveInitialSecret(byte[] secret, int transportIndex) {
if (secret.length != SecretKey.LENGTH)
throw new IllegalArgumentException();
if (transportIndex < 0) throw new IllegalArgumentException();
return counterModeKdf(secret, FIRST, transportIndex);
public SecretKey deriveInvitationKey(SecretKey master, boolean alice) {
return new SecretKey(macKdf(master, alice ? A_INVITE : B_INVITE));
}
public byte[] deriveNextSecret(byte[] secret, long period) {
if (secret.length != SecretKey.LENGTH)
throw new IllegalArgumentException();
if (period < 0 || period > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
return counterModeKdf(secret, ROTATE, period);
public byte[] deriveSignatureNonce(SecretKey master, boolean alice) {
return macKdf(master, alice ? A_NONCE : B_NONCE);
}
public SecretKey deriveTagKey(byte[] secret, boolean alice) {
if (secret.length != SecretKey.LENGTH)
throw new IllegalArgumentException();
if (alice) return deriveKey(secret, A_TAG, 0);
else return deriveKey(secret, B_TAG, 0);
public byte[] deriveGroupSalt(SecretKey master) {
return macKdf(master, SALT);
}
public SecretKey deriveFrameKey(byte[] secret, long streamNumber,
public TransportKeys deriveTransportKeys(TransportId t,
SecretKey master, long rotationPeriod, boolean alice) {
// Keys for the previous period are derived from the master secret
SecretKey inTagPrev = deriveTagKey(master, t, !alice);
SecretKey inHeaderPrev = deriveHeaderKey(master, t, !alice);
SecretKey outTagPrev = deriveTagKey(master, t, alice);
SecretKey outHeaderPrev = deriveHeaderKey(master, t, alice);
// Derive the keys for the current and next periods
SecretKey inTagCurr = rotateKey(inTagPrev, rotationPeriod);
SecretKey inHeaderCurr = rotateKey(inHeaderPrev, rotationPeriod);
SecretKey inTagNext = rotateKey(inTagCurr, rotationPeriod + 1);
SecretKey inHeaderNext = rotateKey(inHeaderCurr, rotationPeriod + 1);
SecretKey outTagCurr = rotateKey(outTagPrev, rotationPeriod);
SecretKey outHeaderCurr = rotateKey(outHeaderPrev, rotationPeriod);
// Initialise the reordering windows and stream counters
IncomingKeys inPrev = new IncomingKeys(inTagPrev, inHeaderPrev,
rotationPeriod - 1);
IncomingKeys inCurr = new IncomingKeys(inTagCurr, inHeaderCurr,
rotationPeriod);
IncomingKeys inNext = new IncomingKeys(inTagNext, inHeaderNext,
rotationPeriod + 1);
OutgoingKeys outCurr = new OutgoingKeys(outTagCurr, outHeaderCurr,
rotationPeriod);
// Collect and return the keys
return new TransportKeys(t, inPrev, inCurr, inNext, outCurr);
}
public TransportKeys rotateTransportKeys(TransportKeys k,
long rotationPeriod) {
if (k.getRotationPeriod() >= rotationPeriod) return k;
IncomingKeys inPrev = k.getPreviousIncomingKeys();
IncomingKeys inCurr = k.getCurrentIncomingKeys();
IncomingKeys inNext = k.getNextIncomingKeys();
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
long startPeriod = outCurr.getRotationPeriod();
// Rotate the keys
for (long p = startPeriod + 1; p <= rotationPeriod; p++) {
inPrev = inCurr;
inCurr = inNext;
SecretKey inNextTag = rotateKey(inNext.getTagKey(), p + 1);
SecretKey inNextHeader = rotateKey(inNext.getHeaderKey(), p + 1);
inNext = new IncomingKeys(inNextTag, inNextHeader, p);
SecretKey outCurrTag = rotateKey(outCurr.getTagKey(), p);
SecretKey outCurrHeader = rotateKey(outCurr.getHeaderKey(), p);
outCurr = new OutgoingKeys(outCurrTag, outCurrHeader, p);
}
// Collect and return the keys
return new TransportKeys(k.getTransportId(), inPrev, inCurr, inNext,
outCurr);
}
private SecretKey rotateKey(SecretKey k, long rotationPeriod) {
byte[] period = new byte[4];
ByteUtils.writeUint32(rotationPeriod, period, 0);
return new SecretKey(macKdf(k, ROTATE, period));
}
private SecretKey deriveTagKey(SecretKey master, TransportId t,
boolean alice) {
if (secret.length != SecretKey.LENGTH)
throw new IllegalArgumentException();
if (streamNumber < 0 || streamNumber > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
if (alice) return deriveKey(secret, A_FRAME, streamNumber);
else return deriveKey(secret, B_FRAME, streamNumber);
byte[] id = StringUtils.toUtf8(t.getString());
return new SecretKey(macKdf(master, alice ? A_TAG : B_TAG, id));
}
private SecretKey deriveKey(byte[] secret, byte[] label, long context) {
return new SecretKey(counterModeKdf(secret, label, context));
private SecretKey deriveHeaderKey(SecretKey master, TransportId t,
boolean alice) {
byte[] id = StringUtils.toUtf8(t.getString());
return new SecretKey(macKdf(master, alice ? A_HEADER : B_HEADER, id));
}
public void encodeTag(byte[] tag, SecretKey tagKey, long streamNumber) {
@@ -277,7 +316,8 @@ class CryptoComponentImpl implements CryptoComponent {
for (int i = 0; i < TAG_LENGTH; i++) tag[i] = 0;
ByteUtils.writeUint32(streamNumber, tag, 0);
BlockCipher cipher = new AESLightEngine();
assert cipher.getBlockSize() == TAG_LENGTH;
if (cipher.getBlockSize() != TAG_LENGTH)
throw new IllegalStateException();
KeyParameter k = new KeyParameter(tagKey.getBytes());
cipher.init(true, k);
cipher.processBlock(tag, 0, tag, 0);
@@ -348,16 +388,16 @@ class CryptoComponentImpl implements CryptoComponent {
// Key derivation function based on a hash function - see NIST SP 800-56A,
// section 5.8
private byte[] concatenationKdf(byte[]... inputs) {
private byte[] hashKdf(byte[]... inputs) {
// The output of the hash function must be long enough to use as a key
MessageDigest messageDigest = getMessageDigest();
if (messageDigest.getDigestLength() < SecretKey.LENGTH)
throw new RuntimeException();
// Each input is length-prefixed - the length must fit in an
// unsigned 8-bit integer
throw new IllegalStateException();
// Calculate the hash over the concatenated length-prefixed inputs
byte[] length = new byte[4];
for (byte[] input : inputs) {
if (input.length > 255) throw new IllegalArgumentException();
messageDigest.update((byte) input.length);
ByteUtils.writeUint32(input.length, length, 0);
messageDigest.update(length);
messageDigest.update(input);
}
byte[] hash = messageDigest.digest();
@@ -368,28 +408,24 @@ class CryptoComponentImpl implements CryptoComponent {
return truncated;
}
// Key derivation function based on a PRF in counter mode - see
// Key derivation function based on a pseudo-random function - see
// NIST SP 800-108, section 5.1
private byte[] counterModeKdf(byte[] secret, byte[] label, long context) {
if (secret.length != SecretKey.LENGTH)
throw new IllegalArgumentException();
// The label must be null-terminated
if (label[label.length - 1] != '\0')
throw new IllegalArgumentException();
private byte[] macKdf(SecretKey key, byte[]... inputs) {
// Initialise the PRF
Mac prf = new HMac(new SHA256Digest());
KeyParameter k = new KeyParameter(secret);
prf.init(k);
int macLength = prf.getMacSize();
prf.init(new KeyParameter(key.getBytes()));
// The output of the PRF must be long enough to use as a key
if (macLength < SecretKey.LENGTH) throw new RuntimeException();
int macLength = prf.getMacSize();
if (macLength < SecretKey.LENGTH)
throw new IllegalStateException();
// Calculate the PRF over the concatenated length-prefixed inputs
byte[] length = new byte[4];
for (byte[] input : inputs) {
ByteUtils.writeUint32(input.length, length, 0);
prf.update(length, 0, length.length);
prf.update(input, 0, input.length);
}
byte[] mac = new byte[macLength];
prf.update((byte) 0); // Counter
prf.update(label, 0, label.length); // Null-terminated
byte[] contextBytes = new byte[4];
ByteUtils.writeUint32(context, contextBytes, 0);
prf.update(contextBytes, 0, contextBytes.length);
prf.update((byte) SecretKey.LENGTH); // Output length
prf.doFinal(mac, 0);
// The output is the first SecretKey.LENGTH bytes of the MAC
if (mac.length == SecretKey.LENGTH) return mac;

View File

@@ -5,7 +5,6 @@ import java.io.InputStream;
import javax.inject.Inject;
import javax.inject.Provider;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.crypto.StreamDecrypter;
import org.briarproject.api.crypto.StreamDecrypterFactory;
@@ -13,34 +12,21 @@ import org.briarproject.api.transport.StreamContext;
class StreamDecrypterFactoryImpl implements StreamDecrypterFactory {
private final CryptoComponent crypto;
private final Provider<AuthenticatedCipher> cipherProvider;
@Inject
StreamDecrypterFactoryImpl(CryptoComponent crypto,
Provider<AuthenticatedCipher> cipherProvider) {
this.crypto = crypto;
StreamDecrypterFactoryImpl(Provider<AuthenticatedCipher> cipherProvider) {
this.cipherProvider = cipherProvider;
}
public StreamDecrypter createStreamDecrypter(InputStream in,
StreamContext ctx) {
// Derive the frame key
byte[] secret = ctx.getSecret();
long streamNumber = ctx.getStreamNumber();
boolean alice = !ctx.getAlice();
SecretKey frameKey = crypto.deriveFrameKey(secret, streamNumber, alice);
// Create the decrypter
AuthenticatedCipher cipher = cipherProvider.get();
return new StreamDecrypterImpl(in, cipher, frameKey);
return new StreamDecrypterImpl(in, cipher, ctx.getHeaderKey());
}
public StreamDecrypter createInvitationStreamDecrypter(InputStream in,
byte[] secret, boolean alice) {
// Derive the frame key
SecretKey frameKey = crypto.deriveFrameKey(secret, 0, alice);
// Create the decrypter
AuthenticatedCipher cipher = cipherProvider.get();
return new StreamDecrypterImpl(in, cipher, frameKey);
SecretKey headerKey) {
return new StreamDecrypterImpl(in, cipherProvider.get(), headerKey);
}
}

View File

@@ -1,75 +1,77 @@
package org.briarproject.crypto;
import org.briarproject.api.FormatException;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.crypto.StreamDecrypter;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.security.GeneralSecurityException;
import static org.briarproject.api.transport.TransportConstants.HEADER_LENGTH;
import static org.briarproject.api.transport.TransportConstants.IV_LENGTH;
import static org.briarproject.api.transport.TransportConstants.MAC_LENGTH;
import static org.briarproject.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import static org.briarproject.api.transport.TransportConstants.MAX_PAYLOAD_LENGTH;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.security.GeneralSecurityException;
import org.briarproject.api.FormatException;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.crypto.StreamDecrypter;
// FIXME: Implementation is incomplete, doesn't read the stream header
class StreamDecrypterImpl implements StreamDecrypter {
private final InputStream in;
private final AuthenticatedCipher frameCipher;
private final SecretKey frameKey;
private final byte[] iv, header, ciphertext;
private final byte[] iv, frameHeader, frameCiphertext;
private long frameNumber;
private boolean finalFrame;
StreamDecrypterImpl(InputStream in, AuthenticatedCipher frameCipher,
SecretKey frameKey) {
SecretKey headerKey) {
this.in = in;
this.frameCipher = frameCipher;
this.frameKey = frameKey;
this.frameKey = headerKey; // FIXME
iv = new byte[IV_LENGTH];
header = new byte[HEADER_LENGTH];
ciphertext = new byte[MAX_FRAME_LENGTH];
frameHeader = new byte[HEADER_LENGTH];
frameCiphertext = new byte[MAX_FRAME_LENGTH];
frameNumber = 0;
finalFrame = false;
}
public int readFrame(byte[] payload) throws IOException {
// The buffer must be big enough for a full-size frame
if (payload.length < MAX_PAYLOAD_LENGTH)
throw new IllegalArgumentException();
if (finalFrame) return -1;
// Read the header
// Read the frame header
int offset = 0;
while (offset < HEADER_LENGTH) {
int read = in.read(ciphertext, offset, HEADER_LENGTH - offset);
int read = in.read(frameCiphertext, offset, HEADER_LENGTH - offset);
if (read == -1) throw new EOFException();
offset += read;
}
// Decrypt and authenticate the header
// Decrypt and authenticate the frame header
FrameEncoder.encodeIv(iv, frameNumber, true);
try {
frameCipher.init(false, frameKey, iv);
int decrypted = frameCipher.process(ciphertext, 0, HEADER_LENGTH,
header, 0);
int decrypted = frameCipher.process(frameCiphertext, 0,
HEADER_LENGTH, frameHeader, 0);
if (decrypted != HEADER_LENGTH - MAC_LENGTH)
throw new RuntimeException();
} catch (GeneralSecurityException e) {
throw new FormatException();
}
// Decode and validate the header
finalFrame = FrameEncoder.isFinalFrame(header);
int payloadLength = FrameEncoder.getPayloadLength(header);
int paddingLength = FrameEncoder.getPaddingLength(header);
// Decode and validate the frame header
finalFrame = FrameEncoder.isFinalFrame(frameHeader);
int payloadLength = FrameEncoder.getPayloadLength(frameHeader);
int paddingLength = FrameEncoder.getPaddingLength(frameHeader);
if (payloadLength + paddingLength > MAX_PAYLOAD_LENGTH)
throw new FormatException();
// Read the payload and padding
int frameLength = HEADER_LENGTH + payloadLength + paddingLength
+ MAC_LENGTH;
while (offset < frameLength) {
int read = in.read(ciphertext, offset, frameLength - offset);
int read = in.read(frameCiphertext, offset, frameLength - offset);
if (read == -1) throw new EOFException();
offset += read;
}
@@ -77,7 +79,7 @@ class StreamDecrypterImpl implements StreamDecrypter {
FrameEncoder.encodeIv(iv, frameNumber, false);
try {
frameCipher.init(false, frameKey, iv);
int decrypted = frameCipher.process(ciphertext, HEADER_LENGTH,
int decrypted = frameCipher.process(frameCiphertext, HEADER_LENGTH,
payloadLength + paddingLength + MAC_LENGTH, payload, 0);
if (decrypted != payloadLength + paddingLength)
throw new RuntimeException();

View File

@@ -27,26 +27,15 @@ class StreamEncrypterFactoryImpl implements StreamEncrypterFactory {
public StreamEncrypter createStreamEncrypter(OutputStream out,
StreamContext ctx) {
byte[] secret = ctx.getSecret();
long streamNumber = ctx.getStreamNumber();
boolean alice = ctx.getAlice();
// Encode the tag
byte[] tag = new byte[TAG_LENGTH];
SecretKey tagKey = crypto.deriveTagKey(secret, alice);
crypto.encodeTag(tag, tagKey, streamNumber);
// Derive the frame key
SecretKey frameKey = crypto.deriveFrameKey(secret, streamNumber, alice);
// Create the encrypter
crypto.encodeTag(tag, ctx.getTagKey(), ctx.getStreamNumber());
AuthenticatedCipher cipher = cipherProvider.get();
return new StreamEncrypterImpl(out, cipher, frameKey, tag);
return new StreamEncrypterImpl(out, cipher, ctx.getHeaderKey(), tag);
}
public StreamEncrypter createInvitationStreamEncrypter(OutputStream out,
byte[] secret, boolean alice) {
// Derive the frame key
SecretKey frameKey = crypto.deriveFrameKey(secret, 0, alice);
// Create the encrypter
SecretKey headerKey) {
AuthenticatedCipher cipher = cipherProvider.get();
return new StreamEncrypterImpl(out, cipher, frameKey, null);
return new StreamEncrypterImpl(out, cipher, headerKey, null);
}
}

View File

@@ -1,5 +1,12 @@
package org.briarproject.crypto;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.crypto.StreamEncrypter;
import java.io.IOException;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import static org.briarproject.api.transport.TransportConstants.HEADER_LENGTH;
import static org.briarproject.api.transport.TransportConstants.IV_LENGTH;
import static org.briarproject.api.transport.TransportConstants.MAC_LENGTH;
@@ -7,32 +14,26 @@ import static org.briarproject.api.transport.TransportConstants.MAX_FRAME_LENGTH
import static org.briarproject.api.transport.TransportConstants.MAX_PAYLOAD_LENGTH;
import static org.briarproject.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.io.IOException;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.crypto.StreamEncrypter;
// FIXME: Implementation is incomplete, doesn't write the stream header
class StreamEncrypterImpl implements StreamEncrypter {
private final OutputStream out;
private final AuthenticatedCipher frameCipher;
private final SecretKey frameKey;
private final byte[] tag, iv, plaintext, ciphertext;
private final byte[] tag, iv, framePlaintext, frameCiphertext;
private long frameNumber;
private boolean writeTag;
StreamEncrypterImpl(OutputStream out, AuthenticatedCipher frameCipher,
SecretKey frameKey, byte[] tag) {
SecretKey headerKey, byte[] tag) {
this.out = out;
this.frameCipher = frameCipher;
this.frameKey = frameKey;
this.frameKey = headerKey; // FIXME
this.tag = tag;
iv = new byte[IV_LENGTH];
plaintext = new byte[HEADER_LENGTH + MAX_PAYLOAD_LENGTH];
ciphertext = new byte[MAX_FRAME_LENGTH];
framePlaintext = new byte[HEADER_LENGTH + MAX_PAYLOAD_LENGTH];
frameCiphertext = new byte[MAX_FRAME_LENGTH];
frameNumber = 0;
writeTag = (tag != null);
}
@@ -48,37 +49,39 @@ class StreamEncrypterImpl implements StreamEncrypter {
out.write(tag, 0, tag.length);
writeTag = false;
}
// Encode the header
FrameEncoder.encodeHeader(plaintext, finalFrame, payloadLength,
// Encode the frame header
FrameEncoder.encodeHeader(framePlaintext, finalFrame, payloadLength,
paddingLength);
// Encrypt and authenticate the header
// Encrypt and authenticate the frame header
FrameEncoder.encodeIv(iv, frameNumber, true);
try {
frameCipher.init(true, frameKey, iv);
int encrypted = frameCipher.process(plaintext, 0,
HEADER_LENGTH - MAC_LENGTH, ciphertext, 0);
int encrypted = frameCipher.process(framePlaintext, 0,
HEADER_LENGTH - MAC_LENGTH, frameCiphertext, 0);
if (encrypted != HEADER_LENGTH) throw new RuntimeException();
} catch (GeneralSecurityException badCipher) {
throw new RuntimeException(badCipher);
}
// Combine the payload and padding
System.arraycopy(payload, 0, plaintext, HEADER_LENGTH, payloadLength);
System.arraycopy(payload, 0, framePlaintext, HEADER_LENGTH,
payloadLength);
for (int i = 0; i < paddingLength; i++)
plaintext[HEADER_LENGTH + payloadLength + i] = 0;
framePlaintext[HEADER_LENGTH + payloadLength + i] = 0;
// Encrypt and authenticate the payload and padding
FrameEncoder.encodeIv(iv, frameNumber, false);
try {
frameCipher.init(true, frameKey, iv);
int encrypted = frameCipher.process(plaintext, HEADER_LENGTH,
payloadLength + paddingLength, ciphertext, HEADER_LENGTH);
int encrypted = frameCipher.process(framePlaintext, HEADER_LENGTH,
payloadLength + paddingLength, frameCiphertext,
HEADER_LENGTH);
if (encrypted != payloadLength + paddingLength + MAC_LENGTH)
throw new RuntimeException();
} catch (GeneralSecurityException badCipher) {
throw new RuntimeException(badCipher);
}
// Write the frame
out.write(ciphertext, 0, HEADER_LENGTH + payloadLength + paddingLength
+ MAC_LENGTH);
out.write(frameCiphertext, 0, HEADER_LENGTH + payloadLength
+ paddingLength + MAC_LENGTH);
frameNumber++;
}

View File

@@ -1,9 +1,5 @@
package org.briarproject.db;
import java.io.IOException;
import java.util.Collection;
import java.util.Map;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorId;
import org.briarproject.api.Contact;
@@ -25,8 +21,11 @@ import org.briarproject.api.messaging.SubscriptionAck;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.TransportAck;
import org.briarproject.api.messaging.TransportUpdate;
import org.briarproject.api.transport.Endpoint;
import org.briarproject.api.transport.TemporarySecret;
import org.briarproject.api.transport.TransportKeys;
import java.io.IOException;
import java.util.Collection;
import java.util.Map;
// FIXME: Document the preconditions for calling each method
@@ -89,13 +88,6 @@ interface Database<T> {
ContactId addContact(T txn, Author remote, AuthorId local)
throws DbException;
/**
* Stores an endpoint.
* <p>
* Locking: write.
*/
void addEndpoint(T txn, Endpoint ep) throws DbException;
/**
* Subscribes to a group, or returns false if the user already has the
* maximum number of subscriptions.
@@ -125,15 +117,6 @@ interface Database<T> {
*/
void addOfferedMessage(T txn, ContactId c, MessageId m) throws DbException;
/**
* Stores the given temporary secrets and deletes any secrets that have
* been made obsolete.
* <p>
* Locking: write.
*/
void addSecrets(T txn, Collection<TemporarySecret> secrets)
throws DbException;
/**
* Initialises the status of the given message with respect to the given
* contact.
@@ -154,6 +137,13 @@ interface Database<T> {
boolean addTransport(T txn, TransportId t, int maxLatency)
throws DbException;
/**
* Stores the given transport keys for a newly added contact.
* <p>
* Locking: write.
*/
void addTransportKeys(T txn, ContactId c, TransportKeys k) throws DbException;
/**
* Makes a group visible to the given contact.
* <p>
@@ -270,13 +260,6 @@ interface Database<T> {
*/
Collection<ContactId> getContacts(T txn, AuthorId a) throws DbException;
/**
* Returns all endpoints.
* <p>
* Locking: read.
*/
Collection<Endpoint> getEndpoints(T txn) throws DbException;
/**
* Returns the amount of free storage space available to the database, in
* bytes. This is based on the minimum of the space available on the device
@@ -461,13 +444,6 @@ interface Database<T> {
RetentionUpdate getRetentionUpdate(T txn, ContactId c, int maxLatency)
throws DbException;
/**
* Returns all temporary secrets.
* <p>
* Locking: read.
*/
Collection<TemporarySecret> getSecrets(T txn) throws DbException;
/**
* Returns all settings.
* <p>
@@ -509,7 +485,15 @@ interface Database<T> {
throws DbException;
/**
* Returns the maximum latencies of all supported transports.
* Returns all transport keys for the given transport.
* <p>
* Locking: read.
*/
Map<ContactId, TransportKeys> getTransportKeys(T txn, TransportId t)
throws DbException;
/**
* Returns the maximum latencies in milliseconds of all transports.
* <p>
* Locking: read.
*/
@@ -540,14 +524,13 @@ interface Database<T> {
Collection<ContactId> getVisibility(T txn, GroupId g) throws DbException;
/**
* Increments the outgoing stream counter for the given endpoint in the
* given rotation period and returns the old value, or -1 if the counter
* does not exist.
* Increments the outgoing stream counter for the given contact and
* transport in the given rotation period.
* <p>
* Locking: write.
*/
long incrementStreamCounter(T txn, ContactId c, TransportId t, long period)
throws DbException;
void incrementStreamCounter(T txn, ContactId c, TransportId t,
long rotationPeriod) throws DbException;
/**
* Increments the retention time versions for all contacts to indicate that
@@ -692,13 +675,13 @@ interface Database<T> {
void resetExpiryTime(T txn, ContactId c, MessageId m) throws DbException;
/**
* Sets the reordering window for the given endpoint in the given rotation
* period.
* Sets the reordering window for the given contact and transport in the
* given rotation period.
* <p>
* Locking: write.
*/
void setReorderingWindow(T txn, ContactId c, TransportId t, long period,
long centre, byte[] bitmap) throws DbException;
void setReorderingWindow(T txn, ContactId c, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException;
/**
* Updates the groups to which the given contact subscribes and returns
@@ -716,7 +699,7 @@ interface Database<T> {
* <p>
* Locking: write.
*/
public void setInboxGroup(T txn, ContactId c, Group g) throws DbException;
void setInboxGroup(T txn, ContactId c, Group g) throws DbException;
/**
* Marks a message as read or unread.
@@ -798,4 +781,12 @@ interface Database<T> {
*/
void updateExpiryTime(T txn, ContactId c, MessageId m, int maxLatency)
throws DbException;
/**
* Stores the given transport keys, deleting any keys they have replaced.
* <p>
* Locking: write.
*/
void updateTransportKeys(T txn, Map<ContactId, TransportKeys> keys)
throws DbException;
}

View File

@@ -1,25 +1,5 @@
package org.briarproject.db;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import static org.briarproject.db.DatabaseConstants.BYTES_PER_SWEEP;
import static org.briarproject.db.DatabaseConstants.CRITICAL_FREE_SPACE;
import static org.briarproject.db.DatabaseConstants.MAX_OFFERED_MESSAGES;
import static org.briarproject.db.DatabaseConstants.MAX_TRANSACTIONS_BETWEEN_SPACE_CHECKS;
import static org.briarproject.db.DatabaseConstants.MIN_FREE_SPACE;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorId;
import org.briarproject.api.Contact;
@@ -75,8 +55,29 @@ import org.briarproject.api.messaging.SubscriptionAck;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.TransportAck;
import org.briarproject.api.messaging.TransportUpdate;
import org.briarproject.api.transport.Endpoint;
import org.briarproject.api.transport.TemporarySecret;
import org.briarproject.api.transport.TransportKeys;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.logging.Logger;
import javax.inject.Inject;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import static org.briarproject.db.DatabaseConstants.BYTES_PER_SWEEP;
import static org.briarproject.db.DatabaseConstants.CRITICAL_FREE_SPACE;
import static org.briarproject.db.DatabaseConstants.MAX_OFFERED_MESSAGES;
import static org.briarproject.db.DatabaseConstants.MAX_TRANSACTIONS_BETWEEN_SPACE_CHECKS;
import static org.briarproject.db.DatabaseConstants.MIN_FREE_SPACE;
/**
* An implementation of DatabaseComponent using reentrant read-write locks.
@@ -85,7 +86,7 @@ import org.briarproject.api.transport.TemporarySecret;
* implementation is safe on a given JVM.
*/
class DatabaseComponentImpl<T> implements DatabaseComponent,
DatabaseCleaner.Callback {
DatabaseCleaner.Callback {
private static final Logger LOG =
Logger.getLogger(DatabaseComponentImpl.class.getName());
@@ -180,26 +181,6 @@ DatabaseCleaner.Callback {
return c;
}
public void addEndpoint(Endpoint ep) throws DbException {
lock.writeLock().lock();
try {
T txn = db.startTransaction();
try {
if (!db.containsContact(txn, ep.getContactId()))
throw new NoSuchContactException();
if (!db.containsTransport(txn, ep.getTransportId()))
throw new NoSuchTransportException();
db.addEndpoint(txn, ep);
db.commitTransaction(txn);
} catch (DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
lock.writeLock().unlock();
}
}
public boolean addGroup(Group g) throws DbException {
boolean added = false;
lock.writeLock().lock();
@@ -290,30 +271,6 @@ DatabaseCleaner.Callback {
}
}
public void addSecrets(Collection<TemporarySecret> secrets)
throws DbException {
lock.writeLock().lock();
try {
T txn = db.startTransaction();
try {
Collection<TemporarySecret> relevant =
new ArrayList<TemporarySecret>();
for (TemporarySecret s : secrets) {
if (db.containsContact(txn, s.getContactId()))
if (db.containsTransport(txn, s.getTransportId()))
relevant.add(s);
}
if (!secrets.isEmpty()) db.addSecrets(txn, relevant);
db.commitTransaction(txn);
} catch (DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
lock.writeLock().unlock();
}
}
public boolean addTransport(TransportId t, int maxLatency)
throws DbException {
boolean added;
@@ -334,6 +291,27 @@ DatabaseCleaner.Callback {
return added;
}
public void addTransportKeys(ContactId c, TransportKeys k)
throws DbException {
lock.writeLock().lock();
try {
T txn = db.startTransaction();
try {
if (!db.containsContact(txn, c))
throw new NoSuchContactException();
if (!db.containsTransport(txn, k.getTransportId()))
throw new NoSuchTransportException();
db.addTransportKeys(txn, c, k);
db.commitTransaction(txn);
} catch (DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
lock.writeLock().unlock();
}
}
public Ack generateAck(ContactId c, int maxMessages) throws DbException {
Collection<MessageId> ids;
lock.writeLock().lock();
@@ -883,23 +861,6 @@ DatabaseCleaner.Callback {
}
}
public Collection<TemporarySecret> getSecrets() throws DbException {
lock.readLock().lock();
try {
T txn = db.startTransaction();
try {
Collection<TemporarySecret> secrets = db.getSecrets(txn);
db.commitTransaction(txn);
return secrets;
} catch (DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
lock.readLock().unlock();
}
}
public Settings getSettings() throws DbException {
lock.readLock().lock();
try {
@@ -934,6 +895,27 @@ DatabaseCleaner.Callback {
}
}
public Map<ContactId, TransportKeys> getTransportKeys(TransportId t)
throws DbException {
lock.readLock().lock();
try {
T txn = db.startTransaction();
try {
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
Map<ContactId, TransportKeys> keys =
db.getTransportKeys(txn, t);
db.commitTransaction(txn);
return keys;
} catch (DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
lock.readLock().unlock();
}
}
public Map<TransportId, Integer> getTransportLatencies()
throws DbException {
lock.readLock().lock();
@@ -989,8 +971,8 @@ DatabaseCleaner.Callback {
}
}
public long incrementStreamCounter(ContactId c, TransportId t,
long period) throws DbException {
public void incrementStreamCounter(ContactId c, TransportId t,
long rotationPeriod) throws DbException {
lock.writeLock().lock();
try {
T txn = db.startTransaction();
@@ -999,9 +981,8 @@ DatabaseCleaner.Callback {
throw new NoSuchContactException();
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
long counter = db.incrementStreamCounter(txn, c, t, period);
db.incrementStreamCounter(txn, c, t, rotationPeriod);
db.commitTransaction(txn);
return counter;
} catch (DbException e) {
db.abortTransaction(txn);
throw e;
@@ -1404,27 +1385,6 @@ DatabaseCleaner.Callback {
eventBus.broadcast(new TransportRemovedEvent(t));
}
public void setReorderingWindow(ContactId c, TransportId t, long period,
long centre, byte[] bitmap) throws DbException {
lock.writeLock().lock();
try {
T txn = db.startTransaction();
try {
if (!db.containsContact(txn, c))
throw new NoSuchContactException();
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
db.setReorderingWindow(txn, c, t, period, centre, bitmap);
db.commitTransaction(txn);
} catch (DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
lock.writeLock().unlock();
}
}
public void setInboxGroup(ContactId c, Group g) throws DbException {
lock.writeLock().lock();
try {
@@ -1480,6 +1440,27 @@ DatabaseCleaner.Callback {
}
}
public void setReorderingWindow(ContactId c, TransportId t,
long rotationPeriod, long base, byte[] bitmap) throws DbException {
lock.writeLock().lock();
try {
T txn = db.startTransaction();
try {
if (!db.containsContact(txn, c))
throw new NoSuchContactException();
if (!db.containsTransport(txn, t))
throw new NoSuchTransportException();
db.setReorderingWindow(txn, c, t, rotationPeriod, base, bitmap);
db.commitTransaction(txn);
} catch (DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
lock.writeLock().unlock();
}
}
public void setVisibility(GroupId g, Collection<ContactId> visible)
throws DbException {
Collection<ContactId> affected = new ArrayList<ContactId>();
@@ -1552,6 +1533,33 @@ DatabaseCleaner.Callback {
eventBus.broadcast(new LocalSubscriptionsUpdatedEvent(affected));
}
public void updateTransportKeys(Map<ContactId, TransportKeys> keys)
throws DbException {
lock.writeLock().lock();
try {
T txn = db.startTransaction();
try {
Map<ContactId, TransportKeys> filtered =
new HashMap<ContactId, TransportKeys>();
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
ContactId c = e.getKey();
TransportKeys k = e.getValue();
if (db.containsContact(txn, c)
&& db.containsTransport(txn, k.getTransportId())) {
filtered.put(c, k);
}
}
db.updateTransportKeys(txn, filtered);
db.commitTransaction(txn);
} catch (DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
lock.writeLock().unlock();
}
}
public void checkFreeSpaceAndClean() throws DbException {
long freeSpace = db.getFreeSpace();
if (LOG.isLoggable(INFO)) LOG.info(freeSpace + " bytes free space");

View File

@@ -1,14 +1,33 @@
package org.briarproject.db;
import static java.sql.Types.BINARY;
import static java.sql.Types.VARCHAR;
import static java.util.logging.Level.WARNING;
import static org.briarproject.api.Author.Status.ANONYMOUS;
import static org.briarproject.api.Author.Status.UNKNOWN;
import static org.briarproject.api.Author.Status.VERIFIED;
import static org.briarproject.api.messaging.MessagingConstants.MAX_SUBSCRIPTIONS;
import static org.briarproject.api.messaging.MessagingConstants.RETENTION_GRANULARITY;
import static org.briarproject.db.ExponentialBackoff.calculateExpiry;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorId;
import org.briarproject.api.Contact;
import org.briarproject.api.ContactId;
import org.briarproject.api.LocalAuthor;
import org.briarproject.api.Settings;
import org.briarproject.api.TransportConfig;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.db.DbClosedException;
import org.briarproject.api.db.DbException;
import org.briarproject.api.db.MessageHeader;
import org.briarproject.api.db.MessageHeader.State;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.GroupId;
import org.briarproject.api.messaging.Message;
import org.briarproject.api.messaging.MessageId;
import org.briarproject.api.messaging.RetentionAck;
import org.briarproject.api.messaging.RetentionUpdate;
import org.briarproject.api.messaging.SubscriptionAck;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.TransportAck;
import org.briarproject.api.messaging.TransportUpdate;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.IncomingKeys;
import org.briarproject.api.transport.OutgoingKeys;
import org.briarproject.api.transport.TransportKeys;
import java.io.IOException;
import java.sql.Connection;
@@ -32,32 +51,15 @@ import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorId;
import org.briarproject.api.Contact;
import org.briarproject.api.ContactId;
import org.briarproject.api.LocalAuthor;
import org.briarproject.api.Settings;
import org.briarproject.api.TransportConfig;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.db.DbClosedException;
import org.briarproject.api.db.DbException;
import org.briarproject.api.db.MessageHeader;
import org.briarproject.api.db.MessageHeader.State;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.GroupId;
import org.briarproject.api.messaging.Message;
import org.briarproject.api.messaging.MessageId;
import org.briarproject.api.messaging.RetentionAck;
import org.briarproject.api.messaging.RetentionUpdate;
import org.briarproject.api.messaging.SubscriptionAck;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.TransportAck;
import org.briarproject.api.messaging.TransportUpdate;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.Endpoint;
import org.briarproject.api.transport.TemporarySecret;
import static java.sql.Types.BINARY;
import static java.sql.Types.VARCHAR;
import static java.util.logging.Level.WARNING;
import static org.briarproject.api.Author.Status.ANONYMOUS;
import static org.briarproject.api.Author.Status.UNKNOWN;
import static org.briarproject.api.Author.Status.VERIFIED;
import static org.briarproject.api.messaging.MessagingConstants.MAX_SUBSCRIPTIONS;
import static org.briarproject.api.messaging.MessagingConstants.RETENTION_GRANULARITY;
import static org.briarproject.db.ExponentialBackoff.calculateExpiry;
/**
* A generic database implementation that can be used with any JDBC-compatible
@@ -65,8 +67,8 @@ import org.briarproject.api.transport.TemporarySecret;
*/
abstract class JdbcDatabase implements Database<Connection> {
private static final int SCHEMA_VERSION = 9;
private static final int MIN_SCHEMA_VERSION = 9;
private static final int SCHEMA_VERSION = 10;
private static final int MIN_SCHEMA_VERSION = 10;
private static final String CREATE_SETTINGS =
"CREATE TABLE settings"
@@ -277,13 +279,16 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE)";
private static final String CREATE_ENDPOINTS =
"CREATE TABLE endpoints"
private static final String CREATE_INCOMING_KEYS =
"CREATE TABLE incomingKeys"
+ " (contactId INT NOT NULL,"
+ " transportId VARCHAR NOT NULL,"
+ " epoch BIGINT NOT NULL,"
+ " alice BOOLEAN NOT NULL,"
+ " PRIMARY KEY (contactId, transportId),"
+ " period BIGINT NOT NULL,"
+ " tagKey SECRET NOT NULL,"
+ " headerKey SECRET NOT NULL,"
+ " base BIGINT NOT NULL,"
+ " bitmap BINARY NOT NULL,"
+ " PRIMARY KEY (contactId, transportId, period),"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE,"
@@ -291,16 +296,15 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " REFERENCES transports (transportId)"
+ " ON DELETE CASCADE)";
private static final String CREATE_SECRETS =
"CREATE TABLE secrets"
private static final String CREATE_OUTGOING_KEYS =
"CREATE TABLE outgoingKeys"
+ " (contactId INT NOT NULL,"
+ " transportId VARCHAR NOT NULL,"
+ " period BIGINT NOT NULL,"
+ " secret SECRET NOT NULL,"
+ " outgoing BIGINT NOT NULL,"
+ " centre BIGINT NOT NULL,"
+ " bitmap BINARY NOT NULL,"
+ " PRIMARY KEY (contactId, transportId, period),"
+ " tagKey SECRET NOT NULL,"
+ " headerKey SECRET NOT NULL,"
+ " stream BIGINT NOT NULL,"
+ " PRIMARY KEY (contactId, transportId),"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE,"
@@ -324,6 +328,7 @@ abstract class JdbcDatabase implements Database<Connection> {
private boolean closed = false; // Locking: connectionsLock
protected abstract Connection createConnection() throws SQLException;
protected abstract void flushBuffersToDisk(Statement s) throws SQLException;
private final Lock connectionsLock = new ReentrantLock();
@@ -339,7 +344,7 @@ abstract class JdbcDatabase implements Database<Connection> {
}
protected void open(String driverClass, boolean reopen) throws DbException,
IOException {
IOException {
// Load the JDBC driver
try {
Class.forName(driverClass);
@@ -382,7 +387,7 @@ abstract class JdbcDatabase implements Database<Connection> {
try {
if (rs != null) rs.close();
} catch (SQLException e) {
if (LOG.isLoggable(WARNING))LOG.log(WARNING, e.toString(), e);
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
@@ -390,7 +395,7 @@ abstract class JdbcDatabase implements Database<Connection> {
try {
if (s != null) s.close();
} catch (SQLException e) {
if (LOG.isLoggable(WARNING))LOG.log(WARNING, e.toString(), e);
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
@@ -418,8 +423,8 @@ abstract class JdbcDatabase implements Database<Connection> {
s.executeUpdate(insertTypeNames(CREATE_TRANSPORT_VERSIONS));
s.executeUpdate(insertTypeNames(CREATE_CONTACT_TRANSPORT_PROPS));
s.executeUpdate(insertTypeNames(CREATE_CONTACT_TRANSPORT_VERSIONS));
s.executeUpdate(insertTypeNames(CREATE_ENDPOINTS));
s.executeUpdate(insertTypeNames(CREATE_SECRETS));
s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS));
s.executeUpdate(insertTypeNames(CREATE_OUTGOING_KEYS));
s.close();
} catch (SQLException e) {
tryToClose(s);
@@ -480,7 +485,8 @@ abstract class JdbcDatabase implements Database<Connection> {
try {
txn.close();
} catch (SQLException e1) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e1.toString(), e1);
if (LOG.isLoggable(WARNING))
LOG.log(WARNING, e1.toString(), e1);
}
// Whatever happens, allow the database to close
connectionsLock.lock();
@@ -679,26 +685,6 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
public void addEndpoint(Connection txn, Endpoint ep) throws DbException {
PreparedStatement ps = null;
try {
String sql = "INSERT INTO endpoints"
+ " (contactId, transportId, epoch, alice)"
+ " VALUES (?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setInt(1, ep.getContactId().getInt());
ps.setString(2, ep.getTransportId().getString());
ps.setLong(3, ep.getEpoch());
ps.setBoolean(4, ep.getAlice());
int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
}
public boolean addGroup(Connection txn, Group g) throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
@@ -824,52 +810,6 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
public void addSecrets(Connection txn, Collection<TemporarySecret> secrets)
throws DbException {
PreparedStatement ps = null;
try {
// Store the new secrets
String sql = "INSERT INTO secrets (contactId, transportId, period,"
+ " secret, outgoing, centre, bitmap)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
for (TemporarySecret s : secrets) {
ps.setInt(1, s.getContactId().getInt());
ps.setString(2, s.getTransportId().getString());
ps.setLong(3, s.getPeriod());
ps.setBytes(4, s.getSecret());
ps.setLong(5, s.getOutgoingStreamCounter());
ps.setLong(6, s.getWindowCentre());
ps.setBytes(7, s.getWindowBitmap());
ps.addBatch();
}
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != secrets.size())
throw new DbStateException();
for (int i = 0; i < batchAffected.length; i++) {
if (batchAffected[i] != 1) throw new DbStateException();
}
ps.close();
// Delete any obsolete secrets
sql = "DELETE FROM secrets"
+ " WHERE contactId = ? AND transportId = ? AND period < ?";
ps = txn.prepareStatement(sql);
for (TemporarySecret s : secrets) {
ps.setInt(1, s.getContactId().getInt());
ps.setString(2, s.getTransportId().getString());
ps.setLong(3, s.getPeriod() - 2);
ps.addBatch();
}
batchAffected = ps.executeBatch();
if (batchAffected.length != secrets.size())
throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
}
public void addStatus(Connection txn, ContactId c, MessageId m, boolean ack,
boolean seen) throws DbException {
PreparedStatement ps = null;
@@ -947,6 +887,68 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
public void addTransportKeys(Connection txn, ContactId c, TransportKeys k)
throws DbException {
PreparedStatement ps = null;
try {
// Store the incoming keys
String sql = "INSERT INTO incomingKeys (contactId, transportId,"
+ " period, tagKey, headerKey, base, bitmap)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setString(2, k.getTransportId().getString());
// Previous rotation period
IncomingKeys inPrev = k.getPreviousIncomingKeys();
ps.setLong(3, inPrev.getRotationPeriod());
ps.setBytes(4, inPrev.getTagKey().getBytes());
ps.setBytes(5, inPrev.getHeaderKey().getBytes());
ps.setLong(6, inPrev.getWindowBase());
ps.setBytes(7, inPrev.getWindowBitmap());
ps.addBatch();
// Current rotation period
IncomingKeys inCurr = k.getCurrentIncomingKeys();
ps.setLong(3, inCurr.getRotationPeriod());
ps.setBytes(4, inCurr.getTagKey().getBytes());
ps.setBytes(5, inCurr.getHeaderKey().getBytes());
ps.setLong(6, inCurr.getWindowBase());
ps.setBytes(7, inCurr.getWindowBitmap());
ps.addBatch();
// Next rotation period
IncomingKeys inNext = k.getNextIncomingKeys();
ps.setLong(3, inNext.getRotationPeriod());
ps.setBytes(4, inNext.getTagKey().getBytes());
ps.setBytes(5, inNext.getHeaderKey().getBytes());
ps.setLong(6, inNext.getWindowBase());
ps.setBytes(7, inNext.getWindowBitmap());
ps.addBatch();
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != 3) throw new DbStateException();
for (int i = 0; i < batchAffected.length; i++) {
if (batchAffected[i] != 1) throw new DbStateException();
}
ps.close();
// Store the outgoing keys
sql = "INSERT INTO outgoingKeys (contactId, transportId, period,"
+ " tagKey, headerKey, stream)"
+ " VALUES (?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setString(2, k.getTransportId().getString());
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
ps.setLong(3, outCurr.getRotationPeriod());
ps.setBytes(4, outCurr.getTagKey().getBytes());
ps.setBytes(5, outCurr.getHeaderKey().getBytes());
ps.setLong(6, outCurr.getStreamCounter());
int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
}
public void addVisibility(Connection txn, ContactId c, GroupId g)
throws DbException {
PreparedStatement ps = null;
@@ -1326,32 +1328,6 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
public Collection<Endpoint> getEndpoints(Connection txn)
throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT contactId, transportId, epoch, alice"
+ " FROM endpoints";
ps = txn.prepareStatement(sql);
rs = ps.executeQuery();
List<Endpoint> endpoints = new ArrayList<Endpoint>();
while (rs.next()) {
ContactId contactId = new ContactId(rs.getInt(1));
TransportId transportId = new TransportId(rs.getString(2));
long epoch = rs.getLong(3);
boolean alice = rs.getBoolean(4);
endpoints.add(new Endpoint(contactId, transportId, epoch,
alice));
}
return Collections.unmodifiableList(endpoints);
} catch (SQLException e) {
tryToClose(rs);
tryToClose(ps);
throw new DbException(e);
}
}
public Group getGroup(Connection txn, GroupId g) throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
@@ -2098,43 +2074,6 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
public Collection<TemporarySecret> getSecrets(Connection txn)
throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT e.contactId, e.transportId, epoch, alice,"
+ " period, secret, outgoing, centre, bitmap"
+ " FROM endpoints AS e"
+ " JOIN secrets AS s"
+ " ON e.contactId = s.contactId"
+ " AND e.transportId = s.transportId";
ps = txn.prepareStatement(sql);
rs = ps.executeQuery();
List<TemporarySecret> secrets = new ArrayList<TemporarySecret>();
while (rs.next()) {
ContactId contactId = new ContactId(rs.getInt(1));
TransportId transportId = new TransportId(rs.getString(2));
long epoch = rs.getLong(3);
boolean alice = rs.getBoolean(4);
long period = rs.getLong(5);
byte[] secret = rs.getBytes(6);
long outgoing = rs.getLong(7);
long centre = rs.getLong(8);
byte[] bitmap = rs.getBytes(9);
secrets.add(new TemporarySecret(contactId, transportId, epoch,
alice, period, secret, outgoing, centre, bitmap));
}
rs.close();
ps.close();
return Collections.unmodifiableList(secrets);
} catch (SQLException e) {
tryToClose(rs);
tryToClose(ps);
throw new DbException(e);
}
}
public Settings getSettings(Connection txn) throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
@@ -2317,6 +2256,67 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
public Map<ContactId, TransportKeys> getTransportKeys(Connection txn,
TransportId t) throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
// Retrieve the incoming keys
String sql = "SELECT period, tagKey, headerKey, base, bitmap"
+ " FROM incomingKeys"
+ " WHERE transportId = ?"
+ " ORDER BY contactId, period";
ps = txn.prepareStatement(sql);
ps.setString(1, t.getString());
rs = ps.executeQuery();
List<IncomingKeys> inKeys = new ArrayList<IncomingKeys>();
while (rs.next()) {
long rotationPeriod = rs.getLong(1);
SecretKey tagKey = new SecretKey(rs.getBytes(2));
SecretKey headerKey = new SecretKey(rs.getBytes(3));
long windowBase = rs.getLong(4);
byte[] windowBitmap = rs.getBytes(5);
inKeys.add(new IncomingKeys(tagKey, headerKey, rotationPeriod,
windowBase, windowBitmap));
}
rs.close();
ps.close();
// Retrieve the outgoing keys in the same order
sql = "SELECT contactId, period, tagKey, headerKey, stream"
+ " FROM outgoingKeys"
+ " WHERE transportId = ?"
+ " ORDER BY contactId, period";
ps = txn.prepareStatement(sql);
ps.setString(1, t.getString());
rs = ps.executeQuery();
Map<ContactId, TransportKeys> keys =
new HashMap<ContactId, TransportKeys>();
for (int i = 0; rs.next(); i++) {
// There should be three times as many incoming keys
if (inKeys.size() < (i + 1) * 3) throw new DbStateException();
ContactId contactId = new ContactId(rs.getInt(1));
long rotationPeriod = rs.getLong(2);
SecretKey tagKey = new SecretKey(rs.getBytes(3));
SecretKey headerKey = new SecretKey(rs.getBytes(4));
long streamCounter = rs.getLong(5);
OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey,
rotationPeriod, streamCounter);
IncomingKeys inPrev = inKeys.get(i * 3);
IncomingKeys inCurr = inKeys.get(i * 3 + 1);
IncomingKeys inNext = inKeys.get(i * 3 + 2);
keys.put(contactId, new TransportKeys(t, inPrev, inCurr,
inNext, outCurr));
}
rs.close();
ps.close();
return Collections.unmodifiableMap(keys);
} catch (SQLException e) {
tryToClose(rs);
tryToClose(ps);
throw new DbException(e);
}
}
public Map<TransportId, Integer> getTransportLatencies(Connection txn)
throws DbException {
PreparedStatement ps = null;
@@ -2327,7 +2327,7 @@ abstract class JdbcDatabase implements Database<Connection> {
rs = ps.executeQuery();
Map<TransportId, Integer> latencies =
new HashMap<TransportId, Integer>();
while (rs.next()){
while (rs.next()) {
TransportId id = new TransportId(rs.getString(1));
latencies.put(id, rs.getInt(2));
}
@@ -2392,7 +2392,7 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.setString(3, u.getId().getString());
ps.addBatch();
}
int [] batchAffected = ps.executeBatch();
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != updates.size())
throw new DbStateException();
for (i = 0; i < batchAffected.length; i++) {
@@ -2455,42 +2455,21 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
public long incrementStreamCounter(Connection txn, ContactId c,
TransportId t, long period) throws DbException {
public void incrementStreamCounter(Connection txn, ContactId c,
TransportId t, long rotationPeriod) throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
// Get the current stream counter
String sql = "SELECT outgoing FROM secrets"
String sql = "UPDATE outgoingKeys SET stream = stream + 1"
+ " WHERE contactId = ? AND transportId = ? AND period = ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setString(2, t.getString());
ps.setLong(3, period);
rs = ps.executeQuery();
if (!rs.next()) {
rs.close();
ps.close();
return -1;
}
long streamNumber = rs.getLong(1);
if (rs.next()) throw new DbStateException();
rs.close();
ps.close();
// Increment the stream counter
sql = "UPDATE secrets SET outgoing = outgoing + 1"
+ " WHERE contactId = ? AND transportId = ? AND period = ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setString(2, t.getString());
ps.setLong(3, period);
ps.setLong(3, rotationPeriod);
int affected = ps.executeUpdate();
if (affected != 1) throw new DbStateException();
ps.close();
return streamNumber;
} catch (SQLException e) {
tryToClose(ps);
tryToClose(rs);
throw new DbException(e);
}
}
@@ -2929,18 +2908,19 @@ abstract class JdbcDatabase implements Database<Connection> {
throw new DbException(e);
}
}
public void setReorderingWindow(Connection txn, ContactId c, TransportId t,
long period, long centre, byte[] bitmap) throws DbException {
long rotationPeriod, long base, byte[] bitmap) throws DbException {
PreparedStatement ps = null;
try {
String sql = "UPDATE secrets SET centre = ?, bitmap = ?"
String sql = "UPDATE incomingKeys SET base = ?, bitmap = ?"
+ " WHERE contactId = ? AND transportId = ? AND period = ?";
ps = txn.prepareStatement(sql);
ps.setLong(1, centre);
ps.setLong(1, base);
ps.setBytes(2, bitmap);
ps.setInt(3, c.getInt());
ps.setString(4, t.getString());
ps.setLong(5, period);
ps.setLong(5, rotationPeriod);
int affected = ps.executeUpdate();
if (affected < 0 || affected > 1) throw new DbStateException();
ps.close();
@@ -3139,7 +3119,7 @@ abstract class JdbcDatabase implements Database<Connection> {
public boolean setRemoteProperties(Connection txn, ContactId c,
TransportId t, TransportProperties p, long version)
throws DbException {
throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
@@ -3354,4 +3334,46 @@ abstract class JdbcDatabase implements Database<Connection> {
throw new DbException(e);
}
}
public void updateTransportKeys(Connection txn,
Map<ContactId, TransportKeys> keys) throws DbException {
PreparedStatement ps = null;
try {
// Delete any existing incoming keys
String sql = "DELETE FROM incomingKeys"
+ " WHERE contactId = ?"
+ " AND transportId = ?";
ps = txn.prepareStatement(sql);
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
ps.setInt(1, e.getKey().getInt());
ps.setString(2, e.getValue().getTransportId().getString());
ps.addBatch();
}
int[] batchAffected = ps.executeBatch();
if (batchAffected.length != keys.size())
throw new DbStateException();
ps.close();
// Delete any existing outgoing keys
sql = "DELETE FROM outgoingKeys"
+ " WHERE contactId = ?"
+ " AND transportId = ?";
ps = txn.prepareStatement(sql);
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
ps.setInt(1, e.getKey().getInt());
ps.setString(2, e.getValue().getTransportId().getString());
ps.addBatch();
}
batchAffected = ps.executeBatch();
if (batchAffected.length != keys.size())
throw new DbStateException();
ps.close();
} catch (SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
// Store the new keys
for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
addTransportKeys(txn, e.getKey(), e.getValue());
}
}
}

View File

@@ -1,23 +1,13 @@
package org.briarproject.invitation;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import java.util.Map;
import java.util.logging.Logger;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorFactory;
import org.briarproject.api.LocalAuthor;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.crypto.PseudoRandom;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.data.Reader;
import org.briarproject.api.data.ReaderFactory;
import org.briarproject.api.data.Writer;
@@ -29,9 +19,20 @@ import org.briarproject.api.plugins.ConnectionManager;
import org.briarproject.api.plugins.duplex.DuplexPlugin;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.KeyManager;
import org.briarproject.api.transport.StreamReaderFactory;
import org.briarproject.api.transport.StreamWriterFactory;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import java.util.Map;
import java.util.logging.Logger;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
/** A connection thread for the peer being Alice in the invitation protocol. */
class AliceConnector extends Connector {
@@ -49,9 +50,9 @@ class AliceConnector extends Connector {
Map<TransportId, TransportProperties> localProps,
PseudoRandom random) {
super(crypto, db, readerFactory, writerFactory, streamReaderFactory,
streamWriterFactory, authorFactory, groupFactory, keyManager,
connectionManager, clock, reuseConnection, group, plugin,
localAuthor, localProps, random);
streamWriterFactory, authorFactory, groupFactory,
keyManager, connectionManager, clock, reuseConnection, group,
plugin, localAuthor, localProps, random);
}
@Override
@@ -71,7 +72,7 @@ class AliceConnector extends Connector {
OutputStream out;
Reader r;
Writer w;
byte[] secret;
SecretKey master;
try {
in = conn.getReader().getInputStream();
out = conn.getWriter().getOutputStream();
@@ -82,7 +83,7 @@ class AliceConnector extends Connector {
byte[] hash = receivePublicKeyHash(r);
sendPublicKey(w);
byte[] key = receivePublicKey(r);
secret = deriveMasterSecret(hash, key, true);
master = deriveMasterSecret(hash, key, true);
} catch (IOException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
group.keyAgreementFailed();
@@ -96,8 +97,8 @@ class AliceConnector extends Connector {
}
// The key agreement succeeded - derive the confirmation codes
if (LOG.isLoggable(INFO)) LOG.info(pluginName + " agreement succeeded");
int[] codes = crypto.deriveConfirmationCodes(secret);
int aliceCode = codes[0], bobCode = codes[1];
int aliceCode = crypto.deriveConfirmationCode(master, true);
int bobCode = crypto.deriveConfirmationCode(master, false);
group.keyAgreementSucceeded(aliceCode, bobCode);
// Exchange confirmation results
boolean localMatched, remoteMatched;
@@ -130,19 +131,22 @@ class AliceConnector extends Connector {
// Confirmation succeeded - upgrade to a secure connection
if (LOG.isLoggable(INFO))
LOG.info(pluginName + " confirmation succeeded");
// Derive the header keys
SecretKey aliceHeaderKey = crypto.deriveInvitationKey(master, true);
SecretKey bobHeaderKey = crypto.deriveInvitationKey(master, false);
// Create the readers
InputStream streamReader =
streamReaderFactory.createInvitationStreamReader(in,
secret, false); // Bob's stream
bobHeaderKey);
r = readerFactory.createReader(streamReader);
// Create the writers
OutputStream streamWriter =
streamWriterFactory.createInvitationStreamWriter(out,
secret, true); // Alice's stream
aliceHeaderKey);
w = writerFactory.createWriter(streamWriter);
// Derive the invitation nonces
byte[][] nonces = crypto.deriveInvitationNonces(secret);
byte[] aliceNonce = nonces[0], bobNonce = nonces[1];
byte[] aliceNonce = crypto.deriveSignatureNonce(master, true);
byte[] bobNonce = crypto.deriveSignatureNonce(master, false);
// Exchange pseudonyms, signed nonces, timestamps and transports
Author remoteAuthor;
long remoteTimestamp;
@@ -171,11 +175,11 @@ class AliceConnector extends Connector {
tryToClose(conn, true);
return;
}
// The epoch is the minimum of the peers' timestamps
long epoch = Math.min(localTimestamp, remoteTimestamp);
// The agreed timestamp is the minimum of the peers' timestamps
long timestamp = Math.min(localTimestamp, remoteTimestamp);
// Add the contact and store the transports
try {
addContact(remoteAuthor, remoteProps, secret, epoch, true);
addContact(remoteAuthor, remoteProps, master, timestamp, true);
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
tryToClose(conn, true);
@@ -190,4 +194,4 @@ class AliceConnector extends Connector {
LOG.info(pluginName + " pseudonym exchange succeeded");
group.pseudonymExchangeSucceeded(remoteAuthor);
}
}
}

View File

@@ -1,23 +1,13 @@
package org.briarproject.invitation;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import java.util.Map;
import java.util.logging.Logger;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorFactory;
import org.briarproject.api.LocalAuthor;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.crypto.PseudoRandom;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.data.Reader;
import org.briarproject.api.data.ReaderFactory;
import org.briarproject.api.data.Writer;
@@ -29,9 +19,20 @@ import org.briarproject.api.plugins.ConnectionManager;
import org.briarproject.api.plugins.duplex.DuplexPlugin;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.KeyManager;
import org.briarproject.api.transport.StreamReaderFactory;
import org.briarproject.api.transport.StreamWriterFactory;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import java.util.Map;
import java.util.logging.Logger;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
/** A connection thread for the peer being Bob in the invitation protocol. */
class BobConnector extends Connector {
@@ -49,9 +50,9 @@ class BobConnector extends Connector {
Map<TransportId, TransportProperties> localProps,
PseudoRandom random) {
super(crypto, db, readerFactory, writerFactory, streamReaderFactory,
streamWriterFactory, authorFactory, groupFactory, keyManager,
connectionManager, clock, reuseConnection, group, plugin,
localAuthor, localProps, random);
streamWriterFactory, authorFactory, groupFactory,
keyManager, connectionManager, clock, reuseConnection, group,
plugin, localAuthor, localProps, random);
}
@Override
@@ -65,7 +66,7 @@ class BobConnector extends Connector {
OutputStream out;
Reader r;
Writer w;
byte[] secret;
SecretKey master;
try {
in = conn.getReader().getInputStream();
out = conn.getWriter().getOutputStream();
@@ -82,7 +83,7 @@ class BobConnector extends Connector {
sendPublicKeyHash(w);
byte[] key = receivePublicKey(r);
sendPublicKey(w);
secret = deriveMasterSecret(hash, key, false);
master = deriveMasterSecret(hash, key, false);
} catch (IOException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
group.keyAgreementFailed();
@@ -96,8 +97,8 @@ class BobConnector extends Connector {
}
// The key agreement succeeded - derive the confirmation codes
if (LOG.isLoggable(INFO)) LOG.info(pluginName + " agreement succeeded");
int[] codes = crypto.deriveConfirmationCodes(secret);
int aliceCode = codes[0], bobCode = codes[1];
int aliceCode = crypto.deriveConfirmationCode(master, true);
int bobCode = crypto.deriveConfirmationCode(master, false);
group.keyAgreementSucceeded(bobCode, aliceCode);
// Exchange confirmation results
boolean localMatched, remoteMatched;
@@ -130,19 +131,22 @@ class BobConnector extends Connector {
// Confirmation succeeded - upgrade to a secure connection
if (LOG.isLoggable(INFO))
LOG.info(pluginName + " confirmation succeeded");
// Derive the header keys
SecretKey aliceHeaderKey = crypto.deriveInvitationKey(master, true);
SecretKey bobHeaderKey = crypto.deriveInvitationKey(master, false);
// Create the readers
InputStream streamReader =
streamReaderFactory.createInvitationStreamReader(in,
secret, true); // Alice's stream
aliceHeaderKey);
r = readerFactory.createReader(streamReader);
// Create the writers
OutputStream streamWriter =
streamWriterFactory.createInvitationStreamWriter(out,
secret, false); // Bob's stream
bobHeaderKey);
w = writerFactory.createWriter(streamWriter);
// Derive the nonces
byte[][] nonces = crypto.deriveInvitationNonces(secret);
byte[] aliceNonce = nonces[0], bobNonce = nonces[1];
byte[] aliceNonce = crypto.deriveSignatureNonce(master, true);
byte[] bobNonce = crypto.deriveSignatureNonce(master, false);
// Exchange pseudonyms, signed nonces, timestamps and transports
Author remoteAuthor;
long remoteTimestamp;
@@ -171,11 +175,11 @@ class BobConnector extends Connector {
tryToClose(conn, true);
return;
}
// The epoch is the minimum of the peers' timestamps
long epoch = Math.min(localTimestamp, remoteTimestamp);
// The agreed timestamp is the minimum of the peers' timestamps
long timestamp = Math.min(localTimestamp, remoteTimestamp);
// Add the contact and store the transports
try {
addContact(remoteAuthor, remoteProps, secret, epoch, false);
addContact(remoteAuthor, remoteProps, master, timestamp, false);
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
tryToClose(conn, true);

View File

@@ -1,5 +1,46 @@
package org.briarproject.invitation;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorFactory;
import org.briarproject.api.ContactId;
import org.briarproject.api.FormatException;
import org.briarproject.api.LocalAuthor;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyPair;
import org.briarproject.api.crypto.KeyParser;
import org.briarproject.api.crypto.MessageDigest;
import org.briarproject.api.crypto.PseudoRandom;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.crypto.Signature;
import org.briarproject.api.data.Reader;
import org.briarproject.api.data.ReaderFactory;
import org.briarproject.api.data.Writer;
import org.briarproject.api.data.WriterFactory;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.GroupFactory;
import org.briarproject.api.plugins.ConnectionManager;
import org.briarproject.api.plugins.duplex.DuplexPlugin;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.KeyManager;
import org.briarproject.api.transport.StreamReaderFactory;
import org.briarproject.api.transport.StreamWriterFactory;
import org.briarproject.api.transport.TransportKeys;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.logging.Logger;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import static org.briarproject.api.AuthorConstants.MAX_AUTHOR_NAME_LENGTH;
@@ -9,50 +50,9 @@ import static org.briarproject.api.TransportPropertyConstants.MAX_PROPERTIES_PER
import static org.briarproject.api.TransportPropertyConstants.MAX_PROPERTY_LENGTH;
import static org.briarproject.api.TransportPropertyConstants.MAX_TRANSPORT_ID_LENGTH;
import static org.briarproject.api.invitation.InvitationConstants.CONNECTION_TIMEOUT;
import static org.briarproject.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.logging.Logger;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorFactory;
import org.briarproject.api.ContactId;
import org.briarproject.api.FormatException;
import org.briarproject.api.LocalAuthor;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.crypto.KeyPair;
import org.briarproject.api.crypto.KeyParser;
import org.briarproject.api.crypto.MessageDigest;
import org.briarproject.api.crypto.PseudoRandom;
import org.briarproject.api.crypto.Signature;
import org.briarproject.api.data.Reader;
import org.briarproject.api.data.ReaderFactory;
import org.briarproject.api.data.Writer;
import org.briarproject.api.data.WriterFactory;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.db.NoSuchTransportException;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.GroupFactory;
import org.briarproject.api.plugins.ConnectionManager;
import org.briarproject.api.plugins.duplex.DuplexPlugin;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.Endpoint;
import org.briarproject.api.transport.StreamReaderFactory;
import org.briarproject.api.transport.StreamWriterFactory;
// FIXME: This class has way too many dependencies
abstract class Connector extends Thread {
private static final Logger LOG =
@@ -152,8 +152,8 @@ abstract class Connector extends Thread {
return b;
}
protected byte[] deriveMasterSecret(byte[] hash, byte[] key, boolean alice)
throws GeneralSecurityException {
protected SecretKey deriveMasterSecret(byte[] hash, byte[] key,
boolean alice) throws GeneralSecurityException {
// Check that the hash matches the key
if (!Arrays.equals(hash, messageDigest.digest(key))) {
if (LOG.isLoggable(INFO))
@@ -271,39 +271,34 @@ abstract class Connector extends Thread {
}
protected void addContact(Author remoteAuthor,
Map<TransportId, TransportProperties> remoteProps, byte[] secret,
long epoch, boolean alice) throws DbException {
Map<TransportId, TransportProperties> remoteProps, SecretKey master,
long timestamp, boolean alice) throws DbException {
// Add the contact to the database
contactId = db.addContact(remoteAuthor, localAuthor.getId());
// Create and store the inbox group
byte[] salt = crypto.deriveGroupSalt(secret);
byte[] salt = crypto.deriveGroupSalt(master);
Group inbox = groupFactory.createGroup("Inbox", salt);
db.addGroup(inbox);
db.setInboxGroup(contactId, inbox);
// Store the remote transport properties
db.setRemoteProperties(contactId, remoteProps);
// Create an endpoint for each transport shared with the contact
List<TransportId> ids = new ArrayList<TransportId>();
// Derive transport keys for each transport shared with the contact
Map<TransportId, Integer> latencies = db.getTransportLatencies();
for (TransportId id : localProps.keySet()) {
if (latencies.containsKey(id) && remoteProps.containsKey(id))
ids.add(id);
}
// Assign indices to the transports deterministically and derive keys
Collections.sort(ids, TransportIdComparator.INSTANCE);
int size = ids.size();
for (int i = 0; i < size; i++) {
TransportId id = ids.get(i);
Endpoint ep = new Endpoint(contactId, id, epoch, alice);
int maxLatency = latencies.get(id);
try {
db.addEndpoint(ep);
} catch (NoSuchTransportException e) {
continue;
List<TransportKeys> keys = new ArrayList<TransportKeys>();
for (TransportId t : localProps.keySet()) {
if (remoteProps.containsKey(t) && latencies.containsKey(t)) {
// Work out what rotation period the timestamp belongs to
long latency = latencies.get(t);
long rotationPeriodLength = latency + MAX_CLOCK_DIFFERENCE;
long rotationPeriod = timestamp / rotationPeriodLength;
// Derive the transport keys
TransportKeys k = crypto.deriveTransportKeys(t, master,
rotationPeriod, alice);
db.addTransportKeys(contactId, k);
keys.add(k);
}
byte[] initialSecret = crypto.deriveInitialSecret(secret, i);
keyManager.endpointAdded(ep, maxLatency, initialSecret);
}
keyManager.contactAdded(contactId, keys);
}
protected void tryToClose(DuplexTransportConnection conn,
@@ -322,16 +317,4 @@ abstract class Connector extends Thread {
TransportId t = plugin.getId();
connectionManager.manageOutgoingConnection(contactId, t, conn);
}
private static class TransportIdComparator
implements Comparator<TransportId> {
private static final TransportIdComparator INSTANCE =
new TransportIdComparator();
public int compare(TransportId t1, TransportId t2) {
String s1 = t1.getString(), s2 = t2.getString();
return String.CASE_INSENSITIVE_ORDER.compare(s1, s2);
}
}
}

View File

@@ -1,19 +1,5 @@
package org.briarproject.invitation;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.logging.Level.WARNING;
import static org.briarproject.api.invitation.InvitationConstants.CONFIRMATION_TIMEOUT;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorFactory;
import org.briarproject.api.AuthorId;
@@ -21,7 +7,6 @@ import org.briarproject.api.LocalAuthor;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.crypto.PseudoRandom;
import org.briarproject.api.data.ReaderFactory;
import org.briarproject.api.data.WriterFactory;
@@ -35,9 +20,24 @@ import org.briarproject.api.plugins.ConnectionManager;
import org.briarproject.api.plugins.PluginManager;
import org.briarproject.api.plugins.duplex.DuplexPlugin;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.KeyManager;
import org.briarproject.api.transport.StreamReaderFactory;
import org.briarproject.api.transport.StreamWriterFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.logging.Level.WARNING;
import static org.briarproject.api.invitation.InvitationConstants.CONFIRMATION_TIMEOUT;
/** A task consisting of one or more parallel connection attempts. */
class ConnectorGroup extends Thread implements InvitationTask {

View File

@@ -5,7 +5,6 @@ import javax.inject.Inject;
import org.briarproject.api.AuthorFactory;
import org.briarproject.api.AuthorId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.data.ReaderFactory;
import org.briarproject.api.data.WriterFactory;
import org.briarproject.api.db.DatabaseComponent;
@@ -15,6 +14,7 @@ import org.briarproject.api.messaging.GroupFactory;
import org.briarproject.api.plugins.ConnectionManager;
import org.briarproject.api.plugins.PluginManager;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.KeyManager;
import org.briarproject.api.transport.StreamReaderFactory;
import org.briarproject.api.transport.StreamWriterFactory;

View File

@@ -14,7 +14,6 @@ import javax.inject.Inject;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.db.DbException;
import org.briarproject.api.lifecycle.IoExecutor;
import org.briarproject.api.messaging.MessagingSession;
@@ -24,10 +23,10 @@ import org.briarproject.api.plugins.ConnectionRegistry;
import org.briarproject.api.plugins.TransportConnectionReader;
import org.briarproject.api.plugins.TransportConnectionWriter;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.transport.KeyManager;
import org.briarproject.api.transport.StreamContext;
import org.briarproject.api.transport.StreamReaderFactory;
import org.briarproject.api.transport.StreamWriterFactory;
import org.briarproject.api.transport.TagRecogniser;
class ConnectionManagerImpl implements ConnectionManager {
@@ -36,7 +35,6 @@ class ConnectionManagerImpl implements ConnectionManager {
private final Executor ioExecutor;
private final KeyManager keyManager;
private final TagRecogniser tagRecogniser;
private final StreamReaderFactory streamReaderFactory;
private final StreamWriterFactory streamWriterFactory;
private final MessagingSessionFactory messagingSessionFactory;
@@ -44,14 +42,12 @@ class ConnectionManagerImpl implements ConnectionManager {
@Inject
ConnectionManagerImpl(@IoExecutor Executor ioExecutor,
KeyManager keyManager, TagRecogniser tagRecogniser,
StreamReaderFactory streamReaderFactory,
KeyManager keyManager, StreamReaderFactory streamReaderFactory,
StreamWriterFactory streamWriterFactory,
MessagingSessionFactory messagingSessionFactory,
ConnectionRegistry connectionRegistry) {
this.ioExecutor = ioExecutor;
this.keyManager = keyManager;
this.tagRecogniser = tagRecogniser;
this.streamReaderFactory = streamReaderFactory;
this.streamWriterFactory = streamWriterFactory;
this.messagingSessionFactory = messagingSessionFactory;
@@ -134,7 +130,7 @@ class ConnectionManagerImpl implements ConnectionManager {
StreamContext ctx;
try {
byte[] tag = readTag(transportId, reader);
ctx = tagRecogniser.recogniseTag(transportId, tag);
ctx = keyManager.recogniseTag(transportId, tag);
} catch (IOException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
disposeReader(true, false);
@@ -238,7 +234,7 @@ class ConnectionManagerImpl implements ConnectionManager {
StreamContext ctx;
try {
byte[] tag = readTag(transportId, reader);
ctx = tagRecogniser.recogniseTag(transportId, tag);
ctx = keyManager.recogniseTag(transportId, tag);
} catch (IOException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
disposeReader(true, false);
@@ -367,7 +363,7 @@ class ConnectionManagerImpl implements ConnectionManager {
StreamContext ctx;
try {
byte[] tag = readTag(transportId, reader);
ctx = tagRecogniser.recogniseTag(transportId, tag);
ctx = keyManager.recogniseTag(transportId, tag);
} catch (IOException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
disposeReader(true, true);
@@ -420,4 +416,4 @@ class ConnectionManagerImpl implements ConnectionManager {
}
}
}
}
}

View File

@@ -1,22 +1,5 @@
package org.briarproject.plugins;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportConfig;
import org.briarproject.api.TransportId;
@@ -42,6 +25,23 @@ import org.briarproject.api.plugins.simplex.SimplexPluginFactory;
import org.briarproject.api.system.Clock;
import org.briarproject.api.ui.UiCallback;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
import javax.inject.Inject;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
class PluginManagerImpl implements PluginManager {
private static final Logger LOG =
@@ -367,7 +367,7 @@ class PluginManagerImpl implements PluginManager {
}
private class SimplexCallback extends PluginCallbackImpl
implements SimplexPluginCallback {
implements SimplexPluginCallback {
private SimplexCallback(TransportId id) {
super(id);
@@ -383,7 +383,7 @@ class PluginManagerImpl implements PluginManager {
}
private class DuplexCallback extends PluginCallbackImpl
implements DuplexPluginCallback {
implements DuplexPluginCallback {
private DuplexCallback(TransportId id) {
super(id);

View File

@@ -1,27 +1,10 @@
package org.briarproject.transport;
import static java.util.logging.Level.WARNING;
import static org.briarproject.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.TimerTask;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DatabaseExecutor;
import org.briarproject.api.db.DbException;
import org.briarproject.api.event.ContactRemovedEvent;
import org.briarproject.api.event.Event;
@@ -31,429 +14,115 @@ import org.briarproject.api.event.TransportAddedEvent;
import org.briarproject.api.event.TransportRemovedEvent;
import org.briarproject.api.system.Clock;
import org.briarproject.api.system.Timer;
import org.briarproject.api.transport.Endpoint;
import org.briarproject.api.transport.KeyManager;
import org.briarproject.api.transport.StreamContext;
import org.briarproject.api.transport.TagRecogniser;
import org.briarproject.api.transport.TemporarySecret;
import org.briarproject.api.transport.TransportKeys;
// FIXME: Don't make alien calls with a lock held
class KeyManagerImpl extends TimerTask implements KeyManager, EventListener {
import java.util.Collection;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
private static final int MS_BETWEEN_CHECKS = 60 * 1000;
import javax.inject.Inject;
import static java.util.logging.Level.WARNING;
class KeyManagerImpl implements KeyManager, EventListener {
private static final Logger LOG =
Logger.getLogger(KeyManagerImpl.class.getName());
private final CryptoComponent crypto;
private final DatabaseComponent db;
private final CryptoComponent crypto;
private final Executor dbExecutor;
private final EventBus eventBus;
private final TagRecogniser tagRecogniser;
private final Clock clock;
private final Timer timer;
private final Lock lock = new ReentrantLock();
// The following are locking: lock
private final Map<TransportId, Integer> maxLatencies;
private final Map<EndpointKey, TemporarySecret> oldSecrets;
private final Map<EndpointKey, TemporarySecret> currentSecrets;
private final Map<EndpointKey, TemporarySecret> newSecrets;
private final Clock clock;
private final ConcurrentHashMap<TransportId, TransportKeyManager> managers;
@Inject
KeyManagerImpl(CryptoComponent crypto, DatabaseComponent db,
EventBus eventBus, TagRecogniser tagRecogniser, Clock clock,
Timer timer) {
this.crypto = crypto;
KeyManagerImpl(DatabaseComponent db, CryptoComponent crypto,
@DatabaseExecutor Executor dbExecutor, EventBus eventBus,
Timer timer, Clock clock) {
this.db = db;
this.crypto = crypto;
this.dbExecutor = dbExecutor;
this.eventBus = eventBus;
this.tagRecogniser = tagRecogniser;
this.clock = clock;
this.timer = timer;
maxLatencies = new HashMap<TransportId, Integer>();
oldSecrets = new HashMap<EndpointKey, TemporarySecret>();
currentSecrets = new HashMap<EndpointKey, TemporarySecret>();
newSecrets = new HashMap<EndpointKey, TemporarySecret>();
this.clock = clock;
managers = new ConcurrentHashMap<TransportId, TransportKeyManager>();
}
public boolean start() {
lock.lock();
eventBus.addListener(this);
try {
eventBus.addListener(this);
// Load the temporary secrets and transport latencies from the DB
Collection<TemporarySecret> secrets;
try {
secrets = db.getSecrets();
maxLatencies.putAll(db.getTransportLatencies());
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return false;
}
// Work out what phase of its lifecycle each secret is in
long now = clock.currentTimeMillis();
Collection<TemporarySecret> dead =
assignSecretsToMaps(now, secrets);
// Replace any dead secrets
Collection<TemporarySecret> created = replaceDeadSecrets(now, dead);
if (!created.isEmpty()) {
// Store any secrets that have been created,
// removing any dead ones
try {
db.addSecrets(created);
} catch (DbException e) {
if (LOG.isLoggable(WARNING))
LOG.log(WARNING, e.toString(), e);
return false;
}
}
// Pass the old, current and new secrets to the recogniser
for (TemporarySecret s : oldSecrets.values())
tagRecogniser.addSecret(s);
for (TemporarySecret s : currentSecrets.values())
tagRecogniser.addSecret(s);
for (TemporarySecret s : newSecrets.values())
tagRecogniser.addSecret(s);
// Schedule periodic key rotation
timer.scheduleAtFixedRate(this, MS_BETWEEN_CHECKS,
MS_BETWEEN_CHECKS);
return true;
} finally {
lock.unlock();
Map<TransportId, Integer> latencies = db.getTransportLatencies();
for (Entry<TransportId, Integer> e : latencies.entrySet())
addTransport(e.getKey(), e.getValue());
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return false;
}
}
// Assigns secrets to the appropriate maps and returns any dead secrets
// Locking: lock
private Collection<TemporarySecret> assignSecretsToMaps(long now,
Collection<TemporarySecret> secrets) {
Collection<TemporarySecret> dead = new ArrayList<TemporarySecret>();
for (TemporarySecret s : secrets) {
// Discard the secret if the transport has been removed
Integer maxLatency = maxLatencies.get(s.getTransportId());
if (maxLatency == null) {
LOG.info("Discarding obsolete secret");
continue;
}
long rotation = maxLatency + MAX_CLOCK_DIFFERENCE;
long creationTime = s.getEpoch() + rotation * (s.getPeriod() - 2);
long activationTime = creationTime + rotation;
long deactivationTime = activationTime + rotation;
long destructionTime = deactivationTime + rotation;
if (now >= destructionTime) {
dead.add(s);
} else if (now >= deactivationTime) {
oldSecrets.put(new EndpointKey(s), s);
} else if (now >= activationTime) {
currentSecrets.put(new EndpointKey(s), s);
} else if (now >= creationTime) {
newSecrets.put(new EndpointKey(s), s);
} else {
// FIXME: Work out what to do here
throw new Error("Clock has moved backwards");
}
}
return dead;
}
// Replaces the given secrets and returns any secrets created
// Locking: lock
private Collection<TemporarySecret> replaceDeadSecrets(long now,
Collection<TemporarySecret> dead) {
// If there are several dead secrets for an endpoint, use the newest
Map<EndpointKey, TemporarySecret> newest =
new HashMap<EndpointKey, TemporarySecret>();
for (TemporarySecret s : dead) {
EndpointKey k = new EndpointKey(s);
TemporarySecret exists = newest.get(k);
if (exists == null) {
// There's no other secret for this endpoint
newest.put(k, s);
} else if (exists.getPeriod() < s.getPeriod()) {
// There's an older secret - use this one instead
newest.put(k, s);
} else {
// There's a newer secret - keep using it
}
}
Collection<TemporarySecret> created = new ArrayList<TemporarySecret>();
for (Entry<EndpointKey, TemporarySecret> e : newest.entrySet()) {
TemporarySecret s = e.getValue();
Integer maxLatency = maxLatencies.get(s.getTransportId());
if (maxLatency == null) throw new IllegalStateException();
// Work out which rotation period we're in
long elapsed = now - s.getEpoch();
long rotation = maxLatency + MAX_CLOCK_DIFFERENCE;
long period = (elapsed / rotation) + 1;
if (period < 1) throw new IllegalStateException();
if (period - s.getPeriod() < 2)
throw new IllegalStateException();
// Derive the old, current and new secrets
byte[] b1 = s.getSecret();
for (long p = s.getPeriod() + 1; p < period; p++)
b1 = crypto.deriveNextSecret(b1, p);
byte[] b2 = crypto.deriveNextSecret(b1, period);
byte[] b3 = crypto.deriveNextSecret(b2, period + 1);
// Add the secrets to their respective maps if not already present
EndpointKey k = e.getKey();
if (!oldSecrets.containsKey(k)) {
TemporarySecret s1 = new TemporarySecret(s, period - 1, b1);
oldSecrets.put(k, s1);
created.add(s1);
}
if (!currentSecrets.containsKey(k)) {
TemporarySecret s2 = new TemporarySecret(s, period, b2);
currentSecrets.put(k, s2);
created.add(s2);
}
if (!newSecrets.containsKey(k)) {
TemporarySecret s3 = new TemporarySecret(s, period + 1, b3);
newSecrets.put(k, s3);
created.add(s3);
}
}
return created;
return true;
}
public boolean stop() {
lock.lock();
try {
eventBus.removeListener(this);
timer.cancel();
tagRecogniser.removeSecrets();
maxLatencies.clear();
oldSecrets.clear();
currentSecrets.clear();
newSecrets.clear();
return true;
} finally {
lock.unlock();
eventBus.removeListener(this);
return true;
}
public void contactAdded(ContactId c, Collection<TransportKeys> keys) {
for (TransportKeys k : keys) {
TransportKeyManager m = managers.get(k.getTransportId());
if (m != null) m.addContact(c, k);
}
}
public StreamContext getStreamContext(ContactId c,
TransportId t) {
lock.lock();
try {
TemporarySecret s = currentSecrets.get(new EndpointKey(c, t));
if (s == null) {
LOG.info("No secret for endpoint");
return null;
}
long streamNumber;
try {
streamNumber = db.incrementStreamCounter(c, t, s.getPeriod());
if (streamNumber == -1) {
LOG.info("No counter for period");
return null;
}
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return null;
}
byte[] secret = s.getSecret();
return new StreamContext(c, t, secret, streamNumber, s.getAlice());
} finally {
lock.unlock();
}
public StreamContext getStreamContext(ContactId c, TransportId t) {
TransportKeyManager m = managers.get(t);
return m == null ? null : m.getStreamContext(c);
}
public void endpointAdded(Endpoint ep, int maxLatency,
byte[] initialSecret) {
lock.lock();
try {
maxLatencies.put(ep.getTransportId(), maxLatency);
// Work out which rotation period we're in
long elapsed = clock.currentTimeMillis() - ep.getEpoch();
long rotation = maxLatency + MAX_CLOCK_DIFFERENCE;
long period = (elapsed / rotation) + 1;
if (period < 1) throw new IllegalStateException();
// Derive the old, current and new secrets
byte[] b1 = initialSecret;
for (long p = 0; p < period; p++)
b1 = crypto.deriveNextSecret(b1, p);
byte[] b2 = crypto.deriveNextSecret(b1, period);
byte[] b3 = crypto.deriveNextSecret(b2, period + 1);
TemporarySecret s1 = new TemporarySecret(ep, period - 1, b1);
TemporarySecret s2 = new TemporarySecret(ep, period, b2);
TemporarySecret s3 = new TemporarySecret(ep, period + 1, b3);
// Add the incoming secrets to their respective maps
EndpointKey k = new EndpointKey(ep);
oldSecrets.put(k, s1);
currentSecrets.put(k, s2);
newSecrets.put(k, s3);
// Store the new secrets
try {
db.addSecrets(Arrays.asList(s1, s2, s3));
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return;
}
// Pass the new secrets to the recogniser
tagRecogniser.addSecret(s1);
tagRecogniser.addSecret(s2);
tagRecogniser.addSecret(s3);
} finally {
lock.unlock();
}
}
@Override
public void run() {
lock.lock();
try {
// Rebuild the maps because we may be running a whole period late
Collection<TemporarySecret> secrets = new ArrayList<TemporarySecret>();
secrets.addAll(oldSecrets.values());
secrets.addAll(currentSecrets.values());
secrets.addAll(newSecrets.values());
oldSecrets.clear();
currentSecrets.clear();
newSecrets.clear();
// Work out what phase of its lifecycle each secret is in
long now = clock.currentTimeMillis();
Collection<TemporarySecret> dead = assignSecretsToMaps(now, secrets);
// Remove any dead secrets from the recogniser
for (TemporarySecret s : dead) {
ContactId c = s.getContactId();
TransportId t = s.getTransportId();
long period = s.getPeriod();
tagRecogniser.removeSecret(c, t, period);
}
// Replace any dead secrets
Collection<TemporarySecret> created = replaceDeadSecrets(now, dead);
if (!created.isEmpty()) {
// Store any secrets that have been created
try {
db.addSecrets(created);
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
// Pass any secrets that have been created to the recogniser
for (TemporarySecret s : created) tagRecogniser.addSecret(s);
}
} finally {
lock.unlock();
}
public StreamContext recogniseTag(TransportId t, byte[] tag)
throws DbException {
TransportKeyManager m = managers.get(t);
return m == null ? null : m.recogniseTag(tag);
}
public void eventOccurred(Event e) {
if (e instanceof ContactRemovedEvent) {
ContactRemovedEvent c = (ContactRemovedEvent) e;
timer.schedule(new ContactRemovedTask(c), 0);
} else if (e instanceof TransportAddedEvent) {
if (e instanceof TransportAddedEvent) {
TransportAddedEvent t = (TransportAddedEvent) e;
timer.schedule(new TransportAddedTask(t), 0);
addTransport(t.getTransportId(), t.getMaxLatency());
} else if (e instanceof TransportRemovedEvent) {
TransportRemovedEvent t = (TransportRemovedEvent) e;
timer.schedule(new TransportRemovedTask(t), 0);
removeTransport(((TransportRemovedEvent) e).getTransportId());
} else if (e instanceof ContactRemovedEvent) {
removeContact(((ContactRemovedEvent) e).getContactId());
}
}
// Locking: lock
private void removeSecrets(ContactId c, Map<?, TemporarySecret> m) {
Iterator<TemporarySecret> it = m.values().iterator();
while (it.hasNext())
if (it.next().getContactId().equals(c)) it.remove();
}
// Locking: lock
private void removeSecrets(TransportId t, Map<?, TemporarySecret> m) {
Iterator<TemporarySecret> it = m.values().iterator();
while (it.hasNext())
if (it.next().getTransportId().equals(t)) it.remove();
}
private static class EndpointKey {
private final ContactId contactId;
private final TransportId transportId;
private EndpointKey(ContactId contactId, TransportId transportId) {
this.contactId = contactId;
this.transportId = transportId;
}
private EndpointKey(Endpoint ep) {
this(ep.getContactId(), ep.getTransportId());
}
@Override
public int hashCode() {
return contactId.hashCode() ^ transportId.hashCode();
}
@Override
public boolean equals(Object o) {
if (o instanceof EndpointKey) {
EndpointKey k = (EndpointKey) o;
return contactId.equals(k.contactId) &&
transportId.equals(k.transportId);
private void addTransport(final TransportId t, final int maxLatency) {
dbExecutor.execute(new Runnable() {
public void run() {
TransportKeyManager m = new TransportKeyManager(db, crypto,
dbExecutor, timer, clock, t, maxLatency);
// Don't add transport twice if event is received during startup
if (managers.putIfAbsent(t, m) == null) m.start();
}
return false;
}
});
}
private class ContactRemovedTask extends TimerTask {
private final ContactRemovedEvent event;
private ContactRemovedTask(ContactRemovedEvent event) {
this.event = event;
}
@Override
public void run() {
ContactId c = event.getContactId();
tagRecogniser.removeSecrets(c);
lock.lock();
try {
removeSecrets(c, oldSecrets);
removeSecrets(c, currentSecrets);
removeSecrets(c, newSecrets);
} finally {
lock.unlock();
}
}
private void removeTransport(TransportId t) {
managers.remove(t);
}
private class TransportAddedTask extends TimerTask {
private final TransportAddedEvent event;
private TransportAddedTask(TransportAddedEvent event) {
this.event = event;
}
@Override
public void run() {
lock.lock();
try {
maxLatencies.put(event.getTransportId(), event.getMaxLatency());
} finally {
lock.unlock();
private void removeContact(final ContactId c) {
dbExecutor.execute(new Runnable() {
public void run() {
for (TransportKeyManager m : managers.values())
m.removeContact(c);
}
}
}
private class TransportRemovedTask extends TimerTask {
private TransportRemovedEvent event;
private TransportRemovedTask(TransportRemovedEvent event) {
this.event = event;
}
@Override
public void run() {
TransportId t = event.getTransportId();
tagRecogniser.removeSecrets(t);
lock.lock();
try {
maxLatencies.remove(t);
removeSecrets(t, oldSecrets);
removeSecrets(t, currentSecrets);
removeSecrets(t, newSecrets);
} finally {
lock.unlock();
}
}
});
}
}

View File

@@ -0,0 +1,40 @@
package org.briarproject.transport;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.transport.IncomingKeys;
// This class is not thread-safe
class MutableIncomingKeys {
private final SecretKey tagKey, headerKey;
private final long rotationPeriod;
private final ReorderingWindow window;
MutableIncomingKeys(IncomingKeys in) {
tagKey = in.getTagKey();
headerKey = in.getHeaderKey();
rotationPeriod = in.getRotationPeriod();
window = new ReorderingWindow(in.getWindowBase(), in.getWindowBitmap());
}
IncomingKeys snapshot() {
return new IncomingKeys(tagKey, headerKey, rotationPeriod,
window.getBase(), window.getBitmap());
}
SecretKey getTagKey() {
return tagKey;
}
SecretKey getHeaderKey() {
return headerKey;
}
long getRotationPeriod() {
return rotationPeriod;
}
ReorderingWindow getWindow() {
return window;
}
}

View File

@@ -0,0 +1,44 @@
package org.briarproject.transport;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.transport.OutgoingKeys;
// This class is not thread-safe
class MutableOutgoingKeys {
private final SecretKey tagKey, headerKey;
private final long rotationPeriod;
private long streamCounter;
MutableOutgoingKeys(OutgoingKeys out) {
tagKey = out.getTagKey();
headerKey = out.getHeaderKey();
rotationPeriod = out.getRotationPeriod();
streamCounter = out.getStreamCounter();
}
OutgoingKeys snapshot() {
return new OutgoingKeys(tagKey, headerKey, rotationPeriod,
streamCounter);
}
SecretKey getTagKey() {
return tagKey;
}
SecretKey getHeaderKey() {
return headerKey;
}
long getRotationPeriod() {
return rotationPeriod;
}
long getStreamCounter() {
return streamCounter;
}
void incrementStreamCounter() {
streamCounter++;
}
}

View File

@@ -0,0 +1,44 @@
package org.briarproject.transport;
import org.briarproject.api.TransportId;
import org.briarproject.api.transport.TransportKeys;
class MutableTransportKeys {
private final TransportId transportId;
private final MutableIncomingKeys inPrev, inCurr, inNext;
private final MutableOutgoingKeys outCurr;
MutableTransportKeys(TransportKeys k) {
transportId = k.getTransportId();
inPrev = new MutableIncomingKeys(k.getPreviousIncomingKeys());
inCurr = new MutableIncomingKeys(k.getCurrentIncomingKeys());
inNext = new MutableIncomingKeys(k.getNextIncomingKeys());
outCurr = new MutableOutgoingKeys(k.getCurrentOutgoingKeys());
}
TransportKeys snapshot() {
return new TransportKeys(transportId, inPrev.snapshot(),
inCurr.snapshot(), inNext.snapshot(), outCurr.snapshot());
}
TransportId getTransportId() {
return transportId;
}
MutableIncomingKeys getPreviousIncomingKeys() {
return inPrev;
}
MutableIncomingKeys getCurrentIncomingKeys() {
return inCurr;
}
MutableIncomingKeys getNextIncomingKeys() {
return inNext;
}
MutableOutgoingKeys getCurrentOutgoingKeys() {
return outCurr;
}
}

View File

@@ -1,102 +1,98 @@
package org.briarproject.transport;
import static org.briarproject.api.transport.TransportConstants.REORDERING_WINDOW_SIZE;
import static org.briarproject.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import java.util.Collections;
import java.util.List;
import static org.briarproject.util.ByteUtils.MAX_32_BIT_UNSIGNED;
// This class is not thread-safe
class ReorderingWindow {
private final Set<Long> unseen;
private long base;
private boolean[] seen;
private long centre;
ReorderingWindow() {
unseen = new HashSet<Long>();
for (long l = 0; l < REORDERING_WINDOW_SIZE / 2; l++) unseen.add(l);
centre = 0;
}
ReorderingWindow(long centre, byte[] bitmap) {
if (centre < 0 || centre > MAX_32_BIT_UNSIGNED + 1)
ReorderingWindow(long base, byte[] bitmap) {
if (base < 0) throw new IllegalArgumentException();
if (base > MAX_32_BIT_UNSIGNED + 1)
throw new IllegalArgumentException();
if (bitmap.length != REORDERING_WINDOW_SIZE / 8)
throw new IllegalArgumentException();
this.centre = centre;
unseen = new HashSet<Long>();
long bitmapBottom = centre - REORDERING_WINDOW_SIZE / 2;
for (int bytes = 0; bytes < bitmap.length; bytes++) {
for (int bits = 0; bits < 8; bits++) {
long streamNumber = bitmapBottom + bytes * 8 + bits;
if (streamNumber >= 0 && streamNumber <= MAX_32_BIT_UNSIGNED) {
if ((bitmap[bytes] & (128 >> bits)) == 0)
unseen.add(streamNumber);
}
this.base = base;
seen = new boolean[bitmap.length * 8];
for (int i = 0; i < bitmap.length; i++) {
for (int j = 0; j < 8; j++) {
if ((bitmap[i] & (128 >> j)) != 0) seen[i * 8 + j] = true;
}
}
}
boolean isSeen(long streamNumber) {
return !unseen.contains(streamNumber);
}
Collection<Long> setSeen(long streamNumber) {
long bottom = getBottom(centre);
long top = getTop(centre);
if (streamNumber < bottom || streamNumber > top)
throw new IllegalArgumentException();
if (!unseen.remove(streamNumber))
throw new IllegalArgumentException();
Collection<Long> changed = new ArrayList<Long>();
if (streamNumber >= centre) {
centre = streamNumber + 1;
long newBottom = getBottom(centre);
long newTop = getTop(centre);
for (long l = bottom; l < newBottom; l++) {
if (unseen.remove(l)) changed.add(l);
}
for (long l = top + 1; l <= newTop; l++) {
if (unseen.add(l)) changed.add(l);
}
}
return changed;
}
long getCentre() {
return centre;
long getBase() {
return base;
}
byte[] getBitmap() {
byte[] bitmap = new byte[REORDERING_WINDOW_SIZE / 8];
long bitmapBottom = centre - REORDERING_WINDOW_SIZE / 2;
for (int bytes = 0; bytes < bitmap.length; bytes++) {
for (int bits = 0; bits < 8; bits++) {
long streamNumber = bitmapBottom + bytes * 8 + bits;
if (streamNumber >= 0 && streamNumber <= MAX_32_BIT_UNSIGNED) {
if (!unseen.contains(streamNumber))
bitmap[bytes] |= 128 >> bits;
}
byte[] bitmap = new byte[seen.length / 8];
for (int i = 0; i < bitmap.length; i++) {
for (int j = 0; j < 8; j++) {
if (seen[i * 8 + j]) bitmap[i] |= 128 >> j;
}
}
return bitmap;
}
// Returns the lowest value contained in a window with the given centre
private static long getBottom(long centre) {
return Math.max(0, centre - REORDERING_WINDOW_SIZE / 2);
}
// Returns the highest value contained in a window with the given centre
private static long getTop(long centre) {
return Math.min(MAX_32_BIT_UNSIGNED,
centre + REORDERING_WINDOW_SIZE / 2 - 1);
}
public Collection<Long> getUnseen() {
List<Long> getUnseen() {
List<Long> unseen = new ArrayList<Long>(seen.length);
for (int i = 0; i < seen.length; i++)
if (!seen[i]) unseen.add(base + i);
return unseen;
}
Change setSeen(long index) {
if (index < base) throw new IllegalArgumentException();
if (index >= base + seen.length) throw new IllegalArgumentException();
if (index > MAX_32_BIT_UNSIGNED) throw new IllegalArgumentException();
int offset = (int) (index - base);
if (seen[offset]) throw new IllegalArgumentException();
seen[offset] = true;
// Rule 1: Slide until all elements above the midpoint are unseen
int slide = Math.max(0, offset + 1 - seen.length / 2);
// Rule 2: Slide until the lowest element is unseen
while (seen[slide]) slide++;
// If the window doesn't need to slide, return
if (slide == 0) {
List<Long> added = Collections.emptyList();
List<Long> removed = Collections.singletonList(index);
return new Change(added, removed);
}
// Record the elements that will be added and removed
List<Long> added = new ArrayList<Long>(slide);
List<Long> removed = new ArrayList<Long>(slide);
for (int i = 0; i < slide; i++) {
if (!seen[i]) removed.add(base + i);
added.add(base + seen.length + i);
}
removed.add(index);
// Update the window
base += slide;
for (int i = 0; i + slide < seen.length; i++) seen[i] = seen[i + slide];
for (int i = seen.length - slide; i < seen.length; i++) seen[i] = false;
return new Change(added, removed);
}
static class Change {
private final List<Long> added, removed;
Change(List<Long> added, List<Long> removed) {
this.added = added;
this.removed = removed;
}
List<Long> getAdded() {
return added;
}
List<Long> getRemoved() {
return removed;
}
}
}

View File

@@ -4,6 +4,7 @@ import java.io.InputStream;
import javax.inject.Inject;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.crypto.StreamDecrypterFactory;
import org.briarproject.api.transport.StreamContext;
import org.briarproject.api.transport.StreamReaderFactory;
@@ -23,9 +24,9 @@ class StreamReaderFactoryImpl implements StreamReaderFactory {
}
public InputStream createInvitationStreamReader(InputStream in,
byte[] secret, boolean alice) {
SecretKey headerKey) {
return new StreamReaderImpl(
streamDecrypterFactory.createInvitationStreamDecrypter(in,
secret, alice));
headerKey));
}
}

View File

@@ -7,6 +7,12 @@ import java.io.InputStream;
import org.briarproject.api.crypto.StreamDecrypter;
/**
* An {@link java.io.InputStream InputStream} that unpacks payload data from
* transport frames.
* <p>
* This class is not thread-safe.
*/
class StreamReaderImpl extends InputStream {
private final StreamDecrypter decrypter;
@@ -50,7 +56,7 @@ class StreamReaderImpl extends InputStream {
}
private void readFrame() throws IOException {
assert length == 0;
if (length != 0) throw new IllegalStateException();
offset = 0;
length = decrypter.readFrame(payload);
}

View File

@@ -4,6 +4,7 @@ import java.io.OutputStream;
import javax.inject.Inject;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.crypto.StreamEncrypterFactory;
import org.briarproject.api.transport.StreamContext;
import org.briarproject.api.transport.StreamWriterFactory;
@@ -24,9 +25,9 @@ class StreamWriterFactoryImpl implements StreamWriterFactory {
}
public OutputStream createInvitationStreamWriter(OutputStream out,
byte[] secret, boolean alice) {
SecretKey headerKey) {
return new StreamWriterImpl(
streamEncrypterFactory.createInvitationStreamEncrypter(out,
secret, alice));
headerKey));
}
}

View File

@@ -8,9 +8,9 @@ import java.io.OutputStream;
import org.briarproject.api.crypto.StreamEncrypter;
/**
* A {@link org.briarproject.api.transport.StreamWriter StreamWriter} that
* buffers its input and writes a frame whenever there is a full frame to write
* or the {@link #flush()} method is called.
* An {@link java.io.OutputStream OutputStream} that packs data into transport
* frames, writing a frame whenever there is a full frame to write or the
* {@link #flush()} method is called.
* <p>
* This class is not thread-safe.
*/

View File

@@ -1,105 +0,0 @@
package org.briarproject.transport;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import javax.inject.Inject;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.transport.StreamContext;
import org.briarproject.api.transport.TagRecogniser;
import org.briarproject.api.transport.TemporarySecret;
class TagRecogniserImpl implements TagRecogniser {
private final CryptoComponent crypto;
private final DatabaseComponent db;
private final Lock lock = new ReentrantLock();
// Locking: lock
private final Map<TransportId, TransportTagRecogniser> recognisers;
@Inject
TagRecogniserImpl(CryptoComponent crypto, DatabaseComponent db) {
this.crypto = crypto;
this.db = db;
recognisers = new HashMap<TransportId, TransportTagRecogniser>();
}
public StreamContext recogniseTag(TransportId t, byte[] tag)
throws DbException {
TransportTagRecogniser r;
lock.lock();
try {
r = recognisers.get(t);
} finally {
lock.unlock();
}
if (r == null) return null;
return r.recogniseTag(tag);
}
public void addSecret(TemporarySecret s) {
TransportId t = s.getTransportId();
TransportTagRecogniser r;
lock.lock();
try {
r = recognisers.get(t);
if (r == null) {
r = new TransportTagRecogniser(crypto, db, t);
recognisers.put(t, r);
}
} finally {
lock.unlock();
}
r.addSecret(s);
}
public void removeSecret(ContactId c, TransportId t, long period) {
TransportTagRecogniser r;
lock.lock();
try {
r = recognisers.get(t);
} finally {
lock.unlock();
}
if (r != null) r.removeSecret(c, period);
}
public void removeSecrets(ContactId c) {
lock.lock();
try {
for (TransportTagRecogniser r : recognisers.values())
r.removeSecrets(c);
} finally {
lock.unlock();
}
}
public void removeSecrets(TransportId t) {
lock.lock();
try {
recognisers.remove(t);
} finally {
lock.unlock();
}
}
public void removeSecrets() {
lock.lock();
try {
for (TransportTagRecogniser r : recognisers.values())
r.removeSecrets();
} finally {
lock.unlock();
}
}
}

View File

@@ -0,0 +1,294 @@
package org.briarproject.transport;
import org.briarproject.api.Bytes;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.system.Clock;
import org.briarproject.api.system.Timer;
import org.briarproject.api.transport.StreamContext;
import org.briarproject.api.transport.TransportKeys;
import org.briarproject.transport.ReorderingWindow.Change;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.TimerTask;
import java.util.concurrent.Executor;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger;
import static java.util.logging.Level.WARNING;
import static org.briarproject.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE;
import static org.briarproject.api.transport.TransportConstants.TAG_LENGTH;
import static org.briarproject.util.ByteUtils.MAX_32_BIT_UNSIGNED;
class TransportKeyManager extends TimerTask {
private static final Logger LOG =
Logger.getLogger(TransportKeyManager.class.getName());
private final DatabaseComponent db;
private final CryptoComponent crypto;
private final Executor dbExecutor;
private final Timer timer;
private final Clock clock;
private final TransportId transportId;
private final long rotationPeriodLength;
private final ReentrantLock lock;
// The following are locking: lock
private final Map<Bytes, TagContext> inContexts;
private final Map<ContactId, MutableOutgoingKeys> outContexts;
private final Map<ContactId, MutableTransportKeys> keys;
TransportKeyManager(DatabaseComponent db, CryptoComponent crypto,
Executor dbExecutor, Timer timer, Clock clock,
TransportId transportId, long maxLatency) {
this.db = db;
this.crypto = crypto;
this.dbExecutor = dbExecutor;
this.timer = timer;
this.clock = clock;
this.transportId = transportId;
rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE;
lock = new ReentrantLock();
inContexts = new HashMap<Bytes, TagContext>();
outContexts = new HashMap<ContactId, MutableOutgoingKeys>();
keys = new HashMap<ContactId, MutableTransportKeys>();
}
void start() {
// Load the transport keys from the DB
Map<ContactId, TransportKeys> loaded;
try {
loaded = db.getTransportKeys(transportId);
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return;
}
// Rotate the keys to the current rotation period
Map<ContactId, TransportKeys> rotated =
new HashMap<ContactId, TransportKeys>();
Map<ContactId, TransportKeys> current =
new HashMap<ContactId, TransportKeys>();
long now = clock.currentTimeMillis();
long rotationPeriod = now / rotationPeriodLength;
for (Entry<ContactId, TransportKeys> e : loaded.entrySet()) {
ContactId c = e.getKey();
TransportKeys k = e.getValue();
TransportKeys k1 = crypto.rotateTransportKeys(k, rotationPeriod);
if (k1.getRotationPeriod() > k.getRotationPeriod())
rotated.put(c, k1);
current.put(c, k1);
}
lock.lock();
try {
// Initialise mutable state for all contacts
for (Entry<ContactId, TransportKeys> e : current.entrySet())
addKeys(e.getKey(), new MutableTransportKeys(e.getValue()));
// Write any rotated keys back to the DB
saveTransportKeys(rotated);
} finally {
lock.unlock();
}
// Schedule a periodic task to rotate the keys
long delay = rotationPeriodLength - now % rotationPeriodLength;
timer.scheduleAtFixedRate(this, delay, rotationPeriodLength);
}
// Locking: lock
private void addKeys(ContactId c, MutableTransportKeys m) {
encodeTags(c, m.getPreviousIncomingKeys());
encodeTags(c, m.getCurrentIncomingKeys());
encodeTags(c, m.getNextIncomingKeys());
outContexts.put(c, m.getCurrentOutgoingKeys());
keys.put(c, m);
}
// Locking: lock
private void encodeTags(ContactId c, MutableIncomingKeys inKeys) {
for (long streamNumber : inKeys.getWindow().getUnseen()) {
TagContext tagCtx = new TagContext(c, inKeys, streamNumber);
byte[] tag = new byte[TAG_LENGTH];
crypto.encodeTag(tag, inKeys.getTagKey(), streamNumber);
inContexts.put(new Bytes(tag), tagCtx);
}
}
private void saveTransportKeys(final Map<ContactId, TransportKeys> rotated) {
dbExecutor.execute(new Runnable() {
public void run() {
try {
db.updateTransportKeys(rotated);
} catch (DbException e) {
if (LOG.isLoggable(WARNING))
LOG.log(WARNING, e.toString(), e);
}
}
});
}
void addContact(ContactId c, TransportKeys k) {
lock.lock();
try {
// Initialise mutable state for the contact
addKeys(c, new MutableTransportKeys(k));
} finally {
lock.unlock();
}
}
void removeContact(ContactId c) {
lock.lock();
try {
// Remove mutable state for the contact
Iterator<Entry<Bytes, TagContext>> it =
inContexts.entrySet().iterator();
while (it.hasNext())
if (it.next().getValue().contactId.equals(c)) it.remove();
outContexts.remove(c);
keys.remove(c);
} finally {
lock.unlock();
}
}
StreamContext getStreamContext(ContactId c) {
StreamContext ctx;
lock.lock();
try {
// Look up the outgoing keys for the contact
MutableOutgoingKeys outKeys = outContexts.get(c);
if (outKeys == null) return null;
if (outKeys.getStreamCounter() > MAX_32_BIT_UNSIGNED) return null;
// Create a stream context
ctx = new StreamContext(c, transportId, outKeys.getTagKey(),
outKeys.getHeaderKey(), outKeys.getStreamCounter());
// Increment the stream counter and write it back to the DB
outKeys.incrementStreamCounter();
saveIncrementedStreamCounter(c, outKeys.getRotationPeriod());
} finally {
lock.unlock();
}
// TODO: Wait for save to complete, return null if it fails
return ctx;
}
private void saveIncrementedStreamCounter(final ContactId c,
final long rotationPeriod) {
dbExecutor.execute(new Runnable() {
public void run() {
try {
db.incrementStreamCounter(c, transportId, rotationPeriod);
} catch (DbException e) {
if (LOG.isLoggable(WARNING))
LOG.log(WARNING, e.toString(), e);
}
}
});
}
StreamContext recogniseTag(byte[] tag) {
StreamContext ctx;
lock.lock();
try {
// Look up the incoming keys for the tag
TagContext tagCtx = inContexts.remove(new Bytes(tag));
if (tagCtx == null) return null;
MutableIncomingKeys inKeys = tagCtx.inKeys;
// Create a stream context
ctx = new StreamContext(tagCtx.contactId, transportId,
inKeys.getTagKey(), inKeys.getHeaderKey(),
tagCtx.streamNumber);
// Update the reordering window
ReorderingWindow window = inKeys.getWindow();
Change change = window.setSeen(tagCtx.streamNumber);
// Add tags for any stream numbers added to the window
for (long streamNumber : change.getAdded()) {
byte[] addTag = new byte[TAG_LENGTH];
crypto.encodeTag(addTag, inKeys.getTagKey(), streamNumber);
inContexts.put(new Bytes(addTag), new TagContext(
tagCtx.contactId, inKeys, streamNumber));
}
// Remove tags for any stream numbers removed from the window
for (long streamNumber : change.getRemoved()) {
byte[] removeTag = new byte[TAG_LENGTH];
crypto.encodeTag(removeTag, inKeys.getTagKey(), streamNumber);
inContexts.remove(new Bytes(removeTag));
}
// Write the window back to the DB
saveReorderingWindow(tagCtx.contactId, inKeys.getRotationPeriod(),
window.getBase(), window.getBitmap());
} finally {
lock.unlock();
}
// TODO: Wait for save to complete, return null if it fails
return ctx;
}
private void saveReorderingWindow(final ContactId c,
final long rotationPeriod, final long base, final byte[] bitmap) {
dbExecutor.execute(new Runnable() {
public void run() {
try {
db.setReorderingWindow(c, transportId, rotationPeriod,
base, bitmap);
} catch (DbException e) {
if (LOG.isLoggable(WARNING))
LOG.log(WARNING, e.toString(), e);
}
}
});
}
@Override
public void run() {
lock.lock();
try {
// Rotate the keys to the current rotation period
Map<ContactId, TransportKeys> rotated =
new HashMap<ContactId, TransportKeys>();
Map<ContactId, TransportKeys> current =
new HashMap<ContactId, TransportKeys>();
long now = clock.currentTimeMillis();
long rotationPeriod = now / rotationPeriodLength;
for (Entry<ContactId, MutableTransportKeys> e : keys.entrySet()) {
ContactId c = e.getKey();
TransportKeys k = e.getValue().snapshot();
TransportKeys k1 = crypto.rotateTransportKeys(k,
rotationPeriod);
if (k1.getRotationPeriod() > k.getRotationPeriod())
rotated.put(c, k1);
current.put(c, k1);
}
// Rebuild the mutable state for all contacts
inContexts.clear();
outContexts.clear();
keys.clear();
for (Entry<ContactId, TransportKeys> e : current.entrySet())
addKeys(e.getKey(), new MutableTransportKeys(e.getValue()));
// Write any rotated keys back to the DB
saveTransportKeys(rotated);
} finally {
lock.unlock();
}
}
private static class TagContext {
private final ContactId contactId;
private final MutableIncomingKeys inKeys;
private final long streamNumber;
private TagContext(ContactId contactId, MutableIncomingKeys inKeys,
long streamNumber) {
this.contactId = contactId;
this.inKeys = inKeys;
this.streamNumber = streamNumber;
}
}
}

View File

@@ -1,23 +1,20 @@
package org.briarproject.transport;
import javax.inject.Singleton;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.lifecycle.LifecycleManager;
import org.briarproject.api.transport.StreamReaderFactory;
import org.briarproject.api.transport.StreamWriterFactory;
import org.briarproject.api.transport.TagRecogniser;
import com.google.inject.AbstractModule;
import com.google.inject.Provides;
import org.briarproject.api.lifecycle.LifecycleManager;
import org.briarproject.api.transport.KeyManager;
import org.briarproject.api.transport.StreamReaderFactory;
import org.briarproject.api.transport.StreamWriterFactory;
import javax.inject.Singleton;
public class TransportModule extends AbstractModule {
@Override
protected void configure() {
bind(StreamReaderFactory.class).to(StreamReaderFactoryImpl.class);
bind(TagRecogniser.class).to(
TagRecogniserImpl.class).in(Singleton.class);
bind(StreamWriterFactory.class).to(StreamWriterFactoryImpl.class);
}

View File

@@ -1,215 +0,0 @@
package org.briarproject.transport;
import static org.briarproject.api.transport.TransportConstants.TAG_LENGTH;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.briarproject.api.Bytes;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.transport.StreamContext;
import org.briarproject.api.transport.TemporarySecret;
// FIXME: Don't make alien calls with a lock held
/**
* A {@link org.briarproject.api.transport.TagRecogniser TagRecogniser} for a
* specific transport.
*/
class TransportTagRecogniser {
private final CryptoComponent crypto;
private final DatabaseComponent db;
private final TransportId transportId;
private final Lock lock = new ReentrantLock();
// The following are locking: lock
private final Map<Bytes, TagContext> tagMap;
private final Map<RemovalKey, RemovalContext> removalMap;
TransportTagRecogniser(CryptoComponent crypto, DatabaseComponent db,
TransportId transportId) {
this.crypto = crypto;
this.db = db;
this.transportId = transportId;
tagMap = new HashMap<Bytes, TagContext>();
removalMap = new HashMap<RemovalKey, RemovalContext>();
}
StreamContext recogniseTag(byte[] tag) throws DbException {
lock.lock();
try {
TagContext t = tagMap.remove(new Bytes(tag));
if (t == null) return null; // The tag was not expected
// Update the reordering window and the expected tags
SecretKey key = crypto.deriveTagKey(t.secret, !t.alice);
for (long streamNumber : t.window.setSeen(t.streamNumber)) {
byte[] tag1 = new byte[TAG_LENGTH];
crypto.encodeTag(tag1, key, streamNumber);
if (streamNumber < t.streamNumber) {
TagContext removed = tagMap.remove(new Bytes(tag1));
assert removed != null;
} else {
TagContext added = new TagContext(t, streamNumber);
TagContext duplicate = tagMap.put(new Bytes(tag1), added);
assert duplicate == null;
}
}
// Store the updated reordering window in the DB
db.setReorderingWindow(t.contactId, transportId, t.period,
t.window.getCentre(), t.window.getBitmap());
return new StreamContext(t.contactId, transportId, t.secret,
t.streamNumber, t.alice);
} finally {
lock.unlock();
}
}
void addSecret(TemporarySecret s) {
lock.lock();
try {
ContactId contactId = s.getContactId();
boolean alice = s.getAlice();
long period = s.getPeriod();
byte[] secret = s.getSecret();
long centre = s.getWindowCentre();
byte[] bitmap = s.getWindowBitmap();
// Create the reordering window and the expected tags
SecretKey key = crypto.deriveTagKey(secret, !alice);
ReorderingWindow window = new ReorderingWindow(centre, bitmap);
for (long streamNumber : window.getUnseen()) {
byte[] tag = new byte[TAG_LENGTH];
crypto.encodeTag(tag, key, streamNumber);
TagContext added = new TagContext(contactId, alice, period,
secret, window, streamNumber);
TagContext duplicate = tagMap.put(new Bytes(tag), added);
assert duplicate == null;
}
// Create a removal context to remove the window and the tags later
RemovalContext r = new RemovalContext(window, secret, alice);
removalMap.put(new RemovalKey(contactId, period), r);
} finally {
lock.unlock();
}
}
void removeSecret(ContactId contactId, long period) {
lock.lock();
try {
RemovalKey k = new RemovalKey(contactId, period);
RemovalContext removed = removalMap.remove(k);
if (removed == null) throw new IllegalArgumentException();
removeSecret(removed);
} finally {
lock.unlock();
}
}
// Locking: lock
private void removeSecret(RemovalContext r) {
// Remove the expected tags
SecretKey key = crypto.deriveTagKey(r.secret, !r.alice);
byte[] tag = new byte[TAG_LENGTH];
for (long streamNumber : r.window.getUnseen()) {
crypto.encodeTag(tag, key, streamNumber);
TagContext removed = tagMap.remove(new Bytes(tag));
assert removed != null;
}
}
void removeSecrets(ContactId c) {
lock.lock();
try {
Collection<RemovalKey> keysToRemove = new ArrayList<RemovalKey>();
for (RemovalKey k : removalMap.keySet())
if (k.contactId.equals(c)) keysToRemove.add(k);
for (RemovalKey k : keysToRemove)
removeSecret(k.contactId, k.period);
} finally {
lock.unlock();
}
}
void removeSecrets() {
lock.lock();
try {
for (RemovalContext r : removalMap.values()) removeSecret(r);
assert tagMap.isEmpty();
removalMap.clear();
} finally {
lock.unlock();
}
}
private static class TagContext {
private final ContactId contactId;
private final boolean alice;
private final long period;
private final byte[] secret;
private final ReorderingWindow window;
private final long streamNumber;
private TagContext(ContactId contactId, boolean alice, long period,
byte[] secret, ReorderingWindow window, long streamNumber) {
this.contactId = contactId;
this.alice = alice;
this.period = period;
this.secret = secret;
this.window = window;
this.streamNumber = streamNumber;
}
private TagContext(TagContext t, long streamNumber) {
this(t.contactId, t.alice, t.period, t.secret, t.window,
streamNumber);
}
}
private static class RemovalKey {
private final ContactId contactId;
private final long period;
private RemovalKey(ContactId contactId, long period) {
this.contactId = contactId;
this.period = period;
}
@Override
public int hashCode() {
return contactId.hashCode() ^ (int) (period ^ (period >>> 32));
}
@Override
public boolean equals(Object o) {
if (o instanceof RemovalKey) {
RemovalKey k = (RemovalKey) o;
return contactId.equals(k.contactId) && period == k.period;
}
return false;
}
}
private static class RemovalContext {
private final ReorderingWindow window;
private final byte[] secret;
private final boolean alice;
private RemovalContext(ReorderingWindow window, byte[] secret,
boolean alice) {
this.window = window;
this.secret = secret;
this.alice = alice;
}
}
}

View File

@@ -10,6 +10,7 @@ import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyPair;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.messaging.Ack;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.GroupFactory;
@@ -44,7 +45,6 @@ import java.io.OutputStream;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Random;
import static org.briarproject.api.transport.TransportConstants.TAG_LENGTH;
import static org.junit.Assert.assertArrayEquals;
@@ -61,14 +61,9 @@ public class ProtocolIntegrationTest extends BriarTestCase {
private final MessageVerifier messageVerifier;
private final ContactId contactId;
private final byte[] secret;
private final Author author;
private final SecretKey tagKey, headerKey;
private final Group group;
private final Message message, message1;
private final String authorName = "Alice";
private final String contentType = "text/plain";
private final long timestamp = System.currentTimeMillis();
private final String messageBody = "Hello world";
private final Collection<MessageId> messageIds;
private final TransportId transportId;
private final TransportProperties transportProperties;
@@ -85,9 +80,9 @@ public class ProtocolIntegrationTest extends BriarTestCase {
packetWriterFactory = i.getInstance(PacketWriterFactory.class);
messageVerifier = i.getInstance(MessageVerifier.class);
contactId = new ContactId(234);
// Create a shared secret
secret = new byte[32];
new Random().nextBytes(secret);
// Create the transport keys
tagKey = TestUtils.createSecretKey();
headerKey = TestUtils.createSecretKey();
// Create a group
GroupFactory groupFactory = i.getInstance(GroupFactory.class);
group = groupFactory.createGroup("Group");
@@ -95,12 +90,15 @@ public class ProtocolIntegrationTest extends BriarTestCase {
AuthorFactory authorFactory = i.getInstance(AuthorFactory.class);
CryptoComponent crypto = i.getInstance(CryptoComponent.class);
KeyPair authorKeyPair = crypto.generateSignatureKeyPair();
author = authorFactory.createAuthor(authorName,
Author author = authorFactory.createAuthor("Alice",
authorKeyPair.getPublic().getEncoded());
// Create two messages to the group: one anonymous, one pseudonymous
MessageFactory messageFactory = i.getInstance(MessageFactory.class);
String contentType = "text/plain";
long timestamp = System.currentTimeMillis();
String messageBody = "Hello world";
message = messageFactory.createAnonymousMessage(null, group,
contentType, timestamp, messageBody.getBytes("UTF-8"));
"text/plain", timestamp, messageBody.getBytes("UTF-8"));
message1 = messageFactory.createPseudonymousMessage(null, group,
author, authorKeyPair.getPrivate(), contentType, timestamp,
messageBody.getBytes("UTF-8"));
@@ -118,8 +116,8 @@ public class ProtocolIntegrationTest extends BriarTestCase {
private byte[] write() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
StreamContext ctx = new StreamContext(contactId, transportId, secret,
0, true);
StreamContext ctx = new StreamContext(contactId, transportId, tagKey,
headerKey, 0);
OutputStream streamWriter =
streamWriterFactory.createStreamWriter(out, ctx);
PacketWriter packetWriter = packetWriterFactory.createPacketWriter(
@@ -134,7 +132,8 @@ public class ProtocolIntegrationTest extends BriarTestCase {
packetWriter.writeRequest(new Request(messageIds));
SubscriptionUpdate su = new SubscriptionUpdate(Arrays.asList(group), 1);
SubscriptionUpdate su = new SubscriptionUpdate(
Collections.singletonList(group), 1);
packetWriter.writeSubscriptionUpdate(su);
TransportUpdate tu = new TransportUpdate(transportId,
@@ -150,8 +149,8 @@ public class ProtocolIntegrationTest extends BriarTestCase {
byte[] tag = new byte[TAG_LENGTH];
assertEquals(TAG_LENGTH, in.read(tag, 0, TAG_LENGTH));
// FIXME: Check that the expected tag was received
StreamContext ctx = new StreamContext(contactId, transportId, secret,
0, false);
StreamContext ctx = new StreamContext(contactId, transportId, tagKey,
headerKey, 0);
InputStream streamReader =
streamReaderFactory.createStreamReader(in, ctx);
PacketReader packetReader = packetReaderFactory.createPacketReader(
@@ -184,7 +183,7 @@ public class ProtocolIntegrationTest extends BriarTestCase {
// Read the subscription update
assertTrue(packetReader.hasSubscriptionUpdate());
SubscriptionUpdate su = packetReader.readSubscriptionUpdate();
assertEquals(Arrays.asList(group), su.getGroups());
assertEquals(Collections.singletonList(group), su.getGroups());
assertEquals(1, su.getVersion());
// Read the transport update

View File

@@ -7,6 +7,9 @@ import java.io.File;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import org.briarproject.api.UniqueId;
import org.briarproject.api.crypto.SecretKey;
public class TestUtils {
private static final AtomicInteger nextTestDir =
@@ -35,4 +38,10 @@ public class TestUtils {
c[i] = (char) ('a' + random.nextInt(26));
return new String(c);
}
public static SecretKey createSecretKey() {
byte[] b = new byte[SecretKey.LENGTH];
random.nextBytes(b);
return new SecretKey(b);
}
}

View File

@@ -4,6 +4,7 @@ import org.briarproject.BriarTestCase;
import org.briarproject.TestSeedProvider;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyPair;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.system.SeedProvider;
import org.junit.Test;
@@ -19,8 +20,8 @@ public class KeyAgreementTest extends BriarTestCase {
byte[] aPub = aPair.getPublic().getEncoded();
KeyPair bPair = crypto.generateAgreementKeyPair();
byte[] bPub = bPair.getPublic().getEncoded();
byte[] aSecret = crypto.deriveMasterSecret(aPub, bPair, true);
byte[] bSecret = crypto.deriveMasterSecret(bPub, aPair, false);
assertArrayEquals(aSecret, bSecret);
SecretKey aMaster = crypto.deriveMasterSecret(aPub, bPair, true);
SecretKey bMaster = crypto.deriveMasterSecret(bPub, aPair, false);
assertArrayEquals(aMaster.getBytes(), bMaster.getBytes());
}
}

View File

@@ -2,72 +2,164 @@ package org.briarproject.crypto;
import org.briarproject.BriarTestCase;
import org.briarproject.TestSeedProvider;
import org.briarproject.TestUtils;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.transport.TransportKeys;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertFalse;
public class KeyDerivationTest extends BriarTestCase {
private final TransportId transportId = new TransportId("id");
private final CryptoComponent crypto;
private final byte[] secret;
private final SecretKey master;
public KeyDerivationTest() {
crypto = new CryptoComponentImpl(new TestSeedProvider());
secret = new byte[32];
new Random().nextBytes(secret);
master = TestUtils.createSecretKey();
}
@Test
public void testKeysAreDistinct() {
List<SecretKey> keys = new ArrayList<SecretKey>();
keys.add(crypto.deriveFrameKey(secret, 0, true));
keys.add(crypto.deriveFrameKey(secret, 0, false));
keys.add(crypto.deriveTagKey(secret, true));
keys.add(crypto.deriveTagKey(secret, false));
for (int i = 0; i < 4; i++) {
byte[] keyI = keys.get(i).getBytes();
for (int j = 0; j < 4; j++) {
byte[] keyJ = keys.get(j).getBytes();
assertEquals(i == j, Arrays.equals(keyI, keyJ));
}
}
TransportKeys k = crypto.deriveTransportKeys(transportId, master,
123, true);
assertAllDifferent(k);
}
@Test
public void testSecretAffectsDerivation() {
Random r = new Random();
List<byte[]> secrets = new ArrayList<byte[]>();
for (int i = 0; i < 20; i++) {
byte[] b = new byte[32];
r.nextBytes(b);
secrets.add(crypto.deriveNextSecret(b, 0));
}
for (int i = 0; i < 20; i++) {
byte[] secretI = secrets.get(i);
for (int j = 0; j < 20; j++) {
byte[] secretJ = secrets.get(j);
assertEquals(i == j, Arrays.equals(secretI, secretJ));
}
}
public void testCurrentKeysMatchCurrentKeysOfContact() {
// Start in rotation period 123
TransportKeys kA = crypto.deriveTransportKeys(transportId, master,
123, true);
TransportKeys kB = crypto.deriveTransportKeys(transportId, master,
123, false);
// Alice's incoming keys should equal Bob's outgoing keys
assertArrayEquals(kA.getCurrentIncomingKeys().getTagKey().getBytes(),
kB.getCurrentOutgoingKeys().getTagKey().getBytes());
assertArrayEquals(kA.getCurrentIncomingKeys().getHeaderKey().getBytes(),
kB.getCurrentOutgoingKeys().getHeaderKey().getBytes());
// Alice's outgoing keys should equal Bob's incoming keys
assertArrayEquals(kA.getCurrentOutgoingKeys().getTagKey().getBytes(),
kB.getCurrentIncomingKeys().getTagKey().getBytes());
assertArrayEquals(kA.getCurrentOutgoingKeys().getHeaderKey().getBytes(),
kB.getCurrentIncomingKeys().getHeaderKey().getBytes());
// Rotate into the future
kA = crypto.rotateTransportKeys(kA, 456);
kB = crypto.rotateTransportKeys(kB, 456);
// Alice's incoming keys should equal Bob's outgoing keys
assertArrayEquals(kA.getCurrentIncomingKeys().getTagKey().getBytes(),
kB.getCurrentOutgoingKeys().getTagKey().getBytes());
assertArrayEquals(kA.getCurrentIncomingKeys().getHeaderKey().getBytes(),
kB.getCurrentOutgoingKeys().getHeaderKey().getBytes());
// Alice's outgoing keys should equal Bob's incoming keys
assertArrayEquals(kA.getCurrentOutgoingKeys().getTagKey().getBytes(),
kB.getCurrentIncomingKeys().getTagKey().getBytes());
assertArrayEquals(kA.getCurrentOutgoingKeys().getHeaderKey().getBytes(),
kB.getCurrentIncomingKeys().getHeaderKey().getBytes());
}
@Test
public void testStreamNumberAffectsDerivation() {
List<byte[]> secrets = new ArrayList<byte[]>();
for (int i = 0; i < 20; i++)
secrets.add(crypto.deriveNextSecret(secret, i));
for (int i = 0; i < 20; i++) {
byte[] secretI = secrets.get(i);
for (int j = 0; j < 20; j++) {
byte[] secretJ = secrets.get(j);
assertEquals(i == j, Arrays.equals(secretI, secretJ));
public void testPreviousKeysMatchPreviousKeysOfContact() {
// Start in rotation period 123
TransportKeys kA = crypto.deriveTransportKeys(transportId, master,
123, true);
TransportKeys kB = crypto.deriveTransportKeys(transportId, master,
123, false);
// Compare Alice's previous keys in period 456 with Bob's current keys
// in period 455
kA = crypto.rotateTransportKeys(kA, 456);
kB = crypto.rotateTransportKeys(kB, 455);
// Alice's previous incoming keys should equal Bob's outgoing keys
assertArrayEquals(kA.getPreviousIncomingKeys().getTagKey().getBytes(),
kB.getCurrentOutgoingKeys().getTagKey().getBytes());
assertArrayEquals(kA.getPreviousIncomingKeys().getHeaderKey().getBytes(),
kB.getCurrentOutgoingKeys().getHeaderKey().getBytes());
// Compare Alice's current keys in period 456 with Bob's previous keys
// in period 457
kB = crypto.rotateTransportKeys(kB, 457);
// Alice's outgoing keys should equal Bob's previous incoming keys
assertArrayEquals(kA.getCurrentOutgoingKeys().getTagKey().getBytes(),
kB.getPreviousIncomingKeys().getTagKey().getBytes());
assertArrayEquals(kA.getCurrentOutgoingKeys().getHeaderKey().getBytes(),
kB.getPreviousIncomingKeys().getHeaderKey().getBytes());
}
@Test
public void testNextKeysMatchNextKeysOfContact() {
// Start in rotation period 123
TransportKeys kA = crypto.deriveTransportKeys(transportId, master,
123, true);
TransportKeys kB = crypto.deriveTransportKeys(transportId, master,
123, false);
// Compare Alice's current keys in period 456 with Bob's next keys in
// period 455
kA = crypto.rotateTransportKeys(kA, 456);
kB = crypto.rotateTransportKeys(kB, 455);
// Alice's outgoing keys should equal Bob's next incoming keys
assertArrayEquals(kA.getCurrentOutgoingKeys().getTagKey().getBytes(),
kB.getNextIncomingKeys().getTagKey().getBytes());
assertArrayEquals(kA.getCurrentOutgoingKeys().getHeaderKey().getBytes(),
kB.getNextIncomingKeys().getHeaderKey().getBytes());
// Compare Alice's next keys in period 456 with Bob's current keys
// in period 457
kB = crypto.rotateTransportKeys(kB, 457);
// Alice's next incoming keys should equal Bob's outgoing keys
assertArrayEquals(kA.getNextIncomingKeys().getTagKey().getBytes(),
kB.getCurrentOutgoingKeys().getTagKey().getBytes());
assertArrayEquals(kA.getNextIncomingKeys().getHeaderKey().getBytes(),
kB.getCurrentOutgoingKeys().getHeaderKey().getBytes());
}
@Test
public void testMasterKeyAffectsOutput() {
SecretKey master1 = TestUtils.createSecretKey();
assertFalse(Arrays.equals(master.getBytes(), master1.getBytes()));
TransportKeys k = crypto.deriveTransportKeys(transportId, master,
123, true);
TransportKeys k1 = crypto.deriveTransportKeys(transportId, master1,
123, true);
assertAllDifferent(k, k1);
}
@Test
public void testTransportIdAffectsOutput() {
TransportId transportId1 = new TransportId("id1");
assertFalse(transportId.getString().equals(transportId1.getString()));
TransportKeys k = crypto.deriveTransportKeys(transportId, master,
123, true);
TransportKeys k1 = crypto.deriveTransportKeys(transportId1, master,
123, true);
assertAllDifferent(k, k1);
}
private void assertAllDifferent(TransportKeys... transportKeys) {
List<SecretKey> secretKeys = new ArrayList<SecretKey>();
for (TransportKeys k : transportKeys) {
secretKeys.add(k.getPreviousIncomingKeys().getTagKey());
secretKeys.add(k.getPreviousIncomingKeys().getHeaderKey());
secretKeys.add(k.getCurrentIncomingKeys().getTagKey());
secretKeys.add(k.getCurrentIncomingKeys().getHeaderKey());
secretKeys.add(k.getNextIncomingKeys().getTagKey());
secretKeys.add(k.getNextIncomingKeys().getHeaderKey());
secretKeys.add(k.getCurrentOutgoingKeys().getTagKey());
secretKeys.add(k.getCurrentOutgoingKeys().getHeaderKey());
}
assertAllDifferent(secretKeys);
}
private void assertAllDifferent(List<SecretKey> keys) {
for (SecretKey ki : keys) {
for (SecretKey kj : keys) {
if (ki == kj) assertArrayEquals(ki.getBytes(), kj.getBytes());
else assertFalse(Arrays.equals(ki.getBytes(), kj.getBytes()));
}
}
}

View File

@@ -11,6 +11,7 @@ import org.briarproject.api.LocalAuthor;
import org.briarproject.api.TransportConfig;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.NoSuchContactException;
import org.briarproject.api.db.NoSuchLocalAuthorException;
@@ -45,8 +46,9 @@ import org.briarproject.api.messaging.SubscriptionAck;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.TransportAck;
import org.briarproject.api.messaging.TransportUpdate;
import org.briarproject.api.transport.Endpoint;
import org.briarproject.api.transport.TemporarySecret;
import org.briarproject.api.transport.IncomingKeys;
import org.briarproject.api.transport.OutgoingKeys;
import org.briarproject.api.transport.TransportKeys;
import org.jmock.Expectations;
import org.jmock.Mockery;
import org.junit.Test;
@@ -84,8 +86,6 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
protected final int maxLatency;
protected final ContactId contactId;
protected final Contact contact;
protected final Endpoint endpoint;
protected final TemporarySecret temporarySecret;
public DatabaseComponentTest() {
groupId = new GroupId(TestUtils.getRandomId());
@@ -112,9 +112,6 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
maxLatency = Integer.MAX_VALUE;
contactId = new ContactId(234);
contact = new Contact(contactId, author, localAuthorId);
endpoint = new Endpoint(contactId, transportId, 123, true);
temporarySecret = new TemporarySecret(contactId, transportId, 123,
false, 234, new byte[32], 345, 456, new byte[4]);
}
protected abstract <T> DatabaseComponent createDatabaseComponent(
@@ -157,7 +154,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
oneOf(eventBus).broadcast(with(any(ContactAddedEvent.class)));
// getContacts()
oneOf(database).getContacts(txn);
will(returnValue(Arrays.asList(contact)));
will(returnValue(Collections.singletonList(contact)));
// getRemoteProperties()
oneOf(database).getRemoteProperties(txn, transportId);
will(returnValue(Collections.emptyMap()));
@@ -177,7 +174,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
will(returnValue(Collections.emptyList()));
// getGroups()
oneOf(database).getGroups(txn);
will(returnValue(Arrays.asList(groupId)));
will(returnValue(Collections.singletonList(group)));
// removeGroup()
oneOf(database).containsGroup(txn, groupId);
will(returnValue(true));
@@ -213,13 +210,13 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
assertFalse(db.open());
db.addLocalAuthor(localAuthor);
assertEquals(contactId, db.addContact(author, localAuthorId));
assertEquals(Arrays.asList(contact), db.getContacts());
assertEquals(Collections.singletonList(contact), db.getContacts());
assertEquals(Collections.emptyMap(),
db.getRemoteProperties(transportId));
db.addGroup(group); // First time - listeners called
db.addGroup(group); // Second time - not called
assertEquals(Collections.emptyList(), db.getMessageHeaders(groupId));
assertEquals(Arrays.asList(groupId), db.getGroups());
assertEquals(Collections.singletonList(group), db.getGroups());
db.removeGroup(group);
db.removeContact(contactId);
db.removeLocalAuthor(localAuthorId);
@@ -297,9 +294,9 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
oneOf(database).addMessage(txn, message, true);
oneOf(database).setReadFlag(txn, messageId, true);
oneOf(database).getVisibility(txn, groupId);
will(returnValue(Arrays.asList(contactId)));
will(returnValue(Collections.singletonList(contactId)));
oneOf(database).getContactIds(txn);
will(returnValue(Arrays.asList(contactId)));
will(returnValue(Collections.singletonList(contactId)));
oneOf(database).removeOfferedMessage(txn, contactId, messageId);
will(returnValue(false));
oneOf(database).addStatus(txn, contactId, messageId, false, false);
@@ -336,7 +333,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
eventBus, shutdown);
try {
db.addEndpoint(endpoint);
db.addTransportKeys(contactId, createTransportKeys());
fail();
} catch (NoSuchContactException expected) {}
@@ -401,7 +398,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
} catch (NoSuchContactException expected) {}
try {
Ack a = new Ack(Arrays.asList(messageId));
Ack a = new Ack(Collections.singletonList(messageId));
db.receiveAck(contactId, a);
fail();
} catch (NoSuchContactException expected) {}
@@ -412,7 +409,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
} catch (NoSuchContactException expected) {}
try {
Offer o = new Offer(Arrays.asList(messageId));
Offer o = new Offer(Collections.singletonList(messageId));
db.receiveOffer(contactId, o);
fail();
} catch (NoSuchContactException expected) {}
@@ -594,7 +591,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
// Check whether the transport is in the DB (which it's not)
exactly(8).of(database).startTransaction();
will(returnValue(txn));
exactly(3).of(database).containsContact(txn, contactId);
exactly(2).of(database).containsContact(txn, contactId);
will(returnValue(true));
exactly(8).of(database).containsTransport(txn, transportId);
will(returnValue(false));
@@ -606,11 +603,6 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
db.addLocalAuthor(localAuthor);
assertEquals(contactId, db.addContact(author, localAuthorId));
try {
db.addEndpoint(endpoint);
fail();
} catch (NoSuchTransportException expected) {}
try {
db.getConfig(transportId);
fail();
@@ -621,6 +613,11 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
fail();
} catch (NoSuchTransportException expected) {}
try {
db.getTransportKeys(transportId);
fail();
} catch (NoSuchTransportException expected) {}
try {
db.mergeConfig(transportId, new TransportConfig());
fail();
@@ -909,7 +906,8 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
oneOf(database).containsContact(txn, contactId);
will(returnValue(true));
oneOf(database).getSubscriptionUpdate(txn, contactId, maxLatency);
will(returnValue(new SubscriptionUpdate(Arrays.asList(group), 1)));
will(returnValue(new SubscriptionUpdate(
Collections.singletonList(group), 1)));
oneOf(database).commitTransaction(txn);
}});
DatabaseComponent db = createDatabaseComponent(database, cleaner,
@@ -917,7 +915,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
SubscriptionUpdate u = db.generateSubscriptionUpdate(contactId,
maxLatency);
assertEquals(Arrays.asList(group), u.getGroups());
assertEquals(Collections.singletonList(group), u.getGroups());
assertEquals(1, u.getVersion());
context.assertIsSatisfied();
@@ -962,8 +960,8 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
oneOf(database).containsContact(txn, contactId);
will(returnValue(true));
oneOf(database).getTransportUpdates(txn, contactId, maxLatency);
will(returnValue(Arrays.asList(new TransportUpdate(transportId,
transportProperties, 1))));
will(returnValue(Collections.singletonList(new TransportUpdate(
transportId, transportProperties, 1))));
oneOf(database).commitTransaction(txn);
}});
DatabaseComponent db = createDatabaseComponent(database, cleaner,
@@ -1003,7 +1001,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
DatabaseComponent db = createDatabaseComponent(database, cleaner,
eventBus, shutdown);
db.receiveAck(contactId, new Ack(Arrays.asList(messageId)));
db.receiveAck(contactId, new Ack(Collections.singletonList(messageId)));
context.assertIsSatisfied();
}
@@ -1027,9 +1025,9 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
will(returnValue(true));
oneOf(database).addMessage(txn, message, false);
oneOf(database).getVisibility(txn, groupId);
will(returnValue(Arrays.asList(contactId)));
will(returnValue(Collections.singletonList(contactId)));
oneOf(database).getContactIds(txn);
will(returnValue(Arrays.asList(contactId)));
will(returnValue(Collections.singletonList(contactId)));
oneOf(database).removeOfferedMessage(txn, contactId, messageId);
will(returnValue(false));
oneOf(database).addStatus(txn, contactId, messageId, false, true);
@@ -1176,7 +1174,8 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
DatabaseComponent db = createDatabaseComponent(database, cleaner,
eventBus, shutdown);
db.receiveRequest(contactId, new Request(Arrays.asList(messageId)));
db.receiveRequest(contactId, new Request(Collections.singletonList(
messageId)));
context.assertIsSatisfied();
}
@@ -1244,13 +1243,15 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
will(returnValue(txn));
oneOf(database).containsContact(txn, contactId);
will(returnValue(true));
oneOf(database).setGroups(txn, contactId, Arrays.asList(group), 1);
oneOf(database).setGroups(txn, contactId,
Collections.singletonList(group), 1);
oneOf(database).commitTransaction(txn);
}});
DatabaseComponent db = createDatabaseComponent(database, cleaner,
eventBus, shutdown);
SubscriptionUpdate u = new SubscriptionUpdate(Arrays.asList(group), 1);
SubscriptionUpdate u = new SubscriptionUpdate(
Collections.singletonList(group), 1);
db.receiveSubscriptionUpdate(contactId, u);
context.assertIsSatisfied();
@@ -1398,7 +1399,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
DatabaseComponent db = createDatabaseComponent(database, cleaner,
eventBus, shutdown);
db.setVisibility(groupId, Arrays.asList(contactId));
db.setVisibility(groupId, Collections.singletonList(contactId));
context.assertIsSatisfied();
}
@@ -1467,7 +1468,7 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
will(returnValue(true));
oneOf(database).setVisibleToAll(txn, groupId, true);
oneOf(database).getVisibility(txn, groupId);
will(returnValue(Arrays.asList(contactId)));
will(returnValue(Collections.singletonList(contactId)));
oneOf(database).getContactIds(txn);
will(returnValue(both));
oneOf(database).addVisibility(txn, contactId1, groupId);
@@ -1478,14 +1479,15 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
DatabaseComponent db = createDatabaseComponent(database, cleaner,
eventBus, shutdown);
db.setVisibility(groupId, Arrays.asList(contactId));
db.setVisibility(groupId, Collections.singletonList(contactId));
db.setVisibleToAll(groupId, true);
context.assertIsSatisfied();
}
@Test
public void testTemporarySecrets() throws Exception {
public void testTransportKeys() throws Exception {
final TransportKeys keys = createTransportKeys();
Mockery context = new Mockery();
@SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class);
@@ -1493,28 +1495,52 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
final ShutdownManager shutdown = context.mock(ShutdownManager.class);
final EventBus eventBus = context.mock(EventBus.class);
context.checking(new Expectations() {{
// addSecrets()
// updateTransportKeys()
oneOf(database).startTransaction();
will(returnValue(txn));
oneOf(database).containsContact(txn, contactId);
will(returnValue(true));
oneOf(database).containsTransport(txn, transportId);
will(returnValue(true));
oneOf(database).addSecrets(txn, Arrays.asList(temporarySecret));
oneOf(database).updateTransportKeys(txn,
Collections.singletonMap(contactId, keys));
oneOf(database).commitTransaction(txn);
// getSecrets()
// getTransportKeys()
oneOf(database).startTransaction();
will(returnValue(txn));
oneOf(database).getSecrets(txn);
will(returnValue(Arrays.asList(temporarySecret)));
oneOf(database).containsTransport(txn, transportId);
will(returnValue(true));
oneOf(database).getTransportKeys(txn, transportId);
will(returnValue(Collections.singletonMap(contactId, keys)));
oneOf(database).commitTransaction(txn);
}});
DatabaseComponent db = createDatabaseComponent(database, cleaner,
eventBus, shutdown);
db.addSecrets(Arrays.asList(temporarySecret));
assertEquals(Arrays.asList(temporarySecret), db.getSecrets());
db.updateTransportKeys(Collections.singletonMap(contactId, keys));
assertEquals(Collections.singletonMap(contactId, keys),
db.getTransportKeys(transportId));
context.assertIsSatisfied();
}
private TransportKeys createTransportKeys() {
SecretKey inPrevTagKey = TestUtils.createSecretKey();
SecretKey inPrevHeaderKey = TestUtils.createSecretKey();
IncomingKeys inPrev = new IncomingKeys(inPrevTagKey, inPrevHeaderKey,
1, 123, new byte[4]);
SecretKey inCurrTagKey = TestUtils.createSecretKey();
SecretKey inCurrHeaderKey = TestUtils.createSecretKey();
IncomingKeys inCurr = new IncomingKeys(inCurrTagKey, inCurrHeaderKey,
2, 234, new byte[4]);
SecretKey inNextTagKey = TestUtils.createSecretKey();
SecretKey inNextHeaderKey = TestUtils.createSecretKey();
IncomingKeys inNext = new IncomingKeys(inNextTagKey, inNextHeaderKey,
3, 345, new byte[4]);
SecretKey outCurrTagKey = TestUtils.createSecretKey();
SecretKey outCurrHeaderKey = TestUtils.createSecretKey();
OutgoingKeys outCurr = new OutgoingKeys(outCurrTagKey, outCurrHeaderKey,
2, 456);
return new TransportKeys(transportId, inPrev, inCurr, inNext, outCurr);
}
}

View File

@@ -11,14 +11,16 @@ import org.briarproject.api.LocalAuthor;
import org.briarproject.api.TransportConfig;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.db.DbException;
import org.briarproject.api.db.MessageHeader;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.GroupId;
import org.briarproject.api.messaging.Message;
import org.briarproject.api.messaging.MessageId;
import org.briarproject.api.transport.Endpoint;
import org.briarproject.api.transport.TemporarySecret;
import org.briarproject.api.transport.IncomingKeys;
import org.briarproject.api.transport.OutgoingKeys;
import org.briarproject.api.transport.TransportKeys;
import org.briarproject.system.SystemClock;
import org.junit.After;
import org.junit.Before;
@@ -34,6 +36,7 @@ import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -59,7 +62,6 @@ public class H2DatabaseTest extends BriarTestCase {
private final Random random = new Random();
private final GroupId groupId;
private final Group group;
private final AuthorId authorId;
private final Author author;
private final AuthorId localAuthorId;
private final LocalAuthor localAuthor;
@@ -75,7 +77,7 @@ public class H2DatabaseTest extends BriarTestCase {
public H2DatabaseTest() throws Exception {
groupId = new GroupId(TestUtils.getRandomId());
group = new Group(groupId, "Group", new byte[GROUP_SALT_LENGTH]);
authorId = new AuthorId(TestUtils.getRandomId());
AuthorId authorId = new AuthorId(TestUtils.getRandomId());
author = new Author(authorId, "Alice", new byte[MAX_PUBLIC_KEY_LENGTH]);
localAuthorId = new AuthorId(TestUtils.getRandomId());
localAuthor = new LocalAuthor(localAuthorId, "Bob",
@@ -171,7 +173,7 @@ public class H2DatabaseTest extends BriarTestCase {
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.addGroup(txn, group);
db.addVisibility(txn, contactId, groupId);
db.setGroups(txn, contactId, Arrays.asList(group), 1);
db.setGroups(txn, contactId, Collections.singletonList(group), 1);
db.addMessage(txn, message, true);
// The message has no status yet, so it should not be sendable
@@ -216,7 +218,7 @@ public class H2DatabaseTest extends BriarTestCase {
assertTrue(ids.isEmpty());
// The contact subscribing should make the message sendable
db.setGroups(txn, contactId, Arrays.asList(group), 1);
db.setGroups(txn, contactId, Collections.singletonList(group), 1);
ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
assertFalse(ids.isEmpty());
Iterator<MessageId> it = ids.iterator();
@@ -243,7 +245,7 @@ public class H2DatabaseTest extends BriarTestCase {
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.addGroup(txn, group);
db.addVisibility(txn, contactId, groupId);
db.setGroups(txn, contactId, Arrays.asList(group), 1);
db.setGroups(txn, contactId, Collections.singletonList(group), 1);
db.addMessage(txn, message, true);
db.addStatus(txn, contactId, messageId, false, false);
@@ -273,7 +275,7 @@ public class H2DatabaseTest extends BriarTestCase {
db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.addGroup(txn, group);
db.setGroups(txn, contactId, Arrays.asList(group), 1);
db.setGroups(txn, contactId, Collections.singletonList(group), 1);
db.addMessage(txn, message, true);
db.addStatus(txn, contactId, messageId, false, false);
@@ -305,7 +307,7 @@ public class H2DatabaseTest extends BriarTestCase {
db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.addGroup(txn, group);
db.setGroups(txn, contactId, Arrays.asList(group), 1);
db.setGroups(txn, contactId, Collections.singletonList(group), 1);
// Add some messages to ack
MessageId messageId1 = new MessageId(TestUtils.getRandomId());
@@ -342,7 +344,7 @@ public class H2DatabaseTest extends BriarTestCase {
db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.addGroup(txn, group);
db.setGroups(txn, contactId, Arrays.asList(group), 1);
db.setGroups(txn, contactId, Collections.singletonList(group), 1);
// Receive the same message twice
db.addMessage(txn, message, true);
@@ -352,10 +354,10 @@ public class H2DatabaseTest extends BriarTestCase {
// The message ID should only be returned once
Collection<MessageId> ids = db.getMessagesToAck(txn, contactId, 1234);
assertEquals(Arrays.asList(messageId), ids);
assertEquals(Collections.singletonList(messageId), ids);
// Remove the message ID
db.lowerAckFlag(txn, contactId, Arrays.asList(messageId));
db.lowerAckFlag(txn, contactId, Collections.singletonList(messageId));
// The message ID should have been removed
assertEquals(Collections.emptyList(), db.getMessagesToAck(txn,
@@ -375,7 +377,7 @@ public class H2DatabaseTest extends BriarTestCase {
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.addGroup(txn, group);
db.addVisibility(txn, contactId, groupId);
db.setGroups(txn, contactId, Arrays.asList(group), 1);
db.setGroups(txn, contactId, Collections.singletonList(group), 1);
db.addMessage(txn, message, true);
db.addStatus(txn, contactId, messageId, false, false);
@@ -674,7 +676,7 @@ public class H2DatabaseTest extends BriarTestCase {
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.addGroup(txn, group);
db.addVisibility(txn, contactId, groupId);
db.setGroups(txn, contactId, Arrays.asList(group), 1);
db.setGroups(txn, contactId, Collections.singletonList(group), 1);
// The message is not in the database
assertFalse(db.containsVisibleMessage(txn, contactId, messageId));
@@ -692,7 +694,7 @@ public class H2DatabaseTest extends BriarTestCase {
// Add a contact with a subscription
db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.setGroups(txn, contactId, Arrays.asList(group), 1);
db.setGroups(txn, contactId, Collections.singletonList(group), 1);
// There's no local subscription for the group
assertFalse(db.containsVisibleMessage(txn, contactId, messageId));
@@ -711,7 +713,7 @@ public class H2DatabaseTest extends BriarTestCase {
db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.addGroup(txn, group);
db.setGroups(txn, contactId, Arrays.asList(group), 1);
db.setGroups(txn, contactId, Collections.singletonList(group), 1);
db.addMessage(txn, message, true);
db.addStatus(txn, contactId, messageId, false, false);
@@ -737,7 +739,8 @@ public class H2DatabaseTest extends BriarTestCase {
// Make the group visible to the contact
db.addVisibility(txn, contactId, groupId);
assertEquals(Arrays.asList(contactId), db.getVisibility(txn, groupId));
assertEquals(Collections.singletonList(contactId),
db.getVisibility(txn, groupId));
// Make the group invisible again
db.removeVisibility(txn, contactId, groupId);
@@ -1111,183 +1114,97 @@ public class H2DatabaseTest extends BriarTestCase {
}
@Test
public void testTemporarySecrets() throws Exception {
// Create an endpoint and four consecutive temporary secrets
long epoch = 123;
int latency = 234;
boolean alice = false;
long outgoing1 = 345, centre1 = 456;
long outgoing2 = 567, centre2 = 678;
long outgoing3 = 789, centre3 = 890;
long outgoing4 = 901, centre4 = 123;
Endpoint ep = new Endpoint(contactId, transportId, epoch, alice);
Random random = new Random();
byte[] secret1 = new byte[32], bitmap1 = new byte[4];
random.nextBytes(secret1);
random.nextBytes(bitmap1);
TemporarySecret s1 = new TemporarySecret(contactId, transportId, epoch,
alice, 0, secret1, outgoing1, centre1, bitmap1);
byte[] secret2 = new byte[32], bitmap2 = new byte[4];
random.nextBytes(secret2);
random.nextBytes(bitmap2);
TemporarySecret s2 = new TemporarySecret(contactId, transportId, epoch,
alice, 1, secret2, outgoing2, centre2, bitmap2);
byte[] secret3 = new byte[32], bitmap3 = new byte[4];
random.nextBytes(secret3);
random.nextBytes(bitmap3);
TemporarySecret s3 = new TemporarySecret(contactId, transportId, epoch,
alice, 2, secret3, outgoing3, centre3, bitmap3);
byte[] secret4 = new byte[32], bitmap4 = new byte[4];
random.nextBytes(secret4);
random.nextBytes(bitmap4);
TemporarySecret s4 = new TemporarySecret(contactId, transportId, epoch,
alice, 3, secret4, outgoing4, centre4, bitmap4);
public void testTransportKeys() throws Exception {
TransportKeys keys = createTransportKeys();
Database<Connection> db = open(false);
Connection txn = db.startTransaction();
// Initially there should be no secrets in the database
assertEquals(Collections.emptyList(), db.getSecrets(txn));
// Initially there should be no transport keys in the database
assertEquals(Collections.emptyMap(),
db.getTransportKeys(txn, transportId));
// Add the contact, the transport, the endpoint and the first three
// secrets (periods 0, 1 and 2)
// Add the contact, the transport and the transport keys
db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.addTransport(txn, transportId, latency);
db.addEndpoint(txn, ep);
db.addSecrets(txn, Arrays.asList(s1, s2, s3));
db.addTransport(txn, transportId, 123);
db.addTransportKeys(txn, contactId, keys);
// Retrieve the first three secrets
Collection<TemporarySecret> secrets = db.getSecrets(txn);
assertEquals(3, secrets.size());
boolean foundFirst = false, foundSecond = false, foundThird = false;
for (TemporarySecret s : secrets) {
assertEquals(contactId, s.getContactId());
assertEquals(transportId, s.getTransportId());
assertEquals(epoch, s.getEpoch());
assertEquals(alice, s.getAlice());
if (s.getPeriod() == 0) {
assertArrayEquals(secret1, s.getSecret());
assertEquals(outgoing1, s.getOutgoingStreamCounter());
assertEquals(centre1, s.getWindowCentre());
assertArrayEquals(bitmap1, s.getWindowBitmap());
foundFirst = true;
} else if (s.getPeriod() == 1) {
assertArrayEquals(secret2, s.getSecret());
assertEquals(outgoing2, s.getOutgoingStreamCounter());
assertEquals(centre2, s.getWindowCentre());
assertArrayEquals(bitmap2, s.getWindowBitmap());
foundSecond = true;
} else if (s.getPeriod() == 2) {
assertArrayEquals(secret3, s.getSecret());
assertEquals(outgoing3, s.getOutgoingStreamCounter());
assertEquals(centre3, s.getWindowCentre());
assertArrayEquals(bitmap3, s.getWindowBitmap());
foundThird = true;
} else {
fail();
}
}
assertTrue(foundFirst);
assertTrue(foundSecond);
assertTrue(foundThird);
// Retrieve the transport keys
Map<ContactId, TransportKeys> newKeys =
db.getTransportKeys(txn, transportId);
assertEquals(1, newKeys.size());
Entry<ContactId, TransportKeys> e =
newKeys.entrySet().iterator().next();
assertEquals(contactId, e.getKey());
TransportKeys k = e.getValue();
assertEquals(transportId, k.getTransportId());
assertKeysEquals(keys.getPreviousIncomingKeys(),
k.getPreviousIncomingKeys());
assertKeysEquals(keys.getCurrentIncomingKeys(),
k.getCurrentIncomingKeys());
assertKeysEquals(keys.getNextIncomingKeys(),
k.getNextIncomingKeys());
assertKeysEquals(keys.getCurrentOutgoingKeys(),
k.getCurrentOutgoingKeys());
// Adding the fourth secret (period 3) should delete the first
db.addSecrets(txn, Arrays.asList(s4));
secrets = db.getSecrets(txn);
assertEquals(3, secrets.size());
foundSecond = foundThird = false;
boolean foundFourth = false;
for (TemporarySecret s : secrets) {
assertEquals(contactId, s.getContactId());
assertEquals(transportId, s.getTransportId());
assertEquals(epoch, s.getEpoch());
assertEquals(alice, s.getAlice());
if (s.getPeriod() == 1) {
assertArrayEquals(secret2, s.getSecret());
assertEquals(outgoing2, s.getOutgoingStreamCounter());
assertEquals(centre2, s.getWindowCentre());
assertArrayEquals(bitmap2, s.getWindowBitmap());
foundSecond = true;
} else if (s.getPeriod() == 2) {
assertArrayEquals(secret3, s.getSecret());
assertEquals(outgoing3, s.getOutgoingStreamCounter());
assertEquals(centre3, s.getWindowCentre());
assertArrayEquals(bitmap3, s.getWindowBitmap());
foundThird = true;
} else if (s.getPeriod() == 3) {
assertArrayEquals(secret4, s.getSecret());
assertEquals(outgoing4, s.getOutgoingStreamCounter());
assertEquals(centre4, s.getWindowCentre());
assertArrayEquals(bitmap4, s.getWindowBitmap());
foundFourth = true;
} else {
fail();
}
}
assertTrue(foundSecond);
assertTrue(foundThird);
assertTrue(foundFourth);
// Removing the contact should remove the secrets
// Removing the contact should remove the transport keys
db.removeContact(txn, contactId);
assertEquals(Collections.emptyList(), db.getSecrets(txn));
assertEquals(Collections.emptyMap(),
db.getTransportKeys(txn, transportId));
db.commitTransaction(txn);
db.close();
}
private void assertKeysEquals(IncomingKeys expected, IncomingKeys actual) {
assertArrayEquals(expected.getTagKey().getBytes(),
actual.getTagKey().getBytes());
assertArrayEquals(expected.getHeaderKey().getBytes(),
actual.getHeaderKey().getBytes());
assertEquals(expected.getRotationPeriod(), actual.getRotationPeriod());
assertEquals(expected.getWindowBase(), actual.getWindowBase());
assertArrayEquals(expected.getWindowBitmap(), actual.getWindowBitmap());
}
private void assertKeysEquals(OutgoingKeys expected, OutgoingKeys actual) {
assertArrayEquals(expected.getTagKey().getBytes(),
actual.getTagKey().getBytes());
assertArrayEquals(expected.getHeaderKey().getBytes(),
actual.getHeaderKey().getBytes());
assertEquals(expected.getRotationPeriod(), actual.getRotationPeriod());
assertEquals(expected.getStreamCounter(), actual.getStreamCounter());
}
@Test
public void testIncrementStreamCounter() throws Exception {
// Create an endpoint and a temporary secret
long epoch = 123;
int latency = 234;
boolean alice = false;
long period = 345, outgoing = 456, centre = 567;
Endpoint ep = new Endpoint(contactId, transportId, epoch, alice);
Random random = new Random();
byte[] secret = new byte[32], bitmap = new byte[4];
random.nextBytes(secret);
TemporarySecret s = new TemporarySecret(contactId, transportId, epoch,
alice, period, secret, outgoing, centre, bitmap);
TransportKeys keys = createTransportKeys();
long rotationPeriod = keys.getCurrentOutgoingKeys().getRotationPeriod();
long streamCounter = keys.getCurrentOutgoingKeys().getStreamCounter();
Database<Connection> db = open(false);
Connection txn = db.startTransaction();
// Add the contact, transport, endpoint and temporary secret
// Add the contact, transport and transport keys
db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.addTransport(txn, transportId, latency);
db.addEndpoint(txn, ep);
db.addSecrets(txn, Arrays.asList(s));
db.addTransport(txn, transportId, 123);
db.updateTransportKeys(txn, Collections.singletonMap(contactId, keys));
// Retrieve the secret
Collection<TemporarySecret> secrets = db.getSecrets(txn);
assertEquals(1, secrets.size());
s = secrets.iterator().next();
assertEquals(contactId, s.getContactId());
assertEquals(transportId, s.getTransportId());
assertEquals(period, s.getPeriod());
assertArrayEquals(secret, s.getSecret());
assertEquals(outgoing, s.getOutgoingStreamCounter());
assertEquals(centre, s.getWindowCentre());
assertArrayEquals(bitmap, s.getWindowBitmap());
// Increment the stream counter twice and retrieve the secret again
assertEquals(outgoing, db.incrementStreamCounter(txn,
s.getContactId(), s.getTransportId(), s.getPeriod()));
assertEquals(outgoing + 1, db.incrementStreamCounter(txn,
s.getContactId(), s.getTransportId(), s.getPeriod()));
secrets = db.getSecrets(txn);
assertEquals(1, secrets.size());
s = secrets.iterator().next();
assertEquals(contactId, s.getContactId());
assertEquals(transportId, s.getTransportId());
assertEquals(period, s.getPeriod());
assertArrayEquals(secret, s.getSecret());
assertEquals(outgoing + 2, s.getOutgoingStreamCounter());
assertEquals(centre, s.getWindowCentre());
assertArrayEquals(bitmap, s.getWindowBitmap());
// Increment the stream counter twice and retrieve the transport keys
db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod);
db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod);
Map<ContactId, TransportKeys> newKeys =
db.getTransportKeys(txn, transportId);
assertEquals(1, newKeys.size());
Entry<ContactId, TransportKeys> e =
newKeys.entrySet().iterator().next();
assertEquals(contactId, e.getKey());
TransportKeys k = e.getValue();
assertEquals(transportId, k.getTransportId());
OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
assertEquals(rotationPeriod, outCurr.getRotationPeriod());
assertEquals(streamCounter + 2, outCurr.getStreamCounter());
db.commitTransaction(txn);
db.close();
@@ -1295,123 +1212,36 @@ public class H2DatabaseTest extends BriarTestCase {
@Test
public void testSetReorderingWindow() throws Exception {
// Create an endpoint and a temporary secret
long epoch = 123;
int latency = 234;
boolean alice = false;
long period = 345, outgoing = 456, centre = 567;
Endpoint ep = new Endpoint(contactId, transportId, epoch, alice);
Random random = new Random();
byte[] secret = new byte[32], bitmap = new byte[4];
random.nextBytes(secret);
TemporarySecret s = new TemporarySecret(contactId, transportId, epoch,
alice, period, secret, outgoing, centre, bitmap);
TransportKeys keys = createTransportKeys();
long rotationPeriod = keys.getCurrentIncomingKeys().getRotationPeriod();
long base = keys.getCurrentIncomingKeys().getWindowBase();
byte[] bitmap = keys.getCurrentIncomingKeys().getWindowBitmap();
Database<Connection> db = open(false);
Connection txn = db.startTransaction();
// Add the contact, transport, endpoint and temporary secret
// Add the contact, transport and transport keys
db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.addTransport(txn, transportId, latency);
db.addEndpoint(txn, ep);
db.addSecrets(txn, Arrays.asList(s));
db.addTransport(txn, transportId, 123);
db.updateTransportKeys(txn, Collections.singletonMap(contactId, keys));
// Retrieve the secret
Collection<TemporarySecret> secrets = db.getSecrets(txn);
assertEquals(1, secrets.size());
s = secrets.iterator().next();
assertEquals(contactId, s.getContactId());
assertEquals(transportId, s.getTransportId());
assertEquals(period, s.getPeriod());
assertArrayEquals(secret, s.getSecret());
assertEquals(outgoing, s.getOutgoingStreamCounter());
assertEquals(centre, s.getWindowCentre());
assertArrayEquals(bitmap, s.getWindowBitmap());
// Update the reordering window and retrieve the secret again
// Update the reordering window and retrieve the transport keys
random.nextBytes(bitmap);
db.setReorderingWindow(txn, contactId, transportId, period, centre,
bitmap);
secrets = db.getSecrets(txn);
assertEquals(1, secrets.size());
s = secrets.iterator().next();
assertEquals(contactId, s.getContactId());
assertEquals(transportId, s.getTransportId());
assertEquals(period, s.getPeriod());
assertArrayEquals(secret, s.getSecret());
assertEquals(outgoing, s.getOutgoingStreamCounter());
assertEquals(centre, s.getWindowCentre());
assertArrayEquals(bitmap, s.getWindowBitmap());
// Updating a nonexistent window should not throw an exception
db.setReorderingWindow(txn, contactId, transportId, period + 1, 1,
bitmap);
// The nonexistent window should not have been created
secrets = db.getSecrets(txn);
assertEquals(1, secrets.size());
s = secrets.iterator().next();
assertEquals(contactId, s.getContactId());
assertEquals(transportId, s.getTransportId());
assertEquals(period, s.getPeriod());
assertArrayEquals(secret, s.getSecret());
assertEquals(outgoing, s.getOutgoingStreamCounter());
assertEquals(centre, s.getWindowCentre());
assertArrayEquals(bitmap, s.getWindowBitmap());
db.commitTransaction(txn);
db.close();
}
@Test
public void testEndpoints() throws Exception {
// Create some endpoints
long epoch1 = 123, epoch2 = 234;
int latency1 = 345, latency2 = 456;
boolean alice1 = true, alice2 = false;
TransportId transportId1 = new TransportId("bar");
TransportId transportId2 = new TransportId("baz");
Endpoint ep1 = new Endpoint(contactId, transportId1, epoch1, alice1);
Endpoint ep2 = new Endpoint(contactId, transportId2, epoch2, alice2);
Database<Connection> db = open(false);
Connection txn = db.startTransaction();
// Initially there should be no endpoints in the database
assertEquals(Collections.emptyList(), db.getEndpoints(txn));
// Add the contact, the transports and the endpoints
db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.addTransport(txn, transportId1, latency1);
db.addTransport(txn, transportId2, latency2);
db.addEndpoint(txn, ep1);
db.addEndpoint(txn, ep2);
// Retrieve the endpoints
Collection<Endpoint> endpoints = db.getEndpoints(txn);
assertEquals(2, endpoints.size());
boolean foundFirst = false, foundSecond = false;
for (Endpoint ep : endpoints) {
assertEquals(contactId, ep.getContactId());
if (ep.getTransportId().equals(transportId1)) {
assertEquals(epoch1, ep.getEpoch());
assertEquals(alice1, ep.getAlice());
foundFirst = true;
} else if (ep.getTransportId().equals(transportId2)) {
assertEquals(epoch2, ep.getEpoch());
assertEquals(alice2, ep.getAlice());
foundSecond = true;
} else {
fail();
}
}
assertTrue(foundFirst);
assertTrue(foundSecond);
// Removing the contact should remove the endpoints
db.removeContact(txn, contactId);
assertEquals(Collections.emptyList(), db.getEndpoints(txn));
db.setReorderingWindow(txn, contactId, transportId, rotationPeriod,
base + 1, bitmap);
Map<ContactId, TransportKeys> newKeys =
db.getTransportKeys(txn, transportId);
assertEquals(1, newKeys.size());
Entry<ContactId, TransportKeys> e =
newKeys.entrySet().iterator().next();
assertEquals(contactId, e.getKey());
TransportKeys k = e.getValue();
assertEquals(transportId, k.getTransportId());
IncomingKeys inCurr = k.getCurrentIncomingKeys();
assertEquals(rotationPeriod, inCurr.getRotationPeriod());
assertEquals(base + 1, inCurr.getWindowBase());
assertArrayEquals(bitmap, inCurr.getWindowBitmap());
db.commitTransaction(txn);
db.close();
@@ -1431,27 +1261,30 @@ public class H2DatabaseTest extends BriarTestCase {
db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
assertEquals(contactId1, db.addContact(txn, author1, localAuthorId));
db.setGroups(txn, contactId, Arrays.asList(group), 1);
db.setGroups(txn, contactId1, Arrays.asList(group), 1);
db.setGroups(txn, contactId, Collections.singletonList(group), 1);
db.setGroups(txn, contactId1, Collections.singletonList(group), 1);
// The group should be available
assertEquals(Collections.emptyList(), db.getGroups(txn));
assertEquals(Arrays.asList(group), db.getAvailableGroups(txn));
assertEquals(Collections.singletonList(group),
db.getAvailableGroups(txn));
// Subscribe to the group - it should no longer be available
db.addGroup(txn, group);
assertEquals(Arrays.asList(group), db.getGroups(txn));
assertEquals(Collections.singletonList(group), db.getGroups(txn));
assertEquals(Collections.emptyList(), db.getAvailableGroups(txn));
// Unsubscribe from the group - it should be available again
db.removeGroup(txn, groupId);
assertEquals(Collections.emptyList(), db.getGroups(txn));
assertEquals(Arrays.asList(group), db.getAvailableGroups(txn));
assertEquals(Collections.singletonList(group),
db.getAvailableGroups(txn));
// The first contact unsubscribes - it should still be available
db.setGroups(txn, contactId, Collections.<Group>emptyList(), 2);
assertEquals(Collections.emptyList(), db.getGroups(txn));
assertEquals(Arrays.asList(group), db.getAvailableGroups(txn));
assertEquals(Collections.singletonList(group),
db.getAvailableGroups(txn));
// The second contact unsubscribes - it should no longer be available
db.setGroups(txn, contactId1, Collections.<Group>emptyList(), 2);
@@ -1501,9 +1334,8 @@ public class H2DatabaseTest extends BriarTestCase {
db.getInboxMessageHeaders(txn, contactId));
// Add a message to the inbox group - the header should be returned
boolean local = true, seen = false;
db.addMessage(txn, message, local);
db.addStatus(txn, contactId, messageId, false, seen);
db.addMessage(txn, message, true);
db.addStatus(txn, contactId, messageId, false, false);
Collection<MessageHeader> headers =
db.getInboxMessageHeaders(txn, contactId);
assertEquals(1, headers.size());
@@ -1514,7 +1346,7 @@ public class H2DatabaseTest extends BriarTestCase {
assertEquals(localAuthor, header.getAuthor());
assertEquals(contentType, header.getContentType());
assertEquals(timestamp, header.getTimestamp());
assertEquals(local, header.isLocal());
assertEquals(true, header.isLocal());
assertEquals(false, header.isRead());
assertEquals(STORED, header.getStatus());
assertFalse(header.isRead());
@@ -1560,7 +1392,7 @@ public class H2DatabaseTest extends BriarTestCase {
// Add a contact who subscribes to a group
db.addLocalAuthor(txn, localAuthor);
assertEquals(contactId, db.addContact(txn, author, localAuthorId));
db.setGroups(txn, contactId, Arrays.asList(group), 1);
db.setGroups(txn, contactId, Collections.singletonList(group), 1);
// Subscribe to the group and make it visible to the contact
db.addGroup(txn, group);
@@ -1571,7 +1403,7 @@ public class H2DatabaseTest extends BriarTestCase {
db.addStatus(txn, contactId, messageId, false, false);
Collection<MessageId> sendable = db.getMessagesToSend(txn, contactId,
ONE_MEGABYTE);
assertEquals(Arrays.asList(messageId), sendable);
assertEquals(Collections.singletonList(messageId), sendable);
// Mark the message as seen - it should no longer be sendable
db.raiseSeenFlag(txn, contactId, messageId);
@@ -1584,9 +1416,9 @@ public class H2DatabaseTest extends BriarTestCase {
assertEquals(Collections.emptyList(), sendable);
// The contact resubscribes - the message should be sendable again
db.setGroups(txn, contactId, Arrays.asList(group), 3);
db.setGroups(txn, contactId, Collections.singletonList(group), 3);
sendable = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
assertEquals(Arrays.asList(messageId), sendable);
assertEquals(Collections.singletonList(messageId), sendable);
db.commitTransaction(txn);
db.close();
@@ -1616,6 +1448,26 @@ public class H2DatabaseTest extends BriarTestCase {
return db;
}
private TransportKeys createTransportKeys() {
SecretKey inPrevTagKey = TestUtils.createSecretKey();
SecretKey inPrevHeaderKey = TestUtils.createSecretKey();
IncomingKeys inPrev = new IncomingKeys(inPrevTagKey, inPrevHeaderKey,
1, 123, new byte[4]);
SecretKey inCurrTagKey = TestUtils.createSecretKey();
SecretKey inCurrHeaderKey = TestUtils.createSecretKey();
IncomingKeys inCurr = new IncomingKeys(inCurrTagKey, inCurrHeaderKey,
2, 234, new byte[4]);
SecretKey inNextTagKey = TestUtils.createSecretKey();
SecretKey inNextHeaderKey = TestUtils.createSecretKey();
IncomingKeys inNext = new IncomingKeys(inNextTagKey, inNextHeaderKey,
3, 345, new byte[4]);
SecretKey outCurrTagKey = TestUtils.createSecretKey();
SecretKey outCurrHeaderKey = TestUtils.createSecretKey();
OutgoingKeys outCurr = new OutgoingKeys(outCurrTagKey, outCurrHeaderKey,
2, 456);
return new TransportKeys(transportId, inPrev, inCurr, inNext, outCurr);
}
@After
public void tearDown() {
TestUtils.deleteTestDirectory(testDir);

View File

@@ -13,7 +13,8 @@ import org.briarproject.api.AuthorId;
import org.briarproject.api.ContactId;
import org.briarproject.api.LocalAuthor;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.event.Event;
import org.briarproject.api.event.EventBus;
@@ -29,11 +30,11 @@ import org.briarproject.api.messaging.PacketReader;
import org.briarproject.api.messaging.PacketReaderFactory;
import org.briarproject.api.messaging.PacketWriter;
import org.briarproject.api.messaging.PacketWriterFactory;
import org.briarproject.api.transport.Endpoint;
import org.briarproject.api.transport.KeyManager;
import org.briarproject.api.transport.StreamContext;
import org.briarproject.api.transport.StreamReaderFactory;
import org.briarproject.api.transport.StreamWriterFactory;
import org.briarproject.api.transport.TagRecogniser;
import org.briarproject.api.transport.TransportKeys;
import org.briarproject.crypto.CryptoModule;
import org.briarproject.data.DataModule;
import org.briarproject.db.DatabaseModule;
@@ -49,7 +50,7 @@ import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Random;
import java.util.Collections;
import static org.briarproject.api.AuthorConstants.MAX_PUBLIC_KEY_LENGTH;
import static org.briarproject.api.messaging.MessagingConstants.GROUP_SALT_LENGTH;
@@ -63,26 +64,18 @@ import static org.junit.Assert.assertTrue;
public class SimplexMessagingIntegrationTest extends BriarTestCase {
private static final int MAX_LATENCY = 2 * 60 * 1000; // 2 minutes
private static final int ROTATION_PERIOD =
MAX_CLOCK_DIFFERENCE + MAX_LATENCY;
private static final long ROTATION_PERIOD_LENGTH =
MAX_LATENCY + MAX_CLOCK_DIFFERENCE;
private final File testDir = TestUtils.getTestDirectory();
private final File aliceDir = new File(testDir, "alice");
private final File bobDir = new File(testDir, "bob");
private final TransportId transportId;
private final byte[] initialSecret;
private final long epoch;
private final TransportId transportId = new TransportId("id");
private final SecretKey master = TestUtils.createSecretKey();
private final long timestamp = System.currentTimeMillis();
private Injector alice, bob;
public SimplexMessagingIntegrationTest() throws Exception {
transportId = new TransportId("id");
// Create matching secrets for Alice and Bob
initialSecret = new byte[32];
new Random().nextBytes(initialSecret);
epoch = System.currentTimeMillis() - 2 * ROTATION_PERIOD;
}
@Before
public void setUp() {
testDir.mkdirs();
@@ -125,14 +118,17 @@ public class SimplexMessagingIntegrationTest extends BriarTestCase {
Group group = gf.createGroup("Group", new byte[GROUP_SALT_LENGTH]);
db.addGroup(group);
db.setInboxGroup(contactId, group);
// Add the transport and the endpoint
// Add the transport
db.addTransport(transportId, MAX_LATENCY);
Endpoint ep = new Endpoint(contactId, transportId, epoch, true);
db.addEndpoint(ep);
keyManager.endpointAdded(ep, MAX_LATENCY, initialSecret);
// Derive and store the transport keys
long rotationPeriod = timestamp / ROTATION_PERIOD_LENGTH;
CryptoComponent crypto = alice.getInstance(CryptoComponent.class);
TransportKeys keys = crypto.deriveTransportKeys(transportId, master,
rotationPeriod, true);
db.addTransportKeys(contactId, keys);
keyManager.contactAdded(contactId, Collections.singletonList(keys));
// Send Bob a message
String contentType = "text/plain";
long timestamp = System.currentTimeMillis();
byte[] body = "Hi Bob!".getBytes("UTF-8");
MessageFactory messageFactory = alice.getInstance(MessageFactory.class);
Message message = messageFactory.createAnonymousMessage(null, group,
@@ -166,7 +162,7 @@ public class SimplexMessagingIntegrationTest extends BriarTestCase {
return out.toByteArray();
}
private void read(byte[] b) throws Exception {
private void read(byte[] stream) throws Exception {
// Open Bob's database
DatabaseComponent db = bob.getInstance(DatabaseComponent.class);
assertFalse(db.open());
@@ -188,21 +184,24 @@ public class SimplexMessagingIntegrationTest extends BriarTestCase {
Group group = gf.createGroup("Group", new byte[GROUP_SALT_LENGTH]);
db.addGroup(group);
db.setInboxGroup(contactId, group);
// Add the transport and the endpoint
// Add the transport
db.addTransport(transportId, MAX_LATENCY);
Endpoint ep = new Endpoint(contactId, transportId, epoch, false);
db.addEndpoint(ep);
keyManager.endpointAdded(ep, MAX_LATENCY, initialSecret);
// Derive and store the transport keys
long rotationPeriod = timestamp / ROTATION_PERIOD_LENGTH;
CryptoComponent crypto = bob.getInstance(CryptoComponent.class);
TransportKeys keys = crypto.deriveTransportKeys(transportId, master,
rotationPeriod, false);
db.addTransportKeys(contactId, keys);
keyManager.contactAdded(contactId, Collections.singletonList(keys));
// Set up an event listener
MessageListener listener = new MessageListener();
bob.getInstance(EventBus.class).addListener(listener);
// Create a tag recogniser and recognise the tag
ByteArrayInputStream in = new ByteArrayInputStream(b);
TagRecogniser rec = bob.getInstance(TagRecogniser.class);
// Read and recognise the tag
ByteArrayInputStream in = new ByteArrayInputStream(stream);
byte[] tag = new byte[TAG_LENGTH];
int read = in.read(tag);
assertEquals(tag.length, read);
StreamContext ctx = rec.recogniseTag(transportId, tag);
StreamContext ctx = keyManager.recogniseTag(transportId, tag);
assertNotNull(ctx);
// Create a stream reader
StreamReaderFactory streamReaderFactory =

View File

@@ -1,600 +1,14 @@
package org.briarproject.transport;
import org.briarproject.BriarTestCase;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.event.EventBus;
import org.briarproject.api.event.EventListener;
import org.briarproject.api.system.Clock;
import org.briarproject.api.system.Timer;
import org.briarproject.api.transport.Endpoint;
import org.briarproject.api.transport.StreamContext;
import org.briarproject.api.transport.TagRecogniser;
import org.briarproject.api.transport.TemporarySecret;
import org.jmock.Expectations;
import org.jmock.Mockery;
import org.junit.Test;
import java.util.Arrays;
import java.util.Collections;
import static org.briarproject.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class KeyManagerImplTest extends BriarTestCase {
private static final long EPOCH = 1000L * 1000L * 1000L * 1000L;
private static final int MAX_LATENCY = 2 * 60 * 1000; // 2 minutes
private static final int ROTATION_PERIOD =
MAX_CLOCK_DIFFERENCE + MAX_LATENCY;
private final ContactId contactId;
private final TransportId transportId;
private final byte[] secret0, secret1, secret2, secret3, secret4;
private final byte[] initialSecret;
public KeyManagerImplTest() {
contactId = new ContactId(234);
transportId = new TransportId("id");
secret0 = new byte[32];
secret1 = new byte[32];
secret2 = new byte[32];
secret3 = new byte[32];
secret4 = new byte[32];
for (int i = 0; i < secret0.length; i++) secret0[i] = 1;
for (int i = 0; i < secret1.length; i++) secret1[i] = 2;
for (int i = 0; i < secret2.length; i++) secret2[i] = 3;
for (int i = 0; i < secret3.length; i++) secret3[i] = 4;
for (int i = 0; i < secret4.length; i++) secret4[i] = 5;
initialSecret = new byte[32];
for (int i = 0; i < initialSecret.length; i++) initialSecret[i] = 123;
}
@Test
public void testStartAndStop() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final TagRecogniser tagRecogniser = context.mock(TagRecogniser.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Collections.emptyList()));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.emptyMap()));
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// stop()
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
oneOf(tagRecogniser).removeSecrets();
}});
assertTrue(keyManager.start());
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testEndpointAdded() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final TagRecogniser tagRecogniser = context.mock(TagRecogniser.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The secrets for periods 0 - 2 should be derived
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Collections.emptyList()));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// endpointAdded() during rotation period 1
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
oneOf(crypto).deriveNextSecret(initialSecret, 0);
will(returnValue(secret0));
oneOf(crypto).deriveNextSecret(secret0, 1);
will(returnValue(secret1));
oneOf(crypto).deriveNextSecret(secret1, 2);
will(returnValue(secret2));
oneOf(db).addSecrets(Arrays.asList(s0, s1, s2));
// The secrets for periods 0 - 2 should be added to the recogniser
oneOf(tagRecogniser).addSecret(s0);
oneOf(tagRecogniser).addSecret(s1);
oneOf(tagRecogniser).addSecret(s2);
// stop()
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
oneOf(tagRecogniser).removeSecrets();
}});
assertTrue(keyManager.start());
keyManager.endpointAdded(ep, MAX_LATENCY, initialSecret);
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testEndpointAddedAndGetConnectionContext() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final TagRecogniser tagRecogniser = context.mock(TagRecogniser.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The secrets for periods 0 - 2 should be derived
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Collections.emptyList()));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// endpointAdded() during rotation period 1
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
oneOf(crypto).deriveNextSecret(initialSecret, 0);
will(returnValue(secret0));
oneOf(crypto).deriveNextSecret(secret0, 1);
will(returnValue(secret1));
oneOf(crypto).deriveNextSecret(secret1, 2);
will(returnValue(secret2));
oneOf(db).addSecrets(Arrays.asList(s0, s1, s2));
// The secrets for periods 0 - 2 should be added to the recogniser
oneOf(tagRecogniser).addSecret(s0);
oneOf(tagRecogniser).addSecret(s1);
oneOf(tagRecogniser).addSecret(s2);
// getConnectionContext()
oneOf(db).incrementStreamCounter(contactId, transportId, 1);
will(returnValue(0L));
// stop()
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
oneOf(tagRecogniser).removeSecrets();
}});
assertTrue(keyManager.start());
keyManager.endpointAdded(ep, MAX_LATENCY, initialSecret);
StreamContext ctx =
keyManager.getStreamContext(contactId, transportId);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(transportId, ctx.getTransportId());
assertArrayEquals(secret1, ctx.getSecret());
assertEquals(0, ctx.getStreamNumber());
assertEquals(true, ctx.getAlice());
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testLoadSecretsAtEpoch() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final TagRecogniser tagRecogniser = context.mock(TagRecogniser.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The DB contains the secrets for periods 0 - 2
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Arrays.asList(s0, s1, s2)));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
// The current time is the epoch, the start of period 1
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
// The secrets for periods 0 - 2 should be added to the recogniser
oneOf(tagRecogniser).addSecret(s0);
oneOf(tagRecogniser).addSecret(s1);
oneOf(tagRecogniser).addSecret(s2);
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// stop()
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
oneOf(tagRecogniser).removeSecrets();
}});
assertTrue(keyManager.start());
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testLoadSecretsAtStartOfPeriod2() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final TagRecogniser tagRecogniser = context.mock(TagRecogniser.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The DB contains the secrets for periods 0 - 2
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
// The secret for period 3 should be derived and stored
final TemporarySecret s3 = new TemporarySecret(ep, 3, secret3);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Arrays.asList(s0, s1, s2)));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
// The current time is the start of period 2
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH + ROTATION_PERIOD));
// The secret for period 3 should be derived and stored
oneOf(crypto).deriveNextSecret(secret0, 1);
will(returnValue(secret1));
oneOf(crypto).deriveNextSecret(secret1, 2);
will(returnValue(secret2));
oneOf(crypto).deriveNextSecret(secret2, 3);
will(returnValue(secret3));
oneOf(db).addSecrets(Arrays.asList(s3));
// The secrets for periods 1 - 3 should be added to the recogniser
oneOf(tagRecogniser).addSecret(s1);
oneOf(tagRecogniser).addSecret(s2);
oneOf(tagRecogniser).addSecret(s3);
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// stop()
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
oneOf(tagRecogniser).removeSecrets();
}});
assertTrue(keyManager.start());
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testLoadSecretsAtEndOfPeriod3() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final TagRecogniser tagRecogniser = context.mock(TagRecogniser.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The DB contains the secrets for periods 0 - 2
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
// The secrets for periods 3 and 4 should be derived and stored
final TemporarySecret s3 = new TemporarySecret(ep, 3, secret3);
final TemporarySecret s4 = new TemporarySecret(ep, 4, secret4);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Arrays.asList(s0, s1, s2)));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
// The current time is the end of period 3
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH + 3 * ROTATION_PERIOD - 1));
// The secrets for periods 3 and 4 should be derived from secret 1
oneOf(crypto).deriveNextSecret(secret1, 2);
will(returnValue(secret2));
oneOf(crypto).deriveNextSecret(secret2, 3);
will(returnValue(secret3));
oneOf(crypto).deriveNextSecret(secret3, 4);
will(returnValue(secret4));
// The new secrets should be stored
oneOf(db).addSecrets(Arrays.asList(s3, s4));
// The secrets for periods 2 - 4 should be added to the recogniser
oneOf(tagRecogniser).addSecret(s2);
oneOf(tagRecogniser).addSecret(s3);
oneOf(tagRecogniser).addSecret(s4);
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// stop()
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
oneOf(tagRecogniser).removeSecrets();
}});
assertTrue(keyManager.start());
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testLoadSecretsAndRotateInSamePeriod() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final TagRecogniser tagRecogniser = context.mock(TagRecogniser.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The DB contains the secrets for periods 0 - 2
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Arrays.asList(s0, s1, s2)));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
// The current time is the epoch, the start of period 1
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
// The secrets for periods 0 - 2 should be added to the recogniser
oneOf(tagRecogniser).addSecret(s0);
oneOf(tagRecogniser).addSecret(s1);
oneOf(tagRecogniser).addSecret(s2);
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// run() during period 1: the secrets should not be affected
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH + 1));
// getConnectionContext()
oneOf(db).incrementStreamCounter(contactId, transportId, 1);
will(returnValue(0L));
// stop()
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
oneOf(tagRecogniser).removeSecrets();
}});
assertTrue(keyManager.start());
keyManager.run();
StreamContext ctx =
keyManager.getStreamContext(contactId, transportId);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(transportId, ctx.getTransportId());
assertArrayEquals(secret1, ctx.getSecret());
assertEquals(0, ctx.getStreamNumber());
assertEquals(true, ctx.getAlice());
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testLoadSecretsAndRotateInNextPeriod() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final TagRecogniser tagRecogniser = context.mock(TagRecogniser.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The DB contains the secrets for periods 0 - 2
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
// The secret for period 3 should be derived and stored
final TemporarySecret s3 = new TemporarySecret(ep, 3, secret3);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Arrays.asList(s0, s1, s2)));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
// The current time is the epoch, the start of period 1
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
// The secrets for periods 0 - 2 should be added to the recogniser
oneOf(tagRecogniser).addSecret(s0);
oneOf(tagRecogniser).addSecret(s1);
oneOf(tagRecogniser).addSecret(s2);
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// run() during period 2: the secrets should be rotated
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH + ROTATION_PERIOD + 1));
oneOf(crypto).deriveNextSecret(secret0, 1);
will(returnValue(secret1));
oneOf(crypto).deriveNextSecret(secret1, 2);
will(returnValue(secret2));
oneOf(crypto).deriveNextSecret(secret2, 3);
will(returnValue(secret3));
oneOf(tagRecogniser).removeSecret(contactId, transportId, 0);
oneOf(db).addSecrets(Arrays.asList(s3));
oneOf(tagRecogniser).addSecret(s3);
// getConnectionContext()
oneOf(db).incrementStreamCounter(contactId, transportId, 2);
will(returnValue(0L));
// stop()
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
oneOf(tagRecogniser).removeSecrets();
}});
assertTrue(keyManager.start());
keyManager.run();
StreamContext ctx =
keyManager.getStreamContext(contactId, transportId);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(transportId, ctx.getTransportId());
assertArrayEquals(secret2, ctx.getSecret());
assertEquals(0, ctx.getStreamNumber());
assertEquals(true, ctx.getAlice());
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testLoadSecretsAndRotateAWholePeriodLate() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final TagRecogniser tagRecogniser = context.mock(TagRecogniser.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The DB contains the secrets for periods 0 - 2
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
// The secrets for periods 3 and 4 should be derived and stored
final TemporarySecret s3 = new TemporarySecret(ep, 3, secret3);
final TemporarySecret s4 = new TemporarySecret(ep, 4, secret4);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Arrays.asList(s0, s1, s2)));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
// The current time is the epoch, the start of period 1
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
// The secrets for periods 0 - 2 should be added to the recogniser
oneOf(tagRecogniser).addSecret(s0);
oneOf(tagRecogniser).addSecret(s1);
oneOf(tagRecogniser).addSecret(s2);
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// run() during period 3 (late): the secrets should be rotated
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH + 2 * ROTATION_PERIOD + 1));
oneOf(crypto).deriveNextSecret(secret1, 2);
will(returnValue(secret2));
oneOf(crypto).deriveNextSecret(secret2, 3);
will(returnValue(secret3));
oneOf(crypto).deriveNextSecret(secret3, 4);
will(returnValue(secret4));
oneOf(tagRecogniser).removeSecret(contactId, transportId, 0);
oneOf(tagRecogniser).removeSecret(contactId, transportId, 1);
oneOf(db).addSecrets(Arrays.asList(s3, s4));
oneOf(tagRecogniser).addSecret(s3);
oneOf(tagRecogniser).addSecret(s4);
// getConnectionContext()
oneOf(db).incrementStreamCounter(contactId, transportId, 3);
will(returnValue(0L));
// stop()
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
oneOf(tagRecogniser).removeSecrets();
}});
assertTrue(keyManager.start());
keyManager.run();
StreamContext ctx =
keyManager.getStreamContext(contactId, transportId);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(transportId, ctx.getTransportId());
assertArrayEquals(secret3, ctx.getSecret());
assertEquals(0, ctx.getStreamNumber());
assertEquals(true, ctx.getAlice());
keyManager.stop();
context.assertIsSatisfied();
public void testUnitTestsExist() {
fail(); // FIXME: Write tests
}
}

View File

@@ -1,772 +0,0 @@
package org.briarproject.transport;
import org.briarproject.BriarTestCase;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.event.EventBus;
import org.briarproject.api.event.EventListener;
import org.briarproject.api.system.Clock;
import org.briarproject.api.system.Timer;
import org.briarproject.api.transport.Endpoint;
import org.briarproject.api.transport.StreamContext;
import org.briarproject.api.transport.TagRecogniser;
import org.briarproject.api.transport.TemporarySecret;
import org.briarproject.util.ByteUtils;
import org.hamcrest.Description;
import org.jmock.Expectations;
import org.jmock.Mockery;
import org.jmock.api.Action;
import org.jmock.api.Invocation;
import org.junit.Test;
import java.util.Arrays;
import java.util.Collections;
import static org.briarproject.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE;
import static org.briarproject.api.transport.TransportConstants.TAG_LENGTH;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
public class KeyRotationIntegrationTest extends BriarTestCase {
private static final long EPOCH = 1000L * 1000L * 1000L * 1000L;
private static final int MAX_LATENCY = 2 * 60 * 1000; // 2 minutes
private static final int ROTATION_PERIOD =
MAX_CLOCK_DIFFERENCE + MAX_LATENCY;
private final ContactId contactId;
private final TransportId transportId;
private final byte[] secret0, secret1, secret2, secret3, secret4;
private final byte[] key0, key1, key2, key3, key4;
private final SecretKey k0, k1, k2, k3, k4;
private final byte[] initialSecret;
public KeyRotationIntegrationTest() {
contactId = new ContactId(234);
transportId = new TransportId("id");
secret0 = new byte[32];
secret1 = new byte[32];
secret2 = new byte[32];
secret3 = new byte[32];
secret4 = new byte[32];
for (int i = 0; i < secret0.length; i++) secret0[i] = 1;
for (int i = 0; i < secret1.length; i++) secret1[i] = 2;
for (int i = 0; i < secret2.length; i++) secret2[i] = 3;
for (int i = 0; i < secret3.length; i++) secret3[i] = 4;
for (int i = 0; i < secret4.length; i++) secret4[i] = 5;
key0 = new byte[32];
key1 = new byte[32];
key2 = new byte[32];
key3 = new byte[32];
key4 = new byte[32];
k0 = new SecretKey(key0);
k1 = new SecretKey(key1);
k2 = new SecretKey(key2);
k3 = new SecretKey(key3);
k4 = new SecretKey(key4);
for (int i = 0; i < key0.length; i++) key0[i] = 1;
for (int i = 0; i < key1.length; i++) key1[i] = 2;
for (int i = 0; i < key2.length; i++) key2[i] = 3;
for (int i = 0; i < key3.length; i++) key3[i] = 4;
for (int i = 0; i < key4.length; i++) key4[i] = 5;
initialSecret = new byte[32];
for (int i = 0; i < initialSecret.length; i++) initialSecret[i] = 123;
}
@Test
public void testStartAndStop() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final TagRecogniser tagRecogniser = new TagRecogniserImpl(crypto, db);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Collections.emptyList()));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.emptyMap()));
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// stop()
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
}});
assertTrue(keyManager.start());
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testEndpointAdded() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final TagRecogniser tagRecogniser = new TagRecogniserImpl(crypto, db);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The secrets for periods 0 - 2 should be derived
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Collections.emptyList()));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// endpointAdded() during rotation period 1
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
oneOf(crypto).deriveNextSecret(initialSecret, 0);
will(returnValue(secret0));
oneOf(crypto).deriveNextSecret(secret0, 1);
will(returnValue(secret1));
oneOf(crypto).deriveNextSecret(secret1, 2);
will(returnValue(secret2));
oneOf(db).addSecrets(Arrays.asList(s0, s1, s2));
// The recogniser should derive the tags for period 0
oneOf(crypto).deriveTagKey(secret0, false);
will(returnValue(k0));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k0),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 1
oneOf(crypto).deriveTagKey(secret1, false);
will(returnValue(k1));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k1),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 2
oneOf(crypto).deriveTagKey(secret2, false);
will(returnValue(k2));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k2),
with((long) i));
will(new EncodeTagAction());
}
// stop()
// The recogniser should derive the tags for period 0
oneOf(crypto).deriveTagKey(secret0, false);
will(returnValue(k0));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k0),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 1
oneOf(crypto).deriveTagKey(secret1, false);
will(returnValue(k1));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k1),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 2
oneOf(crypto).deriveTagKey(secret2, false);
will(returnValue(k2));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k2),
with((long) i));
will(new EncodeTagAction());
}
// Remove the listener and stop the timer
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
}});
assertTrue(keyManager.start());
keyManager.endpointAdded(ep, MAX_LATENCY, initialSecret);
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testEndpointAddedAndGetConnectionContext() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final TagRecogniser tagRecogniser = new TagRecogniserImpl(crypto, db);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The secrets for periods 0 - 2 should be derived
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Collections.emptyList()));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// endpointAdded() during rotation period 1
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
oneOf(crypto).deriveNextSecret(initialSecret, 0);
will(returnValue(secret0));
oneOf(crypto).deriveNextSecret(secret0, 1);
will(returnValue(secret1));
oneOf(crypto).deriveNextSecret(secret1, 2);
will(returnValue(secret2));
oneOf(db).addSecrets(Arrays.asList(s0, s1, s2));
// The recogniser should derive the tags for period 0
oneOf(crypto).deriveTagKey(secret0, false);
will(returnValue(k0));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k0),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 1
oneOf(crypto).deriveTagKey(secret1, false);
will(returnValue(k1));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k1),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 2
oneOf(crypto).deriveTagKey(secret2, false);
will(returnValue(k2));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k2),
with((long) i));
will(new EncodeTagAction());
}
// getConnectionContext()
oneOf(db).incrementStreamCounter(contactId, transportId, 1);
will(returnValue(0L));
// stop()
// The recogniser should derive the tags for period 0
oneOf(crypto).deriveTagKey(secret0, false);
will(returnValue(k0));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k0),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 1
oneOf(crypto).deriveTagKey(secret1, false);
will(returnValue(k1));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k1),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 2
oneOf(crypto).deriveTagKey(secret2, false);
will(returnValue(k2));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k2),
with((long) i));
will(new EncodeTagAction());
}
// Remove the listener and stop the timer
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
}});
assertTrue(keyManager.start());
keyManager.endpointAdded(ep, MAX_LATENCY, initialSecret);
StreamContext ctx =
keyManager.getStreamContext(contactId, transportId);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(transportId, ctx.getTransportId());
assertArrayEquals(secret1, ctx.getSecret());
assertEquals(0, ctx.getStreamNumber());
assertEquals(true, ctx.getAlice());
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testEndpointAddedAndAcceptConnection() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final TagRecogniser tagRecogniser = new TagRecogniserImpl(crypto, db);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The secrets for periods 0 - 2 should be derived
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Collections.emptyList()));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// endpointAdded() during rotation period 1
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
oneOf(crypto).deriveNextSecret(initialSecret, 0);
will(returnValue(secret0));
oneOf(crypto).deriveNextSecret(secret0, 1);
will(returnValue(secret1));
oneOf(crypto).deriveNextSecret(secret1, 2);
will(returnValue(secret2));
oneOf(db).addSecrets(Arrays.asList(s0, s1, s2));
// The recogniser should derive the tags for period 0
oneOf(crypto).deriveTagKey(secret0, false);
will(returnValue(k0));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k0),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 1
oneOf(crypto).deriveTagKey(secret1, false);
will(returnValue(k1));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k1),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 2
oneOf(crypto).deriveTagKey(secret2, false);
will(returnValue(k2));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k2),
with((long) i));
will(new EncodeTagAction());
}
// acceptConnection()
oneOf(crypto).deriveTagKey(secret2, false);
will(returnValue(k2));
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k2),
with(16L));
will(new EncodeTagAction());
oneOf(db).setReorderingWindow(contactId, transportId, 2, 1,
new byte[] {0, 1, 0, 0});
// stop()
// The recogniser should derive the tags for period 0
oneOf(crypto).deriveTagKey(secret0, false);
will(returnValue(k0));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k0),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 1
oneOf(crypto).deriveTagKey(secret1, false);
will(returnValue(k1));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k1),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the updated tags for period 2
oneOf(crypto).deriveTagKey(secret2, false);
will(returnValue(k2));
for (int i = 1; i < 17; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k2),
with((long) i));
will(new EncodeTagAction());
}
// Remove the listener and stop the timer
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
}});
assertTrue(keyManager.start());
keyManager.endpointAdded(ep, MAX_LATENCY, initialSecret);
// Recognise the tag for connection 0 in period 2
byte[] tag = new byte[TAG_LENGTH];
encodeTag(tag, key2, 0);
StreamContext ctx = tagRecogniser.recogniseTag(transportId, tag);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(transportId, ctx.getTransportId());
assertArrayEquals(secret2, ctx.getSecret());
assertEquals(0, ctx.getStreamNumber());
assertEquals(true, ctx.getAlice());
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testLoadSecretsAtEpoch() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final TagRecogniser tagRecogniser = new TagRecogniserImpl(crypto, db);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The DB contains the secrets for periods 0 - 2
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Arrays.asList(s0, s1, s2)));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
// The current time is the epoch, the start of period 1
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH));
// The recogniser should derive the tags for period 0
oneOf(crypto).deriveTagKey(secret0, false);
will(returnValue(k0));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k0),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 1
oneOf(crypto).deriveTagKey(secret1, false);
will(returnValue(k1));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k1),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 2
oneOf(crypto).deriveTagKey(secret2, false);
will(returnValue(k2));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k2),
with((long) i));
will(new EncodeTagAction());
}
// Start the timer
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// stop()
// The recogniser should remove the tags for period 0
oneOf(crypto).deriveTagKey(secret0, false);
will(returnValue(k0));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k0),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 1
oneOf(crypto).deriveTagKey(secret1, false);
will(returnValue(k1));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k1),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 2
oneOf(crypto).deriveTagKey(secret2, false);
will(returnValue(k2));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k2),
with((long) i));
will(new EncodeTagAction());
}
// Remove the listener and stop the timer
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
}});
assertTrue(keyManager.start());
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testLoadSecretsAtStartOfPeriod2() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final TagRecogniser tagRecogniser = new TagRecogniserImpl(crypto, db);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The DB contains the secrets for periods 0 - 2
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
// The secret for period 3 should be derived and stored
final TemporarySecret s3 = new TemporarySecret(ep, 3, secret3);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Arrays.asList(s0, s1, s2)));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
// The current time is the start of period 2
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH + ROTATION_PERIOD));
// The secret for period 3 should be derived and stored
oneOf(crypto).deriveNextSecret(secret0, 1);
will(returnValue(secret1));
oneOf(crypto).deriveNextSecret(secret1, 2);
will(returnValue(secret2));
oneOf(crypto).deriveNextSecret(secret2, 3);
will(returnValue(secret3));
oneOf(db).addSecrets(Arrays.asList(s3));
// The recogniser should derive the tags for period 1
oneOf(crypto).deriveTagKey(secret1, false);
will(returnValue(k1));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k1),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 2
oneOf(crypto).deriveTagKey(secret2, false);
will(returnValue(k2));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k2),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 3
oneOf(crypto).deriveTagKey(secret3, false);
will(returnValue(k3));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k3),
with((long) i));
will(new EncodeTagAction());
}
// Start the timer
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// stop()
// The recogniser should derive the tags for period 1
oneOf(crypto).deriveTagKey(secret1, false);
will(returnValue(k1));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k1),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 2
oneOf(crypto).deriveTagKey(secret2, false);
will(returnValue(k2));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k2),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should remove the tags for period 3
oneOf(crypto).deriveTagKey(secret3, false);
will(returnValue(k3));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k3),
with((long) i));
will(new EncodeTagAction());
}
// Remove the listener and stop the timer
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
}});
assertTrue(keyManager.start());
keyManager.stop();
context.assertIsSatisfied();
}
@Test
public void testLoadSecretsAtEndOfPeriod3() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final DatabaseComponent db = context.mock(DatabaseComponent.class);
final EventBus eventBus = context.mock(EventBus.class);
final Clock clock = context.mock(Clock.class);
final Timer timer = context.mock(Timer.class);
final TagRecogniser tagRecogniser = new TagRecogniserImpl(crypto, db);
final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db,
eventBus, tagRecogniser, clock, timer);
// The DB contains the secrets for periods 0 - 2
Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true);
final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0);
final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1);
final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2);
// The secrets for periods 3 and 4 should be derived and stored
final TemporarySecret s3 = new TemporarySecret(ep, 3, secret3);
final TemporarySecret s4 = new TemporarySecret(ep, 4, secret4);
context.checking(new Expectations() {{
// start()
oneOf(eventBus).addListener(with(any(EventListener.class)));
oneOf(db).getSecrets();
will(returnValue(Arrays.asList(s0, s1, s2)));
oneOf(db).getTransportLatencies();
will(returnValue(Collections.singletonMap(transportId,
MAX_LATENCY)));
// The current time is the end of period 3
oneOf(clock).currentTimeMillis();
will(returnValue(EPOCH + 3 * ROTATION_PERIOD - 1));
// The secrets for periods 3 and 4 should be derived from secret 1
oneOf(crypto).deriveNextSecret(secret1, 2);
will(returnValue(secret2));
oneOf(crypto).deriveNextSecret(secret2, 3);
will(returnValue(secret3));
oneOf(crypto).deriveNextSecret(secret3, 4);
will(returnValue(secret4));
// The new secrets should be stored
oneOf(db).addSecrets(Arrays.asList(s3, s4));
// The recogniser should derive the tags for period 2
oneOf(crypto).deriveTagKey(secret2, false);
will(returnValue(k2));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k2),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 3
oneOf(crypto).deriveTagKey(secret3, false);
will(returnValue(k3));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k3),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 4
oneOf(crypto).deriveTagKey(secret4, false);
will(returnValue(k4));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k4),
with((long) i));
will(new EncodeTagAction());
}
// Start the timer
oneOf(timer).scheduleAtFixedRate(with(keyManager),
with(any(long.class)), with(any(long.class)));
// stop()
// The recogniser should derive the tags for period 2
oneOf(crypto).deriveTagKey(secret2, false);
will(returnValue(k2));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k2),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should remove the tags for period 3
oneOf(crypto).deriveTagKey(secret3, false);
will(returnValue(k3));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k3),
with((long) i));
will(new EncodeTagAction());
}
// The recogniser should derive the tags for period 4
oneOf(crypto).deriveTagKey(secret4, false);
will(returnValue(k4));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(k4),
with((long) i));
will(new EncodeTagAction());
}
// Remove the listener and stop the timer
oneOf(eventBus).removeListener(with(any(EventListener.class)));
oneOf(timer).cancel();
}});
assertTrue(keyManager.start());
keyManager.stop();
context.assertIsSatisfied();
}
private void encodeTag(byte[] tag, byte[] rawKey, long streamNumber) {
// Encode a fake tag based on the key and stream number
System.arraycopy(rawKey, 0, tag, 0, tag.length);
ByteUtils.writeUint32(streamNumber, tag, 0);
}
private class EncodeTagAction implements Action {
public void describeTo(Description description) {
description.appendText("Encodes a tag");
}
public Object invoke(Invocation invocation) throws Throwable {
byte[] tag = (byte[]) invocation.getParameter(0);
SecretKey key = (SecretKey) invocation.getParameter(1);
long streamNumber = (Long) invocation.getParameter(2);
encodeTag(tag, key.getBytes(), streamNumber);
return null;
}
}
}

View File

@@ -1,5 +1,10 @@
package org.briarproject.transport;
import org.briarproject.BriarTestCase;
import org.briarproject.api.transport.TransportConstants;
import org.briarproject.transport.ReorderingWindow.Change;
import org.junit.Assert;
import org.junit.Test;
import org.briarproject.BriarTestCase;
import org.junit.Test;
@@ -13,148 +18,102 @@ import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.util.Arrays;
import java.util.Collections;
import java.util.Random;
import static org.briarproject.api.transport.TransportConstants.REORDERING_WINDOW_SIZE;
import static org.junit.Assert.assertArrayEquals;
public class ReorderingWindowTest extends BriarTestCase {
@Test
public void testWindowSliding() {
ReorderingWindow w = new ReorderingWindow();
for (int i = 0; i < 100; i++) {
assertFalse(w.isSeen(i));
w.setSeen(i);
assertTrue(w.isSeen(i));
}
}
@Test
public void testWindowJumping() {
ReorderingWindow w = new ReorderingWindow();
for (int i = 0; i < 100; i += 13) {
assertFalse(w.isSeen(i));
w.setSeen(i);
assertTrue(w.isSeen(i));
}
}
@Test
public void testWindowUpperLimit() {
ReorderingWindow w = new ReorderingWindow();
// Centre is 0, highest value in window is 15
w.setSeen(15);
// Centre is 16, highest value in window is 31
w.setSeen(31);
try {
// Centre is 32, highest value in window is 47
w.setSeen(48);
fail();
} catch (IllegalArgumentException expected) {}
// Centre is max - 1, highest value in window is max
public void testBitmapConversion() {
Random random = new Random();
byte[] bitmap = new byte[REORDERING_WINDOW_SIZE / 8];
w = new ReorderingWindow(MAX_32_BIT_UNSIGNED - 1, bitmap);
assertFalse(w.isSeen(MAX_32_BIT_UNSIGNED - 1));
assertFalse(w.isSeen(MAX_32_BIT_UNSIGNED));
// Values greater than max should never be allowed
try {
w.setSeen(MAX_32_BIT_UNSIGNED + 1);
fail();
} catch (IllegalArgumentException expected) {}
w.setSeen(MAX_32_BIT_UNSIGNED);
assertTrue(w.isSeen(MAX_32_BIT_UNSIGNED));
// Centre should have moved to max + 1
assertEquals(MAX_32_BIT_UNSIGNED + 1, w.getCentre());
// The bit corresponding to max should be set
byte[] expectedBitmap = new byte[REORDERING_WINDOW_SIZE / 8];
expectedBitmap[expectedBitmap.length / 2 - 1] = 1; // 00000001
assertArrayEquals(expectedBitmap, w.getBitmap());
// Values greater than max should never be allowed even if centre > max
try {
w.setSeen(MAX_32_BIT_UNSIGNED + 1);
fail();
} catch (IllegalArgumentException expected) {}
for (int i = 0; i < 1000; i++) {
random.nextBytes(bitmap);
ReorderingWindow window = new ReorderingWindow(0L, bitmap);
assertArrayEquals(bitmap, window.getBitmap());
}
}
@Test
public void testWindowLowerLimit() {
ReorderingWindow w = new ReorderingWindow();
// Centre is 0, negative values should never be allowed
try {
w.setSeen(-1);
fail();
} catch (IllegalArgumentException expected) {}
// Slide the window
w.setSeen(15);
// Centre is 16, lowest value in window is 0
w.setSeen(0);
// Slide the window
w.setSeen(16);
// Centre is 17, lowest value in window is 1
w.setSeen(1);
try {
w.setSeen(0);
fail();
} catch (IllegalArgumentException expected) {}
// Slide the window
w.setSeen(25);
// Centre is 26, lowest value in window is 10
w.setSeen(10);
try {
w.setSeen(9);
fail();
} catch (IllegalArgumentException expected) {}
// Centre should still be 26
assertEquals(26, w.getCentre());
// The bits corresponding to 10, 15, 16 and 25 should be set
byte[] expectedBitmap = new byte[REORDERING_WINDOW_SIZE / 8];
expectedBitmap[0] = (byte) 134; // 10000110
expectedBitmap[1] = 1; // 00000001
assertArrayEquals(expectedBitmap, w.getBitmap());
public void testWindowSlidesWhenFirstElementIsSeen() {
byte[] bitmap = new byte[REORDERING_WINDOW_SIZE / 8];
ReorderingWindow window = new ReorderingWindow(0L, bitmap);
// Set the first element seen
Change change = window.setSeen(0L);
// The window should slide by one element
assertEquals(1L, window.getBase());
assertEquals(Collections.singletonList((long) REORDERING_WINDOW_SIZE), change.getAdded());
assertEquals(Collections.singletonList(0L), change.getRemoved());
// All elements in the window should be unseen
assertArrayEquals(bitmap, window.getBitmap());
}
@Test
public void testCannotSetSeenTwice() {
ReorderingWindow w = new ReorderingWindow();
w.setSeen(15);
try {
w.setSeen(15);
fail();
} catch (IllegalArgumentException expected) {}
public void testWindowDoesNotSlideWhenElementBelowMidpointIsSeen() {
byte[] bitmap = new byte[REORDERING_WINDOW_SIZE / 8];
ReorderingWindow window = new ReorderingWindow(0L, bitmap);
// Set an element below the midpoint seen
Change change = window.setSeen(1L);
// The window should not slide
assertEquals(0L, window.getBase());
assertEquals(Collections.emptyList(), change.getAdded());
assertEquals(Collections.singletonList(1L), change.getRemoved());
// The second element in the window should be seen
bitmap[0] = 0x40; // 0100 0000
assertArrayEquals(bitmap, window.getBitmap());
}
@Test
public void testGetUnseenStreamNumbers() {
ReorderingWindow w = new ReorderingWindow();
// Centre is 0; window should cover 0 to 15, inclusive, with none seen
Collection<Long> unseen = w.getUnseen();
assertEquals(16, unseen.size());
for (int i = 0; i < 16; i++) {
assertTrue(unseen.contains(Long.valueOf(i)));
assertFalse(w.isSeen(i));
}
w.setSeen(3);
w.setSeen(4);
// Centre is 5; window should cover 0 to 20, inclusive, with two seen
unseen = w.getUnseen();
assertEquals(19, unseen.size());
for (int i = 0; i < 21; i++) {
if (i == 3 || i == 4) {
assertFalse(unseen.contains(Long.valueOf(i)));
assertTrue(w.isSeen(i));
} else {
assertTrue(unseen.contains(Long.valueOf(i)));
assertFalse(w.isSeen(i));
}
}
w.setSeen(19);
// Centre is 20; window should cover 4 to 35, inclusive, with two seen
unseen = w.getUnseen();
assertEquals(30, unseen.size());
for (int i = 4; i < 36; i++) {
if (i == 4 || i == 19) {
assertFalse(unseen.contains(Long.valueOf(i)));
assertTrue(w.isSeen(i));
} else {
assertTrue(unseen.contains(Long.valueOf(i)));
assertFalse(w.isSeen(i));
}
}
public void testWindowSlidesWhenElementAboveMidpointIsSeen() {
byte[] bitmap = new byte[REORDERING_WINDOW_SIZE / 8];
ReorderingWindow window = new ReorderingWindow(0, bitmap);
long aboveMidpoint = REORDERING_WINDOW_SIZE / 2;
// Set an element above the midpoint seen
Change change = window.setSeen(aboveMidpoint);
// The window should slide by one element
assertEquals(1L, window.getBase());
assertEquals(Collections.singletonList((long) REORDERING_WINDOW_SIZE), change.getAdded());
assertEquals(Arrays.asList(0L, aboveMidpoint), change.getRemoved());
// The highest element below the midpoint should be seen
bitmap[bitmap.length / 2 - 1] = (byte) 0x01; // 0000 0001
assertArrayEquals(bitmap, window.getBitmap());
}
@Test
public void testWindowSlidesUntilLowestElementIsUnseenWhenFirstElementIsSeen() {
byte[] bitmap = new byte[REORDERING_WINDOW_SIZE / 8];
ReorderingWindow window = new ReorderingWindow(0L, bitmap);
window.setSeen(1L);
// Set the first element seen
Change change = window.setSeen(0L);
// The window should slide by two elements
assertEquals(2L, window.getBase());
assertEquals(Arrays.asList((long) REORDERING_WINDOW_SIZE,
(long) (REORDERING_WINDOW_SIZE + 1)), change.getAdded());
assertEquals(Collections.singletonList(0L), change.getRemoved());
// All elements in the window should be unseen
assertArrayEquals(bitmap, window.getBitmap());
}
@Test
public void testWindowSlidesUntilLowestElementIsUnseenWhenElementAboveMidpointIsSeen() {
byte[] bitmap = new byte[REORDERING_WINDOW_SIZE / 8];
ReorderingWindow window = new ReorderingWindow(0L, bitmap);
window.setSeen(1L);
long aboveMidpoint = REORDERING_WINDOW_SIZE / 2;
// Set an element above the midpoint seen
Change change = window.setSeen(aboveMidpoint);
// The window should slide by two elements
assertEquals(2L, window.getBase());
assertEquals(Arrays.asList((long) REORDERING_WINDOW_SIZE,
(long) (REORDERING_WINDOW_SIZE + 1)), change.getAdded());
assertEquals(Arrays.asList(0L, aboveMidpoint), change.getRemoved());
// The second-highest element below the midpoint should be seen
bitmap[bitmap.length / 2 - 1] = (byte) 0x02; // 0000 0010
assertArrayEquals(bitmap, window.getBitmap());
}
}

View File

@@ -0,0 +1,14 @@
package org.briarproject.transport;
import org.briarproject.BriarTestCase;
import org.junit.Test;
import static org.junit.Assert.fail;
public class TransportKeyManagerTest extends BriarTestCase {
@Test
public void testUnitTestsExist() {
fail(); // FIXME: Write tests
}
}

View File

@@ -1,130 +0,0 @@
package org.briarproject.transport;
import org.briarproject.BriarTestCase;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.transport.StreamContext;
import org.briarproject.api.transport.TemporarySecret;
import org.briarproject.util.ByteUtils;
import org.hamcrest.Description;
import org.jmock.Expectations;
import org.jmock.Mockery;
import org.jmock.api.Action;
import org.jmock.api.Invocation;
import org.junit.Test;
import java.util.Random;
import static org.briarproject.api.transport.TransportConstants.TAG_LENGTH;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
public class TransportTagRecogniserTest extends BriarTestCase {
private final ContactId contactId = new ContactId(234);
private final TransportId transportId = new TransportId("id");
private final SecretKey tagKey = new SecretKey(new byte[32]);
@Test
public void testAddAndRemoveSecret() {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final byte[] secret = new byte[32];
new Random().nextBytes(secret);
final boolean alice = false;
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Add secret
oneOf(crypto).deriveTagKey(secret, !alice);
will(returnValue(tagKey));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(tagKey),
with((long) i));
will(new EncodeTagAction());
}
// Remove secret
oneOf(crypto).deriveTagKey(secret, !alice);
will(returnValue(tagKey));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(tagKey),
with((long) i));
will(new EncodeTagAction());
}
}});
TemporarySecret s = new TemporarySecret(contactId, transportId, 123,
alice, 0, secret, 0, 0, new byte[4]);
TransportTagRecogniser recogniser =
new TransportTagRecogniser(crypto, db, transportId);
recogniser.addSecret(s);
recogniser.removeSecret(contactId, 0);
context.assertIsSatisfied();
}
@Test
public void testRecogniseTag() throws Exception {
Mockery context = new Mockery();
final CryptoComponent crypto = context.mock(CryptoComponent.class);
final byte[] secret = new byte[32];
new Random().nextBytes(secret);
final boolean alice = false;
final DatabaseComponent db = context.mock(DatabaseComponent.class);
context.checking(new Expectations() {{
// Add secret
oneOf(crypto).deriveTagKey(secret, !alice);
will(returnValue(tagKey));
for (int i = 0; i < 16; i++) {
oneOf(crypto).encodeTag(with(any(byte[].class)), with(tagKey),
with((long) i));
will(new EncodeTagAction());
}
// Recognise tag 0
oneOf(crypto).deriveTagKey(secret, !alice);
will(returnValue(tagKey));
// The window should slide to include tag 16
oneOf(crypto).encodeTag(with(any(byte[].class)), with(tagKey),
with(16L));
will(new EncodeTagAction());
// The updated window should be stored
oneOf(db).setReorderingWindow(contactId, transportId, 0, 1,
new byte[] {0, 1, 0, 0});
// Recognise tag again - no expectations
}});
TemporarySecret s = new TemporarySecret(contactId, transportId, 123,
alice, 0, secret, 0, 0, new byte[4]);
TransportTagRecogniser recogniser =
new TransportTagRecogniser(crypto, db, transportId);
recogniser.addSecret(s);
// Tag 0 should be expected
byte[] tag = new byte[TAG_LENGTH];
StreamContext ctx = recogniser.recogniseTag(tag);
assertNotNull(ctx);
assertEquals(contactId, ctx.getContactId());
assertEquals(transportId, ctx.getTransportId());
assertArrayEquals(secret, ctx.getSecret());
assertEquals(0, ctx.getStreamNumber());
assertEquals(alice, ctx.getAlice());
// Tag 0 should not be expected again
assertNull(recogniser.recogniseTag(tag));
context.assertIsSatisfied();
}
private static class EncodeTagAction implements Action {
public void describeTo(Description description) {
description.appendText("Encodes a tag");
}
public Object invoke(Invocation invocation) throws Throwable {
byte[] tag = (byte[]) invocation.getParameter(0);
long streamNumber = (Long) invocation.getParameter(2);
// Encode a fake tag based on the stream number
ByteUtils.writeUint32(streamNumber, tag, 0);
return null;
}
}
}