commit 6004a047053e97b9f613d8329f690d1fd6d920fb
parent 05e58671c92940620aab36cf5946a0f1dd24c013
Author: ThomasV <thomasv@electrum.org>
Date: Fri, 12 Mar 2021 11:01:07 +0100
Merge pull request #7099 from SomberNight/202103_fail_pending_htlcs_on_shutdown
fail pending htlcs on shutdown
Diffstat:
4 files changed, 182 insertions(+), 15 deletions(-)
diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py
@@ -360,6 +360,55 @@ class HTLCManager:
return ctns[ctx_owner] <= self.ctn_oldest_unrevoked(ctx_owner)
@with_lock
+ def is_htlc_irrevocably_removed_yet(
+ self,
+ *,
+ ctx_owner: HTLCOwner = None,
+ htlc_proposer: HTLCOwner,
+ htlc_id: int,
+ ) -> bool:
+ """Returns whether the removal of an htlc was irrevocably committed to `ctx_owner's` ctx.
+ The removal can either be a fulfill/settle or a fail; they are not distinguished.
+ If `ctx_owner` is None, both parties' ctxs are checked.
+ """
+ in_local = self._is_htlc_irrevocably_removed_yet(
+ ctx_owner=LOCAL, htlc_proposer=htlc_proposer, htlc_id=htlc_id)
+ in_remote = self._is_htlc_irrevocably_removed_yet(
+ ctx_owner=REMOTE, htlc_proposer=htlc_proposer, htlc_id=htlc_id)
+ if ctx_owner is None:
+ return in_local and in_remote
+ elif ctx_owner == LOCAL:
+ return in_local
+ elif ctx_owner == REMOTE:
+ return in_remote
+ else:
+ raise Exception(f"unexpected ctx_owner: {ctx_owner!r}")
+
+ @with_lock
+ def _is_htlc_irrevocably_removed_yet(
+ self,
+ *,
+ ctx_owner: HTLCOwner,
+ htlc_proposer: HTLCOwner,
+ htlc_id: int,
+ ) -> bool:
+ htlc_id = int(htlc_id)
+ if htlc_id >= self.get_next_htlc_id(htlc_proposer):
+ return False
+ if htlc_id in self.log[htlc_proposer]['settles']:
+ ctn_of_settle = self.log[htlc_proposer]['settles'][htlc_id][ctx_owner]
+ else:
+ ctn_of_settle = None
+ if htlc_id in self.log[htlc_proposer]['fails']:
+ ctn_of_fail = self.log[htlc_proposer]['fails'][htlc_id][ctx_owner]
+ else:
+ ctn_of_fail = None
+ ctn_of_rm = ctn_of_settle or ctn_of_fail or None
+ if ctn_of_rm is None:
+ return False
+ return ctn_of_rm <= self.ctn_oldest_unrevoked(ctx_owner)
+
+ @with_lock
def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Dict[int, UpdateAddHtlc]:
"""Return the dict of received or sent (depending on direction) HTLCs
diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
@@ -9,11 +9,12 @@ from collections import OrderedDict, defaultdict
import asyncio
import os
import time
-from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union
+from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set
from datetime import datetime
import functools
import aiorpcx
+from aiorpcx import TaskGroup
from .crypto import sha256, sha256d
from . import bitcoin, util
@@ -74,6 +75,7 @@ class Peer(Logger):
self._sent_init = False # type: bool
self._received_init = False # type: bool
self.initialized = asyncio.Future()
+ self.got_disconnected = asyncio.Event()
self.querying = asyncio.Event()
self.transport = transport
self.pubkey = pubkey # remote pubkey
@@ -98,6 +100,11 @@ class Peer(Logger):
self.orphan_channel_updates = OrderedDict()
Logger.__init__(self)
self.taskgroup = SilentTaskGroup()
+ # HTLCs offered by REMOTE, that we started removing but are still active:
+ self.received_htlcs_pending_removal = set() # type: Set[Tuple[Channel, int]]
+ self.received_htlc_removed_event = asyncio.Event()
+ self._htlc_switch_iterstart_event = asyncio.Event()
+ self._htlc_switch_iterdone_event = asyncio.Event()
def send_message(self, message_name: str, **kwargs):
assert type(message_name) is str
@@ -492,6 +499,7 @@ class Peer(Logger):
except:
pass
self.lnworker.peer_closed(self)
+ self.got_disconnected.set()
def is_static_remotekey(self):
return self.features.supports(LnFeatures.OPTION_STATIC_REMOTEKEY_OPT)
@@ -1575,6 +1583,7 @@ class Peer(Logger):
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
assert chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id)
+ self.received_htlcs_pending_removal.add((chan, htlc_id))
chan.settle_htlc(preimage, htlc_id)
self.send_message(
"update_fulfill_htlc",
@@ -1585,6 +1594,7 @@ class Peer(Logger):
def fail_htlc(self, *, chan: Channel, htlc_id: int, error_bytes: bytes):
self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
+ self.received_htlcs_pending_removal.add((chan, htlc_id))
chan.fail_htlc(htlc_id)
self.send_message(
"update_fail_htlc",
@@ -1596,9 +1606,10 @@ class Peer(Logger):
def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRoutingFailure):
self.logger.info(f"fail_malformed_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}.")
assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}"
- chan.fail_htlc(htlc_id)
if not (reason.code & OnionFailureCodeMetaFlag.BADONION and len(reason.data) == 32):
raise Exception(f"unexpected reason when sending 'update_fail_malformed_htlc': {reason!r}")
+ self.received_htlcs_pending_removal.add((chan, htlc_id))
+ chan.fail_htlc(htlc_id)
self.send_message(
"update_fail_malformed_htlc",
channel_id=chan.channel_id,
@@ -1800,8 +1811,13 @@ class Peer(Logger):
async def htlc_switch(self):
await self.initialized
while True:
- await asyncio.sleep(0.1)
+ self._htlc_switch_iterdone_event.set()
+ self._htlc_switch_iterdone_event.clear()
+ await asyncio.sleep(0.1) # TODO maybe make this partly event-driven
+ self._htlc_switch_iterstart_event.set()
+ self._htlc_switch_iterstart_event.clear()
self.ping_if_required()
+ self._maybe_cleanup_received_htlcs_pending_removal()
for chan_id, chan in self.channels.items():
if not chan.can_send_ctx_updates():
continue
@@ -1853,6 +1869,29 @@ class Peer(Logger):
for htlc_id in done:
unfulfilled.pop(htlc_id)
+ def _maybe_cleanup_received_htlcs_pending_removal(self) -> None:
+ done = set()
+ for chan, htlc_id in self.received_htlcs_pending_removal:
+ if chan.hm.is_htlc_irrevocably_removed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id):
+ done.add((chan, htlc_id))
+ if done:
+ for key in done:
+ self.received_htlcs_pending_removal.remove(key)
+ self.received_htlc_removed_event.set()
+ self.received_htlc_removed_event.clear()
+
+ async def wait_one_htlc_switch_iteration(self) -> None:
+ """Waits until the HTLC switch does a full iteration or the peer disconnects,
+ whichever happens first.
+ """
+ async def htlc_switch_iteration():
+ await self._htlc_switch_iterstart_event.wait()
+ await self._htlc_switch_iterdone_event.wait()
+
+ async with TaskGroup(wait=any) as group:
+ await group.spawn(htlc_switch_iteration())
+ await group.spawn(self.got_disconnected.wait())
+
async def process_unfulfilled_htlc(
self, *,
chan: Channel,
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -22,7 +22,7 @@ import urllib.parse
import dns.resolver
import dns.exception
-from aiorpcx import run_in_thread, TaskGroup, NetAddress
+from aiorpcx import run_in_thread, TaskGroup, NetAddress, ignore_after
from . import constants, util
from . import keystore
@@ -195,6 +195,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
self.features = features
self.network = None # type: Optional[Network]
self.config = None # type: Optional[SimpleConfig]
+ self.stopping_soon = False # whether we are being shut down
util.register_callback(self.on_proxy_changed, ['proxy_set'])
@@ -268,6 +269,8 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
async def _maintain_connectivity(self):
while True:
await asyncio.sleep(1)
+ if self.stopping_soon:
+ return
now = time.time()
if len(self._peers) >= NUM_PEERS_TARGET:
continue
@@ -575,6 +578,7 @@ class LNWallet(LNWorker):
lnwatcher: Optional['LNWalletWatcher']
MPP_EXPIRY = 120
+ TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 # seconds
def __init__(self, wallet: 'Abstract_Wallet', xprv):
self.wallet = wallet
@@ -707,9 +711,32 @@ class LNWallet(LNWorker):
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
async def stop(self):
- await super().stop()
- await self.lnwatcher.stop()
- self.lnwatcher = None
+ self.stopping_soon = True
+ if self.listen_server: # stop accepting new peers
+ self.listen_server.close()
+ async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS):
+ await self.wait_for_received_pending_htlcs_to_get_removed()
+ await LNWorker.stop(self)
+ if self.lnwatcher:
+ await self.lnwatcher.stop()
+ self.lnwatcher = None
+
+ async def wait_for_received_pending_htlcs_to_get_removed(self):
+ assert self.stopping_soon is True
+ # We try to fail pending MPP HTLCs, and wait a bit for them to get removed.
+ # Note: even without MPP, if we just failed/fulfilled an HTLC, it is good
+ # to wait a bit for it to become irrevocably removed.
+ # Note: we don't wait for *all htlcs* to get removed, only for those
+ # that we can already fail/fulfill. e.g. forwarded htlcs cannot be removed
+ async with TaskGroup() as group:
+ for peer in self.peers.values():
+ await group.spawn(peer.wait_one_htlc_switch_iteration())
+ while True:
+ if all(not peer.received_htlcs_pending_removal for peer in self.peers.values()):
+ break
+ async with TaskGroup(wait=any) as group:
+ for peer in self.peers.values():
+ await group.spawn(peer.received_htlc_removed_event.wait())
def peer_closed(self, peer):
for chan in self.channels_for_peer(peer.pubkey).values():
@@ -1635,7 +1662,9 @@ class LNWallet(LNWorker):
if not is_accepted and not is_expired:
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set])
- if time.time() - first_timestamp > self.MPP_EXPIRY:
+ if self.stopping_soon:
+ is_expired = True # try to time out pending HTLCs before shutting down
+ elif time.time() - first_timestamp > self.MPP_EXPIRY:
is_expired = True
elif total == expected_msat:
is_accepted = True
@@ -1897,6 +1926,8 @@ class LNWallet(LNWorker):
async def reestablish_peers_and_channels(self):
while True:
await asyncio.sleep(1)
+ if self.stopping_soon:
+ return
for chan in self.channels.values():
if chan.is_closed():
continue
diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
@@ -10,7 +10,7 @@ from concurrent import futures
import unittest
from typing import Iterable, NamedTuple, Tuple, List
-from aiorpcx import TaskGroup
+from aiorpcx import TaskGroup, timeout_after, TaskTimeout
from electrum import bitcoin
from electrum import constants
@@ -113,7 +113,8 @@ class MockWallet:
class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
- MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
+ MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1
+ TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0
def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name):
self.name = name
@@ -121,6 +122,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1)
self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue)
+ self.taskgroup = TaskGroup()
+ self.lnwatcher = None
+ self.listen_server = None
self._channels = {chan.channel_id: chan for chan in chans}
self.payments = {}
self.logs = defaultdict(list)
@@ -147,6 +151,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
self.trampoline_forwarding_failures = {}
self.inflight_payments = set()
self.preimages = {}
+ self.stopping_soon = False
def get_invoice_status(self, key):
pass
@@ -183,6 +188,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
return self.name
async def stop(self):
+ await LNWallet.stop(self)
if self.channel_db:
self.channel_db.stop()
await self.channel_db.stopped_event.wait()
@@ -215,6 +221,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
_calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice
handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc
is_trampoline_peer = LNWallet.is_trampoline_peer
+ wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed
+ on_proxy_changed = LNWallet.on_proxy_changed
class MockTransport:
@@ -290,13 +298,9 @@ class SquareGraph(NamedTuple):
def all_lnworkers(self) -> Iterable[MockLNWallet]:
return self.w_a, self.w_b, self.w_c, self.w_d
- async def stop_and_cleanup(self):
- async with TaskGroup() as group:
- for lnworker in self.all_lnworkers():
- await group.spawn(lnworker.stop())
-
class PaymentDone(Exception): pass
+class TestSuccess(Exception): pass
class TestPeer(ElectrumTestCase):
@@ -837,6 +841,50 @@ class TestPeer(ElectrumTestCase):
self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':3})
@needs_test_with_all_chacha20_implementations
+ def test_fail_pending_htlcs_on_shutdown(self):
+ """Alice tries to pay Dave via MPP. Dave receives some HTLCs but not all.
+ Dave shuts down (stops wallet).
+ We test if Dave fails the pending HTLCs during shutdown.
+ """
+ graph = self.prepare_chans_and_peers_in_square()
+ self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL))
+ self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL))
+ amount_to_pay = 600_000_000_000
+ peers = graph.all_peers()
+ graph.w_d.MPP_EXPIRY = 120
+ graph.w_d.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3
+ async def pay():
+ graph.w_d.features |= LnFeatures.BASIC_MPP_OPT
+ graph.w_b.enable_htlc_forwarding.clear() # Bob will hold forwarded HTLCs
+ assert graph.w_a.network.channel_db is not None
+ lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay)
+ try:
+ async with timeout_after(0.5):
+ result, log = await graph.w_a.pay_invoice(pay_req, attempts=1)
+ except TaskTimeout:
+ # by now Dave hopefully received some HTLCs:
+ self.assertTrue(len(graph.chan_dc.hm.htlcs(LOCAL)) > 0)
+ self.assertTrue(len(graph.chan_dc.hm.htlcs(REMOTE)) > 0)
+ else:
+ self.fail(f"pay_invoice finished but was not supposed to. result={result}")
+ await graph.w_d.stop()
+ # Dave is supposed to have failed the pending incomplete MPP HTLCs
+ self.assertEqual(0, len(graph.chan_dc.hm.htlcs(LOCAL)))
+ self.assertEqual(0, len(graph.chan_dc.hm.htlcs(REMOTE)))
+ raise TestSuccess()
+
+ async def f():
+ async with TaskGroup() as group:
+ for peer in peers:
+ await group.spawn(peer._message_loop())
+ await group.spawn(peer.htlc_switch())
+ await asyncio.sleep(0.2)
+ await group.spawn(pay())
+
+ with self.assertRaises(TestSuccess):
+ run(f())
+
+ @needs_test_with_all_chacha20_implementations
def test_close(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)