diff --git a/library/src/org/whispersystems/textsecure/crypto/SessionCipher.java b/library/src/org/whispersystems/textsecure/crypto/SessionCipher.java index f8757b27c5..115bc2c349 100644 --- a/library/src/org/whispersystems/textsecure/crypto/SessionCipher.java +++ b/library/src/org/whispersystems/textsecure/crypto/SessionCipher.java @@ -20,6 +20,7 @@ import android.content.Context; import android.util.Log; import org.spongycastle.crypto.params.ECPublicKeyParameters; +import org.whispersystems.textsecure.crypto.kdf.DerivedSecrets; import org.whispersystems.textsecure.crypto.protocol.CiphertextMessage; import org.whispersystems.textsecure.storage.CanonicalRecipientAddress; import org.whispersystems.textsecure.storage.InvalidKeyIdException; @@ -38,11 +39,8 @@ import javax.crypto.spec.SecretKeySpec; import java.math.BigInteger; import java.security.InvalidAlgorithmParameterException; import java.security.InvalidKeyException; -import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Arrays; -import java.util.LinkedList; -import java.util.List; /** * This is where the session encryption magic happens. Implements a compressed version of the OTR protocol. @@ -148,23 +146,16 @@ public class SessionCipher { } } - private SecretKeySpec deriveMacSecret(SecretKeySpec key) { - try { - MessageDigest md = MessageDigest.getInstance("SHA-1"); - byte[] secret = md.digest(key.getEncoded()); - - return new SecretKeySpec(secret, "HmacSHA1"); - } catch (NoSuchAlgorithmException e) { - throw new IllegalArgumentException("SHA-1 Not Supported!",e); - } - } - - private byte[] getPlaintext(byte[] cipherText, SecretKeySpec key, int counter) throws IllegalBlockSizeException, BadPaddingException { + private byte[] getPlaintext(byte[] cipherText, SecretKeySpec key, int counter) + throws IllegalBlockSizeException, BadPaddingException + { Cipher cipher = getCipher(Cipher.DECRYPT_MODE, key, counter); return cipher.doFinal(cipherText); } - private byte[] getCiphertext(byte[] message, SecretKeySpec key, int counter) throws IllegalBlockSizeException, BadPaddingException { + private byte[] getCiphertext(byte[] message, SecretKeySpec key, int counter) + throws IllegalBlockSizeException, BadPaddingException + { Cipher cipher = getCipher(Cipher.ENCRYPT_MODE, key, counter); return cipher.doFinal(message); } @@ -192,49 +183,53 @@ public class SessionCipher { throw new IllegalArgumentException("Bad IV?"); } } - - private SecretKeySpec deriveCipherSecret(int mode, List sharedSecret, - KeyRecords records, int localKeyId, - int remoteKeyId) + + private SessionKey getSessionKey(MasterSecret masterSecret, int mode, + int messageVersion, + IdentityKeyPair localIdentityKey, + KeyRecords records, + int localKeyId, int remoteKeyId) throws InvalidKeyIdException { - byte[] sharedSecretBytes = concatenateSharedSecrets(sharedSecret); - byte[] derivedBytes = deriveBytes(sharedSecretBytes, 16 * 2); - byte[] cipherSecret = new byte[16]; - - boolean isLowEnd = isLowEnd(records, localKeyId, remoteKeyId); - isLowEnd = (mode == Cipher.ENCRYPT_MODE ? isLowEnd : !isLowEnd); - - if (isLowEnd) { - System.arraycopy(derivedBytes, 16, cipherSecret, 0, 16); - } else { - System.arraycopy(derivedBytes, 0, cipherSecret, 0, 16); - } - - return new SecretKeySpec(cipherSecret, "AES"); - } + Log.w("SessionCipher", "Getting session key for local: " + localKeyId + " remote: " + remoteKeyId); + SessionKey sessionKey = records.getSessionRecord().getSessionKey(mode, localKeyId, remoteKeyId); - private byte[] concatenateSharedSecrets(List sharedSecrets) { - int totalByteSize = 0; - List byteValues = new LinkedList(); + if (sessionKey != null) + return sessionKey; - for (BigInteger sharedSecret : sharedSecrets) { - byte[] byteValue = sharedSecret.toByteArray(); - totalByteSize += byteValue.length; - byteValues.add(byteValue); - } + DerivedSecrets derivedSecrets = calculateSharedSecret(messageVersion, mode, localIdentityKey, + records, localKeyId, remoteKeyId); - byte[] combined = new byte[totalByteSize]; - int offset = 0; - - for (byte[] byteValue : byteValues) { - System.arraycopy(byteValue, 0, combined, offset, byteValue.length); - offset += byteValue.length; - } - - return combined; + return new SessionKey(mode, localKeyId, remoteKeyId, derivedSecrets.getCipherKey(), + derivedSecrets.getMacKey(), masterSecret); } + private DerivedSecrets calculateSharedSecret(int messageVersion, int mode, + IdentityKeyPair localIdentityKey, + KeyRecords records, + int localKeyId, int remoteKeyId) + throws InvalidKeyIdException + { + KeyPair localKeyPair = records.getLocalKeyRecord().getKeyPairForId(localKeyId); + ECPublicKeyParameters remoteKey = records.getRemoteKeyRecord().getKeyForId(remoteKeyId).getKey(); + IdentityKey remoteIdentityKey = records.getSessionRecord().getIdentityKey(); + boolean isLowEnd = isLowEnd(records, localKeyId, remoteKeyId); + + isLowEnd = (mode == Cipher.ENCRYPT_MODE ? isLowEnd : !isLowEnd); + + if (isInitiallyExchangedKeys(records, localKeyId, remoteKeyId) && + messageVersion >= CiphertextMessage.DHE3_INTRODUCED_VERSION) + { + return SharedSecretCalculator.calculateSharedSecret(isLowEnd, + localKeyPair, localKeyId, localIdentityKey, + remoteKey, remoteKeyId, remoteIdentityKey); + } else { + return SharedSecretCalculator.calculateSharedSecret(messageVersion, isLowEnd, + localKeyPair, localKeyId, + remoteKey, remoteKeyId); + } + } + private boolean isLowEnd(KeyRecords records, int localKeyId, int remoteKeyId) throws InvalidKeyIdException { @@ -247,67 +242,6 @@ public class SessionCipher { return local.compareTo(remote) < 0; } - private byte[] deriveBytes(byte[] seed, int bytesNeeded) { - MessageDigest md; - - try { - md = MessageDigest.getInstance("SHA-256"); - } catch (NoSuchAlgorithmException e) { - Log.w("SessionCipher",e); - throw new IllegalArgumentException("SHA-256 Not Supported!"); - } - - int rounds = bytesNeeded / md.getDigestLength(); - - for (int i=1;i<=rounds;i++) { - byte[] roundBytes = Conversions.intToByteArray(i); - md.update(roundBytes); - md.update(seed); - } - - return md.digest(); - } - - private SessionKey getSessionKey(MasterSecret masterSecret, int mode, - int messageVersion, - IdentityKeyPair localIdentityKey, - KeyRecords records, - int localKeyId, int remoteKeyId) - throws InvalidKeyIdException - { - Log.w("SessionCipher", "Getting session key for local: " + localKeyId + " remote: " + remoteKeyId); - SessionKey sessionKey = records.getSessionRecord().getSessionKey(localKeyId, remoteKeyId); - - if (sessionKey != null) - return sessionKey; - - List sharedSecret = calculateSharedSecret(messageVersion, localIdentityKey, records, localKeyId, remoteKeyId); - SecretKeySpec cipherKey = deriveCipherSecret(mode, sharedSecret, records, localKeyId, remoteKeyId); - SecretKeySpec macKey = deriveMacSecret(cipherKey); - - return new SessionKey(localKeyId, remoteKeyId, cipherKey, macKey, masterSecret); - } - - private List calculateSharedSecret(int messageVersion, - IdentityKeyPair localIdentityKey, - KeyRecords records, - int localKeyId, int remoteKeyId) - throws InvalidKeyIdException - { - KeyPair localKeyPair = records.getLocalKeyRecord().getKeyPairForId(localKeyId); - ECPublicKeyParameters remoteKey = records.getRemoteKeyRecord().getKeyForId(remoteKeyId).getKey(); - IdentityKey remoteIdentityKey = records.getSessionRecord().getIdentityKey(); - - if (isInitiallyExchangedKeys(records, localKeyId, remoteKeyId) && - messageVersion >= CiphertextMessage.CRADLE_AGREEMENT_VERSION) - { - return SharedSecretCalculator.calculateSharedSecret(localKeyPair, localIdentityKey, - remoteKey, remoteIdentityKey); - } else { - return SharedSecretCalculator.calculateSharedSecret(localKeyPair, remoteKey); - } - } - private boolean isInitiallyExchangedKeys(KeyRecords records, int localKeyId, int remoteKeyId) throws InvalidKeyIdException { diff --git a/library/src/org/whispersystems/textsecure/crypto/SharedSecretCalculator.java b/library/src/org/whispersystems/textsecure/crypto/SharedSecretCalculator.java index 2c51df6a41..9a75452df5 100644 --- a/library/src/org/whispersystems/textsecure/crypto/SharedSecretCalculator.java +++ b/library/src/org/whispersystems/textsecure/crypto/SharedSecretCalculator.java @@ -5,6 +5,12 @@ import android.util.Log; import org.spongycastle.crypto.CipherParameters; import org.spongycastle.crypto.agreement.ECDHBasicAgreement; import org.spongycastle.crypto.params.ECPublicKeyParameters; +import org.whispersystems.textsecure.crypto.kdf.DerivedSecrets; +import org.whispersystems.textsecure.crypto.kdf.HKDF; +import org.whispersystems.textsecure.crypto.kdf.KDF; +import org.whispersystems.textsecure.crypto.kdf.NKDF; +import org.whispersystems.textsecure.crypto.protocol.CiphertextMessage; +import org.whispersystems.textsecure.util.Conversions; import java.math.BigInteger; import java.util.LinkedList; @@ -12,15 +18,18 @@ import java.util.List; public class SharedSecretCalculator { - public static List calculateSharedSecret(KeyPair localKeyPair, - IdentityKeyPair localIdentityKeyPair, - ECPublicKeyParameters remoteKey, - IdentityKey remoteIdentityKey) + public static DerivedSecrets calculateSharedSecret(boolean isLowEnd, KeyPair localKeyPair, + int localKeyId, + IdentityKeyPair localIdentityKeyPair, + ECPublicKeyParameters remoteKey, + int remoteKeyId, + IdentityKey remoteIdentityKey) { Log.w("SharedSecretCalculator", "Calculating shared secret with cradle agreement..."); + KDF kdf = new HKDF(); List results = new LinkedList(); - if (isLowEnd(localKeyPair.getPublicKey().getKey(), remoteKey)) { + if (isSmaller(localKeyPair.getPublicKey().getKey(), remoteKey)) { results.add(calculateAgreement(localIdentityKeyPair.getPrivateKey(), remoteKey)); results.add(calculateAgreement(localKeyPair.getKeyPair().getPrivate(), @@ -33,17 +42,40 @@ public class SharedSecretCalculator { } results.add(calculateAgreement(localKeyPair.getKeyPair().getPrivate(), remoteKey)); - return results; + + return kdf.deriveSecrets(results, isLowEnd, getInfo(localKeyId,remoteKeyId)); } - public static List calculateSharedSecret(KeyPair localKeyPair, - ECPublicKeyParameters remoteKey) + public static DerivedSecrets calculateSharedSecret(int messageVersion, boolean isLowEnd, + KeyPair localKeyPair, int localKeyId, + ECPublicKeyParameters remoteKey, int remoteKeyId) { Log.w("SharedSecretCalculator", "Calculating shared secret with standard agreement..."); + KDF kdf; + + if (messageVersion >= CiphertextMessage.DHE3_INTRODUCED_VERSION) kdf = new HKDF(); + else kdf = new NKDF(); + + Log.w("SharedSecretCalculator", "Using kdf: " + kdf); + List results = new LinkedList(); results.add(calculateAgreement(localKeyPair.getKeyPair().getPrivate(), remoteKey)); - return results; + return kdf.deriveSecrets(results, isLowEnd, getInfo(localKeyId, remoteKeyId)); + } + + private static byte[] getInfo(int localKeyId, int remoteKeyId) { + byte[] info = new byte[3 * 2]; + + if (localKeyId < remoteKeyId) { + Conversions.mediumToByteArray(info, 0, localKeyId); + Conversions.mediumToByteArray(info, 3, remoteKeyId); + } else { + Conversions.mediumToByteArray(info, 0, remoteKeyId); + Conversions.mediumToByteArray(info, 3, localKeyId); + } + + return info; } private static BigInteger calculateAgreement(CipherParameters privateKey, @@ -56,8 +88,8 @@ public class SharedSecretCalculator { } - private static boolean isLowEnd(ECPublicKeyParameters localPublic, - ECPublicKeyParameters remotePublic) + private static boolean isSmaller(ECPublicKeyParameters localPublic, + ECPublicKeyParameters remotePublic) { BigInteger local = localPublic.getQ().getX().toBigInteger(); BigInteger remote = remotePublic.getQ().getX().toBigInteger(); diff --git a/library/src/org/whispersystems/textsecure/crypto/kdf/DerivedSecrets.java b/library/src/org/whispersystems/textsecure/crypto/kdf/DerivedSecrets.java new file mode 100644 index 0000000000..74f8b82e67 --- /dev/null +++ b/library/src/org/whispersystems/textsecure/crypto/kdf/DerivedSecrets.java @@ -0,0 +1,22 @@ +package org.whispersystems.textsecure.crypto.kdf; + +import javax.crypto.spec.SecretKeySpec; + +public class DerivedSecrets { + + private final SecretKeySpec cipherKey; + private final SecretKeySpec macKey; + + public DerivedSecrets(SecretKeySpec cipherKey, SecretKeySpec macKey) { + this.cipherKey = cipherKey; + this.macKey = macKey; + } + + public SecretKeySpec getCipherKey() { + return cipherKey; + } + + public SecretKeySpec getMacKey() { + return macKey; + } +} diff --git a/library/src/org/whispersystems/textsecure/crypto/kdf/HKDF.java b/library/src/org/whispersystems/textsecure/crypto/kdf/HKDF.java new file mode 100644 index 0000000000..69cf3afc0f --- /dev/null +++ b/library/src/org/whispersystems/textsecure/crypto/kdf/HKDF.java @@ -0,0 +1,101 @@ +package org.whispersystems.textsecure.crypto.kdf; + +import java.io.ByteArrayOutputStream; +import java.math.BigInteger; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.List; + +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; + +public class HKDF extends KDF { + + private static final int HASH_OUTPUT_SIZE = 32; + private static final int KEY_MATERIAL_SIZE = 72; + + private static final int CIPHER_KEYS_OFFSET = 0; + private static final int MAC_KEYS_OFFSET = 32; + + @Override + public DerivedSecrets deriveSecrets(List sharedSecret, + boolean isLowEnd, byte[] info) + { + byte[] inputKeyMaterial = concatenateSharedSecrets(sharedSecret); + byte[] salt = new byte[HASH_OUTPUT_SIZE]; + byte[] prk = extract(salt, inputKeyMaterial); + byte[] okm = expand(prk, info, KEY_MATERIAL_SIZE); + + SecretKeySpec cipherKey = deriveCipherKey(okm, isLowEnd); + SecretKeySpec macKey = deriveMacKey(okm, isLowEnd); + + return new DerivedSecrets(cipherKey, macKey); + } + + private SecretKeySpec deriveCipherKey(byte[] okm, boolean isLowEnd) { + byte[] cipherKey = new byte[16]; + + if (isLowEnd) { + System.arraycopy(okm, CIPHER_KEYS_OFFSET + 0, cipherKey, 0, cipherKey.length); + } else { + System.arraycopy(okm, CIPHER_KEYS_OFFSET + 16, cipherKey, 0, cipherKey.length); + } + + return new SecretKeySpec(cipherKey, "AES"); + } + + private SecretKeySpec deriveMacKey(byte[] okm, boolean isLowEnd) { + byte[] macKey = new byte[20]; + + if (isLowEnd) { + System.arraycopy(okm, MAC_KEYS_OFFSET + 0, macKey, 0, macKey.length); + } else { + System.arraycopy(okm, MAC_KEYS_OFFSET + 20, macKey, 0, macKey.length); + } + + return new SecretKeySpec(macKey, "HmacSHA1"); + } + + private byte[] extract(byte[] salt, byte[] inputKeyMaterial) { + try { + Mac mac = Mac.getInstance("HmacSHA256"); + mac.init(new SecretKeySpec(salt, "HmacSHA256")); + return mac.doFinal(inputKeyMaterial); + } catch (NoSuchAlgorithmException e) { + throw new AssertionError(e); + } catch (InvalidKeyException e) { + throw new AssertionError(e); + } + } + + private byte[] expand(byte[] prk, byte[] info, int outputSize) { + try { + int iterations = (int)Math.ceil((double)outputSize/(double)HASH_OUTPUT_SIZE); + byte[] mixin = new byte[0]; + ByteArrayOutputStream results = new ByteArrayOutputStream(); + + for (int i=0;i sharedSecret, + boolean isLowEnd, byte[] info); + + protected byte[] concatenateSharedSecrets(List sharedSecrets) { + int totalByteSize = 0; + List byteValues = new LinkedList(); + + for (BigInteger sharedSecret : sharedSecrets) { + byte[] byteValue = sharedSecret.toByteArray(); + totalByteSize += byteValue.length; + byteValues.add(byteValue); + } + + byte[] combined = new byte[totalByteSize]; + int offset = 0; + + for (byte[] byteValue : byteValues) { + System.arraycopy(byteValue, 0, combined, offset, byteValue.length); + offset += byteValue.length; + } + + return combined; + } + +} diff --git a/library/src/org/whispersystems/textsecure/crypto/kdf/NKDF.java b/library/src/org/whispersystems/textsecure/crypto/kdf/NKDF.java new file mode 100644 index 0000000000..5106af24f5 --- /dev/null +++ b/library/src/org/whispersystems/textsecure/crypto/kdf/NKDF.java @@ -0,0 +1,73 @@ +package org.whispersystems.textsecure.crypto.kdf; + +import android.util.Log; + +import org.whispersystems.textsecure.util.Conversions; + +import java.math.BigInteger; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.List; + +import javax.crypto.spec.SecretKeySpec; + +public class NKDF extends KDF { + + @Override + public DerivedSecrets deriveSecrets(List sharedSecret, + boolean isLowEnd, byte[] info) + { + SecretKeySpec cipherKey = deriveCipherSecret(isLowEnd, sharedSecret); + SecretKeySpec macKey = deriveMacSecret(cipherKey); + + return new DerivedSecrets(cipherKey, macKey); + } + + private SecretKeySpec deriveCipherSecret(boolean isLowEnd, List sharedSecret) { + byte[] sharedSecretBytes = concatenateSharedSecrets(sharedSecret); + byte[] derivedBytes = deriveBytes(sharedSecretBytes, 16 * 2); + byte[] cipherSecret = new byte[16]; + + if (isLowEnd) { + System.arraycopy(derivedBytes, 16, cipherSecret, 0, 16); + } else { + System.arraycopy(derivedBytes, 0, cipherSecret, 0, 16); + } + + return new SecretKeySpec(cipherSecret, "AES"); + } + + private SecretKeySpec deriveMacSecret(SecretKeySpec key) { + try { + MessageDigest md = MessageDigest.getInstance("SHA-1"); + byte[] secret = md.digest(key.getEncoded()); + + return new SecretKeySpec(secret, "HmacSHA1"); + } catch (NoSuchAlgorithmException e) { + throw new IllegalArgumentException("SHA-1 Not Supported!",e); + } + } + + private byte[] deriveBytes(byte[] seed, int bytesNeeded) { + MessageDigest md; + + try { + md = MessageDigest.getInstance("SHA-256"); + } catch (NoSuchAlgorithmException e) { + Log.w("NKDF", e); + throw new IllegalArgumentException("SHA-256 Not Supported!"); + } + + int rounds = bytesNeeded / md.getDigestLength(); + + for (int i=1;i<=rounds;i++) { + byte[] roundBytes = Conversions.intToByteArray(i); + md.update(roundBytes); + md.update(seed); + } + + return md.digest(); + } + + +} diff --git a/library/src/org/whispersystems/textsecure/crypto/protocol/CiphertextMessage.java b/library/src/org/whispersystems/textsecure/crypto/protocol/CiphertextMessage.java index 5db9cbfc80..63710d8eac 100644 --- a/library/src/org/whispersystems/textsecure/crypto/protocol/CiphertextMessage.java +++ b/library/src/org/whispersystems/textsecure/crypto/protocol/CiphertextMessage.java @@ -9,8 +9,8 @@ import org.whispersystems.textsecure.util.Conversions; public class CiphertextMessage { - public static final int SUPPORTED_VERSION = 2; - public static final int CRADLE_AGREEMENT_VERSION = 2; + public static final int SUPPORTED_VERSION = 2; + public static final int DHE3_INTRODUCED_VERSION = 2; static final int VERSION_LENGTH = 1; private static final int SENDER_KEY_ID_LENGTH = 3; diff --git a/library/src/org/whispersystems/textsecure/storage/SessionKey.java b/library/src/org/whispersystems/textsecure/storage/SessionKey.java index 3f0278bd74..c45fb3a080 100644 --- a/library/src/org/whispersystems/textsecure/storage/SessionKey.java +++ b/library/src/org/whispersystems/textsecure/storage/SessionKey.java @@ -16,6 +16,7 @@ */ package org.whispersystems.textsecure.storage; +import org.whispersystems.textsecure.crypto.InvalidMessageException; import org.whispersystems.textsecure.crypto.MasterCipher; import org.whispersystems.textsecure.crypto.MasterSecret; import org.whispersystems.textsecure.crypto.SessionCipher; @@ -34,13 +35,18 @@ import javax.crypto.spec.SecretKeySpec; public class SessionKey { - private int localKeyId; - private int remoteKeyId; + private int mode; + private int localKeyId; + private int remoteKeyId; private SecretKeySpec cipherKey; private SecretKeySpec macKey; - private MasterCipher masterCipher; + private MasterCipher masterCipher; - public SessionKey(int localKeyId, int remoteKeyId, SecretKeySpec cipherKey, SecretKeySpec macKey, MasterSecret masterSecret) { + public SessionKey(int mode, int localKeyId, int remoteKeyId, + SecretKeySpec cipherKey, SecretKeySpec macKey, + MasterSecret masterSecret) + { + this.mode = mode; this.localKeyId = localKeyId; this.remoteKeyId = remoteKeyId; this.cipherKey = cipherKey; @@ -48,7 +54,7 @@ public class SessionKey { this.masterCipher = new MasterCipher(masterSecret); } - public SessionKey(byte[] bytes, MasterSecret masterSecret) { + public SessionKey(byte[] bytes, MasterSecret masterSecret) throws InvalidMessageException { this.masterCipher = new MasterCipher(masterSecret); deserialize(bytes); } @@ -58,13 +64,16 @@ public class SessionKey { byte[] remoteKeyIdBytes = Conversions.mediumToByteArray(remoteKeyId); byte[] cipherKeyBytes = cipherKey.getEncoded(); byte[] macKeyBytes = macKey.getEncoded(); - byte[] combined = Util.combine(localKeyIdBytes, remoteKeyIdBytes, cipherKeyBytes, macKeyBytes); + byte[] modeBytes = {(byte)mode}; + byte[] combined = Util.combine(localKeyIdBytes, remoteKeyIdBytes, + cipherKeyBytes, macKeyBytes, modeBytes); return masterCipher.encryptBytes(combined); } - private void deserialize(byte[] bytes) { - byte[] decrypted = masterCipher.encryptBytes(bytes); + private void deserialize(byte[] bytes) throws InvalidMessageException { + byte[] decrypted = masterCipher.decryptBytes(bytes); + this.localKeyId = Conversions.byteArrayToMedium(decrypted, 0); this.remoteKeyId = Conversions.byteArrayToMedium(decrypted, 3); @@ -73,6 +82,12 @@ public class SessionKey { byte[] macBytes = new byte[SessionCipher.MAC_KEY_LENGTH]; System.arraycopy(decrypted, 6 + keyBytes.length, macBytes, 0, macBytes.length); + + if (decrypted.length < 6 + SessionCipher.CIPHER_KEY_LENGTH + SessionCipher.MAC_KEY_LENGTH + 1) { + throw new InvalidMessageException("No mode included"); + } + + this.mode = decrypted[6 + keyBytes.length + macBytes.length]; this.cipherKey = new SecretKeySpec(keyBytes, "AES"); this.macKey = new SecretKeySpec(macBytes, "HmacSHA1"); @@ -94,4 +109,7 @@ public class SessionKey { return this.macKey; } + public int getMode() { + return mode; + } } diff --git a/library/src/org/whispersystems/textsecure/storage/SessionRecord.java b/library/src/org/whispersystems/textsecure/storage/SessionRecord.java index de42a112c3..af3002189a 100644 --- a/library/src/org/whispersystems/textsecure/storage/SessionRecord.java +++ b/library/src/org/whispersystems/textsecure/storage/SessionRecord.java @@ -21,6 +21,7 @@ import android.util.Log; import org.whispersystems.textsecure.crypto.IdentityKey; import org.whispersystems.textsecure.crypto.InvalidKeyException; +import org.whispersystems.textsecure.crypto.InvalidMessageException; import org.whispersystems.textsecure.crypto.MasterSecret; import java.io.FileInputStream; @@ -41,16 +42,16 @@ public class SessionRecord extends Record { private static final int[] VALID_VERSION_MARKERS = {CURRENT_VERSION_MARKER, 0X55555556, 0X55555555}; private static final Object FILE_LOCK = new Object(); - private int counter; + private int counter; private byte[] localFingerprint; private byte[] remoteFingerprint; - private int negotiatedSessionVersion; - private int currentSessionVersion; + private int negotiatedSessionVersion; + private int currentSessionVersion; private IdentityKey identityKey; - private SessionKey sessionKeyRecord; - private boolean verifiedSessionKey; - private boolean prekeyBundleRequired; + private SessionKey sessionKeyRecord; + private boolean verifiedSessionKey; + private boolean prekeyBundleRequired; private final MasterSecret masterSecret; @@ -208,8 +209,14 @@ public class SessionRecord extends Record { this.remoteFingerprint = readBlob(in); this.currentSessionVersion = 31337; - if (in.available() != 0) - this.sessionKeyRecord = new SessionKey(readBlob(in), masterSecret); + if (in.available() != 0) { + try { + this.sessionKeyRecord = new SessionKey(readBlob(in), masterSecret); + } catch (InvalidMessageException e) { + Log.w("SessionRecord", e); + this.sessionKeyRecord = null; + } + } in.close(); } else { @@ -230,8 +237,14 @@ public class SessionRecord extends Record { this.negotiatedSessionVersion = currentSessionVersion; } - if (in.available() != 0) - this.sessionKeyRecord = new SessionKey(readBlob(in), masterSecret); + if (in.available() != 0) { + try { + this.sessionKeyRecord = new SessionKey(readBlob(in), masterSecret); + } catch (InvalidMessageException e) { + Log.w("SessionRecord", e); + this.sessionKeyRecord = null; + } + } in.close(); } @@ -245,12 +258,15 @@ public class SessionRecord extends Record { } } - public SessionKey getSessionKey(int localKeyId, int remoteKeyId) { + public SessionKey getSessionKey(int mode, int localKeyId, int remoteKeyId) { if (this.sessionKeyRecord == null) return null; - if ((this.sessionKeyRecord.getLocalKeyId() == localKeyId) && - (this.sessionKeyRecord.getRemoteKeyId() == remoteKeyId)) + if ((this.sessionKeyRecord.getLocalKeyId() == localKeyId) && + (this.sessionKeyRecord.getRemoteKeyId() == remoteKeyId) && + (this.sessionKeyRecord.getMode() == mode)) + { return this.sessionKeyRecord; + } return null; } diff --git a/library/src/org/whispersystems/textsecure/util/Util.java b/library/src/org/whispersystems/textsecure/util/Util.java index 119059e28f..64bbfd8486 100644 --- a/library/src/org/whispersystems/textsecure/util/Util.java +++ b/library/src/org/whispersystems/textsecure/util/Util.java @@ -18,32 +18,18 @@ import java.util.List; public class Util { - public static byte[] combine(byte[] one, byte[] two) { - byte[] combined = new byte[one.length + two.length]; - System.arraycopy(one, 0, combined, 0, one.length); - System.arraycopy(two, 0, combined, one.length, two.length); + public static byte[] combine(byte[]... elements) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); - return combined; - } - - public static byte[] combine(byte[] one, byte[] two, byte[] three) { - byte[] combined = new byte[one.length + two.length + three.length]; - System.arraycopy(one, 0, combined, 0, one.length); - System.arraycopy(two, 0, combined, one.length, two.length); - System.arraycopy(three, 0, combined, one.length + two.length, three.length); - - return combined; - } - - public static byte[] combine(byte[] one, byte[] two, byte[] three, byte[] four) { - byte[] combined = new byte[one.length + two.length + three.length + four.length]; - System.arraycopy(one, 0, combined, 0, one.length); - System.arraycopy(two, 0, combined, one.length, two.length); - System.arraycopy(three, 0, combined, one.length + two.length, three.length); - System.arraycopy(four, 0, combined, one.length + two.length + three.length, four.length); - - return combined; + for (byte[] element : elements) { + baos.write(element); + } + return baos.toByteArray(); + } catch (IOException e) { + throw new AssertionError(e); + } } public static byte[][] split(byte[] input, int firstLength, int secondLength) {