Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 59 additions & 20 deletions src/embit/slip39.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,41 @@ def rs1024_polymod(values):
return chk


def rs1024_verify_checksum(cs, data):
def cs_bstring(extendable: bool):
if extendable:
return b"shamir_extendable"
else:
return b"shamir"


def rs1024_verify_checksum(extendable: bool, data):
cs = cs_bstring(extendable)
return rs1024_polymod([x for x in cs] + data) == 1


def rs1024_create_checksum(cs, data):
def rs1024_create_checksum(extendable: bool, data):
cs = cs_bstring(extendable)
values = [x for x in cs] + data
polymod = rs1024_polymod(values + [0, 0, 0]) ^ 1
return [(polymod >> 10 * (2 - i)) & 1023 for i in range(3)]


def _get_salt(id, extendable):
if extendable:
return bytes()
else:
return b"shamir" + id.to_bytes(2, "big")


# function for encryption/decryption
def _crypt(payload, id, exponent, passphrase, indices):
def _crypt(payload, id, extendable, exponent, passphrase, indices):
if len(payload) % 2:
raise ValueError("payload should be an even number of bytes")
else:
half = len(payload) // 2
left = payload[:half]
right = payload[half:]
salt = b"shamir" + id.to_bytes(2, "big")
salt = _get_salt(id, extendable)
for i in indices:
f = hashlib.pbkdf2_hmac(
"sha256",
Expand All @@ -64,6 +80,7 @@ def __init__(
self,
share_bit_length,
id,
extendable,
exponent,
group_index,
group_threshold,
Expand All @@ -74,6 +91,7 @@ def __init__(
):
self.share_bit_length = share_bit_length
self.id = id
self.extendable = extendable
self.exponent = exponent
self.group_index = group_index
if group_index < 0 or group_index > 15:
Expand All @@ -100,10 +118,11 @@ def parse(cls, mnemonic):
# convert mnemonic into bits
words = mnemonic.split()
indices = [SLIP39_WORDS.index(word) for word in words]
if not rs1024_verify_checksum(b"shamir", indices):
raise ValueError("Invalid Checksum")
id = (indices[0] << 5) | (indices[1] >> 5)
exponent = indices[1] & 31
extendable = bool((indices[1] >> 4) & 1)
if not rs1024_verify_checksum(extendable, indices):
raise ValueError("Invalid Checksum")
exponent = indices[1] & 0x0F
group_index = indices[2] >> 6
group_threshold = ((indices[2] >> 2) & 15) + 1
group_count = (((indices[2] & 3) << 2) | (indices[3] >> 8)) + 1
Expand All @@ -120,6 +139,7 @@ def parse(cls, mnemonic):
return cls(
share_bit_length,
id,
extendable,
exponent,
group_index,
group_threshold,
Expand All @@ -130,7 +150,7 @@ def parse(cls, mnemonic):
)

def mnemonic(self):
all_bits = (self.id << 5) | self.exponent
all_bits = (self.id << 5) | self.extendable << 4 | self.exponent
all_bits <<= 4
all_bits |= self.group_index
all_bits <<= 4
Expand All @@ -148,7 +168,7 @@ def mnemonic(self):
indices = [
(all_bits >> 10 * (num_words - i - 1)) & 1023 for i in range(num_words)
]
checksum = rs1024_create_checksum(b"shamir", indices)
checksum = rs1024_create_checksum(self.extendable, indices)
return " ".join([SLIP39_WORDS[index] for index in indices + checksum])


Expand All @@ -174,6 +194,10 @@ def __init__(self, shares):
ids = {s.id for s in shares}
if len(ids) != 1:
raise TypeError("Shares are from different secrets")
# check that the extendable flags are the same
extendable = {s.extendable for s in shares}
if len(extendable) != 1:
raise TypeError("Shares should have the same extendable flag")
# check that the exponents are the same
exponents = {s.exponent for s in shares}
if len(exponents) != 1:
Expand All @@ -196,6 +220,7 @@ def __init__(self, shares):
if len(xs) != len(shares):
raise ValueError("Share indices should be unique")
self.id = shares[0].id
self.extendable = shares[0].extendable
self.salt = b"shamir" + self.id.to_bytes(2, "big")
self.exponent = shares[0].exponent
self.group_threshold = shares[0].group_threshold
Expand All @@ -205,13 +230,13 @@ def __init__(self, shares):
def decrypt(self, secret, passphrase=b""):
# decryption does the reverse of encryption
indices = (b"\x03", b"\x02", b"\x01", b"\x00")
return _crypt(secret, self.id, self.exponent, passphrase, indices)
return _crypt(secret, self.id, self.extendable, self.exponent, passphrase, indices)

@classmethod
def encrypt(cls, payload, id, exponent, passphrase=b""):
def encrypt(cls, payload, id, extendable, exponent, passphrase=b""):
# encryption goes from 0 to 3 in bytes
indices = (b"\x00", b"\x01", b"\x02", b"\x03")
return _crypt(payload, id, exponent, passphrase, indices)
return _crypt(payload, id, extendable, exponent, passphrase, indices)

@classmethod
def interpolate(cls, x, share_data):
Expand Down Expand Up @@ -317,28 +342,29 @@ def split_secret(cls, secret, k, n, randint=secure_randint):
return more_data

@classmethod
def generate_shares(
cls, mnemonic, k, n, passphrase=b"", exponent=0, randint=secure_randint
def generate_shares_from_secret(
cls, secret, k, n, passphrase=b"", extendable=False, exponent=0, identifier=-1, randint=secure_randint,
):
"""Takes a BIP39 mnemonic along with k, n, passphrase and exponent.
"""Takes a seed along with k, n, passphrase, extendable flag and exponent.
Returns a list of SLIP39 mnemonics, any k of of which, along with the passphrase, recover the secret
"""
# convert mnemonic to a shared secret
secret = mnemonic_to_bytes(mnemonic)
num_bits = len(secret) * 8
if num_bits not in (128, 256):
raise ValueError("mnemonic must be 12 or 24 words")
# generate id
id = randint(0, 32767)

# generate id if set to -1
id = identifier if identifier > -1 & identifier < 32768 else randint(0, 32767)

# encrypt secret with passphrase
encrypted = cls.encrypt(secret, id, exponent, passphrase)
encrypted = cls.encrypt(secret, id, extendable, exponent, passphrase)
# split encrypted payload and create shares
shares = []
data = cls.split_secret(encrypted, k, n, randint=randint)
for group_index, share_bytes in data:
share = Share(
share_bit_length=num_bits,
id=id,
extendable=extendable,
exponent=exponent,
group_index=group_index,
group_threshold=k,
Expand All @@ -350,6 +376,19 @@ def generate_shares(
shares.append(share.mnemonic())
return shares

@classmethod
def generate_shares(
cls, mnemonic, k, n, passphrase=b"", extendable=False, exponent=0, randint=secure_randint
):
"""Takes a BIP39 mnemonic along with k, n, passphrase, extendable flag and exponent.
Returns a list of SLIP39 mnemonics, any k of of which, along with the passphrase, recover the secret
"""
# convert mnemonic to a shared secret
secret = mnemonic_to_bytes(mnemonic)
return cls.generate_shares_from_secret(
secret, k, n, passphrase, extendable, exponent, -1, randint
)

@classmethod
def recover_mnemonic(cls, share_mnemonics, passphrase=b""):
"""Recovers the BIP39 mnemonic from a bunch of SLIP39 mnemonics"""
Expand Down
52 changes: 51 additions & 1 deletion tests/tests/test_slip39.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_share_errors(self):
"fraction necklace academic academic award teammate mouse regular testify coding building member verdict purchase blind camera duration email prepare spirit quarter"
],
SyntaxError,
],
]
]
for test_name, mnemonics, error in test_cases:
with self.assertRaises(error):
Expand Down Expand Up @@ -342,6 +342,47 @@ def test_recover(self):
],
"5385577c8cfc6c1a8aa0f7f10ecde0a3318493262591e78b8c14c6686167123b",
],

[
"41. Valid mnemonics which can detect some errors in modular arithmetic",
[
"herald flea academic cage avoid space trend estate dryer hairy evoke eyebrow improve airline artwork garlic premium duration prevent oven",
"herald flea academic client blue skunk class goat luxury deny presence impulse graduate clay join blanket bulge survive dish necklace",
"herald flea academic acne advance fused brother frozen broken game ranked ajar already believe check install theory angry exercise adult"
],
"ad6f2ad8b59bbbaa01369b9006208d9a",
# "xprv9s21ZrQH143K2R4HJxcG1eUsudvHM753BZ9vaGkpYCoeEhCQx147C5qEcupPHxcXYfdYMwJmsKXrHDhtEwutxTTvFzdDCZVQwHneeQH8ioH"
],
[
"42. Valid extendable mnemonic without sharing (128 bits)",
[
"testify swimming academic academic column loyalty smear include exotic bedroom exotic wrist lobe cover grief golden smart junior estimate learn"
],
"1679b4516e0ee5954351d288a838f45e",
],
[
"43. Extendable basic sharing 2-of-3 (128 bits)",
[
"enemy favorite academic acid cowboy phrase havoc level response walnut budget painting inside trash adjust froth kitchen learn tidy punish",
"enemy favorite academic always academic sniff script carpet romp kind promise scatter center unfair training emphasis evening belong fake enforce"
],
"48b1a4b80b8c209ad42c33672bdaa428",
],
[
"44. Valid extendable mnemonic without sharing (256 bits)",
[
"impulse calcium academic academic alcohol sugar lyrics pajamas column facility finance tension extend space birthday rainbow swimming purple syndrome facility trial warn duration snapshot shadow hormone rhyme public spine counter easy hawk album"
],
"8340611602fe91af634a5f4608377b5235fa2d757c51d720c0c7656249a3035f",
],
[
"45. Extendable basic sharing 2-of-3 (256 bits)",
[
"western apart academic always artist resident briefing sugar woman oven coding club ajar merit pecan answer prisoner artist fraction amount desktop mild false necklace muscle photo wealthy alpha category unwrap spew losing making",
"western apart academic acid answer ancient auction flip image penalty oasis beaver multiple thunder problem switch alive heat inherit superior teaspoon explain blanket pencil numb lend punish endless aunt garlic humidity kidney observe"
],
"8dc652d6d6cd370d8c963141f6d79ba440300f25c467302c1d966bff8f62300d",
]
]
for test_name, mnemonics, expected in test_cases:
share_set = ShareSet([Share.parse(m) for m in mnemonics])
Expand All @@ -355,6 +396,15 @@ def test_split(self):
share_data = ShareSet.split_secret(secret, k, n)
self.assertEqual(secret, ShareSet.interpolate(255, share_data[:k]))


def test_split_extendable(self):
secret = unhexlify("7c3397a292a5941682d7a4ae2d898d11")
mnemonics = ShareSet.generate_shares_from_secret(secret, 3, 5, passphrase=b"TREZOR", identifier=42)
share_set = ShareSet([Share.parse(m) for m in mnemonics])
self.assertEqual(
share_set.recover(passphrase=b"TREZOR"), unhexlify("7c3397a292a5941682d7a4ae2d898d11")
)

def test_generate(self):
test_cases = [
"zoo zoo zoo zoo zoo zoo zoo zoo zoo zoo zoo wrong",
Expand Down