electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit cc4029c335dea48a93c3a78fa3b27262d34458d9
parent 7153e753d1829ccdf1f0ee6821082c1f13ba0f21
Author: SomberNight <somber.night@protonmail.com>
Date:   Wed,  6 May 2020 11:00:58 +0200

test_lnpeer: add some multi-hop payment unit tests

Diffstat:
Melectrum/lnpeer.py | 4+++-
Melectrum/lnworker.py | 1+
Melectrum/tests/test_lnpeer.py | 184++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
3 files changed, 186 insertions(+), 3 deletions(-)

diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -1510,7 +1510,9 @@ class Peer(Logger): self.logger.info(f"error processing onion packet: {e!r}") error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') else: - if processed_onion.are_we_final: + if self.lnworker._fail_htlcs_with_temp_node_failure: + error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') + elif processed_onion.are_we_final: preimage, error_reason = self.maybe_fulfill_htlc( chan=chan, htlc=htlc, diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -494,6 +494,7 @@ class LNWallet(LNWorker): # used in tests self.enable_htlc_settle = asyncio.Event() self.enable_htlc_settle.set() + self._fail_htlcs_with_temp_node_failure = False # note: accessing channels (besides simple lookup) needs self.lock! self._channels = {} # type: Dict[bytes, Channel] diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -8,7 +8,7 @@ import logging import concurrent from concurrent import futures import unittest -from typing import Iterable +from typing import Iterable, NamedTuple from aiorpcx import TaskGroup @@ -24,12 +24,13 @@ from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner from electrum.lnchannel import ChannelState, PeerState, Channel -from electrum.lnrouter import LNPathFinder +from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent from electrum.channel_db import ChannelDB from electrum.lnworker import LNWallet, NoPathFound from electrum.lnmsg import encode_msg, decode_msg from electrum.logging import console_stderr_handler, Logger from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID +from electrum.lnonion import OnionFailureCode from .test_lnchannel import create_test_channels from .test_bitcoin import needs_test_with_all_chacha20_implementations @@ -117,6 +118,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): # used in tests self.enable_htlc_settle = asyncio.Event() self.enable_htlc_settle.set() + self._fail_htlcs_with_temp_node_failure = False def get_invoice_status(self, key): pass @@ -212,6 +214,37 @@ def transport_pair(k1, k2, name1, name2): return t1, t2 +class DiamondGraph(NamedTuple): + # A + # / \ + # B C + # \ / + # D + w_a: MockLNWallet + w_b: MockLNWallet + w_c: MockLNWallet + w_d: MockLNWallet + peer_ab: Peer + peer_ac: Peer + peer_ba: Peer + peer_bd: Peer + peer_ca: Peer + peer_cd: Peer + peer_db: Peer + peer_dc: Peer + chan_ab: Channel + chan_ac: Channel + chan_ba: Channel + chan_bd: Channel + chan_ca: Channel + chan_cd: Channel + chan_db: Channel + chan_dc: Channel + + def all_peers(self) -> Iterable[Peer]: + return self.peer_ab, self.peer_ac, self.peer_ba, self.peer_bd, self.peer_ca, self.peer_cd, self.peer_db, self.peer_dc + + class PaymentDone(Exception): pass @@ -252,6 +285,77 @@ class TestPeer(ElectrumTestCase): p2.mark_open(bob_channel) return p1, p2, w1, w2, q1, q2 + def prepare_chans_and_peers_in_diamond(self) -> DiamondGraph: + key_a, key_b, key_c, key_d = [keypair() for i in range(4)] + chan_ab, chan_ba = create_test_channels(alice_name="alice", bob_name="bob", alice_pubkey=key_a.pubkey, bob_pubkey=key_b.pubkey) + chan_ac, chan_ca = create_test_channels(alice_name="alice", bob_name="carol", alice_pubkey=key_a.pubkey, bob_pubkey=key_c.pubkey) + chan_bd, chan_db = create_test_channels(alice_name="bob", bob_name="dave", alice_pubkey=key_b.pubkey, bob_pubkey=key_d.pubkey) + chan_cd, chan_dc = create_test_channels(alice_name="carol", bob_name="dave", alice_pubkey=key_c.pubkey, bob_pubkey=key_d.pubkey) + trans_ab, trans_ba = transport_pair(key_a, key_b, chan_ab.name, chan_ba.name) + trans_ac, trans_ca = transport_pair(key_a, key_c, chan_ac.name, chan_ca.name) + trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name) + trans_cd, trans_dc = transport_pair(key_c, key_d, chan_cd.name, chan_dc.name) + txq_a, txq_b, txq_c, txq_d = [asyncio.Queue() for i in range(4)] + w_a = MockLNWallet(local_keypair=key_a, chans=[chan_ab, chan_ac], tx_queue=txq_a) + w_b = MockLNWallet(local_keypair=key_b, chans=[chan_ba, chan_bd], tx_queue=txq_b) + w_c = MockLNWallet(local_keypair=key_c, chans=[chan_ca, chan_cd], tx_queue=txq_c) + w_d = MockLNWallet(local_keypair=key_d, chans=[chan_db, chan_dc], tx_queue=txq_d) + peer_ab = Peer(w_a, key_b.pubkey, trans_ab) + peer_ac = Peer(w_a, key_c.pubkey, trans_ac) + peer_ba = Peer(w_b, key_a.pubkey, trans_ba) + peer_bd = Peer(w_b, key_d.pubkey, trans_bd) + peer_ca = Peer(w_c, key_a.pubkey, trans_ca) + peer_cd = Peer(w_c, key_d.pubkey, trans_cd) + peer_db = Peer(w_d, key_b.pubkey, trans_db) + peer_dc = Peer(w_d, key_c.pubkey, trans_dc) + w_a._peers[peer_ab.pubkey] = peer_ab + w_a._peers[peer_ac.pubkey] = peer_ac + w_b._peers[peer_ba.pubkey] = peer_ba + w_b._peers[peer_bd.pubkey] = peer_bd + w_c._peers[peer_ca.pubkey] = peer_ca + w_c._peers[peer_cd.pubkey] = peer_cd + w_d._peers[peer_db.pubkey] = peer_db + w_d._peers[peer_dc.pubkey] = peer_dc + + w_b.network.config.set_key('lightning_forward_payments', True) + w_c.network.config.set_key('lightning_forward_payments', True) + + # mark_open won't work if state is already OPEN. + # so set it to FUNDED + for chan in [chan_ab, chan_ac, chan_ba, chan_bd, chan_ca, chan_cd, chan_db, chan_dc]: + chan._state = ChannelState.FUNDED + # this populates the channel graph: + peer_ab.mark_open(chan_ab) + peer_ac.mark_open(chan_ac) + peer_ba.mark_open(chan_ba) + peer_bd.mark_open(chan_bd) + peer_ca.mark_open(chan_ca) + peer_cd.mark_open(chan_cd) + peer_db.mark_open(chan_db) + peer_dc.mark_open(chan_dc) + return DiamondGraph( + w_a=w_a, + w_b=w_b, + w_c=w_c, + w_d=w_d, + peer_ab=peer_ab, + peer_ac=peer_ac, + peer_ba=peer_ba, + peer_bd=peer_bd, + peer_ca=peer_ca, + peer_cd=peer_cd, + peer_db=peer_db, + peer_dc=peer_dc, + chan_ab=chan_ab, + chan_ac=chan_ac, + chan_ba=chan_ba, + chan_bd=chan_bd, + chan_ca=chan_ca, + chan_cd=chan_cd, + chan_db=chan_db, + chan_dc=chan_dc, + ) + @staticmethod async def prepare_invoice( w2: MockLNWallet, # receiver @@ -383,6 +487,82 @@ class TestPeer(ElectrumTestCase): self.assertEqual(bob_init_balance_msat + num_payments * payment_value_sat * 1000, alice_channel.balance(HTLCOwner.REMOTE)) @needs_test_with_all_chacha20_implementations + def test_payment_multihop(self): + graph = self.prepare_chans_and_peers_in_diamond() + peers = graph.all_peers() + async def pay(pay_req): + result, log = await graph.w_a._pay(pay_req) + self.assertTrue(result) + raise PaymentDone() + async def f(): + async with TaskGroup() as group: + for peer in peers: + await group.spawn(peer._message_loop()) + await group.spawn(peer.htlc_switch()) + await asyncio.sleep(0.2) + pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) + await group.spawn(pay(pay_req)) + with self.assertRaises(PaymentDone): + run(f()) + + @needs_test_with_all_chacha20_implementations + def test_payment_multihop_with_preselected_path(self): + graph = self.prepare_chans_and_peers_in_diamond() + peers = graph.all_peers() + async def pay(pay_req): + with self.subTest(msg="bad path: edges do not chain together"): + path = [PathEdge(node_id=graph.w_c.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id), + PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)] + result, log = await graph.w_a._pay(pay_req, full_path=path) + self.assertFalse(result) + self.assertTrue(isinstance(log[0].exception, LNPathInconsistent)) + with self.subTest(msg="bad path: last node id differs from invoice pubkey"): + path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id)] + result, log = await graph.w_a._pay(pay_req, full_path=path) + self.assertFalse(result) + self.assertTrue(isinstance(log[0].exception, LNPathInconsistent)) + with self.subTest(msg="good path"): + path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id), + PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)] + result, log = await graph.w_a._pay(pay_req, full_path=path) + self.assertTrue(result) + self.assertEqual([edge.short_channel_id for edge in path], + [edge.short_channel_id for edge in log[0].route]) + raise PaymentDone() + async def f(): + async with TaskGroup() as group: + for peer in peers: + await group.spawn(peer._message_loop()) + await group.spawn(peer.htlc_switch()) + await asyncio.sleep(0.2) + pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) + await group.spawn(pay(pay_req)) + with self.assertRaises(PaymentDone): + run(f()) + + @needs_test_with_all_chacha20_implementations + def test_payment_multihop_temp_node_failure(self): + graph = self.prepare_chans_and_peers_in_diamond() + graph.w_b._fail_htlcs_with_temp_node_failure = True + graph.w_c._fail_htlcs_with_temp_node_failure = True + peers = graph.all_peers() + async def pay(pay_req): + result, log = await graph.w_a._pay(pay_req) + self.assertFalse(result) + self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_details.failure_msg.code) + raise PaymentDone() + async def f(): + async with TaskGroup() as group: + for peer in peers: + await group.spawn(peer._message_loop()) + await group.spawn(peer.htlc_switch()) + await asyncio.sleep(0.2) + pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) + await group.spawn(pay(pay_req)) + with self.assertRaises(PaymentDone): + run(f()) + + @needs_test_with_all_chacha20_implementations def test_close(self): alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)