electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit f3e5ba6ac16e8e30879079a8a18f7ad931a6017a
parent 362a3a5a442d63374ea9d3daf358f63cd18e88f9
Author: SomberNight <somber.night@protonmail.com>
Date:   Mon, 30 Jul 2018 13:51:03 +0200

more reliable peer and channel re-establishing

Diffstat:
Melectrum/gui/qt/channels_list.py | 2+-
Melectrum/lnbase.py | 48++++++++++++++++++++++++++++++++++--------------
Melectrum/lnhtlc.py | 16+++++++++++++++-
Melectrum/lnrouter.py | 15+++++++++++++++
Melectrum/lnworker.py | 90+++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------
5 files changed, 130 insertions(+), 41 deletions(-)

diff --git a/electrum/gui/qt/channels_list.py b/electrum/gui/qt/channels_list.py @@ -24,7 +24,7 @@ class ChannelsList(MyTreeWidget): bh2u(chan.node_id), self.parent.format_amount(chan.local_state.amount_msat//1000), self.parent.format_amount(chan.remote_state.amount_msat//1000), - chan.state + chan.get_state() ] def create_menu(self, position): diff --git a/electrum/lnbase.py b/electrum/lnbase.py @@ -207,6 +207,10 @@ class HandshakeState(object): self.h = sha256(self.h + data) return self.h + +class HandshakeFailed(Exception): pass + + def get_nonce_bytes(n): """BOLT 8 requires the nonce to be 12 bytes, 4 bytes leading zeroes and 8 bytes little endian encoded 64 bit integer. @@ -285,6 +289,7 @@ class Peer(PrintError): self.host = host self.port = port self.pubkey = pubkey + self.peer_addr = LNPeerAddr(host, port, pubkey) self.lnworker = lnworker self.privkey = lnworker.privkey self.network = lnworker.network @@ -340,7 +345,10 @@ class Peer(PrintError): self.read_buffer = self.read_buffer[offset:] msg = aead_decrypt(rk_m, rn_m, b'', c) return msg - s = await self.reader.read(2**10) + try: + s = await self.reader.read(2**10) + except: + s = None if not s: raise LightningPeerConnectionClosed() self.read_buffer += s @@ -354,9 +362,11 @@ class Peer(PrintError): # act 1 self.writer.write(msg) rspns = await self.reader.read(2**10) - assert len(rspns) == 50, "Lightning handshake act 1 response has bad length, are you sure this is the right pubkey? " + str(bh2u(self.pubkey)) + if len(rspns) != 50: + raise HandshakeFailed("Lightning handshake act 1 response has bad length, are you sure this is the right pubkey? " + str(bh2u(self.pubkey))) hver, alice_epub, tag = rspns[0], rspns[1:34], rspns[34:] - assert bytes([hver]) == hs.handshake_version + if bytes([hver]) != hs.handshake_version: + raise HandshakeFailed("unexpected handshake version: {}".format(hver)) # act 2 hs.update(alice_epub) ss = get_ecdh(epriv, alice_epub) @@ -461,15 +471,21 @@ class Peer(PrintError): @aiosafe async def main_loop(self): await asyncio.wait_for(self.initialize(), 5) - self.channel_db.add_recent_peer(LNPeerAddr(self.host, self.port, self.pubkey)) + self.channel_db.add_recent_peer(self.peer_addr) # loop while True: self.ping_if_required() msg = await self.read_message() self.process_message(msg) - # close socket - self.print_error('closing lnbase') - self.writer.close() + + def close_and_cleanup(self): + try: + self.writer.close() + except: + pass + for chan in self.channels.values(): + chan.set_state('DISCONNECTED') + self.network.trigger_callback('channel', chan) @aiosafe async def channel_establishment_flow(self, wallet, config, password, funding_sat, push_msat, temp_channel_id): @@ -601,14 +617,18 @@ class Peer(PrintError): assert success, success m.remote_state = m.remote_state._replace(ctn=0) m.local_state = m.local_state._replace(ctn=0, current_commitment_signature=remote_sig) - m.state = 'OPENING' + m.set_state('OPENING') return m @aiosafe async def reestablish_channel(self, chan): await self.initialized chan_id = chan.channel_id - chan.state = 'REESTABLISHING' + if chan.get_state() != 'DISCONNECTED': + self.print_error('reestablish_channel was called but channel {} already in state {}' + .format(chan_id, chan.get_state())) + return + chan.set_state('REESTABLISHING') self.network.trigger_callback('channel', chan) self.send_message(gen_msg("channel_reestablish", channel_id=chan_id, @@ -616,7 +636,7 @@ class Peer(PrintError): next_remote_revocation_number=chan.remote_state.ctn )) await self.channel_reestablished[chan_id] - chan.state = 'OPENING' + chan.set_state('OPENING') if chan.local_state.funding_locked_received and chan.short_channel_id: self.mark_open(chan) self.network.trigger_callback('channel', chan) @@ -727,10 +747,10 @@ class Peer(PrintError): print("SENT CHANNEL ANNOUNCEMENT") def mark_open(self, chan): - if chan.state == "OPEN": + if chan.get_state() == "OPEN": return assert chan.local_state.funding_locked_received - chan.state = "OPEN" + chan.set_state("OPEN") self.network.trigger_callback('channel', chan) # add channel to database node_ids = [self.pubkey, self.lnworker.pubkey] @@ -820,7 +840,7 @@ class Peer(PrintError): @aiosafe async def pay(self, path, chan, amount_msat, payment_hash, pubkey_in_invoice, min_final_cltv_expiry): - assert chan.state == "OPEN" + assert chan.get_state() == "OPEN" assert amount_msat > 0, "amount_msat is not greater zero" height = self.network.get_local_height() route = self.network.path_finder.create_route_from_path(path, self.lnworker.pubkey) @@ -911,7 +931,7 @@ class Peer(PrintError): htlc_id = int.from_bytes(htlc["id"], 'big') assert htlc_id == chan.remote_state.next_htlc_id, (htlc_id, chan.remote_state.next_htlc_id) - assert chan.state == "OPEN" + assert chan.get_state() == "OPEN" cltv_expiry = int.from_bytes(htlc["cltv_expiry"], 'big') # TODO verify sanity of their cltv expiry diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py @@ -138,7 +138,21 @@ class HTLCStateMachine(PrintError): self.local_commitment = self.pending_local_commitment self.remote_commitment = self.pending_remote_commitment - self.state = 'DISCONNECTED' + self._is_funding_txo_spent = None # "don't know" + self.set_state('DISCONNECTED') + + def set_state(self, state: str): + self._state = state + + def get_state(self): + return self._state + + def set_funding_txo_spentness(self, is_spent: bool): + assert isinstance(is_spent, bool) + self._is_funding_txo_spent = is_spent + + def should_try_to_reestablish_peer(self) -> bool: + return self._is_funding_txo_spent is False and self._state == 'DISCONNECTED' def get_funding_address(self): script = funding_output_script(self.local_config, self.remote_config) diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py @@ -269,6 +269,7 @@ class ChannelDB(JsonDB): self._channels_for_node = defaultdict(set) # node -> set(short_channel_id) self.nodes = {} # node_id -> NodeInfo self._recent_peers = [] + self._last_good_address = {} # node_id -> LNPeerAddr self.ca_verifier = LNChanAnnVerifier(network, self) self.network.add_jobs([self.ca_verifier]) @@ -297,6 +298,11 @@ class ChannelDB(JsonDB): for host, port, pubkey in recent_peers: peer = LNPeerAddr(str(host), int(port), bfh(pubkey)) self._recent_peers.append(peer) + # last good address + last_good_addr = self.get('last_good_address', {}) + for node_id, host_and_port in last_good_addr.items(): + host, port = host_and_port + self._last_good_address[bfh(node_id)] = LNPeerAddr(str(host), int(port), bfh(node_id)) def save_data(self): with self.lock: @@ -316,6 +322,11 @@ class ChannelDB(JsonDB): recent_peers.append( [str(peer.host), int(peer.port), bh2u(peer.pubkey)]) self.put('recent_peers', recent_peers) + # last good address + last_good_addr = {} + for node_id, peer in self._last_good_address.items(): + last_good_addr[bh2u(node_id)] = [str(peer.host), int(peer.port)] + self.put('last_good_address', last_good_addr) self.write() def __len__(self): @@ -347,6 +358,10 @@ class ChannelDB(JsonDB): self._recent_peers.remove(peer) self._recent_peers.insert(0, peer) self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS] + self._last_good_address[peer.pubkey] = peer + + def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]: + return self._last_good_address.get(node_id, None) def on_channel_announcement(self, msg_payload, trusted=False): short_channel_id = msg_payload['short_channel_id'] diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -4,6 +4,7 @@ from decimal import Decimal import random import time from typing import Optional, Sequence +import threading import dns.resolver import dns.exception @@ -22,6 +23,7 @@ from .i18n import _ NUM_PEERS_TARGET = 4 PEER_RETRY_INTERVAL = 600 # seconds +PEER_RETRY_INTERVAL_FOR_CHANNELS = 30 # seconds FALLBACK_NODE_LIST = ( LNPeerAddr('ecdsa.net', 9735, bfh('038370f0e7a03eded3e1d41dc081084a87f0afa1c5b22090b4f3abb391eb15d8ff')), @@ -34,6 +36,7 @@ class LNWorker(PrintError): self.wallet = wallet self.network = network self.channel_db = self.network.channel_db + self.lock = threading.RLock() pk = wallet.storage.get('lightning_privkey') if pk is None: pk = bh2u(os.urandom(32)) @@ -48,9 +51,6 @@ class LNWorker(PrintError): for chan_id, chan in self.channels.items(): self.network.lnwatcher.watch_channel(chan, self.on_channel_utxos) self._last_tried_peer = {} # LNPeerAddr -> unix timestamp - # TODO peers that we have channels with should also be added now - # but we don't store their IP/port yet.. also what if it changes? - # need to listen for node_announcements and save the new IP/port self._add_peers_from_config() # wait until we see confirmations self.network.register_callback(self.on_network_update, ['updated', 'verified', 'fee_histogram']) # thread safe @@ -72,15 +72,14 @@ class LNWorker(PrintError): def channels_for_peer(self, node_id): assert type(node_id) is bytes - return {x: y for (x, y) in self.channels.items() if y.node_id == node_id} + with self.lock: + return {x: y for (x, y) in self.channels.items() if y.node_id == node_id} def add_peer(self, host, port, node_id): port = int(port) peer_addr = LNPeerAddr(host, port, node_id) if node_id in self.peers: return - if peer_addr in self._last_tried_peer: - return self._last_tried_peer[peer_addr] = time.time() self.print_error("adding peer", peer_addr) peer = Peer(self, host, port, node_id, request_initial_sync=self.config.get("request_initial_sync", True)) @@ -90,10 +89,11 @@ class LNWorker(PrintError): def save_channel(self, openchannel): assert type(openchannel) is HTLCStateMachine - self.channels[openchannel.channel_id] = openchannel if openchannel.remote_state.next_per_commitment_point == openchannel.remote_state.current_per_commitment_point: raise Exception("Tried to save channel with next_point == current_point, this should not happen") - dumped = [x.serialize() for x in self.channels.values()] + with self.lock: + self.channels[openchannel.channel_id] = openchannel + dumped = [x.serialize() for x in self.channels.values()] self.wallet.storage.put("channels", dumped) self.wallet.storage.write() self.network.trigger_callback('channel', openchannel) @@ -104,7 +104,7 @@ class LNWorker(PrintError): If the Funding TX has not been mined, return None """ - assert chan.state in ["OPEN", "OPENING"] + assert chan.get_state() in ["OPEN", "OPENING"] peer = self.peers[chan.node_id] conf = self.wallet.get_tx_height(chan.funding_outpoint.txid)[1] if conf >= chan.constraints.funding_txn_minimum_depth: @@ -121,16 +121,12 @@ class LNWorker(PrintError): def on_channel_utxos(self, chan, utxos): outpoints = [Outpoint(x["tx_hash"], x["tx_pos"]) for x in utxos] if chan.funding_outpoint not in outpoints: - chan.state = "CLOSED" + chan.set_funding_txo_spentness(True) + chan.set_state("CLOSED") # FIXME is this properly GC-ed? (or too soon?) LNChanCloseHandler(self.network, self.wallet, chan) - elif chan.state == 'DISCONNECTED': - if chan.node_id not in self.peers: - self.print_error("received channel_utxos for channel which does not have peer (errored?)") - return - peer = self.peers[chan.node_id] - coro = peer.reestablish_channel(chan) - asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) + else: + chan.set_funding_txo_spentness(False) self.network.trigger_callback('channel', chan) def on_network_update(self, event, *args): @@ -139,8 +135,10 @@ class LNWorker(PrintError): # since short_channel_id could be changed while saving. # Mitigated by posting to loop: async def network_jobs(): - for chan in self.channels.values(): - if chan.state == "OPENING": + with self.lock: + channels = list(self.channels.values()) + for chan in channels: + if chan.get_state() == "OPENING": res = self.save_short_chan_id(chan) if not res: self.print_error("network update but funding tx is still not at sufficient depth") @@ -148,7 +146,7 @@ class LNWorker(PrintError): # this results in the channel being marked OPEN peer = self.peers[chan.node_id] peer.funding_locked(chan) - elif chan.state == "OPEN": + elif chan.get_state() == "OPEN": peer = self.peers.get(chan.node_id) if peer is None: self.print_error("peer not found for {}".format(bh2u(chan.node_id))) @@ -177,6 +175,7 @@ class LNWorker(PrintError): return asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) def pay(self, invoice, amount_sat=None): + # TODO try some number of paths (e.g. 10) in case of failures addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) payment_hash = addr.paymenthash invoice_pubkey = addr.pubkey.serialize() @@ -189,7 +188,9 @@ class LNWorker(PrintError): raise Exception("No path found") node_id, short_channel_id = path[0] peer = self.peers[node_id] - for chan in self.channels.values(): + with self.lock: + channels = list(self.channels.values()) + for chan in channels: if chan.short_channel_id == short_channel_id: break else: @@ -216,7 +217,8 @@ class LNWorker(PrintError): self.wallet.storage.write() def list_channels(self): - return [str(x) for x in self.channels] + with self.lock: + return [str(x) for x in self.channels] def close_channel(self, chan_id): chan = self.channels[chan_id] @@ -250,7 +252,7 @@ class LNWorker(PrintError): # try random peer from graph all_nodes = self.channel_db.nodes if all_nodes: - self.print_error('trying to get ln peers from channel db') + #self.print_error('trying to get ln peers from channel db') node_ids = list(all_nodes) max_tries = min(200, len(all_nodes)) for i in range(max_tries): @@ -259,7 +261,7 @@ class LNWorker(PrintError): if node is None: continue addresses = node.addresses if not addresses: continue - host, port = addresses[0] + host, port = random.choice(addresses) peer = LNPeerAddr(host, port, node_id) if peer.pubkey in self.peers: continue if peer in self._last_tried_peer: continue @@ -309,16 +311,54 @@ class LNWorker(PrintError): self.print_error('got {} ln peers from dns seed'.format(len(peers))) return peers + def reestablish_peers_and_channels(self): + def reestablish_peer_for_given_channel(): + # try last good address first + peer = self.channel_db.get_last_good_address(chan.node_id) + if peer: + last_tried = self._last_tried_peer.get(peer, 0) + if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now: + self.add_peer(peer.host, peer.port, peer.pubkey) + return + # try random address for node_id + node_info = self.channel_db.nodes.get(chan.node_id, None) + if not node_info: return + addresses = node_info.addresses + if not addresses: return + host, port = random.choice(addresses) + peer = LNPeerAddr(host, port, chan.node_id) + last_tried = self._last_tried_peer.get(peer, 0) + if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now: + self.add_peer(host, port, chan.node_id) + + with self.lock: + channels = list(self.channels.values()) + now = time.time() + for chan in channels: + if not chan.should_try_to_reestablish_peer(): + continue + peer = self.peers.get(chan.node_id, None) + if peer is None: + reestablish_peer_for_given_channel() + else: + coro = peer.reestablish_channel(chan) + asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) + @aiosafe async def main_loop(self): while True: await asyncio.sleep(1) + now = time.time() for node_id, peer in list(self.peers.items()): if peer.exception: self.print_error("removing peer", peer.host) + peer.close_and_cleanup() self.peers.pop(node_id) + self.reestablish_peers_and_channels() if len(self.peers) >= NUM_PEERS_TARGET: continue peers = self._get_next_peers_to_try() for peer in peers: - self.add_peer(peer.host, peer.port, peer.pubkey) + last_tried = self._last_tried_peer.get(peer, 0) + if last_tried + PEER_RETRY_INTERVAL < now: + self.add_peer(peer.host, peer.port, peer.pubkey)