electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 13d69973555fef8802cce264f7ea324e89e36684
parent edba59ef54b4489ddccdbe70b883bfec42be3d3a
Author: SomberNight <somber.night@protonmail.com>
Date:   Tue, 26 Nov 2019 00:15:33 +0100

LNPeerAddr: validate arguments

no longer subclassing NamedTuple (as it is difficult to do validation then...)

Diffstat:
Melectrum/channel_db.py | 14++++++++++----
Melectrum/lntransport.py | 23++++++++++++++---------
Melectrum/lnutil.py | 29+++++++++++++++++++++++------
Melectrum/lnworker.py | 7++++++-
4 files changed, 53 insertions(+), 20 deletions(-)

diff --git a/electrum/channel_db.py b/electrum/channel_db.py @@ -281,13 +281,19 @@ class ChannelDB(SqlDB): return None addr = sorted(list(r), key=lambda x: x[2])[0] host, port, timestamp = addr - return LNPeerAddr(host, port, node_id) + try: + return LNPeerAddr(host, port, node_id) + except ValueError: + return None def get_recent_peers(self): assert self.data_loaded.is_set(), "channelDB load_data did not finish yet!" - r = [self.get_last_good_address(x) for x in self._addresses.keys()] - r = r[-self.NUM_MAX_RECENT_PEERS:] - return r + # FIXME this does not reliably return "recent" peers... + # Also, the list() cast over the whole dict (thousands of elements), + # is really inefficient. + r = [self.get_last_good_address(node_id) + for node_id in list(self._addresses.keys())[-self.NUM_MAX_RECENT_PEERS:]] + return list(reversed(r)) def add_channel_announcement(self, msg_payloads, trusted=True): if type(msg_payloads) is dict: diff --git a/electrum/lntransport.py b/electrum/lntransport.py @@ -8,11 +8,12 @@ import hashlib import asyncio from asyncio import StreamReader, StreamWriter + from Cryptodome.Cipher import ChaCha20_Poly1305 from .crypto import sha256, hmac_oneshot from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed, - HandshakeFailed) + HandshakeFailed, LNPeerAddr) from . import ecc from .util import bh2u @@ -86,7 +87,13 @@ def create_ephemeral_key() -> (bytes, bytes): privkey = ecc.ECPrivkey.generate_random_key() return privkey.get_secret_bytes(), privkey.get_public_key_bytes() + class LNTransportBase: + reader: StreamReader + writer: StreamWriter + + def name(self) -> str: + raise NotImplementedError() def send_bytes(self, msg: bytes) -> None: l = len(msg).to_bytes(2, 'big') @@ -207,21 +214,18 @@ class LNResponderTransport(LNTransportBase): class LNTransport(LNTransportBase): - def __init__(self, privkey: bytes, peer_addr): + def __init__(self, privkey: bytes, peer_addr: LNPeerAddr): LNTransportBase.__init__(self) assert type(privkey) is bytes and len(privkey) == 32 self.privkey = privkey - self.remote_pubkey = peer_addr.pubkey - self.host = peer_addr.host - self.port = peer_addr.port self.peer_addr = peer_addr def name(self): - return str(self.host) + ':' + str(self.port) + return self.peer_addr.net_addr_str() async def handshake(self): - self.reader, self.writer = await asyncio.open_connection(self.host, self.port) - hs = HandshakeState(self.remote_pubkey) + self.reader, self.writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port) + hs = HandshakeState(self.peer_addr.pubkey) # Get a new ephemeral key epriv, epub = create_ephemeral_key() @@ -230,7 +234,8 @@ class LNTransport(LNTransportBase): self.writer.write(msg) rspns = await self.reader.read(2**10) if len(rspns) != 50: - raise HandshakeFailed(f"Lightning handshake act 1 response has bad length, are you sure this is the right pubkey? {bh2u(self.remote_pubkey)}") + raise HandshakeFailed(f"Lightning handshake act 1 response has bad length, " + f"are you sure this is the right pubkey? {self.peer_addr}") hver, alice_epub, tag = rspns[0], rspns[1:34], rspns[34:] if bytes([hver]) != hs.handshake_version: raise HandshakeFailed("unexpected handshake version: {}".format(hver)) diff --git a/electrum/lnutil.py b/electrum/lnutil.py @@ -658,14 +658,31 @@ class LnGlobalFeatures(IntFlag): LN_GLOBAL_FEATURES_KNOWN_SET = set(LnGlobalFeatures) -class LNPeerAddr(NamedTuple): - host: str - port: int - pubkey: bytes +class LNPeerAddr: + + def __init__(self, host: str, port: int, pubkey: bytes): + assert isinstance(host, str), repr(host) + assert isinstance(port, int), repr(port) + assert isinstance(pubkey, bytes), repr(pubkey) + try: + net_addr = NetAddress(host, port) # this validates host and port + except Exception as e: + raise ValueError(f"cannot construct LNPeerAddr: invalid host or port (host={host}, port={port})") from e + # note: not validating pubkey as it would be too expensive: + # if not ECPubkey.is_pubkey_bytes(pubkey): raise ValueError() + self.host = host + self.port = port + self.pubkey = pubkey + self._net_addr_str = str(net_addr) def __str__(self): - host_and_port = str(NetAddress(self.host, self.port)) - return '{}@{}'.format(self.pubkey.hex(), host_and_port) + return '{}@{}'.format(self.pubkey.hex(), self.net_addr_str()) + + def __repr__(self): + return f'<LNPeerAddr host={self.host} port={self.port} pubkey={self.pubkey.hex()}>' + + def net_addr_str(self) -> str: + return self._net_addr_str def get_compressed_pubkey_from_bech32(bech32_pubkey: str) -> bytes: diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -221,7 +221,10 @@ class LNWorker(Logger): if not addrs: continue host, port, timestamp = self.choose_preferred_address(list(addrs)) - peer = LNPeerAddr(host, port, node_id) + try: + peer = LNPeerAddr(host, port, node_id) + except ValueError: + continue if peer in self._last_tried_peer: continue #self.logger.info('taking random ln peer from our channel db') @@ -1265,6 +1268,8 @@ class LNWallet(LNWorker): self.network.trigger_callback('channels_updated', self.wallet) self.network.trigger_callback('wallet_updated', self.wallet) + @ignore_exceptions + @log_exceptions async def reestablish_peer_for_given_channel(self, chan): now = time.time() # try last good address first