Files
briar/components/net/sf/briar/transport/IncomingEncryptionLayer.java

128 lines
4.0 KiB
Java

package net.sf.briar.transport;
import static javax.crypto.Cipher.DECRYPT_MODE;
import static net.sf.briar.api.transport.TransportConstants.AAD_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.IV_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.security.GeneralSecurityException;
import javax.crypto.Cipher;
import net.sf.briar.api.FormatException;
import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.ErasableKey;
class IncomingEncryptionLayer implements FrameReader {
private final InputStream in;
private final Cipher tagCipher;
private final AuthenticatedCipher frameCipher;
private final ErasableKey tagKey, frameKey;
private final byte[] iv, aad, ciphertext;
private final int frameLength;
private long frameNumber;
private boolean readTag, lastFrame;
/** Constructor for the initiator's side of a connection. */
IncomingEncryptionLayer(InputStream in, Cipher tagCipher,
AuthenticatedCipher frameCipher, ErasableKey tagKey,
ErasableKey frameKey, int frameLength) {
this.in = in;
this.tagCipher = tagCipher;
this.frameCipher = frameCipher;
this.tagKey = tagKey;
this.frameKey = frameKey;
this.frameLength = frameLength;
iv = new byte[IV_LENGTH];
aad = new byte[AAD_LENGTH];
ciphertext = new byte[frameLength];
frameNumber = 0L;
readTag = true;
lastFrame = false;
}
/** Constructor for the responder's side of a connection. */
IncomingEncryptionLayer(InputStream in, AuthenticatedCipher frameCipher,
ErasableKey frameKey, int frameLength) {
this.in = in;
this.frameCipher = frameCipher;
this.frameKey = frameKey;
this.frameLength = frameLength;
tagCipher = null;
tagKey = null;
iv = new byte[IV_LENGTH];
aad = new byte[AAD_LENGTH];
ciphertext = new byte[frameLength];
frameNumber = 0L;
readTag = false;
lastFrame = false;
}
public int readFrame(byte[] frame) throws IOException {
if(lastFrame) return -1;
// Read the tag if required
if(readTag) {
int offset = 0;
try {
while(offset < TAG_LENGTH) {
int read = in.read(ciphertext, offset, TAG_LENGTH - offset);
if(read == -1) throw new EOFException();
offset += read;
}
} catch(IOException e) {
frameKey.erase();
tagKey.erase();
throw e;
}
if(!TagEncoder.decodeTag(ciphertext, tagCipher, tagKey))
throw new FormatException();
readTag = false;
}
// Read the frame
int ciphertextLength = 0;
try {
while(ciphertextLength < frameLength) {
int read = in.read(ciphertext, ciphertextLength,
frameLength - ciphertextLength);
if(read == -1) break; // We'll check the length later
ciphertextLength += read;
}
} catch(IOException e) {
frameKey.erase();
tagKey.erase();
throw e;
}
int plaintextLength = ciphertextLength - MAC_LENGTH;
if(plaintextLength < HEADER_LENGTH) throw new EOFException();
// Decrypt and authenticate the frame
FrameEncoder.encodeIv(iv, frameNumber);
FrameEncoder.encodeAad(aad, frameNumber, plaintextLength);
try {
frameCipher.init(DECRYPT_MODE, frameKey, iv, aad);
int decrypted = frameCipher.doFinal(ciphertext, 0, ciphertextLength,
frame, 0);
if(decrypted != plaintextLength) throw new RuntimeException();
} catch(GeneralSecurityException badCipher) {
throw new RuntimeException(badCipher);
}
// Decode and validate the header
lastFrame = FrameEncoder.isLastFrame(frame);
if(!lastFrame && ciphertextLength < frameLength)
throw new EOFException();
int payloadLength = FrameEncoder.getPayloadLength(frame);
if(payloadLength > plaintextLength - HEADER_LENGTH)
throw new FormatException();
// If there's any padding it must be all zeroes
for(int i = HEADER_LENGTH + payloadLength; i < plaintextLength; i++)
if(frame[i] != 0) throw new FormatException();
frameNumber++;
return payloadLength;
}
}