electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit f28a2aae73b8a4a2067f41d2f69809e2b1ec6a20
parent 1102ea50e878d02c126cbb27480abb39e74e0534
Author: ThomasV <thomasv@electrum.org>
Date:   Sat, 30 Jan 2021 16:10:51 +0100

Reorganize code so that we can send Multi Part Payments:
 - LNWorker is notified about htlc events and creates payment events.
 - LNWorker._pay is a while loop that calls create_routes_from_invoice.
 - create_route_from_invoices should decide whether to split the payment,
   using graph knowledge and feedback from previous attempts (not in this commit)
 - data structures for payment logs are simplified into a single type, HtlcLog

Diffstat:
Melectrum/gui/qt/channel_details.py | 20++++++++++----------
Melectrum/gui/qt/invoice_list.py | 4++--
Melectrum/lnchannel.py | 8++++----
Melectrum/lnutil.py | 58+++++++++++++++++++++-------------------------------------
Melectrum/lnworker.py | 254++++++++++++++++++++++++++++++++++++++++++++-----------------------------------
Melectrum/tests/test_lnpeer.py | 56++++++++++++++++++++++++++++----------------------------
6 files changed, 206 insertions(+), 194 deletions(-)

diff --git a/electrum/gui/qt/channel_details.py b/electrum/gui/qt/channel_details.py @@ -82,8 +82,8 @@ class ChannelDetailsDialog(QtWidgets.QDialog): dest_mapping = self.keyname_rows[to] dest_mapping[payment_hash] = len(dest_mapping) - ln_payment_completed = QtCore.pyqtSignal(str, bytes, bytes) - ln_payment_failed = QtCore.pyqtSignal(str, bytes, bytes) + htlc_fulfilled = QtCore.pyqtSignal(str, bytes, bytes) + htlc_failed = QtCore.pyqtSignal(str, bytes, bytes) htlc_added = QtCore.pyqtSignal(str, Channel, UpdateAddHtlc, Direction) state_changed = QtCore.pyqtSignal(str, Abstract_Wallet, AbstractChannel) @@ -95,7 +95,7 @@ class ChannelDetailsDialog(QtWidgets.QDialog): self.update() @QtCore.pyqtSlot(str, Channel, UpdateAddHtlc, Direction) - def do_htlc_added(self, evtname, chan, htlc, direction): + def on_htlc_added(self, evtname, chan, htlc, direction): if chan != self.chan: return mapping = self.keyname_rows['inflight'] @@ -103,14 +103,14 @@ class ChannelDetailsDialog(QtWidgets.QDialog): self.folders['inflight'].appendRow(self.make_htlc_item(htlc, direction)) @QtCore.pyqtSlot(str, bytes, bytes) - def do_ln_payment_completed(self, evtname, payment_hash, chan_id): + def on_htlc_fulfilled(self, evtname, payment_hash, chan_id): if chan_id != self.chan.channel_id: return self.move('inflight', 'settled', payment_hash) self.update() @QtCore.pyqtSlot(str, bytes, bytes) - def do_ln_payment_failed(self, evtname, payment_hash, chan_id): + def on_htlc_failed(self, evtname, payment_hash, chan_id): if chan_id != self.chan.channel_id: return self.move('inflight', 'failed', payment_hash) @@ -137,14 +137,14 @@ class ChannelDetailsDialog(QtWidgets.QDialog): self.format_msat = lambda msat: window.format_amount_and_units(msat / 1000) # connect signals with slots - self.ln_payment_completed.connect(self.do_ln_payment_completed) - self.ln_payment_failed.connect(self.do_ln_payment_failed) + self.htlc_fulfilled.connect(self.on_htlc_fulfilled) + self.htlc_failed.connect(self.on_htlc_failed_failed) self.state_changed.connect(self.do_state_changed) - self.htlc_added.connect(self.do_htlc_added) + self.htlc_added.connect(self.on_htlc_added) # register callbacks for updating - util.register_callback(self.ln_payment_completed.emit, ['ln_payment_completed']) - util.register_callback(self.ln_payment_failed.emit, ['ln_payment_failed']) + util.register_callback(self.htlc_fulfilled.emit, ['htlc_fulfilled']) + util.register_callback(self.htlc_failed.emit, ['htlc_failed']) util.register_callback(self.htlc_added.emit, ['htlc_added']) util.register_callback(self.state_changed.emit, ['channel']) diff --git a/electrum/gui/qt/invoice_list.py b/electrum/gui/qt/invoice_list.py @@ -34,7 +34,7 @@ from PyQt5.QtWidgets import QMenu, QVBoxLayout, QTreeWidget, QTreeWidgetItem, QH from electrum.i18n import _ from electrum.util import format_time from electrum.invoices import Invoice, PR_UNPAID, PR_PAID, PR_INFLIGHT, PR_FAILED, PR_TYPE_ONCHAIN, PR_TYPE_LN -from electrum.lnutil import PaymentAttemptLog +from electrum.lnutil import HtlcLog from .util import MyTreeView, read_QIcon, MySortModel, pr_icons from .util import CloseButton, Buttons @@ -173,7 +173,7 @@ class InvoiceList(MyTreeView): menu.addAction(_("Delete"), lambda: self.parent.delete_invoices([key])) menu.exec_(self.viewport().mapToGlobal(position)) - def show_log(self, key, log: Sequence[PaymentAttemptLog]): + def show_log(self, key, log: Sequence[HtlcLog]): d = WindowModalDialog(self, _("Payment log")) d.setMinimumWidth(600) vbox = QVBoxLayout(d) diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py @@ -991,7 +991,7 @@ class Channel(AbstractChannel): if self.lnworker: sent = self.hm.sent_in_ctn(new_ctn) for htlc in sent: - self.lnworker.payment_sent(self, htlc.payment_hash) + self.lnworker.htlc_fulfilled(self, htlc.payment_hash, htlc.htlc_id, htlc.amount_msat) failed = self.hm.failed_in_ctn(new_ctn) for htlc in failed: try: @@ -1002,7 +1002,7 @@ class Channel(AbstractChannel): if self.lnworker.get_payment_info(htlc.payment_hash) is None: self.save_fail_htlc_reason(htlc.htlc_id, error_bytes, failure_message) else: - self.lnworker.payment_failed(self, htlc.payment_hash, error_bytes, failure_message) + self.lnworker.htlc_failed(self, htlc.payment_hash, htlc.htlc_id, htlc.amount_msat, error_bytes, failure_message) def save_fail_htlc_reason( self, @@ -1048,9 +1048,9 @@ class Channel(AbstractChannel): info = self.lnworker.get_payment_info(payment_hash) if info is not None and info.status != PR_PAID: if is_sent: - self.lnworker.payment_sent(self, payment_hash) + self.lnworker.htlc_fulfilled(self, payment_hash, htlc.htlc_id, htlc.amount_msat) else: - self.lnworker.payment_received(payment_hash) + self.lnworker.htlc_received(self, payment_hash) def balance(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None) -> int: assert type(whose) is HTLCOwner diff --git a/electrum/lnutil.py b/electrum/lnutil.py @@ -249,52 +249,36 @@ class Outpoint(StoredObject): return "{}:{}".format(self.txid, self.output_index) -class PaymentAttemptFailureDetails(NamedTuple): - sender_idx: Optional[int] - failure_msg: 'OnionRoutingFailureMessage' - is_blacklisted: bool - - -class PaymentAttemptLog(NamedTuple): +class HtlcLog(NamedTuple): success: bool + amount_msat: int route: Optional['LNPaymentRoute'] = None preimage: Optional[bytes] = None - failure_details: Optional[PaymentAttemptFailureDetails] = None - exception: Optional[Exception] = None + error_bytes: Optional[bytes] = None + failure_msg: Optional['OnionRoutingFailureMessage'] = None + sender_idx: Optional[int] = None def formatted_tuple(self): - if not self.exception: - route = self.route - route_str = '%d'%len(route) - short_channel_id = None - if not self.success: - sender_idx = self.failure_details.sender_idx - failure_msg = self.failure_details.failure_msg - if sender_idx is not None: - try: - short_channel_id = route[sender_idx + 1].short_channel_id - except IndexError: - # payment destination reported error - short_channel_id = _("Destination node") - message = failure_msg.code_name() - else: - short_channel_id = route[-1].short_channel_id - message = _('Success') - chan_str = str(short_channel_id) if short_channel_id else _("Unknown") + route = self.route + route_str = '%d'%len(route) + short_channel_id = None + if not self.success: + sender_idx = self.sender_idx + failure_msg = self.failure_msg + if sender_idx is not None: + try: + short_channel_id = route[sender_idx + 1].short_channel_id + except IndexError: + # payment destination reported error + short_channel_id = _("Destination node") + message = failure_msg.code_name() else: - route_str = 'None' - chan_str = 'N/A' - message = str(self.exception) + short_channel_id = route[-1].short_channel_id + message = _('Success') + chan_str = str(short_channel_id) if short_channel_id else _("Unknown") return route_str, chan_str, message -class BarePaymentAttemptLog(NamedTuple): - success: bool - preimage: Optional[bytes] = None - error_bytes: Optional[bytes] = None - failure_message: Optional['OnionRoutingFailureMessage'] = None - - class LightningError(Exception): pass class LightningPeerConnectionClosed(LightningError): pass class UnableToDeriveSecret(LightningError): pass diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -55,9 +55,8 @@ from .lnutil import (Outpoint, LNPeerAddr, generate_keypair, LnKeyFamily, LOCAL, REMOTE, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE, NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner, - UpdateAddHtlc, Direction, LnFeatures, - ShortChannelID, PaymentAttemptLog, PaymentAttemptFailureDetails, - BarePaymentAttemptLog, derive_payment_secret_from_payment_preimage) + UpdateAddHtlc, Direction, LnFeatures, ShortChannelID, + HtlcLog, derive_payment_secret_from_payment_preimage) from .lnutil import ln_dummy_address, ln_compare_features, IncompatibleLightningFeatures from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput from .lnonion import OnionFailureCode, process_onion_packet, OnionPacket, OnionRoutingFailureMessage @@ -570,7 +569,7 @@ class LNWallet(LNWorker): self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage # note: this sweep_address is only used as fallback; as it might result in address-reuse self.sweep_address = wallet.get_new_sweep_address_for_channel() - self.logs = defaultdict(list) # type: Dict[str, List[PaymentAttemptLog]] # key is RHASH # (not persisted) + self.logs = defaultdict(list) # type: Dict[str, List[HtlcLog]] # key is RHASH # (not persisted) # used in tests self.enable_htlc_settle = asyncio.Event() self.enable_htlc_settle.set() @@ -581,8 +580,11 @@ class LNWallet(LNWorker): for channel_id, c in random_shuffled_copy(channels.items()): self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self) - self.pending_payments = defaultdict(asyncio.Future) # type: Dict[bytes, asyncio.Future[BarePaymentAttemptLog]] + self.pending_payments = defaultdict(asyncio.Future) # type: Dict[bytes, asyncio.Future[HtlcLog]] + self.pending_sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Future[HtlcLog]] + self.pending_htlcs = defaultdict(set) # type: Dict[bytes, set] + self.htlc_routes = defaultdict(list) self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self) # detect inflight payments @@ -930,7 +932,7 @@ class LNWallet(LNWorker): return chan, funding_tx - def pay(self, invoice: str, *, amount_msat: int = None, attempts: int = 1) -> Tuple[bool, List[PaymentAttemptLog]]: + def pay(self, invoice: str, *, amount_msat: int = None, attempts: int = 1) -> Tuple[bool, List[HtlcLog]]: """ Can be called from other threads """ @@ -945,13 +947,11 @@ class LNWallet(LNWorker): @log_exceptions async def _pay( - self, - invoice: str, - *, + self, invoice: str, *, amount_msat: int = None, attempts: int = 1, - full_path: LNPaymentPath = None, - ) -> Tuple[bool, List[PaymentAttemptLog]]: + full_path: LNPaymentPath = None) -> Tuple[bool, List[HtlcLog]]: + lnaddr = self._check_invoice(invoice, amount_msat=amount_msat) payment_hash = lnaddr.paymenthash key = payment_hash.hex() @@ -967,84 +967,89 @@ class LNWallet(LNWorker): self.logs[key] = log = [] success = False reason = '' - for i in range(attempts): - try: + amount_to_pay = lnaddr.get_amount_msat() + amount_inflight = 0 # what we sent in htlcs + + self.set_invoice_status(key, PR_INFLIGHT) + util.trigger_callback('invoice_status', self.wallet, key) + while True: + amount_to_send = amount_to_pay - amount_inflight + if amount_to_send > 0: + # 1. create a set of routes for remaining amount. # note: path-finding runs in a separate thread so that we don't block the asyncio loop # graph updates might occur during the computation - self.set_invoice_status(key, PR_INFLIGHT) + try: + routes = await run_in_thread(partial(self.create_routes_from_invoice, amount_to_send, lnaddr, full_path=full_path)) + except NoPathFound: + # catch this exception because we still want to return the htlc log + reason = 'No path found' + break + # 2. send htlcs + for route, amount_msat in routes: + await self.pay_to_route(route, amount_msat, lnaddr) + amount_inflight += amount_msat util.trigger_callback('invoice_status', self.wallet, key) - route = await run_in_thread(partial(self._create_route_from_invoice, lnaddr, full_path=full_path)) - payment_attempt_log = await self._pay_to_route(route, lnaddr) - except Exception as e: - log.append(PaymentAttemptLog(success=False, exception=e)) - reason = str(e) + # 3. await a queue + htlc_log = await self.pending_sent_htlcs[payment_hash].get() + amount_inflight -= htlc_log.amount_msat + log.append(htlc_log) + if htlc_log.success: + success = True break - log.append(payment_attempt_log) - success = payment_attempt_log.success - if success: + # htlc failed + # if we get a tmp channel failure, it might work to split the amount and try more routes + # if we get a channel update, we might retry the same route and amount + if len(log) >= attempts: + reason = 'Giving up after %d attempts'%len(log) break - else: - reason = _('Failed after {} attempts').format(attempts) - self.set_invoice_status(key, PR_PAID if success else PR_UNPAID) - util.trigger_callback('invoice_status', self.wallet, key) + if htlc_log.sender_idx is not None: + # apply channel update here + should_continue = self.handle_error_code_from_failed_htlc(htlc_log) + if not should_continue: + break + else: + # probably got "update_fail_malformed_htlc". well... who to penalise now? + reason = 'sender idx missing' + break + + # MPP: should we await all the inflight htlcs, or have another state? if success: + self.set_invoice_status(key, PR_PAID) util.trigger_callback('payment_succeeded', self.wallet, key) else: + self.set_invoice_status(key, PR_UNPAID) util.trigger_callback('payment_failed', self.wallet, key, reason) + util.trigger_callback('invoice_status', self.wallet, key) return success, log - async def _pay_to_route(self, route: LNPaymentRoute, lnaddr: LnAddr) -> PaymentAttemptLog: + async def pay_to_route(self, route: LNPaymentRoute, amount_msat:int, lnaddr: LnAddr): + # send a single htlc short_channel_id = route[0].short_channel_id chan = self.get_channel_by_short_id(short_channel_id) peer = self._peers.get(route[0].node_id) + payment_hash = lnaddr.paymenthash if not peer: raise Exception('Dropped peer') await peer.initialized htlc = peer.pay( route=route, chan=chan, - amount_msat=lnaddr.get_amount_msat(), - payment_hash=lnaddr.paymenthash, + amount_msat=amount_msat, + payment_hash=payment_hash, min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(), payment_secret=lnaddr.payment_secret) + self.htlc_routes[(payment_hash, short_channel_id, htlc.htlc_id)] = route util.trigger_callback('htlc_added', chan, htlc, SENT) - payment_attempt = await self.await_payment(lnaddr.paymenthash) - if payment_attempt.success: - failure_log = None - else: - if payment_attempt.error_bytes: - # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone? - failure_msg, sender_idx = chan.decode_onion_error(payment_attempt.error_bytes, route, htlc.htlc_id) - is_blacklisted = self.handle_error_code_from_failed_htlc(failure_msg, sender_idx, route, peer) - if is_blacklisted: - # blacklist channel after reporter node - # TODO this should depend on the error (even more granularity) - # also, we need finer blacklisting (directed edges; nodes) - try: - short_chan_id = route[sender_idx + 1].short_channel_id - except IndexError: - self.logger.info("payment destination reported error") - else: - self.logger.info(f'blacklisting channel {short_chan_id}') - self.network.channel_blacklist.add(short_chan_id) - else: - # probably got "update_fail_malformed_htlc". well... who to penalise now? - assert payment_attempt.failure_message is not None - sender_idx = None - failure_msg = payment_attempt.failure_message - is_blacklisted = False - failure_log = PaymentAttemptFailureDetails(sender_idx=sender_idx, - failure_msg=failure_msg, - is_blacklisted=is_blacklisted) - return PaymentAttemptLog(route=route, - success=payment_attempt.success, - preimage=payment_attempt.preimage, - failure_details=failure_log) - - def handle_error_code_from_failed_htlc(self, failure_msg, sender_idx, route, peer): + + def handle_error_code_from_failed_htlc(self, htlc_log): + route = htlc_log.route + sender_idx = htlc_log.sender_idx + failure_msg = htlc_log.failure_msg code, data = failure_msg.code, failure_msg.data self.logger.info(f"UPDATE_FAIL_HTLC {repr(code)} {data}") self.logger.info(f"error reported by {bh2u(route[sender_idx].node_id)}") + if code == OnionFailureCode.MPP_TIMEOUT: + return False # handle some specific error codes failure_codes = { OnionFailureCode.TEMPORARY_CHANNEL_FAILURE: 0, @@ -1067,7 +1072,10 @@ class LNWallet(LNWorker): short_channel_id = ShortChannelID(payload['short_channel_id']) if r == UpdateStatus.GOOD: self.logger.info(f"applied channel update to {short_channel_id}") - peer.maybe_save_remote_update(payload) + # TODO: test this + for chan in self.channels.values(): + if chan.short_channel_id == short_channel_id: + chan.set_remote_update(payload['raw']) elif r == UpdateStatus.ORPHANED: # maybe it is a private channel (and data in invoice was outdated) self.logger.info(f"Could not find {short_channel_id}. maybe update is for private channel?") @@ -1082,7 +1090,23 @@ class LNWallet(LNWorker): blacklist = True else: blacklist = True - return blacklist + # blacklist channel after reporter node + # TODO this should depend on the error (even more granularity) + # also, we need finer blacklisting (directed edges; nodes) + if blacklist and sender_idx: + try: + short_chan_id = route[sender_idx + 1].short_channel_id + except IndexError: + self.logger.info("payment destination reported error") + short_chan_id = None + else: + # TODO: for MPP we need to save the amount for which + # we saw temporary channel failure + self.logger.info(f'blacklisting channel {short_chan_id}') + self.network.channel_blacklist.add(short_chan_id) + return True + return False + @classmethod def _decode_channel_update_msg(cls, chan_upd_msg: bytes) -> Optional[Dict[str, Any]]: @@ -1123,9 +1147,13 @@ class LNWallet(LNWorker): return addr @profiler - def _create_route_from_invoice(self, decoded_invoice: 'LnAddr', - *, full_path: LNPaymentPath = None) -> LNPaymentRoute: - amount_msat = decoded_invoice.get_amount_msat() + def create_routes_from_invoice( + self, + amount_msat: int, + decoded_invoice: 'LnAddr', + *, full_path: LNPaymentPath = None) -> LNPaymentRoute: + # TODO: return multiples routes if we know that a single one will not work + # initially, try with less htlcs invoice_pubkey = decoded_invoice.pubkey.serialize() # use 'r' field from invoice route = None # type: Optional[LNPaymentRoute] @@ -1211,7 +1239,8 @@ class LNWallet(LNWorker): # add features from invoice invoice_features = decoded_invoice.get_tag('9') or 0 route[-1].node_features |= invoice_features - return route + # return a list of routes + return [(route, amount_msat)] def add_request(self, amount_sat, message, expiry) -> str: coro = self._add_request_coro(amount_sat, message, expiry) @@ -1297,7 +1326,8 @@ class LNWallet(LNWorker): expired = time.time() - first_timestamp > MPP_EXPIRY if total >= expected_msat and not expired: # status must be persisted - self.payment_received(htlc.payment_hash) + self.set_payment_status(htlc.payment_hash, PR_PAID) + util.trigger_callback('request_status', self.wallet, htlc.payment_hash.hex(), PR_PAID) return True, None if expired: return None, True @@ -1326,12 +1356,6 @@ class LNWallet(LNWorker): if status in SAVED_PR_STATUS: self.set_payment_status(bfh(key), status) - async def await_payment(self, payment_hash: bytes) -> BarePaymentAttemptLog: - # note side-effect: Future is created and added here (defaultdict): - payment_attempt = await self.pending_payments[payment_hash] - self.pending_payments.pop(payment_hash) - return payment_attempt - def set_payment_status(self, payment_hash: bytes, status): info = self.get_payment_info(payment_hash) if info is None: @@ -1340,48 +1364,52 @@ class LNWallet(LNWorker): info = info._replace(status=status) self.save_payment_info(info) - def payment_failed( + def htlc_fulfilled(self, chan, payment_hash: bytes, htlc_id:int, amount_msat:int): + route = self.htlc_routes.get((payment_hash, chan.short_channel_id, htlc_id)) + htlc_log = HtlcLog( + success=True, + route=route, + amount_msat=amount_msat) + q = self.pending_sent_htlcs[payment_hash] + q.put_nowait(htlc_log) + util.trigger_callback('htlc_fulfilled', payment_hash, chan.channel_id) + + def htlc_failed( self, - chan: Channel, + chan, payment_hash: bytes, + htlc_id: int, + amount_msat:int, error_bytes: Optional[bytes], - failure_message: Optional['OnionRoutingFailureMessage'], - ): - self.set_payment_status(payment_hash, PR_UNPAID) - f = self.pending_payments.get(payment_hash) - if f and not f.cancelled(): - payment_attempt = BarePaymentAttemptLog( - success=False, - error_bytes=error_bytes, - failure_message=failure_message) - f.set_result(payment_attempt) - else: - chan.logger.info('received unexpected payment_failed, probably from previous session') - key = payment_hash.hex() - util.trigger_callback('invoice_status', self.wallet, key) - util.trigger_callback('payment_failed', self.wallet, key, '') - util.trigger_callback('ln_payment_failed', payment_hash, chan.channel_id) - - def payment_sent(self, chan, payment_hash: bytes): - self.set_payment_status(payment_hash, PR_PAID) - preimage = self.get_preimage(payment_hash) - f = self.pending_payments.get(payment_hash) - if f and not f.cancelled(): - payment_attempt = BarePaymentAttemptLog( - success=True, - preimage=preimage) - f.set_result(payment_attempt) + failure_message: Optional['OnionRoutingFailureMessage']): + + route = self.htlc_routes.get((payment_hash, chan.short_channel_id, htlc_id)) + if error_bytes and route: + self.logger.info(f" {(error_bytes, route, htlc_id)}") + # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone? + try: + failure_message, sender_idx = chan.decode_onion_error(error_bytes, route, htlc_id) + except Exception as e: + sender_idx = None + failure_message = OnionRoutingFailureMessage(-1, str(e)) else: - chan.logger.info('received unexpected payment_sent, probably from previous session') - key = payment_hash.hex() - util.trigger_callback('invoice_status', self.wallet, key) - util.trigger_callback('payment_succeeded', self.wallet, key) - util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id) + # probably got "update_fail_malformed_htlc". well... who to penalise now? + assert failure_message is not None + sender_idx = None + + htlc_log = HtlcLog( + success=False, + route=route, + amount_msat=amount_msat, + error_bytes=error_bytes, + failure_msg=failure_message, + sender_idx=sender_idx) + + q = self.pending_sent_htlcs[payment_hash] + q.put_nowait(htlc_log) + util.trigger_callback('htlc_failed', payment_hash, chan.channel_id) + - def payment_received(self, payment_hash: bytes): - self.set_payment_status(payment_hash, PR_PAID) - util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID) - #util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id) async def _calc_routing_hints_for_invoice(self, amount_msat: Optional[int]): """calculate routing hints (BOLT-11 'r' field)""" diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -133,6 +133,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): self.enable_htlc_settle = asyncio.Event() self.enable_htlc_settle.set() self.pending_htlcs = defaultdict(set) + self.pending_sent_htlcs = defaultdict(asyncio.Queue) + self.htlc_routes = defaultdict(list) def get_invoice_status(self, key): pass @@ -169,15 +171,13 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): set_payment_status = LNWallet.set_payment_status get_payment_status = LNWallet.get_payment_status htlc_received = LNWallet.htlc_received - await_payment = LNWallet.await_payment - payment_received = LNWallet.payment_received - payment_sent = LNWallet.payment_sent - payment_failed = LNWallet.payment_failed + htlc_fulfilled = LNWallet.htlc_fulfilled + htlc_failed = LNWallet.htlc_failed save_preimage = LNWallet.save_preimage get_preimage = LNWallet.get_preimage - _create_route_from_invoice = LNWallet._create_route_from_invoice + create_routes_from_invoice = LNWallet.create_routes_from_invoice _check_invoice = staticmethod(LNWallet._check_invoice) - _pay_to_route = LNWallet._pay_to_route + pay_to_route = LNWallet.pay_to_route _pay = LNWallet._pay force_close_channel = LNWallet.force_close_channel try_force_closing = LNWallet.try_force_closing @@ -490,27 +490,27 @@ class TestPeer(ElectrumTestCase): lnaddr1 = lndecode(pay_req1, expected_hrp=constants.net.SEGWIT_HRP) # alice sends htlc BUT NOT COMMITMENT_SIGNED p1.maybe_send_commitment = lambda x: None + route1, amount_msat1 = w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2)[0] p1.pay( - route=w1._create_route_from_invoice(decoded_invoice=lnaddr2), + route=route1, chan=alice_channel, amount_msat=lnaddr2.get_amount_msat(), payment_hash=lnaddr2.paymenthash, min_final_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(), payment_secret=lnaddr2.payment_secret, ) - w1.pending_payments[lnaddr2.paymenthash] = asyncio.Future() p1.maybe_send_commitment = _maybe_send_commitment1 # bob sends htlc BUT NOT COMMITMENT_SIGNED p2.maybe_send_commitment = lambda x: None + route2, amount_msat2 = w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1)[0] p2.pay( - route=w2._create_route_from_invoice(decoded_invoice=lnaddr1), + route=route2, chan=bob_channel, amount_msat=lnaddr1.get_amount_msat(), payment_hash=lnaddr1.paymenthash, min_final_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(), payment_secret=lnaddr1.payment_secret, ) - w2.pending_payments[lnaddr1.paymenthash] = asyncio.Future() p2.maybe_send_commitment = _maybe_send_commitment2 # sleep a bit so that they both receive msgs sent so far await asyncio.sleep(0.1) @@ -518,10 +518,10 @@ class TestPeer(ElectrumTestCase): p1.maybe_send_commitment(alice_channel) p2.maybe_send_commitment(bob_channel) - payment_attempt1 = await w1.await_payment(lnaddr2.paymenthash) - assert payment_attempt1.success - payment_attempt2 = await w2.await_payment(lnaddr1.paymenthash) - assert payment_attempt2.success + htlc_log1 = await w1.pending_sent_htlcs[lnaddr2.paymenthash].get() + assert htlc_log1.success + htlc_log2 = await w2.pending_sent_htlcs[lnaddr1.paymenthash].get() + assert htlc_log2.success raise PaymentDone() async def f(): @@ -594,21 +594,20 @@ class TestPeer(ElectrumTestCase): with self.subTest(msg="bad path: edges do not chain together"): path = [PathEdge(node_id=graph.w_c.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id), PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)] - result, log = await graph.w_a._pay(pay_req, full_path=path) - self.assertFalse(result) - self.assertTrue(isinstance(log[0].exception, LNPathInconsistent)) + with self.assertRaises(LNPathInconsistent): + await graph.w_a._pay(pay_req, full_path=path) with self.subTest(msg="bad path: last node id differs from invoice pubkey"): path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id)] - result, log = await graph.w_a._pay(pay_req, full_path=path) - self.assertFalse(result) - self.assertTrue(isinstance(log[0].exception, LNPathInconsistent)) + with self.assertRaises(LNPathInconsistent): + await graph.w_a._pay(pay_req, full_path=path) with self.subTest(msg="good path"): path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id), PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)] result, log = await graph.w_a._pay(pay_req, full_path=path) self.assertTrue(result) - self.assertEqual([edge.short_channel_id for edge in path], - [edge.short_channel_id for edge in log[0].route]) + self.assertEqual( + [edge.short_channel_id for edge in path], + [edge.short_channel_id for edge in log[0].route]) raise PaymentDone() async def f(): async with TaskGroup() as group: @@ -630,7 +629,7 @@ class TestPeer(ElectrumTestCase): async def pay(pay_req): result, log = await graph.w_a._pay(pay_req) self.assertFalse(result) - self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_details.failure_msg.code) + self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code) raise PaymentDone() async def f(): async with TaskGroup() as group: @@ -658,7 +657,7 @@ class TestPeer(ElectrumTestCase): await asyncio.wait_for(p1.initialized, 1) await asyncio.wait_for(p2.initialized, 1) # alice sends htlc - route = w1._create_route_from_invoice(decoded_invoice=lnaddr) + route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0] htlc = p1.pay(route=route, chan=alice_channel, amount_msat=lnaddr.get_amount_msat(), @@ -752,21 +751,22 @@ class TestPeer(ElectrumTestCase): p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel) pay_req = run(self.prepare_invoice(w2)) - addr = w1._check_invoice(pay_req) - route = w1._create_route_from_invoice(decoded_invoice=addr) + lnaddr = w1._check_invoice(pay_req) + route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0] + assert amount_msat == lnaddr.get_amount_msat() run(w1.force_close_channel(alice_channel.channel_id)) # check if a tx (commitment transaction) was broadcasted: assert q1.qsize() == 1 with self.assertRaises(NoPathFound) as e: - w1._create_route_from_invoice(decoded_invoice=addr) + w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr) peer = w1.peers[route[0].node_id] # AssertionError is ok since we shouldn't use old routes, and the # route finding should fail when channel is closed async def f(): - await asyncio.gather(w1._pay_to_route(route, addr), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) + await asyncio.gather(w1.pay_to_route(route, amount_msat, lnaddr), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) with self.assertRaises(PaymentFailure): run(f())