diff --git a/src/embit/slip39.py b/src/embit/slip39.py index 86927b3..5ea03ef 100644 --- a/src/embit/slip39.py +++ b/src/embit/slip39.py @@ -28,25 +28,41 @@ def rs1024_polymod(values): return chk -def rs1024_verify_checksum(cs, data): +def cs_bstring(extendable: bool): + if extendable: + return b"shamir_extendable" + else: + return b"shamir" + + +def rs1024_verify_checksum(extendable: bool, data): + cs = cs_bstring(extendable) return rs1024_polymod([x for x in cs] + data) == 1 -def rs1024_create_checksum(cs, data): +def rs1024_create_checksum(extendable: bool, data): + cs = cs_bstring(extendable) values = [x for x in cs] + data polymod = rs1024_polymod(values + [0, 0, 0]) ^ 1 return [(polymod >> 10 * (2 - i)) & 1023 for i in range(3)] +def _get_salt(id, extendable): + if extendable: + return bytes() + else: + return b"shamir" + id.to_bytes(2, "big") + + # function for encryption/decryption -def _crypt(payload, id, exponent, passphrase, indices): +def _crypt(payload, id, extendable, exponent, passphrase, indices): if len(payload) % 2: raise ValueError("payload should be an even number of bytes") else: half = len(payload) // 2 left = payload[:half] right = payload[half:] - salt = b"shamir" + id.to_bytes(2, "big") + salt = _get_salt(id, extendable) for i in indices: f = hashlib.pbkdf2_hmac( "sha256", @@ -64,6 +80,7 @@ def __init__( self, share_bit_length, id, + extendable, exponent, group_index, group_threshold, @@ -74,6 +91,7 @@ def __init__( ): self.share_bit_length = share_bit_length self.id = id + self.extendable = extendable self.exponent = exponent self.group_index = group_index if group_index < 0 or group_index > 15: @@ -100,10 +118,11 @@ def parse(cls, mnemonic): # convert mnemonic into bits words = mnemonic.split() indices = [SLIP39_WORDS.index(word) for word in words] - if not rs1024_verify_checksum(b"shamir", indices): - raise ValueError("Invalid Checksum") id = (indices[0] << 5) | (indices[1] >> 5) - exponent = indices[1] & 31 + extendable = bool((indices[1] >> 4) & 1) + if not rs1024_verify_checksum(extendable, indices): + raise ValueError("Invalid Checksum") + exponent = indices[1] & 0x0F group_index = indices[2] >> 6 group_threshold = ((indices[2] >> 2) & 15) + 1 group_count = (((indices[2] & 3) << 2) | (indices[3] >> 8)) + 1 @@ -120,6 +139,7 @@ def parse(cls, mnemonic): return cls( share_bit_length, id, + extendable, exponent, group_index, group_threshold, @@ -130,7 +150,7 @@ def parse(cls, mnemonic): ) def mnemonic(self): - all_bits = (self.id << 5) | self.exponent + all_bits = (self.id << 5) | self.extendable << 4 | self.exponent all_bits <<= 4 all_bits |= self.group_index all_bits <<= 4 @@ -148,7 +168,7 @@ def mnemonic(self): indices = [ (all_bits >> 10 * (num_words - i - 1)) & 1023 for i in range(num_words) ] - checksum = rs1024_create_checksum(b"shamir", indices) + checksum = rs1024_create_checksum(self.extendable, indices) return " ".join([SLIP39_WORDS[index] for index in indices + checksum]) @@ -174,6 +194,10 @@ def __init__(self, shares): ids = {s.id for s in shares} if len(ids) != 1: raise TypeError("Shares are from different secrets") + # check that the extendable flags are the same + extendable = {s.extendable for s in shares} + if len(extendable) != 1: + raise TypeError("Shares should have the same extendable flag") # check that the exponents are the same exponents = {s.exponent for s in shares} if len(exponents) != 1: @@ -196,6 +220,7 @@ def __init__(self, shares): if len(xs) != len(shares): raise ValueError("Share indices should be unique") self.id = shares[0].id + self.extendable = shares[0].extendable self.salt = b"shamir" + self.id.to_bytes(2, "big") self.exponent = shares[0].exponent self.group_threshold = shares[0].group_threshold @@ -205,13 +230,13 @@ def __init__(self, shares): def decrypt(self, secret, passphrase=b""): # decryption does the reverse of encryption indices = (b"\x03", b"\x02", b"\x01", b"\x00") - return _crypt(secret, self.id, self.exponent, passphrase, indices) + return _crypt(secret, self.id, self.extendable, self.exponent, passphrase, indices) @classmethod - def encrypt(cls, payload, id, exponent, passphrase=b""): + def encrypt(cls, payload, id, extendable, exponent, passphrase=b""): # encryption goes from 0 to 3 in bytes indices = (b"\x00", b"\x01", b"\x02", b"\x03") - return _crypt(payload, id, exponent, passphrase, indices) + return _crypt(payload, id, extendable, exponent, passphrase, indices) @classmethod def interpolate(cls, x, share_data): @@ -317,21 +342,21 @@ def split_secret(cls, secret, k, n, randint=secure_randint): return more_data @classmethod - def generate_shares( - cls, mnemonic, k, n, passphrase=b"", exponent=0, randint=secure_randint + def generate_shares_from_secret( + cls, secret, k, n, passphrase=b"", extendable=False, exponent=0, identifier=-1, randint=secure_randint, ): - """Takes a BIP39 mnemonic along with k, n, passphrase and exponent. + """Takes a seed along with k, n, passphrase, extendable flag and exponent. Returns a list of SLIP39 mnemonics, any k of of which, along with the passphrase, recover the secret """ - # convert mnemonic to a shared secret - secret = mnemonic_to_bytes(mnemonic) num_bits = len(secret) * 8 if num_bits not in (128, 256): raise ValueError("mnemonic must be 12 or 24 words") - # generate id - id = randint(0, 32767) + + # generate id if set to -1 + id = identifier if identifier > -1 & identifier < 32768 else randint(0, 32767) + # encrypt secret with passphrase - encrypted = cls.encrypt(secret, id, exponent, passphrase) + encrypted = cls.encrypt(secret, id, extendable, exponent, passphrase) # split encrypted payload and create shares shares = [] data = cls.split_secret(encrypted, k, n, randint=randint) @@ -339,6 +364,7 @@ def generate_shares( share = Share( share_bit_length=num_bits, id=id, + extendable=extendable, exponent=exponent, group_index=group_index, group_threshold=k, @@ -350,6 +376,19 @@ def generate_shares( shares.append(share.mnemonic()) return shares + @classmethod + def generate_shares( + cls, mnemonic, k, n, passphrase=b"", extendable=False, exponent=0, randint=secure_randint + ): + """Takes a BIP39 mnemonic along with k, n, passphrase, extendable flag and exponent. + Returns a list of SLIP39 mnemonics, any k of of which, along with the passphrase, recover the secret + """ + # convert mnemonic to a shared secret + secret = mnemonic_to_bytes(mnemonic) + return cls.generate_shares_from_secret( + secret, k, n, passphrase, extendable, exponent, -1, randint + ) + @classmethod def recover_mnemonic(cls, share_mnemonics, passphrase=b""): """Recovers the BIP39 mnemonic from a bunch of SLIP39 mnemonics""" diff --git a/tests/tests/test_slip39.py b/tests/tests/test_slip39.py index 088257d..c1de533 100644 --- a/tests/tests/test_slip39.py +++ b/tests/tests/test_slip39.py @@ -247,7 +247,7 @@ def test_share_errors(self): "fraction necklace academic academic award teammate mouse regular testify coding building member verdict purchase blind camera duration email prepare spirit quarter" ], SyntaxError, - ], + ] ] for test_name, mnemonics, error in test_cases: with self.assertRaises(error): @@ -342,6 +342,47 @@ def test_recover(self): ], "5385577c8cfc6c1a8aa0f7f10ecde0a3318493262591e78b8c14c6686167123b", ], + + [ + "41. Valid mnemonics which can detect some errors in modular arithmetic", + [ + "herald flea academic cage avoid space trend estate dryer hairy evoke eyebrow improve airline artwork garlic premium duration prevent oven", + "herald flea academic client blue skunk class goat luxury deny presence impulse graduate clay join blanket bulge survive dish necklace", + "herald flea academic acne advance fused brother frozen broken game ranked ajar already believe check install theory angry exercise adult" + ], + "ad6f2ad8b59bbbaa01369b9006208d9a", + # "xprv9s21ZrQH143K2R4HJxcG1eUsudvHM753BZ9vaGkpYCoeEhCQx147C5qEcupPHxcXYfdYMwJmsKXrHDhtEwutxTTvFzdDCZVQwHneeQH8ioH" + ], + [ + "42. Valid extendable mnemonic without sharing (128 bits)", + [ + "testify swimming academic academic column loyalty smear include exotic bedroom exotic wrist lobe cover grief golden smart junior estimate learn" + ], + "1679b4516e0ee5954351d288a838f45e", + ], + [ + "43. Extendable basic sharing 2-of-3 (128 bits)", + [ + "enemy favorite academic acid cowboy phrase havoc level response walnut budget painting inside trash adjust froth kitchen learn tidy punish", + "enemy favorite academic always academic sniff script carpet romp kind promise scatter center unfair training emphasis evening belong fake enforce" + ], + "48b1a4b80b8c209ad42c33672bdaa428", + ], + [ + "44. Valid extendable mnemonic without sharing (256 bits)", + [ + "impulse calcium academic academic alcohol sugar lyrics pajamas column facility finance tension extend space birthday rainbow swimming purple syndrome facility trial warn duration snapshot shadow hormone rhyme public spine counter easy hawk album" + ], + "8340611602fe91af634a5f4608377b5235fa2d757c51d720c0c7656249a3035f", + ], + [ + "45. Extendable basic sharing 2-of-3 (256 bits)", + [ + "western apart academic always artist resident briefing sugar woman oven coding club ajar merit pecan answer prisoner artist fraction amount desktop mild false necklace muscle photo wealthy alpha category unwrap spew losing making", + "western apart academic acid answer ancient auction flip image penalty oasis beaver multiple thunder problem switch alive heat inherit superior teaspoon explain blanket pencil numb lend punish endless aunt garlic humidity kidney observe" + ], + "8dc652d6d6cd370d8c963141f6d79ba440300f25c467302c1d966bff8f62300d", + ] ] for test_name, mnemonics, expected in test_cases: share_set = ShareSet([Share.parse(m) for m in mnemonics]) @@ -355,6 +396,15 @@ def test_split(self): share_data = ShareSet.split_secret(secret, k, n) self.assertEqual(secret, ShareSet.interpolate(255, share_data[:k])) + + def test_split_extendable(self): + secret = unhexlify("7c3397a292a5941682d7a4ae2d898d11") + mnemonics = ShareSet.generate_shares_from_secret(secret, 3, 5, passphrase=b"TREZOR", identifier=42) + share_set = ShareSet([Share.parse(m) for m in mnemonics]) + self.assertEqual( + share_set.recover(passphrase=b"TREZOR"), unhexlify("7c3397a292a5941682d7a4ae2d898d11") + ) + def test_generate(self): test_cases = [ "zoo zoo zoo zoo zoo zoo zoo zoo zoo zoo zoo wrong",