Merge branch 'AbrahamKiggundu/briar-master': better lock encapsulation

This commit is contained in:
akwizgran
2015-01-29 11:27:30 +00:00
23 changed files with 944 additions and 517 deletions

2
.gitignore vendored
View File

@@ -1,2 +1,4 @@
build build
.gradle .gradle
.metadata
*.tmp

View File

@@ -11,6 +11,8 @@ import static java.util.logging.Level.WARNING;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger; import java.util.logging.Logger;
import javax.inject.Inject; import javax.inject.Inject;
@@ -47,6 +49,10 @@ Service, EventListener {
private static final int PRIVATE_MESSAGE_NOTIFICATION_ID = 3; private static final int PRIVATE_MESSAGE_NOTIFICATION_ID = 3;
private static final int GROUP_POST_NOTIFICATION_ID = 4; private static final int GROUP_POST_NOTIFICATION_ID = 4;
private static final String CONTACT_URI =
"content://org.briarproject/contact";
private static final String GROUP_URI =
"content://org.briarproject/group";
private static final Logger LOG = private static final Logger LOG =
Logger.getLogger(AndroidNotificationManagerImpl.class.getName()); Logger.getLogger(AndroidNotificationManagerImpl.class.getName());
@@ -55,13 +61,15 @@ Service, EventListener {
private final Executor dbExecutor; private final Executor dbExecutor;
private final EventBus eventBus; private final EventBus eventBus;
private final Context appContext; private final Context appContext;
private final Map<ContactId, Integer> contactCounts = private final Lock synchLock = new ReentrantLock();
new HashMap<ContactId, Integer>(); // Locking: this
private final Map<GroupId, Integer> groupCounts =
new HashMap<GroupId, Integer>(); // Locking: this
private int privateTotal = 0, groupTotal = 0; // Locking: this // The following are locking: synchLock
private int nextRequestId = 0; // Locking: this private final Map<ContactId, Integer> contactCounts =
new HashMap<ContactId, Integer>();
private final Map<GroupId, Integer> groupCounts =
new HashMap<GroupId, Integer>();
private int privateTotal = 0, groupTotal = 0;
private int nextRequestId = 0;
private volatile Settings settings = new Settings(); private volatile Settings settings = new Settings();
@@ -103,22 +111,32 @@ Service, EventListener {
if(e instanceof SettingsUpdatedEvent) loadSettings(); if(e instanceof SettingsUpdatedEvent) loadSettings();
} }
public synchronized void showPrivateMessageNotification(ContactId c) { public void showPrivateMessageNotification(ContactId c) {
Integer count = contactCounts.get(c); synchLock.lock();
if(count == null) contactCounts.put(c, 1); try {
else contactCounts.put(c, count + 1); Integer count = contactCounts.get(c);
privateTotal++; if(count == null) contactCounts.put(c, 1);
updatePrivateMessageNotification(); else contactCounts.put(c, count + 1);
privateTotal++;
updatePrivateMessageNotification();
} finally {
synchLock.unlock();
}
} }
public synchronized void clearPrivateMessageNotification(ContactId c) { public void clearPrivateMessageNotification(ContactId c) {
Integer count = contactCounts.remove(c); synchLock.lock();
if(count == null) return; // Already cleared try {
privateTotal -= count; Integer count = contactCounts.remove(c);
updatePrivateMessageNotification(); if(count == null) return; // Already cleared
privateTotal -= count;
updatePrivateMessageNotification();
} finally {
synchLock.unlock();
}
} }
// Locking: this // Locking: synchLock
private void updatePrivateMessageNotification() { private void updatePrivateMessageNotification() {
if(privateTotal == 0) { if(privateTotal == 0) {
clearPrivateMessageNotification(); clearPrivateMessageNotification();
@@ -143,6 +161,7 @@ Service, EventListener {
Intent i = new Intent(appContext, ConversationActivity.class); Intent i = new Intent(appContext, ConversationActivity.class);
ContactId c = contactCounts.keySet().iterator().next(); ContactId c = contactCounts.keySet().iterator().next();
i.putExtra("briar.CONTACT_ID", c.getInt()); i.putExtra("briar.CONTACT_ID", c.getInt());
i.setData(Uri.parse(CONTACT_URI + "/" + c.getInt()));
i.setFlags(FLAG_ACTIVITY_CLEAR_TOP | FLAG_ACTIVITY_SINGLE_TOP); i.setFlags(FLAG_ACTIVITY_CLEAR_TOP | FLAG_ACTIVITY_SINGLE_TOP);
TaskStackBuilder t = TaskStackBuilder.create(appContext); TaskStackBuilder t = TaskStackBuilder.create(appContext);
t.addParentStack(ConversationActivity.class); t.addParentStack(ConversationActivity.class);
@@ -162,7 +181,7 @@ Service, EventListener {
} }
} }
// Locking: this // Locking: synchLock
private void clearPrivateMessageNotification() { private void clearPrivateMessageNotification() {
Object o = appContext.getSystemService(NOTIFICATION_SERVICE); Object o = appContext.getSystemService(NOTIFICATION_SERVICE);
NotificationManager nm = (NotificationManager) o; NotificationManager nm = (NotificationManager) o;
@@ -180,22 +199,32 @@ Service, EventListener {
return defaults; return defaults;
} }
public synchronized void showGroupPostNotification(GroupId g) { public void showGroupPostNotification(GroupId g) {
Integer count = groupCounts.get(g); synchLock.lock();
if(count == null) groupCounts.put(g, 1); try {
else groupCounts.put(g, count + 1); Integer count = groupCounts.get(g);
groupTotal++; if(count == null) groupCounts.put(g, 1);
updateGroupPostNotification(); else groupCounts.put(g, count + 1);
groupTotal++;
updateGroupPostNotification();
} finally {
synchLock.unlock();
}
} }
public synchronized void clearGroupPostNotification(GroupId g) { public void clearGroupPostNotification(GroupId g) {
Integer count = groupCounts.remove(g); synchLock.lock();
if(count == null) return; // Already cleared try {
groupTotal -= count; Integer count = groupCounts.remove(g);
updateGroupPostNotification(); if(count == null) return; // Already cleared
groupTotal -= count;
updateGroupPostNotification();
} finally {
synchLock.unlock();
}
} }
// Locking: this // Locking: synchLock
private void updateGroupPostNotification() { private void updateGroupPostNotification() {
if(groupTotal == 0) { if(groupTotal == 0) {
clearGroupPostNotification(); clearGroupPostNotification();
@@ -219,6 +248,8 @@ Service, EventListener {
Intent i = new Intent(appContext, GroupActivity.class); Intent i = new Intent(appContext, GroupActivity.class);
GroupId g = groupCounts.keySet().iterator().next(); GroupId g = groupCounts.keySet().iterator().next();
i.putExtra("briar.GROUP_ID", g.getBytes()); i.putExtra("briar.GROUP_ID", g.getBytes());
String idHex = StringUtils.toHexString(g.getBytes());
i.setData(Uri.parse(GROUP_URI + "/" + idHex));
i.setFlags(FLAG_ACTIVITY_CLEAR_TOP | FLAG_ACTIVITY_SINGLE_TOP); i.setFlags(FLAG_ACTIVITY_CLEAR_TOP | FLAG_ACTIVITY_SINGLE_TOP);
TaskStackBuilder t = TaskStackBuilder.create(appContext); TaskStackBuilder t = TaskStackBuilder.create(appContext);
t.addParentStack(GroupActivity.class); t.addParentStack(GroupActivity.class);
@@ -238,18 +269,23 @@ Service, EventListener {
} }
} }
// Locking: this // Locking: synchLock
private void clearGroupPostNotification() { private void clearGroupPostNotification() {
Object o = appContext.getSystemService(NOTIFICATION_SERVICE); Object o = appContext.getSystemService(NOTIFICATION_SERVICE);
NotificationManager nm = (NotificationManager) o; NotificationManager nm = (NotificationManager) o;
nm.cancel(GROUP_POST_NOTIFICATION_ID); nm.cancel(GROUP_POST_NOTIFICATION_ID);
} }
public synchronized void clearNotifications() { public void clearNotifications() {
contactCounts.clear(); synchLock.lock();
groupCounts.clear(); try {
privateTotal = groupTotal = 0; contactCounts.clear();
clearPrivateMessageNotification(); groupCounts.clear();
clearGroupPostNotification(); privateTotal = groupTotal = 0;
clearPrivateMessageNotification();
clearGroupPostNotification();
} finally {
synchLock.unlock();
}
} }
} }

View File

@@ -4,6 +4,8 @@ import static java.util.logging.Level.INFO;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger; import java.util.logging.Logger;
import org.briarproject.api.android.ReferenceManager; import org.briarproject.api.android.ReferenceManager;
@@ -13,49 +15,67 @@ class ReferenceManagerImpl implements ReferenceManager {
private static final Logger LOG = private static final Logger LOG =
Logger.getLogger(ReferenceManagerImpl.class.getName()); Logger.getLogger(ReferenceManagerImpl.class.getName());
// Locking: this private final Lock synchLock = new ReentrantLock();
// The following are locking: synchLock
private final Map<Class<?>, Map<Long, Object>> outerMap = private final Map<Class<?>, Map<Long, Object>> outerMap =
new HashMap<Class<?>, Map<Long, Object>>(); new HashMap<Class<?>, Map<Long, Object>>();
private long nextHandle = 0;
private long nextHandle = 0; // Locking: this public <T> T getReference(long handle, Class<T> c) {
synchLock.lock();
public synchronized <T> T getReference(long handle, Class<T> c) { try {
Map<Long, Object> innerMap = outerMap.get(c); Map<Long, Object> innerMap = outerMap.get(c);
if(innerMap == null) { if(innerMap == null) {
if(LOG.isLoggable(INFO))
LOG.info("0 handles for " + c.getName());
return null;
}
if(LOG.isLoggable(INFO)) if(LOG.isLoggable(INFO))
LOG.info("0 handles for " + c.getName()); LOG.info(innerMap.size() + " handles for " + c.getName());
return null; Object o = innerMap.get(handle);
return c.cast(o);
} finally {
synchLock.unlock();
} }
if(LOG.isLoggable(INFO))
LOG.info(innerMap.size() + " handles for " + c.getName());
Object o = innerMap.get(handle);
return c.cast(o);
} }
public synchronized <T> long putReference(T reference, Class<T> c) { public <T> long putReference(T reference, Class<T> c) {
Map<Long, Object> innerMap = outerMap.get(c); synchLock.lock();
if(innerMap == null) { try {
innerMap = new HashMap<Long, Object>(); Map<Long, Object> innerMap = outerMap.get(c);
outerMap.put(c, innerMap); if(innerMap == null) {
innerMap = new HashMap<Long, Object>();
outerMap.put(c, innerMap);
}
long handle = nextHandle++;
innerMap.put(handle, reference);
if(LOG.isLoggable(INFO)) {
LOG.info(innerMap.size() + " handles for " + c.getName() +
" after put");
}
return handle;
} finally {
synchLock.unlock();
} }
long handle = nextHandle++;
innerMap.put(handle, reference);
if(LOG.isLoggable(INFO)) {
LOG.info(innerMap.size() + " handles for " + c.getName() +
" after put");
}
return handle;
} }
public synchronized <T> T removeReference(long handle, Class<T> c) { public <T> T removeReference(long handle, Class<T> c) {
Map<Long, Object> innerMap = outerMap.get(c); synchLock.lock();
if(innerMap == null) return null; try {
Object o = innerMap.remove(handle); Map<Long, Object> innerMap = outerMap.get(c);
if(innerMap.isEmpty()) outerMap.remove(c); if(innerMap == null) return null;
if(LOG.isLoggable(INFO)) { Object o = innerMap.remove(handle);
LOG.info(innerMap.size() + " handles for " + c.getName() + if(innerMap.isEmpty()) outerMap.remove(c);
" after remove"); if(LOG.isLoggable(INFO)) {
LOG.info(innerMap.size() + " handles for " + c.getName() +
" after remove");
}
return c.cast(o);
} finally {
synchLock.unlock();
} }
return c.cast(o);
} }
} }

View File

@@ -34,9 +34,6 @@ class AndroidLocationUtils implements LocationUtils {
* <ul> * <ul>
* <li>Phone network. This works even when no SIM card is inserted, or a * <li>Phone network. This works even when no SIM card is inserted, or a
* foreign SIM card is inserted.</li> * foreign SIM card is inserted.</li>
* <li><del>Location service (GPS/WiFi/etc).</del> <em>This is disabled for
* now, until we figure out an offline method of converting a long/lat
* into a country code, that doesn't involve a network call.</em>
* <li>SIM card. This is only an heuristic and assumes the user is not * <li>SIM card. This is only an heuristic and assumes the user is not
* roaming.</li> * roaming.</li>
* <li>User locale. This is an even worse heuristic.</li> * <li>User locale. This is an even worse heuristic.</li>

View File

@@ -1,5 +1,8 @@
package org.briarproject.crypto; package org.briarproject.crypto;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.briarproject.api.crypto.MessageDigest; import org.briarproject.api.crypto.MessageDigest;
import org.spongycastle.crypto.BlockCipher; import org.spongycastle.crypto.BlockCipher;
import org.spongycastle.crypto.digests.SHA256Digest; import org.spongycastle.crypto.digests.SHA256Digest;
@@ -16,7 +19,9 @@ class FortunaGenerator {
private static final int KEY_BYTES = 32; private static final int KEY_BYTES = 32;
private static final int BLOCK_BYTES = 16; private static final int BLOCK_BYTES = 16;
// All of the following are locking: this private final Lock synchLock = new ReentrantLock();
// The following are locking: synchLock
private final MessageDigest digest = new DoubleDigest(new SHA256Digest()); private final MessageDigest digest = new DoubleDigest(new SHA256Digest());
private final BlockCipher cipher = new AESLightEngine(); private final BlockCipher cipher = new AESLightEngine();
private final byte[] key = new byte[KEY_BYTES]; private final byte[] key = new byte[KEY_BYTES];
@@ -28,56 +33,78 @@ class FortunaGenerator {
reseed(seed); reseed(seed);
} }
synchronized void reseed(byte[] seed) { void reseed(byte[] seed) {
digest.update(key); synchLock.lock();
digest.update(seed); try {
digest.digest(key, 0, KEY_BYTES); digest.update(key);
incrementCounter(); digest.update(seed);
digest.digest(key, 0, KEY_BYTES);
incrementCounter();
} finally {
synchLock.unlock();
}
} }
// Package access for testing // Package access for testing
synchronized void incrementCounter() { void incrementCounter() {
counter[0]++; synchLock.lock();
for(int i = 0; counter[i] == 0; i++) { try {
if(i + 1 == BLOCK_BYTES) counter[0]++;
throw new RuntimeException("Counter exhausted"); for(int i = 0; counter[i] == 0; i++) {
counter[i + 1]++; if(i + 1 == BLOCK_BYTES)
throw new RuntimeException("Counter exhausted");
counter[i + 1]++;
}
} finally {
synchLock.unlock();
} }
} }
// Package access for testing // Package access for testing
synchronized byte[] getCounter() { byte[] getCounter() {
return counter; synchLock.lock();
try {
return counter;
} finally {
synchLock.unlock();
}
} }
synchronized int nextBytes(byte[] dest, int off, int len) { int nextBytes(byte[] dest, int off, int len) {
// Don't write more than the maximum number of bytes in one request synchLock.lock();
if(len > MAX_BYTES_PER_REQUEST) len = MAX_BYTES_PER_REQUEST; try {
cipher.init(true, new KeyParameter(key)); // Don't write more than the maximum number of bytes in one request
// Generate full blocks directly into the output buffer if(len > MAX_BYTES_PER_REQUEST) len = MAX_BYTES_PER_REQUEST;
int fullBlocks = len / BLOCK_BYTES; cipher.init(true, new KeyParameter(key));
for(int i = 0; i < fullBlocks; i++) { // Generate full blocks directly into the output buffer
cipher.processBlock(counter, 0, dest, off + i * BLOCK_BYTES); int fullBlocks = len / BLOCK_BYTES;
incrementCounter(); for(int i = 0; i < fullBlocks; i++) {
cipher.processBlock(counter, 0, dest, off + i * BLOCK_BYTES);
incrementCounter();
}
// Generate a partial block if needed
int done = fullBlocks * BLOCK_BYTES, remaining = len - done;
assert remaining < BLOCK_BYTES;
if(remaining > 0) {
cipher.processBlock(counter, 0, buffer, 0);
incrementCounter();
// Copy the partial block to the output buffer and erase our copy
System.arraycopy(buffer, 0, dest, off + done, remaining);
for(int i = 0; i < BLOCK_BYTES; i++) buffer[i] = 0;
}
// Generate a new key
for(int i = 0; i < KEY_BYTES / BLOCK_BYTES; i++) {
cipher.processBlock(counter, 0, newKey, i * BLOCK_BYTES);
incrementCounter();
}
System.arraycopy(newKey, 0, key, 0, KEY_BYTES);
for(int i = 0; i < KEY_BYTES; i++) newKey[i] = 0;
// Return the number of bytes written
return len;
} finally {
synchLock.unlock();
} }
// Generate a partial block if needed
int done = fullBlocks * BLOCK_BYTES, remaining = len - done;
assert remaining < BLOCK_BYTES;
if(remaining > 0) {
cipher.processBlock(counter, 0, buffer, 0);
incrementCounter();
// Copy the partial block to the output buffer and erase our copy
System.arraycopy(buffer, 0, dest, off + done, remaining);
for(int i = 0; i < BLOCK_BYTES; i++) buffer[i] = 0;
}
// Generate a new key
for(int i = 0; i < KEY_BYTES / BLOCK_BYTES; i++) {
cipher.processBlock(counter, 0, newKey, i * BLOCK_BYTES);
incrementCounter();
}
System.arraycopy(newKey, 0, key, 0, KEY_BYTES);
for(int i = 0; i < KEY_BYTES; i++) newKey[i] = 0;
// Return the number of bytes written
return len;
} }
} }

View File

@@ -14,7 +14,7 @@ class PseudoRandomImpl implements PseudoRandom {
generator = new FortunaGenerator(seed); generator = new FortunaGenerator(seed);
} }
public synchronized byte[] nextBytes(int length) { public byte[] nextBytes(int length) {
byte[] b = new byte[length]; byte[] b = new byte[length];
int offset = 0; int offset = 0;
while(offset < length) offset += generator.nextBytes(b, offset, length); while(offset < length) offset += generator.nextBytes(b, offset, length);

View File

@@ -27,6 +27,9 @@ import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger; import java.util.logging.Logger;
import org.briarproject.api.Author; import org.briarproject.api.Author;
@@ -313,16 +316,19 @@ abstract class JdbcDatabase implements Database<Connection> {
private final Clock clock; private final Clock clock;
private final LinkedList<Connection> connections = private final LinkedList<Connection> connections =
new LinkedList<Connection>(); // Locking: self new LinkedList<Connection>(); // Locking: connectionsLock
private final AtomicInteger transactionCount = new AtomicInteger(0); private final AtomicInteger transactionCount = new AtomicInteger(0);
private int openConnections = 0; // Locking: connections private int openConnections = 0; // Locking: connectionsLock
private boolean closed = false; // Locking: connections private boolean closed = false; // Locking: connectionsLock
protected abstract Connection createConnection() throws SQLException; protected abstract Connection createConnection() throws SQLException;
protected abstract void flushBuffersToDisk(Statement s) throws SQLException; protected abstract void flushBuffersToDisk(Statement s) throws SQLException;
private final Lock connectionsLock = new ReentrantLock();
private final Condition connectionsChanged = connectionsLock.newCondition();
JdbcDatabase(String hashType, String binaryType, String counterType, JdbcDatabase(String hashType, String binaryType, String counterType,
String secretType, Clock clock) { String secretType, Clock clock) {
this.hashType = hashType; this.hashType = hashType;
@@ -431,9 +437,12 @@ abstract class JdbcDatabase implements Database<Connection> {
public Connection startTransaction() throws DbException { public Connection startTransaction() throws DbException {
Connection txn = null; Connection txn = null;
synchronized(connections) { connectionsLock.lock();
try {
if(closed) throw new DbClosedException(); if(closed) throw new DbClosedException();
txn = connections.poll(); txn = connections.poll();
} finally {
connectionsLock.unlock();
} }
try { try {
if(txn == null) { if(txn == null) {
@@ -441,8 +450,11 @@ abstract class JdbcDatabase implements Database<Connection> {
txn = createConnection(); txn = createConnection();
if(txn == null) throw new DbException(); if(txn == null) throw new DbException();
txn.setAutoCommit(false); txn.setAutoCommit(false);
synchronized(connections) { connectionsLock.lock();
try {
openConnections++; openConnections++;
} finally {
connectionsLock.unlock();
} }
} }
} catch(SQLException e) { } catch(SQLException e) {
@@ -455,9 +467,12 @@ abstract class JdbcDatabase implements Database<Connection> {
public void abortTransaction(Connection txn) { public void abortTransaction(Connection txn) {
try { try {
txn.rollback(); txn.rollback();
synchronized(connections) { connectionsLock.lock();
try {
connections.add(txn); connections.add(txn);
connections.notifyAll(); connectionsChanged.signalAll();
} finally {
connectionsLock.unlock();
} }
} catch(SQLException e) { } catch(SQLException e) {
// Try to close the connection // Try to close the connection
@@ -468,9 +483,12 @@ abstract class JdbcDatabase implements Database<Connection> {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e1.toString(), e1); if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e1.toString(), e1);
} }
// Whatever happens, allow the database to close // Whatever happens, allow the database to close
synchronized(connections) { connectionsLock.lock();
try {
openConnections--; openConnections--;
connections.notifyAll(); connectionsChanged.signalAll();
} finally {
connectionsLock.unlock();
} }
} }
} }
@@ -486,9 +504,12 @@ abstract class JdbcDatabase implements Database<Connection> {
tryToClose(s); tryToClose(s);
throw new DbException(e); throw new DbException(e);
} }
synchronized(connections) { connectionsLock.lock();
try {
connections.add(txn); connections.add(txn);
connections.notifyAll(); connectionsChanged.signalAll();
} finally {
connectionsLock.unlock();
} }
} }
@@ -502,14 +523,15 @@ abstract class JdbcDatabase implements Database<Connection> {
protected void closeAllConnections() throws SQLException { protected void closeAllConnections() throws SQLException {
boolean interrupted = false; boolean interrupted = false;
synchronized(connections) { connectionsLock.lock();
try {
closed = true; closed = true;
for(Connection c : connections) c.close(); for(Connection c : connections) c.close();
openConnections -= connections.size(); openConnections -= connections.size();
connections.clear(); connections.clear();
while(openConnections > 0) { while(openConnections > 0) {
try { try {
connections.wait(); connectionsChanged.await();
} catch(InterruptedException e) { } catch(InterruptedException e) {
LOG.warning("Interrupted while closing connections"); LOG.warning("Interrupted while closing connections");
interrupted = true; interrupted = true;
@@ -518,7 +540,10 @@ abstract class JdbcDatabase implements Database<Connection> {
openConnections -= connections.size(); openConnections -= connections.size();
connections.clear(); connections.clear();
} }
} finally {
connectionsLock.unlock();
} }
if(interrupted) Thread.currentThread().interrupt(); if(interrupted) Thread.currentThread().interrupt();
} }

View File

@@ -10,6 +10,8 @@ import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger; import java.util.logging.Logger;
import org.briarproject.api.Author; import org.briarproject.api.Author;
@@ -60,13 +62,9 @@ class ConnectorGroup extends Thread implements InvitationTask {
private final Collection<InvitationListener> listeners; private final Collection<InvitationListener> listeners;
private final AtomicBoolean connected; private final AtomicBoolean connected;
private final CountDownLatch localConfirmationLatch; private final CountDownLatch localConfirmationLatch;
private final Lock synchLock = new ReentrantLock();
/* // The following are locking: synchLock
* All of the following require locking: this. We don't want to call the
* listeners with a lock held, but we need to avoid a race condition in
* addListener(), so the state that's accessed in addListener() after
* calling listeners.add() must be guarded by a lock.
*/
private int localConfirmationCode = -1, remoteConfirmationCode = -1; private int localConfirmationCode = -1, remoteConfirmationCode = -1;
private boolean connectionFailed = false; private boolean connectionFailed = false;
private boolean localCompared = false, remoteCompared = false; private boolean localCompared = false, remoteCompared = false;
@@ -104,12 +102,18 @@ class ConnectorGroup extends Thread implements InvitationTask {
localConfirmationLatch = new CountDownLatch(1); localConfirmationLatch = new CountDownLatch(1);
} }
public synchronized InvitationState addListener(InvitationListener l) { public InvitationState addListener(InvitationListener l) {
listeners.add(l); synchLock.lock();
return new InvitationState(localInvitationCode, remoteInvitationCode, try {
localConfirmationCode, remoteConfirmationCode, connected.get(), listeners.add(l);
connectionFailed, localCompared, remoteCompared, localMatched, return new InvitationState(localInvitationCode,
remoteMatched, remoteName); remoteInvitationCode, localConfirmationCode,
remoteConfirmationCode, connected.get(), connectionFailed,
localCompared, remoteCompared, localMatched, remoteMatched,
remoteName);
} finally {
synchLock.unlock();
}
} }
public void removeListener(InvitationListener l) { public void removeListener(InvitationListener l) {
@@ -130,8 +134,11 @@ class ConnectorGroup extends Thread implements InvitationTask {
localProps = db.getLocalProperties(); localProps = db.getLocalProperties();
} catch(DbException e) { } catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
synchronized(this) { synchLock.lock();
try {
connectionFailed = true; connectionFailed = true;
} finally {
synchLock.unlock();
} }
for(InvitationListener l : listeners) l.connectionFailed(); for(InvitationListener l : listeners) l.connectionFailed();
return; return;
@@ -163,8 +170,11 @@ class ConnectorGroup extends Thread implements InvitationTask {
} }
// If none of the threads connected, inform the listeners // If none of the threads connected, inform the listeners
if(!connected.get()) { if(!connected.get()) {
synchronized(this) { synchLock.lock();
try {
connectionFailed = true; connectionFailed = true;
} finally {
synchLock.unlock();
} }
for(InvitationListener l : listeners) l.connectionFailed(); for(InvitationListener l : listeners) l.connectionFailed();
} }
@@ -193,17 +203,23 @@ class ConnectorGroup extends Thread implements InvitationTask {
} }
public void localConfirmationSucceeded() { public void localConfirmationSucceeded() {
synchronized(this) { synchLock.lock();
try {
localCompared = true; localCompared = true;
localMatched = true; localMatched = true;
} finally {
synchLock.unlock();
} }
localConfirmationLatch.countDown(); localConfirmationLatch.countDown();
} }
public void localConfirmationFailed() { public void localConfirmationFailed() {
synchronized(this) { synchLock.lock();
try {
localCompared = true; localCompared = true;
localMatched = false; localMatched = false;
} finally {
synchLock.unlock();
} }
localConfirmationLatch.countDown(); localConfirmationLatch.countDown();
} }
@@ -216,9 +232,12 @@ class ConnectorGroup extends Thread implements InvitationTask {
} }
void keyAgreementSucceeded(int localCode, int remoteCode) { void keyAgreementSucceeded(int localCode, int remoteCode) {
synchronized(this) { synchLock.lock();
try {
localConfirmationCode = localCode; localConfirmationCode = localCode;
remoteConfirmationCode = remoteCode; remoteConfirmationCode = remoteCode;
} finally {
synchLock.unlock();
} }
for(InvitationListener l : listeners) for(InvitationListener l : listeners)
l.keyAgreementSucceeded(localCode, remoteCode); l.keyAgreementSucceeded(localCode, remoteCode);
@@ -230,31 +249,43 @@ class ConnectorGroup extends Thread implements InvitationTask {
boolean waitForLocalConfirmationResult() throws InterruptedException { boolean waitForLocalConfirmationResult() throws InterruptedException {
localConfirmationLatch.await(CONFIRMATION_TIMEOUT, MILLISECONDS); localConfirmationLatch.await(CONFIRMATION_TIMEOUT, MILLISECONDS);
synchronized(this) { synchLock.lock();
try {
return localMatched; return localMatched;
} finally {
synchLock.unlock();
} }
} }
void remoteConfirmationSucceeded() { void remoteConfirmationSucceeded() {
synchronized(this) { synchLock.lock();
try {
remoteCompared = true; remoteCompared = true;
remoteMatched = true; remoteMatched = true;
} finally {
synchLock.unlock();
} }
for(InvitationListener l : listeners) l.remoteConfirmationSucceeded(); for(InvitationListener l : listeners) l.remoteConfirmationSucceeded();
} }
void remoteConfirmationFailed() { void remoteConfirmationFailed() {
synchronized(this) { synchLock.lock();
try {
remoteCompared = true; remoteCompared = true;
remoteMatched = false; remoteMatched = false;
} finally {
synchLock.unlock();
} }
for(InvitationListener l : listeners) l.remoteConfirmationFailed(); for(InvitationListener l : listeners) l.remoteConfirmationFailed();
} }
void pseudonymExchangeSucceeded(Author remoteAuthor) { void pseudonymExchangeSucceeded(Author remoteAuthor) {
String name = remoteAuthor.getName(); String name = remoteAuthor.getName();
synchronized(this) { synchLock.lock();
try {
remoteName = name; remoteName = name;
} finally {
synchLock.unlock();
} }
for(InvitationListener l : listeners) for(InvitationListener l : listeners)
l.pseudonymExchangeSucceeded(name); l.pseudonymExchangeSucceeded(name);

View File

@@ -2,34 +2,50 @@ package org.briarproject.lifecycle;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.briarproject.api.lifecycle.ShutdownManager; import org.briarproject.api.lifecycle.ShutdownManager;
class ShutdownManagerImpl implements ShutdownManager { class ShutdownManagerImpl implements ShutdownManager {
protected final Map<Integer, Thread> hooks; // Locking: this private final Lock synchLock = new ReentrantLock();
private int nextHandle = 0; // Locking: this // The following are locking: synchLock
protected final Map<Integer, Thread> hooks;
private int nextHandle = 0;
ShutdownManagerImpl() { ShutdownManagerImpl() {
hooks = new HashMap<Integer, Thread>(); hooks = new HashMap<Integer, Thread>();
} }
public synchronized int addShutdownHook(Runnable r) { public int addShutdownHook(Runnable r) {
int handle = nextHandle++; synchLock.lock();
Thread hook = createThread(r); try {
hooks.put(handle, hook); int handle = nextHandle++;
Runtime.getRuntime().addShutdownHook(hook); Thread hook = createThread(r);
return handle; hooks.put(handle, hook);
Runtime.getRuntime().addShutdownHook(hook);
return handle;
} finally {
synchLock.unlock();
}
} }
protected Thread createThread(Runnable r) { protected Thread createThread(Runnable r) {
return new Thread(r, "ShutdownManager"); return new Thread(r, "ShutdownManager");
} }
public synchronized boolean removeShutdownHook(int handle) { public boolean removeShutdownHook(int handle) {
Thread hook = hooks.remove(handle); synchLock.lock();
if(hook == null) return false; try {
else return Runtime.getRuntime().removeShutdownHook(hook); Thread hook = hooks.remove(handle);
if(hook == null) return false;
else return Runtime.getRuntime().removeShutdownHook(hook);
} finally {
synchLock.unlock();
}
} }
} }

View File

@@ -9,9 +9,9 @@ import org.briarproject.api.messaging.Group;
import org.briarproject.api.messaging.GroupFactory; import org.briarproject.api.messaging.GroupFactory;
import org.briarproject.api.messaging.MessageFactory; import org.briarproject.api.messaging.MessageFactory;
import org.briarproject.api.messaging.MessageVerifier; import org.briarproject.api.messaging.MessageVerifier;
import org.briarproject.api.messaging.MessagingSessionFactory;
import org.briarproject.api.messaging.PacketReaderFactory; import org.briarproject.api.messaging.PacketReaderFactory;
import org.briarproject.api.messaging.PacketWriterFactory; import org.briarproject.api.messaging.PacketWriterFactory;
import org.briarproject.api.messaging.MessagingSessionFactory;
import org.briarproject.api.messaging.SubscriptionUpdate; import org.briarproject.api.messaging.SubscriptionUpdate;
import org.briarproject.api.messaging.UnverifiedMessage; import org.briarproject.api.messaging.UnverifiedMessage;
import org.briarproject.api.serial.StructReader; import org.briarproject.api.serial.StructReader;

View File

@@ -8,6 +8,8 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger; import java.util.logging.Logger;
import org.briarproject.api.ContactId; import org.briarproject.api.ContactId;
@@ -25,9 +27,10 @@ class ConnectionRegistryImpl implements ConnectionRegistry {
Logger.getLogger(ConnectionRegistryImpl.class.getName()); Logger.getLogger(ConnectionRegistryImpl.class.getName());
private final EventBus eventBus; private final EventBus eventBus;
// Locking: this private final Lock synchLock = new ReentrantLock();
// The following are locking: synchLock
private final Map<TransportId, Map<ContactId, Integer>> connections; private final Map<TransportId, Map<ContactId, Integer>> connections;
// Locking: this
private final Map<ContactId, Integer> contactCounts; private final Map<ContactId, Integer> contactCounts;
@Inject @Inject
@@ -40,7 +43,8 @@ class ConnectionRegistryImpl implements ConnectionRegistry {
public void registerConnection(ContactId c, TransportId t) { public void registerConnection(ContactId c, TransportId t) {
LOG.info("Connection registered"); LOG.info("Connection registered");
boolean firstConnection = false; boolean firstConnection = false;
synchronized(this) { synchLock.lock();
try {
Map<ContactId, Integer> m = connections.get(t); Map<ContactId, Integer> m = connections.get(t);
if(m == null) { if(m == null) {
m = new HashMap<ContactId, Integer>(); m = new HashMap<ContactId, Integer>();
@@ -56,7 +60,10 @@ class ConnectionRegistryImpl implements ConnectionRegistry {
} else { } else {
contactCounts.put(c, count + 1); contactCounts.put(c, count + 1);
} }
} finally {
synchLock.unlock();
} }
if(firstConnection) { if(firstConnection) {
LOG.info("Contact connected"); LOG.info("Contact connected");
eventBus.broadcast(new ContactConnectedEvent(c)); eventBus.broadcast(new ContactConnectedEvent(c));
@@ -66,7 +73,8 @@ class ConnectionRegistryImpl implements ConnectionRegistry {
public void unregisterConnection(ContactId c, TransportId t) { public void unregisterConnection(ContactId c, TransportId t) {
LOG.info("Connection unregistered"); LOG.info("Connection unregistered");
boolean lastConnection = false; boolean lastConnection = false;
synchronized(this) { synchLock.lock();
try {
Map<ContactId, Integer> m = connections.get(t); Map<ContactId, Integer> m = connections.get(t);
if(m == null) throw new IllegalArgumentException(); if(m == null) throw new IllegalArgumentException();
Integer count = m.remove(c); Integer count = m.remove(c);
@@ -84,23 +92,38 @@ class ConnectionRegistryImpl implements ConnectionRegistry {
} else { } else {
contactCounts.put(c, count - 1); contactCounts.put(c, count - 1);
} }
} finally {
synchLock.unlock();
} }
if(lastConnection) { if(lastConnection) {
LOG.info("Contact disconnected"); LOG.info("Contact disconnected");
eventBus.broadcast(new ContactDisconnectedEvent(c)); eventBus.broadcast(new ContactDisconnectedEvent(c));
} }
} }
public synchronized Collection<ContactId> getConnectedContacts( public Collection<ContactId> getConnectedContacts(
TransportId t) { TransportId t) {
Map<ContactId, Integer> m = connections.get(t); synchLock.lock();
if(m == null) return Collections.emptyList(); try {
List<ContactId> ids = new ArrayList<ContactId>(m.keySet()); Map<ContactId, Integer> m = connections.get(t);
if(LOG.isLoggable(INFO)) LOG.info(ids.size() + " contacts connected"); if(m == null) return Collections.emptyList();
return Collections.unmodifiableList(ids); List<ContactId> ids = new ArrayList<ContactId>(m.keySet());
if(LOG.isLoggable(INFO)) LOG.info(ids.size() + " contacts connected");
return Collections.unmodifiableList(ids);
} finally {
synchLock.unlock();
}
} }
public synchronized boolean isConnected(ContactId c) { public boolean isConnected(ContactId c) {
return contactCounts.containsKey(c); synchLock.lock();
try {
return contactCounts.containsKey(c);
} finally {
synchLock.unlock();
}
} }
} }

View File

@@ -1,10 +1,15 @@
package org.briarproject.reliability; package org.briarproject.reliability;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import java.io.IOException; import java.io.IOException;
import java.util.Comparator; import java.util.Comparator;
import java.util.Iterator; import java.util.Iterator;
import java.util.SortedSet; import java.util.SortedSet;
import java.util.TreeSet; import java.util.TreeSet;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.briarproject.api.reliability.ReadHandler; import org.briarproject.api.reliability.ReadHandler;
import org.briarproject.api.system.Clock; import org.briarproject.api.system.Clock;
@@ -16,9 +21,13 @@ class Receiver implements ReadHandler {
private final Clock clock; private final Clock clock;
private final Sender sender; private final Sender sender;
private final SortedSet<Data> dataFrames; // Locking: this private final Lock windowLock = new ReentrantLock();
private final Condition dataFrameAvailable = windowLock.newCondition();
// The following are locking: windowLock
private final SortedSet<Data> dataFrames;
private int windowSize = MAX_WINDOW_SIZE;
private int windowSize = MAX_WINDOW_SIZE; // Locking: this
private long finalSequenceNumber = Long.MAX_VALUE; private long finalSequenceNumber = Long.MAX_VALUE;
private long nextSequenceNumber = 1; private long nextSequenceNumber = 1;
@@ -30,36 +39,44 @@ class Receiver implements ReadHandler {
dataFrames = new TreeSet<Data>(new SequenceNumberComparator()); dataFrames = new TreeSet<Data>(new SequenceNumberComparator());
} }
synchronized Data read() throws IOException, InterruptedException { Data read() throws IOException, InterruptedException {
long now = clock.currentTimeMillis(), end = now + READ_TIMEOUT; windowLock.lock();
while(now < end && valid) { try {
if(dataFrames.isEmpty()) { long now = clock.currentTimeMillis(), end = now + READ_TIMEOUT;
// Wait for a data frame while(now < end && valid) {
wait(end - now); if(dataFrames.isEmpty()) {
} else { // Wait for a data frame
Data d = dataFrames.first(); dataFrameAvailable.await(end - now, MILLISECONDS);
if(d.getSequenceNumber() == nextSequenceNumber) {
dataFrames.remove(d);
// Update the window
windowSize += d.getPayloadLength();
sender.sendAck(0, windowSize);
nextSequenceNumber++;
return d;
} else { } else {
// Wait for the next in-order data frame Data d = dataFrames.first();
wait(end - now); if(d.getSequenceNumber() == nextSequenceNumber) {
dataFrames.remove(d);
// Update the window
windowSize += d.getPayloadLength();
sender.sendAck(0, windowSize);
nextSequenceNumber++;
return d;
} else {
// Wait for the next in-order data frame
dataFrameAvailable.await(end - now, MILLISECONDS);
}
} }
now = clock.currentTimeMillis();
} }
now = clock.currentTimeMillis(); if(valid) throw new IOException("Read timed out");
throw new IOException("Connection closed");
} finally {
windowLock.unlock();
} }
if(valid) throw new IOException("Read timed out");
throw new IOException("Connection closed");
} }
void invalidate() { void invalidate() {
valid = false; valid = false;
synchronized(this) { windowLock.lock();
notifyAll(); try {
dataFrameAvailable.signalAll();
} finally {
windowLock.unlock();
} }
} }
@@ -79,43 +96,48 @@ class Receiver implements ReadHandler {
} }
} }
private synchronized void handleData(byte[] b) throws IOException { private void handleData(byte[] b) throws IOException {
if(b.length < Data.MIN_LENGTH || b.length > Data.MAX_LENGTH) { windowLock.lock();
// Ignore data frame with invalid length try {
return; if(b.length < Data.MIN_LENGTH || b.length > Data.MAX_LENGTH) {
} // Ignore data frame with invalid length
Data d = new Data(b); return;
int payloadLength = d.getPayloadLength();
if(payloadLength > windowSize) return; // No space in the window
if(d.getChecksum() != d.calculateChecksum()) {
// Ignore data frame with invalid checksum
return;
}
long sequenceNumber = d.getSequenceNumber();
if(sequenceNumber == 0) {
// Window probe
} else if(sequenceNumber < nextSequenceNumber) {
// Duplicate data frame
} else if(d.isLastFrame()) {
finalSequenceNumber = sequenceNumber;
// Remove any data frames with higher sequence numbers
Iterator<Data> it = dataFrames.iterator();
while(it.hasNext()) {
Data d1 = it.next();
if(d1.getSequenceNumber() >= finalSequenceNumber) it.remove();
} }
if(dataFrames.add(d)) { Data d = new Data(b);
windowSize -= payloadLength; int payloadLength = d.getPayloadLength();
notifyAll(); if(payloadLength > windowSize) return; // No space in the window
if(d.getChecksum() != d.calculateChecksum()) {
// Ignore data frame with invalid checksum
return;
} }
} else if(sequenceNumber < finalSequenceNumber) { long sequenceNumber = d.getSequenceNumber();
if(dataFrames.add(d)) { if(sequenceNumber == 0) {
windowSize -= payloadLength; // Window probe
notifyAll(); } else if(sequenceNumber < nextSequenceNumber) {
// Duplicate data frame
} else if(d.isLastFrame()) {
finalSequenceNumber = sequenceNumber;
// Remove any data frames with higher sequence numbers
Iterator<Data> it = dataFrames.iterator();
while(it.hasNext()) {
Data d1 = it.next();
if(d1.getSequenceNumber() >= finalSequenceNumber) it.remove();
}
if(dataFrames.add(d)) {
windowSize -= payloadLength;
dataFrameAvailable.signalAll();
}
} else if(sequenceNumber < finalSequenceNumber) {
if(dataFrames.add(d)) {
windowSize -= payloadLength;
dataFrameAvailable.signalAll();
}
} }
// Acknowledge the data frame even if it's a duplicate
sender.sendAck(sequenceNumber, windowSize);
} finally {
windowLock.unlock();
} }
// Acknowledge the data frame even if it's a duplicate
sender.sendAck(sequenceNumber, windowSize);
} }
private static class SequenceNumberComparator implements Comparator<Data> { private static class SequenceNumberComparator implements Comparator<Data> {

View File

@@ -1,10 +1,15 @@
package org.briarproject.reliability; package org.briarproject.reliability;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.briarproject.api.reliability.WriteHandler; import org.briarproject.api.reliability.WriteHandler;
import org.briarproject.api.system.Clock; import org.briarproject.api.system.Clock;
@@ -21,9 +26,11 @@ class Sender {
private final Clock clock; private final Clock clock;
private final WriteHandler writeHandler; private final WriteHandler writeHandler;
private final LinkedList<Outstanding> outstanding; // Locking: this private final Lock windowLock = new ReentrantLock();
private final Condition sendWindowAvailable = windowLock.newCondition();
// All of the following are locking: this // The following are locking: windowLock
private final LinkedList<Outstanding> outstanding;
private int outstandingBytes = 0; private int outstandingBytes = 0;
private int windowSize = Data.MAX_PAYLOAD_LENGTH; private int windowSize = Data.MAX_PAYLOAD_LENGTH;
private int rtt = INITIAL_RTT, rttVar = INITIAL_RTT_VAR; private int rtt = INITIAL_RTT, rttVar = INITIAL_RTT_VAR;
@@ -58,7 +65,8 @@ class Sender {
long sequenceNumber = a.getSequenceNumber(); long sequenceNumber = a.getSequenceNumber();
long now = clock.currentTimeMillis(); long now = clock.currentTimeMillis();
Outstanding fastRetransmit = null; Outstanding fastRetransmit = null;
synchronized(this) { windowLock.lock();
try {
// Remove the acked data frame if it's outstanding // Remove the acked data frame if it's outstanding
int foundIndex = -1; int foundIndex = -1;
Iterator<Outstanding> it = outstanding.iterator(); Iterator<Outstanding> it = outstanding.iterator();
@@ -94,7 +102,10 @@ class Sender {
// Don't accept an unreasonably large window size // Don't accept an unreasonably large window size
windowSize = Math.min(a.getWindowSize(), MAX_WINDOW_SIZE); windowSize = Math.min(a.getWindowSize(), MAX_WINDOW_SIZE);
// If space has become available, notify any waiting writers // If space has become available, notify any waiting writers
if(windowSize > oldWindowSize || foundIndex != -1) notifyAll(); if(windowSize > oldWindowSize || foundIndex != -1)
sendWindowAvailable.signalAll();
} finally {
windowLock.unlock();
} }
// Fast retransmission // Fast retransmission
if(fastRetransmit != null) if(fastRetransmit != null)
@@ -105,7 +116,8 @@ class Sender {
long now = clock.currentTimeMillis(); long now = clock.currentTimeMillis();
List<Outstanding> retransmit = null; List<Outstanding> retransmit = null;
boolean sendProbe = false; boolean sendProbe = false;
synchronized(this) { windowLock.lock();
try {
if(outstanding.isEmpty()) { if(outstanding.isEmpty()) {
if(dataWaiting && now - lastWindowUpdateOrProbe > rto) { if(dataWaiting && now - lastWindowUpdateOrProbe > rto) {
sendProbe = true; sendProbe = true;
@@ -134,6 +146,8 @@ class Sender {
} }
} }
} }
} finally {
windowLock.unlock();
} }
// Send a window probe if necessary // Send a window probe if necessary
if(sendProbe) { if(sendProbe) {
@@ -151,12 +165,13 @@ class Sender {
void write(Data d) throws IOException, InterruptedException { void write(Data d) throws IOException, InterruptedException {
int payloadLength = d.getPayloadLength(); int payloadLength = d.getPayloadLength();
synchronized(this) { windowLock.lock();
try {
// Wait for space in the window // Wait for space in the window
long now = clock.currentTimeMillis(), end = now + WRITE_TIMEOUT; long now = clock.currentTimeMillis(), end = now + WRITE_TIMEOUT;
while(now < end && outstandingBytes + payloadLength >= windowSize) { while(now < end && outstandingBytes + payloadLength >= windowSize) {
dataWaiting = true; dataWaiting = true;
wait(end - now); sendWindowAvailable.await(end - now, MILLISECONDS);
now = clock.currentTimeMillis(); now = clock.currentTimeMillis();
} }
if(outstandingBytes + payloadLength >= windowSize) if(outstandingBytes + payloadLength >= windowSize)
@@ -164,12 +179,20 @@ class Sender {
outstanding.add(new Outstanding(d, now)); outstanding.add(new Outstanding(d, now));
outstandingBytes += payloadLength; outstandingBytes += payloadLength;
dataWaiting = false; dataWaiting = false;
} finally {
windowLock.unlock();
} }
writeHandler.handleWrite(d.getBuffer()); writeHandler.handleWrite(d.getBuffer());
} }
synchronized void flush() throws IOException, InterruptedException { void flush() throws IOException, InterruptedException {
while(dataWaiting || !outstanding.isEmpty()) wait(); windowLock.lock();
try {
while(dataWaiting || !outstanding.isEmpty())
sendWindowAvailable.await();
} finally {
windowLock.unlock();
}
} }
private static class Outstanding { private static class Outstanding {

View File

@@ -11,6 +11,8 @@ import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.TimerTask; import java.util.TimerTask;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger; import java.util.logging.Logger;
import javax.inject.Inject; import javax.inject.Inject;
@@ -48,8 +50,9 @@ class KeyManagerImpl extends TimerTask implements KeyManager, EventListener {
private final TagRecogniser tagRecogniser; private final TagRecogniser tagRecogniser;
private final Clock clock; private final Clock clock;
private final Timer timer; private final Timer timer;
private final Lock synchLock = new ReentrantLock();
// All of the following are locking: this // The following are locking: synchLock
private final Map<TransportId, Integer> maxLatencies; private final Map<TransportId, Integer> maxLatencies;
private final Map<EndpointKey, TemporarySecret> oldSecrets; private final Map<EndpointKey, TemporarySecret> oldSecrets;
private final Map<EndpointKey, TemporarySecret> currentSecrets; private final Map<EndpointKey, TemporarySecret> currentSecrets;
@@ -71,45 +74,54 @@ class KeyManagerImpl extends TimerTask implements KeyManager, EventListener {
newSecrets = new HashMap<EndpointKey, TemporarySecret>(); newSecrets = new HashMap<EndpointKey, TemporarySecret>();
} }
public synchronized boolean start() { public boolean start() {
eventBus.addListener(this); synchLock.lock();
// Load the temporary secrets and transport latencies from the database
Collection<TemporarySecret> secrets;
try { try {
secrets = db.getSecrets(); eventBus.addListener(this);
maxLatencies.putAll(db.getTransportLatencies()); // Load the temporary secrets and transport latencies from the DB
} catch(DbException e) { Collection<TemporarySecret> secrets;
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return false;
}
// Work out what phase of its lifecycle each secret is in
long now = clock.currentTimeMillis();
Collection<TemporarySecret> dead = assignSecretsToMaps(now, secrets);
// Replace any dead secrets
Collection<TemporarySecret> created = replaceDeadSecrets(now, dead);
if(!created.isEmpty()) {
// Store any secrets that have been created, removing any dead ones
try { try {
db.addSecrets(created); secrets = db.getSecrets();
maxLatencies.putAll(db.getTransportLatencies());
} catch(DbException e) { } catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return false; return false;
} }
// Work out what phase of its lifecycle each secret is in
long now = clock.currentTimeMillis();
Collection<TemporarySecret> dead =
assignSecretsToMaps(now, secrets);
// Replace any dead secrets
Collection<TemporarySecret> created = replaceDeadSecrets(now, dead);
if(!created.isEmpty()) {
// Store any secrets that have been created,
// removing any dead ones
try {
db.addSecrets(created);
} catch(DbException e) {
if(LOG.isLoggable(WARNING))
LOG.log(WARNING, e.toString(), e);
return false;
}
}
// Pass the old, current and new secrets to the recogniser
for(TemporarySecret s : oldSecrets.values())
tagRecogniser.addSecret(s);
for(TemporarySecret s : currentSecrets.values())
tagRecogniser.addSecret(s);
for(TemporarySecret s : newSecrets.values())
tagRecogniser.addSecret(s);
// Schedule periodic key rotation
timer.scheduleAtFixedRate(this, MS_BETWEEN_CHECKS,
MS_BETWEEN_CHECKS);
return true;
} finally {
synchLock.unlock();
} }
// Pass the old, current and new secrets to the recogniser
for(TemporarySecret s : oldSecrets.values())
tagRecogniser.addSecret(s);
for(TemporarySecret s : currentSecrets.values())
tagRecogniser.addSecret(s);
for(TemporarySecret s : newSecrets.values())
tagRecogniser.addSecret(s);
// Schedule periodic key rotation
timer.scheduleAtFixedRate(this, MS_BETWEEN_CHECKS, MS_BETWEEN_CHECKS);
return true;
} }
// Assigns secrets to the appropriate maps and returns any dead secrets // Assigns secrets to the appropriate maps and returns any dead secrets
// Locking: this // Locking: synchLock
private Collection<TemporarySecret> assignSecretsToMaps(long now, private Collection<TemporarySecret> assignSecretsToMaps(long now,
Collection<TemporarySecret> secrets) { Collection<TemporarySecret> secrets) {
Collection<TemporarySecret> dead = new ArrayList<TemporarySecret>(); Collection<TemporarySecret> dead = new ArrayList<TemporarySecret>();
@@ -142,7 +154,7 @@ class KeyManagerImpl extends TimerTask implements KeyManager, EventListener {
} }
// Replaces the given secrets and returns any secrets created // Replaces the given secrets and returns any secrets created
// Locking: this // Locking: synchLock
private Collection<TemporarySecret> replaceDeadSecrets(long now, private Collection<TemporarySecret> replaceDeadSecrets(long now,
Collection<TemporarySecret> dead) { Collection<TemporarySecret> dead) {
// If there are several dead secrets for an endpoint, use the newest // If there are several dead secrets for an endpoint, use the newest
@@ -200,105 +212,125 @@ class KeyManagerImpl extends TimerTask implements KeyManager, EventListener {
return created; return created;
} }
public synchronized boolean stop() { public boolean stop() {
eventBus.removeListener(this); synchLock.lock();
timer.cancel(); try {
tagRecogniser.removeSecrets(); eventBus.removeListener(this);
maxLatencies.clear(); timer.cancel();
oldSecrets.clear(); tagRecogniser.removeSecrets();
currentSecrets.clear(); maxLatencies.clear();
newSecrets.clear(); oldSecrets.clear();
return true; currentSecrets.clear();
newSecrets.clear();
return true;
} finally {
synchLock.unlock();
}
} }
public synchronized StreamContext getStreamContext(ContactId c, public StreamContext getStreamContext(ContactId c,
TransportId t) { TransportId t) {
TemporarySecret s = currentSecrets.get(new EndpointKey(c, t)); synchLock.lock();
if(s == null) {
LOG.info("No secret for endpoint");
return null;
}
long streamNumber;
try { try {
streamNumber = db.incrementStreamCounter(c, t, s.getPeriod()); TemporarySecret s = currentSecrets.get(new EndpointKey(c, t));
if(streamNumber == -1) { if(s == null) {
LOG.info("No counter for period"); LOG.info("No secret for endpoint");
return null; return null;
} }
} catch(DbException e) { long streamNumber;
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); try {
return null; streamNumber = db.incrementStreamCounter(c, t, s.getPeriod());
if(streamNumber == -1) {
LOG.info("No counter for period");
return null;
}
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return null;
}
byte[] secret = s.getSecret();
return new StreamContext(c, t, secret, streamNumber, s.getAlice());
} finally {
synchLock.unlock();
} }
byte[] secret = s.getSecret();
return new StreamContext(c, t, secret, streamNumber, s.getAlice());
} }
public synchronized void endpointAdded(Endpoint ep, int maxLatency, public void endpointAdded(Endpoint ep, int maxLatency,
byte[] initialSecret) { byte[] initialSecret) {
maxLatencies.put(ep.getTransportId(), maxLatency); synchLock.lock();
// Work out which rotation period we're in
long elapsed = clock.currentTimeMillis() - ep.getEpoch();
long rotation = maxLatency + MAX_CLOCK_DIFFERENCE;
long period = (elapsed / rotation) + 1;
if(period < 1) throw new IllegalStateException();
// Derive the old, current and new secrets
byte[] b1 = initialSecret;
for(long p = 0; p < period; p++)
b1 = crypto.deriveNextSecret(b1, p);
byte[] b2 = crypto.deriveNextSecret(b1, period);
byte[] b3 = crypto.deriveNextSecret(b2, period + 1);
TemporarySecret s1 = new TemporarySecret(ep, period - 1, b1);
TemporarySecret s2 = new TemporarySecret(ep, period, b2);
TemporarySecret s3 = new TemporarySecret(ep, period + 1, b3);
// Add the incoming secrets to their respective maps
EndpointKey k = new EndpointKey(ep);
oldSecrets.put(k, s1);
currentSecrets.put(k, s2);
newSecrets.put(k, s3);
// Store the new secrets
try { try {
db.addSecrets(Arrays.asList(s1, s2, s3)); maxLatencies.put(ep.getTransportId(), maxLatency);
} catch(DbException e) { // Work out which rotation period we're in
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); long elapsed = clock.currentTimeMillis() - ep.getEpoch();
return; long rotation = maxLatency + MAX_CLOCK_DIFFERENCE;
long period = (elapsed / rotation) + 1;
if(period < 1) throw new IllegalStateException();
// Derive the old, current and new secrets
byte[] b1 = initialSecret;
for(long p = 0; p < period; p++)
b1 = crypto.deriveNextSecret(b1, p);
byte[] b2 = crypto.deriveNextSecret(b1, period);
byte[] b3 = crypto.deriveNextSecret(b2, period + 1);
TemporarySecret s1 = new TemporarySecret(ep, period - 1, b1);
TemporarySecret s2 = new TemporarySecret(ep, period, b2);
TemporarySecret s3 = new TemporarySecret(ep, period + 1, b3);
// Add the incoming secrets to their respective maps
EndpointKey k = new EndpointKey(ep);
oldSecrets.put(k, s1);
currentSecrets.put(k, s2);
newSecrets.put(k, s3);
// Store the new secrets
try {
db.addSecrets(Arrays.asList(s1, s2, s3));
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
return;
}
// Pass the new secrets to the recogniser
tagRecogniser.addSecret(s1);
tagRecogniser.addSecret(s2);
tagRecogniser.addSecret(s3);
} finally {
synchLock.unlock();
} }
// Pass the new secrets to the recogniser
tagRecogniser.addSecret(s1);
tagRecogniser.addSecret(s2);
tagRecogniser.addSecret(s3);
} }
@Override @Override
public synchronized void run() { public void run() {
// Rebuild the maps because we may be running a whole period late synchLock.lock();
Collection<TemporarySecret> secrets = new ArrayList<TemporarySecret>(); try {
secrets.addAll(oldSecrets.values()); // Rebuild the maps because we may be running a whole period late
secrets.addAll(currentSecrets.values()); Collection<TemporarySecret> secrets = new ArrayList<TemporarySecret>();
secrets.addAll(newSecrets.values()); secrets.addAll(oldSecrets.values());
oldSecrets.clear(); secrets.addAll(currentSecrets.values());
currentSecrets.clear(); secrets.addAll(newSecrets.values());
newSecrets.clear(); oldSecrets.clear();
// Work out what phase of its lifecycle each secret is in currentSecrets.clear();
long now = clock.currentTimeMillis(); newSecrets.clear();
Collection<TemporarySecret> dead = assignSecretsToMaps(now, secrets); // Work out what phase of its lifecycle each secret is in
// Remove any dead secrets from the recogniser long now = clock.currentTimeMillis();
for(TemporarySecret s : dead) { Collection<TemporarySecret> dead = assignSecretsToMaps(now, secrets);
ContactId c = s.getContactId(); // Remove any dead secrets from the recogniser
TransportId t = s.getTransportId(); for(TemporarySecret s : dead) {
long period = s.getPeriod(); ContactId c = s.getContactId();
tagRecogniser.removeSecret(c, t, period); TransportId t = s.getTransportId();
} long period = s.getPeriod();
// Replace any dead secrets tagRecogniser.removeSecret(c, t, period);
Collection<TemporarySecret> created = replaceDeadSecrets(now, dead);
if(!created.isEmpty()) {
// Store any secrets that have been created
try {
db.addSecrets(created);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
} }
// Pass any secrets that have been created to the recogniser // Replace any dead secrets
for(TemporarySecret s : created) tagRecogniser.addSecret(s); Collection<TemporarySecret> created = replaceDeadSecrets(now, dead);
if(!created.isEmpty()) {
// Store any secrets that have been created
try {
db.addSecrets(created);
} catch(DbException e) {
if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
// Pass any secrets that have been created to the recogniser
for(TemporarySecret s : created) tagRecogniser.addSecret(s);
}
} finally {
synchLock.unlock();
} }
} }
@@ -315,14 +347,14 @@ class KeyManagerImpl extends TimerTask implements KeyManager, EventListener {
} }
} }
// Locking: this // Locking: synchLock
private void removeSecrets(ContactId c, Map<?, TemporarySecret> m) { private void removeSecrets(ContactId c, Map<?, TemporarySecret> m) {
Iterator<TemporarySecret> it = m.values().iterator(); Iterator<TemporarySecret> it = m.values().iterator();
while(it.hasNext()) while(it.hasNext())
if(it.next().getContactId().equals(c)) it.remove(); if(it.next().getContactId().equals(c)) it.remove();
} }
// Locking: this // Locking: synchLock
private void removeSecrets(TransportId t, Map<?, TemporarySecret> m) { private void removeSecrets(TransportId t, Map<?, TemporarySecret> m) {
Iterator<TemporarySecret> it = m.values().iterator(); Iterator<TemporarySecret> it = m.values().iterator();
while(it.hasNext()) while(it.hasNext())
@@ -371,10 +403,13 @@ class KeyManagerImpl extends TimerTask implements KeyManager, EventListener {
public void run() { public void run() {
ContactId c = event.getContactId(); ContactId c = event.getContactId();
tagRecogniser.removeSecrets(c); tagRecogniser.removeSecrets(c);
synchronized(KeyManagerImpl.this) { synchLock.lock();
try {
removeSecrets(c, oldSecrets); removeSecrets(c, oldSecrets);
removeSecrets(c, currentSecrets); removeSecrets(c, currentSecrets);
removeSecrets(c, newSecrets); removeSecrets(c, newSecrets);
} finally {
synchLock.unlock();
} }
} }
} }
@@ -389,8 +424,11 @@ class KeyManagerImpl extends TimerTask implements KeyManager, EventListener {
@Override @Override
public void run() { public void run() {
synchronized(KeyManagerImpl.this) { synchLock.lock();
try {
maxLatencies.put(event.getTransportId(), event.getMaxLatency()); maxLatencies.put(event.getTransportId(), event.getMaxLatency());
} finally {
synchLock.unlock();
} }
} }
} }
@@ -407,11 +445,14 @@ class KeyManagerImpl extends TimerTask implements KeyManager, EventListener {
public void run() { public void run() {
TransportId t = event.getTransportId(); TransportId t = event.getTransportId();
tagRecogniser.removeSecrets(t); tagRecogniser.removeSecrets(t);
synchronized(KeyManagerImpl.this) { synchLock.lock();
try {
maxLatencies.remove(t); maxLatencies.remove(t);
removeSecrets(t, oldSecrets); removeSecrets(t, oldSecrets);
removeSecrets(t, currentSecrets); removeSecrets(t, currentSecrets);
removeSecrets(t, newSecrets); removeSecrets(t, newSecrets);
} finally {
synchLock.unlock();
} }
} }
} }

View File

@@ -2,6 +2,8 @@ package org.briarproject.transport;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import javax.inject.Inject; import javax.inject.Inject;
@@ -18,7 +20,9 @@ class TagRecogniserImpl implements TagRecogniser {
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final DatabaseComponent db; private final DatabaseComponent db;
// Locking: this private final Lock synchLock = new ReentrantLock();
// Locking: synchLock
private final Map<TransportId, TransportTagRecogniser> recognisers; private final Map<TransportId, TransportTagRecogniser> recognisers;
@Inject @Inject
@@ -31,8 +35,11 @@ class TagRecogniserImpl implements TagRecogniser {
public StreamContext recogniseTag(TransportId t, byte[] tag) public StreamContext recogniseTag(TransportId t, byte[] tag)
throws DbException { throws DbException {
TransportTagRecogniser r; TransportTagRecogniser r;
synchronized(this) { synchLock.lock();
try {
r = recognisers.get(t); r = recognisers.get(t);
} finally {
synchLock.unlock();
} }
if(r == null) return null; if(r == null) return null;
return r.recogniseTag(tag); return r.recogniseTag(tag);
@@ -41,35 +48,58 @@ class TagRecogniserImpl implements TagRecogniser {
public void addSecret(TemporarySecret s) { public void addSecret(TemporarySecret s) {
TransportId t = s.getTransportId(); TransportId t = s.getTransportId();
TransportTagRecogniser r; TransportTagRecogniser r;
synchronized(this) { synchLock.lock();
try {
r = recognisers.get(t); r = recognisers.get(t);
if(r == null) { if(r == null) {
r = new TransportTagRecogniser(crypto, db, t); r = new TransportTagRecogniser(crypto, db, t);
recognisers.put(t, r); recognisers.put(t, r);
} }
} finally {
synchLock.unlock();
} }
r.addSecret(s); r.addSecret(s);
} }
public void removeSecret(ContactId c, TransportId t, long period) { public void removeSecret(ContactId c, TransportId t, long period) {
TransportTagRecogniser r; TransportTagRecogniser r;
synchronized(this) { synchLock.lock();
try {
r = recognisers.get(t); r = recognisers.get(t);
} finally {
synchLock.unlock();
} }
if(r != null) r.removeSecret(c, period); if(r != null) r.removeSecret(c, period);
} }
public synchronized void removeSecrets(ContactId c) { public void removeSecrets(ContactId c) {
for(TransportTagRecogniser r : recognisers.values()) synchLock.lock();
r.removeSecrets(c); try {
for(TransportTagRecogniser r : recognisers.values())
r.removeSecrets(c);
} finally {
synchLock.unlock();
}
} }
public synchronized void removeSecrets(TransportId t) { public void removeSecrets(TransportId t) {
recognisers.remove(t); synchLock.lock();
try {
recognisers.remove(t);
} finally {
synchLock.unlock();
}
} }
public synchronized void removeSecrets() { public void removeSecrets() {
for(TransportTagRecogniser r : recognisers.values()) synchLock.lock();
r.removeSecrets(); try {
for(TransportTagRecogniser r : recognisers.values())
r.removeSecrets();
} finally {
synchLock.unlock();
}
} }
} }

View File

@@ -6,6 +6,8 @@ import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.briarproject.api.Bytes; import org.briarproject.api.Bytes;
import org.briarproject.api.ContactId; import org.briarproject.api.ContactId;
@@ -27,8 +29,11 @@ class TransportTagRecogniser {
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final DatabaseComponent db; private final DatabaseComponent db;
private final TransportId transportId; private final TransportId transportId;
private final Map<Bytes, TagContext> tagMap; // Locking: this private final Lock synchLock = new ReentrantLock();
private final Map<RemovalKey, RemovalContext> removalMap; // Locking: this
// The following are locking: synchLock
private final Map<Bytes, TagContext> tagMap;
private final Map<RemovalKey, RemovalContext> removalMap;
TransportTagRecogniser(CryptoComponent crypto, DatabaseComponent db, TransportTagRecogniser(CryptoComponent crypto, DatabaseComponent db,
TransportId transportId) { TransportId transportId) {
@@ -39,61 +44,76 @@ class TransportTagRecogniser {
removalMap = new HashMap<RemovalKey, RemovalContext>(); removalMap = new HashMap<RemovalKey, RemovalContext>();
} }
synchronized StreamContext recogniseTag(byte[] tag) throws DbException { StreamContext recogniseTag(byte[] tag) throws DbException {
TagContext t = tagMap.remove(new Bytes(tag)); synchLock.lock();
if(t == null) return null; // The tag was not expected try {
// Update the reordering window and the expected tags TagContext t = tagMap.remove(new Bytes(tag));
SecretKey key = crypto.deriveTagKey(t.secret, !t.alice); if(t == null) return null; // The tag was not expected
for(long streamNumber : t.window.setSeen(t.streamNumber)) { // Update the reordering window and the expected tags
byte[] tag1 = new byte[TAG_LENGTH]; SecretKey key = crypto.deriveTagKey(t.secret, !t.alice);
crypto.encodeTag(tag1, key, streamNumber); for(long streamNumber : t.window.setSeen(t.streamNumber)) {
if(streamNumber < t.streamNumber) { byte[] tag1 = new byte[TAG_LENGTH];
TagContext removed = tagMap.remove(new Bytes(tag1)); crypto.encodeTag(tag1, key, streamNumber);
assert removed != null; if(streamNumber < t.streamNumber) {
} else { TagContext removed = tagMap.remove(new Bytes(tag1));
TagContext added = new TagContext(t, streamNumber); assert removed != null;
TagContext duplicate = tagMap.put(new Bytes(tag1), added); } else {
TagContext added = new TagContext(t, streamNumber);
TagContext duplicate = tagMap.put(new Bytes(tag1), added);
assert duplicate == null;
}
}
// Store the updated reordering window in the DB
db.setReorderingWindow(t.contactId, transportId, t.period,
t.window.getCentre(), t.window.getBitmap());
return new StreamContext(t.contactId, transportId, t.secret,
t.streamNumber, t.alice);
} finally {
synchLock.unlock();
}
}
void addSecret(TemporarySecret s) {
synchLock.lock();
try {
ContactId contactId = s.getContactId();
boolean alice = s.getAlice();
long period = s.getPeriod();
byte[] secret = s.getSecret();
long centre = s.getWindowCentre();
byte[] bitmap = s.getWindowBitmap();
// Create the reordering window and the expected tags
SecretKey key = crypto.deriveTagKey(secret, !alice);
ReorderingWindow window = new ReorderingWindow(centre, bitmap);
for(long streamNumber : window.getUnseen()) {
byte[] tag = new byte[TAG_LENGTH];
crypto.encodeTag(tag, key, streamNumber);
TagContext added = new TagContext(contactId, alice, period,
secret, window, streamNumber);
TagContext duplicate = tagMap.put(new Bytes(tag), added);
assert duplicate == null; assert duplicate == null;
} }
// Create a removal context to remove the window and the tags later
RemovalContext r = new RemovalContext(window, secret, alice);
removalMap.put(new RemovalKey(contactId, period), r);
} finally {
synchLock.unlock();
} }
// Store the updated reordering window in the DB
db.setReorderingWindow(t.contactId, transportId, t.period,
t.window.getCentre(), t.window.getBitmap());
return new StreamContext(t.contactId, transportId, t.secret,
t.streamNumber, t.alice);
} }
synchronized void addSecret(TemporarySecret s) { void removeSecret(ContactId contactId, long period) {
ContactId contactId = s.getContactId(); synchLock.lock();
boolean alice = s.getAlice(); try {
long period = s.getPeriod(); RemovalKey k = new RemovalKey(contactId, period);
byte[] secret = s.getSecret(); RemovalContext removed = removalMap.remove(k);
long centre = s.getWindowCentre(); if(removed == null) throw new IllegalArgumentException();
byte[] bitmap = s.getWindowBitmap(); removeSecret(removed);
// Create the reordering window and the expected tags } finally {
SecretKey key = crypto.deriveTagKey(secret, !alice); synchLock.unlock();
ReorderingWindow window = new ReorderingWindow(centre, bitmap);
for(long streamNumber : window.getUnseen()) {
byte[] tag = new byte[TAG_LENGTH];
crypto.encodeTag(tag, key, streamNumber);
TagContext added = new TagContext(contactId, alice, period,
secret, window, streamNumber);
TagContext duplicate = tagMap.put(new Bytes(tag), added);
assert duplicate == null;
} }
// Create a removal context to remove the window and the tags later
RemovalContext r = new RemovalContext(window, secret, alice);
removalMap.put(new RemovalKey(contactId, period), r);
} }
synchronized void removeSecret(ContactId contactId, long period) { // Locking: synchLock
RemovalKey k = new RemovalKey(contactId, period);
RemovalContext removed = removalMap.remove(k);
if(removed == null) throw new IllegalArgumentException();
removeSecret(removed);
}
// Locking: this
private void removeSecret(RemovalContext r) { private void removeSecret(RemovalContext r) {
// Remove the expected tags // Remove the expected tags
SecretKey key = crypto.deriveTagKey(r.secret, !r.alice); SecretKey key = crypto.deriveTagKey(r.secret, !r.alice);
@@ -105,17 +125,28 @@ class TransportTagRecogniser {
} }
} }
synchronized void removeSecrets(ContactId c) { void removeSecrets(ContactId c) {
Collection<RemovalKey> keysToRemove = new ArrayList<RemovalKey>(); synchLock.lock();
for(RemovalKey k : removalMap.keySet()) try {
if(k.contactId.equals(c)) keysToRemove.add(k); Collection<RemovalKey> keysToRemove = new ArrayList<RemovalKey>();
for(RemovalKey k : keysToRemove) removeSecret(k.contactId, k.period); for(RemovalKey k : removalMap.keySet())
if(k.contactId.equals(c)) keysToRemove.add(k);
for(RemovalKey k : keysToRemove)
removeSecret(k.contactId, k.period);
} finally {
synchLock.unlock();
}
} }
synchronized void removeSecrets() { void removeSecrets() {
for(RemovalContext r : removalMap.values()) removeSecret(r); synchLock.lock();
assert tagMap.isEmpty(); try {
removalMap.clear(); for(RemovalContext r : removalMap.values()) removeSecret(r);
assert tagMap.isEmpty();
removalMap.clear();
} finally {
synchLock.unlock();
}
} }
private static class TagContext { private static class TagContext {

View File

@@ -1,7 +1,8 @@
package org.briarproject.util; package org.briarproject.util;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
public class LatchedReference<T> { public class LatchedReference<T> {
@@ -23,7 +24,7 @@ public class LatchedReference<T> {
} }
public T waitForReference(long timeout) throws InterruptedException { public T waitForReference(long timeout) throws InterruptedException {
latch.await(timeout, TimeUnit.MILLISECONDS); latch.await(timeout, MILLISECONDS);
return reference.get(); return reference.get();
} }
} }

View File

@@ -8,6 +8,8 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger; import java.util.logging.Logger;
import org.briarproject.util.OsUtils; import org.briarproject.util.OsUtils;
@@ -36,8 +38,9 @@ class WindowsShutdownManagerImpl extends ShutdownManagerImpl {
private static final int WS_MINIMIZE = 0x20000000; private static final int WS_MINIMIZE = 0x20000000;
private final Map<String, Object> options; private final Map<String, Object> options;
private final Lock synchLock = new ReentrantLock();
private boolean initialised = false; // Locking: this private boolean initialised = false; // Locking: synchLock
WindowsShutdownManagerImpl() { WindowsShutdownManagerImpl() {
// Use the Unicode versions of Win32 API calls // Use the Unicode versions of Win32 API calls
@@ -48,9 +51,14 @@ class WindowsShutdownManagerImpl extends ShutdownManagerImpl {
} }
@Override @Override
public synchronized int addShutdownHook(Runnable r) { public int addShutdownHook(Runnable r) {
if(!initialised) initialise(); synchLock.lock();
return super.addShutdownHook(r); try {
if(!initialised) initialise();
return super.addShutdownHook(r);
} finally {
synchLock.unlock();
}
} }
@Override @Override
@@ -58,7 +66,7 @@ class WindowsShutdownManagerImpl extends ShutdownManagerImpl {
return new StartOnce(r); return new StartOnce(r);
} }
// Locking: this // Locking: synchLock
private void initialise() { private void initialise() {
if(OsUtils.isWindows()) { if(OsUtils.isWindows()) {
new EventLoop().start(); new EventLoop().start();
@@ -69,20 +77,25 @@ class WindowsShutdownManagerImpl extends ShutdownManagerImpl {
} }
// Package access for testing // Package access for testing
synchronized void runShutdownHooks() { void runShutdownHooks() {
boolean interrupted = false; synchLock.lock();
// Start each hook in its own thread try {
for(Thread hook : hooks.values()) hook.start(); boolean interrupted = false;
// Wait for all the hooks to finish // Start each hook in its own thread
for(Thread hook : hooks.values()) { for(Thread hook : hooks.values()) hook.start();
try { // Wait for all the hooks to finish
hook.join(); for(Thread hook : hooks.values()) {
} catch(InterruptedException e) { try {
LOG.warning("Interrupted while running shutdown hooks"); hook.join();
interrupted = true; } catch(InterruptedException e) {
LOG.warning("Interrupted while running shutdown hooks");
interrupted = true;
}
} }
if(interrupted) Thread.currentThread().interrupt();
} finally {
synchLock.unlock();
} }
if(interrupted) Thread.currentThread().interrupt();
} }
private class EventLoop extends Thread { private class EventLoop extends Thread {

View File

@@ -1,9 +1,14 @@
package org.briarproject.plugins.file; package org.briarproject.plugins.file;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.Collection; import java.util.Collection;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger; import java.util.logging.Logger;
class PollingRemovableDriveMonitor implements RemovableDriveMonitor, Runnable { class PollingRemovableDriveMonitor implements RemovableDriveMonitor, Runnable {
@@ -14,7 +19,9 @@ class PollingRemovableDriveMonitor implements RemovableDriveMonitor, Runnable {
private final Executor ioExecutor; private final Executor ioExecutor;
private final RemovableDriveFinder finder; private final RemovableDriveFinder finder;
private final int pollingInterval; private final int pollingInterval;
private final Object pollingLock = new Object();
private final Lock pollingLock = new ReentrantLock();
private final Condition stopPolling = pollingLock.newCondition();
private volatile boolean running = false; private volatile boolean running = false;
private volatile Callback callback = null; private volatile Callback callback = null;
@@ -34,8 +41,12 @@ class PollingRemovableDriveMonitor implements RemovableDriveMonitor, Runnable {
public void stop() throws IOException { public void stop() throws IOException {
running = false; running = false;
synchronized(pollingLock) { pollingLock.lock();
pollingLock.notifyAll(); try {
stopPolling.signalAll();
}
finally {
pollingLock.unlock();
} }
} }
@@ -43,8 +54,11 @@ class PollingRemovableDriveMonitor implements RemovableDriveMonitor, Runnable {
try { try {
Collection<File> drives = finder.findRemovableDrives(); Collection<File> drives = finder.findRemovableDrives();
while(running) { while(running) {
synchronized(pollingLock) { pollingLock.lock();
pollingLock.wait(pollingInterval); try {
stopPolling.await(pollingInterval, MILLISECONDS);
} finally {
pollingLock.unlock();
} }
if(!running) return; if(!running) return;
Collection<File> newDrives = finder.findRemovableDrives(); Collection<File> newDrives = finder.findRemovableDrives();

View File

@@ -4,6 +4,8 @@ import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import net.contentobjects.jnotify.JNotify; import net.contentobjects.jnotify.JNotify;
import net.contentobjects.jnotify.JNotifyListener; import net.contentobjects.jnotify.JNotifyListener;
@@ -11,14 +13,19 @@ import net.contentobjects.jnotify.JNotifyListener;
abstract class UnixRemovableDriveMonitor implements RemovableDriveMonitor, abstract class UnixRemovableDriveMonitor implements RemovableDriveMonitor,
JNotifyListener { JNotifyListener {
private static boolean triedLoad = false; // Locking: class //TODO: rationalise this in a further refactor
private static Throwable loadError = null; // Locking: class private static final Lock staticSynchLock = new ReentrantLock();
// Locking: this // The following are locking: staticSynchLock
private static boolean triedLoad = false;
private static Throwable loadError = null;
private final Lock synchLock = new ReentrantLock();
// The following are locking: synchLock
private final List<Integer> watches = new ArrayList<Integer>(); private final List<Integer> watches = new ArrayList<Integer>();
private boolean started = false;
private boolean started = false; // Locking: this private Callback callback = null;
private Callback callback = null; // Locking: this
protected abstract String[] getPathsToWatch(); protected abstract String[] getPathsToWatch();
@@ -33,12 +40,17 @@ JNotifyListener {
} }
} }
public static synchronized void checkEnabled() throws IOException { public static void checkEnabled() throws IOException {
if(!triedLoad) { staticSynchLock.lock();
loadError = tryLoad(); try {
triedLoad = true; if(!triedLoad) {
loadError = tryLoad();
triedLoad = true;
}
if(loadError != null) throw new IOException(loadError.toString());
} finally {
staticSynchLock.unlock();
} }
if(loadError != null) throw new IOException(loadError.toString());
} }
public void start(Callback callback) throws IOException { public void start(Callback callback) throws IOException {
@@ -49,33 +61,42 @@ JNotifyListener {
if(new File(path).exists()) if(new File(path).exists())
watches.add(JNotify.addWatch(path, mask, false, this)); watches.add(JNotify.addWatch(path, mask, false, this));
} }
synchronized(this) { synchLock.lock();
try {
assert !started; assert !started;
assert this.callback == null; assert this.callback == null;
started = true; started = true;
this.callback = callback; this.callback = callback;
this.watches.addAll(watches); this.watches.addAll(watches);
} finally {
synchLock.unlock();
} }
} }
public void stop() throws IOException { public void stop() throws IOException {
checkEnabled(); checkEnabled();
List<Integer> watches; List<Integer> watches;
synchronized(this) { synchLock.lock();
try {
assert started; assert started;
assert callback != null; assert callback != null;
started = false; started = false;
callback = null; callback = null;
watches = new ArrayList<Integer>(this.watches); watches = new ArrayList<Integer>(this.watches);
this.watches.clear(); this.watches.clear();
} finally {
synchLock.unlock();
} }
for(Integer w : watches) JNotify.removeWatch(w); for(Integer w : watches) JNotify.removeWatch(w);
} }
public void fileCreated(int wd, String rootPath, String name) { public void fileCreated(int wd, String rootPath, String name) {
Callback callback; Callback callback;
synchronized(this) { synchLock.lock();
try {
callback = this.callback; callback = this.callback;
} finally {
synchLock.unlock();
} }
if(callback != null) if(callback != null)
callback.driveInserted(new File(rootPath + "/" + name)); callback.driveInserted(new File(rootPath + "/" + name));

View File

@@ -1,5 +1,6 @@
package org.briarproject.plugins.modem; package org.briarproject.plugins.modem;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.logging.Level.INFO; import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING; import static java.util.logging.Level.WARNING;
import static jssc.SerialPort.PURGE_RXCLEAR; import static jssc.SerialPort.PURGE_RXCLEAR;
@@ -10,6 +11,9 @@ import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.Semaphore; import java.util.concurrent.Semaphore;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger; import java.util.logging.Logger;
import jssc.SerialPortEvent; import jssc.SerialPortEvent;
@@ -40,10 +44,15 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
private final Semaphore stateChange; private final Semaphore stateChange;
private final byte[] line; private final byte[] line;
private int lineLen = 0; private final Lock synchLock = new ReentrantLock();
private final Condition connectedStateChanged = synchLock.newCondition();
private final Condition initialisedStateChanged = synchLock.newCondition();
private ReliabilityLayer reliability = null; // Locking: this // The following are locking: synchLock
private boolean initialised = false, connected = false; // Locking: this private ReliabilityLayer reliability = null;
private boolean initialised = false, connected = false;
private int lineLen = 0;
ModemImpl(Executor ioExecutor, ReliabilityLayerFactory reliabilityFactory, ModemImpl(Executor ioExecutor, ReliabilityLayerFactory reliabilityFactory,
Clock clock, Callback callback, SerialPort port) { Clock clock, Callback callback, SerialPort port) {
@@ -91,14 +100,17 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
// Wait for the event thread to receive "OK" // Wait for the event thread to receive "OK"
boolean success = false; boolean success = false;
try { try {
synchronized(this) { synchLock.lock();
try {
long now = clock.currentTimeMillis(); long now = clock.currentTimeMillis();
long end = now + OK_TIMEOUT; long end = now + OK_TIMEOUT;
while(now < end && !initialised) { while(now < end && !initialised) {
wait(end - now); initialisedStateChanged.await(end - now, MILLISECONDS);
now = clock.currentTimeMillis(); now = clock.currentTimeMillis();
} }
success = initialised; success = initialised;
} finally {
synchLock.unlock();
} }
} catch(InterruptedException e) { } catch(InterruptedException e) {
tryToClose(port); tryToClose(port);
@@ -123,11 +135,15 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
public void stop() throws IOException { public void stop() throws IOException {
LOG.info("Stopping"); LOG.info("Stopping");
// Wake any threads that are waiting to connect synchLock.lock();
synchronized(this) { try {
// Wake any threads that are waiting to connect
initialised = false; initialised = false;
connected = false; connected = false;
notifyAll(); initialisedStateChanged.signalAll();
connectedStateChanged.signalAll();
} finally {
synchLock.unlock();
} }
// Hang up if necessary and close the port // Hang up if necessary and close the port
try { try {
@@ -148,7 +164,8 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
// Locking: stateChange // Locking: stateChange
private void hangUpInner() throws IOException { private void hangUpInner() throws IOException {
ReliabilityLayer reliability; ReliabilityLayer reliability;
synchronized(this) { synchLock.lock();
try {
if(this.reliability == null) { if(this.reliability == null) {
LOG.info("Not hanging up - already on the hook"); LOG.info("Not hanging up - already on the hook");
return; return;
@@ -156,6 +173,8 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
reliability = this.reliability; reliability = this.reliability;
this.reliability = null; this.reliability = null;
connected = false; connected = false;
} finally {
synchLock.unlock();
} }
reliability.stop(); reliability.stop();
LOG.info("Hanging up"); LOG.info("Hanging up");
@@ -182,7 +201,8 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
try { try {
ReliabilityLayer reliability = ReliabilityLayer reliability =
reliabilityFactory.createReliabilityLayer(this); reliabilityFactory.createReliabilityLayer(this);
synchronized(this) { synchLock.lock();
try {
if(!initialised) { if(!initialised) {
LOG.info("Not dialling - modem not initialised"); LOG.info("Not dialling - modem not initialised");
return false; return false;
@@ -192,6 +212,8 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
return false; return false;
} }
this.reliability = reliability; this.reliability = reliability;
} finally {
synchLock.unlock();
} }
reliability.start(); reliability.start();
LOG.info("Dialling"); LOG.info("Dialling");
@@ -204,14 +226,17 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
} }
// Wait for the event thread to receive "CONNECT" // Wait for the event thread to receive "CONNECT"
try { try {
synchronized(this) { synchLock.lock();
try {
long now = clock.currentTimeMillis(); long now = clock.currentTimeMillis();
long end = now + CONNECT_TIMEOUT; long end = now + CONNECT_TIMEOUT;
while(now < end && initialised && !connected) { while(now < end && initialised && !connected) {
wait(end - now); connectedStateChanged.await(end - now, MILLISECONDS);
now = clock.currentTimeMillis(); now = clock.currentTimeMillis();
} }
if(connected) return true; if(connected) return true;
} finally {
synchLock.unlock();
} }
} catch(InterruptedException e) { } catch(InterruptedException e) {
tryToClose(port); tryToClose(port);
@@ -227,8 +252,11 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
public InputStream getInputStream() throws IOException { public InputStream getInputStream() throws IOException {
ReliabilityLayer reliability; ReliabilityLayer reliability;
synchronized(this) { synchLock.lock();
try {
reliability = this.reliability; reliability = this.reliability;
} finally {
synchLock.unlock();
} }
if(reliability == null) throw new IOException("Not connected"); if(reliability == null) throw new IOException("Not connected");
return reliability.getInputStream(); return reliability.getInputStream();
@@ -236,8 +264,11 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
public OutputStream getOutputStream() throws IOException { public OutputStream getOutputStream() throws IOException {
ReliabilityLayer reliability; ReliabilityLayer reliability;
synchronized(this) { synchLock.lock();
try {
reliability = this.reliability; reliability = this.reliability;
} finally {
synchLock.unlock();
} }
if(reliability == null) throw new IOException("Not connected"); if(reliability == null) throw new IOException("Not connected");
return reliability.getOutputStream(); return reliability.getOutputStream();
@@ -288,8 +319,11 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
private boolean handleData(byte[] b) throws IOException { private boolean handleData(byte[] b) throws IOException {
ReliabilityLayer reliability; ReliabilityLayer reliability;
synchronized(this) { synchLock.lock();
try {
reliability = this.reliability; reliability = this.reliability;
} finally {
synchLock.unlock();
} }
if(reliability == null) return false; if(reliability == null) return false;
reliability.handleRead(b); reliability.handleRead(b);
@@ -309,9 +343,12 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
lineLen = 0; lineLen = 0;
if(LOG.isLoggable(INFO)) LOG.info("Modem status: " + s); if(LOG.isLoggable(INFO)) LOG.info("Modem status: " + s);
if(s.startsWith("CONNECT")) { if(s.startsWith("CONNECT")) {
synchronized(this) { synchLock.lock();
try {
connected = true; connected = true;
notifyAll(); connectedStateChanged.signalAll();
} finally {
synchLock.unlock();
} }
// There might be data in the buffer as well as text // There might be data in the buffer as well as text
int off = i + 1; int off = i + 1;
@@ -323,14 +360,20 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
return; return;
} else if(s.equals("BUSY") || s.equals("NO DIALTONE") } else if(s.equals("BUSY") || s.equals("NO DIALTONE")
|| s.equals("NO CARRIER")) { || s.equals("NO CARRIER")) {
synchronized(this) { synchLock.lock();
try {
connected = false; connected = false;
notifyAll(); connectedStateChanged.signalAll();
} finally {
synchLock.unlock();
} }
} else if(s.equals("OK")) { } else if(s.equals("OK")) {
synchronized(this) { synchLock.lock();
try {
initialised = true; initialised = true;
notifyAll(); initialisedStateChanged.signalAll();
} finally {
synchLock.unlock();
} }
} else if(s.equals("RING")) { } else if(s.equals("RING")) {
ioExecutor.execute(new Runnable() { ioExecutor.execute(new Runnable() {
@@ -358,7 +401,8 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
try { try {
ReliabilityLayer reliability = ReliabilityLayer reliability =
reliabilityFactory.createReliabilityLayer(this); reliabilityFactory.createReliabilityLayer(this);
synchronized(this) { synchLock.lock();
try {
if(!initialised) { if(!initialised) {
LOG.info("Not answering - modem not initialised"); LOG.info("Not answering - modem not initialised");
return; return;
@@ -368,6 +412,8 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
return; return;
} }
this.reliability = reliability; this.reliability = reliability;
} finally {
synchLock.unlock();
} }
reliability.start(); reliability.start();
LOG.info("Answering"); LOG.info("Answering");
@@ -380,14 +426,17 @@ class ModemImpl implements Modem, WriteHandler, SerialPortEventListener {
// Wait for the event thread to receive "CONNECT" // Wait for the event thread to receive "CONNECT"
boolean success = false; boolean success = false;
try { try {
synchronized(this) { synchLock.lock();
try {
long now = clock.currentTimeMillis(); long now = clock.currentTimeMillis();
long end = now + CONNECT_TIMEOUT; long end = now + CONNECT_TIMEOUT;
while(now < end && initialised && !connected) { while(now < end && initialised && !connected) {
wait(end - now); connectedStateChanged.await(end - now, MILLISECONDS);
now = clock.currentTimeMillis(); now = clock.currentTimeMillis();
} }
success = connected; success = connected;
} finally {
synchLock.unlock();
} }
} catch(InterruptedException e) { } catch(InterruptedException e) {
tryToClose(port); tryToClose(port);

View File

@@ -38,8 +38,8 @@ public class StreamEncrypterImplTest extends BriarTestCase {
byte[] header = new byte[HEADER_LENGTH]; byte[] header = new byte[HEADER_LENGTH];
FrameEncoder.encodeHeader(header, false, payloadLength, 0); FrameEncoder.encodeHeader(header, false, payloadLength, 0);
byte[] expected = new byte[TAG_LENGTH + HEADER_LENGTH + payloadLength int frameLength = HEADER_LENGTH + payloadLength + MAC_LENGTH;
+ MAC_LENGTH]; byte[] expected = new byte[TAG_LENGTH + frameLength];
System.arraycopy(tag, 0, expected, 0, TAG_LENGTH); System.arraycopy(tag, 0, expected, 0, TAG_LENGTH);
System.arraycopy(header, 0, expected, TAG_LENGTH, HEADER_LENGTH); System.arraycopy(header, 0, expected, TAG_LENGTH, HEADER_LENGTH);
System.arraycopy(payload, 0, expected, TAG_LENGTH + HEADER_LENGTH, System.arraycopy(payload, 0, expected, TAG_LENGTH + HEADER_LENGTH,
@@ -53,6 +53,7 @@ public class StreamEncrypterImplTest extends BriarTestCase {
StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher, StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher,
frameKey, tag); frameKey, tag);
int payloadLength = 123; int payloadLength = 123;
int frameLength = HEADER_LENGTH + payloadLength + MAC_LENGTH;
byte[] payload = new byte[payloadLength]; byte[] payload = new byte[payloadLength];
new Random().nextBytes(payload); new Random().nextBytes(payload);
@@ -60,8 +61,7 @@ public class StreamEncrypterImplTest extends BriarTestCase {
byte[] header = new byte[HEADER_LENGTH]; byte[] header = new byte[HEADER_LENGTH];
FrameEncoder.encodeHeader(header, true, payloadLength, 0); FrameEncoder.encodeHeader(header, true, payloadLength, 0);
byte[] expected = new byte[TAG_LENGTH + HEADER_LENGTH + payloadLength byte[] expected = new byte[TAG_LENGTH + frameLength];
+ MAC_LENGTH];
System.arraycopy(tag, 0, expected, 0, TAG_LENGTH); System.arraycopy(tag, 0, expected, 0, TAG_LENGTH);
System.arraycopy(header, 0, expected, TAG_LENGTH, HEADER_LENGTH); System.arraycopy(header, 0, expected, TAG_LENGTH, HEADER_LENGTH);
System.arraycopy(payload, 0, expected, TAG_LENGTH + HEADER_LENGTH, System.arraycopy(payload, 0, expected, TAG_LENGTH + HEADER_LENGTH,
@@ -75,6 +75,7 @@ public class StreamEncrypterImplTest extends BriarTestCase {
StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher, StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher,
frameKey, null); frameKey, null);
int payloadLength = 123; int payloadLength = 123;
int frameLength = HEADER_LENGTH + payloadLength + MAC_LENGTH;
byte[] payload = new byte[payloadLength]; byte[] payload = new byte[payloadLength];
new Random().nextBytes(payload); new Random().nextBytes(payload);
@@ -82,7 +83,7 @@ public class StreamEncrypterImplTest extends BriarTestCase {
byte[] header = new byte[HEADER_LENGTH]; byte[] header = new byte[HEADER_LENGTH];
FrameEncoder.encodeHeader(header, false, payloadLength, 0); FrameEncoder.encodeHeader(header, false, payloadLength, 0);
byte[] expected = new byte[HEADER_LENGTH + payloadLength + MAC_LENGTH]; byte[] expected = new byte[frameLength];
System.arraycopy(header, 0, expected, 0, HEADER_LENGTH); System.arraycopy(header, 0, expected, 0, HEADER_LENGTH);
System.arraycopy(payload, 0, expected, HEADER_LENGTH, payloadLength); System.arraycopy(payload, 0, expected, HEADER_LENGTH, payloadLength);
assertArrayEquals(expected, out.toByteArray()); assertArrayEquals(expected, out.toByteArray());
@@ -94,6 +95,7 @@ public class StreamEncrypterImplTest extends BriarTestCase {
StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher, StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher,
frameKey, null); frameKey, null);
int payloadLength = 123; int payloadLength = 123;
int frameLength = HEADER_LENGTH + payloadLength + MAC_LENGTH;
byte[] payload = new byte[payloadLength]; byte[] payload = new byte[payloadLength];
new Random().nextBytes(payload); new Random().nextBytes(payload);
@@ -101,7 +103,7 @@ public class StreamEncrypterImplTest extends BriarTestCase {
byte[] header = new byte[HEADER_LENGTH]; byte[] header = new byte[HEADER_LENGTH];
FrameEncoder.encodeHeader(header, true, payloadLength, 0); FrameEncoder.encodeHeader(header, true, payloadLength, 0);
byte[] expected = new byte[HEADER_LENGTH + payloadLength + MAC_LENGTH]; byte[] expected = new byte[frameLength];
System.arraycopy(header, 0, expected, 0, HEADER_LENGTH); System.arraycopy(header, 0, expected, 0, HEADER_LENGTH);
System.arraycopy(payload, 0, expected, HEADER_LENGTH, payloadLength); System.arraycopy(payload, 0, expected, HEADER_LENGTH, payloadLength);
assertArrayEquals(expected, out.toByteArray()); assertArrayEquals(expected, out.toByteArray());
@@ -113,6 +115,8 @@ public class StreamEncrypterImplTest extends BriarTestCase {
StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher, StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher,
frameKey, tag); frameKey, tag);
int payloadLength = 123, paddingLength = 234; int payloadLength = 123, paddingLength = 234;
int frameLength = HEADER_LENGTH + payloadLength + paddingLength
+ MAC_LENGTH;
byte[] payload = new byte[payloadLength]; byte[] payload = new byte[payloadLength];
new Random().nextBytes(payload); new Random().nextBytes(payload);
@@ -120,8 +124,7 @@ public class StreamEncrypterImplTest extends BriarTestCase {
byte[] header = new byte[HEADER_LENGTH]; byte[] header = new byte[HEADER_LENGTH];
FrameEncoder.encodeHeader(header, false, payloadLength, paddingLength); FrameEncoder.encodeHeader(header, false, payloadLength, paddingLength);
byte[] expected = new byte[TAG_LENGTH + HEADER_LENGTH + payloadLength byte[] expected = new byte[TAG_LENGTH + frameLength];
+ paddingLength + MAC_LENGTH];
System.arraycopy(tag, 0, expected, 0, TAG_LENGTH); System.arraycopy(tag, 0, expected, 0, TAG_LENGTH);
System.arraycopy(header, 0, expected, TAG_LENGTH, HEADER_LENGTH); System.arraycopy(header, 0, expected, TAG_LENGTH, HEADER_LENGTH);
System.arraycopy(payload, 0, expected, TAG_LENGTH + HEADER_LENGTH, System.arraycopy(payload, 0, expected, TAG_LENGTH + HEADER_LENGTH,
@@ -135,6 +138,8 @@ public class StreamEncrypterImplTest extends BriarTestCase {
StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher, StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher,
frameKey, tag); frameKey, tag);
int payloadLength = 123, paddingLength = 234; int payloadLength = 123, paddingLength = 234;
int frameLength = HEADER_LENGTH + payloadLength + paddingLength
+ MAC_LENGTH;
byte[] payload = new byte[payloadLength]; byte[] payload = new byte[payloadLength];
new Random().nextBytes(payload); new Random().nextBytes(payload);
@@ -142,8 +147,7 @@ public class StreamEncrypterImplTest extends BriarTestCase {
byte[] header = new byte[HEADER_LENGTH]; byte[] header = new byte[HEADER_LENGTH];
FrameEncoder.encodeHeader(header, true, payloadLength, paddingLength); FrameEncoder.encodeHeader(header, true, payloadLength, paddingLength);
byte[] expected = new byte[TAG_LENGTH + HEADER_LENGTH + payloadLength byte[] expected = new byte[TAG_LENGTH + frameLength];
+ paddingLength + MAC_LENGTH];
System.arraycopy(tag, 0, expected, 0, TAG_LENGTH); System.arraycopy(tag, 0, expected, 0, TAG_LENGTH);
System.arraycopy(header, 0, expected, TAG_LENGTH, HEADER_LENGTH); System.arraycopy(header, 0, expected, TAG_LENGTH, HEADER_LENGTH);
System.arraycopy(payload, 0, expected, TAG_LENGTH + HEADER_LENGTH, System.arraycopy(payload, 0, expected, TAG_LENGTH + HEADER_LENGTH,
@@ -157,6 +161,8 @@ public class StreamEncrypterImplTest extends BriarTestCase {
StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher, StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher,
frameKey, null); frameKey, null);
int payloadLength = 123, paddingLength = 234; int payloadLength = 123, paddingLength = 234;
int frameLength = HEADER_LENGTH + payloadLength + paddingLength
+ MAC_LENGTH;
byte[] payload = new byte[payloadLength]; byte[] payload = new byte[payloadLength];
new Random().nextBytes(payload); new Random().nextBytes(payload);
@@ -164,8 +170,7 @@ public class StreamEncrypterImplTest extends BriarTestCase {
byte[] header = new byte[HEADER_LENGTH]; byte[] header = new byte[HEADER_LENGTH];
FrameEncoder.encodeHeader(header, false, payloadLength, paddingLength); FrameEncoder.encodeHeader(header, false, payloadLength, paddingLength);
byte[] expected = new byte[HEADER_LENGTH + payloadLength byte[] expected = new byte[frameLength];
+ paddingLength + MAC_LENGTH];
System.arraycopy(header, 0, expected, 0, HEADER_LENGTH); System.arraycopy(header, 0, expected, 0, HEADER_LENGTH);
System.arraycopy(payload, 0, expected, HEADER_LENGTH, payloadLength); System.arraycopy(payload, 0, expected, HEADER_LENGTH, payloadLength);
assertArrayEquals(expected, out.toByteArray()); assertArrayEquals(expected, out.toByteArray());
@@ -177,6 +182,8 @@ public class StreamEncrypterImplTest extends BriarTestCase {
StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher, StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher,
frameKey, null); frameKey, null);
int payloadLength = 123, paddingLength = 234; int payloadLength = 123, paddingLength = 234;
int frameLength = HEADER_LENGTH + payloadLength + paddingLength
+ MAC_LENGTH;
byte[] payload = new byte[payloadLength]; byte[] payload = new byte[payloadLength];
new Random().nextBytes(payload); new Random().nextBytes(payload);
@@ -184,8 +191,7 @@ public class StreamEncrypterImplTest extends BriarTestCase {
byte[] header = new byte[HEADER_LENGTH]; byte[] header = new byte[HEADER_LENGTH];
FrameEncoder.encodeHeader(header, true, payloadLength, paddingLength); FrameEncoder.encodeHeader(header, true, payloadLength, paddingLength);
byte[] expected = new byte[HEADER_LENGTH + payloadLength byte[] expected = new byte[frameLength];
+ paddingLength + MAC_LENGTH];
System.arraycopy(header, 0, expected, 0, HEADER_LENGTH); System.arraycopy(header, 0, expected, 0, HEADER_LENGTH);
System.arraycopy(payload, 0, expected, HEADER_LENGTH, payloadLength); System.arraycopy(payload, 0, expected, HEADER_LENGTH, payloadLength);
assertArrayEquals(expected, out.toByteArray()); assertArrayEquals(expected, out.toByteArray());
@@ -197,9 +203,13 @@ public class StreamEncrypterImplTest extends BriarTestCase {
StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher, StreamEncrypterImpl s = new StreamEncrypterImpl(out, frameCipher,
frameKey, null); frameKey, null);
int payloadLength = 123, paddingLength = 234; int payloadLength = 123, paddingLength = 234;
int frameLength = HEADER_LENGTH + payloadLength + paddingLength
+ MAC_LENGTH;
byte[] payload = new byte[payloadLength]; byte[] payload = new byte[payloadLength];
new Random().nextBytes(payload); new Random().nextBytes(payload);
int payloadLength1 = 345, paddingLength1 = 456; int payloadLength1 = 345, paddingLength1 = 456;
int frameLength1 = HEADER_LENGTH + payloadLength1 + paddingLength1
+ MAC_LENGTH;
byte[] payload1 = new byte[payloadLength1]; byte[] payload1 = new byte[payloadLength1];
new Random().nextBytes(payload1); new Random().nextBytes(payload1);
@@ -211,16 +221,12 @@ public class StreamEncrypterImplTest extends BriarTestCase {
byte[] header1 = new byte[HEADER_LENGTH]; byte[] header1 = new byte[HEADER_LENGTH];
FrameEncoder.encodeHeader(header1, true, payloadLength1, FrameEncoder.encodeHeader(header1, true, payloadLength1,
paddingLength1); paddingLength1);
byte[] expected = new byte[HEADER_LENGTH + payloadLength byte[] expected = new byte[frameLength + frameLength1];
+ paddingLength + MAC_LENGTH
+ HEADER_LENGTH + payloadLength1
+ paddingLength1 + MAC_LENGTH];
System.arraycopy(header, 0, expected, 0, HEADER_LENGTH); System.arraycopy(header, 0, expected, 0, HEADER_LENGTH);
System.arraycopy(payload, 0, expected, HEADER_LENGTH, payloadLength); System.arraycopy(payload, 0, expected, HEADER_LENGTH, payloadLength);
System.arraycopy(header1, 0, expected, HEADER_LENGTH + payloadLength System.arraycopy(header1, 0, expected, frameLength, HEADER_LENGTH);
+ paddingLength + MAC_LENGTH, HEADER_LENGTH); System.arraycopy(payload1, 0, expected, frameLength + HEADER_LENGTH,
System.arraycopy(payload1, 0, expected, HEADER_LENGTH + payloadLength payloadLength1);
+ paddingLength + MAC_LENGTH + HEADER_LENGTH, payloadLength1);
assertArrayEquals(expected, out.toByteArray()); assertArrayEquals(expected, out.toByteArray());
} }

View File

@@ -11,7 +11,6 @@ import org.briarproject.api.event.ContactConnectedEvent;
import org.briarproject.api.event.ContactDisconnectedEvent; import org.briarproject.api.event.ContactDisconnectedEvent;
import org.briarproject.api.event.EventBus; import org.briarproject.api.event.EventBus;
import org.briarproject.api.plugins.ConnectionRegistry; import org.briarproject.api.plugins.ConnectionRegistry;
import org.briarproject.plugins.ConnectionRegistryImpl;
import org.jmock.Expectations; import org.jmock.Expectations;
import org.jmock.Mockery; import org.jmock.Mockery;
import org.junit.Test; import org.junit.Test;