electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 509df9ddaf411a3fe2f6d96cee59d4ab0373de4c
parent 251db638af2f9c892a41af941b587729ba1c6771
Author: SomberNight <somber.night@protonmail.com>
Date:   Fri,  6 Sep 2019 18:09:05 +0200

create class for ShortChannelID and use it

Diffstat:
Melectrum/channel_db.py | 30++++++++++++++++--------------
Melectrum/lnchannel.py | 5+++--
Melectrum/lnonion.py | 5+++--
Melectrum/lnpeer.py | 10+++++-----
Melectrum/lnrouter.py | 28+++++++++++++++-------------
Melectrum/lnutil.py | 49++++++++++++++++++++++++++++++++++++++-----------
Melectrum/lnverifier.py | 35+++++++++++++++++------------------
Melectrum/lnworker.py | 24+++++++++++++-----------
8 files changed, 110 insertions(+), 76 deletions(-)

diff --git a/electrum/channel_db.py b/electrum/channel_db.py @@ -37,7 +37,7 @@ from .sql_db import SqlDB, sql from . import constants from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits from .logging import Logger -from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id +from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update if TYPE_CHECKING: @@ -57,10 +57,10 @@ FLAG_DISABLE = 1 << 1 FLAG_DIRECTION = 1 << 0 class ChannelInfo(NamedTuple): - short_channel_id: bytes + short_channel_id: ShortChannelID node1_id: bytes node2_id: bytes - capacity_sat: int + capacity_sat: Optional[int] @staticmethod def from_msg(payload): @@ -72,10 +72,11 @@ class ChannelInfo(NamedTuple): assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] capacity_sat = None return ChannelInfo( - short_channel_id = channel_id, + short_channel_id = ShortChannelID.normalize(channel_id), node1_id = node_id_1, node2_id = node_id_2, - capacity_sat = capacity_sat) + capacity_sat = capacity_sat + ) class Policy(NamedTuple): @@ -107,8 +108,8 @@ class Policy(NamedTuple): return self.channel_flags & FLAG_DISABLE @property - def short_channel_id(self): - return self.key[0:8] + def short_channel_id(self) -> ShortChannelID: + return ShortChannelID.normalize(self.key[0:8]) @property def start_node(self): @@ -290,7 +291,7 @@ class ChannelDB(SqlDB): msg_payloads = [msg_payloads] added = 0 for msg in msg_payloads: - short_channel_id = msg['short_channel_id'] + short_channel_id = ShortChannelID(msg['short_channel_id']) if short_channel_id in self._channels: continue if constants.net.rev_genesis_bytes() != msg['chain_hash']: @@ -339,7 +340,7 @@ class ChannelDB(SqlDB): known = [] now = int(time.time()) for payload in payloads: - short_channel_id = payload['short_channel_id'] + short_channel_id = ShortChannelID(payload['short_channel_id']) timestamp = int.from_bytes(payload['timestamp'], "big") if max_age and now - timestamp > max_age: expired.append(payload) @@ -357,7 +358,7 @@ class ChannelDB(SqlDB): for payload in known: timestamp = int.from_bytes(payload['timestamp'], "big") start_node = payload['start_node'] - short_channel_id = payload['short_channel_id'] + short_channel_id = ShortChannelID(payload['short_channel_id']) key = (start_node, short_channel_id) old_policy = self._policies.get(key) if old_policy and timestamp <= old_policy.timestamp: @@ -434,11 +435,11 @@ class ChannelDB(SqlDB): def verify_channel_update(self, payload): short_channel_id = payload['short_channel_id'] - scid = format_short_channel_id(short_channel_id) + short_channel_id = ShortChannelID(short_channel_id) if constants.net.rev_genesis_bytes() != payload['chain_hash']: raise Exception('wrong chain hash') if not verify_sig_for_channel_update(payload, payload['start_node']): - raise Exception(f'failed verifying channel update for {scid}') + raise Exception(f'failed verifying channel update for {short_channel_id}') def add_node_announcement(self, msg_payloads): if type(msg_payloads) is dict: @@ -510,11 +511,11 @@ class ChannelDB(SqlDB): def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes): if not verify_sig_for_channel_update(msg_payload, start_node_id): return # ignore - short_channel_id = msg_payload['short_channel_id'] + short_channel_id = ShortChannelID(msg_payload['short_channel_id']) msg_payload['start_node'] = start_node_id self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload - def remove_channel(self, short_channel_id): + def remove_channel(self, short_channel_id: ShortChannelID): channel_info = self._channels.pop(short_channel_id, None) if channel_info: self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id) @@ -533,6 +534,7 @@ class ChannelDB(SqlDB): self._addresses[node_id].add((str(host), int(port), int(timestamp or 0))) c.execute("""SELECT * FROM channel_info""") for x in c: + x = (ShortChannelID.normalize(x[0]), *x[1:]) ci = ChannelInfo(*x) self._channels[ci.short_channel_id] = ci c.execute("""SELECT * FROM node_info""") diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py @@ -45,7 +45,8 @@ from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKey make_htlc_tx_with_open_channel, make_commitment, make_received_htlc, make_offered_htlc, HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT, extract_ctn_from_tx_and_chan, UpdateAddHtlc, funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs, - ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script) + ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script, + ShortChannelID) from .lnutil import FeeUpdate from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx from .lnsweep import create_sweeptx_for_their_revoked_htlc @@ -130,7 +131,7 @@ class Channel(Logger): self.constraints = ChannelConstraints(**state["constraints"]) if type(state["constraints"]) is not ChannelConstraints else state["constraints"] self.funding_outpoint = Outpoint(**dict(decodeAll(state["funding_outpoint"], False))) if type(state["funding_outpoint"]) is not Outpoint else state["funding_outpoint"] self.node_id = bfh(state["node_id"]) if type(state["node_id"]) not in (bytes, type(None)) else state["node_id"] # type: bytes - self.short_channel_id = bfh(state["short_channel_id"]) if type(state["short_channel_id"]) not in (bytes, type(None)) else state["short_channel_id"] + self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"]) self.short_channel_id_predicted = self.short_channel_id self.onion_keys = str_bytes_dict_from_save(state.get('onion_keys', {})) self.force_closed = state.get('force_closed') diff --git a/electrum/lnonion.py b/electrum/lnonion.py @@ -32,7 +32,8 @@ from Cryptodome.Cipher import ChaCha20 from . import ecc from .crypto import sha256, hmac_oneshot from .util import bh2u, profiler, xor_bytes, bfh -from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH, NUM_MAX_EDGES_IN_PAYMENT_PATH +from .lnutil import (get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH, + NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID) if TYPE_CHECKING: from .lnrouter import RouteEdge @@ -51,7 +52,7 @@ class InvalidOnionMac(Exception): pass class OnionPerHop: def __init__(self, short_channel_id: bytes, amt_to_forward: bytes, outgoing_cltv_value: bytes): - self.short_channel_id = short_channel_id + self.short_channel_id = ShortChannelID(short_channel_id) self.amt_to_forward = amt_to_forward self.outgoing_cltv_value = outgoing_cltv_value diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -41,7 +41,7 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, LightningPeerConnectionClosed, HandshakeFailed, NotFoundChanAnnouncementForUpdate, MINIMUM_MAX_HTLC_VALUE_IN_FLIGHT_ACCEPTED, MAXIMUM_HTLC_MINIMUM_MSAT_ACCEPTED, MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY, - NBLOCK_OUR_CLTV_EXPIRY_DELTA, format_short_channel_id) + NBLOCK_OUR_CLTV_EXPIRY_DELTA, format_short_channel_id, ShortChannelID) from .lnutil import FeeUpdate from .lntransport import LNTransport, LNTransportBase from .lnmsg import encode_msg, decode_msg @@ -283,7 +283,7 @@ class Peer(Logger): # as it might be for our own direct channel with this peer # (and we might not yet know the short channel id for that) for chan_upd_payload in orphaned: - short_channel_id = chan_upd_payload['short_channel_id'] + short_channel_id = ShortChannelID(chan_upd_payload['short_channel_id']) self.orphan_channel_updates[short_channel_id] = chan_upd_payload while len(self.orphan_channel_updates) > 25: self.orphan_channel_updates.popitem(last=False) @@ -959,7 +959,7 @@ class Peer(Logger): def mark_open(self, chan: Channel): assert chan.short_channel_id is not None - scid = format_short_channel_id(chan.short_channel_id) + scid = chan.short_channel_id # only allow state transition to "OPEN" from "OPENING" if chan.get_state() != "OPENING": return @@ -1096,7 +1096,7 @@ class Peer(Logger): chan = self.channels[channel_id] key = (channel_id, htlc_id) try: - route = self.attempted_route[key] + route = self.attempted_route[key] # type: List[RouteEdge] except KeyError: # the remote might try to fail an htlc after we restarted... # attempted_route is not persisted, so we will get here then @@ -1310,7 +1310,7 @@ class Peer(Logger): return dph = processed_onion.hop_data.per_hop next_chan = self.lnworker.get_channel_by_short_id(dph.short_channel_id) - next_chan_scid = format_short_channel_id(dph.short_channel_id) + next_chan_scid = dph.short_channel_id next_peer = self.lnworker.peers[next_chan.node_id] local_height = self.network.get_local_height() if next_chan is None: diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py @@ -29,7 +29,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK from .util import bh2u, profiler from .logging import Logger -from .lnutil import NUM_MAX_EDGES_IN_PAYMENT_PATH +from .lnutil import NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID from .channel_db import ChannelDB, Policy if TYPE_CHECKING: @@ -38,7 +38,8 @@ if TYPE_CHECKING: class NoChannelPolicy(Exception): def __init__(self, short_channel_id: bytes): - super().__init__(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}') + short_channel_id = ShortChannelID.normalize(short_channel_id) + super().__init__(f'cannot find channel policy for short_channel_id: {short_channel_id}') def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_proportional_millionths: int) -> int: @@ -46,12 +47,13 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor + (forwarded_amount_msat * fee_proportional_millionths // 1_000_000) -class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes), - ('short_channel_id', bytes), - ('fee_base_msat', int), - ('fee_proportional_millionths', int), - ('cltv_expiry_delta', int)])): +class RouteEdge(NamedTuple): """if you travel through short_channel_id, you will reach node_id""" + node_id: bytes + short_channel_id: ShortChannelID + fee_base_msat: int + fee_proportional_millionths: int + cltv_expiry_delta: int def fee_for_edge(self, amount_msat: int) -> int: return fee_for_edge_msat(forwarded_amount_msat=amount_msat, @@ -61,10 +63,10 @@ class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes), @classmethod def from_channel_policy(cls, channel_policy: 'Policy', short_channel_id: bytes, end_node: bytes) -> 'RouteEdge': - assert type(short_channel_id) is bytes + assert isinstance(short_channel_id, bytes) assert type(end_node) is bytes return RouteEdge(end_node, - short_channel_id, + ShortChannelID.normalize(short_channel_id), channel_policy.fee_base_msat, channel_policy.fee_proportional_millionths, channel_policy.cltv_expiry_delta) @@ -119,8 +121,8 @@ class LNPathFinder(Logger): self.channel_db = channel_db self.blacklist = set() - def add_to_blacklist(self, short_channel_id): - self.logger.info(f'blacklisting channel {bh2u(short_channel_id)}') + def add_to_blacklist(self, short_channel_id: ShortChannelID): + self.logger.info(f'blacklisting channel {short_channel_id}') self.blacklist.add(short_channel_id) def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes, @@ -218,7 +220,7 @@ class LNPathFinder(Logger): # so there are duplicates in the queue, that we discard now: continue for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode): - assert type(edge_channel_id) is bytes + assert isinstance(edge_channel_id, bytes) if edge_channel_id in self.blacklist: continue channel_info = self.channel_db.get_channel_info(edge_channel_id) @@ -237,7 +239,7 @@ class LNPathFinder(Logger): return path def create_route_from_path(self, path, from_node_id: bytes) -> List[RouteEdge]: - assert type(from_node_id) is bytes + assert isinstance(from_node_id, bytes) if path is None: raise Exception('cannot create route from None path') route = [] diff --git a/electrum/lnutil.py b/electrum/lnutil.py @@ -546,17 +546,6 @@ def funding_output_script_from_keys(pubkey1: bytes, pubkey2: bytes) -> str: pubkeys = sorted([bh2u(pubkey1), bh2u(pubkey2)]) return transaction.multisig_script(pubkeys, 2) -def calc_short_channel_id(block_height: int, tx_pos_in_block: int, output_index: int) -> bytes: - bh = block_height.to_bytes(3, byteorder='big') - tpos = tx_pos_in_block.to_bytes(3, byteorder='big') - oi = output_index.to_bytes(2, byteorder='big') - return bh + tpos + oi - -def invert_short_channel_id(short_channel_id: bytes) -> (int, int, int): - bh = int.from_bytes(short_channel_id[:3], byteorder='big') - tpos = int.from_bytes(short_channel_id[3:6], byteorder='big') - oi = int.from_bytes(short_channel_id[6:8], byteorder='big') - return bh, tpos, oi def get_obscured_ctn(ctn: int, funder: bytes, fundee: bytes) -> int: mask = int.from_bytes(sha256(funder + fundee)[-6:], 'big') @@ -705,6 +694,44 @@ def generate_keypair(ln_keystore: BIP32_KeyStore, key_family: LnKeyFamily, index NUM_MAX_HOPS_IN_PAYMENT_PATH = 20 NUM_MAX_EDGES_IN_PAYMENT_PATH = NUM_MAX_HOPS_IN_PAYMENT_PATH + 1 + +class ShortChannelID(bytes): + + def __repr__(self): + return f"<ShortChannelID: {format_short_channel_id(self)}>" + + def __str__(self): + return format_short_channel_id(self) + + @classmethod + def from_components(cls, block_height: int, tx_pos_in_block: int, output_index: int) -> 'ShortChannelID': + bh = block_height.to_bytes(3, byteorder='big') + tpos = tx_pos_in_block.to_bytes(3, byteorder='big') + oi = output_index.to_bytes(2, byteorder='big') + return ShortChannelID(bh + tpos + oi) + + @classmethod + def normalize(cls, data: Union[None, str, bytes, 'ShortChannelID']) -> Optional['ShortChannelID']: + if isinstance(data, ShortChannelID) or data is None: + return data + if isinstance(data, str): + return ShortChannelID.fromhex(data) + if isinstance(data, bytes): + return ShortChannelID(data) + + @property + def block_height(self) -> int: + return int.from_bytes(self[:3], byteorder='big') + + @property + def txpos(self) -> int: + return int.from_bytes(self[3:6], byteorder='big') + + @property + def output_index(self) -> int: + return int.from_bytes(self[6:8], byteorder='big') + + def format_short_channel_id(short_channel_id: Optional[bytes]): if not short_channel_id: return _('Not yet available') diff --git a/electrum/lnverifier.py b/electrum/lnverifier.py @@ -25,7 +25,7 @@ import asyncio import threading -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, Set import aiorpcx @@ -33,7 +33,7 @@ from . import bitcoin from . import ecc from . import constants from .util import bh2u, bfh, NetworkJobOnDefaultServer -from .lnutil import invert_short_channel_id, funding_output_script_from_keys +from .lnutil import funding_output_script_from_keys, ShortChannelID from .verifier import verify_tx_is_in_block, MerkleVerificationFailure from .transaction import Transaction from .interface import GracefulDisconnect @@ -56,17 +56,16 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): NetworkJobOnDefaultServer.__init__(self, network) self.channel_db = channel_db self.lock = threading.Lock() - self.unverified_channel_info = {} # short_channel_id -> msg_payload + self.unverified_channel_info = {} # type: Dict[ShortChannelID, dict] # scid -> msg_payload # channel announcements that seem to be invalid: - self.blacklist = set() # short_channel_id + self.blacklist = set() # type: Set[ShortChannelID] def _reset(self): super()._reset() - self.started_verifying_channel = set() # short_channel_id + self.started_verifying_channel = set() # type: Set[ShortChannelID] # TODO make async; and rm self.lock completely - def add_new_channel_info(self, short_channel_id_hex, msg_payload): - short_channel_id = bfh(short_channel_id_hex) + def add_new_channel_info(self, short_channel_id: ShortChannelID, msg_payload): if short_channel_id in self.unverified_channel_info: return if short_channel_id in self.blacklist: @@ -93,7 +92,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): for short_channel_id in unverified_channel_info: if short_channel_id in self.started_verifying_channel: continue - block_height, tx_pos, output_idx = invert_short_channel_id(short_channel_id) + block_height = short_channel_id.block_height # only resolve short_channel_id if headers are available. if block_height <= 0 or block_height > local_height: continue @@ -103,16 +102,17 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): await self.group.spawn(self.network.request_chunk(block_height, None, can_return_early=True)) continue self.started_verifying_channel.add(short_channel_id) - await self.group.spawn(self.verify_channel(block_height, tx_pos, short_channel_id)) + await self.group.spawn(self.verify_channel(block_height, short_channel_id)) #self.logger.info(f'requested short_channel_id {bh2u(short_channel_id)}') - async def verify_channel(self, block_height: int, tx_pos: int, short_channel_id: bytes): + async def verify_channel(self, block_height: int, short_channel_id: ShortChannelID): # we are verifying channel announcements as they are from untrusted ln peers. # we use electrum servers to do this. however we don't trust electrum servers either... try: - result = await self.network.get_txid_from_txpos(block_height, tx_pos, True) + result = await self.network.get_txid_from_txpos( + block_height, short_channel_id.txpos, True) except aiorpcx.jsonrpc.RPCError: - # the electrum server is complaining about the tx_pos for given block. + # the electrum server is complaining about the txpos for given block. # it is not clear what to do now, but let's believe the server. self._blacklist_short_channel_id(short_channel_id) return @@ -122,7 +122,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): async with self.network.bhi_lock: header = self.network.blockchain().read_header(block_height) try: - verify_tx_is_in_block(tx_hash, merkle_branch, tx_pos, header, block_height) + verify_tx_is_in_block(tx_hash, merkle_branch, short_channel_id.txpos, header, block_height) except MerkleVerificationFailure as e: # the electrum server sent an incorrect proof. blame is on server, not the ln peer raise GracefulDisconnect(e) from e @@ -151,28 +151,27 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): assert msg_type == 'channel_announcement' redeem_script = funding_output_script_from_keys(chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2']) expected_address = bitcoin.redeem_script_to_address('p2wsh', redeem_script) - output_idx = invert_short_channel_id(short_channel_id)[2] try: - actual_output = tx.outputs()[output_idx] + actual_output = tx.outputs()[short_channel_id.output_index] except IndexError: self._blacklist_short_channel_id(short_channel_id) return if expected_address != actual_output.address: # FIXME what now? best would be to ban the originating ln peer. - self.logger.info(f"funding output script mismatch for {bh2u(short_channel_id)}") + self.logger.info(f"funding output script mismatch for {short_channel_id}") self._remove_channel_from_unverified_db(short_channel_id) return # put channel into channel DB self.channel_db.add_verified_channel_info(short_channel_id, actual_output.value) self._remove_channel_from_unverified_db(short_channel_id) - def _remove_channel_from_unverified_db(self, short_channel_id: bytes): + def _remove_channel_from_unverified_db(self, short_channel_id: ShortChannelID): with self.lock: self.unverified_channel_info.pop(short_channel_id, None) try: self.started_verifying_channel.remove(short_channel_id) except KeyError: pass - def _blacklist_short_channel_id(self, short_channel_id: bytes) -> None: + def _blacklist_short_channel_id(self, short_channel_id: ShortChannelID) -> None: self.blacklist.add(short_channel_id) with self.lock: self.unverified_channel_info.pop(short_channel_id, None) diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -39,13 +39,14 @@ from .ecc import der_sig_from_sig_string from .ecc_fast import is_using_fast_ecc from .lnchannel import Channel, ChannelJsonEncoder from . import lnutil -from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr, +from .lnutil import (Outpoint, LNPeerAddr, get_compressed_pubkey_from_bech32, extract_nodeid, PaymentFailure, split_host_port, ConnStringFormatError, generate_keypair, LnKeyFamily, LOCAL, REMOTE, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE, NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner, - UpdateAddHtlc, Direction, LnLocalFeatures, format_short_channel_id) + UpdateAddHtlc, Direction, LnLocalFeatures, format_short_channel_id, + ShortChannelID) from .i18n import _ from .lnrouter import RouteEdge, is_route_sane_to_use from .address_synchronizer import TX_HEIGHT_LOCAL @@ -553,10 +554,11 @@ class LNWallet(LNWorker): if conf > 0: block_height, tx_pos = self.lnwatcher.get_txpos(chan.funding_outpoint.txid) assert tx_pos >= 0 - chan.short_channel_id_predicted = calc_short_channel_id(block_height, tx_pos, chan.funding_outpoint.output_index) + chan.short_channel_id_predicted = ShortChannelID.from_components( + block_height, tx_pos, chan.funding_outpoint.output_index) if conf >= chan.constraints.funding_txn_minimum_depth > 0: - self.logger.info(f"save_short_channel_id") chan.short_channel_id = chan.short_channel_id_predicted + self.logger.info(f"save_short_channel_id: {chan.short_channel_id}") self.save_channel(chan) self.on_channels_updated() else: @@ -795,7 +797,7 @@ class LNWallet(LNWorker): else: self.network.trigger_callback('payment_status', key, 'failure') - def get_channel_by_short_id(self, short_channel_id): + def get_channel_by_short_id(self, short_channel_id: ShortChannelID) -> Channel: with self.lock: for chan in self.channels.values(): if chan.short_channel_id == short_channel_id: @@ -815,7 +817,7 @@ class LNWallet(LNWorker): for i in range(attempts): route = await self._create_route_from_invoice(decoded_invoice=addr) if not self.get_channel_by_short_id(route[0].short_channel_id): - scid = format_short_channel_id(route[0].short_channel_id) + scid = route[0].short_channel_id raise Exception(f"Got route with unknown first channel: {scid}") self.network.trigger_callback('payment_status', key, 'progress', i) if await self._pay_to_route(route, addr, invoice): @@ -826,8 +828,8 @@ class LNWallet(LNWorker): short_channel_id = route[0].short_channel_id chan = self.get_channel_by_short_id(short_channel_id) if not chan: - scid = format_short_channel_id(short_channel_id) - raise Exception(f"PathFinder returned path with short_channel_id {scid} that is not in channel list") + raise Exception(f"PathFinder returned path with short_channel_id " + f"{short_channel_id} that is not in channel list") peer = self.peers[route[0].node_id] htlc = await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry()) self.network.trigger_callback('htlc_added', htlc, addr, SENT) @@ -879,6 +881,7 @@ class LNWallet(LNWorker): prev_node_id = border_node_pubkey for node_pubkey, edge_rest in zip(private_route_nodes, private_route_rest): short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest + short_channel_id = ShortChannelID(short_channel_id) # if we have a routing policy for this edge in the db, that takes precedence, # as it is likely from a previous failure channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id) @@ -1030,7 +1033,7 @@ class LNWallet(LNWorker): if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat: continue chan_id = chan.short_channel_id - assert type(chan_id) is bytes, chan_id + assert isinstance(chan_id, bytes), chan_id channel_info = self.channel_db.get_channel_info(chan_id) # note: as a fallback, if we don't have a channel update for the # incoming direction of our private channel, we fill the invoice with garbage. @@ -1048,8 +1051,7 @@ class LNWallet(LNWorker): cltv_expiry_delta = policy.cltv_expiry_delta missing_info = False if missing_info: - scid = format_short_channel_id(chan_id) - self.logger.info(f"Warning. Missing channel update for our channel {scid}; " + self.logger.info(f"Warning. Missing channel update for our channel {chan_id}; " f"filling invoice with incorrect data.") routing_hints.append(('r', [(chan.node_id, chan_id,