Retransmission with exponential backoff (untested).

This commit is contained in:
akwizgran
2013-02-06 20:07:08 +00:00
parent 4c5657321d
commit 379d6ed220
5 changed files with 162 additions and 82 deletions

View File

@@ -418,6 +418,15 @@ interface Database<T> {
SubscriptionUpdate getSubscriptionUpdate(T txn, ContactId c,
long maxLatency) throws DbException;
/**
* Returns the transmission count of the given message with respect to the
* given contact.
* <p>
* Locking: contact read, message read.
*/
int getTransmissionCount(T txn, ContactId c, MessageId m)
throws DbException;
/**
* Returns a collection of transport acks for the given contact, or null if
* no acks are due.
@@ -567,15 +576,6 @@ interface Database<T> {
void setConnectionWindow(T txn, ContactId c, TransportId t, long period,
long centre, byte[] bitmap) throws DbException;
/**
* Updates the expiry times of the given messages with respect to the given
* contact, using the latency of the transport over which they were sent.
* <p>
* Locking: contact read, message write.
*/
void setMessageExpiry(T txn, ContactId c, Collection<MessageId> sent,
long maxLatency) throws DbException;
/**
* Sets the user's rating for the given author.
* <p>
@@ -674,4 +674,14 @@ interface Database<T> {
*/
void setTransportUpdateAcked(T txn, ContactId c, TransportId t,
long version) throws DbException;
/**
* Updates the expiry times of the given messages with respect to the given
* contact, using the given transmission counts and the latency of the
* transport over which they were sent.
* <p>
* Locking: contact read, message write.
*/
void updateExpiryTimes(T txn, ContactId c, Map<MessageId, Integer> sent,
long maxLatency) throws DbException;
}

View File

@@ -13,6 +13,7 @@ import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
@@ -491,6 +492,7 @@ DatabaseCleaner.Callback {
public Collection<byte[]> generateBatch(ContactId c, int maxLength,
long maxLatency) throws DbException {
Collection<MessageId> ids;
Map<MessageId, Integer> sent = new HashMap<MessageId, Integer>();
List<byte[]> messages = new ArrayList<byte[]>();
// Get some sendable messages from the database
contactLock.readLock().lock();
@@ -504,8 +506,10 @@ DatabaseCleaner.Callback {
if(!db.containsContact(txn, c))
throw new NoSuchContactException();
ids = db.getSendableMessages(txn, c, maxLength);
for(MessageId m : ids)
for(MessageId m : ids) {
messages.add(db.getRawMessage(txn, m));
sent.put(m, db.getTransmissionCount(txn, c, m));
}
db.commitTransaction(txn);
} catch(DbException e) {
db.abortTransaction(txn);
@@ -523,7 +527,7 @@ DatabaseCleaner.Callback {
try {
T txn = db.startTransaction();
try {
db.setMessageExpiry(txn, c, ids, maxLatency);
db.updateExpiryTimes(txn, c, sent, maxLatency);
db.commitTransaction(txn);
} catch(DbException e) {
db.abortTransaction(txn);
@@ -541,7 +545,7 @@ DatabaseCleaner.Callback {
public Collection<byte[]> generateBatch(ContactId c, int maxLength,
long maxLatency, Collection<MessageId> requested)
throws DbException {
Collection<MessageId> ids = new ArrayList<MessageId>();
Map<MessageId, Integer> sent = new HashMap<MessageId, Integer>();
List<byte[]> messages = new ArrayList<byte[]>();
// Get some sendable messages from the database
contactLock.readLock().lock();
@@ -561,7 +565,7 @@ DatabaseCleaner.Callback {
if(raw != null) {
if(raw.length > maxLength) break;
messages.add(raw);
ids.add(m);
sent.put(m, db.getTransmissionCount(txn, c, m));
maxLength -= raw.length;
}
it.remove();
@@ -583,7 +587,7 @@ DatabaseCleaner.Callback {
try {
T txn = db.startTransaction();
try {
db.setMessageExpiry(txn, c, ids, maxLatency);
db.updateExpiryTimes(txn, c, sent, maxLatency);
db.commitTransaction(txn);
} catch(DbException e) {
db.abortTransaction(txn);

View File

@@ -99,6 +99,7 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " remoteVersion BIGINT UNSIGNED NOT NULL,"
+ " remoteAcked BOOLEAN NOT NULL,"
+ " expiry BIGINT UNSIGNED NOT NULL,"
+ " txCount INT UNSIGNED NOT NULL,"
+ " PRIMARY KEY (contactId),"
+ " FOREIGN KEY (contactid)"
+ " REFERENCES contacts (contactId)"
@@ -158,6 +159,7 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " contactId INT UNSIGNED NOT NULL,"
+ " seen BOOLEAN NOT NULL,"
+ " expiry BIGINT UNSIGNED NOT NULL,"
+ " txCount INT UNSIGNED NOT NULL,"
+ " PRIMARY KEY (messageId, contactId),"
+ " FOREIGN KEY (messageId)"
+ " REFERENCES messages (messageId)"
@@ -189,6 +191,7 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " remoteVersion BIGINT UNSIGNED NOT NULL,"
+ " remoteAcked BOOLEAN NOT NULL,"
+ " expiry BIGINT UNSIGNED NOT NULL,"
+ " txCount INT UNSIGNED NOT NULL,"
+ " PRIMARY KEY (contactId),"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
@@ -228,6 +231,7 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " localVersion BIGINT UNSIGNED NOT NULL,"
+ " localAcked BIGINT UNSIGNED NOT NULL,"
+ " expiry BIGINT UNSIGNED NOT NULL,"
+ " txCount INT UNSIGNED NOT NULL,"
+ " PRIMARY KEY (contactId, transportId),"
+ " FOREIGN KEY (contactId)"
+ " REFERENCES contacts (contactId)"
@@ -508,10 +512,11 @@ abstract class JdbcDatabase implements Database<Connection> {
rs.close();
ps.close();
// Create a retention version row
sql = "INSERT INTO retentionVersions"
+ " (contactId, retention, localVersion, localAcked,"
+ " remoteVersion, remoteAcked, expiry)"
+ " VALUES (?, ZERO(), ?, ZERO(), ZERO(), TRUE, ZERO())";
sql = "INSERT INTO retentionVersions (contactId, retention,"
+ " localVersion, localAcked, remoteVersion, remoteAcked,"
+ " expiry, txCount)"
+ " VALUES (?, ZERO(), ?, ZERO(), ZERO(), TRUE, ZERO(),"
+ " ZERO())";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setInt(2, 1);
@@ -520,8 +525,9 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.close();
// Create a group version row
sql = "INSERT INTO groupVersions (contactId, localVersion,"
+ " localAcked, remoteVersion, remoteAcked, expiry)"
+ " VALUES (?, ?, ZERO(), ZERO(), TRUE, ZERO())";
+ " localAcked, remoteVersion, remoteAcked, expiry,"
+ " txCount)"
+ " VALUES (?, ?, ZERO(), ZERO(), TRUE, ZERO(), ZERO())";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setInt(2, 1);
@@ -538,8 +544,8 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.close();
if(transports.isEmpty()) return c;
sql = "INSERT INTO transportVersions (contactId, transportId,"
+ " localVersion, localAcked, expiry)"
+ " VALUES (?, ?, ?, ZERO(), ZERO())";
+ " localVersion, localAcked, expiry, txCount)"
+ " VALUES (?, ?, ?, ZERO(), ZERO(), ZERO())";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setInt(3, 1);
@@ -687,8 +693,8 @@ abstract class JdbcDatabase implements Database<Connection> {
PreparedStatement ps = null;
try {
String sql = "INSERT INTO statuses"
+ " (messageId, contactId, seen, expiry)"
+ " VALUES (?, ?, ?, ZERO())";
+ " (messageId, contactId, seen, expiry, txCount)"
+ " VALUES (?, ?, ?, ZERO(), ZERO())";
ps = txn.prepareStatement(sql);
ps.setBytes(1, m.getBytes());
ps.setInt(2, c.getInt());
@@ -790,8 +796,8 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.close();
if(contacts.isEmpty()) return;
sql = "INSERT INTO transportVersions (contactId, transportId,"
+ " localVersion, localAcked, expiry)"
+ " VALUES (?, ?, ?, ZERO(), ZERO())";
+ " localVersion, localAcked, expiry, txCount)"
+ " VALUES (?, ?, ?, ZERO(), ZERO(), ZERO())";
ps = txn.prepareStatement(sql);
ps.setBytes(2, t.getBytes());
ps.setInt(3, 1);
@@ -826,7 +832,8 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.close();
// Bump the subscription version
sql = "UPDATE groupVersions"
+ " SET localVersion = localVersion + ?, expiry = ZERO()"
+ " SET localVersion = localVersion + ?,"
+ " expiry = ZERO(), txCount = ZERO()"
+ " WHERE contactId = ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, 1);
@@ -1541,7 +1548,7 @@ abstract class JdbcDatabase implements Database<Connection> {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT timestamp, localVersion"
String sql = "SELECT timestamp, localVersion, txCount"
+ " FROM messages AS m"
+ " JOIN retentionVersions AS rv"
+ " WHERE rv.contactId = ?"
@@ -1561,13 +1568,17 @@ abstract class JdbcDatabase implements Database<Connection> {
long retention = rs.getLong(1);
retention -= retention % RETENTION_MODULUS;
long version = rs.getLong(2);
int txCount = rs.getInt(3);
if(rs.next()) throw new DbStateException();
rs.close();
ps.close();
sql = "UPDATE retentionVersions SET expiry = ? WHERE contactId = ?";
sql = "UPDATE retentionVersions"
+ " SET expiry = ?, txCount = txCount + ?"
+ " WHERE contactId = ?";
ps = txn.prepareStatement(sql);
ps.setLong(1, calculateExpiry(now, maxLatency));
ps.setInt(2, c.getInt());
ps.setLong(1, calculateExpiry(now, maxLatency, txCount));
ps.setInt(2, 1);
ps.setInt(3, c.getInt());
int affected = ps.executeUpdate();
if(affected != 1) throw new DbStateException();
ps.close();
@@ -1817,7 +1828,7 @@ abstract class JdbcDatabase implements Database<Connection> {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT g.groupId, name, key, localVersion"
String sql = "SELECT g.groupId, name, key, localVersion, txCount"
+ " FROM groups AS g"
+ " JOIN groupVisibilities AS vis"
+ " ON g.groupId = vis.groupId"
@@ -1832,20 +1843,25 @@ abstract class JdbcDatabase implements Database<Connection> {
rs = ps.executeQuery();
List<Group> subs = new ArrayList<Group>();
long version = 0;
int txCount = 0;
while(rs.next()) {
byte[] id = rs.getBytes(1);
String name = rs.getString(2);
byte[] key = rs.getBytes(3);
version = rs.getLong(4);
subs.add(new Group(new GroupId(id), name, key));
version = rs.getLong(4);
txCount = rs.getInt(5);
}
rs.close();
ps.close();
if(subs.isEmpty()) return null;
sql = "UPDATE groupVersions SET expiry = ? WHERE contactId = ?";
sql = "UPDATE groupVersions"
+ " SET expiry = ?, txCount = txCount + ?"
+ " WHERE contactId = ?";
ps = txn.prepareStatement(sql);
ps.setLong(1, calculateExpiry(now, maxLatency));
ps.setInt(2, c.getInt());
ps.setLong(1, calculateExpiry(now, maxLatency, txCount));
ps.setInt(2, 1);
ps.setInt(3, c.getInt());
int affected = ps.executeUpdate();
if(affected != 1) throw new DbStateException();
ps.close();
@@ -1858,6 +1874,30 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
public int getTransmissionCount(Connection txn, ContactId c, MessageId m)
throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT txCount FROM statuses"
+ " WHERE messageId = ? AND contactId = ?";
ps = txn.prepareStatement(sql);
ps.setBytes(1, m.getBytes());
ps.setInt(2, c.getInt());
rs = ps.executeQuery();
if(!rs.next()) throw new DbStateException();
int txCount = rs.getInt(1);
if(rs.next()) throw new DbStateException();
rs.close();
ps.close();
return txCount;
} catch(SQLException e) {
tryToClose(ps);
tryToClose(rs);
throw new DbException(e);
}
}
public Collection<TransportAck> getTransportAcks(Connection txn,
ContactId c) throws DbException {
PreparedStatement ps = null;
@@ -1906,7 +1946,8 @@ abstract class JdbcDatabase implements Database<Connection> {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT tp.transportId, key, value, localVersion"
String sql = "SELECT tp.transportId, key, value, localVersion,"
+ " txCount"
+ " FROM transportProperties AS tp"
+ " JOIN transportVersions AS tv"
+ " ON tp.transportId = tv.transportId"
@@ -1920,32 +1961,39 @@ abstract class JdbcDatabase implements Database<Connection> {
List<TransportUpdate> updates = new ArrayList<TransportUpdate>();
TransportId lastId = null;
TransportProperties p = null;
List<Integer> txCounts = new ArrayList<Integer>();
while(rs.next()) {
TransportId id = new TransportId(rs.getBytes(1));
String key = rs.getString(2), value = rs.getString(3);
long version = rs.getLong(4);
int txCount = rs.getInt(5);
if(!id.equals(lastId)) {
p = new TransportProperties();
updates.add(new TransportUpdate(id, p, version));
txCounts.add(txCount);
}
p.put(key, value);
}
rs.close();
ps.close();
if(updates.isEmpty()) return null;
sql = "UPDATE transportVersions SET expiry = ?"
sql = "UPDATE transportVersions"
+ " SET expiry = ?, txCount = txCount + ?"
+ " WHERE contactId = ? AND transportId = ?";
ps = txn.prepareStatement(sql);
ps.setLong(1, calculateExpiry(now, maxLatency));
ps.setInt(2, c.getInt());
ps.setInt(2, 1);
ps.setInt(3, c.getInt());
int i = 0;
for(TransportUpdate u : updates) {
int txCount = txCounts.get(i++);
ps.setLong(1, calculateExpiry(now, maxLatency, txCount));
ps.setBytes(3, u.getId().getBytes());
ps.addBatch();
}
int [] batchAffected = ps.executeBatch();
if(batchAffected.length != updates.size())
throw new DbStateException();
for(int i = 0; i < batchAffected.length; i++) {
for(i = 0; i < batchAffected.length; i++) {
if(batchAffected[i] != 1) throw new DbStateException();
}
ps.close();
@@ -2405,41 +2453,6 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
public void setMessageExpiry(Connection txn, ContactId c,
Collection<MessageId> sent, long maxLatency) throws DbException {
long now = clock.currentTimeMillis();
PreparedStatement ps = null;
try {
String sql = "UPDATE statuses SET expiry = ?"
+ " WHERE messageId = ? AND contactId = ?";
ps = txn.prepareStatement(sql);
ps.setLong(1, calculateExpiry(now, maxLatency));
ps.setInt(3, c.getInt());
for(MessageId m : sent) {
ps.setBytes(2, m.getBytes());
ps.addBatch();
}
int[] batchAffected = ps.executeBatch();
if(batchAffected.length != sent.size())
throw new DbStateException();
for(int i = 0; i < batchAffected.length; i++) {
if(batchAffected[i] > 1) throw new DbStateException();
}
ps.close();
} catch(SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
}
private long calculateExpiry(long now, long maxLatency) {
long roundTrip = maxLatency * 2;
if(roundTrip < 0) return Long.MAX_VALUE; // Overflow;
long expiry = now + roundTrip;
if(expiry < 0) return Long.MAX_VALUE; // Overflow
return expiry;
}
public Rating setRating(Connection txn, AuthorId a, Rating r)
throws DbException {
PreparedStatement ps = null;
@@ -2835,4 +2848,46 @@ abstract class JdbcDatabase implements Database<Connection> {
throw new DbException(e);
}
}
public void updateExpiryTimes(Connection txn, ContactId c,
Map<MessageId, Integer> sent, long maxLatency) throws DbException {
long now = clock.currentTimeMillis();
PreparedStatement ps = null;
try {
String sql = "UPDATE statuses"
+ " SET expiry = ?, txCount = txCount + ?"
+ " WHERE messageId = ? AND contactId = ?";
ps = txn.prepareStatement(sql);
ps.setInt(2, 1);
ps.setInt(4, c.getInt());
for(Entry<MessageId, Integer> e : sent.entrySet()) {
ps.setLong(1, calculateExpiry(now, maxLatency, e.getValue()));
ps.setBytes(3, e.getKey().getBytes());
ps.addBatch();
}
int[] batchAffected = ps.executeBatch();
if(batchAffected.length != sent.size())
throw new DbStateException();
for(int i = 0; i < batchAffected.length; i++) {
if(batchAffected[i] > 1) throw new DbStateException();
}
ps.close();
} catch(SQLException e) {
tryToClose(ps);
throw new DbException(e);
}
}
// FIXME: Refactor the exponential backoff logic into a separate class
private long calculateExpiry(long now, long maxLatency, int txCount) {
long roundTrip = maxLatency * 2;
if(roundTrip < 0) return Long.MAX_VALUE;
for(int i = 0; i < txCount; i++) {
roundTrip <<= 1;
if(roundTrip < 0) return Long.MAX_VALUE;
}
long expiry = now + roundTrip;
if(expiry < 0) return Long.MAX_VALUE;
return expiry;
}
}

View File

@@ -8,6 +8,8 @@ import java.util.Arrays;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import net.sf.briar.BriarTestCase;
import net.sf.briar.TestMessage;
@@ -677,6 +679,9 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
final Collection<MessageId> sendable = Arrays.asList(messageId,
messageId1);
final Collection<byte[]> messages = Arrays.asList(raw, raw1);
final Map<MessageId, Integer> sent = new HashMap<MessageId, Integer>();
sent.put(messageId, 1);
sent.put(messageId1, 2);
Mockery context = new Mockery();
@SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class);
@@ -688,15 +693,19 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
allowing(database).commitTransaction(txn);
allowing(database).containsContact(txn, contactId);
will(returnValue(true));
// Get the sendable messages
// Get the sendable messages and their transmission counts
oneOf(database).getSendableMessages(txn, contactId, size * 2);
will(returnValue(sendable));
oneOf(database).getRawMessage(txn, messageId);
will(returnValue(raw));
oneOf(database).getTransmissionCount(txn, contactId, messageId);
will(returnValue(1));
oneOf(database).getRawMessage(txn, messageId1);
will(returnValue(raw1));
oneOf(database).getTransmissionCount(txn, contactId, messageId1);
will(returnValue(2));
// Record the outstanding messages
oneOf(database).setMessageExpiry(txn, contactId, sendable,
oneOf(database).updateExpiryTimes(txn, contactId, sent,
Long.MAX_VALUE);
}});
DatabaseComponent db = createDatabaseComponent(database, cleaner,
@@ -731,11 +740,13 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
will(returnValue(null)); // Message is not sendable
oneOf(database).getRawMessageIfSendable(txn, contactId, messageId1);
will(returnValue(raw1)); // Message is sendable
oneOf(database).getTransmissionCount(txn, contactId, messageId1);
will(returnValue(2));
oneOf(database).getRawMessageIfSendable(txn, contactId, messageId2);
will(returnValue(null)); // Message is not sendable
// Mark the message as sent
oneOf(database).setMessageExpiry(txn, contactId,
Arrays.asList(messageId1), Long.MAX_VALUE);
oneOf(database).updateExpiryTimes(txn, contactId,
Collections.singletonMap(messageId1, 2), Long.MAX_VALUE);
}});
DatabaseComponent db = createDatabaseComponent(database, cleaner,
shutdown);

View File

@@ -523,8 +523,8 @@ public class H2DatabaseTest extends BriarTestCase {
assertTrue(it.hasNext());
assertEquals(messageId, it.next());
assertFalse(it.hasNext());
db.setMessageExpiry(txn, contactId, Arrays.asList(messageId),
Long.MAX_VALUE);
db.updateExpiryTimes(txn, contactId,
Collections.singletonMap(messageId, 0), Long.MAX_VALUE);
// The message should no longer be sendable
it = db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator();