commit 509df9ddaf411a3fe2f6d96cee59d4ab0373de4c
parent 251db638af2f9c892a41af941b587729ba1c6771
Author: SomberNight <somber.night@protonmail.com>
Date: Fri, 6 Sep 2019 18:09:05 +0200
create class for ShortChannelID and use it
Diffstat:
8 files changed, 110 insertions(+), 76 deletions(-)
diff --git a/electrum/channel_db.py b/electrum/channel_db.py
@@ -37,7 +37,7 @@ from .sql_db import SqlDB, sql
from . import constants
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
from .logging import Logger
-from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id
+from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
if TYPE_CHECKING:
@@ -57,10 +57,10 @@ FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0
class ChannelInfo(NamedTuple):
- short_channel_id: bytes
+ short_channel_id: ShortChannelID
node1_id: bytes
node2_id: bytes
- capacity_sat: int
+ capacity_sat: Optional[int]
@staticmethod
def from_msg(payload):
@@ -72,10 +72,11 @@ class ChannelInfo(NamedTuple):
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
capacity_sat = None
return ChannelInfo(
- short_channel_id = channel_id,
+ short_channel_id = ShortChannelID.normalize(channel_id),
node1_id = node_id_1,
node2_id = node_id_2,
- capacity_sat = capacity_sat)
+ capacity_sat = capacity_sat
+ )
class Policy(NamedTuple):
@@ -107,8 +108,8 @@ class Policy(NamedTuple):
return self.channel_flags & FLAG_DISABLE
@property
- def short_channel_id(self):
- return self.key[0:8]
+ def short_channel_id(self) -> ShortChannelID:
+ return ShortChannelID.normalize(self.key[0:8])
@property
def start_node(self):
@@ -290,7 +291,7 @@ class ChannelDB(SqlDB):
msg_payloads = [msg_payloads]
added = 0
for msg in msg_payloads:
- short_channel_id = msg['short_channel_id']
+ short_channel_id = ShortChannelID(msg['short_channel_id'])
if short_channel_id in self._channels:
continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
@@ -339,7 +340,7 @@ class ChannelDB(SqlDB):
known = []
now = int(time.time())
for payload in payloads:
- short_channel_id = payload['short_channel_id']
+ short_channel_id = ShortChannelID(payload['short_channel_id'])
timestamp = int.from_bytes(payload['timestamp'], "big")
if max_age and now - timestamp > max_age:
expired.append(payload)
@@ -357,7 +358,7 @@ class ChannelDB(SqlDB):
for payload in known:
timestamp = int.from_bytes(payload['timestamp'], "big")
start_node = payload['start_node']
- short_channel_id = payload['short_channel_id']
+ short_channel_id = ShortChannelID(payload['short_channel_id'])
key = (start_node, short_channel_id)
old_policy = self._policies.get(key)
if old_policy and timestamp <= old_policy.timestamp:
@@ -434,11 +435,11 @@ class ChannelDB(SqlDB):
def verify_channel_update(self, payload):
short_channel_id = payload['short_channel_id']
- scid = format_short_channel_id(short_channel_id)
+ short_channel_id = ShortChannelID(short_channel_id)
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
raise Exception('wrong chain hash')
if not verify_sig_for_channel_update(payload, payload['start_node']):
- raise Exception(f'failed verifying channel update for {scid}')
+ raise Exception(f'failed verifying channel update for {short_channel_id}')
def add_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict:
@@ -510,11 +511,11 @@ class ChannelDB(SqlDB):
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
if not verify_sig_for_channel_update(msg_payload, start_node_id):
return # ignore
- short_channel_id = msg_payload['short_channel_id']
+ short_channel_id = ShortChannelID(msg_payload['short_channel_id'])
msg_payload['start_node'] = start_node_id
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
- def remove_channel(self, short_channel_id):
+ def remove_channel(self, short_channel_id: ShortChannelID):
channel_info = self._channels.pop(short_channel_id, None)
if channel_info:
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
@@ -533,6 +534,7 @@ class ChannelDB(SqlDB):
self._addresses[node_id].add((str(host), int(port), int(timestamp or 0)))
c.execute("""SELECT * FROM channel_info""")
for x in c:
+ x = (ShortChannelID.normalize(x[0]), *x[1:])
ci = ChannelInfo(*x)
self._channels[ci.short_channel_id] = ci
c.execute("""SELECT * FROM node_info""")
diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py
@@ -45,7 +45,8 @@ from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKey
make_htlc_tx_with_open_channel, make_commitment, make_received_htlc, make_offered_htlc,
HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT, extract_ctn_from_tx_and_chan, UpdateAddHtlc,
funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs,
- ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script)
+ ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script,
+ ShortChannelID)
from .lnutil import FeeUpdate
from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx
from .lnsweep import create_sweeptx_for_their_revoked_htlc
@@ -130,7 +131,7 @@ class Channel(Logger):
self.constraints = ChannelConstraints(**state["constraints"]) if type(state["constraints"]) is not ChannelConstraints else state["constraints"]
self.funding_outpoint = Outpoint(**dict(decodeAll(state["funding_outpoint"], False))) if type(state["funding_outpoint"]) is not Outpoint else state["funding_outpoint"]
self.node_id = bfh(state["node_id"]) if type(state["node_id"]) not in (bytes, type(None)) else state["node_id"] # type: bytes
- self.short_channel_id = bfh(state["short_channel_id"]) if type(state["short_channel_id"]) not in (bytes, type(None)) else state["short_channel_id"]
+ self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"])
self.short_channel_id_predicted = self.short_channel_id
self.onion_keys = str_bytes_dict_from_save(state.get('onion_keys', {}))
self.force_closed = state.get('force_closed')
diff --git a/electrum/lnonion.py b/electrum/lnonion.py
@@ -32,7 +32,8 @@ from Cryptodome.Cipher import ChaCha20
from . import ecc
from .crypto import sha256, hmac_oneshot
from .util import bh2u, profiler, xor_bytes, bfh
-from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH, NUM_MAX_EDGES_IN_PAYMENT_PATH
+from .lnutil import (get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH,
+ NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID)
if TYPE_CHECKING:
from .lnrouter import RouteEdge
@@ -51,7 +52,7 @@ class InvalidOnionMac(Exception): pass
class OnionPerHop:
def __init__(self, short_channel_id: bytes, amt_to_forward: bytes, outgoing_cltv_value: bytes):
- self.short_channel_id = short_channel_id
+ self.short_channel_id = ShortChannelID(short_channel_id)
self.amt_to_forward = amt_to_forward
self.outgoing_cltv_value = outgoing_cltv_value
diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
@@ -41,7 +41,7 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc,
LightningPeerConnectionClosed, HandshakeFailed, NotFoundChanAnnouncementForUpdate,
MINIMUM_MAX_HTLC_VALUE_IN_FLIGHT_ACCEPTED, MAXIMUM_HTLC_MINIMUM_MSAT_ACCEPTED,
MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY,
- NBLOCK_OUR_CLTV_EXPIRY_DELTA, format_short_channel_id)
+ NBLOCK_OUR_CLTV_EXPIRY_DELTA, format_short_channel_id, ShortChannelID)
from .lnutil import FeeUpdate
from .lntransport import LNTransport, LNTransportBase
from .lnmsg import encode_msg, decode_msg
@@ -283,7 +283,7 @@ class Peer(Logger):
# as it might be for our own direct channel with this peer
# (and we might not yet know the short channel id for that)
for chan_upd_payload in orphaned:
- short_channel_id = chan_upd_payload['short_channel_id']
+ short_channel_id = ShortChannelID(chan_upd_payload['short_channel_id'])
self.orphan_channel_updates[short_channel_id] = chan_upd_payload
while len(self.orphan_channel_updates) > 25:
self.orphan_channel_updates.popitem(last=False)
@@ -959,7 +959,7 @@ class Peer(Logger):
def mark_open(self, chan: Channel):
assert chan.short_channel_id is not None
- scid = format_short_channel_id(chan.short_channel_id)
+ scid = chan.short_channel_id
# only allow state transition to "OPEN" from "OPENING"
if chan.get_state() != "OPENING":
return
@@ -1096,7 +1096,7 @@ class Peer(Logger):
chan = self.channels[channel_id]
key = (channel_id, htlc_id)
try:
- route = self.attempted_route[key]
+ route = self.attempted_route[key] # type: List[RouteEdge]
except KeyError:
# the remote might try to fail an htlc after we restarted...
# attempted_route is not persisted, so we will get here then
@@ -1310,7 +1310,7 @@ class Peer(Logger):
return
dph = processed_onion.hop_data.per_hop
next_chan = self.lnworker.get_channel_by_short_id(dph.short_channel_id)
- next_chan_scid = format_short_channel_id(dph.short_channel_id)
+ next_chan_scid = dph.short_channel_id
next_peer = self.lnworker.peers[next_chan.node_id]
local_height = self.network.get_local_height()
if next_chan is None:
diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
@@ -29,7 +29,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
from .util import bh2u, profiler
from .logging import Logger
-from .lnutil import NUM_MAX_EDGES_IN_PAYMENT_PATH
+from .lnutil import NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID
from .channel_db import ChannelDB, Policy
if TYPE_CHECKING:
@@ -38,7 +38,8 @@ if TYPE_CHECKING:
class NoChannelPolicy(Exception):
def __init__(self, short_channel_id: bytes):
- super().__init__(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}')
+ short_channel_id = ShortChannelID.normalize(short_channel_id)
+ super().__init__(f'cannot find channel policy for short_channel_id: {short_channel_id}')
def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_proportional_millionths: int) -> int:
@@ -46,12 +47,13 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor
+ (forwarded_amount_msat * fee_proportional_millionths // 1_000_000)
-class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
- ('short_channel_id', bytes),
- ('fee_base_msat', int),
- ('fee_proportional_millionths', int),
- ('cltv_expiry_delta', int)])):
+class RouteEdge(NamedTuple):
"""if you travel through short_channel_id, you will reach node_id"""
+ node_id: bytes
+ short_channel_id: ShortChannelID
+ fee_base_msat: int
+ fee_proportional_millionths: int
+ cltv_expiry_delta: int
def fee_for_edge(self, amount_msat: int) -> int:
return fee_for_edge_msat(forwarded_amount_msat=amount_msat,
@@ -61,10 +63,10 @@ class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
@classmethod
def from_channel_policy(cls, channel_policy: 'Policy',
short_channel_id: bytes, end_node: bytes) -> 'RouteEdge':
- assert type(short_channel_id) is bytes
+ assert isinstance(short_channel_id, bytes)
assert type(end_node) is bytes
return RouteEdge(end_node,
- short_channel_id,
+ ShortChannelID.normalize(short_channel_id),
channel_policy.fee_base_msat,
channel_policy.fee_proportional_millionths,
channel_policy.cltv_expiry_delta)
@@ -119,8 +121,8 @@ class LNPathFinder(Logger):
self.channel_db = channel_db
self.blacklist = set()
- def add_to_blacklist(self, short_channel_id):
- self.logger.info(f'blacklisting channel {bh2u(short_channel_id)}')
+ def add_to_blacklist(self, short_channel_id: ShortChannelID):
+ self.logger.info(f'blacklisting channel {short_channel_id}')
self.blacklist.add(short_channel_id)
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
@@ -218,7 +220,7 @@ class LNPathFinder(Logger):
# so there are duplicates in the queue, that we discard now:
continue
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
- assert type(edge_channel_id) is bytes
+ assert isinstance(edge_channel_id, bytes)
if edge_channel_id in self.blacklist:
continue
channel_info = self.channel_db.get_channel_info(edge_channel_id)
@@ -237,7 +239,7 @@ class LNPathFinder(Logger):
return path
def create_route_from_path(self, path, from_node_id: bytes) -> List[RouteEdge]:
- assert type(from_node_id) is bytes
+ assert isinstance(from_node_id, bytes)
if path is None:
raise Exception('cannot create route from None path')
route = []
diff --git a/electrum/lnutil.py b/electrum/lnutil.py
@@ -546,17 +546,6 @@ def funding_output_script_from_keys(pubkey1: bytes, pubkey2: bytes) -> str:
pubkeys = sorted([bh2u(pubkey1), bh2u(pubkey2)])
return transaction.multisig_script(pubkeys, 2)
-def calc_short_channel_id(block_height: int, tx_pos_in_block: int, output_index: int) -> bytes:
- bh = block_height.to_bytes(3, byteorder='big')
- tpos = tx_pos_in_block.to_bytes(3, byteorder='big')
- oi = output_index.to_bytes(2, byteorder='big')
- return bh + tpos + oi
-
-def invert_short_channel_id(short_channel_id: bytes) -> (int, int, int):
- bh = int.from_bytes(short_channel_id[:3], byteorder='big')
- tpos = int.from_bytes(short_channel_id[3:6], byteorder='big')
- oi = int.from_bytes(short_channel_id[6:8], byteorder='big')
- return bh, tpos, oi
def get_obscured_ctn(ctn: int, funder: bytes, fundee: bytes) -> int:
mask = int.from_bytes(sha256(funder + fundee)[-6:], 'big')
@@ -705,6 +694,44 @@ def generate_keypair(ln_keystore: BIP32_KeyStore, key_family: LnKeyFamily, index
NUM_MAX_HOPS_IN_PAYMENT_PATH = 20
NUM_MAX_EDGES_IN_PAYMENT_PATH = NUM_MAX_HOPS_IN_PAYMENT_PATH + 1
+
+class ShortChannelID(bytes):
+
+ def __repr__(self):
+ return f"<ShortChannelID: {format_short_channel_id(self)}>"
+
+ def __str__(self):
+ return format_short_channel_id(self)
+
+ @classmethod
+ def from_components(cls, block_height: int, tx_pos_in_block: int, output_index: int) -> 'ShortChannelID':
+ bh = block_height.to_bytes(3, byteorder='big')
+ tpos = tx_pos_in_block.to_bytes(3, byteorder='big')
+ oi = output_index.to_bytes(2, byteorder='big')
+ return ShortChannelID(bh + tpos + oi)
+
+ @classmethod
+ def normalize(cls, data: Union[None, str, bytes, 'ShortChannelID']) -> Optional['ShortChannelID']:
+ if isinstance(data, ShortChannelID) or data is None:
+ return data
+ if isinstance(data, str):
+ return ShortChannelID.fromhex(data)
+ if isinstance(data, bytes):
+ return ShortChannelID(data)
+
+ @property
+ def block_height(self) -> int:
+ return int.from_bytes(self[:3], byteorder='big')
+
+ @property
+ def txpos(self) -> int:
+ return int.from_bytes(self[3:6], byteorder='big')
+
+ @property
+ def output_index(self) -> int:
+ return int.from_bytes(self[6:8], byteorder='big')
+
+
def format_short_channel_id(short_channel_id: Optional[bytes]):
if not short_channel_id:
return _('Not yet available')
diff --git a/electrum/lnverifier.py b/electrum/lnverifier.py
@@ -25,7 +25,7 @@
import asyncio
import threading
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Dict, Set
import aiorpcx
@@ -33,7 +33,7 @@ from . import bitcoin
from . import ecc
from . import constants
from .util import bh2u, bfh, NetworkJobOnDefaultServer
-from .lnutil import invert_short_channel_id, funding_output_script_from_keys
+from .lnutil import funding_output_script_from_keys, ShortChannelID
from .verifier import verify_tx_is_in_block, MerkleVerificationFailure
from .transaction import Transaction
from .interface import GracefulDisconnect
@@ -56,17 +56,16 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
NetworkJobOnDefaultServer.__init__(self, network)
self.channel_db = channel_db
self.lock = threading.Lock()
- self.unverified_channel_info = {} # short_channel_id -> msg_payload
+ self.unverified_channel_info = {} # type: Dict[ShortChannelID, dict] # scid -> msg_payload
# channel announcements that seem to be invalid:
- self.blacklist = set() # short_channel_id
+ self.blacklist = set() # type: Set[ShortChannelID]
def _reset(self):
super()._reset()
- self.started_verifying_channel = set() # short_channel_id
+ self.started_verifying_channel = set() # type: Set[ShortChannelID]
# TODO make async; and rm self.lock completely
- def add_new_channel_info(self, short_channel_id_hex, msg_payload):
- short_channel_id = bfh(short_channel_id_hex)
+ def add_new_channel_info(self, short_channel_id: ShortChannelID, msg_payload):
if short_channel_id in self.unverified_channel_info:
return
if short_channel_id in self.blacklist:
@@ -93,7 +92,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
for short_channel_id in unverified_channel_info:
if short_channel_id in self.started_verifying_channel:
continue
- block_height, tx_pos, output_idx = invert_short_channel_id(short_channel_id)
+ block_height = short_channel_id.block_height
# only resolve short_channel_id if headers are available.
if block_height <= 0 or block_height > local_height:
continue
@@ -103,16 +102,17 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
await self.group.spawn(self.network.request_chunk(block_height, None, can_return_early=True))
continue
self.started_verifying_channel.add(short_channel_id)
- await self.group.spawn(self.verify_channel(block_height, tx_pos, short_channel_id))
+ await self.group.spawn(self.verify_channel(block_height, short_channel_id))
#self.logger.info(f'requested short_channel_id {bh2u(short_channel_id)}')
- async def verify_channel(self, block_height: int, tx_pos: int, short_channel_id: bytes):
+ async def verify_channel(self, block_height: int, short_channel_id: ShortChannelID):
# we are verifying channel announcements as they are from untrusted ln peers.
# we use electrum servers to do this. however we don't trust electrum servers either...
try:
- result = await self.network.get_txid_from_txpos(block_height, tx_pos, True)
+ result = await self.network.get_txid_from_txpos(
+ block_height, short_channel_id.txpos, True)
except aiorpcx.jsonrpc.RPCError:
- # the electrum server is complaining about the tx_pos for given block.
+ # the electrum server is complaining about the txpos for given block.
# it is not clear what to do now, but let's believe the server.
self._blacklist_short_channel_id(short_channel_id)
return
@@ -122,7 +122,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
async with self.network.bhi_lock:
header = self.network.blockchain().read_header(block_height)
try:
- verify_tx_is_in_block(tx_hash, merkle_branch, tx_pos, header, block_height)
+ verify_tx_is_in_block(tx_hash, merkle_branch, short_channel_id.txpos, header, block_height)
except MerkleVerificationFailure as e:
# the electrum server sent an incorrect proof. blame is on server, not the ln peer
raise GracefulDisconnect(e) from e
@@ -151,28 +151,27 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
assert msg_type == 'channel_announcement'
redeem_script = funding_output_script_from_keys(chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2'])
expected_address = bitcoin.redeem_script_to_address('p2wsh', redeem_script)
- output_idx = invert_short_channel_id(short_channel_id)[2]
try:
- actual_output = tx.outputs()[output_idx]
+ actual_output = tx.outputs()[short_channel_id.output_index]
except IndexError:
self._blacklist_short_channel_id(short_channel_id)
return
if expected_address != actual_output.address:
# FIXME what now? best would be to ban the originating ln peer.
- self.logger.info(f"funding output script mismatch for {bh2u(short_channel_id)}")
+ self.logger.info(f"funding output script mismatch for {short_channel_id}")
self._remove_channel_from_unverified_db(short_channel_id)
return
# put channel into channel DB
self.channel_db.add_verified_channel_info(short_channel_id, actual_output.value)
self._remove_channel_from_unverified_db(short_channel_id)
- def _remove_channel_from_unverified_db(self, short_channel_id: bytes):
+ def _remove_channel_from_unverified_db(self, short_channel_id: ShortChannelID):
with self.lock:
self.unverified_channel_info.pop(short_channel_id, None)
try: self.started_verifying_channel.remove(short_channel_id)
except KeyError: pass
- def _blacklist_short_channel_id(self, short_channel_id: bytes) -> None:
+ def _blacklist_short_channel_id(self, short_channel_id: ShortChannelID) -> None:
self.blacklist.add(short_channel_id)
with self.lock:
self.unverified_channel_info.pop(short_channel_id, None)
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -39,13 +39,14 @@ from .ecc import der_sig_from_sig_string
from .ecc_fast import is_using_fast_ecc
from .lnchannel import Channel, ChannelJsonEncoder
from . import lnutil
-from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
+from .lnutil import (Outpoint, LNPeerAddr,
get_compressed_pubkey_from_bech32, extract_nodeid,
PaymentFailure, split_host_port, ConnStringFormatError,
generate_keypair, LnKeyFamily, LOCAL, REMOTE,
UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE,
NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner,
- UpdateAddHtlc, Direction, LnLocalFeatures, format_short_channel_id)
+ UpdateAddHtlc, Direction, LnLocalFeatures, format_short_channel_id,
+ ShortChannelID)
from .i18n import _
from .lnrouter import RouteEdge, is_route_sane_to_use
from .address_synchronizer import TX_HEIGHT_LOCAL
@@ -553,10 +554,11 @@ class LNWallet(LNWorker):
if conf > 0:
block_height, tx_pos = self.lnwatcher.get_txpos(chan.funding_outpoint.txid)
assert tx_pos >= 0
- chan.short_channel_id_predicted = calc_short_channel_id(block_height, tx_pos, chan.funding_outpoint.output_index)
+ chan.short_channel_id_predicted = ShortChannelID.from_components(
+ block_height, tx_pos, chan.funding_outpoint.output_index)
if conf >= chan.constraints.funding_txn_minimum_depth > 0:
- self.logger.info(f"save_short_channel_id")
chan.short_channel_id = chan.short_channel_id_predicted
+ self.logger.info(f"save_short_channel_id: {chan.short_channel_id}")
self.save_channel(chan)
self.on_channels_updated()
else:
@@ -795,7 +797,7 @@ class LNWallet(LNWorker):
else:
self.network.trigger_callback('payment_status', key, 'failure')
- def get_channel_by_short_id(self, short_channel_id):
+ def get_channel_by_short_id(self, short_channel_id: ShortChannelID) -> Channel:
with self.lock:
for chan in self.channels.values():
if chan.short_channel_id == short_channel_id:
@@ -815,7 +817,7 @@ class LNWallet(LNWorker):
for i in range(attempts):
route = await self._create_route_from_invoice(decoded_invoice=addr)
if not self.get_channel_by_short_id(route[0].short_channel_id):
- scid = format_short_channel_id(route[0].short_channel_id)
+ scid = route[0].short_channel_id
raise Exception(f"Got route with unknown first channel: {scid}")
self.network.trigger_callback('payment_status', key, 'progress', i)
if await self._pay_to_route(route, addr, invoice):
@@ -826,8 +828,8 @@ class LNWallet(LNWorker):
short_channel_id = route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id)
if not chan:
- scid = format_short_channel_id(short_channel_id)
- raise Exception(f"PathFinder returned path with short_channel_id {scid} that is not in channel list")
+ raise Exception(f"PathFinder returned path with short_channel_id "
+ f"{short_channel_id} that is not in channel list")
peer = self.peers[route[0].node_id]
htlc = await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry())
self.network.trigger_callback('htlc_added', htlc, addr, SENT)
@@ -879,6 +881,7 @@ class LNWallet(LNWorker):
prev_node_id = border_node_pubkey
for node_pubkey, edge_rest in zip(private_route_nodes, private_route_rest):
short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest
+ short_channel_id = ShortChannelID(short_channel_id)
# if we have a routing policy for this edge in the db, that takes precedence,
# as it is likely from a previous failure
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
@@ -1030,7 +1033,7 @@ class LNWallet(LNWorker):
if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat:
continue
chan_id = chan.short_channel_id
- assert type(chan_id) is bytes, chan_id
+ assert isinstance(chan_id, bytes), chan_id
channel_info = self.channel_db.get_channel_info(chan_id)
# note: as a fallback, if we don't have a channel update for the
# incoming direction of our private channel, we fill the invoice with garbage.
@@ -1048,8 +1051,7 @@ class LNWallet(LNWorker):
cltv_expiry_delta = policy.cltv_expiry_delta
missing_info = False
if missing_info:
- scid = format_short_channel_id(chan_id)
- self.logger.info(f"Warning. Missing channel update for our channel {scid}; "
+ self.logger.info(f"Warning. Missing channel update for our channel {chan_id}; "
f"filling invoice with incorrect data.")
routing_hints.append(('r', [(chan.node_id,
chan_id,