SessionBuilder improvements, more extensive SessionBuilder tests.

This commit is contained in:
Moxie Marlinspike 2014-04-23 17:12:47 -07:00
parent 72af8b11c2
commit af45e5d544
10 changed files with 511 additions and 57 deletions

View File

@ -0,0 +1,47 @@
package org.whispersystems.test;
import org.whispersystems.libaxolotl.IdentityKey;
import org.whispersystems.libaxolotl.IdentityKeyPair;
import org.whispersystems.libaxolotl.ecc.Curve;
import org.whispersystems.libaxolotl.ecc.ECKeyPair;
import org.whispersystems.libaxolotl.state.IdentityKeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.HashMap;
import java.util.Map;
public class InMemoryIdentityKeyStore implements IdentityKeyStore {
private final Map<Long, IdentityKey> trustedKeys = new HashMap<>();
private final IdentityKeyPair identityKeyPair;
private final int localRegistrationId;
public InMemoryIdentityKeyStore() {
try {
ECKeyPair identityKeyPairKeys = Curve.generateKeyPair(false);
this.identityKeyPair = new IdentityKeyPair(new IdentityKey(identityKeyPairKeys.getPublicKey()),
identityKeyPairKeys.getPrivateKey());
this.localRegistrationId = SecureRandom.getInstance("SHA1PRNG").nextInt(16380) + 1;
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
}
@Override
public IdentityKeyPair getIdentityKeyPair() {
return identityKeyPair;
}
@Override
public int getLocalRegistrationId() {
return localRegistrationId;
}
@Override
public void saveIdentity(long recipientId, IdentityKey identityKey) {
trustedKeys.put(recipientId, identityKey);
}
}

View File

@ -0,0 +1,37 @@
package org.whispersystems.test;
import org.whispersystems.libaxolotl.InvalidKeyIdException;
import org.whispersystems.libaxolotl.state.PreKeyRecord;
import org.whispersystems.libaxolotl.state.PreKeyStore;
import java.util.HashMap;
import java.util.Map;
public class InMemoryPreKeyStore implements PreKeyStore {
private final Map<Integer, PreKeyRecord> store = new HashMap<>();
@Override
public PreKeyRecord load(int preKeyId) throws InvalidKeyIdException {
if (!store.containsKey(preKeyId)) {
throw new InvalidKeyIdException("No such prekeyrecord!");
}
return store.get(preKeyId);
}
@Override
public void store(int preKeyId, PreKeyRecord record) {
store.put(preKeyId, record);
}
@Override
public boolean contains(int preKeyId) {
return store.containsKey(preKeyId);
}
@Override
public void remove(int preKeyId) {
store.remove(preKeyId);
}
}

View File

@ -3,12 +3,14 @@ package org.whispersystems.test;
import org.whispersystems.libaxolotl.IdentityKey; import org.whispersystems.libaxolotl.IdentityKey;
import org.whispersystems.libaxolotl.IdentityKeyPair; import org.whispersystems.libaxolotl.IdentityKeyPair;
import org.whispersystems.libaxolotl.InvalidKeyException; import org.whispersystems.libaxolotl.InvalidKeyException;
import org.whispersystems.libaxolotl.state.SessionState; import org.whispersystems.libaxolotl.ecc.Curve;
import org.whispersystems.libaxolotl.ecc.ECKeyPair; import org.whispersystems.libaxolotl.ecc.ECKeyPair;
import org.whispersystems.libaxolotl.ecc.ECPrivateKey;
import org.whispersystems.libaxolotl.ecc.ECPublicKey; import org.whispersystems.libaxolotl.ecc.ECPublicKey;
import org.whispersystems.libaxolotl.ratchet.ChainKey; import org.whispersystems.libaxolotl.ratchet.ChainKey;
import org.whispersystems.libaxolotl.ratchet.MessageKeys; import org.whispersystems.libaxolotl.ratchet.MessageKeys;
import org.whispersystems.libaxolotl.ratchet.RootKey; import org.whispersystems.libaxolotl.ratchet.RootKey;
import org.whispersystems.libaxolotl.state.SessionState;
import org.whispersystems.libaxolotl.util.Pair; import org.whispersystems.libaxolotl.util.Pair;
import java.util.HashMap; import java.util.HashMap;
@ -36,23 +38,74 @@ public class InMemorySessionState implements SessionState {
private int remoteRegistrationId; private int remoteRegistrationId;
private int localRegistrationId; private int localRegistrationId;
private InMemoryPendingKeyExchange pendingKeyExchange;
public InMemorySessionState() {} public InMemorySessionState() {}
public InMemorySessionState(SessionState sessionState) { public InMemorySessionState(SessionState sessionState) {
try { try {
this.needsRefresh = sessionState.getNeedsRefresh(); this.needsRefresh = sessionState.getNeedsRefresh();
this.sessionVersion = sessionState.getSessionVersion(); this.sessionVersion = sessionState.getSessionVersion();
if (sessionState.getRemoteIdentityKey() != null) {
this.remoteIdentityKey = new IdentityKey(sessionState.getRemoteIdentityKey().serialize(), 0); this.remoteIdentityKey = new IdentityKey(sessionState.getRemoteIdentityKey().serialize(), 0);
}
if (sessionState.getLocalIdentityKey() != null) {
this.localIdentityKey = new IdentityKey(sessionState.getLocalIdentityKey().serialize(), 0); this.localIdentityKey = new IdentityKey(sessionState.getLocalIdentityKey().serialize(), 0);
}
this.previousCounter = sessionState.getPreviousCounter(); this.previousCounter = sessionState.getPreviousCounter();
if (sessionState.getRootKey() != null) {
this.rootKey = new RootKey(sessionState.getRootKey().getKeyBytes()); this.rootKey = new RootKey(sessionState.getRootKey().getKeyBytes());
}
this.senderEphemeral = sessionState.getSenderEphemeralPair(); this.senderEphemeral = sessionState.getSenderEphemeralPair();
if (sessionState.getSenderChainKey() != null) {
this.senderChainKey = new ChainKey(sessionState.getSenderChainKey().getKey(), this.senderChainKey = new ChainKey(sessionState.getSenderChainKey().getKey(),
sessionState.getSenderChainKey().getIndex()); sessionState.getSenderChainKey().getIndex());
}
if (sessionState.getPendingPreKey() != null) {
this.pendingPreKeyid = sessionState.getPendingPreKey().first(); this.pendingPreKeyid = sessionState.getPendingPreKey().first();
}
if (sessionState.getPendingPreKey() != null) {
this.pendingPreKey = sessionState.getPendingPreKey().second(); this.pendingPreKey = sessionState.getPendingPreKey().second();
}
this.remoteRegistrationId = sessionState.getRemoteRegistrationId(); this.remoteRegistrationId = sessionState.getRemoteRegistrationId();
this.localRegistrationId = sessionState.getLocalRegistrationId(); this.localRegistrationId = sessionState.getLocalRegistrationId();
if (sessionState.hasPendingKeyExchange()) {
pendingKeyExchange = new InMemoryPendingKeyExchange();
pendingKeyExchange.sequence = sessionState.getPendingKeyExchangeSequence();
pendingKeyExchange.localBaseKey = sessionState.getPendingKeyExchangeBaseKey()
.getPublicKey().serialize();
pendingKeyExchange.localBaseKeyPrivate = sessionState.getPendingKeyExchangeBaseKey()
.getPrivateKey().serialize();
pendingKeyExchange.localEphemeralKey = sessionState.getPendingKeyExchangeEphemeralKey()
.getPublicKey().serialize();
pendingKeyExchange.localEphemeralKeyPrivate = sessionState.getPendingKeyExchangeEphemeralKey()
.getPrivateKey().serialize();
pendingKeyExchange.localIdentityKey = sessionState.getPendingKeyExchangeIdentityKey()
.getPublicKey().serialize();
pendingKeyExchange.localIdentityKeyPrivate = sessionState.getPendingKeyExchangeIdentityKey()
.getPrivateKey().serialize();
}
for (ECPublicKey key : ((InMemorySessionState)sessionState).receiverChains.keySet()) {
ECPublicKey chainKey = Curve.decodePoint(key.serialize(), 0);
InMemoryChain ourChain = new InMemoryChain();
InMemoryChain theirChain = ((InMemorySessionState)sessionState).receiverChains.get(key);
ourChain.chainKey = theirChain.chainKey;
ourChain.index = theirChain.index;
ourChain.messageKeys = theirChain.messageKeys;
receiverChains.put(chainKey, ourChain);
}
} catch (InvalidKeyException e) { } catch (InvalidKeyException e) {
throw new AssertionError(e); throw new AssertionError(e);
} }
@ -230,33 +283,51 @@ public class InMemorySessionState implements SessionState {
} }
@Override @Override
public void setPendingKeyExchange(int sequence, ECKeyPair ourBaseKey, ECKeyPair ourEphemeralKey, IdentityKeyPair ourIdentityKey) { public void setPendingKeyExchange(int sequence, ECKeyPair ourBaseKey, ECKeyPair ourEphemeralKey,
throw new AssertionError(); IdentityKeyPair ourIdentityKey)
{
pendingKeyExchange = new InMemoryPendingKeyExchange();
pendingKeyExchange.sequence = sequence;
pendingKeyExchange.localBaseKey = ourBaseKey.getPublicKey().serialize();
pendingKeyExchange.localBaseKeyPrivate = ourBaseKey.getPrivateKey().serialize();
pendingKeyExchange.localEphemeralKey = ourEphemeralKey.getPublicKey().serialize();
pendingKeyExchange.localEphemeralKeyPrivate = ourEphemeralKey.getPrivateKey().serialize();
pendingKeyExchange.localIdentityKey = ourIdentityKey.getPublicKey().serialize();
pendingKeyExchange.localIdentityKeyPrivate = ourIdentityKey.getPrivateKey().serialize();
} }
@Override @Override
public int getPendingKeyExchangeSequence() { public int getPendingKeyExchangeSequence() {
throw new AssertionError(); return pendingKeyExchange == null ? 0 : pendingKeyExchange.sequence;
} }
@Override @Override
public ECKeyPair getPendingKeyExchangeBaseKey() throws InvalidKeyException { public ECKeyPair getPendingKeyExchangeBaseKey() throws InvalidKeyException {
throw new AssertionError(); ECPublicKey publicKey = Curve.decodePoint(pendingKeyExchange.localBaseKey, 0);
ECPrivateKey privateKey = Curve.decodePrivatePoint(pendingKeyExchange.localBaseKeyPrivate);
return new ECKeyPair(publicKey, privateKey);
} }
@Override @Override
public ECKeyPair getPendingKeyExchangeEphemeralKey() throws InvalidKeyException { public ECKeyPair getPendingKeyExchangeEphemeralKey() throws InvalidKeyException {
throw new AssertionError(); ECPublicKey publicKey = Curve.decodePoint(pendingKeyExchange.localEphemeralKey, 0);
ECPrivateKey privateKey = Curve.decodePrivatePoint(pendingKeyExchange.localEphemeralKeyPrivate);
return new ECKeyPair(publicKey, privateKey);
} }
@Override @Override
public IdentityKeyPair getPendingKeyExchangeIdentityKey() throws InvalidKeyException { public IdentityKeyPair getPendingKeyExchangeIdentityKey() throws InvalidKeyException {
throw new AssertionError(); IdentityKey publicKey = new IdentityKey(pendingKeyExchange.localIdentityKey, 0);
ECPrivateKey privateKey = Curve.decodePrivatePoint(pendingKeyExchange.localIdentityKeyPrivate);
return new IdentityKeyPair(publicKey, privateKey);
} }
@Override @Override
public boolean hasPendingKeyExchange() { public boolean hasPendingKeyExchange() {
throw new AssertionError(); return pendingKeyExchange != null;
} }
@Override @Override
@ -318,4 +389,14 @@ public class InMemorySessionState implements SessionState {
byte[] macKey; byte[] macKey;
} }
} }
private static class InMemoryPendingKeyExchange {
int sequence;
byte[] localBaseKey;
byte[] localBaseKeyPrivate;
byte[] localEphemeralKey;
byte[] localEphemeralKeyPrivate;
byte[] localIdentityKey;
byte[] localIdentityKeyPrivate;
}
} }

View File

@ -16,7 +16,7 @@ public class InMemorySessionStore implements SessionStore {
public InMemorySessionStore() {} public InMemorySessionStore() {}
@Override @Override
public SessionRecord get(long recipientId, int deviceId) { public synchronized SessionRecord get(long recipientId, int deviceId) {
if (contains(recipientId, deviceId)) { if (contains(recipientId, deviceId)) {
return new InMemorySessionRecord(sessions.get(new Pair<>(recipientId, deviceId))); return new InMemorySessionRecord(sessions.get(new Pair<>(recipientId, deviceId)));
} else { } else {
@ -25,7 +25,7 @@ public class InMemorySessionStore implements SessionStore {
} }
@Override @Override
public List<Integer> getSubDeviceSessions(long recipientId) { public synchronized List<Integer> getSubDeviceSessions(long recipientId) {
List<Integer> deviceIds = new LinkedList<>(); List<Integer> deviceIds = new LinkedList<>();
for (Pair<Long, Integer> key : sessions.keySet()) { for (Pair<Long, Integer> key : sessions.keySet()) {
@ -38,22 +38,22 @@ public class InMemorySessionStore implements SessionStore {
} }
@Override @Override
public void put(long recipientId, int deviceId, SessionRecord record) { public synchronized void put(long recipientId, int deviceId, SessionRecord record) {
sessions.put(new Pair<>(recipientId, deviceId), record); sessions.put(new Pair<>(recipientId, deviceId), record);
} }
@Override @Override
public boolean contains(long recipientId, int deviceId) { public synchronized boolean contains(long recipientId, int deviceId) {
return sessions.containsKey(new Pair<>(recipientId, deviceId)); return sessions.containsKey(new Pair<>(recipientId, deviceId));
} }
@Override @Override
public void delete(long recipientId, int deviceId) { public synchronized void delete(long recipientId, int deviceId) {
sessions.remove(new Pair<>(recipientId, deviceId)); sessions.remove(new Pair<>(recipientId, deviceId));
} }
@Override @Override
public void deleteAll(long recipientId) { public synchronized void deleteAll(long recipientId) {
for (Pair<Long, Integer> key : sessions.keySet()) { for (Pair<Long, Integer> key : sessions.keySet()) {
if (key.first() == recipientId) { if (key.first() == recipientId) {
sessions.remove(key); sessions.remove(key);

View File

@ -0,0 +1,275 @@
package org.whispersystems.test;
import android.test.AndroidTestCase;
import android.util.Log;
import org.whispersystems.libaxolotl.DuplicateMessageException;
import org.whispersystems.libaxolotl.IdentityKey;
import org.whispersystems.libaxolotl.InvalidKeyException;
import org.whispersystems.libaxolotl.InvalidKeyIdException;
import org.whispersystems.libaxolotl.InvalidMessageException;
import org.whispersystems.libaxolotl.InvalidVersionException;
import org.whispersystems.libaxolotl.LegacyMessageException;
import org.whispersystems.libaxolotl.SessionBuilder;
import org.whispersystems.libaxolotl.SessionCipher;
import org.whispersystems.libaxolotl.ecc.Curve;
import org.whispersystems.libaxolotl.ecc.ECKeyPair;
import org.whispersystems.libaxolotl.ecc.ECPublicKey;
import org.whispersystems.libaxolotl.protocol.CiphertextMessage;
import org.whispersystems.libaxolotl.protocol.KeyExchangeMessage;
import org.whispersystems.libaxolotl.protocol.PreKeyWhisperMessage;
import org.whispersystems.libaxolotl.state.IdentityKeyStore;
import org.whispersystems.libaxolotl.state.PreKey;
import org.whispersystems.libaxolotl.state.PreKeyRecord;
import org.whispersystems.libaxolotl.state.PreKeyStore;
import org.whispersystems.libaxolotl.state.SessionStore;
import org.whispersystems.libaxolotl.util.Pair;
import java.util.HashSet;
import java.util.Set;
public class SessionBuilderTest extends AndroidTestCase {
private static final long ALICE_RECIPIENT_ID = 5L;
private static final long BOB_RECIPIENT_ID = 2L;
public void testBasicPreKey()
throws InvalidKeyException, InvalidVersionException, InvalidMessageException, InvalidKeyIdException, DuplicateMessageException, LegacyMessageException
{
SessionStore aliceSessionStore = new InMemorySessionStore();
PreKeyStore alicePreKeyStore = new InMemoryPreKeyStore();
IdentityKeyStore aliceIdentityKeyStore = new InMemoryIdentityKeyStore();
SessionBuilder aliceSessionBuilder = new SessionBuilder(aliceSessionStore, alicePreKeyStore,
aliceIdentityKeyStore,
BOB_RECIPIENT_ID, 1);
SessionStore bobSessionStore = new InMemorySessionStore();
PreKeyStore bobPreKeyStore = new InMemoryPreKeyStore();
IdentityKeyStore bobIdentityKeyStore = new InMemoryIdentityKeyStore();
SessionBuilder bobSessionBuilder = new SessionBuilder(bobSessionStore, bobPreKeyStore,
bobIdentityKeyStore,
ALICE_RECIPIENT_ID, 1);
InMemoryPreKey bobPreKey = new InMemoryPreKey(31337, Curve.generateKeyPair(true),
bobIdentityKeyStore.getIdentityKeyPair().getPublicKey(),
bobIdentityKeyStore.getLocalRegistrationId());
aliceSessionBuilder.process(bobPreKey);
assertTrue(aliceSessionStore.contains(BOB_RECIPIENT_ID, 1));
assertTrue(!aliceSessionStore.get(BOB_RECIPIENT_ID, 1).getSessionState().getNeedsRefresh());
String originalMessage = "L'homme est condamné à être libre";
SessionCipher aliceSessionCipher = new SessionCipher(aliceSessionStore, BOB_RECIPIENT_ID, 1);
CiphertextMessage outgoingMessage = aliceSessionCipher.encrypt(originalMessage.getBytes());
assertTrue(outgoingMessage.getType() == CiphertextMessage.PREKEY_TYPE);
PreKeyWhisperMessage incomingMessage = new PreKeyWhisperMessage(outgoingMessage.serialize());
bobPreKeyStore.store(31337, bobPreKey);
bobSessionBuilder.process(incomingMessage);
assertTrue(bobSessionStore.contains(ALICE_RECIPIENT_ID, 1));
SessionCipher bobSessionCipher = new SessionCipher(bobSessionStore, ALICE_RECIPIENT_ID, 1);
byte[] plaintext = bobSessionCipher.decrypt(incomingMessage.getWhisperMessage().serialize());
assertTrue(originalMessage.equals(new String(plaintext)));
}
public void testBasicKeyExchange() throws InvalidKeyException, LegacyMessageException, InvalidMessageException, DuplicateMessageException {
SessionStore aliceSessionStore = new InMemorySessionStore();
PreKeyStore alicePreKeyStore = new InMemoryPreKeyStore();
IdentityKeyStore aliceIdentityKeyStore = new InMemoryIdentityKeyStore();
SessionBuilder aliceSessionBuilder = new SessionBuilder(aliceSessionStore, alicePreKeyStore,
aliceIdentityKeyStore,
BOB_RECIPIENT_ID, 1);
SessionStore bobSessionStore = new InMemorySessionStore();
PreKeyStore bobPreKeyStore = new InMemoryPreKeyStore();
IdentityKeyStore bobIdentityKeyStore = new InMemoryIdentityKeyStore();
SessionBuilder bobSessionBuilder = new SessionBuilder(bobSessionStore, bobPreKeyStore,
bobIdentityKeyStore,
ALICE_RECIPIENT_ID, 1);
KeyExchangeMessage aliceKeyExchangeMessage = aliceSessionBuilder.process();
KeyExchangeMessage bobKeyExchangeMessage = bobSessionBuilder.process(aliceKeyExchangeMessage);
Log.w("SessionBuilderTest", "Record from test: " + bobSessionStore.get(ALICE_RECIPIENT_ID, 1));
assertTrue(bobKeyExchangeMessage != null);
assertTrue(aliceKeyExchangeMessage != null);
KeyExchangeMessage response = aliceSessionBuilder.process(bobKeyExchangeMessage);
Log.w("SessionBuilderTest", "Record from test 2: " + bobSessionStore.get(ALICE_RECIPIENT_ID, 1));
assertTrue(response == null);
assertTrue(aliceSessionStore.contains(BOB_RECIPIENT_ID, 1));
assertTrue(bobSessionStore.contains(ALICE_RECIPIENT_ID, 1));
runInteraction(aliceSessionStore, bobSessionStore);
}
public void testSimultaneousKeyExchange()
throws InvalidKeyException, DuplicateMessageException, LegacyMessageException, InvalidMessageException
{
SessionStore aliceSessionStore = new InMemorySessionStore();
PreKeyStore alicePreKeyStore = new InMemoryPreKeyStore();
IdentityKeyStore aliceIdentityKeyStore = new InMemoryIdentityKeyStore();
SessionBuilder aliceSessionBuilder = new SessionBuilder(aliceSessionStore, alicePreKeyStore,
aliceIdentityKeyStore,
BOB_RECIPIENT_ID, 1);
SessionStore bobSessionStore = new InMemorySessionStore();
PreKeyStore bobPreKeyStore = new InMemoryPreKeyStore();
IdentityKeyStore bobIdentityKeyStore = new InMemoryIdentityKeyStore();
SessionBuilder bobSessionBuilder = new SessionBuilder(bobSessionStore, bobPreKeyStore,
bobIdentityKeyStore,
ALICE_RECIPIENT_ID, 1);
KeyExchangeMessage aliceKeyExchange = aliceSessionBuilder.process();
KeyExchangeMessage bobKeyExchange = bobSessionBuilder.process();
assertTrue(aliceKeyExchange != null);
assertTrue(bobKeyExchange != null);
KeyExchangeMessage aliceResponse = aliceSessionBuilder.process(bobKeyExchange);
KeyExchangeMessage bobResponse = bobSessionBuilder.process(aliceKeyExchange);
assertTrue(aliceResponse != null);
assertTrue(bobResponse != null);
KeyExchangeMessage aliceAck = aliceSessionBuilder.process(bobResponse);
KeyExchangeMessage bobAck = bobSessionBuilder.process(aliceResponse);
assertTrue(aliceAck == null);
assertTrue(bobAck == null);
runInteraction(aliceSessionStore, bobSessionStore);
}
private void runInteraction(SessionStore aliceSessionStore, SessionStore bobSessionStore)
throws DuplicateMessageException, LegacyMessageException, InvalidMessageException
{
SessionCipher aliceSessionCipher = new SessionCipher(aliceSessionStore, BOB_RECIPIENT_ID, 1);
SessionCipher bobSessionCipher = new SessionCipher(bobSessionStore, ALICE_RECIPIENT_ID, 1);
String originalMessage = "smert ze smert";
CiphertextMessage aliceMessage = aliceSessionCipher.encrypt(originalMessage.getBytes());
assertTrue(aliceMessage.getType() == CiphertextMessage.WHISPER_TYPE);
byte[] plaintext = bobSessionCipher.decrypt(aliceMessage.serialize());
assertTrue(new String(plaintext).equals(originalMessage));
CiphertextMessage bobMessage = bobSessionCipher.encrypt(originalMessage.getBytes());
assertTrue(bobMessage.getType() == CiphertextMessage.WHISPER_TYPE);
plaintext = aliceSessionCipher.decrypt(bobMessage.serialize());
assertTrue(new String(plaintext).equals(originalMessage));
for (int i=0;i<10;i++) {
String loopingMessage = ("You can only desire based on what you know: " + i);
CiphertextMessage aliceLoopingMessage = aliceSessionCipher.encrypt(loopingMessage.getBytes());
byte[] loopingPlaintext = bobSessionCipher.decrypt(aliceLoopingMessage.serialize());
assertTrue(new String(loopingPlaintext).equals(loopingMessage));
}
for (int i=0;i<10;i++) {
String loopingMessage = ("You can only desire based on what you know: " + i);
CiphertextMessage bobLoopingMessage = bobSessionCipher.encrypt(loopingMessage.getBytes());
byte[] loopingPlaintext = aliceSessionCipher.decrypt(bobLoopingMessage.serialize());
assertTrue(new String(loopingPlaintext).equals(loopingMessage));
}
Set<Pair<String, CiphertextMessage>> aliceOutOfOrderMessages = new HashSet<>();
for (int i=0;i<10;i++) {
String loopingMessage = ("You can only desire based on what you know: " + i);
CiphertextMessage aliceLoopingMessage = aliceSessionCipher.encrypt(loopingMessage.getBytes());
aliceOutOfOrderMessages.add(new Pair<>(loopingMessage, aliceLoopingMessage));
}
for (int i=0;i<10;i++) {
String loopingMessage = ("You can only desire based on what you know: " + i);
CiphertextMessage aliceLoopingMessage = aliceSessionCipher.encrypt(loopingMessage.getBytes());
byte[] loopingPlaintext = bobSessionCipher.decrypt(aliceLoopingMessage.serialize());
assertTrue(new String(loopingPlaintext).equals(loopingMessage));
}
for (int i=0;i<10;i++) {
String loopingMessage = ("You can only desire based on what you know: " + i);
CiphertextMessage bobLoopingMessage = bobSessionCipher.encrypt(loopingMessage.getBytes());
byte[] loopingPlaintext = aliceSessionCipher.decrypt(bobLoopingMessage.serialize());
assertTrue(new String(loopingPlaintext).equals(loopingMessage));
}
for (Pair<String, CiphertextMessage> aliceOutOfOrderMessage : aliceOutOfOrderMessages) {
byte[] outOfOrderPlaintext = bobSessionCipher.decrypt(aliceOutOfOrderMessage.second().serialize());
assertTrue(new String(outOfOrderPlaintext).equals(aliceOutOfOrderMessage.first()));
}
}
private class InMemoryPreKey implements PreKey, PreKeyRecord {
private final int keyId;
private final ECKeyPair keyPair;
private final IdentityKey identityKey;
private final int registrationId;
public InMemoryPreKey(int keyId, ECKeyPair keyPair, IdentityKey identityKey, int registrationId) {
this.keyId = keyId;
this.keyPair = keyPair;
this.identityKey = identityKey;
this.registrationId = registrationId;
}
@Override
public int getDeviceId() {
return 1;
}
@Override
public int getKeyId() {
return keyId;
}
@Override
public ECPublicKey getPublicKey() {
return keyPair.getPublicKey();
}
@Override
public IdentityKey getIdentityKey() {
return identityKey;
}
@Override
public int getRegistrationId() {
return registrationId;
}
@Override
public int getId() {
return keyId;
}
@Override
public ECKeyPair getKeyPair() {
return keyPair;
}
@Override
public byte[] serialize() {
throw new AssertionError("nyi");
}
}
}

View File

@ -14,6 +14,7 @@ import org.whispersystems.libaxolotl.state.PreKeyRecord;
import org.whispersystems.libaxolotl.state.PreKeyStore; import org.whispersystems.libaxolotl.state.PreKeyStore;
import org.whispersystems.libaxolotl.state.SessionRecord; import org.whispersystems.libaxolotl.state.SessionRecord;
import org.whispersystems.libaxolotl.state.SessionStore; import org.whispersystems.libaxolotl.state.SessionStore;
import org.whispersystems.libaxolotl.util.Helper;
import org.whispersystems.libaxolotl.util.Medium; import org.whispersystems.libaxolotl.util.Medium;
public class SessionBuilder { public class SessionBuilder {
@ -176,5 +177,22 @@ public class SessionBuilder {
return responseMessage; return responseMessage;
} }
public KeyExchangeMessage process() {
int sequence = Helper.getRandomSequence(65534) + 1;
int flags = KeyExchangeMessage.INITIATE_FLAG;
ECKeyPair baseKey = Curve.generateKeyPair(true);
ECKeyPair ephemeralKey = Curve.generateKeyPair(true);
IdentityKeyPair identityKey = identityKeyStore.getIdentityKeyPair();
SessionRecord sessionRecord = sessionStore.get(recipientId, deviceId);
sessionRecord.getSessionState().setPendingKeyExchange(sequence, baseKey, ephemeralKey, identityKey);
sessionStore.put(recipientId, deviceId, sessionRecord);
return new KeyExchangeMessage(sequence, flags,
baseKey.getPublicKey(),
ephemeralKey.getPublicKey(),
identityKey.getPublicKey());
}
} }

View File

@ -0,0 +1,16 @@
package org.whispersystems.libaxolotl.util;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
public class Helper {
public static int getRandomSequence(int max) {
try {
return SecureRandom.getInstance("SHA1PRNG").nextInt(max);
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
}
}

View File

@ -34,9 +34,8 @@ public class Hex {
public static String toString(byte[] bytes, int offset, int length) { public static String toString(byte[] bytes, int offset, int length) {
StringBuffer buf = new StringBuffer(); StringBuffer buf = new StringBuffer();
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
buf.append("(byte)0x");
appendHexChar(buf, bytes[offset + i]); appendHexChar(buf, bytes[offset + i]);
buf.append(", "); buf.append(" ");
} }
return buf.toString(); return buf.toString();
} }

View File

@ -26,20 +26,18 @@ import org.thoughtcrime.securesms.recipients.Recipient;
import org.thoughtcrime.securesms.sms.MessageSender; import org.thoughtcrime.securesms.sms.MessageSender;
import org.thoughtcrime.securesms.sms.OutgoingKeyExchangeMessage; import org.thoughtcrime.securesms.sms.OutgoingKeyExchangeMessage;
import org.thoughtcrime.securesms.util.Dialogs; import org.thoughtcrime.securesms.util.Dialogs;
import org.whispersystems.libaxolotl.IdentityKeyPair; import org.whispersystems.libaxolotl.SessionBuilder;
import org.whispersystems.libaxolotl.ecc.Curve;
import org.whispersystems.libaxolotl.ecc.ECKeyPair;
import org.whispersystems.libaxolotl.protocol.KeyExchangeMessage; import org.whispersystems.libaxolotl.protocol.KeyExchangeMessage;
import org.whispersystems.libaxolotl.state.IdentityKeyStore;
import org.whispersystems.libaxolotl.state.PreKeyStore;
import org.whispersystems.libaxolotl.state.SessionRecord; import org.whispersystems.libaxolotl.state.SessionRecord;
import org.whispersystems.libaxolotl.state.SessionStore; import org.whispersystems.libaxolotl.state.SessionStore;
import org.whispersystems.textsecure.crypto.MasterSecret; import org.whispersystems.textsecure.crypto.MasterSecret;
import org.whispersystems.textsecure.storage.RecipientDevice; import org.whispersystems.textsecure.storage.RecipientDevice;
import org.whispersystems.textsecure.storage.TextSecurePreKeyStore;
import org.whispersystems.textsecure.storage.TextSecureSessionStore; import org.whispersystems.textsecure.storage.TextSecureSessionStore;
import org.whispersystems.textsecure.util.Base64; import org.whispersystems.textsecure.util.Base64;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
public class KeyExchangeInitiator { public class KeyExchangeInitiator {
public static void initiate(final Context context, final MasterSecret masterSecret, final Recipient recipient, boolean promptOnExisting) { public static void initiate(final Context context, final MasterSecret masterSecret, final Recipient recipient, boolean promptOnExisting) {
@ -62,23 +60,17 @@ public class KeyExchangeInitiator {
} }
private static void initiateKeyExchange(Context context, MasterSecret masterSecret, Recipient recipient) { private static void initiateKeyExchange(Context context, MasterSecret masterSecret, Recipient recipient) {
int sequence = getRandomSequence();
int flags = KeyExchangeMessage.INITIATE_FLAG;
ECKeyPair baseKey = Curve.generateKeyPair(true);
ECKeyPair ephemeralKey = Curve.generateKeyPair(true);
IdentityKeyPair identityKey = IdentityKeyUtil.getIdentityKeyPair(context, masterSecret);
KeyExchangeMessage message = new KeyExchangeMessage(sequence, flags,
baseKey.getPublicKey(),
ephemeralKey.getPublicKey(),
identityKey.getPublicKey());
OutgoingKeyExchangeMessage textMessage = new OutgoingKeyExchangeMessage(recipient, Base64.encodeBytesWithoutPadding(message.serialize()));
SessionStore sessionStore = new TextSecureSessionStore(context, masterSecret); SessionStore sessionStore = new TextSecureSessionStore(context, masterSecret);
SessionRecord sessionRecord = sessionStore.get(recipient.getRecipientId(), RecipientDevice.DEFAULT_DEVICE_ID); PreKeyStore preKeyStore = new TextSecurePreKeyStore(context, masterSecret);
IdentityKeyStore identityKeyStore = new TextSecureIdentityKeyStore(context, masterSecret);
sessionRecord.getSessionState().setPendingKeyExchange(sequence, baseKey, ephemeralKey, identityKey); SessionBuilder sessionBuilder = new SessionBuilder(sessionStore, preKeyStore, identityKeyStore,
sessionStore.put(recipient.getRecipientId(), RecipientDevice.DEFAULT_DEVICE_ID, sessionRecord); recipient.getRecipientId(),
RecipientDevice.DEFAULT_DEVICE_ID);
KeyExchangeMessage keyExchangeMessage = sessionBuilder.process();
String serializedMessage = Base64.encodeBytesWithoutPadding(keyExchangeMessage.serialize());
OutgoingKeyExchangeMessage textMessage = new OutgoingKeyExchangeMessage(recipient, serializedMessage);
MessageSender.send(context, masterSecret, textMessage, -1, false); MessageSender.send(context, masterSecret, textMessage, -1, false);
} }
@ -91,15 +83,4 @@ public class KeyExchangeInitiator {
return sessionRecord.getSessionState().hasPendingPreKey(); return sessionRecord.getSessionState().hasPendingPreKey();
} }
private static int getRandomSequence() {
try {
SecureRandom random = SecureRandom.getInstance("SHA1PRNG");
int candidate = Math.abs(random.nextInt());
return candidate % 65535;
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
}
} }