Moved Bytes to the main package. Added a SharedSecret class to parse

encrypted secrets retrieved from the database.
This commit is contained in:
akwizgran
2011-08-11 19:14:20 +01:00
parent df972e294d
commit 7545a1cc8f
8 changed files with 68 additions and 26 deletions

View File

@@ -1,4 +1,4 @@
package net.sf.briar.api.serial; package net.sf.briar.api;
import java.util.Arrays; import java.util.Arrays;

View File

@@ -10,11 +10,11 @@ import javax.crypto.SecretKey;
public interface CryptoComponent { public interface CryptoComponent {
SecretKey deriveMacKey(byte[] secret, boolean alice); SecretKey deriveMacKey(byte[] secret);
SecretKey derivePacketKey(byte[] secret, boolean alice); SecretKey derivePacketKey(byte[] secret);
SecretKey deriveTagKey(byte[] secret, boolean alice); SecretKey deriveTagKey(byte[] secret);
KeyPair generateKeyPair(); KeyPair generateKeyPair();

View File

@@ -68,33 +68,31 @@ class CryptoComponentImpl implements CryptoComponent {
} }
} }
public SecretKey deriveMacKey(byte[] secret, boolean alice) { public SecretKey deriveMacKey(byte[] secret) {
if(alice) return deriveKey("MACA", secret); SharedSecret s = new SharedSecret(secret);
else return deriveKey("MACB", secret); if(s.getAlice()) return deriveKey("MACA", s.getIv(), s.getCiphertext());
else return deriveKey("MACB", s.getIv(), s.getCiphertext());
} }
private SecretKey deriveKey(String name, byte[] secret) { private SecretKey deriveKey(String name, IvParameterSpec iv,
byte[] ciphertext) {
MessageDigest digest = getMessageDigest(); MessageDigest digest = getMessageDigest();
try { try {
digest.update(name.getBytes("UTF-8")); digest.update(name.getBytes("UTF-8"));
} catch(UnsupportedEncodingException e) { } catch(UnsupportedEncodingException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
byte[] decrypted = decryptSharedSecret(secret); byte[] decrypted = decryptSharedSecret(iv, ciphertext);
digest.update(decrypted); digest.update(decrypted);
Arrays.fill(decrypted, (byte) 0); Arrays.fill(decrypted, (byte) 0); // Destroy the plaintext secret
return new SecretKeySpec(digest.digest(), SECRET_KEY_ALGO); return new SecretKeySpec(digest.digest(), SECRET_KEY_ALGO);
} }
private byte[] decryptSharedSecret(byte[] secret) { private byte[] decryptSharedSecret(IvParameterSpec iv, byte[] ciphertext) {
// The first 16 bytes of the stored secret are the IV
if(secret.length <= 16) throw new IllegalArgumentException();
IvParameterSpec iv = new IvParameterSpec(secret, 0, 16);
try { try {
// Decrypt and return the remainder of the stored secret
Cipher c = Cipher.getInstance(SECRET_STORAGE_ALGO, PROVIDER); Cipher c = Cipher.getInstance(SECRET_STORAGE_ALGO, PROVIDER);
c.init(Cipher.DECRYPT_MODE, secretStorageKey, iv); c.init(Cipher.DECRYPT_MODE, secretStorageKey, iv);
return c.doFinal(secret, 16, secret.length - 16); return c.doFinal(ciphertext);
} catch(BadPaddingException e) { } catch(BadPaddingException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} catch(IllegalBlockSizeException e) { } catch(IllegalBlockSizeException e) {
@@ -112,14 +110,16 @@ class CryptoComponentImpl implements CryptoComponent {
} }
} }
public SecretKey derivePacketKey(byte[] secret, boolean alice) { public SecretKey derivePacketKey(byte[] secret) {
if(alice) return deriveKey("PKTA", secret); SharedSecret s = new SharedSecret(secret);
else return deriveKey("PKTB", secret); if(s.getAlice()) return deriveKey("PKTA", s.getIv(), s.getCiphertext());
else return deriveKey("PKTB", s.getIv(), s.getCiphertext());
} }
public SecretKey deriveTagKey(byte[] secret, boolean alice) { public SecretKey deriveTagKey(byte[] secret) {
if(alice) return deriveKey("TAGA", secret); SharedSecret s = new SharedSecret(secret);
else return deriveKey("TAGB", secret); if(s.getAlice()) return deriveKey("TAGA", s.getIv(), s.getCiphertext());
else return deriveKey("TAGB", s.getIv(), s.getCiphertext());
} }
public KeyPair generateKeyPair() { public KeyPair generateKeyPair() {

View File

@@ -0,0 +1,42 @@
package net.sf.briar.crypto;
import java.util.Arrays;
import javax.crypto.spec.IvParameterSpec;
class SharedSecret {
private static final int IV_BYTES = 16;
private final IvParameterSpec iv;
private final boolean alice;
private final byte[] ciphertext;
SharedSecret(byte[] secret) {
if(secret.length < IV_BYTES + 2) throw new IllegalArgumentException();
iv = new IvParameterSpec(secret, 0, IV_BYTES);
switch(secret[IV_BYTES]) {
case 0:
alice = false;
break;
case 1:
alice = true;
break;
default:
throw new IllegalArgumentException();
}
ciphertext = Arrays.copyOfRange(secret, IV_BYTES + 1, secret.length);
}
IvParameterSpec getIv() {
return iv;
}
boolean getAlice() {
return alice;
}
byte[] getCiphertext() {
return ciphertext;
}
}

View File

@@ -8,7 +8,7 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import net.sf.briar.api.serial.Bytes; import net.sf.briar.api.Bytes;
import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.FormatException; import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.ObjectReader;

View File

@@ -7,7 +7,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import net.sf.briar.api.serial.Bytes; import net.sf.briar.api.Bytes;
import net.sf.briar.api.serial.Tag; import net.sf.briar.api.serial.Tag;
import net.sf.briar.api.serial.Writable; import net.sf.briar.api.serial.Writable;
import net.sf.briar.api.serial.Writer; import net.sf.briar.api.serial.Writer;

View File

@@ -11,7 +11,7 @@ import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec; import javax.crypto.spec.SecretKeySpec;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.api.serial.Bytes; import net.sf.briar.api.Bytes;
import org.bouncycastle.jce.provider.BouncyCastleProvider; import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.Test; import org.junit.Test;

View File

@@ -9,10 +9,10 @@ import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import junit.framework.TestCase; import junit.framework.TestCase;
import net.sf.briar.api.Bytes;
import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Consumer;
import net.sf.briar.api.serial.FormatException; import net.sf.briar.api.serial.FormatException;
import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.ObjectReader;
import net.sf.briar.api.serial.Bytes;
import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.Reader;
import net.sf.briar.util.StringUtils; import net.sf.briar.util.StringUtils;