Added a method for retrieving the message body from the DB.

This commit is contained in:
akwizgran
2011-10-21 20:42:43 +01:00
parent 1173e0a563
commit 9ec8feec78
13 changed files with 217 additions and 93 deletions

View File

@@ -1,7 +1,5 @@
package net.sf.briar.api.protocol;
import java.io.InputStream;
public interface Message extends MessageHeader {
/**
@@ -21,15 +19,15 @@ public interface Message extends MessageHeader {
/** The length of the random salt in bytes. */
static final int SALT_LENGTH = 8;
/** Returns the length of the message in bytes. */
/** Returns the length of the serialised message in bytes. */
int getLength();
/** Returns the serialised representation of the entire message. */
byte[] getSerialisedBytes();
/** Returns the serialised message. */
byte[] getSerialised();
/**
* Returns a stream for reading the serialised representation of the entire
* message.
*/
InputStream getSerialisedStream();
/** Returns the offset of the message body within the serialised message. */
int getBodyStart();
/** Returns the length of the message body in bytes. */
int getBodyLength();
}

View File

@@ -239,12 +239,19 @@ interface Database<T> {
Collection<BatchId> getLostBatches(T txn, ContactId c) throws DbException;
/**
* Returns the message identified by the given ID, in raw format.
* Returns the message identified by the given ID, in serialised form.
* <p>
* Locking: messages read.
*/
byte[] getMessage(T txn, MessageId m) throws DbException;
/**
* Returns the body of the message identified by the given ID.
* <p>
* Locking: messages read.
*/
byte[] getMessageBody(T txn, MessageId m) throws DbException;
/**
* Returns the message identified by the given ID, in raw format, or null
* if the message is not present in the database or is not sendable to the

View File

@@ -432,8 +432,7 @@ DatabaseCleaner.Callback {
int capacity = b.getCapacity();
ids = db.getSendableMessages(txn, c, capacity);
for(MessageId m : ids) {
byte[] raw = db.getMessage(txn, m);
messages.add(new Bytes(raw));
messages.add(new Bytes(db.getMessage(txn, m)));
}
db.commitTransaction(txn);
} catch(DbException e) {

View File

@@ -1,7 +1,9 @@
package net.sf.briar.db;
import java.io.EOFException;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
@@ -61,8 +63,11 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " parentId HASH,"
+ " groupId HASH,"
+ " authorId HASH,"
+ " subject VARCHAR NOT NULL,"
+ " timestamp BIGINT NOT NULL,"
+ " size INT NOT NULL,"
+ " length INT NOT NULL,"
+ " bodyStart INT NOT NULL,"
+ " bodyLength INT NOT NULL,"
+ " raw BLOB NOT NULL,"
+ " sendability INT,"
+ " contactId INT,"
@@ -536,10 +541,10 @@ abstract class JdbcDatabase implements Database<Connection> {
if(containsMessage(txn, m.getId())) return false;
PreparedStatement ps = null;
try {
String sql = "INSERT INTO messages"
+ " (messageId, parentId, groupId, authorId, timestamp, size,"
+ " raw, sendability)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?, ZERO())";
String sql = "INSERT INTO messages (messageId, parentId, groupId,"
+ " authorId, subject, timestamp, length, bodyStart,"
+ " bodyLength, raw, sendability)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ZERO())";
ps = txn.prepareStatement(sql);
ps.setBytes(1, m.getId().getBytes());
if(m.getParent() == null) ps.setNull(2, Types.BINARY);
@@ -547,10 +552,12 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.setBytes(3, m.getGroup().getBytes());
if(m.getAuthor() == null) ps.setNull(4, Types.BINARY);
else ps.setBytes(4, m.getAuthor().getBytes());
ps.setLong(5, m.getTimestamp());
int length = m.getLength();
ps.setInt(6, length);
ps.setBinaryStream(7, m.getSerialisedStream(), length);
ps.setString(5, m.getSubject());
ps.setLong(6, m.getTimestamp());
ps.setInt(7, m.getLength());
ps.setInt(8, m.getBodyStart());
ps.setInt(9, m.getBodyLength());
ps.setBytes(10, m.getSerialised());
int affected = ps.executeUpdate();
if(affected != 1) throw new DbStateException();
ps.close();
@@ -626,18 +633,20 @@ abstract class JdbcDatabase implements Database<Connection> {
if(containsMessage(txn, m.getId())) return false;
PreparedStatement ps = null;
try {
String sql = "INSERT INTO messages"
+ " (messageId, parentId, timestamp, size, raw, contactId)"
+ " VALUES (?, ?, ?, ?, ?, ?)";
String sql = "INSERT INTO messages (messageId, parentId, subject,"
+ " timestamp, length, bodyStart, bodyLength, raw, contactId)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setBytes(1, m.getId().getBytes());
if(m.getParent() == null) ps.setNull(2, Types.BINARY);
else ps.setBytes(2, m.getParent().getBytes());
ps.setLong(3, m.getTimestamp());
int length = m.getLength();
ps.setInt(4, length);
ps.setBinaryStream(5, m.getSerialisedStream(), length);
ps.setInt(6, c.getInt());
ps.setString(3, m.getSubject());
ps.setLong(4, m.getTimestamp());
ps.setInt(5, m.getLength());
ps.setInt(6, m.getBodyStart());
ps.setInt(7, m.getBodyLength());
ps.setBytes(8, m.getSerialised());
ps.setInt(9, c.getInt());
int affected = ps.executeUpdate();
if(affected != 1) throw new DbStateException();
ps.close();
@@ -1042,14 +1051,14 @@ abstract class JdbcDatabase implements Database<Connection> {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT size, raw FROM messages WHERE messageId = ?";
String sql = "SELECT length, raw FROM messages WHERE messageId = ?";
ps = txn.prepareStatement(sql);
ps.setBytes(1, m.getBytes());
rs = ps.executeQuery();
if(!rs.next()) throw new DbStateException();
int size = rs.getInt(1);
byte[] raw = rs.getBlob(2).getBytes(1, size);
if(raw.length != size) throw new DbStateException();
int length = rs.getInt(1);
byte[] raw = rs.getBlob(2).getBytes(1, length);
if(raw.length != length) throw new DbStateException();
if(rs.next()) throw new DbStateException();
rs.close();
ps.close();
@@ -1061,13 +1070,60 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
public byte[] getMessageBody(Connection txn, MessageId m)
throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT bodyStart, bodyLength, raw FROM messages"
+ " WHERE messageId = ?";
ps = txn.prepareStatement(sql);
ps.setBytes(1, m.getBytes());
rs = ps.executeQuery();
if(!rs.next()) throw new DbStateException();
int bodyStart = rs.getInt(1);
int bodyLength = rs.getInt(2);
InputStream in = rs.getBlob(3).getBinaryStream();
// FIXME: We have to read and discard the header because
// InputStream.skip() is broken for blobs - find out why
byte[] head = new byte[bodyStart];
byte[] body = new byte[bodyLength];
try {
int offset = 0;
while(offset < head.length) {
int read = in.read(head, offset, head.length - offset);
if(read == -1) throw new SQLException(new EOFException());
offset += read;
}
offset = 0;
while(offset < body.length) {
int read = in.read(body, offset, body.length - offset);
if(read == -1) throw new SQLException(new EOFException());
offset += read;
}
in.close();
} catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
throw new SQLException(e);
}
if(rs.next()) throw new DbStateException();
rs.close();
ps.close();
return body;
} catch(SQLException e) {
tryToClose(rs);
tryToClose(ps);
throw new DbException(e);
}
}
public byte[] getMessageIfSendable(Connection txn, ContactId c, MessageId m)
throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
// Do we have a sendable private message with the given ID?
String sql = "SELECT size, raw FROM messages"
String sql = "SELECT length, raw FROM messages"
+ " JOIN statuses ON messages.messageId = statuses.messageId"
+ " WHERE messages.messageId = ? AND messages.contactId = ?"
+ " AND status = ?";
@@ -1078,16 +1134,16 @@ abstract class JdbcDatabase implements Database<Connection> {
rs = ps.executeQuery();
byte[] raw = null;
if(rs.next()) {
int size = rs.getInt(1);
raw = rs.getBlob(2).getBytes(1, size);
if(raw.length != size) throw new DbStateException();
int length = rs.getInt(1);
raw = rs.getBlob(2).getBytes(1, length);
if(raw.length != length) throw new DbStateException();
}
if(rs.next()) throw new DbStateException();
rs.close();
ps.close();
if(raw != null) return raw;
// Do we have a sendable group message with the given ID?
sql = "SELECT size, raw FROM messages"
sql = "SELECT length, raw FROM messages"
+ " JOIN contactSubscriptions"
+ " ON messages.groupId = contactSubscriptions.groupId"
+ " JOIN visibilities"
@@ -1107,9 +1163,9 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.setShort(3, (short) Status.NEW.ordinal());
rs = ps.executeQuery();
if(rs.next()) {
int size = rs.getInt(1);
raw = rs.getBlob(2).getBytes(1, size);
if(raw.length != size) throw new DbStateException();
int length = rs.getInt(1);
raw = rs.getBlob(2).getBytes(1, length);
if(raw.length != length) throw new DbStateException();
}
if(rs.next()) throw new DbStateException();
rs.close();
@@ -1203,17 +1259,17 @@ abstract class JdbcDatabase implements Database<Connection> {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT size, messageId FROM messages"
String sql = "SELECT length, messageId FROM messages"
+ " ORDER BY timestamp";
ps = txn.prepareStatement(sql);
rs = ps.executeQuery();
Collection<MessageId> ids = new ArrayList<MessageId>();
int total = 0;
while(rs.next()) {
int size = rs.getInt(1);
if(total + size > capacity) break;
int length = rs.getInt(1);
if(total + length > capacity) break;
ids.add(new MessageId(rs.getBytes(2)));
total += size;
total += length;
}
rs.close();
ps.close();
@@ -1361,7 +1417,7 @@ abstract class JdbcDatabase implements Database<Connection> {
ResultSet rs = null;
try {
// Do we have any sendable private messages?
String sql = "SELECT size, messages.messageId FROM messages"
String sql = "SELECT length, messages.messageId FROM messages"
+ " JOIN statuses ON messages.messageId = statuses.messageId"
+ " WHERE messages.contactId = ? AND status = ?"
+ " ORDER BY timestamp";
@@ -1372,10 +1428,10 @@ abstract class JdbcDatabase implements Database<Connection> {
Collection<MessageId> ids = new ArrayList<MessageId>();
int total = 0;
while(rs.next()) {
int size = rs.getInt(1);
if(total + size > capacity) break;
int length = rs.getInt(1);
if(total + length > capacity) break;
ids.add(new MessageId(rs.getBytes(2)));
total += size;
total += length;
}
rs.close();
ps.close();
@@ -1384,7 +1440,7 @@ abstract class JdbcDatabase implements Database<Connection> {
total + "/" + capacity + " bytes");
if(total == capacity) return ids;
// Do we have any sendable group messages?
sql = "SELECT size, messages.messageId FROM messages"
sql = "SELECT length, messages.messageId FROM messages"
+ " JOIN contactSubscriptions"
+ " ON messages.groupId = contactSubscriptions.groupId"
+ " JOIN visibilities"
@@ -1403,10 +1459,10 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.setShort(2, (short) Status.NEW.ordinal());
rs = ps.executeQuery();
while(rs.next()) {
int size = rs.getInt(1);
if(total + size > capacity) break;
int length = rs.getInt(1);
if(total + length > capacity) break;
ids.add(new MessageId(rs.getBytes(2)));
total += size;
total += length;
}
rs.close();
ps.close();

View File

@@ -16,6 +16,7 @@ import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageEncoder;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.Writer;
@@ -81,6 +82,9 @@ class MessageEncoderImpl implements MessageEncoder {
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
// Initialise the consumers
CountingConsumer counting = new CountingConsumer(
ProtocolConstants.MAX_PACKET_LENGTH);
w.addConsumer(counting);
Consumer digestingConsumer = new DigestingConsumer(messageDigest);
w.addConsumer(digestingConsumer);
Consumer authorConsumer = null;
@@ -110,6 +114,7 @@ class MessageEncoderImpl implements MessageEncoder {
random.nextBytes(salt);
w.writeBytes(salt);
w.writeBytes(body);
int bodyStart = (int) counting.getCount() - body.length;
// Sign the message with the author's private key, if there is one
if(authorKey == null) {
w.writeNull();
@@ -137,6 +142,6 @@ class MessageEncoderImpl implements MessageEncoder {
GroupId groupId = group == null ? null : group.getId();
AuthorId authorId = author == null ? null : author.getId();
return new MessageImpl(id, parent, groupId, authorId, subject,
timestamp, raw);
timestamp, raw, bodyStart, body.length);
}
}

View File

@@ -1,8 +1,5 @@
package net.sf.briar.protocol;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message;
@@ -17,9 +14,15 @@ class MessageImpl implements Message {
private final String subject;
private final long timestamp;
private final byte[] raw;
private final int bodyStart, bodyLength;
public MessageImpl(MessageId id, MessageId parent, GroupId group,
AuthorId author, String subject, long timestamp, byte[] raw) {
AuthorId author, String subject, long timestamp, byte[] raw,
int bodyStart, int bodyLength) {
if(bodyStart + bodyLength > raw.length)
throw new IllegalArgumentException();
if(bodyLength > Message.MAX_BODY_LENGTH)
throw new IllegalArgumentException();
this.id = id;
this.parent = parent;
this.group = group;
@@ -27,6 +30,8 @@ class MessageImpl implements Message {
this.subject = subject;
this.timestamp = timestamp;
this.raw = raw;
this.bodyStart = bodyStart;
this.bodyLength = bodyLength;
}
public MessageId getId() {
@@ -57,12 +62,16 @@ class MessageImpl implements Message {
return raw.length;
}
public byte[] getSerialisedBytes() {
public byte[] getSerialised() {
return raw;
}
public InputStream getSerialisedStream() {
return new ByteArrayInputStream(raw);
public int getBodyStart() {
return bodyStart;
}
public int getBodyLength() {
return bodyLength;
}
@Override

View File

@@ -84,8 +84,10 @@ class MessageReader implements ObjectReader<Message> {
// Read the salt
byte[] salt = r.readBytes(Message.SALT_LENGTH);
if(salt.length != Message.SALT_LENGTH) throw new FormatException();
// Skip the message body
r.readBytes(Message.MAX_BODY_LENGTH);
// Read the message body
byte[] body = r.readBytes(Message.MAX_BODY_LENGTH);
// Record the offset of the body within the message
int bodyStart = (int) counting.getCount() - body.length;
// Record the length of the data covered by the author's signature
int signedByAuthor = (int) counting.getCount();
// Read the author's signature, if there is one
@@ -131,6 +133,6 @@ class MessageReader implements ObjectReader<Message> {
GroupId groupId = group == null ? null : group.getId();
AuthorId authorId = author == null ? null : author.getId();
return new MessageImpl(id, parent, groupId, authorId, subject,
timestamp, raw);
timestamp, raw, bodyStart, body.length);
}
}

View File

@@ -272,26 +272,21 @@ class ReaderImpl implements Reader {
}
public String readString() throws IOException {
return readString(maxStringLength);
}
public String readString(int maxLength) throws IOException {
if(!hasString()) throw new FormatException();
consumeLookahead();
int length;
if(next == Tag.STRING) length = readLength();
else length = 0xFF & next ^ Tag.SHORT_STRING;
if(length > maxStringLength) throw new FormatException();
if(length > maxLength) throw new FormatException();
if(length == 0) return "";
readIntoBuffer(length);
return new String(buf, 0, length, "UTF-8");
}
public String readString(int maxLength) throws IOException {
setMaxStringLength(maxLength);
try {
return readString();
} finally {
resetMaxStringLength();
}
}
private int readLength() throws IOException {
if(!hasLength()) throw new FormatException();
if(next >= 0) return readUint7();
@@ -315,27 +310,22 @@ class ReaderImpl implements Reader {
}
public byte[] readBytes() throws IOException {
return readBytes(maxBytesLength);
}
public byte[] readBytes(int maxLength) throws IOException {
if(!hasBytes()) throw new FormatException();
consumeLookahead();
int length;
if(next == Tag.BYTES) length = readLength();
else length = 0xFF & next ^ Tag.SHORT_BYTES;
if(length > maxBytesLength) throw new FormatException();
if(length > maxLength) throw new FormatException();
if(length == 0) return EMPTY_BUFFER;
byte[] b = new byte[length];
readIntoBuffer(b, length);
return b;
}
public byte[] readBytes(int maxLength) throws IOException {
setMaxBytesLength(maxLength);
try {
return readBytes();
} finally {
resetMaxBytesLength();
}
}
public boolean hasList() throws IOException {
if(!hasLookahead) readLookahead(true);
if(eof) return false;

View File

@@ -142,10 +142,10 @@ public class ProtocolIntegrationTest extends TestCase {
a.finish();
BatchWriter b = protocolWriterFactory.createBatchWriter(out1);
assertTrue(b.writeMessage(message.getSerialisedBytes()));
assertTrue(b.writeMessage(message1.getSerialisedBytes()));
assertTrue(b.writeMessage(message2.getSerialisedBytes()));
assertTrue(b.writeMessage(message3.getSerialisedBytes()));
assertTrue(b.writeMessage(message.getSerialised()));
assertTrue(b.writeMessage(message1.getSerialised()));
assertTrue(b.writeMessage(message2.getSerialised()));
assertTrue(b.writeMessage(message3.getSerialised()));
b.finish();
OfferWriter o = protocolWriterFactory.createOfferWriter(out1);
@@ -255,6 +255,6 @@ public class ProtocolIntegrationTest extends TestCase {
assertEquals(m1.getGroup(), m2.getGroup());
assertEquals(m1.getAuthor(), m2.getAuthor());
assertEquals(m1.getTimestamp(), m2.getTimestamp());
assertArrayEquals(m1.getSerialisedBytes(), m2.getSerialisedBytes());
assertArrayEquals(m1.getSerialised(), m2.getSerialised());
}
}

View File

@@ -4,6 +4,7 @@ import static org.junit.Assert.assertArrayEquals;
import java.io.File;
import java.sql.Connection;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
@@ -1606,6 +1607,46 @@ public class H2DatabaseTest extends TestCase {
db.close();
}
@Test
public void testGetMessageBody() throws Exception {
Database<Connection> db = open(false);
Connection txn = db.startTransaction();
// Add a contact and subscribe to a group
assertEquals(contactId, db.addContact(txn, transports, secret));
db.addSubscription(txn, group);
// Store a couple of messages
int bodyLength = raw.length - 20;
Message message1 = new TestMessage(messageId, null, groupId, null,
subject, timestamp, raw, 5, bodyLength);
Message privateMessage1 = new TestMessage(privateMessageId, null, null,
null, subject, timestamp, raw, 10, bodyLength);
db.addGroupMessage(txn, message1);
db.addPrivateMessage(txn, privateMessage1, contactId);
// Calculate the expected message bodies
byte[] expectedBody = new byte[bodyLength];
System.arraycopy(raw, 5, expectedBody, 0, bodyLength);
assertFalse(Arrays.equals(expectedBody, new byte[bodyLength]));
byte[] expectedBody1 = new byte[bodyLength];
System.arraycopy(raw, 10, expectedBody1, 0, bodyLength);
System.arraycopy(raw, 10, expectedBody1, 0, bodyLength);
// Retrieve the raw messages
assertArrayEquals(raw, db.getMessage(txn, messageId));
assertArrayEquals(raw, db.getMessage(txn, privateMessageId));
// Retrieve the message bodies
byte[] body = db.getMessageBody(txn, messageId);
assertArrayEquals(expectedBody, body);
byte[] body1 = db.getMessageBody(txn, privateMessageId);
assertArrayEquals(expectedBody1, body1);
db.commitTransaction(txn);
db.close();
}
@Test
public void testExceptionHandling() throws Exception {
Database<Connection> db = open(false);

View File

@@ -16,9 +16,16 @@ class TestMessage implements Message {
private final String subject;
private final long timestamp;
private final byte[] raw;
private final int bodyStart, bodyLength;
public TestMessage(MessageId id, MessageId parent, GroupId group,
AuthorId author, String subject, long timestamp, byte[] raw) {
this(id, parent, group, author, subject, timestamp, raw, 0, raw.length);
}
public TestMessage(MessageId id, MessageId parent, GroupId group,
AuthorId author, String subject, long timestamp, byte[] raw,
int bodyStart, int bodyLength) {
this.id = id;
this.parent = parent;
this.group = group;
@@ -26,6 +33,8 @@ class TestMessage implements Message {
this.subject = subject;
this.timestamp = timestamp;
this.raw = raw;
this.bodyStart = bodyStart;
this.bodyLength = bodyLength;
}
public MessageId getId() {
@@ -56,10 +65,18 @@ class TestMessage implements Message {
return raw.length;
}
public byte[] getSerialisedBytes() {
public byte[] getSerialised() {
return raw;
}
public int getBodyStart() {
return bodyStart;
}
public int getBodyLength() {
return bodyLength;
}
public InputStream getSerialisedStream() {
return new ByteArrayInputStream(raw);
}

View File

@@ -87,7 +87,7 @@ public class ProtocolReadWriteTest extends TestCase {
a.finish();
BatchWriter b = writerFactory.createBatchWriter(out);
b.writeMessage(message.getSerialisedBytes());
b.writeMessage(message.getSerialised());
b.finish();
OfferWriter o = writerFactory.createOfferWriter(out);

View File

@@ -113,7 +113,7 @@ public class ConstantsTest extends TestCase {
ProtocolConstants.MAX_PACKET_LENGTH);
BatchWriter b = new BatchWriterImpl(out, serial, writerFactory,
crypto.getMessageDigest());
assertTrue(b.writeMessage(message.getSerialisedBytes()));
assertTrue(b.writeMessage(message.getSerialised()));
b.finish();
// Check the size of the serialised batch
assertTrue(out.size() > UniqueId.LENGTH + Group.MAX_NAME_LENGTH +