Don't forget to check the MAC.

This commit is contained in:
akwizgran
2011-08-12 22:11:56 +02:00
parent 0504a2d6fd
commit a8994a3354
3 changed files with 75 additions and 12 deletions

View File

@@ -1,6 +1,5 @@
package net.sf.briar.protocol;
import java.io.IOException;
import java.security.MessageDigest;
import net.sf.briar.api.serial.Consumer;
@@ -14,15 +13,15 @@ class DigestingConsumer implements Consumer {
this.messageDigest = messageDigest;
}
public void write(byte b) throws IOException {
public void write(byte b) {
messageDigest.update(b);
}
public void write(byte[] b) throws IOException {
public void write(byte[] b) {
messageDigest.update(b);
}
public void write(byte[] b, int off, int len) throws IOException {
public void write(byte[] b, int off, int len) {
messageDigest.update(b, off, len);
}
}

View File

@@ -0,0 +1,27 @@
package net.sf.briar.transport;
import javax.crypto.Mac;
import net.sf.briar.api.serial.Consumer;
/** A consumer that passes its input through a MAC. */
class MacConsumer implements Consumer {
private final Mac mac;
MacConsumer(Mac mac) {
this.mac = mac;
}
public void write(byte b) {
mac.update(b);
}
public void write(byte[] b) {
mac.update(b);
}
public void write(byte[] b, int off, int len) {
mac.update(b, off, len);
}
}

View File

@@ -2,6 +2,7 @@ package net.sf.briar.transport;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import javax.crypto.Mac;
@@ -13,6 +14,7 @@ import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.writers.ProtocolReaderFactory;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory;
import net.sf.briar.api.transport.PacketReader;
@@ -22,7 +24,7 @@ class PacketReaderImpl implements PacketReader {
private final Reader reader;
private final PacketDecrypter decrypter;
private final Mac mac;
private final int transportId;
private final int macLength, transportId;
private final long connection;
private long packet = 0L;
@@ -41,8 +43,10 @@ class PacketReaderImpl implements PacketReader {
protocol.createSubscriptionReader(in));
reader.addObjectReader(Tags.TRANSPORTS,
protocol.createTransportReader(in));
reader.addConsumer(new MacConsumer(mac));
this.decrypter = decrypter;
this.mac = mac;
macLength = mac.getMacLength();
this.transportId = transportId;
this.connection = connection;
}
@@ -62,7 +66,7 @@ class PacketReaderImpl implements PacketReader {
throw new IllegalStateException();
byte[] tag = decrypter.readTag();
if(!TagDecoder.decodeTag(tag, transportId, connection, packet))
throw new IOException();
throw new FormatException();
mac.update(tag);
packet++;
betweenPackets = false;
@@ -70,7 +74,24 @@ class PacketReaderImpl implements PacketReader {
public Ack readAck() throws IOException {
if(betweenPackets) readTag();
return reader.readUserDefined(Tags.ACK, Ack.class);
Ack a = reader.readUserDefined(Tags.ACK, Ack.class);
readMac();
betweenPackets = true;
return a;
}
private void readMac() throws IOException {
byte[] expectedMac = mac.doFinal();
byte[] actualMac = new byte[macLength];
InputStream in = decrypter.getInputStream();
int offset = 0;
while(offset < macLength) {
int read = in.read(actualMac, offset, actualMac.length - offset);
if(read == -1) break;
offset += read;
}
if(offset < macLength) throw new FormatException();
if(!Arrays.equals(expectedMac, actualMac)) throw new FormatException();
}
public boolean hasBatch() throws IOException {
@@ -80,7 +101,10 @@ class PacketReaderImpl implements PacketReader {
public Batch readBatch() throws IOException {
if(betweenPackets) readTag();
return reader.readUserDefined(Tags.BATCH, Batch.class);
Batch b = reader.readUserDefined(Tags.BATCH, Batch.class);
readMac();
betweenPackets = true;
return b;
}
public boolean hasOffer() throws IOException {
@@ -90,7 +114,10 @@ class PacketReaderImpl implements PacketReader {
public Offer readOffer() throws IOException {
if(betweenPackets) readTag();
return reader.readUserDefined(Tags.OFFER, Offer.class);
Offer o = reader.readUserDefined(Tags.OFFER, Offer.class);
readMac();
betweenPackets = true;
return o;
}
public boolean hasRequest() throws IOException {
@@ -100,7 +127,10 @@ class PacketReaderImpl implements PacketReader {
public Request readRequest() throws IOException {
if(betweenPackets) readTag();
return reader.readUserDefined(Tags.REQUEST, Request.class);
Request r = reader.readUserDefined(Tags.REQUEST, Request.class);
readMac();
betweenPackets = true;
return r;
}
public boolean hasSubscriptionUpdate() throws IOException {
@@ -110,8 +140,11 @@ class PacketReaderImpl implements PacketReader {
public SubscriptionUpdate readSubscriptionUpdate() throws IOException {
if(betweenPackets) readTag();
return reader.readUserDefined(Tags.SUBSCRIPTIONS,
SubscriptionUpdate s = reader.readUserDefined(Tags.SUBSCRIPTIONS,
SubscriptionUpdate.class);
readMac();
betweenPackets = true;
return s;
}
public boolean hasTransportUpdate() throws IOException {
@@ -121,6 +154,10 @@ class PacketReaderImpl implements PacketReader {
public TransportUpdate readTransportUpdate() throws IOException {
if(betweenPackets) readTag();
return reader.readUserDefined(Tags.TRANSPORTS, TransportUpdate.class);
TransportUpdate t = reader.readUserDefined(Tags.TRANSPORTS,
TransportUpdate.class);
readMac();
betweenPackets = true;
return t;
}
}