From c330eef7b97460ebae5faa7345314b563578e482 Mon Sep 17 00:00:00 2001 From: Moxie Marlinspike Date: Thu, 24 Jul 2014 11:59:54 -0700 Subject: [PATCH] Make PreKeyWhisperMessage decrypt more reliably atomic. --- .../libaxolotl/SessionBuilder.java | 32 +++++---------- .../libaxolotl/SessionCipher.java | 39 ++++++++++--------- 2 files changed, 30 insertions(+), 41 deletions(-) diff --git a/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionBuilder.java b/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionBuilder.java index 310b4672bf..a6dd95304d 100644 --- a/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionBuilder.java +++ b/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionBuilder.java @@ -88,35 +88,32 @@ public class SessionBuilder { * @throws org.whispersystems.libaxolotl.InvalidKeyException when the message is formatted incorrectly. * @throws org.whispersystems.libaxolotl.UntrustedIdentityException when the {@link IdentityKey} of the sender is untrusted. */ - /*package*/ boolean process(PreKeyWhisperMessage message) + /*package*/ void process(SessionRecord sessionRecord, PreKeyWhisperMessage message) throws InvalidKeyIdException, InvalidKeyException, UntrustedIdentityException { int messageVersion = message.getMessageVersion(); IdentityKey theirIdentityKey = message.getIdentityKey(); - boolean createdSession; - if (!identityKeyStore.isTrustedIdentity(recipientId, theirIdentityKey)) { throw new UntrustedIdentityException(); } - if (messageVersion == 2) createdSession = processV2(message); - else if (messageVersion == 3) createdSession = processV3(message); - else throw new AssertionError("Unknown version: " + messageVersion); + switch (messageVersion) { + case 2: processV2(sessionRecord, message); break; + case 3: processV3(sessionRecord, message); break; + default: throw new AssertionError("Unknown version: " + messageVersion); + } identityKeyStore.saveIdentity(recipientId, theirIdentityKey); - - return createdSession; } - private boolean processV3(PreKeyWhisperMessage message) + private void processV3(SessionRecord sessionRecord, PreKeyWhisperMessage message) throws UntrustedIdentityException, InvalidKeyIdException, InvalidKeyException { - SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId); if (sessionRecord.hasSessionState(message.getMessageVersion(), message.getBaseKey().serialize())) { Log.w(TAG, "We've already setup a session for this V3 message, letting bundled message fall through..."); - return false; + return; } boolean simultaneousInitiate = sessionRecord.getSessionState().hasUnacknowledgedPreKeyMessage(); @@ -147,16 +144,12 @@ public class SessionBuilder { if (simultaneousInitiate) sessionRecord.getSessionState().setNeedsRefresh(true); - sessionStore.storeSession(recipientId, deviceId, sessionRecord); - if (message.getPreKeyId() >= 0 && message.getPreKeyId() != Medium.MAX_VALUE) { preKeyStore.removePreKey(message.getPreKeyId()); } - - return true; } - private boolean processV2(PreKeyWhisperMessage message) + private void processV2(SessionRecord sessionRecord, PreKeyWhisperMessage message) throws UntrustedIdentityException, InvalidKeyIdException, InvalidKeyException { @@ -164,10 +157,9 @@ public class SessionBuilder { sessionStore.containsSession(recipientId, deviceId)) { Log.w(TAG, "We've already processed the prekey part of this V2 session, letting bundled message fall through..."); - return false; + return; } - SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId); ECKeyPair ourPreKey = preKeyStore.loadPreKey(message.getPreKeyId()).getKeyPair(); boolean simultaneousInitiate = sessionRecord.getSessionState().hasUnacknowledgedPreKeyMessage(); @@ -193,10 +185,6 @@ public class SessionBuilder { if (message.getPreKeyId() != Medium.MAX_VALUE) { preKeyStore.removePreKey(message.getPreKeyId()); } - - sessionStore.storeSession(recipientId, deviceId, sessionRecord); - - return true; } /** diff --git a/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionCipher.java b/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionCipher.java index 59ddcc02d2..d94f648b89 100644 --- a/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionCipher.java +++ b/libaxolotl/src/main/java/org/whispersystems/libaxolotl/SessionCipher.java @@ -145,17 +145,13 @@ public class SessionCipher { InvalidKeyIdException, InvalidKeyException, UntrustedIdentityException, NoSessionException { synchronized (SESSION_LOCK) { - boolean sessionCreated = sessionBuilder.process(ciphertext); + SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId); - try { - return decrypt(ciphertext.getWhisperMessage()); - } catch (InvalidMessageException | DuplicateMessageException | LegacyMessageException e) { - if (sessionCreated) { - sessionStore.deleteSession(recipientId, deviceId); - } + sessionBuilder.process(sessionRecord, ciphertext); + byte[] plaintext = decrypt(sessionRecord, ciphertext.getWhisperMessage()); - throw e; - } + sessionStore.storeSession(recipientId, deviceId, sessionRecord); + return plaintext; } } @@ -182,26 +178,32 @@ public class SessionCipher { throw new NoSessionException("No session for: " + recipientId + ", " + deviceId); } - SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId); + SessionRecord sessionRecord = sessionStore.loadSession(recipientId, deviceId); + byte[] plaintext = decrypt(sessionRecord, ciphertext); + + sessionStore.storeSession(recipientId, deviceId, sessionRecord); + + return plaintext; + } + } + + private byte[] decrypt(SessionRecord sessionRecord, WhisperMessage ciphertext) + throws DuplicateMessageException, LegacyMessageException, InvalidMessageException + { + synchronized (SESSION_LOCK) { SessionState sessionState = sessionRecord.getSessionState(); List previousStates = sessionRecord.getPreviousSessionStates(); List exceptions = new LinkedList<>(); try { - byte[] plaintext = decrypt(sessionState, ciphertext); - sessionStore.storeSession(recipientId, deviceId, sessionRecord); - - return plaintext; + return decrypt(sessionState, ciphertext); } catch (InvalidMessageException e) { exceptions.add(e); } for (SessionState previousState : previousStates) { try { - byte[] plaintext = decrypt(previousState, ciphertext); - sessionStore.storeSession(recipientId, deviceId, sessionRecord); - - return plaintext; + return decrypt(previousState, ciphertext); } catch (InvalidMessageException e) { exceptions.add(e); } @@ -240,7 +242,6 @@ public class SessionCipher { sessionState.clearUnacknowledgedPreKeyMessage(); return plaintext; - } public int getRemoteRegistrationId() {