Added a method for retrieving the message body from the DB.

This commit is contained in:
akwizgran
2011-10-21 20:42:43 +01:00
parent 1173e0a563
commit 9ec8feec78
13 changed files with 217 additions and 93 deletions

View File

@@ -239,12 +239,19 @@ interface Database<T> {
Collection<BatchId> getLostBatches(T txn, ContactId c) throws DbException;
/**
* Returns the message identified by the given ID, in raw format.
* Returns the message identified by the given ID, in serialised form.
* <p>
* Locking: messages read.
*/
byte[] getMessage(T txn, MessageId m) throws DbException;
/**
* Returns the body of the message identified by the given ID.
* <p>
* Locking: messages read.
*/
byte[] getMessageBody(T txn, MessageId m) throws DbException;
/**
* Returns the message identified by the given ID, in raw format, or null
* if the message is not present in the database or is not sendable to the

View File

@@ -432,8 +432,7 @@ DatabaseCleaner.Callback {
int capacity = b.getCapacity();
ids = db.getSendableMessages(txn, c, capacity);
for(MessageId m : ids) {
byte[] raw = db.getMessage(txn, m);
messages.add(new Bytes(raw));
messages.add(new Bytes(db.getMessage(txn, m)));
}
db.commitTransaction(txn);
} catch(DbException e) {

View File

@@ -1,7 +1,9 @@
package net.sf.briar.db;
import java.io.EOFException;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
@@ -61,8 +63,11 @@ abstract class JdbcDatabase implements Database<Connection> {
+ " parentId HASH,"
+ " groupId HASH,"
+ " authorId HASH,"
+ " subject VARCHAR NOT NULL,"
+ " timestamp BIGINT NOT NULL,"
+ " size INT NOT NULL,"
+ " length INT NOT NULL,"
+ " bodyStart INT NOT NULL,"
+ " bodyLength INT NOT NULL,"
+ " raw BLOB NOT NULL,"
+ " sendability INT,"
+ " contactId INT,"
@@ -536,10 +541,10 @@ abstract class JdbcDatabase implements Database<Connection> {
if(containsMessage(txn, m.getId())) return false;
PreparedStatement ps = null;
try {
String sql = "INSERT INTO messages"
+ " (messageId, parentId, groupId, authorId, timestamp, size,"
+ " raw, sendability)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?, ZERO())";
String sql = "INSERT INTO messages (messageId, parentId, groupId,"
+ " authorId, subject, timestamp, length, bodyStart,"
+ " bodyLength, raw, sendability)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ZERO())";
ps = txn.prepareStatement(sql);
ps.setBytes(1, m.getId().getBytes());
if(m.getParent() == null) ps.setNull(2, Types.BINARY);
@@ -547,10 +552,12 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.setBytes(3, m.getGroup().getBytes());
if(m.getAuthor() == null) ps.setNull(4, Types.BINARY);
else ps.setBytes(4, m.getAuthor().getBytes());
ps.setLong(5, m.getTimestamp());
int length = m.getLength();
ps.setInt(6, length);
ps.setBinaryStream(7, m.getSerialisedStream(), length);
ps.setString(5, m.getSubject());
ps.setLong(6, m.getTimestamp());
ps.setInt(7, m.getLength());
ps.setInt(8, m.getBodyStart());
ps.setInt(9, m.getBodyLength());
ps.setBytes(10, m.getSerialised());
int affected = ps.executeUpdate();
if(affected != 1) throw new DbStateException();
ps.close();
@@ -626,18 +633,20 @@ abstract class JdbcDatabase implements Database<Connection> {
if(containsMessage(txn, m.getId())) return false;
PreparedStatement ps = null;
try {
String sql = "INSERT INTO messages"
+ " (messageId, parentId, timestamp, size, raw, contactId)"
+ " VALUES (?, ?, ?, ?, ?, ?)";
String sql = "INSERT INTO messages (messageId, parentId, subject,"
+ " timestamp, length, bodyStart, bodyLength, raw, contactId)"
+ " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
ps = txn.prepareStatement(sql);
ps.setBytes(1, m.getId().getBytes());
if(m.getParent() == null) ps.setNull(2, Types.BINARY);
else ps.setBytes(2, m.getParent().getBytes());
ps.setLong(3, m.getTimestamp());
int length = m.getLength();
ps.setInt(4, length);
ps.setBinaryStream(5, m.getSerialisedStream(), length);
ps.setInt(6, c.getInt());
ps.setString(3, m.getSubject());
ps.setLong(4, m.getTimestamp());
ps.setInt(5, m.getLength());
ps.setInt(6, m.getBodyStart());
ps.setInt(7, m.getBodyLength());
ps.setBytes(8, m.getSerialised());
ps.setInt(9, c.getInt());
int affected = ps.executeUpdate();
if(affected != 1) throw new DbStateException();
ps.close();
@@ -1042,14 +1051,14 @@ abstract class JdbcDatabase implements Database<Connection> {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT size, raw FROM messages WHERE messageId = ?";
String sql = "SELECT length, raw FROM messages WHERE messageId = ?";
ps = txn.prepareStatement(sql);
ps.setBytes(1, m.getBytes());
rs = ps.executeQuery();
if(!rs.next()) throw new DbStateException();
int size = rs.getInt(1);
byte[] raw = rs.getBlob(2).getBytes(1, size);
if(raw.length != size) throw new DbStateException();
int length = rs.getInt(1);
byte[] raw = rs.getBlob(2).getBytes(1, length);
if(raw.length != length) throw new DbStateException();
if(rs.next()) throw new DbStateException();
rs.close();
ps.close();
@@ -1061,13 +1070,60 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
public byte[] getMessageBody(Connection txn, MessageId m)
throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT bodyStart, bodyLength, raw FROM messages"
+ " WHERE messageId = ?";
ps = txn.prepareStatement(sql);
ps.setBytes(1, m.getBytes());
rs = ps.executeQuery();
if(!rs.next()) throw new DbStateException();
int bodyStart = rs.getInt(1);
int bodyLength = rs.getInt(2);
InputStream in = rs.getBlob(3).getBinaryStream();
// FIXME: We have to read and discard the header because
// InputStream.skip() is broken for blobs - find out why
byte[] head = new byte[bodyStart];
byte[] body = new byte[bodyLength];
try {
int offset = 0;
while(offset < head.length) {
int read = in.read(head, offset, head.length - offset);
if(read == -1) throw new SQLException(new EOFException());
offset += read;
}
offset = 0;
while(offset < body.length) {
int read = in.read(body, offset, body.length - offset);
if(read == -1) throw new SQLException(new EOFException());
offset += read;
}
in.close();
} catch(IOException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage());
throw new SQLException(e);
}
if(rs.next()) throw new DbStateException();
rs.close();
ps.close();
return body;
} catch(SQLException e) {
tryToClose(rs);
tryToClose(ps);
throw new DbException(e);
}
}
public byte[] getMessageIfSendable(Connection txn, ContactId c, MessageId m)
throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
// Do we have a sendable private message with the given ID?
String sql = "SELECT size, raw FROM messages"
String sql = "SELECT length, raw FROM messages"
+ " JOIN statuses ON messages.messageId = statuses.messageId"
+ " WHERE messages.messageId = ? AND messages.contactId = ?"
+ " AND status = ?";
@@ -1078,16 +1134,16 @@ abstract class JdbcDatabase implements Database<Connection> {
rs = ps.executeQuery();
byte[] raw = null;
if(rs.next()) {
int size = rs.getInt(1);
raw = rs.getBlob(2).getBytes(1, size);
if(raw.length != size) throw new DbStateException();
int length = rs.getInt(1);
raw = rs.getBlob(2).getBytes(1, length);
if(raw.length != length) throw new DbStateException();
}
if(rs.next()) throw new DbStateException();
rs.close();
ps.close();
if(raw != null) return raw;
// Do we have a sendable group message with the given ID?
sql = "SELECT size, raw FROM messages"
sql = "SELECT length, raw FROM messages"
+ " JOIN contactSubscriptions"
+ " ON messages.groupId = contactSubscriptions.groupId"
+ " JOIN visibilities"
@@ -1107,9 +1163,9 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.setShort(3, (short) Status.NEW.ordinal());
rs = ps.executeQuery();
if(rs.next()) {
int size = rs.getInt(1);
raw = rs.getBlob(2).getBytes(1, size);
if(raw.length != size) throw new DbStateException();
int length = rs.getInt(1);
raw = rs.getBlob(2).getBytes(1, length);
if(raw.length != length) throw new DbStateException();
}
if(rs.next()) throw new DbStateException();
rs.close();
@@ -1203,17 +1259,17 @@ abstract class JdbcDatabase implements Database<Connection> {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT size, messageId FROM messages"
String sql = "SELECT length, messageId FROM messages"
+ " ORDER BY timestamp";
ps = txn.prepareStatement(sql);
rs = ps.executeQuery();
Collection<MessageId> ids = new ArrayList<MessageId>();
int total = 0;
while(rs.next()) {
int size = rs.getInt(1);
if(total + size > capacity) break;
int length = rs.getInt(1);
if(total + length > capacity) break;
ids.add(new MessageId(rs.getBytes(2)));
total += size;
total += length;
}
rs.close();
ps.close();
@@ -1361,7 +1417,7 @@ abstract class JdbcDatabase implements Database<Connection> {
ResultSet rs = null;
try {
// Do we have any sendable private messages?
String sql = "SELECT size, messages.messageId FROM messages"
String sql = "SELECT length, messages.messageId FROM messages"
+ " JOIN statuses ON messages.messageId = statuses.messageId"
+ " WHERE messages.contactId = ? AND status = ?"
+ " ORDER BY timestamp";
@@ -1372,10 +1428,10 @@ abstract class JdbcDatabase implements Database<Connection> {
Collection<MessageId> ids = new ArrayList<MessageId>();
int total = 0;
while(rs.next()) {
int size = rs.getInt(1);
if(total + size > capacity) break;
int length = rs.getInt(1);
if(total + length > capacity) break;
ids.add(new MessageId(rs.getBytes(2)));
total += size;
total += length;
}
rs.close();
ps.close();
@@ -1384,7 +1440,7 @@ abstract class JdbcDatabase implements Database<Connection> {
total + "/" + capacity + " bytes");
if(total == capacity) return ids;
// Do we have any sendable group messages?
sql = "SELECT size, messages.messageId FROM messages"
sql = "SELECT length, messages.messageId FROM messages"
+ " JOIN contactSubscriptions"
+ " ON messages.groupId = contactSubscriptions.groupId"
+ " JOIN visibilities"
@@ -1403,10 +1459,10 @@ abstract class JdbcDatabase implements Database<Connection> {
ps.setShort(2, (short) Status.NEW.ordinal());
rs = ps.executeQuery();
while(rs.next()) {
int size = rs.getInt(1);
if(total + size > capacity) break;
int length = rs.getInt(1);
if(total + length > capacity) break;
ids.add(new MessageId(rs.getBytes(2)));
total += size;
total += length;
}
rs.close();
ps.close();

View File

@@ -16,6 +16,7 @@ import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.MessageEncoder;
import net.sf.briar.api.protocol.MessageId;
import net.sf.briar.api.protocol.ProtocolConstants;
import net.sf.briar.api.protocol.Types;
import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.Writer;
@@ -81,6 +82,9 @@ class MessageEncoderImpl implements MessageEncoder {
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
// Initialise the consumers
CountingConsumer counting = new CountingConsumer(
ProtocolConstants.MAX_PACKET_LENGTH);
w.addConsumer(counting);
Consumer digestingConsumer = new DigestingConsumer(messageDigest);
w.addConsumer(digestingConsumer);
Consumer authorConsumer = null;
@@ -110,6 +114,7 @@ class MessageEncoderImpl implements MessageEncoder {
random.nextBytes(salt);
w.writeBytes(salt);
w.writeBytes(body);
int bodyStart = (int) counting.getCount() - body.length;
// Sign the message with the author's private key, if there is one
if(authorKey == null) {
w.writeNull();
@@ -137,6 +142,6 @@ class MessageEncoderImpl implements MessageEncoder {
GroupId groupId = group == null ? null : group.getId();
AuthorId authorId = author == null ? null : author.getId();
return new MessageImpl(id, parent, groupId, authorId, subject,
timestamp, raw);
timestamp, raw, bodyStart, body.length);
}
}

View File

@@ -1,8 +1,5 @@
package net.sf.briar.protocol;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import net.sf.briar.api.protocol.AuthorId;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Message;
@@ -17,9 +14,15 @@ class MessageImpl implements Message {
private final String subject;
private final long timestamp;
private final byte[] raw;
private final int bodyStart, bodyLength;
public MessageImpl(MessageId id, MessageId parent, GroupId group,
AuthorId author, String subject, long timestamp, byte[] raw) {
AuthorId author, String subject, long timestamp, byte[] raw,
int bodyStart, int bodyLength) {
if(bodyStart + bodyLength > raw.length)
throw new IllegalArgumentException();
if(bodyLength > Message.MAX_BODY_LENGTH)
throw new IllegalArgumentException();
this.id = id;
this.parent = parent;
this.group = group;
@@ -27,6 +30,8 @@ class MessageImpl implements Message {
this.subject = subject;
this.timestamp = timestamp;
this.raw = raw;
this.bodyStart = bodyStart;
this.bodyLength = bodyLength;
}
public MessageId getId() {
@@ -57,12 +62,16 @@ class MessageImpl implements Message {
return raw.length;
}
public byte[] getSerialisedBytes() {
public byte[] getSerialised() {
return raw;
}
public InputStream getSerialisedStream() {
return new ByteArrayInputStream(raw);
public int getBodyStart() {
return bodyStart;
}
public int getBodyLength() {
return bodyLength;
}
@Override

View File

@@ -84,8 +84,10 @@ class MessageReader implements ObjectReader<Message> {
// Read the salt
byte[] salt = r.readBytes(Message.SALT_LENGTH);
if(salt.length != Message.SALT_LENGTH) throw new FormatException();
// Skip the message body
r.readBytes(Message.MAX_BODY_LENGTH);
// Read the message body
byte[] body = r.readBytes(Message.MAX_BODY_LENGTH);
// Record the offset of the body within the message
int bodyStart = (int) counting.getCount() - body.length;
// Record the length of the data covered by the author's signature
int signedByAuthor = (int) counting.getCount();
// Read the author's signature, if there is one
@@ -131,6 +133,6 @@ class MessageReader implements ObjectReader<Message> {
GroupId groupId = group == null ? null : group.getId();
AuthorId authorId = author == null ? null : author.getId();
return new MessageImpl(id, parent, groupId, authorId, subject,
timestamp, raw);
timestamp, raw, bodyStart, body.length);
}
}

View File

@@ -272,26 +272,21 @@ class ReaderImpl implements Reader {
}
public String readString() throws IOException {
return readString(maxStringLength);
}
public String readString(int maxLength) throws IOException {
if(!hasString()) throw new FormatException();
consumeLookahead();
int length;
if(next == Tag.STRING) length = readLength();
else length = 0xFF & next ^ Tag.SHORT_STRING;
if(length > maxStringLength) throw new FormatException();
if(length > maxLength) throw new FormatException();
if(length == 0) return "";
readIntoBuffer(length);
return new String(buf, 0, length, "UTF-8");
}
public String readString(int maxLength) throws IOException {
setMaxStringLength(maxLength);
try {
return readString();
} finally {
resetMaxStringLength();
}
}
private int readLength() throws IOException {
if(!hasLength()) throw new FormatException();
if(next >= 0) return readUint7();
@@ -315,27 +310,22 @@ class ReaderImpl implements Reader {
}
public byte[] readBytes() throws IOException {
return readBytes(maxBytesLength);
}
public byte[] readBytes(int maxLength) throws IOException {
if(!hasBytes()) throw new FormatException();
consumeLookahead();
int length;
if(next == Tag.BYTES) length = readLength();
else length = 0xFF & next ^ Tag.SHORT_BYTES;
if(length > maxBytesLength) throw new FormatException();
if(length > maxLength) throw new FormatException();
if(length == 0) return EMPTY_BUFFER;
byte[] b = new byte[length];
readIntoBuffer(b, length);
return b;
}
public byte[] readBytes(int maxLength) throws IOException {
setMaxBytesLength(maxLength);
try {
return readBytes();
} finally {
resetMaxBytesLength();
}
}
public boolean hasList() throws IOException {
if(!hasLookahead) readLookahead(true);
if(eof) return false;