Encrypted shard return handshake

This commit is contained in:
ameba23
2021-04-14 17:22:34 +02:00
parent 9b4f5be6fe
commit 536905c260
2 changed files with 190 additions and 38 deletions

View File

@@ -8,6 +8,9 @@ import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.KeyPair; import org.briarproject.bramble.api.crypto.KeyPair;
import org.briarproject.bramble.api.crypto.SecretKey; import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.data.BdfList; import org.briarproject.bramble.api.data.BdfList;
import org.briarproject.bramble.api.transport.StreamReaderFactory;
import org.briarproject.bramble.api.transport.StreamWriter;
import org.briarproject.bramble.api.transport.StreamWriterFactory;
import org.briarproject.briar.api.socialbackup.recovery.CustodianTask; import org.briarproject.briar.api.socialbackup.recovery.CustodianTask;
import java.io.IOException; import java.io.IOException;
@@ -16,18 +19,23 @@ import java.io.OutputStream;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.Socket; import java.net.Socket;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.util.logging.Logger;
import javax.inject.Inject; import javax.inject.Inject;
import static java.util.logging.Logger.getLogger;
public class CustodianTaskImpl implements CustodianTask { public class CustodianTaskImpl implements CustodianTask {
private boolean cancelled = false; private boolean cancelled = false;
private Observer observer; private Observer observer;
private ClientHelper clientHelper; private final ClientHelper clientHelper;
private InetSocketAddress remoteSocketAddress; private InetSocketAddress remoteSocketAddress;
private Socket socket = new Socket(); private final Socket socket = new Socket();
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final AuthenticatedCipher cipher; private final AuthenticatedCipher cipher;
private final KeyPair localKeyPair; private final KeyPair localKeyPair;
@@ -35,13 +43,22 @@ public class CustodianTaskImpl implements CustodianTask {
private SecretKey sharedSecret; private SecretKey sharedSecret;
private final int TIMEOUT = 120 * 1000; private final int TIMEOUT = 120 * 1000;
private final int NONCE_LENGTH = 24; // TODO get this constant private final int NONCE_LENGTH = 24; // TODO get this constant
private final StreamReaderFactory streamReaderFactory;
private final StreamWriterFactory streamWriterFactory;
private static final Logger LOG =
getLogger(CustodianTaskImpl.class.getName());
@Inject @Inject
CustodianTaskImpl(CryptoComponent crypto, ClientHelper clientHelper, CustodianTaskImpl(CryptoComponent crypto, ClientHelper clientHelper,
AuthenticatedCipher cipher) { AuthenticatedCipher cipher, StreamReaderFactory streamReaderFactory,
StreamWriterFactory streamWriterFactory) {
this.clientHelper = clientHelper; this.clientHelper = clientHelper;
this.crypto = crypto; this.crypto = crypto;
this.streamReaderFactory = streamReaderFactory;
this.streamWriterFactory = streamWriterFactory;
this.secureRandom = crypto.getSecureRandom(); this.secureRandom = crypto.getSecureRandom();
this.cipher = cipher; this.cipher = cipher;
localKeyPair = crypto.generateAgreementKeyPair(); localKeyPair = crypto.generateAgreementKeyPair();
} }
@@ -58,8 +75,10 @@ public class CustodianTaskImpl implements CustodianTask {
try { try {
socket.close(); socket.close();
} catch (IOException e) { } catch (IOException e) {
// The reason here is OTHER rather than NO_CONNECTION because
// the socket could fail to close because it is already closed
observer.onStateChanged(new CustodianTask.State.Failure( observer.onStateChanged(new CustodianTask.State.Failure(
State.Failure.Reason.NO_CONNECTION)); State.Failure.Reason.OTHER));
} }
observer.onStateChanged( observer.onStateChanged(
new CustodianTask.State.Failure(State.Failure.Reason.OTHER)); new CustodianTask.State.Failure(State.Failure.Reason.OTHER));
@@ -80,7 +99,7 @@ public class CustodianTaskImpl implements CustodianTask {
crypto.deriveSharedSecret("ShardReturn", remotePublicKey, crypto.deriveSharedSecret("ShardReturn", remotePublicKey,
localKeyPair, addressRaw); localKeyPair, addressRaw);
System.out.println( LOG.info(
" Qr code decoded " + remotePublicKey.getEncoded().length + " Qr code decoded " + remotePublicKey.getEncoded().length +
" " + " " +
remoteSocketAddress); remoteSocketAddress);
@@ -96,30 +115,48 @@ public class CustodianTaskImpl implements CustodianTask {
observer.onStateChanged(new CustodianTask.State.SendingShard()); observer.onStateChanged(new CustodianTask.State.SendingShard());
try { try {
socket.connect(remoteSocketAddress, TIMEOUT); socket.connect(remoteSocketAddress, TIMEOUT);
LOG.info("Connected to secret owner " + remoteSocketAddress);
OutputStream outputStream = socket.getOutputStream(); OutputStream outputStream = socket.getOutputStream();
outputStream.write(createPayload());
// TODO insert the actual payload
byte[] payload = "crunchy".getBytes();
byte[] payloadNonce = new byte[NONCE_LENGTH];
secureRandom.nextBytes(payloadNonce);
byte[] payloadEncrypted = encrypt(payload, payloadNonce);
outputStream.write(localKeyPair.getPublic().getEncoded());
outputStream.write(payloadNonce);
outputStream.write(ByteBuffer.allocate(4).putInt(payloadEncrypted.length)
.array());
LOG.info("Written payload header");
outputStream.write(payloadEncrypted);
// OutputStream encryptedOutputStream = streamWriterFactory
// .createContactExchangeStreamWriter(outputStream,
// sharedSecret).getOutputStream();
// encryptedOutputStream.write(payload);
LOG.info("Written payload");
observer.onStateChanged(new CustodianTask.State.ReceivingAck()); observer.onStateChanged(new CustodianTask.State.ReceivingAck());
} catch (IOException e) { } catch (IOException e) {
if (e instanceof SocketTimeoutException) {
observer.onStateChanged(new CustodianTask.State.Failure(
State.Failure.Reason.NO_CONNECTION));
return;
}
observer.onStateChanged(new CustodianTask.State.Failure( observer.onStateChanged(new CustodianTask.State.Failure(
State.Failure.Reason.QR_CODE_INVALID)); State.Failure.Reason.QR_CODE_INVALID));
return; return;
} // }
System.out.println("Connected *****");
receiveAck();
}
private byte[] createPayload() throws FormatException {
BdfList payloadList = new BdfList();
payloadList.add(localKeyPair.getPublic().getEncoded());
byte[] nonce = new byte[NONCE_LENGTH];
secureRandom.nextBytes(nonce);
payloadList.add(nonce);
try {
payloadList.add(encrypt("crunchy".getBytes(), nonce));
} catch (GeneralSecurityException e) { } catch (GeneralSecurityException e) {
throw new FormatException(); observer.onStateChanged(new CustodianTask.State.Failure(
State.Failure.Reason.OTHER));
return;
} }
return clientHelper.toByteArray(payloadList); receiveAck();
} }
private byte[] encrypt(byte[] message, byte[] nonce) private byte[] encrypt(byte[] message, byte[] nonce)
@@ -141,17 +178,33 @@ public class CustodianTaskImpl implements CustodianTask {
private void receiveAck() { private void receiveAck() {
try { try {
InputStream inputStream = socket.getInputStream(); InputStream inputStream = socket.getInputStream();
byte[] ackMessage = new byte[3]; // InputStream inputStream = streamReaderFactory
int read = inputStream.read(ackMessage); // .createContactExchangeStreamReader(socket.getInputStream(),
if (read < 0) throw new IOException("Ack not read"); // sharedSecret);
System.out.println("ack message: " + new String(ackMessage)); byte[] ackNonce = read(inputStream, NONCE_LENGTH);
byte[] ackMessageEncrypted = read(inputStream, 3 + cipher.getMacBytes());
byte[] ackMessage = decrypt(ackMessageEncrypted, ackNonce);
String ackMessageString = new String(ackMessage);
LOG.info("Received ack message: " + new String(ackMessage));
if (!ackMessageString.equals("ack")) throw new GeneralSecurityException("Bad ack message");
observer.onStateChanged(new CustodianTask.State.Success()); observer.onStateChanged(new CustodianTask.State.Success());
socket.close(); socket.close();
} catch (IOException e) { } catch (IOException e) {
LOG.warning("IO Error reading ack" + e.getMessage());
observer.onStateChanged(new CustodianTask.State.Failure( observer.onStateChanged(new CustodianTask.State.Failure(
State.Failure.Reason.QR_CODE_INVALID)); State.Failure.Reason.QR_CODE_INVALID));
return; } catch (GeneralSecurityException e) {
LOG.warning("Security Error reading ack" + e.getMessage());
observer.onStateChanged(new CustodianTask.State.Failure(
State.Failure.Reason.OTHER));
} }
} }
private byte[] read(InputStream inputStream, int length)
throws IOException {
byte[] output = new byte[length];
int bytesRead = inputStream.read(output);
if (bytesRead < 0) throw new IOException("Cannot read from socket");
return output;
}
} }

View File

@@ -1,10 +1,18 @@
package org.briarproject.briar.socialbackup.recovery; package org.briarproject.briar.socialbackup.recovery;
import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.client.ClientHelper; import org.briarproject.bramble.api.client.ClientHelper;
import org.briarproject.bramble.api.crypto.AgreementPublicKey;
import org.briarproject.bramble.api.crypto.AuthenticatedCipher;
import org.briarproject.bramble.api.crypto.CryptoComponent; import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.KeyPair; import org.briarproject.bramble.api.crypto.KeyPair;
import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.data.BdfList; import org.briarproject.bramble.api.data.BdfList;
import org.briarproject.bramble.api.keyagreement.Payload;
import org.briarproject.bramble.api.lifecycle.IoExecutor; import org.briarproject.bramble.api.lifecycle.IoExecutor;
import org.briarproject.bramble.api.transport.StreamReaderFactory;
import org.briarproject.bramble.api.transport.StreamWriter;
import org.briarproject.bramble.api.transport.StreamWriterFactory;
import org.briarproject.briar.api.socialbackup.recovery.SecretOwnerTask; import org.briarproject.briar.api.socialbackup.recovery.SecretOwnerTask;
import java.io.IOException; import java.io.IOException;
@@ -14,15 +22,22 @@ import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.net.Socket; import java.net.Socket;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.logging.Logger;
import javax.inject.Inject; import javax.inject.Inject;
import static java.util.logging.Logger.getLogger;
public class SecretOwnerTaskImpl implements SecretOwnerTask { public class SecretOwnerTaskImpl implements SecretOwnerTask {
private final CryptoComponent crypto; private final CryptoComponent crypto;
private final Executor ioExecutor; private final Executor ioExecutor;
private final KeyPair localKeyPair; private final KeyPair localKeyPair;
private final AuthenticatedCipher cipher;
private boolean cancelled = false; private boolean cancelled = false;
private InetSocketAddress socketAddress; private InetSocketAddress socketAddress;
private ClientHelper clientHelper; private ClientHelper clientHelper;
@@ -30,21 +45,36 @@ public class SecretOwnerTaskImpl implements SecretOwnerTask {
private Observer observer; private Observer observer;
private ServerSocket serverSocket; private ServerSocket serverSocket;
private Socket socket; private Socket socket;
private SecretKey sharedSecret;
private final int NONCE_LENGTH = 24;
private final SecureRandom secureRandom;
private final StreamReaderFactory streamReaderFactory;
private final StreamWriterFactory streamWriterFactory;
private static final Logger LOG =
getLogger(SecretOwnerTaskImpl.class.getName());
@Inject @Inject
SecretOwnerTaskImpl(CryptoComponent crypto, SecretOwnerTaskImpl(AuthenticatedCipher cipher, CryptoComponent crypto,
@IoExecutor Executor ioExecutor, ClientHelper clientHelper) { @IoExecutor Executor ioExecutor, ClientHelper clientHelper, StreamReaderFactory streamReaderFactory, StreamWriterFactory streamWriterFactory) {
this.crypto = crypto; this.crypto = crypto;
secureRandom = crypto.getSecureRandom();
this.cipher = cipher;
this.ioExecutor = ioExecutor; this.ioExecutor = ioExecutor;
this.clientHelper = clientHelper; this.clientHelper = clientHelper;
this.streamReaderFactory = streamReaderFactory;
this.streamWriterFactory = streamWriterFactory;
localKeyPair = crypto.generateAgreementKeyPair(); localKeyPair = crypto.generateAgreementKeyPair();
} }
@Override @Override
public void start(Observer observer, InetAddress inetAddress) { public void start(Observer observer, InetAddress inetAddress) {
this.observer = observer; this.observer = observer;
if (inetAddress == null) observer.onStateChanged(new State.Failure()); if (inetAddress == null) {
System.out.println("InetAddress is " + inetAddress); LOG.warning("Cannot retrieve local IP address, failing.");
observer.onStateChanged(new State.Failure());
}
LOG.info("InetAddress is " + inetAddress);
socketAddress = new InetSocketAddress(inetAddress, PORT); socketAddress = new InetSocketAddress(inetAddress, PORT);
// start listening on socketAddress // start listening on socketAddress
@@ -52,6 +82,7 @@ public class SecretOwnerTaskImpl implements SecretOwnerTask {
serverSocket = new ServerSocket(); serverSocket = new ServerSocket();
serverSocket.bind(socketAddress); serverSocket.bind(socketAddress);
} catch (IOException e) { } catch (IOException e) {
LOG.warning("IO Error when listening on local socket" + e.getMessage());
observer.onStateChanged(new State.Failure()); observer.onStateChanged(new State.Failure());
// TODO could try incrementing the port number // TODO could try incrementing the port number
return; return;
@@ -65,34 +96,86 @@ public class SecretOwnerTaskImpl implements SecretOwnerTask {
payloadList.add(socketAddress.getPort()); payloadList.add(socketAddress.getPort());
observer.onStateChanged( observer.onStateChanged(
new State.Listening(clientHelper.toByteArray(payloadList))); new State.Listening(clientHelper.toByteArray(payloadList)));
} catch (Exception e) { } catch (FormatException e) {
LOG.warning("Error encoding QR code");
observer.onStateChanged(new State.Failure()); observer.onStateChanged(new State.Failure());
return; return;
} }
receiveShard(); receivePayload();
} }
private void receiveShard() { private void receivePayload() {
try { try {
socket = serverSocket.accept(); socket = serverSocket.accept();
LOG.info("Client connected");
observer.onStateChanged(new State.ReceivingShard()); observer.onStateChanged(new State.ReceivingShard());
InputStream inputStream = socket.getInputStream(); InputStream inputStream = socket.getInputStream();
byte[] payloadRaw = new byte[7];
int read = inputStream.read(payloadRaw); AgreementPublicKey remotePublicKey = new AgreementPublicKey(read(inputStream, 32));
if (read < 0) throw new IOException("Payload not read"); LOG.info("Read remote public key");
System.out.println("payload message: " + new String(payloadRaw)); deriveSharedSecret(remotePublicKey);
byte[] payloadNonce = read(inputStream, NONCE_LENGTH);
LOG.info("Read payload nonce");
byte[] payloadLengthRaw = read(inputStream, 4);
int payloadLength = ByteBuffer.wrap(payloadLengthRaw).getInt();
LOG.info("Expected payload length " + payloadLength + " bytes");
byte[] payloadRaw = read(inputStream, payloadLength);
// InputStream clearInputStream = streamReaderFactory.createContactExchangeStreamReader(inputStream, sharedSecret);
// byte[] payloadClear = read(clearInputStream, payloadLength);
byte[] payloadClear = decrypt(payloadRaw, payloadNonce);
LOG.info("Payload decrypted: " + new String(payloadClear));
// StreamWriter streamWriter = streamWriterFactory.createContactExchangeStreamWriter(socket.getOutputStream(), sharedSecret);
// OutputStream outputStream = streamWriter.getOutputStream();
OutputStream outputStream = socket.getOutputStream(); OutputStream outputStream = socket.getOutputStream();
outputStream.write("ack".getBytes()); byte[] ackNonce = new byte[NONCE_LENGTH];
secureRandom.nextBytes(ackNonce);
outputStream.write(ackNonce);
byte[] ackMessage = encrypt("ack".getBytes(), ackNonce);
outputStream.write(ackMessage);
LOG.info("Acknowledgement sent");
serverSocket.close(); serverSocket.close();
observer.onStateChanged(new State.Success()); observer.onStateChanged(new State.Success());
} catch (IOException e) { } catch (IOException e) {
LOG.warning("IO Error receiving payload" + e.getMessage());
// TODO reasons
observer.onStateChanged(new State.Failure());
} catch (GeneralSecurityException e) {
LOG.warning("Security Error receiving payload" + e.getMessage());
observer.onStateChanged(new State.Failure()); observer.onStateChanged(new State.Failure());
} }
} }
private byte[] read (InputStream inputStream, int length) throws IOException {
byte[] output = new byte[length];
int read = inputStream.read(output);
if (read < 0) throw new IOException("Cannot read from socket");
return output;
}
private void deriveSharedSecret(AgreementPublicKey remotePublicKey) throws
GeneralSecurityException {
byte[] addressRaw = socketAddress.getAddress().getAddress();
sharedSecret =
crypto.deriveSharedSecret("ShardReturn", remotePublicKey,
localKeyPair, addressRaw);
}
@Override @Override
public void cancel() { public void cancel() {
cancelled = true; cancelled = true;
LOG.info("Cancel called, failing...");
try { try {
serverSocket.close(); serverSocket.close();
} catch (IOException e) { } catch (IOException e) {
@@ -100,4 +183,20 @@ public class SecretOwnerTaskImpl implements SecretOwnerTask {
} }
observer.onStateChanged(new State.Failure()); observer.onStateChanged(new State.Failure());
} }
private byte[] encrypt(byte[] message, byte[] nonce)
throws GeneralSecurityException {
cipher.init(true, sharedSecret, nonce);
byte[] cipherText = new byte[message.length + cipher.getMacBytes()];
cipher.process(message, 0, message.length, cipherText, 0);
return cipherText;
}
private byte[] decrypt(byte[] cipherText, byte[] nonce)
throws GeneralSecurityException {
cipher.init(false, sharedSecret, nonce);
byte[] message = new byte[cipherText.length - cipher.getMacBytes()];
cipher.process(cipherText, 0, cipherText.length, message, 0);
return message;
}
} }