diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/io/BlockSource.java b/bramble-api/src/main/java/org/briarproject/bramble/api/io/BlockSource.java new file mode 100644 index 000000000..83a88a29f --- /dev/null +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/io/BlockSource.java @@ -0,0 +1,11 @@ +package org.briarproject.bramble.api.io; + +import org.briarproject.bramble.api.db.DbException; +import org.briarproject.bramble.api.sync.MessageId; + +public interface BlockSource { + + int getBlockCount(MessageId m) throws DbException; + + byte[] getBlock(MessageId m, int blockNumber) throws DbException; +} diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/io/MessageInputStreamFactory.java b/bramble-api/src/main/java/org/briarproject/bramble/api/io/MessageInputStreamFactory.java new file mode 100644 index 000000000..68e9f85e1 --- /dev/null +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/io/MessageInputStreamFactory.java @@ -0,0 +1,17 @@ +package org.briarproject.bramble.api.io; + +import org.briarproject.bramble.api.sync.MessageId; + +import java.io.IOException; +import java.io.InputStream; + +public interface MessageInputStreamFactory { + + /** + * Returns an {@link InputStream} for reading the given message from the + * database. This method returns immediately. If the message is not in the + * database or cannot be read, reading from the stream will throw an + * {@link IOException}; + */ + InputStream getMessageInputStream(MessageId m); +} diff --git a/bramble-core/src/main/java/org/briarproject/bramble/io/BlockInputStream.java b/bramble-core/src/main/java/org/briarproject/bramble/io/BlockInputStream.java new file mode 100644 index 000000000..e4ae53931 --- /dev/null +++ b/bramble-core/src/main/java/org/briarproject/bramble/io/BlockInputStream.java @@ -0,0 +1,155 @@ +package org.briarproject.bramble.io; + +import org.briarproject.bramble.api.nullsafety.NotNullByDefault; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InterruptedIOException; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +import static java.lang.System.arraycopy; +import static java.lang.Thread.currentThread; + +/** + * An {@link InputStream} that asynchronously fetches blocks of data on demand. + */ +@ThreadSafe +@NotNullByDefault +abstract class BlockInputStream extends InputStream { + + private final int minBufferBytes; + private final BlockingQueue queue = new ArrayBlockingQueue<>(1); + private final Object lock = new Object(); + + @GuardedBy("lock") + @Nullable + private Buffer buffer = null; + + @GuardedBy("lock") + private int offset = 0; + + @GuardedBy("lock") + private boolean fetchingBlock = false; + + abstract void fetchBlockAsync(int blockNumber); + + BlockInputStream(int minBufferBytes) { + this.minBufferBytes = minBufferBytes; + } + + @Override + public int read() throws IOException { + synchronized (lock) { + if (!prepareRead()) return -1; + if (buffer == null) throw new AssertionError(); + return buffer.data[offset++] & 0xFF; + } + } + + @Override + public int read(byte[] b) throws IOException { + return read(b, 0, b.length); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (off < 0 || len < 0 || off + len > b.length) + throw new IllegalArgumentException(); + synchronized (lock) { + if (!prepareRead()) return -1; + if (buffer == null) throw new AssertionError(); + len = Math.min(len, buffer.length - offset); + if (len < 0) throw new AssertionError(); + arraycopy(buffer.data, offset, b, off, len); + offset += len; + return len; + } + } + + private boolean prepareRead() throws IOException { + throwExceptionIfNecessary(); + if (isEndOfStream()) return false; + if (shouldFetchBlock()) fetchBlockAsync(); + waitForBlock(); + if (buffer == null) throw new AssertionError(); + return offset < buffer.length; + } + + @GuardedBy("lock") + private void throwExceptionIfNecessary() throws IOException { + if (buffer != null && buffer.exception != null) + throw new IOException(buffer.exception); + } + + @GuardedBy("lock") + private boolean isEndOfStream() { + return buffer != null && offset == buffer.length && !fetchingBlock; + } + + @GuardedBy("lock") + private boolean shouldFetchBlock() { + if (fetchingBlock) return false; + if (buffer == null) return true; + if (buffer.length == 0) return false; + return buffer.length - offset < minBufferBytes; + } + + @GuardedBy("lock") + private void fetchBlockAsync() { + if (buffer == null) fetchBlockAsync(0); + else fetchBlockAsync(buffer.blockNumber + 1); + fetchingBlock = true; + } + + @GuardedBy("lock") + private void waitForBlock() throws IOException { + if (buffer != null && offset < buffer.length) return; + try { + buffer = queue.take(); + } catch (InterruptedException e) { + currentThread().interrupt(); + throw new InterruptedIOException(); + } + fetchingBlock = false; + offset = 0; + throwExceptionIfNecessary(); + } + + void fetchSucceeded(int blockNumber, byte[] data, int length) { + queue.add(new Buffer(blockNumber, data, length)); + } + + void fetchFailed(int blockNumber, Exception exception) { + queue.add(new Buffer(blockNumber, exception)); + } + + private static class Buffer { + + private final int blockNumber; + private final byte[] data; + private final int length; + @Nullable + private final Exception exception; + + private Buffer(int blockNumber, byte[] data, int length) { + if (length < 0 || length > data.length) + throw new IllegalArgumentException(); + this.blockNumber = blockNumber; + this.data = data; + this.length = length; + exception = null; + } + + private Buffer(int blockNumber, Exception exception) { + this.blockNumber = blockNumber; + this.exception = exception; + data = new byte[0]; + length = 0; + } + } +} diff --git a/bramble-core/src/main/java/org/briarproject/bramble/io/BlockSourceInputStream.java b/bramble-core/src/main/java/org/briarproject/bramble/io/BlockSourceInputStream.java new file mode 100644 index 000000000..959a79a47 --- /dev/null +++ b/bramble-core/src/main/java/org/briarproject/bramble/io/BlockSourceInputStream.java @@ -0,0 +1,53 @@ +package org.briarproject.bramble.io; + +import org.briarproject.bramble.api.db.DbException; +import org.briarproject.bramble.api.io.BlockSource; +import org.briarproject.bramble.api.nullsafety.NotNullByDefault; +import org.briarproject.bramble.api.sync.MessageId; + +import java.util.concurrent.Executor; + +import javax.annotation.concurrent.ThreadSafe; + +/** + * A {@link BlockInputStream} that fetches data from a {@link BlockSource}. + */ +@ThreadSafe +@NotNullByDefault +class BlockSourceInputStream extends BlockInputStream { + + private final Executor executor; + private final BlockSource blockSource; + private final MessageId messageId; + + private volatile int blockCount = -1; + + BlockSourceInputStream(int minBufferBytes, Executor executor, + BlockSource blockSource, MessageId messageId) { + super(minBufferBytes); + this.executor = executor; + this.blockSource = blockSource; + this.messageId = messageId; + } + + @Override + void fetchBlockAsync(int blockNumber) { + executor.execute(() -> { + try { + if (blockCount == -1) { + blockCount = blockSource.getBlockCount(messageId); + } + if (blockNumber > blockCount) { + fetchFailed(blockNumber, new IllegalArgumentException()); + } else if (blockNumber == blockCount) { + fetchSucceeded(blockNumber, new byte[0], 0); // EOF + } else { + byte[] block = blockSource.getBlock(messageId, blockNumber); + fetchSucceeded(blockNumber, block, block.length); + } + } catch (DbException e) { + fetchFailed(blockNumber, e); + } + }); + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/io/BlockSourceInputStreamTest.java b/bramble-core/src/test/java/org/briarproject/bramble/io/BlockSourceInputStreamTest.java new file mode 100644 index 000000000..5f3c7539a --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/io/BlockSourceInputStreamTest.java @@ -0,0 +1,152 @@ +package org.briarproject.bramble.io; + +import org.briarproject.bramble.api.db.DbException; +import org.briarproject.bramble.api.io.BlockSource; +import org.briarproject.bramble.api.sync.MessageId; +import org.briarproject.bramble.test.BrambleMockTestCase; +import org.jmock.Expectations; +import org.jmock.lib.concurrent.Synchroniser; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Random; +import java.util.concurrent.Executor; + +import static java.util.concurrent.Executors.newSingleThreadExecutor; +import static org.briarproject.bramble.test.TestUtils.getRandomId; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.spongycastle.util.Arrays.copyOfRange; + +public class BlockSourceInputStreamTest extends BrambleMockTestCase { + + private static final int MAX_DATA_BYTES = 1_000_000; + private static final int READ_BUFFER_BYTES = 4 * 1024; + private static final int BLOCK_BYTES = 32 * 1024; + private static final int MIN_BUFFER_BYTES = 32 * 1024; + + private final BlockSource blockSource; + + private final Random random = new Random(); + private final Executor executor = newSingleThreadExecutor(); + private final MessageId messageId = new MessageId(getRandomId()); + + public BlockSourceInputStreamTest() { + context.setThreadingPolicy(new Synchroniser()); + blockSource = context.mock(BlockSource.class); + } + + @Test + public void testReadSingleBytes() throws IOException { + byte[] data = createRandomData(); + BlockSource source = new ByteArrayBlockSource(data, BLOCK_BYTES); + InputStream in = new BlockSourceInputStream(MIN_BUFFER_BYTES, executor, + source, messageId); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + //noinspection ForLoopReplaceableByForEach + for (int i = 0; i < data.length; i++) { + int read = in.read(); + assertNotEquals(-1, read); + out.write(read); + } + assertEquals(-1, in.read()); + in.close(); + out.flush(); + out.close(); + assertArrayEquals(data, out.toByteArray()); + } + + @Test + public void testReadByteArrays() throws IOException { + byte[] data = createRandomData(); + BlockSource source = new ByteArrayBlockSource(data, BLOCK_BYTES); + InputStream in = new BlockSourceInputStream(MIN_BUFFER_BYTES, executor, + source, messageId); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + byte[] buf = new byte[READ_BUFFER_BYTES]; + int dataOffset = 0; + while (dataOffset < data.length) { + int length = Math.min(random.nextInt(buf.length) + 1, + data.length - dataOffset); + int bufOffset = 0; + if (length < buf.length) + bufOffset = random.nextInt(buf.length - length); + int read = in.read(buf, bufOffset, length); + assertNotEquals(-1, read); + out.write(buf, bufOffset, read); + dataOffset += read; + } + assertEquals(-1, in.read(buf, 0, 0)); + in.close(); + out.flush(); + out.close(); + assertArrayEquals(data, out.toByteArray()); + } + + @Test(expected = IOException.class) + public void testDbExceptionFromGetBlockCountIsRethrown() throws Exception { + context.checking(new Expectations() {{ + oneOf(blockSource).getBlockCount(messageId); + will(throwException(new DbException())); + }}); + + InputStream in = new BlockSourceInputStream(MIN_BUFFER_BYTES, executor, + blockSource, messageId); + //noinspection ResultOfMethodCallIgnored + in.read(); + } + + @Test(expected = IOException.class) + public void testDbExceptionFromGetBlockIsRethrown() throws Exception { + context.checking(new Expectations() {{ + oneOf(blockSource).getBlockCount(messageId); + will(returnValue(1)); + oneOf(blockSource).getBlock(messageId, 0); + will(throwException(new DbException())); + }}); + + InputStream in = new BlockSourceInputStream(MIN_BUFFER_BYTES, executor, + blockSource, messageId); + //noinspection ResultOfMethodCallIgnored + in.read(); + } + + @Test + public void testReadFullBlockAtEndOfMessage() throws Exception { + testReadBlockAtEndOfMessage(BLOCK_BYTES); + } + + @Test + public void testReadPartialBlockAtEndOfMessage() throws Exception { + testReadBlockAtEndOfMessage(BLOCK_BYTES - 1); + } + + private void testReadBlockAtEndOfMessage(int blockLength) throws Exception { + byte[] block = new byte[blockLength]; + random.nextBytes(block); + + context.checking(new Expectations() {{ + oneOf(blockSource).getBlockCount(messageId); + will(returnValue(1)); + oneOf(blockSource).getBlock(messageId, 0); + will(returnValue(block)); + }}); + + InputStream in = new BlockSourceInputStream(MIN_BUFFER_BYTES, executor, + blockSource, messageId); + byte[] buf = new byte[BLOCK_BYTES * 2]; + assertEquals(block.length, in.read(buf, 0, buf.length)); + assertArrayEquals(block, copyOfRange(buf, 0, block.length)); + assertEquals(-1, in.read(buf, 0, buf.length)); + } + + private byte[] createRandomData() { + int length = random.nextInt(MAX_DATA_BYTES) + 1; + byte[] data = new byte[length]; + random.nextBytes(data); + return data; + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/io/ByteArrayBlockSource.java b/bramble-core/src/test/java/org/briarproject/bramble/io/ByteArrayBlockSource.java new file mode 100644 index 000000000..e96e1baaa --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/io/ByteArrayBlockSource.java @@ -0,0 +1,32 @@ +package org.briarproject.bramble.io; + +import org.briarproject.bramble.api.io.BlockSource; +import org.briarproject.bramble.api.sync.MessageId; + +import static java.lang.System.arraycopy; + +class ByteArrayBlockSource implements BlockSource { + + private final byte[] data; + private final int blockBytes; + + ByteArrayBlockSource(byte[] data, int blockBytes) { + this.data = data; + this.blockBytes = blockBytes; + } + + @Override + public int getBlockCount(MessageId m) { + return (data.length + blockBytes - 1) / blockBytes; + } + + @Override + public byte[] getBlock(MessageId m, int blockNumber) { + int offset = blockNumber * blockBytes; + if (offset >= data.length) throw new IllegalArgumentException(); + int length = Math.min(blockBytes, data.length - offset); + byte[] block = new byte[length]; + arraycopy(data, offset, block, 0, length); + return block; + } +}