commit b5482e4470dcf2da331bf33b65c8a37ba009aa57
parent 61638664f769fa28d33711115c849267517ede08
Author: ThomasV <thomasv@electrum.org>
Date: Fri, 1 Feb 2019 20:21:59 +0100
create transport and perform handshake before creating Peer
Diffstat:
4 files changed, 43 insertions(+), 43 deletions(-)
diff --git a/electrum/lnbase.py b/electrum/lnbase.py
@@ -197,15 +197,14 @@ def gen_msg(msg_type: str, **kwargs) -> bytes:
class Peer(PrintError):
- def __init__(self, lnworker: 'LNWorker', peer_addr: LNPeerAddr, responding=False,
- request_initial_sync=False, transport: LNTransportBase=None):
- self.responding = responding
+ def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase,
+ request_initial_sync=False):
self.initialized = asyncio.Event()
self.transport = transport
- self.peer_addr = peer_addr
+ self.pubkey = pubkey
self.lnworker = lnworker
self.privkey = lnworker.node_keypair.privkey
- self.node_ids = [peer_addr.pubkey, privkey_to_pubkey(self.privkey)]
+ self.node_ids = [self.pubkey, privkey_to_pubkey(self.privkey)]
self.network = lnworker.network
self.lnwatcher = lnworker.network.lnwatcher
self.channel_db = lnworker.network.channel_db
@@ -233,19 +232,14 @@ class Peer(PrintError):
self.transport.send_bytes(gen_msg(message_name, **kwargs))
async def initialize(self):
- if not self.transport:
- reader, writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port)
- transport = LNTransport(self.privkey, self.peer_addr.pubkey, reader, writer)
- await transport.handshake()
- self.transport = transport
self.send_message("init", gflen=0, lflen=1, localfeatures=self.localfeatures)
@property
def channels(self) -> Dict[bytes, Channel]:
- return self.lnworker.channels_for_peer(self.peer_addr.pubkey)
+ return self.lnworker.channels_for_peer(self.pubkey)
def diagnostic_name(self):
- return str(self.peer_addr.host) + ':' + str(self.peer_addr.port)
+ return self.transport.name()
def ping_if_required(self):
if time.time() - self.ping_time > 120:
@@ -352,7 +346,7 @@ class Peer(PrintError):
self.print_error("disconnecting gracefully. {}".format(e))
finally:
self.close_and_cleanup()
- self.lnworker.peers.pop(self.peer_addr.pubkey)
+ self.lnworker.peers.pop(self.pubkey)
return wrapper_func
@ignore_exceptions # do not kill main_taskgroup
@@ -373,8 +367,6 @@ class Peer(PrintError):
except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
self.print_error('initialize failed, disconnecting: {}'.format(repr(e)))
return
- if not self.responding:
- self.channel_db.add_recent_peer(self.peer_addr)
# loop
async for msg in self.transport.read_messages():
self.process_message(msg)
@@ -513,7 +505,7 @@ class Peer(PrintError):
# remote commitment transaction
channel_id, funding_txid_bytes = channel_id_from_funding_tx(funding_txid, funding_index)
chan_dict = {
- "node_id": self.peer_addr.pubkey,
+ "node_id": self.pubkey,
"channel_id": channel_id,
"short_channel_id": None,
"funding_outpoint": Outpoint(funding_txid, funding_index),
@@ -587,7 +579,7 @@ class Peer(PrintError):
remote_dust_limit_sat = int.from_bytes(payload['dust_limit_satoshis'], byteorder='big') # TODO validate
remote_reserve_sat = self.validate_remote_reserve(payload['channel_reserve_satoshis'], remote_dust_limit_sat, funding_sat)
chan_dict = {
- "node_id": self.peer_addr.pubkey,
+ "node_id": self.pubkey,
"channel_id": channel_id,
"short_channel_id": None,
"funding_outpoint": Outpoint(funding_txid, funding_idx),
@@ -794,7 +786,7 @@ class Peer(PrintError):
remote_bitcoin_sig = announcement_signatures_msg["bitcoin_signature"]
if not ecc.verify_signature(chan.config[REMOTE].multisig_key.pubkey, remote_bitcoin_sig, h):
raise Exception("bitcoin_sig invalid in announcement_signatures")
- if not ecc.verify_signature(self.peer_addr.pubkey, remote_node_sig, h):
+ if not ecc.verify_signature(self.pubkey, remote_node_sig, h):
raise Exception("node_sig invalid in announcement_signatures")
node_sigs = [remote_node_sig, local_node_sig]
diff --git a/electrum/lntransport.py b/electrum/lntransport.py
@@ -6,6 +6,7 @@
# Derived from https://gist.github.com/AdamISZ/046d05c156aaeb56cc897f85eecb3eb8
import hashlib
+import asyncio
from asyncio import StreamReader, StreamWriter
from Cryptodome.Cipher import ChaCha20_Poly1305
@@ -87,10 +88,6 @@ def create_ephemeral_key() -> (bytes, bytes):
class LNTransportBase:
- def __init__(self, reader: StreamReader, writer: StreamWriter):
- self.reader = reader
- self.writer = writer
-
def send_bytes(self, msg):
l = len(msg).to_bytes(2, 'big')
lc = aead_encrypt(self.sk, self.sn(), b'', l)
@@ -153,12 +150,16 @@ class LNTransportBase:
class LNResponderTransport(LNTransportBase):
def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
- LNTransportBase.__init__(self, reader, writer)
+ LNTransportBase.__init__(self)
+ self.reader = reader
+ self.writer = writer
self.privkey = privkey
+ def name(self):
+ return "responder"
+
async def handshake(self, **kwargs):
hs = HandshakeState(privkey_to_pubkey(self.privkey))
-
act1 = b''
while len(act1) < 50:
act1 += await self.reader.read(50 - len(act1))
@@ -205,14 +206,20 @@ class LNResponderTransport(LNTransportBase):
return rs
class LNTransport(LNTransportBase):
- def __init__(self, privkey: bytes, remote_pubkey: bytes,
- reader: StreamReader, writer: StreamWriter):
- LNTransportBase.__init__(self, reader, writer)
+
+ def __init__(self, privkey: bytes, peer_addr):
+ LNTransportBase.__init__(self)
assert type(privkey) is bytes and len(privkey) == 32
self.privkey = privkey
- self.remote_pubkey = remote_pubkey
+ self.remote_pubkey = peer_addr.pubkey
+ self.host = peer_addr.host
+ self.port = peer_addr.port
+
+ def name(self):
+ return str(self.host) + ':' + str(self.port)
async def handshake(self):
+ self.reader, self.writer = await asyncio.open_connection(self.host, self.port)
hs = HandshakeState(self.remote_pubkey)
# Get a new ephemeral key
epriv, epub = create_ephemeral_key()
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -28,7 +28,7 @@ from .crypto import sha256
from .bip32 import bip32_root
from .util import bh2u, bfh, PrintError, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions
from .util import timestamp_to_datetime
-from .lntransport import LNResponderTransport
+from .lntransport import LNTransport, LNResponderTransport
from .lnbase import Peer
from .lnaddr import lnencode, LnAddr, lndecode
from .ecc import der_sig_from_sig_string
@@ -244,13 +244,16 @@ class LNWorker(PrintError):
return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
async def add_peer(self, host, port, node_id):
- port = int(port)
- peer_addr = LNPeerAddr(host, port, node_id)
if node_id in self.peers:
return self.peers[node_id]
+ port = int(port)
+ peer_addr = LNPeerAddr(host, port, node_id)
+ transport = LNTransport(self.node_keypair.privkey, peer_addr)
+ await transport.handshake()
+ self.channel_db.add_recent_peer(peer_addr)
self._last_tried_peer[peer_addr] = time.time()
self.print_error("adding peer", peer_addr)
- peer = Peer(self, peer_addr, request_initial_sync=self.config.get("request_initial_sync", True))
+ peer = Peer(self, node_id, transport, request_initial_sync=self.config.get("request_initial_sync", True))
await self.network.main_taskgroup.spawn(peer.main_loop())
self.peers[node_id] = peer
self.network.trigger_callback('ln_status')
@@ -797,16 +800,13 @@ class LNWorker(PrintError):
# ipv6
addr = addr[1:-1]
async def cb(reader, writer):
- t = LNResponderTransport(self.node_keypair.privkey, reader, writer)
+ transport = LNResponderTransport(self.node_keypair.privkey, reader, writer)
try:
- node_id = await t.handshake()
+ node_id = await transport.handshake()
except:
self.print_error('handshake failure from incoming connection')
return
- # FIXME extract host and port from transport
- peer = Peer(self, LNPeerAddr("bogus", 1337, node_id), responding=True,
- request_initial_sync=self.config.get("request_initial_sync", True),
- transport=t)
+ peer = Peer(self, node_id, transport, request_initial_sync=self.config.get("request_initial_sync", True))
self.peers[node_id] = peer
await self.network.main_taskgroup.spawn(peer.main_loop())
self.network.trigger_callback('ln_status')
diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py
@@ -113,6 +113,9 @@ class MockTransport:
def __init__(self):
self.queue = asyncio.Queue()
+ def name(self):
+ return ""
+
async def read_messages(self):
while True:
yield await self.queue.get()
@@ -150,7 +153,7 @@ class TestPeer(unittest.TestCase):
def test_require_data_loss_protect(self):
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
mock_transport = NoFeaturesTransport()
- p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
+ p1 = Peer(mock_lnworker, b"\x00" * 33, mock_transport, request_initial_sync=False)
mock_lnworker.peer = p1
with self.assertRaises(LightningPeerConnectionClosed):
run(asyncio.wait_for(p1._main_loop(), 1))
@@ -161,10 +164,8 @@ class TestPeer(unittest.TestCase):
q1, q2 = asyncio.Queue(), asyncio.Queue()
w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1)
w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2)
- p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
- request_initial_sync=False, transport=t1)
- p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
- request_initial_sync=False, transport=t2)
+ p1 = Peer(w1, k1.pubkey, t1, request_initial_sync=False)
+ p2 = Peer(w2, k2.pubkey, t2, request_initial_sync=False)
w1.peer = p1
w2.peer = p2
# mark_open won't work if state is already OPEN.