Removed signatures from headers and bundles, since the transport's

authentication will make them redundant.
This commit is contained in:
akwizgran
2011-07-20 18:33:06 +01:00
parent 45b4bef348
commit f727420838
7 changed files with 319 additions and 140 deletions

View File

@@ -3,9 +3,6 @@ package net.sf.briar.protocol;
import java.io.IOException; import java.io.IOException;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import java.util.List; import java.util.List;
import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.Batch;
@@ -17,17 +14,12 @@ import net.sf.briar.api.serial.Reader;
public class BatchReader implements ObjectReader<Batch> { public class BatchReader implements ObjectReader<Batch> {
private final PublicKey publicKey;
private final Signature signature;
private final MessageDigest messageDigest; private final MessageDigest messageDigest;
private final ObjectReader<Message> messageReader; private final ObjectReader<Message> messageReader;
private final BatchFactory batchFactory; private final BatchFactory batchFactory;
BatchReader(PublicKey publicKey, Signature signature, BatchReader(MessageDigest messageDigest,
MessageDigest messageDigest, ObjectReader<Message> messageReader, ObjectReader<Message> messageReader, BatchFactory batchFactory) {
BatchFactory batchFactory) {
this.publicKey = publicKey;
this.signature = signature;
this.messageDigest = messageDigest; this.messageDigest = messageDigest;
this.messageReader = messageReader; this.messageReader = messageReader;
this.batchFactory = batchFactory; this.batchFactory = batchFactory;
@@ -35,25 +27,18 @@ public class BatchReader implements ObjectReader<Batch> {
public Batch readObject(Reader reader) throws IOException, public Batch readObject(Reader reader) throws IOException,
GeneralSecurityException { GeneralSecurityException {
// Initialise the input stream // Initialise the consumers
CountingConsumer counting = new CountingConsumer(Batch.MAX_SIZE); CountingConsumer counting = new CountingConsumer(Batch.MAX_SIZE);
DigestingConsumer digesting = new DigestingConsumer(messageDigest); DigestingConsumer digesting = new DigestingConsumer(messageDigest);
messageDigest.reset(); messageDigest.reset();
SigningConsumer signing = new SigningConsumer(signature); // Read and digest the data
signature.initVerify(publicKey);
// Read the signed data
reader.addConsumer(counting); reader.addConsumer(counting);
reader.addConsumer(digesting); reader.addConsumer(digesting);
reader.addConsumer(signing);
reader.addObjectReader(Tags.MESSAGE, messageReader); reader.addObjectReader(Tags.MESSAGE, messageReader);
List<Message> messages = reader.readList(Message.class); List<Message> messages = reader.readList(Message.class);
reader.removeObjectReader(Tags.MESSAGE); reader.removeObjectReader(Tags.MESSAGE);
reader.removeConsumer(signing);
// Read and verify the signature
byte[] sig = reader.readRaw();
reader.removeConsumer(digesting); reader.removeConsumer(digesting);
reader.removeConsumer(counting); reader.removeConsumer(counting);
if(!signature.verify(sig)) throw new SignatureException();
// Build and return the batch // Build and return the batch
BatchId id = new BatchId(messageDigest.digest()); BatchId id = new BatchId(messageDigest.digest());
return batchFactory.createBatch(id, messages); return batchFactory.createBatch(id, messages);

View File

@@ -13,7 +13,7 @@ import net.sf.briar.api.serial.Reader;
class BundleReaderImpl implements BundleReader { class BundleReaderImpl implements BundleReader {
private static enum State { START, FIRST_BATCH, MORE_BATCHES, END }; private static enum State { START, BATCHES, END };
private final Reader reader; private final Reader reader;
private final ObjectReader<Header> headerReader; private final ObjectReader<Header> headerReader;
@@ -29,21 +29,19 @@ class BundleReaderImpl implements BundleReader {
public Header getHeader() throws IOException, GeneralSecurityException { public Header getHeader() throws IOException, GeneralSecurityException {
if(state != State.START) throw new IllegalStateException(); if(state != State.START) throw new IllegalStateException();
reader.addObjectReader(Tags.HEADER, headerReader);
reader.readUserDefinedTag(Tags.HEADER); reader.readUserDefinedTag(Tags.HEADER);
reader.addObjectReader(Tags.HEADER, headerReader);
Header h = reader.readUserDefinedObject(Tags.HEADER, Header.class); Header h = reader.readUserDefinedObject(Tags.HEADER, Header.class);
reader.removeObjectReader(Tags.HEADER); reader.removeObjectReader(Tags.HEADER);
state = State.FIRST_BATCH; // Expect a list of batches
reader.readListStart();
reader.addObjectReader(Tags.BATCH, batchReader);
state = State.BATCHES;
return h; return h;
} }
public Batch getNextBatch() throws IOException, GeneralSecurityException { public Batch getNextBatch() throws IOException, GeneralSecurityException {
if(state == State.FIRST_BATCH) { if(state != State.BATCHES) throw new IllegalStateException();
reader.readListStart();
reader.addObjectReader(Tags.BATCH, batchReader);
state = State.MORE_BATCHES;
}
if(state != State.MORE_BATCHES) throw new IllegalStateException();
if(reader.hasListEnd()) { if(reader.hasListEnd()) {
reader.removeObjectReader(Tags.BATCH); reader.removeObjectReader(Tags.BATCH);
reader.readListEnd(); reader.readListEnd();

View File

@@ -2,10 +2,9 @@ package net.sf.briar.protocol;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.security.DigestOutputStream;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.security.PrivateKey;
import java.security.Signature;
import java.util.Collection; import java.util.Collection;
import java.util.Map; import java.util.Map;
@@ -21,22 +20,17 @@ class BundleWriterImpl implements BundleWriter {
private static enum State { START, FIRST_BATCH, MORE_BATCHES, END }; private static enum State { START, FIRST_BATCH, MORE_BATCHES, END };
private final SigningDigestingOutputStream out; private final DigestOutputStream out;
private final Writer writer; private final Writer writer;
private final PrivateKey privateKey;
private final Signature signature;
private final MessageDigest messageDigest; private final MessageDigest messageDigest;
private final long capacity; private final long capacity;
private State state = State.START; private State state = State.START;
BundleWriterImpl(OutputStream out, WriterFactory writerFactory, BundleWriterImpl(OutputStream out, WriterFactory writerFactory,
PrivateKey privateKey, Signature signature,
MessageDigest messageDigest, long capacity) { MessageDigest messageDigest, long capacity) {
this.out = this.out = new DigestOutputStream(out, messageDigest);
new SigningDigestingOutputStream(out, signature, messageDigest); this.out.on(false); // Turn off the digest until we need it
writer = writerFactory.createWriter(this.out); writer = writerFactory.createWriter(this.out);
this.privateKey = privateKey;
this.signature = signature;
this.messageDigest = messageDigest; this.messageDigest = messageDigest;
this.capacity = capacity; this.capacity = capacity;
} }
@@ -49,24 +43,13 @@ class BundleWriterImpl implements BundleWriter {
Map<String, String> transports) throws IOException, Map<String, String> transports) throws IOException,
GeneralSecurityException { GeneralSecurityException {
if(state != State.START) throw new IllegalStateException(); if(state != State.START) throw new IllegalStateException();
// Initialise the output stream
signature.initSign(privateKey);
// Write the initial tag // Write the initial tag
writer.writeUserDefinedTag(Tags.HEADER); writer.writeUserDefinedTag(Tags.HEADER);
// Write the data to be signed // Write the data
out.setSigning(true);
// Acks
writer.writeList(acks); writer.writeList(acks);
// Subs
writer.writeList(subs); writer.writeList(subs);
// Transports
writer.writeMap(transports); writer.writeMap(transports);
// Timestamp
writer.writeInt64(System.currentTimeMillis()); writer.writeInt64(System.currentTimeMillis());
out.setSigning(false);
// Create and write the signature
byte[] sig = signature.sign();
writer.writeRaw(sig);
// Expect a (possibly empty) list of batches // Expect a (possibly empty) list of batches
state = State.FIRST_BATCH; state = State.FIRST_BATCH;
} }
@@ -78,26 +61,21 @@ class BundleWriterImpl implements BundleWriter {
state = State.MORE_BATCHES; state = State.MORE_BATCHES;
} }
if(state != State.MORE_BATCHES) throw new IllegalStateException(); if(state != State.MORE_BATCHES) throw new IllegalStateException();
// Initialise the output stream
signature.initSign(privateKey);
messageDigest.reset();
// Write the initial tag // Write the initial tag
writer.writeUserDefinedTag(Tags.BATCH); writer.writeUserDefinedTag(Tags.BATCH);
// Write the data to be signed // Start digesting
out.setDigesting(true); messageDigest.reset();
out.setSigning(true); out.on(true);
// Write the data
writer.writeListStart(); writer.writeListStart();
// Bypass the writer and write the raw messages directly // Bypass the writer and write each raw message directly
for(Raw message : messages) { for(Raw message : messages) {
writer.writeUserDefinedTag(Tags.MESSAGE); writer.writeUserDefinedTag(Tags.MESSAGE);
out.write(message.getBytes()); out.write(message.getBytes());
} }
writer.writeListEnd(); writer.writeListEnd();
out.setSigning(false); // Stop digesting
// Create and write the signature out.on(false);
byte[] sig = signature.sign();
writer.writeRaw(sig);
out.setDigesting(false);
// Calculate and return the ID // Calculate and return the ID
return new BatchId(messageDigest.digest()); return new BatchId(messageDigest.digest());
} }

View File

@@ -2,9 +2,6 @@ package net.sf.briar.protocol;
import java.io.IOException; import java.io.IOException;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import java.util.Collection; import java.util.Collection;
import java.util.Map; import java.util.Map;
@@ -18,26 +15,17 @@ import net.sf.briar.api.serial.Reader;
class HeaderReader implements ObjectReader<Header> { class HeaderReader implements ObjectReader<Header> {
private final PublicKey publicKey;
private final Signature signature;
private final HeaderFactory headerFactory; private final HeaderFactory headerFactory;
HeaderReader(PublicKey publicKey, Signature signature, HeaderReader(HeaderFactory headerFactory) {
HeaderFactory headerFactory) {
this.publicKey = publicKey;
this.signature = signature;
this.headerFactory = headerFactory; this.headerFactory = headerFactory;
} }
public Header readObject(Reader reader) throws IOException, public Header readObject(Reader reader) throws IOException,
GeneralSecurityException { GeneralSecurityException {
// Initialise the input stream // Initialise and add the consumer
CountingConsumer counting = new CountingConsumer(Header.MAX_SIZE); CountingConsumer counting = new CountingConsumer(Header.MAX_SIZE);
SigningConsumer signing = new SigningConsumer(signature);
signature.initVerify(publicKey);
// Read the signed data
reader.addConsumer(counting); reader.addConsumer(counting);
reader.addConsumer(signing);
// Acks // Acks
reader.addObjectReader(Tags.BATCH_ID, new BatchIdReader()); reader.addObjectReader(Tags.BATCH_ID, new BatchIdReader());
Collection<BatchId> acks = reader.readList(BatchId.class); Collection<BatchId> acks = reader.readList(BatchId.class);
@@ -52,11 +40,8 @@ class HeaderReader implements ObjectReader<Header> {
// Timestamp // Timestamp
long timestamp = reader.readInt64(); long timestamp = reader.readInt64();
if(timestamp < 0L) throw new FormatException(); if(timestamp < 0L) throw new FormatException();
reader.removeConsumer(signing); // Remove the consumer
// Read and verify the signature
byte[] sig = reader.readRaw();
reader.removeConsumer(counting); reader.removeConsumer(counting);
if(!signature.verify(sig)) throw new SignatureException();
// Build and return the header // Build and return the header
return headerFactory.createHeader(acks, subs, transports, timestamp); return headerFactory.createHeader(acks, subs, transports, timestamp);
} }

View File

@@ -20,6 +20,7 @@
<test name='net.sf.briar.i18n.FontManagerTest'/> <test name='net.sf.briar.i18n.FontManagerTest'/>
<test name='net.sf.briar.i18n.I18nTest'/> <test name='net.sf.briar.i18n.I18nTest'/>
<test name='net.sf.briar.invitation.InvitationWorkerTest'/> <test name='net.sf.briar.invitation.InvitationWorkerTest'/>
<test name='net.sf.briar.protocol.BundleReaderImplTest'/>
<test name='net.sf.briar.protocol.BundleReadWriteTest'/> <test name='net.sf.briar.protocol.BundleReadWriteTest'/>
<test name='net.sf.briar.protocol.ConsumersTest'/> <test name='net.sf.briar.protocol.ConsumersTest'/>
<test name='net.sf.briar.protocol.SigningDigestingOutputStreamTest'/> <test name='net.sf.briar.protocol.SigningDigestingOutputStreamTest'/>

View File

@@ -3,8 +3,6 @@ package net.sf.briar.protocol;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.RandomAccessFile;
import java.security.GeneralSecurityException;
import java.security.KeyFactory; import java.security.KeyFactory;
import java.security.KeyPair; import java.security.KeyPair;
import java.security.KeyPairGenerator; import java.security.KeyPairGenerator;
@@ -66,25 +64,23 @@ public class BundleReadWriteTest extends TestCase {
private final String nick = "Foo Bar"; private final String nick = "Foo Bar";
private final String messageBody = "This is the message body! Wooooooo!"; private final String messageBody = "This is the message body! Wooooooo!";
private final ReaderFactory rf; private final ReaderFactory readerFactory;
private final WriterFactory wf; private final WriterFactory writerFactory;
private final Signature signature;
private final KeyPair keyPair; private final MessageDigest messageDigest, batchDigest;
private final Signature sig, sig1;
private final MessageDigest dig, dig1;
private final KeyParser keyParser; private final KeyParser keyParser;
private final Message message; private final Message message;
public BundleReadWriteTest() throws Exception { public BundleReadWriteTest() throws Exception {
super(); super();
// Inject the reader and writer factories, since they belong to
// a different component
Injector i = Guice.createInjector(new SerialModule()); Injector i = Guice.createInjector(new SerialModule());
rf = i.getInstance(ReaderFactory.class); readerFactory = i.getInstance(ReaderFactory.class);
wf = i.getInstance(WriterFactory.class); writerFactory = i.getInstance(WriterFactory.class);
keyPair = KeyPairGenerator.getInstance(KEY_PAIR_ALGO).generateKeyPair(); signature = Signature.getInstance(SIGNATURE_ALGO);
sig = Signature.getInstance(SIGNATURE_ALGO); messageDigest = MessageDigest.getInstance(DIGEST_ALGO);
sig1 = Signature.getInstance(SIGNATURE_ALGO); batchDigest = MessageDigest.getInstance(DIGEST_ALGO);
dig = MessageDigest.getInstance(DIGEST_ALGO);
dig1 = MessageDigest.getInstance(DIGEST_ALGO);
final KeyFactory keyFactory = KeyFactory.getInstance(KEY_PAIR_ALGO); final KeyFactory keyFactory = KeyFactory.getInstance(KEY_PAIR_ALGO);
keyParser = new KeyParser() { keyParser = new KeyParser() {
public PublicKey parsePublicKey(byte[] encodedKey) public PublicKey parsePublicKey(byte[] encodedKey)
@@ -93,8 +89,13 @@ public class BundleReadWriteTest extends TestCase {
return keyFactory.generatePublic(e); return keyFactory.generatePublic(e);
} }
}; };
assertEquals(dig.getDigestLength(), UniqueId.LENGTH); assertEquals(messageDigest.getDigestLength(), UniqueId.LENGTH);
MessageEncoder messageEncoder = new MessageEncoderImpl(sig, dig, wf); assertEquals(batchDigest.getDigestLength(), UniqueId.LENGTH);
// Create and encode a test message
MessageEncoder messageEncoder = new MessageEncoderImpl(signature,
messageDigest, writerFactory);
KeyPair keyPair =
KeyPairGenerator.getInstance(KEY_PAIR_ALGO).generateKeyPair();
message = messageEncoder.encodeMessage(MessageId.NONE, sub, nick, message = messageEncoder.encodeMessage(MessageId.NONE, sub, nick,
keyPair, messageBody.getBytes("UTF-8")); keyPair, messageBody.getBytes("UTF-8"));
} }
@@ -107,8 +108,8 @@ public class BundleReadWriteTest extends TestCase {
@Test @Test
public void testWriteBundle() throws Exception { public void testWriteBundle() throws Exception {
FileOutputStream out = new FileOutputStream(bundle); FileOutputStream out = new FileOutputStream(bundle);
BundleWriter w = new BundleWriterImpl(out, wf, keyPair.getPrivate(), BundleWriter w = new BundleWriterImpl(out, writerFactory, batchDigest,
sig, dig, capacity); capacity);
Raw messageRaw = new RawByteArray(message.getBytes()); Raw messageRaw = new RawByteArray(message.getBytes());
w.addHeader(acks, subs, transports); w.addHeader(acks, subs, transports);
@@ -125,12 +126,12 @@ public class BundleReadWriteTest extends TestCase {
testWriteBundle(); testWriteBundle();
FileInputStream in = new FileInputStream(bundle); FileInputStream in = new FileInputStream(bundle);
Reader reader = rf.createReader(in); Reader reader = readerFactory.createReader(in);
MessageReader messageReader = new MessageReader(keyParser, sig1, dig1); MessageReader messageReader =
HeaderReader headerReader = new HeaderReader(keyPair.getPublic(), sig, new MessageReader(keyParser, signature, messageDigest);
new HeaderFactoryImpl()); HeaderReader headerReader = new HeaderReader(new HeaderFactoryImpl());
BatchReader batchReader = new BatchReader(keyPair.getPublic(), sig, dig, BatchReader batchReader = new BatchReader(batchDigest, messageReader,
messageReader, new BatchFactoryImpl()); new BatchFactoryImpl());
BundleReader r = new BundleReaderImpl(reader, headerReader, BundleReader r = new BundleReaderImpl(reader, headerReader,
batchReader); batchReader);
@@ -153,40 +154,6 @@ public class BundleReadWriteTest extends TestCase {
r.finish(); r.finish();
} }
@Test
public void testModifyingBundleBreaksSignature() throws Exception {
testWriteBundle();
RandomAccessFile f = new RandomAccessFile(bundle, "rw");
f.seek(bundle.length() - 100);
byte b = f.readByte();
f.seek(bundle.length() - 100);
f.writeByte(b + 1);
f.close();
FileInputStream in = new FileInputStream(bundle);
Reader reader = rf.createReader(in);
MessageReader messageReader = new MessageReader(keyParser, sig1, dig1);
HeaderReader headerReader = new HeaderReader(keyPair.getPublic(), sig,
new HeaderFactoryImpl());
BatchReader batchReader = new BatchReader(keyPair.getPublic(), sig, dig,
messageReader, new BatchFactoryImpl());
BundleReader r = new BundleReaderImpl(reader, headerReader,
batchReader);
Header h = r.getHeader();
assertEquals(acks, h.getAcks());
assertEquals(subs, h.getSubscriptions());
assertEquals(transports, h.getTransports());
try {
r.getNextBatch();
assertTrue(false);
} catch(GeneralSecurityException expected) {}
r.finish();
}
@After @After
public void tearDown() { public void tearDown() {
TestUtils.deleteTestDirectory(testDir); TestUtils.deleteTestDirectory(testDir);

View File

@@ -1,5 +1,270 @@
package net.sf.briar.protocol; package net.sf.briar.protocol;
public class BundleReaderImplTest { import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import junit.framework.TestCase;
import net.sf.briar.api.protocol.Batch;
import net.sf.briar.api.protocol.BatchId;
import net.sf.briar.api.protocol.GroupId;
import net.sf.briar.api.protocol.Header;
import net.sf.briar.api.protocol.Message;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory;
import net.sf.briar.api.serial.Writer;
import net.sf.briar.api.serial.WriterFactory;
import net.sf.briar.serial.SerialModule;
import org.junit.Test;
import com.google.inject.Guice;
import com.google.inject.Injector;
public class BundleReaderImplTest extends TestCase {
private final ReaderFactory readerFactory;
private final WriterFactory writerFactory;
public BundleReaderImplTest() {
Injector i = Guice.createInjector(new SerialModule());
readerFactory = i.getInstance(ReaderFactory.class);
writerFactory = i.getInstance(WriterFactory.class);
}
@Test
public void testEmptyBundleThrowsFormatException() throws Exception {
ByteArrayInputStream in = new ByteArrayInputStream(new byte[] {});
Reader r = readerFactory.createReader(in);
BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
new TestBatchReader());
try {
b.getHeader();
assertTrue(false);
} catch(FormatException expected) {}
}
@Test
public void testReadingBatchBeforeHeaderThrowsIllegalStateException()
throws Exception {
ByteArrayInputStream in = new ByteArrayInputStream(createValidBundle());
Reader r = readerFactory.createReader(in);
BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
new TestBatchReader());
try {
b.getNextBatch();
assertTrue(false);
} catch(IllegalStateException expected) {}
}
@Test
public void testMissingHeaderThrowsFormatException() throws Exception {
// Create a headless bundle
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
w.writeListStart();
w.writeUserDefinedTag(Tags.BATCH);
w.writeList(Collections.emptyList());
w.writeListEnd();
w.close();
byte[] headless = out.toByteArray();
// Try to read a header from the headless bundle
ByteArrayInputStream in = new ByteArrayInputStream(headless);
Reader r = readerFactory.createReader(in);
BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
new TestBatchReader());
try {
b.getHeader();
assertTrue(false);
} catch(FormatException expected) {}
}
@Test
public void testMissingBatchListThrowsFormatException() throws Exception {
// Create a header-only bundle
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
w.writeUserDefinedTag(Tags.HEADER);
w.writeList(Collections.emptyList()); // Acks
w.writeList(Collections.emptyList()); // Subs
w.writeMap(Collections.emptyMap()); // Transports
w.writeInt64(System.currentTimeMillis()); // Timestamp
w.close();
byte[] headerOnly = out.toByteArray();
// Try to read a header from the header-only bundle
ByteArrayInputStream in = new ByteArrayInputStream(headerOnly);
final Reader r = readerFactory.createReader(in);
BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
new TestBatchReader());
try {
b.getHeader();
assertTrue(false);
} catch(FormatException expected) {}
}
@Test
public void testEmptyBatchListIsAcceptable() throws Exception {
// Create a bundle with no batches
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
w.writeUserDefinedTag(Tags.HEADER);
w.writeList(Collections.emptyList()); // Acks
w.writeList(Collections.emptyList()); // Subs
w.writeMap(Collections.emptyMap()); // Transports
w.writeInt64(System.currentTimeMillis()); // Timestamp
w.writeListStart();
w.writeListEnd();
w.close();
byte[] batchless = out.toByteArray();
// It should be possible to read the header and null
ByteArrayInputStream in = new ByteArrayInputStream(batchless);
final Reader r = readerFactory.createReader(in);
BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
new TestBatchReader());
assertNotNull(b.getHeader());
assertNull(b.getNextBatch());
}
@Test
public void testValidBundle() throws Exception {
// It should be possible to read the header, a batch, and null
ByteArrayInputStream in = new ByteArrayInputStream(createValidBundle());
final Reader r = readerFactory.createReader(in);
BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
new TestBatchReader());
assertNotNull(b.getHeader());
assertNotNull(b.getNextBatch());
assertNull(b.getNextBatch());
}
@Test
public void testReadingBatchAfterNullThrowsIllegalStateException()
throws Exception {
// Trying to read another batch after null should not succeed
ByteArrayInputStream in = new ByteArrayInputStream(createValidBundle());
final Reader r = readerFactory.createReader(in);
BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
new TestBatchReader());
assertNotNull(b.getHeader());
assertNotNull(b.getNextBatch());
assertNull(b.getNextBatch());
try {
b.getNextBatch();
assertTrue(false);
} catch(IllegalStateException expected) {}
}
@Test
public void testReadingHeaderTwiceThrowsIllegalStateException()
throws Exception {
// Trying to read the header twice should not succeed
ByteArrayInputStream in = new ByteArrayInputStream(createValidBundle());
final Reader r = readerFactory.createReader(in);
BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
new TestBatchReader());
assertNotNull(b.getHeader());
try {
b.getHeader();
assertTrue(false);
} catch(IllegalStateException expected) {}
}
@Test
public void testReadingHeaderAfterBatchThrowsIllegalStateException()
throws Exception {
// Trying to read the header after a batch should not succeed
ByteArrayInputStream in = new ByteArrayInputStream(createValidBundle());
final Reader r = readerFactory.createReader(in);
BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
new TestBatchReader());
assertNotNull(b.getHeader());
assertNotNull(b.getNextBatch());
try {
b.getHeader();
assertTrue(false);
} catch(IllegalStateException expected) {}
}
private byte[] createValidBundle() throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
Writer w = writerFactory.createWriter(out);
w.writeUserDefinedTag(Tags.HEADER);
w.writeList(Collections.emptyList()); // Acks
w.writeList(Collections.emptyList()); // Subs
w.writeMap(Collections.emptyMap()); // Transports
w.writeInt64(System.currentTimeMillis()); // Timestamp
w.writeListStart();
w.writeUserDefinedTag(Tags.BATCH);
w.writeList(Collections.emptyList()); // Messages
w.writeListEnd();
w.close();
return out.toByteArray();
}
private static class TestHeaderReader implements ObjectReader<Header> {
public Header readObject(Reader r) throws IOException,
GeneralSecurityException {
r.readList();
r.readList();
r.readMap();
r.readInt64();
return new TestHeader();
}
}
private static class TestHeader implements Header {
public Set<BatchId> getAcks() {
return null;
}
public Set<GroupId> getSubscriptions() {
return null;
}
public Map<String, String> getTransports() {
return null;
}
public long getTimestamp() {
return 0;
}
}
private static class TestBatchReader implements ObjectReader<Batch> {
public Batch readObject(Reader r) throws IOException,
GeneralSecurityException {
r.readList();
return new TestBatch();
}
}
private static class TestBatch implements Batch {
public BatchId getId() {
return null;
}
public Iterable<Message> getMessages() {
return null;
}
}
} }