Merge branch '617-protocol-versioning-for-contact-exchange' into 'master'

Protocol versioning for the contact exchange protocol

Closes #617

See merge request akwizgran/briar!765
This commit is contained in:
akwizgran
2018-04-29 16:40:05 +00:00
49 changed files with 1164 additions and 747 deletions

View File

@@ -13,9 +13,9 @@ import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
public interface ContactExchangeTask {
/**
* The current version of the contact exchange protocol
* The current version of the contact exchange protocol.
*/
int PROTOCOL_VERSION = 0;
byte PROTOCOL_VERSION = 1;
/**
* Label for deriving Alice's header key from the master secret.

View File

@@ -0,0 +1,9 @@
package org.briarproject.bramble.api.contact;
/**
* Record types for the contact exchange protocol.
*/
public interface RecordTypes {
byte CONTACT_INFO = 0;
}

View File

@@ -24,9 +24,9 @@ public class BdfDictionary extends TreeMap<String, Object> {
* );
* </pre>
*/
public static BdfDictionary of(Entry<String, Object>... entries) {
public static BdfDictionary of(Entry<String, ?>... entries) {
BdfDictionary d = new BdfDictionary();
for (Entry<String, Object> e : entries) d.put(e.getKey(), e.getValue());
for (Entry<String, ?> e : entries) d.put(e.getKey(), e.getValue());
return d;
}

View File

@@ -0,0 +1,36 @@
package org.briarproject.bramble.api.record;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import javax.annotation.concurrent.Immutable;
@Immutable
@NotNullByDefault
public class Record {
public static final int RECORD_HEADER_BYTES = 4;
public static final int MAX_RECORD_PAYLOAD_BYTES = 48 * 1024; // 48 KiB
private final byte protocolVersion, recordType;
private final byte[] payload;
public Record(byte protocolVersion, byte recordType, byte[] payload) {
if (payload.length > MAX_RECORD_PAYLOAD_BYTES)
throw new IllegalArgumentException();
this.protocolVersion = protocolVersion;
this.recordType = recordType;
this.payload = payload;
}
public byte getProtocolVersion() {
return protocolVersion;
}
public byte getRecordType() {
return recordType;
}
public byte[] getPayload() {
return payload;
}
}

View File

@@ -0,0 +1,20 @@
package org.briarproject.bramble.api.record;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import java.io.EOFException;
import java.io.IOException;
@NotNullByDefault
public interface RecordReader {
/**
* Reads and returns the next record.
*
* @throws EOFException if the end of the stream is reached without reading
* a complete record
*/
Record readRecord() throws IOException;
void close() throws IOException;
}

View File

@@ -0,0 +1,8 @@
package org.briarproject.bramble.api.record;
import java.io.InputStream;
public interface RecordReaderFactory {
RecordReader createRecordReader(InputStream in);
}

View File

@@ -0,0 +1,15 @@
package org.briarproject.bramble.api.record;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import java.io.IOException;
@NotNullByDefault
public interface RecordWriter {
void writeRecord(Record r) throws IOException;
void flush() throws IOException;
void close() throws IOException;
}

View File

@@ -0,0 +1,8 @@
package org.briarproject.bramble.api.record;
import java.io.OutputStream;
public interface RecordWriterFactory {
RecordWriter createRecordWriter(OutputStream out);
}

View File

@@ -7,5 +7,7 @@ public interface MessageFactory {
Message createMessage(GroupId g, long timestamp, byte[] body);
Message createMessage(byte[] raw);
Message createMessage(MessageId m, byte[] raw);
}

View File

@@ -2,6 +2,8 @@ package org.briarproject.bramble.api.sync;
import org.briarproject.bramble.api.UniqueId;
import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES;
public interface SyncConstants {
/**
@@ -9,16 +11,6 @@ public interface SyncConstants {
*/
byte PROTOCOL_VERSION = 0;
/**
* The length of the record header in bytes.
*/
int RECORD_HEADER_LENGTH = 4;
/**
* The maximum length of the record payload in bytes.
*/
int MAX_RECORD_PAYLOAD_LENGTH = 48 * 1024; // 48 KiB
/**
* The maximum length of a group descriptor in bytes.
*/
@@ -42,5 +34,5 @@ public interface SyncConstants {
/**
* The maximum number of message IDs in an ack, offer or request record.
*/
int MAX_MESSAGE_IDS = MAX_RECORD_PAYLOAD_LENGTH / UniqueId.LENGTH;
int MAX_MESSAGE_IDS = MAX_RECORD_PAYLOAD_BYTES / UniqueId.LENGTH;
}

View File

@@ -5,7 +5,7 @@ import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import java.io.IOException;
@NotNullByDefault
public interface RecordReader {
public interface SyncRecordReader {
boolean eof() throws IOException;

View File

@@ -5,7 +5,7 @@ import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import java.io.InputStream;
@NotNullByDefault
public interface RecordReaderFactory {
public interface SyncRecordReaderFactory {
RecordReader createRecordReader(InputStream in);
SyncRecordReader createRecordReader(InputStream in);
}

View File

@@ -5,7 +5,7 @@ import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import java.io.IOException;
@NotNullByDefault
public interface RecordWriter {
public interface SyncRecordWriter {
void writeAck(Ack a) throws IOException;

View File

@@ -5,7 +5,7 @@ import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import java.io.OutputStream;
@NotNullByDefault
public interface RecordWriterFactory {
public interface SyncRecordWriterFactory {
RecordWriter createRecordWriter(OutputStream out);
SyncRecordWriter createRecordWriter(OutputStream out);
}

View File

@@ -13,6 +13,7 @@ import org.briarproject.bramble.keyagreement.KeyAgreementModule;
import org.briarproject.bramble.lifecycle.LifecycleModule;
import org.briarproject.bramble.plugin.PluginModule;
import org.briarproject.bramble.properties.PropertiesModule;
import org.briarproject.bramble.record.RecordModule;
import org.briarproject.bramble.reliability.ReliabilityModule;
import org.briarproject.bramble.reporting.ReportingModule;
import org.briarproject.bramble.settings.SettingsModule;
@@ -38,6 +39,7 @@ import dagger.Module;
LifecycleModule.class,
PluginModule.class,
PropertiesModule.class,
RecordModule.class,
ReliabilityModule.class,
ReportingModule.class,
SettingsModule.class,

View File

@@ -1,23 +1,20 @@
package org.briarproject.bramble.contact;
import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.client.ClientHelper;
import org.briarproject.bramble.api.contact.ContactExchangeListener;
import org.briarproject.bramble.api.contact.ContactExchangeTask;
import org.briarproject.bramble.api.contact.ContactId;
import org.briarproject.bramble.api.contact.ContactManager;
import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.data.BdfDictionary;
import org.briarproject.bramble.api.data.BdfList;
import org.briarproject.bramble.api.data.BdfReader;
import org.briarproject.bramble.api.data.BdfReaderFactory;
import org.briarproject.bramble.api.data.BdfWriter;
import org.briarproject.bramble.api.data.BdfWriterFactory;
import org.briarproject.bramble.api.db.ContactExistsException;
import org.briarproject.bramble.api.db.DatabaseComponent;
import org.briarproject.bramble.api.db.DbException;
import org.briarproject.bramble.api.db.Transaction;
import org.briarproject.bramble.api.identity.Author;
import org.briarproject.bramble.api.identity.AuthorFactory;
import org.briarproject.bramble.api.identity.LocalAuthor;
import org.briarproject.bramble.api.nullsafety.MethodsNotNullByDefault;
import org.briarproject.bramble.api.nullsafety.ParametersNotNullByDefault;
@@ -26,30 +23,30 @@ import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
import org.briarproject.bramble.api.properties.TransportProperties;
import org.briarproject.bramble.api.properties.TransportPropertyManager;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import org.briarproject.bramble.api.system.Clock;
import org.briarproject.bramble.api.transport.StreamReaderFactory;
import org.briarproject.bramble.api.transport.StreamWriterFactory;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.logging.Logger;
import javax.inject.Inject;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import static org.briarproject.bramble.api.identity.Author.FORMAT_VERSION;
import static org.briarproject.bramble.api.identity.AuthorConstants.MAX_AUTHOR_NAME_LENGTH;
import static org.briarproject.bramble.api.identity.AuthorConstants.MAX_PUBLIC_KEY_LENGTH;
import static org.briarproject.bramble.api.contact.RecordTypes.CONTACT_INFO;
import static org.briarproject.bramble.api.identity.AuthorConstants.MAX_SIGNATURE_LENGTH;
import static org.briarproject.bramble.api.plugin.TransportId.MAX_TRANSPORT_ID_LENGTH;
import static org.briarproject.bramble.api.properties.TransportPropertyConstants.MAX_PROPERTIES_PER_TRANSPORT;
import static org.briarproject.bramble.api.properties.TransportPropertyConstants.MAX_PROPERTY_LENGTH;
import static org.briarproject.bramble.util.ValidationUtils.checkLength;
import static org.briarproject.bramble.util.ValidationUtils.checkSize;
@MethodsNotNullByDefault
@ParametersNotNullByDefault
@@ -62,9 +59,9 @@ class ContactExchangeTaskImpl extends Thread implements ContactExchangeTask {
"org.briarproject.briar.contact/EXCHANGE";
private final DatabaseComponent db;
private final AuthorFactory authorFactory;
private final BdfReaderFactory bdfReaderFactory;
private final BdfWriterFactory bdfWriterFactory;
private final ClientHelper clientHelper;
private final RecordReaderFactory recordReaderFactory;
private final RecordWriterFactory recordWriterFactory;
private final Clock clock;
private final ConnectionManager connectionManager;
private final ContactManager contactManager;
@@ -81,17 +78,17 @@ class ContactExchangeTaskImpl extends Thread implements ContactExchangeTask {
private volatile boolean alice;
@Inject
ContactExchangeTaskImpl(DatabaseComponent db,
AuthorFactory authorFactory, BdfReaderFactory bdfReaderFactory,
BdfWriterFactory bdfWriterFactory, Clock clock,
ContactExchangeTaskImpl(DatabaseComponent db, ClientHelper clientHelper,
RecordReaderFactory recordReaderFactory,
RecordWriterFactory recordWriterFactory, Clock clock,
ConnectionManager connectionManager, ContactManager contactManager,
TransportPropertyManager transportPropertyManager,
CryptoComponent crypto, StreamReaderFactory streamReaderFactory,
StreamWriterFactory streamWriterFactory) {
this.db = db;
this.authorFactory = authorFactory;
this.bdfReaderFactory = bdfReaderFactory;
this.bdfWriterFactory = bdfWriterFactory;
this.clientHelper = clientHelper;
this.recordReaderFactory = recordReaderFactory;
this.recordWriterFactory = recordWriterFactory;
this.clock = clock;
this.connectionManager = connectionManager;
this.contactManager = contactManager;
@@ -126,18 +123,18 @@ class ContactExchangeTaskImpl extends Thread implements ContactExchangeTask {
} catch (IOException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
listener.contactExchangeFailed();
tryToClose(conn, true);
tryToClose(conn);
return;
}
// Get the local transport properties
Map<TransportId, TransportProperties> localProperties, remoteProperties;
Map<TransportId, TransportProperties> localProperties;
try {
localProperties = transportPropertyManager.getLocalProperties();
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
listener.contactExchangeFailed();
tryToClose(conn, true);
tryToClose(conn);
return;
}
@@ -151,159 +148,138 @@ class ContactExchangeTaskImpl extends Thread implements ContactExchangeTask {
InputStream streamReader =
streamReaderFactory.createContactExchangeStreamReader(in,
alice ? bobHeaderKey : aliceHeaderKey);
BdfReader r = bdfReaderFactory.createReader(streamReader);
RecordReader recordReader =
recordReaderFactory.createRecordReader(streamReader);
// Create the writers
OutputStream streamWriter =
streamWriterFactory.createContactExchangeStreamWriter(out,
alice ? aliceHeaderKey : bobHeaderKey);
BdfWriter w = bdfWriterFactory.createWriter(streamWriter);
RecordWriter recordWriter =
recordWriterFactory.createRecordWriter(streamWriter);
// Derive the nonces to be signed
byte[] aliceNonce = crypto.mac(ALICE_NONCE_LABEL, masterSecret,
new byte[] {PROTOCOL_VERSION});
byte[] bobNonce = crypto.mac(BOB_NONCE_LABEL, masterSecret,
new byte[] {PROTOCOL_VERSION});
byte[] localNonce = alice ? aliceNonce : bobNonce;
byte[] remoteNonce = alice ? bobNonce : aliceNonce;
// Exchange pseudonyms, signed nonces, and timestamps
// Sign the nonce
byte[] localSignature = sign(localAuthor, localNonce);
// Exchange contact info
long localTimestamp = clock.currentTimeMillis();
Author remoteAuthor;
long remoteTimestamp;
ContactInfo remoteInfo;
try {
if (alice) {
sendPseudonym(w, aliceNonce);
sendTimestamp(w, localTimestamp);
sendTransportProperties(w, localProperties);
w.flush();
remoteAuthor = receivePseudonym(r, bobNonce);
remoteTimestamp = receiveTimestamp(r);
remoteProperties = receiveTransportProperties(r);
sendContactInfo(recordWriter, localAuthor, localProperties,
localSignature, localTimestamp);
recordWriter.flush();
remoteInfo = receiveContactInfo(recordReader);
} else {
remoteAuthor = receivePseudonym(r, aliceNonce);
remoteTimestamp = receiveTimestamp(r);
remoteProperties = receiveTransportProperties(r);
sendPseudonym(w, bobNonce);
sendTimestamp(w, localTimestamp);
sendTransportProperties(w, localProperties);
w.flush();
remoteInfo = receiveContactInfo(recordReader);
sendContactInfo(recordWriter, localAuthor, localProperties,
localSignature, localTimestamp);
recordWriter.flush();
}
// Close the outgoing stream and expect EOF on the incoming stream
w.close();
if (!r.eof()) LOG.warning("Unexpected data at end of connection");
} catch (GeneralSecurityException | IOException e) {
// Close the outgoing stream
recordWriter.close();
// Skip any remaining records from the incoming stream
try {
while (true) recordReader.readRecord();
} catch (EOFException expected) {
LOG.info("End of stream");
}
} catch (IOException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
listener.contactExchangeFailed();
tryToClose(conn, true);
tryToClose(conn);
return;
}
// Verify the contact's signature
if (!verify(remoteInfo.author, remoteNonce, remoteInfo.signature)) {
LOG.warning("Invalid signature");
listener.contactExchangeFailed();
tryToClose(conn);
return;
}
// The agreed timestamp is the minimum of the peers' timestamps
long timestamp = Math.min(localTimestamp, remoteTimestamp);
long timestamp = Math.min(localTimestamp, remoteInfo.timestamp);
try {
// Add the contact
ContactId contactId = addContact(remoteAuthor, timestamp,
remoteProperties);
ContactId contactId = addContact(remoteInfo.author, timestamp,
remoteInfo.properties);
// Reuse the connection as a transport connection
connectionManager.manageOutgoingConnection(contactId, transportId,
conn);
// Pseudonym exchange succeeded
LOG.info("Pseudonym exchange succeeded");
listener.contactExchangeSucceeded(remoteAuthor);
listener.contactExchangeSucceeded(remoteInfo.author);
} catch (ContactExistsException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
tryToClose(conn, true);
listener.duplicateContact(remoteAuthor);
tryToClose(conn);
listener.duplicateContact(remoteInfo.author);
} catch (DbException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
tryToClose(conn, true);
tryToClose(conn);
listener.contactExchangeFailed();
}
}
private void sendPseudonym(BdfWriter w, byte[] nonce)
throws GeneralSecurityException, IOException {
// Sign the nonce
byte[] privateKey = localAuthor.getPrivateKey();
byte[] sig = crypto.sign(SIGNING_LABEL_EXCHANGE, nonce, privateKey);
// Write the name, public key and signature
w.writeListStart();
w.writeLong(localAuthor.getFormatVersion());
w.writeString(localAuthor.getName());
w.writeRaw(localAuthor.getPublicKey());
w.writeRaw(sig);
w.writeListEnd();
LOG.info("Sent pseudonym");
}
private Author receivePseudonym(BdfReader r, byte[] nonce)
throws GeneralSecurityException, IOException {
// Read the format version, name, public key and signature
r.readListStart();
int formatVersion = (int) r.readLong();
if (formatVersion != FORMAT_VERSION) throw new FormatException();
String name = r.readString(MAX_AUTHOR_NAME_LENGTH);
if (name.isEmpty()) throw new FormatException();
byte[] publicKey = r.readRaw(MAX_PUBLIC_KEY_LENGTH);
if (publicKey.length == 0) throw new FormatException();
byte[] sig = r.readRaw(MAX_SIGNATURE_LENGTH);
if (sig.length == 0) throw new FormatException();
r.readListEnd();
LOG.info("Received pseudonym");
// Verify the signature
if (!crypto.verifySignature(sig, SIGNING_LABEL_EXCHANGE, nonce,
publicKey)) {
if (LOG.isLoggable(INFO))
LOG.info("Invalid signature");
throw new GeneralSecurityException();
private byte[] sign(LocalAuthor author, byte[] nonce) {
try {
return crypto.sign(SIGNING_LABEL_EXCHANGE, nonce,
author.getPrivateKey());
} catch (GeneralSecurityException e) {
throw new AssertionError();
}
return authorFactory.createAuthor(formatVersion, name, publicKey);
}
private void sendTimestamp(BdfWriter w, long timestamp)
private boolean verify(Author author, byte[] nonce, byte[] signature) {
try {
return crypto.verifySignature(signature, SIGNING_LABEL_EXCHANGE,
nonce, author.getPublicKey());
} catch (GeneralSecurityException e) {
return false;
}
}
private void sendContactInfo(RecordWriter recordWriter, Author author,
Map<TransportId, TransportProperties> properties, byte[] signature,
long timestamp) throws IOException {
BdfList authorList = clientHelper.toList(author);
BdfDictionary props = clientHelper.toDictionary(properties);
BdfList payload = BdfList.of(authorList, props, signature, timestamp);
recordWriter.writeRecord(new Record(PROTOCOL_VERSION, CONTACT_INFO,
clientHelper.toByteArray(payload)));
LOG.info("Sent contact info");
}
private ContactInfo receiveContactInfo(RecordReader recordReader)
throws IOException {
w.writeLong(timestamp);
LOG.info("Sent timestamp");
}
private long receiveTimestamp(BdfReader r) throws IOException {
long timestamp = r.readLong();
Record record;
do {
record = recordReader.readRecord();
if (record.getProtocolVersion() != PROTOCOL_VERSION)
throw new FormatException();
} while (record.getRecordType() != CONTACT_INFO);
LOG.info("Received contact info");
BdfList payload = clientHelper.toList(record.getPayload());
checkSize(payload, 4);
Author author = clientHelper.parseAndValidateAuthor(payload.getList(0));
BdfDictionary props = payload.getDictionary(1);
Map<TransportId, TransportProperties> properties =
clientHelper.parseAndValidateTransportPropertiesMap(props);
byte[] signature = payload.getRaw(2);
checkLength(signature, 1, MAX_SIGNATURE_LENGTH);
long timestamp = payload.getLong(3);
if (timestamp < 0) throw new FormatException();
LOG.info("Received timestamp");
return timestamp;
}
private void sendTransportProperties(BdfWriter w,
Map<TransportId, TransportProperties> local) throws IOException {
w.writeListStart();
for (Entry<TransportId, TransportProperties> e : local.entrySet())
w.writeList(BdfList.of(e.getKey().getString(), e.getValue()));
w.writeListEnd();
}
private Map<TransportId, TransportProperties> receiveTransportProperties(
BdfReader r) throws IOException {
Map<TransportId, TransportProperties> remote = new HashMap<>();
r.readListStart();
while (!r.hasListEnd()) {
r.readListStart();
String id = r.readString(MAX_TRANSPORT_ID_LENGTH);
if (id.isEmpty()) throw new FormatException();
TransportProperties p = new TransportProperties();
r.readDictionaryStart();
while (!r.hasDictionaryEnd()) {
if (p.size() == MAX_PROPERTIES_PER_TRANSPORT)
throw new FormatException();
String key = r.readString(MAX_PROPERTY_LENGTH);
String value = r.readString(MAX_PROPERTY_LENGTH);
p.put(key, value);
}
r.readDictionaryEnd();
r.readListEnd();
remote.put(new TransportId(id), p);
}
r.readListEnd();
return remote;
return new ContactInfo(author, properties, signature, timestamp);
}
private ContactId addContact(Author remoteAuthor, long timestamp,
@@ -324,13 +300,30 @@ class ContactExchangeTaskImpl extends Thread implements ContactExchangeTask {
return contactId;
}
private void tryToClose(DuplexTransportConnection conn, boolean exception) {
private void tryToClose(DuplexTransportConnection conn) {
try {
LOG.info("Closing connection");
conn.getReader().dispose(exception, true);
conn.getWriter().dispose(exception);
conn.getReader().dispose(true, true);
conn.getWriter().dispose(true);
} catch (IOException e) {
if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
}
}
private static class ContactInfo {
private final Author author;
private final Map<TransportId, TransportProperties> properties;
private final byte[] signature;
private final long timestamp;
private ContactInfo(Author author,
Map<TransportId, TransportProperties> properties,
byte[] signature, long timestamp) {
this.author = author;
this.properties = properties;
this.signature = signature;
this.timestamp = timestamp;
}
}
}

View File

@@ -13,6 +13,8 @@ import org.briarproject.bramble.api.plugin.PluginManager;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexPlugin;
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import java.io.IOException;
import java.io.InputStream;
@@ -44,6 +46,8 @@ class KeyAgreementConnector {
private final KeyAgreementCrypto keyAgreementCrypto;
private final PluginManager pluginManager;
private final ConnectionChooser connectionChooser;
private final RecordReaderFactory recordReaderFactory;
private final RecordWriterFactory recordWriterFactory;
private final List<KeyAgreementListener> listeners =
new CopyOnWriteArrayList<>();
@@ -54,11 +58,15 @@ class KeyAgreementConnector {
KeyAgreementConnector(Callbacks callbacks,
KeyAgreementCrypto keyAgreementCrypto, PluginManager pluginManager,
ConnectionChooser connectionChooser) {
ConnectionChooser connectionChooser,
RecordReaderFactory recordReaderFactory,
RecordWriterFactory recordWriterFactory) {
this.callbacks = callbacks;
this.keyAgreementCrypto = keyAgreementCrypto;
this.pluginManager = pluginManager;
this.connectionChooser = connectionChooser;
this.recordReaderFactory = recordReaderFactory;
this.recordWriterFactory = recordWriterFactory;
}
Payload listen(KeyPair localKeyPair) {
@@ -119,7 +127,8 @@ class KeyAgreementConnector {
KeyAgreementConnection chosen =
connectionChooser.poll(CONNECTION_TIMEOUT);
if (chosen == null) return null;
return new KeyAgreementTransport(chosen);
return new KeyAgreementTransport(recordReaderFactory,
recordWriterFactory, chosen);
} catch (InterruptedException e) {
LOG.info("Interrupted while waiting for connection");
Thread.currentThread().interrupt();

View File

@@ -19,6 +19,8 @@ import org.briarproject.bramble.api.keyagreement.event.KeyAgreementWaitingEvent;
import org.briarproject.bramble.api.nullsafety.MethodsNotNullByDefault;
import org.briarproject.bramble.api.nullsafety.ParametersNotNullByDefault;
import org.briarproject.bramble.api.plugin.PluginManager;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import java.io.IOException;
import java.util.logging.Logger;
@@ -49,14 +51,17 @@ class KeyAgreementTaskImpl extends Thread implements KeyAgreementTask,
KeyAgreementTaskImpl(CryptoComponent crypto,
KeyAgreementCrypto keyAgreementCrypto, EventBus eventBus,
PayloadEncoder payloadEncoder, PluginManager pluginManager,
ConnectionChooser connectionChooser) {
ConnectionChooser connectionChooser,
RecordReaderFactory recordReaderFactory,
RecordWriterFactory recordWriterFactory) {
this.crypto = crypto;
this.keyAgreementCrypto = keyAgreementCrypto;
this.eventBus = eventBus;
this.payloadEncoder = payloadEncoder;
localKeyPair = crypto.generateAgreementKeyPair();
connector = new KeyAgreementConnector(this, keyAgreementCrypto,
pluginManager, connectionChooser);
pluginManager, connectionChooser, recordReaderFactory,
recordWriterFactory);
}
@Override

View File

@@ -4,9 +4,12 @@ import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
import org.briarproject.bramble.util.ByteUtils;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
@@ -14,8 +17,6 @@ import java.util.logging.Logger;
import static java.util.logging.Level.WARNING;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_LENGTH;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_PAYLOAD_LENGTH_OFFSET;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.KEY;
@@ -30,14 +31,17 @@ class KeyAgreementTransport {
Logger.getLogger(KeyAgreementTransport.class.getName());
private final KeyAgreementConnection kac;
private final InputStream in;
private final OutputStream out;
private final RecordReader reader;
private final RecordWriter writer;
KeyAgreementTransport(KeyAgreementConnection kac)
KeyAgreementTransport(RecordReaderFactory recordReaderFactory,
RecordWriterFactory recordWriterFactory, KeyAgreementConnection kac)
throws IOException {
this.kac = kac;
in = kac.getConnection().getReader().getInputStream();
out = kac.getConnection().getWriter().getOutputStream();
InputStream in = kac.getConnection().getReader().getInputStream();
reader = recordReaderFactory.createRecordReader(in);
OutputStream out = kac.getConnection().getWriter().getOutputStream();
writer = recordWriterFactory.createRecordWriter(out);
}
public DuplexTransportConnection getConnection() {
@@ -74,9 +78,8 @@ class KeyAgreementTransport {
tryToClose(exception);
}
public void tryToClose(boolean exception) {
private void tryToClose(boolean exception) {
try {
LOG.info("Closing connection");
kac.getConnection().getReader().dispose(exception, true);
kac.getConnection().getWriter().dispose(exception);
} catch (IOException e) {
@@ -85,59 +88,27 @@ class KeyAgreementTransport {
}
private void writeRecord(byte type, byte[] payload) throws IOException {
byte[] recordHeader = new byte[RECORD_HEADER_LENGTH];
recordHeader[0] = PROTOCOL_VERSION;
recordHeader[1] = type;
ByteUtils.writeUint16(payload.length, recordHeader,
RECORD_HEADER_PAYLOAD_LENGTH_OFFSET);
out.write(recordHeader);
out.write(payload);
out.flush();
writer.writeRecord(new Record(PROTOCOL_VERSION, type, payload));
writer.flush();
}
private byte[] readRecord(byte expectedType) throws AbortException {
while (true) {
byte[] header = readHeader();
byte version = header[0], type = header[1];
int len = ByteUtils.readUint16(header,
RECORD_HEADER_PAYLOAD_LENGTH_OFFSET);
// Reject unrecognised protocol version
if (version != PROTOCOL_VERSION) throw new AbortException(false);
if (type == ABORT) throw new AbortException(true);
if (type == expectedType) {
try {
return readData(len);
} catch (IOException e) {
throw new AbortException(e);
}
}
// Reject recognised but unexpected record type
if (type == KEY || type == CONFIRM) throw new AbortException(false);
// Skip unrecognised record type
try {
readData(len);
Record record = reader.readRecord();
// Reject unrecognised protocol version
if (record.getProtocolVersion() != PROTOCOL_VERSION)
throw new AbortException(false);
byte type = record.getRecordType();
if (type == ABORT) throw new AbortException(true);
if (type == expectedType) return record.getPayload();
// Reject recognised but unexpected record type
if (type == KEY || type == CONFIRM)
throw new AbortException(false);
// Skip unrecognised record type
} catch (IOException e) {
throw new AbortException(e);
}
}
}
private byte[] readHeader() throws AbortException {
try {
return readData(RECORD_HEADER_LENGTH);
} catch (IOException e) {
throw new AbortException(e);
}
}
private byte[] readData(int len) throws IOException {
byte[] data = new byte[len];
int offset = 0;
while (offset < data.length) {
int read = in.read(data, offset, data.length - offset);
if (read == -1) throw new EOFException();
offset += read;
}
return data;
}
}

View File

@@ -0,0 +1,21 @@
package org.briarproject.bramble.record;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import dagger.Module;
import dagger.Provides;
@Module
public class RecordModule {
@Provides
RecordReaderFactory provideRecordReaderFactory() {
return new RecordReaderFactoryImpl();
}
@Provides
RecordWriterFactory provideRecordWriterFactory() {
return new RecordWriterFactoryImpl();
}
}

View File

@@ -0,0 +1,14 @@
package org.briarproject.bramble.record;
import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import java.io.InputStream;
class RecordReaderFactoryImpl implements RecordReaderFactory {
@Override
public RecordReader createRecordReader(InputStream in) {
return new RecordReaderImpl(in);
}
}

View File

@@ -0,0 +1,46 @@
package org.briarproject.bramble.record;
import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.util.ByteUtils;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import javax.annotation.concurrent.NotThreadSafe;
import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES;
import static org.briarproject.bramble.api.record.Record.RECORD_HEADER_BYTES;
@NotThreadSafe
@NotNullByDefault
class RecordReaderImpl implements RecordReader {
private final DataInputStream in;
private final byte[] header = new byte[RECORD_HEADER_BYTES];
RecordReaderImpl(InputStream in) {
this.in = new DataInputStream(in);
}
@Override
public Record readRecord() throws IOException {
in.readFully(header);
byte protocolVersion = header[0];
byte recordType = header[1];
int payloadLength = ByteUtils.readUint16(header, 2);
if (payloadLength < 0 || payloadLength > MAX_RECORD_PAYLOAD_BYTES)
throw new FormatException();
byte[] payload = new byte[payloadLength];
in.readFully(payload);
return new Record(protocolVersion, recordType, payload);
}
@Override
public void close() throws IOException {
in.close();
}
}

View File

@@ -0,0 +1,14 @@
package org.briarproject.bramble.record;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import java.io.OutputStream;
class RecordWriterFactoryImpl implements RecordWriterFactory {
@Override
public RecordWriter createRecordWriter(OutputStream out) {
return new RecordWriterImpl(out);
}
}

View File

@@ -0,0 +1,45 @@
package org.briarproject.bramble.record;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.util.ByteUtils;
import java.io.IOException;
import java.io.OutputStream;
import javax.annotation.concurrent.NotThreadSafe;
import static org.briarproject.bramble.api.record.Record.RECORD_HEADER_BYTES;
@NotThreadSafe
@NotNullByDefault
class RecordWriterImpl implements RecordWriter {
private final OutputStream out;
private final byte[] header = new byte[RECORD_HEADER_BYTES];
RecordWriterImpl(OutputStream out) {
this.out = out;
}
@Override
public void writeRecord(Record r) throws IOException {
byte[] payload = r.getPayload();
header[0] = r.getProtocolVersion();
header[1] = r.getRecordType();
ByteUtils.writeUint16(payload.length, header, 2);
out.write(header);
out.write(payload);
}
@Override
public void flush() throws IOException {
out.flush();
}
@Override
public void close() throws IOException {
out.close();
}
}

View File

@@ -14,8 +14,8 @@ import org.briarproject.bramble.api.lifecycle.event.LifecycleEvent;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.Offer;
import org.briarproject.bramble.api.sync.RecordWriter;
import org.briarproject.bramble.api.sync.Request;
import org.briarproject.bramble.api.sync.SyncRecordWriter;
import org.briarproject.bramble.api.sync.SyncSession;
import org.briarproject.bramble.api.sync.event.GroupVisibilityUpdatedEvent;
import org.briarproject.bramble.api.sync.event.MessageRequestedEvent;
@@ -39,8 +39,8 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import static org.briarproject.bramble.api.lifecycle.LifecycleManager.LifecycleState.STOPPING;
import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_IDS;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_RECORD_PAYLOAD_LENGTH;
/**
* An outgoing {@link SyncSession} suitable for duplex transports. The session
@@ -67,7 +67,7 @@ class DuplexOutgoingSession implements SyncSession, EventListener {
private final Clock clock;
private final ContactId contactId;
private final int maxLatency, maxIdleTime;
private final RecordWriter recordWriter;
private final SyncRecordWriter recordWriter;
private final BlockingQueue<ThrowingRunnable<IOException>> writerTasks;
private final AtomicBoolean generateAckQueued = new AtomicBoolean(false);
@@ -81,7 +81,7 @@ class DuplexOutgoingSession implements SyncSession, EventListener {
DuplexOutgoingSession(DatabaseComponent db, Executor dbExecutor,
EventBus eventBus, Clock clock, ContactId contactId, int maxLatency,
int maxIdleTime, RecordWriter recordWriter) {
int maxIdleTime, SyncRecordWriter recordWriter) {
this.db = db;
this.dbExecutor = dbExecutor;
this.eventBus = eventBus;
@@ -273,7 +273,7 @@ class DuplexOutgoingSession implements SyncSession, EventListener {
Transaction txn = db.startTransaction(false);
try {
b = db.generateRequestedBatch(txn, contactId,
MAX_RECORD_PAYLOAD_LENGTH, maxLatency);
MAX_RECORD_PAYLOAD_BYTES, maxLatency);
setNextSendTime(db.getNextSendTime(txn, contactId));
db.commitTransaction(txn);
} finally {

View File

@@ -20,6 +20,9 @@ import static org.briarproject.bramble.util.ByteUtils.INT_32_BYTES;
@NotNullByDefault
class GroupFactoryImpl implements GroupFactory {
private static final byte[] FORMAT_VERSION_BYTES =
new byte[] {FORMAT_VERSION};
private final CryptoComponent crypto;
@Inject
@@ -31,7 +34,7 @@ class GroupFactoryImpl implements GroupFactory {
public Group createGroup(ClientId c, int majorVersion, byte[] descriptor) {
byte[] majorVersionBytes = new byte[INT_32_BYTES];
ByteUtils.writeUint32(majorVersion, majorVersionBytes, 0);
byte[] hash = crypto.hash(LABEL, new byte[] {FORMAT_VERSION},
byte[] hash = crypto.hash(LABEL, FORMAT_VERSION_BYTES,
StringUtils.toUtf8(c.getString()), majorVersionBytes,
descriptor);
return new Group(new GroupId(hash), c, majorVersion, descriptor);

View File

@@ -16,8 +16,8 @@ import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.Message;
import org.briarproject.bramble.api.sync.Offer;
import org.briarproject.bramble.api.sync.RecordReader;
import org.briarproject.bramble.api.sync.Request;
import org.briarproject.bramble.api.sync.SyncRecordReader;
import org.briarproject.bramble.api.sync.SyncSession;
import java.io.IOException;
@@ -43,13 +43,13 @@ class IncomingSession implements SyncSession, EventListener {
private final Executor dbExecutor;
private final EventBus eventBus;
private final ContactId contactId;
private final RecordReader recordReader;
private final SyncRecordReader recordReader;
private volatile boolean interrupted = false;
IncomingSession(DatabaseComponent db, Executor dbExecutor,
EventBus eventBus, ContactId contactId,
RecordReader recordReader) {
SyncRecordReader recordReader) {
this.db = db;
this.dbExecutor = dbExecutor;
this.eventBus = eventBus;

View File

@@ -16,6 +16,7 @@ import static org.briarproject.bramble.api.sync.Message.FORMAT_VERSION;
import static org.briarproject.bramble.api.sync.MessageId.BLOCK_LABEL;
import static org.briarproject.bramble.api.sync.MessageId.ID_LABEL;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_BODY_LENGTH;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_LENGTH;
import static org.briarproject.bramble.api.sync.SyncConstants.MESSAGE_HEADER_LENGTH;
import static org.briarproject.bramble.util.ByteUtils.INT_64_BYTES;
@@ -23,6 +24,9 @@ import static org.briarproject.bramble.util.ByteUtils.INT_64_BYTES;
@NotNullByDefault
class MessageFactoryImpl implements MessageFactory {
private static final byte[] FORMAT_VERSION_BYTES =
new byte[] {FORMAT_VERSION};
private final CryptoComponent crypto;
@Inject
@@ -34,14 +38,7 @@ class MessageFactoryImpl implements MessageFactory {
public Message createMessage(GroupId g, long timestamp, byte[] body) {
if (body.length > MAX_MESSAGE_BODY_LENGTH)
throw new IllegalArgumentException();
byte[] versionBytes = new byte[] {FORMAT_VERSION};
// There's only one block, so the root hash is the hash of the block
byte[] rootHash = crypto.hash(BLOCK_LABEL, versionBytes, body);
byte[] timeBytes = new byte[INT_64_BYTES];
ByteUtils.writeUint64(timestamp, timeBytes, 0);
byte[] idHash = crypto.hash(ID_LABEL, versionBytes, g.getBytes(),
timeBytes, rootHash);
MessageId id = new MessageId(idHash);
MessageId id = getMessageId(g, timestamp, body);
byte[] raw = new byte[MESSAGE_HEADER_LENGTH + body.length];
System.arraycopy(g.getBytes(), 0, raw, 0, UniqueId.LENGTH);
ByteUtils.writeUint64(timestamp, raw, UniqueId.LENGTH);
@@ -49,10 +46,38 @@ class MessageFactoryImpl implements MessageFactory {
return new Message(id, g, timestamp, raw);
}
private MessageId getMessageId(GroupId g, long timestamp, byte[] body) {
// There's only one block, so the root hash is the hash of the block
byte[] rootHash = crypto.hash(BLOCK_LABEL, FORMAT_VERSION_BYTES, body);
byte[] timeBytes = new byte[INT_64_BYTES];
ByteUtils.writeUint64(timestamp, timeBytes, 0);
byte[] idHash = crypto.hash(ID_LABEL, FORMAT_VERSION_BYTES,
g.getBytes(), timeBytes, rootHash);
return new MessageId(idHash);
}
@Override
public Message createMessage(byte[] raw) {
if (raw.length < MESSAGE_HEADER_LENGTH)
throw new IllegalArgumentException();
if (raw.length > MAX_MESSAGE_LENGTH)
throw new IllegalArgumentException();
byte[] groupId = new byte[UniqueId.LENGTH];
System.arraycopy(raw, 0, groupId, 0, UniqueId.LENGTH);
GroupId g = new GroupId(groupId);
long timestamp = ByteUtils.readUint64(raw, UniqueId.LENGTH);
byte[] body = new byte[raw.length - MESSAGE_HEADER_LENGTH];
System.arraycopy(raw, MESSAGE_HEADER_LENGTH, body, 0, body.length);
MessageId id = getMessageId(g, timestamp, body);
return new Message(id, g, timestamp, raw);
}
@Override
public Message createMessage(MessageId m, byte[] raw) {
if (raw.length < MESSAGE_HEADER_LENGTH)
throw new IllegalArgumentException();
if (raw.length > MAX_MESSAGE_LENGTH)
throw new IllegalArgumentException();
byte[] groupId = new byte[UniqueId.LENGTH];
System.arraycopy(raw, 0, groupId, 0, UniqueId.LENGTH);
long timestamp = ByteUtils.readUint64(raw, UniqueId.LENGTH);

View File

@@ -1,28 +0,0 @@
package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.sync.MessageFactory;
import org.briarproject.bramble.api.sync.RecordReader;
import org.briarproject.bramble.api.sync.RecordReaderFactory;
import java.io.InputStream;
import javax.annotation.concurrent.Immutable;
import javax.inject.Inject;
@Immutable
@NotNullByDefault
class RecordReaderFactoryImpl implements RecordReaderFactory {
private final MessageFactory messageFactory;
@Inject
RecordReaderFactoryImpl(MessageFactory messageFactory) {
this.messageFactory = messageFactory;
}
@Override
public RecordReader createRecordReader(InputStream in) {
return new RecordReaderImpl(messageFactory, in);
}
}

View File

@@ -1,16 +0,0 @@
package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.sync.RecordWriter;
import org.briarproject.bramble.api.sync.RecordWriterFactory;
import java.io.OutputStream;
@NotNullByDefault
class RecordWriterFactoryImpl implements RecordWriterFactory {
@Override
public RecordWriter createRecordWriter(OutputStream out) {
return new RecordWriterImpl(out);
}
}

View File

@@ -13,7 +13,7 @@ import org.briarproject.bramble.api.lifecycle.IoExecutor;
import org.briarproject.bramble.api.lifecycle.event.LifecycleEvent;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.RecordWriter;
import org.briarproject.bramble.api.sync.SyncRecordWriter;
import org.briarproject.bramble.api.sync.SyncSession;
import java.io.IOException;
@@ -29,8 +29,8 @@ import javax.annotation.concurrent.ThreadSafe;
import static java.util.logging.Level.INFO;
import static java.util.logging.Level.WARNING;
import static org.briarproject.bramble.api.lifecycle.LifecycleManager.LifecycleState.STOPPING;
import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_IDS;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_RECORD_PAYLOAD_LENGTH;
/**
* An outgoing {@link SyncSession} suitable for simplex transports. The session
@@ -51,7 +51,7 @@ class SimplexOutgoingSession implements SyncSession, EventListener {
private final EventBus eventBus;
private final ContactId contactId;
private final int maxLatency;
private final RecordWriter recordWriter;
private final SyncRecordWriter recordWriter;
private final AtomicInteger outstandingQueries;
private final BlockingQueue<ThrowingRunnable<IOException>> writerTasks;
@@ -59,7 +59,7 @@ class SimplexOutgoingSession implements SyncSession, EventListener {
SimplexOutgoingSession(DatabaseComponent db, Executor dbExecutor,
EventBus eventBus, ContactId contactId,
int maxLatency, RecordWriter recordWriter) {
int maxLatency, SyncRecordWriter recordWriter) {
this.db = db;
this.dbExecutor = dbExecutor;
this.eventBus = eventBus;
@@ -171,7 +171,7 @@ class SimplexOutgoingSession implements SyncSession, EventListener {
Transaction txn = db.startTransaction(false);
try {
b = db.generateBatch(txn, contactId,
MAX_RECORD_PAYLOAD_LENGTH, maxLatency);
MAX_RECORD_PAYLOAD_BYTES, maxLatency);
db.commitTransaction(txn);
} finally {
db.endTransaction(txn);

View File

@@ -9,8 +9,8 @@ import org.briarproject.bramble.api.event.EventBus;
import org.briarproject.bramble.api.lifecycle.LifecycleManager;
import org.briarproject.bramble.api.sync.GroupFactory;
import org.briarproject.bramble.api.sync.MessageFactory;
import org.briarproject.bramble.api.sync.RecordReaderFactory;
import org.briarproject.bramble.api.sync.RecordWriterFactory;
import org.briarproject.bramble.api.sync.SyncRecordReaderFactory;
import org.briarproject.bramble.api.sync.SyncRecordWriterFactory;
import org.briarproject.bramble.api.sync.SyncSessionFactory;
import org.briarproject.bramble.api.sync.ValidationManager;
import org.briarproject.bramble.api.system.Clock;
@@ -52,22 +52,23 @@ public class SyncModule {
}
@Provides
RecordReaderFactory provideRecordReaderFactory(
RecordReaderFactoryImpl recordReaderFactory) {
SyncRecordReaderFactory provideRecordReaderFactory(
SyncRecordReaderFactoryImpl recordReaderFactory) {
return recordReaderFactory;
}
@Provides
RecordWriterFactory provideRecordWriterFactory() {
return new RecordWriterFactoryImpl();
SyncRecordWriterFactory provideRecordWriterFactory(
SyncRecordWriterFactoryImpl recordWriterFactory) {
return recordWriterFactory;
}
@Provides
@Singleton
SyncSessionFactory provideSyncSessionFactory(DatabaseComponent db,
@DatabaseExecutor Executor dbExecutor, EventBus eventBus,
Clock clock, RecordReaderFactory recordReaderFactory,
RecordWriterFactory recordWriterFactory) {
Clock clock, SyncRecordReaderFactory recordReaderFactory,
SyncRecordWriterFactory recordWriterFactory) {
return new SyncSessionFactoryImpl(db, dbExecutor, eventBus, clock,
recordReaderFactory, recordWriterFactory);
}

View File

@@ -0,0 +1,34 @@
package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.sync.MessageFactory;
import org.briarproject.bramble.api.sync.SyncRecordReader;
import org.briarproject.bramble.api.sync.SyncRecordReaderFactory;
import java.io.InputStream;
import javax.annotation.concurrent.Immutable;
import javax.inject.Inject;
@Immutable
@NotNullByDefault
class SyncRecordReaderFactoryImpl implements SyncRecordReaderFactory {
private final MessageFactory messageFactory;
private final RecordReaderFactory recordReaderFactory;
@Inject
SyncRecordReaderFactoryImpl(MessageFactory messageFactory,
RecordReaderFactory recordReaderFactory) {
this.messageFactory = messageFactory;
this.recordReaderFactory = recordReaderFactory;
}
@Override
public SyncRecordReader createRecordReader(InputStream in) {
RecordReader reader = recordReaderFactory.createRecordReader(in);
return new SyncRecordReaderImpl(messageFactory, reader);
}
}

View File

@@ -3,82 +3,56 @@ package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.UniqueId;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.GroupId;
import org.briarproject.bramble.api.sync.Message;
import org.briarproject.bramble.api.sync.MessageFactory;
import org.briarproject.bramble.api.sync.MessageId;
import org.briarproject.bramble.api.sync.Offer;
import org.briarproject.bramble.api.sync.RecordReader;
import org.briarproject.bramble.api.sync.Request;
import org.briarproject.bramble.api.sync.SyncRecordReader;
import org.briarproject.bramble.util.ByteUtils;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;
import javax.annotation.concurrent.NotThreadSafe;
import static org.briarproject.bramble.api.sync.RecordTypes.ACK;
import static org.briarproject.bramble.api.sync.RecordTypes.MESSAGE;
import static org.briarproject.bramble.api.sync.RecordTypes.OFFER;
import static org.briarproject.bramble.api.sync.RecordTypes.REQUEST;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_RECORD_PAYLOAD_LENGTH;
import static org.briarproject.bramble.api.sync.SyncConstants.MESSAGE_HEADER_LENGTH;
import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION;
import static org.briarproject.bramble.api.sync.SyncConstants.RECORD_HEADER_LENGTH;
@NotThreadSafe
@NotNullByDefault
class RecordReaderImpl implements RecordReader {
private enum State {BUFFER_EMPTY, BUFFER_FULL, EOF}
class SyncRecordReaderImpl implements SyncRecordReader {
private final MessageFactory messageFactory;
private final InputStream in;
private final byte[] header, payload;
private final RecordReader reader;
private State state = State.BUFFER_EMPTY;
private int payloadLength = 0;
@Nullable
private Record nextRecord = null;
private boolean eof = false;
RecordReaderImpl(MessageFactory messageFactory, InputStream in) {
SyncRecordReaderImpl(MessageFactory messageFactory, RecordReader reader) {
this.messageFactory = messageFactory;
this.in = in;
header = new byte[RECORD_HEADER_LENGTH];
payload = new byte[MAX_RECORD_PAYLOAD_LENGTH];
this.reader = reader;
}
private void readRecord() throws IOException {
if (state != State.BUFFER_EMPTY) throw new IllegalStateException();
assert nextRecord == null;
while (true) {
// Read the header
int offset = 0;
while (offset < RECORD_HEADER_LENGTH) {
int read =
in.read(header, offset, RECORD_HEADER_LENGTH - offset);
if (read == -1) {
if (offset > 0) throw new FormatException();
state = State.EOF;
return;
}
offset += read;
}
byte version = header[0], type = header[1];
payloadLength = ByteUtils.readUint16(header, 2);
nextRecord = reader.readRecord();
// Check the protocol version
byte version = nextRecord.getProtocolVersion();
if (version != PROTOCOL_VERSION) throw new FormatException();
// Check the payload length
if (payloadLength > MAX_RECORD_PAYLOAD_LENGTH)
throw new FormatException();
// Read the payload
offset = 0;
while (offset < payloadLength) {
int read = in.read(payload, offset, payloadLength - offset);
if (read == -1) throw new FormatException();
offset += read;
}
state = State.BUFFER_FULL;
byte type = nextRecord.getRecordType();
// Return if this is a known record type, otherwise continue
if (type == ACK || type == MESSAGE || type == OFFER ||
type == REQUEST) {
@@ -87,6 +61,11 @@ class RecordReaderImpl implements RecordReader {
}
}
private byte getNextRecordType() {
assert nextRecord != null;
return nextRecord.getRecordType();
}
/**
* Returns true if there's another record available or false if we've
* reached the end of the input stream.
@@ -97,14 +76,21 @@ class RecordReaderImpl implements RecordReader {
*/
@Override
public boolean eof() throws IOException {
if (state == State.BUFFER_EMPTY) readRecord();
if (state == State.BUFFER_EMPTY) throw new IllegalStateException();
return state == State.EOF;
if (nextRecord != null) return false;
if (eof) return true;
try {
readRecord();
return false;
} catch (EOFException e) {
nextRecord = null;
eof = true;
return true;
}
}
@Override
public boolean hasAck() throws IOException {
return !eof() && header[1] == ACK;
return !eof() && getNextRecordType() == ACK;
}
@Override
@@ -114,45 +100,41 @@ class RecordReaderImpl implements RecordReader {
}
private List<MessageId> readMessageIds() throws IOException {
if (payloadLength == 0) throw new FormatException();
if (payloadLength % UniqueId.LENGTH != 0) throw new FormatException();
List<MessageId> ids = new ArrayList<>();
for (int off = 0; off < payloadLength; off += UniqueId.LENGTH) {
assert nextRecord != null;
byte[] payload = nextRecord.getPayload();
if (payload.length == 0) throw new FormatException();
if (payload.length % UniqueId.LENGTH != 0) throw new FormatException();
List<MessageId> ids = new ArrayList<>(payload.length / UniqueId.LENGTH);
for (int off = 0; off < payload.length; off += UniqueId.LENGTH) {
byte[] id = new byte[UniqueId.LENGTH];
System.arraycopy(payload, off, id, 0, UniqueId.LENGTH);
ids.add(new MessageId(id));
}
state = State.BUFFER_EMPTY;
nextRecord = null;
return ids;
}
@Override
public boolean hasMessage() throws IOException {
return !eof() && header[1] == MESSAGE;
return !eof() && getNextRecordType() == MESSAGE;
}
@Override
public Message readMessage() throws IOException {
if (!hasMessage()) throw new FormatException();
if (payloadLength <= MESSAGE_HEADER_LENGTH) throw new FormatException();
// Group ID
byte[] id = new byte[UniqueId.LENGTH];
System.arraycopy(payload, 0, id, 0, UniqueId.LENGTH);
GroupId groupId = new GroupId(id);
// Timestamp
assert nextRecord != null;
byte[] payload = nextRecord.getPayload();
if (payload.length < MESSAGE_HEADER_LENGTH) throw new FormatException();
// Validate timestamp
long timestamp = ByteUtils.readUint64(payload, UniqueId.LENGTH);
if (timestamp < 0) throw new FormatException();
// Body
byte[] body = new byte[payloadLength - MESSAGE_HEADER_LENGTH];
System.arraycopy(payload, MESSAGE_HEADER_LENGTH, body, 0,
payloadLength - MESSAGE_HEADER_LENGTH);
state = State.BUFFER_EMPTY;
return messageFactory.createMessage(groupId, timestamp, body);
nextRecord = null;
return messageFactory.createMessage(payload);
}
@Override
public boolean hasOffer() throws IOException {
return !eof() && header[1] == OFFER;
return !eof() && getNextRecordType() == OFFER;
}
@Override
@@ -163,7 +145,7 @@ class RecordReaderImpl implements RecordReader {
@Override
public boolean hasRequest() throws IOException {
return !eof() && header[1] == REQUEST;
return !eof() && getNextRecordType() == REQUEST;
}
@Override

View File

@@ -0,0 +1,28 @@
package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import org.briarproject.bramble.api.sync.SyncRecordWriter;
import org.briarproject.bramble.api.sync.SyncRecordWriterFactory;
import java.io.OutputStream;
import javax.inject.Inject;
@NotNullByDefault
class SyncRecordWriterFactoryImpl implements SyncRecordWriterFactory {
private final RecordWriterFactory recordWriterFactory;
@Inject
SyncRecordWriterFactoryImpl(RecordWriterFactory recordWriterFactory) {
this.recordWriterFactory = recordWriterFactory;
}
@Override
public SyncRecordWriter createRecordWriter(OutputStream out) {
RecordWriter writer = recordWriterFactory.createRecordWriter(out);
return new SyncRecordWriterImpl(writer);
}
}

View File

@@ -1,81 +1,67 @@
package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.MessageId;
import org.briarproject.bramble.api.sync.Offer;
import org.briarproject.bramble.api.sync.RecordTypes;
import org.briarproject.bramble.api.sync.RecordWriter;
import org.briarproject.bramble.api.sync.Request;
import org.briarproject.bramble.util.ByteUtils;
import org.briarproject.bramble.api.sync.SyncRecordWriter;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import javax.annotation.concurrent.NotThreadSafe;
import static org.briarproject.bramble.api.sync.RecordTypes.ACK;
import static org.briarproject.bramble.api.sync.RecordTypes.MESSAGE;
import static org.briarproject.bramble.api.sync.RecordTypes.OFFER;
import static org.briarproject.bramble.api.sync.RecordTypes.REQUEST;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_RECORD_PAYLOAD_LENGTH;
import static org.briarproject.bramble.api.sync.SyncConstants.RECORD_HEADER_LENGTH;
import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION;
@NotThreadSafe
@NotNullByDefault
class RecordWriterImpl implements RecordWriter {
class SyncRecordWriterImpl implements SyncRecordWriter {
private final OutputStream out;
private final byte[] header;
private final ByteArrayOutputStream payload;
private final RecordWriter writer;
private final ByteArrayOutputStream payload = new ByteArrayOutputStream();
RecordWriterImpl(OutputStream out) {
this.out = out;
header = new byte[RECORD_HEADER_LENGTH];
header[0] = PROTOCOL_VERSION;
payload = new ByteArrayOutputStream(MAX_RECORD_PAYLOAD_LENGTH);
SyncRecordWriterImpl(RecordWriter writer) {
this.writer = writer;
}
private void writeRecord(byte recordType) throws IOException {
header[1] = recordType;
ByteUtils.writeUint16(payload.size(), header, 2);
out.write(header);
payload.writeTo(out);
writer.writeRecord(new Record(PROTOCOL_VERSION, recordType,
payload.toByteArray()));
payload.reset();
}
@Override
public void writeAck(Ack a) throws IOException {
if (payload.size() != 0) throw new IllegalStateException();
for (MessageId m : a.getMessageIds()) payload.write(m.getBytes());
writeRecord(ACK);
}
@Override
public void writeMessage(byte[] raw) throws IOException {
header[1] = RecordTypes.MESSAGE;
ByteUtils.writeUint16(raw.length, header, 2);
out.write(header);
out.write(raw);
writer.writeRecord(new Record(PROTOCOL_VERSION, MESSAGE, raw));
}
@Override
public void writeOffer(Offer o) throws IOException {
if (payload.size() != 0) throw new IllegalStateException();
for (MessageId m : o.getMessageIds()) payload.write(m.getBytes());
writeRecord(OFFER);
}
@Override
public void writeRequest(Request r) throws IOException {
if (payload.size() != 0) throw new IllegalStateException();
for (MessageId m : r.getMessageIds()) payload.write(m.getBytes());
writeRecord(REQUEST);
}
@Override
public void flush() throws IOException {
out.flush();
writer.flush();
}
}

View File

@@ -5,10 +5,10 @@ import org.briarproject.bramble.api.db.DatabaseComponent;
import org.briarproject.bramble.api.db.DatabaseExecutor;
import org.briarproject.bramble.api.event.EventBus;
import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
import org.briarproject.bramble.api.sync.RecordReader;
import org.briarproject.bramble.api.sync.RecordReaderFactory;
import org.briarproject.bramble.api.sync.RecordWriter;
import org.briarproject.bramble.api.sync.RecordWriterFactory;
import org.briarproject.bramble.api.sync.SyncRecordReader;
import org.briarproject.bramble.api.sync.SyncRecordReaderFactory;
import org.briarproject.bramble.api.sync.SyncRecordWriter;
import org.briarproject.bramble.api.sync.SyncRecordWriterFactory;
import org.briarproject.bramble.api.sync.SyncSession;
import org.briarproject.bramble.api.sync.SyncSessionFactory;
import org.briarproject.bramble.api.system.Clock;
@@ -28,14 +28,14 @@ class SyncSessionFactoryImpl implements SyncSessionFactory {
private final Executor dbExecutor;
private final EventBus eventBus;
private final Clock clock;
private final RecordReaderFactory recordReaderFactory;
private final RecordWriterFactory recordWriterFactory;
private final SyncRecordReaderFactory recordReaderFactory;
private final SyncRecordWriterFactory recordWriterFactory;
@Inject
SyncSessionFactoryImpl(DatabaseComponent db,
@DatabaseExecutor Executor dbExecutor, EventBus eventBus,
Clock clock, RecordReaderFactory recordReaderFactory,
RecordWriterFactory recordWriterFactory) {
Clock clock, SyncRecordReaderFactory recordReaderFactory,
SyncRecordWriterFactory recordWriterFactory) {
this.db = db;
this.dbExecutor = dbExecutor;
this.eventBus = eventBus;
@@ -46,14 +46,16 @@ class SyncSessionFactoryImpl implements SyncSessionFactory {
@Override
public SyncSession createIncomingSession(ContactId c, InputStream in) {
RecordReader recordReader = recordReaderFactory.createRecordReader(in);
SyncRecordReader recordReader =
recordReaderFactory.createRecordReader(in);
return new IncomingSession(db, dbExecutor, eventBus, c, recordReader);
}
@Override
public SyncSession createSimplexOutgoingSession(ContactId c,
int maxLatency, OutputStream out) {
RecordWriter recordWriter = recordWriterFactory.createRecordWriter(out);
SyncRecordWriter recordWriter =
recordWriterFactory.createRecordWriter(out);
return new SimplexOutgoingSession(db, dbExecutor, eventBus, c,
maxLatency, recordWriter);
}
@@ -61,7 +63,8 @@ class SyncSessionFactoryImpl implements SyncSessionFactory {
@Override
public SyncSession createDuplexOutgoingSession(ContactId c, int maxLatency,
int maxIdleTime, OutputStream out) {
RecordWriter recordWriter = recordWriterFactory.createRecordWriter(out);
SyncRecordWriter recordWriter =
recordWriterFactory.createRecordWriter(out);
return new DuplexOutgoingSession(db, dbExecutor, eventBus, clock, c,
maxLatency, maxIdleTime, recordWriter);
}

View File

@@ -5,23 +5,31 @@ import org.briarproject.bramble.api.plugin.TransportConnectionReader;
import org.briarproject.bramble.api.plugin.TransportConnectionWriter;
import org.briarproject.bramble.api.plugin.TransportId;
import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.record.RecordReaderFactory;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.api.record.RecordWriterFactory;
import org.briarproject.bramble.test.BrambleMockTestCase;
import org.briarproject.bramble.test.TestUtils;
import org.briarproject.bramble.util.ByteUtils;
import org.briarproject.bramble.test.CaptureArgumentAction;
import org.jmock.Expectations;
import org.jmock.lib.legacy.ClassImposteriser;
import org.junit.Test;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.concurrent.atomic.AtomicReference;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION;
import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_LENGTH;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM;
import static org.briarproject.bramble.api.keyagreement.RecordTypes.KEY;
import static org.briarproject.bramble.test.TestUtils.getRandomBytes;
import static org.briarproject.bramble.test.TestUtils.getTransportId;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
public class KeyAgreementTransportTest extends BrambleMockTestCase {
@@ -31,222 +39,268 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase {
context.mock(TransportConnectionReader.class);
private final TransportConnectionWriter transportConnectionWriter =
context.mock(TransportConnectionWriter.class);
private final RecordReaderFactory recordReaderFactory =
context.mock(RecordReaderFactory.class);
private final RecordWriterFactory recordWriterFactory =
context.mock(RecordWriterFactory.class);
private final RecordReader recordReader = context.mock(RecordReader.class);
private final RecordWriter recordWriter = context.mock(RecordWriter.class);
private final TransportId transportId = getTransportId();
private final KeyAgreementConnection keyAgreementConnection =
new KeyAgreementConnection(duplexTransportConnection, transportId);
private ByteArrayInputStream inputStream;
private ByteArrayOutputStream outputStream;
private final InputStream inputStream;
private final OutputStream outputStream;
private KeyAgreementTransport kat;
public KeyAgreementTransportTest() {
context.setImposteriser(ClassImposteriser.INSTANCE);
inputStream = context.mock(InputStream.class);
outputStream = context.mock(OutputStream.class);
}
@Test
public void testSendKey() throws Exception {
setup(new byte[0]);
byte[] key = TestUtils.getRandomBytes(123);
byte[] key = getRandomBytes(123);
setup();
AtomicReference<Record> written = expectWriteRecord();
kat.sendKey(key);
assertRecordSent(KEY, key);
assertNotNull(written.get());
assertRecordEquals(PROTOCOL_VERSION, KEY, key, written.get());
}
@Test
public void testSendConfirm() throws Exception {
setup(new byte[0]);
byte[] confirm = TestUtils.getRandomBytes(123);
byte[] confirm = getRandomBytes(123);
setup();
AtomicReference<Record> written = expectWriteRecord();
kat.sendConfirm(confirm);
assertRecordSent(CONFIRM, confirm);
assertNotNull(written.get());
assertRecordEquals(PROTOCOL_VERSION, CONFIRM, confirm, written.get());
}
@Test
public void testSendAbortWithException() throws Exception {
setup(new byte[0]);
setup();
AtomicReference<Record> written = expectWriteRecord();
context.checking(new Expectations() {{
oneOf(transportConnectionReader).dispose(true, true);
oneOf(transportConnectionWriter).dispose(true);
}});
kat.sendAbort(true);
assertRecordSent(ABORT, new byte[0]);
assertNotNull(written.get());
assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get());
}
@Test
public void testSendAbortWithoutException() throws Exception {
setup(new byte[0]);
setup();
AtomicReference<Record> written = expectWriteRecord();
context.checking(new Expectations() {{
oneOf(transportConnectionReader).dispose(false, true);
oneOf(transportConnectionWriter).dispose(false);
}});
kat.sendAbort(false);
assertRecordSent(ABORT, new byte[0]);
assertNotNull(written.get());
assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get());
}
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfAtEndOfStream()
throws Exception {
setup(new byte[0]);
kat.receiveKey();
}
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(throwException(new EOFException()));
}});
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfHeaderIsTooShort()
throws Exception {
byte[] input = new byte[RECORD_HEADER_LENGTH - 1];
input[0] = PROTOCOL_VERSION;
input[1] = KEY;
setup(input);
kat.receiveKey();
}
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfPayloadIsTooShort()
throws Exception {
int payloadLength = 123;
byte[] input = new byte[RECORD_HEADER_LENGTH + payloadLength - 1];
input[0] = PROTOCOL_VERSION;
input[1] = KEY;
ByteUtils.writeUint16(payloadLength, input, 2);
setup(input);
kat.receiveKey();
}
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfProtocolVersionIsUnrecognised()
throws Exception {
setup(createRecord((byte) (PROTOCOL_VERSION + 1), KEY, new byte[123]));
byte unknownVersion = (byte) (PROTOCOL_VERSION + 1);
byte[] key = getRandomBytes(123);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(unknownVersion, KEY, key)));
}});
kat.receiveKey();
}
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfAbortIsReceived()
throws Exception {
setup(createRecord(PROTOCOL_VERSION, ABORT, new byte[0]));
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0])));
}});
kat.receiveKey();
}
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfConfirmIsReceived()
throws Exception {
setup(createRecord(PROTOCOL_VERSION, CONFIRM, new byte[123]));
byte[] confirm = getRandomBytes(123);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(PROTOCOL_VERSION, CONFIRM, confirm)));
}});
kat.receiveKey();
}
@Test
public void testReceiveKeySkipsUnrecognisedRecordTypes() throws Exception {
byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 1),
new byte[123]);
byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 2),
new byte[0]);
byte[] payload = TestUtils.getRandomBytes(123);
byte[] key = createRecord(PROTOCOL_VERSION, KEY, payload);
ByteArrayOutputStream input = new ByteArrayOutputStream();
input.write(skip1);
input.write(skip2);
input.write(key);
setup(input.toByteArray());
assertArrayEquals(payload, kat.receiveKey());
byte type1 = (byte) (ABORT + 1);
byte[] payload1 = getRandomBytes(123);
Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1);
byte type2 = (byte) (ABORT + 2);
byte[] payload2 = new byte[0];
Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2);
byte[] key = getRandomBytes(123);
Record keyRecord = new Record(PROTOCOL_VERSION, KEY, key);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(unknownRecord1));
oneOf(recordReader).readRecord();
will(returnValue(unknownRecord2));
oneOf(recordReader).readRecord();
will(returnValue(keyRecord));
}});
assertArrayEquals(key, kat.receiveKey());
}
@Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfAtEndOfStream()
throws Exception {
setup(new byte[0]);
kat.receiveConfirm();
}
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(throwException(new EOFException()));
}});
@Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfHeaderIsTooShort()
throws Exception {
byte[] input = new byte[RECORD_HEADER_LENGTH - 1];
input[0] = PROTOCOL_VERSION;
input[1] = CONFIRM;
setup(input);
kat.receiveConfirm();
}
@Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfPayloadIsTooShort()
throws Exception {
int payloadLength = 123;
byte[] input = new byte[RECORD_HEADER_LENGTH + payloadLength - 1];
input[0] = PROTOCOL_VERSION;
input[1] = CONFIRM;
ByteUtils.writeUint16(payloadLength, input, 2);
setup(input);
kat.receiveConfirm();
}
@Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfProtocolVersionIsUnrecognised()
throws Exception {
setup(createRecord((byte) (PROTOCOL_VERSION + 1), CONFIRM,
new byte[123]));
byte unknownVersion = (byte) (PROTOCOL_VERSION + 1);
byte[] confirm = getRandomBytes(123);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(unknownVersion, CONFIRM, confirm)));
}});
kat.receiveConfirm();
}
@Test(expected = AbortException.class)
public void testReceiveConfirmThrowsExceptionIfAbortIsReceived()
throws Exception {
setup(createRecord(PROTOCOL_VERSION, ABORT, new byte[0]));
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0])));
}});
kat.receiveConfirm();
}
@Test(expected = AbortException.class)
public void testReceiveKeyThrowsExceptionIfKeyIsReceived()
public void testReceiveConfirmThrowsExceptionIfKeyIsReceived()
throws Exception {
setup(createRecord(PROTOCOL_VERSION, KEY, new byte[123]));
byte[] key = getRandomBytes(123);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(new Record(PROTOCOL_VERSION, KEY, key)));
}});
kat.receiveConfirm();
}
@Test
public void testReceiveConfirmSkipsUnrecognisedRecordTypes()
throws Exception {
byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 1),
new byte[123]);
byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 2),
new byte[0]);
byte[] payload = TestUtils.getRandomBytes(123);
byte[] confirm = createRecord(PROTOCOL_VERSION, CONFIRM, payload);
ByteArrayOutputStream input = new ByteArrayOutputStream();
input.write(skip1);
input.write(skip2);
input.write(confirm);
setup(input.toByteArray());
assertArrayEquals(payload, kat.receiveConfirm());
byte type1 = (byte) (ABORT + 1);
byte[] payload1 = getRandomBytes(123);
Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1);
byte type2 = (byte) (ABORT + 2);
byte[] payload2 = new byte[0];
Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2);
byte[] confirm = getRandomBytes(123);
Record confirmRecord = new Record(PROTOCOL_VERSION, CONFIRM, confirm);
setup();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(unknownRecord1));
oneOf(recordReader).readRecord();
will(returnValue(unknownRecord2));
oneOf(recordReader).readRecord();
will(returnValue(confirmRecord));
}});
assertArrayEquals(confirm, kat.receiveConfirm());
}
private void setup(byte[] input) throws Exception {
inputStream = new ByteArrayInputStream(input);
outputStream = new ByteArrayOutputStream();
private void setup() throws Exception {
context.checking(new Expectations() {{
allowing(duplexTransportConnection).getReader();
will(returnValue(transportConnectionReader));
allowing(transportConnectionReader).getInputStream();
will(returnValue(inputStream));
oneOf(recordReaderFactory).createRecordReader(inputStream);
will(returnValue(recordReader));
allowing(duplexTransportConnection).getWriter();
will(returnValue(transportConnectionWriter));
allowing(transportConnectionWriter).getOutputStream();
will(returnValue(outputStream));
oneOf(recordWriterFactory).createRecordWriter(outputStream);
will(returnValue(recordWriter));
}});
kat = new KeyAgreementTransport(keyAgreementConnection);
kat = new KeyAgreementTransport(recordReaderFactory,
recordWriterFactory, keyAgreementConnection);
}
private void assertRecordSent(byte expectedType, byte[] expectedPayload) {
byte[] output = outputStream.toByteArray();
assertEquals(RECORD_HEADER_LENGTH + expectedPayload.length,
output.length);
assertEquals(PROTOCOL_VERSION, output[0]);
assertEquals(expectedType, output[1]);
assertEquals(expectedPayload.length, ByteUtils.readUint16(output, 2));
byte[] payload = new byte[output.length - RECORD_HEADER_LENGTH];
System.arraycopy(output, RECORD_HEADER_LENGTH, payload, 0,
payload.length);
assertArrayEquals(expectedPayload, payload);
private AtomicReference<Record> expectWriteRecord() throws Exception {
AtomicReference<Record> captured = new AtomicReference<>();
context.checking(new Expectations() {{
oneOf(recordWriter).writeRecord(with(any(Record.class)));
will(new CaptureArgumentAction<>(captured, Record.class, 0));
oneOf(recordWriter).flush();
}});
return captured;
}
private byte[] createRecord(byte version, byte type, byte[] payload) {
byte[] b = new byte[RECORD_HEADER_LENGTH + payload.length];
b[0] = version;
b[1] = type;
ByteUtils.writeUint16(payload.length, b, 2);
System.arraycopy(payload, 0, b, RECORD_HEADER_LENGTH, payload.length);
return b;
private void assertRecordEquals(byte expectedVersion, byte expectedType,
byte[] expectedPayload, Record actual) {
assertEquals(expectedVersion, actual.getProtocolVersion());
assertEquals(expectedType, actual.getRecordType());
assertArrayEquals(expectedPayload, actual.getPayload());
}
}

View File

@@ -0,0 +1,102 @@
package org.briarproject.bramble.record;
import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.test.BrambleTestCase;
import org.briarproject.bramble.util.ByteUtils;
import org.junit.Test;
import java.io.ByteArrayInputStream;
import java.io.EOFException;
import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES;
import static org.briarproject.bramble.api.record.Record.RECORD_HEADER_BYTES;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
public class RecordReaderImplTest extends BrambleTestCase {
@Test
public void testAcceptsEmptyPayload() throws Exception {
// Version 1, type 2, payload length 0
byte[] header = new byte[] {1, 2, 0, 0};
ByteArrayInputStream in = new ByteArrayInputStream(header);
RecordReader reader = new RecordReaderImpl(in);
Record record = reader.readRecord();
assertEquals(1, record.getProtocolVersion());
assertEquals(2, record.getRecordType());
assertArrayEquals(new byte[0], record.getPayload());
}
@Test
public void testAcceptsMaxLengthPayload() throws Exception {
byte[] record =
new byte[RECORD_HEADER_BYTES + MAX_RECORD_PAYLOAD_BYTES];
// Version 1, type 2, payload length MAX_RECORD_PAYLOAD_BYTES
record[0] = 1;
record[1] = 2;
ByteUtils.writeUint16(MAX_RECORD_PAYLOAD_BYTES, record, 2);
ByteArrayInputStream in = new ByteArrayInputStream(record);
RecordReader reader = new RecordReaderImpl(in);
reader.readRecord();
}
@Test(expected = FormatException.class)
public void testFormatExceptionIfPayloadLengthIsNegative()
throws Exception {
// Version 1, type 2, payload length -1
byte[] header = new byte[] {1, 2, (byte) 0xFF, (byte) 0xFF};
ByteArrayInputStream in = new ByteArrayInputStream(header);
RecordReader reader = new RecordReaderImpl(in);
reader.readRecord();
}
@Test(expected = FormatException.class)
public void testFormatExceptionIfPayloadLengthIsTooLarge()
throws Exception {
// Version 1, type 2, payload length MAX_RECORD_PAYLOAD_BYTES + 1
byte[] header = new byte[] {1, 2, 0, 0};
ByteUtils.writeUint16(MAX_RECORD_PAYLOAD_BYTES + 1, header, 2);
ByteArrayInputStream in = new ByteArrayInputStream(header);
RecordReader reader = new RecordReaderImpl(in);
reader.readRecord();
}
@Test(expected = EOFException.class)
public void testEofExceptionIfProtocolVersionIsMissing() throws Exception {
ByteArrayInputStream in = new ByteArrayInputStream(new byte[0]);
RecordReader reader = new RecordReaderImpl(in);
reader.readRecord();
}
@Test(expected = EOFException.class)
public void testEofExceptionIfRecordTypeIsMissing() throws Exception {
ByteArrayInputStream in = new ByteArrayInputStream(new byte[1]);
RecordReader reader = new RecordReaderImpl(in);
reader.readRecord();
}
@Test(expected = EOFException.class)
public void testEofExceptionIfPayloadLengthIsMissing() throws Exception {
ByteArrayInputStream in = new ByteArrayInputStream(new byte[2]);
RecordReader reader = new RecordReaderImpl(in);
reader.readRecord();
}
@Test(expected = EOFException.class)
public void testEofExceptionIfPayloadLengthIsTruncated() throws Exception {
ByteArrayInputStream in = new ByteArrayInputStream(new byte[3]);
RecordReader reader = new RecordReaderImpl(in);
reader.readRecord();
}
@Test(expected = EOFException.class)
public void testEofExceptionIfPayloadIsTruncated() throws Exception {
// Version 0, type 0, payload length 1
byte[] header = new byte[] {0, 0, 0, 1};
ByteArrayInputStream in = new ByteArrayInputStream(header);
RecordReader reader = new RecordReaderImpl(in);
reader.readRecord();
}
}

View File

@@ -0,0 +1,49 @@
package org.briarproject.bramble.record;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordWriter;
import org.briarproject.bramble.test.BrambleTestCase;
import org.briarproject.bramble.util.ByteUtils;
import org.junit.Test;
import java.io.ByteArrayOutputStream;
import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES;
import static org.briarproject.bramble.api.record.Record.RECORD_HEADER_BYTES;
import static org.briarproject.bramble.test.TestUtils.getRandomBytes;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
public class RecordWriterImplTest extends BrambleTestCase {
@Test
public void testWritesEmptyRecord() throws Exception {
testWritesRecord(0);
}
@Test
public void testWritesMaxLengthRecord() throws Exception {
testWritesRecord(MAX_RECORD_PAYLOAD_BYTES);
}
private void testWritesRecord(int payloadLength) throws Exception {
byte protocolVersion = 123;
byte recordType = 45;
byte[] payload = getRandomBytes(payloadLength);
ByteArrayOutputStream out = new ByteArrayOutputStream();
RecordWriter writer = new RecordWriterImpl(out);
writer.writeRecord(new Record(protocolVersion, recordType, payload));
writer.flush();
byte[] written = out.toByteArray();
assertEquals(RECORD_HEADER_BYTES + payloadLength, written.length);
assertEquals(protocolVersion, written[0]);
assertEquals(recordType, written[1]);
assertEquals(payloadLength, ByteUtils.readUint16(written, 2));
byte[] writtenPayload = new byte[payloadLength];
System.arraycopy(written, RECORD_HEADER_BYTES, writtenPayload, 0,
payloadLength);
assertArrayEquals(payload, writtenPayload);
}
}

View File

@@ -1,219 +0,0 @@
package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.UniqueId;
import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.MessageFactory;
import org.briarproject.bramble.test.BrambleMockTestCase;
import org.briarproject.bramble.test.TestUtils;
import org.briarproject.bramble.util.ByteUtils;
import org.junit.Test;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import static org.briarproject.bramble.api.sync.RecordTypes.ACK;
import static org.briarproject.bramble.api.sync.RecordTypes.OFFER;
import static org.briarproject.bramble.api.sync.RecordTypes.REQUEST;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_IDS;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_RECORD_PAYLOAD_LENGTH;
import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION;
import static org.briarproject.bramble.api.sync.SyncConstants.RECORD_HEADER_LENGTH;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
public class RecordReaderImplTest extends BrambleMockTestCase {
private final MessageFactory messageFactory =
context.mock(MessageFactory.class);
@Test(expected = FormatException.class)
public void testFormatExceptionIfAckIsTooLarge() throws Exception {
byte[] b = createAck(true);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readAck();
}
@Test
public void testNoFormatExceptionIfAckIsMaximumSize() throws Exception {
byte[] b = createAck(false);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readAck();
}
@Test(expected = FormatException.class)
public void testFormatExceptionIfAckIsEmpty() throws Exception {
byte[] b = createEmptyAck();
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readAck();
}
@Test(expected = FormatException.class)
public void testFormatExceptionIfOfferIsTooLarge() throws Exception {
byte[] b = createOffer(true);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readOffer();
}
@Test
public void testNoFormatExceptionIfOfferIsMaximumSize() throws Exception {
byte[] b = createOffer(false);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readOffer();
}
@Test(expected = FormatException.class)
public void testFormatExceptionIfOfferIsEmpty() throws Exception {
byte[] b = createEmptyOffer();
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readOffer();
}
@Test(expected = FormatException.class)
public void testFormatExceptionIfRequestIsTooLarge() throws Exception {
byte[] b = createRequest(true);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readRequest();
}
@Test
public void testNoFormatExceptionIfRequestIsMaximumSize() throws Exception {
byte[] b = createRequest(false);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readRequest();
}
@Test(expected = FormatException.class)
public void testFormatExceptionIfRequestIsEmpty() throws Exception {
byte[] b = createEmptyRequest();
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readRequest();
}
@Test
public void testEofReturnsTrueWhenAtEndOfStream() throws Exception {
ByteArrayInputStream in = new ByteArrayInputStream(new byte[0]);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
assertTrue(reader.eof());
}
@Test
public void testEofReturnsFalseWhenNotAtEndOfStream() throws Exception {
byte[] b = createAck(false);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
assertFalse(reader.eof());
}
@Test(expected = FormatException.class)
public void testThrowsExceptionIfHeaderIsTooShort() throws Exception {
byte[] b = new byte[RECORD_HEADER_LENGTH - 1];
b[0] = PROTOCOL_VERSION;
b[1] = ACK;
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.eof();
}
@Test(expected = FormatException.class)
public void testThrowsExceptionIfPayloadIsTooShort() throws Exception {
int payloadLength = 123;
byte[] b = new byte[RECORD_HEADER_LENGTH + payloadLength - 1];
b[0] = PROTOCOL_VERSION;
b[1] = ACK;
ByteUtils.writeUint16(payloadLength, b, 2);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.eof();
}
@Test(expected = FormatException.class)
public void testThrowsExceptionIfProtocolVersionIsUnrecognised()
throws Exception {
byte version = (byte) (PROTOCOL_VERSION + 1);
byte[] b = createRecord(version, ACK, new byte[0]);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.eof();
}
@Test(expected = FormatException.class)
public void testThrowsExceptionIfPayloadIsTooLong() throws Exception {
byte[] payload = new byte[MAX_RECORD_PAYLOAD_LENGTH + 1];
byte[] b = createRecord(PROTOCOL_VERSION, ACK, payload);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.eof();
}
@Test
public void testSkipsUnrecognisedRecordTypes() throws Exception {
byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (REQUEST + 1),
new byte[123]);
byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (REQUEST + 2),
new byte[0]);
byte[] ack = createAck(false);
ByteArrayOutputStream input = new ByteArrayOutputStream();
input.write(skip1);
input.write(skip2);
input.write(ack);
ByteArrayInputStream in = new ByteArrayInputStream(input.toByteArray());
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
assertTrue(reader.hasAck());
Ack a = reader.readAck();
assertEquals(MAX_MESSAGE_IDS, a.getMessageIds().size());
}
private byte[] createAck(boolean tooBig) throws Exception {
return createRecord(PROTOCOL_VERSION, ACK, createPayload(tooBig));
}
private byte[] createEmptyAck() throws Exception {
return createRecord(PROTOCOL_VERSION, ACK, new byte[0]);
}
private byte[] createOffer(boolean tooBig) throws Exception {
return createRecord(PROTOCOL_VERSION, OFFER, createPayload(tooBig));
}
private byte[] createEmptyOffer() throws Exception {
return createRecord(PROTOCOL_VERSION, OFFER, new byte[0]);
}
private byte[] createRequest(boolean tooBig) throws Exception {
return createRecord(PROTOCOL_VERSION, REQUEST, createPayload(tooBig));
}
private byte[] createEmptyRequest() throws Exception {
return createRecord(PROTOCOL_VERSION, REQUEST, new byte[0]);
}
private byte[] createRecord(byte version, byte type, byte[] payload) {
byte[] b = new byte[RECORD_HEADER_LENGTH + payload.length];
b[0] = version;
b[1] = type;
ByteUtils.writeUint16(payload.length, b, 2);
System.arraycopy(payload, 0, b, RECORD_HEADER_LENGTH, payload.length);
return b;
}
private byte[] createPayload(boolean tooBig) throws Exception {
ByteArrayOutputStream payload = new ByteArrayOutputStream();
while (payload.size() + UniqueId.LENGTH <= MAX_RECORD_PAYLOAD_LENGTH) {
payload.write(TestUtils.getRandomId());
}
if (tooBig) payload.write(TestUtils.getRandomId());
assertEquals(tooBig, payload.size() > MAX_RECORD_PAYLOAD_LENGTH);
return payload.toByteArray();
}
}

View File

@@ -6,7 +6,7 @@ import org.briarproject.bramble.api.db.Transaction;
import org.briarproject.bramble.api.event.EventBus;
import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.MessageId;
import org.briarproject.bramble.api.sync.RecordWriter;
import org.briarproject.bramble.api.sync.SyncRecordWriter;
import org.briarproject.bramble.test.BrambleTestCase;
import org.briarproject.bramble.test.ImmediateExecutor;
import org.briarproject.bramble.test.TestUtils;
@@ -29,14 +29,14 @@ public class SimplexOutgoingSessionTest extends BrambleTestCase {
private final ContactId contactId;
private final MessageId messageId;
private final int maxLatency;
private final RecordWriter recordWriter;
private final SyncRecordWriter recordWriter;
public SimplexOutgoingSessionTest() {
context = new Mockery();
db = context.mock(DatabaseComponent.class);
dbExecutor = new ImmediateExecutor();
eventBus = context.mock(EventBus.class);
recordWriter = context.mock(RecordWriter.class);
recordWriter = context.mock(SyncRecordWriter.class);
contactId = new ContactId(234);
messageId = new MessageId(TestUtils.getRandomId());
maxLatency = Integer.MAX_VALUE;

View File

@@ -12,11 +12,11 @@ import org.briarproject.bramble.api.sync.Message;
import org.briarproject.bramble.api.sync.MessageFactory;
import org.briarproject.bramble.api.sync.MessageId;
import org.briarproject.bramble.api.sync.Offer;
import org.briarproject.bramble.api.sync.RecordReader;
import org.briarproject.bramble.api.sync.RecordReaderFactory;
import org.briarproject.bramble.api.sync.RecordWriter;
import org.briarproject.bramble.api.sync.RecordWriterFactory;
import org.briarproject.bramble.api.sync.Request;
import org.briarproject.bramble.api.sync.SyncRecordReader;
import org.briarproject.bramble.api.sync.SyncRecordReaderFactory;
import org.briarproject.bramble.api.sync.SyncRecordWriter;
import org.briarproject.bramble.api.sync.SyncRecordWriterFactory;
import org.briarproject.bramble.api.transport.StreamContext;
import org.briarproject.bramble.api.transport.StreamReaderFactory;
import org.briarproject.bramble.api.transport.StreamWriterFactory;
@@ -54,9 +54,9 @@ public class SyncIntegrationTest extends BrambleTestCase {
@Inject
StreamWriterFactory streamWriterFactory;
@Inject
RecordReaderFactory recordReaderFactory;
SyncRecordReaderFactory recordReaderFactory;
@Inject
RecordWriterFactory recordWriterFactory;
SyncRecordWriterFactory recordWriterFactory;
@Inject
TransportCrypto transportCrypto;
@@ -104,7 +104,7 @@ public class SyncIntegrationTest extends BrambleTestCase {
headerKey, streamNumber);
OutputStream streamWriter = streamWriterFactory.createStreamWriter(out,
ctx);
RecordWriter recordWriter = recordWriterFactory.createRecordWriter(
SyncRecordWriter recordWriter = recordWriterFactory.createRecordWriter(
streamWriter);
recordWriter.writeAck(new Ack(messageIds));
@@ -112,8 +112,8 @@ public class SyncIntegrationTest extends BrambleTestCase {
recordWriter.writeMessage(message1.getRaw());
recordWriter.writeOffer(new Offer(messageIds));
recordWriter.writeRequest(new Request(messageIds));
recordWriter.flush();
streamWriter.flush();
return out.toByteArray();
}
@@ -134,7 +134,7 @@ public class SyncIntegrationTest extends BrambleTestCase {
headerKey, streamNumber);
InputStream streamReader = streamReaderFactory.createStreamReader(in,
ctx);
RecordReader recordReader = recordReaderFactory.createRecordReader(
SyncRecordReader recordReader = recordReaderFactory.createRecordReader(
streamReader);
// Read the ack

View File

@@ -1,6 +1,7 @@
package org.briarproject.bramble.sync;
import org.briarproject.bramble.crypto.CryptoModule;
import org.briarproject.bramble.record.RecordModule;
import org.briarproject.bramble.system.SystemModule;
import org.briarproject.bramble.test.TestSecureRandomModule;
import org.briarproject.bramble.transport.TransportModule;
@@ -13,6 +14,7 @@ import dagger.Component;
@Component(modules = {
TestSecureRandomModule.class,
CryptoModule.class,
RecordModule.class,
SyncModule.class,
SystemModule.class,
TransportModule.class

View File

@@ -0,0 +1,195 @@
package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.UniqueId;
import org.briarproject.bramble.api.record.Record;
import org.briarproject.bramble.api.record.RecordReader;
import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.MessageFactory;
import org.briarproject.bramble.api.sync.Offer;
import org.briarproject.bramble.api.sync.Request;
import org.briarproject.bramble.api.sync.SyncRecordReader;
import org.briarproject.bramble.test.BrambleMockTestCase;
import org.jmock.Expectations;
import org.junit.Test;
import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES;
import static org.briarproject.bramble.api.sync.RecordTypes.ACK;
import static org.briarproject.bramble.api.sync.RecordTypes.OFFER;
import static org.briarproject.bramble.api.sync.RecordTypes.REQUEST;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_IDS;
import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION;
import static org.briarproject.bramble.test.TestUtils.getRandomBytes;
import static org.briarproject.bramble.test.TestUtils.getRandomId;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
public class SyncRecordReaderImplTest extends BrambleMockTestCase {
private final MessageFactory messageFactory =
context.mock(MessageFactory.class);
private final RecordReader recordReader = context.mock(RecordReader.class);
@Test
public void testNoFormatExceptionIfAckIsMaximumSize() throws Exception {
expectReadRecord(createAck());
SyncRecordReader reader =
new SyncRecordReaderImpl(messageFactory, recordReader);
Ack ack = reader.readAck();
assertEquals(MAX_MESSAGE_IDS, ack.getMessageIds().size());
}
@Test(expected = FormatException.class)
public void testFormatExceptionIfAckIsEmpty() throws Exception {
expectReadRecord(createEmptyAck());
SyncRecordReader reader =
new SyncRecordReaderImpl(messageFactory, recordReader);
reader.readAck();
}
@Test
public void testNoFormatExceptionIfOfferIsMaximumSize() throws Exception {
expectReadRecord(createOffer());
SyncRecordReader reader =
new SyncRecordReaderImpl(messageFactory, recordReader);
Offer offer = reader.readOffer();
assertEquals(MAX_MESSAGE_IDS, offer.getMessageIds().size());
}
@Test(expected = FormatException.class)
public void testFormatExceptionIfOfferIsEmpty() throws Exception {
expectReadRecord(createEmptyOffer());
SyncRecordReader reader =
new SyncRecordReaderImpl(messageFactory, recordReader);
reader.readOffer();
}
@Test
public void testNoFormatExceptionIfRequestIsMaximumSize() throws Exception {
expectReadRecord(createRequest());
SyncRecordReader reader =
new SyncRecordReaderImpl(messageFactory, recordReader);
Request request = reader.readRequest();
assertEquals(MAX_MESSAGE_IDS, request.getMessageIds().size());
}
@Test(expected = FormatException.class)
public void testFormatExceptionIfRequestIsEmpty() throws Exception {
expectReadRecord(createEmptyRequest());
SyncRecordReader reader =
new SyncRecordReaderImpl(messageFactory, recordReader);
reader.readRequest();
}
@Test
public void testEofReturnsTrueWhenAtEndOfStream() throws Exception {
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(throwException(new EOFException()));
}});
SyncRecordReader reader =
new SyncRecordReaderImpl(messageFactory, recordReader);
assertTrue(reader.eof());
assertTrue(reader.eof());
}
@Test
public void testEofReturnsFalseWhenNotAtEndOfStream() throws Exception {
expectReadRecord(createAck());
SyncRecordReader reader =
new SyncRecordReaderImpl(messageFactory, recordReader);
assertFalse(reader.eof());
assertFalse(reader.eof());
}
@Test(expected = FormatException.class)
public void testThrowsExceptionIfProtocolVersionIsUnrecognised()
throws Exception {
byte version = (byte) (PROTOCOL_VERSION + 1);
byte[] payload = getRandomId();
expectReadRecord(new Record(version, ACK, payload));
SyncRecordReader reader =
new SyncRecordReaderImpl(messageFactory, recordReader);
reader.eof();
}
@Test
public void testSkipsUnrecognisedRecordTypes() throws Exception {
byte type1 = (byte) (REQUEST + 1);
byte[] payload1 = getRandomBytes(123);
Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1);
byte type2 = (byte) (REQUEST + 2);
byte[] payload2 = new byte[0];
Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2);
Record ackRecord = createAck();
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(unknownRecord1));
oneOf(recordReader).readRecord();
will(returnValue(unknownRecord2));
oneOf(recordReader).readRecord();
will(returnValue(ackRecord));
}});
SyncRecordReader reader =
new SyncRecordReaderImpl(messageFactory, recordReader);
assertTrue(reader.hasAck());
Ack a = reader.readAck();
assertEquals(MAX_MESSAGE_IDS, a.getMessageIds().size());
}
private void expectReadRecord(Record record) throws Exception {
context.checking(new Expectations() {{
oneOf(recordReader).readRecord();
will(returnValue(record));
}});
}
private Record createAck() throws Exception {
return new Record(PROTOCOL_VERSION, ACK, createPayload());
}
private Record createEmptyAck() throws Exception {
return new Record(PROTOCOL_VERSION, ACK, new byte[0]);
}
private Record createOffer() throws Exception {
return new Record(PROTOCOL_VERSION, OFFER, createPayload());
}
private Record createEmptyOffer() throws Exception {
return new Record(PROTOCOL_VERSION, OFFER, new byte[0]);
}
private Record createRequest() throws Exception {
return new Record(PROTOCOL_VERSION, REQUEST, createPayload());
}
private Record createEmptyRequest() throws Exception {
return new Record(PROTOCOL_VERSION, REQUEST, new byte[0]);
}
private byte[] createPayload() throws Exception {
ByteArrayOutputStream payload = new ByteArrayOutputStream();
while (payload.size() + UniqueId.LENGTH <= MAX_RECORD_PAYLOAD_BYTES) {
payload.write(getRandomId());
}
return payload.toByteArray();
}
}

View File

@@ -10,6 +10,7 @@ import org.briarproject.bramble.event.EventModule;
import org.briarproject.bramble.identity.IdentityModule;
import org.briarproject.bramble.lifecycle.LifecycleModule;
import org.briarproject.bramble.properties.PropertiesModule;
import org.briarproject.bramble.record.RecordModule;
import org.briarproject.bramble.sync.SyncModule;
import org.briarproject.bramble.system.SystemModule;
import org.briarproject.bramble.test.TestDatabaseModule;
@@ -52,6 +53,7 @@ import dagger.Component;
MessagingModule.class,
PrivateGroupModule.class,
PropertiesModule.class,
RecordModule.class,
SharingModule.class,
SyncModule.class,
SystemModule.class,

View File

@@ -26,7 +26,7 @@ import javax.inject.Inject;
import static org.briarproject.bramble.api.identity.AuthorConstants.MAX_AUTHOR_NAME_LENGTH;
import static org.briarproject.bramble.api.identity.AuthorConstants.MAX_PUBLIC_KEY_LENGTH;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_RECORD_PAYLOAD_LENGTH;
import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES;
import static org.briarproject.bramble.test.TestUtils.getRandomId;
import static org.briarproject.bramble.util.StringUtils.getRandomString;
import static org.briarproject.briar.api.forum.ForumConstants.MAX_FORUM_POST_BODY_LENGTH;
@@ -63,7 +63,7 @@ public class MessageSizeIntegrationTest extends BriarTestCase {
int length = message.getMessage().getRaw().length;
assertTrue(
length > UniqueId.LENGTH + 8 + MAX_PRIVATE_MESSAGE_BODY_LENGTH);
assertTrue(length <= MAX_RECORD_PAYLOAD_LENGTH);
assertTrue(length <= MAX_RECORD_PAYLOAD_BYTES);
}
@Test
@@ -87,7 +87,7 @@ public class MessageSizeIntegrationTest extends BriarTestCase {
assertTrue(length > UniqueId.LENGTH + 8 + UniqueId.LENGTH + 4
+ MAX_AUTHOR_NAME_LENGTH + MAX_PUBLIC_KEY_LENGTH
+ MAX_FORUM_POST_BODY_LENGTH);
assertTrue(length <= MAX_RECORD_PAYLOAD_LENGTH);
assertTrue(length <= MAX_RECORD_PAYLOAD_BYTES);
}
private static void injectEagerSingletons(

View File

@@ -16,6 +16,7 @@ import org.briarproject.bramble.db.DatabaseModule;
import org.briarproject.bramble.event.EventModule;
import org.briarproject.bramble.identity.IdentityModule;
import org.briarproject.bramble.lifecycle.LifecycleModule;
import org.briarproject.bramble.record.RecordModule;
import org.briarproject.bramble.sync.SyncModule;
import org.briarproject.bramble.system.SystemModule;
import org.briarproject.bramble.test.TestCryptoExecutorModule;
@@ -48,6 +49,7 @@ import dagger.Component;
IdentityModule.class,
LifecycleModule.class,
MessagingModule.class,
RecordModule.class,
SyncModule.class,
SystemModule.class,
TransportModule.class,

View File

@@ -19,6 +19,7 @@ import org.briarproject.bramble.event.EventModule;
import org.briarproject.bramble.identity.IdentityModule;
import org.briarproject.bramble.lifecycle.LifecycleModule;
import org.briarproject.bramble.properties.PropertiesModule;
import org.briarproject.bramble.record.RecordModule;
import org.briarproject.bramble.sync.SyncModule;
import org.briarproject.bramble.system.SystemModule;
import org.briarproject.bramble.test.TestDatabaseModule;
@@ -73,6 +74,7 @@ import dagger.Component;
MessagingModule.class,
PrivateGroupModule.class,
PropertiesModule.class,
RecordModule.class,
SharingModule.class,
SyncModule.class,
SystemModule.class,