diff --git a/api/net/sf/briar/api/protocol/GroupFactory.java b/api/net/sf/briar/api/protocol/GroupFactory.java index 044fca597..55a4f79e2 100644 --- a/api/net/sf/briar/api/protocol/GroupFactory.java +++ b/api/net/sf/briar/api/protocol/GroupFactory.java @@ -2,5 +2,5 @@ package net.sf.briar.api.protocol; public interface GroupFactory { - Group createGroup(GroupId id, String name, byte[] salt, byte[] publicKey); + Group createGroup(GroupId id, String name, boolean restricted, byte[] b); } diff --git a/components/net/sf/briar/db/JdbcDatabase.java b/components/net/sf/briar/db/JdbcDatabase.java index c48bad47b..9e251e201 100644 --- a/components/net/sf/briar/db/JdbcDatabase.java +++ b/components/net/sf/briar/db/JdbcDatabase.java @@ -2,7 +2,6 @@ package net.sf.briar.db; import java.io.ByteArrayInputStream; import java.io.File; -import java.security.PublicKey; import java.sql.Blob; import java.sql.Connection; import java.sql.PreparedStatement; @@ -40,9 +39,9 @@ abstract class JdbcDatabase implements Database { private static final String CREATE_LOCAL_SUBSCRIPTIONS = "CREATE TABLE localSubscriptions" + " (groupId HASH NOT NULL," - + " name VARCHAR NOT NULL," - + " salt BINARY," - + " publicKey BINARY," + + " groupName VARCHAR NOT NULL," + + " restricted BOOLEAN NOT NULL," + + " groupKey BINARY NOT NULL," + " PRIMARY KEY (groupId))"; private static final String CREATE_MESSAGES = @@ -90,9 +89,9 @@ abstract class JdbcDatabase implements Database { "CREATE TABLE contactSubscriptions" + " (contactId INT NOT NULL," + " groupId HASH NOT NULL," - + " name VARCHAR NOT NULL," - + " salt BINARY," - + " publicKey BINARY," + + " groupName VARCHAR NOT NULL," + + " restricted BOOLEAN NOT NULL," + + " groupKey BINARY NOT NULL," + " PRIMARY KEY (contactId, groupId)," + " FOREIGN KEY (contactId) REFERENCES contacts (contactId)" + " ON DELETE CASCADE)"; @@ -531,14 +530,14 @@ abstract class JdbcDatabase implements Database { PreparedStatement ps = null; try { String sql = "INSERT INTO localSubscriptions" - + " (groupId, name, salt, publicKey)" + + " (groupId, groupName, restricted, groupKey)" + " VALUES (?, ?, ?, ?)"; ps = txn.prepareStatement(sql); ps.setBytes(1, g.getId().getBytes()); ps.setString(2, g.getName()); - ps.setBytes(3, g.getSalt()); - PublicKey k = g.getPublicKey(); - ps.setBytes(4, k == null ? null : k.getEncoded()); + ps.setBoolean(3, g.isRestricted()); + if(g.isRestricted()) ps.setBytes(4, g.getPublicKey().getEncoded()); + else ps.setBytes(4, g.getSalt()); int rowsAffected = ps.executeUpdate(); assert rowsAffected == 1; ps.close(); @@ -990,7 +989,7 @@ abstract class JdbcDatabase implements Database { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "SELECT (groupId, name, salt, publicKey)" + String sql = "SELECT groupId, groupName, restricted, groupKey" + " FROM localSubscriptions"; ps = txn.prepareStatement(sql); rs = ps.executeQuery(); @@ -998,9 +997,9 @@ abstract class JdbcDatabase implements Database { while(rs.next()) { GroupId id = new GroupId(rs.getBytes(1)); String name = rs.getString(2); - byte[] salt = rs.getBytes(3); - byte[] publicKey = rs.getBytes(4); - subs.add(groupFactory.createGroup(id, name, salt, publicKey)); + boolean restricted = rs.getBoolean(3); + byte[] key = rs.getBytes(4); + subs.add(groupFactory.createGroup(id, name, restricted, key)); } rs.close(); ps.close(); @@ -1018,7 +1017,7 @@ abstract class JdbcDatabase implements Database { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "SELECT (groupId, name, salt, publicKey)" + String sql = "SELECT groupId, groupName, restricted, groupKey" + " FROM contactSubscriptions" + " WHERE contactId = ?"; ps = txn.prepareStatement(sql); @@ -1028,9 +1027,9 @@ abstract class JdbcDatabase implements Database { while(rs.next()) { GroupId id = new GroupId(rs.getBytes(1)); String name = rs.getString(2); - byte[] salt = rs.getBytes(3); - byte[] publicKey = rs.getBytes(4); - subs.add(groupFactory.createGroup(id, name, salt, publicKey)); + boolean restricted = rs.getBoolean(3); + byte[] key = rs.getBytes(4); + subs.add(groupFactory.createGroup(id, name, restricted, key)); } rs.close(); ps.close(); @@ -1390,16 +1389,17 @@ abstract class JdbcDatabase implements Database { ps.close(); // Store the new subscriptions sql = "INSERT INTO contactSubscriptions" - + "(contactId, groupId, name, salt, publicKey)" + + "(contactId, groupId, groupName, restricted, groupKey)" + " VALUES (?, ?, ?, ?, ?)"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); for(Group g : subs) { ps.setBytes(2, g.getId().getBytes()); ps.setString(3, g.getName()); - ps.setBytes(4, g.getSalt()); - PublicKey k = g.getPublicKey(); - ps.setBytes(5, k == null ? null : k.getEncoded()); + ps.setBoolean(4, g.isRestricted()); + if(g.isRestricted()) + ps.setBytes(5, g.getPublicKey().getEncoded()); + else ps.setBytes(5, g.getSalt()); ps.addBatch(); } int[] rowsAffectedArray = ps.executeBatch(); diff --git a/components/net/sf/briar/protocol/GroupFactoryImpl.java b/components/net/sf/briar/protocol/GroupFactoryImpl.java index e98a0398b..d9b194106 100644 --- a/components/net/sf/briar/protocol/GroupFactoryImpl.java +++ b/components/net/sf/briar/protocol/GroupFactoryImpl.java @@ -19,20 +19,15 @@ class GroupFactoryImpl implements GroupFactory { this.keyParser = keyParser; } - public Group createGroup(GroupId id, String name, byte[] salt, - byte[] publicKey) { - if(salt == null && publicKey == null) - throw new IllegalArgumentException(); - if(salt != null && publicKey != null) - throw new IllegalArgumentException(); - PublicKey key = null; - if(publicKey != null) { + public Group createGroup(GroupId id, String name, boolean restricted, + byte[] b) { + if(restricted) { try { - key = keyParser.parsePublicKey(publicKey); + PublicKey key = keyParser.parsePublicKey(b); + return new GroupImpl(id, name, null, key); } catch (InvalidKeySpecException e) { throw new IllegalArgumentException(e); } - } - return new GroupImpl(id, name, salt, key); + } else return new GroupImpl(id, name, b, null); } } diff --git a/components/net/sf/briar/protocol/GroupImpl.java b/components/net/sf/briar/protocol/GroupImpl.java index 2d62fa362..73a7c95b9 100644 --- a/components/net/sf/briar/protocol/GroupImpl.java +++ b/components/net/sf/briar/protocol/GroupImpl.java @@ -39,4 +39,14 @@ public class GroupImpl implements Group { public PublicKey getPublicKey() { return publicKey; } + + @Override + public boolean equals(Object o) { + return o instanceof Group && id.equals(((Group) o).getId()); + } + + @Override + public int hashCode() { + return id.hashCode(); + } } diff --git a/components/net/sf/briar/protocol/MessageImpl.java b/components/net/sf/briar/protocol/MessageImpl.java index e8b02bb7b..a8360a7ed 100644 --- a/components/net/sf/briar/protocol/MessageImpl.java +++ b/components/net/sf/briar/protocol/MessageImpl.java @@ -54,7 +54,7 @@ class MessageImpl implements Message { @Override public boolean equals(Object o) { - return o instanceof Message && id.equals(((Message)o).getId()); + return o instanceof Message && id.equals(((Message) o).getId()); } @Override diff --git a/test/build.xml b/test/build.xml index f5e38bbba..a0605ef09 100644 --- a/test/build.xml +++ b/test/build.xml @@ -13,6 +13,7 @@ + diff --git a/test/net/sf/briar/db/BasicH2Test.java b/test/net/sf/briar/db/BasicH2Test.java new file mode 100644 index 000000000..843b05196 --- /dev/null +++ b/test/net/sf/briar/db/BasicH2Test.java @@ -0,0 +1,117 @@ +package net.sf.briar.db; + +import java.io.File; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Random; + +import junit.framework.TestCase; +import net.sf.briar.TestUtils; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class BasicH2Test extends TestCase { + + private static final String CREATE_TABLE = + "CREATE TABLE foo" + + " (uniqueId BINARY(32) NOT NULL," + + " name VARCHAR NOT NULL," + + " PRIMARY KEY (uniqueId))"; + + private final File testDir = TestUtils.getTestDirectory(); + private final File db = new File(testDir, "db"); + private final String url = "jdbc:h2:" + db.getPath(); + + private Connection connection = null; + + @Before + public void setUp() throws Exception { + testDir.mkdirs(); + Class.forName("org.h2.Driver"); + connection = DriverManager.getConnection(url); + } + + @Test + public void testCreateTableAndAddRow() throws Exception { + // Create the table + createTable(connection); + // Generate a unique ID + byte[] uniqueId = new byte[32]; + new Random().nextBytes(uniqueId); + // Insert the unique ID and name into the table + addRow(uniqueId, "foo"); + } + + @Test + public void testCreateTableAddAndRetrieveRow() throws Exception { + // Create the table + createTable(connection); + // Generate a unique ID + byte[] uniqueId = new byte[32]; + new Random().nextBytes(uniqueId); + // Insert the unique ID and name into the table + addRow(uniqueId, "foo"); + // Check that the name can be retrieved using the unique ID + assertEquals("foo", getName(uniqueId)); + } + + private void addRow(byte[] uniqueId, String name) throws SQLException { + String sql = "INSERT INTO foo (uniqueId, name) VALUES (?, ?)"; + PreparedStatement ps = null; + try { + ps = connection.prepareStatement(sql); + ps.setBytes(1, uniqueId); + ps.setString(2, name); + int rowsAffected = ps.executeUpdate(); + ps.close(); + assertEquals(1, rowsAffected); + } catch(SQLException e) { + connection.close(); + throw e; + } + } + + private String getName(byte[] uniqueId) throws SQLException { + String sql = "SELECT name FROM foo WHERE uniqueID = ?"; + PreparedStatement ps = null; + ResultSet rs = null; + try { + ps = connection.prepareStatement(sql); + ps.setBytes(1, uniqueId); + rs = ps.executeQuery(); + assertTrue(rs.next()); + String name = rs.getString(1); + assertFalse(rs.next()); + rs.close(); + ps.close(); + return name; + } catch(SQLException e) { + connection.close(); + throw e; + } + } + + private void createTable(Connection connection) throws SQLException { + Statement s; + try { + s = connection.createStatement(); + s.executeUpdate(CREATE_TABLE); + s.close(); + } catch(SQLException e) { + connection.close(); + throw e; + } + } + + @After + public void tearDown() throws Exception { + if(connection != null) connection.close(); + TestUtils.deleteTestDirectory(testDir); + } +} diff --git a/test/net/sf/briar/db/DatabaseComponentTest.java b/test/net/sf/briar/db/DatabaseComponentTest.java index 385baf084..65ba87ad8 100644 --- a/test/net/sf/briar/db/DatabaseComponentTest.java +++ b/test/net/sf/briar/db/DatabaseComponentTest.java @@ -468,6 +468,7 @@ public abstract class DatabaseComponentTest extends TestCase { will(returnValue(true)); oneOf(ackWriter).addBatchId(batchId1); will(returnValue(false)); + oneOf(ackWriter).finish(); // Record the batch that was acked oneOf(database).removeBatchesToAck(txn, contactId, acks); }}); diff --git a/test/net/sf/briar/db/H2DatabaseTest.java b/test/net/sf/briar/db/H2DatabaseTest.java index 72a6fe8f0..2ca8f0963 100644 --- a/test/net/sf/briar/db/H2DatabaseTest.java +++ b/test/net/sf/briar/db/H2DatabaseTest.java @@ -76,8 +76,8 @@ public class H2DatabaseTest extends TestCase { random.nextBytes(raw); message = new TestMessage(messageId, MessageId.NONE, groupId, authorId, timestamp, raw); - group = groupFactory.createGroup(groupId, "Group name", - TestUtils.getRandomId(), null); + group = groupFactory.createGroup(groupId, "Group name", false, + TestUtils.getRandomId()); } @Before @@ -533,7 +533,7 @@ public class H2DatabaseTest extends TestCase { MessageId childId3 = new MessageId(TestUtils.getRandomId()); GroupId groupId1 = new GroupId(TestUtils.getRandomId()); Group group1 = groupFactory.createGroup(groupId1, "Another group name", - TestUtils.getRandomId(), null); + false, TestUtils.getRandomId()); Message child1 = new TestMessage(childId1, messageId, groupId, authorId, timestamp, raw); Message child2 = new TestMessage(childId2, messageId, groupId, @@ -758,7 +758,7 @@ public class H2DatabaseTest extends TestCase { public void testUpdateSubscriptions() throws DbException { GroupId groupId1 = new GroupId(TestUtils.getRandomId()); Group group1 = groupFactory.createGroup(groupId1, "Another group name", - TestUtils.getRandomId(), null); + false, TestUtils.getRandomId()); Database db = open(false); Connection txn = db.startTransaction(); @@ -783,7 +783,7 @@ public class H2DatabaseTest extends TestCase { throws DbException { GroupId groupId1 = new GroupId(TestUtils.getRandomId()); Group group1 = groupFactory.createGroup(groupId1, "Another group name", - TestUtils.getRandomId(), null); + false, TestUtils.getRandomId()); Database db = open(false); Connection txn = db.startTransaction();