diff --git a/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyAgreementTest.java b/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyAgreementTest.java index 92bb9fc38..0836b058e 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyAgreementTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyAgreementTest.java @@ -11,12 +11,14 @@ import org.junit.Test; import org.whispersystems.curve25519.Curve25519; import java.security.GeneralSecurityException; +import java.util.Arrays; import java.util.Random; -import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.SHARED_SECRET_LABEL; import static org.briarproject.bramble.test.TestUtils.getRandomBytes; import static org.briarproject.bramble.util.StringUtils.fromHexString; +import static org.briarproject.bramble.util.StringUtils.getRandomString; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertFalse; public class KeyAgreementTest extends BrambleTestCase { @@ -36,6 +38,7 @@ public class KeyAgreementTest extends BrambleTestCase { private final CryptoComponent crypto = new CryptoComponentImpl(new TestSecureRandomProvider(), null); + private final String label = getRandomString(123); private final byte[][] inputs; public KeyAgreementTest() { @@ -49,9 +52,9 @@ public class KeyAgreementTest extends BrambleTestCase { public void testDerivesSharedSecret() throws Exception { KeyPair aPair = crypto.generateAgreementKeyPair(); KeyPair bPair = crypto.generateAgreementKeyPair(); - SecretKey aShared = crypto.deriveSharedSecret(SHARED_SECRET_LABEL, + SecretKey aShared = crypto.deriveSharedSecret(label, bPair.getPublic(), aPair, inputs); - SecretKey bShared = crypto.deriveSharedSecret(SHARED_SECRET_LABEL, + SecretKey bShared = crypto.deriveSharedSecret(label, aPair.getPublic(), bPair, inputs); assertArrayEquals(aShared.getBytes(), bShared.getBytes()); } @@ -60,17 +63,14 @@ public class KeyAgreementTest extends BrambleTestCase { public void testRejectsInvalidPublicKey() throws Exception { KeyPair keyPair = crypto.generateAgreementKeyPair(); PublicKey invalid = new AgreementPublicKey(new byte[32]); - crypto.deriveSharedSecret(SHARED_SECRET_LABEL, invalid, keyPair, - inputs); + crypto.deriveSharedSecret(label, invalid, keyPair, inputs); } @Test public void testRfc7748TestVector() { - // Private keys need to be clamped because curve25519-java does the - // clamping at key generation time, not multiplication time - byte[] aPriv = AgreementKeyParser.clamp(fromHexString(ALICE_PRIVATE)); + byte[] aPriv = parsePrivateKey(ALICE_PRIVATE); byte[] aPub = fromHexString(ALICE_PUBLIC); - byte[] bPriv = AgreementKeyParser.clamp(fromHexString(BOB_PRIVATE)); + byte[] bPriv = parsePrivateKey(BOB_PRIVATE); byte[] bPub = fromHexString(BOB_PUBLIC); byte[] sharedSecret = fromHexString(SHARED_SECRET); Curve25519 curve25519 = Curve25519.getInstance("java"); @@ -79,4 +79,82 @@ public class KeyAgreementTest extends BrambleTestCase { assertArrayEquals(sharedSecret, curve25519.calculateAgreement(bPub, aPriv)); } + + @Test + public void testDerivesSameSharedSecretFromEquivalentPublicKey() { + byte[] aPub = fromHexString(ALICE_PUBLIC); + byte[] bPriv = parsePrivateKey(BOB_PRIVATE); + byte[] sharedSecret = fromHexString(SHARED_SECRET); + Curve25519 curve25519 = Curve25519.getInstance("java"); + + // Flip the unused most significant bit of the little-endian public key + byte[] aPubEquiv = aPub.clone(); + aPubEquiv[31] ^= (byte) 128; + + // The public keys should be different but give the same shared secret + assertFalse(Arrays.equals(aPub, aPubEquiv)); + assertArrayEquals(sharedSecret, + curve25519.calculateAgreement(aPub, bPriv)); + assertArrayEquals(sharedSecret, + curve25519.calculateAgreement(aPubEquiv, bPriv)); + } + + @Test + public void testDerivesSameSharedSecretFromEquivalentPublicKeyWithoutPublicKeysHashedIn() + throws Exception { + KeyPair aPair = crypto.generateAgreementKeyPair(); + KeyPair bPair = crypto.generateAgreementKeyPair(); + + // Flip the unused most significant bit of the little-endian public key + byte[] aPub = aPair.getPublic().getEncoded(); + byte[] aPubEquiv = aPub.clone(); + aPubEquiv[31] ^= (byte) 128; + KeyPair aPairEquiv = new KeyPair(new AgreementPublicKey(aPubEquiv), + aPair.getPrivate()); + + // The public keys should be different but give the same shared secret + assertFalse(Arrays.equals(aPub, aPubEquiv)); + SecretKey shared = crypto.deriveSharedSecret(label, + aPair.getPublic(), bPair); + SecretKey sharedEquiv = crypto.deriveSharedSecret(label, + aPairEquiv.getPublic(), bPair); + assertArrayEquals(shared.getBytes(), sharedEquiv.getBytes()); + } + + @Test + public void testDerivesDifferentSharedSecretFromEquivalentPublicKeyWithPublicKeysHashedIn() + throws Exception { + KeyPair aPair = crypto.generateAgreementKeyPair(); + KeyPair bPair = crypto.generateAgreementKeyPair(); + + // Flip the unused most significant bit of the little-endian public key + byte[] aPub = aPair.getPublic().getEncoded(); + byte[] aPubEquiv = aPub.clone(); + aPubEquiv[31] ^= (byte) 128; + KeyPair aPairEquiv = new KeyPair(new AgreementPublicKey(aPubEquiv), + aPair.getPrivate()); + + // The public keys should be different and give different shared secrets + assertFalse(Arrays.equals(aPub, aPubEquiv)); + SecretKey shared = deriveSharedSecretWithPublicKeysHashedIn(label, + aPair.getPublic(), bPair); + SecretKey sharedEquiv = deriveSharedSecretWithPublicKeysHashedIn(label, + aPairEquiv.getPublic(), bPair); + assertFalse(Arrays.equals(shared.getBytes(), sharedEquiv.getBytes())); + } + + private SecretKey deriveSharedSecretWithPublicKeysHashedIn(String label, + PublicKey publicKey, KeyPair keyPair) throws Exception { + byte[][] inputs = new byte[][] { + publicKey.getEncoded(), + keyPair.getPublic().getEncoded() + }; + return crypto.deriveSharedSecret(label, publicKey, keyPair, inputs); + } + + private byte[] parsePrivateKey(String hex) { + // Private keys need to be clamped because curve25519-java does the + // clamping at key generation time, not multiplication time + return AgreementKeyParser.clamp(fromHexString(hex)); + } }