mirror of
https://github.com/oxen-io/session-android.git
synced 2024-11-25 02:55:23 +00:00
Add MnemonicCodecTest
This commit is contained in:
parent
9288701556
commit
72dccaa1d3
@ -15,7 +15,6 @@ import network.loki.messenger.R
|
||||
import org.session.libsignal.crypto.MnemonicCodec
|
||||
import org.session.libsignal.crypto.MnemonicCodec.DecodingError.InputTooShort
|
||||
import org.session.libsignal.crypto.MnemonicCodec.DecodingError.InvalidWord
|
||||
import org.session.libsignal.crypto.MnemonicCodec.DecodingError.MissingLastWord
|
||||
import org.session.libsignal.utilities.Hex
|
||||
import org.thoughtcrime.securesms.crypto.MnemonicUtilities
|
||||
import javax.inject.Inject
|
||||
@ -45,17 +44,21 @@ internal class LinkDeviceViewModel @Inject constructor(
|
||||
|
||||
fun onContinue() {
|
||||
viewModelScope.launch {
|
||||
runDecodeCatching(state.value.recoveryPhrase)
|
||||
.onSuccess(::onSuccess)
|
||||
.onFailure(::onFailure)
|
||||
try {
|
||||
decode(state.value.recoveryPhrase).let(::onSuccess)
|
||||
} catch (e: Exception) {
|
||||
onFailure(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun onScanQrCode(string: String) {
|
||||
viewModelScope.launch {
|
||||
runDecodeCatching(string)
|
||||
.onSuccess(::onSuccess)
|
||||
.onFailure(::onQrCodeScanFailure)
|
||||
try {
|
||||
decode(string).let(::onSuccess)
|
||||
} catch (e: Exception) {
|
||||
onFailure(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -70,9 +73,8 @@ internal class LinkDeviceViewModel @Inject constructor(
|
||||
state.update {
|
||||
it.copy(
|
||||
error = when (error) {
|
||||
is InputTooShort,
|
||||
is MissingLastWord -> R.string.recoveryPasswordErrorMessageShort
|
||||
is InvalidWord -> R.string.recoveryPasswordErrorMessageIncorrect
|
||||
is InputTooShort -> R.string.recoveryPasswordErrorMessageShort
|
||||
else -> R.string.recoveryPasswordErrorMessageGeneric
|
||||
}.let(application::getString)
|
||||
)
|
||||
@ -83,8 +85,5 @@ internal class LinkDeviceViewModel @Inject constructor(
|
||||
viewModelScope.launch { _qrErrors.emit(error) }
|
||||
}
|
||||
|
||||
private fun runDecodeCatching(mnemonic: String) = runCatching {
|
||||
decode(mnemonic)
|
||||
}
|
||||
private fun decode(mnemonic: String) = codec.decode(mnemonic).let(Hex::fromStringCondensed)!!
|
||||
}
|
||||
|
File diff suppressed because one or more lines are too long
@ -36,7 +36,6 @@ class MnemonicCodec(private val loadFileContents: (String) -> String) {
|
||||
sealed class DecodingError(val description: String) : Exception(description) {
|
||||
object Generic : DecodingError("Something went wrong. Please check your mnemonic and try again.")
|
||||
object InputTooShort : DecodingError("Looks like you didn't enter enough words. Please check your mnemonic and try again.")
|
||||
object MissingLastWord : DecodingError("You seem to be missing the last word of your mnemonic. Please check what you entered and try again.")
|
||||
object InvalidWord : DecodingError("There appears to be an invalid word in your mnemonic. Please check what you entered and try again.")
|
||||
object VerificationFailed : DecodingError("Your mnemonic couldn't be verified. Please check what you entered and try again.")
|
||||
}
|
||||
@ -46,7 +45,7 @@ class MnemonicCodec(private val loadFileContents: (String) -> String) {
|
||||
val language = Language(loadFileContents, languageConfiguration)
|
||||
val wordSet = language.loadWordSet()
|
||||
val prefixLength = languageConfiguration.prefixLength
|
||||
val result = mutableListOf<String>()
|
||||
|
||||
val n = wordSet.size.toLong()
|
||||
val characterCount = string.length
|
||||
for (chunkStartIndex in 0..(characterCount - 8) step 8) {
|
||||
@ -56,54 +55,56 @@ class MnemonicCodec(private val loadFileContents: (String) -> String) {
|
||||
val p3 = string.substring(chunkEndIndex until characterCount)
|
||||
string = p1 + p2 + p3
|
||||
}
|
||||
for (chunkStartIndex in 0..(characterCount - 8) step 8) {
|
||||
val chunkEndIndex = chunkStartIndex + 8
|
||||
val x = string.substring(chunkStartIndex until chunkEndIndex).toLong(16)
|
||||
|
||||
return string.windowed(8, 8).map {
|
||||
val x = it.toLong(16)
|
||||
val w1 = x % n
|
||||
val w2 = ((x / n) + w1) % n
|
||||
val w3 = (((x / n) / n) + w2) % n
|
||||
result += listOf( wordSet[w1.toInt()], wordSet[w2.toInt()], wordSet[w3.toInt()] )
|
||||
}
|
||||
val checksumIndex = determineChecksumIndex(result, prefixLength)
|
||||
val checksumWord = result[checksumIndex]
|
||||
result.add(checksumWord)
|
||||
return result.joinToString(" ")
|
||||
listOf(w1, w2, w3).map(Long::toInt).map { wordSet[it] }
|
||||
}.flatten().let {
|
||||
val checksumIndex = determineChecksumIndex(it, prefixLength)
|
||||
it + it[checksumIndex]
|
||||
}.joinToString(" ")
|
||||
}
|
||||
|
||||
fun decode(mnemonic: String, languageConfiguration: Language.Configuration = Language.Configuration.english): String {
|
||||
val words = mnemonic.split(" ").toMutableList()
|
||||
val words = mnemonic.split(" ")
|
||||
val language = Language(loadFileContents, languageConfiguration)
|
||||
val truncatedWordSet = language.loadTruncatedWordSet()
|
||||
val prefixLength = languageConfiguration.prefixLength
|
||||
var result = ""
|
||||
val n = truncatedWordSet.size.toLong()
|
||||
|
||||
if (mnemonic.isEmpty()) throw IllegalArgumentException()
|
||||
if (words.isEmpty()) throw IllegalArgumentException()
|
||||
|
||||
fun String.prefix() = substring(0 until prefixLength)
|
||||
|
||||
// Throw on invalid words, as this is the most difficult issue for a user to solve, do this first.
|
||||
val wordPrefixes = words
|
||||
.onEach { if (it.length < prefixLength) throw DecodingError.InvalidWord }
|
||||
.map { it.prefix() }
|
||||
|
||||
val wordIndexes = wordPrefixes.map { truncatedWordSet.indexOf(it) }
|
||||
.onEach { if (it < 0) throw DecodingError.InvalidWord }
|
||||
|
||||
// Check preconditions
|
||||
if (words.size < 12) throw DecodingError.InputTooShort
|
||||
if (words.size % 3 == 0) throw DecodingError.MissingLastWord
|
||||
// Get checksum word
|
||||
val checksumWord = words.removeAt(words.lastIndex)
|
||||
// Decode
|
||||
for (chunkStartIndex in 0..(words.size - 3) step 3) {
|
||||
try {
|
||||
val w1 = truncatedWordSet.indexOf(words[chunkStartIndex].substring(0 until prefixLength))
|
||||
val w2 = truncatedWordSet.indexOf(words[chunkStartIndex + 1].substring(0 until prefixLength))
|
||||
val w3 = truncatedWordSet.indexOf(words[chunkStartIndex + 2].substring(0 until prefixLength))
|
||||
val x = w1 + n * ((n - w1 + w2) % n) + n * n * ((n - w2 + w3) % n)
|
||||
if (x % n != w1.toLong()) { throw DecodingError.Generic
|
||||
}
|
||||
val string = "0000000" + x.toString(16)
|
||||
result += swap(string.substring(string.length - 8 until string.length))
|
||||
} catch (e: Exception) {
|
||||
throw DecodingError.InvalidWord
|
||||
}
|
||||
}
|
||||
if (words.size < 13) throw DecodingError.InputTooShort
|
||||
|
||||
// Verify checksum
|
||||
val checksumIndex = determineChecksumIndex(words, prefixLength)
|
||||
val checksumIndex = determineChecksumIndex(words.dropLast(1), prefixLength)
|
||||
val expectedChecksumWord = words[checksumIndex]
|
||||
if (expectedChecksumWord.substring(0 until prefixLength) != checksumWord.substring(0 until prefixLength)) { throw DecodingError.VerificationFailed
|
||||
if (expectedChecksumWord.prefix() != wordPrefixes.last()) {
|
||||
throw DecodingError.VerificationFailed
|
||||
}
|
||||
// Return
|
||||
return result
|
||||
|
||||
// Decode
|
||||
return wordIndexes.windowed(3, 3) { (w1, w2, w3) ->
|
||||
val x = w1 + n * ((n - w1 + w2) % n) + n * n * ((n - w2 + w3) % n)
|
||||
if (x % n != w1.toLong()) throw DecodingError.Generic
|
||||
val string = "0000000" + x.toString(16)
|
||||
swap(string.substring(string.length - 8 until string.length))
|
||||
}.joinToString(separator = "") { it }
|
||||
}
|
||||
|
||||
private fun swap(x: String): String {
|
||||
@ -116,9 +117,7 @@ class MnemonicCodec(private val loadFileContents: (String) -> String) {
|
||||
|
||||
private fun determineChecksumIndex(x: List<String>, prefixLength: Int): Int {
|
||||
val bytes = x.joinToString("") { it.substring(0 until prefixLength) }.toByteArray()
|
||||
val crc32 = CRC32()
|
||||
crc32.update(bytes)
|
||||
val checksum = crc32.value
|
||||
val checksum = CRC32().apply { update(bytes) }.value
|
||||
return (checksum % x.size.toLong()).toInt()
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user