Double-check the initiator flag and transport ID of incoming

connections, and invert the flag for the responder's side.
This commit is contained in:
akwizgran
2011-10-15 14:15:25 +01:00
parent 2618fea0eb
commit 89001e4c91
8 changed files with 31 additions and 24 deletions

View File

@@ -10,5 +10,5 @@ public interface ConnectionWriterFactory {
boolean initiator, TransportId t, long connection, byte[] secret); boolean initiator, TransportId t, long connection, byte[] secret);
ConnectionWriter createConnectionWriter(OutputStream out, long capacity, ConnectionWriter createConnectionWriter(OutputStream out, long capacity,
byte[] encryptedIv, byte[] secret); TransportId t, byte[] encryptedIv, byte[] secret);
} }

View File

@@ -5,8 +5,8 @@ import net.sf.briar.api.TransportId;
public interface StreamConnectionFactory { public interface StreamConnectionFactory {
void createIncomingConnection(ContactId c, StreamTransportConnection s, void createIncomingConnection(TransportId t, ContactId c,
byte[] encryptedIv); StreamTransportConnection s, byte[] encryptedIv);
void createOutgoingConnection(TransportId t, ContactId c, void createOutgoingConnection(TransportId t, ContactId c,
StreamTransportConnection s); StreamTransportConnection s);

View File

@@ -118,7 +118,7 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher {
s.dispose(false); s.dispose(false);
return; return;
} }
streamConnFactory.createIncomingConnection(c, s, encryptedIv); streamConnFactory.createIncomingConnection(t, c, s, encryptedIv);
} }
public void dispatchOutgoingConnection(TransportId t, ContactId c, public void dispatchOutgoingConnection(TransportId t, ContactId c,

View File

@@ -43,7 +43,7 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
} }
public ConnectionWriter createConnectionWriter(OutputStream out, public ConnectionWriter createConnectionWriter(OutputStream out,
long capacity, byte[] encryptedIv, byte[] secret) { long capacity, TransportId t, byte[] encryptedIv, byte[] secret) {
// Decrypt the IV // Decrypt the IV
Cipher ivCipher = crypto.getIvCipher(); Cipher ivCipher = crypto.getIvCipher();
SecretKey ivKey = crypto.deriveIncomingIvKey(secret); SecretKey ivKey = crypto.deriveIncomingIvKey(secret);
@@ -58,10 +58,15 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
} catch(InvalidKeyException badKey) { } catch(InvalidKeyException badKey) {
throw new RuntimeException(badKey); throw new RuntimeException(badKey);
} }
boolean initiator = IvEncoder.getInitiatorFlag(iv); // Check that the initiator flag is raised
TransportId t = new TransportId(IvEncoder.getTransportId(iv)); if(!IvEncoder.getInitiatorFlag(iv))
throw new IllegalArgumentException();
// Check that the transport ID matches the expected ID
if(!t.equals(new TransportId(IvEncoder.getTransportId(iv))))
throw new IllegalArgumentException();
// Copy the connection number
long connection = IvEncoder.getConnectionNumber(iv); long connection = IvEncoder.getConnectionNumber(iv);
return createConnectionWriter(out, capacity, initiator, t, connection, return createConnectionWriter(out, capacity, false, t, connection,
secret); secret);
} }
} }

View File

@@ -3,6 +3,7 @@ package net.sf.briar.transport.stream;
import java.io.IOException; import java.io.IOException;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.TransportId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory;
@@ -20,10 +21,11 @@ public class IncomingStreamConnection extends StreamConnection {
IncomingStreamConnection(ConnectionReaderFactory connReaderFactory, IncomingStreamConnection(ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId, ProtocolWriterFactory protoWriterFactory, TransportId transportId,
StreamTransportConnection connection, byte[] encryptedIv) { ContactId contactId, StreamTransportConnection connection,
byte[] encryptedIv) {
super(connReaderFactory, connWriterFactory, db, protoReaderFactory, super(connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, contactId, connection); protoWriterFactory, transportId, contactId, connection);
this.encryptedIv = encryptedIv; this.encryptedIv = encryptedIv;
} }
@@ -40,7 +42,7 @@ public class IncomingStreamConnection extends StreamConnection {
IOException { IOException {
byte[] secret = db.getSharedSecret(contactId); byte[] secret = db.getSharedSecret(contactId);
return connWriterFactory.createConnectionWriter( return connWriterFactory.createConnectionWriter(
connection.getOutputStream(), Long.MAX_VALUE, encryptedIv, connection.getOutputStream(), Long.MAX_VALUE, transportId,
secret); encryptedIv, secret);
} }
} }

View File

@@ -16,18 +16,15 @@ import net.sf.briar.api.transport.StreamTransportConnection;
public class OutgoingStreamConnection extends StreamConnection { public class OutgoingStreamConnection extends StreamConnection {
private final TransportId transportId;
private long connectionNum = -1L; // Locking: this private long connectionNum = -1L; // Locking: this
OutgoingStreamConnection(ConnectionReaderFactory connReaderFactory, OutgoingStreamConnection(ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId, ProtocolWriterFactory protoWriterFactory, TransportId transportId,
StreamTransportConnection connection, TransportId transportId) { ContactId contactId, StreamTransportConnection connection) {
super(connReaderFactory, connWriterFactory, db, protoReaderFactory, super(connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, contactId, connection); protoWriterFactory, transportId, contactId, connection);
this.transportId = transportId;
} }
@Override @Override

View File

@@ -12,6 +12,7 @@ import java.util.logging.Logger;
import net.sf.briar.api.ContactId; import net.sf.briar.api.ContactId;
import net.sf.briar.api.FormatException; import net.sf.briar.api.FormatException;
import net.sf.briar.api.TransportId;
import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DatabaseListener; import net.sf.briar.api.db.DatabaseListener;
import net.sf.briar.api.db.DbException; import net.sf.briar.api.db.DbException;
@@ -49,6 +50,7 @@ abstract class StreamConnection implements DatabaseListener {
protected final DatabaseComponent db; protected final DatabaseComponent db;
protected final ProtocolReaderFactory protoReaderFactory; protected final ProtocolReaderFactory protoReaderFactory;
protected final ProtocolWriterFactory protoWriterFactory; protected final ProtocolWriterFactory protoWriterFactory;
protected final TransportId transportId;
protected final ContactId contactId; protected final ContactId contactId;
protected final StreamTransportConnection connection; protected final StreamTransportConnection connection;
@@ -61,13 +63,14 @@ abstract class StreamConnection implements DatabaseListener {
StreamConnection(ConnectionReaderFactory connReaderFactory, StreamConnection(ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory, ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId, ProtocolWriterFactory protoWriterFactory, TransportId transportId,
StreamTransportConnection connection) { ContactId contactId, StreamTransportConnection connection) {
this.connReaderFactory = connReaderFactory; this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory; this.connWriterFactory = connWriterFactory;
this.db = db; this.db = db;
this.protoReaderFactory = protoReaderFactory; this.protoReaderFactory = protoReaderFactory;
this.protoWriterFactory = protoWriterFactory; this.protoWriterFactory = protoWriterFactory;
this.transportId = transportId;
this.contactId = contactId; this.contactId = contactId;
this.connection = connection; this.connection = connection;
} }

View File

@@ -32,11 +32,11 @@ public class StreamConnectionFactoryImpl implements StreamConnectionFactory {
this.protoWriterFactory = protoWriterFactory; this.protoWriterFactory = protoWriterFactory;
} }
public void createIncomingConnection(ContactId c, public void createIncomingConnection(TransportId t, ContactId c,
StreamTransportConnection s, byte[] encryptedIv) { StreamTransportConnection s, byte[] encryptedIv) {
final StreamConnection conn = new IncomingStreamConnection( final StreamConnection conn = new IncomingStreamConnection(
connReaderFactory, connWriterFactory, db, protoReaderFactory, connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, c, s, encryptedIv); protoWriterFactory, t, c, s, encryptedIv);
Runnable write = new Runnable() { Runnable write = new Runnable() {
public void run() { public void run() {
conn.write(); conn.write();
@@ -55,7 +55,7 @@ public class StreamConnectionFactoryImpl implements StreamConnectionFactory {
StreamTransportConnection s) { StreamTransportConnection s) {
final StreamConnection conn = new OutgoingStreamConnection( final StreamConnection conn = new OutgoingStreamConnection(
connReaderFactory, connWriterFactory, db, protoReaderFactory, connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, c, s, t); protoWriterFactory, t, c, s);
Runnable write = new Runnable() { Runnable write = new Runnable() {
public void run() { public void run() {
conn.write(); conn.write();