commit b6b13217b4929c5701e18a1310b20caf01fb61e5
parent 2f223cdf467962cb53fdf43d8cf856ad541920c6
Author: ThomasV <thomasv@electrum.org>
Date: Sat, 27 Feb 2021 20:26:58 +0100
lnworker: keep invoice status INFLIGHT as long as HTLCs are inflight
Diffstat:
2 files changed, 61 insertions(+), 53 deletions(-)
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -657,8 +657,8 @@ class LNWallet(LNWorker):
self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
+ self.sent_htlcs_routes = dict() # (RHASH, scid, htlc_id) -> route
self.received_htlcs = dict() # RHASH -> mpp_status, htlc_set
- self.htlc_routes = dict()
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
# detect inflight payments
@@ -939,14 +939,13 @@ class LNWallet(LNWorker):
@log_exceptions
async def _open_channel_coroutine(
- self,
- *,
+ self, *,
connect_str: str,
funding_tx: PartialTransaction,
funding_sat: int,
push_sat: int,
- password: Optional[str],
- ) -> Tuple[Channel, PartialTransaction]:
+ password: Optional[str]) -> Tuple[Channel, PartialTransaction]:
+
peer = await self.add_peer(connect_str)
coro = peer.channel_establishment_flow(
funding_tx=funding_tx,
@@ -1053,7 +1052,6 @@ class LNWallet(LNWorker):
random.shuffle(self.trampoline2_list)
self.set_invoice_status(key, PR_INFLIGHT)
- util.trigger_callback('invoice_status', self.wallet, key)
try:
await self.pay_to_node(
node_pubkey=invoice_pubkey,
@@ -1071,6 +1069,11 @@ class LNWallet(LNWorker):
self.logger.exception('')
success = False
reason = str(e)
+ # keep invoice status INFLIGHT as long as HTLCs are inflight
+ # maybe we could add an extra state for the waiting time.
+ while payment_hash in self.get_payments(status='inflight'):
+ self.logger.info('waiting for inflight HTLCs...')
+ await self.sent_htlcs[payment_hash].get()
if success:
self.set_invoice_status(key, PR_PAID)
util.trigger_callback('payment_succeeded', self.wallet, key)
@@ -1081,8 +1084,7 @@ class LNWallet(LNWorker):
return success, log
async def pay_to_node(
- self,
- *,
+ self, *,
node_pubkey: bytes,
payment_hash: bytes,
payment_secret: Optional[bytes],
@@ -1095,8 +1097,7 @@ class LNWallet(LNWorker):
full_path: LNPaymentPath = None,
trampoline_onion=None,
trampoline_fee=None,
- trampoline_cltv_delta=None,
- ) -> None:
+ trampoline_cltv_delta=None) -> None:
if trampoline_onion:
# todo: compare to the fee of the actual route we found
@@ -1119,7 +1120,7 @@ class LNWallet(LNWorker):
# 2. send htlcs
for route, amount_msat in routes:
await self.pay_to_route(
- route,
+ route=route,
amount_msat=amount_msat,
total_msat=amount_to_pay,
payment_hash=payment_hash,
@@ -1142,16 +1143,15 @@ class LNWallet(LNWorker):
self.handle_error_code_from_failed_htlc(htlc_log)
async def pay_to_route(
- self,
+ self, *,
route: LNPaymentRoute,
- *,
amount_msat: int,
total_msat: int,
payment_hash: bytes,
payment_secret: Optional[bytes],
min_cltv_expiry: int,
- trampoline_onion: bytes = None,
- ) -> None:
+ trampoline_onion: bytes = None) -> None:
+
# send a single htlc
short_channel_id = route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id)
@@ -1168,7 +1168,7 @@ class LNWallet(LNWorker):
min_final_cltv_expiry=min_cltv_expiry,
payment_secret=payment_secret,
fwd_trampoline_onion=trampoline_onion)
- self.htlc_routes[(payment_hash, short_channel_id, htlc.htlc_id)] = route
+ self.sent_htlcs_routes[(payment_hash, short_channel_id, htlc.htlc_id)] = route
util.trigger_callback('htlc_added', chan, htlc, SENT)
def handle_error_code_from_failed_htlc(self, htlc_log):
@@ -1729,6 +1729,7 @@ class LNWallet(LNWorker):
self.inflight_payments.remove(key)
if status in SAVED_PR_STATUS:
self.set_payment_status(bfh(key), status)
+ util.trigger_callback('invoice_status', self.wallet, key)
def set_payment_status(self, payment_hash: bytes, status):
info = self.get_payment_info(payment_hash)
@@ -1739,54 +1740,60 @@ class LNWallet(LNWorker):
self.save_payment_info(info)
def htlc_fulfilled(self, chan, payment_hash: bytes, htlc_id:int, amount_msat:int):
- route = self.htlc_routes.get((payment_hash, chan.short_channel_id, htlc_id))
- htlc_log = HtlcLog(
- success=True,
- route=route,
- amount_msat=amount_msat)
- q = self.sent_htlcs[payment_hash]
- q.put_nowait(htlc_log)
util.trigger_callback('htlc_fulfilled', payment_hash, chan.channel_id)
+ q = self.sent_htlcs.get(payment_hash)
+ if q:
+ route = self.sent_htlcs_routes[(payment_hash, chan.short_channel_id, htlc_id)]
+ htlc_log = HtlcLog(
+ success=True,
+ route=route,
+ amount_msat=amount_msat)
+ q.put_nowait(htlc_log)
+ else:
+ if payment_hash not in self.get_payments(status='inflight'):
+ key = payment_hash.hex()
+ self.set_invoice_status(key, PR_PAID)
+ util.trigger_callback('payment_succeeded', self.wallet, key)
def htlc_failed(
self,
- chan,
+ chan: Channel,
payment_hash: bytes,
htlc_id: int,
amount_msat:int,
error_bytes: Optional[bytes],
failure_message: Optional['OnionRoutingFailure']):
- route = self.htlc_routes.get((payment_hash, chan.short_channel_id, htlc_id))
- if not route:
- self.logger.info(f"received unknown htlc_failed, probably from previous session")
- return
- if error_bytes:
- self.logger.info(f" {(error_bytes, route, htlc_id)}")
- # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
- try:
- failure_message, sender_idx = chan.decode_onion_error(error_bytes, route, htlc_id)
- except Exception as e:
+ util.trigger_callback('htlc_failed', payment_hash, chan.channel_id)
+ q = self.sent_htlcs.get(payment_hash)
+ if q:
+ route = self.sent_htlcs_routes[(payment_hash, chan.short_channel_id, htlc_id)]
+ if error_bytes:
+ # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
+ try:
+ failure_message, sender_idx = chan.decode_onion_error(error_bytes, route, htlc_id)
+ except Exception as e:
+ sender_idx = None
+ failure_message = OnionRoutingFailure(-1, str(e))
+ else:
+ # probably got "update_fail_malformed_htlc". well... who to penalise now?
+ assert failure_message is not None
sender_idx = None
- failure_message = OnionRoutingFailure(-1, str(e))
+ self.logger.info(f"htlc_failed {failure_message}")
+ htlc_log = HtlcLog(
+ success=False,
+ route=route,
+ amount_msat=amount_msat,
+ error_bytes=error_bytes,
+ failure_msg=failure_message,
+ sender_idx=sender_idx)
+ q.put_nowait(htlc_log)
else:
- # probably got "update_fail_malformed_htlc". well... who to penalise now?
- assert failure_message is not None
- sender_idx = None
-
- htlc_log = HtlcLog(
- success=False,
- route=route,
- amount_msat=amount_msat,
- error_bytes=error_bytes,
- failure_msg=failure_message,
- sender_idx=sender_idx)
-
- q = self.sent_htlcs[payment_hash]
- q.put_nowait(htlc_log)
- util.trigger_callback('htlc_failed', payment_hash, chan.channel_id)
-
-
+ self.logger.info(f"received unknown htlc_failed, probably from previous session")
+ if payment_hash not in self.get_payments(status='inflight'):
+ key = payment_hash.hex()
+ self.set_invoice_status(key, PR_UNPAID)
+ util.trigger_callback('payment_failed', self.wallet, key, '')
async def _calc_routing_hints_for_invoice(self, amount_msat: Optional[int]):
"""calculate routing hints (BOLT-11 'r' field)"""
diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
@@ -165,6 +165,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
inflight_payments = set()
preimages = {}
+ get_payments = LNWallet.get_payments
get_payment_info = LNWallet.get_payment_info
save_payment_info = LNWallet.save_payment_info
set_invoice_status = LNWallet.set_invoice_status
@@ -776,7 +777,7 @@ class TestPeer(ElectrumTestCase):
payment_hash = lnaddr.paymenthash
payment_secret = lnaddr.payment_secret
pay = w1.pay_to_route(
- route,
+ route=route,
amount_msat=amount_msat,
total_msat=amount_msat,
payment_hash=payment_hash,