commit 691ebaf4f816b45ca10eabceae3068b7465e6bc5
parent d800f88bfcdc83a2cb9d359791c0c7e8cf6fcaff
Author: SomberNight <somber.night@protonmail.com>
Date: Wed, 24 Feb 2021 20:03:12 +0100
lnworker/lnpeer: add some type hints, force some kwargs
Diffstat:
5 files changed, 160 insertions(+), 71 deletions(-)
diff --git a/electrum/lnonion.py b/electrum/lnonion.py
@@ -437,9 +437,12 @@ class OnionRoutingFailure(Exception):
return str(self.code.name)
return f"Unknown error ({self.code!r})"
-def construct_onion_error(reason: OnionRoutingFailure,
- onion_packet: OnionPacket,
- our_onion_private_key: bytes) -> bytes:
+
+def construct_onion_error(
+ reason: OnionRoutingFailure,
+ onion_packet: OnionPacket,
+ our_onion_private_key: bytes,
+) -> bytes:
# create payload
failure_msg = reason.to_bytes()
failure_len = len(failure_msg)
diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
@@ -1373,9 +1373,12 @@ class Peer(Logger):
chan.receive_htlc(htlc, onion_packet)
util.trigger_callback('htlc_added', chan, htlc, RECEIVED)
- def maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *,
- onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket
- ) -> Tuple[Optional[bytes], Optional[int], Optional[OnionRoutingFailure]]:
+ def maybe_forward_htlc(
+ self,
+ *,
+ htlc: UpdateAddHtlc,
+ processed_onion: ProcessedOnionPacket,
+ ) -> Tuple[bytes, int]:
# Forward HTLC
# FIXME: there are critical safety checks MISSING here
forwarding_enabled = self.network.config.get('lightning_forward_payments', False)
@@ -1662,7 +1665,7 @@ class Peer(Logger):
self.shutdown_received[chan_id] = asyncio.Future()
await self.send_shutdown(chan)
payload = await self.shutdown_received[chan_id]
- txid = await self._shutdown(chan, payload, True)
+ txid = await self._shutdown(chan, payload, is_local=True)
self.logger.info(f'({chan.get_id_for_log()}) Channel closed {txid}')
return txid
@@ -1686,10 +1689,10 @@ class Peer(Logger):
else:
chan = self.channels[chan_id]
await self.send_shutdown(chan)
- txid = await self._shutdown(chan, payload, False)
+ txid = await self._shutdown(chan, payload, is_local=False)
self.logger.info(f'({chan.get_id_for_log()}) Channel closed by remote peer {txid}')
- def can_send_shutdown(self, chan):
+ def can_send_shutdown(self, chan: Channel):
if chan.get_state() >= ChannelState.OPENING:
return True
if chan.constraints.is_initiator and chan.channel_id in self.funding_created_sent:
@@ -1718,7 +1721,7 @@ class Peer(Logger):
chan.set_can_send_ctx_updates(True)
@log_exceptions
- async def _shutdown(self, chan: Channel, payload, is_local):
+ async def _shutdown(self, chan: Channel, payload, *, is_local: bool):
# wait until no HTLCs remain in either commitment transaction
while len(chan.hm.htlcs(LOCAL)) + len(chan.hm.htlcs(REMOTE)) > 0:
self.logger.info(f'(chan: {chan.short_channel_id}) waiting for htlcs to settle...')
@@ -1826,7 +1829,12 @@ class Peer(Logger):
error_reason = e
else:
try:
- preimage, fw_info, error_bytes = self.process_unfulfilled_htlc(chan, htlc_id, htlc, forwarding_info, onion_packet_bytes, onion_packet)
+ preimage, fw_info, error_bytes = self.process_unfulfilled_htlc(
+ chan=chan,
+ htlc=htlc,
+ forwarding_info=forwarding_info,
+ onion_packet_bytes=onion_packet_bytes,
+ onion_packet=onion_packet)
except OnionRoutingFailure as e:
error_bytes = construct_onion_error(e, onion_packet, our_onion_private_key=self.privkey)
if fw_info:
@@ -1850,13 +1858,24 @@ class Peer(Logger):
for htlc_id in done:
unfulfilled.pop(htlc_id)
- def process_unfulfilled_htlc(self, chan, htlc_id, htlc, forwarding_info, onion_packet_bytes, onion_packet):
+ def process_unfulfilled_htlc(
+ self,
+ *,
+ chan: Channel,
+ htlc: UpdateAddHtlc,
+ forwarding_info: Tuple[str, int],
+ onion_packet_bytes: bytes,
+ onion_packet: OnionPacket,
+ ) -> Tuple[Optional[bytes], Union[bool, None, Tuple[str, int]], Optional[bytes]]:
"""
returns either preimage or fw_info or error_bytes or (None, None, None)
raise an OnionRoutingFailure if we need to fail the htlc
"""
payment_hash = htlc.payment_hash
- processed_onion = self.process_onion_packet(onion_packet, payment_hash, onion_packet_bytes)
+ processed_onion = self.process_onion_packet(
+ onion_packet,
+ payment_hash=payment_hash,
+ onion_packet_bytes=onion_packet_bytes)
if processed_onion.are_we_final:
preimage = self.maybe_fulfill_htlc(
chan=chan,
@@ -1867,8 +1886,8 @@ class Peer(Logger):
if not forwarding_info:
trampoline_onion = self.process_onion_packet(
processed_onion.trampoline_onion_packet,
- htlc.payment_hash,
- onion_packet_bytes,
+ payment_hash=htlc.payment_hash,
+ onion_packet_bytes=onion_packet_bytes,
is_trampoline=True)
if trampoline_onion.are_we_final:
preimage = self.maybe_fulfill_htlc(
@@ -1892,13 +1911,10 @@ class Peer(Logger):
elif not forwarding_info:
next_chan_id, next_htlc_id = self.maybe_forward_htlc(
- chan=chan,
htlc=htlc,
- onion_packet=onion_packet,
processed_onion=processed_onion)
- if next_chan_id:
- fw_info = (next_chan_id.hex(), next_htlc_id)
- return None, fw_info, None
+ fw_info = (next_chan_id.hex(), next_htlc_id)
+ return None, fw_info, None
else:
preimage = self.lnworker.get_preimage(payment_hash)
next_chan_id_hex, htlc_id = forwarding_info
@@ -1913,7 +1929,14 @@ class Peer(Logger):
return preimage, None, None
return None, None, None
- def process_onion_packet(self, onion_packet, payment_hash, onion_packet_bytes, is_trampoline=False):
+ def process_onion_packet(
+ self,
+ onion_packet: OnionPacket,
+ *,
+ payment_hash: bytes,
+ onion_packet_bytes: bytes,
+ is_trampoline: bool = False,
+ ) -> ProcessedOnionPacket:
failure_data = sha256(onion_packet_bytes)
try:
processed_onion = process_onion_packet(
diff --git a/electrum/lnrater.py b/electrum/lnrater.py
@@ -268,7 +268,10 @@ class LNRater(Logger):
return pk, self._node_stats[pk]
- def suggest_peer(self):
+ def suggest_peer(self) -> Optional[bytes]:
+ """Suggests a LN node to open a channel with.
+ Returns a node ID (pubkey).
+ """
self.maybe_analyze_graph()
if self._node_ratings:
return self.suggest_node_channel_open()[0]
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -7,7 +7,8 @@ import os
from decimal import Decimal
import random
import time
-from typing import Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any
+from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING,
+ NamedTuple, Union, Mapping, Any, Iterable)
import threading
import socket
import aiohttp
@@ -266,10 +267,10 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
with self.lock:
return self._peers.copy()
- def channels_for_peer(self, node_id):
+ def channels_for_peer(self, node_id: bytes) -> Dict[bytes, Channel]:
return {}
- def get_node_alias(self, node_id):
+ def get_node_alias(self, node_id: bytes) -> str:
if self.channel_db:
node_info = self.channel_db.get_node_info_for_node_id(node_id)
node_alias = (node_info.alias if node_info else '') or node_id.hex()
@@ -380,7 +381,7 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
self._add_peer(host, int(port), bfh(pubkey)),
self.network.asyncio_loop)
- def is_good_peer(self, peer):
+ def is_good_peer(self, peer: LNPeerAddr) -> bool:
# the purpose of this method is to filter peers that advertise the desired feature bits
# it is disabled for now, because feature bits published in node announcements seem to be unreliable
return True
@@ -566,7 +567,7 @@ class LNGossip(LNWorker):
self.channel_db.prune_orphaned_channels()
await asyncio.sleep(120)
- async def add_new_ids(self, ids):
+ async def add_new_ids(self, ids: Iterable[bytes]):
known = self.channel_db.get_channel_ids()
new = set(ids) - set(known)
self.unknown_ids.update(new)
@@ -574,7 +575,7 @@ class LNGossip(LNWorker):
util.trigger_callback('gossip_peers', self.num_peers())
util.trigger_callback('ln_gossip_sync_progress')
- def get_ids_to_query(self):
+ def get_ids_to_query(self) -> Sequence[bytes]:
N = 500
l = list(self.unknown_ids)
self.unknown_ids = set(l[N:])
@@ -910,7 +911,7 @@ class LNWallet(LNWorker):
if chan.funding_outpoint.to_str() == txo:
return chan
- async def on_channel_update(self, chan):
+ async def on_channel_update(self, chan: Channel):
if chan.get_state() == ChannelState.OPEN and chan.should_be_closed_due_to_expiring_htlcs(self.network.get_local_height()):
self.logger.info(f"force-closing due to expiring htlcs")
@@ -938,10 +939,14 @@ class LNWallet(LNWorker):
@log_exceptions
async def _open_channel_coroutine(
- self, *, connect_str: str,
+ self,
+ *,
+ connect_str: str,
funding_tx: PartialTransaction,
- funding_sat: int, push_sat: int,
- password: Optional[str]) -> Tuple[Channel, PartialTransaction]:
+ funding_sat: int,
+ push_sat: int,
+ password: Optional[str],
+ ) -> Tuple[Channel, PartialTransaction]:
peer = await self.add_peer(connect_str)
coro = peer.channel_establishment_flow(
funding_tx=funding_tx,
@@ -1006,7 +1011,7 @@ class LNWallet(LNWorker):
if chan.short_channel_id == short_channel_id:
return chan
- def create_routes_from_invoice(self, amount_msat, decoded_invoice, *, full_path=None):
+ def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None):
return self.create_routes_for_payment(
amount_msat=amount_msat,
invoice_pubkey=decoded_invoice.pubkey.serialize(),
@@ -1051,9 +1056,16 @@ class LNWallet(LNWorker):
util.trigger_callback('invoice_status', self.wallet, key)
try:
await self.pay_to_node(
- invoice_pubkey, payment_hash, payment_secret, amount_to_pay,
- min_cltv_expiry, r_tags, t_tags, invoice_features,
- attempts=attempts, full_path=full_path)
+ node_pubkey=invoice_pubkey,
+ payment_hash=payment_hash,
+ payment_secret=payment_secret,
+ amount_to_pay=amount_to_pay,
+ min_cltv_expiry=min_cltv_expiry,
+ r_tags=r_tags,
+ t_tags=t_tags,
+ invoice_features=invoice_features,
+ attempts=attempts,
+ full_path=full_path)
success = True
except PaymentFailure as e:
self.logger.exception('')
@@ -1068,12 +1080,23 @@ class LNWallet(LNWorker):
log = self.logs[key]
return success, log
-
async def pay_to_node(
- self, node_pubkey, payment_hash, payment_secret, amount_to_pay,
- min_cltv_expiry, r_tags, t_tags, invoice_features, *,
- attempts: int = 1, full_path: LNPaymentPath=None,
- trampoline_onion=None, trampoline_fee=None, trampoline_cltv_delta=None):
+ self,
+ *,
+ node_pubkey: bytes,
+ payment_hash: bytes,
+ payment_secret: Optional[bytes],
+ amount_to_pay: int, # in msat
+ min_cltv_expiry: int,
+ r_tags,
+ t_tags,
+ invoice_features: int,
+ attempts: int = 1,
+ full_path: LNPaymentPath = None,
+ trampoline_onion=None,
+ trampoline_fee=None,
+ trampoline_cltv_delta=None,
+ ) -> None:
if trampoline_onion:
# todo: compare to the fee of the actual route we found
@@ -1095,7 +1118,14 @@ class LNWallet(LNWorker):
min_cltv_expiry, r_tags, t_tags, invoice_features, full_path=full_path))
# 2. send htlcs
for route, amount_msat in routes:
- await self.pay_to_route(route, amount_msat, amount_to_pay, payment_hash, payment_secret, min_cltv_expiry, trampoline_onion)
+ await self.pay_to_route(
+ route,
+ amount_msat=amount_msat,
+ total_msat=amount_to_pay,
+ payment_hash=payment_hash,
+ payment_secret=payment_secret,
+ min_cltv_expiry=min_cltv_expiry,
+ trampoline_onion=trampoline_onion)
amount_inflight += amount_msat
util.trigger_callback('invoice_status', self.wallet, payment_hash.hex())
# 3. await a queue
@@ -1111,9 +1141,17 @@ class LNWallet(LNWorker):
# if we get a channel update, we might retry the same route and amount
self.handle_error_code_from_failed_htlc(htlc_log)
- async def pay_to_route(self, route: LNPaymentRoute, amount_msat: int,
- total_msat: int, payment_hash: bytes, payment_secret: bytes,
- min_cltv_expiry: int, trampoline_onion: bytes=None):
+ async def pay_to_route(
+ self,
+ route: LNPaymentRoute,
+ *,
+ amount_msat: int,
+ total_msat: int,
+ payment_hash: bytes,
+ payment_secret: Optional[bytes],
+ min_cltv_expiry: int,
+ trampoline_onion: bytes = None,
+ ) -> None:
# send a single htlc
short_channel_id = route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id)
@@ -1267,7 +1305,7 @@ class LNWallet(LNWorker):
result.append(bitstring.BitArray(pubkey) + bitstring.BitArray(channel) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv))
return result.tobytes()
- def is_trampoline_peer(self, node_id):
+ def is_trampoline_peer(self, node_id: bytes) -> bool:
# until trampoline is advertised in lnfeatures, check against hardcoded list
if is_hardcoded_trampoline(node_id):
return True
@@ -1276,8 +1314,11 @@ class LNWallet(LNWorker):
return True
return False
- def suggest_peer(self):
- return self.lnrater.suggest_peer() if self.channel_db else random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
+ def suggest_peer(self) -> Optional[bytes]:
+ if self.channel_db:
+ return self.lnrater.suggest_peer()
+ else:
+ return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey
def create_trampoline_route(
self, amount_msat:int,
@@ -1400,8 +1441,10 @@ class LNWallet(LNWorker):
invoice_pubkey,
min_cltv_expiry,
r_tags, t_tags,
- invoice_features,
- *, full_path: LNPaymentPath = None) -> Sequence[Tuple[LNPaymentRoute, int]]:
+ invoice_features: int,
+ *,
+ full_path: LNPaymentPath = None,
+ ) -> Sequence[Tuple[LNPaymentRoute, int]]:
"""Creates multiple routes for splitting a payment over the available
private channels.
@@ -1411,13 +1454,14 @@ class LNWallet(LNWorker):
# try to send over a single channel
try:
routes = [self.create_route_for_payment(
- amount_msat,
- invoice_pubkey,
- min_cltv_expiry,
- r_tags, t_tags,
- invoice_features,
- None,
- full_path=full_path
+ amount_msat=amount_msat,
+ invoice_pubkey=invoice_pubkey,
+ min_cltv_expiry=min_cltv_expiry,
+ r_tags=r_tags,
+ t_tags=t_tags,
+ invoice_features=invoice_features,
+ outgoing_channel=None,
+ full_path=full_path,
)]
except NoPathFound:
if not invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
@@ -1439,12 +1483,13 @@ class LNWallet(LNWorker):
# its capacity. This could be dealt with by temporarily
# iteratively blacklisting channels for this mpp attempt.
route, amt = self.create_route_for_payment(
- part_amount_msat,
- invoice_pubkey,
- min_cltv_expiry,
- r_tags, t_tags,
- invoice_features,
- channel,
+ amount_msat=part_amount_msat,
+ invoice_pubkey=invoice_pubkey,
+ min_cltv_expiry=min_cltv_expiry,
+ r_tags=r_tags,
+ t_tags=t_tags,
+ invoice_features=invoice_features,
+ outgoing_channel=channel,
full_path=None)
routes.append((route, amt))
self.logger.info(f"found acceptable split configuration: {list(s[0].values())} rating: {s[1]}")
@@ -1457,13 +1502,16 @@ class LNWallet(LNWorker):
def create_route_for_payment(
self,
+ *,
amount_msat: int,
- invoice_pubkey,
- min_cltv_expiry,
- r_tags, t_tags,
- invoice_features,
+ invoice_pubkey: bytes,
+ min_cltv_expiry: int,
+ r_tags,
+ t_tags,
+ invoice_features: int,
outgoing_channel: Channel = None,
- *, full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]:
+ full_path: Optional[LNPaymentPath],
+ ) -> Tuple[LNPaymentRoute, int]:
channels = [outgoing_channel] if outgoing_channel else list(self.channels.values())
if not self.channel_db:
@@ -1554,7 +1602,13 @@ class LNWallet(LNWorker):
raise Exception(_("add invoice timed out"))
@log_exceptions
- async def create_invoice(self, *, amount_msat: Optional[int], message, expiry: int):
+ async def create_invoice(
+ self,
+ *,
+ amount_msat: Optional[int],
+ message,
+ expiry: int,
+ ) -> Tuple[LnAddr, str]:
timestamp = int(time.time())
routing_hints = await self._calc_routing_hints_for_invoice(amount_msat)
if not routing_hints:
@@ -1628,7 +1682,7 @@ class LNWallet(LNWorker):
self.payments[key] = info.amount_msat, info.direction, info.status
self.wallet.save_db()
- def htlc_received(self, short_channel_id, htlc, expected_msat):
+ def htlc_received(self, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int):
status = self.get_payment_status(htlc.payment_hash)
if status == PR_PAID:
return True, None
diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
@@ -775,7 +775,13 @@ class TestPeer(ElectrumTestCase):
min_cltv_expiry = lnaddr.get_min_final_cltv_expiry()
payment_hash = lnaddr.paymenthash
payment_secret = lnaddr.payment_secret
- pay = w1.pay_to_route(route, amount_msat, amount_msat, payment_hash, payment_secret, min_cltv_expiry)
+ pay = w1.pay_to_route(
+ route,
+ amount_msat=amount_msat,
+ total_msat=amount_msat,
+ payment_hash=payment_hash,
+ payment_secret=payment_secret,
+ min_cltv_expiry=min_cltv_expiry)
await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
with self.assertRaises(PaymentFailure):
run(f())