electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

lnaddr.py (19074B)


      1 #! /usr/bin/env python3
      2 # This was forked from https://github.com/rustyrussell/lightning-payencode/tree/acc16ec13a3fa1dc16c07af6ec67c261bd8aff23
      3 
      4 import re
      5 import time
      6 from hashlib import sha256
      7 from binascii import hexlify
      8 from decimal import Decimal
      9 from typing import Optional, TYPE_CHECKING
     10 
     11 import random
     12 import bitstring
     13 
     14 from .bitcoin import hash160_to_b58_address, b58_address_to_hash160, TOTAL_COIN_SUPPLY_LIMIT_IN_BTC
     15 from .segwit_addr import bech32_encode, bech32_decode, CHARSET
     16 from . import constants
     17 from . import ecc
     18 from .bitcoin import COIN
     19 
     20 if TYPE_CHECKING:
     21     from .lnutil import LnFeatures
     22 
     23 
     24 # BOLT #11:
     25 #
     26 # A writer MUST encode `amount` as a positive decimal integer with no
     27 # leading zeroes, SHOULD use the shortest representation possible.
     28 def shorten_amount(amount):
     29     """ Given an amount in bitcoin, shorten it
     30     """
     31     # Convert to pico initially
     32     amount = int(amount * 10**12)
     33     units = ['p', 'n', 'u', 'm', '']
     34     for unit in units:
     35         if amount % 1000 == 0:
     36             amount //= 1000
     37         else:
     38             break
     39     return str(amount) + unit
     40 
     41 def unshorten_amount(amount) -> Decimal:
     42     """ Given a shortened amount, convert it into a decimal
     43     """
     44     # BOLT #11:
     45     # The following `multiplier` letters are defined:
     46     #
     47     #* `m` (milli): multiply by 0.001
     48     #* `u` (micro): multiply by 0.000001
     49     #* `n` (nano): multiply by 0.000000001
     50     #* `p` (pico): multiply by 0.000000000001
     51     units = {
     52         'p': 10**12,
     53         'n': 10**9,
     54         'u': 10**6,
     55         'm': 10**3,
     56     }
     57     unit = str(amount)[-1]
     58     # BOLT #11:
     59     # A reader SHOULD fail if `amount` contains a non-digit, or is followed by
     60     # anything except a `multiplier` in the table above.
     61     if not re.fullmatch("\\d+[pnum]?", str(amount)):
     62         raise ValueError("Invalid amount '{}'".format(amount))
     63 
     64     if unit in units.keys():
     65         return Decimal(amount[:-1]) / units[unit]
     66     else:
     67         return Decimal(amount)
     68 
     69 _INT_TO_BINSTR = {a: '0' * (5-len(bin(a)[2:])) + bin(a)[2:] for a in range(32)}
     70 
     71 # Bech32 spits out array of 5-bit values.  Shim here.
     72 def u5_to_bitarray(arr):
     73     b = ''.join(_INT_TO_BINSTR[a] for a in arr)
     74     return bitstring.BitArray(bin=b)
     75 
     76 def bitarray_to_u5(barr):
     77     assert barr.len % 5 == 0
     78     ret = []
     79     s = bitstring.ConstBitStream(barr)
     80     while s.pos != s.len:
     81         ret.append(s.read(5).uint)
     82     return ret
     83 
     84 def encode_fallback(fallback, currency):
     85     """ Encode all supported fallback addresses.
     86     """
     87     if currency in [constants.BitcoinMainnet.SEGWIT_HRP, constants.BitcoinTestnet.SEGWIT_HRP]:
     88         fbhrp, witness = bech32_decode(fallback, ignore_long_length=True)
     89         if fbhrp:
     90             if fbhrp != currency:
     91                 raise ValueError("Not a bech32 address for this currency")
     92             wver = witness[0]
     93             if wver > 16:
     94                 raise ValueError("Invalid witness version {}".format(witness[0]))
     95             wprog = u5_to_bitarray(witness[1:])
     96         else:
     97             addrtype, addr = b58_address_to_hash160(fallback)
     98             if is_p2pkh(currency, addrtype):
     99                 wver = 17
    100             elif is_p2sh(currency, addrtype):
    101                 wver = 18
    102             else:
    103                 raise ValueError("Unknown address type for {}".format(currency))
    104             wprog = addr
    105         return tagged('f', bitstring.pack("uint:5", wver) + wprog)
    106     else:
    107         raise NotImplementedError("Support for currency {} not implemented".format(currency))
    108 
    109 def parse_fallback(fallback, currency):
    110     if currency in [constants.BitcoinMainnet.SEGWIT_HRP, constants.BitcoinTestnet.SEGWIT_HRP]:
    111         wver = fallback[0:5].uint
    112         if wver == 17:
    113             addr=hash160_to_b58_address(fallback[5:].tobytes(), base58_prefix_map[currency][0])
    114         elif wver == 18:
    115             addr=hash160_to_b58_address(fallback[5:].tobytes(), base58_prefix_map[currency][1])
    116         elif wver <= 16:
    117             addr=bech32_encode(currency, bitarray_to_u5(fallback))
    118         else:
    119             return None
    120     else:
    121         addr=fallback.tobytes()
    122     return addr
    123 
    124 
    125 # Map of classical and witness address prefixes
    126 base58_prefix_map = {
    127     constants.BitcoinMainnet.SEGWIT_HRP : (constants.BitcoinMainnet.ADDRTYPE_P2PKH, constants.BitcoinMainnet.ADDRTYPE_P2SH),
    128     constants.BitcoinTestnet.SEGWIT_HRP : (constants.BitcoinTestnet.ADDRTYPE_P2PKH, constants.BitcoinTestnet.ADDRTYPE_P2SH)
    129 }
    130 
    131 def is_p2pkh(currency, prefix):
    132     return prefix == base58_prefix_map[currency][0]
    133 
    134 def is_p2sh(currency, prefix):
    135     return prefix == base58_prefix_map[currency][1]
    136 
    137 # Tagged field containing BitArray
    138 def tagged(char, l):
    139     # Tagged fields need to be zero-padded to 5 bits.
    140     while l.len % 5 != 0:
    141         l.append('0b0')
    142     return bitstring.pack("uint:5, uint:5, uint:5",
    143                           CHARSET.find(char),
    144                           (l.len / 5) / 32, (l.len / 5) % 32) + l
    145 
    146 # Tagged field containing bytes
    147 def tagged_bytes(char, l):
    148     return tagged(char, bitstring.BitArray(l))
    149 
    150 def trim_to_min_length(bits):
    151     """Ensures 'bits' have min number of leading zeroes.
    152     Assumes 'bits' is big-endian, and that it needs to be encoded in 5 bit blocks.
    153     """
    154     bits = bits[:]  # copy
    155     # make sure we can be split into 5 bit blocks
    156     while bits.len % 5 != 0:
    157         bits.prepend('0b0')
    158     # Get minimal length by trimming leading 5 bits at a time.
    159     while bits.startswith('0b00000'):
    160         if len(bits) == 5:
    161             break  # v == 0
    162         bits = bits[5:]
    163     return bits
    164 
    165 # Discard trailing bits, convert to bytes.
    166 def trim_to_bytes(barr):
    167     # Adds a byte if necessary.
    168     b = barr.tobytes()
    169     if barr.len % 8 != 0:
    170         return b[:-1]
    171     return b
    172 
    173 # Try to pull out tagged data: returns tag, tagged data and remainder.
    174 def pull_tagged(stream):
    175     tag = stream.read(5).uint
    176     length = stream.read(5).uint * 32 + stream.read(5).uint
    177     return (CHARSET[tag], stream.read(length * 5), stream)
    178 
    179 def lnencode(addr: 'LnAddr', privkey) -> str:
    180     if addr.amount:
    181         amount = addr.currency + shorten_amount(addr.amount)
    182     else:
    183         amount = addr.currency if addr.currency else ''
    184 
    185     hrp = 'ln' + amount
    186 
    187     # Start with the timestamp
    188     data = bitstring.pack('uint:35', addr.date)
    189 
    190     tags_set = set()
    191 
    192     # Payment hash
    193     data += tagged_bytes('p', addr.paymenthash)
    194     tags_set.add('p')
    195 
    196     if addr.payment_secret is not None:
    197         data += tagged_bytes('s', addr.payment_secret)
    198         tags_set.add('s')
    199 
    200     for k, v in addr.tags:
    201 
    202         # BOLT #11:
    203         #
    204         # A writer MUST NOT include more than one `d`, `h`, `n` or `x` fields,
    205         if k in ('d', 'h', 'n', 'x', 'p', 's'):
    206             if k in tags_set:
    207                 raise ValueError("Duplicate '{}' tag".format(k))
    208 
    209         if k == 'r':
    210             route = bitstring.BitArray()
    211             for step in v:
    212                 pubkey, channel, feebase, feerate, cltv = step
    213                 route.append(bitstring.BitArray(pubkey) + bitstring.BitArray(channel) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv))
    214             data += tagged('r', route)
    215         elif k == 't':
    216             pubkey, feebase, feerate, cltv = v
    217             route = bitstring.BitArray(pubkey) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv)
    218             data += tagged('t', route)
    219         elif k == 'f':
    220             data += encode_fallback(v, addr.currency)
    221         elif k == 'd':
    222             # truncate to max length: 1024*5 bits = 639 bytes
    223             data += tagged_bytes('d', v.encode()[0:639])
    224         elif k == 'x':
    225             expirybits = bitstring.pack('intbe:64', v)
    226             expirybits = trim_to_min_length(expirybits)
    227             data += tagged('x', expirybits)
    228         elif k == 'h':
    229             data += tagged_bytes('h', sha256(v.encode('utf-8')).digest())
    230         elif k == 'n':
    231             data += tagged_bytes('n', v)
    232         elif k == 'c':
    233             finalcltvbits = bitstring.pack('intbe:64', v)
    234             finalcltvbits = trim_to_min_length(finalcltvbits)
    235             data += tagged('c', finalcltvbits)
    236         elif k == '9':
    237             if v == 0:
    238                 continue
    239             feature_bits = bitstring.BitArray(uint=v, length=v.bit_length())
    240             feature_bits = trim_to_min_length(feature_bits)
    241             data += tagged('9', feature_bits)
    242         else:
    243             # FIXME: Support unknown tags?
    244             raise ValueError("Unknown tag {}".format(k))
    245 
    246         tags_set.add(k)
    247 
    248     # BOLT #11:
    249     #
    250     # A writer MUST include either a `d` or `h` field, and MUST NOT include
    251     # both.
    252     if 'd' in tags_set and 'h' in tags_set:
    253         raise ValueError("Cannot include both 'd' and 'h'")
    254     if not 'd' in tags_set and not 'h' in tags_set:
    255         raise ValueError("Must include either 'd' or 'h'")
    256 
    257     # We actually sign the hrp, then data (padded to 8 bits with zeroes).
    258     msg = hrp.encode("ascii") + data.tobytes()
    259     privkey = ecc.ECPrivkey(privkey)
    260     sig = privkey.sign_message(msg, is_compressed=False, algo=lambda x:sha256(x).digest())
    261     recovery_flag = bytes([sig[0] - 27])
    262     sig = bytes(sig[1:]) + recovery_flag
    263     data += sig
    264 
    265     return bech32_encode(hrp, bitarray_to_u5(data))
    266 
    267 class LnAddr(object):
    268     def __init__(self, *, paymenthash: bytes = None, amount=None, currency=None, tags=None, date=None,
    269                  payment_secret: bytes = None):
    270         self.date = int(time.time()) if not date else int(date)
    271         self.tags = [] if not tags else tags
    272         self.unknown_tags = []
    273         self.paymenthash = paymenthash
    274         self.payment_secret = payment_secret
    275         self.signature = None
    276         self.pubkey = None
    277         self.currency = constants.net.SEGWIT_HRP if currency is None else currency
    278         self._amount = amount  # type: Optional[Decimal]  # in bitcoins
    279         self._min_final_cltv_expiry = 18
    280 
    281     @property
    282     def amount(self) -> Optional[Decimal]:
    283         return self._amount
    284 
    285     @amount.setter
    286     def amount(self, value):
    287         if not (isinstance(value, Decimal) or value is None):
    288             raise ValueError(f"amount must be Decimal or None, not {value!r}")
    289         if value is None:
    290             self._amount = None
    291             return
    292         assert isinstance(value, Decimal)
    293         if value.is_nan() or not (0 <= value <= TOTAL_COIN_SUPPLY_LIMIT_IN_BTC):
    294             raise ValueError(f"amount is out-of-bounds: {value!r} BTC")
    295         if value * 10**12 % 10:
    296             # max resolution is millisatoshi
    297             raise ValueError(f"Cannot encode {value!r}: too many decimal places")
    298         self._amount = value
    299 
    300     def get_amount_sat(self) -> Optional[Decimal]:
    301         # note that this has msat resolution potentially
    302         if self.amount is None:
    303             return None
    304         return self.amount * COIN
    305 
    306     def get_routing_info(self, tag):
    307         # note: tag will be 't' for trampoline
    308         r_tags = list(filter(lambda x: x[0] == tag, self.tags))
    309         # strip the tag type, it's implicitly 'r' now
    310         r_tags = list(map(lambda x: x[1], r_tags))
    311         # if there are multiple hints, we will use the first one that works,
    312         # from a random permutation
    313         random.shuffle(r_tags)
    314         return r_tags
    315 
    316     def get_amount_msat(self) -> Optional[int]:
    317         if self.amount is None:
    318             return None
    319         return int(self.amount * COIN * 1000)
    320 
    321     def get_features(self) -> 'LnFeatures':
    322         from .lnutil import LnFeatures
    323         return LnFeatures(self.get_tag('9') or 0)
    324 
    325     def __str__(self):
    326         return "LnAddr[{}, amount={}{} tags=[{}]]".format(
    327             hexlify(self.pubkey.serialize()).decode('utf-8') if self.pubkey else None,
    328             self.amount, self.currency,
    329             ", ".join([k + '=' + str(v) for k, v in self.tags])
    330         )
    331 
    332     def get_min_final_cltv_expiry(self) -> int:
    333         return self._min_final_cltv_expiry
    334 
    335     def get_tag(self, tag):
    336         for k, v in self.tags:
    337             if k == tag:
    338                 return v
    339         return None
    340 
    341     def get_description(self) -> str:
    342         return self.get_tag('d') or ''
    343 
    344     def get_expiry(self) -> int:
    345         exp = self.get_tag('x')
    346         if exp is None:
    347             exp = 3600
    348         return int(exp)
    349 
    350     def is_expired(self) -> bool:
    351         now = time.time()
    352         # BOLT-11 does not specify what expiration of '0' means.
    353         # we treat it as 0 seconds here (instead of never)
    354         return now > self.get_expiry() + self.date
    355 
    356 
    357 class LnDecodeException(Exception): pass
    358 
    359 class SerializableKey:
    360     def __init__(self, pubkey):
    361         self.pubkey = pubkey
    362     def serialize(self):
    363         return self.pubkey.get_public_key_bytes(True)
    364 
    365 def lndecode(invoice: str, *, verbose=False, expected_hrp=None) -> LnAddr:
    366     if expected_hrp is None:
    367         expected_hrp = constants.net.SEGWIT_HRP
    368     hrp, data = bech32_decode(invoice, ignore_long_length=True)
    369     if not hrp:
    370         raise ValueError("Bad bech32 checksum")
    371 
    372     # BOLT #11:
    373     #
    374     # A reader MUST fail if it does not understand the `prefix`.
    375     if not hrp.startswith('ln'):
    376         raise ValueError("Does not start with ln")
    377 
    378     if not hrp[2:].startswith(expected_hrp):
    379         raise ValueError("Wrong Lightning invoice HRP " + hrp[2:] + ", should be " + expected_hrp)
    380 
    381     data = u5_to_bitarray(data)
    382 
    383     # Final signature 65 bytes, split it off.
    384     if len(data) < 65*8:
    385         raise ValueError("Too short to contain signature")
    386     sigdecoded = data[-65*8:].tobytes()
    387     data = bitstring.ConstBitStream(data[:-65*8])
    388 
    389     addr = LnAddr()
    390     addr.pubkey = None
    391 
    392     m = re.search("[^\\d]+", hrp[2:])
    393     if m:
    394         addr.currency = m.group(0)
    395         amountstr = hrp[2+m.end():]
    396         # BOLT #11:
    397         #
    398         # A reader SHOULD indicate if amount is unspecified, otherwise it MUST
    399         # multiply `amount` by the `multiplier` value (if any) to derive the
    400         # amount required for payment.
    401         if amountstr != '':
    402             addr.amount = unshorten_amount(amountstr)
    403 
    404     addr.date = data.read(35).uint
    405 
    406     while data.pos != data.len:
    407         tag, tagdata, data = pull_tagged(data)
    408 
    409         # BOLT #11:
    410         #
    411         # A reader MUST skip over unknown fields, an `f` field with unknown
    412         # `version`, or a `p`, `h`, or `n` field which does not have
    413         # `data_length` 52, 52, or 53 respectively.
    414         data_length = len(tagdata) / 5
    415 
    416         if tag == 'r':
    417             # BOLT #11:
    418             #
    419             # * `r` (3): `data_length` variable.  One or more entries
    420             # containing extra routing information for a private route;
    421             # there may be more than one `r` field, too.
    422             #    * `pubkey` (264 bits)
    423             #    * `short_channel_id` (64 bits)
    424             #    * `feebase` (32 bits, big-endian)
    425             #    * `feerate` (32 bits, big-endian)
    426             #    * `cltv_expiry_delta` (16 bits, big-endian)
    427             route=[]
    428             s = bitstring.ConstBitStream(tagdata)
    429             while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
    430                 route.append((s.read(264).tobytes(),
    431                               s.read(64).tobytes(),
    432                               s.read(32).uintbe,
    433                               s.read(32).uintbe,
    434                               s.read(16).uintbe))
    435             addr.tags.append(('r',route))
    436         elif tag == 't':
    437             s = bitstring.ConstBitStream(tagdata)
    438             e = (s.read(264).tobytes(),
    439                  s.read(32).uintbe,
    440                  s.read(32).uintbe,
    441                  s.read(16).uintbe)
    442             addr.tags.append(('t', e))
    443         elif tag == 'f':
    444             fallback = parse_fallback(tagdata, addr.currency)
    445             if fallback:
    446                 addr.tags.append(('f', fallback))
    447             else:
    448                 # Incorrect version.
    449                 addr.unknown_tags.append((tag, tagdata))
    450                 continue
    451 
    452         elif tag == 'd':
    453             addr.tags.append(('d', trim_to_bytes(tagdata).decode('utf-8')))
    454 
    455         elif tag == 'h':
    456             if data_length != 52:
    457                 addr.unknown_tags.append((tag, tagdata))
    458                 continue
    459             addr.tags.append(('h', trim_to_bytes(tagdata)))
    460 
    461         elif tag == 'x':
    462             addr.tags.append(('x', tagdata.uint))
    463 
    464         elif tag == 'p':
    465             if data_length != 52:
    466                 addr.unknown_tags.append((tag, tagdata))
    467                 continue
    468             addr.paymenthash = trim_to_bytes(tagdata)
    469 
    470         elif tag == 's':
    471             if data_length != 52:
    472                 addr.unknown_tags.append((tag, tagdata))
    473                 continue
    474             addr.payment_secret = trim_to_bytes(tagdata)
    475 
    476         elif tag == 'n':
    477             if data_length != 53:
    478                 addr.unknown_tags.append((tag, tagdata))
    479                 continue
    480             pubkeybytes = trim_to_bytes(tagdata)
    481             addr.pubkey = pubkeybytes
    482 
    483         elif tag == 'c':
    484             addr._min_final_cltv_expiry = tagdata.uint
    485 
    486         elif tag == '9':
    487             features = tagdata.uint
    488             addr.tags.append(('9', features))
    489             from .lnutil import validate_features
    490             validate_features(features)
    491 
    492         else:
    493             addr.unknown_tags.append((tag, tagdata))
    494 
    495     if verbose:
    496         print('hex of signature data (32 byte r, 32 byte s): {}'
    497               .format(hexlify(sigdecoded[0:64])))
    498         print('recovery flag: {}'.format(sigdecoded[64]))
    499         print('hex of data for signing: {}'
    500               .format(hexlify(hrp.encode("ascii") + data.tobytes())))
    501         print('SHA256 of above: {}'.format(sha256(hrp.encode("ascii") + data.tobytes()).hexdigest()))
    502 
    503     # BOLT #11:
    504     #
    505     # A reader MUST check that the `signature` is valid (see the `n` tagged
    506     # field specified below).
    507     addr.signature = sigdecoded[:65]
    508     hrp_hash = sha256(hrp.encode("ascii") + data.tobytes()).digest()
    509     if addr.pubkey: # Specified by `n`
    510         # BOLT #11:
    511         #
    512         # A reader MUST use the `n` field to validate the signature instead of
    513         # performing signature recovery if a valid `n` field is provided.
    514         ecc.ECPubkey(addr.pubkey).verify_message_hash(sigdecoded[:64], hrp_hash)
    515         pubkey_copy = addr.pubkey
    516         class WrappedBytesKey:
    517             serialize = lambda: pubkey_copy
    518         addr.pubkey = WrappedBytesKey
    519     else: # Recover pubkey from signature.
    520         addr.pubkey = SerializableKey(ecc.ECPubkey.from_sig_string(sigdecoded[:64], sigdecoded[64], hrp_hash))
    521 
    522     return addr
    523 
    524 
    525 
    526 
    527 if __name__ == '__main__':
    528     # run using
    529     # python3 -m electrum.lnaddr <invoice> <expected hrp>
    530     # python3 -m electrum.lnaddr lntb1n1pdlcakepp5e7rn0knl0gm46qqp9eqdsza2c942d8pjqnwa5903n39zu28sgk3sdq423jhxapqv3hkuct5d9hkucqp2rzjqwyx8nu2hygyvgc02cwdtvuxe0lcxz06qt3lpsldzcdr46my5epmj9vk9sqqqlcqqqqqqqlgqqqqqqgqjqdhnmkgahfaynuhe9md8k49xhxuatnv6jckfmsjq8maxta2l0trh5sdrqlyjlwutdnpd5gwmdnyytsl9q0dj6g08jacvthtpeg383k0sq542rz2 tb1n
    531     import sys
    532     print(lndecode(sys.argv[1], expected_hrp=sys.argv[2]))