electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 549b9a95df6fedbf0880df139a99f419931150e1
parent d4de25a8cde35ee47a7594194eef794a1b4f7b2b
Author: ThomasV <thomasv@electrum.org>
Date:   Wed, 10 Mar 2021 17:09:07 +0100

test_lnpeer: add test for mpp_timeout

Diffstat:
Melectrum/lnpeer.py | 6++++--
Melectrum/lnworker.py | 6++++--
Melectrum/tests/test_lnpeer.py | 46+++++++++++++++++++++++++++++++---------------
3 files changed, 39 insertions(+), 19 deletions(-)

diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -1820,7 +1820,7 @@ class Peer(Logger): error_reason = e else: try: - preimage, fw_info, error_bytes = self.process_unfulfilled_htlc( + preimage, fw_info, error_bytes = await self.process_unfulfilled_htlc( chan=chan, htlc=htlc, forwarding_info=forwarding_info, @@ -1849,7 +1849,7 @@ class Peer(Logger): for htlc_id in done: unfulfilled.pop(htlc_id) - def process_unfulfilled_htlc( + async def process_unfulfilled_htlc( self, *, chan: Channel, htlc: UpdateAddHtlc, @@ -1885,6 +1885,7 @@ class Peer(Logger): processed_onion=trampoline_onion, is_trampoline=True) else: + await self.lnworker.enable_htlc_forwarding.wait() self.maybe_forward_trampoline( chan=chan, htlc=htlc, @@ -1899,6 +1900,7 @@ class Peer(Logger): raise error_reason elif not forwarding_info: + await self.lnworker.enable_htlc_forwarding.wait() next_chan_id, next_htlc_id = self.maybe_forward_htlc( htlc=htlc, processed_onion=processed_onion) diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -91,7 +91,6 @@ SAVED_PR_STATUS = [PR_PAID, PR_UNPAID] # status that are persisted NUM_PEERS_TARGET = 4 -MPP_EXPIRY = 120 FALLBACK_NODE_LIST_TESTNET = ( @@ -575,6 +574,7 @@ class LNGossip(LNWorker): class LNWallet(LNWorker): lnwatcher: Optional['LNWalletWatcher'] + MPP_EXPIRY = 120 def __init__(self, wallet: 'Abstract_Wallet', xprv): self.wallet = wallet @@ -592,6 +592,8 @@ class LNWallet(LNWorker): # used in tests self.enable_htlc_settle = asyncio.Event() self.enable_htlc_settle.set() + self.enable_htlc_forwarding = asyncio.Event() + self.enable_htlc_forwarding.set() # note: accessing channels (besides simple lookup) needs self.lock! self._channels = {} # type: Dict[bytes, Channel] @@ -1633,7 +1635,7 @@ class LNWallet(LNWorker): if not is_accepted and not is_expired: total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set]) - if time.time() - first_timestamp > MPP_EXPIRY: + if time.time() - first_timestamp > self.MPP_EXPIRY: is_expired = True elif total == expected_msat: is_accepted = True diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -113,6 +113,7 @@ class MockWallet: class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): + MPP_EXPIRY = 1 def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name): self.name = name Logger.__init__(self) @@ -136,6 +137,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): # used in tests self.enable_htlc_settle = asyncio.Event() self.enable_htlc_settle.set() + self.enable_htlc_forwarding = asyncio.Event() + self.enable_htlc_forwarding.set() self.received_htlcs = dict() self.sent_htlcs = defaultdict(asyncio.Queue) self.sent_htlcs_routes = dict() @@ -747,7 +750,7 @@ class TestPeer(ElectrumTestCase): with self.assertRaises(PaymentDone): run(f()) - def _test_multipart_payment(self, graph, *, attempts): + async def _run_mpp(self, graph, *, attempts): self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL)) self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL)) amount_to_pay = 600_000_000_000 @@ -761,32 +764,45 @@ class TestPeer(ElectrumTestCase): raise PaymentDone() else: raise NoPathFound() - async def f(): - async with TaskGroup() as group: - for peer in peers: - await group.spawn(peer._message_loop()) - await group.spawn(peer.htlc_switch()) - await asyncio.sleep(0.2) - await group.spawn(pay()) - self.assertFalse(graph.w_d.features.supports(LnFeatures.BASIC_MPP_OPT)) - with self.assertRaises(NoPathFound): - run(f()) + async with TaskGroup() as group: + for peer in peers: + await group.spawn(peer._message_loop()) + await group.spawn(peer.htlc_switch()) + await asyncio.sleep(0.2) + await group.spawn(pay()) + + @needs_test_with_all_chacha20_implementations + def test_multipart_payment_with_timeout(self): + graph = self.prepare_chans_and_peers_in_square() graph.w_d.features |= LnFeatures.BASIC_MPP_OPT + graph.w_b.enable_htlc_forwarding.clear() + with self.assertRaises(NoPathFound): + run(self._run_mpp(graph, attempts=1)) + graph.w_b.enable_htlc_forwarding.set() with self.assertRaises(PaymentDone): - run(f()) + run(self._run_mpp(graph, attempts=1)) @needs_test_with_all_chacha20_implementations def test_multipart_payment(self): graph = self.prepare_chans_and_peers_in_square() - self._test_multipart_payment(graph, attempts=1) + self.assertFalse(graph.w_d.features.supports(LnFeatures.BASIC_MPP_OPT)) + with self.assertRaises(NoPathFound): + run(self._run_mpp(graph, attempts=1)) + graph.w_d.features |= LnFeatures.BASIC_MPP_OPT + with self.assertRaises(PaymentDone): + run(self._run_mpp(graph, attempts=1)) @needs_test_with_all_chacha20_implementations def test_multipart_payment_with_trampoline(self): graph = self.prepare_chans_and_peers_in_square() + graph.w_d.features |= LnFeatures.BASIC_MPP_OPT graph.w_a.network.channel_db.stop() graph.w_a.network.channel_db = None - # Note: first attempt will fail with insufficient trampoline fee - self._test_multipart_payment(graph, attempts=3) + # Note: single attempt will fail with insufficient trampoline fee + with self.assertRaises(NoPathFound): + run(self._run_mpp(graph, attempts=1)) + with self.assertRaises(PaymentDone): + run(self._run_mpp(graph, attempts=3)) @needs_test_with_all_chacha20_implementations def test_close(self):