electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 04d018cd0f5fafb8abd4b2771fcc1c6206f80279
parent 7951f2ed3b652e12dfaf7f9685d1428a2a9e4deb
Author: SomberNight <somber.night@protonmail.com>
Date:   Wed,  6 May 2020 10:44:38 +0200

test_lnpeer: some clean-up, make it easier to add "num_node>2" tests

Diffstat:
Melectrum/coinchooser.py | 4++--
Melectrum/tests/test_lnchannel.py | 22++++++++++++++--------
Melectrum/tests/test_lnpeer.py | 94+++++++++++++++++++++++++++++++++++++++++++++++--------------------------------
3 files changed, 72 insertions(+), 48 deletions(-)

diff --git a/electrum/coinchooser.py b/electrum/coinchooser.py @@ -44,12 +44,12 @@ class PRNG: self.sha = sha256(seed) self.pool = bytearray() - def get_bytes(self, n): + def get_bytes(self, n: int) -> bytes: while len(self.pool) < n: self.pool.extend(self.sha) self.sha = sha256(self.sha) result, self.pool = self.pool[:n], self.pool[n:] - return result + return bytes(result) def randint(self, start, end): # Returns random integer in [start, end) diff --git a/electrum/tests/test_lnchannel.py b/electrum/tests/test_lnchannel.py @@ -39,6 +39,7 @@ from electrum.ecc import sig_string_from_der_sig from electrum.logging import console_stderr_handler from electrum.lnchannel import ChannelState from electrum.json_db import StoredDict +from electrum.coinchooser import PRNG from . import ElectrumTestCase @@ -110,8 +111,13 @@ def bip32(sequence): assert type(k) is bytes return k -def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None): - funding_txid = binascii.hexlify(b"\x01"*32).decode("ascii") +def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None, + alice_name="alice", bob_name="bob", + alice_pubkey=b"\x01"*33, bob_pubkey=b"\x02"*33, random_seed=None): + if random_seed is None: # needed for deterministic randomness + random_seed = os.urandom(32) + random_gen = PRNG(random_seed) + funding_txid = binascii.hexlify(random_gen.get_bytes(32)).decode("ascii") funding_index = 0 funding_sat = ((local_msat + remote_msat) // 1000) if local_msat is not None and remote_msat is not None else (bitcoin.COIN * 10) local_amount = local_msat if local_msat is not None else (funding_sat * 1000 // 2) @@ -123,20 +129,20 @@ def create_test_channels(*, feerate=6000, local_msat=None, remote_msat=None): alice_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in alice_privkeys] bob_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in bob_privkeys] - alice_seed = b"\x01" * 32 - bob_seed = b"\x02" * 32 + alice_seed = random_gen.get_bytes(32) + bob_seed = random_gen.get_bytes(32) alice_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX), "big")) bob_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX), "big")) alice, bob = ( lnchannel.Channel( - create_channel_state(funding_txid, funding_index, funding_sat, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None, bob_first, b"\x02"*33, l_dust=200, r_dust=1300, l_csv=5, r_csv=4), - name="alice", + create_channel_state(funding_txid, funding_index, funding_sat, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None, bob_first, other_node_id=bob_pubkey, l_dust=200, r_dust=1300, l_csv=5, r_csv=4), + name=bob_name, initial_feerate=feerate), lnchannel.Channel( - create_channel_state(funding_txid, funding_index, funding_sat, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, None, alice_first, b"\x01"*33, l_dust=1300, r_dust=200, l_csv=4, r_csv=5), - name="bob", + create_channel_state(funding_txid, funding_index, funding_sat, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, None, alice_first, other_node_id=alice_pubkey, l_dust=1300, r_dust=200, l_csv=4, r_csv=5), + name=alice_name, initial_feerate=feerate) ) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -8,6 +8,7 @@ import logging import concurrent from concurrent import futures import unittest +from typing import Iterable from aiorpcx import TaskGroup @@ -96,21 +97,23 @@ class MockWallet: return False class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): - def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue): + def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue): Logger.__init__(self) NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1) - self.remote_keypair = remote_keypair self.node_keypair = local_keypair self.network = MockNetwork(tx_queue) - self._channels = {chan.channel_id: chan} + self.channel_db = self.network.channel_db + self._channels = {chan.channel_id: chan + for chan in chans} self.payments = {} self.logs = defaultdict(list) self.wallet = MockWallet() self.features = LnFeatures(0) self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT self.pending_payments = defaultdict(asyncio.Future) - chan.lnworker = self - chan.node_id = remote_keypair.pubkey + for chan in chans: + chan.lnworker = self + self._peers = {} # bytes -> Peer # used in tests self.enable_htlc_settle = asyncio.Event() self.enable_htlc_settle.set() @@ -130,13 +133,6 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): def peers(self): return self._peers - @property - def _peers(self): - return {self.remote_keypair.pubkey: self.peer} - - def channels_for_peer(self, pubkey): - return self._channels - def get_channel_by_short_id(self, short_channel_id): with self.lock: for chan in self._channels.values(): @@ -171,6 +167,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): get_first_timestamp = lambda self: 0 on_peer_successfully_established = LNWallet.on_peer_successfully_established get_channel_by_id = LNWallet.get_channel_by_id + channels_for_peer = LNWallet.channels_for_peer + _calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice + handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc class MockTransport: @@ -206,12 +205,16 @@ class PutIntoOthersQueueTransport(MockTransport): self.other_mock_transport.queue.put_nowait(data) def transport_pair(k1, k2, name1, name2): - t1 = PutIntoOthersQueueTransport(k1, name1) - t2 = PutIntoOthersQueueTransport(k2, name2) + t1 = PutIntoOthersQueueTransport(k1, name2) + t2 = PutIntoOthersQueueTransport(k2, name1) t1.other_mock_transport = t2 t2.other_mock_transport = t1 return t1, t2 + +class PaymentDone(Exception): pass + + class TestPeer(ElectrumTestCase): @classmethod @@ -230,14 +233,16 @@ class TestPeer(ElectrumTestCase): def prepare_peers(self, alice_channel, bob_channel): k1, k2 = keypair(), keypair() - t1, t2 = transport_pair(k2, k1, alice_channel.name, bob_channel.name) + alice_channel.node_id = k2.pubkey + bob_channel.node_id = k1.pubkey + t1, t2 = transport_pair(k1, k2, alice_channel.name, bob_channel.name) q1, q2 = asyncio.Queue(), asyncio.Queue() - w1 = MockLNWallet(k1, k2, alice_channel, tx_queue=q1) - w2 = MockLNWallet(k2, k1, bob_channel, tx_queue=q2) - p1 = Peer(w1, k1.pubkey, t1) - p2 = Peer(w2, k2.pubkey, t2) - w1.peer = p1 - w2.peer = p2 + w1 = MockLNWallet(local_keypair=k1, chans=[alice_channel], tx_queue=q1) + w2 = MockLNWallet(local_keypair=k2, chans=[bob_channel], tx_queue=q2) + p1 = Peer(w1, k2.pubkey, t1) + p2 = Peer(w2, k1.pubkey, t2) + w1._peers[p1.pubkey] = p1 + w2._peers[p2.pubkey] = p2 # mark_open won't work if state is already OPEN. # so set it to FUNDED alice_channel._state = ChannelState.FUNDED @@ -248,10 +253,11 @@ class TestPeer(ElectrumTestCase): return p1, p2, w1, w2, q1, q2 @staticmethod - def prepare_invoice( - w2, # receiver + async def prepare_invoice( + w2: MockLNWallet, # receiver *, amount_sat=100_000, + include_routing_hints=False, ): amount_btc = amount_sat/Decimal(COIN) payment_preimage = os.urandom(32) @@ -259,12 +265,16 @@ class TestPeer(ElectrumTestCase): info = PaymentInfo(RHASH, amount_sat, RECEIVED, PR_UNPAID) w2.save_preimage(RHASH, payment_preimage) w2.save_payment_info(info) + if include_routing_hints: + routing_hints = await w2._calc_routing_hints_for_invoice(amount_sat) + else: + routing_hints = [] lnaddr = LnAddr( paymenthash=RHASH, amount=amount_btc, tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE), ('d', 'coffee') - ]) + ] + routing_hints) return lnencode(lnaddr, w2.node_keypair.privkey) def test_reestablish(self): @@ -287,10 +297,11 @@ class TestPeer(ElectrumTestCase): @needs_test_with_all_chacha20_implementations def test_reestablish_with_old_state(self): - alice_channel, bob_channel = create_test_channels() - alice_channel_0, bob_channel_0 = create_test_channels() # these are identical + random_seed = os.urandom(32) + alice_channel, bob_channel = create_test_channels(random_seed=random_seed) + alice_channel_0, bob_channel_0 = create_test_channels(random_seed=random_seed) # these are identical p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) - pay_req = self.prepare_invoice(w2) + pay_req = run(self.prepare_invoice(w2)) async def pay(): result, log = await w1._pay(pay_req) self.assertEqual(result, True) @@ -323,15 +334,20 @@ class TestPeer(ElectrumTestCase): def test_payment(self): alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) - pay_req = self.prepare_invoice(w2) - async def pay(): + async def pay(pay_req): result, log = await w1._pay(pay_req) self.assertTrue(result) - gath.cancel() - gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) + raise PaymentDone() async def f(): - await gath - with self.assertRaises(concurrent.futures.CancelledError): + async with TaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + pay_req = await self.prepare_invoice(w2) + await group.spawn(pay(pay_req)) + with self.assertRaises(PaymentDone): run(f()) #@unittest.skip("too expensive") @@ -343,15 +359,17 @@ class TestPeer(ElectrumTestCase): bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL) num_payments = 50 payment_value_sat = 10000 # make it large enough so that there are actually HTLCs on the ctx - #pay_reqs1 = [self.prepare_invoice(w1, amount_sat=1) for i in range(num_payments)] - pay_reqs2 = [self.prepare_invoice(w2, amount_sat=payment_value_sat) for i in range(num_payments)] max_htlcs_in_flight = asyncio.Semaphore(5) async def single_payment(pay_req): async with max_htlcs_in_flight: await w1._pay(pay_req) async def many_payments(): async with TaskGroup() as group: - for pay_req in pay_reqs2: + pay_reqs_tasks = [await group.spawn(self.prepare_invoice(w2, amount_sat=payment_value_sat)) + for i in range(num_payments)] + async with TaskGroup() as group: + for pay_req_task in pay_reqs_tasks: + pay_req = pay_req_task.result() await group.spawn(single_payment(pay_req)) gath.cancel() gath = asyncio.gather(many_payments(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) @@ -373,7 +391,7 @@ class TestPeer(ElectrumTestCase): w1.network.config.set_key('fee_per_kb', 5000) w2.network.config.set_key('fee_per_kb', 1000) w2.enable_htlc_settle.clear() - pay_req = self.prepare_invoice(w2) + pay_req = run(self.prepare_invoice(w2)) lnaddr = lndecode(pay_req, expected_hrp=constants.net.SEGWIT_HRP) async def pay(): await asyncio.wait_for(p1.initialized, 1) @@ -401,7 +419,7 @@ class TestPeer(ElectrumTestCase): def test_channel_usage_after_closing(self): alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel) - pay_req = self.prepare_invoice(w2) + pay_req = run(self.prepare_invoice(w2)) addr = w1._check_invoice(pay_req) route = w1._create_route_from_invoice(decoded_invoice=addr)