Changed the root package from net.sf.briar to org.briarproject.

This commit is contained in:
akwizgran
2014-01-08 16:18:30 +00:00
parent dce70f487c
commit 832476412c
427 changed files with 2507 additions and 2507 deletions

View File

@@ -0,0 +1,71 @@
package org.briarproject.crypto;
import java.security.GeneralSecurityException;
import javax.crypto.Cipher;
import org.briarproject.api.crypto.AuthenticatedCipher;
import org.briarproject.api.crypto.SecretKey;
import org.spongycastle.crypto.DataLengthException;
import org.spongycastle.crypto.InvalidCipherTextException;
import org.spongycastle.crypto.modes.AEADBlockCipher;
import org.spongycastle.crypto.params.AEADParameters;
import org.spongycastle.crypto.params.KeyParameter;
class AuthenticatedCipherImpl implements AuthenticatedCipher {
private final AEADBlockCipher cipher;
private final int macLength;
AuthenticatedCipherImpl(AEADBlockCipher cipher, int macLength) {
this.cipher = cipher;
this.macLength = macLength;
}
public int doFinal(byte[] input, int inputOff, int len, byte[] output,
int outputOff) throws GeneralSecurityException {
int processed = 0;
if(len != 0) {
processed = cipher.processBytes(input, inputOff, len, output,
outputOff);
}
try {
return processed + cipher.doFinal(output, outputOff + processed);
} catch(DataLengthException e) {
throw new GeneralSecurityException(e.getMessage());
} catch(InvalidCipherTextException e) {
throw new GeneralSecurityException(e.getMessage());
}
}
public void init(int opmode, SecretKey key, byte[] iv, byte[] aad)
throws GeneralSecurityException {
KeyParameter k = new KeyParameter(key.getEncoded());
AEADParameters params = new AEADParameters(k, macLength * 8, iv, aad);
try {
switch(opmode) {
case Cipher.ENCRYPT_MODE:
case Cipher.WRAP_MODE:
cipher.init(true, params);
break;
case Cipher.DECRYPT_MODE:
case Cipher.UNWRAP_MODE:
cipher.init(false, params);
break;
default:
throw new IllegalArgumentException();
}
} catch(IllegalArgumentException e) {
throw new GeneralSecurityException(e.getMessage());
}
}
public int getMacLength() {
return macLength;
}
public int getBlockSize() {
return cipher.getUnderlyingCipher().getBlockSize();
}
}

View File

@@ -0,0 +1,526 @@
package org.briarproject.crypto;
import static java.util.logging.Level.INFO;
import static javax.crypto.Cipher.DECRYPT_MODE;
import static javax.crypto.Cipher.ENCRYPT_MODE;
import static org.briarproject.api.invitation.InvitationConstants.CODE_BITS;
import static org.briarproject.api.transport.TransportConstants.TAG_LENGTH;
import static org.briarproject.crypto.EllipticCurveConstants.P;
import static org.briarproject.crypto.EllipticCurveConstants.PARAMETERS;
import static org.briarproject.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;
import org.briarproject.api.crypto.AuthenticatedCipher;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyPair;
import org.briarproject.api.crypto.KeyParser;
import org.briarproject.api.crypto.MessageDigest;
import org.briarproject.api.crypto.PrivateKey;
import org.briarproject.api.crypto.PseudoRandom;
import org.briarproject.api.crypto.PublicKey;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.crypto.Signature;
import org.briarproject.util.ByteUtils;
import org.spongycastle.crypto.AsymmetricCipherKeyPair;
import org.spongycastle.crypto.BlockCipher;
import org.spongycastle.crypto.CipherParameters;
import org.spongycastle.crypto.Mac;
import org.spongycastle.crypto.agreement.ECDHCBasicAgreement;
import org.spongycastle.crypto.digests.SHA384Digest;
import org.spongycastle.crypto.engines.AESLightEngine;
import org.spongycastle.crypto.generators.ECKeyPairGenerator;
import org.spongycastle.crypto.generators.PKCS5S2ParametersGenerator;
import org.spongycastle.crypto.macs.HMac;
import org.spongycastle.crypto.modes.AEADBlockCipher;
import org.spongycastle.crypto.modes.GCMBlockCipher;
import org.spongycastle.crypto.params.ECKeyGenerationParameters;
import org.spongycastle.crypto.params.ECPrivateKeyParameters;
import org.spongycastle.crypto.params.ECPublicKeyParameters;
import org.spongycastle.crypto.params.KeyParameter;
import org.spongycastle.util.Strings;
class CryptoComponentImpl implements CryptoComponent {
private static final Logger LOG =
Logger.getLogger(CryptoComponentImpl.class.getName());
private static final int CIPHER_KEY_BYTES = 32; // 256 bits
private static final int AGREEMENT_KEY_PAIR_BITS = 384;
private static final int SIGNATURE_KEY_PAIR_BITS = 384;
private static final int MAC_BYTES = 16; // 128 bits
private static final int STORAGE_IV_BYTES = 16; // 128 bits
private static final int PBKDF_SALT_BYTES = 16; // 128 bits
private static final int PBKDF_TARGET_MILLIS = 500;
private static final int PBKDF_SAMPLES = 30;
// Labels for secret derivation
private static final byte[] MASTER = { 'M', 'A', 'S', 'T', 'E', 'R', '\0' };
private static final byte[] SALT = { 'S', 'A', 'L', 'T', '\0' };
private static final byte[] FIRST = { 'F', 'I', 'R', 'S', 'T', '\0' };
private static final byte[] ROTATE = { 'R', 'O', 'T', 'A', 'T', 'E', '\0' };
// Label for confirmation code derivation
private static final byte[] CODE = { 'C', 'O', 'D', 'E', '\0' };
// Label for invitation nonce derivation
private static final byte[] NONCE = { 'N', 'O', 'N', 'C', 'E', '\0' };
// Labels for key derivation
private static final byte[] A_TAG = { 'A', '_', 'T', 'A', 'G', '\0' };
private static final byte[] B_TAG = { 'B', '_', 'T', 'A', 'G', '\0' };
private static final byte[] A_FRAME_A =
{ 'A', '_', 'F', 'R', 'A', 'M', 'E', '_', 'A', '\0' };
private static final byte[] A_FRAME_B =
{ 'A', '_', 'F', 'R', 'A', 'M', 'E', '_', 'B', '\0' };
private static final byte[] B_FRAME_A =
{ 'B', '_', 'F', 'R', 'A', 'M', 'E', '_', 'A', '\0' };
private static final byte[] B_FRAME_B =
{ 'B', '_', 'F', 'R', 'A', 'M', 'E', '_', 'B', '\0' };
// Blank secret for argument validation
private static final byte[] BLANK_SECRET = new byte[CIPHER_KEY_BYTES];
private final KeyParser agreementKeyParser, signatureKeyParser;
private final SecureRandom secureRandom;
private final ECKeyPairGenerator agreementKeyPairGenerator;
private final ECKeyPairGenerator signatureKeyPairGenerator;
CryptoComponentImpl() {
agreementKeyParser = new Sec1KeyParser(PARAMETERS, P,
AGREEMENT_KEY_PAIR_BITS);
signatureKeyParser = new Sec1KeyParser(PARAMETERS, P,
SIGNATURE_KEY_PAIR_BITS);
secureRandom = new SecureRandom();
ECKeyGenerationParameters params = new ECKeyGenerationParameters(
PARAMETERS, secureRandom);
agreementKeyPairGenerator = new ECKeyPairGenerator();
agreementKeyPairGenerator.init(params);
signatureKeyPairGenerator = new ECKeyPairGenerator();
signatureKeyPairGenerator.init(params);
}
public SecretKey generateSecretKey() {
byte[] b = new byte[CIPHER_KEY_BYTES];
secureRandom.nextBytes(b);
return new SecretKeyImpl(b);
}
public MessageDigest getMessageDigest() {
return new DoubleDigest(new SHA384Digest());
}
public PseudoRandom getPseudoRandom(int seed1, int seed2) {
return new PseudoRandomImpl(getMessageDigest(), seed1, seed2);
}
public SecureRandom getSecureRandom() {
return secureRandom;
}
public Signature getSignature() {
return new SignatureImpl(secureRandom);
}
public KeyPair generateAgreementKeyPair() {
AsymmetricCipherKeyPair keyPair =
agreementKeyPairGenerator.generateKeyPair();
// Return a wrapper that uses the SEC 1 encoding
ECPublicKeyParameters ecPublicKey =
(ECPublicKeyParameters) keyPair.getPublic();
PublicKey publicKey = new Sec1PublicKey(ecPublicKey,
AGREEMENT_KEY_PAIR_BITS);
ECPrivateKeyParameters ecPrivateKey =
(ECPrivateKeyParameters) keyPair.getPrivate();
PrivateKey privateKey = new Sec1PrivateKey(ecPrivateKey,
AGREEMENT_KEY_PAIR_BITS);
return new KeyPair(publicKey, privateKey);
}
public KeyParser getAgreementKeyParser() {
return agreementKeyParser;
}
public KeyPair generateSignatureKeyPair() {
AsymmetricCipherKeyPair keyPair =
signatureKeyPairGenerator.generateKeyPair();
// Return a wrapper that uses the SEC 1 encoding
ECPublicKeyParameters ecPublicKey =
(ECPublicKeyParameters) keyPair.getPublic();
PublicKey publicKey = new Sec1PublicKey(ecPublicKey,
SIGNATURE_KEY_PAIR_BITS);
ECPrivateKeyParameters ecPrivateKey =
(ECPrivateKeyParameters) keyPair.getPrivate();
PrivateKey privateKey = new Sec1PrivateKey(ecPrivateKey,
SIGNATURE_KEY_PAIR_BITS);
return new KeyPair(publicKey, privateKey);
}
public KeyParser getSignatureKeyParser() {
return signatureKeyParser;
}
public int generateInvitationCode() {
int codeBytes = (int) Math.ceil(CODE_BITS / 8.0);
byte[] random = new byte[codeBytes];
secureRandom.nextBytes(random);
return ByteUtils.readUint(random, CODE_BITS);
}
public int[] deriveConfirmationCodes(byte[] secret) {
if(secret.length != CIPHER_KEY_BYTES)
throw new IllegalArgumentException();
if(Arrays.equals(secret, BLANK_SECRET))
throw new IllegalArgumentException();
byte[] alice = counterModeKdf(secret, CODE, 0);
byte[] bob = counterModeKdf(secret, CODE, 1);
int[] codes = new int[2];
codes[0] = ByteUtils.readUint(alice, CODE_BITS);
codes[1] = ByteUtils.readUint(bob, CODE_BITS);
ByteUtils.erase(alice);
ByteUtils.erase(bob);
return codes;
}
public byte[][] deriveInvitationNonces(byte[] secret) {
if(secret.length != CIPHER_KEY_BYTES)
throw new IllegalArgumentException();
if(Arrays.equals(secret, BLANK_SECRET))
throw new IllegalArgumentException();
byte[] alice = counterModeKdf(secret, NONCE, 0);
byte[] bob = counterModeKdf(secret, NONCE, 1);
return new byte[][] { alice, bob };
}
public byte[] deriveMasterSecret(byte[] theirPublicKey,
KeyPair ourKeyPair, boolean alice) throws GeneralSecurityException {
MessageDigest messageDigest = getMessageDigest();
byte[] ourPublicKey = ourKeyPair.getPublic().getEncoded();
byte[] ourHash = messageDigest.digest(ourPublicKey);
byte[] theirHash = messageDigest.digest(theirPublicKey);
byte[] aliceInfo, bobInfo;
if(alice) {
aliceInfo = ourHash;
bobInfo = theirHash;
} else {
aliceInfo = theirHash;
bobInfo = ourHash;
}
PrivateKey ourPriv = ourKeyPair.getPrivate();
PublicKey theirPub = agreementKeyParser.parsePublicKey(theirPublicKey);
// The raw secret comes from the key agreement algorithm
byte[] raw = deriveSharedSecret(ourPriv, theirPub);
// Derive the cooked secret from the raw secret using the
// concatenation KDF
byte[] cooked = concatenationKdf(raw, MASTER, aliceInfo, bobInfo);
ByteUtils.erase(raw);
return cooked;
}
// Package access for testing
byte[] deriveSharedSecret(PrivateKey priv, PublicKey pub)
throws GeneralSecurityException {
if(!(priv instanceof Sec1PrivateKey))
throw new IllegalArgumentException();
if(!(pub instanceof Sec1PublicKey))
throw new IllegalArgumentException();
ECPrivateKeyParameters ecPriv = ((Sec1PrivateKey) priv).getKey();
ECPublicKeyParameters ecPub = ((Sec1PublicKey) pub).getKey();
ECDHCBasicAgreement agreement = new ECDHCBasicAgreement();
agreement.init(ecPriv);
// FIXME: Should we use another format for the shared secret?
return agreement.calculateAgreement(ecPub).toByteArray();
}
public byte[] deriveGroupSalt(byte[] secret) {
if(secret.length != CIPHER_KEY_BYTES)
throw new IllegalArgumentException();
if(Arrays.equals(secret, BLANK_SECRET))
throw new IllegalArgumentException();
return counterModeKdf(secret, SALT, 0);
}
public byte[] deriveInitialSecret(byte[] secret, int transportIndex) {
if(secret.length != CIPHER_KEY_BYTES)
throw new IllegalArgumentException();
if(Arrays.equals(secret, BLANK_SECRET))
throw new IllegalArgumentException();
if(transportIndex < 0) throw new IllegalArgumentException();
return counterModeKdf(secret, FIRST, transportIndex);
}
public byte[] deriveNextSecret(byte[] secret, long period) {
if(secret.length != CIPHER_KEY_BYTES)
throw new IllegalArgumentException();
if(Arrays.equals(secret, BLANK_SECRET))
throw new IllegalArgumentException();
if(period < 0 || period > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
return counterModeKdf(secret, ROTATE, period);
}
public SecretKey deriveTagKey(byte[] secret, boolean alice) {
if(secret.length != CIPHER_KEY_BYTES)
throw new IllegalArgumentException();
if(Arrays.equals(secret, BLANK_SECRET))
throw new IllegalArgumentException();
if(alice) return deriveKey(secret, A_TAG, 0);
else return deriveKey(secret, B_TAG, 0);
}
public SecretKey deriveFrameKey(byte[] secret, long connection,
boolean alice, boolean initiator) {
if(secret.length != CIPHER_KEY_BYTES)
throw new IllegalArgumentException();
if(Arrays.equals(secret, BLANK_SECRET))
throw new IllegalArgumentException();
if(connection < 0 || connection > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
if(alice) {
if(initiator) return deriveKey(secret, A_FRAME_A, connection);
else return deriveKey(secret, A_FRAME_B, connection);
} else {
if(initiator) return deriveKey(secret, B_FRAME_A, connection);
else return deriveKey(secret, B_FRAME_B, connection);
}
}
private SecretKey deriveKey(byte[] secret, byte[] label, long context) {
if(secret.length != CIPHER_KEY_BYTES)
throw new IllegalArgumentException();
if(Arrays.equals(secret, BLANK_SECRET))
throw new IllegalArgumentException();
byte[] key = counterModeKdf(secret, label, context);
return new SecretKeyImpl(key);
}
public AuthenticatedCipher getFrameCipher() {
AEADBlockCipher cipher = new GCMBlockCipher(new AESLightEngine());
return new AuthenticatedCipherImpl(cipher, MAC_BYTES);
}
public void encodeTag(byte[] tag, SecretKey tagKey, long connection) {
if(tag.length < TAG_LENGTH) throw new IllegalArgumentException();
if(connection < 0 || connection > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
for(int i = 0; i < TAG_LENGTH; i++) tag[i] = 0;
ByteUtils.writeUint32(connection, tag, 0);
BlockCipher cipher = new AESLightEngine();
assert cipher.getBlockSize() == TAG_LENGTH;
KeyParameter k = new KeyParameter(tagKey.getEncoded());
cipher.init(true, k);
cipher.processBlock(tag, 0, tag, 0);
ByteUtils.erase(k.getKey());
}
public byte[] encryptWithPassword(byte[] input, char[] password) {
// Generate a random salt
byte[] salt = new byte[PBKDF_SALT_BYTES];
secureRandom.nextBytes(salt);
// Calibrate the KDF
int iterations = chooseIterationCount(PBKDF_TARGET_MILLIS);
// Derive the key from the password
byte[] keyBytes = pbkdf2(password, salt, iterations);
SecretKey key = new SecretKeyImpl(keyBytes);
// Generate a random IV
byte[] iv = new byte[STORAGE_IV_BYTES];
secureRandom.nextBytes(iv);
// The output contains the salt, iterations, IV, ciphertext and MAC
int outputLen = salt.length + 4 + iv.length + input.length + MAC_BYTES;
byte[] output = new byte[outputLen];
System.arraycopy(salt, 0, output, 0, salt.length);
ByteUtils.writeUint32(iterations, output, salt.length);
System.arraycopy(iv, 0, output, salt.length + 4, iv.length);
// Initialise the cipher and encrypt the plaintext
try {
AEADBlockCipher c = new GCMBlockCipher(new AESLightEngine());
AuthenticatedCipher cipher = new AuthenticatedCipherImpl(c,
MAC_BYTES);
cipher.init(ENCRYPT_MODE, key, iv, null);
int outputOff = salt.length + 4 + iv.length;
cipher.doFinal(input, 0, input.length, output, outputOff);
return output;
} catch(GeneralSecurityException e) {
throw new RuntimeException(e);
} finally {
key.erase();
}
}
public byte[] decryptWithPassword(byte[] input, char[] password) {
// The input contains the salt, iterations, IV, ciphertext and MAC
if(input.length < PBKDF_SALT_BYTES + 4 + STORAGE_IV_BYTES + MAC_BYTES)
return null; // Invalid
byte[] salt = new byte[PBKDF_SALT_BYTES];
System.arraycopy(input, 0, salt, 0, salt.length);
long iterations = ByteUtils.readUint32(input, salt.length);
if(iterations < 0 || iterations > Integer.MAX_VALUE)
return null; // Invalid
byte[] iv = new byte[STORAGE_IV_BYTES];
System.arraycopy(input, salt.length + 4, iv, 0, iv.length);
// Derive the key from the password
byte[] keyBytes = pbkdf2(password, salt, (int) iterations);
SecretKey key = new SecretKeyImpl(keyBytes);
// Initialise the cipher
AuthenticatedCipher cipher;
try {
AEADBlockCipher c = new GCMBlockCipher(new AESLightEngine());
cipher = new AuthenticatedCipherImpl(c, MAC_BYTES);
cipher.init(DECRYPT_MODE, key, iv, null);
} catch(GeneralSecurityException e) {
key.erase();
throw new RuntimeException(e);
}
// Try to decrypt the ciphertext (may be invalid)
try {
int inputOff = salt.length + 4 + iv.length;
int inputLen = input.length - inputOff;
byte[] output = new byte[inputLen - MAC_BYTES];
cipher.doFinal(input, inputOff, inputLen, output, 0);
return output;
} catch(GeneralSecurityException e) {
return null; // Invalid
} finally {
key.erase();
}
}
// Key derivation function based on a hash function - see NIST SP 800-56A,
// section 5.8
private byte[] concatenationKdf(byte[] rawSecret, byte[] label,
byte[] initiatorInfo, byte[] responderInfo) {
// The output of the hash function must be long enough to use as a key
MessageDigest messageDigest = getMessageDigest();
if(messageDigest.getDigestLength() < CIPHER_KEY_BYTES)
throw new RuntimeException();
// The length of every field must fit in an unsigned 8-bit integer
if(rawSecret.length > 255) throw new IllegalArgumentException();
if(label.length > 255) throw new IllegalArgumentException();
if(initiatorInfo.length > 255) throw new IllegalArgumentException();
if(responderInfo.length > 255) throw new IllegalArgumentException();
// All fields are length-prefixed
messageDigest.update((byte) rawSecret.length);
messageDigest.update(rawSecret);
messageDigest.update((byte) label.length);
messageDigest.update(label);
messageDigest.update((byte) initiatorInfo.length);
messageDigest.update(initiatorInfo);
messageDigest.update((byte) responderInfo.length);
messageDigest.update(responderInfo);
byte[] hash = messageDigest.digest();
// The secret is the first CIPHER_KEY_BYTES bytes of the hash
byte[] output = new byte[CIPHER_KEY_BYTES];
System.arraycopy(hash, 0, output, 0, output.length);
ByteUtils.erase(hash);
return output;
}
// Key derivation function based on a PRF in counter mode - see
// NIST SP 800-108, section 5.1
private byte[] counterModeKdf(byte[] secret, byte[] label, long context) {
if(secret.length != CIPHER_KEY_BYTES)
throw new IllegalArgumentException();
if(Arrays.equals(secret, BLANK_SECRET))
throw new IllegalArgumentException();
// The label must be null-terminated
if(label[label.length - 1] != '\0')
throw new IllegalArgumentException();
// Initialise the PRF
Mac prf = new HMac(new SHA384Digest());
KeyParameter k = new KeyParameter(secret);
prf.init(k);
int macLength = prf.getMacSize();
// The output of the PRF must be long enough to use as a key
if(macLength < CIPHER_KEY_BYTES) throw new RuntimeException();
byte[] mac = new byte[macLength], output = new byte[CIPHER_KEY_BYTES];
prf.update((byte) 0); // Counter
prf.update(label, 0, label.length); // Null-terminated
byte[] contextBytes = new byte[4];
ByteUtils.writeUint32(context, contextBytes, 0);
prf.update(contextBytes, 0, contextBytes.length);
prf.update((byte) CIPHER_KEY_BYTES); // Output length
prf.doFinal(mac, 0);
System.arraycopy(mac, 0, output, 0, output.length);
ByteUtils.erase(mac);
ByteUtils.erase(k.getKey());
return output;
}
// Password-based key derivation function - see PKCS#5 v2.1, section 5.2
private byte[] pbkdf2(char[] password, byte[] salt, int iterations) {
byte[] utf8 = toUtf8ByteArray(password);
PKCS5S2ParametersGenerator gen = new PKCS5S2ParametersGenerator();
gen.init(utf8, salt, iterations);
int keyLengthInBits = CIPHER_KEY_BYTES * 8;
CipherParameters p = gen.generateDerivedParameters(keyLengthInBits);
ByteUtils.erase(utf8);
return ((KeyParameter) p).getKey();
}
// Package access for testing
int chooseIterationCount(int targetMillis) {
List<Long> quickSamples = new ArrayList<Long>(PBKDF_SAMPLES);
List<Long> slowSamples = new ArrayList<Long>(PBKDF_SAMPLES);
long iterationNanos = 0, initNanos = 0;
while(iterationNanos <= 0 || initNanos <= 0) {
// Sample the running time with one iteration and two iterations
for(int i = 0; i < PBKDF_SAMPLES; i++) {
quickSamples.add(sampleRunningTime(1));
slowSamples.add(sampleRunningTime(2));
}
// Calculate the iteration time and the initialisation time
long quickMedian = median(quickSamples);
long slowMedian = median(slowSamples);
iterationNanos = slowMedian - quickMedian;
initNanos = quickMedian - iterationNanos;
if(LOG.isLoggable(INFO)) {
LOG.info("Init: " + initNanos + ", iteration: "
+ iterationNanos);
}
}
long targetNanos = targetMillis * 1000L * 1000L;
long iterations = (targetNanos - initNanos) / iterationNanos;
if(LOG.isLoggable(INFO)) LOG.info("Target iterations: " + iterations);
if(iterations < 1) return 1;
if(iterations > Integer.MAX_VALUE) return Integer.MAX_VALUE;
return (int) iterations;
}
private long sampleRunningTime(int iterations) {
byte[] password = { 'p', 'a', 's', 's', 'w', 'o', 'r', 'd' };
byte[] salt = new byte[PBKDF_SALT_BYTES];
int keyLengthInBits = CIPHER_KEY_BYTES * 8;
long start = System.nanoTime();
PKCS5S2ParametersGenerator gen = new PKCS5S2ParametersGenerator();
gen.init(password, salt, iterations);
gen.generateDerivedParameters(keyLengthInBits);
return System.nanoTime() - start;
}
private long median(List<Long> list) {
int size = list.size();
if(size == 0) throw new IllegalArgumentException();
Collections.sort(list);
if(size % 2 == 1) return list.get(size / 2);
return list.get(size / 2 - 1) + list.get(size / 2) / 2;
}
byte[] toUtf8ByteArray(char[] c) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
try {
Strings.toUTF8ByteArray(c, out);
byte[] utf8 = out.toByteArray();
// Erase the output stream's buffer
out.reset();
out.write(new byte[utf8.length]);
return utf8;
} catch(IOException e) {
throw new RuntimeException(e);
}
}
}

View File

@@ -0,0 +1,50 @@
package org.briarproject.crypto;
import static java.util.concurrent.TimeUnit.SECONDS;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadPoolExecutor;
import javax.inject.Singleton;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.CryptoExecutor;
import org.briarproject.api.lifecycle.LifecycleManager;
import com.google.inject.AbstractModule;
import com.google.inject.Provides;
public class CryptoModule extends AbstractModule {
/** The maximum number of executor threads. */
private static final int MAX_EXECUTOR_THREADS =
Runtime.getRuntime().availableProcessors();
private final ExecutorService cryptoExecutor;
public CryptoModule() {
// The queue is unbounded, so tasks can be dependent
BlockingQueue<Runnable> queue = new LinkedBlockingQueue<Runnable>();
// Discard tasks that are submitted during shutdown
RejectedExecutionHandler policy =
new ThreadPoolExecutor.DiscardPolicy();
// Create a limited # of threads and keep them in the pool for 60 secs
cryptoExecutor = new ThreadPoolExecutor(0, MAX_EXECUTOR_THREADS,
60, SECONDS, queue, policy);
}
protected void configure() {
bind(CryptoComponent.class).to(
CryptoComponentImpl.class).in(Singleton.class);
}
@Provides @Singleton @CryptoExecutor
Executor getCryptoExecutor(LifecycleManager lifecycleManager) {
lifecycleManager.registerForShutdown(cryptoExecutor);
return cryptoExecutor;
}
}

View File

@@ -0,0 +1,63 @@
package org.briarproject.crypto;
import org.briarproject.api.crypto.MessageDigest;
import org.spongycastle.crypto.Digest;
/**
* A message digest that prevents length extension attacks - see Ferguson and
* Schneier, <i>Practical Cryptography</i>, chapter 6.
* <p>
* "Let h be an interative hash function. The hash function h<sub>d</sub> is
* defined by h<sub>d</sub> := h(h(m)), and has a claimed security level of
* min(k, n/2) where k is the security level of h and n is the size of the hash
* result."
*/
class DoubleDigest implements MessageDigest {
private final Digest delegate;
DoubleDigest(Digest delegate) {
this.delegate = delegate;
}
public byte[] digest() {
byte[] digest = new byte[delegate.getDigestSize()];
delegate.doFinal(digest, 0); // h(m)
delegate.update(digest, 0, digest.length);
delegate.doFinal(digest, 0); // h(h(m))
return digest;
}
public byte[] digest(byte[] input) {
delegate.update(input, 0, input.length);
return digest();
}
public int digest(byte[] buf, int offset, int len) {
byte[] digest = digest();
len = Math.min(len, digest.length);
System.arraycopy(digest, 0, buf, offset, len);
return len;
}
public int getDigestLength() {
return delegate.getDigestSize();
}
public void reset() {
delegate.reset();
}
public void update(byte input) {
delegate.update(input);
}
public void update(byte[] input) {
delegate.update(input, 0, input.length);
}
public void update(byte[] input, int offset, int len) {
delegate.update(input, offset, len);
}
}

View File

@@ -0,0 +1,70 @@
package org.briarproject.crypto;
import java.math.BigInteger;
import org.spongycastle.crypto.params.ECDomainParameters;
import org.spongycastle.math.ec.ECCurve;
import org.spongycastle.math.ec.ECFieldElement;
import org.spongycastle.math.ec.ECPoint;
/** Parameters for curve brainpoolP384r1 - see RFC 5639. */
interface EllipticCurveConstants {
/**
* The prime specifying the finite field. (This is called p in RFC 5639 and
* q in SEC 2.)
*/
BigInteger P = new BigInteger("8CB91E82" + "A3386D28" + "0F5D6F7E" +
"50E641DF" + "152F7109" + "ED5456B4" + "12B1DA19" + "7FB71123" +
"ACD3A729" + "901D1A71" + "87470013" + "3107EC53", 16);
/**
* A coefficient of the equation y^2 = x^3 + A*x + B defining the elliptic
* curve. (This is called A in RFC 5639 and a in SEC 2.)
*/
BigInteger A = new BigInteger("7BC382C6" + "3D8C150C" + "3C72080A" +
"CE05AFA0" + "C2BEA28E" + "4FB22787" + "139165EF" + "BA91F90F" +
"8AA5814A" + "503AD4EB" + "04A8C7DD" + "22CE2826", 16);
/**
* A coefficient of the equation y^2 = x^3 + A*x + B defining the elliptic
* curve. (This is called B in RFC 5639 b in SEC 2.)
*/
BigInteger B = new BigInteger("04A8C7DD" + "22CE2826" + "8B39B554" +
"16F0447C" + "2FB77DE1" + "07DCD2A6" + "2E880EA5" + "3EEB62D5" +
"7CB43902" + "95DBC994" + "3AB78696" + "FA504C11", 16);
/**
* The x co-ordinate of the base point G. (This is called x in RFC 5639 and
* SEC 2.)
*/
BigInteger X = new BigInteger("1D1C64F0" + "68CF45FF" + "A2A63A81" +
"B7C13F6B" + "8847A3E7" + "7EF14FE3" + "DB7FCAFE" + "0CBD10E8" +
"E826E034" + "36D646AA" + "EF87B2E2" + "47D4AF1E", 16);
/**
* The y co-ordinate of the base point G. (This is called y in RFC 5639 and
* SEC 2.)
*/
BigInteger Y = new BigInteger("8ABE1D75" + "20F9C2A4" + "5CB1EB8E" +
"95CFD552" + "62B70B29" + "FEEC5864" + "E19C054F" + "F9912928" +
"0E464621" + "77918111" + "42820341" + "263C5315", 16);
/**
* The order of the base point G. (This is called q in RFC 5639 and n in
* SEC 2.)
*/
BigInteger Q = new BigInteger("8CB91E82" + "A3386D28" + "0F5D6F7E" +
"50E641DF" + "152F7109" + "ED5456B3" + "1F166E6C" + "AC0425A7" +
"CF3AB6AF" + "6B7FC310" + "3B883202" + "E9046565", 16);
/** The cofactor of G. (This is called h in RFC 5639 and SEC 2.) */
BigInteger H = BigInteger.ONE;
// Static parameter objects derived from the above parameters
ECCurve CURVE = new ECCurve.Fp(P, A, B);
ECPoint G = new ECPoint.Fp(CURVE,
new ECFieldElement.Fp(P, X),
new ECFieldElement.Fp(P, Y));
ECDomainParameters PARAMETERS = new ECDomainParameters(CURVE, G, Q, H);
}

View File

@@ -0,0 +1,41 @@
package org.briarproject.crypto;
import org.briarproject.api.crypto.MessageDigest;
import org.briarproject.api.crypto.PseudoRandom;
import org.briarproject.util.ByteUtils;
class PseudoRandomImpl implements PseudoRandom {
private final MessageDigest messageDigest;
private byte[] state;
private int offset;
PseudoRandomImpl(MessageDigest messageDigest, int seed1, int seed2) {
this.messageDigest = messageDigest;
byte[] seedBytes = new byte[8];
ByteUtils.writeUint32(seed1, seedBytes, 0);
ByteUtils.writeUint32(seed2, seedBytes, 4);
messageDigest.update(seedBytes);
state = messageDigest.digest();
offset = 0;
}
public synchronized byte[] nextBytes(int bytes) {
byte[] b = new byte[bytes];
int half = state.length / 2;
int off = 0, len = b.length, available = half - offset;
while(available < len) {
System.arraycopy(state, offset, b, off, available);
off += available;
len -= available;
messageDigest.update(state, half, half);
state = messageDigest.digest();
offset = 0;
available = half;
}
System.arraycopy(state, offset, b, off, len);
offset += len;
return b;
}
}

View File

@@ -0,0 +1,86 @@
package org.briarproject.crypto;
import java.math.BigInteger;
import java.security.GeneralSecurityException;
import org.briarproject.api.crypto.KeyParser;
import org.briarproject.api.crypto.PrivateKey;
import org.briarproject.api.crypto.PublicKey;
import org.spongycastle.crypto.params.ECDomainParameters;
import org.spongycastle.crypto.params.ECPrivateKeyParameters;
import org.spongycastle.crypto.params.ECPublicKeyParameters;
import org.spongycastle.math.ec.ECFieldElement;
import org.spongycastle.math.ec.ECPoint;
/**
* A key parser that uses the encoding defined in "SEC 1: Elliptic Curve
* Cryptography", section 2.3 (Certicom Corporation, May 2009). Point
* compression is not used.
*/
class Sec1KeyParser implements KeyParser {
private final ECDomainParameters params;
private final BigInteger modulus;
private final int keyBits, bytesPerInt, publicKeyBytes, privateKeyBytes;
Sec1KeyParser(ECDomainParameters params, BigInteger modulus, int keyBits) {
this.params = params;
this.modulus = modulus;
this.keyBits = keyBits;
bytesPerInt = (keyBits + 7) / 8;
publicKeyBytes = 1 + 2 * bytesPerInt;
privateKeyBytes = bytesPerInt;
}
public PublicKey parsePublicKey(byte[] encodedKey)
throws GeneralSecurityException {
// The validation procedure comes from SEC 1, section 3.2.2.1. Note
// that SEC 1 parameter names are used below, not RFC 5639 names
if(encodedKey.length != publicKeyBytes)
throw new GeneralSecurityException();
// The first byte must be 0x04
if(encodedKey[0] != 4) throw new GeneralSecurityException();
// The x co-ordinate must be >= 0 and < p
byte[] xBytes = new byte[bytesPerInt];
System.arraycopy(encodedKey, 1, xBytes, 0, bytesPerInt);
BigInteger x = new BigInteger(1, xBytes); // Positive signum
if(x.compareTo(modulus) >= 0) throw new GeneralSecurityException();
// The y co-ordinate must be >= 0 and < p
byte[] yBytes = new byte[bytesPerInt];
System.arraycopy(encodedKey, 1 + bytesPerInt, yBytes, 0, bytesPerInt);
BigInteger y = new BigInteger(1, yBytes); // Positive signum
if(y.compareTo(modulus) >= 0) throw new GeneralSecurityException();
// Verify that y^2 == x^3 + ax + b (mod p)
BigInteger a = params.getCurve().getA().toBigInteger();
BigInteger b = params.getCurve().getB().toBigInteger();
BigInteger lhs = y.multiply(y).mod(modulus);
BigInteger rhs = x.multiply(x).add(a).multiply(x).add(b).mod(modulus);
if(!lhs.equals(rhs)) throw new GeneralSecurityException();
// We know the point (x, y) is on the curve, so we can create the point
ECFieldElement elementX = new ECFieldElement.Fp(modulus, x);
ECFieldElement elementY = new ECFieldElement.Fp(modulus, y);
ECPoint pub = new ECPoint.Fp(params.getCurve(), elementX, elementY);
// Verify that the point (x, y) is not the point at infinity
if(pub.isInfinity()) throw new GeneralSecurityException();
// Verify that the point (x, y) times n is the point at infinity
if(!pub.multiply(params.getN()).isInfinity())
throw new GeneralSecurityException();
// Construct a public key from the point (x, y) and the params
ECPublicKeyParameters k = new ECPublicKeyParameters(pub, params);
return new Sec1PublicKey(k, keyBits);
}
public PrivateKey parsePrivateKey(byte[] encodedKey)
throws GeneralSecurityException {
if(encodedKey.length != privateKeyBytes)
throw new GeneralSecurityException();
BigInteger d = new BigInteger(1, encodedKey); // Positive signum
// Verify that the private value is < n
if(d.compareTo(params.getN()) >= 0)
throw new GeneralSecurityException();
// Construct a private key from the private value and the params
ECPrivateKeyParameters k = new ECPrivateKeyParameters(d, params);
return new Sec1PrivateKey(k, keyBits);
}
}

View File

@@ -0,0 +1,27 @@
package org.briarproject.crypto;
import org.briarproject.api.crypto.PrivateKey;
import org.spongycastle.crypto.params.ECPrivateKeyParameters;
class Sec1PrivateKey implements PrivateKey {
private final ECPrivateKeyParameters key;
private final int bytesPerInt;
Sec1PrivateKey(ECPrivateKeyParameters key, int keyBits) {
this.key = key;
bytesPerInt = (keyBits + 7) / 8;
}
public byte[] getEncoded() {
byte[] encodedKey = new byte[bytesPerInt];
byte[] d = key.getD().toByteArray();
Sec1Utils.convertToFixedLength(d, encodedKey, bytesPerInt, 0);
return encodedKey;
}
ECPrivateKeyParameters getKey() {
return key;
}
}

View File

@@ -0,0 +1,37 @@
package org.briarproject.crypto;
import org.briarproject.api.crypto.PublicKey;
import org.spongycastle.crypto.params.ECPublicKeyParameters;
/**
* An elliptic curve public key that uses the encoding defined in "SEC 1:
* Elliptic Curve Cryptography", section 2.3 (Certicom Corporation, May 2009).
* Point compression is not used.
*/
class Sec1PublicKey implements PublicKey {
private final ECPublicKeyParameters key;
private final int bytesPerInt, publicKeyBytes;
Sec1PublicKey(ECPublicKeyParameters key, int keyBits) {
this.key = key;
bytesPerInt = (keyBits + 7) / 8;
publicKeyBytes = 1 + 2 * bytesPerInt;
}
public byte[] getEncoded() {
byte[] encodedKey = new byte[publicKeyBytes];
encodedKey[0] = 4;
byte[] x = key.getQ().getX().toBigInteger().toByteArray();
Sec1Utils.convertToFixedLength(x, encodedKey, bytesPerInt, 1);
byte[] y = key.getQ().getY().toBigInteger().toByteArray();
Sec1Utils.convertToFixedLength(y, encodedKey, bytesPerInt,
1 + bytesPerInt);
return encodedKey;
}
ECPublicKeyParameters getKey() {
return key;
}
}

View File

@@ -0,0 +1,15 @@
package org.briarproject.crypto;
class Sec1Utils {
static void convertToFixedLength(byte[] src, byte[] dest, int destLen,
int destOff) {
if(src.length < destLen) {
destOff += destLen - src.length;
System.arraycopy(src, 0, dest, destOff, src.length);
} else {
int srcOff = src.length - destLen;
System.arraycopy(src, srcOff, dest, destOff, destLen);
}
}
}

View File

@@ -0,0 +1,30 @@
package org.briarproject.crypto;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.util.ByteUtils;
class SecretKeyImpl implements SecretKey {
private final byte[] key;
private boolean erased = false; // Locking: this
SecretKeyImpl(byte[] key) {
this.key = key;
}
public synchronized byte[] getEncoded() {
if(erased) throw new IllegalStateException();
return key;
}
public SecretKey copy() {
return new SecretKeyImpl(key.clone());
}
public synchronized void erase() {
if(erased) throw new IllegalStateException();
ByteUtils.erase(key);
erased = true;
}
}

View File

@@ -0,0 +1,58 @@
package org.briarproject.crypto;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import org.briarproject.api.crypto.PrivateKey;
import org.briarproject.api.crypto.PublicKey;
import org.briarproject.api.crypto.Signature;
import org.spongycastle.crypto.digests.SHA384Digest;
import org.spongycastle.crypto.params.ECPrivateKeyParameters;
import org.spongycastle.crypto.params.ECPublicKeyParameters;
import org.spongycastle.crypto.params.ParametersWithRandom;
import org.spongycastle.crypto.signers.DSADigestSigner;
import org.spongycastle.crypto.signers.ECDSASigner;
class SignatureImpl implements Signature {
private final SecureRandom secureRandom;
private final DSADigestSigner signer;
SignatureImpl(SecureRandom secureRandom) {
this.secureRandom = secureRandom;
signer = new DSADigestSigner(new ECDSASigner(), new SHA384Digest());
}
public void initSign(PrivateKey k) throws GeneralSecurityException {
if(!(k instanceof Sec1PrivateKey)) throw new GeneralSecurityException();
ECPrivateKeyParameters priv = ((Sec1PrivateKey) k).getKey();
signer.init(true, new ParametersWithRandom(priv, secureRandom));
}
public void initVerify(PublicKey k) throws GeneralSecurityException {
if(!(k instanceof Sec1PublicKey)) throw new GeneralSecurityException();
ECPublicKeyParameters pub = ((Sec1PublicKey) k).getKey();
signer.init(false, pub);
}
public void update(byte b) {
signer.update(b);
}
public void update(byte[] b) {
update(b, 0, b.length);
}
public void update(byte[] b, int off, int len) {
signer.update(b, off, len);
}
public byte[] sign() {
return signer.generateSignature();
}
public boolean verify(byte[] signature) {
return signer.verifySignature(signature);
}
}

View File

@@ -0,0 +1,797 @@
package org.briarproject.db;
import java.io.IOException;
import java.util.Collection;
import java.util.Map;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorId;
import org.briarproject.api.Contact;
import org.briarproject.api.ContactId;
import org.briarproject.api.LocalAuthor;
import org.briarproject.api.TransportConfig;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.db.DbException;
import org.briarproject.api.db.MessageHeader;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.GroupId;
import org.briarproject.api.messaging.GroupStatus;
import org.briarproject.api.messaging.Message;
import org.briarproject.api.messaging.MessageId;
import org.briarproject.api.messaging.RetentionAck;
import org.briarproject.api.messaging.RetentionUpdate;
import org.briarproject.api.messaging.SubscriptionAck;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.TransportAck;
import org.briarproject.api.messaging.TransportUpdate;
import org.briarproject.api.transport.Endpoint;
import org.briarproject.api.transport.TemporarySecret;
// FIXME: Document the preconditions for calling each method
/**
* A low-level interface to the database (DatabaseComponent provides a
* high-level interface). Most operations take a transaction argument, which is
* obtained by calling {@link #startTransaction()}. Every transaction must be
* terminated by calling either {@link #abortTransaction(T)} or
* {@link #commitTransaction(T)}, even if an exception is thrown.
* <p>
* Locking is provided by the DatabaseComponent implementation. To prevent
* deadlock, locks must be acquired in the following (alphabetical) order:
* <ul>
* <li> contact
* <li> identity
* <li> message
* <li> retention
* <li> subscription
* <li> transport
* <li> window
* </ul>
* If table A has a foreign key pointing to table B, we get a read lock on A to
* read A, a write lock on A to write A, and write locks on A and B to write B.
*/
interface Database<T> {
/** Opens the database and returns true if the database already existed. */
boolean open() throws DbException, IOException;
/**
* Prevents new transactions from starting, waits for all current
* transactions to finish, and closes the database.
*/
void close() throws DbException, IOException;
/** Starts a new transaction and returns an object representing it. */
T startTransaction() throws DbException;
/**
* Aborts the given transaction - no changes made during the transaction
* will be applied to the database.
*/
void abortTransaction(T txn);
/**
* Commits the given transaction - all changes made during the transaction
* will be applied to the database.
*/
void commitTransaction(T txn) throws DbException;
/**
* Stores a contact associated with the given local and remote pseudonyms,
* and returns an ID for the contact.
* <p>
* Locking: contact write, message write, retention write,
* subscription write, transport write, window write.
*/
ContactId addContact(T txn, Author remote, AuthorId local)
throws DbException;
/**
* Stores an endpoint.
* <p>
* Locking: window write.
*/
void addEndpoint(T txn, Endpoint ep) throws DbException;
/**
* Subscribes to a group, or returns false if the user already has the
* maximum number of subscriptions.
* <p>
* Locking: message write, subscription write.
*/
boolean addGroup(T txn, Group g) throws DbException;
/**
* Stores a local pseudonym.
* <p>
* Locking: contact write, identity write, message write, retention write,
* subscription write, transport write, window write.
*/
void addLocalAuthor(T txn, LocalAuthor a) throws DbException;
/**
* Stores a message.
* <p>
* Locking: message write.
*/
void addMessage(T txn, Message m, boolean incoming) throws DbException;
/**
* Records that a message has been offered by the given contact.
* <p>
* Locking: message write.
*/
void addOfferedMessage(T txn, ContactId c, MessageId m) throws DbException;
/**
* Stores the given temporary secrets and deletes any secrets that have
* been made obsolete.
* <p>
* Locking: window write.
*/
void addSecrets(T txn, Collection<TemporarySecret> secrets)
throws DbException;
/**
* Initialises the status of the given message with respect to the given
* contact.
* @param ack whether the message needs to be acknowledged.
* @param seen whether the contact has seen the message.
* <p>
* Locking: message write.
*/
void addStatus(T txn, ContactId c, MessageId m, boolean ack, boolean seen)
throws DbException;
/**
* Stores a transport and returns true if the transport was not previously
* in the database.
* <p>
* Locking: transport write, window write.
*/
boolean addTransport(T txn, TransportId t, long maxLatency)
throws DbException;
/**
* Makes a group visible to the given contact.
* <p>
* Locking: subscription write.
*/
void addVisibility(T txn, ContactId c, GroupId g) throws DbException;
/**
* Returns true if the database contains the given contact.
* <p>
* Locking: contact read.
*/
boolean containsContact(T txn, AuthorId a) throws DbException;
/**
* Returns true if the database contains the given contact.
* <p>
* Locking: contact read.
*/
boolean containsContact(T txn, ContactId c) throws DbException;
/**
* Returns true if the user subscribes to the given group.
* <p>
* Locking: subscription read.
*/
boolean containsGroup(T txn, GroupId g) throws DbException;
/**
* Returns true if the database contains the given local pseudonym.
* <p>
* Locking: identity read.
*/
boolean containsLocalAuthor(T txn, AuthorId a) throws DbException;
/**
* Returns true if the database contains the given message.
* <p>
* Locking: message read.
*/
boolean containsMessage(T txn, MessageId m) throws DbException;
/**
* Returns true if the database contains the given transport.
* <p>
* Locking: transport read.
*/
boolean containsTransport(T txn, TransportId t) throws DbException;
/**
* Returns true if the user subscribes to the given group and the group is
* visible to the given contact.
* <p>
* Locking: subscription read.
*/
boolean containsVisibleGroup(T txn, ContactId c, GroupId g)
throws DbException;
/**
* Returns true if the database contains the given message and the message
* is visible to the given contact.
* <p>
* Locking: message read, subscription read.
*/
boolean containsVisibleMessage(T txn, ContactId c, MessageId m)
throws DbException;
/**
* Returns the number of messages offered by the given contact.
* <p>
* Locking: message read.
*/
int countOfferedMessages(T txn, ContactId c) throws DbException;
/**
* Returns the status of all groups to which the user subscribes or can
* subscribe, excluding inbox groups.
* <p>
* Locking: subscription read.
*/
Collection<GroupStatus> getAvailableGroups(T txn) throws DbException;
/**
* Returns the configuration for the given transport.
* <p>
* Locking: transport read.
*/
TransportConfig getConfig(T txn, TransportId t) throws DbException;
/**
* Returns the contact with the given ID.
* <p>
* Locking: contact read.
*/
Contact getContact(T txn, ContactId c) throws DbException;
/**
* Returns the IDs of all contacts.
* <p>
* Locking: contact read.
*/
Collection<ContactId> getContactIds(T txn) throws DbException;
/**
* Returns all contacts.
* <p>
* Locking: contact read, window read.
*/
Collection<Contact> getContacts(T txn) throws DbException;
/**
* Returns all contacts associated with the given local pseudonym.
* <p>
* Locking: contact read.
*/
Collection<ContactId> getContacts(T txn, AuthorId a) throws DbException;
/**
* Returns all endpoints.
* <p>
* Locking: window read.
*/
Collection<Endpoint> getEndpoints(T txn) throws DbException;
/**
* Returns the amount of free storage space available to the database, in
* bytes. This is based on the minimum of the space available on the device
* where the database is stored and the database's configured size.
*/
long getFreeSpace() throws DbException;
/**
* Returns the group with the given ID, if the user subscribes to it.
* <p>
* Locking: subscription read.
*/
Group getGroup(T txn, GroupId g) throws DbException;
/**
* Returns all groups to which the user subscribes.
* <p>
* Locking: subscription read.
*/
Collection<Group> getGroups(T txn) throws DbException;
/**
* Returns the ID of the inbox group for the given contact, or null if no
* inbox group has been set.
* <p>
* Locking: contact read, subscription read.
*/
GroupId getInboxGroupId(T txn, ContactId c) throws DbException;
/**
* Returns the headers of all messages in the inbox group for the given
* contact, or null if no inbox group has been set.
* <p>
* Locking: contact read, identity read, message read, subscription read.
*/
Collection<MessageHeader> getInboxMessageHeaders(T txn, ContactId c)
throws DbException;
/**
* Returns the time at which a connection to each contact was last opened
* or closed.
* <p>
* Locking: window read.
*/
Map<ContactId, Long> getLastConnected(T txn) throws DbException;
/**
* Returns the local pseudonym with the given ID.
* <p>
* Locking: identity read.
*/
LocalAuthor getLocalAuthor(T txn, AuthorId a) throws DbException;
/**
* Returns all local pseudonyms.
* <p>
* Locking: identity read.
*/
Collection<LocalAuthor> getLocalAuthors(T txn) throws DbException;
/**
* Returns the local transport properties for all transports.
* <p>
* Locking: transport read.
*/
Map<TransportId, TransportProperties> getLocalProperties(T txn)
throws DbException;
/**
* Returns the local transport properties for the given transport.
* <p>
* Locking: transport read.
*/
TransportProperties getLocalProperties(T txn, TransportId t)
throws DbException;
/**
* Returns the body of the message identified by the given ID.
* <p>
* Locking: message read.
*/
byte[] getMessageBody(T txn, MessageId m) throws DbException;
/**
* Returns the headers of all messages in the given group.
* <p>
* Locking: message read.
*/
Collection<MessageHeader> getMessageHeaders(T txn, GroupId g)
throws DbException;
/**
* Returns the IDs of some messages received from the given contact that
* need to be acknowledged, up to the given number of messages.
* <p>
* Locking: message read.
*/
Collection<MessageId> getMessagesToAck(T txn, ContactId c, int maxMessages)
throws DbException;
/**
* Returns the IDs of some messages that are eligible to be offered to the
* given contact, up to the given number of messages.
* <p>
* Locking: message read, subscription read.
*/
Collection<MessageId> getMessagesToOffer(T txn, ContactId c,
int maxMessages) throws DbException;
/**
* Returns the IDs of some messages that are eligible to be sent to the
* given contact, up to the given total length.
* <p>
* Locking: message read, subscription read.
*/
Collection<MessageId> getMessagesToSend(T txn, ContactId c, int maxLength)
throws DbException;
/**
* Returns the IDs of some messages that are eligible to be requested from
* the given contact, up to the given number of messages.
* <p>
* Locking: message read.
*/
Collection<MessageId> getMessagesToRequest(T txn, ContactId c,
int maxMessages) throws DbException;
/**
* Returns the IDs of the oldest messages in the database, with a total
* size less than or equal to the given size.
* <p>
* Locking: message read.
*/
Collection<MessageId> getOldMessages(T txn, int size) throws DbException;
/**
* Returns the parent of the given message, or null if either the message
* has no parent, or the parent is absent from the database, or the parent
* belongs to a different group.
* <p>
* Locking: message read.
*/
MessageId getParent(T txn, MessageId m) throws DbException;
/**
* Returns the message identified by the given ID, in serialised form.
* <p>
* Locking: message read.
*/
byte[] getRawMessage(T txn, MessageId m) throws DbException;
/**
* Returns true if the given message is marked as read.
* <p>
* Locking: message read.
*/
boolean getReadFlag(T txn, MessageId m) throws DbException;
/**
* Returns all remote properties for the given transport.
* <p>
* Locking: transport read.
*/
Map<ContactId, TransportProperties> getRemoteProperties(T txn,
TransportId t) throws DbException;
/**
* Returns the IDs of some messages that are eligible to be sent to the
* given contact and have been requested by the contact, up to the given
* total length.
* <p>
* Locking: message read, subscription read.
*/
Collection<MessageId> getRequestedMessagesToSend(T txn, ContactId c,
int maxLength) throws DbException;
/**
* Returns a retention ack for the given contact, or null if no ack is due.
* <p>
* Locking: retention write.
*/
RetentionAck getRetentionAck(T txn, ContactId c) throws DbException;
/**
* Returns a retention update for the given contact and updates its expiry
* time using the given latency, or returns null if no update is due.
* <p>
* Locking: message read, retention write.
*/
RetentionUpdate getRetentionUpdate(T txn, ContactId c, long maxLatency)
throws DbException;
/**
* Returns all temporary secrets.
* <p>
* Locking: window read.
*/
Collection<TemporarySecret> getSecrets(T txn) throws DbException;
/**
* Returns a subscription ack for the given contact, or null if no ack is
* due.
* <p>
* Locking: subscription write.
*/
SubscriptionAck getSubscriptionAck(T txn, ContactId c) throws DbException;
/**
* Returns a subscription update for the given contact and updates its
* expiry time using the given latency, or returns null if no update is due.
* <p>
* Locking: subscription write.
*/
SubscriptionUpdate getSubscriptionUpdate(T txn, ContactId c,
long maxLatency) throws DbException;
/**
* Returns a collection of transport acks for the given contact, or null if
* no acks are due.
* <p>
* Locking: transport write.
*/
Collection<TransportAck> getTransportAcks(T txn, ContactId c)
throws DbException;
/**
* Returns the maximum latencies of all local transports.
* <p>
* Locking: transport read.
*/
Map<TransportId, Long> getTransportLatencies(T txn) throws DbException;
/**
* Returns a collection of transport updates for the given contact and
* updates their expiry times using the given latency, or returns null if
* no updates are due.
* <p>
* Locking: transport write.
*/
Collection<TransportUpdate> getTransportUpdates(T txn, ContactId c,
long maxLatency) throws DbException;
/**
* Returns the number of unread messages in each subscribed group.
* <p>
* Locking: message read.
*/
Map<GroupId, Integer> getUnreadMessageCounts(T txn) throws DbException;
/**
* Returns the IDs of all contacts to which the given group is visible.
* <p>
* Locking: subscription read.
*/
Collection<ContactId> getVisibility(T txn, GroupId g) throws DbException;
/**
* Increments the outgoing connection counter for the given endpoint
* in the given rotation period and returns the old value, or -1 if the
* counter does not exist.
* <p>
* Locking: window write.
*/
long incrementConnectionCounter(T txn, ContactId c, TransportId t,
long period) throws DbException;
/**
* Increments the retention time versions for all contacts to indicate that
* the database's retention time has changed and updates should be sent.
* <p>
* Locking: retention write.
*/
void incrementRetentionVersions(T txn) throws DbException;
/**
* Marks the given messages as not needing to be acknowledged to the
* given contact.
* <p>
* Locking: message write.
*/
void lowerAckFlag(T txn, ContactId c, Collection<MessageId> acked)
throws DbException;
/**
* Marks the given messages as not having been requested by the given
* contact.
* <p>
* Locking: message write.
*/
void lowerRequestedFlag(T txn, ContactId c, Collection<MessageId> requested)
throws DbException;
/**
* Merges the given configuration with the existing configuration for the
* given transport.
* <p>
* Locking: transport write.
*/
void mergeConfig(T txn, TransportId t, TransportConfig config)
throws DbException;
/**
* Merges the given properties with the existing local properties for the
* given transport.
* <p>
* Locking: transport write.
*/
void mergeLocalProperties(T txn, TransportId t, TransportProperties p)
throws DbException;
/**
* Marks a message as needing to be acknowledged to the given contact.
* <p>
* Locking: message write.
*/
void raiseAckFlag(T txn, ContactId c, MessageId m) throws DbException;
/**
* Marks a message as having been requested by the given contact.
* <p>
* Locking: message write.
*/
void raiseRequestedFlag(T txn, ContactId c, MessageId m) throws DbException;
/**
* Marks a message as having been seen by the given contact.
* <p>
* Locking: message write.
*/
void raiseSeenFlag(T txn, ContactId c, MessageId m) throws DbException;
/**
* Removes a contact from the database.
* <p>
* Locking: contact write, message write, retention write,
* subscription write, transport write, window write.
*/
void removeContact(T txn, ContactId c) throws DbException;
/**
* Unsubscribes from a group. Any messages belonging to the group are
* deleted from the database.
* <p>
* Locking: message write, subscription write.
*/
void removeGroup(T txn, GroupId g) throws DbException;
/**
* Removes a local pseudonym (and all associated contacts) from the
* database.
* <p>
* Locking: contact write, identity write, message write, retention write,
* subscription write, transport write, window write.
*/
void removeLocalAuthor(T txn, AuthorId a) throws DbException;
/**
* Removes a message (and all associated state) from the database.
* <p>
* Locking: message write.
*/
void removeMessage(T txn, MessageId m) throws DbException;
/**
* Removes an offered message that was offered by the given contact, or
* returns false if there is no such message.
* <p>
* Locking: message write.
*/
boolean removeOfferedMessage(T txn, ContactId c, MessageId m)
throws DbException;
/**
* Removes the given offered messages that were offered by the given
* contact.
* <p>
* Locking: message write.
*/
void removeOfferedMessages(T txn, ContactId c,
Collection<MessageId> requested) throws DbException;
/**
* Removes a transport (and all associated state) from the database.
* <p>
* Locking: transport write, window write.
*/
void removeTransport(T txn, TransportId t) throws DbException;
/**
* Makes a group invisible to the given contact.
* <p>
* Locking: subscription write.
*/
void removeVisibility(T txn, ContactId c, GroupId g) throws DbException;
/**
* Resets the transmission count and expiry time of the given message with
* respect to the given contact.
* <p>
* Locking: message write.
*/
void resetExpiryTime(T txn, ContactId c, MessageId m) throws DbException;
/**
* Sets the connection reordering window for the given endpoint in the
* given rotation period.
* <p>
* Locking: window write.
*/
void setConnectionWindow(T txn, ContactId c, TransportId t, long period,
long centre, byte[] bitmap) throws DbException;
/**
* Updates the groups to which the given contact subscribes and returns
* true, unless an update with an equal or higher version number has
* already been received from the contact.
* <p>
* Locking: subscription write.
*/
boolean setGroups(T txn, ContactId c, Collection<Group> groups,
long version) throws DbException;
/**
* Makes a group visible to the given contact, adds it to the contact's
* subscriptions, and sets it as the inbox group for the contact.
* <p>
* Locking: contact read, message write, subscription write.
*/
public void setInboxGroup(T txn, ContactId c, Group g) throws DbException;
/**
* Sets the time at which a connection to the given contact was last made.
* <p>
* Locking: window write.
*/
void setLastConnected(T txn, ContactId c, long now) throws DbException;
/**
* Marks a message as read or unread.
* <p>
* Locking: message write.
*/
void setReadFlag(T txn, MessageId m, boolean read) throws DbException;
/**
* Sets the remote transport properties for the given contact, replacing
* any existing properties.
* <p>
* Locking: transport write.
*/
void setRemoteProperties(T txn, ContactId c,
Map<TransportId, TransportProperties> p) throws DbException;
/**
* Updates the remote transport properties for the given contact and the
* given transport, replacing any existing properties, and returns true,
* unless an update with an equal or higher version number has already been
* received from the contact.
* <p>
* Locking: transport write.
*/
boolean setRemoteProperties(T txn, ContactId c, TransportId t,
TransportProperties p, long version) throws DbException;
/**
* Sets the retention time of the given contact's database and returns
* true, unless an update with an equal or higher version number has
* already been received from the contact.
* <p>
* Locking: retention write.
*/
boolean setRetentionTime(T txn, ContactId c, long retention, long version)
throws DbException;
/**
* Records a retention ack from the given contact for the given version,
* unless the contact has already acked an equal or higher version.
* <p>
* Locking: retention write.
*/
void setRetentionUpdateAcked(T txn, ContactId c, long version)
throws DbException;
/**
* Records a subscription ack from the given contact for the given version,
* unless the contact has already acked an equal or higher version.
* <p>
* Locking: subscription write.
*/
void setSubscriptionUpdateAcked(T txn, ContactId c, long version)
throws DbException;
/**
* Records a transport ack from the give contact for the given version,
* unless the contact has already acked an equal or higher version.
* <p>
* Locking: transport write.
*/
void setTransportUpdateAcked(T txn, ContactId c, TransportId t,
long version) throws DbException;
/**
* Makes a group visible or invisible to future contacts by default.
* <p>
* Locking: subscription write.
*/
void setVisibleToAll(T txn, GroupId g, boolean all) throws DbException;
/**
* Updates the transmission count and expiry time of the given message
* with respect to the given contact, using the latency of the transport
* over which it was sent.
* <p>
* Locking: message write.
*/
void updateExpiryTime(T txn, ContactId c, MessageId m, long maxLatency)
throws DbException;
}

View File

@@ -0,0 +1,34 @@
package org.briarproject.db;
import org.briarproject.api.db.DbException;
interface DatabaseCleaner {
/**
* Starts a new thread to monitor the amount of free storage space
* available to the database and expire old messages as necessary. The
* cleaner will pause for the given number of milliseconds between sweeps.
*/
void startCleaning(Callback callback, long msBetweenSweeps);
/** Tells the cleaner thread to exit. */
void stopCleaning();
interface Callback {
/**
* Checks how much free storage space is available to the database, and
* if necessary expires old messages until the free space is at least
* DatabaseConstants.MIN_FREE_SPACE. If the free space is less than
* DatabaseConstants.CRITICAL_FREE_SPACE and there are no more messages
* to expire, an Error will be thrown.
*/
void checkFreeSpaceAndClean() throws DbException;
/**
* Returns true if the amount of free storage space available to the
* database should be checked.
*/
boolean shouldCheckFreeSpace();
}
}

View File

@@ -0,0 +1,55 @@
package org.briarproject.db;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import java.util.TimerTask;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.briarproject.api.db.DbClosedException;
import org.briarproject.api.db.DbException;
import org.briarproject.api.system.Timer;
class DatabaseCleanerImpl extends TimerTask implements DatabaseCleaner {
private static final Logger LOG =
Logger.getLogger(DatabaseCleanerImpl.class.getName());
private final Timer timer;
private volatile Callback callback = null;
@Inject
DatabaseCleanerImpl(Timer timer) {
this.timer = timer;
}
public void startCleaning(Callback callback, long msBetweenSweeps) {
this.callback = callback;
timer.scheduleAtFixedRate(this, 0, msBetweenSweeps);
}
public void stopCleaning() {
timer.cancel();
}
public void run() {
if(callback == null) throw new IllegalStateException();
try {
if(callback.shouldCheckFreeSpace()) {
if(LOG.isLoggable(INFO)) LOG.info("Checking free space");
callback.checkFreeSpaceAndClean();
}
} catch(DbClosedException e) {
if(LOG.isLoggable(INFO)) LOG.info("Database closed, exiting");
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
throw new Error(e); // Kill the application
} catch(RuntimeException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
throw new Error(e); // Kill the application
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,45 @@
package org.briarproject.db;
interface DatabaseConstants {
/**
* The maximum number of offered messages from each contact that will be
* stored. If offers arrive more quickly than requests can be sent and this
* limit is reached, additional offers will not be stored.
*/
int MAX_OFFERED_MESSAGES = 1000;
// FIXME: These should be configurable
/**
* The minimum amount of space in bytes that should be kept free on the
* device where the database is stored. Whenever less than this much space
* is free, old messages will be expired from the database.
*/
long MIN_FREE_SPACE = 50 * 1024 * 1024; // 50 MiB
/**
* The minimum amount of space in bytes that must be kept free on the device
* where the database is stored. If less than this much space is free and
* there are no more messages to expire, an Error will be thrown.
*/
long CRITICAL_FREE_SPACE = 10 * 1024 * 1024; // 10 MiB
/**
* The amount of free space will be checked whenever this many bytes of
* messages have been added to the database since the last check.
*/
int MAX_BYTES_BETWEEN_SPACE_CHECKS = 1024 * 1024; // 1 MiB
/**
* The amount of free space will be checked whenever this many milliseconds
* have passed since the last check.
*/
long MAX_MS_BETWEEN_SPACE_CHECKS = 60 * 1000; // 1 min
/**
* Up to this many bytes of messages will be expired from the database each
* time it is necessary to expire messages.
*/
int BYTES_PER_SWEEP = 10 * 1024 * 1024; // 10 MiB
}

View File

@@ -0,0 +1,67 @@
package org.briarproject.db;
import static java.util.concurrent.TimeUnit.SECONDS;
import java.sql.Connection;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadPoolExecutor;
import javax.inject.Singleton;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DatabaseConfig;
import org.briarproject.api.db.DatabaseExecutor;
import org.briarproject.api.lifecycle.LifecycleManager;
import org.briarproject.api.lifecycle.ShutdownManager;
import org.briarproject.api.system.Clock;
import org.briarproject.api.system.FileUtils;
import org.briarproject.api.system.SystemClock;
import com.google.inject.AbstractModule;
import com.google.inject.Provides;
public class DatabaseModule extends AbstractModule {
/** The maximum number of executor threads. */
private static final int MAX_EXECUTOR_THREADS = 10;
private final ExecutorService databaseExecutor;
public DatabaseModule() {
// The queue is unbounded, so tasks can be dependent
BlockingQueue<Runnable> queue = new LinkedBlockingQueue<Runnable>();
// Discard tasks that are submitted during shutdown
RejectedExecutionHandler policy =
new ThreadPoolExecutor.DiscardPolicy();
// Create a limited # of threads and keep them in the pool for 60 secs
databaseExecutor = new ThreadPoolExecutor(0, MAX_EXECUTOR_THREADS,
60, SECONDS, queue, policy);
}
protected void configure() {
bind(DatabaseCleaner.class).to(DatabaseCleanerImpl.class);
}
@Provides
Database<Connection> getDatabase(DatabaseConfig config,
FileUtils fileUtils) {
return new H2Database(config, fileUtils, new SystemClock());
}
@Provides @Singleton
DatabaseComponent getDatabaseComponent(Database<Connection> db,
DatabaseCleaner cleaner, ShutdownManager shutdown, Clock clock) {
return new DatabaseComponentImpl<Connection>(db, cleaner, shutdown,
clock);
}
@Provides @Singleton @DatabaseExecutor
Executor getDatabaseExecutor(LifecycleManager lifecycleManager) {
lifecycleManager.registerForShutdown(databaseExecutor);
return databaseExecutor;
}
}

View File

@@ -0,0 +1,9 @@
package org.briarproject.db;
import java.sql.SQLException;
/** Thrown when the database is in an illegal state. */
class DbStateException extends SQLException {
private static final long serialVersionUID = 10793396057218891L;
}

View File

@@ -0,0 +1,30 @@
package org.briarproject.db;
class ExponentialBackoff {
/**
* Returns the expiry time of a packet transmitted at time <tt>now</tt>
* over a transport with maximum latency <tt>maxLatency</tt>, where the
* packet has previously been transmitted <tt>txCount</tt> times. All times
* are in milliseconds. The expiry time is
* <tt>now + maxLatency * 2 ^ (txCount + 1)</tt>, so the interval between
* transmissions increases exponentially. If the expiry time would
* be greater than Long.MAX_VALUE, Long.MAX_VALUE is returned.
*/
static long calculateExpiry(long now, long maxLatency, int txCount) {
if(now < 0) throw new IllegalArgumentException();
if(maxLatency <= 0) throw new IllegalArgumentException();
if(txCount < 0) throw new IllegalArgumentException();
// The maximum round-trip time is twice the maximum latency
long roundTrip = maxLatency * 2;
if(roundTrip < 0) return Long.MAX_VALUE;
// The interval between transmissions is roundTrip * 2 ^ txCount
for(int i = 0; i < txCount; i++) {
roundTrip <<= 1;
if(roundTrip < 0) return Long.MAX_VALUE;
}
// The expiry time is the current time plus the interval
long expiry = now + roundTrip;
return expiry < 0 ? Long.MAX_VALUE : expiry;
}
}

View File

@@ -0,0 +1,112 @@
package org.briarproject.db;
import java.io.File;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Properties;
import javax.inject.Inject;
import org.briarproject.api.db.DatabaseConfig;
import org.briarproject.api.db.DbException;
import org.briarproject.api.system.Clock;
import org.briarproject.api.system.FileUtils;
import org.briarproject.util.StringUtils;
/** Contains all the H2-specific code for the database. */
class H2Database extends JdbcDatabase {
private static final String HASH_TYPE = "BINARY(48)";
private static final String BINARY_TYPE = "BINARY";
private static final String COUNTER_TYPE = "INT NOT NULL AUTO_INCREMENT";
private static final String SECRET_TYPE = "BINARY(32)";
private final DatabaseConfig config;
private final FileUtils fileUtils;
private final String url;
@Inject
H2Database(DatabaseConfig config, FileUtils fileUtils, Clock clock) {
super(HASH_TYPE, BINARY_TYPE, COUNTER_TYPE, SECRET_TYPE, clock);
this.config = config;
this.fileUtils = fileUtils;
String path = new File(config.getDatabaseDirectory(), "db").getPath();
// FIXME: Remove WRITE_DELAY=0 after implementing BTPv2?
url = "jdbc:h2:split:" + path + ";CIPHER=AES;MULTI_THREADED=1"
+ ";WRITE_DELAY=0;DB_CLOSE_ON_EXIT=false";
}
public boolean open() throws DbException, IOException {
boolean reopen = config.databaseExists();
if(!reopen) config.getDatabaseDirectory().mkdirs();
super.open("org.h2.Driver", reopen);
return reopen;
}
public void close() throws DbException {
// H2 will close the database when the last connection closes
try {
super.closeAllConnections();
} catch(SQLException e) {
throw new DbException(e);
}
}
public long getFreeSpace() throws DbException {
File dir = config.getDatabaseDirectory();
long maxSize = config.getMaxSize();
try {
long free = fileUtils.getFreeSpace(dir);
long used = getDiskSpace(dir);
long quota = maxSize - used;
long min = Math.min(free, quota);
return min;
} catch(IOException e) {
throw new DbException(e);
}
}
private long getDiskSpace(File f) {
long total = 0;
if(f.isDirectory()) {
for(File child : f.listFiles()) total += getDiskSpace(child);
return total;
} else return f.length();
}
protected Connection createConnection() throws SQLException {
byte[] key = config.getEncryptionKey();
if(key == null) throw new IllegalStateException();
char[] password = encodePassword(key);
Properties props = new Properties();
props.setProperty("user", "user");
props.put("password", password);
try {
return DriverManager.getConnection(url, props);
} finally {
for(int i = 0; i < password.length; i++) password[i] = 0;
}
}
private char[] encodePassword(byte[] key) {
// The database password is the hex-encoded key
char[] hex = StringUtils.toHexChars(key);
// Separate the database password from the user password with a space
char[] user = "password".toCharArray();
char[] combined = new char[hex.length + 1 + user.length];
System.arraycopy(hex, 0, combined, 0, hex.length);
combined[hex.length] = ' ';
System.arraycopy(user, 0, combined, hex.length + 1, user.length);
// Erase the hex-encoded key
for(int i = 0; i < hex.length; i++) hex[i] = 0;
return combined;
}
protected void flushBuffersToDisk(Statement s) throws SQLException {
// FIXME: Remove this after implementing BTPv2?
s.execute("CHECKPOINT SYNC");
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,186 @@
package org.briarproject.invitation;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import java.util.Map;
import java.util.logging.Logger;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorFactory;
import org.briarproject.api.LocalAuthor;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.crypto.PseudoRandom;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.messaging.GroupFactory;
import org.briarproject.api.plugins.duplex.DuplexPlugin;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.serial.Reader;
import org.briarproject.api.serial.ReaderFactory;
import org.briarproject.api.serial.Writer;
import org.briarproject.api.serial.WriterFactory;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.ConnectionDispatcher;
import org.briarproject.api.transport.ConnectionReader;
import org.briarproject.api.transport.ConnectionReaderFactory;
import org.briarproject.api.transport.ConnectionWriter;
import org.briarproject.api.transport.ConnectionWriterFactory;
/** A connection thread for the peer being Alice in the invitation protocol. */
class AliceConnector extends Connector {
private static final Logger LOG =
Logger.getLogger(AliceConnector.class.getName());
AliceConnector(CryptoComponent crypto, DatabaseComponent db,
ReaderFactory readerFactory, WriterFactory writerFactory,
ConnectionReaderFactory connectionReaderFactory,
ConnectionWriterFactory connectionWriterFactory,
AuthorFactory authorFactory, GroupFactory groupFactory,
KeyManager keyManager, ConnectionDispatcher connectionDispatcher,
Clock clock, ConnectorGroup group, DuplexPlugin plugin,
LocalAuthor localAuthor,
Map<TransportId, TransportProperties> localProps,
PseudoRandom random) {
super(crypto, db, readerFactory, writerFactory, connectionReaderFactory,
connectionWriterFactory, authorFactory, groupFactory,
keyManager, connectionDispatcher, clock, group, plugin,
localAuthor, localProps, random);
}
@Override
public void run() {
// Create an incoming or outgoing connection
DuplexTransportConnection conn = createInvitationConnection();
if(conn == null) return;
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " connected");
// Don't proceed with more than one connection
if(group.getAndSetConnected()) {
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " redundant");
tryToClose(conn, false);
return;
}
// Carry out the key agreement protocol
InputStream in;
OutputStream out;
Reader r;
Writer w;
byte[] secret;
try {
in = conn.getInputStream();
out = conn.getOutputStream();
r = readerFactory.createReader(in);
w = writerFactory.createWriter(out);
// Alice goes first
sendPublicKeyHash(w);
byte[] hash = receivePublicKeyHash(r);
sendPublicKey(w);
byte[] key = receivePublicKey(r);
secret = deriveMasterSecret(hash, key, true);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
group.keyAgreementFailed();
tryToClose(conn, true);
return;
} catch(GeneralSecurityException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
group.keyAgreementFailed();
tryToClose(conn, true);
return;
}
// The key agreement succeeded - derive the confirmation codes
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " agreement succeeded");
int[] codes = crypto.deriveConfirmationCodes(secret);
int aliceCode = codes[0], bobCode = codes[1];
group.keyAgreementSucceeded(aliceCode, bobCode);
// Exchange confirmation results
boolean localMatched, remoteMatched;
try {
localMatched = group.waitForLocalConfirmationResult();
sendConfirmation(w, localMatched);
remoteMatched = receiveConfirmation(r);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
group.remoteConfirmationFailed();
tryToClose(conn, true);
return;
} catch(InterruptedException e) {
if(LOG.isLoggable(WARNING))
LOG.warning("Interrupted while waiting for confirmation");
group.remoteConfirmationFailed();
tryToClose(conn, true);
Thread.currentThread().interrupt();
return;
}
if(remoteMatched) group.remoteConfirmationSucceeded();
else group.remoteConfirmationFailed();
if(!(localMatched && remoteMatched)) {
tryToClose(conn, false);
return;
}
// The timestamp is taken after exhanging confirmation results
long localTimestamp = clock.currentTimeMillis();
// Confirmation succeeded - upgrade to a secure connection
if(LOG.isLoggable(INFO))
LOG.info(pluginName + " confirmation succeeded");
int maxFrameLength = conn.getMaxFrameLength();
ConnectionReader connectionReader =
connectionReaderFactory.createInvitationConnectionReader(in,
maxFrameLength, secret, false);
r = readerFactory.createReader(connectionReader.getInputStream());
ConnectionWriter connectionWriter =
connectionWriterFactory.createInvitationConnectionWriter(out,
maxFrameLength, secret, true);
w = writerFactory.createWriter(connectionWriter.getOutputStream());
// Derive the invitation nonces
byte[][] nonces = crypto.deriveInvitationNonces(secret);
byte[] aliceNonce = nonces[0], bobNonce = nonces[1];
// Exchange pseudonyms, signed nonces, timestamps and transports
Author remoteAuthor;
long remoteTimestamp;
Map<TransportId, TransportProperties> remoteProps;
try {
sendPseudonym(w, aliceNonce);
sendTimestamp(w, localTimestamp);
sendTransportProperties(w);
remoteAuthor = receivePseudonym(r, bobNonce);
remoteTimestamp = receiveTimestamp(r);
remoteProps = receiveTransportProperties(r);
} catch(GeneralSecurityException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
group.pseudonymExchangeFailed();
tryToClose(conn, true);
return;
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
group.pseudonymExchangeFailed();
tryToClose(conn, true);
return;
}
// The epoch is the minimum of the peers' timestamps
long epoch = Math.min(localTimestamp, remoteTimestamp);
// Add the contact and store the transports
try {
addContact(remoteAuthor, remoteProps, secret, epoch, true);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
tryToClose(conn, true);
group.pseudonymExchangeFailed();
return;
}
// Pseudonym exchange succeeded
if(LOG.isLoggable(INFO))
LOG.info(pluginName + " pseudonym exchange succeeded");
group.pseudonymExchangeSucceeded(remoteAuthor);
// Reuse the connection as an outgoing BTP connection
reuseConnection(conn, true);
}
}

View File

@@ -0,0 +1,186 @@
package org.briarproject.invitation;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import java.util.Map;
import java.util.logging.Logger;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorFactory;
import org.briarproject.api.LocalAuthor;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.crypto.PseudoRandom;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.messaging.GroupFactory;
import org.briarproject.api.plugins.duplex.DuplexPlugin;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.serial.Reader;
import org.briarproject.api.serial.ReaderFactory;
import org.briarproject.api.serial.Writer;
import org.briarproject.api.serial.WriterFactory;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.ConnectionDispatcher;
import org.briarproject.api.transport.ConnectionReader;
import org.briarproject.api.transport.ConnectionReaderFactory;
import org.briarproject.api.transport.ConnectionWriter;
import org.briarproject.api.transport.ConnectionWriterFactory;
/** A connection thread for the peer being Bob in the invitation protocol. */
class BobConnector extends Connector {
private static final Logger LOG =
Logger.getLogger(BobConnector.class.getName());
BobConnector(CryptoComponent crypto, DatabaseComponent db,
ReaderFactory readerFactory, WriterFactory writerFactory,
ConnectionReaderFactory connectionReaderFactory,
ConnectionWriterFactory connectionWriterFactory,
AuthorFactory authorFactory, GroupFactory groupFactory,
KeyManager keyManager, ConnectionDispatcher connectionDispatcher,
Clock clock, ConnectorGroup group, DuplexPlugin plugin,
LocalAuthor localAuthor,
Map<TransportId, TransportProperties> localProps,
PseudoRandom random) {
super(crypto, db, readerFactory, writerFactory, connectionReaderFactory,
connectionWriterFactory, authorFactory, groupFactory,
keyManager, connectionDispatcher, clock, group, plugin,
localAuthor, localProps, random);
}
@Override
public void run() {
// Create an incoming or outgoing connection
DuplexTransportConnection conn = createInvitationConnection();
if(conn == null) return;
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " connected");
// Carry out the key agreement protocol
InputStream in;
OutputStream out;
Reader r;
Writer w;
byte[] secret;
try {
in = conn.getInputStream();
out = conn.getOutputStream();
r = readerFactory.createReader(in);
w = writerFactory.createWriter(out);
// Alice goes first
byte[] hash = receivePublicKeyHash(r);
// Don't proceed with more than one connection
if(group.getAndSetConnected()) {
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " redundant");
tryToClose(conn, false);
return;
}
sendPublicKeyHash(w);
byte[] key = receivePublicKey(r);
sendPublicKey(w);
secret = deriveMasterSecret(hash, key, false);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
group.keyAgreementFailed();
tryToClose(conn, true);
return;
} catch(GeneralSecurityException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
group.keyAgreementFailed();
tryToClose(conn, true);
return;
}
// The key agreement succeeded - derive the confirmation codes
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " agreement succeeded");
int[] codes = crypto.deriveConfirmationCodes(secret);
int aliceCode = codes[0], bobCode = codes[1];
group.keyAgreementSucceeded(bobCode, aliceCode);
// Exchange confirmation results
boolean localMatched, remoteMatched;
try {
remoteMatched = receiveConfirmation(r);
localMatched = group.waitForLocalConfirmationResult();
sendConfirmation(w, localMatched);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
group.remoteConfirmationFailed();
tryToClose(conn, true);
return;
} catch(InterruptedException e) {
if(LOG.isLoggable(WARNING))
LOG.warning("Interrupted while waiting for confirmation");
group.remoteConfirmationFailed();
tryToClose(conn, true);
Thread.currentThread().interrupt();
return;
}
if(remoteMatched) group.remoteConfirmationSucceeded();
else group.remoteConfirmationFailed();
if(!(localMatched && remoteMatched)) {
tryToClose(conn, false);
return;
}
// The timestamp is taken after exhanging confirmation results
long localTimestamp = clock.currentTimeMillis();
// Confirmation succeeded - upgrade to a secure connection
if(LOG.isLoggable(INFO))
LOG.info(pluginName + " confirmation succeeded");
int maxFrameLength = conn.getMaxFrameLength();
ConnectionReader connectionReader =
connectionReaderFactory.createInvitationConnectionReader(in,
maxFrameLength, secret, true);
r = readerFactory.createReader(connectionReader.getInputStream());
ConnectionWriter connectionWriter =
connectionWriterFactory.createInvitationConnectionWriter(out,
maxFrameLength, secret, false);
w = writerFactory.createWriter(connectionWriter.getOutputStream());
// Derive the nonces
byte[][] nonces = crypto.deriveInvitationNonces(secret);
byte[] aliceNonce = nonces[0], bobNonce = nonces[1];
// Exchange pseudonyms, signed nonces, timestamps and transports
Author remoteAuthor;
long remoteTimestamp;
Map<TransportId, TransportProperties> remoteProps;
try {
remoteAuthor = receivePseudonym(r, aliceNonce);
remoteTimestamp = receiveTimestamp(r);
remoteProps = receiveTransportProperties(r);
sendPseudonym(w, bobNonce);
sendTimestamp(w, localTimestamp);
sendTransportProperties(w);
} catch(GeneralSecurityException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
group.pseudonymExchangeFailed();
tryToClose(conn, true);
return;
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
group.pseudonymExchangeFailed();
tryToClose(conn, true);
return;
}
// The epoch is the minimum of the peers' timestamps
long epoch = Math.min(localTimestamp, remoteTimestamp);
// Add the contact and store the transports
try {
addContact(remoteAuthor, remoteProps, secret, epoch, false);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
tryToClose(conn, true);
group.pseudonymExchangeFailed();
return;
}
// Pseudonym exchange succeeded
if(LOG.isLoggable(INFO))
LOG.info(pluginName + " pseudonym exchange succeeded");
group.pseudonymExchangeSucceeded(remoteAuthor);
// Reuse the connection as an incoming BTP connection
reuseConnection(conn, false);
}
}

View File

@@ -0,0 +1,342 @@
package org.briarproject.invitation;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import static org.briarproject.api.AuthorConstants.MAX_AUTHOR_NAME_LENGTH;
import static org.briarproject.api.AuthorConstants.MAX_SIGNATURE_LENGTH;
import static org.briarproject.api.TransportPropertyConstants.MAX_PROPERTIES_PER_TRANSPORT;
import static org.briarproject.api.TransportPropertyConstants.MAX_PROPERTY_LENGTH;
import static org.briarproject.api.invitation.InvitationConstants.CONNECTION_TIMEOUT;
import static org.briarproject.api.invitation.InvitationConstants.HASH_LENGTH;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.logging.Logger;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorConstants;
import org.briarproject.api.AuthorFactory;
import org.briarproject.api.ContactId;
import org.briarproject.api.FormatException;
import org.briarproject.api.LocalAuthor;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.UniqueId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.crypto.KeyPair;
import org.briarproject.api.crypto.KeyParser;
import org.briarproject.api.crypto.MessageDigest;
import org.briarproject.api.crypto.PseudoRandom;
import org.briarproject.api.crypto.Signature;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.db.NoSuchTransportException;
import org.briarproject.api.invitation.InvitationConstants;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.GroupFactory;
import org.briarproject.api.plugins.duplex.DuplexPlugin;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.serial.Reader;
import org.briarproject.api.serial.ReaderFactory;
import org.briarproject.api.serial.Writer;
import org.briarproject.api.serial.WriterFactory;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.ConnectionDispatcher;
import org.briarproject.api.transport.ConnectionReaderFactory;
import org.briarproject.api.transport.ConnectionWriterFactory;
import org.briarproject.api.transport.Endpoint;
abstract class Connector extends Thread {
private static final Logger LOG =
Logger.getLogger(Connector.class.getName());
protected final CryptoComponent crypto;
protected final DatabaseComponent db;
protected final ReaderFactory readerFactory;
protected final WriterFactory writerFactory;
protected final ConnectionReaderFactory connectionReaderFactory;
protected final ConnectionWriterFactory connectionWriterFactory;
protected final AuthorFactory authorFactory;
protected final GroupFactory groupFactory;
protected final KeyManager keyManager;
protected final ConnectionDispatcher connectionDispatcher;
protected final Clock clock;
protected final ConnectorGroup group;
protected final DuplexPlugin plugin;
protected final LocalAuthor localAuthor;
protected final Map<TransportId, TransportProperties> localProps;
protected final PseudoRandom random;
protected final String pluginName;
private final KeyPair keyPair;
private final KeyParser keyParser;
private final MessageDigest messageDigest;
private volatile ContactId contactId = null;
Connector(CryptoComponent crypto, DatabaseComponent db,
ReaderFactory readerFactory, WriterFactory writerFactory,
ConnectionReaderFactory connectionReaderFactory,
ConnectionWriterFactory connectionWriterFactory,
AuthorFactory authorFactory, GroupFactory groupFactory,
KeyManager keyManager, ConnectionDispatcher connectionDispatcher,
Clock clock, ConnectorGroup group, DuplexPlugin plugin,
LocalAuthor localAuthor,
Map<TransportId, TransportProperties> localProps,
PseudoRandom random) {
super("Connector");
this.crypto = crypto;
this.db = db;
this.readerFactory = readerFactory;
this.writerFactory = writerFactory;
this.connectionReaderFactory = connectionReaderFactory;
this.connectionWriterFactory = connectionWriterFactory;
this.authorFactory = authorFactory;
this.groupFactory = groupFactory;
this.keyManager = keyManager;
this.connectionDispatcher = connectionDispatcher;
this.clock = clock;
this.group = group;
this.plugin = plugin;
this.localAuthor = localAuthor;
this.localProps = localProps;
this.random = random;
pluginName = plugin.getClass().getName();
keyPair = crypto.generateAgreementKeyPair();
keyParser = crypto.getAgreementKeyParser();
messageDigest = crypto.getMessageDigest();
}
protected DuplexTransportConnection createInvitationConnection() {
if(LOG.isLoggable(INFO))
LOG.info(pluginName + " creating invitation connection");
return plugin.createInvitationConnection(random, CONNECTION_TIMEOUT);
}
protected void sendPublicKeyHash(Writer w) throws IOException {
w.writeBytes(messageDigest.digest(keyPair.getPublic().getEncoded()));
w.flush();
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " sent hash");
}
protected byte[] receivePublicKeyHash(Reader r) throws IOException {
byte[] b = r.readBytes(HASH_LENGTH);
if(b.length < HASH_LENGTH) throw new FormatException();
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " received hash");
return b;
}
protected void sendPublicKey(Writer w) throws IOException {
byte[] key = keyPair.getPublic().getEncoded();
w.writeBytes(key);
w.flush();
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " sent key");
}
protected byte[] receivePublicKey(Reader r) throws GeneralSecurityException,
IOException {
byte[] b = r.readBytes(InvitationConstants.MAX_PUBLIC_KEY_LENGTH);
keyParser.parsePublicKey(b);
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " received key");
return b;
}
protected byte[] deriveMasterSecret(byte[] hash, byte[] key, boolean alice)
throws GeneralSecurityException {
// Check that the hash matches the key
if(!Arrays.equals(hash, messageDigest.digest(key))) {
if(LOG.isLoggable(INFO))
LOG.info(pluginName + " hash does not match key");
throw new GeneralSecurityException();
}
// Derive the master secret
if(LOG.isLoggable(INFO))
LOG.info(pluginName + " deriving master secret");
return crypto.deriveMasterSecret(key, keyPair, alice);
}
protected void sendConfirmation(Writer w, boolean matched)
throws IOException {
w.writeBoolean(matched);
w.flush();
if(LOG.isLoggable(INFO))
LOG.info(pluginName + " sent confirmation: " + matched);
}
protected boolean receiveConfirmation(Reader r) throws IOException {
boolean matched = r.readBoolean();
if(LOG.isLoggable(INFO))
LOG.info(pluginName + " received confirmation: " + matched);
return matched;
}
protected void sendPseudonym(Writer w, byte[] nonce)
throws GeneralSecurityException, IOException {
// Sign the nonce
Signature signature = crypto.getSignature();
KeyParser keyParser = crypto.getSignatureKeyParser();
byte[] privateKey = localAuthor.getPrivateKey();
signature.initSign(keyParser.parsePrivateKey(privateKey));
signature.update(nonce);
byte[] sig = signature.sign();
// Write the name, public key and signature
w.writeString(localAuthor.getName());
w.writeBytes(localAuthor.getPublicKey());
w.writeBytes(sig);
w.flush();
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " sent pseudonym");
}
protected Author receivePseudonym(Reader r, byte[] nonce)
throws GeneralSecurityException, IOException {
// Read the name, public key and signature
String name = r.readString(MAX_AUTHOR_NAME_LENGTH);
byte[] publicKey = r.readBytes(AuthorConstants.MAX_PUBLIC_KEY_LENGTH);
byte[] sig = r.readBytes(MAX_SIGNATURE_LENGTH);
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " received pseudonym");
// Verify the signature
Signature signature = crypto.getSignature();
KeyParser keyParser = crypto.getSignatureKeyParser();
signature.initVerify(keyParser.parsePublicKey(publicKey));
signature.update(nonce);
if(!signature.verify(sig)) {
if(LOG.isLoggable(INFO))
LOG.info(pluginName + " invalid signature");
throw new GeneralSecurityException();
}
return authorFactory.createAuthor(name, publicKey);
}
protected void sendTimestamp(Writer w, long timestamp) throws IOException {
w.writeInt64(timestamp);
w.flush();
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " sent timestamp");
}
protected long receiveTimestamp(Reader r) throws IOException {
long timestamp = r.readInt64();
if(timestamp < 0) throw new FormatException();
if(LOG.isLoggable(INFO)) LOG.info(pluginName + " received timestamp");
return timestamp;
}
protected void sendTransportProperties(Writer w) throws IOException {
w.writeListStart();
for(Entry<TransportId, TransportProperties> e : localProps.entrySet()) {
w.writeBytes(e.getKey().getBytes());
w.writeMap(e.getValue());
}
w.writeListEnd();
w.flush();
if(LOG.isLoggable(INFO))
LOG.info(pluginName + " sent transport properties");
}
protected Map<TransportId, TransportProperties> receiveTransportProperties(
Reader r) throws IOException {
Map<TransportId, TransportProperties> remoteProps =
new HashMap<TransportId, TransportProperties>();
r.readListStart();
while(!r.hasListEnd()) {
byte[] b = r.readBytes(UniqueId.LENGTH);
if(b.length != UniqueId.LENGTH) throw new FormatException();
TransportId id = new TransportId(b);
Map<String, String> p = new HashMap<String, String>();
r.readMapStart();
for(int i = 0; !r.hasMapEnd(); i++) {
if(i == MAX_PROPERTIES_PER_TRANSPORT)
throw new FormatException();
String key = r.readString(MAX_PROPERTY_LENGTH);
String value = r.readString(MAX_PROPERTY_LENGTH);
p.put(key, value);
}
r.readMapEnd();
remoteProps.put(id, new TransportProperties(p));
}
r.readListEnd();
if(LOG.isLoggable(INFO))
LOG.info(pluginName + " received transport properties");
return remoteProps;
}
protected void addContact(Author remoteAuthor,
Map<TransportId, TransportProperties> remoteProps, byte[] secret,
long epoch, boolean alice) throws DbException {
// Add the contact to the database
contactId = db.addContact(remoteAuthor, localAuthor.getId());
// Create and store the inbox group
byte[] salt = crypto.deriveGroupSalt(secret);
Group inbox = groupFactory.createGroup("Inbox", salt);
db.addGroup(inbox);
db.setInboxGroup(contactId, inbox);
// Store the remote transport properties
db.setRemoteProperties(contactId, remoteProps);
// Create an endpoint for each transport shared with the contact
List<TransportId> ids = new ArrayList<TransportId>();
Map<TransportId, Long> latencies = db.getTransportLatencies();
for(TransportId id : localProps.keySet()) {
if(latencies.containsKey(id) && remoteProps.containsKey(id))
ids.add(id);
}
// Assign indices to the transports deterministically and derive keys
Collections.sort(ids, TransportIdComparator.INSTANCE);
int size = ids.size();
for(int i = 0; i < size; i++) {
TransportId id = ids.get(i);
Endpoint ep = new Endpoint(contactId, id, epoch, alice);
long maxLatency = latencies.get(id);
try {
db.addEndpoint(ep);
} catch(NoSuchTransportException e) {
continue;
}
byte[] initialSecret = crypto.deriveInitialSecret(secret, i);
keyManager.endpointAdded(ep, maxLatency, initialSecret);
}
}
protected void tryToClose(DuplexTransportConnection conn,
boolean exception) {
try {
if(LOG.isLoggable(INFO)) LOG.info("Closing connection");
conn.dispose(exception, true);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
protected void reuseConnection(DuplexTransportConnection conn,
boolean alice) {
if(contactId == null) throw new IllegalStateException();
TransportId t = plugin.getId();
if(alice)
connectionDispatcher.dispatchOutgoingConnection(contactId, t, conn);
else connectionDispatcher.dispatchIncomingConnection(t, conn);
}
private static class TransportIdComparator
implements Comparator<TransportId> {
private static final TransportIdComparator INSTANCE =
new TransportIdComparator();
public int compare(TransportId t1, TransportId t2) {
byte[] b1 = t1.getBytes(), b2 = t2.getBytes();
for(int i = 0; i < UniqueId.LENGTH; i++) {
if((b1[i] & 0xff) < (b2[i] & 0xff)) return -1;
if((b1[i] & 0xff) > (b2[i] & 0xff)) return 1;
}
return 0;
}
}
}

View File

@@ -0,0 +1,263 @@
package org.briarproject.invitation;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.logging.Level.WARNING;
import static org.briarproject.api.invitation.InvitationConstants.CONFIRMATION_TIMEOUT;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Logger;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorFactory;
import org.briarproject.api.AuthorId;
import org.briarproject.api.LocalAuthor;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.crypto.PseudoRandom;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.invitation.InvitationListener;
import org.briarproject.api.invitation.InvitationState;
import org.briarproject.api.invitation.InvitationTask;
import org.briarproject.api.messaging.GroupFactory;
import org.briarproject.api.plugins.PluginManager;
import org.briarproject.api.plugins.duplex.DuplexPlugin;
import org.briarproject.api.serial.ReaderFactory;
import org.briarproject.api.serial.WriterFactory;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.ConnectionDispatcher;
import org.briarproject.api.transport.ConnectionReaderFactory;
import org.briarproject.api.transport.ConnectionWriterFactory;
/** A task consisting of one or more parallel connection attempts. */
class ConnectorGroup extends Thread implements InvitationTask {
private static final Logger LOG =
Logger.getLogger(ConnectorGroup.class.getName());
private final CryptoComponent crypto;
private final DatabaseComponent db;
private final ReaderFactory readerFactory;
private final WriterFactory writerFactory;
private final ConnectionReaderFactory connectionReaderFactory;
private final ConnectionWriterFactory connectionWriterFactory;
private final AuthorFactory authorFactory;
private final GroupFactory groupFactory;
private final KeyManager keyManager;
private final ConnectionDispatcher connectionDispatcher;
private final Clock clock;
private final PluginManager pluginManager;
private final AuthorId localAuthorId;
private final int localInvitationCode, remoteInvitationCode;
private final Collection<InvitationListener> listeners;
private final AtomicBoolean connected;
private final CountDownLatch localConfirmationLatch;
/*
* All of the following require locking: this. We don't want to call the
* listeners with a lock held, but we need to avoid a race condition in
* addListener(), so the state that's accessed there after calling
* listeners.add() must be guarded by a lock.
*/
private int localConfirmationCode = -1, remoteConfirmationCode = -1;
private boolean connectionFailed = false;
private boolean localCompared = false, remoteCompared = false;
private boolean localMatched = false, remoteMatched = false;
private String remoteName = null;
ConnectorGroup(CryptoComponent crypto, DatabaseComponent db,
ReaderFactory readerFactory, WriterFactory writerFactory,
ConnectionReaderFactory connectionReaderFactory,
ConnectionWriterFactory connectionWriterFactory,
AuthorFactory authorFactory, GroupFactory groupFactory,
KeyManager keyManager, ConnectionDispatcher connectionDispatcher,
Clock clock, PluginManager pluginManager, AuthorId localAuthorId,
int localInvitationCode, int remoteInvitationCode) {
super("ConnectorGroup");
this.crypto = crypto;
this.db = db;
this.readerFactory = readerFactory;
this.writerFactory = writerFactory;
this.connectionReaderFactory = connectionReaderFactory;
this.connectionWriterFactory = connectionWriterFactory;
this.authorFactory = authorFactory;
this.groupFactory = groupFactory;
this.keyManager = keyManager;
this.connectionDispatcher = connectionDispatcher;
this.clock = clock;
this.pluginManager = pluginManager;
this.localAuthorId = localAuthorId;
this.localInvitationCode = localInvitationCode;
this.remoteInvitationCode = remoteInvitationCode;
listeners = new CopyOnWriteArrayList<InvitationListener>();
connected = new AtomicBoolean(false);
localConfirmationLatch = new CountDownLatch(1);
}
public synchronized InvitationState addListener(InvitationListener l) {
listeners.add(l);
return new InvitationState(localInvitationCode, remoteInvitationCode,
localConfirmationCode, remoteConfirmationCode, connected.get(),
connectionFailed, localCompared, remoteCompared, localMatched,
remoteMatched, remoteName);
}
public void removeListener(InvitationListener l) {
listeners.remove(l);
}
public void connect() {
start();
}
@Override
public void run() {
LocalAuthor localAuthor;
Map<TransportId, TransportProperties> localProps;
// Load the local pseudonym and transport properties
try {
localAuthor = db.getLocalAuthor(localAuthorId);
localProps = db.getLocalProperties();
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
synchronized(this) {
connectionFailed = true;
}
for(InvitationListener l : listeners) l.connectionFailed();
return;
}
// Start the connection threads
Collection<Connector> connectors = new ArrayList<Connector>();
// Alice is the party with the smaller invitation code
if(localInvitationCode < remoteInvitationCode) {
for(DuplexPlugin plugin : pluginManager.getInvitationPlugins()) {
Connector c = createAliceConnector(plugin, localAuthor,
localProps);
connectors.add(c);
c.start();
}
} else {
for(DuplexPlugin plugin: pluginManager.getInvitationPlugins()) {
Connector c = createBobConnector(plugin, localAuthor,
localProps);
connectors.add(c);
c.start();
}
}
// Wait for the connection threads to finish
try {
for(Connector c : connectors) c.join();
} catch(InterruptedException e) {
if(LOG.isLoggable(WARNING))
LOG.warning("Interrupted while waiting for connectors");
}
// If none of the threads connected, inform the listeners
if(!connected.get()) {
synchronized(this) {
connectionFailed = true;
}
for(InvitationListener l : listeners) l.connectionFailed();
}
}
private Connector createAliceConnector(DuplexPlugin plugin,
LocalAuthor localAuthor,
Map<TransportId, TransportProperties> localProps) {
PseudoRandom random = crypto.getPseudoRandom(localInvitationCode,
remoteInvitationCode);
return new AliceConnector(crypto, db, readerFactory, writerFactory,
connectionReaderFactory, connectionWriterFactory, authorFactory,
groupFactory, keyManager, connectionDispatcher, clock, this,
plugin, localAuthor, localProps, random);
}
private Connector createBobConnector(DuplexPlugin plugin,
LocalAuthor localAuthor,
Map<TransportId, TransportProperties> localProps) {
PseudoRandom random = crypto.getPseudoRandom(remoteInvitationCode,
localInvitationCode);
return new BobConnector(crypto, db, readerFactory, writerFactory,
connectionReaderFactory, connectionWriterFactory, authorFactory,
groupFactory, keyManager, connectionDispatcher, clock, this,
plugin, localAuthor, localProps, random);
}
public void localConfirmationSucceeded() {
synchronized(this) {
localCompared = true;
localMatched = true;
}
localConfirmationLatch.countDown();
}
public void localConfirmationFailed() {
synchronized(this) {
localCompared = true;
localMatched = false;
}
localConfirmationLatch.countDown();
}
boolean getAndSetConnected() {
boolean redundant = connected.getAndSet(true);
if(!redundant)
for(InvitationListener l : listeners) l.connectionSucceeded();
return redundant;
}
void keyAgreementSucceeded(int localCode, int remoteCode) {
synchronized(this) {
localConfirmationCode = localCode;
remoteConfirmationCode = remoteCode;
}
for(InvitationListener l : listeners)
l.keyAgreementSucceeded(localCode, remoteCode);
}
void keyAgreementFailed() {
for(InvitationListener l : listeners) l.keyAgreementFailed();
}
boolean waitForLocalConfirmationResult() throws InterruptedException {
localConfirmationLatch.await(CONFIRMATION_TIMEOUT, MILLISECONDS);
synchronized(this) {
return localMatched;
}
}
void remoteConfirmationSucceeded() {
synchronized(this) {
remoteCompared = true;
remoteMatched = true;
}
for(InvitationListener l : listeners) l.remoteConfirmationSucceeded();
}
void remoteConfirmationFailed() {
synchronized(this) {
remoteCompared = true;
remoteMatched = false;
}
for(InvitationListener l : listeners) l.remoteConfirmationFailed();
}
void pseudonymExchangeSucceeded(Author remoteAuthor) {
String name = remoteAuthor.getName();
synchronized(this) {
remoteName = name;
}
for(InvitationListener l : listeners)
l.pseudonymExchangeSucceeded(name);
}
void pseudonymExchangeFailed() {
for(InvitationListener l : listeners) l.pseudonymExchangeFailed();
}
}

View File

@@ -0,0 +1,15 @@
package org.briarproject.invitation;
import javax.inject.Singleton;
import org.briarproject.api.invitation.InvitationTaskFactory;
import com.google.inject.AbstractModule;
public class InvitationModule extends AbstractModule {
protected void configure() {
bind(InvitationTaskFactory.class).to(
InvitationTaskFactoryImpl.class).in(Singleton.class);
}
}

View File

@@ -0,0 +1,65 @@
package org.briarproject.invitation;
import javax.inject.Inject;
import org.briarproject.api.AuthorFactory;
import org.briarproject.api.AuthorId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.invitation.InvitationTask;
import org.briarproject.api.invitation.InvitationTaskFactory;
import org.briarproject.api.messaging.GroupFactory;
import org.briarproject.api.plugins.PluginManager;
import org.briarproject.api.serial.ReaderFactory;
import org.briarproject.api.serial.WriterFactory;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.ConnectionDispatcher;
import org.briarproject.api.transport.ConnectionReaderFactory;
import org.briarproject.api.transport.ConnectionWriterFactory;
class InvitationTaskFactoryImpl implements InvitationTaskFactory {
private final CryptoComponent crypto;
private final DatabaseComponent db;
private final ReaderFactory readerFactory;
private final WriterFactory writerFactory;
private final ConnectionReaderFactory connectionReaderFactory;
private final ConnectionWriterFactory connectionWriterFactory;
private final AuthorFactory authorFactory;
private final GroupFactory groupFactory;
private final KeyManager keyManager;
private final ConnectionDispatcher connectionDispatcher;
private final Clock clock;
private final PluginManager pluginManager;
@Inject
InvitationTaskFactoryImpl(CryptoComponent crypto, DatabaseComponent db,
ReaderFactory readerFactory, WriterFactory writerFactory,
ConnectionReaderFactory connectionReaderFactory,
ConnectionWriterFactory connectionWriterFactory,
AuthorFactory authorFactory, GroupFactory groupFactory,
KeyManager keyManager, ConnectionDispatcher connectionDispatcher,
Clock clock, PluginManager pluginManager) {
this.crypto = crypto;
this.db = db;
this.readerFactory = readerFactory;
this.writerFactory = writerFactory;
this.connectionReaderFactory = connectionReaderFactory;
this.connectionWriterFactory = connectionWriterFactory;
this.authorFactory = authorFactory;
this.groupFactory = groupFactory;
this.keyManager = keyManager;
this.connectionDispatcher = connectionDispatcher;
this.clock = clock;
this.pluginManager = pluginManager;
}
public InvitationTask createTask(AuthorId localAuthorId, int localCode,
int remoteCode) {
return new ConnectorGroup(crypto, db, readerFactory, writerFactory,
connectionReaderFactory, connectionWriterFactory,
authorFactory, groupFactory, keyManager, connectionDispatcher,
clock, pluginManager, localAuthorId, localCode, remoteCode);
}
}

View File

@@ -0,0 +1,111 @@
package org.briarproject.lifecycle;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import java.io.IOException;
import java.util.Collection;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.lifecycle.LifecycleManager;
import org.briarproject.api.lifecycle.Service;
class LifecycleManagerImpl implements LifecycleManager {
private static final Logger LOG =
Logger.getLogger(LifecycleManagerImpl.class.getName());
private final DatabaseComponent db;
private final Collection<Service> services;
private final Collection<ExecutorService> executors;
private final CountDownLatch dbLatch = new CountDownLatch(1);
private final CountDownLatch startupLatch = new CountDownLatch(1);
private final CountDownLatch shutdownLatch = new CountDownLatch(1);
@Inject
LifecycleManagerImpl(DatabaseComponent db) {
this.db = db;
services = new CopyOnWriteArrayList<Service>();
executors = new CopyOnWriteArrayList<ExecutorService>();
}
public void register(Service s) {
if(LOG.isLoggable(INFO))
LOG.info("Registering service " + s.getClass().getName());
services.add(s);
}
public void registerForShutdown(ExecutorService e) {
if(LOG.isLoggable(INFO))
LOG.info("Registering executor " + e.getClass().getName());
executors.add(e);
}
public void startServices() {
try {
if(LOG.isLoggable(INFO)) LOG.info("Starting");
boolean reopened = db.open();
if(LOG.isLoggable(INFO)) {
if(reopened) LOG.info("Database reopened");
else LOG.info("Database created");
}
dbLatch.countDown();
for(Service s : services) {
boolean started = s.start();
if(LOG.isLoggable(INFO)) {
String name = s.getClass().getName();
if(started) LOG.info("Service started: " + name);
else LOG.info("Service failed to start: " + name);
}
}
startupLatch.countDown();
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
public void stopServices() {
try {
if(LOG.isLoggable(INFO)) LOG.info("Shutting down");
for(Service s : services) {
boolean stopped = s.stop();
if(LOG.isLoggable(INFO)) {
String name = s.getClass().getName();
if(stopped) LOG.info("Service stopped: " + name);
else LOG.warning("Service failed to stop: " + name);
}
}
for(ExecutorService e : executors) e.shutdownNow();
if(LOG.isLoggable(INFO))
LOG.info(executors.size() + " executors shut down");
db.close();
if(LOG.isLoggable(INFO)) LOG.info("Database closed");
shutdownLatch.countDown();
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
public void waitForDatabase() throws InterruptedException {
dbLatch.await();
}
public void waitForStartup() throws InterruptedException {
startupLatch.await();
}
public void waitForShutdown() throws InterruptedException {
shutdownLatch.await();
}
}

View File

@@ -0,0 +1,18 @@
package org.briarproject.lifecycle;
import javax.inject.Singleton;
import org.briarproject.api.lifecycle.LifecycleManager;
import org.briarproject.api.lifecycle.ShutdownManager;
import com.google.inject.AbstractModule;
public class LifecycleModule extends AbstractModule {
protected void configure() {
bind(LifecycleManager.class).to(
LifecycleManagerImpl.class).in(Singleton.class);
bind(ShutdownManager.class).to(
ShutdownManagerImpl.class).in(Singleton.class);
}
}

View File

@@ -0,0 +1,35 @@
package org.briarproject.lifecycle;
import java.util.HashMap;
import java.util.Map;
import org.briarproject.api.lifecycle.ShutdownManager;
class ShutdownManagerImpl implements ShutdownManager {
protected final Map<Integer, Thread> hooks; // Locking: this
private int nextHandle = 0; // Locking: this
ShutdownManagerImpl() {
hooks = new HashMap<Integer, Thread>();
}
public synchronized int addShutdownHook(Runnable r) {
int handle = nextHandle++;
Thread hook = createThread(r);
hooks.put(handle, hook);
Runtime.getRuntime().addShutdownHook(hook);
return handle;
}
protected Thread createThread(Runnable r) {
return new Thread(r, "ShutdownManager");
}
public synchronized boolean removeShutdownHook(int handle) {
Thread hook = hooks.remove(handle);
if(hook == null) return false;
else return Runtime.getRuntime().removeShutdownHook(hook);
}
}

View File

@@ -0,0 +1,56 @@
package org.briarproject.messaging;
import static org.briarproject.api.messaging.Types.AUTHOR;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import javax.inject.Inject;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorFactory;
import org.briarproject.api.AuthorId;
import org.briarproject.api.LocalAuthor;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.MessageDigest;
import org.briarproject.api.serial.Writer;
import org.briarproject.api.serial.WriterFactory;
class AuthorFactoryImpl implements AuthorFactory {
private final CryptoComponent crypto;
private final WriterFactory writerFactory;
@Inject
AuthorFactoryImpl(CryptoComponent crypto, WriterFactory writerFactory) {
this.crypto = crypto;
this.writerFactory = writerFactory;
}
public Author createAuthor(String name, byte[] publicKey) {
return new Author(getId(name, publicKey), name, publicKey);
}
public LocalAuthor createLocalAuthor(String name, byte[] publicKey,
byte[] privateKey) {
return new LocalAuthor(getId(name, publicKey), name, publicKey,
privateKey);
}
private AuthorId getId(String name, byte[] publicKey) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
try {
w.writeStructStart(AUTHOR);
w.writeString(name);
w.writeBytes(publicKey);
w.writeStructEnd();
} catch(IOException e) {
// Shouldn't happen with ByteArrayOutputStream
throw new RuntimeException();
}
MessageDigest messageDigest = crypto.getMessageDigest();
messageDigest.update(out.toByteArray());
return new AuthorId(messageDigest.digest());
}
}

View File

@@ -0,0 +1,40 @@
package org.briarproject.messaging;
import static org.briarproject.api.AuthorConstants.MAX_AUTHOR_NAME_LENGTH;
import static org.briarproject.api.AuthorConstants.MAX_PUBLIC_KEY_LENGTH;
import static org.briarproject.api.messaging.Types.AUTHOR;
import java.io.IOException;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.MessageDigest;
import org.briarproject.api.serial.DigestingConsumer;
import org.briarproject.api.serial.Reader;
import org.briarproject.api.serial.StructReader;
class AuthorReader implements StructReader<Author> {
private final MessageDigest messageDigest;
AuthorReader(CryptoComponent crypto) {
messageDigest = crypto.getMessageDigest();
}
public Author readStruct(Reader r) throws IOException {
// Set up the reader
DigestingConsumer digesting = new DigestingConsumer(messageDigest);
r.addConsumer(digesting);
// Read and digest the data
r.readStructStart(AUTHOR);
String name = r.readString(MAX_AUTHOR_NAME_LENGTH);
byte[] publicKey = r.readBytes(MAX_PUBLIC_KEY_LENGTH);
r.readStructEnd();
// Reset the reader
r.removeConsumer(digesting);
// Build and return the author
AuthorId id = new AuthorId(messageDigest.digest());
return new Author(id, name, publicKey);
}
}

View File

@@ -0,0 +1,53 @@
package org.briarproject.messaging;
import static org.briarproject.api.messaging.MessagingConstants.GROUP_SALT_LENGTH;
import static org.briarproject.api.messaging.Types.GROUP;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import javax.inject.Inject;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.MessageDigest;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.GroupFactory;
import org.briarproject.api.messaging.GroupId;
import org.briarproject.api.serial.Writer;
import org.briarproject.api.serial.WriterFactory;
class GroupFactoryImpl implements GroupFactory {
private final CryptoComponent crypto;
private final WriterFactory writerFactory;
@Inject
GroupFactoryImpl(CryptoComponent crypto, WriterFactory writerFactory) {
this.crypto = crypto;
this.writerFactory = writerFactory;
}
public Group createGroup(String name) {
byte[] salt = new byte[GROUP_SALT_LENGTH];
crypto.getSecureRandom().nextBytes(salt);
return createGroup(name, salt);
}
public Group createGroup(String name, byte[] salt) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
try {
w.writeStructStart(GROUP);
w.writeString(name);
w.writeBytes(salt);
w.writeStructEnd();
} catch(IOException e) {
// Shouldn't happen with ByteArrayOutputStream
throw new RuntimeException();
}
MessageDigest messageDigest = crypto.getMessageDigest();
messageDigest.update(out.toByteArray());
GroupId id = new GroupId(messageDigest.digest());
return new Group(id, name, salt);
}
}

View File

@@ -0,0 +1,40 @@
package org.briarproject.messaging;
import static org.briarproject.api.AuthorConstants.MAX_PUBLIC_KEY_LENGTH;
import static org.briarproject.api.messaging.MessagingConstants.MAX_GROUP_NAME_LENGTH;
import static org.briarproject.api.messaging.Types.GROUP;
import java.io.IOException;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.MessageDigest;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.GroupId;
import org.briarproject.api.serial.DigestingConsumer;
import org.briarproject.api.serial.Reader;
import org.briarproject.api.serial.StructReader;
class GroupReader implements StructReader<Group> {
private final MessageDigest messageDigest;
GroupReader(CryptoComponent crypto) {
messageDigest = crypto.getMessageDigest();
}
public Group readStruct(Reader r) throws IOException {
DigestingConsumer digesting = new DigestingConsumer(messageDigest);
// Read and digest the data
r.addConsumer(digesting);
r.readStructStart(GROUP);
String name = r.readString(MAX_GROUP_NAME_LENGTH);
byte[] publicKey = null;
if(r.hasNull()) r.readNull();
else publicKey = r.readBytes(MAX_PUBLIC_KEY_LENGTH);
r.readStructEnd();
r.removeConsumer(digesting);
// Build and return the group
GroupId id = new GroupId(messageDigest.digest());
return new Group(id, name, publicKey);
}
}

View File

@@ -0,0 +1,134 @@
package org.briarproject.messaging;
import static org.briarproject.api.AuthorConstants.MAX_SIGNATURE_LENGTH;
import static org.briarproject.api.messaging.MessagingConstants.MAX_BODY_LENGTH;
import static org.briarproject.api.messaging.MessagingConstants.MAX_CONTENT_TYPE_LENGTH;
import static org.briarproject.api.messaging.MessagingConstants.MAX_PACKET_LENGTH;
import static org.briarproject.api.messaging.MessagingConstants.MESSAGE_SALT_LENGTH;
import static org.briarproject.api.messaging.Types.AUTHOR;
import static org.briarproject.api.messaging.Types.GROUP;
import static org.briarproject.api.messaging.Types.MESSAGE;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import javax.inject.Inject;
import org.briarproject.api.Author;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.MessageDigest;
import org.briarproject.api.crypto.PrivateKey;
import org.briarproject.api.crypto.Signature;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.Message;
import org.briarproject.api.messaging.MessageFactory;
import org.briarproject.api.messaging.MessageId;
import org.briarproject.api.serial.Consumer;
import org.briarproject.api.serial.CountingConsumer;
import org.briarproject.api.serial.DigestingConsumer;
import org.briarproject.api.serial.SigningConsumer;
import org.briarproject.api.serial.Writer;
import org.briarproject.api.serial.WriterFactory;
class MessageFactoryImpl implements MessageFactory {
private final Signature signature;
private final SecureRandom random;
private final MessageDigest messageDigest;
private final WriterFactory writerFactory;
@Inject
MessageFactoryImpl(CryptoComponent crypto, WriterFactory writerFactory) {
signature = crypto.getSignature();
random = crypto.getSecureRandom();
messageDigest = crypto.getMessageDigest();
this.writerFactory = writerFactory;
}
public Message createAnonymousMessage(MessageId parent, Group group,
String contentType, long timestamp, byte[] body) throws IOException,
GeneralSecurityException {
return createMessage(parent, group, null, null, contentType, timestamp,
body);
}
public Message createPseudonymousMessage(MessageId parent, Group group,
Author author, PrivateKey privateKey, String contentType,
long timestamp, byte[] body) throws IOException,
GeneralSecurityException {
return createMessage(parent, group, author, privateKey, contentType,
timestamp, body);
}
private Message createMessage(MessageId parent, Group group, Author author,
PrivateKey privateKey, String contentType, long timestamp,
byte[] body) throws IOException, GeneralSecurityException {
// Validate the arguments
if((author == null) != (privateKey == null))
throw new IllegalArgumentException();
if(contentType.getBytes("UTF-8").length > MAX_CONTENT_TYPE_LENGTH)
throw new IllegalArgumentException();
if(body.length > MAX_BODY_LENGTH)
throw new IllegalArgumentException();
// Serialise the message to a buffer
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
// Initialise the consumers
CountingConsumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
w.addConsumer(counting);
Consumer digestingConsumer = new DigestingConsumer(messageDigest);
w.addConsumer(digestingConsumer);
Consumer signingConsumer = null;
if(privateKey != null) {
signature.initSign(privateKey);
signingConsumer = new SigningConsumer(signature);
w.addConsumer(signingConsumer);
}
// Write the message
w.writeStructStart(MESSAGE);
if(parent == null) w.writeNull();
else w.writeBytes(parent.getBytes());
writeGroup(w, group);
if(author == null) w.writeNull();
else writeAuthor(w, author);
w.writeString(contentType);
w.writeIntAny(timestamp);
byte[] salt = new byte[MESSAGE_SALT_LENGTH];
random.nextBytes(salt);
w.writeBytes(salt);
w.writeBytes(body);
int bodyStart = (int) counting.getCount() - body.length;
// Sign the message with the author's private key, if there is one
if(privateKey == null) {
w.writeNull();
} else {
w.removeConsumer(signingConsumer);
byte[] sig = signature.sign();
if(sig.length > MAX_SIGNATURE_LENGTH)
throw new IllegalArgumentException();
w.writeBytes(sig);
}
w.writeStructEnd();
// Hash the message, including the signature, to get the message ID
w.removeConsumer(digestingConsumer);
MessageId id = new MessageId(messageDigest.digest());
return new MessageImpl(id, parent, group, author, contentType,
timestamp, out.toByteArray(), bodyStart, body.length);
}
private void writeGroup(Writer w, Group g) throws IOException {
w.writeStructStart(GROUP);
w.writeString(g.getName());
w.writeBytes(g.getSalt());
w.writeStructEnd();
}
private void writeAuthor(Writer w, Author a) throws IOException {
w.writeStructStart(AUTHOR);
w.writeString(a.getName());
w.writeBytes(a.getPublicKey());
w.writeStructEnd();
}
}

View File

@@ -0,0 +1,83 @@
package org.briarproject.messaging;
import static org.briarproject.api.messaging.MessagingConstants.MAX_BODY_LENGTH;
import org.briarproject.api.Author;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.Message;
import org.briarproject.api.messaging.MessageId;
/** A simple in-memory implementation of a message. */
class MessageImpl implements Message {
private final MessageId id, parent;
private final Group group;
private final Author author;
private final String contentType;
private final long timestamp;
private final byte[] raw;
private final int bodyStart, bodyLength;
public MessageImpl(MessageId id, MessageId parent, Group group,
Author author, String contentType, long timestamp,
byte[] raw, int bodyStart, int bodyLength) {
if(bodyStart + bodyLength > raw.length)
throw new IllegalArgumentException();
if(bodyLength > MAX_BODY_LENGTH)
throw new IllegalArgumentException();
this.id = id;
this.parent = parent;
this.group = group;
this.author = author;
this.contentType = contentType;
this.timestamp = timestamp;
this.raw = raw;
this.bodyStart = bodyStart;
this.bodyLength = bodyLength;
}
public MessageId getId() {
return id;
}
public MessageId getParent() {
return parent;
}
public Group getGroup() {
return group;
}
public Author getAuthor() {
return author;
}
public String getContentType() {
return contentType;
}
public long getTimestamp() {
return timestamp;
}
public byte[] getSerialised() {
return raw;
}
public int getBodyStart() {
return bodyStart;
}
public int getBodyLength() {
return bodyLength;
}
@Override
public int hashCode() {
return id.hashCode();
}
@Override
public boolean equals(Object o) {
return o instanceof Message && id.equals(((Message) o).getId());
}
}

View File

@@ -0,0 +1,84 @@
package org.briarproject.messaging;
import static org.briarproject.api.AuthorConstants.MAX_SIGNATURE_LENGTH;
import static org.briarproject.api.messaging.MessagingConstants.MAX_BODY_LENGTH;
import static org.briarproject.api.messaging.MessagingConstants.MAX_CONTENT_TYPE_LENGTH;
import static org.briarproject.api.messaging.MessagingConstants.MAX_PACKET_LENGTH;
import static org.briarproject.api.messaging.MessagingConstants.MESSAGE_SALT_LENGTH;
import static org.briarproject.api.messaging.Types.MESSAGE;
import java.io.IOException;
import org.briarproject.api.Author;
import org.briarproject.api.FormatException;
import org.briarproject.api.UniqueId;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.MessageId;
import org.briarproject.api.messaging.UnverifiedMessage;
import org.briarproject.api.serial.CopyingConsumer;
import org.briarproject.api.serial.CountingConsumer;
import org.briarproject.api.serial.Reader;
import org.briarproject.api.serial.StructReader;
class MessageReader implements StructReader<UnverifiedMessage> {
private final StructReader<Group> groupReader;
private final StructReader<Author> authorReader;
MessageReader(StructReader<Group> groupReader,
StructReader<Author> authorReader) {
this.groupReader = groupReader;
this.authorReader = authorReader;
}
public UnverifiedMessage readStruct(Reader r) throws IOException {
CopyingConsumer copying = new CopyingConsumer();
CountingConsumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
r.addConsumer(copying);
r.addConsumer(counting);
// Read the start of the struct
r.readStructStart(MESSAGE);
// Read the parent's message ID, if there is one
MessageId parent = null;
if(r.hasNull()) {
r.readNull();
} else {
byte[] b = r.readBytes(UniqueId.LENGTH);
if(b.length < UniqueId.LENGTH) throw new FormatException();
parent = new MessageId(b);
}
// Read the group
Group group = groupReader.readStruct(r);
// Read the author, if there is one
Author author = null;
if(r.hasNull()) r.readNull();
else author = authorReader.readStruct(r);
// Read the content type
String contentType = r.readString(MAX_CONTENT_TYPE_LENGTH);
// Read the timestamp
long timestamp = r.readIntAny();
if(timestamp < 0) throw new FormatException();
// Read the salt
byte[] salt = r.readBytes(MESSAGE_SALT_LENGTH);
if(salt.length < MESSAGE_SALT_LENGTH) throw new FormatException();
// Read the message body
byte[] body = r.readBytes(MAX_BODY_LENGTH);
// Record the offset of the body within the message
int bodyStart = (int) counting.getCount() - body.length;
// Record the length of the data covered by the author's signature
int signedLength = (int) counting.getCount();
// Read the author's signature, if there is one
byte[] signature = null;
if(author == null) r.readNull();
else signature = r.readBytes(MAX_SIGNATURE_LENGTH);
// Read the end of the struct
r.readStructEnd();
// The signature will be verified later
r.removeConsumer(counting);
r.removeConsumer(copying);
byte[] raw = copying.getCopy();
return new UnverifiedMessage(parent, group, author, contentType,
timestamp, raw, signature, bodyStart, body.length,
signedLength);
}
}

View File

@@ -0,0 +1,59 @@
package org.briarproject.messaging;
import static org.briarproject.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE;
import java.security.GeneralSecurityException;
import javax.inject.Inject;
import org.briarproject.api.Author;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.KeyParser;
import org.briarproject.api.crypto.MessageDigest;
import org.briarproject.api.crypto.PublicKey;
import org.briarproject.api.crypto.Signature;
import org.briarproject.api.messaging.Message;
import org.briarproject.api.messaging.MessageId;
import org.briarproject.api.messaging.MessageVerifier;
import org.briarproject.api.messaging.UnverifiedMessage;
import org.briarproject.api.system.Clock;
class MessageVerifierImpl implements MessageVerifier {
private final CryptoComponent crypto;
private final Clock clock;
private final KeyParser keyParser;
@Inject
MessageVerifierImpl(CryptoComponent crypto, Clock clock) {
this.crypto = crypto;
this.clock = clock;
keyParser = crypto.getSignatureKeyParser();
}
public Message verifyMessage(UnverifiedMessage m)
throws GeneralSecurityException {
MessageDigest messageDigest = crypto.getMessageDigest();
Signature signature = crypto.getSignature();
// Reject the message if it's too far in the future
long now = clock.currentTimeMillis();
if(m.getTimestamp() > now + MAX_CLOCK_DIFFERENCE)
throw new GeneralSecurityException();
// Hash the message to get the message ID
byte[] raw = m.getSerialised();
messageDigest.update(raw);
MessageId id = new MessageId(messageDigest.digest());
// Verify the author's signature, if there is one
Author author = m.getAuthor();
if(author != null) {
PublicKey k = keyParser.parsePublicKey(author.getPublicKey());
signature.initVerify(k);
signature.update(raw, 0, m.getSignedLength());
if(!signature.verify(m.getSignature()))
throw new GeneralSecurityException();
}
return new MessageImpl(id, m.getParent(), m.getGroup(), author,
m.getContentType(), m.getTimestamp(), raw, m.getBodyStart(),
m.getBodyLength());
}
}

View File

@@ -0,0 +1,52 @@
package org.briarproject.messaging;
import org.briarproject.api.Author;
import org.briarproject.api.AuthorFactory;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.GroupFactory;
import org.briarproject.api.messaging.MessageFactory;
import org.briarproject.api.messaging.MessageVerifier;
import org.briarproject.api.messaging.PacketReaderFactory;
import org.briarproject.api.messaging.PacketWriterFactory;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.UnverifiedMessage;
import org.briarproject.api.serial.StructReader;
import com.google.inject.AbstractModule;
import com.google.inject.Provides;
public class MessagingModule extends AbstractModule {
protected void configure() {
bind(AuthorFactory.class).to(AuthorFactoryImpl.class);
bind(GroupFactory.class).to(GroupFactoryImpl.class);
bind(MessageFactory.class).to(MessageFactoryImpl.class);
bind(MessageVerifier.class).to(MessageVerifierImpl.class);
bind(PacketReaderFactory.class).to(PacketReaderFactoryImpl.class);
bind(PacketWriterFactory.class).to(PacketWriterFactoryImpl.class);
}
@Provides
StructReader<Author> getAuthorReader(CryptoComponent crypto) {
return new AuthorReader(crypto);
}
@Provides
StructReader<Group> getGroupReader(CryptoComponent crypto) {
return new GroupReader(crypto);
}
@Provides
StructReader<UnverifiedMessage> getMessageReader(
StructReader<Group> groupReader,
StructReader<Author> authorReader) {
return new MessageReader(groupReader, authorReader);
}
@Provides
StructReader<SubscriptionUpdate> getSubscriptionUpdateReader(
StructReader<Group> groupReader) {
return new SubscriptionUpdateReader(groupReader);
}
}

View File

@@ -0,0 +1,33 @@
package org.briarproject.messaging;
import java.io.InputStream;
import javax.inject.Inject;
import org.briarproject.api.messaging.PacketReader;
import org.briarproject.api.messaging.PacketReaderFactory;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.UnverifiedMessage;
import org.briarproject.api.serial.ReaderFactory;
import org.briarproject.api.serial.StructReader;
class PacketReaderFactoryImpl implements PacketReaderFactory {
private final ReaderFactory readerFactory;
private final StructReader<UnverifiedMessage> messageReader;
private final StructReader<SubscriptionUpdate> subscriptionUpdateReader;
@Inject
PacketReaderFactoryImpl(ReaderFactory readerFactory,
StructReader<UnverifiedMessage> messageReader,
StructReader<SubscriptionUpdate> subscriptionUpdateReader) {
this.readerFactory = readerFactory;
this.messageReader = messageReader;
this.subscriptionUpdateReader = subscriptionUpdateReader;
}
public PacketReader createPacketReader(InputStream in) {
return new PacketReaderImpl(readerFactory, messageReader,
subscriptionUpdateReader, in);
}
}

View File

@@ -0,0 +1,257 @@
package org.briarproject.messaging;
import static org.briarproject.api.TransportPropertyConstants.MAX_PROPERTIES_PER_TRANSPORT;
import static org.briarproject.api.TransportPropertyConstants.MAX_PROPERTY_LENGTH;
import static org.briarproject.api.messaging.MessagingConstants.MAX_PACKET_LENGTH;
import static org.briarproject.api.messaging.Types.ACK;
import static org.briarproject.api.messaging.Types.MESSAGE;
import static org.briarproject.api.messaging.Types.OFFER;
import static org.briarproject.api.messaging.Types.REQUEST;
import static org.briarproject.api.messaging.Types.RETENTION_ACK;
import static org.briarproject.api.messaging.Types.RETENTION_UPDATE;
import static org.briarproject.api.messaging.Types.SUBSCRIPTION_ACK;
import static org.briarproject.api.messaging.Types.SUBSCRIPTION_UPDATE;
import static org.briarproject.api.messaging.Types.TRANSPORT_ACK;
import static org.briarproject.api.messaging.Types.TRANSPORT_UPDATE;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.briarproject.api.FormatException;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.UniqueId;
import org.briarproject.api.messaging.Ack;
import org.briarproject.api.messaging.MessageId;
import org.briarproject.api.messaging.Offer;
import org.briarproject.api.messaging.PacketReader;
import org.briarproject.api.messaging.Request;
import org.briarproject.api.messaging.RetentionAck;
import org.briarproject.api.messaging.RetentionUpdate;
import org.briarproject.api.messaging.SubscriptionAck;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.TransportAck;
import org.briarproject.api.messaging.TransportUpdate;
import org.briarproject.api.messaging.UnverifiedMessage;
import org.briarproject.api.serial.Consumer;
import org.briarproject.api.serial.CountingConsumer;
import org.briarproject.api.serial.Reader;
import org.briarproject.api.serial.ReaderFactory;
import org.briarproject.api.serial.StructReader;
// This class is not thread-safe
class PacketReaderImpl implements PacketReader {
private final StructReader<UnverifiedMessage> messageReader;
private final StructReader<SubscriptionUpdate> subscriptionUpdateReader;
private final Reader r;
PacketReaderImpl(ReaderFactory readerFactory,
StructReader<UnverifiedMessage> messageReader,
StructReader<SubscriptionUpdate> subscriptionUpdateReader,
InputStream in) {
this.messageReader = messageReader;
this.subscriptionUpdateReader = subscriptionUpdateReader;
r = readerFactory.createReader(in);
}
public boolean eof() throws IOException {
return r.eof();
}
public boolean hasAck() throws IOException {
return r.hasStruct(ACK);
}
public Ack readAck() throws IOException {
// Set up the reader
Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
r.addConsumer(counting);
// Read the start of the struct
r.readStructStart(ACK);
// Read the message IDs
List<MessageId> acked = new ArrayList<MessageId>();
r.readListStart();
while(!r.hasListEnd()) {
byte[] b = r.readBytes(UniqueId.LENGTH);
if(b.length != UniqueId.LENGTH)
throw new FormatException();
acked.add(new MessageId(b));
}
if(acked.isEmpty()) throw new FormatException();
r.readListEnd();
// Read the end of the struct
r.readStructEnd();
// Reset the reader
r.removeConsumer(counting);
// Build and return the ack
return new Ack(Collections.unmodifiableList(acked));
}
public boolean hasMessage() throws IOException {
return r.hasStruct(MESSAGE);
}
public UnverifiedMessage readMessage() throws IOException {
return messageReader.readStruct(r);
}
public boolean hasOffer() throws IOException {
return r.hasStruct(OFFER);
}
public Offer readOffer() throws IOException {
// Set up the reader
Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
r.addConsumer(counting);
// Read the start of the struct
r.readStructStart(OFFER);
// Read the message IDs
List<MessageId> offered = new ArrayList<MessageId>();
r.readListStart();
while(!r.hasListEnd()) {
byte[] b = r.readBytes(UniqueId.LENGTH);
if(b.length != UniqueId.LENGTH)
throw new FormatException();
offered.add(new MessageId(b));
}
if(offered.isEmpty()) throw new FormatException();
r.readListEnd();
// Read the end of the struct
r.readStructEnd();
// Reset the reader
r.removeConsumer(counting);
// Build and return the offer
return new Offer(Collections.unmodifiableList(offered));
}
public boolean hasRequest() throws IOException {
return r.hasStruct(REQUEST);
}
public Request readRequest() throws IOException {
// Set up the reader
Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
r.addConsumer(counting);
// Read the start of the struct
r.readStructStart(REQUEST);
// Read the message IDs
List<MessageId> requested = new ArrayList<MessageId>();
r.readListStart();
while(!r.hasListEnd()) {
byte[] b = r.readBytes(UniqueId.LENGTH);
if(b.length != UniqueId.LENGTH)
throw new FormatException();
requested.add(new MessageId(b));
}
if(requested.isEmpty()) throw new FormatException();
r.readListEnd();
// Read the end of the struct
r.readStructEnd();
// Reset the reader
r.removeConsumer(counting);
// Build and return the request
return new Request(Collections.unmodifiableList(requested));
}
public boolean hasRetentionAck() throws IOException {
return r.hasStruct(RETENTION_ACK);
}
public RetentionAck readRetentionAck() throws IOException {
r.readStructStart(RETENTION_ACK);
long version = r.readIntAny();
if(version < 0) throw new FormatException();
r.readStructEnd();
return new RetentionAck(version);
}
public boolean hasRetentionUpdate() throws IOException {
return r.hasStruct(RETENTION_UPDATE);
}
public RetentionUpdate readRetentionUpdate() throws IOException {
r.readStructStart(RETENTION_UPDATE);
long retention = r.readIntAny();
if(retention < 0) throw new FormatException();
long version = r.readIntAny();
if(version < 0) throw new FormatException();
r.readStructEnd();
return new RetentionUpdate(retention, version);
}
public boolean hasSubscriptionAck() throws IOException {
return r.hasStruct(SUBSCRIPTION_ACK);
}
public SubscriptionAck readSubscriptionAck() throws IOException {
r.readStructStart(SUBSCRIPTION_ACK);
long version = r.readIntAny();
if(version < 0) throw new FormatException();
r.readStructEnd();
return new SubscriptionAck(version);
}
public boolean hasSubscriptionUpdate() throws IOException {
return r.hasStruct(SUBSCRIPTION_UPDATE);
}
public SubscriptionUpdate readSubscriptionUpdate() throws IOException {
return subscriptionUpdateReader.readStruct(r);
}
public boolean hasTransportAck() throws IOException {
return r.hasStruct(TRANSPORT_ACK);
}
public TransportAck readTransportAck() throws IOException {
r.readStructStart(TRANSPORT_ACK);
byte[] b = r.readBytes(UniqueId.LENGTH);
if(b.length < UniqueId.LENGTH) throw new FormatException();
long version = r.readIntAny();
if(version < 0) throw new FormatException();
r.readStructEnd();
return new TransportAck(new TransportId(b), version);
}
public boolean hasTransportUpdate() throws IOException {
return r.hasStruct(TRANSPORT_UPDATE);
}
public TransportUpdate readTransportUpdate() throws IOException {
// Set up the reader
Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
r.addConsumer(counting);
// Read the start of the struct
r.readStructStart(TRANSPORT_UPDATE);
// Read the transport ID
byte[] b = r.readBytes(UniqueId.LENGTH);
if(b.length < UniqueId.LENGTH) throw new FormatException();
TransportId id = new TransportId(b);
// Read the transport properties
Map<String, String> p = new HashMap<String, String>();
r.readMapStart();
for(int i = 0; !r.hasMapEnd(); i++) {
if(i == MAX_PROPERTIES_PER_TRANSPORT)
throw new FormatException();
String key = r.readString(MAX_PROPERTY_LENGTH);
String value = r.readString(MAX_PROPERTY_LENGTH);
p.put(key, value);
}
r.readMapEnd();
// Read the version number
long version = r.readIntAny();
if(version < 0) throw new FormatException();
// Read the end of the struct
r.readStructEnd();
// Reset the reader
r.removeConsumer(counting);
// Build and return the transport update
return new TransportUpdate(id, new TransportProperties(p), version);
}
}

View File

@@ -0,0 +1,28 @@
package org.briarproject.messaging;
import java.io.OutputStream;
import javax.inject.Inject;
import org.briarproject.api.messaging.PacketWriter;
import org.briarproject.api.messaging.PacketWriterFactory;
import org.briarproject.api.serial.SerialComponent;
import org.briarproject.api.serial.WriterFactory;
class PacketWriterFactoryImpl implements PacketWriterFactory {
private final SerialComponent serial;
private final WriterFactory writerFactory;
@Inject
PacketWriterFactoryImpl(SerialComponent serial,
WriterFactory writerFactory) {
this.serial = serial;
this.writerFactory = writerFactory;
}
public PacketWriter createPacketWriter(OutputStream out,
boolean flush) {
return new PacketWriterImpl(serial, writerFactory, out, flush);
}
}

View File

@@ -0,0 +1,164 @@
package org.briarproject.messaging;
import static org.briarproject.api.messaging.MessagingConstants.MAX_PACKET_LENGTH;
import static org.briarproject.api.messaging.Types.ACK;
import static org.briarproject.api.messaging.Types.GROUP;
import static org.briarproject.api.messaging.Types.OFFER;
import static org.briarproject.api.messaging.Types.REQUEST;
import static org.briarproject.api.messaging.Types.RETENTION_ACK;
import static org.briarproject.api.messaging.Types.RETENTION_UPDATE;
import static org.briarproject.api.messaging.Types.SUBSCRIPTION_ACK;
import static org.briarproject.api.messaging.Types.SUBSCRIPTION_UPDATE;
import static org.briarproject.api.messaging.Types.TRANSPORT_ACK;
import static org.briarproject.api.messaging.Types.TRANSPORT_UPDATE;
import java.io.IOException;
import java.io.OutputStream;
import org.briarproject.api.messaging.Ack;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.MessageId;
import org.briarproject.api.messaging.Offer;
import org.briarproject.api.messaging.PacketWriter;
import org.briarproject.api.messaging.Request;
import org.briarproject.api.messaging.RetentionAck;
import org.briarproject.api.messaging.RetentionUpdate;
import org.briarproject.api.messaging.SubscriptionAck;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.TransportAck;
import org.briarproject.api.messaging.TransportUpdate;
import org.briarproject.api.serial.SerialComponent;
import org.briarproject.api.serial.Writer;
import org.briarproject.api.serial.WriterFactory;
// This class is not thread-safe
class PacketWriterImpl implements PacketWriter {
private final SerialComponent serial;
private final OutputStream out;
private final boolean flush;
private final Writer w;
PacketWriterImpl(SerialComponent serial, WriterFactory writerFactory,
OutputStream out, boolean flush) {
this.serial = serial;
this.out = out;
this.flush = flush;
w = writerFactory.createWriter(out);
}
public int getMaxMessagesForRequest(long capacity) {
int packet = (int) Math.min(capacity, MAX_PACKET_LENGTH);
int overhead = serial.getSerialisedStructStartLength(ACK)
+ serial.getSerialisedListStartLength()
+ serial.getSerialisedListEndLength()
+ serial.getSerialisedStructEndLength();
int idLength = serial.getSerialisedUniqueIdLength();
return (packet - overhead) / idLength;
}
public int getMaxMessagesForOffer(long capacity) {
int packet = (int) Math.min(capacity, MAX_PACKET_LENGTH);
int overhead = serial.getSerialisedStructStartLength(OFFER)
+ serial.getSerialisedListStartLength()
+ serial.getSerialisedListEndLength()
+ serial.getSerialisedStructEndLength();
int idLength = serial.getSerialisedUniqueIdLength();
return (packet - overhead) / idLength;
}
public void writeAck(Ack a) throws IOException {
w.writeStructStart(ACK);
w.writeListStart();
for(MessageId m : a.getMessageIds()) w.writeBytes(m.getBytes());
w.writeListEnd();
w.writeStructEnd();
if(flush) out.flush();
}
public void writeMessage(byte[] raw) throws IOException {
out.write(raw);
if(flush) out.flush();
}
public void writeOffer(Offer o) throws IOException {
w.writeStructStart(OFFER);
w.writeListStart();
for(MessageId m : o.getMessageIds()) w.writeBytes(m.getBytes());
w.writeListEnd();
w.writeStructEnd();
if(flush) out.flush();
}
public void writeRequest(Request r) throws IOException {
w.writeStructStart(REQUEST);
w.writeListStart();
for(MessageId m : r.getMessageIds()) w.writeBytes(m.getBytes());
w.writeListEnd();
w.writeStructEnd();
if(flush) out.flush();
}
public void writeRetentionAck(RetentionAck a) throws IOException {
w.writeStructStart(RETENTION_ACK);
w.writeIntAny(a.getVersion());
w.writeStructEnd();
if(flush) out.flush();
}
public void writeRetentionUpdate(RetentionUpdate u) throws IOException {
w.writeStructStart(RETENTION_UPDATE);
w.writeIntAny(u.getRetentionTime());
w.writeIntAny(u.getVersion());
w.writeStructEnd();
if(flush) out.flush();
}
public void writeSubscriptionAck(SubscriptionAck a) throws IOException {
w.writeStructStart(SUBSCRIPTION_ACK);
w.writeIntAny(a.getVersion());
w.writeStructEnd();
if(flush) out.flush();
}
public void writeSubscriptionUpdate(SubscriptionUpdate u)
throws IOException {
w.writeStructStart(SUBSCRIPTION_UPDATE);
w.writeListStart();
for(Group g : u.getGroups()) {
w.writeStructStart(GROUP);
w.writeString(g.getName());
w.writeBytes(g.getSalt());
w.writeStructEnd();
}
w.writeListEnd();
w.writeIntAny(u.getVersion());
w.writeStructEnd();
if(flush) out.flush();
}
public void writeTransportAck(TransportAck a) throws IOException {
w.writeStructStart(TRANSPORT_ACK);
w.writeBytes(a.getId().getBytes());
w.writeIntAny(a.getVersion());
w.writeStructEnd();
if(flush) out.flush();
}
public void writeTransportUpdate(TransportUpdate u) throws IOException {
w.writeStructStart(TRANSPORT_UPDATE);
w.writeBytes(u.getId().getBytes());
w.writeMap(u.getProperties());
w.writeIntAny(u.getVersion());
w.writeStructEnd();
if(flush) out.flush();
}
public void flush() throws IOException {
out.flush();
}
public void close() throws IOException {
out.close();
}
}

View File

@@ -0,0 +1,51 @@
package org.briarproject.messaging;
import static org.briarproject.api.messaging.MessagingConstants.MAX_PACKET_LENGTH;
import static org.briarproject.api.messaging.MessagingConstants.MAX_SUBSCRIPTIONS;
import static org.briarproject.api.messaging.Types.SUBSCRIPTION_UPDATE;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.briarproject.api.FormatException;
import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.serial.Consumer;
import org.briarproject.api.serial.CountingConsumer;
import org.briarproject.api.serial.Reader;
import org.briarproject.api.serial.StructReader;
class SubscriptionUpdateReader implements StructReader<SubscriptionUpdate> {
private final StructReader<Group> groupReader;
SubscriptionUpdateReader(StructReader<Group> groupReader) {
this.groupReader = groupReader;
}
public SubscriptionUpdate readStruct(Reader r) throws IOException {
// Set up the reader
Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
r.addConsumer(counting);
// Read the start of the struct
r.readStructStart(SUBSCRIPTION_UPDATE);
// Read the subscriptions
List<Group> groups = new ArrayList<Group>();
r.readListStart();
for(int i = 0; i < MAX_SUBSCRIPTIONS && !r.hasListEnd(); i++)
groups.add(groupReader.readStruct(r));
r.readListEnd();
// Read the version number
long version = r.readIntAny();
if(version < 0) throw new FormatException();
// Read the end of the struct
r.readStructEnd();
// Reset the reader
r.removeConsumer(counting);
// Build and return the subscription update
groups = Collections.unmodifiableList(groups);
return new SubscriptionUpdate(groups, version);
}
}

View File

@@ -0,0 +1,879 @@
package org.briarproject.messaging.duplex;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import static org.briarproject.api.messaging.MessagingConstants.MAX_PACKET_LENGTH;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import java.util.Collection;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Logger;
import org.briarproject.api.ContactId;
import org.briarproject.api.FormatException;
import org.briarproject.api.TransportId;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.event.ContactRemovedEvent;
import org.briarproject.api.event.Event;
import org.briarproject.api.event.EventListener;
import org.briarproject.api.event.LocalSubscriptionsUpdatedEvent;
import org.briarproject.api.event.LocalTransportsUpdatedEvent;
import org.briarproject.api.event.MessageAddedEvent;
import org.briarproject.api.event.MessageExpiredEvent;
import org.briarproject.api.event.MessageRequestedEvent;
import org.briarproject.api.event.MessageToAckEvent;
import org.briarproject.api.event.MessageToRequestEvent;
import org.briarproject.api.event.RemoteRetentionTimeUpdatedEvent;
import org.briarproject.api.event.RemoteSubscriptionsUpdatedEvent;
import org.briarproject.api.event.RemoteTransportsUpdatedEvent;
import org.briarproject.api.messaging.Ack;
import org.briarproject.api.messaging.Message;
import org.briarproject.api.messaging.MessageVerifier;
import org.briarproject.api.messaging.Offer;
import org.briarproject.api.messaging.PacketReader;
import org.briarproject.api.messaging.PacketReaderFactory;
import org.briarproject.api.messaging.PacketWriter;
import org.briarproject.api.messaging.PacketWriterFactory;
import org.briarproject.api.messaging.Request;
import org.briarproject.api.messaging.RetentionAck;
import org.briarproject.api.messaging.RetentionUpdate;
import org.briarproject.api.messaging.SubscriptionAck;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.TransportAck;
import org.briarproject.api.messaging.TransportUpdate;
import org.briarproject.api.messaging.UnverifiedMessage;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.transport.ConnectionContext;
import org.briarproject.api.transport.ConnectionReader;
import org.briarproject.api.transport.ConnectionReaderFactory;
import org.briarproject.api.transport.ConnectionRegistry;
import org.briarproject.api.transport.ConnectionWriter;
import org.briarproject.api.transport.ConnectionWriterFactory;
import org.briarproject.util.ByteUtils;
abstract class DuplexConnection implements EventListener {
private static final Logger LOG =
Logger.getLogger(DuplexConnection.class.getName());
private static final Runnable CLOSE = new Runnable() {
public void run() {}
};
private static final Runnable DIE = new Runnable() {
public void run() {}
};
protected final DatabaseComponent db;
protected final ConnectionRegistry connRegistry;
protected final ConnectionReaderFactory connReaderFactory;
protected final ConnectionWriterFactory connWriterFactory;
protected final PacketReaderFactory packetReaderFactory;
protected final PacketWriterFactory packetWriterFactory;
protected final ConnectionContext ctx;
protected final DuplexTransportConnection transport;
protected final ContactId contactId;
protected final TransportId transportId;
private final Executor dbExecutor, cryptoExecutor;
private final MessageVerifier messageVerifier;
private final long maxLatency;
private final AtomicBoolean disposed;
private final BlockingQueue<Runnable> writerTasks;
private volatile PacketWriter writer = null;
DuplexConnection(Executor dbExecutor, Executor cryptoExecutor,
MessageVerifier messageVerifier, DatabaseComponent db,
ConnectionRegistry connRegistry,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory,
PacketReaderFactory packetReaderFactory,
PacketWriterFactory packetWriterFactory, ConnectionContext ctx,
DuplexTransportConnection transport) {
this.dbExecutor = dbExecutor;
this.cryptoExecutor = cryptoExecutor;
this.messageVerifier = messageVerifier;
this.db = db;
this.connRegistry = connRegistry;
this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory;
this.packetReaderFactory = packetReaderFactory;
this.packetWriterFactory = packetWriterFactory;
this.ctx = ctx;
this.transport = transport;
contactId = ctx.getContactId();
transportId = ctx.getTransportId();
maxLatency = transport.getMaxLatency();
disposed = new AtomicBoolean(false);
writerTasks = new LinkedBlockingQueue<Runnable>();
}
protected abstract ConnectionReader createConnectionReader()
throws IOException;
protected abstract ConnectionWriter createConnectionWriter()
throws IOException;
public void eventOccurred(Event e) {
if(e instanceof ContactRemovedEvent) {
ContactRemovedEvent c = (ContactRemovedEvent) e;
if(contactId.equals(c.getContactId())) writerTasks.add(CLOSE);
} else if(e instanceof MessageAddedEvent) {
dbExecutor.execute(new GenerateOffer());
} else if(e instanceof MessageExpiredEvent) {
dbExecutor.execute(new GenerateRetentionUpdate());
} else if(e instanceof LocalSubscriptionsUpdatedEvent) {
LocalSubscriptionsUpdatedEvent l =
(LocalSubscriptionsUpdatedEvent) e;
if(l.getAffectedContacts().contains(contactId)) {
dbExecutor.execute(new GenerateSubscriptionUpdate());
dbExecutor.execute(new GenerateOffer());
}
} else if(e instanceof LocalTransportsUpdatedEvent) {
dbExecutor.execute(new GenerateTransportUpdates());
} else if(e instanceof MessageRequestedEvent) {
if(((MessageRequestedEvent) e).getContactId().equals(contactId))
dbExecutor.execute(new GenerateBatch());
} else if(e instanceof MessageToAckEvent) {
if(((MessageToAckEvent) e).getContactId().equals(contactId))
dbExecutor.execute(new GenerateAck());
} else if(e instanceof MessageToRequestEvent) {
if(((MessageToRequestEvent) e).getContactId().equals(contactId))
dbExecutor.execute(new GenerateRequest());
} else if(e instanceof RemoteRetentionTimeUpdatedEvent) {
dbExecutor.execute(new GenerateRetentionAck());
} else if(e instanceof RemoteSubscriptionsUpdatedEvent) {
dbExecutor.execute(new GenerateSubscriptionAck());
dbExecutor.execute(new GenerateOffer());
} else if(e instanceof RemoteTransportsUpdatedEvent) {
dbExecutor.execute(new GenerateTransportAcks());
}
}
void read() {
try {
InputStream in = createConnectionReader().getInputStream();
PacketReader reader = packetReaderFactory.createPacketReader(in);
if(LOG.isLoggable(INFO)) LOG.info("Starting to read");
while(!reader.eof()) {
if(reader.hasAck()) {
Ack a = reader.readAck();
if(LOG.isLoggable(INFO)) LOG.info("Received ack");
dbExecutor.execute(new ReceiveAck(a));
} else if(reader.hasMessage()) {
UnverifiedMessage m = reader.readMessage();
if(LOG.isLoggable(INFO)) LOG.info("Received message");
cryptoExecutor.execute(new VerifyMessage(m));
} else if(reader.hasOffer()) {
Offer o = reader.readOffer();
if(LOG.isLoggable(INFO)) LOG.info("Received offer");
dbExecutor.execute(new ReceiveOffer(o));
} else if(reader.hasRequest()) {
Request r = reader.readRequest();
if(LOG.isLoggable(INFO)) LOG.info("Received request");
dbExecutor.execute(new ReceiveRequest(r));
} else if(reader.hasRetentionAck()) {
RetentionAck a = reader.readRetentionAck();
if(LOG.isLoggable(INFO)) LOG.info("Received retention ack");
dbExecutor.execute(new ReceiveRetentionAck(a));
} else if(reader.hasRetentionUpdate()) {
RetentionUpdate u = reader.readRetentionUpdate();
if(LOG.isLoggable(INFO))
LOG.info("Received retention update");
dbExecutor.execute(new ReceiveRetentionUpdate(u));
} else if(reader.hasSubscriptionAck()) {
SubscriptionAck a = reader.readSubscriptionAck();
if(LOG.isLoggable(INFO))
LOG.info("Received subscription ack");
dbExecutor.execute(new ReceiveSubscriptionAck(a));
} else if(reader.hasSubscriptionUpdate()) {
SubscriptionUpdate u = reader.readSubscriptionUpdate();
if(LOG.isLoggable(INFO))
LOG.info("Received subscription update");
dbExecutor.execute(new ReceiveSubscriptionUpdate(u));
} else if(reader.hasTransportAck()) {
TransportAck a = reader.readTransportAck();
if(LOG.isLoggable(INFO))
LOG.info("Received transport ack");
dbExecutor.execute(new ReceiveTransportAck(a));
} else if(reader.hasTransportUpdate()) {
TransportUpdate u = reader.readTransportUpdate();
if(LOG.isLoggable(INFO))
LOG.info("Received transport update");
dbExecutor.execute(new ReceiveTransportUpdate(u));
} else {
throw new FormatException();
}
}
if(LOG.isLoggable(INFO)) LOG.info("Finished reading");
writerTasks.add(CLOSE);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
writerTasks.add(DIE);
}
}
void write() {
connRegistry.registerConnection(contactId, transportId);
db.addListener(this);
try {
OutputStream out = createConnectionWriter().getOutputStream();
writer = packetWriterFactory.createPacketWriter(out,
transport.shouldFlush());
if(LOG.isLoggable(INFO)) LOG.info("Starting to write");
// Send the initial packets
dbExecutor.execute(new GenerateTransportAcks());
dbExecutor.execute(new GenerateTransportUpdates());
dbExecutor.execute(new GenerateSubscriptionAck());
dbExecutor.execute(new GenerateSubscriptionUpdate());
dbExecutor.execute(new GenerateRetentionAck());
dbExecutor.execute(new GenerateRetentionUpdate());
dbExecutor.execute(new GenerateAck());
dbExecutor.execute(new GenerateBatch());
dbExecutor.execute(new GenerateOffer());
dbExecutor.execute(new GenerateRequest());
// Main loop
Runnable task = null;
while(true) {
if(LOG.isLoggable(INFO))
LOG.info("Waiting for something to write");
task = writerTasks.take();
if(task == CLOSE || task == DIE) break;
task.run();
}
if(LOG.isLoggable(INFO)) LOG.info("Finished writing");
if(task == CLOSE) {
writer.flush();
writer.close();
dispose(false, true);
} else {
dispose(true, true);
}
} catch(InterruptedException e) {
if(LOG.isLoggable(INFO))
LOG.info("Interrupted while waiting for task");
dispose(true, true);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, true);
}
db.removeListener(this);
connRegistry.unregisterConnection(contactId, transportId);
}
private void dispose(boolean exception, boolean recognised) {
if(disposed.getAndSet(true)) return;
if(LOG.isLoggable(INFO))
LOG.info("Disposing: " + exception + ", " + recognised);
ByteUtils.erase(ctx.getSecret());
try {
transport.dispose(exception, recognised);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
// This task runs on a database thread
private class ReceiveAck implements Runnable {
private final Ack ack;
private ReceiveAck(Ack ack) {
this.ack = ack;
}
public void run() {
try {
db.receiveAck(contactId, ack);
if(LOG.isLoggable(INFO)) LOG.info("DB received ack");
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on a crypto thread
private class VerifyMessage implements Runnable {
private final UnverifiedMessage message;
private VerifyMessage(UnverifiedMessage message) {
this.message = message;
}
public void run() {
try {
Message m = messageVerifier.verifyMessage(message);
if(LOG.isLoggable(INFO)) LOG.info("Verified message");
dbExecutor.execute(new ReceiveMessage(m));
} catch(GeneralSecurityException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on a database thread
private class ReceiveMessage implements Runnable {
private final Message message;
private ReceiveMessage(Message message) {
this.message = message;
}
public void run() {
try {
db.receiveMessage(contactId, message);
if(LOG.isLoggable(INFO)) LOG.info("DB received message");
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on a database thread
private class ReceiveOffer implements Runnable {
private final Offer offer;
private ReceiveOffer(Offer offer) {
this.offer = offer;
}
public void run() {
try {
db.receiveOffer(contactId, offer);
if(LOG.isLoggable(INFO)) LOG.info("DB received offer");
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on a database thread
private class ReceiveRequest implements Runnable {
private final Request request;
private ReceiveRequest(Request request) {
this.request = request;
}
public void run() {
try {
db.receiveRequest(contactId, request);
if(LOG.isLoggable(INFO)) LOG.info("DB received request");
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on a database thread
private class ReceiveRetentionAck implements Runnable {
private final RetentionAck ack;
private ReceiveRetentionAck(RetentionAck ack) {
this.ack = ack;
}
public void run() {
try {
db.receiveRetentionAck(contactId, ack);
if(LOG.isLoggable(INFO)) LOG.info("DB received retention ack");
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on a database thread
private class ReceiveRetentionUpdate implements Runnable {
private final RetentionUpdate update;
private ReceiveRetentionUpdate(RetentionUpdate update) {
this.update = update;
}
public void run() {
try {
db.receiveRetentionUpdate(contactId, update);
if(LOG.isLoggable(INFO))
LOG.info("DB received retention update");
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on a database thread
private class ReceiveSubscriptionAck implements Runnable {
private final SubscriptionAck ack;
private ReceiveSubscriptionAck(SubscriptionAck ack) {
this.ack = ack;
}
public void run() {
try {
db.receiveSubscriptionAck(contactId, ack);
if(LOG.isLoggable(INFO))
LOG.info("DB received subscription ack");
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on a database thread
private class ReceiveSubscriptionUpdate implements Runnable {
private final SubscriptionUpdate update;
private ReceiveSubscriptionUpdate(SubscriptionUpdate update) {
this.update = update;
}
public void run() {
try {
db.receiveSubscriptionUpdate(contactId, update);
if(LOG.isLoggable(INFO))
LOG.info("DB received subscription update");
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on a database thread
private class ReceiveTransportAck implements Runnable {
private final TransportAck ack;
private ReceiveTransportAck(TransportAck ack) {
this.ack = ack;
}
public void run() {
try {
db.receiveTransportAck(contactId, ack);
if(LOG.isLoggable(INFO)) LOG.info("DB received transport ack");
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on a database thread
private class ReceiveTransportUpdate implements Runnable {
private final TransportUpdate update;
private ReceiveTransportUpdate(TransportUpdate update) {
this.update = update;
}
public void run() {
try {
db.receiveTransportUpdate(contactId, update);
if(LOG.isLoggable(INFO))
LOG.info("DB received transport update");
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on a database thread
private class GenerateAck implements Runnable {
public void run() {
assert writer != null;
int maxMessages = writer.getMaxMessagesForRequest(Long.MAX_VALUE);
try {
Ack a = db.generateAck(contactId, maxMessages);
if(LOG.isLoggable(INFO))
LOG.info("Generated ack: " + (a != null));
if(a != null) writerTasks.add(new WriteAck(a));
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on the writer thread
private class WriteAck implements Runnable {
private final Ack ack;
private WriteAck(Ack ack) {
this.ack = ack;
}
public void run() {
assert writer != null;
try {
writer.writeAck(ack);
if(LOG.isLoggable(INFO)) LOG.info("Sent ack");
dbExecutor.execute(new GenerateAck());
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, true);
}
}
}
// This task runs on a database thread
private class GenerateBatch implements Runnable {
public void run() {
assert writer != null;
try {
Collection<byte[]> b = db.generateRequestedBatch(contactId,
MAX_PACKET_LENGTH, maxLatency);
if(LOG.isLoggable(INFO))
LOG.info("Generated batch: " + (b != null));
if(b != null) writerTasks.add(new WriteBatch(b));
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on the writer thread
private class WriteBatch implements Runnable {
private final Collection<byte[]> batch;
private WriteBatch(Collection<byte[]> batch) {
this.batch = batch;
}
public void run() {
assert writer != null;
try {
for(byte[] raw : batch) writer.writeMessage(raw);
if(LOG.isLoggable(INFO)) LOG.info("Sent batch");
dbExecutor.execute(new GenerateBatch());
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, true);
}
}
}
// This task runs on a database thread
private class GenerateOffer implements Runnable {
public void run() {
assert writer != null;
int maxMessages = writer.getMaxMessagesForOffer(Long.MAX_VALUE);
try {
Offer o = db.generateOffer(contactId, maxMessages, maxLatency);
if(LOG.isLoggable(INFO))
LOG.info("Generated offer: " + (o != null));
if(o != null) writerTasks.add(new WriteOffer(o));
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on the writer thread
private class WriteOffer implements Runnable {
private final Offer offer;
private WriteOffer(Offer offer) {
this.offer = offer;
}
public void run() {
assert writer != null;
try {
writer.writeOffer(offer);
if(LOG.isLoggable(INFO)) LOG.info("Sent offer");
dbExecutor.execute(new GenerateOffer());
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, true);
}
}
}
// This task runs on a database thread
private class GenerateRequest implements Runnable {
public void run() {
assert writer != null;
int maxMessages = writer.getMaxMessagesForRequest(Long.MAX_VALUE);
try {
Request r = db.generateRequest(contactId, maxMessages);
if(LOG.isLoggable(INFO))
LOG.info("Generated request: " + (r != null));
if(r != null) writerTasks.add(new WriteRequest(r));
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on the writer thread
private class WriteRequest implements Runnable {
private final Request request;
private WriteRequest(Request request) {
this.request = request;
}
public void run() {
assert writer != null;
try {
writer.writeRequest(request);
if(LOG.isLoggable(INFO)) LOG.info("Sent request");
dbExecutor.execute(new GenerateRequest());
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, true);
}
}
}
// This task runs on a database thread
private class GenerateRetentionAck implements Runnable {
public void run() {
try {
RetentionAck a = db.generateRetentionAck(contactId);
if(LOG.isLoggable(INFO))
LOG.info("Generated retention ack: " + (a != null));
if(a != null) writerTasks.add(new WriteRetentionAck(a));
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This tasks runs on the writer thread
private class WriteRetentionAck implements Runnable {
private final RetentionAck ack;
private WriteRetentionAck(RetentionAck ack) {
this.ack = ack;
}
public void run() {
assert writer != null;
try {
writer.writeRetentionAck(ack);
if(LOG.isLoggable(INFO)) LOG.info("Sent retention ack");
dbExecutor.execute(new GenerateRetentionAck());
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, true);
}
}
}
// This task runs on a database thread
private class GenerateRetentionUpdate implements Runnable {
public void run() {
try {
RetentionUpdate u =
db.generateRetentionUpdate(contactId, maxLatency);
if(LOG.isLoggable(INFO))
LOG.info("Generated retention update: " + (u != null));
if(u != null) writerTasks.add(new WriteRetentionUpdate(u));
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on the writer thread
private class WriteRetentionUpdate implements Runnable {
private final RetentionUpdate update;
private WriteRetentionUpdate(RetentionUpdate update) {
this.update = update;
}
public void run() {
assert writer != null;
try {
writer.writeRetentionUpdate(update);
if(LOG.isLoggable(INFO)) LOG.info("Sent retention update");
dbExecutor.execute(new GenerateRetentionUpdate());
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, true);
}
}
}
// This task runs on a database thread
private class GenerateSubscriptionAck implements Runnable {
public void run() {
try {
SubscriptionAck a = db.generateSubscriptionAck(contactId);
if(LOG.isLoggable(INFO))
LOG.info("Generated subscription ack: " + (a != null));
if(a != null) writerTasks.add(new WriteSubscriptionAck(a));
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This tasks runs on the writer thread
private class WriteSubscriptionAck implements Runnable {
private final SubscriptionAck ack;
private WriteSubscriptionAck(SubscriptionAck ack) {
this.ack = ack;
}
public void run() {
assert writer != null;
try {
writer.writeSubscriptionAck(ack);
if(LOG.isLoggable(INFO)) LOG.info("Sent subscription ack");
dbExecutor.execute(new GenerateSubscriptionAck());
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, true);
}
}
}
// This task runs on a database thread
private class GenerateSubscriptionUpdate implements Runnable {
public void run() {
try {
SubscriptionUpdate u =
db.generateSubscriptionUpdate(contactId, maxLatency);
if(LOG.isLoggable(INFO))
LOG.info("Generated subscription update: " + (u != null));
if(u != null) writerTasks.add(new WriteSubscriptionUpdate(u));
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on the writer thread
private class WriteSubscriptionUpdate implements Runnable {
private final SubscriptionUpdate update;
private WriteSubscriptionUpdate(SubscriptionUpdate update) {
this.update = update;
}
public void run() {
assert writer != null;
try {
writer.writeSubscriptionUpdate(update);
if(LOG.isLoggable(INFO)) LOG.info("Sent subscription update");
dbExecutor.execute(new GenerateSubscriptionUpdate());
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, true);
}
}
}
// This task runs on a database thread
private class GenerateTransportAcks implements Runnable {
public void run() {
try {
Collection<TransportAck> acks =
db.generateTransportAcks(contactId);
if(LOG.isLoggable(INFO))
LOG.info("Generated transport acks: " + (acks != null));
if(acks != null) writerTasks.add(new WriteTransportAcks(acks));
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This tasks runs on the writer thread
private class WriteTransportAcks implements Runnable {
private final Collection<TransportAck> acks;
private WriteTransportAcks(Collection<TransportAck> acks) {
this.acks = acks;
}
public void run() {
assert writer != null;
try {
for(TransportAck a : acks) writer.writeTransportAck(a);
if(LOG.isLoggable(INFO)) LOG.info("Sent transport acks");
dbExecutor.execute(new GenerateTransportAcks());
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, true);
}
}
}
// This task runs on a database thread
private class GenerateTransportUpdates implements Runnable {
public void run() {
try {
Collection<TransportUpdate> t =
db.generateTransportUpdates(contactId, maxLatency);
if(LOG.isLoggable(INFO))
LOG.info("Generated transport updates: " + (t != null));
if(t != null) writerTasks.add(new WriteTransportUpdates(t));
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
// This task runs on the writer thread
private class WriteTransportUpdates implements Runnable {
private final Collection<TransportUpdate> updates;
private WriteTransportUpdates(Collection<TransportUpdate> updates) {
this.updates = updates;
}
public void run() {
assert writer != null;
try {
for(TransportUpdate u : updates) writer.writeTransportUpdate(u);
if(LOG.isLoggable(INFO)) LOG.info("Sent transport updates");
dbExecutor.execute(new GenerateTransportUpdates());
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, true);
}
}
}
}

View File

@@ -0,0 +1,107 @@
package org.briarproject.messaging.duplex;
import static java.util.logging.Level.WARNING;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.CryptoExecutor;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DatabaseExecutor;
import org.briarproject.api.messaging.MessageVerifier;
import org.briarproject.api.messaging.PacketReaderFactory;
import org.briarproject.api.messaging.PacketWriterFactory;
import org.briarproject.api.messaging.duplex.DuplexConnectionFactory;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.transport.ConnectionContext;
import org.briarproject.api.transport.ConnectionReaderFactory;
import org.briarproject.api.transport.ConnectionRegistry;
import org.briarproject.api.transport.ConnectionWriterFactory;
class DuplexConnectionFactoryImpl implements DuplexConnectionFactory {
private static final Logger LOG =
Logger.getLogger(DuplexConnectionFactoryImpl.class.getName());
private final Executor dbExecutor, cryptoExecutor;
private final MessageVerifier messageVerifier;
private final DatabaseComponent db;
private final KeyManager keyManager;
private final ConnectionRegistry connRegistry;
private final ConnectionReaderFactory connReaderFactory;
private final ConnectionWriterFactory connWriterFactory;
private final PacketReaderFactory packetReaderFactory;
private final PacketWriterFactory packetWriterFactory;
@Inject
DuplexConnectionFactoryImpl(@DatabaseExecutor Executor dbExecutor,
@CryptoExecutor Executor cryptoExecutor,
MessageVerifier messageVerifier, DatabaseComponent db,
KeyManager keyManager, ConnectionRegistry connRegistry,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory,
PacketReaderFactory packetReaderFactory,
PacketWriterFactory packetWriterFactory) {
this.dbExecutor = dbExecutor;
this.cryptoExecutor = cryptoExecutor;
this.messageVerifier = messageVerifier;
this.db = db;
this.keyManager = keyManager;
this.connRegistry = connRegistry;
this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory;
this.packetReaderFactory = packetReaderFactory;
this.packetWriterFactory = packetWriterFactory;
}
public void createIncomingConnection(ConnectionContext ctx,
DuplexTransportConnection transport) {
final DuplexConnection conn = new IncomingDuplexConnection(dbExecutor,
cryptoExecutor, messageVerifier, db, connRegistry,
connReaderFactory, connWriterFactory, packetReaderFactory,
packetWriterFactory, ctx, transport);
Runnable write = new Runnable() {
public void run() {
conn.write();
}
};
new Thread(write, "DuplexConnectionWriter").start();
Runnable read = new Runnable() {
public void run() {
conn.read();
}
};
new Thread(read, "DuplexConnectionReader").start();
}
public void createOutgoingConnection(ContactId c, TransportId t,
DuplexTransportConnection transport) {
ConnectionContext ctx = keyManager.getConnectionContext(c, t);
if(ctx == null) {
if(LOG.isLoggable(WARNING))
LOG.warning("Could not create outgoing connection context");
return;
}
final DuplexConnection conn = new OutgoingDuplexConnection(dbExecutor,
cryptoExecutor, messageVerifier, db, connRegistry,
connReaderFactory, connWriterFactory, packetReaderFactory,
packetWriterFactory, ctx, transport);
Runnable write = new Runnable() {
public void run() {
conn.write();
}
};
new Thread(write, "DuplexConnectionWriter").start();
Runnable read = new Runnable() {
public void run() {
conn.read();
}
};
new Thread(read, "DuplexConnectionReader").start();
}
}

View File

@@ -0,0 +1,15 @@
package org.briarproject.messaging.duplex;
import javax.inject.Singleton;
import org.briarproject.api.messaging.duplex.DuplexConnectionFactory;
import com.google.inject.AbstractModule;
public class DuplexMessagingModule extends AbstractModule {
protected void configure() {
bind(DuplexConnectionFactory.class).to(
DuplexConnectionFactoryImpl.class).in(Singleton.class);
}
}

View File

@@ -0,0 +1,50 @@
package org.briarproject.messaging.duplex;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.concurrent.Executor;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.messaging.MessageVerifier;
import org.briarproject.api.messaging.PacketReaderFactory;
import org.briarproject.api.messaging.PacketWriterFactory;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.transport.ConnectionContext;
import org.briarproject.api.transport.ConnectionReader;
import org.briarproject.api.transport.ConnectionReaderFactory;
import org.briarproject.api.transport.ConnectionRegistry;
import org.briarproject.api.transport.ConnectionWriter;
import org.briarproject.api.transport.ConnectionWriterFactory;
class IncomingDuplexConnection extends DuplexConnection {
IncomingDuplexConnection(Executor dbExecutor, Executor cryptoExecutor,
MessageVerifier messageVerifier, DatabaseComponent db,
ConnectionRegistry connRegistry,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory,
PacketReaderFactory packetReaderFactory,
PacketWriterFactory packetWriterFactory,
ConnectionContext ctx, DuplexTransportConnection transport) {
super(dbExecutor, cryptoExecutor, messageVerifier, db, connRegistry,
connReaderFactory, connWriterFactory, packetReaderFactory,
packetWriterFactory, ctx, transport);
}
@Override
protected ConnectionReader createConnectionReader() throws IOException {
InputStream in = transport.getInputStream();
int maxFrameLength = transport.getMaxFrameLength();
return connReaderFactory.createConnectionReader(in, maxFrameLength,
ctx, true, true);
}
@Override
protected ConnectionWriter createConnectionWriter() throws IOException {
OutputStream out = transport.getOutputStream();
int maxFrameLength = transport.getMaxFrameLength();
return connWriterFactory.createConnectionWriter(out, maxFrameLength,
Long.MAX_VALUE, ctx, true, false);
}
}

View File

@@ -0,0 +1,50 @@
package org.briarproject.messaging.duplex;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.concurrent.Executor;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.messaging.MessageVerifier;
import org.briarproject.api.messaging.PacketReaderFactory;
import org.briarproject.api.messaging.PacketWriterFactory;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.transport.ConnectionContext;
import org.briarproject.api.transport.ConnectionReader;
import org.briarproject.api.transport.ConnectionReaderFactory;
import org.briarproject.api.transport.ConnectionRegistry;
import org.briarproject.api.transport.ConnectionWriter;
import org.briarproject.api.transport.ConnectionWriterFactory;
class OutgoingDuplexConnection extends DuplexConnection {
OutgoingDuplexConnection(Executor dbExecutor, Executor cryptoExecutor,
MessageVerifier messageVerifier, DatabaseComponent db,
ConnectionRegistry connRegistry,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory,
PacketReaderFactory packetReaderFactory,
PacketWriterFactory packetWriterFactory, ConnectionContext ctx,
DuplexTransportConnection transport) {
super(dbExecutor, cryptoExecutor, messageVerifier, db, connRegistry,
connReaderFactory, connWriterFactory, packetReaderFactory,
packetWriterFactory, ctx, transport);
}
@Override
protected ConnectionReader createConnectionReader() throws IOException {
InputStream in = transport.getInputStream();
int maxFrameLength = transport.getMaxFrameLength();
return connReaderFactory.createConnectionReader(in, maxFrameLength,
ctx, false, false);
}
@Override
protected ConnectionWriter createConnectionWriter() throws IOException {
OutputStream out = transport.getOutputStream();
int maxFrameLength = transport.getMaxFrameLength();
return connWriterFactory.createConnectionWriter(out, maxFrameLength,
Long.MAX_VALUE, ctx, false, true);
}
}

View File

@@ -0,0 +1,280 @@
package org.briarproject.messaging.simplex;
import static java.util.logging.Level.WARNING;
import java.io.IOException;
import java.io.InputStream;
import java.security.GeneralSecurityException;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
import org.briarproject.api.ContactId;
import org.briarproject.api.FormatException;
import org.briarproject.api.TransportId;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.messaging.Ack;
import org.briarproject.api.messaging.Message;
import org.briarproject.api.messaging.MessageVerifier;
import org.briarproject.api.messaging.PacketReader;
import org.briarproject.api.messaging.PacketReaderFactory;
import org.briarproject.api.messaging.RetentionAck;
import org.briarproject.api.messaging.RetentionUpdate;
import org.briarproject.api.messaging.SubscriptionAck;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.TransportAck;
import org.briarproject.api.messaging.TransportUpdate;
import org.briarproject.api.messaging.UnverifiedMessage;
import org.briarproject.api.plugins.simplex.SimplexTransportReader;
import org.briarproject.api.transport.ConnectionContext;
import org.briarproject.api.transport.ConnectionReader;
import org.briarproject.api.transport.ConnectionReaderFactory;
import org.briarproject.api.transport.ConnectionRegistry;
import org.briarproject.util.ByteUtils;
class IncomingSimplexConnection {
private static final Logger LOG =
Logger.getLogger(IncomingSimplexConnection.class.getName());
private final Executor dbExecutor, cryptoExecutor;
private final MessageVerifier messageVerifier;
private final DatabaseComponent db;
private final ConnectionRegistry connRegistry;
private final ConnectionReaderFactory connReaderFactory;
private final PacketReaderFactory packetReaderFactory;
private final ConnectionContext ctx;
private final SimplexTransportReader transport;
private final ContactId contactId;
private final TransportId transportId;
IncomingSimplexConnection(Executor dbExecutor, Executor cryptoExecutor,
MessageVerifier messageVerifier, DatabaseComponent db,
ConnectionRegistry connRegistry,
ConnectionReaderFactory connReaderFactory,
PacketReaderFactory packetReaderFactory, ConnectionContext ctx,
SimplexTransportReader transport) {
this.dbExecutor = dbExecutor;
this.cryptoExecutor = cryptoExecutor;
this.messageVerifier = messageVerifier;
this.db = db;
this.connRegistry = connRegistry;
this.connReaderFactory = connReaderFactory;
this.packetReaderFactory = packetReaderFactory;
this.ctx = ctx;
this.transport = transport;
contactId = ctx.getContactId();
transportId = ctx.getTransportId();
}
void read() {
connRegistry.registerConnection(contactId, transportId);
try {
InputStream in = transport.getInputStream();
int maxFrameLength = transport.getMaxFrameLength();
ConnectionReader conn = connReaderFactory.createConnectionReader(in,
maxFrameLength, ctx, true, true);
in = conn.getInputStream();
PacketReader reader = packetReaderFactory.createPacketReader(in);
// Read packets until EOF
while(!reader.eof()) {
if(reader.hasAck()) {
Ack a = reader.readAck();
dbExecutor.execute(new ReceiveAck(a));
} else if(reader.hasMessage()) {
UnverifiedMessage m = reader.readMessage();
cryptoExecutor.execute(new VerifyMessage(m));
} else if(reader.hasRetentionAck()) {
RetentionAck a = reader.readRetentionAck();
dbExecutor.execute(new ReceiveRetentionAck(a));
} else if(reader.hasRetentionUpdate()) {
RetentionUpdate u = reader.readRetentionUpdate();
dbExecutor.execute(new ReceiveRetentionUpdate(u));
} else if(reader.hasSubscriptionAck()) {
SubscriptionAck a = reader.readSubscriptionAck();
dbExecutor.execute(new ReceiveSubscriptionAck(a));
} else if(reader.hasSubscriptionUpdate()) {
SubscriptionUpdate u = reader.readSubscriptionUpdate();
dbExecutor.execute(new ReceiveSubscriptionUpdate(u));
} else if(reader.hasTransportAck()) {
TransportAck a = reader.readTransportAck();
dbExecutor.execute(new ReceiveTransportAck(a));
} else if(reader.hasTransportUpdate()) {
TransportUpdate u = reader.readTransportUpdate();
dbExecutor.execute(new ReceiveTransportUpdate(u));
} else {
throw new FormatException();
}
}
dispose(false, true);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, true);
} finally {
connRegistry.unregisterConnection(contactId, transportId);
}
}
private void dispose(boolean exception, boolean recognised) {
ByteUtils.erase(ctx.getSecret());
try {
transport.dispose(exception, recognised);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
private class ReceiveAck implements Runnable {
private final Ack ack;
private ReceiveAck(Ack ack) {
this.ack = ack;
}
public void run() {
try {
db.receiveAck(contactId, ack);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
private class VerifyMessage implements Runnable {
private final UnverifiedMessage message;
private VerifyMessage(UnverifiedMessage message) {
this.message = message;
}
public void run() {
try {
Message m = messageVerifier.verifyMessage(message);
dbExecutor.execute(new ReceiveMessage(m));
} catch(GeneralSecurityException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
private class ReceiveMessage implements Runnable {
private final Message message;
private ReceiveMessage(Message message) {
this.message = message;
}
public void run() {
try {
db.receiveMessage(contactId, message);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
private class ReceiveRetentionAck implements Runnable {
private final RetentionAck ack;
private ReceiveRetentionAck(RetentionAck ack) {
this.ack = ack;
}
public void run() {
try {
db.receiveRetentionAck(contactId, ack);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
private class ReceiveRetentionUpdate implements Runnable {
private final RetentionUpdate update;
private ReceiveRetentionUpdate(RetentionUpdate update) {
this.update = update;
}
public void run() {
try {
db.receiveRetentionUpdate(contactId, update);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
private class ReceiveSubscriptionAck implements Runnable {
private final SubscriptionAck ack;
private ReceiveSubscriptionAck(SubscriptionAck ack) {
this.ack = ack;
}
public void run() {
try {
db.receiveSubscriptionAck(contactId, ack);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
private class ReceiveSubscriptionUpdate implements Runnable {
private final SubscriptionUpdate update;
private ReceiveSubscriptionUpdate(SubscriptionUpdate update) {
this.update = update;
}
public void run() {
try {
db.receiveSubscriptionUpdate(contactId, update);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
private class ReceiveTransportAck implements Runnable {
private final TransportAck ack;
private ReceiveTransportAck(TransportAck ack) {
this.ack = ack;
}
public void run() {
try {
db.receiveTransportAck(contactId, ack);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
private class ReceiveTransportUpdate implements Runnable {
private final TransportUpdate update;
private ReceiveTransportUpdate(TransportUpdate update) {
this.update = update;
}
public void run() {
try {
db.receiveTransportUpdate(contactId, update);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
}

View File

@@ -0,0 +1,187 @@
package org.briarproject.messaging.simplex;
import static java.util.logging.Level.WARNING;
import static org.briarproject.api.messaging.MessagingConstants.MAX_PACKET_LENGTH;
import java.io.EOFException;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Collection;
import java.util.logging.Logger;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.messaging.Ack;
import org.briarproject.api.messaging.PacketWriter;
import org.briarproject.api.messaging.PacketWriterFactory;
import org.briarproject.api.messaging.RetentionAck;
import org.briarproject.api.messaging.RetentionUpdate;
import org.briarproject.api.messaging.SubscriptionAck;
import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.TransportAck;
import org.briarproject.api.messaging.TransportUpdate;
import org.briarproject.api.plugins.simplex.SimplexTransportWriter;
import org.briarproject.api.transport.ConnectionContext;
import org.briarproject.api.transport.ConnectionRegistry;
import org.briarproject.api.transport.ConnectionWriter;
import org.briarproject.api.transport.ConnectionWriterFactory;
import org.briarproject.util.ByteUtils;
class OutgoingSimplexConnection {
private static final Logger LOG =
Logger.getLogger(OutgoingSimplexConnection.class.getName());
private final DatabaseComponent db;
private final ConnectionRegistry connRegistry;
private final ConnectionWriterFactory connWriterFactory;
private final PacketWriterFactory packetWriterFactory;
private final ConnectionContext ctx;
private final SimplexTransportWriter transport;
private final ContactId contactId;
private final TransportId transportId;
private final long maxLatency;
OutgoingSimplexConnection(DatabaseComponent db,
ConnectionRegistry connRegistry,
ConnectionWriterFactory connWriterFactory,
PacketWriterFactory packetWriterFactory, ConnectionContext ctx,
SimplexTransportWriter transport) {
this.db = db;
this.connRegistry = connRegistry;
this.connWriterFactory = connWriterFactory;
this.packetWriterFactory = packetWriterFactory;
this.ctx = ctx;
this.transport = transport;
contactId = ctx.getContactId();
transportId = ctx.getTransportId();
maxLatency = transport.getMaxLatency();
}
void write() {
connRegistry.registerConnection(contactId, transportId);
try {
OutputStream out = transport.getOutputStream();
long capacity = transport.getCapacity();
int maxFrameLength = transport.getMaxFrameLength();
ConnectionWriter conn = connWriterFactory.createConnectionWriter(
out, maxFrameLength, capacity, ctx, false, true);
out = conn.getOutputStream();
if(conn.getRemainingCapacity() < MAX_PACKET_LENGTH)
throw new EOFException();
PacketWriter writer = packetWriterFactory.createPacketWriter(out,
transport.shouldFlush());
// Send the initial packets: updates and acks
boolean hasSpace = writeTransportAcks(conn, writer);
if(hasSpace) hasSpace = writeTransportUpdates(conn, writer);
if(hasSpace) hasSpace = writeSubscriptionAck(conn, writer);
if(hasSpace) hasSpace = writeSubscriptionUpdate(conn, writer);
if(hasSpace) hasSpace = writeRetentionAck(conn, writer);
if(hasSpace) hasSpace = writeRetentionUpdate(conn, writer);
// Write acks until you can't write acks no more
capacity = conn.getRemainingCapacity();
int maxMessages = writer.getMaxMessagesForRequest(capacity);
Ack a = db.generateAck(contactId, maxMessages);
while(a != null) {
writer.writeAck(a);
capacity = conn.getRemainingCapacity();
maxMessages = writer.getMaxMessagesForRequest(capacity);
a = db.generateAck(contactId, maxMessages);
}
// Write messages until you can't write messages no more
capacity = conn.getRemainingCapacity();
int maxLength = (int) Math.min(capacity, MAX_PACKET_LENGTH);
Collection<byte[]> batch = db.generateBatch(contactId, maxLength,
maxLatency);
while(batch != null) {
for(byte[] raw : batch) writer.writeMessage(raw);
capacity = conn.getRemainingCapacity();
maxLength = (int) Math.min(capacity, MAX_PACKET_LENGTH);
batch = db.generateBatch(contactId, maxLength, maxLatency);
}
writer.flush();
writer.close();
dispose(false);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true);
}
connRegistry.unregisterConnection(contactId, transportId);
}
private boolean writeTransportAcks(ConnectionWriter conn,
PacketWriter writer) throws DbException, IOException {
assert conn.getRemainingCapacity() >= MAX_PACKET_LENGTH;
Collection<TransportAck> acks = db.generateTransportAcks(contactId);
if(acks == null) return true;
for(TransportAck a : acks) {
writer.writeTransportAck(a);
if(conn.getRemainingCapacity() < MAX_PACKET_LENGTH) return false;
}
return true;
}
private boolean writeTransportUpdates(ConnectionWriter conn,
PacketWriter writer) throws DbException, IOException {
assert conn.getRemainingCapacity() >= MAX_PACKET_LENGTH;
Collection<TransportUpdate> updates =
db.generateTransportUpdates(contactId, maxLatency);
if(updates == null) return true;
for(TransportUpdate u : updates) {
writer.writeTransportUpdate(u);
if(conn.getRemainingCapacity() < MAX_PACKET_LENGTH) return false;
}
return true;
}
private boolean writeSubscriptionAck(ConnectionWriter conn,
PacketWriter writer) throws DbException, IOException {
assert conn.getRemainingCapacity() >= MAX_PACKET_LENGTH;
SubscriptionAck a = db.generateSubscriptionAck(contactId);
if(a == null) return true;
writer.writeSubscriptionAck(a);
return conn.getRemainingCapacity() >= MAX_PACKET_LENGTH;
}
private boolean writeSubscriptionUpdate(ConnectionWriter conn,
PacketWriter writer) throws DbException, IOException {
assert conn.getRemainingCapacity() >= MAX_PACKET_LENGTH;
SubscriptionUpdate u =
db.generateSubscriptionUpdate(contactId, maxLatency);
if(u == null) return true;
writer.writeSubscriptionUpdate(u);
return conn.getRemainingCapacity() >= MAX_PACKET_LENGTH;
}
private boolean writeRetentionAck(ConnectionWriter conn,
PacketWriter writer) throws DbException, IOException {
assert conn.getRemainingCapacity() >= MAX_PACKET_LENGTH;
RetentionAck a = db.generateRetentionAck(contactId);
if(a == null) return true;
writer.writeRetentionAck(a);
return conn.getRemainingCapacity() >= MAX_PACKET_LENGTH;
}
private boolean writeRetentionUpdate(ConnectionWriter conn,
PacketWriter writer) throws DbException, IOException {
assert conn.getRemainingCapacity() >= MAX_PACKET_LENGTH;
RetentionUpdate u = db.generateRetentionUpdate(contactId, maxLatency);
if(u == null) return true;
writer.writeRetentionUpdate(u);
return conn.getRemainingCapacity() >= MAX_PACKET_LENGTH;
}
private void dispose(boolean exception) {
ByteUtils.erase(ctx.getSecret());
try {
transport.dispose(exception);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}

View File

@@ -0,0 +1,93 @@
package org.briarproject.messaging.simplex;
import static java.util.logging.Level.WARNING;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.CryptoExecutor;
import org.briarproject.api.crypto.KeyManager;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DatabaseExecutor;
import org.briarproject.api.messaging.MessageVerifier;
import org.briarproject.api.messaging.PacketReaderFactory;
import org.briarproject.api.messaging.PacketWriterFactory;
import org.briarproject.api.messaging.simplex.SimplexConnectionFactory;
import org.briarproject.api.plugins.simplex.SimplexTransportReader;
import org.briarproject.api.plugins.simplex.SimplexTransportWriter;
import org.briarproject.api.transport.ConnectionContext;
import org.briarproject.api.transport.ConnectionReaderFactory;
import org.briarproject.api.transport.ConnectionRegistry;
import org.briarproject.api.transport.ConnectionWriterFactory;
class SimplexConnectionFactoryImpl implements SimplexConnectionFactory {
private static final Logger LOG =
Logger.getLogger(SimplexConnectionFactoryImpl.class.getName());
private final Executor dbExecutor, cryptoExecutor;
private final MessageVerifier messageVerifier;
private final DatabaseComponent db;
private final KeyManager keyManager;
private final ConnectionRegistry connRegistry;
private final ConnectionReaderFactory connReaderFactory;
private final ConnectionWriterFactory connWriterFactory;
private final PacketReaderFactory packetReaderFactory;
private final PacketWriterFactory packetWriterFactory;
@Inject
SimplexConnectionFactoryImpl(@DatabaseExecutor Executor dbExecutor,
@CryptoExecutor Executor cryptoExecutor,
MessageVerifier messageVerifier, DatabaseComponent db,
KeyManager keyManager, ConnectionRegistry connRegistry,
ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory,
PacketReaderFactory packetReaderFactory,
PacketWriterFactory packetWriterFactory) {
this.dbExecutor = dbExecutor;
this.cryptoExecutor = cryptoExecutor;
this.messageVerifier = messageVerifier;
this.db = db;
this.keyManager = keyManager;
this.connRegistry = connRegistry;
this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory;
this.packetReaderFactory = packetReaderFactory;
this.packetWriterFactory = packetWriterFactory;
}
public void createIncomingConnection(ConnectionContext ctx,
SimplexTransportReader r) {
final IncomingSimplexConnection conn = new IncomingSimplexConnection(
dbExecutor, cryptoExecutor, messageVerifier, db, connRegistry,
connReaderFactory, packetReaderFactory, ctx, r);
Runnable read = new Runnable() {
public void run() {
conn.read();
}
};
new Thread(read, "SimplexConnectionReader").start();
}
public void createOutgoingConnection(ContactId c, TransportId t,
SimplexTransportWriter w) {
ConnectionContext ctx = keyManager.getConnectionContext(c, t);
if(ctx == null) {
if(LOG.isLoggable(WARNING))
LOG.warning("Could not create outgoing connection context");
return;
}
final OutgoingSimplexConnection conn = new OutgoingSimplexConnection(db,
connRegistry, connWriterFactory, packetWriterFactory, ctx, w);
Runnable write = new Runnable() {
public void run() {
conn.write();
}
};
new Thread(write, "SimplexConnectionWriter").start();
}
}

View File

@@ -0,0 +1,15 @@
package org.briarproject.messaging.simplex;
import javax.inject.Singleton;
import org.briarproject.api.messaging.simplex.SimplexConnectionFactory;
import com.google.inject.AbstractModule;
public class SimplexMessagingModule extends AbstractModule {
protected void configure() {
bind(SimplexConnectionFactory.class).to(
SimplexConnectionFactoryImpl.class).in(Singleton.class);
}
}

View File

@@ -0,0 +1,362 @@
package org.briarproject.plugins;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportConfig;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.plugins.Plugin;
import org.briarproject.api.plugins.PluginCallback;
import org.briarproject.api.plugins.PluginExecutor;
import org.briarproject.api.plugins.PluginManager;
import org.briarproject.api.plugins.duplex.DuplexPlugin;
import org.briarproject.api.plugins.duplex.DuplexPluginCallback;
import org.briarproject.api.plugins.duplex.DuplexPluginConfig;
import org.briarproject.api.plugins.duplex.DuplexPluginFactory;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.plugins.simplex.SimplexPlugin;
import org.briarproject.api.plugins.simplex.SimplexPluginCallback;
import org.briarproject.api.plugins.simplex.SimplexPluginConfig;
import org.briarproject.api.plugins.simplex.SimplexPluginFactory;
import org.briarproject.api.plugins.simplex.SimplexTransportReader;
import org.briarproject.api.plugins.simplex.SimplexTransportWriter;
import org.briarproject.api.transport.ConnectionDispatcher;
import org.briarproject.api.ui.UiCallback;
// FIXME: Don't make alien calls with a lock held (that includes waiting on a
// latch that depends on an alien call)
class PluginManagerImpl implements PluginManager {
private static final Logger LOG =
Logger.getLogger(PluginManagerImpl.class.getName());
private final Executor pluginExecutor;
private final SimplexPluginConfig simplexPluginConfig;
private final DuplexPluginConfig duplexPluginConfig;
private final DatabaseComponent db;
private final Poller poller;
private final ConnectionDispatcher dispatcher;
private final UiCallback uiCallback;
private final List<SimplexPlugin> simplexPlugins;
private final List<DuplexPlugin> duplexPlugins;
@Inject
PluginManagerImpl(@PluginExecutor Executor pluginExecutor,
SimplexPluginConfig simplexPluginConfig,
DuplexPluginConfig duplexPluginConfig, DatabaseComponent db,
Poller poller, ConnectionDispatcher dispatcher,
UiCallback uiCallback) {
this.pluginExecutor = pluginExecutor;
this.simplexPluginConfig = simplexPluginConfig;
this.duplexPluginConfig = duplexPluginConfig;
this.db = db;
this.poller = poller;
this.dispatcher = dispatcher;
this.uiCallback = uiCallback;
simplexPlugins = new CopyOnWriteArrayList<SimplexPlugin>();
duplexPlugins = new CopyOnWriteArrayList<DuplexPlugin>();
}
public synchronized boolean start() {
// Instantiate and start the simplex plugins
if(LOG.isLoggable(INFO)) LOG.info("Starting simplex plugins");
Collection<SimplexPluginFactory> sFactories =
simplexPluginConfig.getFactories();
final CountDownLatch sLatch = new CountDownLatch(sFactories.size());
for(SimplexPluginFactory factory : sFactories)
pluginExecutor.execute(new SimplexPluginStarter(factory, sLatch));
// Instantiate and start the duplex plugins
if(LOG.isLoggable(INFO)) LOG.info("Starting duplex plugins");
Collection<DuplexPluginFactory> dFactories =
duplexPluginConfig.getFactories();
final CountDownLatch dLatch = new CountDownLatch(dFactories.size());
for(DuplexPluginFactory factory : dFactories)
pluginExecutor.execute(new DuplexPluginStarter(factory, dLatch));
// Wait for the plugins to start
try {
sLatch.await();
dLatch.await();
} catch(InterruptedException e) {
if(LOG.isLoggable(WARNING))
LOG.warning("Interrupted while starting plugins");
Thread.currentThread().interrupt();
return false;
}
// Start the poller
if(LOG.isLoggable(INFO)) LOG.info("Starting poller");
List<Plugin> plugins = new ArrayList<Plugin>();
plugins.addAll(simplexPlugins);
plugins.addAll(duplexPlugins);
poller.start(Collections.unmodifiableList(plugins));
return true;
}
public synchronized boolean stop() {
// Stop the poller
if(LOG.isLoggable(INFO)) LOG.info("Stopping poller");
poller.stop();
int plugins = simplexPlugins.size() + duplexPlugins.size();
final CountDownLatch latch = new CountDownLatch(plugins);
// Stop the simplex plugins
if(LOG.isLoggable(INFO)) LOG.info("Stopping simplex plugins");
for(SimplexPlugin plugin : simplexPlugins)
pluginExecutor.execute(new PluginStopper(plugin, latch));
// Stop the duplex plugins
if(LOG.isLoggable(INFO)) LOG.info("Stopping duplex plugins");
for(DuplexPlugin plugin : duplexPlugins)
pluginExecutor.execute(new PluginStopper(plugin, latch));
simplexPlugins.clear();
duplexPlugins.clear();
// Wait for all the plugins to stop
try {
latch.await();
} catch(InterruptedException e) {
if(LOG.isLoggable(WARNING))
LOG.warning("Interrupted while stopping plugins");
Thread.currentThread().interrupt();
return false;
}
return true;
}
public Collection<DuplexPlugin> getInvitationPlugins() {
List<DuplexPlugin> supported = new ArrayList<DuplexPlugin>();
for(DuplexPlugin d : duplexPlugins)
if(d.supportsInvitations()) supported.add(d);
return Collections.unmodifiableList(supported);
}
private class SimplexPluginStarter implements Runnable {
private final SimplexPluginFactory factory;
private final CountDownLatch latch;
private SimplexPluginStarter(SimplexPluginFactory factory,
CountDownLatch latch) {
this.factory = factory;
this.latch = latch;
}
public void run() {
try {
TransportId id = factory.getId();
SimplexCallback callback = new SimplexCallback(id);
SimplexPlugin plugin = factory.createPlugin(callback);
if(plugin == null) {
if(LOG.isLoggable(INFO)) {
String name = factory.getClass().getSimpleName();
LOG.info(name + " did not create a plugin");
}
return;
}
try {
db.addTransport(id, plugin.getMaxLatency());
} catch(DbException e) {
if(LOG.isLoggable(WARNING))
LOG.log(WARNING, e.toString(), e);
return;
}
try {
if(plugin.start()) {
simplexPlugins.add(plugin);
} else {
if(LOG.isLoggable(INFO)) {
String name = plugin.getClass().getSimpleName();
LOG.info(name + " did not start");
}
}
} catch(IOException e) {
if(LOG.isLoggable(WARNING))
LOG.log(WARNING, e.toString(), e);
}
} finally {
latch.countDown();
}
}
}
private class DuplexPluginStarter implements Runnable {
private final DuplexPluginFactory factory;
private final CountDownLatch latch;
private DuplexPluginStarter(DuplexPluginFactory factory,
CountDownLatch latch) {
this.factory = factory;
this.latch = latch;
}
public void run() {
try {
TransportId id = factory.getId();
DuplexCallback callback = new DuplexCallback(id);
DuplexPlugin plugin = factory.createPlugin(callback);
if(plugin == null) {
if(LOG.isLoggable(INFO)) {
String name = factory.getClass().getSimpleName();
LOG.info(name + " did not create a plugin");
}
return;
}
try {
db.addTransport(id, plugin.getMaxLatency());
} catch(DbException e) {
if(LOG.isLoggable(WARNING))
LOG.log(WARNING, e.toString(), e);
return;
}
try {
if(plugin.start()) {
duplexPlugins.add(plugin);
} else {
if(LOG.isLoggable(INFO)) {
String name = plugin.getClass().getSimpleName();
LOG.info(name + " did not start");
}
}
} catch(IOException e) {
if(LOG.isLoggable(WARNING))
LOG.log(WARNING, e.toString(), e);
}
} finally {
latch.countDown();
}
}
}
private class PluginStopper implements Runnable {
private final Plugin plugin;
private final CountDownLatch latch;
private PluginStopper(Plugin plugin, CountDownLatch latch) {
this.plugin = plugin;
this.latch = latch;
}
public void run() {
try {
plugin.stop();
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
} finally {
latch.countDown();
}
}
}
private abstract class PluginCallbackImpl implements PluginCallback {
protected final TransportId id;
protected PluginCallbackImpl(TransportId id) {
this.id = id;
}
public TransportConfig getConfig() {
try {
return db.getConfig(id);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return new TransportConfig();
}
}
public TransportProperties getLocalProperties() {
try {
TransportProperties p = db.getLocalProperties(id);
return p == null ? new TransportProperties() : p;
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return new TransportProperties();
}
}
public Map<ContactId, TransportProperties> getRemoteProperties() {
try {
return db.getRemoteProperties(id);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return Collections.emptyMap();
}
}
public void mergeConfig(TransportConfig c) {
try {
db.mergeConfig(id, c);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
public void mergeLocalProperties(TransportProperties p) {
try {
db.mergeLocalProperties(id, p);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
public int showChoice(String[] options, String... message) {
return uiCallback.showChoice(options, message);
}
public boolean showConfirmationMessage(String... message) {
return uiCallback.showConfirmationMessage(message);
}
public void showMessage(String... message) {
uiCallback.showMessage(message);
}
}
private class SimplexCallback extends PluginCallbackImpl
implements SimplexPluginCallback {
private SimplexCallback(TransportId id) {
super(id);
}
public void readerCreated(SimplexTransportReader r) {
dispatcher.dispatchReader(id, r);
}
public void writerCreated(ContactId c, SimplexTransportWriter w) {
dispatcher.dispatchWriter(c, id, w);
}
}
private class DuplexCallback extends PluginCallbackImpl
implements DuplexPluginCallback {
private DuplexCallback(TransportId id) {
super(id);
}
public void incomingConnectionCreated(DuplexTransportConnection d) {
dispatcher.dispatchIncomingConnection(id, d);
}
public void outgoingConnectionCreated(ContactId c,
DuplexTransportConnection d) {
dispatcher.dispatchOutgoingConnection(c, id, d);
}
}
}

View File

@@ -0,0 +1,52 @@
package org.briarproject.plugins;
import static java.util.concurrent.TimeUnit.SECONDS;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import javax.inject.Singleton;
import org.briarproject.api.lifecycle.LifecycleManager;
import org.briarproject.api.plugins.PluginExecutor;
import org.briarproject.api.plugins.PluginManager;
import com.google.inject.AbstractModule;
import com.google.inject.Provides;
public class PluginsModule extends AbstractModule {
private final ExecutorService pluginExecutor;
public PluginsModule() {
// The thread pool is unbounded, so use direct handoff
BlockingQueue<Runnable> queue = new SynchronousQueue<Runnable>();
// Discard tasks that are submitted during shutdown
RejectedExecutionHandler policy =
new ThreadPoolExecutor.DiscardPolicy();
// Create threads as required and keep them in the pool for 60 seconds
pluginExecutor = new ThreadPoolExecutor(0, Integer.MAX_VALUE,
60, SECONDS, queue, policy);
}
protected void configure() {
bind(Poller.class).to(PollerImpl.class);
}
@Provides @Singleton
PluginManager getPluginManager(LifecycleManager lifecycleManager,
PluginManagerImpl pluginManager) {
lifecycleManager.register(pluginManager);
return pluginManager;
}
@Provides @Singleton @PluginExecutor
Executor getPluginExecutor(LifecycleManager lifecycleManager) {
lifecycleManager.registerForShutdown(pluginExecutor);
return pluginExecutor;
}
}

View File

@@ -0,0 +1,14 @@
package org.briarproject.plugins;
import java.util.Collection;
import org.briarproject.api.plugins.Plugin;
interface Poller {
/** Starts a new thread to poll the given collection of plugins. */
void start(Collection<Plugin> plugins);
/** Tells the poller thread to exit. */
void stop();
}

View File

@@ -0,0 +1,126 @@
package org.briarproject.plugins;
import static java.util.logging.Level.INFO;
import java.util.Collection;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.briarproject.api.ContactId;
import org.briarproject.api.plugins.Plugin;
import org.briarproject.api.plugins.PluginExecutor;
import org.briarproject.api.system.Clock;
import org.briarproject.api.transport.ConnectionRegistry;
class PollerImpl implements Poller, Runnable {
private static final Logger LOG =
Logger.getLogger(PollerImpl.class.getName());
private final Executor pluginExecutor;
private final ConnectionRegistry connRegistry;
private final Clock clock;
private final SortedSet<PollTime> pollTimes;
@Inject
PollerImpl(@PluginExecutor Executor pluginExecutor,
ConnectionRegistry connRegistry, Clock clock) {
this.pluginExecutor = pluginExecutor;
this.connRegistry = connRegistry;
this.clock = clock;
pollTimes = new TreeSet<PollTime>();
}
public synchronized void start(Collection<Plugin> plugins) {
for(Plugin plugin : plugins) schedule(plugin, true);
new Thread(this, "Poller").start();
}
private synchronized void schedule(Plugin plugin, boolean randomise) {
if(plugin.shouldPoll()) {
long now = clock.currentTimeMillis();
long interval = plugin.getPollingInterval();
// Randomise intervals at startup to spread out connection attempts
if(randomise) interval = (long) (interval * Math.random());
pollTimes.add(new PollTime(now + interval, plugin));
}
}
public synchronized void stop() {
pollTimes.clear();
notifyAll();
}
public void run() {
while(true) {
synchronized(this) {
if(pollTimes.isEmpty()) {
if(LOG.isLoggable(INFO)) LOG.info("Finished polling");
return;
}
long now = clock.currentTimeMillis();
final PollTime p = pollTimes.first();
if(now >= p.time) {
boolean removed = pollTimes.remove(p);
assert removed;
final Collection<ContactId> connected =
connRegistry.getConnectedContacts(p.plugin.getId());
if(LOG.isLoggable(INFO))
LOG.info("Polling " + p.plugin.getClass().getName());
pluginExecutor.execute(new Runnable() {
public void run() {
p.plugin.poll(connected);
}
});
schedule(p.plugin, false);
} else {
try {
wait(p.time - now);
} catch(InterruptedException e) {
if(LOG.isLoggable(INFO))
LOG.info("Interrupted while waiting to poll");
Thread.currentThread().interrupt();
return;
}
}
}
}
}
private static class PollTime implements Comparable<PollTime> {
private final long time;
private final Plugin plugin;
private PollTime(long time, Plugin plugin) {
this.time = time;
this.plugin = plugin;
}
// Must be consistent with equals()
public int compareTo(PollTime p) {
if(time < p.time) return -1;
if(time > p.time) return 1;
return 0;
}
// Must be consistent with equals()
@Override
public int hashCode() {
return (int) (time ^ (time >>> 32)) ^ plugin.hashCode();
}
@Override
public boolean equals(Object o) {
if(o instanceof PollTime) {
PollTime p = (PollTime) o;
return time == p.time && plugin == p.plugin;
}
return false;
}
}
}

View File

@@ -0,0 +1,122 @@
package org.briarproject.plugins.file;
import static java.util.logging.Level.WARNING;
import static org.briarproject.api.transport.TransportConstants.MIN_CONNECTION_LENGTH;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Collection;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
import org.briarproject.api.ContactId;
import org.briarproject.api.plugins.simplex.SimplexPlugin;
import org.briarproject.api.plugins.simplex.SimplexPluginCallback;
import org.briarproject.api.plugins.simplex.SimplexTransportReader;
import org.briarproject.api.plugins.simplex.SimplexTransportWriter;
import org.briarproject.api.system.FileUtils;
public abstract class FilePlugin implements SimplexPlugin {
private static final Logger LOG =
Logger.getLogger(FilePlugin.class.getName());
protected final Executor pluginExecutor;
protected final FileUtils fileUtils;
protected final SimplexPluginCallback callback;
protected final int maxFrameLength;
protected final long maxLatency;
protected volatile boolean running = false;
protected abstract File chooseOutputDirectory();
protected abstract Collection<File> findFilesByName(String filename);
protected abstract void writerFinished(File f);
protected abstract void readerFinished(File f);
protected FilePlugin(Executor pluginExecutor, FileUtils fileUtils,
SimplexPluginCallback callback, int maxFrameLength,
long maxLatency) {
this.pluginExecutor = pluginExecutor;
this.fileUtils = fileUtils;
this.callback = callback;
this.maxFrameLength = maxFrameLength;
this.maxLatency = maxLatency;
}
public int getMaxFrameLength() {
return maxFrameLength;
}
public long getMaxLatency() {
return maxLatency;
}
public SimplexTransportReader createReader(ContactId c) {
return null;
}
public SimplexTransportWriter createWriter(ContactId c) {
if(!running) return null;
return createWriter(createConnectionFilename());
}
private String createConnectionFilename() {
StringBuilder s = new StringBuilder(12);
for(int i = 0; i < 8; i++) s.append((char) ('a' + Math.random() * 26));
s.append(".dat");
return s.toString();
}
// Package access for testing
boolean isPossibleConnectionFilename(String filename) {
return filename.toLowerCase().matches("[a-z]{8}\\.dat");
}
private SimplexTransportWriter createWriter(String filename) {
if(!running) return null;
File dir = chooseOutputDirectory();
if(dir == null || !dir.exists() || !dir.isDirectory()) return null;
File f = new File(dir, filename);
try {
long capacity = fileUtils.getFreeSpace(dir);
if(capacity < MIN_CONNECTION_LENGTH) return null;
OutputStream out = new FileOutputStream(f);
return new FileTransportWriter(f, out, capacity, this);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
f.delete();
return null;
}
}
protected void createReaderFromFile(final File f) {
if(!running) return;
pluginExecutor.execute(new ReaderCreator(f));
}
private class ReaderCreator implements Runnable {
private final File file;
private ReaderCreator(File file) {
this.file = file;
}
public void run() {
if(isPossibleConnectionFilename(file.getName())) {
try {
FileInputStream in = new FileInputStream(file);
callback.readerCreated(new FileTransportReader(file, in,
FilePlugin.this));
} catch(IOException e) {
if(LOG.isLoggable(WARNING))
LOG.log(WARNING, e.toString(), e);
}
}
}
}
}

View File

@@ -0,0 +1,46 @@
package org.briarproject.plugins.file;
import static java.util.logging.Level.WARNING;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.logging.Logger;
import org.briarproject.api.plugins.simplex.SimplexTransportReader;
class FileTransportReader implements SimplexTransportReader {
private static final Logger LOG =
Logger.getLogger(FileTransportReader.class.getName());
private final File file;
private final InputStream in;
private final FilePlugin plugin;
FileTransportReader(File file, InputStream in, FilePlugin plugin) {
this.file = file;
this.in = in;
this.plugin = plugin;
}
public int getMaxFrameLength() {
return plugin.getMaxFrameLength();
}
public InputStream getInputStream() {
return in;
}
public void dispose(boolean exception, boolean recognised) {
try {
in.close();
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
if(recognised) {
file.delete();
plugin.readerFinished(file);
}
}
}

View File

@@ -0,0 +1,59 @@
package org.briarproject.plugins.file;
import static java.util.logging.Level.WARNING;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.util.logging.Logger;
import org.briarproject.api.plugins.simplex.SimplexTransportWriter;
class FileTransportWriter implements SimplexTransportWriter {
private static final Logger LOG =
Logger.getLogger(FileTransportWriter.class.getName());
private final File file;
private final OutputStream out;
private final long capacity;
private final FilePlugin plugin;
FileTransportWriter(File file, OutputStream out, long capacity,
FilePlugin plugin) {
this.file = file;
this.out = out;
this.capacity = capacity;
this.plugin = plugin;
}
public long getCapacity() {
return capacity;
}
public int getMaxFrameLength() {
return plugin.getMaxFrameLength();
}
public long getMaxLatency() {
return plugin.getMaxLatency();
}
public OutputStream getOutputStream() {
return out;
}
public boolean shouldFlush() {
return false;
}
public void dispose(boolean exception) {
try {
out.close();
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
if(exception) file.delete();
else plugin.writerFinished(file);
}
}

View File

@@ -0,0 +1,345 @@
package org.briarproject.plugins.tcp;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.MulticastSocket;
import java.net.NetworkInterface;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.PseudoRandom;
import org.briarproject.api.plugins.duplex.DuplexPluginCallback;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.system.Clock;
import org.briarproject.util.ByteUtils;
import org.briarproject.util.LatchedReference;
import org.briarproject.util.StringUtils;
/** A socket plugin that supports exchanging invitations over a LAN. */
class LanTcpPlugin extends TcpPlugin {
static final byte[] TRANSPORT_ID =
StringUtils.fromHexString("0d79357fd7f74d66c2f6f6ad0f7fff81"
+ "d21c53a43b90b0507ed0683872d8e2fc"
+ "5a88e8f953638228dc26669639757bbf");
static final TransportId ID = new TransportId(TRANSPORT_ID);
private static final Logger LOG =
Logger.getLogger(LanTcpPlugin.class.getName());
private static final int MULTICAST_INTERVAL = 1000; // 1 second
private final Clock clock;
LanTcpPlugin(Executor pluginExecutor, Clock clock,
DuplexPluginCallback callback, int maxFrameLength, long maxLatency,
long pollingInterval) {
super(pluginExecutor, callback, maxFrameLength, maxLatency,
pollingInterval);
this.clock = clock;
}
public TransportId getId() {
return ID;
}
public String getName() {
return "LAN_TCP_PLUGIN_NAME";
}
@Override
protected List<SocketAddress> getLocalSocketAddresses() {
List<SocketAddress> addrs = new ArrayList<SocketAddress>();
// Prefer a previously used address and port if available
TransportProperties p = callback.getLocalProperties();
String addrString = p.get("address");
String portString = p.get("port");
InetAddress addr = null;
if(!StringUtils.isNullOrEmpty(addrString) &&
!StringUtils.isNullOrEmpty(portString)) {
try {
addr = InetAddress.getByName(addrString);
int port = Integer.parseInt(portString);
addrs.add(new InetSocketAddress(addr, port));
addrs.add(new InetSocketAddress(addr, 0));
} catch(NumberFormatException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
} catch(UnknownHostException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
List<NetworkInterface> ifaces;
try {
ifaces = Collections.list(NetworkInterface.getNetworkInterfaces());
} catch(SocketException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return addrs;
}
// Prefer interfaces with link-local or site-local addresses
for(NetworkInterface iface : ifaces) {
for(InetAddress a : Collections.list(iface.getInetAddresses())) {
if(addr != null && a.equals(addr)) continue;
if(a instanceof Inet6Address) continue;
if(a.isLoopbackAddress()) continue;
boolean link = a.isLinkLocalAddress();
boolean site = a.isSiteLocalAddress();
if(link || site) addrs.add(new InetSocketAddress(a, 0));
}
}
// Accept interfaces without link-local or site-local addresses
for(NetworkInterface iface : ifaces) {
for(InetAddress a : Collections.list(iface.getInetAddresses())) {
if(addr != null && a.equals(addr)) continue;
if(a instanceof Inet6Address) continue;
if(a.isLoopbackAddress()) continue;
boolean link = a.isLinkLocalAddress();
boolean site = a.isSiteLocalAddress();
if(!link && !site) addrs.add(new InetSocketAddress(a, 0));
}
}
return addrs;
}
public boolean supportsInvitations() {
return true;
}
public DuplexTransportConnection createInvitationConnection(PseudoRandom r,
long timeout) {
if(!running) return null;
// Use the invitation codes to generate the group address and port
InetSocketAddress group = chooseMulticastGroup(r);
// Bind a multicast socket for sending and receiving packets
InetAddress iface = null;
MulticastSocket ms = null;
try {
iface = chooseInvitationInterface();
if(iface == null) return null;
ms = new MulticastSocket(group.getPort());
ms.setInterface(iface);
ms.joinGroup(group.getAddress());
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
if(ms != null) tryToClose(ms, group.getAddress());
return null;
}
// Bind a server socket for receiving invitation connections
ServerSocket ss = null;
try {
ss = new ServerSocket();
ss.bind(new InetSocketAddress(iface, 0));
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
if(ss != null) tryToClose(ss);
return null;
}
// Start the listener threads
LatchedReference<Socket> socketLatch = new LatchedReference<Socket>();
new MulticastListenerThread(socketLatch, ms, iface).start();
new TcpListenerThread(socketLatch, ss).start();
// Send packets until a connection is made or we run out of time
byte[] buffer = new byte[2];
ByteUtils.writeUint16(ss.getLocalPort(), buffer, 0);
DatagramPacket packet = new DatagramPacket(buffer, buffer.length);
packet.setAddress(group.getAddress());
packet.setPort(group.getPort());
long now = clock.currentTimeMillis();
long end = now + timeout;
try {
while(now < end && running) {
// Send a packet
if(LOG.isLoggable(INFO)) LOG.info("Sending multicast packet");
ms.send(packet);
// Wait for an incoming or outgoing connection
try {
Socket s = socketLatch.waitForReference(MULTICAST_INTERVAL);
if(s != null) return new TcpTransportConnection(this, s);
} catch(InterruptedException e) {
if(LOG.isLoggable(INFO))
LOG.info("Interrupted while exchanging invitations");
Thread.currentThread().interrupt();
return null;
}
now = clock.currentTimeMillis();
}
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
} finally {
// Closing the sockets will terminate the listener threads
tryToClose(ms, group.getAddress());
tryToClose(ss);
}
return null;
}
private InetSocketAddress chooseMulticastGroup(PseudoRandom r) {
byte[] b = r.nextBytes(5);
// The group address is 239.random.random.random, excluding 0 and 255
byte[] group = new byte[4];
group[0] = (byte) 239;
group[1] = legalAddressByte(b[0]);
group[2] = legalAddressByte(b[1]);
group[3] = legalAddressByte(b[2]);
// The port is random in the range 32768 - 65535, inclusive
int port = ByteUtils.readUint16(b, 3);
if(port < 32768) port += 32768;
InetAddress address;
try {
address = InetAddress.getByAddress(group);
} catch(UnknownHostException badAddressLength) {
throw new RuntimeException(badAddressLength);
}
return new InetSocketAddress(address, port);
}
private byte legalAddressByte(byte b) {
if(b == 0) return 1;
if(b == (byte) 255) return (byte) 254;
return b;
}
private InetAddress chooseInvitationInterface() throws IOException {
List<NetworkInterface> ifaces =
Collections.list(NetworkInterface.getNetworkInterfaces());
// Prefer an interface with a link-local or site-local address
for(NetworkInterface iface : ifaces) {
for(InetAddress addr : Collections.list(iface.getInetAddresses())) {
if(addr.isLoopbackAddress()) continue;
boolean link = addr.isLinkLocalAddress();
boolean site = addr.isSiteLocalAddress();
if(link || site) return addr;
}
}
// Accept an interface without a link-local or site-local address
for(NetworkInterface iface : ifaces) {
for(InetAddress addr : Collections.list(iface.getInetAddresses())) {
if(!addr.isLoopbackAddress()) return addr;
}
}
// No suitable interfaces
return null;
}
private void tryToClose(MulticastSocket ms, InetAddress addr) {
try {
ms.leaveGroup(addr);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
ms.close();
}
private class MulticastListenerThread extends Thread {
private final LatchedReference<Socket> socketLatch;
private final MulticastSocket multicastSocket;
private final InetAddress localAddress;
private MulticastListenerThread(LatchedReference<Socket> socketLatch,
MulticastSocket multicastSocket, InetAddress localAddress) {
this.socketLatch = socketLatch;
this.multicastSocket = multicastSocket;
this.localAddress = localAddress;
}
@Override
public void run() {
if(LOG.isLoggable(INFO))
LOG.info("Listening for multicast packets");
// Listen until a valid packet is received or the socket is closed
byte[] buffer = new byte[2];
DatagramPacket packet = new DatagramPacket(buffer, buffer.length);
try {
while(running) {
multicastSocket.receive(packet);
if(LOG.isLoggable(INFO))
LOG.info("Received multicast packet");
parseAndConnectBack(packet);
}
} catch(IOException e) {
// This is expected when the socket is closed
if(LOG.isLoggable(INFO)) LOG.log(INFO, e.toString(), e);
}
}
private void parseAndConnectBack(DatagramPacket packet) {
InetAddress addr = packet.getAddress();
if(addr.equals(localAddress)) {
if(LOG.isLoggable(INFO)) LOG.info("Ignoring own packet");
return;
}
byte[] b = packet.getData();
int off = packet.getOffset();
int len = packet.getLength();
if(len != 2) {
if(LOG.isLoggable(INFO)) LOG.info("Invalid length: " + len);
return;
}
int port = ByteUtils.readUint16(b, off);
if(port < 32768 || port >= 65536) {
if(LOG.isLoggable(INFO)) LOG.info("Invalid port: " + port);
return;
}
if(LOG.isLoggable(INFO))
LOG.info("Packet from " + getHostAddress(addr) + ":" + port);
try {
// Connect back on the advertised TCP port
Socket s = new Socket(addr, port);
if(LOG.isLoggable(INFO)) LOG.info("Outgoing connection");
if(!socketLatch.set(s)) {
if(LOG.isLoggable(INFO))
LOG.info("Closing redundant connection");
s.close();
}
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
private class TcpListenerThread extends Thread {
private final LatchedReference<Socket> socketLatch;
private final ServerSocket serverSocket;
private TcpListenerThread(LatchedReference<Socket> socketLatch,
ServerSocket serverSocket) {
this.socketLatch = socketLatch;
this.serverSocket = serverSocket;
}
@Override
public void run() {
if(LOG.isLoggable(INFO))
LOG.info("Listening for invitation connections");
// Listen until a connection is received or the socket is closed
try {
Socket s = serverSocket.accept();
if(LOG.isLoggable(INFO)) LOG.info("Incoming connection");
if(!socketLatch.set(s)) {
if(LOG.isLoggable(INFO))
LOG.info("Closing redundant connection");
s.close();
}
} catch(IOException e) {
// This is expected when the socket is closed
if(LOG.isLoggable(INFO)) LOG.log(INFO, e.toString(), e);
}
}
}
}

View File

@@ -0,0 +1,34 @@
package org.briarproject.plugins.tcp;
import java.util.concurrent.Executor;
import org.briarproject.api.TransportId;
import org.briarproject.api.plugins.duplex.DuplexPlugin;
import org.briarproject.api.plugins.duplex.DuplexPluginCallback;
import org.briarproject.api.plugins.duplex.DuplexPluginFactory;
import org.briarproject.api.system.Clock;
import org.briarproject.api.system.SystemClock;
public class LanTcpPluginFactory implements DuplexPluginFactory {
private static final int MAX_FRAME_LENGTH = 1024;
private static final long MAX_LATENCY = 60 * 1000; // 1 minute
private static final long POLLING_INTERVAL = 60 * 1000; // 1 minute
private final Executor pluginExecutor;
private final Clock clock;
public LanTcpPluginFactory(Executor pluginExecutor) {
this.pluginExecutor = pluginExecutor;
clock = new SystemClock();
}
public TransportId getId() {
return LanTcpPlugin.ID;
}
public DuplexPlugin createPlugin(DuplexPluginCallback callback) {
return new LanTcpPlugin(pluginExecutor, clock, callback,
MAX_FRAME_LENGTH, MAX_LATENCY, POLLING_INTERVAL);
}
}

View File

@@ -0,0 +1,31 @@
package org.briarproject.plugins.tcp;
import java.net.InetAddress;
import java.net.InetSocketAddress;
class MappingResult {
private final InetAddress internal, external;
private final int port;
private final boolean succeeded;
MappingResult(InetAddress internal, InetAddress external, int port,
boolean succeeded) {
this.internal = internal;
this.external = external;
this.port = port;
this.succeeded = succeeded;
}
InetSocketAddress getInternal() {
return isUsable() ? new InetSocketAddress(internal, port) : null;
}
InetSocketAddress getExternal() {
return isUsable() ? new InetSocketAddress(external, port) : null;
}
boolean isUsable() {
return internal != null && external != null && port != 0 && succeeded;
}
}

View File

@@ -0,0 +1,6 @@
package org.briarproject.plugins.tcp;
interface PortMapper {
MappingResult map(int port);
}

View File

@@ -0,0 +1,97 @@
package org.briarproject.plugins.tcp;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import java.io.IOException;
import java.net.InetAddress;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Logger;
import javax.xml.parsers.ParserConfigurationException;
import org.briarproject.api.lifecycle.ShutdownManager;
import org.bitlet.weupnp.GatewayDevice;
import org.bitlet.weupnp.GatewayDiscover;
import org.xml.sax.SAXException;
class PortMapperImpl implements PortMapper {
private static final Logger LOG =
Logger.getLogger(PortMapperImpl.class.getName());
private final ShutdownManager shutdownManager;
private final AtomicBoolean started = new AtomicBoolean(false);
private volatile GatewayDevice gateway = null;
PortMapperImpl(ShutdownManager shutdownManager) {
this.shutdownManager = shutdownManager;
}
public MappingResult map(final int port) {
if(!started.getAndSet(true)) start();
if(gateway == null) return null;
InetAddress internal = gateway.getLocalAddress();
if(internal == null) return null;
if(LOG.isLoggable(INFO))
LOG.info("Internal address " + getHostAddress(internal));
boolean succeeded = false;
InetAddress external = null;
try {
succeeded = gateway.addPortMapping(port, port,
getHostAddress(internal), "TCP", "TCP");
if(succeeded) {
shutdownManager.addShutdownHook(new Runnable() {
public void run() {
deleteMapping(port);
}
});
}
String externalString = gateway.getExternalIPAddress();
if(LOG.isLoggable(INFO))
LOG.info("External address " + externalString);
if(externalString != null)
external = InetAddress.getByName(externalString);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
} catch(SAXException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
return new MappingResult(internal, external, port, succeeded);
}
private String getHostAddress(InetAddress a) {
String addr = a.getHostAddress();
int percent = addr.indexOf('%');
if(percent == -1) return addr;
return addr.substring(0, percent);
}
private void start() {
GatewayDiscover d = new GatewayDiscover();
try {
d.discover();
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
} catch(SAXException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
} catch(ParserConfigurationException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
gateway = d.getValidGateway();
}
private void deleteMapping(int port) {
try {
gateway.deletePortMapping(port, "TCP");
if(LOG.isLoggable(INFO))
LOG.info("Deleted mapping for port " + port);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
} catch(SAXException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}

View File

@@ -0,0 +1,210 @@
package org.briarproject.plugins.tcp;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.plugins.duplex.DuplexPlugin;
import org.briarproject.api.plugins.duplex.DuplexPluginCallback;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.util.StringUtils;
abstract class TcpPlugin implements DuplexPlugin {
private static final Logger LOG =
Logger.getLogger(TcpPlugin.class.getName());
protected final Executor pluginExecutor;
protected final DuplexPluginCallback callback;
protected final int maxFrameLength;
protected final long maxLatency, pollingInterval;
protected volatile boolean running = false;
private volatile ServerSocket socket = null;
/**
* Returns zero or more socket addresses on which the plugin should listen,
* in order of preference. At most one of the addresses will be bound.
*/
protected abstract List<SocketAddress> getLocalSocketAddresses();
protected TcpPlugin(Executor pluginExecutor, DuplexPluginCallback callback,
int maxFrameLength, long maxLatency, long pollingInterval) {
this.pluginExecutor = pluginExecutor;
this.callback = callback;
this.maxFrameLength = maxFrameLength;
this.maxLatency = maxLatency;
this.pollingInterval = pollingInterval;
}
public int getMaxFrameLength() {
return maxFrameLength;
}
public long getMaxLatency() {
return maxLatency;
}
public boolean start() {
running = true;
pluginExecutor.execute(new Runnable() {
public void run() {
bind();
}
});
return true;
}
private void bind() {
ServerSocket ss = null;
boolean found = false;
for(SocketAddress addr : getLocalSocketAddresses()) {
try {
ss = new ServerSocket();
ss.bind(addr);
found = true;
break;
} catch(IOException e) {
if(LOG.isLoggable(INFO)) LOG.info("Failed to bind " + addr);
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
tryToClose(ss);
continue;
}
}
if(!found) {
if(LOG.isLoggable(INFO)) LOG.info("Could not bind server socket");
return;
}
if(!running) {
tryToClose(ss);
return;
}
socket = ss;
if(LOG.isLoggable(INFO))
LOG.info("Listening on " + ss.getLocalSocketAddress());
setLocalSocketAddress((InetSocketAddress) ss.getLocalSocketAddress());
acceptContactConnections(ss);
}
protected void tryToClose(ServerSocket ss) {
try {
ss.close();
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
protected String getHostAddress(InetAddress a) {
String addr = a.getHostAddress();
int percent = addr.indexOf('%');
return percent == -1 ? addr : addr.substring(0, percent);
}
protected void setLocalSocketAddress(InetSocketAddress a) {
TransportProperties p = new TransportProperties();
p.put("address", getHostAddress(a.getAddress()));
p.put("port", String.valueOf(a.getPort()));
callback.mergeLocalProperties(p);
}
private void acceptContactConnections(ServerSocket ss) {
while(true) {
Socket s;
try {
s = ss.accept();
} catch(IOException e) {
// This is expected when the socket is closed
if(LOG.isLoggable(INFO)) LOG.log(INFO, e.toString(), e);
tryToClose(ss);
return;
}
if(LOG.isLoggable(INFO))
LOG.info("Connection from " + s.getRemoteSocketAddress());
TcpTransportConnection conn = new TcpTransportConnection(this, s);
callback.incomingConnectionCreated(conn);
if(!running) return;
}
}
public void stop() {
running = false;
if(socket != null) tryToClose(socket);
}
public boolean shouldPoll() {
return true;
}
public long getPollingInterval() {
return pollingInterval;
}
public void poll(Collection<ContactId> connected) {
if(!running) return;
Map<ContactId, TransportProperties> remote =
callback.getRemoteProperties();
for(final ContactId c : remote.keySet()) {
if(connected.contains(c)) continue;
pluginExecutor.execute(new Runnable() {
public void run() {
connectAndCallBack(c);
}
});
}
}
private void connectAndCallBack(ContactId c) {
DuplexTransportConnection d = createConnection(c);
if(d != null) callback.outgoingConnectionCreated(c, d);
}
public DuplexTransportConnection createConnection(ContactId c) {
if(!running) return null;
SocketAddress addr = getRemoteSocketAddress(c);
if(addr == null) return null;
Socket s = new Socket();
try {
if(LOG.isLoggable(INFO)) LOG.info("Connecting to " + addr);
s.connect(addr);
if(LOG.isLoggable(INFO)) LOG.info("Connected to " + addr);
return new TcpTransportConnection(this, s);
} catch(IOException e) {
if(LOG.isLoggable(INFO)) LOG.log(INFO, e.toString(), e);
return null;
}
}
private SocketAddress getRemoteSocketAddress(ContactId c) {
TransportProperties p = callback.getRemoteProperties().get(c);
if(p == null) return null;
String addrString = p.get("address");
if(StringUtils.isNullOrEmpty(addrString)) return null;
String portString = p.get("port");
if(StringUtils.isNullOrEmpty(portString)) return null;
try {
InetAddress addr = InetAddress.getByName(addrString);
int port = Integer.parseInt(portString);
return new InetSocketAddress(addr, port);
} catch(NumberFormatException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return null;
} catch(UnknownHostException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return null;
}
}
}

View File

@@ -0,0 +1,45 @@
package org.briarproject.plugins.tcp;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import org.briarproject.api.plugins.Plugin;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
class TcpTransportConnection implements DuplexTransportConnection {
private final Plugin plugin;
private final Socket socket;
TcpTransportConnection(Plugin plugin, Socket socket) {
this.plugin = plugin;
this.socket = socket;
}
public int getMaxFrameLength() {
return plugin.getMaxFrameLength();
}
public long getMaxLatency() {
return plugin.getMaxLatency();
}
public InputStream getInputStream() throws IOException {
return socket.getInputStream();
}
public OutputStream getOutputStream() throws IOException {
return socket.getOutputStream();
}
public boolean shouldFlush() {
return true;
}
public void dispose(boolean exception, boolean recognised)
throws IOException {
socket.close();
}
}

View File

@@ -0,0 +1,132 @@
package org.briarproject.plugins.tcp;
import static java.util.logging.Level.WARNING;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.net.SocketAddress;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
import org.briarproject.api.TransportId;
import org.briarproject.api.TransportProperties;
import org.briarproject.api.crypto.PseudoRandom;
import org.briarproject.api.plugins.duplex.DuplexPluginCallback;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.util.StringUtils;
class WanTcpPlugin extends TcpPlugin {
static final byte[] TRANSPORT_ID =
StringUtils.fromHexString("58c66d999e492b85065924acfd739d80"
+ "c65a62f87e5a4fc6c284f95908b9007d"
+ "512a93ebf89bf68f50a29e96eebf97b6");
static final TransportId ID = new TransportId(TRANSPORT_ID);
private static final Logger LOG =
Logger.getLogger(WanTcpPlugin.class.getName());
private final PortMapper portMapper;
private volatile MappingResult mappingResult;
WanTcpPlugin(Executor pluginExecutor, DuplexPluginCallback callback,
int maxFrameLength, long maxLatency, long pollingInterval,
PortMapper portMapper) {
super(pluginExecutor, callback, maxFrameLength, maxLatency,
pollingInterval);
this.portMapper = portMapper;
}
public TransportId getId() {
return ID;
}
public String getName() {
return "WAN_TCP_PLUGIN_NAME";
}
@Override
protected List<SocketAddress> getLocalSocketAddresses() {
List<SocketAddress> addrs = new ArrayList<SocketAddress>();
// Prefer a previously used address and port if available
TransportProperties p = callback.getLocalProperties();
String addrString = p.get("address");
String portString = p.get("port");
InetAddress addr = null;
int port = 0;
if(!StringUtils.isNullOrEmpty(addrString) &&
!StringUtils.isNullOrEmpty(portString)) {
try {
addr = InetAddress.getByName(addrString);
port = Integer.parseInt(portString);
addrs.add(new InetSocketAddress(addr, port));
addrs.add(new InetSocketAddress(addr, 0));
} catch(NumberFormatException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
} catch(UnknownHostException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
// Get a list of the device's network interfaces
List<NetworkInterface> ifaces;
try {
ifaces = Collections.list(NetworkInterface.getNetworkInterfaces());
} catch(SocketException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return addrs;
}
// Accept interfaces without link-local or site-local addresses
for(NetworkInterface iface : ifaces) {
for(InetAddress a : Collections.list(iface.getInetAddresses())) {
if(addr != null && a.equals(addr)) continue;
if(a instanceof Inet6Address) continue;
if(a.isLoopbackAddress()) continue;
boolean link = a.isLinkLocalAddress();
boolean site = a.isSiteLocalAddress();
if(!link && !site) addrs.add(new InetSocketAddress(a, 0));
}
}
// Accept interfaces with local addresses that can be port-mapped
if(port == 0) port = chooseEphemeralPort();
mappingResult = portMapper.map(port);
if(mappingResult != null && mappingResult.isUsable()) {
InetSocketAddress a = mappingResult.getInternal();
if(!(a.getAddress() instanceof Inet6Address)) addrs.add(a);
}
return addrs;
}
private int chooseEphemeralPort() {
return 32768 + (int) (Math.random() * 32768);
}
@Override
protected void setLocalSocketAddress(InetSocketAddress a) {
if(mappingResult != null && mappingResult.isUsable()) {
// Advertise the external address to contacts
if(a.equals(mappingResult.getInternal()))
a = mappingResult.getExternal();
}
TransportProperties p = new TransportProperties();
p.put("address", getHostAddress(a.getAddress()));
p.put("port", String.valueOf(a.getPort()));
callback.mergeLocalProperties(p);
}
public boolean supportsInvitations() {
return false;
}
public DuplexTransportConnection createInvitationConnection(PseudoRandom r,
long timeout) {
throw new UnsupportedOperationException();
}
}

View File

@@ -0,0 +1,35 @@
package org.briarproject.plugins.tcp;
import java.util.concurrent.Executor;
import org.briarproject.api.TransportId;
import org.briarproject.api.lifecycle.ShutdownManager;
import org.briarproject.api.plugins.duplex.DuplexPlugin;
import org.briarproject.api.plugins.duplex.DuplexPluginCallback;
import org.briarproject.api.plugins.duplex.DuplexPluginFactory;
public class WanTcpPluginFactory implements DuplexPluginFactory {
private static final int MAX_FRAME_LENGTH = 1024;
private static final long MAX_LATENCY = 60 * 1000; // 1 minute
private static final long POLLING_INTERVAL = 5 * 60 * 1000; // 5 minutes
private final Executor pluginExecutor;
private final ShutdownManager shutdownManager;
public WanTcpPluginFactory(Executor pluginExecutor,
ShutdownManager shutdownManager) {
this.pluginExecutor = pluginExecutor;
this.shutdownManager = shutdownManager;
}
public TransportId getId() {
return WanTcpPlugin.ID;
}
public DuplexPlugin createPlugin(DuplexPluginCallback callback) {
return new WanTcpPlugin(pluginExecutor, callback, MAX_FRAME_LENGTH,
MAX_LATENCY, POLLING_INTERVAL,
new PortMapperImpl(shutdownManager));
}
}

View File

@@ -0,0 +1,27 @@
package org.briarproject.reliability;
import org.briarproject.util.ByteUtils;
class Ack extends Frame {
static final int LENGTH = 11;
Ack() {
super(new byte[LENGTH]);
buf[0] = (byte) Frame.ACK_FLAG;
}
Ack(byte[] buf) {
super(buf);
if(buf.length != LENGTH) throw new IllegalArgumentException();
buf[0] = (byte) Frame.ACK_FLAG;
}
int getWindowSize() {
return ByteUtils.readUint16(buf, 5);
}
void setWindowSize(int windowSize) {
ByteUtils.writeUint16(windowSize, buf, 5);
}
}

View File

@@ -0,0 +1,27 @@
package org.briarproject.reliability;
class Crc32 {
private static final long[] TABLE = new long[256];
static {
for(int i = 0; i < 256; i++) {
long c = i;
for(int j = 0; j < 8; j++) {
if((c & 1) != 0) c = 0xedb88320L ^ (c >> 1);
else c >>= 1;
}
TABLE[i] = c;
}
}
private static long update(long c, byte[] b, int off, int len) {
for(int i = off; i < off + len; i++)
c = TABLE[(int) ((c ^ b[i]) & 0xff)] ^ (c >> 8);
return c;
}
static long crc(byte[] b, int off, int len) {
return update(0xffffffffL, b, off, len) ^ 0xffffffffL;
}
}

View File

@@ -0,0 +1,27 @@
package org.briarproject.reliability;
class Data extends Frame {
static final int HEADER_LENGTH = 5, FOOTER_LENGTH = 4;
static final int MIN_LENGTH = HEADER_LENGTH + FOOTER_LENGTH;
static final int MAX_PAYLOAD_LENGTH = 1024;
static final int MAX_LENGTH = MIN_LENGTH + MAX_PAYLOAD_LENGTH;
Data(byte[] buf) {
super(buf);
if(buf.length < MIN_LENGTH || buf.length > MAX_LENGTH)
throw new IllegalArgumentException();
}
boolean isLastFrame() {
return buf[0] == Frame.FIN_FLAG;
}
void setLastFrame(boolean lastFrame) {
if(lastFrame) buf[0] = (byte) Frame.FIN_FLAG;
}
int getPayloadLength() {
return buf.length - MIN_LENGTH;
}
}

View File

@@ -0,0 +1,58 @@
package org.briarproject.reliability;
import org.briarproject.util.ByteUtils;
abstract class Frame {
static final byte ACK_FLAG = (byte) 128, FIN_FLAG = 64;
protected final byte[] buf;
protected Frame(byte[] buf) {
this.buf = buf;
}
byte[] getBuffer() {
return buf;
}
int getLength() {
return buf.length;
}
long getChecksum() {
return ByteUtils.readUint32(buf, buf.length - 4);
}
void setChecksum(long checksum) {
ByteUtils.writeUint32(checksum, buf, buf.length - 4);
}
long calculateChecksum() {
return Crc32.crc(buf, 0, buf.length - 4);
}
long getSequenceNumber() {
return ByteUtils.readUint32(buf, 1);
}
void setSequenceNumber(long sequenceNumber) {
ByteUtils.writeUint32(sequenceNumber, buf, 1);
}
@Override
public int hashCode() {
long sequenceNumber = getSequenceNumber();
return buf[0] ^ (int) (sequenceNumber ^ (sequenceNumber >>> 32));
}
@Override
public boolean equals(Object o) {
if(o instanceof Frame) {
Frame f = (Frame) o;
return buf[0] == f.buf[0] &&
getSequenceNumber() == f.getSequenceNumber();
}
return false;
}
}

View File

@@ -0,0 +1,130 @@
package org.briarproject.reliability;
import java.io.IOException;
import java.util.Comparator;
import java.util.Iterator;
import java.util.SortedSet;
import java.util.TreeSet;
import org.briarproject.api.reliability.ReadHandler;
import org.briarproject.api.system.Clock;
class Receiver implements ReadHandler {
private static final int READ_TIMEOUT = 5 * 60 * 1000; // Milliseconds
private static final int MAX_WINDOW_SIZE = 8 * Data.MAX_PAYLOAD_LENGTH;
private final Clock clock;
private final Sender sender;
private final SortedSet<Data> dataFrames; // Locking: this
private int windowSize = MAX_WINDOW_SIZE; // Locking: this
private long finalSequenceNumber = Long.MAX_VALUE;
private long nextSequenceNumber = 1;
private volatile boolean valid = true;
Receiver(Clock clock, Sender sender) {
this.sender = sender;
this.clock = clock;
dataFrames = new TreeSet<Data>(new SequenceNumberComparator());
}
synchronized Data read() throws IOException, InterruptedException {
long now = clock.currentTimeMillis(), end = now + READ_TIMEOUT;
while(now < end && valid) {
if(dataFrames.isEmpty()) {
// Wait for a data frame
wait(end - now);
} else {
Data d = dataFrames.first();
if(d.getSequenceNumber() == nextSequenceNumber) {
dataFrames.remove(d);
// Update the window
windowSize += d.getPayloadLength();
sender.sendAck(0, windowSize);
nextSequenceNumber++;
return d;
} else {
// Wait for the next in-order data frame
wait(end - now);
}
}
now = clock.currentTimeMillis();
}
if(valid) throw new IOException("Read timed out");
throw new IOException("Connection closed");
}
void invalidate() {
valid = false;
synchronized(this) {
notifyAll();
}
}
public void handleRead(byte[] b) throws IOException {
if(!valid) throw new IOException("Connection closed");
switch(b[0]) {
case 0:
case Frame.FIN_FLAG:
handleData(b);
break;
case Frame.ACK_FLAG:
sender.handleAck(b);
break;
default:
// Ignore unknown frame type
return;
}
}
private synchronized void handleData(byte[] b) throws IOException {
if(b.length < Data.MIN_LENGTH || b.length > Data.MAX_LENGTH) {
// Ignore data frame with invalid length
return;
}
Data d = new Data(b);
int payloadLength = d.getPayloadLength();
if(payloadLength > windowSize) return; // No space in the window
if(d.getChecksum() != d.calculateChecksum()) {
// Ignore data frame with invalid checksum
return;
}
long sequenceNumber = d.getSequenceNumber();
if(sequenceNumber == 0) {
// Window probe
} else if(sequenceNumber < nextSequenceNumber) {
// Duplicate data frame
} else if(d.isLastFrame()) {
finalSequenceNumber = sequenceNumber;
// Remove any data frames with higher sequence numbers
Iterator<Data> it = dataFrames.iterator();
while(it.hasNext()) {
Data d1 = it.next();
if(d1.getSequenceNumber() >= finalSequenceNumber) it.remove();
}
if(dataFrames.add(d)) {
windowSize -= payloadLength;
notifyAll();
}
} else if(sequenceNumber < finalSequenceNumber) {
if(dataFrames.add(d)) {
windowSize -= payloadLength;
notifyAll();
}
}
// Acknowledge the data frame even if it's a duplicate
sender.sendAck(sequenceNumber, windowSize);
}
private static class SequenceNumberComparator implements Comparator<Data> {
public int compare(Data d1, Data d2) {
long s1 = d1.getSequenceNumber(), s2 = d2.getSequenceNumber();
if(s1 < s2) return -1;
if(s1 > s2) return 1;
return 0;
}
}
}

View File

@@ -0,0 +1,59 @@
package org.briarproject.reliability;
import java.io.IOException;
import java.io.InputStream;
class ReceiverInputStream extends InputStream {
private final Receiver receiver;
private Data data = null;
private int offset = 0, length = 0;
ReceiverInputStream(Receiver receiver) {
this.receiver = receiver;
}
@Override
public int read() throws IOException {
if(length == -1) return -1;
while(length == 0) if(!receive()) return -1;
int b = data.getBuffer()[offset] & 0xff;
offset++;
length--;
return b;
}
@Override
public int read(byte[] b) throws IOException {
return read(b, 0, b.length);
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
if(length == -1) return -1;
while(length == 0) if(!receive()) return -1;
len = Math.min(len, length);
System.arraycopy(data.getBuffer(), offset, b, off, len);
offset += len;
length -= len;
return len;
}
private boolean receive() throws IOException {
assert length == 0;
if(data != null && data.isLastFrame()) {
length = -1;
return false;
}
try {
data = receiver.read();
} catch(InterruptedException e) {
Thread.currentThread().interrupt();
throw new IOException("Interrupted while reading");
}
offset = Data.HEADER_LENGTH;
length = data.getLength() - Data.MIN_LENGTH;
return true;
}
}

View File

@@ -0,0 +1,28 @@
package org.briarproject.reliability;
import java.util.concurrent.Executor;
import javax.inject.Inject;
import org.briarproject.api.reliability.ReliabilityExecutor;
import org.briarproject.api.reliability.ReliabilityLayer;
import org.briarproject.api.reliability.ReliabilityLayerFactory;
import org.briarproject.api.reliability.WriteHandler;
import org.briarproject.api.system.Clock;
import org.briarproject.api.system.SystemClock;
class ReliabilityLayerFactoryImpl implements ReliabilityLayerFactory {
private final Executor executor;
private final Clock clock;
@Inject
ReliabilityLayerFactoryImpl(@ReliabilityExecutor Executor executor) {
this.executor = executor;
clock = new SystemClock();
}
public ReliabilityLayer createReliabilityLayer(WriteHandler writeHandler) {
return new ReliabilityLayerImpl(executor, clock, writeHandler);
}
}

View File

@@ -0,0 +1,109 @@
package org.briarproject.reliability;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.logging.Level.WARNING;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.logging.Logger;
import org.briarproject.api.reliability.ReliabilityLayer;
import org.briarproject.api.reliability.WriteHandler;
import org.briarproject.api.system.Clock;
class ReliabilityLayerImpl implements ReliabilityLayer, WriteHandler {
private static final int TICK_INTERVAL = 500; // Milliseconds
private static final Logger LOG =
Logger.getLogger(ReliabilityLayerImpl.class.getName());
private final Executor executor;
private final Clock clock;
private final WriteHandler writeHandler;
private final BlockingQueue<byte[]> writes;
private volatile Receiver receiver = null;
private volatile SlipDecoder decoder = null;
private volatile ReceiverInputStream inputStream = null;
private volatile SenderOutputStream outputStream = null;
private volatile boolean running = false;
ReliabilityLayerImpl(Executor executor, Clock clock,
WriteHandler writeHandler) {
this.executor = executor;
this.clock = clock;
this.writeHandler = writeHandler;
writes = new LinkedBlockingQueue<byte[]>();
}
public void start() {
SlipEncoder encoder = new SlipEncoder(this);
final Sender sender = new Sender(clock, encoder);
receiver = new Receiver(clock, sender);
decoder = new SlipDecoder(receiver, Data.MAX_LENGTH);
inputStream = new ReceiverInputStream(receiver);
outputStream = new SenderOutputStream(sender);
running = true;
executor.execute(new Runnable() {
public void run() {
long now = clock.currentTimeMillis();
long next = now + TICK_INTERVAL;
try {
while(running) {
byte[] b = null;
while(now < next && b == null) {
b = writes.poll(next - now, MILLISECONDS);
if(!running) return;
now = clock.currentTimeMillis();
}
if(b == null) {
sender.tick();
while(next <= now) next += TICK_INTERVAL;
} else {
if(b.length == 0) return; // Poison pill
writeHandler.handleWrite(b);
}
}
} catch(InterruptedException e) {
if(LOG.isLoggable(WARNING))
LOG.warning("Interrupted while waiting to write");
Thread.currentThread().interrupt();
running = false;
} catch(IOException e) {
if(LOG.isLoggable(WARNING))
LOG.log(WARNING, e.toString(), e);
running = false;
}
}
});
}
public void stop() {
running = false;
receiver.invalidate();
writes.add(new byte[0]); // Poison pill
}
public InputStream getInputStream() {
return inputStream;
}
public OutputStream getOutputStream() {
return outputStream;
}
// The lower layer calls this method to pass data up to the SLIP decoder
public void handleRead(byte[] b) throws IOException {
if(running) decoder.handleRead(b);
}
// The SLIP encoder calls this method to pass data down to the lower layer
public void handleWrite(byte[] b) {
if(running && b.length > 0) writes.add(b);
}
}

View File

@@ -0,0 +1,46 @@
package org.briarproject.reliability;
import static java.util.concurrent.TimeUnit.SECONDS;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import javax.inject.Singleton;
import org.briarproject.api.lifecycle.LifecycleManager;
import org.briarproject.api.reliability.ReliabilityExecutor;
import org.briarproject.api.reliability.ReliabilityLayerFactory;
import com.google.inject.AbstractModule;
import com.google.inject.Provides;
public class ReliabilityModule extends AbstractModule {
private final ExecutorService reliabilityExecutor;
public ReliabilityModule() {
// The thread pool is unbounded, so use direct handoff
BlockingQueue<Runnable> queue = new SynchronousQueue<Runnable>();
// Discard tasks that are submitted during shutdown
RejectedExecutionHandler policy =
new ThreadPoolExecutor.DiscardPolicy();
// Create threads as required and keep them in the pool for 60 seconds
reliabilityExecutor = new ThreadPoolExecutor(0, Integer.MAX_VALUE,
60, SECONDS, queue, policy);
}
protected void configure() {
bind(ReliabilityLayerFactory.class).to(
ReliabilityLayerFactoryImpl.class);
}
@Provides @Singleton @ReliabilityExecutor
Executor getReliabilityExecutor(LifecycleManager lifecycleManager) {
lifecycleManager.registerForShutdown(reliabilityExecutor);
return reliabilityExecutor;
}
}

View File

@@ -0,0 +1,188 @@
package org.briarproject.reliability;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.briarproject.api.reliability.WriteHandler;
import org.briarproject.api.system.Clock;
class Sender {
// All times are in milliseconds
private static final int WRITE_TIMEOUT = 5 * 60 * 1000;
private static final int MIN_RTO = 1000;
private static final int MAX_RTO = 60 * 1000;
private static final int INITIAL_RTT = 0;
private static final int INITIAL_RTT_VAR = 3 * 1000;
private static final int MAX_WINDOW_SIZE = 64 * Data.MAX_PAYLOAD_LENGTH;
private final Clock clock;
private final WriteHandler writeHandler;
private final LinkedList<Outstanding> outstanding; // Locking: this
// All of the following are locking: this
private int outstandingBytes = 0;
private int windowSize = Data.MAX_PAYLOAD_LENGTH;
private int rtt = INITIAL_RTT, rttVar = INITIAL_RTT_VAR;
private int rto = rtt + (rttVar << 2);
private long lastWindowUpdateOrProbe = Long.MAX_VALUE;
private boolean dataWaiting = false;
Sender(Clock clock, WriteHandler writeHandler) {
this.clock = clock;
this.writeHandler = writeHandler;
outstanding = new LinkedList<Outstanding>();
}
void sendAck(long sequenceNumber, int windowSize) throws IOException {
Ack a = new Ack();
a.setSequenceNumber(sequenceNumber);
a.setWindowSize(windowSize);
a.setChecksum(a.calculateChecksum());
writeHandler.handleWrite(a.getBuffer());
}
void handleAck(byte[] b) throws IOException {
if(b.length != Ack.LENGTH) {
// Ignore ack frame with invalid length
return;
}
Ack a = new Ack(b);
if(a.getChecksum() != a.calculateChecksum()) {
// Ignore ack frame with invalid checksum
return;
}
long sequenceNumber = a.getSequenceNumber();
long now = clock.currentTimeMillis();
Outstanding fastRetransmit = null;
synchronized(this) {
// Remove the acked data frame if it's outstanding
int foundIndex = -1;
Iterator<Outstanding> it = outstanding.iterator();
for(int i = 0; it.hasNext(); i++) {
Outstanding o = it.next();
if(o.data.getSequenceNumber() == sequenceNumber) {
it.remove();
outstandingBytes -= o.data.getPayloadLength();
foundIndex = i;
// Update the round-trip time and retransmission timeout
if(!o.retransmitted) {
int sample = (int) (now - o.lastTransmitted);
int error = sample - rtt;
rtt += (error >> 3);
rttVar += (Math.abs(error) - rttVar) >> 2;
rto = rtt + (rttVar << 2);
if(rto < MIN_RTO) rto = MIN_RTO;
else if(rto > MAX_RTO) rto = MAX_RTO;
}
break;
}
}
// If any older data frames are outstanding, retransmit the oldest
if(foundIndex > 0) {
fastRetransmit = outstanding.poll();
fastRetransmit.lastTransmitted = now;
fastRetransmit.retransmitted = true;
outstanding.add(fastRetransmit);
}
// Update the window
lastWindowUpdateOrProbe = now;
int oldWindowSize = windowSize;
// Don't accept an unreasonably large window size
windowSize = Math.min(a.getWindowSize(), MAX_WINDOW_SIZE);
// If space has become available, notify any waiting writers
if(windowSize > oldWindowSize || foundIndex != -1) notifyAll();
}
// Fast retransmission
if(fastRetransmit != null)
writeHandler.handleWrite(fastRetransmit.data.getBuffer());
}
void tick() throws IOException {
long now = clock.currentTimeMillis();
List<Outstanding> retransmit = null;
boolean sendProbe = false;
synchronized(this) {
if(outstanding.isEmpty()) {
if(dataWaiting && now - lastWindowUpdateOrProbe > rto) {
sendProbe = true;
rto <<= 1;
if(rto > MAX_RTO) rto = MAX_RTO;
}
} else {
Iterator<Outstanding> it = outstanding.iterator();
while(it.hasNext()) {
Outstanding o = it.next();
if(now - o.lastTransmitted > rto) {
it.remove();
if(retransmit == null)
retransmit = new ArrayList<Outstanding>();
retransmit.add(o);
// Update the retransmission timeout
rto <<= 1;
if(rto > MAX_RTO) rto = MAX_RTO;
}
}
if(retransmit != null) {
for(Outstanding o : retransmit) {
o.lastTransmitted = now;
o.retransmitted = true;
outstanding.add(o);
}
}
}
}
// Send a window probe if necessary
if(sendProbe) {
byte[] buf = new byte[Data.MIN_LENGTH];
Data probe = new Data(buf);
probe.setChecksum(probe.calculateChecksum());
writeHandler.handleWrite(buf);
}
// Retransmit any lost data frames
if(retransmit != null) {
for(Outstanding o : retransmit)
writeHandler.handleWrite(o.data.getBuffer());
}
}
void write(Data d) throws IOException, InterruptedException {
int payloadLength = d.getPayloadLength();
synchronized(this) {
// Wait for space in the window
long now = clock.currentTimeMillis(), end = now + WRITE_TIMEOUT;
while(now < end && outstandingBytes + payloadLength >= windowSize) {
dataWaiting = true;
wait(end - now);
now = clock.currentTimeMillis();
}
if(outstandingBytes + payloadLength >= windowSize)
throw new IOException("Write timed out");
outstanding.add(new Outstanding(d, now));
outstandingBytes += payloadLength;
dataWaiting = false;
}
writeHandler.handleWrite(d.getBuffer());
}
synchronized void flush() throws IOException, InterruptedException {
while(dataWaiting || !outstanding.isEmpty()) wait();
}
private static class Outstanding {
private final Data data;
private volatile long lastTransmitted;
private volatile boolean retransmitted;
private Outstanding(Data data, long lastTransmitted) {
this.data = data;
this.lastTransmitted = lastTransmitted;
retransmitted = false;
}
}
}

View File

@@ -0,0 +1,82 @@
package org.briarproject.reliability;
import java.io.IOException;
import java.io.OutputStream;
class SenderOutputStream extends OutputStream {
private final Sender sender;
private final byte[] buf = new byte[Data.MAX_LENGTH];
private int offset = Data.HEADER_LENGTH;
private long sequenceNumber = 1;
SenderOutputStream(Sender sender) {
this.sender = sender;
}
@Override
public void close() throws IOException {
send(true);
try {
sender.flush();
} catch(InterruptedException e) {
Thread.currentThread().interrupt();
throw new IOException("Interrupted while closing");
}
}
@Override
public void flush() throws IOException {
if(offset > Data.HEADER_LENGTH) send(false);
try {
sender.flush();
} catch(InterruptedException e) {
Thread.currentThread().interrupt();
throw new IOException("Interrupted while flushing");
}
}
@Override
public void write(int b) throws IOException {
buf[offset] = (byte) b;
offset++;
if(offset == Data.HEADER_LENGTH + Data.MAX_PAYLOAD_LENGTH) send(false);
}
@Override
public void write(byte[] b) throws IOException {
write(b, 0, b.length);
}
@Override
public void write(byte[] b, int off, int len) throws IOException {
int available = Data.MAX_LENGTH - offset - Data.FOOTER_LENGTH;
while(available <= len) {
System.arraycopy(b, off, buf, offset, available);
offset += available;
send(false);
off += available;
len -= available;
available = Data.MAX_LENGTH - offset - Data.FOOTER_LENGTH;
}
System.arraycopy(b, off, buf, offset, len);
offset += len;
}
private void send(boolean lastFrame) throws IOException {
byte[] frame = new byte[offset + Data.FOOTER_LENGTH];
System.arraycopy(buf, 0, frame, 0, frame.length);
Data d = new Data(frame);
d.setLastFrame(lastFrame);
d.setSequenceNumber(sequenceNumber++);
d.setChecksum(d.calculateChecksum());
try {
sender.write(d);
} catch(InterruptedException e) {
Thread.currentThread().interrupt();
throw new IOException("Interrupted while writing");
}
offset = Data.HEADER_LENGTH;
}
}

View File

@@ -0,0 +1,75 @@
package org.briarproject.reliability;
import java.io.IOException;
import org.briarproject.api.reliability.ReadHandler;
class SlipDecoder implements ReadHandler {
// https://tools.ietf.org/html/rfc1055
private static final byte END = (byte) 192, ESC = (byte) 219;
private static final byte TEND = (byte) 220, TESC = (byte) 221;
private final ReadHandler readHandler;
private final byte[] buf;
private int decodedLength = 0;
private boolean escape = false;
SlipDecoder(ReadHandler readHandler, int maxDecodedLength) {
this.readHandler = readHandler;
buf = new byte[maxDecodedLength];
}
public void handleRead(byte[] b) throws IOException {
for(int i = 0; i < b.length; i++) {
switch(b[i]) {
case END:
if(escape) {
reset(true);
} else {
if(decodedLength > 0) {
byte[] decoded = new byte[decodedLength];
System.arraycopy(buf, 0, decoded, 0, decodedLength);
readHandler.handleRead(decoded);
}
reset(false);
}
break;
case ESC:
if(escape) reset(true);
else escape = true;
break;
case TEND:
if(escape) {
escape = false;
if(decodedLength == buf.length) reset(true);
else buf[decodedLength++] = END;
} else {
if(decodedLength == buf.length) reset(true);
else buf[decodedLength++] = TEND;
}
break;
case TESC:
if(escape) {
escape = false;
if(decodedLength == buf.length) reset(true);
else buf[decodedLength++] = ESC;
} else {
if(decodedLength == buf.length) reset(true);
else buf[decodedLength++] = TESC;
}
break;
default:
if(escape || decodedLength == buf.length) reset(true);
else buf[decodedLength++] = b[i];
break;
}
}
}
private void reset(boolean error) {
escape = false;
decodedLength = 0;
}
}

View File

@@ -0,0 +1,39 @@
package org.briarproject.reliability;
import java.io.IOException;
import org.briarproject.api.reliability.WriteHandler;
class SlipEncoder implements WriteHandler {
// https://tools.ietf.org/html/rfc1055
private static final byte END = (byte) 192, ESC = (byte) 219;
private static final byte TEND = (byte) 220, TESC = (byte) 221;
private final WriteHandler writeHandler;
SlipEncoder(WriteHandler writeHandler) {
this.writeHandler = writeHandler;
}
public void handleWrite(byte[] b) throws IOException {
int encodedLength = b.length + 2;
for(int i = 0; i < b.length; i++)
if(b[i] == END || b[i] == ESC) encodedLength++;
byte[] encoded = new byte[encodedLength];
encoded[0] = END;
for(int i = 0, j = 1; i < b.length; i++) {
if(b[i] == END) {
encoded[j++] = ESC;
encoded[j++] = TEND;
} else if(b[i] == ESC) {
encoded[j++] = ESC;
encoded[j++] = TESC;
} else {
encoded[j++] = b[i];
}
}
encoded[encodedLength - 1] = END;
writeHandler.handleWrite(encoded);
}
}

View File

@@ -0,0 +1,13 @@
package org.briarproject.serial;
import java.io.InputStream;
import org.briarproject.api.serial.Reader;
import org.briarproject.api.serial.ReaderFactory;
class ReaderFactoryImpl implements ReaderFactory {
public Reader createReader(InputStream in) {
return new ReaderImpl(in);
}
}

View File

@@ -0,0 +1,498 @@
package org.briarproject.serial;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import org.briarproject.api.FormatException;
import org.briarproject.api.serial.Consumer;
import org.briarproject.api.serial.Reader;
// This class is not thread-safe
class ReaderImpl implements Reader {
private static final byte[] EMPTY_BUFFER = new byte[] {};
private final InputStream in;
private final Collection<Consumer> consumers = new ArrayList<Consumer>(0);
private boolean hasLookahead = false, eof = false;
private byte next, nextStructId;
private byte[] buf = new byte[8];
ReaderImpl(InputStream in) {
this.in = in;
}
public boolean eof() throws IOException {
if(!hasLookahead) readLookahead();
return eof;
}
private void readLookahead() throws IOException {
assert !eof;
// If one or two lookahead bytes have been read, feed the consumers
if(hasLookahead) consumeLookahead();
// Read a lookahead byte
int i = in.read();
if(i == -1) {
eof = true;
return;
}
next = (byte) i;
// If necessary, read another lookahead byte
if(next == Tag.STRUCT) {
i = in.read();
if(i == -1) throw new FormatException();
nextStructId = (byte) i;
}
hasLookahead = true;
}
private void consumeLookahead() throws IOException {
assert hasLookahead;
for(Consumer c : consumers) {
c.write(next);
if(next == Tag.STRUCT) c.write(nextStructId);
}
hasLookahead = false;
}
public void close() throws IOException {
in.close();
}
public void addConsumer(Consumer c) {
consumers.add(c);
}
public void removeConsumer(Consumer c) {
if(!consumers.remove(c)) throw new IllegalArgumentException();
}
public boolean hasBoolean() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.FALSE || next == Tag.TRUE;
}
public boolean readBoolean() throws IOException {
if(!hasBoolean()) throw new FormatException();
consumeLookahead();
return next == Tag.TRUE;
}
public void skipBoolean() throws IOException {
if(!hasBoolean()) throw new FormatException();
hasLookahead = false;
}
public boolean hasUint7() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next >= 0;
}
public byte readUint7() throws IOException {
if(!hasUint7()) throw new FormatException();
consumeLookahead();
return next;
}
public void skipUint7() throws IOException {
if(!hasUint7()) throw new FormatException();
hasLookahead = false;
}
public boolean hasInt8() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.INT8;
}
public byte readInt8() throws IOException {
if(!hasInt8()) throw new FormatException();
consumeLookahead();
int i = in.read();
if(i == -1) {
eof = true;
throw new FormatException();
}
byte b = (byte) i;
// Feed the hungry mouths
for(Consumer c : consumers) c.write(b);
return b;
}
public void skipInt8() throws IOException {
if(!hasInt8()) throw new FormatException();
if(in.read() == -1) {
eof = true;
throw new FormatException();
}
hasLookahead = false;
}
public boolean hasInt16() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.INT16;
}
public short readInt16() throws IOException {
if(!hasInt16()) throw new FormatException();
consumeLookahead();
readIntoBuffer(2);
return (short) (((buf[0] & 0xFF) << 8) | (buf[1] & 0xFF));
}
private void readIntoBuffer(int length) throws IOException {
if(buf.length < length) buf = new byte[length];
readIntoBuffer(buf, length);
}
private void readIntoBuffer(byte[] b, int length) throws IOException {
assert !hasLookahead;
int offset = 0;
while(offset < length) {
int read = in.read(b, offset, length - offset);
if(read == -1) {
eof = true;
throw new FormatException();
}
offset += read;
}
// Feed the hungry mouths
for(Consumer c : consumers) c.write(b, 0, length);
}
public void skipInt16() throws IOException {
if(!hasInt16()) throw new FormatException();
hasLookahead = false;
skip(2);
}
private void skip(int length) throws IOException {
while(length > 0) {
int read = in.read(buf, 0, Math.min(length, buf.length));
if(read == -1) {
eof = true;
throw new FormatException();
}
length -= read;
}
}
public boolean hasInt32() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.INT32;
}
public int readInt32() throws IOException {
if(!hasInt32()) throw new FormatException();
consumeLookahead();
return readInt32Bits();
}
private int readInt32Bits() throws IOException {
readIntoBuffer(4);
return ((buf[0] & 0xFF) << 24) | ((buf[1] & 0xFF) << 16) |
((buf[2] & 0xFF) << 8) | (buf[3] & 0xFF);
}
public void skipInt32() throws IOException {
if(!hasInt32()) throw new FormatException();
hasLookahead = false;
skip(4);
}
public boolean hasInt64() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.INT64;
}
public long readInt64() throws IOException {
if(!hasInt64()) throw new FormatException();
consumeLookahead();
return readInt64Bits();
}
private long readInt64Bits() throws IOException {
readIntoBuffer(8);
return ((buf[0] & 0xFFL) << 56) | ((buf[1] & 0xFFL) << 48) |
((buf[2] & 0xFFL) << 40) | ((buf[3] & 0xFFL) << 32) |
((buf[4] & 0xFFL) << 24) | ((buf[5] & 0xFFL) << 16) |
((buf[6] & 0xFFL) << 8) | (buf[7] & 0xFFL);
}
public void skipInt64() throws IOException {
if(!hasInt64()) throw new FormatException();
hasLookahead = false;
skip(8);
}
public boolean hasIntAny() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next >= 0 || next == Tag.INT8 || next == Tag.INT16
|| next == Tag.INT32 || next == Tag.INT64;
}
public long readIntAny() throws IOException {
if(!hasIntAny()) throw new FormatException();
if(next >= 0) return readUint7();
if(next == Tag.INT8) return readInt8();
if(next == Tag.INT16) return readInt16();
if(next == Tag.INT32) return readInt32();
if(next == Tag.INT64) return readInt64();
throw new IllegalStateException();
}
public void skipIntAny() throws IOException {
if(!hasIntAny()) throw new FormatException();
if(next >= 0) skipUint7();
else if(next == Tag.INT8) skipInt8();
else if(next == Tag.INT16) skipInt16();
else if(next == Tag.INT32) skipInt32();
else if(next == Tag.INT64) skipInt64();
else throw new IllegalStateException();
}
public boolean hasFloat32() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.FLOAT32;
}
public float readFloat32() throws IOException {
if(!hasFloat32()) throw new FormatException();
consumeLookahead();
return Float.intBitsToFloat(readInt32Bits());
}
public void skipFloat32() throws IOException {
if(!hasFloat32()) throw new FormatException();
hasLookahead = false;
skip(4);
}
public boolean hasFloat64() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.FLOAT64;
}
public double readFloat64() throws IOException {
if(!hasFloat64()) throw new FormatException();
consumeLookahead();
return Double.longBitsToDouble(readInt64Bits());
}
public void skipFloat64() throws IOException {
if(!hasFloat64()) throw new FormatException();
hasLookahead = false;
skip(8);
}
public boolean hasString() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.STRING;
}
public String readString(int maxLength) throws IOException {
if(!hasString()) throw new FormatException();
consumeLookahead();
int length = readLength();
if(length > maxLength) throw new FormatException();
if(length == 0) return "";
readIntoBuffer(length);
return new String(buf, 0, length, "UTF-8");
}
private int readLength() throws IOException {
if(!hasLength()) throw new FormatException();
int length;
if(next >= 0) length = readUint7();
else if(next == Tag.INT8) length = readInt8();
else if(next == Tag.INT16) length = readInt16();
else if(next == Tag.INT32) length = readInt32();
else throw new IllegalStateException();
if(length < 0) throw new FormatException();
return length;
}
private boolean hasLength() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next >= 0 || next == Tag.INT8 || next == Tag.INT16
|| next == Tag.INT32;
}
public void skipString(int maxLength) throws IOException {
if(!hasString()) throw new FormatException();
hasLookahead = false;
int length = readLength();
if(length > maxLength) throw new FormatException();
hasLookahead = false;
skip(length);
}
public boolean hasBytes() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.BYTES;
}
public byte[] readBytes(int maxLength) throws IOException {
if(!hasBytes()) throw new FormatException();
consumeLookahead();
int length = readLength();
if(length > maxLength) throw new FormatException();
if(length == 0) return EMPTY_BUFFER;
byte[] b = new byte[length];
readIntoBuffer(b, length);
return b;
}
public void skipBytes(int maxLength) throws IOException {
if(!hasBytes()) throw new FormatException();
hasLookahead = false;
int length = readLength();
if(length > maxLength) throw new FormatException();
hasLookahead = false;
skip(length);
}
public boolean hasList() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.LIST;
}
public void readListStart() throws IOException {
if(!hasList()) throw new FormatException();
consumeLookahead();
}
public boolean hasListEnd() throws IOException {
return hasEnd();
}
private boolean hasEnd() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.END;
}
public void readListEnd() throws IOException {
readEnd();
}
private void readEnd() throws IOException {
if(!hasEnd()) throw new FormatException();
consumeLookahead();
}
public void skipList() throws IOException {
if(!hasList()) throw new FormatException();
hasLookahead = false;
while(!hasListEnd()) skipObject();
hasLookahead = false;
}
private void skipObject() throws IOException {
if(hasBoolean()) skipBoolean();
else if(hasIntAny()) skipIntAny();
else if(hasFloat32()) skipFloat32();
else if(hasFloat64()) skipFloat64();
else if(hasString()) skipString(Integer.MAX_VALUE);
else if(hasBytes()) skipBytes(Integer.MAX_VALUE);
else if(hasList()) skipList();
else if(hasMap()) skipMap();
else if(hasStruct()) skipStruct();
else if(hasNull()) skipNull();
else throw new FormatException();
}
public boolean hasMap() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.MAP;
}
public void readMapStart() throws IOException {
if(!hasMap()) throw new FormatException();
consumeLookahead();
}
public boolean hasMapEnd() throws IOException {
return hasEnd();
}
public void readMapEnd() throws IOException {
readEnd();
}
public void skipMap() throws IOException {
if(!hasMap()) throw new FormatException();
hasLookahead = false;
while(!hasMapEnd()) {
skipObject();
skipObject();
}
hasLookahead = false;
}
public boolean hasStruct() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.STRUCT;
}
public boolean hasStruct(int id) throws IOException {
if(id < 0 || id > 255) throw new IllegalArgumentException();
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.STRUCT && (nextStructId & 0xFF) == id;
}
public void readStructStart(int id) throws IOException {
if(!hasStruct(id)) throw new FormatException();
consumeLookahead();
}
public boolean hasStructEnd() throws IOException {
return hasEnd();
}
public void readStructEnd() throws IOException {
readEnd();
}
public void skipStruct() throws IOException {
if(!hasStruct()) throw new FormatException();
hasLookahead = false;
while(!hasStructEnd()) skipObject();
hasLookahead = false;
}
public boolean hasNull() throws IOException {
if(!hasLookahead) readLookahead();
if(eof) return false;
return next == Tag.NULL;
}
public void readNull() throws IOException {
if(!hasNull()) throw new FormatException();
consumeLookahead();
}
public void skipNull() throws IOException {
if(!hasNull()) throw new FormatException();
hasLookahead = false;
}
}

View File

@@ -0,0 +1,39 @@
package org.briarproject.serial;
import org.briarproject.api.UniqueId;
import org.briarproject.api.serial.SerialComponent;
class SerialComponentImpl implements SerialComponent {
public int getSerialisedListStartLength() {
// LIST tag
return 1;
}
public int getSerialisedListEndLength() {
// END tag
return 1;
}
public int getSerialisedStructStartLength(int id) {
// STRUCT tag, ID
return 2;
}
public int getSerialisedStructEndLength() {
// END tag
return 1;
}
public int getSerialisedUniqueIdLength() {
// BYTES tag, length spec, bytes
return 1 + getSerialisedLengthSpecLength(UniqueId.LENGTH)
+ UniqueId.LENGTH;
}
private int getSerialisedLengthSpecLength(int length) {
if(length < 0) throw new IllegalArgumentException();
// Uint7, int16 or int32
return length <= Byte.MAX_VALUE ? 1 : length <= Short.MAX_VALUE ? 3 : 5;
}
}

View File

@@ -0,0 +1,19 @@
package org.briarproject.serial;
import javax.inject.Singleton;
import org.briarproject.api.serial.ReaderFactory;
import org.briarproject.api.serial.SerialComponent;
import org.briarproject.api.serial.WriterFactory;
import com.google.inject.AbstractModule;
public class SerialModule extends AbstractModule {
protected void configure() {
bind(ReaderFactory.class).to(ReaderFactoryImpl.class);
bind(SerialComponent.class).to(
SerialComponentImpl.class).in(Singleton.class);
bind(WriterFactory.class).to(WriterFactoryImpl.class);
}
}

View File

@@ -0,0 +1,20 @@
package org.briarproject.serial;
interface Tag {
byte FALSE = (byte) 0xFF;
byte TRUE = (byte) 0xFE;
byte INT8 = (byte) 0xFD;
byte INT16 = (byte) 0xFC;
byte INT32 = (byte) 0xFB;
byte INT64 = (byte) 0xFA;
byte FLOAT32 = (byte) 0xF9;
byte FLOAT64 = (byte) 0xF8;
byte STRING = (byte) 0xF7;
byte BYTES = (byte) 0xF6;
byte LIST = (byte) 0xF5;
byte MAP = (byte) 0xF4;
byte STRUCT = (byte) 0xF3;
byte END = (byte) 0xF2;
byte NULL = (byte) 0xF1;
}

View File

@@ -0,0 +1,13 @@
package org.briarproject.serial;
import java.io.OutputStream;
import org.briarproject.api.serial.Writer;
import org.briarproject.api.serial.WriterFactory;
class WriterFactoryImpl implements WriterFactory {
public Writer createWriter(OutputStream out) {
return new WriterImpl(out);
}
}

View File

@@ -0,0 +1,203 @@
package org.briarproject.serial;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.briarproject.api.Bytes;
import org.briarproject.api.serial.Consumer;
import org.briarproject.api.serial.Writer;
// This class is not thread-safe
class WriterImpl implements Writer {
private final OutputStream out;
private final Collection<Consumer> consumers = new ArrayList<Consumer>(0);
WriterImpl(OutputStream out) {
this.out = out;
}
public void flush() throws IOException {
out.flush();
}
public void close() throws IOException {
out.close();
}
public void addConsumer(Consumer c) {
consumers.add(c);
}
public void removeConsumer(Consumer c) {
if(!consumers.remove(c)) throw new IllegalArgumentException();
}
public void writeBoolean(boolean b) throws IOException {
if(b) write(Tag.TRUE);
else write(Tag.FALSE);
}
public void writeUint7(byte b) throws IOException {
if(b < 0) throw new IllegalArgumentException();
write(b);
}
public void writeInt8(byte b) throws IOException {
write(Tag.INT8);
write(b);
}
public void writeInt16(short s) throws IOException {
write(Tag.INT16);
write((byte) (s >> 8));
write((byte) ((s << 8) >> 8));
}
public void writeInt32(int i) throws IOException {
write(Tag.INT32);
writeInt32Bits(i);
}
private void writeInt32Bits(int i) throws IOException {
write((byte) (i >> 24));
write((byte) ((i << 8) >> 24));
write((byte) ((i << 16) >> 24));
write((byte) ((i << 24) >> 24));
}
public void writeInt64(long l) throws IOException {
write(Tag.INT64);
writeInt64Bits(l);
}
private void writeInt64Bits(long l) throws IOException {
write((byte) (l >> 56));
write((byte) ((l << 8) >> 56));
write((byte) ((l << 16) >> 56));
write((byte) ((l << 24) >> 56));
write((byte) ((l << 32) >> 56));
write((byte) ((l << 40) >> 56));
write((byte) ((l << 48) >> 56));
write((byte) ((l << 56) >> 56));
}
public void writeIntAny(long l) throws IOException {
if(l >= 0 && l <= Byte.MAX_VALUE)
writeUint7((byte) l);
else if(l >= Byte.MIN_VALUE && l <= Byte.MAX_VALUE)
writeInt8((byte) l);
else if(l >= Short.MIN_VALUE && l <= Short.MAX_VALUE)
writeInt16((short) l);
else if(l >= Integer.MIN_VALUE && l <= Integer.MAX_VALUE)
writeInt32((int) l);
else writeInt64(l);
}
public void writeFloat32(float f) throws IOException {
write(Tag.FLOAT32);
writeInt32Bits(Float.floatToRawIntBits(f));
}
public void writeFloat64(double d) throws IOException {
write(Tag.FLOAT64);
writeInt64Bits(Double.doubleToRawLongBits(d));
}
public void writeString(String s) throws IOException {
byte[] b = s.getBytes("UTF-8");
write(Tag.STRING);
writeLength(b.length);
write(b);
}
private void writeLength(int i) throws IOException {
assert i >= 0;
// Fun fact: it's never worth writing a length as an int8
if(i <= Byte.MAX_VALUE) writeUint7((byte) i);
else if(i <= Short.MAX_VALUE) writeInt16((short) i);
else writeInt32(i);
}
public void writeBytes(byte[] b) throws IOException {
write(Tag.BYTES);
writeLength(b.length);
write(b);
}
public void writeList(Collection<?> c) throws IOException {
write(Tag.LIST);
for(Object o : c) writeObject(o);
write(Tag.END);
}
private void writeObject(Object o) throws IOException {
if(o instanceof Boolean) writeBoolean((Boolean) o);
else if(o instanceof Byte) writeIntAny((Byte) o);
else if(o instanceof Short) writeIntAny((Short) o);
else if(o instanceof Integer) writeIntAny((Integer) o);
else if(o instanceof Long) writeIntAny((Long) o);
else if(o instanceof Float) writeFloat32((Float) o);
else if(o instanceof Double) writeFloat64((Double) o);
else if(o instanceof String) writeString((String) o);
else if(o instanceof Bytes) writeBytes(((Bytes) o).getBytes());
else if(o instanceof List<?>) writeList((List<?>) o);
else if(o instanceof Map<?, ?>) writeMap((Map<?, ?>) o);
else if(o == null) writeNull();
else throw new IllegalStateException();
}
public void writeListStart() throws IOException {
write(Tag.LIST);
}
public void writeListEnd() throws IOException {
write(Tag.END);
}
public void writeMap(Map<?, ?> m) throws IOException {
write(Tag.MAP);
for(Entry<?, ?> e : m.entrySet()) {
writeObject(e.getKey());
writeObject(e.getValue());
}
write(Tag.END);
}
public void writeMapStart() throws IOException {
write(Tag.MAP);
}
public void writeMapEnd() throws IOException {
write(Tag.END);
}
public void writeStructStart(int id) throws IOException {
if(id < 0 || id > 255) throw new IllegalArgumentException();
write(Tag.STRUCT);
write((byte) id);
}
public void writeStructEnd() throws IOException {
write(Tag.END);
}
public void writeNull() throws IOException {
write(Tag.NULL);
}
private void write(byte b) throws IOException {
out.write(b);
for(Consumer c : consumers) c.write(b);
}
private void write(byte[] b) throws IOException {
out.write(b);
for(Consumer c : consumers) c.write(b, 0, b.length);
}
}

View File

@@ -0,0 +1,16 @@
package org.briarproject.system;
import org.briarproject.api.system.Clock;
import org.briarproject.api.system.SystemClock;
import org.briarproject.api.system.SystemTimer;
import org.briarproject.api.system.Timer;
import com.google.inject.AbstractModule;
public class SystemModule extends AbstractModule {
protected void configure() {
bind(Clock.class).to(SystemClock.class);
bind(Timer.class).to(SystemTimer.class);
}
}

View File

@@ -0,0 +1,160 @@
package org.briarproject.transport;
import static java.util.logging.Level.WARNING;
import static org.briarproject.api.transport.TransportConstants.TAG_LENGTH;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
import javax.inject.Inject;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.db.DbException;
import org.briarproject.api.messaging.duplex.DuplexConnectionFactory;
import org.briarproject.api.messaging.simplex.SimplexConnectionFactory;
import org.briarproject.api.plugins.duplex.DuplexTransportConnection;
import org.briarproject.api.plugins.simplex.SimplexTransportReader;
import org.briarproject.api.plugins.simplex.SimplexTransportWriter;
import org.briarproject.api.transport.ConnectionContext;
import org.briarproject.api.transport.ConnectionDispatcher;
import org.briarproject.api.transport.ConnectionRecogniser;
import org.briarproject.api.transport.IncomingConnectionExecutor;
class ConnectionDispatcherImpl implements ConnectionDispatcher {
private static final Logger LOG =
Logger.getLogger(ConnectionDispatcherImpl.class.getName());
private final Executor connExecutor;
private final ConnectionRecogniser recogniser;
private final SimplexConnectionFactory simplexConnFactory;
private final DuplexConnectionFactory duplexConnFactory;
@Inject
ConnectionDispatcherImpl(@IncomingConnectionExecutor Executor connExecutor,
ConnectionRecogniser recogniser,
SimplexConnectionFactory simplexConnFactory,
DuplexConnectionFactory duplexConnFactory) {
this.connExecutor = connExecutor;
this.recogniser = recogniser;
this.simplexConnFactory = simplexConnFactory;
this.duplexConnFactory = duplexConnFactory;
}
public void dispatchReader(TransportId t, SimplexTransportReader r) {
connExecutor.execute(new DispatchSimplexConnection(t, r));
}
public void dispatchWriter(ContactId c, TransportId t,
SimplexTransportWriter w) {
simplexConnFactory.createOutgoingConnection(c, t, w);
}
public void dispatchIncomingConnection(TransportId t,
DuplexTransportConnection d) {
connExecutor.execute(new DispatchDuplexConnection(t, d));
}
public void dispatchOutgoingConnection(ContactId c, TransportId t,
DuplexTransportConnection d) {
duplexConnFactory.createOutgoingConnection(c, t, d);
}
private byte[] readTag(InputStream in) throws IOException {
byte[] b = new byte[TAG_LENGTH];
int offset = 0;
while(offset < b.length) {
int read = in.read(b, offset, b.length - offset);
if(read == -1) throw new EOFException();
offset += read;
}
return b;
}
private class DispatchSimplexConnection implements Runnable {
private final TransportId transportId;
private final SimplexTransportReader transport;
private DispatchSimplexConnection(TransportId transportId,
SimplexTransportReader transport) {
this.transportId = transportId;
this.transport = transport;
}
public void run() {
try {
byte[] tag = readTag(transport.getInputStream());
ConnectionContext ctx = recogniser.acceptConnection(transportId,
tag);
if(ctx == null) {
transport.dispose(false, false);
} else {
simplexConnFactory.createIncomingConnection(ctx,
transport);
}
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
try {
transport.dispose(true, false);
} catch(IOException e1) {
if(LOG.isLoggable(WARNING))
LOG.log(WARNING, e1.toString(), e1);
}
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
try {
transport.dispose(true, false);
} catch(IOException e1) {
if(LOG.isLoggable(WARNING))
LOG.log(WARNING, e1.toString(), e1);
}
}
}
}
private class DispatchDuplexConnection implements Runnable {
private final TransportId transportId;
private final DuplexTransportConnection transport;
private DispatchDuplexConnection(TransportId transportId,
DuplexTransportConnection transport) {
this.transportId = transportId;
this.transport = transport;
}
public void run() {
byte[] tag;
try {
tag = readTag(transport.getInputStream());
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, false);
return;
}
ConnectionContext ctx = null;
try {
ctx = recogniser.acceptConnection(transportId, tag);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
dispose(true, false);
return;
}
if(ctx == null) dispose(false, false);
else duplexConnFactory.createIncomingConnection(ctx, transport);
}
private void dispose(boolean exception, boolean recognised) {
try {
transport.dispose(exception, recognised);
} catch(IOException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
}
}

View File

@@ -0,0 +1,43 @@
package org.briarproject.transport;
import java.io.InputStream;
import javax.inject.Inject;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.transport.ConnectionContext;
import org.briarproject.api.transport.ConnectionReader;
import org.briarproject.api.transport.ConnectionReaderFactory;
class ConnectionReaderFactoryImpl implements ConnectionReaderFactory {
private final CryptoComponent crypto;
@Inject
ConnectionReaderFactoryImpl(CryptoComponent crypto) {
this.crypto = crypto;
}
public ConnectionReader createConnectionReader(InputStream in,
int maxFrameLength, ConnectionContext ctx, boolean incoming,
boolean initiator) {
byte[] secret = ctx.getSecret();
long connection = ctx.getConnectionNumber();
boolean weAreAlice = ctx.getAlice();
boolean initiatorIsAlice = incoming ? !weAreAlice : weAreAlice;
SecretKey frameKey = crypto.deriveFrameKey(secret, connection,
initiatorIsAlice, initiator);
FrameReader encryption = new IncomingEncryptionLayer(in,
crypto.getFrameCipher(), frameKey, maxFrameLength);
return new ConnectionReaderImpl(encryption, maxFrameLength);
}
public ConnectionReader createInvitationConnectionReader(InputStream in,
int maxFrameLength, byte[] secret, boolean alice) {
SecretKey frameKey = crypto.deriveFrameKey(secret, 0, true, alice);
FrameReader encryption = new IncomingEncryptionLayer(in,
crypto.getFrameCipher(), frameKey, maxFrameLength);
return new ConnectionReaderImpl(encryption, maxFrameLength);
}
}

View File

@@ -0,0 +1,62 @@
package org.briarproject.transport;
import static org.briarproject.api.transport.TransportConstants.HEADER_LENGTH;
import static org.briarproject.api.transport.TransportConstants.MAC_LENGTH;
import java.io.IOException;
import java.io.InputStream;
import org.briarproject.api.transport.ConnectionReader;
class ConnectionReaderImpl extends InputStream implements ConnectionReader {
private final FrameReader in;
private final byte[] frame;
private int offset = 0, length = 0;
ConnectionReaderImpl(FrameReader in, int frameLength) {
this.in = in;
frame = new byte[frameLength - MAC_LENGTH];
}
public InputStream getInputStream() {
return this;
}
@Override
public int read() throws IOException {
while(length <= 0) {
if(length == -1) return -1;
readFrame();
}
int b = frame[offset] & 0xff;
offset++;
length--;
return b;
}
@Override
public int read(byte[] b) throws IOException {
return read(b, 0, b.length);
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
while(length <= 0) {
if(length == -1) return -1;
readFrame();
}
len = Math.min(len, length);
System.arraycopy(frame, offset, b, off, len);
offset += len;
length -= len;
return len;
}
private void readFrame() throws IOException {
assert length == 0;
offset = HEADER_LENGTH;
length = in.readFrame(frame);
}
}

View File

@@ -0,0 +1,75 @@
package org.briarproject.transport;
import java.util.HashMap;
import java.util.Map;
import javax.inject.Inject;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.db.DatabaseComponent;
import org.briarproject.api.db.DbException;
import org.briarproject.api.transport.ConnectionContext;
import org.briarproject.api.transport.ConnectionRecogniser;
import org.briarproject.api.transport.TemporarySecret;
class ConnectionRecogniserImpl implements ConnectionRecogniser {
private final CryptoComponent crypto;
private final DatabaseComponent db;
// Locking: this
private final Map<TransportId, TransportConnectionRecogniser> recognisers;
@Inject
ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db) {
this.crypto = crypto;
this.db = db;
recognisers = new HashMap<TransportId, TransportConnectionRecogniser>();
}
public ConnectionContext acceptConnection(TransportId t, byte[] tag)
throws DbException {
TransportConnectionRecogniser r;
synchronized(this) {
r = recognisers.get(t);
}
if(r == null) return null;
return r.acceptConnection(tag);
}
public void addSecret(TemporarySecret s) {
TransportId t = s.getTransportId();
TransportConnectionRecogniser r;
synchronized(this) {
r = recognisers.get(t);
if(r == null) {
r = new TransportConnectionRecogniser(crypto, db, t);
recognisers.put(t, r);
}
}
r.addSecret(s);
}
public void removeSecret(ContactId c, TransportId t, long period) {
TransportConnectionRecogniser r;
synchronized(this) {
r = recognisers.get(t);
}
if(r != null) r.removeSecret(c, period);
}
public synchronized void removeSecrets(ContactId c) {
for(TransportConnectionRecogniser r : recognisers.values())
r.removeSecrets(c);
}
public synchronized void removeSecrets(TransportId t) {
recognisers.remove(t);
}
public synchronized void removeSecrets() {
for(TransportConnectionRecogniser r : recognisers.values())
r.removeSecrets();
}
}

View File

@@ -0,0 +1,110 @@
package org.briarproject.transport;
import static java.util.logging.Level.INFO;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.logging.Logger;
import org.briarproject.api.ContactId;
import org.briarproject.api.TransportId;
import org.briarproject.api.transport.ConnectionListener;
import org.briarproject.api.transport.ConnectionRegistry;
class ConnectionRegistryImpl implements ConnectionRegistry {
private static final Logger LOG =
Logger.getLogger(ConnectionRegistryImpl.class.getName());
// Locking: this
private final Map<TransportId, Map<ContactId, Integer>> connections;
// Locking: this
private final Map<ContactId, Integer> contactCounts;
private final List<ConnectionListener> listeners;
ConnectionRegistryImpl() {
connections = new HashMap<TransportId, Map<ContactId, Integer>>();
contactCounts = new HashMap<ContactId, Integer>();
listeners = new CopyOnWriteArrayList<ConnectionListener>();
}
public void addListener(ConnectionListener c) {
listeners.add(c);
}
public void removeListener(ConnectionListener c) {
listeners.remove(c);
}
public void registerConnection(ContactId c, TransportId t) {
if(LOG.isLoggable(INFO)) LOG.info("Connection registered");
boolean firstConnection = false;
synchronized(this) {
Map<ContactId, Integer> m = connections.get(t);
if(m == null) {
m = new HashMap<ContactId, Integer>();
connections.put(t, m);
}
Integer count = m.get(c);
if(count == null) m.put(c, 1);
else m.put(c, count + 1);
count = contactCounts.get(c);
if(count == null) {
firstConnection = true;
contactCounts.put(c, 1);
} else {
contactCounts.put(c, count + 1);
}
}
if(firstConnection) {
if(LOG.isLoggable(INFO)) LOG.info("Contact connected");
for(ConnectionListener l : listeners) l.contactConnected(c);
}
}
public void unregisterConnection(ContactId c, TransportId t) {
if(LOG.isLoggable(INFO)) LOG.info("Connection unregistered");
boolean lastConnection = false;
synchronized(this) {
Map<ContactId, Integer> m = connections.get(t);
if(m == null) throw new IllegalArgumentException();
Integer count = m.remove(c);
if(count == null) throw new IllegalArgumentException();
if(count == 1) {
if(m.isEmpty()) connections.remove(t);
} else {
m.put(c, count - 1);
}
count = contactCounts.get(c);
if(count == null) throw new IllegalArgumentException();
if(count == 1) {
lastConnection = true;
contactCounts.remove(c);
} else {
contactCounts.put(c, count - 1);
}
}
if(lastConnection) {
if(LOG.isLoggable(INFO)) LOG.info("Contact disconnected");
for(ConnectionListener l : listeners) l.contactDisconnected(c);
}
}
public synchronized Collection<ContactId> getConnectedContacts(
TransportId t) {
Map<ContactId, Integer> m = connections.get(t);
if(m == null) return Collections.emptyList();
List<ContactId> ids = new ArrayList<ContactId>(m.keySet());
if(LOG.isLoggable(INFO)) LOG.info(ids.size() + " contacts connected");
return Collections.unmodifiableList(ids);
}
public synchronized boolean isConnected(ContactId c) {
return contactCounts.containsKey(c);
}
}

View File

@@ -0,0 +1,102 @@
package org.briarproject.transport;
import static org.briarproject.api.transport.TransportConstants.CONNECTION_WINDOW_SIZE;
import static org.briarproject.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
// This class is not thread-safe
class ConnectionWindow {
private final Set<Long> unseen;
private long centre;
ConnectionWindow() {
unseen = new HashSet<Long>();
for(long l = 0; l < CONNECTION_WINDOW_SIZE / 2; l++) unseen.add(l);
centre = 0;
}
ConnectionWindow(long centre, byte[] bitmap) {
if(centre < 0 || centre > MAX_32_BIT_UNSIGNED + 1)
throw new IllegalArgumentException();
if(bitmap.length != CONNECTION_WINDOW_SIZE / 8)
throw new IllegalArgumentException();
this.centre = centre;
unseen = new HashSet<Long>();
long bitmapBottom = centre - CONNECTION_WINDOW_SIZE / 2;
for(int bytes = 0; bytes < bitmap.length; bytes++) {
for(int bits = 0; bits < 8; bits++) {
long connection = bitmapBottom + bytes * 8 + bits;
if(connection >= 0 && connection <= MAX_32_BIT_UNSIGNED) {
if((bitmap[bytes] & (128 >> bits)) == 0)
unseen.add(connection);
}
}
}
}
boolean isSeen(long connection) {
return !unseen.contains(connection);
}
Collection<Long> setSeen(long connection) {
long bottom = getBottom(centre);
long top = getTop(centre);
if(connection < bottom || connection > top)
throw new IllegalArgumentException();
if(!unseen.remove(connection))
throw new IllegalArgumentException();
Collection<Long> changed = new ArrayList<Long>();
if(connection >= centre) {
centre = connection + 1;
long newBottom = getBottom(centre);
long newTop = getTop(centre);
for(long l = bottom; l < newBottom; l++) {
if(unseen.remove(l)) changed.add(l);
}
for(long l = top + 1; l <= newTop; l++) {
if(unseen.add(l)) changed.add(l);
}
}
return changed;
}
long getCentre() {
return centre;
}
byte[] getBitmap() {
byte[] bitmap = new byte[CONNECTION_WINDOW_SIZE / 8];
long bitmapBottom = centre - CONNECTION_WINDOW_SIZE / 2;
for(int bytes = 0; bytes < bitmap.length; bytes++) {
for(int bits = 0; bits < 8; bits++) {
long connection = bitmapBottom + bytes * 8 + bits;
if(connection >= 0 && connection <= MAX_32_BIT_UNSIGNED) {
if(!unseen.contains(connection))
bitmap[bytes] |= 128 >> bits;
}
}
}
return bitmap;
}
// Returns the lowest value contained in a window with the given centre
private static long getBottom(long centre) {
return Math.max(0, centre - CONNECTION_WINDOW_SIZE / 2);
}
// Returns the highest value contained in a window with the given centre
private static long getTop(long centre) {
return Math.min(MAX_32_BIT_UNSIGNED,
centre + CONNECTION_WINDOW_SIZE / 2 - 1);
}
public Collection<Long> getUnseen() {
return unseen;
}
}

View File

@@ -0,0 +1,56 @@
package org.briarproject.transport;
import static org.briarproject.api.transport.TransportConstants.TAG_LENGTH;
import java.io.OutputStream;
import javax.inject.Inject;
import org.briarproject.api.crypto.CryptoComponent;
import org.briarproject.api.crypto.SecretKey;
import org.briarproject.api.transport.ConnectionContext;
import org.briarproject.api.transport.ConnectionWriter;
import org.briarproject.api.transport.ConnectionWriterFactory;
class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
private final CryptoComponent crypto;
@Inject
ConnectionWriterFactoryImpl(CryptoComponent crypto) {
this.crypto = crypto;
}
public ConnectionWriter createConnectionWriter(OutputStream out,
int maxFrameLength, long capacity, ConnectionContext ctx,
boolean incoming, boolean initiator) {
byte[] secret = ctx.getSecret();
long connection = ctx.getConnectionNumber();
boolean weAreAlice = ctx.getAlice();
boolean initiatorIsAlice = incoming ? !weAreAlice : weAreAlice;
SecretKey frameKey = crypto.deriveFrameKey(secret, connection,
initiatorIsAlice, initiator);
FrameWriter encryption;
if(initiator) {
byte[] tag = new byte[TAG_LENGTH];
SecretKey tagKey = crypto.deriveTagKey(secret, initiatorIsAlice);
crypto.encodeTag(tag, tagKey, connection);
tagKey.erase();
encryption = new OutgoingEncryptionLayer(out, capacity,
crypto.getFrameCipher(), frameKey, maxFrameLength, tag);
} else {
encryption = new OutgoingEncryptionLayer(out, capacity,
crypto.getFrameCipher(), frameKey, maxFrameLength);
}
return new ConnectionWriterImpl(encryption, maxFrameLength);
}
public ConnectionWriter createInvitationConnectionWriter(OutputStream out,
int maxFrameLength, byte[] secret, boolean alice) {
SecretKey frameKey = crypto.deriveFrameKey(secret, 0, true, alice);
FrameWriter encryption = new OutgoingEncryptionLayer(out,
Long.MAX_VALUE, crypto.getFrameCipher(), frameKey,
maxFrameLength);
return new ConnectionWriterImpl(encryption, maxFrameLength);
}
}

View File

@@ -0,0 +1,84 @@
package org.briarproject.transport;
import static org.briarproject.api.transport.TransportConstants.HEADER_LENGTH;
import static org.briarproject.api.transport.TransportConstants.MAC_LENGTH;
import java.io.IOException;
import java.io.OutputStream;
import org.briarproject.api.transport.ConnectionWriter;
/**
* A ConnectionWriter that buffers its input and writes a frame whenever there
* is a full frame to write or the {@link #flush()} method is called.
* <p>
* This class is not thread-safe.
*/
class ConnectionWriterImpl extends OutputStream implements ConnectionWriter {
private final FrameWriter out;
private final byte[] frame;
private final int frameLength;
private int length = 0;
ConnectionWriterImpl(FrameWriter out, int frameLength) {
this.out = out;
this.frameLength = frameLength;
frame = new byte[frameLength - MAC_LENGTH];
}
public OutputStream getOutputStream() {
return this;
}
public long getRemainingCapacity() {
return out.getRemainingCapacity();
}
@Override
public void close() throws IOException {
writeFrame(true);
out.flush();
super.close();
}
@Override
public void flush() throws IOException {
if(length > 0) writeFrame(false);
out.flush();
}
@Override
public void write(int b) throws IOException {
frame[HEADER_LENGTH + length] = (byte) b;
length++;
if(HEADER_LENGTH + length + MAC_LENGTH == frameLength)
writeFrame(false);
}
@Override
public void write(byte[] b) throws IOException {
write(b, 0, b.length);
}
@Override
public void write(byte[] b, int off, int len) throws IOException {
int available = frameLength - HEADER_LENGTH - length - MAC_LENGTH;
while(available <= len) {
System.arraycopy(b, off, frame, HEADER_LENGTH + length, available);
length += available;
writeFrame(false);
off += available;
len -= available;
available = frameLength - HEADER_LENGTH - length - MAC_LENGTH;
}
System.arraycopy(b, off, frame, HEADER_LENGTH + length, len);
length += len;
}
private void writeFrame(boolean finalFrame) throws IOException {
out.writeFrame(frame, length, finalFrame);
length = 0;
}
}

View File

@@ -0,0 +1,53 @@
package org.briarproject.transport;
import static org.briarproject.api.transport.TransportConstants.AAD_LENGTH;
import static org.briarproject.api.transport.TransportConstants.HEADER_LENGTH;
import static org.briarproject.api.transport.TransportConstants.IV_LENGTH;
import static org.briarproject.api.transport.TransportConstants.MAC_LENGTH;
import static org.briarproject.api.transport.TransportConstants.MAX_FRAME_LENGTH;
import static org.briarproject.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import org.briarproject.util.ByteUtils;
class FrameEncoder {
static void encodeIv(byte[] iv, long frameNumber) {
if(iv.length < IV_LENGTH) throw new IllegalArgumentException();
if(frameNumber < 0 || frameNumber > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
ByteUtils.writeUint32(frameNumber, iv, 0);
for(int i = 4; i < IV_LENGTH; i++) iv[i] = 0;
}
static void encodeAad(byte[] aad, long frameNumber, int plaintextLength) {
if(aad.length < AAD_LENGTH) throw new IllegalArgumentException();
if(frameNumber < 0 || frameNumber > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
if(plaintextLength < HEADER_LENGTH)
throw new IllegalArgumentException();
if(plaintextLength > MAX_FRAME_LENGTH - MAC_LENGTH)
throw new IllegalArgumentException();
ByteUtils.writeUint32(frameNumber, aad, 0);
ByteUtils.writeUint16(plaintextLength, aad, 4);
}
static void encodeHeader(byte[] header, boolean finalFrame,
int payloadLength) {
if(header.length < HEADER_LENGTH) throw new IllegalArgumentException();
if(payloadLength < 0)
throw new IllegalArgumentException();
if(payloadLength > MAX_FRAME_LENGTH - HEADER_LENGTH - MAC_LENGTH)
throw new IllegalArgumentException();
ByteUtils.writeUint16(payloadLength, header, 0);
if(finalFrame) header[0] |= 0x80;
}
static boolean isFinalFrame(byte[] header) {
if(header.length < HEADER_LENGTH) throw new IllegalArgumentException();
return (header[0] & 0x80) == 0x80;
}
static int getPayloadLength(byte[] header) {
if(header.length < HEADER_LENGTH) throw new IllegalArgumentException();
return ByteUtils.readUint16(header, 0) & 0x7FFF;
}
}

Some files were not shown because too many files have changed in this diff Show More