electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 42c10c2fecf5cc56d149b7d09ae2dddb36560624
parent 2c2d3f3b306833b16b43789b791e856d7924934e
Author: ThomasV <thomasv@electrum.org>
Date:   Sun,  7 Feb 2021 11:57:20 +0100

Separate pay_to_node logic from pay_invoice:
 - pay_to_node will be needed to forward trampoline onions.
 - pay_to_node either is successful or raises
 - pay_invoice handles invoice status

Diffstat:
Melectrum/lnworker.py | 152++++++++++++++++++++++++++++++++++++++++++++-----------------------------------
Melectrum/tests/test_lnpeer.py | 8+++++++-
2 files changed, 92 insertions(+), 68 deletions(-)

diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -922,12 +922,10 @@ class LNWallet(LNWorker): chan, funding_tx = fut.result(timeout=timeout) except concurrent.futures.TimeoutError: raise Exception(_("open_channel timed out")) - # at this point the channel opening was successful # if this is the first channel that got opened, we start gossiping if self.channels: self.network.start_gossip() - return chan, funding_tx def get_channel_by_short_id(self, short_channel_id: bytes) -> Optional[Channel]: @@ -935,6 +933,15 @@ class LNWallet(LNWorker): if chan.short_channel_id == short_channel_id: return chan + def create_routes_from_invoice(self, amount_msat, decoded_invoice, *, full_path=None): + return self.create_routes_for_payment( + amount_msat=amount_msat, + invoice_pubkey=decoded_invoice.pubkey.serialize(), + min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(), + r_tags=decoded_invoice.get_routing_info('r'), + invoice_features=decoded_invoice.get_tag('9') or 0, + full_path=full_path) + @log_exceptions async def pay_invoice( self, invoice: str, *, @@ -943,8 +950,13 @@ class LNWallet(LNWorker): full_path: LNPaymentPath = None) -> Tuple[bool, List[HtlcLog]]: lnaddr = self._check_invoice(invoice, amount_msat=amount_msat) + min_cltv_expiry = lnaddr.get_min_final_cltv_expiry() payment_hash = lnaddr.paymenthash key = payment_hash.hex() + payment_secret = lnaddr.payment_secret + invoice_pubkey = lnaddr.pubkey.serialize() + invoice_features = lnaddr.get_tag('9') or 0 + r_tags = lnaddr.get_routing_info('r') amount_to_pay = lnaddr.get_amount_msat() status = self.get_payment_status(payment_hash) if status == PR_PAID: @@ -954,69 +966,68 @@ class LNWallet(LNWorker): info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID) self.save_payment_info(info) self.wallet.set_label(key, lnaddr.get_description()) - self.logs[key] = log = [] - success = False - reason = '' - amount_inflight = 0 # what we sent in htlcs - self.set_invoice_status(key, PR_INFLIGHT) util.trigger_callback('invoice_status', self.wallet, key) + try: + await self.pay_to_node( + invoice_pubkey, payment_hash, payment_secret, amount_to_pay, + min_cltv_expiry, r_tags, invoice_features, + attempts=attempts, full_path=full_path) + success = True + except PaymentFailure as e: + self.logger.exception('') + success = False + reason = str(e) + if success: + self.set_invoice_status(key, PR_PAID) + util.trigger_callback('payment_succeeded', self.wallet, key) + else: + self.set_invoice_status(key, PR_UNPAID) + util.trigger_callback('payment_failed', self.wallet, key, reason) + log = self.logs[key] + return success, log + + + async def pay_to_node( + self, node_pubkey, payment_hash, payment_secret, amount_to_pay, + min_cltv_expiry, r_tags, invoice_features, *, attempts: int = 1, + full_path: LNPaymentPath = None): + + self.logs[payment_hash.hex()] = log = [] + amount_inflight = 0 # what we sent in htlcs while True: amount_to_send = amount_to_pay - amount_inflight if amount_to_send > 0: # 1. create a set of routes for remaining amount. # note: path-finding runs in a separate thread so that we don't block the asyncio loop # graph updates might occur during the computation - try: - routes = await run_in_thread(partial(self.create_routes_from_invoice, amount_to_send, lnaddr, full_path=full_path)) - except NoPathFound: - # catch this exception because we still want to return the htlc log - reason = 'No path found' - break + routes = await run_in_thread(partial( + self.create_routes_for_payment, amount_to_send, node_pubkey, + min_cltv_expiry, r_tags, invoice_features, full_path=full_path)) # 2. send htlcs for route, amount_msat in routes: - await self.pay_to_route(route, amount_msat, lnaddr) + await self.pay_to_route(route, amount_msat, payment_hash, payment_secret, min_cltv_expiry) amount_inflight += amount_msat - util.trigger_callback('invoice_status', self.wallet, key) + util.trigger_callback('invoice_status', self.wallet, payment_hash.hex()) # 3. await a queue htlc_log = await self.sent_htlcs[payment_hash].get() amount_inflight -= htlc_log.amount_msat log.append(htlc_log) if htlc_log.success: - success = True - break + return # htlc failed + if len(log) >= attempts: + raise PaymentFailure('Giving up after %d attempts'%len(log)) # if we get a tmp channel failure, it might work to split the amount and try more routes # if we get a channel update, we might retry the same route and amount - if len(log) >= attempts: - reason = 'Giving up after %d attempts'%len(log) - break - if htlc_log.sender_idx is not None: - # apply channel update here - should_continue = self.handle_error_code_from_failed_htlc(htlc_log) - if not should_continue: - break - else: - # probably got "update_fail_malformed_htlc". well... who to penalise now? - reason = 'sender idx missing' - break + self.handle_error_code_from_failed_htlc(htlc_log) - # MPP: should we await all the inflight htlcs, or have another state? - if success: - self.set_invoice_status(key, PR_PAID) - util.trigger_callback('payment_succeeded', self.wallet, key) - else: - self.set_invoice_status(key, PR_UNPAID) - util.trigger_callback('payment_failed', self.wallet, key, reason) - util.trigger_callback('invoice_status', self.wallet, key) - return success, log - async def pay_to_route(self, route: LNPaymentRoute, amount_msat:int, lnaddr: LnAddr): + async def pay_to_route(self, route: LNPaymentRoute, amount_msat:int, payment_hash:bytes, payment_secret:bytes, min_cltv_expiry:int): # send a single htlc short_channel_id = route[0].short_channel_id chan = self.get_channel_by_short_id(short_channel_id) peer = self._peers.get(route[0].node_id) - payment_hash = lnaddr.paymenthash if not peer: raise Exception('Dropped peer') await peer.initialized @@ -1025,8 +1036,8 @@ class LNWallet(LNWorker): chan=chan, amount_msat=amount_msat, payment_hash=payment_hash, - min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(), - payment_secret=lnaddr.payment_secret) + min_final_cltv_expiry=min_cltv_expiry, + payment_secret=payment_secret) self.htlc_routes[(payment_hash, short_channel_id, htlc.htlc_id)] = route util.trigger_callback('htlc_added', chan, htlc, SENT) @@ -1037,8 +1048,6 @@ class LNWallet(LNWorker): code, data = failure_msg.code, failure_msg.data self.logger.info(f"UPDATE_FAIL_HTLC {repr(code)} {data}") self.logger.info(f"error reported by {bh2u(route[sender_idx].node_id)}") - if code == OnionFailureCode.MPP_TIMEOUT: - return False # handle some specific error codes failure_codes = { OnionFailureCode.TEMPORARY_CHANNEL_FAILURE: 0, @@ -1048,6 +1057,8 @@ class LNWallet(LNWorker): OnionFailureCode.EXPIRY_TOO_SOON: 0, OnionFailureCode.CHANNEL_DISABLED: 2, } + blacklist = False + update = False if code in failure_codes: offset = failure_codes[code] channel_update_len = int.from_bytes(data[offset:offset+2], byteorder="big") @@ -1058,7 +1069,6 @@ class LNWallet(LNWorker): blacklist = True else: r = self.channel_db.add_channel_update(payload) - blacklist = False short_channel_id = ShortChannelID(payload['short_channel_id']) if r == UpdateStatus.GOOD: self.logger.info(f"applied channel update to {short_channel_id}") @@ -1066,11 +1076,13 @@ class LNWallet(LNWorker): for chan in self.channels.values(): if chan.short_channel_id == short_channel_id: chan.set_remote_update(payload['raw']) + update = True elif r == UpdateStatus.ORPHANED: # maybe it is a private channel (and data in invoice was outdated) self.logger.info(f"Could not find {short_channel_id}. maybe update is for private channel?") start_node_id = route[sender_idx].node_id self.channel_db.add_channel_update_for_private_channel(payload, start_node_id) + #update = True # FIXME: we need to check if we actually updated something elif r == UpdateStatus.EXPIRED: blacklist = True elif r == UpdateStatus.DEPRECATED: @@ -1080,22 +1092,25 @@ class LNWallet(LNWorker): blacklist = True else: blacklist = True - # blacklist channel after reporter node - # TODO this should depend on the error (even more granularity) - # also, we need finer blacklisting (directed edges; nodes) - if blacklist and sender_idx: + + if blacklist: + # blacklist channel after reporter node + # TODO this should depend on the error (even more granularity) + # also, we need finer blacklisting (directed edges; nodes) + if htlc_log.sender_idx is None: + raise PaymentFailure(htlc_log.failure_msg.code_name()) try: short_chan_id = route[sender_idx + 1].short_channel_id except IndexError: - self.logger.info("payment destination reported error") - short_chan_id = None - else: - # TODO: for MPP we need to save the amount for which - # we saw temporary channel failure - self.logger.info(f'blacklisting channel {short_chan_id}') - self.network.channel_blacklist.add(short_chan_id) - return True - return False + raise PaymentFailure('payment destination reported error') + # TODO: for MPP we need to save the amount for which + # we saw temporary channel failure + self.logger.info(f'blacklisting channel {short_chan_id}') + self.network.channel_blacklist.add(short_chan_id) + + # we should not continue if we did not blacklist or update anything + if not (blacklist or update): + raise PaymentFailure(htlc_log.failure_msg.code_name()) @classmethod @@ -1137,16 +1152,17 @@ class LNWallet(LNWorker): return addr @profiler - def create_routes_from_invoice( + def create_routes_for_payment( self, amount_msat: int, - decoded_invoice: 'LnAddr', + invoice_pubkey, + min_cltv_expiry, + r_tags, + invoice_features, *, full_path: LNPaymentPath = None) -> LNPaymentRoute: # TODO: return multiples routes if we know that a single one will not work # initially, try with less htlcs - invoice_pubkey = decoded_invoice.pubkey.serialize() - r_tags = decoded_invoice.get_routing_info('r') - route = None # type: Optional[LNPaymentRoute] + route = None channels = list(self.channels.values()) scid_to_my_channels = {chan.short_channel_id: chan for chan in channels if chan.short_channel_id is not None} @@ -1201,7 +1217,7 @@ class LNWallet(LNWorker): node_features=node_info.features if node_info else 0)) prev_node_id = node_pubkey # test sanity - if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()): + if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry): self.logger.info(f"rejecting insane route {route}") route = None continue @@ -1213,14 +1229,13 @@ class LNWallet(LNWorker): path=full_path, my_channels=scid_to_my_channels, blacklist=blacklist) if not route: raise NoPathFound() - if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()): + if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry): self.logger.info(f"rejecting insane route {route}") raise NoPathFound() assert len(route) > 0 if route[-1].node_id != invoice_pubkey: raise LNPathInconsistent("last node_id != invoice pubkey") # add features from invoice - invoice_features = decoded_invoice.get_tag('9') or 0 route[-1].node_features |= invoice_features # return a list of routes return [(route, amount_msat)] @@ -1367,7 +1382,10 @@ class LNWallet(LNWorker): failure_message: Optional['OnionRoutingFailureMessage']): route = self.htlc_routes.get((payment_hash, chan.short_channel_id, htlc_id)) - if error_bytes and route: + if not route: + self.logger.info(f"received unknown htlc_failed, probably from previous session") + return + if error_bytes: self.logger.info(f" {(error_bytes, route, htlc_id)}") # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone? try: diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -175,9 +175,11 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): htlc_failed = LNWallet.htlc_failed save_preimage = LNWallet.save_preimage get_preimage = LNWallet.get_preimage + create_routes_for_payment = LNWallet.create_routes_for_payment create_routes_from_invoice = LNWallet.create_routes_from_invoice _check_invoice = staticmethod(LNWallet._check_invoice) pay_to_route = LNWallet.pay_to_route + pay_to_node = LNWallet.pay_to_node pay_invoice = LNWallet.pay_invoice force_close_channel = LNWallet.force_close_channel try_force_closing = LNWallet.try_force_closing @@ -766,7 +768,11 @@ class TestPeer(ElectrumTestCase): # AssertionError is ok since we shouldn't use old routes, and the # route finding should fail when channel is closed async def f(): - await asyncio.gather(w1.pay_to_route(route, amount_msat, lnaddr), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) + min_cltv_expiry = lnaddr.get_min_final_cltv_expiry() + payment_hash = lnaddr.paymenthash + payment_secret = lnaddr.payment_secret + pay = w1.pay_to_route(route, amount_msat, payment_hash, payment_secret, min_cltv_expiry) + await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) with self.assertRaises(PaymentFailure): run(f())