Refactor duplicate task code into parent class

This commit is contained in:
ameba23
2021-04-14 17:59:29 +02:00
parent e6d80ec484
commit ed1ed7d3e1
3 changed files with 112 additions and 122 deletions

View File

@@ -1,15 +1,11 @@
package org.briarproject.briar.socialbackup.recovery;
import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.client.ClientHelper;
import org.briarproject.bramble.api.crypto.AgreementPublicKey;
import org.briarproject.bramble.api.crypto.AuthenticatedCipher;
import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.KeyPair;
import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.data.BdfList;
import org.briarproject.bramble.api.transport.StreamReaderFactory;
import org.briarproject.bramble.api.transport.StreamWriter;
import org.briarproject.bramble.api.transport.StreamWriterFactory;
import org.briarproject.briar.api.socialbackup.recovery.CustodianTask;
@@ -22,29 +18,23 @@ import java.net.Socket;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.util.logging.Logger;
import javax.inject.Inject;
import static java.util.logging.Logger.getLogger;
public class CustodianTaskImpl implements CustodianTask {
public class CustodianTaskImpl extends ReturnShardTaskImpl
implements CustodianTask {
private boolean cancelled = false;
private Observer observer;
private final ClientHelper clientHelper;
private InetSocketAddress remoteSocketAddress;
private final Socket socket = new Socket();
private final CryptoComponent crypto;
private final AuthenticatedCipher cipher;
private final KeyPair localKeyPair;
private final SecureRandom secureRandom;
private SecretKey sharedSecret;
private final int TIMEOUT = 120 * 1000;
private final int NONCE_LENGTH = 24; // TODO get this constant
private final StreamReaderFactory streamReaderFactory;
private final StreamWriterFactory streamWriterFactory;
// private final StreamReaderFactory streamReaderFactory;
// private final StreamWriterFactory streamWriterFactory;
private static final Logger LOG =
getLogger(CustodianTaskImpl.class.getName());
@@ -53,14 +43,12 @@ public class CustodianTaskImpl implements CustodianTask {
CustodianTaskImpl(CryptoComponent crypto, ClientHelper clientHelper,
AuthenticatedCipher cipher, StreamReaderFactory streamReaderFactory,
StreamWriterFactory streamWriterFactory) {
super(cipher, crypto);
this.clientHelper = clientHelper;
this.crypto = crypto;
this.streamReaderFactory = streamReaderFactory;
this.streamWriterFactory = streamWriterFactory;
this.secureRandom = crypto.getSecureRandom();
// this.streamReaderFactory = streamReaderFactory;
// this.streamWriterFactory = streamWriterFactory;
this.cipher = cipher;
localKeyPair = crypto.generateAgreementKeyPair();
}
@Override
@@ -95,14 +83,9 @@ public class CustodianTaskImpl implements CustodianTask {
remoteSocketAddress =
new InetSocketAddress(InetAddress.getByAddress(addressRaw),
port);
sharedSecret =
crypto.deriveSharedSecret("ShardReturn", remotePublicKey,
localKeyPair, addressRaw);
deriveSharedSecret(remotePublicKey, addressRaw);
LOG.info(
" Qr code decoded " + remotePublicKey.getEncoded().length +
" " +
remoteSocketAddress);
LOG.info("Qr code payload parsed successfully");
} catch (Exception e) {
observer.onStateChanged(new CustodianTask.State.Failure(
State.Failure.Reason.QR_CODE_INVALID));
@@ -122,12 +105,13 @@ public class CustodianTaskImpl implements CustodianTask {
// TODO insert the actual payload
byte[] payload = "crunchy".getBytes();
byte[] payloadNonce = new byte[NONCE_LENGTH];
secureRandom.nextBytes(payloadNonce);
byte[] payloadNonce = generateNonce();
byte[] payloadEncrypted = encrypt(payload, payloadNonce);
outputStream.write(localKeyPair.getPublic().getEncoded());
outputStream.write(payloadNonce);
outputStream.write(ByteBuffer.allocate(4).putInt(payloadEncrypted.length)
outputStream.write(ByteBuffer.allocate(4)
.putInt(payloadEncrypted.length)
.array());
LOG.info("Written payload header");
@@ -159,34 +143,20 @@ public class CustodianTaskImpl implements CustodianTask {
receiveAck();
}
private byte[] encrypt(byte[] message, byte[] nonce)
throws GeneralSecurityException {
cipher.init(true, sharedSecret, nonce);
byte[] cipherText = new byte[message.length + cipher.getMacBytes()];
cipher.process(message, 0, message.length, cipherText, 0);
return cipherText;
}
private byte[] decrypt(byte[] cipherText, byte[] nonce)
throws GeneralSecurityException {
cipher.init(false, sharedSecret, nonce);
byte[] message = new byte[cipherText.length - cipher.getMacBytes()];
cipher.process(cipherText, 0, cipherText.length, message, 0);
return message;
}
private void receiveAck() {
try {
InputStream inputStream = socket.getInputStream();
// InputStream inputStream = streamReaderFactory
// .createContactExchangeStreamReader(socket.getInputStream(),
// sharedSecret);
byte[] ackNonce = read(inputStream, NONCE_LENGTH);
byte[] ackMessageEncrypted = read(inputStream, 3 + cipher.getMacBytes());
byte[] ackNonce = read(inputStream, NONCE_LENGTH);
byte[] ackMessageEncrypted =
read(inputStream, 3 + cipher.getMacBytes());
byte[] ackMessage = decrypt(ackMessageEncrypted, ackNonce);
String ackMessageString = new String(ackMessage);
LOG.info("Received ack message: " + new String(ackMessage));
if (!ackMessageString.equals("ack")) throw new GeneralSecurityException("Bad ack message");
if (!ackMessageString.equals("ack"))
throw new GeneralSecurityException("Bad ack message");
observer.onStateChanged(new CustodianTask.State.Success());
socket.close();
} catch (IOException e) {
@@ -199,12 +169,4 @@ public class CustodianTaskImpl implements CustodianTask {
State.Failure.Reason.OTHER));
}
}
private byte[] read(InputStream inputStream, int length)
throws IOException {
byte[] output = new byte[length];
int bytesRead = inputStream.read(output);
if (bytesRead < 0) throw new IOException("Cannot read from socket");
return output;
}
}

View File

@@ -0,0 +1,68 @@
package org.briarproject.briar.socialbackup.recovery;
import org.briarproject.bramble.api.crypto.AgreementPublicKey;
import org.briarproject.bramble.api.crypto.AuthenticatedCipher;
import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.KeyPair;
import org.briarproject.bramble.api.crypto.SecretKey;
import java.io.IOException;
import java.io.InputStream;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
public class ReturnShardTaskImpl {
private final AuthenticatedCipher cipher;
private final CryptoComponent crypto;
private final SecureRandom secureRandom;
final int PORT = 3002;
final int TIMEOUT = 120 * 1000;
final int NONCE_LENGTH = 24; // TODO get these constants
final int AGREEMENT_PUBLIC_KEY_LENGTH = 32;
SecretKey sharedSecret;
final KeyPair localKeyPair;
ReturnShardTaskImpl(AuthenticatedCipher cipher, CryptoComponent crypto) {
this.cipher = cipher;
this.crypto = crypto;
this.secureRandom = crypto.getSecureRandom();
localKeyPair = crypto.generateAgreementKeyPair();
}
byte[] generateNonce() {
byte[] nonce = new byte[NONCE_LENGTH];
secureRandom.nextBytes(nonce);
return nonce;
}
void deriveSharedSecret(AgreementPublicKey remotePublicKey, byte[] context) throws
GeneralSecurityException {
sharedSecret =
crypto.deriveSharedSecret("ShardReturn", remotePublicKey,
localKeyPair, context);
}
byte[] encrypt(byte[] message, byte[] nonce)
throws GeneralSecurityException {
cipher.init(true, sharedSecret, nonce);
byte[] cipherText = new byte[message.length + cipher.getMacBytes()];
cipher.process(message, 0, message.length, cipherText, 0);
return cipherText;
}
byte[] decrypt(byte[] cipherText, byte[] nonce)
throws GeneralSecurityException {
cipher.init(false, sharedSecret, nonce);
byte[] message = new byte[cipherText.length - cipher.getMacBytes()];
cipher.process(cipherText, 0, cipherText.length, message, 0);
return message;
}
byte[] read(InputStream inputStream, int length)
throws IOException {
byte[] output = new byte[length];
int bytesRead = inputStream.read(output);
if (bytesRead < 0) throw new IOException("Cannot read from socket");
return output;
}
}

View File

@@ -6,13 +6,8 @@ import org.briarproject.bramble.api.crypto.AgreementPublicKey;
import org.briarproject.bramble.api.crypto.AuthenticatedCipher;
import org.briarproject.bramble.api.crypto.CryptoComponent;
import org.briarproject.bramble.api.crypto.KeyPair;
import org.briarproject.bramble.api.crypto.SecretKey;
import org.briarproject.bramble.api.data.BdfList;
import org.briarproject.bramble.api.keyagreement.Payload;
import org.briarproject.bramble.api.lifecycle.IoExecutor;
import org.briarproject.bramble.api.transport.StreamReaderFactory;
import org.briarproject.bramble.api.transport.StreamWriter;
import org.briarproject.bramble.api.transport.StreamWriterFactory;
import org.briarproject.briar.api.socialbackup.recovery.SecretOwnerTask;
import java.io.IOException;
@@ -24,7 +19,6 @@ import java.net.ServerSocket;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.util.concurrent.Executor;
import java.util.logging.Logger;
@@ -32,38 +26,31 @@ import javax.inject.Inject;
import static java.util.logging.Logger.getLogger;
public class SecretOwnerTaskImpl implements SecretOwnerTask {
public class SecretOwnerTaskImpl extends ReturnShardTaskImpl
implements SecretOwnerTask {
private final CryptoComponent crypto;
private final Executor ioExecutor;
// private final Executor ioExecutor;
private final KeyPair localKeyPair;
private final AuthenticatedCipher cipher;
private boolean cancelled = false;
private InetSocketAddress socketAddress;
private ClientHelper clientHelper;
private final int PORT = 3002;
private Observer observer;
private ServerSocket serverSocket;
private Socket socket;
private SecretKey sharedSecret;
private final int NONCE_LENGTH = 24;
private final SecureRandom secureRandom;
private final StreamReaderFactory streamReaderFactory;
private final StreamWriterFactory streamWriterFactory;
// private final StreamReaderFactory streamReaderFactory;
// private final StreamWriterFactory streamWriterFactory;
private static final Logger LOG =
getLogger(SecretOwnerTaskImpl.class.getName());
@Inject
SecretOwnerTaskImpl(AuthenticatedCipher cipher, CryptoComponent crypto,
@IoExecutor Executor ioExecutor, ClientHelper clientHelper, StreamReaderFactory streamReaderFactory, StreamWriterFactory streamWriterFactory) {
this.crypto = crypto;
secureRandom = crypto.getSecureRandom();
this.cipher = cipher;
this.ioExecutor = ioExecutor;
@IoExecutor Executor ioExecutor, ClientHelper clientHelper) {
super(cipher, crypto);
// this.ioExecutor = ioExecutor;
this.clientHelper = clientHelper;
this.streamReaderFactory = streamReaderFactory;
this.streamWriterFactory = streamWriterFactory;
// this.streamReaderFactory = streamReaderFactory;
// this.streamWriterFactory = streamWriterFactory;
localKeyPair = crypto.generateAgreementKeyPair();
}
@@ -77,12 +64,13 @@ public class SecretOwnerTaskImpl implements SecretOwnerTask {
LOG.info("InetAddress is " + inetAddress);
socketAddress = new InetSocketAddress(inetAddress, PORT);
// start listening on socketAddress
// Start listening on socketAddress
try {
serverSocket = new ServerSocket();
serverSocket.bind(socketAddress);
} catch (IOException e) {
LOG.warning("IO Error when listening on local socket" + e.getMessage());
LOG.warning(
"IO Error when listening on local socket" + e.getMessage());
observer.onStateChanged(new State.Failure());
// TODO could try incrementing the port number
return;
@@ -112,37 +100,40 @@ public class SecretOwnerTaskImpl implements SecretOwnerTask {
InputStream inputStream = socket.getInputStream();
AgreementPublicKey remotePublicKey = new AgreementPublicKey(read(inputStream, 32));
AgreementPublicKey remotePublicKey =
new AgreementPublicKey(
read(inputStream, AGREEMENT_PUBLIC_KEY_LENGTH));
LOG.info("Read remote public key");
deriveSharedSecret(remotePublicKey);
byte[] addressRaw = socketAddress.getAddress().getAddress();
deriveSharedSecret(remotePublicKey, addressRaw);
byte[] payloadNonce = read(inputStream, NONCE_LENGTH);
LOG.info("Read payload nonce");
byte[] payloadLengthRaw = read(inputStream, 4);
int payloadLength = ByteBuffer.wrap(payloadLengthRaw).getInt();
LOG.info("Expected payload length " + payloadLength + " bytes");
LOG.info("Expected payload length " + payloadLength + " bytes");
byte[] payloadRaw = read(inputStream, payloadLength);
byte[] payloadRaw = read(inputStream, payloadLength);
// InputStream clearInputStream = streamReaderFactory.createContactExchangeStreamReader(inputStream, sharedSecret);
// byte[] payloadClear = read(clearInputStream, payloadLength);
byte[] payloadClear = decrypt(payloadRaw, payloadNonce);
byte[] payloadClear = decrypt(payloadRaw, payloadNonce);
LOG.info("Payload decrypted: " + new String(payloadClear));
LOG.info("Payload decrypted: " + new String(payloadClear));
// StreamWriter streamWriter = streamWriterFactory.createContactExchangeStreamWriter(socket.getOutputStream(), sharedSecret);
// OutputStream outputStream = streamWriter.getOutputStream();
OutputStream outputStream = socket.getOutputStream();
byte[] ackNonce = new byte[NONCE_LENGTH];
secureRandom.nextBytes(ackNonce);
outputStream.write(ackNonce);
byte[] ackNonce = generateNonce();
outputStream.write(ackNonce);
byte[] ackMessage = encrypt("ack".getBytes(), ackNonce);
outputStream.write(ackMessage);
LOG.info("Acknowledgement sent");
LOG.info("Acknowledgement sent");
serverSocket.close();
@@ -157,21 +148,6 @@ public class SecretOwnerTaskImpl implements SecretOwnerTask {
}
}
private byte[] read (InputStream inputStream, int length) throws IOException {
byte[] output = new byte[length];
int read = inputStream.read(output);
if (read < 0) throw new IOException("Cannot read from socket");
return output;
}
private void deriveSharedSecret(AgreementPublicKey remotePublicKey) throws
GeneralSecurityException {
byte[] addressRaw = socketAddress.getAddress().getAddress();
sharedSecret =
crypto.deriveSharedSecret("ShardReturn", remotePublicKey,
localKeyPair, addressRaw);
}
@Override
public void cancel() {
cancelled = true;
@@ -183,20 +159,4 @@ public class SecretOwnerTaskImpl implements SecretOwnerTask {
}
observer.onStateChanged(new State.Failure());
}
private byte[] encrypt(byte[] message, byte[] nonce)
throws GeneralSecurityException {
cipher.init(true, sharedSecret, nonce);
byte[] cipherText = new byte[message.length + cipher.getMacBytes()];
cipher.process(message, 0, message.length, cipherText, 0);
return cipherText;
}
private byte[] decrypt(byte[] cipherText, byte[] nonce)
throws GeneralSecurityException {
cipher.init(false, sharedSecret, nonce);
byte[] message = new byte[cipherText.length - cipher.getMacBytes()];
cipher.process(cipherText, 0, cipherText.length, message, 0);
return message;
}
}