Skip to content

Commit

Permalink
improved bip32 derivation (#112)
Browse files Browse the repository at this point in the history
* improved bip32 derivation

* save redundant EC multiplication

* cleaned up code
  • Loading branch information
fametrano authored Jul 7, 2023
1 parent 28bbeaf commit 92633a3
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 38 deletions.
58 changes: 25 additions & 33 deletions btclib/bip32/bip32.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,43 +309,27 @@ def __init__(
self.assert_valid()


def __prv_key_derivation(xkey: _BIP32KeyData, index: int) -> None:
xkey.index = index
Q_bytes = bytes_from_point(mult(xkey.prv_key_int))
xkey.parent_fingerprint = hash160(Q_bytes)[:4]
hmac_ = (
hmac.new(
xkey.chain_code,
xkey.key + index.to_bytes(4, byteorder="big", signed=False),
"sha512",
).digest()
if xkey.is_hardened
else hmac.new(
xkey.chain_code,
Q_bytes + index.to_bytes(4, byteorder="big", signed=False),
"sha512",
).digest()
def __prv_key_derivation(xkey: _BIP32KeyData, index: int, pub_key: bytes) -> None:
xb = (
xkey.key
if index >= 0x80000000
else pub_key or bytes_from_point(mult(xkey.prv_key_int))
)
xb += index.to_bytes(4, byteorder="big", signed=False)
hmac_ = hmac.new(xkey.chain_code, xb, "sha512").digest()
xkey.chain_code = hmac_[32:]
offset = int.from_bytes(hmac_[:32], byteorder="big", signed=False)
xkey.prv_key_int = (xkey.prv_key_int + offset) % ec.n
xkey.key = b"\x00" + xkey.prv_key_int.to_bytes(32, byteorder="big", signed=False)
xkey.pub_key_point = INF


def __pub_key_derivation(xkey: _BIP32KeyData, index: int) -> None:
xkey.index = index
xkey.parent_fingerprint = hash160(xkey.key)[:4]
hmac_ = hmac.new(
xkey.chain_code,
xkey.key + index.to_bytes(4, byteorder="big", signed=False),
"sha512",
).digest()
xb = xkey.key + index.to_bytes(4, byteorder="big", signed=False)
hmac_ = hmac.new(xkey.chain_code, xb, "sha512").digest()
xkey.chain_code = hmac_[32:]
offset = int.from_bytes(hmac_[:32], byteorder="big", signed=False)
xkey.pub_key_point = ec.add(xkey.pub_key_point, mult(offset))
xkey.key = bytes_from_point(xkey.pub_key_point)
xkey.prv_key_int = 0


def _derive(
Expand Down Expand Up @@ -382,14 +366,22 @@ def _derive(
raise BTClibValueError(err_msg)
xkey.version = fversion

if xkey.is_private:
for index in indexes:
__prv_key_derivation(xkey, index)
else:
if any(index >= 0x80000000 for index in indexes):
raise BTClibValueError("invalid hardened derivation from public key")
for index in indexes:
__pub_key_derivation(xkey, index)
if indexes:
if xkey.is_private:
for index in indexes[:-1]:
__prv_key_derivation(xkey, index, b"")
pub_key = bytes_from_point(mult(xkey.prv_key_int))
xkey.parent_fingerprint = hash160(pub_key)[:4]
__prv_key_derivation(xkey, indexes[-1], pub_key)
else:
if any(index >= 0x80000000 for index in indexes):
raise BTClibValueError("invalid hardened derivation from public key")
for index in indexes[:-1]:
__pub_key_derivation(xkey, index)
xkey.parent_fingerprint = hash160(xkey.key)[:4]
__pub_key_derivation(xkey, indexes[-1])

xkey.index = indexes[-1]

return xkey

Expand Down
23 changes: 19 additions & 4 deletions tests/bip32/test_bip32.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest

from btclib import base58, hashes
from btclib import base58
from btclib.b58 import p2pkh # FIXME why it is needed here
from btclib.bip32 import (
BIP32KeyData,
Expand All @@ -29,6 +29,7 @@
from btclib.bip32.bip32 import _derive
from btclib.bip32.der_path import _indexes_from_bip32_path_str
from btclib.exceptions import BTClibValueError
from btclib.hashes import hash160
from btclib.to_pub_key import pub_keyinfo_from_key


Expand Down Expand Up @@ -78,7 +79,7 @@ def test_assert_valid2() -> None:
xkey_data.assert_valid()

xkey_data = BIP32KeyData.b58decode(xkey)
xkey_data.depth = tuple() # type: ignore[assignment]
xkey_data.depth = () # type: ignore[assignment]
with pytest.raises(TypeError):
xkey_data.assert_valid()

Expand All @@ -103,7 +104,7 @@ def test_assert_valid2() -> None:
xkey_data.assert_valid()

xkey_data = BIP32KeyData.b58decode(xkey)
xkey_data.index = tuple() # type: ignore[assignment]
xkey_data.index = () # type: ignore[assignment]
with pytest.raises(TypeError):
xkey_data.assert_valid()

Expand Down Expand Up @@ -224,7 +225,7 @@ def test_derive_exceptions() -> None:
assert rootmxprv == derive(xprv, "m")
assert rootmxprv == derive(xprv, "")

fingerprint = hashes.hash160(pub_keyinfo_from_key(xprv)[0])[:4]
fingerprint = hash160(pub_keyinfo_from_key(xprv)[0])[:4]
assert fingerprint == _derive(xprv, bytes.fromhex("80000000")).parent_fingerprint

for der_path in ("/1", "800000", "80000000"):
Expand Down Expand Up @@ -362,3 +363,17 @@ def test_bips_pr905() -> None:
assert derive(xroot, der_path) == xprv
xpub = "xpub6CpsfWjghR6XdCB8yDq7jQRpRKEDP2LT3ZRUgURF9g5xevB7YoTpogkFRqq5nQtVSN8YCMZo2CD8u4zCaxRv85ctCWmzEi9gQ5DBhBFaTNo"
assert xpub_from_xprv(xprv) == xpub


def test_pub_key_derivation() -> None:
parent_xpub = "xpub6CpsfWjghR6XdCB8yDq7jQRpRKEDP2LT3ZRUgURF9g5xevB7YoTpogkFRqq5nQtVSN8YCMZo2CD8u4zCaxRv85ctCWmzEi9gQ5DBhBFaTNo"
proper_child = "xpub6FCCuDg6j52SWRVZ1TugkjrnGkqPcDuNNKDzohU2mmd4dxiGJypZa535iqYT8KcN2oouRF7A6tXEGAX6HCSjQe7HVSDR4LQ4yUT3HwF1Tqi"
assert derive(parent_xpub, "m/0") == proper_child
parent_key = BIP32KeyData.b58decode(parent_xpub).key
parent_fingerprint = hash160(parent_key)[:4]
assert BIP32KeyData.b58decode(proper_child).parent_fingerprint == parent_fingerprint

orphan_child_key = BIP32KeyData.b58decode(proper_child)
orphan_child_key.parent_fingerprint = b"\x00" * 4
orphan_child = "xpub6DXuQW1FgeHbhsSchbuDWE9Bj8mPiPUpiroAmAvRdRqYbGHXHTyEkttkxSvtCac64QzpasL1Tvd5Znvn5GQMswQUrpRBsPRz7npvyZ8ExWi"
assert orphan_child_key.b58encode() == orphan_child
2 changes: 1 addition & 1 deletion tests/mnemonic/test_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_exceptions() -> None:
bin_str_entropy_from_entropy(bytes_entropy216, 224)

with pytest.raises(BTClibValueError, match=err_msg):
bin_str_entropy_from_entropy(tuple()) # type: ignore[arg-type]
bin_str_entropy_from_entropy(()) # type: ignore[arg-type]

with pytest.raises(ValueError):
bin_str_entropy_from_int("not an int")
Expand Down

0 comments on commit 92633a3

Please sign in to comment.