Store group details in the database. Some tests are still failing...

This commit is contained in:
akwizgran
2011-07-23 01:29:18 +01:00
parent de648daca5
commit 0edcb31d64
21 changed files with 322 additions and 147 deletions

View File

@@ -11,6 +11,7 @@ import net.sf.briar.api.protocol.AckWriter;
import net.sf.briar.api.protocol.AuthorId; import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchWriter; import net.sf.briar.api.protocol.BatchWriter;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.SubscriptionWriter; import net.sf.briar.api.protocol.SubscriptionWriter;
@@ -81,7 +82,7 @@ public interface DatabaseComponent {
Rating getRating(AuthorId a) throws DbException; Rating getRating(AuthorId a) throws DbException;
/** Returns the set of groups to which the user subscribes. */ /** Returns the set of groups to which the user subscribes. */
Collection<GroupId> getSubscriptions() throws DbException; Collection<Group> getSubscriptions() throws DbException;
/** Returns the local transport details. */ /** Returns the local transport details. */
Map<String, String> getTransports() throws DbException; Map<String, String> getTransports() throws DbException;
@@ -108,7 +109,7 @@ public interface DatabaseComponent {
void setRating(AuthorId a, Rating r) throws DbException; void setRating(AuthorId a, Rating r) throws DbException;
/** Subscribes to the given group. */ /** Subscribes to the given group. */
void subscribe(GroupId g) throws DbException; void subscribe(Group g) throws DbException;
/** /**
* Unsubscribes from the given group. Any messages belonging to the group * Unsubscribes from the given group. Any messages belonging to the group

View File

@@ -5,7 +5,10 @@ import java.security.PublicKey;
/** A group to which users may subscribe. */ /** A group to which users may subscribe. */
public interface Group { public interface Group {
/** Returns the name of the group. */ /** Returns the group's unique identifier. */
GroupId getId();
/** Returns the group's name. */
String getName(); String getName();
/** /**

View File

@@ -0,0 +1,6 @@
package net.sf.briar.api.protocol;
public interface GroupFactory {
Group createGroup(GroupId id, String name, byte[] salt, byte[] publicKey);
}

View File

@@ -5,7 +5,6 @@ import java.io.IOException;
/** An interface for creating a subscription update. */ /** An interface for creating a subscription update. */
public interface SubscriptionWriter { public interface SubscriptionWriter {
// FIXME: This should work with groups, not IDs
/** Sets the contents of the update. */ /** Sets the contents of the update. */
void setSubscriptions(Iterable<GroupId> subs) throws IOException; void setSubscriptions(Iterable<Group> subs) throws IOException;
} }

View File

@@ -5,9 +5,8 @@ import java.util.Collection;
/** A packet updating the sender's subscriptions. */ /** A packet updating the sender's subscriptions. */
public interface Subscriptions { public interface Subscriptions {
// FIXME: This should work with groups, not IDs
/** Returns the subscriptions contained in the update. */ /** Returns the subscriptions contained in the update. */
Collection<GroupId> getSubscriptions(); Collection<Group> getSubscriptions();
/** /**
* Returns the update's timestamp. Updates that are older than the newest * Returns the update's timestamp. Updates that are older than the newest

View File

@@ -0,0 +1,23 @@
package net.sf.briar.crypto;
import java.security.NoSuchAlgorithmException;
import net.sf.briar.api.crypto.KeyParser;
import com.google.inject.AbstractModule;
public class CryptoModule extends AbstractModule {
public static final String DIGEST_ALGO = "SHA-256";
public static final String KEY_PAIR_ALGO = "RSA";
public static final String SIGNATURE_ALGO = "SHA256withRSA";
@Override
protected void configure() {
try {
bind(KeyParser.class).toInstance(new KeyParserImpl(KEY_PAIR_ALGO));
} catch(NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}
}

View File

@@ -0,0 +1,26 @@
package net.sf.briar.crypto;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.spec.EncodedKeySpec;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;
import net.sf.briar.api.crypto.KeyParser;
public class KeyParserImpl implements KeyParser {
private final KeyFactory keyFactory;
KeyParserImpl(String algorithm) throws NoSuchAlgorithmException {
keyFactory = KeyFactory.getInstance(algorithm);
}
public PublicKey parsePublicKey(byte[] encodedKey)
throws InvalidKeySpecException {
EncodedKeySpec e = new X509EncodedKeySpec(encodedKey);
return keyFactory.generatePublic(e);
}
}

View File

@@ -9,6 +9,7 @@ import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.Status; import net.sf.briar.api.db.Status;
import net.sf.briar.api.protocol.AuthorId; import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
@@ -105,7 +106,7 @@ interface Database<T> {
* <p> * <p>
* Locking: subscriptions write. * Locking: subscriptions write.
*/ */
void addSubscription(T txn, GroupId g) throws DbException; void addSubscription(T txn, Group g) throws DbException;
/** /**
* Returns true iff the database contains the given contact. * Returns true iff the database contains the given contact.
@@ -235,14 +236,14 @@ interface Database<T> {
* <p> * <p>
* Locking: subscriptions read. * Locking: subscriptions read.
*/ */
Collection<GroupId> getSubscriptions(T txn) throws DbException; Collection<Group> getSubscriptions(T txn) throws DbException;
/** /**
* Returns the groups to which the given contact subscribes. * Returns the groups to which the given contact subscribes.
* <p> * <p>
* Locking: contacts read, subscriptions read. * Locking: contacts read, subscriptions read.
*/ */
Collection<GroupId> getSubscriptions(T txn, ContactId c) throws DbException; Collection<Group> getSubscriptions(T txn, ContactId c) throws DbException;
/** /**
* Returns the local transport details. * Returns the local transport details.
@@ -335,7 +336,7 @@ interface Database<T> {
* <p> * <p>
* Locking: contacts write, subscriptions write. * Locking: contacts write, subscriptions write.
*/ */
void setSubscriptions(T txn, ContactId c, Collection<GroupId> subs, void setSubscriptions(T txn, ContactId c, Collection<Group> subs,
long timestamp) throws DbException; long timestamp) throws DbException;
/** /**

View File

@@ -13,6 +13,7 @@ import java.util.logging.Logger;
import net.sf.briar.api.crypto.Password; import net.sf.briar.api.crypto.Password;
import net.sf.briar.api.db.DatabasePassword; import net.sf.briar.api.db.DatabasePassword;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.GroupFactory;
import org.apache.commons.io.FileSystemUtils; import org.apache.commons.io.FileSystemUtils;
@@ -30,8 +31,9 @@ class H2Database extends JdbcDatabase {
private final long maxSize; private final long maxSize;
@Inject @Inject
H2Database(File dir, @DatabasePassword Password password, long maxSize) { H2Database(File dir, @DatabasePassword Password password, long maxSize,
super("BINARY(32)", "BIGINT"); GroupFactory groupFactory) {
super(groupFactory, "BINARY(32)", "BIGINT", "BINARY");
home = new File(dir, "db"); home = new File(dir, "db");
this.password = password; this.password = password;
url = "jdbc:h2:split:" + home.getPath() url = "jdbc:h2:split:" + home.getPath()

View File

@@ -2,6 +2,7 @@ package net.sf.briar.db;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.File; import java.io.File;
import java.security.PublicKey;
import java.sql.Blob; import java.sql.Blob;
import java.sql.Connection; import java.sql.Connection;
import java.sql.PreparedStatement; import java.sql.PreparedStatement;
@@ -23,6 +24,8 @@ import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.Status; import net.sf.briar.api.db.Status;
import net.sf.briar.api.protocol.AuthorId; import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
@@ -37,6 +40,9 @@ abstract class JdbcDatabase implements Database<Connection> {
private static final String CREATE_LOCAL_SUBSCRIPTIONS = private static final String CREATE_LOCAL_SUBSCRIPTIONS =
"CREATE TABLE localSubscriptions" "CREATE TABLE localSubscriptions"
+ " (groupId HASH NOT NULL," + " (groupId HASH NOT NULL,"
+ " name VARCHAR NOT NULL,"
+ " salt BINARY,"
+ " publicKey BINARY,"
+ " PRIMARY KEY (groupId))"; + " PRIMARY KEY (groupId))";
private static final String CREATE_MESSAGES = private static final String CREATE_MESSAGES =
@@ -84,6 +90,9 @@ abstract class JdbcDatabase implements Database<Connection> {
"CREATE TABLE contactSubscriptions" "CREATE TABLE contactSubscriptions"
+ " (contactId INT NOT NULL," + " (contactId INT NOT NULL,"
+ " groupId HASH NOT NULL," + " groupId HASH NOT NULL,"
+ " name VARCHAR NOT NULL,"
+ " salt BINARY,"
+ " publicKey BINARY,"
+ " PRIMARY KEY (contactId, groupId)," + " PRIMARY KEY (contactId, groupId),"
+ " FOREIGN KEY (contactId) REFERENCES contacts (contactId)" + " FOREIGN KEY (contactId) REFERENCES contacts (contactId)"
+ " ON DELETE CASCADE)"; + " ON DELETE CASCADE)";
@@ -157,7 +166,8 @@ abstract class JdbcDatabase implements Database<Connection> {
Logger.getLogger(JdbcDatabase.class.getName()); Logger.getLogger(JdbcDatabase.class.getName());
// Different database libraries use different names for certain types // Different database libraries use different names for certain types
private final String hashType, timestampType; private final String hashType, timestampType, binaryType;
private final GroupFactory groupFactory;
private final LinkedList<Connection> connections = private final LinkedList<Connection> connections =
new LinkedList<Connection>(); // Locking: self new LinkedList<Connection>(); // Locking: self
@@ -166,9 +176,12 @@ abstract class JdbcDatabase implements Database<Connection> {
protected abstract Connection createConnection() throws SQLException; protected abstract Connection createConnection() throws SQLException;
JdbcDatabase(String hashType, String timestampType) { JdbcDatabase(GroupFactory groupFactory, String hashType,
String timestampType, String binaryType) {
this.groupFactory = groupFactory;
this.hashType = hashType; this.hashType = hashType;
this.timestampType = timestampType; this.timestampType = timestampType;
this.binaryType = binaryType;
} }
protected void open(boolean resume, File dir, String driverClass) protected void open(boolean resume, File dir, String driverClass)
@@ -255,7 +268,9 @@ abstract class JdbcDatabase implements Database<Connection> {
private String insertTypeNames(String s) { private String insertTypeNames(String s) {
s = s.replaceAll("HASH", hashType); s = s.replaceAll("HASH", hashType);
return s.replaceAll("TIMESTAMP", timestampType); s = s.replaceAll("TIMESTAMP", timestampType);
s = s.replaceAll("BINARY", binaryType);
return s;
} }
private void tryToClose(Connection c) { private void tryToClose(Connection c) {
@@ -512,12 +527,18 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
} }
public void addSubscription(Connection txn, GroupId g) throws DbException { public void addSubscription(Connection txn, Group g) throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
try { try {
String sql = "INSERT INTO localSubscriptions (groupId) VALUES (?)"; String sql = "INSERT INTO localSubscriptions"
+ " (groupId, name, salt, publicKey)"
+ " VALUES (?, ?, ?, ?)";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setBytes(1, g.getBytes()); ps.setBytes(1, g.getId().getBytes());
ps.setString(2, g.getName());
ps.setBytes(3, g.getSalt());
PublicKey k = g.getPublicKey();
ps.setBytes(4, k == null ? null : k.getEncoded());
int rowsAffected = ps.executeUpdate(); int rowsAffected = ps.executeUpdate();
assert rowsAffected == 1; assert rowsAffected == 1;
ps.close(); ps.close();
@@ -964,19 +985,26 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
} }
public Collection<GroupId> getSubscriptions(Connection txn) public Collection<Group> getSubscriptions(Connection txn)
throws DbException { throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null; ResultSet rs = null;
try { try {
String sql = "SELECT groupId FROM localSubscriptions"; String sql = "SELECT (groupId, name, salt, publicKey)"
+ " FROM localSubscriptions";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
rs = ps.executeQuery(); rs = ps.executeQuery();
Collection<GroupId> ids = new ArrayList<GroupId>(); Collection<Group> subs = new ArrayList<Group>();
while(rs.next()) ids.add(new GroupId(rs.getBytes(1))); while(rs.next()) {
GroupId id = new GroupId(rs.getBytes(1));
String name = rs.getString(2);
byte[] salt = rs.getBytes(3);
byte[] publicKey = rs.getBytes(4);
subs.add(groupFactory.createGroup(id, name, salt, publicKey));
}
rs.close(); rs.close();
ps.close(); ps.close();
return ids; return subs;
} catch(SQLException e) { } catch(SQLException e) {
tryToClose(rs); tryToClose(rs);
tryToClose(ps); tryToClose(ps);
@@ -985,21 +1013,28 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
} }
public Collection<GroupId> getSubscriptions(Connection txn, ContactId c) public Collection<Group> getSubscriptions(Connection txn, ContactId c)
throws DbException { throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null; ResultSet rs = null;
try { try {
String sql = "SELECT groupId FROM contactSubscriptions" String sql = "SELECT (groupId, name, salt, publicKey)"
+ " FROM contactSubscriptions"
+ " WHERE contactId = ?"; + " WHERE contactId = ?";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); ps.setInt(1, c.getInt());
rs = ps.executeQuery(); rs = ps.executeQuery();
Collection<GroupId> ids = new ArrayList<GroupId>(); Collection<Group> subs = new ArrayList<Group>();
while(rs.next()) ids.add(new GroupId(rs.getBytes(1))); while(rs.next()) {
GroupId id = new GroupId(rs.getBytes(1));
String name = rs.getString(2);
byte[] salt = rs.getBytes(3);
byte[] publicKey = rs.getBytes(4);
subs.add(groupFactory.createGroup(id, name, salt, publicKey));
}
rs.close(); rs.close();
ps.close(); ps.close();
return ids; return subs;
} catch(SQLException e) { } catch(SQLException e) {
tryToClose(rs); tryToClose(rs);
tryToClose(ps); tryToClose(ps);
@@ -1329,7 +1364,7 @@ abstract class JdbcDatabase implements Database<Connection> {
} }
public void setSubscriptions(Connection txn, ContactId c, public void setSubscriptions(Connection txn, ContactId c,
Collection<GroupId> subs, long timestamp) throws DbException { Collection<Group> subs, long timestamp) throws DbException {
PreparedStatement ps = null; PreparedStatement ps = null;
ResultSet rs = null; ResultSet rs = null;
try { try {
@@ -1354,12 +1389,17 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.executeUpdate(); ps.executeUpdate();
ps.close(); ps.close();
// Store the new subscriptions // Store the new subscriptions
sql = "INSERT INTO contactSubscriptions (contactId, groupId)" sql = "INSERT INTO contactSubscriptions"
+ " VALUES (?, ?)"; + "(contactId, groupId, name, salt, publicKey)"
+ " VALUES (?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql); ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt()); ps.setInt(1, c.getInt());
for(GroupId g : subs) { for(Group g : subs) {
ps.setBytes(2, g.getBytes()); ps.setBytes(2, g.getId().getBytes());
ps.setString(3, g.getName());
ps.setBytes(4, g.getSalt());
PublicKey k = g.getPublicKey();
ps.setBytes(5, k == null ? null : k.getEncoded());
ps.addBatch(); ps.addBatch();
} }
int[] rowsAffectedArray = ps.executeBatch(); int[] rowsAffectedArray = ps.executeBatch();

View File

@@ -21,6 +21,7 @@ import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.BatchWriter; import net.sf.briar.api.protocol.BatchWriter;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
@@ -347,8 +348,7 @@ class ReadWriteLockDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
try { try {
Txn txn = db.startTransaction(); Txn txn = db.startTransaction();
try { try {
// FIXME: This should deal in Groups, not GroupIds Collection<Group> subs = db.getSubscriptions(txn);
Collection<GroupId> subs = db.getSubscriptions(txn);
s.setSubscriptions(subs); s.setSubscriptions(subs);
if(LOG.isLoggable(Level.FINE)) if(LOG.isLoggable(Level.FINE))
LOG.fine("Added " + subs.size() + " subscriptions"); LOG.fine("Added " + subs.size() + " subscriptions");
@@ -431,12 +431,12 @@ class ReadWriteLockDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
} }
} }
public Collection<GroupId> getSubscriptions() throws DbException { public Collection<Group> getSubscriptions() throws DbException {
subscriptionLock.readLock().lock(); subscriptionLock.readLock().lock();
try { try {
Txn txn = db.startTransaction(); Txn txn = db.startTransaction();
try { try {
Collection<GroupId> subs = db.getSubscriptions(txn); Collection<Group> subs = db.getSubscriptions(txn);
db.commitTransaction(txn); db.commitTransaction(txn);
return subs; return subs;
} catch(DbException e) { } catch(DbException e) {
@@ -575,7 +575,7 @@ class ReadWriteLockDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
try { try {
Txn txn = db.startTransaction(); Txn txn = db.startTransaction();
try { try {
Collection<GroupId> subs = s.getSubscriptions(); Collection<Group> subs = s.getSubscriptions();
db.setSubscriptions(txn, c, subs, s.getTimestamp()); db.setSubscriptions(txn, c, subs, s.getTimestamp());
if(LOG.isLoggable(Level.FINE)) if(LOG.isLoggable(Level.FINE))
LOG.fine("Received " + subs.size() + " subscriptions"); LOG.fine("Received " + subs.size() + " subscriptions");
@@ -678,7 +678,7 @@ class ReadWriteLockDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
} }
} }
public void subscribe(GroupId g) throws DbException { public void subscribe(Group g) throws DbException {
if(LOG.isLoggable(Level.FINE)) LOG.fine("Subscribing to " + g); if(LOG.isLoggable(Level.FINE)) LOG.fine("Subscribing to " + g);
subscriptionLock.writeLock().lock(); subscriptionLock.writeLock().lock();
try { try {

View File

@@ -20,6 +20,7 @@ import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.BatchWriter; import net.sf.briar.api.protocol.BatchWriter;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
@@ -252,8 +253,7 @@ class SynchronizedDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
synchronized(subscriptionLock) { synchronized(subscriptionLock) {
Txn txn = db.startTransaction(); Txn txn = db.startTransaction();
try { try {
// FIXME: This should deal in Groups, not GroupIds Collection<Group> subs = db.getSubscriptions(txn);
Collection<GroupId> subs = db.getSubscriptions(txn);
s.setSubscriptions(subs); s.setSubscriptions(subs);
if(LOG.isLoggable(Level.FINE)) if(LOG.isLoggable(Level.FINE))
LOG.fine("Added " + subs.size() + " subscriptions"); LOG.fine("Added " + subs.size() + " subscriptions");
@@ -320,11 +320,11 @@ class SynchronizedDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
} }
} }
public Collection<GroupId> getSubscriptions() throws DbException { public Collection<Group> getSubscriptions() throws DbException {
synchronized(subscriptionLock) { synchronized(subscriptionLock) {
Txn txn = db.startTransaction(); Txn txn = db.startTransaction();
try { try {
Collection<GroupId> subs = db.getSubscriptions(txn); Collection<Group> subs = db.getSubscriptions(txn);
db.commitTransaction(txn); db.commitTransaction(txn);
return subs; return subs;
} catch(DbException e) { } catch(DbException e) {
@@ -429,7 +429,7 @@ class SynchronizedDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
synchronized(subscriptionLock) { synchronized(subscriptionLock) {
Txn txn = db.startTransaction(); Txn txn = db.startTransaction();
try { try {
Collection<GroupId> subs = s.getSubscriptions(); Collection<Group> subs = s.getSubscriptions();
db.setSubscriptions(txn, c, subs, s.getTimestamp()); db.setSubscriptions(txn, c, subs, s.getTimestamp());
if(LOG.isLoggable(Level.FINE)) if(LOG.isLoggable(Level.FINE))
LOG.fine("Received " + subs.size() + " subscriptions"); LOG.fine("Received " + subs.size() + " subscriptions");
@@ -504,7 +504,7 @@ class SynchronizedDatabaseComponent<Txn> extends DatabaseComponentImpl<Txn> {
} }
} }
public void subscribe(GroupId g) throws DbException { public void subscribe(Group g) throws DbException {
if(LOG.isLoggable(Level.FINE)) LOG.fine("Subscribing to " + g); if(LOG.isLoggable(Level.FINE)) LOG.fine("Subscribing to " + g);
synchronized(subscriptionLock) { synchronized(subscriptionLock) {
Txn txn = db.startTransaction(); Txn txn = db.startTransaction();

View File

@@ -0,0 +1,38 @@
package net.sf.briar.protocol;
import java.security.PublicKey;
import java.security.spec.InvalidKeySpecException;
import net.sf.briar.api.crypto.KeyParser;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.GroupId;
import com.google.inject.Inject;
class GroupFactoryImpl implements GroupFactory {
private final KeyParser keyParser;
@Inject
GroupFactoryImpl(KeyParser keyParser) {
this.keyParser = keyParser;
}
public Group createGroup(GroupId id, String name, byte[] salt,
byte[] publicKey) {
if(salt == null && publicKey == null)
throw new IllegalArgumentException();
if(salt != null && publicKey != null)
throw new IllegalArgumentException();
PublicKey key = null;
if(publicKey != null) {
try {
key = keyParser.parsePublicKey(publicKey);
} catch (InvalidKeySpecException e) {
throw new IllegalArgumentException(e);
}
}
return new GroupImpl(id, name, salt, key);
}
}

View File

@@ -0,0 +1,42 @@
package net.sf.briar.protocol;
import java.security.PublicKey;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupId;
public class GroupImpl implements Group {
private final GroupId id;
private final String name;
private final byte[] salt;
private final PublicKey publicKey;
GroupImpl(GroupId id, String name, byte[] salt, PublicKey publicKey) {
assert salt == null || publicKey == null;
this.id = id;
this.name = name;
this.salt = salt;
this.publicKey = publicKey;
}
public GroupId getId() {
return id;
}
public String getName() {
return name;
}
public boolean isRestricted() {
return salt == null;
}
public byte[] getSalt() {
return salt;
}
public PublicKey getPublicKey() {
return publicKey;
}
}

View File

@@ -1,11 +1,15 @@
package net.sf.briar.protocol; package net.sf.briar.protocol;
import net.sf.briar.api.protocol.GroupFactory;
import com.google.inject.AbstractModule; import com.google.inject.AbstractModule;
public class ProtocolModule extends AbstractModule { public class ProtocolModule extends AbstractModule {
@Override @Override
protected void configure() { protected void configure() {
bind(AckFactory.class).to(AckFactoryImpl.class);
bind(BatchFactory.class).to(BatchFactoryImpl.class); bind(BatchFactory.class).to(BatchFactoryImpl.class);
bind(GroupFactory.class).to(GroupFactoryImpl.class);
} }
} }

View File

@@ -17,6 +17,7 @@ import net.sf.briar.api.protocol.Ack;
import net.sf.briar.api.protocol.AckWriter; import net.sf.briar.api.protocol.AckWriter;
import net.sf.briar.api.protocol.AuthorId; import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
@@ -72,6 +73,7 @@ public abstract class DatabaseComponentTest extends TestCase {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
final Database<Object> database = context.mock(Database.class); final Database<Object> database = context.mock(Database.class);
final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class);
final Group group = context.mock(Group.class);
context.checking(new Expectations() {{ context.checking(new Expectations() {{
allowing(database).startTransaction(); allowing(database).startTransaction();
will(returnValue(txn)); will(returnValue(txn));
@@ -93,8 +95,8 @@ public abstract class DatabaseComponentTest extends TestCase {
will(returnValue(true)); will(returnValue(true));
oneOf(database).getTransports(txn, contactId); oneOf(database).getTransports(txn, contactId);
will(returnValue(transports)); will(returnValue(transports));
// subscribe(groupId) // subscribe(group)
oneOf(database).addSubscription(txn, groupId); oneOf(database).addSubscription(txn, group);
// getSubscriptions() // getSubscriptions()
oneOf(database).getSubscriptions(txn); oneOf(database).getSubscriptions(txn);
will(returnValue(subs)); will(returnValue(subs));
@@ -113,7 +115,7 @@ public abstract class DatabaseComponentTest extends TestCase {
assertEquals(contactId, db.addContact(transports)); assertEquals(contactId, db.addContact(transports));
assertEquals(contacts, db.getContacts()); assertEquals(contacts, db.getContacts());
assertEquals(transports, db.getTransports(contactId)); assertEquals(transports, db.getTransports(contactId));
db.subscribe(groupId); db.subscribe(group);
assertEquals(subs, db.getSubscriptions()); assertEquals(subs, db.getSubscriptions());
db.unsubscribe(groupId); db.unsubscribe(groupId);
db.removeContact(contactId); db.removeContact(contactId);

View File

@@ -21,17 +21,22 @@ import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.Status; import net.sf.briar.api.db.Status;
import net.sf.briar.api.protocol.AuthorId; import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.Group;
import net.sf.briar.api.protocol.GroupFactory;
import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.protocol.ProtocolModule;
import org.apache.commons.io.FileSystemUtils; import org.apache.commons.io.FileSystemUtils;
import org.jmock.Expectations;
import org.jmock.Mockery;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import com.google.inject.Guice;
import com.google.inject.Injector;
public class H2DatabaseTest extends TestCase { public class H2DatabaseTest extends TestCase {
private static final int ONE_MEGABYTE = 1024 * 1024; private static final int ONE_MEGABYTE = 1024 * 1024;
@@ -40,7 +45,10 @@ public class H2DatabaseTest extends TestCase {
private final File testDir = TestUtils.getTestDirectory(); private final File testDir = TestUtils.getTestDirectory();
// The password has the format <file password> <space> <user password> // The password has the format <file password> <space> <user password>
private final String passwordString = "foo bar"; private final String passwordString = "foo bar";
private final Password password = new TestPassword();
private final Random random = new Random(); private final Random random = new Random();
private final GroupFactory groupFactory;
private final AuthorId authorId; private final AuthorId authorId;
private final BatchId batchId; private final BatchId batchId;
private final ContactId contactId; private final ContactId contactId;
@@ -50,9 +58,13 @@ public class H2DatabaseTest extends TestCase {
private final int size; private final int size;
private final byte[] raw; private final byte[] raw;
private final Message message; private final Message message;
private final Group group;
public H2DatabaseTest() { public H2DatabaseTest() throws Exception {
super(); super();
Injector i = Guice.createInjector(new ProtocolModule(),
new CryptoModule());
groupFactory = i.getInstance(GroupFactory.class);
authorId = new AuthorId(TestUtils.getRandomId()); authorId = new AuthorId(TestUtils.getRandomId());
batchId = new BatchId(TestUtils.getRandomId()); batchId = new BatchId(TestUtils.getRandomId());
contactId = new ContactId(1); contactId = new ContactId(1);
@@ -64,6 +76,8 @@ public class H2DatabaseTest extends TestCase {
random.nextBytes(raw); random.nextBytes(raw);
message = new TestMessage(messageId, MessageId.NONE, groupId, authorId, message = new TestMessage(messageId, MessageId.NONE, groupId, authorId,
timestamp, raw); timestamp, raw);
group = groupFactory.createGroup(groupId, "Group name",
TestUtils.getRandomId(), null);
} }
@Before @Before
@@ -81,7 +95,7 @@ public class H2DatabaseTest extends TestCase {
assertEquals(contactId, db.addContact(txn, transports)); assertEquals(contactId, db.addContact(txn, transports));
assertTrue(db.containsContact(txn, contactId)); assertTrue(db.containsContact(txn, contactId));
assertFalse(db.containsSubscription(txn, groupId)); assertFalse(db.containsSubscription(txn, groupId));
db.addSubscription(txn, groupId); db.addSubscription(txn, group);
assertTrue(db.containsSubscription(txn, groupId)); assertTrue(db.containsSubscription(txn, groupId));
assertFalse(db.containsMessage(txn, messageId)); assertFalse(db.containsMessage(txn, messageId));
db.addMessage(txn, message); db.addMessage(txn, message);
@@ -158,7 +172,7 @@ public class H2DatabaseTest extends TestCase {
db.setRating(txn, authorId, Rating.GOOD); db.setRating(txn, authorId, Rating.GOOD);
// Check that the rating was stored // Check that the rating was stored
assertEquals(Rating.GOOD, db.getRating(txn, authorId)); assertEquals(Rating.GOOD, db.getRating(txn, authorId));
db.commitTransaction(txn); db.commitTransaction(txn);
db.close(); db.close();
} }
@@ -169,7 +183,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Subscribe to a group and store a message // Subscribe to a group and store a message
db.addSubscription(txn, groupId); db.addSubscription(txn, group);
db.addMessage(txn, message); db.addMessage(txn, message);
// Unsubscribing from the group should delete the message // Unsubscribing from the group should delete the message
@@ -188,8 +202,8 @@ public class H2DatabaseTest extends TestCase {
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, null)); assertEquals(contactId, db.addContact(txn, null));
db.addSubscription(txn, groupId); db.addSubscription(txn, group);
db.setSubscriptions(txn, contactId, Collections.singleton(groupId), 1); db.setSubscriptions(txn, contactId, Collections.singleton(group), 1);
db.addMessage(txn, message); db.addMessage(txn, message);
db.setStatus(txn, contactId, messageId, Status.NEW); db.setStatus(txn, contactId, messageId, Status.NEW);
@@ -221,8 +235,8 @@ public class H2DatabaseTest extends TestCase {
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, null)); assertEquals(contactId, db.addContact(txn, null));
db.addSubscription(txn, groupId); db.addSubscription(txn, group);
db.setSubscriptions(txn, contactId, Collections.singleton(groupId), 1); db.setSubscriptions(txn, contactId, Collections.singleton(group), 1);
db.addMessage(txn, message); db.addMessage(txn, message);
db.setSendability(txn, messageId, 1); db.setSendability(txn, messageId, 1);
@@ -258,7 +272,7 @@ public class H2DatabaseTest extends TestCase {
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, null)); assertEquals(contactId, db.addContact(txn, null));
db.addSubscription(txn, groupId); db.addSubscription(txn, group);
db.addMessage(txn, message); db.addMessage(txn, message);
db.setSendability(txn, messageId, 1); db.setSendability(txn, messageId, 1);
db.setStatus(txn, contactId, messageId, Status.NEW); db.setStatus(txn, contactId, messageId, Status.NEW);
@@ -269,13 +283,13 @@ public class H2DatabaseTest extends TestCase {
assertFalse(it.hasNext()); assertFalse(it.hasNext());
// The contact subscribing should make the message sendable // The contact subscribing should make the message sendable
db.setSubscriptions(txn, contactId, Collections.singleton(groupId), 1); db.setSubscriptions(txn, contactId, Collections.singleton(group), 1);
it = db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator(); it = db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator();
assertTrue(it.hasNext()); assertTrue(it.hasNext());
assertEquals(messageId, it.next()); assertEquals(messageId, it.next());
// The contact unsubscribing should make the message unsendable // The contact unsubscribing should make the message unsendable
db.setSubscriptions(txn, contactId, Collections.<GroupId>emptySet(), 2); db.setSubscriptions(txn, contactId, Collections.<Group>emptySet(), 2);
it = db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator(); it = db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator();
assertFalse(it.hasNext()); assertFalse(it.hasNext());
@@ -290,8 +304,8 @@ public class H2DatabaseTest extends TestCase {
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, null)); assertEquals(contactId, db.addContact(txn, null));
db.addSubscription(txn, groupId); db.addSubscription(txn, group);
db.setSubscriptions(txn, contactId, Collections.singleton(groupId), 1); db.setSubscriptions(txn, contactId, Collections.singleton(group), 1);
db.addMessage(txn, message); db.addMessage(txn, message);
db.setSendability(txn, messageId, 1); db.setSendability(txn, messageId, 1);
db.setStatus(txn, contactId, messageId, Status.NEW); db.setStatus(txn, contactId, messageId, Status.NEW);
@@ -346,8 +360,8 @@ public class H2DatabaseTest extends TestCase {
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, null)); assertEquals(contactId, db.addContact(txn, null));
db.addSubscription(txn, groupId); db.addSubscription(txn, group);
db.setSubscriptions(txn, contactId, Collections.singleton(groupId), 1); db.setSubscriptions(txn, contactId, Collections.singleton(group), 1);
db.addMessage(txn, message); db.addMessage(txn, message);
db.setSendability(txn, messageId, 1); db.setSendability(txn, messageId, 1);
db.setStatus(txn, contactId, messageId, Status.NEW); db.setStatus(txn, contactId, messageId, Status.NEW);
@@ -382,8 +396,8 @@ public class H2DatabaseTest extends TestCase {
// Add a contact, subscribe to a group and store a message // Add a contact, subscribe to a group and store a message
assertEquals(contactId, db.addContact(txn, null)); assertEquals(contactId, db.addContact(txn, null));
db.addSubscription(txn, groupId); db.addSubscription(txn, group);
db.setSubscriptions(txn, contactId, Collections.singleton(groupId), 1); db.setSubscriptions(txn, contactId, Collections.singleton(group), 1);
db.addMessage(txn, message); db.addMessage(txn, message);
db.setSendability(txn, messageId, 1); db.setSendability(txn, messageId, 1);
db.setStatus(txn, contactId, messageId, Status.NEW); db.setStatus(txn, contactId, messageId, Status.NEW);
@@ -493,7 +507,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Subscribe to a group and store two messages // Subscribe to a group and store two messages
db.addSubscription(txn, groupId); db.addSubscription(txn, group);
db.addMessage(txn, message); db.addMessage(txn, message);
db.addMessage(txn, message1); db.addMessage(txn, message1);
@@ -518,6 +532,8 @@ public class H2DatabaseTest extends TestCase {
MessageId childId2 = new MessageId(TestUtils.getRandomId()); MessageId childId2 = new MessageId(TestUtils.getRandomId());
MessageId childId3 = new MessageId(TestUtils.getRandomId()); MessageId childId3 = new MessageId(TestUtils.getRandomId());
GroupId groupId1 = new GroupId(TestUtils.getRandomId()); GroupId groupId1 = new GroupId(TestUtils.getRandomId());
Group group1 = groupFactory.createGroup(groupId1, "Another group name",
TestUtils.getRandomId(), null);
Message child1 = new TestMessage(childId1, messageId, groupId, Message child1 = new TestMessage(childId1, messageId, groupId,
authorId, timestamp, raw); authorId, timestamp, raw);
Message child2 = new TestMessage(childId2, messageId, groupId, Message child2 = new TestMessage(childId2, messageId, groupId,
@@ -529,8 +545,8 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Subscribe to the groups and store the messages // Subscribe to the groups and store the messages
db.addSubscription(txn, groupId); db.addSubscription(txn, group);
db.addSubscription(txn, groupId1); db.addSubscription(txn, group1);
db.addMessage(txn, message); db.addMessage(txn, message);
db.addMessage(txn, child1); db.addMessage(txn, child1);
db.addMessage(txn, child2); db.addMessage(txn, child2);
@@ -560,7 +576,7 @@ public class H2DatabaseTest extends TestCase {
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
// Subscribe to a group and store two messages // Subscribe to a group and store two messages
db.addSubscription(txn, groupId); db.addSubscription(txn, group);
db.addMessage(txn, message); db.addMessage(txn, message);
db.addMessage(txn, message1); db.addMessage(txn, message1);
@@ -598,7 +614,7 @@ public class H2DatabaseTest extends TestCase {
assertTrue(free > 0); assertTrue(free > 0);
// Storing a message should reduce the free space // Storing a message should reduce the free space
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
db.addSubscription(txn, groupId); db.addSubscription(txn, group);
db.addMessage(txn, message1); db.addMessage(txn, message1);
db.commitTransaction(txn); db.commitTransaction(txn);
assertTrue(db.getFreeSpace() < free); assertTrue(db.getFreeSpace() < free);
@@ -740,6 +756,9 @@ public class H2DatabaseTest extends TestCase {
@Test @Test
public void testUpdateSubscriptions() throws DbException { public void testUpdateSubscriptions() throws DbException {
GroupId groupId1 = new GroupId(TestUtils.getRandomId());
Group group1 = groupFactory.createGroup(groupId1, "Another group name",
TestUtils.getRandomId(), null);
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
@@ -747,19 +766,13 @@ public class H2DatabaseTest extends TestCase {
Map<String, String> transports = Collections.emptyMap(); Map<String, String> transports = Collections.emptyMap();
assertEquals(contactId, db.addContact(txn, transports)); assertEquals(contactId, db.addContact(txn, transports));
// Add some subscriptions // Add some subscriptions
Collection<GroupId> subs = new HashSet<GroupId>(); Collection<Group> subs = Collections.singletonList(group);
subs.add(new GroupId(TestUtils.getRandomId()));
subs.add(new GroupId(TestUtils.getRandomId()));
db.setSubscriptions(txn, contactId, subs, 1); db.setSubscriptions(txn, contactId, subs, 1);
assertEquals(subs, assertEquals(subs, db.getSubscriptions(txn, contactId));
new HashSet<GroupId>(db.getSubscriptions(txn, contactId)));
// Update the subscriptions // Update the subscriptions
Collection<GroupId> subs1 = new HashSet<GroupId>(); Collection<Group> subs1 = Collections.singletonList(group1);
subs1.add(new GroupId(TestUtils.getRandomId()));
subs1.add(new GroupId(TestUtils.getRandomId()));
db.setSubscriptions(txn, contactId, subs1, 2); db.setSubscriptions(txn, contactId, subs1, 2);
assertEquals(subs1, assertEquals(subs1, db.getSubscriptions(txn, contactId));
new HashSet<GroupId>(db.getSubscriptions(txn, contactId)));
db.commitTransaction(txn); db.commitTransaction(txn);
db.close(); db.close();
@@ -768,6 +781,9 @@ public class H2DatabaseTest extends TestCase {
@Test @Test
public void testSubscriptionsNotUpdatedIfTimestampIsOld() public void testSubscriptionsNotUpdatedIfTimestampIsOld()
throws DbException { throws DbException {
GroupId groupId1 = new GroupId(TestUtils.getRandomId());
Group group1 = groupFactory.createGroup(groupId1, "Another group name",
TestUtils.getRandomId(), null);
Database<Connection> db = open(false); Database<Connection> db = open(false);
Connection txn = db.startTransaction(); Connection txn = db.startTransaction();
@@ -775,39 +791,23 @@ public class H2DatabaseTest extends TestCase {
Map<String, String> transports = Collections.emptyMap(); Map<String, String> transports = Collections.emptyMap();
assertEquals(contactId, db.addContact(txn, transports)); assertEquals(contactId, db.addContact(txn, transports));
// Add some subscriptions // Add some subscriptions
Collection<GroupId> subs = new HashSet<GroupId>(); Collection<Group> subs = Collections.singletonList(group);
subs.add(new GroupId(TestUtils.getRandomId()));
subs.add(new GroupId(TestUtils.getRandomId()));
db.setSubscriptions(txn, contactId, subs, 2); db.setSubscriptions(txn, contactId, subs, 2);
assertEquals(subs, assertEquals(subs, db.getSubscriptions(txn, contactId));
new HashSet<GroupId>(db.getSubscriptions(txn, contactId)));
// Try to update the subscriptions using a timestamp of 1 // Try to update the subscriptions using a timestamp of 1
Collection<GroupId> subs1 = new HashSet<GroupId>(); Collection<Group> subs1 = Collections.singletonList(group1);
subs1.add(new GroupId(TestUtils.getRandomId()));
subs1.add(new GroupId(TestUtils.getRandomId()));
db.setSubscriptions(txn, contactId, subs1, 1); db.setSubscriptions(txn, contactId, subs1, 1);
// The old subscriptions should still be there // The old subscriptions should still be there
assertEquals(subs, assertEquals(subs, db.getSubscriptions(txn, contactId));
new HashSet<GroupId>(db.getSubscriptions(txn, contactId)));
db.commitTransaction(txn); db.commitTransaction(txn);
db.close(); db.close();
} }
private Database<Connection> open(boolean resume) throws DbException { private Database<Connection> open(boolean resume) throws DbException {
final char[] passwordArray = passwordString.toCharArray(); Database<Connection> db = new H2Database(testDir, password, MAX_SIZE,
Mockery context = new Mockery(); groupFactory);
final Password password = context.mock(Password.class);
context.checking(new Expectations() {{
oneOf(password).getPassword();
will(returnValue(passwordArray));
}});
Database<Connection> db = new H2Database(testDir, password, MAX_SIZE);
db.open(resume); db.open(resume);
context.assertIsSatisfied();
// The password array should be cleared after use
assertTrue(Arrays.equals(new char[passwordString.length()],
passwordArray));
return db; return db;
} }
@@ -815,4 +815,11 @@ public class H2DatabaseTest extends TestCase {
public void tearDown() { public void tearDown() {
TestUtils.deleteTestDirectory(testDir); TestUtils.deleteTestDirectory(testDir);
} }
private class TestPassword implements Password {
public char[] getPassword() {
return passwordString.toCharArray();
}
}
} }

View File

@@ -17,6 +17,7 @@ import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.serial.ReaderFactory;
import net.sf.briar.api.serial.Writer; import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory; import net.sf.briar.api.serial.WriterFactory;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
import org.jmock.Expectations; import org.jmock.Expectations;
@@ -29,8 +30,6 @@ import com.google.inject.Injector;
public class BatchReaderTest extends TestCase { public class BatchReaderTest extends TestCase {
private static final String DIGEST_ALGO = "SHA-256";
private final ReaderFactory readerFactory; private final ReaderFactory readerFactory;
private final WriterFactory writerFactory; private final WriterFactory writerFactory;
private final MessageDigest messageDigest; private final MessageDigest messageDigest;
@@ -42,7 +41,7 @@ public class BatchReaderTest extends TestCase {
Injector i = Guice.createInjector(new SerialModule()); Injector i = Guice.createInjector(new SerialModule());
readerFactory = i.getInstance(ReaderFactory.class); readerFactory = i.getInstance(ReaderFactory.class);
writerFactory = i.getInstance(WriterFactory.class); writerFactory = i.getInstance(WriterFactory.class);
messageDigest = MessageDigest.getInstance(DIGEST_ALGO); messageDigest = MessageDigest.getInstance(CryptoModule.DIGEST_ALGO);
context = new Mockery(); context = new Mockery();
message = context.mock(Message.class); message = context.mock(Message.class);
} }

View File

@@ -9,19 +9,18 @@ import java.util.Random;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.api.serial.FormatException; import net.sf.briar.api.serial.FormatException;
import net.sf.briar.crypto.CryptoModule;
import org.junit.Test; import org.junit.Test;
public class ConsumersTest extends TestCase { public class ConsumersTest extends TestCase {
private static final String SIGNATURE_ALGO = "SHA256withRSA";
private static final String KEY_PAIR_ALGO = "RSA";
private static final String DIGEST_ALGO = "SHA-256";
@Test @Test
public void testSigningConsumer() throws Exception { public void testSigningConsumer() throws Exception {
Signature s = Signature.getInstance(SIGNATURE_ALGO); Signature s = Signature.getInstance(CryptoModule.SIGNATURE_ALGO);
KeyPair k = KeyPairGenerator.getInstance(KEY_PAIR_ALGO).genKeyPair(); KeyPairGenerator gen =
KeyPairGenerator.getInstance(CryptoModule.KEY_PAIR_ALGO);
KeyPair k = gen.genKeyPair();
byte[] data = new byte[1234]; byte[] data = new byte[1234];
// Generate some random data and sign it // Generate some random data and sign it
new Random().nextBytes(data); new Random().nextBytes(data);
@@ -40,7 +39,7 @@ public class ConsumersTest extends TestCase {
@Test @Test
public void testDigestingConsumer() throws Exception { public void testDigestingConsumer() throws Exception {
MessageDigest m = MessageDigest.getInstance(DIGEST_ALGO); MessageDigest m = MessageDigest.getInstance(CryptoModule.DIGEST_ALGO);
byte[] data = new byte[1234]; byte[] data = new byte[1234];
// Generate some random data and digest it // Generate some random data and digest it
new Random().nextBytes(data); new Random().nextBytes(data);

View File

@@ -3,15 +3,10 @@ package net.sf.briar.protocol;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.security.KeyFactory;
import java.security.KeyPair; import java.security.KeyPair;
import java.security.KeyPairGenerator; import java.security.KeyPairGenerator;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.security.PublicKey;
import java.security.Signature; import java.security.Signature;
import java.security.spec.EncodedKeySpec;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
@@ -33,6 +28,7 @@ import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.serial.ReaderFactory;
import net.sf.briar.api.serial.WriterFactory; import net.sf.briar.api.serial.WriterFactory;
import net.sf.briar.crypto.CryptoModule;
import net.sf.briar.serial.SerialModule; import net.sf.briar.serial.SerialModule;
import org.junit.After; import org.junit.After;
@@ -44,10 +40,6 @@ import com.google.inject.Injector;
public class FileReadWriteTest extends TestCase { public class FileReadWriteTest extends TestCase {
private static final String SIGNATURE_ALGO = "SHA256withRSA";
private static final String KEY_PAIR_ALGO = "RSA";
private static final String DIGEST_ALGO = "SHA-256";
private final File testDir = TestUtils.getTestDirectory(); private final File testDir = TestUtils.getTestDirectory();
private final File file = new File(testDir, "foo"); private final File file = new File(testDir, "foo");
@@ -65,29 +57,22 @@ public class FileReadWriteTest extends TestCase {
public FileReadWriteTest() throws Exception { public FileReadWriteTest() throws Exception {
super(); super();
// Inject the reader and writer factories, since they belong to Injector i = Guice.createInjector(new SerialModule(),
// a different component new CryptoModule());
Injector i = Guice.createInjector(new SerialModule());
readerFactory = i.getInstance(ReaderFactory.class); readerFactory = i.getInstance(ReaderFactory.class);
writerFactory = i.getInstance(WriterFactory.class); writerFactory = i.getInstance(WriterFactory.class);
signature = Signature.getInstance(SIGNATURE_ALGO); keyParser = i.getInstance(KeyParser.class);
messageDigest = MessageDigest.getInstance(DIGEST_ALGO); signature = Signature.getInstance(CryptoModule.SIGNATURE_ALGO);
batchDigest = MessageDigest.getInstance(DIGEST_ALGO); messageDigest = MessageDigest.getInstance(CryptoModule.DIGEST_ALGO);
final KeyFactory keyFactory = KeyFactory.getInstance(KEY_PAIR_ALGO); batchDigest = MessageDigest.getInstance(CryptoModule.DIGEST_ALGO);
keyParser = new KeyParser() {
public PublicKey parsePublicKey(byte[] encodedKey)
throws InvalidKeySpecException {
EncodedKeySpec e = new X509EncodedKeySpec(encodedKey);
return keyFactory.generatePublic(e);
}
};
assertEquals(messageDigest.getDigestLength(), UniqueId.LENGTH); assertEquals(messageDigest.getDigestLength(), UniqueId.LENGTH);
assertEquals(batchDigest.getDigestLength(), UniqueId.LENGTH); assertEquals(batchDigest.getDigestLength(), UniqueId.LENGTH);
// Create and encode a test message // Create and encode a test message
MessageEncoder messageEncoder = new MessageEncoderImpl(signature, MessageEncoder messageEncoder = new MessageEncoderImpl(signature,
messageDigest, writerFactory); messageDigest, writerFactory);
KeyPair keyPair = KeyPairGenerator gen =
KeyPairGenerator.getInstance(KEY_PAIR_ALGO).generateKeyPair(); KeyPairGenerator.getInstance(CryptoModule.KEY_PAIR_ALGO);
KeyPair keyPair = gen.generateKeyPair();
message = messageEncoder.encodeMessage(MessageId.NONE, sub, nick, message = messageEncoder.encodeMessage(MessageId.NONE, sub, nick,
keyPair, messageBody.getBytes("UTF-8")); keyPair, messageBody.getBytes("UTF-8"));
} }

View File

@@ -10,25 +10,24 @@ import java.util.Arrays;
import java.util.Random; import java.util.Random;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.crypto.CryptoModule;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
public class SigningDigestingOutputStreamTest extends TestCase { public class SigningDigestingOutputStreamTest extends TestCase {
private static final String SIGNATURE_ALGO = "SHA256withRSA";
private static final String KEY_PAIR_ALGO = "RSA";
private static final String DIGEST_ALGO = "SHA-256";
private KeyPair keyPair = null; private KeyPair keyPair = null;
private Signature sig = null; private Signature sig = null;
private MessageDigest dig = null; private MessageDigest dig = null;
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
keyPair = KeyPairGenerator.getInstance(KEY_PAIR_ALGO).generateKeyPair(); KeyPairGenerator gen =
sig = Signature.getInstance(SIGNATURE_ALGO); KeyPairGenerator.getInstance(CryptoModule.KEY_PAIR_ALGO);
dig = MessageDigest.getInstance(DIGEST_ALGO); keyPair = gen.generateKeyPair();
sig = Signature.getInstance(CryptoModule.SIGNATURE_ALGO);
dig = MessageDigest.getInstance(CryptoModule.DIGEST_ALGO);
} }
@Test @Test