electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 6004a047053e97b9f613d8329f690d1fd6d920fb
parent 05e58671c92940620aab36cf5946a0f1dd24c013
Author: ThomasV <thomasv@electrum.org>
Date:   Fri, 12 Mar 2021 11:01:07 +0100

Merge pull request #7099 from SomberNight/202103_fail_pending_htlcs_on_shutdown

fail pending htlcs on shutdown
Diffstat:
Melectrum/lnhtlc.py | 49+++++++++++++++++++++++++++++++++++++++++++++++++
Melectrum/lnpeer.py | 45++++++++++++++++++++++++++++++++++++++++++---
Melectrum/lnworker.py | 41++++++++++++++++++++++++++++++++++++-----
Melectrum/tests/test_lnpeer.py | 62+++++++++++++++++++++++++++++++++++++++++++++++++++++++-------
4 files changed, 182 insertions(+), 15 deletions(-)

diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py @@ -360,6 +360,55 @@ class HTLCManager: return ctns[ctx_owner] <= self.ctn_oldest_unrevoked(ctx_owner) @with_lock + def is_htlc_irrevocably_removed_yet( + self, + *, + ctx_owner: HTLCOwner = None, + htlc_proposer: HTLCOwner, + htlc_id: int, + ) -> bool: + """Returns whether the removal of an htlc was irrevocably committed to `ctx_owner's` ctx. + The removal can either be a fulfill/settle or a fail; they are not distinguished. + If `ctx_owner` is None, both parties' ctxs are checked. + """ + in_local = self._is_htlc_irrevocably_removed_yet( + ctx_owner=LOCAL, htlc_proposer=htlc_proposer, htlc_id=htlc_id) + in_remote = self._is_htlc_irrevocably_removed_yet( + ctx_owner=REMOTE, htlc_proposer=htlc_proposer, htlc_id=htlc_id) + if ctx_owner is None: + return in_local and in_remote + elif ctx_owner == LOCAL: + return in_local + elif ctx_owner == REMOTE: + return in_remote + else: + raise Exception(f"unexpected ctx_owner: {ctx_owner!r}") + + @with_lock + def _is_htlc_irrevocably_removed_yet( + self, + *, + ctx_owner: HTLCOwner, + htlc_proposer: HTLCOwner, + htlc_id: int, + ) -> bool: + htlc_id = int(htlc_id) + if htlc_id >= self.get_next_htlc_id(htlc_proposer): + return False + if htlc_id in self.log[htlc_proposer]['settles']: + ctn_of_settle = self.log[htlc_proposer]['settles'][htlc_id][ctx_owner] + else: + ctn_of_settle = None + if htlc_id in self.log[htlc_proposer]['fails']: + ctn_of_fail = self.log[htlc_proposer]['fails'][htlc_id][ctx_owner] + else: + ctn_of_fail = None + ctn_of_rm = ctn_of_settle or ctn_of_fail or None + if ctn_of_rm is None: + return False + return ctn_of_rm <= self.ctn_oldest_unrevoked(ctx_owner) + + @with_lock def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction, ctn: int = None) -> Dict[int, UpdateAddHtlc]: """Return the dict of received or sent (depending on direction) HTLCs diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -9,11 +9,12 @@ from collections import OrderedDict, defaultdict import asyncio import os import time -from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union +from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set from datetime import datetime import functools import aiorpcx +from aiorpcx import TaskGroup from .crypto import sha256, sha256d from . import bitcoin, util @@ -74,6 +75,7 @@ class Peer(Logger): self._sent_init = False # type: bool self._received_init = False # type: bool self.initialized = asyncio.Future() + self.got_disconnected = asyncio.Event() self.querying = asyncio.Event() self.transport = transport self.pubkey = pubkey # remote pubkey @@ -98,6 +100,11 @@ class Peer(Logger): self.orphan_channel_updates = OrderedDict() Logger.__init__(self) self.taskgroup = SilentTaskGroup() + # HTLCs offered by REMOTE, that we started removing but are still active: + self.received_htlcs_pending_removal = set() # type: Set[Tuple[Channel, int]] + self.received_htlc_removed_event = asyncio.Event() + self._htlc_switch_iterstart_event = asyncio.Event() + self._htlc_switch_iterdone_event = asyncio.Event() def send_message(self, message_name: str, **kwargs): assert type(message_name) is str @@ -492,6 +499,7 @@ class Peer(Logger): except: pass self.lnworker.peer_closed(self) + self.got_disconnected.set() def is_static_remotekey(self): return self.features.supports(LnFeatures.OPTION_STATIC_REMOTEKEY_OPT) @@ -1575,6 +1583,7 @@ class Peer(Logger): self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}" assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id) + self.received_htlcs_pending_removal.add((chan, htlc_id)) chan.settle_htlc(preimage, htlc_id) self.send_message( "update_fulfill_htlc", @@ -1585,6 +1594,7 @@ class Peer(Logger): def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes): self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}" + self.received_htlcs_pending_removal.add((chan, htlc_id)) chan.fail_htlc(htlc_id) self.send_message( "update_fail_htlc", @@ -1596,9 +1606,10 @@ class Peer(Logger): def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure): self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.") assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}" - chan.fail_htlc(htlc_id) if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32): raise Exception(f"unexpected reason when sending 'update_fail_malformed_htlc': {reason!r}") + self.received_htlcs_pending_removal.add((chan, htlc_id)) + chan.fail_htlc(htlc_id) self.send_message( "update_fail_malformed_htlc", channel_id=chan.channel_id, @@ -1800,8 +1811,13 @@ class Peer(Logger): async def htlc_switch(self): await self.initialized while True: - await asyncio.sleep(0.1) + self._htlc_switch_iterdone_event.set() + self._htlc_switch_iterdone_event.clear() + await asyncio.sleep(0.1) # TODO maybe make this partly event-driven + self._htlc_switch_iterstart_event.set() + self._htlc_switch_iterstart_event.clear() self.ping_if_required() + self._maybe_cleanup_received_htlcs_pending_removal() for chan_id, chan in self.channels.items(): if not chan.can_send_ctx_updates(): continue @@ -1853,6 +1869,29 @@ class Peer(Logger): for htlc_id in done: unfulfilled.pop(htlc_id) + def _maybe_cleanup_received_htlcs_pending_removal(self) -> None: + done = set() + for chan, htlc_id in self.received_htlcs_pending_removal: + if chan.hm.is_htlc_irrevocably_removed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): + done.add((chan, htlc_id)) + if done: + for key in done: + self.received_htlcs_pending_removal.remove(key) + self.received_htlc_removed_event.set() + self.received_htlc_removed_event.clear() + + async def wait_one_htlc_switch_iteration(self) -> None: + """Waits until the HTLC switch does a full iteration or the peer disconnects, + whichever happens first. + """ + async def htlc_switch_iteration(): + await self._htlc_switch_iterstart_event.wait() + await self._htlc_switch_iterdone_event.wait() + + async with TaskGroup(wait=any) as group: + await group.spawn(htlc_switch_iteration()) + await group.spawn(self.got_disconnected.wait()) + async def process_unfulfilled_htlc( self, *, chan: Channel, diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -22,7 +22,7 @@ import urllib.parse import dns.resolver import dns.exception -from aiorpcx import run_in_thread, TaskGroup, NetAddress +from aiorpcx import run_in_thread, TaskGroup, NetAddress, ignore_after from . import constants, util from . import keystore @@ -195,6 +195,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]): self.features = features self.network = None # type: Optional[Network] self.config = None # type: Optional[SimpleConfig] + self.stopping_soon = False # whether we are being shut down util.register_callback(self.on_proxy_changed, ['proxy_set']) @@ -268,6 +269,8 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]): async def _maintain_connectivity(self): while True: await asyncio.sleep(1) + if self.stopping_soon: + return now = time.time() if len(self._peers) >= NUM_PEERS_TARGET: continue @@ -575,6 +578,7 @@ class LNWallet(LNWorker): lnwatcher: Optional['LNWalletWatcher'] MPP_EXPIRY = 120 + TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 # seconds def __init__(self, wallet: 'Abstract_Wallet', xprv): self.wallet = wallet @@ -707,9 +711,32 @@ class LNWallet(LNWorker): asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop) async def stop(self): - await super().stop() - await self.lnwatcher.stop() - self.lnwatcher = None + self.stopping_soon = True + if self.listen_server: # stop accepting new peers + self.listen_server.close() + async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS): + await self.wait_for_received_pending_htlcs_to_get_removed() + await LNWorker.stop(self) + if self.lnwatcher: + await self.lnwatcher.stop() + self.lnwatcher = None + + async def wait_for_received_pending_htlcs_to_get_removed(self): + assert self.stopping_soon is True + # We try to fail pending MPP HTLCs, and wait a bit for them to get removed. + # Note: even without MPP, if we just failed/fulfilled an HTLC, it is good + # to wait a bit for it to become irrevocably removed. + # Note: we don't wait for *all htlcs* to get removed, only for those + # that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed + async with TaskGroup() as group: + for peer in self.peers.values(): + await group.spawn(peer.wait_one_htlc_switch_iteration()) + while True: + if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()): + break + async with TaskGroup(wait=any) as group: + for peer in self.peers.values(): + await group.spawn(peer.received_htlc_removed_event.wait()) def peer_closed(self, peer): for chan in self.channels_for_peer(peer.pubkey).values(): @@ -1635,7 +1662,9 @@ class LNWallet(LNWorker): if not is_accepted and not is_expired: total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set]) - if time.time() - first_timestamp > self.MPP_EXPIRY: + if self.stopping_soon: + is_expired = True # try to time out pending HTLCs before shutting down + elif time.time() - first_timestamp > self.MPP_EXPIRY: is_expired = True elif total == expected_msat: is_accepted = True @@ -1897,6 +1926,8 @@ class LNWallet(LNWorker): async def reestablish_peers_and_channels(self): while True: await asyncio.sleep(1) + if self.stopping_soon: + return for chan in self.channels.values(): if chan.is_closed(): continue diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -10,7 +10,7 @@ from concurrent import futures import unittest from typing import Iterable, NamedTuple, Tuple, List -from aiorpcx import TaskGroup +from aiorpcx import TaskGroup, timeout_after, TaskTimeout from electrum import bitcoin from electrum import constants @@ -113,7 +113,8 @@ class MockWallet: class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): - MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1 + MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1 + TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0 def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name): self.name = name @@ -121,6 +122,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1) self.node_keypair = local_keypair self.network = MockNetwork(tx_queue) + self.taskgroup = TaskGroup() + self.lnwatcher = None + self.listen_server = None self._channels = {chan.channel_id: chan for chan in chans} self.payments = {} self.logs = defaultdict(list) @@ -147,6 +151,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): self.trampoline_forwarding_failures = {} self.inflight_payments = set() self.preimages = {} + self.stopping_soon = False def get_invoice_status(self, key): pass @@ -183,6 +188,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): return self.name async def stop(self): + await LNWallet.stop(self) if self.channel_db: self.channel_db.stop() await self.channel_db.stopped_event.wait() @@ -215,6 +221,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): _calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc is_trampoline_peer = LNWallet.is_trampoline_peer + wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed + on_proxy_changed = LNWallet.on_proxy_changed class MockTransport: @@ -290,13 +298,9 @@ class SquareGraph(NamedTuple): def all_lnworkers(self) -> Iterable[MockLNWallet]: return self.w_a, self.w_b, self.w_c, self.w_d - async def stop_and_cleanup(self): - async with TaskGroup() as group: - for lnworker in self.all_lnworkers(): - await group.spawn(lnworker.stop()) - class PaymentDone(Exception): pass +class TestSuccess(Exception): pass class TestPeer(ElectrumTestCase): @@ -837,6 +841,50 @@ class TestPeer(ElectrumTestCase): self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':3}) @needs_test_with_all_chacha20_implementations + def test_fail_pending_htlcs_on_shutdown(self): + """Alice tries to pay Dave via MPP. Dave receives some HTLCs but not all. + Dave shuts down (stops wallet). + We test if Dave fails the pending HTLCs during shutdown. + """ + graph = self.prepare_chans_and_peers_in_square() + self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL)) + self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL)) + amount_to_pay = 600_000_000_000 + peers = graph.all_peers() + graph.w_d.MPP_EXPIRY = 120 + graph.w_d.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 + async def pay(): + graph.w_d.features |= LnFeatures.BASIC_MPP_OPT + graph.w_b.enable_htlc_forwarding.clear() # Bob will hold forwarded HTLCs + assert graph.w_a.network.channel_db is not None + lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay) + try: + async with timeout_after(0.5): + result, log = await graph.w_a.pay_invoice(pay_req, attempts=1) + except TaskTimeout: + # by now Dave hopefully received some HTLCs: + self.assertTrue(len(graph.chan_dc.hm.htlcs(LOCAL)) > 0) + self.assertTrue(len(graph.chan_dc.hm.htlcs(REMOTE)) > 0) + else: + self.fail(f"pay_invoice finished but was not supposed to. result={result}") + await graph.w_d.stop() + # Dave is supposed to have failed the pending incomplete MPP HTLCs + self.assertEqual(0, len(graph.chan_dc.hm.htlcs(LOCAL))) + self.assertEqual(0, len(graph.chan_dc.hm.htlcs(REMOTE))) + raise TestSuccess() + + async def f(): + async with TaskGroup() as group: + for peer in peers: + await group.spawn(peer._message_loop()) + await group.spawn(peer.htlc_switch()) + await asyncio.sleep(0.2) + await group.spawn(pay()) + + with self.assertRaises(TestSuccess): + run(f()) + + @needs_test_with_all_chacha20_implementations def test_close(self): alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)