Collapse SessionRecord, SessionState, and PreKeyRecord interfaces.

This commit is contained in:
Moxie Marlinspike
2014-04-24 15:39:55 -07:00
parent 5a3c19fe3e
commit a601c56af1
25 changed files with 1271 additions and 1836 deletions

View File

@@ -4,25 +4,30 @@ import org.whispersystems.libaxolotl.InvalidKeyIdException;
import org.whispersystems.libaxolotl.state.PreKeyRecord;
import org.whispersystems.libaxolotl.state.PreKeyStore;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
public class InMemoryPreKeyStore implements PreKeyStore {
private final Map<Integer, PreKeyRecord> store = new HashMap<>();
private final Map<Integer, byte[]> store = new HashMap<>();
@Override
public PreKeyRecord load(int preKeyId) throws InvalidKeyIdException {
if (!store.containsKey(preKeyId)) {
throw new InvalidKeyIdException("No such prekeyrecord!");
}
try {
if (!store.containsKey(preKeyId)) {
throw new InvalidKeyIdException("No such prekeyrecord!");
}
return store.get(preKeyId);
return new PreKeyRecord(store.get(preKeyId));
} catch (IOException e) {
throw new AssertionError(e);
}
}
@Override
public void store(int preKeyId, PreKeyRecord record) {
store.put(preKeyId, record);
store.put(preKeyId, record.serialize());
}
@Override

View File

@@ -1,54 +0,0 @@
package org.whispersystems.test;
import org.whispersystems.libaxolotl.state.SessionRecord;
import org.whispersystems.libaxolotl.state.SessionState;
import java.util.LinkedList;
import java.util.List;
public class InMemorySessionRecord implements SessionRecord {
private SessionState currentSessionState;
private List<SessionState> previousSessionStates;
public InMemorySessionRecord() {
currentSessionState = new InMemorySessionState();
previousSessionStates = new LinkedList<>();
}
public InMemorySessionRecord(SessionRecord copy) {
currentSessionState = new InMemorySessionState(copy.getSessionState());
previousSessionStates = new LinkedList<>();
for (SessionState previousState : copy.getPreviousSessionStates()) {
previousSessionStates.add(new InMemorySessionState(previousState));
}
}
@Override
public SessionState getSessionState() {
return currentSessionState;
}
@Override
public List<SessionState> getPreviousSessionStates() {
return previousSessionStates;
}
@Override
public void reset() {
this.currentSessionState = new InMemorySessionState();
this.previousSessionStates = new LinkedList<>();
}
@Override
public void archiveCurrentState() {
this.previousSessionStates.add(currentSessionState);
this.currentSessionState = new InMemorySessionState();
}
@Override
public byte[] serialize() {
throw new AssertionError();
}
}

View File

@@ -1,402 +0,0 @@
package org.whispersystems.test;
import org.whispersystems.libaxolotl.IdentityKey;
import org.whispersystems.libaxolotl.IdentityKeyPair;
import org.whispersystems.libaxolotl.InvalidKeyException;
import org.whispersystems.libaxolotl.ecc.Curve;
import org.whispersystems.libaxolotl.ecc.ECKeyPair;
import org.whispersystems.libaxolotl.ecc.ECPrivateKey;
import org.whispersystems.libaxolotl.ecc.ECPublicKey;
import org.whispersystems.libaxolotl.ratchet.ChainKey;
import org.whispersystems.libaxolotl.ratchet.MessageKeys;
import org.whispersystems.libaxolotl.ratchet.RootKey;
import org.whispersystems.libaxolotl.state.SessionState;
import org.whispersystems.libaxolotl.util.Pair;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import javax.crypto.spec.SecretKeySpec;
public class InMemorySessionState implements SessionState {
private Map<ECPublicKey, InMemoryChain> receiverChains = new HashMap<>();
private boolean needsRefresh;
private int sessionVersion;
private IdentityKey remoteIdentityKey;
private IdentityKey localIdentityKey;
private int previousCounter;
private RootKey rootKey;
private ECKeyPair senderEphemeral;
private ChainKey senderChainKey;
private int pendingPreKeyid;
private ECPublicKey pendingPreKey;
private int remoteRegistrationId;
private int localRegistrationId;
private InMemoryPendingKeyExchange pendingKeyExchange;
public InMemorySessionState() {}
public InMemorySessionState(SessionState sessionState) {
try {
this.needsRefresh = sessionState.getNeedsRefresh();
this.sessionVersion = sessionState.getSessionVersion();
if (sessionState.getRemoteIdentityKey() != null) {
this.remoteIdentityKey = new IdentityKey(sessionState.getRemoteIdentityKey().serialize(), 0);
}
if (sessionState.getLocalIdentityKey() != null) {
this.localIdentityKey = new IdentityKey(sessionState.getLocalIdentityKey().serialize(), 0);
}
this.previousCounter = sessionState.getPreviousCounter();
if (sessionState.getRootKey() != null) {
this.rootKey = new RootKey(sessionState.getRootKey().getKeyBytes());
}
this.senderEphemeral = sessionState.getSenderEphemeralPair();
if (sessionState.getSenderChainKey() != null) {
this.senderChainKey = new ChainKey(sessionState.getSenderChainKey().getKey(),
sessionState.getSenderChainKey().getIndex());
}
if (sessionState.getPendingPreKey() != null) {
this.pendingPreKeyid = sessionState.getPendingPreKey().first();
}
if (sessionState.getPendingPreKey() != null) {
this.pendingPreKey = sessionState.getPendingPreKey().second();
}
this.remoteRegistrationId = sessionState.getRemoteRegistrationId();
this.localRegistrationId = sessionState.getLocalRegistrationId();
if (sessionState.hasPendingKeyExchange()) {
pendingKeyExchange = new InMemoryPendingKeyExchange();
pendingKeyExchange.sequence = sessionState.getPendingKeyExchangeSequence();
pendingKeyExchange.localBaseKey = sessionState.getPendingKeyExchangeBaseKey()
.getPublicKey().serialize();
pendingKeyExchange.localBaseKeyPrivate = sessionState.getPendingKeyExchangeBaseKey()
.getPrivateKey().serialize();
pendingKeyExchange.localEphemeralKey = sessionState.getPendingKeyExchangeEphemeralKey()
.getPublicKey().serialize();
pendingKeyExchange.localEphemeralKeyPrivate = sessionState.getPendingKeyExchangeEphemeralKey()
.getPrivateKey().serialize();
pendingKeyExchange.localIdentityKey = sessionState.getPendingKeyExchangeIdentityKey()
.getPublicKey().serialize();
pendingKeyExchange.localIdentityKeyPrivate = sessionState.getPendingKeyExchangeIdentityKey()
.getPrivateKey().serialize();
}
for (ECPublicKey key : ((InMemorySessionState)sessionState).receiverChains.keySet()) {
ECPublicKey chainKey = Curve.decodePoint(key.serialize(), 0);
InMemoryChain ourChain = new InMemoryChain();
InMemoryChain theirChain = ((InMemorySessionState)sessionState).receiverChains.get(key);
ourChain.chainKey = theirChain.chainKey;
ourChain.index = theirChain.index;
ourChain.messageKeys = theirChain.messageKeys;
receiverChains.put(chainKey, ourChain);
}
} catch (InvalidKeyException e) {
throw new AssertionError(e);
}
}
@Override
public void setNeedsRefresh(boolean needsRefresh) {
this.needsRefresh = needsRefresh;
}
@Override
public boolean getNeedsRefresh() {
return needsRefresh;
}
@Override
public void setSessionVersion(int version) {
this.sessionVersion = version;
}
@Override
public int getSessionVersion() {
return sessionVersion;
}
@Override
public void setRemoteIdentityKey(IdentityKey identityKey) {
this.remoteIdentityKey = identityKey;
}
@Override
public void setLocalIdentityKey(IdentityKey identityKey) {
this.localIdentityKey = identityKey;
}
@Override
public IdentityKey getRemoteIdentityKey() {
return remoteIdentityKey;
}
@Override
public IdentityKey getLocalIdentityKey() {
return localIdentityKey;
}
@Override
public int getPreviousCounter() {
return previousCounter;
}
@Override
public void setPreviousCounter(int previousCounter) {
this.previousCounter = previousCounter;
}
@Override
public RootKey getRootKey() {
return rootKey;
}
@Override
public void setRootKey(RootKey rootKey) {
this.rootKey = rootKey;
}
@Override
public ECPublicKey getSenderEphemeral() {
return senderEphemeral.getPublicKey();
}
@Override
public ECKeyPair getSenderEphemeralPair() {
return senderEphemeral;
}
@Override
public boolean hasReceiverChain(ECPublicKey senderEphemeral) {
return receiverChains.containsKey(senderEphemeral);
}
@Override
public boolean hasSenderChain() {
return senderChainKey != null;
}
@Override
public ChainKey getReceiverChainKey(ECPublicKey senderEphemeral) {
InMemoryChain chain = receiverChains.get(senderEphemeral);
return new ChainKey(chain.chainKey, chain.index);
}
@Override
public void addReceiverChain(ECPublicKey senderEphemeral, ChainKey chainKey) {
InMemoryChain chain = new InMemoryChain();
chain.chainKey = chainKey.getKey();
chain.index = chainKey.getIndex();
receiverChains.put(senderEphemeral, chain);
}
@Override
public void setSenderChain(ECKeyPair senderEphemeralPair, ChainKey chainKey) {
this.senderEphemeral = senderEphemeralPair;
this.senderChainKey = chainKey;
}
@Override
public ChainKey getSenderChainKey() {
return senderChainKey;
}
@Override
public void setSenderChainKey(ChainKey nextChainKey) {
this.senderChainKey = nextChainKey;
}
@Override
public boolean hasMessageKeys(ECPublicKey senderEphemeral, int counter) {
InMemoryChain chain = receiverChains.get(senderEphemeral);
if (chain == null) return false;
for (InMemoryChain.InMemoryMessageKey messageKey : chain.messageKeys) {
if (messageKey.index == counter) {
return true;
}
}
return false;
}
@Override
public MessageKeys removeMessageKeys(ECPublicKey senderEphemeral, int counter) {
InMemoryChain chain = receiverChains.get(senderEphemeral);
MessageKeys results = null;
if (chain == null) return null;
Iterator<InMemoryChain.InMemoryMessageKey> iterator = chain.messageKeys.iterator();
while (iterator.hasNext()) {
InMemoryChain.InMemoryMessageKey messageKey = iterator.next();
if (messageKey.index == counter) {
results = new MessageKeys(new SecretKeySpec(messageKey.cipherKey, "AES"),
new SecretKeySpec(messageKey.macKey, "HmacSHA256"),
messageKey.index);
iterator.remove();
break;
}
}
return results;
}
@Override
public void setMessageKeys(ECPublicKey senderEphemeral, MessageKeys messageKeys) {
InMemoryChain chain = receiverChains.get(senderEphemeral);
InMemoryChain.InMemoryMessageKey key = new InMemoryChain.InMemoryMessageKey();
key.cipherKey = messageKeys.getCipherKey().getEncoded();
key.macKey = messageKeys.getMacKey().getEncoded();
key.index = messageKeys.getCounter();
chain.messageKeys.add(key);
}
@Override
public void setReceiverChainKey(ECPublicKey senderEphemeral, ChainKey chainKey) {
InMemoryChain chain = receiverChains.get(senderEphemeral);
chain.chainKey = chainKey.getKey();
chain.index = chainKey.getIndex();
}
@Override
public void setPendingKeyExchange(int sequence, ECKeyPair ourBaseKey, ECKeyPair ourEphemeralKey,
IdentityKeyPair ourIdentityKey)
{
pendingKeyExchange = new InMemoryPendingKeyExchange();
pendingKeyExchange.sequence = sequence;
pendingKeyExchange.localBaseKey = ourBaseKey.getPublicKey().serialize();
pendingKeyExchange.localBaseKeyPrivate = ourBaseKey.getPrivateKey().serialize();
pendingKeyExchange.localEphemeralKey = ourEphemeralKey.getPublicKey().serialize();
pendingKeyExchange.localEphemeralKeyPrivate = ourEphemeralKey.getPrivateKey().serialize();
pendingKeyExchange.localIdentityKey = ourIdentityKey.getPublicKey().serialize();
pendingKeyExchange.localIdentityKeyPrivate = ourIdentityKey.getPrivateKey().serialize();
}
@Override
public int getPendingKeyExchangeSequence() {
return pendingKeyExchange == null ? 0 : pendingKeyExchange.sequence;
}
@Override
public ECKeyPair getPendingKeyExchangeBaseKey() throws InvalidKeyException {
ECPublicKey publicKey = Curve.decodePoint(pendingKeyExchange.localBaseKey, 0);
ECPrivateKey privateKey = Curve.decodePrivatePoint(pendingKeyExchange.localBaseKeyPrivate);
return new ECKeyPair(publicKey, privateKey);
}
@Override
public ECKeyPair getPendingKeyExchangeEphemeralKey() throws InvalidKeyException {
ECPublicKey publicKey = Curve.decodePoint(pendingKeyExchange.localEphemeralKey, 0);
ECPrivateKey privateKey = Curve.decodePrivatePoint(pendingKeyExchange.localEphemeralKeyPrivate);
return new ECKeyPair(publicKey, privateKey);
}
@Override
public IdentityKeyPair getPendingKeyExchangeIdentityKey() throws InvalidKeyException {
IdentityKey publicKey = new IdentityKey(pendingKeyExchange.localIdentityKey, 0);
ECPrivateKey privateKey = Curve.decodePrivatePoint(pendingKeyExchange.localIdentityKeyPrivate);
return new IdentityKeyPair(publicKey, privateKey);
}
@Override
public boolean hasPendingKeyExchange() {
return pendingKeyExchange != null;
}
@Override
public void setPendingPreKey(int preKeyId, ECPublicKey baseKey) {
this.pendingPreKeyid = preKeyId;
this.pendingPreKey = baseKey;
}
@Override
public boolean hasPendingPreKey() {
return this.pendingPreKey != null;
}
@Override
public Pair<Integer, ECPublicKey> getPendingPreKey() {
return new Pair<>(pendingPreKeyid, pendingPreKey);
}
@Override
public void clearPendingPreKey() {
this.pendingPreKey = null;
this.pendingPreKeyid = -1;
}
@Override
public void setRemoteRegistrationId(int registrationId) {
this.remoteRegistrationId = registrationId;
}
@Override
public int getRemoteRegistrationId() {
return remoteRegistrationId;
}
@Override
public void setLocalRegistrationId(int registrationId) {
this.localRegistrationId = registrationId;
}
@Override
public int getLocalRegistrationId() {
return localRegistrationId;
}
@Override
public byte[] serialize() {
throw new AssertionError();
}
private static class InMemoryChain {
byte[] chainKey;
int index;
List<InMemoryMessageKey> messageKeys = new LinkedList<>();
public static class InMemoryMessageKey {
public InMemoryMessageKey(){}
int index;
byte[] cipherKey;
byte[] macKey;
}
}
private static class InMemoryPendingKeyExchange {
int sequence;
byte[] localBaseKey;
byte[] localBaseKeyPrivate;
byte[] localEphemeralKey;
byte[] localEphemeralKeyPrivate;
byte[] localIdentityKey;
byte[] localIdentityKeyPrivate;
}
}

View File

@@ -4,6 +4,7 @@ import org.whispersystems.libaxolotl.state.SessionRecord;
import org.whispersystems.libaxolotl.state.SessionStore;
import org.whispersystems.libaxolotl.util.Pair;
import java.io.IOException;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
@@ -11,16 +12,20 @@ import java.util.Map;
public class InMemorySessionStore implements SessionStore {
private Map<Pair<Long, Integer>, SessionRecord> sessions = new HashMap<>();
private Map<Pair<Long, Integer>, byte[]> sessions = new HashMap<>();
public InMemorySessionStore() {}
@Override
public synchronized SessionRecord load(long recipientId, int deviceId) {
if (contains(recipientId, deviceId)) {
return new InMemorySessionRecord(sessions.get(new Pair<>(recipientId, deviceId)));
} else {
return new InMemorySessionRecord();
try {
if (contains(recipientId, deviceId)) {
return new SessionRecord(sessions.get(new Pair<>(recipientId, deviceId)));
} else {
return new SessionRecord();
}
} catch (IOException e) {
throw new AssertionError(e);
}
}
@@ -39,7 +44,7 @@ public class InMemorySessionStore implements SessionStore {
@Override
public synchronized void store(long recipientId, int deviceId, SessionRecord record) {
sessions.put(new Pair<>(recipientId, deviceId), record);
sessions.put(new Pair<>(recipientId, deviceId), record.serialize());
}
@Override

View File

@@ -220,16 +220,13 @@ public class SessionBuilderTest extends AndroidTestCase {
}
}
private class InMemoryPreKey implements PreKey, PreKeyRecord {
private class InMemoryPreKey extends PreKeyRecord implements PreKey {
private final int keyId;
private final ECKeyPair keyPair;
private final IdentityKey identityKey;
private final int registrationId;
public InMemoryPreKey(int keyId, ECKeyPair keyPair, IdentityKey identityKey, int registrationId) {
this.keyId = keyId;
this.keyPair = keyPair;
super(keyId, keyPair);
this.identityKey = identityKey;
this.registrationId = registrationId;
}
@@ -241,12 +238,12 @@ public class SessionBuilderTest extends AndroidTestCase {
@Override
public int getKeyId() {
return keyId;
return getId();
}
@Override
public ECPublicKey getPublicKey() {
return keyPair.getPublicKey();
return getKeyPair().getPublicKey();
}
@Override
@@ -258,21 +255,6 @@ public class SessionBuilderTest extends AndroidTestCase {
public int getRegistrationId() {
return registrationId;
}
@Override
public int getId() {
return keyId;
}
@Override
public ECKeyPair getKeyPair() {
return keyPair;
}
@Override
public byte[] serialize() {
throw new AssertionError("nyi");
}
}
}

View File

@@ -25,8 +25,8 @@ public class SessionCipherTest extends AndroidTestCase {
throws InvalidKeyException, DuplicateMessageException,
LegacyMessageException, InvalidMessageException
{
SessionRecord aliceSessionRecord = new InMemorySessionRecord();
SessionRecord bobSessionRecord = new InMemorySessionRecord();
SessionRecord aliceSessionRecord = new SessionRecord();
SessionRecord bobSessionRecord = new SessionRecord();
initializeSessions(aliceSessionRecord.getSessionState(), bobSessionRecord.getSessionState());

View File

@@ -4,7 +4,6 @@ import android.test.AndroidTestCase;
import org.whispersystems.libaxolotl.IdentityKey;
import org.whispersystems.libaxolotl.IdentityKeyPair;
import org.whispersystems.test.InMemorySessionState;
import org.whispersystems.libaxolotl.InvalidKeyException;
import org.whispersystems.libaxolotl.state.SessionState;
import org.whispersystems.libaxolotl.ecc.Curve;
@@ -106,7 +105,7 @@ public class RatchetingSessionTest extends AndroidTestCase {
ECPublicKey aliceEphemeralPublicKey = Curve.decodePoint(aliceEphemeralPublic, 0);
IdentityKey aliceIdentityPublicKey = new IdentityKey(aliceIdentityPublic, 0);
SessionState session = new InMemorySessionState();
SessionState session = new SessionState();
RatchetingSession.initializeSession(session, bobBaseKey, aliceBasePublicKey,
bobEphemeralKey, aliceEphemeralPublicKey,
@@ -203,7 +202,7 @@ public class RatchetingSessionTest extends AndroidTestCase {
ECPrivateKey aliceIdentityPrivateKey = Curve.decodePrivatePoint(aliceIdentityPrivate);
IdentityKeyPair aliceIdentityKey = new IdentityKeyPair(aliceIdentityPublicKey, aliceIdentityPrivateKey);
SessionState session = new InMemorySessionState();
SessionState session = new SessionState();
RatchetingSession.initializeSession(session, aliceBaseKey, bobBasePublicKey,
aliceEphemeralKey, bobEphemeralPublicKey,