Add MnemonicCodecTest

This commit is contained in:
Andrew
2024-06-25 19:11:13 +09:30
parent 9288701556
commit 72dccaa1d3
3 changed files with 110 additions and 51 deletions

View File

@@ -36,7 +36,6 @@ class MnemonicCodec(private val loadFileContents: (String) -> String) {
sealed class DecodingError(val description: String) : Exception(description) {
object Generic : DecodingError("Something went wrong. Please check your mnemonic and try again.")
object InputTooShort : DecodingError("Looks like you didn't enter enough words. Please check your mnemonic and try again.")
object MissingLastWord : DecodingError("You seem to be missing the last word of your mnemonic. Please check what you entered and try again.")
object InvalidWord : DecodingError("There appears to be an invalid word in your mnemonic. Please check what you entered and try again.")
object VerificationFailed : DecodingError("Your mnemonic couldn't be verified. Please check what you entered and try again.")
}
@@ -46,7 +45,7 @@ class MnemonicCodec(private val loadFileContents: (String) -> String) {
val language = Language(loadFileContents, languageConfiguration)
val wordSet = language.loadWordSet()
val prefixLength = languageConfiguration.prefixLength
val result = mutableListOf<String>()
val n = wordSet.size.toLong()
val characterCount = string.length
for (chunkStartIndex in 0..(characterCount - 8) step 8) {
@@ -56,54 +55,56 @@ class MnemonicCodec(private val loadFileContents: (String) -> String) {
val p3 = string.substring(chunkEndIndex until characterCount)
string = p1 + p2 + p3
}
for (chunkStartIndex in 0..(characterCount - 8) step 8) {
val chunkEndIndex = chunkStartIndex + 8
val x = string.substring(chunkStartIndex until chunkEndIndex).toLong(16)
return string.windowed(8, 8).map {
val x = it.toLong(16)
val w1 = x % n
val w2 = ((x / n) + w1) % n
val w3 = (((x / n) / n) + w2) % n
result += listOf( wordSet[w1.toInt()], wordSet[w2.toInt()], wordSet[w3.toInt()] )
}
val checksumIndex = determineChecksumIndex(result, prefixLength)
val checksumWord = result[checksumIndex]
result.add(checksumWord)
return result.joinToString(" ")
listOf(w1, w2, w3).map(Long::toInt).map { wordSet[it] }
}.flatten().let {
val checksumIndex = determineChecksumIndex(it, prefixLength)
it + it[checksumIndex]
}.joinToString(" ")
}
fun decode(mnemonic: String, languageConfiguration: Language.Configuration = Language.Configuration.english): String {
val words = mnemonic.split(" ").toMutableList()
val words = mnemonic.split(" ")
val language = Language(loadFileContents, languageConfiguration)
val truncatedWordSet = language.loadTruncatedWordSet()
val prefixLength = languageConfiguration.prefixLength
var result = ""
val n = truncatedWordSet.size.toLong()
if (mnemonic.isEmpty()) throw IllegalArgumentException()
if (words.isEmpty()) throw IllegalArgumentException()
fun String.prefix() = substring(0 until prefixLength)
// Throw on invalid words, as this is the most difficult issue for a user to solve, do this first.
val wordPrefixes = words
.onEach { if (it.length < prefixLength) throw DecodingError.InvalidWord }
.map { it.prefix() }
val wordIndexes = wordPrefixes.map { truncatedWordSet.indexOf(it) }
.onEach { if (it < 0) throw DecodingError.InvalidWord }
// Check preconditions
if (words.size < 12) throw DecodingError.InputTooShort
if (words.size % 3 == 0) throw DecodingError.MissingLastWord
// Get checksum word
val checksumWord = words.removeAt(words.lastIndex)
// Decode
for (chunkStartIndex in 0..(words.size - 3) step 3) {
try {
val w1 = truncatedWordSet.indexOf(words[chunkStartIndex].substring(0 until prefixLength))
val w2 = truncatedWordSet.indexOf(words[chunkStartIndex + 1].substring(0 until prefixLength))
val w3 = truncatedWordSet.indexOf(words[chunkStartIndex + 2].substring(0 until prefixLength))
val x = w1 + n * ((n - w1 + w2) % n) + n * n * ((n - w2 + w3) % n)
if (x % n != w1.toLong()) { throw DecodingError.Generic
}
val string = "0000000" + x.toString(16)
result += swap(string.substring(string.length - 8 until string.length))
} catch (e: Exception) {
throw DecodingError.InvalidWord
}
}
if (words.size < 13) throw DecodingError.InputTooShort
// Verify checksum
val checksumIndex = determineChecksumIndex(words, prefixLength)
val checksumIndex = determineChecksumIndex(words.dropLast(1), prefixLength)
val expectedChecksumWord = words[checksumIndex]
if (expectedChecksumWord.substring(0 until prefixLength) != checksumWord.substring(0 until prefixLength)) { throw DecodingError.VerificationFailed
if (expectedChecksumWord.prefix() != wordPrefixes.last()) {
throw DecodingError.VerificationFailed
}
// Return
return result
// Decode
return wordIndexes.windowed(3, 3) { (w1, w2, w3) ->
val x = w1 + n * ((n - w1 + w2) % n) + n * n * ((n - w2 + w3) % n)
if (x % n != w1.toLong()) throw DecodingError.Generic
val string = "0000000" + x.toString(16)
swap(string.substring(string.length - 8 until string.length))
}.joinToString(separator = "") { it }
}
private fun swap(x: String): String {
@@ -116,9 +117,7 @@ class MnemonicCodec(private val loadFileContents: (String) -> String) {
private fun determineChecksumIndex(x: List<String>, prefixLength: Int): Int {
val bytes = x.joinToString("") { it.substring(0 until prefixLength) }.toByteArray()
val crc32 = CRC32()
crc32.update(bytes)
val checksum = crc32.value
val checksum = CRC32().apply { update(bytes) }.value
return (checksum % x.size.toLong()).toInt()
}
}