electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 691ebaf4f816b45ca10eabceae3068b7465e6bc5
parent d800f88bfcdc83a2cb9d359791c0c7e8cf6fcaff
Author: SomberNight <somber.night@protonmail.com>
Date:   Wed, 24 Feb 2021 20:03:12 +0100

lnworker/lnpeer: add some type hints, force some kwargs

Diffstat:
Melectrum/lnonion.py | 9++++++---
Melectrum/lnpeer.py | 59+++++++++++++++++++++++++++++++++++++++++------------------
Melectrum/lnrater.py | 5++++-
Melectrum/lnworker.py | 150++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------------
Melectrum/tests/test_lnpeer.py | 8+++++++-
5 files changed, 160 insertions(+), 71 deletions(-)

diff --git a/electrum/lnonion.py b/electrum/lnonion.py @@ -437,9 +437,12 @@ class OnionRoutingFailure(Exception): return str(self.code.name) return f"Unknown error ({self.code!r})" -def construct_onion_error(reason: OnionRoutingFailure, - onion_packet: OnionPacket, - our_onion_private_key: bytes) -> bytes: + +def construct_onion_error( + reason: OnionRoutingFailure, + onion_packet: OnionPacket, + our_onion_private_key: bytes, +) -> bytes: # create payload failure_msg = reason.to_bytes() failure_len = len(failure_msg) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -1373,9 +1373,12 @@ class Peer(Logger): chan.receive_htlc(htlc, onion_packet) util.trigger_callback('htlc_added', chan, htlc, RECEIVED) - def maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, - onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket - ) -> Tuple[Optional[bytes], Optional[int], Optional[OnionRoutingFailure]]: + def maybe_forward_htlc( + self, + *, + htlc: UpdateAddHtlc, + processed_onion: ProcessedOnionPacket, + ) -> Tuple[bytes, int]: # Forward HTLC # FIXME: there are critical safety checks MISSING here forwarding_enabled = self.network.config.get('lightning_forward_payments', False) @@ -1662,7 +1665,7 @@ class Peer(Logger): self.shutdown_received[chan_id] = asyncio.Future() await self.send_shutdown(chan) payload = await self.shutdown_received[chan_id] - txid = await self._shutdown(chan, payload, True) + txid = await self._shutdown(chan, payload, is_local=True) self.logger.info(f'({chan.get_id_for_log()}) Channel closed {txid}') return txid @@ -1686,10 +1689,10 @@ class Peer(Logger): else: chan = self.channels[chan_id] await self.send_shutdown(chan) - txid = await self._shutdown(chan, payload, False) + txid = await self._shutdown(chan, payload, is_local=False) self.logger.info(f'({chan.get_id_for_log()}) Channel closed by remote peer {txid}') - def can_send_shutdown(self, chan): + def can_send_shutdown(self, chan: Channel): if chan.get_state() >= ChannelState.OPENING: return True if chan.constraints.is_initiator and chan.channel_id in self.funding_created_sent: @@ -1718,7 +1721,7 @@ class Peer(Logger): chan.set_can_send_ctx_updates(True) @log_exceptions - async def _shutdown(self, chan: Channel, payload, is_local): + async def _shutdown(self, chan: Channel, payload, *, is_local: bool): # wait until no HTLCs remain in either commitment transaction while len(chan.hm.htlcs(LOCAL)) + len(chan.hm.htlcs(REMOTE)) > 0: self.logger.info(f'(chan: {chan.short_channel_id}) waiting for htlcs to settle...') @@ -1826,7 +1829,12 @@ class Peer(Logger): error_reason = e else: try: - preimage, fw_info, error_bytes = self.process_unfulfilled_htlc(chan, htlc_id, htlc, forwarding_info, onion_packet_bytes, onion_packet) + preimage, fw_info, error_bytes = self.process_unfulfilled_htlc( + chan=chan, + htlc=htlc, + forwarding_info=forwarding_info, + onion_packet_bytes=onion_packet_bytes, + onion_packet=onion_packet) except OnionRoutingFailure as e: error_bytes = construct_onion_error(e, onion_packet, our_onion_private_key=self.privkey) if fw_info: @@ -1850,13 +1858,24 @@ class Peer(Logger): for htlc_id in done: unfulfilled.pop(htlc_id) - def process_unfulfilled_htlc(self, chan, htlc_id, htlc, forwarding_info, onion_packet_bytes, onion_packet): + def process_unfulfilled_htlc( + self, + *, + chan: Channel, + htlc: UpdateAddHtlc, + forwarding_info: Tuple[str, int], + onion_packet_bytes: bytes, + onion_packet: OnionPacket, + ) -> Tuple[Optional[bytes], Union[bool, None, Tuple[str, int]], Optional[bytes]]: """ returns either preimage or fw_info or error_bytes or (None, None, None) raise an OnionRoutingFailure if we need to fail the htlc """ payment_hash = htlc.payment_hash - processed_onion = self.process_onion_packet(onion_packet, payment_hash, onion_packet_bytes) + processed_onion = self.process_onion_packet( + onion_packet, + payment_hash=payment_hash, + onion_packet_bytes=onion_packet_bytes) if processed_onion.are_we_final: preimage = self.maybe_fulfill_htlc( chan=chan, @@ -1867,8 +1886,8 @@ class Peer(Logger): if not forwarding_info: trampoline_onion = self.process_onion_packet( processed_onion.trampoline_onion_packet, - htlc.payment_hash, - onion_packet_bytes, + payment_hash=htlc.payment_hash, + onion_packet_bytes=onion_packet_bytes, is_trampoline=True) if trampoline_onion.are_we_final: preimage = self.maybe_fulfill_htlc( @@ -1892,13 +1911,10 @@ class Peer(Logger): elif not forwarding_info: next_chan_id, next_htlc_id = self.maybe_forward_htlc( - chan=chan, htlc=htlc, - onion_packet=onion_packet, processed_onion=processed_onion) - if next_chan_id: - fw_info = (next_chan_id.hex(), next_htlc_id) - return None, fw_info, None + fw_info = (next_chan_id.hex(), next_htlc_id) + return None, fw_info, None else: preimage = self.lnworker.get_preimage(payment_hash) next_chan_id_hex, htlc_id = forwarding_info @@ -1913,7 +1929,14 @@ class Peer(Logger): return preimage, None, None return None, None, None - def process_onion_packet(self, onion_packet, payment_hash, onion_packet_bytes, is_trampoline=False): + def process_onion_packet( + self, + onion_packet: OnionPacket, + *, + payment_hash: bytes, + onion_packet_bytes: bytes, + is_trampoline: bool = False, + ) -> ProcessedOnionPacket: failure_data = sha256(onion_packet_bytes) try: processed_onion = process_onion_packet( diff --git a/electrum/lnrater.py b/electrum/lnrater.py @@ -268,7 +268,10 @@ class LNRater(Logger): return pk, self._node_stats[pk] - def suggest_peer(self): + def suggest_peer(self) -> Optional[bytes]: + """Suggests a LN node to open a channel with. + Returns a node ID (pubkey). + """ self.maybe_analyze_graph() if self._node_ratings: return self.suggest_node_channel_open()[0] diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -7,7 +7,8 @@ import os from decimal import Decimal import random import time -from typing import Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any +from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, + NamedTuple, Union, Mapping, Any, Iterable) import threading import socket import aiohttp @@ -266,10 +267,10 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]): with self.lock: return self._peers.copy() - def channels_for_peer(self, node_id): + def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]: return {} - def get_node_alias(self, node_id): + def get_node_alias(self, node_id: bytes) -> str: if self.channel_db: node_info = self.channel_db.get_node_info_for_node_id(node_id) node_alias = (node_info.alias if node_info else '') or node_id.hex() @@ -380,7 +381,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]): self._add_peer(host, int(port), bfh(pubkey)), self.network.asyncio_loop) - def is_good_peer(self, peer): + def is_good_peer(self, peer: LNPeerAddr) -> bool: # the purpose of this method is to filter peers that advertise the desired feature bits # it is disabled for now, because feature bits published in node announcements seem to be unreliable return True @@ -566,7 +567,7 @@ class LNGossip(LNWorker): self.channel_db.prune_orphaned_channels() await asyncio.sleep(120) - async def add_new_ids(self, ids): + async def add_new_ids(self, ids: Iterable[bytes]): known = self.channel_db.get_channel_ids() new = set(ids) - set(known) self.unknown_ids.update(new) @@ -574,7 +575,7 @@ class LNGossip(LNWorker): util.trigger_callback('gossip_peers', self.num_peers()) util.trigger_callback('ln_gossip_sync_progress') - def get_ids_to_query(self): + def get_ids_to_query(self) -> Sequence[bytes]: N = 500 l = list(self.unknown_ids) self.unknown_ids = set(l[N:]) @@ -910,7 +911,7 @@ class LNWallet(LNWorker): if chan.funding_outpoint.to_str() == txo: return chan - async def on_channel_update(self, chan): + async def on_channel_update(self, chan: Channel): if chan.get_state() == ChannelState.OPEN and chan.should_be_closed_due_to_expiring_htlcs(self.network.get_local_height()): self.logger.info(f"force-closing due to expiring htlcs") @@ -938,10 +939,14 @@ class LNWallet(LNWorker): @log_exceptions async def _open_channel_coroutine( - self, *, connect_str: str, + self, + *, + connect_str: str, funding_tx: PartialTransaction, - funding_sat: int, push_sat: int, - password: Optional[str]) -> Tuple[Channel, PartialTransaction]: + funding_sat: int, + push_sat: int, + password: Optional[str], + ) -> Tuple[Channel, PartialTransaction]: peer = await self.add_peer(connect_str) coro = peer.channel_establishment_flow( funding_tx=funding_tx, @@ -1006,7 +1011,7 @@ class LNWallet(LNWorker): if chan.short_channel_id == short_channel_id: return chan - def create_routes_from_invoice(self, amount_msat, decoded_invoice, *, full_path=None): + def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None): return self.create_routes_for_payment( amount_msat=amount_msat, invoice_pubkey=decoded_invoice.pubkey.serialize(), @@ -1051,9 +1056,16 @@ class LNWallet(LNWorker): util.trigger_callback('invoice_status', self.wallet, key) try: await self.pay_to_node( - invoice_pubkey, payment_hash, payment_secret, amount_to_pay, - min_cltv_expiry, r_tags, t_tags, invoice_features, - attempts=attempts, full_path=full_path) + node_pubkey=invoice_pubkey, + payment_hash=payment_hash, + payment_secret=payment_secret, + amount_to_pay=amount_to_pay, + min_cltv_expiry=min_cltv_expiry, + r_tags=r_tags, + t_tags=t_tags, + invoice_features=invoice_features, + attempts=attempts, + full_path=full_path) success = True except PaymentFailure as e: self.logger.exception('') @@ -1068,12 +1080,23 @@ class LNWallet(LNWorker): log = self.logs[key] return success, log - async def pay_to_node( - self, node_pubkey, payment_hash, payment_secret, amount_to_pay, - min_cltv_expiry, r_tags, t_tags, invoice_features, *, - attempts: int = 1, full_path: LNPaymentPath=None, - trampoline_onion=None, trampoline_fee=None, trampoline_cltv_delta=None): + self, + *, + node_pubkey: bytes, + payment_hash: bytes, + payment_secret: Optional[bytes], + amount_to_pay: int, # in msat + min_cltv_expiry: int, + r_tags, + t_tags, + invoice_features: int, + attempts: int = 1, + full_path: LNPaymentPath = None, + trampoline_onion=None, + trampoline_fee=None, + trampoline_cltv_delta=None, + ) -> None: if trampoline_onion: # todo: compare to the fee of the actual route we found @@ -1095,7 +1118,14 @@ class LNWallet(LNWorker): min_cltv_expiry, r_tags, t_tags, invoice_features, full_path=full_path)) # 2. send htlcs for route, amount_msat in routes: - await self.pay_to_route(route, amount_msat, amount_to_pay, payment_hash, payment_secret, min_cltv_expiry, trampoline_onion) + await self.pay_to_route( + route, + amount_msat=amount_msat, + total_msat=amount_to_pay, + payment_hash=payment_hash, + payment_secret=payment_secret, + min_cltv_expiry=min_cltv_expiry, + trampoline_onion=trampoline_onion) amount_inflight += amount_msat util.trigger_callback('invoice_status', self.wallet, payment_hash.hex()) # 3. await a queue @@ -1111,9 +1141,17 @@ class LNWallet(LNWorker): # if we get a channel update, we might retry the same route and amount self.handle_error_code_from_failed_htlc(htlc_log) - async def pay_to_route(self, route: LNPaymentRoute, amount_msat: int, - total_msat: int, payment_hash: bytes, payment_secret: bytes, - min_cltv_expiry: int, trampoline_onion: bytes=None): + async def pay_to_route( + self, + route: LNPaymentRoute, + *, + amount_msat: int, + total_msat: int, + payment_hash: bytes, + payment_secret: Optional[bytes], + min_cltv_expiry: int, + trampoline_onion: bytes = None, + ) -> None: # send a single htlc short_channel_id = route[0].short_channel_id chan = self.get_channel_by_short_id(short_channel_id) @@ -1267,7 +1305,7 @@ class LNWallet(LNWorker): result.append(bitstring.BitArray(pubkey) + bitstring.BitArray(channel) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv)) return result.tobytes() - def is_trampoline_peer(self, node_id): + def is_trampoline_peer(self, node_id: bytes) -> bool: # until trampoline is advertised in lnfeatures, check against hardcoded list if is_hardcoded_trampoline(node_id): return True @@ -1276,8 +1314,11 @@ class LNWallet(LNWorker): return True return False - def suggest_peer(self): - return self.lnrater.suggest_peer() if self.channel_db else random.choice(list(hardcoded_trampoline_nodes().values())).pubkey + def suggest_peer(self) -> Optional[bytes]: + if self.channel_db: + return self.lnrater.suggest_peer() + else: + return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey def create_trampoline_route( self, amount_msat:int, @@ -1400,8 +1441,10 @@ class LNWallet(LNWorker): invoice_pubkey, min_cltv_expiry, r_tags, t_tags, - invoice_features, - *, full_path: LNPaymentPath = None) -> Sequence[Tuple[LNPaymentRoute, int]]: + invoice_features: int, + *, + full_path: LNPaymentPath = None, + ) -> Sequence[Tuple[LNPaymentRoute, int]]: """Creates multiple routes for splitting a payment over the available private channels. @@ -1411,13 +1454,14 @@ class LNWallet(LNWorker): # try to send over a single channel try: routes = [self.create_route_for_payment( - amount_msat, - invoice_pubkey, - min_cltv_expiry, - r_tags, t_tags, - invoice_features, - None, - full_path=full_path + amount_msat=amount_msat, + invoice_pubkey=invoice_pubkey, + min_cltv_expiry=min_cltv_expiry, + r_tags=r_tags, + t_tags=t_tags, + invoice_features=invoice_features, + outgoing_channel=None, + full_path=full_path, )] except NoPathFound: if not invoice_features.supports(LnFeatures.BASIC_MPP_OPT): @@ -1439,12 +1483,13 @@ class LNWallet(LNWorker): # its capacity. This could be dealt with by temporarily # iteratively blacklisting channels for this mpp attempt. route, amt = self.create_route_for_payment( - part_amount_msat, - invoice_pubkey, - min_cltv_expiry, - r_tags, t_tags, - invoice_features, - channel, + amount_msat=part_amount_msat, + invoice_pubkey=invoice_pubkey, + min_cltv_expiry=min_cltv_expiry, + r_tags=r_tags, + t_tags=t_tags, + invoice_features=invoice_features, + outgoing_channel=channel, full_path=None) routes.append((route, amt)) self.logger.info(f"found acceptable split configuration: {list(s[0].values())} rating: {s[1]}") @@ -1457,13 +1502,16 @@ class LNWallet(LNWorker): def create_route_for_payment( self, + *, amount_msat: int, - invoice_pubkey, - min_cltv_expiry, - r_tags, t_tags, - invoice_features, + invoice_pubkey: bytes, + min_cltv_expiry: int, + r_tags, + t_tags, + invoice_features: int, outgoing_channel: Channel = None, - *, full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]: + full_path: Optional[LNPaymentPath], + ) -> Tuple[LNPaymentRoute, int]: channels = [outgoing_channel] if outgoing_channel else list(self.channels.values()) if not self.channel_db: @@ -1554,7 +1602,13 @@ class LNWallet(LNWorker): raise Exception(_("add invoice timed out")) @log_exceptions - async def create_invoice(self, *, amount_msat: Optional[int], message, expiry: int): + async def create_invoice( + self, + *, + amount_msat: Optional[int], + message, + expiry: int, + ) -> Tuple[LnAddr, str]: timestamp = int(time.time()) routing_hints = await self._calc_routing_hints_for_invoice(amount_msat) if not routing_hints: @@ -1628,7 +1682,7 @@ class LNWallet(LNWorker): self.payments[key] = info.amount_msat, info.direction, info.status self.wallet.save_db() - def htlc_received(self, short_channel_id, htlc, expected_msat): + def htlc_received(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int): status = self.get_payment_status(htlc.payment_hash) if status == PR_PAID: return True, None diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -775,7 +775,13 @@ class TestPeer(ElectrumTestCase): min_cltv_expiry = lnaddr.get_min_final_cltv_expiry() payment_hash = lnaddr.paymenthash payment_secret = lnaddr.payment_secret - pay = w1.pay_to_route(route, amount_msat, amount_msat, payment_hash, payment_secret, min_cltv_expiry) + pay = w1.pay_to_route( + route, + amount_msat=amount_msat, + total_msat=amount_msat, + payment_hash=payment_hash, + payment_secret=payment_secret, + min_cltv_expiry=min_cltv_expiry) await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) with self.assertRaises(PaymentFailure): run(f())