Add MnemonicCodecTest

This commit is contained in:
Andrew 2024-06-25 19:11:13 +09:30
parent 9288701556
commit 72dccaa1d3
3 changed files with 110 additions and 51 deletions

View File

@ -15,7 +15,6 @@ import network.loki.messenger.R
import org.session.libsignal.crypto.MnemonicCodec
import org.session.libsignal.crypto.MnemonicCodec.DecodingError.InputTooShort
import org.session.libsignal.crypto.MnemonicCodec.DecodingError.InvalidWord
import org.session.libsignal.crypto.MnemonicCodec.DecodingError.MissingLastWord
import org.session.libsignal.utilities.Hex
import org.thoughtcrime.securesms.crypto.MnemonicUtilities
import javax.inject.Inject
@ -45,17 +44,21 @@ internal class LinkDeviceViewModel @Inject constructor(
fun onContinue() {
viewModelScope.launch {
runDecodeCatching(state.value.recoveryPhrase)
.onSuccess(::onSuccess)
.onFailure(::onFailure)
try {
decode(state.value.recoveryPhrase).let(::onSuccess)
} catch (e: Exception) {
onFailure(e)
}
}
}
fun onScanQrCode(string: String) {
viewModelScope.launch {
runDecodeCatching(string)
.onSuccess(::onSuccess)
.onFailure(::onQrCodeScanFailure)
try {
decode(string).let(::onSuccess)
} catch (e: Exception) {
onFailure(e)
}
}
}
@ -70,9 +73,8 @@ internal class LinkDeviceViewModel @Inject constructor(
state.update {
it.copy(
error = when (error) {
is InputTooShort,
is MissingLastWord -> R.string.recoveryPasswordErrorMessageShort
is InvalidWord -> R.string.recoveryPasswordErrorMessageIncorrect
is InputTooShort -> R.string.recoveryPasswordErrorMessageShort
else -> R.string.recoveryPasswordErrorMessageGeneric
}.let(application::getString)
)
@ -83,8 +85,5 @@ internal class LinkDeviceViewModel @Inject constructor(
viewModelScope.launch { _qrErrors.emit(error) }
}
private fun runDecodeCatching(mnemonic: String) = runCatching {
decode(mnemonic)
}
private fun decode(mnemonic: String) = codec.decode(mnemonic).let(Hex::fromStringCondensed)!!
}

File diff suppressed because one or more lines are too long

View File

@ -36,7 +36,6 @@ class MnemonicCodec(private val loadFileContents: (String) -> String) {
sealed class DecodingError(val description: String) : Exception(description) {
object Generic : DecodingError("Something went wrong. Please check your mnemonic and try again.")
object InputTooShort : DecodingError("Looks like you didn't enter enough words. Please check your mnemonic and try again.")
object MissingLastWord : DecodingError("You seem to be missing the last word of your mnemonic. Please check what you entered and try again.")
object InvalidWord : DecodingError("There appears to be an invalid word in your mnemonic. Please check what you entered and try again.")
object VerificationFailed : DecodingError("Your mnemonic couldn't be verified. Please check what you entered and try again.")
}
@ -46,7 +45,7 @@ class MnemonicCodec(private val loadFileContents: (String) -> String) {
val language = Language(loadFileContents, languageConfiguration)
val wordSet = language.loadWordSet()
val prefixLength = languageConfiguration.prefixLength
val result = mutableListOf<String>()
val n = wordSet.size.toLong()
val characterCount = string.length
for (chunkStartIndex in 0..(characterCount - 8) step 8) {
@ -56,54 +55,56 @@ class MnemonicCodec(private val loadFileContents: (String) -> String) {
val p3 = string.substring(chunkEndIndex until characterCount)
string = p1 + p2 + p3
}
for (chunkStartIndex in 0..(characterCount - 8) step 8) {
val chunkEndIndex = chunkStartIndex + 8
val x = string.substring(chunkStartIndex until chunkEndIndex).toLong(16)
return string.windowed(8, 8).map {
val x = it.toLong(16)
val w1 = x % n
val w2 = ((x / n) + w1) % n
val w3 = (((x / n) / n) + w2) % n
result += listOf( wordSet[w1.toInt()], wordSet[w2.toInt()], wordSet[w3.toInt()] )
}
val checksumIndex = determineChecksumIndex(result, prefixLength)
val checksumWord = result[checksumIndex]
result.add(checksumWord)
return result.joinToString(" ")
listOf(w1, w2, w3).map(Long::toInt).map { wordSet[it] }
}.flatten().let {
val checksumIndex = determineChecksumIndex(it, prefixLength)
it + it[checksumIndex]
}.joinToString(" ")
}
fun decode(mnemonic: String, languageConfiguration: Language.Configuration = Language.Configuration.english): String {
val words = mnemonic.split(" ").toMutableList()
val words = mnemonic.split(" ")
val language = Language(loadFileContents, languageConfiguration)
val truncatedWordSet = language.loadTruncatedWordSet()
val prefixLength = languageConfiguration.prefixLength
var result = ""
val n = truncatedWordSet.size.toLong()
if (mnemonic.isEmpty()) throw IllegalArgumentException()
if (words.isEmpty()) throw IllegalArgumentException()
fun String.prefix() = substring(0 until prefixLength)
// Throw on invalid words, as this is the most difficult issue for a user to solve, do this first.
val wordPrefixes = words
.onEach { if (it.length < prefixLength) throw DecodingError.InvalidWord }
.map { it.prefix() }
val wordIndexes = wordPrefixes.map { truncatedWordSet.indexOf(it) }
.onEach { if (it < 0) throw DecodingError.InvalidWord }
// Check preconditions
if (words.size < 12) throw DecodingError.InputTooShort
if (words.size % 3 == 0) throw DecodingError.MissingLastWord
// Get checksum word
val checksumWord = words.removeAt(words.lastIndex)
// Decode
for (chunkStartIndex in 0..(words.size - 3) step 3) {
try {
val w1 = truncatedWordSet.indexOf(words[chunkStartIndex].substring(0 until prefixLength))
val w2 = truncatedWordSet.indexOf(words[chunkStartIndex + 1].substring(0 until prefixLength))
val w3 = truncatedWordSet.indexOf(words[chunkStartIndex + 2].substring(0 until prefixLength))
val x = w1 + n * ((n - w1 + w2) % n) + n * n * ((n - w2 + w3) % n)
if (x % n != w1.toLong()) { throw DecodingError.Generic
}
val string = "0000000" + x.toString(16)
result += swap(string.substring(string.length - 8 until string.length))
} catch (e: Exception) {
throw DecodingError.InvalidWord
}
}
if (words.size < 13) throw DecodingError.InputTooShort
// Verify checksum
val checksumIndex = determineChecksumIndex(words, prefixLength)
val checksumIndex = determineChecksumIndex(words.dropLast(1), prefixLength)
val expectedChecksumWord = words[checksumIndex]
if (expectedChecksumWord.substring(0 until prefixLength) != checksumWord.substring(0 until prefixLength)) { throw DecodingError.VerificationFailed
if (expectedChecksumWord.prefix() != wordPrefixes.last()) {
throw DecodingError.VerificationFailed
}
// Return
return result
// Decode
return wordIndexes.windowed(3, 3) { (w1, w2, w3) ->
val x = w1 + n * ((n - w1 + w2) % n) + n * n * ((n - w2 + w3) % n)
if (x % n != w1.toLong()) throw DecodingError.Generic
val string = "0000000" + x.toString(16)
swap(string.substring(string.length - 8 until string.length))
}.joinToString(separator = "") { it }
}
private fun swap(x: String): String {
@ -116,9 +117,7 @@ class MnemonicCodec(private val loadFileContents: (String) -> String) {
private fun determineChecksumIndex(x: List<String>, prefixLength: Int): Int {
val bytes = x.joinToString("") { it.substring(0 until prefixLength) }.toByteArray()
val crc32 = CRC32()
crc32.update(bytes)
val checksum = crc32.value
val checksum = CRC32().apply { update(bytes) }.value
return (checksum % x.size.toLong()).toInt()
}
}