electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 7f61f22857e54954dfe1132ceb91f8006e055a57
parent 0ce6adffcc271a8e31705e16b7b24a0e93a80ca0
Author: ThomasV <thomasv@electrum.org>
Date:   Sat, 27 Feb 2021 11:48:14 +0100

MPP receive: allow payer to retry after mpp timeout

Diffstat:
Melectrum/lnpeer.py | 6+++---
Melectrum/lnworker.py | 44+++++++++++++++++++++++++-------------------
Melectrum/tests/test_lnpeer.py | 4++--
3 files changed, 30 insertions(+), 24 deletions(-)

diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -1576,10 +1576,10 @@ class Peer(Logger): invoice_msat = info.amount_msat if not (invoice_msat is None or invoice_msat <= total_msat <= 2 * invoice_msat): raise exc_incorrect_or_unknown_pd - accepted, expired = self.lnworker.htlc_received(chan.short_channel_id, htlc, total_msat) - if accepted: + mpp_status = self.lnworker.add_received_htlc(chan.short_channel_id, htlc, total_msat) + if mpp_status == True: return preimage - elif expired: + elif mpp_status == False: raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') else: return None diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -657,7 +657,7 @@ class LNWallet(LNWorker): self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self) self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] - self.received_htlcs = defaultdict(set) # type: Dict[bytes, set] + self.received_htlcs = dict() # RHASH -> mpp_status, htlc_set self.htlc_routes = dict() self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self) @@ -1682,24 +1682,30 @@ class LNWallet(LNWorker): self.payments[key] = info.amount_msat, info.direction, info.status self.wallet.save_db() - def htlc_received(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int): - status = self.get_payment_status(htlc.payment_hash) - if status == PR_PAID: - return True, None - s = self.received_htlcs[htlc.payment_hash] - if (short_channel_id, htlc) not in s: - s.add((short_channel_id, htlc)) - total = sum([htlc.amount_msat for scid, htlc in s]) - first_timestamp = min([htlc.timestamp for scid, htlc in s]) - expired = time.time() - first_timestamp > MPP_EXPIRY - if total == expected_msat and not expired: - # status must be persisted - self.set_payment_status(htlc.payment_hash, PR_PAID) - util.trigger_callback('request_status', self.wallet, htlc.payment_hash.hex(), PR_PAID) - return True, None - if expired: - return None, True - return None, None + def add_received_htlc(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]: + """ return MPP status: True (accepted), False (expired) or None """ + payment_hash = htlc.payment_hash + mpp_status, htlc_set = self.received_htlcs.get(payment_hash, (None, set())) + key = (short_channel_id, htlc) + if key not in htlc_set: + htlc_set.add(key) + if mpp_status is None: + total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) + first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set]) + expired = time.time() - first_timestamp > MPP_EXPIRY + if expired: + mpp_status = False + elif total == expected_msat: + mpp_status = True + self.set_payment_status(payment_hash, PR_PAID) + util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID) + if mpp_status is not None: + htlc_set.remove(key) + if len(htlc_set) > 0: + self.received_htlcs[payment_hash] = mpp_status, htlc_set + elif payment_hash in self.received_htlcs: + self.received_htlcs.pop(payment_hash) + return mpp_status def get_payment_status(self, payment_hash): info = self.get_payment_info(payment_hash) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -132,7 +132,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): # used in tests self.enable_htlc_settle = asyncio.Event() self.enable_htlc_settle.set() - self.received_htlcs = defaultdict(set) + self.received_htlcs = dict() self.sent_htlcs = defaultdict(asyncio.Queue) self.htlc_routes = defaultdict(list) @@ -170,7 +170,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): set_invoice_status = LNWallet.set_invoice_status set_payment_status = LNWallet.set_payment_status get_payment_status = LNWallet.get_payment_status - htlc_received = LNWallet.htlc_received + add_received_htlc = LNWallet.add_received_htlc htlc_fulfilled = LNWallet.htlc_fulfilled htlc_failed = LNWallet.htlc_failed save_preimage = LNWallet.save_preimage