commit 7e76e821522eb0fcb6aef0fcbd9897d77606b382
parent ce2b572fa5a97a4686e92d9eaca7c3a95ece02ec
Author: Janus <ysangkok@gmail.com>
Date: Thu, 25 Oct 2018 21:59:16 +0200
test_lnbase: add test that pays to another local electrum
Diffstat:
3 files changed, 146 insertions(+), 26 deletions(-)
diff --git a/electrum/lnbase.py b/electrum/lnbase.py
@@ -350,6 +350,11 @@ class Peer(PrintError):
@log_exceptions
@handle_disconnect
async def main_loop(self):
+ """
+ This is used in LNWorker and is necessary so that we don't kill the main
+ task group. It is not merged with _main_loop, so that we can test if the
+ correct exceptions are getting thrown using _main_loop.
+ """
await self._main_loop()
async def _main_loop(self):
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -32,7 +32,6 @@ from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
generate_keypair, LnKeyFamily, LOCAL, REMOTE,
UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE,
NUM_MAX_EDGES_IN_PAYMENT_PATH)
-from .lnaddr import lndecode
from .i18n import _
from .lnrouter import RouteEdge, is_route_sane_to_use
@@ -258,6 +257,15 @@ class LNWorker(PrintError):
return bh2u(chan.node_id)
def pay(self, invoice, amount_sat=None):
+ """
+ This is not merged with _pay so that we can run the test with
+ one thread only.
+ """
+ addr, peer, coro = self._pay(invoice, amount_sat)
+ fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
+ return addr, peer, fut
+
+ def _pay(self, invoice, amount_sat=None):
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
payment_hash = addr.paymenthash
amount_sat = (addr.amount * COIN) if addr.amount else amount_sat
@@ -279,7 +287,7 @@ class LNWorker(PrintError):
raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
peer = self.peers[node_id]
coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry())
- return addr, peer, asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
+ return addr, peer, coro
def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]:
invoice_pubkey = decoded_invoice.pubkey.serialize()
diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py
@@ -1,16 +1,40 @@
-from electrum.lnbase import Peer, decode_msg, gen_msg
-from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
-from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
-from electrum.ecc import ECPrivkey
-from electrum.lnrouter import ChannelDB
import unittest
import asyncio
-from electrum import simple_config
import tempfile
+from decimal import Decimal
+import os
+from contextlib import contextmanager
+from collections import defaultdict
+
+from electrum.network import Network
+from electrum.ecc import ECPrivkey
+from electrum import simple_config, lnutil
+from electrum.lnaddr import lnencode, LnAddr, lndecode
+from electrum.bitcoin import COIN, sha256
+from electrum.util import bh2u
+
+from electrum.lnbase import Peer, decode_msg, gen_msg
+from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
+from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
+from electrum.lnrouter import ChannelDB, LNPathFinder
+from electrum.lnworker import LNWorker
+
from .test_lnchan import create_test_channels
+def keypair():
+ priv = ECPrivkey.generate_random_key().get_secret_bytes()
+ k1 = Keypair(
+ pubkey=privkey_to_pubkey(priv),
+ privkey=priv)
+ return k1
+
+@contextmanager
+def noop_lock():
+ yield
+
class MockNetwork:
def __init__(self):
+ self.callbacks = defaultdict(list)
self.lnwatcher = None
user_config = {}
user_dir = tempfile.mkdtemp(prefix="electrum-lnbase-test-")
@@ -18,49 +42,132 @@ class MockNetwork:
self.asyncio_loop = asyncio.get_event_loop()
self.channel_db = ChannelDB(self)
self.interface = None
- def register_callback(self, cb, trigger_names):
- print("callback registered", repr(trigger_names))
- def trigger_callback(self, trigger_name, obj):
- print("callback triggered", repr(trigger_name))
+ self.path_finder = LNPathFinder(self.channel_db)
+
+ @property
+ def callback_lock(self):
+ return noop_lock()
+
+ register_callback = Network.register_callback
+ unregister_callback = Network.unregister_callback
+ trigger_callback = Network.trigger_callback
+
+ def get_local_height(self):
+ return 0
class MockLNWorker:
- def __init__(self, remote_peer_pubkey, chan):
+ def __init__(self, remote_keypair, local_keypair, chan):
self.chan = chan
- self.remote_peer_pubkey = remote_peer_pubkey
- priv = ECPrivkey.generate_random_key().get_secret_bytes()
- self.node_keypair = Keypair(
- pubkey=privkey_to_pubkey(priv),
- privkey=priv)
+ self.remote_keypair = remote_keypair
+ self.node_keypair = local_keypair
self.network = MockNetwork()
+ self.channels = {self.chan.channel_id: self.chan}
+ self.invoices = {}
+
+ @property
+ def lock(self):
+ return noop_lock()
+
@property
def peers(self):
- return {self.remote_peer_pubkey: self.peer}
+ return {self.remote_keypair.pubkey: self.peer}
+
def channels_for_peer(self, pubkey):
- return {self.chan.channel_id: self.chan}
+ return self.channels
+
+ def save_channel(self, chan):
+ pass
+
+ get_invoice = LNWorker.get_invoice
+ _create_route_from_invoice = LNWorker._create_route_from_invoice
class MockTransport:
def __init__(self):
self.queue = asyncio.Queue()
+
async def read_messages(self):
while True:
yield await self.queue.get()
-class BadFeaturesTransport(MockTransport):
+class NoFeaturesTransport(MockTransport):
+ """
+ This answers the init message with a init that doesn't signal any features.
+ Used for testing that we require DATA_LOSS_PROTECT.
+ """
def send_bytes(self, data):
decoded = decode_msg(data)
print(decoded)
if decoded[0] == 'init':
self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
+class PutIntoOthersQueueTransport(MockTransport):
+ def __init__(self):
+ super().__init__()
+ self.other_mock_transport = None
+
+ def send_bytes(self, data):
+ self.other_mock_transport.queue.put_nowait(data)
+
+def transport_pair():
+ t1 = PutIntoOthersQueueTransport()
+ t2 = PutIntoOthersQueueTransport()
+ t1.other_mock_transport = t2
+ t2.other_mock_transport = t1
+ return t1, t2
+
class TestPeer(unittest.TestCase):
def setUp(self):
self.alice_channel, self.bob_channel = create_test_channels()
- def test_bad_feature_flags(self):
- # we should require DATA_LOSS_PROTECT
- mock_lnworker = MockLNWorker(b"\x00" * 32, self.alice_channel)
- mock_transport = BadFeaturesTransport()
- p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 32), request_initial_sync=False, transport=mock_transport)
+
+ def test_require_data_loss_protect(self):
+ mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel)
+ mock_transport = NoFeaturesTransport()
+ p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
mock_lnworker.peer = p1
with self.assertRaises(LightningPeerConnectionClosed):
asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
+ def test_payment(self):
+ k1, k2 = keypair(), keypair()
+ t1, t2 = transport_pair()
+ w1 = MockLNWorker(k1, k2, self.alice_channel)
+ w2 = MockLNWorker(k2, k1, self.bob_channel)
+ p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
+ request_initial_sync=False, transport=t1)
+ p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
+ request_initial_sync=False, transport=t2)
+ w1.peer = p1
+ w2.peer = p2
+ # mark_open won't work if state is already OPEN.
+ # so set it to OPENING
+ self.alice_channel.set_state("OPENING")
+ self.bob_channel.set_state("OPENING")
+ # this populates the channel graph:
+ p1.mark_open(self.alice_channel)
+ p2.mark_open(self.bob_channel)
+ amount_btc = 100000/Decimal(COIN)
+ payment_preimage = os.urandom(32)
+ RHASH = sha256(payment_preimage)
+ addr = LnAddr(
+ RHASH,
+ amount_btc,
+ tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE),
+ ('d', 'coffee')
+ ])
+ pay_req = lnencode(addr, w2.node_keypair.privkey)
+ w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req)
+ l = asyncio.get_event_loop()
+ async def pay():
+ fut = asyncio.Future()
+ def evt_set(event, _lnworker, msg):
+ fut.set_result(msg)
+ w2.network.register_callback(evt_set, ['ln_message'])
+
+ addr, peer, coro = LNWorker._pay(w1, pay_req)
+ await coro
+ print("HTLC ADDED")
+ self.assertEqual(await fut, 'Payment received')
+ gath.cancel()
+ gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop())
+ with self.assertRaises(asyncio.CancelledError):
+ l.run_until_complete(gath)