commit 85789d8a09523cd6c5635ec71ec03d099caf0c48
parent a42c1067abcdd1a610f3d26da804703055a34b1a
Author: Janus <ysangkok@gmail.com>
Date: Thu, 25 Oct 2018 18:28:18 +0200
lnbase: mark initialized later, add tests, etc
- consistent node_id sorting
- require OPTION_DATA_LOSS_PROTECT and test it
Diffstat:
2 files changed, 96 insertions(+), 20 deletions(-)
diff --git a/electrum/lnbase.py b/electrum/lnbase.py
@@ -201,6 +201,7 @@ class Peer(PrintError):
self.peer_addr = peer_addr
self.lnworker = lnworker
self.privkey = lnworker.node_keypair.privkey
+ self.node_ids = [peer_addr.pubkey, privkey_to_pubkey(self.privkey)]
self.network = lnworker.network
self.lnwatcher = lnworker.network.lnwatcher
self.channel_db = lnworker.network.channel_db
@@ -218,7 +219,7 @@ class Peer(PrintError):
self.localfeatures = LnLocalFeatures(0)
if request_initial_sync:
self.localfeatures |= LnLocalFeatures.INITIAL_ROUTING_SYNC
- self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT
+ self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ
self.attempted_route = {}
self.orphan_channel_updates = OrderedDict()
@@ -234,7 +235,6 @@ class Peer(PrintError):
await transport.handshake()
self.transport = transport
self.send_message("init", gflen=0, lflen=1, localfeatures=self.localfeatures)
- self.initialized.set_result(True)
@property
def channels(self) -> Dict[bytes, Channel]:
@@ -310,6 +310,7 @@ class Peer(PrintError):
raise LightningPeerConnectionClosed("remote does not have even flag {}"
.format(str(LnLocalFeatures(1 << flag))))
self.localfeatures ^= 1 << flag # disable flag
+ self.initialized.set_result(True)
def on_channel_update(self, payload):
try:
@@ -349,6 +350,13 @@ class Peer(PrintError):
@log_exceptions
@handle_disconnect
async def main_loop(self):
+ """
+ This is used from the GUI. It is not merged with the other function,
+ so that we can test if the correct exceptions are getting thrown.
+ """
+ await self._main_loop()
+
+ async def _main_loop(self):
try:
await asyncio.wait_for(self.initialize(), 10)
except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
@@ -757,16 +765,17 @@ class Peer(PrintError):
if not ecc.verify_signature(self.peer_addr.pubkey, remote_node_sig, h):
raise Exception("node_sig invalid in announcement_signatures")
- node_sigs = [local_node_sig, remote_node_sig]
- bitcoin_sigs = [local_bitcoin_sig, remote_bitcoin_sig]
- node_ids = [privkey_to_pubkey(self.privkey), self.peer_addr.pubkey]
- bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
+ node_sigs = [remote_node_sig, local_node_sig]
+ bitcoin_sigs = [remote_bitcoin_sig, local_bitcoin_sig]
+ bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey, chan.config[LOCAL].multisig_key.pubkey]
- if node_ids[0] > node_ids[1]:
+ if self.node_ids[0] > self.node_ids[1]:
node_sigs.reverse()
bitcoin_sigs.reverse()
- node_ids.reverse()
+ node_ids = list(reversed(self.node_ids))
bitcoin_keys.reverse()
+ else:
+ node_ids = self.node_ids
self.send_message("channel_announcement",
node_signatures_1=node_sigs[0],
@@ -793,14 +802,13 @@ class Peer(PrintError):
chan.set_state("OPEN")
self.network.trigger_callback('channel', chan)
# add channel to database
- pubkey_ours = self.lnworker.node_keypair.pubkey
- pubkey_theirs = self.peer_addr.pubkey
- node_ids = [pubkey_theirs, pubkey_ours]
bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
- sorted_node_ids = list(sorted(node_ids))
- if sorted_node_ids != node_ids:
+ sorted_node_ids = list(sorted(self.node_ids))
+ if sorted_node_ids != self.node_ids:
node_ids = sorted_node_ids
bitcoin_keys.reverse()
+ else:
+ node_ids = self.node_ids
# note: we inject a channel announcement, and a channel update (for outgoing direction)
# This is atm needed for
# - finding routes
@@ -813,7 +821,10 @@ class Peer(PrintError):
'bitcoin_key_1': bitcoin_keys[0], 'bitcoin_key_2': bitcoin_keys[1]},
trusted=True)
# only inject outgoing direction:
- channel_flags = b'\x00' if node_ids[0] == pubkey_ours else b'\x01'
+ if node_ids[0] == privkey_to_pubkey(self.privkey):
+ channel_flags = b'\x00'
+ else:
+ channel_flags = b'\x01'
now = int(time.time()).to_bytes(4, byteorder="big")
self.channel_db.on_channel_update({"short_channel_id": chan.short_channel_id, 'channel_flags': channel_flags, 'cltv_expiry_delta': b'\x90',
'htlc_minimum_msat': b'\x03\xe8', 'fee_base_msat': b'\x03\xe8', 'fee_proportional_millionths': b'\x01',
@@ -832,16 +843,15 @@ class Peer(PrintError):
def send_announcement_signatures(self, chan):
- bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey,
- chan.config[REMOTE].multisig_key.pubkey]
-
- node_ids = [privkey_to_pubkey(self.privkey),
- self.peer_addr.pubkey]
+ bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey,
+ chan.config[LOCAL].multisig_key.pubkey]
- sorted_node_ids = list(sorted(node_ids))
+ sorted_node_ids = list(sorted(self.node_ids))
if sorted_node_ids != node_ids:
node_ids = sorted_node_ids
bitcoin_keys.reverse()
+ else:
+ node_ids = self.node_ids
chan_ann = gen_msg("channel_announcement",
len=0,
diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py
@@ -0,0 +1,66 @@
+from electrum.lnbase import Peer, decode_msg, gen_msg
+from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
+from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
+from electrum.ecc import ECPrivkey
+from electrum.lnrouter import ChannelDB
+import unittest
+import asyncio
+from electrum import simple_config
+import tempfile
+from .test_lnchan import create_test_channels
+
+class MockNetwork:
+ def __init__(self):
+ self.lnwatcher = None
+ user_config = {}
+ user_dir = tempfile.mkdtemp(prefix="electrum-lnbase-test-")
+ self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir)
+ self.asyncio_loop = asyncio.get_event_loop()
+ self.channel_db = ChannelDB(self)
+ self.interface = None
+ def register_callback(self, cb, trigger_names):
+ print("callback registered", repr(trigger_names))
+ def trigger_callback(self, trigger_name, obj):
+ print("callback triggered", repr(trigger_name))
+
+class MockLNWorker:
+ def __init__(self, remote_peer_pubkey, chan):
+ self.chan = chan
+ self.remote_peer_pubkey = remote_peer_pubkey
+ priv = ECPrivkey.generate_random_key().get_secret_bytes()
+ self.node_keypair = Keypair(
+ pubkey=privkey_to_pubkey(priv),
+ privkey=priv)
+ self.network = MockNetwork()
+ @property
+ def peers(self):
+ return {self.remote_peer_pubkey: self.peer}
+ def channels_for_peer(self, pubkey):
+ return {self.chan.channel_id: self.chan}
+
+class MockTransport:
+ def __init__(self):
+ self.queue = asyncio.Queue()
+ async def read_messages(self):
+ while True:
+ yield await self.queue.get()
+
+class BadFeaturesTransport(MockTransport):
+ def send_bytes(self, data):
+ decoded = decode_msg(data)
+ print(decoded)
+ if decoded[0] == 'init':
+ self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
+
+class TestPeer(unittest.TestCase):
+ def setUp(self):
+ self.alice_channel, self.bob_channel = create_test_channels()
+ def test_bad_feature_flags(self):
+ # we should require DATA_LOSS_PROTECT
+ mock_lnworker = MockLNWorker(b"\x00" * 32, self.alice_channel)
+ mock_transport = BadFeaturesTransport()
+ p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 32), request_initial_sync=False, transport=mock_transport)
+ mock_lnworker.peer = p1
+ with self.assertRaises(LightningPeerConnectionClosed):
+ asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
+