Add HKDF support for new sessions.

This commit is contained in:
Moxie Marlinspike 2013-11-01 13:52:41 -07:00
parent dbc070cd65
commit a03fff8b24
10 changed files with 386 additions and 171 deletions

@ -20,6 +20,7 @@ import android.content.Context;
import android.util.Log; import android.util.Log;
import org.spongycastle.crypto.params.ECPublicKeyParameters; import org.spongycastle.crypto.params.ECPublicKeyParameters;
import org.whispersystems.textsecure.crypto.kdf.DerivedSecrets;
import org.whispersystems.textsecure.crypto.protocol.CiphertextMessage; import org.whispersystems.textsecure.crypto.protocol.CiphertextMessage;
import org.whispersystems.textsecure.storage.CanonicalRecipientAddress; import org.whispersystems.textsecure.storage.CanonicalRecipientAddress;
import org.whispersystems.textsecure.storage.InvalidKeyIdException; import org.whispersystems.textsecure.storage.InvalidKeyIdException;
@ -38,11 +39,8 @@ import javax.crypto.spec.SecretKeySpec;
import java.math.BigInteger; import java.math.BigInteger;
import java.security.InvalidAlgorithmParameterException; import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException; import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.util.Arrays; import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
/** /**
* This is where the session encryption magic happens. Implements a compressed version of the OTR protocol. * This is where the session encryption magic happens. Implements a compressed version of the OTR protocol.
@ -148,23 +146,16 @@ public class SessionCipher {
} }
} }
private SecretKeySpec deriveMacSecret(SecretKeySpec key) { private byte[] getPlaintext(byte[] cipherText, SecretKeySpec key, int counter)
try { throws IllegalBlockSizeException, BadPaddingException
MessageDigest md = MessageDigest.getInstance("SHA-1"); {
byte[] secret = md.digest(key.getEncoded());
return new SecretKeySpec(secret, "HmacSHA1");
} catch (NoSuchAlgorithmException e) {
throw new IllegalArgumentException("SHA-1 Not Supported!",e);
}
}
private byte[] getPlaintext(byte[] cipherText, SecretKeySpec key, int counter) throws IllegalBlockSizeException, BadPaddingException {
Cipher cipher = getCipher(Cipher.DECRYPT_MODE, key, counter); Cipher cipher = getCipher(Cipher.DECRYPT_MODE, key, counter);
return cipher.doFinal(cipherText); return cipher.doFinal(cipherText);
} }
private byte[] getCiphertext(byte[] message, SecretKeySpec key, int counter) throws IllegalBlockSizeException, BadPaddingException { private byte[] getCiphertext(byte[] message, SecretKeySpec key, int counter)
throws IllegalBlockSizeException, BadPaddingException
{
Cipher cipher = getCipher(Cipher.ENCRYPT_MODE, key, counter); Cipher cipher = getCipher(Cipher.ENCRYPT_MODE, key, counter);
return cipher.doFinal(message); return cipher.doFinal(message);
} }
@ -193,46 +184,50 @@ public class SessionCipher {
} }
} }
private SecretKeySpec deriveCipherSecret(int mode, List<BigInteger> sharedSecret, private SessionKey getSessionKey(MasterSecret masterSecret, int mode,
KeyRecords records, int localKeyId, int messageVersion,
int remoteKeyId) IdentityKeyPair localIdentityKey,
KeyRecords records,
int localKeyId, int remoteKeyId)
throws InvalidKeyIdException throws InvalidKeyIdException
{ {
byte[] sharedSecretBytes = concatenateSharedSecrets(sharedSecret); Log.w("SessionCipher", "Getting session key for local: " + localKeyId + " remote: " + remoteKeyId);
byte[] derivedBytes = deriveBytes(sharedSecretBytes, 16 * 2); SessionKey sessionKey = records.getSessionRecord().getSessionKey(mode, localKeyId, remoteKeyId);
byte[] cipherSecret = new byte[16];
boolean isLowEnd = isLowEnd(records, localKeyId, remoteKeyId); if (sessionKey != null)
isLowEnd = (mode == Cipher.ENCRYPT_MODE ? isLowEnd : !isLowEnd); return sessionKey;
if (isLowEnd) { DerivedSecrets derivedSecrets = calculateSharedSecret(messageVersion, mode, localIdentityKey,
System.arraycopy(derivedBytes, 16, cipherSecret, 0, 16); records, localKeyId, remoteKeyId);
} else {
System.arraycopy(derivedBytes, 0, cipherSecret, 0, 16);
}
return new SecretKeySpec(cipherSecret, "AES"); return new SessionKey(mode, localKeyId, remoteKeyId, derivedSecrets.getCipherKey(),
derivedSecrets.getMacKey(), masterSecret);
} }
private byte[] concatenateSharedSecrets(List<BigInteger> sharedSecrets) { private DerivedSecrets calculateSharedSecret(int messageVersion, int mode,
int totalByteSize = 0; IdentityKeyPair localIdentityKey,
List<byte[]> byteValues = new LinkedList<byte[]>(); KeyRecords records,
int localKeyId, int remoteKeyId)
throws InvalidKeyIdException
{
KeyPair localKeyPair = records.getLocalKeyRecord().getKeyPairForId(localKeyId);
ECPublicKeyParameters remoteKey = records.getRemoteKeyRecord().getKeyForId(remoteKeyId).getKey();
IdentityKey remoteIdentityKey = records.getSessionRecord().getIdentityKey();
boolean isLowEnd = isLowEnd(records, localKeyId, remoteKeyId);
for (BigInteger sharedSecret : sharedSecrets) { isLowEnd = (mode == Cipher.ENCRYPT_MODE ? isLowEnd : !isLowEnd);
byte[] byteValue = sharedSecret.toByteArray();
totalByteSize += byteValue.length; if (isInitiallyExchangedKeys(records, localKeyId, remoteKeyId) &&
byteValues.add(byteValue); messageVersion >= CiphertextMessage.DHE3_INTRODUCED_VERSION)
{
return SharedSecretCalculator.calculateSharedSecret(isLowEnd,
localKeyPair, localKeyId, localIdentityKey,
remoteKey, remoteKeyId, remoteIdentityKey);
} else {
return SharedSecretCalculator.calculateSharedSecret(messageVersion, isLowEnd,
localKeyPair, localKeyId,
remoteKey, remoteKeyId);
} }
byte[] combined = new byte[totalByteSize];
int offset = 0;
for (byte[] byteValue : byteValues) {
System.arraycopy(byteValue, 0, combined, offset, byteValue.length);
offset += byteValue.length;
}
return combined;
} }
private boolean isLowEnd(KeyRecords records, int localKeyId, int remoteKeyId) private boolean isLowEnd(KeyRecords records, int localKeyId, int remoteKeyId)
@ -247,67 +242,6 @@ public class SessionCipher {
return local.compareTo(remote) < 0; return local.compareTo(remote) < 0;
} }
private byte[] deriveBytes(byte[] seed, int bytesNeeded) {
MessageDigest md;
try {
md = MessageDigest.getInstance("SHA-256");
} catch (NoSuchAlgorithmException e) {
Log.w("SessionCipher",e);
throw new IllegalArgumentException("SHA-256 Not Supported!");
}
int rounds = bytesNeeded / md.getDigestLength();
for (int i=1;i<=rounds;i++) {
byte[] roundBytes = Conversions.intToByteArray(i);
md.update(roundBytes);
md.update(seed);
}
return md.digest();
}
private SessionKey getSessionKey(MasterSecret masterSecret, int mode,
int messageVersion,
IdentityKeyPair localIdentityKey,
KeyRecords records,
int localKeyId, int remoteKeyId)
throws InvalidKeyIdException
{
Log.w("SessionCipher", "Getting session key for local: " + localKeyId + " remote: " + remoteKeyId);
SessionKey sessionKey = records.getSessionRecord().getSessionKey(localKeyId, remoteKeyId);
if (sessionKey != null)
return sessionKey;
List<BigInteger> sharedSecret = calculateSharedSecret(messageVersion, localIdentityKey, records, localKeyId, remoteKeyId);
SecretKeySpec cipherKey = deriveCipherSecret(mode, sharedSecret, records, localKeyId, remoteKeyId);
SecretKeySpec macKey = deriveMacSecret(cipherKey);
return new SessionKey(localKeyId, remoteKeyId, cipherKey, macKey, masterSecret);
}
private List<BigInteger> calculateSharedSecret(int messageVersion,
IdentityKeyPair localIdentityKey,
KeyRecords records,
int localKeyId, int remoteKeyId)
throws InvalidKeyIdException
{
KeyPair localKeyPair = records.getLocalKeyRecord().getKeyPairForId(localKeyId);
ECPublicKeyParameters remoteKey = records.getRemoteKeyRecord().getKeyForId(remoteKeyId).getKey();
IdentityKey remoteIdentityKey = records.getSessionRecord().getIdentityKey();
if (isInitiallyExchangedKeys(records, localKeyId, remoteKeyId) &&
messageVersion >= CiphertextMessage.CRADLE_AGREEMENT_VERSION)
{
return SharedSecretCalculator.calculateSharedSecret(localKeyPair, localIdentityKey,
remoteKey, remoteIdentityKey);
} else {
return SharedSecretCalculator.calculateSharedSecret(localKeyPair, remoteKey);
}
}
private boolean isInitiallyExchangedKeys(KeyRecords records, int localKeyId, int remoteKeyId) private boolean isInitiallyExchangedKeys(KeyRecords records, int localKeyId, int remoteKeyId)
throws InvalidKeyIdException throws InvalidKeyIdException
{ {

@ -5,6 +5,12 @@ import android.util.Log;
import org.spongycastle.crypto.CipherParameters; import org.spongycastle.crypto.CipherParameters;
import org.spongycastle.crypto.agreement.ECDHBasicAgreement; import org.spongycastle.crypto.agreement.ECDHBasicAgreement;
import org.spongycastle.crypto.params.ECPublicKeyParameters; import org.spongycastle.crypto.params.ECPublicKeyParameters;
import org.whispersystems.textsecure.crypto.kdf.DerivedSecrets;
import org.whispersystems.textsecure.crypto.kdf.HKDF;
import org.whispersystems.textsecure.crypto.kdf.KDF;
import org.whispersystems.textsecure.crypto.kdf.NKDF;
import org.whispersystems.textsecure.crypto.protocol.CiphertextMessage;
import org.whispersystems.textsecure.util.Conversions;
import java.math.BigInteger; import java.math.BigInteger;
import java.util.LinkedList; import java.util.LinkedList;
@ -12,15 +18,18 @@ import java.util.List;
public class SharedSecretCalculator { public class SharedSecretCalculator {
public static List<BigInteger> calculateSharedSecret(KeyPair localKeyPair, public static DerivedSecrets calculateSharedSecret(boolean isLowEnd, KeyPair localKeyPair,
IdentityKeyPair localIdentityKeyPair, int localKeyId,
ECPublicKeyParameters remoteKey, IdentityKeyPair localIdentityKeyPair,
IdentityKey remoteIdentityKey) ECPublicKeyParameters remoteKey,
int remoteKeyId,
IdentityKey remoteIdentityKey)
{ {
Log.w("SharedSecretCalculator", "Calculating shared secret with cradle agreement..."); Log.w("SharedSecretCalculator", "Calculating shared secret with cradle agreement...");
KDF kdf = new HKDF();
List<BigInteger> results = new LinkedList<BigInteger>(); List<BigInteger> results = new LinkedList<BigInteger>();
if (isLowEnd(localKeyPair.getPublicKey().getKey(), remoteKey)) { if (isSmaller(localKeyPair.getPublicKey().getKey(), remoteKey)) {
results.add(calculateAgreement(localIdentityKeyPair.getPrivateKey(), remoteKey)); results.add(calculateAgreement(localIdentityKeyPair.getPrivateKey(), remoteKey));
results.add(calculateAgreement(localKeyPair.getKeyPair().getPrivate(), results.add(calculateAgreement(localKeyPair.getKeyPair().getPrivate(),
@ -33,17 +42,40 @@ public class SharedSecretCalculator {
} }
results.add(calculateAgreement(localKeyPair.getKeyPair().getPrivate(), remoteKey)); results.add(calculateAgreement(localKeyPair.getKeyPair().getPrivate(), remoteKey));
return results;
return kdf.deriveSecrets(results, isLowEnd, getInfo(localKeyId,remoteKeyId));
} }
public static List<BigInteger> calculateSharedSecret(KeyPair localKeyPair, public static DerivedSecrets calculateSharedSecret(int messageVersion, boolean isLowEnd,
ECPublicKeyParameters remoteKey) KeyPair localKeyPair, int localKeyId,
ECPublicKeyParameters remoteKey, int remoteKeyId)
{ {
Log.w("SharedSecretCalculator", "Calculating shared secret with standard agreement..."); Log.w("SharedSecretCalculator", "Calculating shared secret with standard agreement...");
KDF kdf;
if (messageVersion >= CiphertextMessage.DHE3_INTRODUCED_VERSION) kdf = new HKDF();
else kdf = new NKDF();
Log.w("SharedSecretCalculator", "Using kdf: " + kdf);
List<BigInteger> results = new LinkedList<BigInteger>(); List<BigInteger> results = new LinkedList<BigInteger>();
results.add(calculateAgreement(localKeyPair.getKeyPair().getPrivate(), remoteKey)); results.add(calculateAgreement(localKeyPair.getKeyPair().getPrivate(), remoteKey));
return results; return kdf.deriveSecrets(results, isLowEnd, getInfo(localKeyId, remoteKeyId));
}
private static byte[] getInfo(int localKeyId, int remoteKeyId) {
byte[] info = new byte[3 * 2];
if (localKeyId < remoteKeyId) {
Conversions.mediumToByteArray(info, 0, localKeyId);
Conversions.mediumToByteArray(info, 3, remoteKeyId);
} else {
Conversions.mediumToByteArray(info, 0, remoteKeyId);
Conversions.mediumToByteArray(info, 3, localKeyId);
}
return info;
} }
private static BigInteger calculateAgreement(CipherParameters privateKey, private static BigInteger calculateAgreement(CipherParameters privateKey,
@ -56,8 +88,8 @@ public class SharedSecretCalculator {
} }
private static boolean isLowEnd(ECPublicKeyParameters localPublic, private static boolean isSmaller(ECPublicKeyParameters localPublic,
ECPublicKeyParameters remotePublic) ECPublicKeyParameters remotePublic)
{ {
BigInteger local = localPublic.getQ().getX().toBigInteger(); BigInteger local = localPublic.getQ().getX().toBigInteger();
BigInteger remote = remotePublic.getQ().getX().toBigInteger(); BigInteger remote = remotePublic.getQ().getX().toBigInteger();

@ -0,0 +1,22 @@
package org.whispersystems.textsecure.crypto.kdf;
import javax.crypto.spec.SecretKeySpec;
public class DerivedSecrets {
private final SecretKeySpec cipherKey;
private final SecretKeySpec macKey;
public DerivedSecrets(SecretKeySpec cipherKey, SecretKeySpec macKey) {
this.cipherKey = cipherKey;
this.macKey = macKey;
}
public SecretKeySpec getCipherKey() {
return cipherKey;
}
public SecretKeySpec getMacKey() {
return macKey;
}
}

@ -0,0 +1,101 @@
package org.whispersystems.textsecure.crypto.kdf;
import java.io.ByteArrayOutputStream;
import java.math.BigInteger;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
public class HKDF extends KDF {
private static final int HASH_OUTPUT_SIZE = 32;
private static final int KEY_MATERIAL_SIZE = 72;
private static final int CIPHER_KEYS_OFFSET = 0;
private static final int MAC_KEYS_OFFSET = 32;
@Override
public DerivedSecrets deriveSecrets(List<BigInteger> sharedSecret,
boolean isLowEnd, byte[] info)
{
byte[] inputKeyMaterial = concatenateSharedSecrets(sharedSecret);
byte[] salt = new byte[HASH_OUTPUT_SIZE];
byte[] prk = extract(salt, inputKeyMaterial);
byte[] okm = expand(prk, info, KEY_MATERIAL_SIZE);
SecretKeySpec cipherKey = deriveCipherKey(okm, isLowEnd);
SecretKeySpec macKey = deriveMacKey(okm, isLowEnd);
return new DerivedSecrets(cipherKey, macKey);
}
private SecretKeySpec deriveCipherKey(byte[] okm, boolean isLowEnd) {
byte[] cipherKey = new byte[16];
if (isLowEnd) {
System.arraycopy(okm, CIPHER_KEYS_OFFSET + 0, cipherKey, 0, cipherKey.length);
} else {
System.arraycopy(okm, CIPHER_KEYS_OFFSET + 16, cipherKey, 0, cipherKey.length);
}
return new SecretKeySpec(cipherKey, "AES");
}
private SecretKeySpec deriveMacKey(byte[] okm, boolean isLowEnd) {
byte[] macKey = new byte[20];
if (isLowEnd) {
System.arraycopy(okm, MAC_KEYS_OFFSET + 0, macKey, 0, macKey.length);
} else {
System.arraycopy(okm, MAC_KEYS_OFFSET + 20, macKey, 0, macKey.length);
}
return new SecretKeySpec(macKey, "HmacSHA1");
}
private byte[] extract(byte[] salt, byte[] inputKeyMaterial) {
try {
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(salt, "HmacSHA256"));
return mac.doFinal(inputKeyMaterial);
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
} catch (InvalidKeyException e) {
throw new AssertionError(e);
}
}
private byte[] expand(byte[] prk, byte[] info, int outputSize) {
try {
int iterations = (int)Math.ceil((double)outputSize/(double)HASH_OUTPUT_SIZE);
byte[] mixin = new byte[0];
ByteArrayOutputStream results = new ByteArrayOutputStream();
for (int i=0;i<iterations;i++) {
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(prk, "HmacSHA256"));
mac.update(mixin);
mac.update(info);
mac.update((byte)i);
byte[] stepResult = mac.doFinal();
results.write(stepResult, 0, stepResult.length);
mixin = stepResult;
}
return results.toByteArray();
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
} catch (InvalidKeyException e) {
throw new AssertionError(e);
}
}
}

@ -0,0 +1,33 @@
package org.whispersystems.textsecure.crypto.kdf;
import java.math.BigInteger;
import java.util.LinkedList;
import java.util.List;
public abstract class KDF {
public abstract DerivedSecrets deriveSecrets(List<BigInteger> sharedSecret,
boolean isLowEnd, byte[] info);
protected byte[] concatenateSharedSecrets(List<BigInteger> sharedSecrets) {
int totalByteSize = 0;
List<byte[]> byteValues = new LinkedList<byte[]>();
for (BigInteger sharedSecret : sharedSecrets) {
byte[] byteValue = sharedSecret.toByteArray();
totalByteSize += byteValue.length;
byteValues.add(byteValue);
}
byte[] combined = new byte[totalByteSize];
int offset = 0;
for (byte[] byteValue : byteValues) {
System.arraycopy(byteValue, 0, combined, offset, byteValue.length);
offset += byteValue.length;
}
return combined;
}
}

@ -0,0 +1,73 @@
package org.whispersystems.textsecure.crypto.kdf;
import android.util.Log;
import org.whispersystems.textsecure.util.Conversions;
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import javax.crypto.spec.SecretKeySpec;
public class NKDF extends KDF {
@Override
public DerivedSecrets deriveSecrets(List<BigInteger> sharedSecret,
boolean isLowEnd, byte[] info)
{
SecretKeySpec cipherKey = deriveCipherSecret(isLowEnd, sharedSecret);
SecretKeySpec macKey = deriveMacSecret(cipherKey);
return new DerivedSecrets(cipherKey, macKey);
}
private SecretKeySpec deriveCipherSecret(boolean isLowEnd, List<BigInteger> sharedSecret) {
byte[] sharedSecretBytes = concatenateSharedSecrets(sharedSecret);
byte[] derivedBytes = deriveBytes(sharedSecretBytes, 16 * 2);
byte[] cipherSecret = new byte[16];
if (isLowEnd) {
System.arraycopy(derivedBytes, 16, cipherSecret, 0, 16);
} else {
System.arraycopy(derivedBytes, 0, cipherSecret, 0, 16);
}
return new SecretKeySpec(cipherSecret, "AES");
}
private SecretKeySpec deriveMacSecret(SecretKeySpec key) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-1");
byte[] secret = md.digest(key.getEncoded());
return new SecretKeySpec(secret, "HmacSHA1");
} catch (NoSuchAlgorithmException e) {
throw new IllegalArgumentException("SHA-1 Not Supported!",e);
}
}
private byte[] deriveBytes(byte[] seed, int bytesNeeded) {
MessageDigest md;
try {
md = MessageDigest.getInstance("SHA-256");
} catch (NoSuchAlgorithmException e) {
Log.w("NKDF", e);
throw new IllegalArgumentException("SHA-256 Not Supported!");
}
int rounds = bytesNeeded / md.getDigestLength();
for (int i=1;i<=rounds;i++) {
byte[] roundBytes = Conversions.intToByteArray(i);
md.update(roundBytes);
md.update(seed);
}
return md.digest();
}
}

@ -9,8 +9,8 @@ import org.whispersystems.textsecure.util.Conversions;
public class CiphertextMessage { public class CiphertextMessage {
public static final int SUPPORTED_VERSION = 2; public static final int SUPPORTED_VERSION = 2;
public static final int CRADLE_AGREEMENT_VERSION = 2; public static final int DHE3_INTRODUCED_VERSION = 2;
static final int VERSION_LENGTH = 1; static final int VERSION_LENGTH = 1;
private static final int SENDER_KEY_ID_LENGTH = 3; private static final int SENDER_KEY_ID_LENGTH = 3;

@ -16,6 +16,7 @@
*/ */
package org.whispersystems.textsecure.storage; package org.whispersystems.textsecure.storage;
import org.whispersystems.textsecure.crypto.InvalidMessageException;
import org.whispersystems.textsecure.crypto.MasterCipher; import org.whispersystems.textsecure.crypto.MasterCipher;
import org.whispersystems.textsecure.crypto.MasterSecret; import org.whispersystems.textsecure.crypto.MasterSecret;
import org.whispersystems.textsecure.crypto.SessionCipher; import org.whispersystems.textsecure.crypto.SessionCipher;
@ -34,13 +35,18 @@ import javax.crypto.spec.SecretKeySpec;
public class SessionKey { public class SessionKey {
private int localKeyId; private int mode;
private int remoteKeyId; private int localKeyId;
private int remoteKeyId;
private SecretKeySpec cipherKey; private SecretKeySpec cipherKey;
private SecretKeySpec macKey; private SecretKeySpec macKey;
private MasterCipher masterCipher; private MasterCipher masterCipher;
public SessionKey(int localKeyId, int remoteKeyId, SecretKeySpec cipherKey, SecretKeySpec macKey, MasterSecret masterSecret) { public SessionKey(int mode, int localKeyId, int remoteKeyId,
SecretKeySpec cipherKey, SecretKeySpec macKey,
MasterSecret masterSecret)
{
this.mode = mode;
this.localKeyId = localKeyId; this.localKeyId = localKeyId;
this.remoteKeyId = remoteKeyId; this.remoteKeyId = remoteKeyId;
this.cipherKey = cipherKey; this.cipherKey = cipherKey;
@ -48,7 +54,7 @@ public class SessionKey {
this.masterCipher = new MasterCipher(masterSecret); this.masterCipher = new MasterCipher(masterSecret);
} }
public SessionKey(byte[] bytes, MasterSecret masterSecret) { public SessionKey(byte[] bytes, MasterSecret masterSecret) throws InvalidMessageException {
this.masterCipher = new MasterCipher(masterSecret); this.masterCipher = new MasterCipher(masterSecret);
deserialize(bytes); deserialize(bytes);
} }
@ -58,13 +64,16 @@ public class SessionKey {
byte[] remoteKeyIdBytes = Conversions.mediumToByteArray(remoteKeyId); byte[] remoteKeyIdBytes = Conversions.mediumToByteArray(remoteKeyId);
byte[] cipherKeyBytes = cipherKey.getEncoded(); byte[] cipherKeyBytes = cipherKey.getEncoded();
byte[] macKeyBytes = macKey.getEncoded(); byte[] macKeyBytes = macKey.getEncoded();
byte[] combined = Util.combine(localKeyIdBytes, remoteKeyIdBytes, cipherKeyBytes, macKeyBytes); byte[] modeBytes = {(byte)mode};
byte[] combined = Util.combine(localKeyIdBytes, remoteKeyIdBytes,
cipherKeyBytes, macKeyBytes, modeBytes);
return masterCipher.encryptBytes(combined); return masterCipher.encryptBytes(combined);
} }
private void deserialize(byte[] bytes) { private void deserialize(byte[] bytes) throws InvalidMessageException {
byte[] decrypted = masterCipher.encryptBytes(bytes); byte[] decrypted = masterCipher.decryptBytes(bytes);
this.localKeyId = Conversions.byteArrayToMedium(decrypted, 0); this.localKeyId = Conversions.byteArrayToMedium(decrypted, 0);
this.remoteKeyId = Conversions.byteArrayToMedium(decrypted, 3); this.remoteKeyId = Conversions.byteArrayToMedium(decrypted, 3);
@ -74,6 +83,12 @@ public class SessionKey {
byte[] macBytes = new byte[SessionCipher.MAC_KEY_LENGTH]; byte[] macBytes = new byte[SessionCipher.MAC_KEY_LENGTH];
System.arraycopy(decrypted, 6 + keyBytes.length, macBytes, 0, macBytes.length); System.arraycopy(decrypted, 6 + keyBytes.length, macBytes, 0, macBytes.length);
if (decrypted.length < 6 + SessionCipher.CIPHER_KEY_LENGTH + SessionCipher.MAC_KEY_LENGTH + 1) {
throw new InvalidMessageException("No mode included");
}
this.mode = decrypted[6 + keyBytes.length + macBytes.length];
this.cipherKey = new SecretKeySpec(keyBytes, "AES"); this.cipherKey = new SecretKeySpec(keyBytes, "AES");
this.macKey = new SecretKeySpec(macBytes, "HmacSHA1"); this.macKey = new SecretKeySpec(macBytes, "HmacSHA1");
} }
@ -94,4 +109,7 @@ public class SessionKey {
return this.macKey; return this.macKey;
} }
public int getMode() {
return mode;
}
} }

@ -21,6 +21,7 @@ import android.util.Log;
import org.whispersystems.textsecure.crypto.IdentityKey; import org.whispersystems.textsecure.crypto.IdentityKey;
import org.whispersystems.textsecure.crypto.InvalidKeyException; import org.whispersystems.textsecure.crypto.InvalidKeyException;
import org.whispersystems.textsecure.crypto.InvalidMessageException;
import org.whispersystems.textsecure.crypto.MasterSecret; import org.whispersystems.textsecure.crypto.MasterSecret;
import java.io.FileInputStream; import java.io.FileInputStream;
@ -41,16 +42,16 @@ public class SessionRecord extends Record {
private static final int[] VALID_VERSION_MARKERS = {CURRENT_VERSION_MARKER, 0X55555556, 0X55555555}; private static final int[] VALID_VERSION_MARKERS = {CURRENT_VERSION_MARKER, 0X55555556, 0X55555555};
private static final Object FILE_LOCK = new Object(); private static final Object FILE_LOCK = new Object();
private int counter; private int counter;
private byte[] localFingerprint; private byte[] localFingerprint;
private byte[] remoteFingerprint; private byte[] remoteFingerprint;
private int negotiatedSessionVersion; private int negotiatedSessionVersion;
private int currentSessionVersion; private int currentSessionVersion;
private IdentityKey identityKey; private IdentityKey identityKey;
private SessionKey sessionKeyRecord; private SessionKey sessionKeyRecord;
private boolean verifiedSessionKey; private boolean verifiedSessionKey;
private boolean prekeyBundleRequired; private boolean prekeyBundleRequired;
private final MasterSecret masterSecret; private final MasterSecret masterSecret;
@ -208,8 +209,14 @@ public class SessionRecord extends Record {
this.remoteFingerprint = readBlob(in); this.remoteFingerprint = readBlob(in);
this.currentSessionVersion = 31337; this.currentSessionVersion = 31337;
if (in.available() != 0) if (in.available() != 0) {
this.sessionKeyRecord = new SessionKey(readBlob(in), masterSecret); try {
this.sessionKeyRecord = new SessionKey(readBlob(in), masterSecret);
} catch (InvalidMessageException e) {
Log.w("SessionRecord", e);
this.sessionKeyRecord = null;
}
}
in.close(); in.close();
} else { } else {
@ -230,8 +237,14 @@ public class SessionRecord extends Record {
this.negotiatedSessionVersion = currentSessionVersion; this.negotiatedSessionVersion = currentSessionVersion;
} }
if (in.available() != 0) if (in.available() != 0) {
this.sessionKeyRecord = new SessionKey(readBlob(in), masterSecret); try {
this.sessionKeyRecord = new SessionKey(readBlob(in), masterSecret);
} catch (InvalidMessageException e) {
Log.w("SessionRecord", e);
this.sessionKeyRecord = null;
}
}
in.close(); in.close();
} }
@ -245,12 +258,15 @@ public class SessionRecord extends Record {
} }
} }
public SessionKey getSessionKey(int localKeyId, int remoteKeyId) { public SessionKey getSessionKey(int mode, int localKeyId, int remoteKeyId) {
if (this.sessionKeyRecord == null) return null; if (this.sessionKeyRecord == null) return null;
if ((this.sessionKeyRecord.getLocalKeyId() == localKeyId) && if ((this.sessionKeyRecord.getLocalKeyId() == localKeyId) &&
(this.sessionKeyRecord.getRemoteKeyId() == remoteKeyId)) (this.sessionKeyRecord.getRemoteKeyId() == remoteKeyId) &&
(this.sessionKeyRecord.getMode() == mode))
{
return this.sessionKeyRecord; return this.sessionKeyRecord;
}
return null; return null;
} }

@ -18,32 +18,18 @@ import java.util.List;
public class Util { public class Util {
public static byte[] combine(byte[] one, byte[] two) { public static byte[] combine(byte[]... elements) {
byte[] combined = new byte[one.length + two.length]; try {
System.arraycopy(one, 0, combined, 0, one.length); ByteArrayOutputStream baos = new ByteArrayOutputStream();
System.arraycopy(two, 0, combined, one.length, two.length);
return combined; for (byte[] element : elements) {
} baos.write(element);
}
public static byte[] combine(byte[] one, byte[] two, byte[] three) {
byte[] combined = new byte[one.length + two.length + three.length];
System.arraycopy(one, 0, combined, 0, one.length);
System.arraycopy(two, 0, combined, one.length, two.length);
System.arraycopy(three, 0, combined, one.length + two.length, three.length);
return combined;
}
public static byte[] combine(byte[] one, byte[] two, byte[] three, byte[] four) {
byte[] combined = new byte[one.length + two.length + three.length + four.length];
System.arraycopy(one, 0, combined, 0, one.length);
System.arraycopy(two, 0, combined, one.length, two.length);
System.arraycopy(three, 0, combined, one.length + two.length, three.length);
System.arraycopy(four, 0, combined, one.length + two.length + three.length, four.length);
return combined;
return baos.toByteArray();
} catch (IOException e) {
throw new AssertionError(e);
}
} }
public static byte[][] split(byte[] input, int firstLength, int secondLength) { public static byte[][] split(byte[] input, int firstLength, int secondLength) {