Don't forget to check the MAC.

This commit is contained in:
akwizgran
2011-08-12 22:11:56 +02:00
parent 0504a2d6fd
commit a8994a3354
3 changed files with 75 additions and 12 deletions

View File

@@ -1,6 +1,5 @@
package net.sf.briar.protocol; package net.sf.briar.protocol;
import java.io.IOException;
import java.security.MessageDigest; import java.security.MessageDigest;
import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Consumer;
@@ -14,15 +13,15 @@ class DigestingConsumer implements Consumer {
this.messageDigest = messageDigest; this.messageDigest = messageDigest;
} }
public void write(byte b) throws IOException { public void write(byte b) {
messageDigest.update(b); messageDigest.update(b);
} }
public void write(byte[] b) throws IOException { public void write(byte[] b) {
messageDigest.update(b); messageDigest.update(b);
} }
public void write(byte[] b, int off, int len) throws IOException { public void write(byte[] b, int off, int len) {
messageDigest.update(b, off, len); messageDigest.update(b, off, len);
} }
} }

View File

@@ -0,0 +1,27 @@
package net.sf.briar.transport;
import javax.crypto.Mac;
import net.sf.briar.api.serial.Consumer;
/** A consumer that passes its input through a MAC. */
class MacConsumer implements Consumer {
private final Mac mac;
MacConsumer(Mac mac) {
this.mac = mac;
}
public void write(byte b) {
mac.update(b);
}
public void write(byte[] b) {
mac.update(b);
}
public void write(byte[] b, int off, int len) {
mac.update(b, off, len);
}
}

View File

@@ -2,6 +2,7 @@ package net.sf.briar.transport;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.Arrays;
import javax.crypto.Mac; import javax.crypto.Mac;
@@ -13,6 +14,7 @@ import net.sf.briar.api.protocol.SubscriptionUpdate;
import net.sf.briar.api.protocol.Tags; import net.sf.briar.api.protocol.Tags;
import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.TransportUpdate;
import net.sf.briar.api.protocol.writers.ProtocolReaderFactory; import net.sf.briar.api.protocol.writers.ProtocolReaderFactory;
import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.Reader;
import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.serial.ReaderFactory;
import net.sf.briar.api.transport.PacketReader; import net.sf.briar.api.transport.PacketReader;
@@ -22,7 +24,7 @@ class PacketReaderImpl implements PacketReader {
private final Reader reader; private final Reader reader;
private final PacketDecrypter decrypter; private final PacketDecrypter decrypter;
private final Mac mac; private final Mac mac;
private final int transportId; private final int macLength, transportId;
private final long connection; private final long connection;
private long packet = 0L; private long packet = 0L;
@@ -41,8 +43,10 @@ class PacketReaderImpl implements PacketReader {
protocol.createSubscriptionReader(in)); protocol.createSubscriptionReader(in));
reader.addObjectReader(Tags.TRANSPORTS, reader.addObjectReader(Tags.TRANSPORTS,
protocol.createTransportReader(in)); protocol.createTransportReader(in));
reader.addConsumer(new MacConsumer(mac));
this.decrypter = decrypter; this.decrypter = decrypter;
this.mac = mac; this.mac = mac;
macLength = mac.getMacLength();
this.transportId = transportId; this.transportId = transportId;
this.connection = connection; this.connection = connection;
} }
@@ -62,7 +66,7 @@ class PacketReaderImpl implements PacketReader {
throw new IllegalStateException(); throw new IllegalStateException();
byte[] tag = decrypter.readTag(); byte[] tag = decrypter.readTag();
if(!TagDecoder.decodeTag(tag, transportId, connection, packet)) if(!TagDecoder.decodeTag(tag, transportId, connection, packet))
throw new IOException(); throw new FormatException();
mac.update(tag); mac.update(tag);
packet++; packet++;
betweenPackets = false; betweenPackets = false;
@@ -70,7 +74,24 @@ class PacketReaderImpl implements PacketReader {
public Ack readAck() throws IOException { public Ack readAck() throws IOException {
if(betweenPackets) readTag(); if(betweenPackets) readTag();
return reader.readUserDefined(Tags.ACK, Ack.class); Ack a = reader.readUserDefined(Tags.ACK, Ack.class);
readMac();
betweenPackets = true;
return a;
}
private void readMac() throws IOException {
byte[] expectedMac = mac.doFinal();
byte[] actualMac = new byte[macLength];
InputStream in = decrypter.getInputStream();
int offset = 0;
while(offset < macLength) {
int read = in.read(actualMac, offset, actualMac.length - offset);
if(read == -1) break;
offset += read;
}
if(offset < macLength) throw new FormatException();
if(!Arrays.equals(expectedMac, actualMac)) throw new FormatException();
} }
public boolean hasBatch() throws IOException { public boolean hasBatch() throws IOException {
@@ -80,7 +101,10 @@ class PacketReaderImpl implements PacketReader {
public Batch readBatch() throws IOException { public Batch readBatch() throws IOException {
if(betweenPackets) readTag(); if(betweenPackets) readTag();
return reader.readUserDefined(Tags.BATCH, Batch.class); Batch b = reader.readUserDefined(Tags.BATCH, Batch.class);
readMac();
betweenPackets = true;
return b;
} }
public boolean hasOffer() throws IOException { public boolean hasOffer() throws IOException {
@@ -90,7 +114,10 @@ class PacketReaderImpl implements PacketReader {
public Offer readOffer() throws IOException { public Offer readOffer() throws IOException {
if(betweenPackets) readTag(); if(betweenPackets) readTag();
return reader.readUserDefined(Tags.OFFER, Offer.class); Offer o = reader.readUserDefined(Tags.OFFER, Offer.class);
readMac();
betweenPackets = true;
return o;
} }
public boolean hasRequest() throws IOException { public boolean hasRequest() throws IOException {
@@ -100,7 +127,10 @@ class PacketReaderImpl implements PacketReader {
public Request readRequest() throws IOException { public Request readRequest() throws IOException {
if(betweenPackets) readTag(); if(betweenPackets) readTag();
return reader.readUserDefined(Tags.REQUEST, Request.class); Request r = reader.readUserDefined(Tags.REQUEST, Request.class);
readMac();
betweenPackets = true;
return r;
} }
public boolean hasSubscriptionUpdate() throws IOException { public boolean hasSubscriptionUpdate() throws IOException {
@@ -110,8 +140,11 @@ class PacketReaderImpl implements PacketReader {
public SubscriptionUpdate readSubscriptionUpdate() throws IOException { public SubscriptionUpdate readSubscriptionUpdate() throws IOException {
if(betweenPackets) readTag(); if(betweenPackets) readTag();
return reader.readUserDefined(Tags.SUBSCRIPTIONS, SubscriptionUpdate s = reader.readUserDefined(Tags.SUBSCRIPTIONS,
SubscriptionUpdate.class); SubscriptionUpdate.class);
readMac();
betweenPackets = true;
return s;
} }
public boolean hasTransportUpdate() throws IOException { public boolean hasTransportUpdate() throws IOException {
@@ -121,6 +154,10 @@ class PacketReaderImpl implements PacketReader {
public TransportUpdate readTransportUpdate() throws IOException { public TransportUpdate readTransportUpdate() throws IOException {
if(betweenPackets) readTag(); if(betweenPackets) readTag();
return reader.readUserDefined(Tags.TRANSPORTS, TransportUpdate.class); TransportUpdate t = reader.readUserDefined(Tags.TRANSPORTS,
TransportUpdate.class);
readMac();
betweenPackets = true;
return t;
} }
} }