Double-check the initiator flag and transport ID of incoming

connections, and invert the flag for the responder's side.
This commit is contained in:
akwizgran
2011-10-15 14:15:25 +01:00
parent 2618fea0eb
commit 89001e4c91
8 changed files with 31 additions and 24 deletions

View File

@@ -10,5 +10,5 @@ public interface ConnectionWriterFactory {
boolean initiator, TransportId t, long connection, byte[] secret);
ConnectionWriter createConnectionWriter(OutputStream out, long capacity,
byte[] encryptedIv, byte[] secret);
TransportId t, byte[] encryptedIv, byte[] secret);
}

View File

@@ -5,8 +5,8 @@ import net.sf.briar.api.TransportId;
public interface StreamConnectionFactory {
void createIncomingConnection(ContactId c, StreamTransportConnection s,
byte[] encryptedIv);
void createIncomingConnection(TransportId t, ContactId c,
StreamTransportConnection s, byte[] encryptedIv);
void createOutgoingConnection(TransportId t, ContactId c,
StreamTransportConnection s);

View File

@@ -118,7 +118,7 @@ public class ConnectionDispatcherImpl implements ConnectionDispatcher {
s.dispose(false);
return;
}
streamConnFactory.createIncomingConnection(c, s, encryptedIv);
streamConnFactory.createIncomingConnection(t, c, s, encryptedIv);
}
public void dispatchOutgoingConnection(TransportId t, ContactId c,

View File

@@ -43,7 +43,7 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
}
public ConnectionWriter createConnectionWriter(OutputStream out,
long capacity, byte[] encryptedIv, byte[] secret) {
long capacity, TransportId t, byte[] encryptedIv, byte[] secret) {
// Decrypt the IV
Cipher ivCipher = crypto.getIvCipher();
SecretKey ivKey = crypto.deriveIncomingIvKey(secret);
@@ -58,10 +58,15 @@ class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
} catch(InvalidKeyException badKey) {
throw new RuntimeException(badKey);
}
boolean initiator = IvEncoder.getInitiatorFlag(iv);
TransportId t = new TransportId(IvEncoder.getTransportId(iv));
// Check that the initiator flag is raised
if(!IvEncoder.getInitiatorFlag(iv))
throw new IllegalArgumentException();
// Check that the transport ID matches the expected ID
if(!t.equals(new TransportId(IvEncoder.getTransportId(iv))))
throw new IllegalArgumentException();
// Copy the connection number
long connection = IvEncoder.getConnectionNumber(iv);
return createConnectionWriter(out, capacity, initiator, t, connection,
return createConnectionWriter(out, capacity, false, t, connection,
secret);
}
}

View File

@@ -3,6 +3,7 @@ package net.sf.briar.transport.stream;
import java.io.IOException;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.TransportId;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.ProtocolReaderFactory;
@@ -20,10 +21,11 @@ public class IncomingStreamConnection extends StreamConnection {
IncomingStreamConnection(ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId,
StreamTransportConnection connection, byte[] encryptedIv) {
ProtocolWriterFactory protoWriterFactory, TransportId transportId,
ContactId contactId, StreamTransportConnection connection,
byte[] encryptedIv) {
super(connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, contactId, connection);
protoWriterFactory, transportId, contactId, connection);
this.encryptedIv = encryptedIv;
}
@@ -40,7 +42,7 @@ public class IncomingStreamConnection extends StreamConnection {
IOException {
byte[] secret = db.getSharedSecret(contactId);
return connWriterFactory.createConnectionWriter(
connection.getOutputStream(), Long.MAX_VALUE, encryptedIv,
secret);
connection.getOutputStream(), Long.MAX_VALUE, transportId,
encryptedIv, secret);
}
}

View File

@@ -16,18 +16,15 @@ import net.sf.briar.api.transport.StreamTransportConnection;
public class OutgoingStreamConnection extends StreamConnection {
private final TransportId transportId;
private long connectionNum = -1L; // Locking: this
OutgoingStreamConnection(ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId,
StreamTransportConnection connection, TransportId transportId) {
ProtocolWriterFactory protoWriterFactory, TransportId transportId,
ContactId contactId, StreamTransportConnection connection) {
super(connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, contactId, connection);
this.transportId = transportId;
protoWriterFactory, transportId, contactId, connection);
}
@Override

View File

@@ -12,6 +12,7 @@ import java.util.logging.Logger;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.TransportId;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DatabaseListener;
import net.sf.briar.api.db.DbException;
@@ -49,6 +50,7 @@ abstract class StreamConnection implements DatabaseListener {
protected final DatabaseComponent db;
protected final ProtocolReaderFactory protoReaderFactory;
protected final ProtocolWriterFactory protoWriterFactory;
protected final TransportId transportId;
protected final ContactId contactId;
protected final StreamTransportConnection connection;
@@ -61,13 +63,14 @@ abstract class StreamConnection implements DatabaseListener {
StreamConnection(ConnectionReaderFactory connReaderFactory,
ConnectionWriterFactory connWriterFactory, DatabaseComponent db,
ProtocolReaderFactory protoReaderFactory,
ProtocolWriterFactory protoWriterFactory, ContactId contactId,
StreamTransportConnection connection) {
ProtocolWriterFactory protoWriterFactory, TransportId transportId,
ContactId contactId, StreamTransportConnection connection) {
this.connReaderFactory = connReaderFactory;
this.connWriterFactory = connWriterFactory;
this.db = db;
this.protoReaderFactory = protoReaderFactory;
this.protoWriterFactory = protoWriterFactory;
this.transportId = transportId;
this.contactId = contactId;
this.connection = connection;
}

View File

@@ -32,11 +32,11 @@ public class StreamConnectionFactoryImpl implements StreamConnectionFactory {
this.protoWriterFactory = protoWriterFactory;
}
public void createIncomingConnection(ContactId c,
public void createIncomingConnection(TransportId t, ContactId c,
StreamTransportConnection s, byte[] encryptedIv) {
final StreamConnection conn = new IncomingStreamConnection(
connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, c, s, encryptedIv);
protoWriterFactory, t, c, s, encryptedIv);
Runnable write = new Runnable() {
public void run() {
conn.write();
@@ -55,7 +55,7 @@ public class StreamConnectionFactoryImpl implements StreamConnectionFactory {
StreamTransportConnection s) {
final StreamConnection conn = new OutgoingStreamConnection(
connReaderFactory, connWriterFactory, db, protoReaderFactory,
protoWriterFactory, c, s, t);
protoWriterFactory, t, c, s);
Runnable write = new Runnable() {
public void run() {
conn.write();