test_lnpeer.py (46112B)
1 import asyncio 2 import tempfile 3 from decimal import Decimal 4 import os 5 from contextlib import contextmanager 6 from collections import defaultdict 7 import logging 8 import concurrent 9 from concurrent import futures 10 import unittest 11 from typing import Iterable, NamedTuple, Tuple, List 12 13 from aiorpcx import TaskGroup, timeout_after, TaskTimeout 14 15 from electrum import bitcoin 16 from electrum import constants 17 from electrum.network import Network 18 from electrum.ecc import ECPrivkey 19 from electrum import simple_config, lnutil 20 from electrum.lnaddr import lnencode, LnAddr, lndecode 21 from electrum.bitcoin import COIN, sha256 22 from electrum.util import bh2u, create_and_start_event_loop, NetworkRetryManager, bfh 23 from electrum.lnpeer import Peer, UpfrontShutdownScriptViolation 24 from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey 25 from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving 26 from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner 27 from electrum.lnchannel import ChannelState, PeerState, Channel 28 from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent 29 from electrum.channel_db import ChannelDB 30 from electrum.lnworker import LNWallet, NoPathFound 31 from electrum.lnmsg import encode_msg, decode_msg 32 from electrum.logging import console_stderr_handler, Logger 33 from electrum.lnworker import PaymentInfo, RECEIVED 34 from electrum.lnonion import OnionFailureCode 35 from electrum.lnutil import ChannelBlackList, derive_payment_secret_from_payment_preimage 36 from electrum.lnutil import LOCAL, REMOTE 37 from electrum.invoices import PR_PAID, PR_UNPAID 38 39 from .test_lnchannel import create_test_channels 40 from .test_bitcoin import needs_test_with_all_chacha20_implementations 41 from . import ElectrumTestCase 42 43 def keypair(): 44 priv = ECPrivkey.generate_random_key().get_secret_bytes() 45 k1 = Keypair( 46 pubkey=privkey_to_pubkey(priv), 47 privkey=priv) 48 return k1 49 50 @contextmanager 51 def noop_lock(): 52 yield 53 54 class MockNetwork: 55 def __init__(self, tx_queue): 56 self.callbacks = defaultdict(list) 57 self.lnwatcher = None 58 self.interface = None 59 user_config = {} 60 user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-") 61 self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir) 62 self.asyncio_loop = asyncio.get_event_loop() 63 self.channel_db = ChannelDB(self) 64 self.channel_db.data_loaded.set() 65 self.path_finder = LNPathFinder(self.channel_db) 66 self.tx_queue = tx_queue 67 self._blockchain = MockBlockchain() 68 self.channel_blacklist = ChannelBlackList() 69 70 @property 71 def callback_lock(self): 72 return noop_lock() 73 74 def get_local_height(self): 75 return 0 76 77 def blockchain(self): 78 return self._blockchain 79 80 async def broadcast_transaction(self, tx): 81 if self.tx_queue: 82 await self.tx_queue.put(tx) 83 84 async def try_broadcasting(self, tx, name): 85 await self.broadcast_transaction(tx) 86 87 88 class MockBlockchain: 89 90 def height(self): 91 return 0 92 93 def is_tip_stale(self): 94 return False 95 96 97 class MockWallet: 98 99 def set_label(self, x, y): 100 pass 101 102 def save_db(self): 103 pass 104 105 def add_transaction(self, tx): 106 pass 107 108 def is_lightning_backup(self): 109 return False 110 111 def is_mine(self, addr): 112 return True 113 114 115 class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): 116 MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1 117 TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0 118 119 def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name): 120 self.name = name 121 Logger.__init__(self) 122 NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1) 123 self.node_keypair = local_keypair 124 self.network = MockNetwork(tx_queue) 125 self.taskgroup = TaskGroup() 126 self.lnwatcher = None 127 self.listen_server = None 128 self._channels = {chan.channel_id: chan for chan in chans} 129 self.payments = {} 130 self.logs = defaultdict(list) 131 self.wallet = MockWallet() 132 self.features = LnFeatures(0) 133 self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT 134 self.features |= LnFeatures.OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT 135 self.features |= LnFeatures.VAR_ONION_OPT 136 self.features |= LnFeatures.PAYMENT_SECRET_OPT 137 self.features |= LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT 138 self.pending_payments = defaultdict(asyncio.Future) 139 for chan in chans: 140 chan.lnworker = self 141 self._peers = {} # bytes -> Peer 142 # used in tests 143 self.enable_htlc_settle = asyncio.Event() 144 self.enable_htlc_settle.set() 145 self.enable_htlc_forwarding = asyncio.Event() 146 self.enable_htlc_forwarding.set() 147 self.received_mpp_htlcs = dict() 148 self.sent_htlcs = defaultdict(asyncio.Queue) 149 self.sent_htlcs_routes = dict() 150 self.sent_buckets = defaultdict(set) 151 self.trampoline_forwarding_failures = {} 152 self.inflight_payments = set() 153 self.preimages = {} 154 self.stopping_soon = False 155 156 def get_invoice_status(self, key): 157 pass 158 159 @property 160 def lock(self): 161 return noop_lock() 162 163 @property 164 def channel_db(self): 165 return self.network.channel_db if self.network else None 166 167 @property 168 def channels(self): 169 return self._channels 170 171 @property 172 def peers(self): 173 return self._peers 174 175 def get_channel_by_short_id(self, short_channel_id): 176 with self.lock: 177 for chan in self._channels.values(): 178 if chan.short_channel_id == short_channel_id: 179 return chan 180 181 def channel_state_changed(self, chan): 182 pass 183 184 def save_channel(self, chan): 185 print("Ignoring channel save") 186 187 def diagnostic_name(self): 188 return self.name 189 190 async def stop(self): 191 await LNWallet.stop(self) 192 if self.channel_db: 193 self.channel_db.stop() 194 await self.channel_db.stopped_event.wait() 195 196 get_payments = LNWallet.get_payments 197 get_payment_info = LNWallet.get_payment_info 198 save_payment_info = LNWallet.save_payment_info 199 set_invoice_status = LNWallet.set_invoice_status 200 set_request_status = LNWallet.set_request_status 201 set_payment_status = LNWallet.set_payment_status 202 get_payment_status = LNWallet.get_payment_status 203 check_received_mpp_htlc = LNWallet.check_received_mpp_htlc 204 htlc_fulfilled = LNWallet.htlc_fulfilled 205 htlc_failed = LNWallet.htlc_failed 206 save_preimage = LNWallet.save_preimage 207 get_preimage = LNWallet.get_preimage 208 create_route_for_payment = LNWallet.create_route_for_payment 209 create_routes_for_payment = LNWallet.create_routes_for_payment 210 create_routes_from_invoice = LNWallet.create_routes_from_invoice 211 _check_invoice = staticmethod(LNWallet._check_invoice) 212 pay_to_route = LNWallet.pay_to_route 213 pay_to_node = LNWallet.pay_to_node 214 pay_invoice = LNWallet.pay_invoice 215 force_close_channel = LNWallet.force_close_channel 216 try_force_closing = LNWallet.try_force_closing 217 get_first_timestamp = lambda self: 0 218 on_peer_successfully_established = LNWallet.on_peer_successfully_established 219 get_channel_by_id = LNWallet.get_channel_by_id 220 channels_for_peer = LNWallet.channels_for_peer 221 _calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice 222 handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc 223 is_trampoline_peer = LNWallet.is_trampoline_peer 224 wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed 225 on_proxy_changed = LNWallet.on_proxy_changed 226 227 228 class MockTransport: 229 def __init__(self, name): 230 self.queue = asyncio.Queue() 231 self._name = name 232 233 def name(self): 234 return self._name 235 236 async def read_messages(self): 237 while True: 238 yield await self.queue.get() 239 240 class NoFeaturesTransport(MockTransport): 241 """ 242 This answers the init message with a init that doesn't signal any features. 243 Used for testing that we require DATA_LOSS_PROTECT. 244 """ 245 def send_bytes(self, data): 246 decoded = decode_msg(data) 247 print(decoded) 248 if decoded[0] == 'init': 249 self.queue.put_nowait(encode_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00")) 250 251 class PutIntoOthersQueueTransport(MockTransport): 252 def __init__(self, keypair, name): 253 super().__init__(name) 254 self.other_mock_transport = None 255 self.privkey = keypair.privkey 256 257 def send_bytes(self, data): 258 self.other_mock_transport.queue.put_nowait(data) 259 260 def transport_pair(k1, k2, name1, name2): 261 t1 = PutIntoOthersQueueTransport(k1, name2) 262 t2 = PutIntoOthersQueueTransport(k2, name1) 263 t1.other_mock_transport = t2 264 t2.other_mock_transport = t1 265 return t1, t2 266 267 268 class SquareGraph(NamedTuple): 269 # A 270 # high fee / \ low fee 271 # B C 272 # high fee \ / low fee 273 # D 274 w_a: MockLNWallet 275 w_b: MockLNWallet 276 w_c: MockLNWallet 277 w_d: MockLNWallet 278 peer_ab: Peer 279 peer_ac: Peer 280 peer_ba: Peer 281 peer_bd: Peer 282 peer_ca: Peer 283 peer_cd: Peer 284 peer_db: Peer 285 peer_dc: Peer 286 chan_ab: Channel 287 chan_ac: Channel 288 chan_ba: Channel 289 chan_bd: Channel 290 chan_ca: Channel 291 chan_cd: Channel 292 chan_db: Channel 293 chan_dc: Channel 294 295 def all_peers(self) -> Iterable[Peer]: 296 return self.peer_ab, self.peer_ac, self.peer_ba, self.peer_bd, self.peer_ca, self.peer_cd, self.peer_db, self.peer_dc 297 298 def all_lnworkers(self) -> Iterable[MockLNWallet]: 299 return self.w_a, self.w_b, self.w_c, self.w_d 300 301 302 class PaymentDone(Exception): pass 303 class TestSuccess(Exception): pass 304 305 306 class TestPeer(ElectrumTestCase): 307 308 @classmethod 309 def setUpClass(cls): 310 super().setUpClass() 311 console_stderr_handler.setLevel(logging.DEBUG) 312 313 def setUp(self): 314 super().setUp() 315 self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop() 316 self._lnworkers_created = [] # type: List[MockLNWallet] 317 318 def tearDown(self): 319 async def cleanup_lnworkers(): 320 async with TaskGroup() as group: 321 for lnworker in self._lnworkers_created: 322 await group.spawn(lnworker.stop()) 323 self._lnworkers_created.clear() 324 run(cleanup_lnworkers()) 325 326 self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1) 327 self._loop_thread.join(timeout=1) 328 super().tearDown() 329 330 def prepare_peers(self, alice_channel, bob_channel): 331 k1, k2 = keypair(), keypair() 332 alice_channel.node_id = k2.pubkey 333 bob_channel.node_id = k1.pubkey 334 t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name) 335 q1, q2 = asyncio.Queue(), asyncio.Queue() 336 w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1, name=bob_channel.name) 337 w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2, name=alice_channel.name) 338 self._lnworkers_created.extend([w1, w2]) 339 p1 = Peer(w1, k2.pubkey, t1) 340 p2 = Peer(w2, k1.pubkey, t2) 341 w1._peers[p1.pubkey] = p1 342 w2._peers[p2.pubkey] = p2 343 # mark_open won't work if state is already OPEN. 344 # so set it to FUNDED 345 alice_channel._state = ChannelState.FUNDED 346 bob_channel._state = ChannelState.FUNDED 347 # this populates the channel graph: 348 p1.mark_open(alice_channel) 349 p2.mark_open(bob_channel) 350 return p1, p2, w1, w2, q1, q2 351 352 def prepare_chans_and_peers_in_square(self) -> SquareGraph: 353 key_a, key_b, key_c, key_d = [keypair() for i in range(4)] 354 chan_ab, chan_ba = create_test_channels(alice_name="alice", bob_name="bob", alice_pubkey=key_a.pubkey, bob_pubkey=key_b.pubkey) 355 chan_ac, chan_ca = create_test_channels(alice_name="alice", bob_name="carol", alice_pubkey=key_a.pubkey, bob_pubkey=key_c.pubkey) 356 chan_bd, chan_db = create_test_channels(alice_name="bob", bob_name="dave", alice_pubkey=key_b.pubkey, bob_pubkey=key_d.pubkey) 357 chan_cd, chan_dc = create_test_channels(alice_name="carol", bob_name="dave", alice_pubkey=key_c.pubkey, bob_pubkey=key_d.pubkey) 358 trans_ab, trans_ba = transport_pair(key_a, key_b, chan_ab.name, chan_ba.name) 359 trans_ac, trans_ca = transport_pair(key_a, key_c, chan_ac.name, chan_ca.name) 360 trans_bd, trans_db = transport_pair(key_b, key_d, chan_bd.name, chan_db.name) 361 trans_cd, trans_dc = transport_pair(key_c, key_d, chan_cd.name, chan_dc.name) 362 txq_a, txq_b, txq_c, txq_d = [asyncio.Queue() for i in range(4)] 363 w_a = MockLNWallet(local_keypair=key_a, chans=[chan_ab, chan_ac], tx_queue=txq_a, name="alice") 364 w_b = MockLNWallet(local_keypair=key_b, chans=[chan_ba, chan_bd], tx_queue=txq_b, name="bob") 365 w_c = MockLNWallet(local_keypair=key_c, chans=[chan_ca, chan_cd], tx_queue=txq_c, name="carol") 366 w_d = MockLNWallet(local_keypair=key_d, chans=[chan_db, chan_dc], tx_queue=txq_d, name="dave") 367 self._lnworkers_created.extend([w_a, w_b, w_c, w_d]) 368 peer_ab = Peer(w_a, key_b.pubkey, trans_ab) 369 peer_ac = Peer(w_a, key_c.pubkey, trans_ac) 370 peer_ba = Peer(w_b, key_a.pubkey, trans_ba) 371 peer_bd = Peer(w_b, key_d.pubkey, trans_bd) 372 peer_ca = Peer(w_c, key_a.pubkey, trans_ca) 373 peer_cd = Peer(w_c, key_d.pubkey, trans_cd) 374 peer_db = Peer(w_d, key_b.pubkey, trans_db) 375 peer_dc = Peer(w_d, key_c.pubkey, trans_dc) 376 w_a._peers[peer_ab.pubkey] = peer_ab 377 w_a._peers[peer_ac.pubkey] = peer_ac 378 w_b._peers[peer_ba.pubkey] = peer_ba 379 w_b._peers[peer_bd.pubkey] = peer_bd 380 w_c._peers[peer_ca.pubkey] = peer_ca 381 w_c._peers[peer_cd.pubkey] = peer_cd 382 w_d._peers[peer_db.pubkey] = peer_db 383 w_d._peers[peer_dc.pubkey] = peer_dc 384 385 w_b.network.config.set_key('lightning_forward_payments', True) 386 w_c.network.config.set_key('lightning_forward_payments', True) 387 388 # forwarding fees, etc 389 chan_ab.forwarding_fee_proportional_millionths *= 500 390 chan_ab.forwarding_fee_base_msat *= 500 391 chan_ba.forwarding_fee_proportional_millionths *= 500 392 chan_ba.forwarding_fee_base_msat *= 500 393 chan_bd.forwarding_fee_proportional_millionths *= 500 394 chan_bd.forwarding_fee_base_msat *= 500 395 chan_db.forwarding_fee_proportional_millionths *= 500 396 chan_db.forwarding_fee_base_msat *= 500 397 398 # mark_open won't work if state is already OPEN. 399 # so set it to FUNDED 400 for chan in [chan_ab, chan_ac, chan_ba, chan_bd, chan_ca, chan_cd, chan_db, chan_dc]: 401 chan._state = ChannelState.FUNDED 402 # this populates the channel graph: 403 peer_ab.mark_open(chan_ab) 404 peer_ac.mark_open(chan_ac) 405 peer_ba.mark_open(chan_ba) 406 peer_bd.mark_open(chan_bd) 407 peer_ca.mark_open(chan_ca) 408 peer_cd.mark_open(chan_cd) 409 peer_db.mark_open(chan_db) 410 peer_dc.mark_open(chan_dc) 411 return SquareGraph( 412 w_a=w_a, 413 w_b=w_b, 414 w_c=w_c, 415 w_d=w_d, 416 peer_ab=peer_ab, 417 peer_ac=peer_ac, 418 peer_ba=peer_ba, 419 peer_bd=peer_bd, 420 peer_ca=peer_ca, 421 peer_cd=peer_cd, 422 peer_db=peer_db, 423 peer_dc=peer_dc, 424 chan_ab=chan_ab, 425 chan_ac=chan_ac, 426 chan_ba=chan_ba, 427 chan_bd=chan_bd, 428 chan_ca=chan_ca, 429 chan_cd=chan_cd, 430 chan_db=chan_db, 431 chan_dc=chan_dc, 432 ) 433 434 @staticmethod 435 async def prepare_invoice( 436 w2: MockLNWallet, # receiver 437 *, 438 amount_msat=100_000_000, 439 include_routing_hints=False, 440 ) -> Tuple[LnAddr, str]: 441 amount_btc = amount_msat/Decimal(COIN*1000) 442 payment_preimage = os.urandom(32) 443 RHASH = sha256(payment_preimage) 444 info = PaymentInfo(RHASH, amount_msat, RECEIVED, PR_UNPAID) 445 w2.save_preimage(RHASH, payment_preimage) 446 w2.save_payment_info(info) 447 if include_routing_hints: 448 routing_hints = await w2._calc_routing_hints_for_invoice(amount_msat) 449 else: 450 routing_hints = [] 451 trampoline_hints = [] 452 for r in routing_hints: 453 node_id, short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = r[1][0] 454 if len(r[1])== 1 and w2.is_trampoline_peer(node_id): 455 trampoline_hints.append(('t', (node_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta))) 456 invoice_features = w2.features.for_invoice() 457 if invoice_features.supports(LnFeatures.PAYMENT_SECRET_OPT): 458 payment_secret = derive_payment_secret_from_payment_preimage(payment_preimage) 459 else: 460 payment_secret = None 461 lnaddr1 = LnAddr( 462 paymenthash=RHASH, 463 amount=amount_btc, 464 tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE), 465 ('d', 'coffee'), 466 ('9', invoice_features), 467 ] + routing_hints + trampoline_hints, 468 payment_secret=payment_secret, 469 ) 470 invoice = lnencode(lnaddr1, w2.node_keypair.privkey) 471 lnaddr2 = lndecode(invoice) # unlike lnaddr1, this now has a pubkey set 472 return lnaddr2, invoice 473 474 def test_reestablish(self): 475 alice_channel, bob_channel = create_test_channels() 476 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) 477 for chan in (alice_channel, bob_channel): 478 chan.peer_state = PeerState.DISCONNECTED 479 async def reestablish(): 480 await asyncio.gather( 481 p1.reestablish_channel(alice_channel), 482 p2.reestablish_channel(bob_channel)) 483 self.assertEqual(alice_channel.peer_state, PeerState.GOOD) 484 self.assertEqual(bob_channel.peer_state, PeerState.GOOD) 485 gath.cancel() 486 gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p1.htlc_switch()) 487 async def f(): 488 await gath 489 with self.assertRaises(concurrent.futures.CancelledError): 490 run(f()) 491 492 @needs_test_with_all_chacha20_implementations 493 def test_reestablish_with_old_state(self): 494 random_seed = os.urandom(32) 495 alice_channel, bob_channel = create_test_channels(random_seed=random_seed) 496 alice_channel_0, bob_channel_0 = create_test_channels(random_seed=random_seed) # these are identical 497 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) 498 lnaddr, pay_req = run(self.prepare_invoice(w2)) 499 async def pay(): 500 result, log = await w1.pay_invoice(pay_req) 501 self.assertEqual(result, True) 502 gath.cancel() 503 gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) 504 async def f(): 505 await gath 506 with self.assertRaises(concurrent.futures.CancelledError): 507 run(f()) 508 509 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel_0, bob_channel) 510 for chan in (alice_channel_0, bob_channel): 511 chan.peer_state = PeerState.DISCONNECTED 512 async def reestablish(): 513 await asyncio.gather( 514 p1.reestablish_channel(alice_channel_0), 515 p2.reestablish_channel(bob_channel)) 516 self.assertEqual(alice_channel_0.peer_state, PeerState.BAD) 517 self.assertEqual(bob_channel._state, ChannelState.FORCE_CLOSING) 518 # wait so that pending messages are processed 519 #await asyncio.sleep(1) 520 gath.cancel() 521 gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) 522 async def f(): 523 await gath 524 with self.assertRaises(concurrent.futures.CancelledError): 525 run(f()) 526 527 @needs_test_with_all_chacha20_implementations 528 def test_payment(self): 529 """Alice pays Bob a single HTLC via direct channel.""" 530 alice_channel, bob_channel = create_test_channels() 531 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) 532 async def pay(lnaddr, pay_req): 533 self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) 534 result, log = await w1.pay_invoice(pay_req) 535 self.assertTrue(result) 536 self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) 537 raise PaymentDone() 538 async def f(): 539 async with TaskGroup() as group: 540 await group.spawn(p1._message_loop()) 541 await group.spawn(p1.htlc_switch()) 542 await group.spawn(p2._message_loop()) 543 await group.spawn(p2.htlc_switch()) 544 await asyncio.sleep(0.01) 545 lnaddr, pay_req = await self.prepare_invoice(w2) 546 invoice_features = lnaddr.get_features() 547 self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) 548 await group.spawn(pay(lnaddr, pay_req)) 549 with self.assertRaises(PaymentDone): 550 run(f()) 551 552 @needs_test_with_all_chacha20_implementations 553 def test_payment_race(self): 554 """Alice and Bob pay each other simultaneously. 555 They both send 'update_add_htlc' and receive each other's update 556 before sending 'commitment_signed'. Neither party should fulfill 557 the respective HTLCs until those are irrevocably committed to. 558 """ 559 alice_channel, bob_channel = create_test_channels() 560 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) 561 async def pay(): 562 await asyncio.wait_for(p1.initialized, 1) 563 await asyncio.wait_for(p2.initialized, 1) 564 # prep 565 _maybe_send_commitment1 = p1.maybe_send_commitment 566 _maybe_send_commitment2 = p2.maybe_send_commitment 567 lnaddr2, pay_req2 = await self.prepare_invoice(w2) 568 lnaddr1, pay_req1 = await self.prepare_invoice(w1) 569 # create the htlc queues now (side-effecting defaultdict) 570 q1 = w1.sent_htlcs[lnaddr2.paymenthash] 571 q2 = w2.sent_htlcs[lnaddr1.paymenthash] 572 # alice sends htlc BUT NOT COMMITMENT_SIGNED 573 p1.maybe_send_commitment = lambda x: None 574 route1 = w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2)[0][0] 575 amount_msat = lnaddr2.get_amount_msat() 576 await w1.pay_to_route( 577 route=route1, 578 amount_msat=amount_msat, 579 total_msat=amount_msat, 580 amount_receiver_msat=amount_msat, 581 payment_hash=lnaddr2.paymenthash, 582 min_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(), 583 payment_secret=lnaddr2.payment_secret, 584 ) 585 p1.maybe_send_commitment = _maybe_send_commitment1 586 # bob sends htlc BUT NOT COMMITMENT_SIGNED 587 p2.maybe_send_commitment = lambda x: None 588 route2 = w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1)[0][0] 589 amount_msat = lnaddr1.get_amount_msat() 590 await w2.pay_to_route( 591 route=route2, 592 amount_msat=amount_msat, 593 total_msat=amount_msat, 594 amount_receiver_msat=amount_msat, 595 payment_hash=lnaddr1.paymenthash, 596 min_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(), 597 payment_secret=lnaddr1.payment_secret, 598 ) 599 p2.maybe_send_commitment = _maybe_send_commitment2 600 # sleep a bit so that they both receive msgs sent so far 601 await asyncio.sleep(0.2) 602 # now they both send COMMITMENT_SIGNED 603 p1.maybe_send_commitment(alice_channel) 604 p2.maybe_send_commitment(bob_channel) 605 606 htlc_log1 = await q1.get() 607 assert htlc_log1.success 608 htlc_log2 = await q2.get() 609 assert htlc_log2.success 610 raise PaymentDone() 611 612 async def f(): 613 async with TaskGroup() as group: 614 await group.spawn(p1._message_loop()) 615 await group.spawn(p1.htlc_switch()) 616 await group.spawn(p2._message_loop()) 617 await group.spawn(p2.htlc_switch()) 618 await asyncio.sleep(0.01) 619 await group.spawn(pay()) 620 with self.assertRaises(PaymentDone): 621 run(f()) 622 623 #@unittest.skip("too expensive") 624 #@needs_test_with_all_chacha20_implementations 625 def test_payments_stresstest(self): 626 alice_channel, bob_channel = create_test_channels() 627 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) 628 alice_init_balance_msat = alice_channel.balance(HTLCOwner.LOCAL) 629 bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL) 630 num_payments = 50 631 payment_value_msat = 10_000_000 # make it large enough so that there are actually HTLCs on the ctx 632 max_htlcs_in_flight = asyncio.Semaphore(5) 633 async def single_payment(pay_req): 634 async with max_htlcs_in_flight: 635 await w1.pay_invoice(pay_req) 636 async def many_payments(): 637 async with TaskGroup() as group: 638 pay_reqs_tasks = [await group.spawn(self.prepare_invoice(w2, amount_msat=payment_value_msat)) 639 for i in range(num_payments)] 640 async with TaskGroup() as group: 641 for pay_req_task in pay_reqs_tasks: 642 lnaddr, pay_req = pay_req_task.result() 643 await group.spawn(single_payment(pay_req)) 644 gath.cancel() 645 gath = asyncio.gather(many_payments(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) 646 async def f(): 647 await gath 648 with self.assertRaises(concurrent.futures.CancelledError): 649 run(f()) 650 self.assertEqual(alice_init_balance_msat - num_payments * payment_value_msat, alice_channel.balance(HTLCOwner.LOCAL)) 651 self.assertEqual(alice_init_balance_msat - num_payments * payment_value_msat, bob_channel.balance(HTLCOwner.REMOTE)) 652 self.assertEqual(bob_init_balance_msat + num_payments * payment_value_msat, bob_channel.balance(HTLCOwner.LOCAL)) 653 self.assertEqual(bob_init_balance_msat + num_payments * payment_value_msat, alice_channel.balance(HTLCOwner.REMOTE)) 654 655 @needs_test_with_all_chacha20_implementations 656 def test_payment_multihop(self): 657 graph = self.prepare_chans_and_peers_in_square() 658 peers = graph.all_peers() 659 async def pay(lnaddr, pay_req): 660 self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) 661 result, log = await graph.w_a.pay_invoice(pay_req) 662 self.assertTrue(result) 663 self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) 664 raise PaymentDone() 665 async def f(): 666 async with TaskGroup() as group: 667 for peer in peers: 668 await group.spawn(peer._message_loop()) 669 await group.spawn(peer.htlc_switch()) 670 await asyncio.sleep(0.2) 671 lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) 672 await group.spawn(pay(lnaddr, pay_req)) 673 with self.assertRaises(PaymentDone): 674 run(f()) 675 676 @needs_test_with_all_chacha20_implementations 677 def test_payment_multihop_with_preselected_path(self): 678 graph = self.prepare_chans_and_peers_in_square() 679 peers = graph.all_peers() 680 async def pay(pay_req): 681 with self.subTest(msg="bad path: edges do not chain together"): 682 path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey, 683 end_node=graph.w_c.node_keypair.pubkey, 684 short_channel_id=graph.chan_ab.short_channel_id), 685 PathEdge(start_node=graph.w_b.node_keypair.pubkey, 686 end_node=graph.w_d.node_keypair.pubkey, 687 short_channel_id=graph.chan_bd.short_channel_id)] 688 with self.assertRaises(LNPathInconsistent): 689 await graph.w_a.pay_invoice(pay_req, full_path=path) 690 with self.subTest(msg="bad path: last node id differs from invoice pubkey"): 691 path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey, 692 end_node=graph.w_b.node_keypair.pubkey, 693 short_channel_id=graph.chan_ab.short_channel_id)] 694 with self.assertRaises(LNPathInconsistent): 695 await graph.w_a.pay_invoice(pay_req, full_path=path) 696 with self.subTest(msg="good path"): 697 path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey, 698 end_node=graph.w_b.node_keypair.pubkey, 699 short_channel_id=graph.chan_ab.short_channel_id), 700 PathEdge(start_node=graph.w_b.node_keypair.pubkey, 701 end_node=graph.w_d.node_keypair.pubkey, 702 short_channel_id=graph.chan_bd.short_channel_id)] 703 result, log = await graph.w_a.pay_invoice(pay_req, full_path=path) 704 self.assertTrue(result) 705 self.assertEqual( 706 [edge.short_channel_id for edge in path], 707 [edge.short_channel_id for edge in log[0].route]) 708 raise PaymentDone() 709 async def f(): 710 async with TaskGroup() as group: 711 for peer in peers: 712 await group.spawn(peer._message_loop()) 713 await group.spawn(peer.htlc_switch()) 714 await asyncio.sleep(0.2) 715 lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) 716 await group.spawn(pay(pay_req)) 717 with self.assertRaises(PaymentDone): 718 run(f()) 719 720 @needs_test_with_all_chacha20_implementations 721 def test_payment_multihop_temp_node_failure(self): 722 graph = self.prepare_chans_and_peers_in_square() 723 graph.w_b.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True) 724 graph.w_c.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True) 725 peers = graph.all_peers() 726 async def pay(lnaddr, pay_req): 727 self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) 728 result, log = await graph.w_a.pay_invoice(pay_req) 729 self.assertFalse(result) 730 self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) 731 self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code) 732 raise PaymentDone() 733 async def f(): 734 async with TaskGroup() as group: 735 for peer in peers: 736 await group.spawn(peer._message_loop()) 737 await group.spawn(peer.htlc_switch()) 738 await asyncio.sleep(0.2) 739 lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) 740 await group.spawn(pay(lnaddr, pay_req)) 741 with self.assertRaises(PaymentDone): 742 run(f()) 743 744 @needs_test_with_all_chacha20_implementations 745 def test_payment_multihop_route_around_failure(self): 746 # Alice will pay Dave. Alice first tries A->C->D route, due to lower fees, but Carol 747 # will fail the htlc and get blacklisted. Alice will then try A->B->D and succeed. 748 graph = self.prepare_chans_and_peers_in_square() 749 graph.w_c.network.config.set_key('test_fail_htlcs_with_temp_node_failure', True) 750 peers = graph.all_peers() 751 async def pay(lnaddr, pay_req): 752 self.assertEqual(500000000000, graph.chan_ab.balance(LOCAL)) 753 self.assertEqual(500000000000, graph.chan_db.balance(LOCAL)) 754 self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) 755 result, log = await graph.w_a.pay_invoice(pay_req, attempts=2) 756 self.assertEqual(2, len(log)) 757 self.assertTrue(result) 758 self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) 759 self.assertEqual([graph.chan_ac.short_channel_id, graph.chan_cd.short_channel_id], 760 [edge.short_channel_id for edge in log[0].route]) 761 self.assertEqual([graph.chan_ab.short_channel_id, graph.chan_bd.short_channel_id], 762 [edge.short_channel_id for edge in log[1].route]) 763 self.assertEqual(OnionFailureCode.TEMPORARY_NODE_FAILURE, log[0].failure_msg.code) 764 self.assertEqual(499899450000, graph.chan_ab.balance(LOCAL)) 765 await asyncio.sleep(0.2) # wait for COMMITMENT_SIGNED / REVACK msgs to update balance 766 self.assertEqual(500100000000, graph.chan_db.balance(LOCAL)) 767 raise PaymentDone() 768 async def f(): 769 async with TaskGroup() as group: 770 for peer in peers: 771 await group.spawn(peer._message_loop()) 772 await group.spawn(peer.htlc_switch()) 773 await asyncio.sleep(0.2) 774 lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True) 775 invoice_features = lnaddr.get_features() 776 self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) 777 await group.spawn(pay(lnaddr, pay_req)) 778 with self.assertRaises(PaymentDone): 779 run(f()) 780 781 def _run_mpp(self, graph, kwargs1, kwargs2): 782 self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL)) 783 self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL)) 784 amount_to_pay = 600_000_000_000 785 peers = graph.all_peers() 786 async def pay(attempts=1, 787 alice_uses_trampoline=False, 788 bob_forwarding=True, 789 mpp_invoice=True): 790 if mpp_invoice: 791 graph.w_d.features |= LnFeatures.BASIC_MPP_OPT 792 if not bob_forwarding: 793 graph.w_b.enable_htlc_forwarding.clear() 794 if alice_uses_trampoline: 795 if graph.w_a.network.channel_db: 796 graph.w_a.network.channel_db.stop() 797 await graph.w_a.network.channel_db.stopped_event.wait() 798 graph.w_a.network.channel_db = None 799 else: 800 assert graph.w_a.network.channel_db is not None 801 lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay) 802 self.assertEqual(PR_UNPAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) 803 result, log = await graph.w_a.pay_invoice(pay_req, attempts=attempts) 804 if not bob_forwarding: 805 # reset to previous state, sleep 2s so that the second htlc can time out 806 graph.w_b.enable_htlc_forwarding.set() 807 await asyncio.sleep(2) 808 if result: 809 self.assertEqual(PR_PAID, graph.w_d.get_payment_status(lnaddr.paymenthash)) 810 raise PaymentDone() 811 else: 812 raise NoPathFound() 813 814 async def f(kwargs): 815 async with TaskGroup() as group: 816 for peer in peers: 817 await group.spawn(peer._message_loop()) 818 await group.spawn(peer.htlc_switch()) 819 await asyncio.sleep(0.2) 820 await group.spawn(pay(**kwargs)) 821 822 with self.assertRaises(NoPathFound): 823 run(f(kwargs1)) 824 with self.assertRaises(PaymentDone): 825 run(f(kwargs2)) 826 827 @needs_test_with_all_chacha20_implementations 828 def test_multipart_payment_with_timeout(self): 829 graph = self.prepare_chans_and_peers_in_square() 830 self._run_mpp(graph, {'bob_forwarding':False}, {'bob_forwarding':True}) 831 832 @needs_test_with_all_chacha20_implementations 833 def test_multipart_payment(self): 834 graph = self.prepare_chans_and_peers_in_square() 835 self._run_mpp(graph, {'mpp_invoice':False}, {'mpp_invoice':True}) 836 837 @needs_test_with_all_chacha20_implementations 838 def test_multipart_payment_with_trampoline(self): 839 # single attempt will fail with insufficient trampoline fee 840 graph = self.prepare_chans_and_peers_in_square() 841 self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':3}) 842 843 @needs_test_with_all_chacha20_implementations 844 def test_fail_pending_htlcs_on_shutdown(self): 845 """Alice tries to pay Dave via MPP. Dave receives some HTLCs but not all. 846 Dave shuts down (stops wallet). 847 We test if Dave fails the pending HTLCs during shutdown. 848 """ 849 graph = self.prepare_chans_and_peers_in_square() 850 self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL)) 851 self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL)) 852 amount_to_pay = 600_000_000_000 853 peers = graph.all_peers() 854 graph.w_d.MPP_EXPIRY = 120 855 graph.w_d.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 856 async def pay(): 857 graph.w_d.features |= LnFeatures.BASIC_MPP_OPT 858 graph.w_b.enable_htlc_forwarding.clear() # Bob will hold forwarded HTLCs 859 assert graph.w_a.network.channel_db is not None 860 lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay) 861 try: 862 async with timeout_after(0.5): 863 result, log = await graph.w_a.pay_invoice(pay_req, attempts=1) 864 except TaskTimeout: 865 # by now Dave hopefully received some HTLCs: 866 self.assertTrue(len(graph.chan_dc.hm.htlcs(LOCAL)) > 0) 867 self.assertTrue(len(graph.chan_dc.hm.htlcs(REMOTE)) > 0) 868 else: 869 self.fail(f"pay_invoice finished but was not supposed to. result={result}") 870 await graph.w_d.stop() 871 # Dave is supposed to have failed the pending incomplete MPP HTLCs 872 self.assertEqual(0, len(graph.chan_dc.hm.htlcs(LOCAL))) 873 self.assertEqual(0, len(graph.chan_dc.hm.htlcs(REMOTE))) 874 raise TestSuccess() 875 876 async def f(): 877 async with TaskGroup() as group: 878 for peer in peers: 879 await group.spawn(peer._message_loop()) 880 await group.spawn(peer.htlc_switch()) 881 await asyncio.sleep(0.2) 882 await group.spawn(pay()) 883 884 with self.assertRaises(TestSuccess): 885 run(f()) 886 887 @needs_test_with_all_chacha20_implementations 888 def test_close(self): 889 alice_channel, bob_channel = create_test_channels() 890 p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) 891 w1.network.config.set_key('dynamic_fees', False) 892 w2.network.config.set_key('dynamic_fees', False) 893 w1.network.config.set_key('fee_per_kb', 5000) 894 w2.network.config.set_key('fee_per_kb', 1000) 895 w2.enable_htlc_settle.clear() 896 lnaddr, pay_req = run(self.prepare_invoice(w2)) 897 async def pay(): 898 await asyncio.wait_for(p1.initialized, 1) 899 await asyncio.wait_for(p2.initialized, 1) 900 # alice sends htlc 901 route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0][0:2] 902 htlc = p1.pay(route=route, 903 chan=alice_channel, 904 amount_msat=lnaddr.get_amount_msat(), 905 total_msat=lnaddr.get_amount_msat(), 906 payment_hash=lnaddr.paymenthash, 907 min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(), 908 payment_secret=lnaddr.payment_secret) 909 # alice closes 910 await p1.close_channel(alice_channel.channel_id) 911 gath.cancel() 912 async def set_settle(): 913 await asyncio.sleep(0.1) 914 w2.enable_htlc_settle.set() 915 gath = asyncio.gather(pay(), set_settle(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) 916 async def f(): 917 await gath 918 with self.assertRaises(concurrent.futures.CancelledError): 919 run(f()) 920 921 @needs_test_with_all_chacha20_implementations 922 def test_close_upfront_shutdown_script(self): 923 alice_channel, bob_channel = create_test_channels() 924 925 # create upfront shutdown script for bob, alice doesn't use upfront 926 # shutdown script 927 bob_uss_pub = lnutil.privkey_to_pubkey(os.urandom(32)) 928 bob_uss_addr = bitcoin.pubkey_to_address('p2wpkh', bh2u(bob_uss_pub)) 929 bob_uss = bfh(bitcoin.address_to_script(bob_uss_addr)) 930 931 # bob commits to close to bob_uss 932 alice_channel.config[HTLCOwner.REMOTE].upfront_shutdown_script = bob_uss 933 # but bob closes to some receiving address, which we achieve by not 934 # setting the upfront shutdown script in the channel config 935 bob_channel.config[HTLCOwner.LOCAL].upfront_shutdown_script = b'' 936 937 p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel) 938 w1.network.config.set_key('dynamic_fees', False) 939 w2.network.config.set_key('dynamic_fees', False) 940 w1.network.config.set_key('fee_per_kb', 5000) 941 w2.network.config.set_key('fee_per_kb', 1000) 942 943 async def test(): 944 async def close(): 945 await asyncio.wait_for(p1.initialized, 1) 946 await asyncio.wait_for(p2.initialized, 1) 947 # bob closes channel with different shutdown script 948 await p1.close_channel(alice_channel.channel_id) 949 gath.cancel() 950 951 async def main_loop(peer): 952 async with peer.taskgroup as group: 953 await group.spawn(peer._message_loop()) 954 await group.spawn(peer.htlc_switch()) 955 956 coros = [close(), main_loop(p1), main_loop(p2)] 957 gath = asyncio.gather(*coros) 958 await gath 959 960 with self.assertRaises(UpfrontShutdownScriptViolation): 961 run(test()) 962 963 # bob sends the same upfront_shutdown_script has he announced 964 alice_channel.config[HTLCOwner.REMOTE].upfront_shutdown_script = bob_uss 965 bob_channel.config[HTLCOwner.LOCAL].upfront_shutdown_script = bob_uss 966 967 p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel) 968 w1.network.config.set_key('dynamic_fees', False) 969 w2.network.config.set_key('dynamic_fees', False) 970 w1.network.config.set_key('fee_per_kb', 5000) 971 w2.network.config.set_key('fee_per_kb', 1000) 972 973 async def test(): 974 async def close(): 975 await asyncio.wait_for(p1.initialized, 1) 976 await asyncio.wait_for(p2.initialized, 1) 977 await p1.close_channel(alice_channel.channel_id) 978 gath.cancel() 979 980 async def main_loop(peer): 981 async with peer.taskgroup as group: 982 await group.spawn(peer._message_loop()) 983 await group.spawn(peer.htlc_switch()) 984 985 coros = [close(), main_loop(p1), main_loop(p2)] 986 gath = asyncio.gather(*coros) 987 await gath 988 with self.assertRaises(concurrent.futures.CancelledError): 989 run(test()) 990 991 def test_channel_usage_after_closing(self): 992 alice_channel, bob_channel = create_test_channels() 993 p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel) 994 lnaddr, pay_req = run(self.prepare_invoice(w2)) 995 996 lnaddr = w1._check_invoice(pay_req) 997 route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0][0:2] 998 assert amount_msat == lnaddr.get_amount_msat() 999 1000 run(w1.force_close_channel(alice_channel.channel_id)) 1001 # check if a tx (commitment transaction) was broadcasted: 1002 assert q1.qsize() == 1 1003 1004 with self.assertRaises(NoPathFound) as e: 1005 w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr) 1006 1007 peer = w1.peers[route[0].node_id] 1008 # AssertionError is ok since we shouldn't use old routes, and the 1009 # route finding should fail when channel is closed 1010 async def f(): 1011 min_cltv_expiry = lnaddr.get_min_final_cltv_expiry() 1012 payment_hash = lnaddr.paymenthash 1013 payment_secret = lnaddr.payment_secret 1014 pay = w1.pay_to_route( 1015 route=route, 1016 amount_msat=amount_msat, 1017 total_msat=amount_msat, 1018 amount_receiver_msat=amount_msat, 1019 payment_hash=payment_hash, 1020 payment_secret=payment_secret, 1021 min_cltv_expiry=min_cltv_expiry) 1022 await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) 1023 with self.assertRaises(PaymentFailure): 1024 run(f()) 1025 1026 1027 def run(coro): 1028 return asyncio.run_coroutine_threadsafe(coro, loop=asyncio.get_event_loop()).result()