electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 46d8080c76e79670e8abaaaa0eb2d4d4a74544c1
parent 7d65fe1ba32200ae7e46841b7e0e4b6397bf7a2b
Author: SomberNight <somber.night@protonmail.com>
Date:   Mon, 17 Feb 2020 20:38:41 +0100

ln gossip: don't put own channels into db; always pass them to fn calls

Previously we would put fake chan announcement and fake outgoing chan upd
for own channels into db (to make path finding work). See Peer.add_own_channel().
Now, instead of above, we pass a "my_channels" param to the relevant ChannelDB methods.

Diffstat:
Melectrum/channel_db.py | 88+++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------
Melectrum/lnchannel.py | 76+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----
Melectrum/lnpeer.py | 111+++++++++----------------------------------------------------------------------
Melectrum/lnrouter.py | 32+++++++++++++++++++-------------
Melectrum/lnworker.py | 25++++++++++++++++++-------
Melectrum/tests/test_lnpeer.py | 9++++++---
6 files changed, 190 insertions(+), 151 deletions(-)

diff --git a/electrum/channel_db.py b/electrum/channel_db.py @@ -39,9 +39,11 @@ from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enab from .logging import Logger from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update +from .lnmsg import decode_msg if TYPE_CHECKING: from .network import Network + from .lnchannel import Channel class UnknownEvenFeatureBits(Exception): pass @@ -63,7 +65,7 @@ class ChannelInfo(NamedTuple): capacity_sat: Optional[int] @staticmethod - def from_msg(payload): + def from_msg(payload: dict) -> 'ChannelInfo': features = int.from_bytes(payload['features'], 'big') validate_features(features) channel_id = payload['short_channel_id'] @@ -78,6 +80,11 @@ class ChannelInfo(NamedTuple): capacity_sat = capacity_sat ) + @staticmethod + def from_raw_msg(raw: bytes) -> 'ChannelInfo': + payload_dict = decode_msg(raw)[1] + return ChannelInfo.from_msg(payload_dict) + class Policy(NamedTuple): key: bytes @@ -91,7 +98,7 @@ class Policy(NamedTuple): timestamp: int @staticmethod - def from_msg(payload): + def from_msg(payload: dict) -> 'Policy': return Policy( key = payload['short_channel_id'] + payload['start_node'], cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"), @@ -248,11 +255,11 @@ class ChannelDB(SqlDB): self.ca_verifier = LNChannelVerifier(network, self) # initialized in load_data self._channels = {} # type: Dict[bytes, ChannelInfo] - self._policies = {} + self._policies = {} # type: Dict[Tuple[bytes, bytes], Policy] # (node_id, scid) -> Policy self._nodes = {} # node_id -> (host, port, ts) self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]] - self._channels_for_node = defaultdict(set) + self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]] self.data_loaded = asyncio.Event() self.network = network # only for callback @@ -495,17 +502,6 @@ class ChannelDB(SqlDB): self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) self.update_counts() - def get_routing_policy_for_channel(self, start_node_id: bytes, - short_channel_id: bytes) -> Optional[Policy]: - if not start_node_id or not short_channel_id: return None - channel_info = self.get_channel_info(short_channel_id) - if channel_info is not None: - return self.get_policy_for_node(short_channel_id, start_node_id) - msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id)) - if not msg: - return None - return Policy.from_msg(msg) # won't actually be written to DB - def get_old_policies(self, delta): now = int(time.time()) return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta) @@ -587,12 +583,56 @@ class ChannelDB(SqlDB): out.add(short_channel_id) self.logger.info(f'semi-orphaned: {len(out)}') - def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']: - return self._policies.get((node_id, short_channel_id)) - - def get_channel_info(self, channel_id: bytes) -> ChannelInfo: - return self._channels.get(channel_id) - - def get_channels_for_node(self, node_id) -> Set[bytes]: - """Returns the set of channels that have node_id as one of the endpoints.""" - return self._channels_for_node.get(node_id) or set() + def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *, + my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']: + channel_info = self.get_channel_info(short_channel_id) + if channel_info is not None: # publicly announced channel + policy = self._policies.get((node_id, short_channel_id)) + if policy: + return policy + else: # private channel + chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id)) + if chan_upd_dict: + return Policy.from_msg(chan_upd_dict) + # check if it's one of our own channels + if not my_channels: + return + chan = my_channels.get(short_channel_id) # type: Optional[Channel] + if not chan: + return + if node_id == chan.node_id: # incoming direction (to us) + remote_update_raw = chan.get_remote_update() + if not remote_update_raw: + return + now = int(time.time()) + remote_update_decoded = decode_msg(remote_update_raw)[1] + remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big") + remote_update_decoded['start_node'] = node_id + return Policy.from_msg(remote_update_decoded) + elif node_id == chan.get_local_pubkey(): # outgoing direction (from us) + local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1] + local_update_decoded['start_node'] = node_id + return Policy.from_msg(local_update_decoded) + + def get_channel_info(self, short_channel_id: bytes, *, + my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[ChannelInfo]: + ret = self._channels.get(short_channel_id) + if ret: + return ret + # check if it's one of our own channels + if not my_channels: + return + chan = my_channels.get(short_channel_id) # type: Optional[Channel] + ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs()) + return ci._replace(capacity_sat=chan.constraints.capacity) + + def get_channels_for_node(self, node_id: bytes, *, + my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]: + """Returns the set of short channel IDs where node_id is one of the channel participants.""" + relevant_channels = self._channels_for_node.get(node_id) or set() + relevant_channels = set(relevant_channels) # copy + # add our own channels # TODO maybe slow? + for chan in (my_channels.values() or []): + if node_id in (chan.node_id, chan.get_local_pubkey()): + relevant_channels.add(chan.short_channel_id) + return relevant_channels diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py @@ -32,13 +32,14 @@ import time import threading from . import ecc +from . import constants from .util import bfh, bh2u from .bitcoin import redeem_script_to_address from .crypto import sha256, sha256d from .transaction import Transaction, PartialTransaction from .logging import Logger - from .lnonion import decode_onion_error +from . import lnutil from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints, get_per_commitment_secret_from_seed, secret_to_pubkey, derive_privkey, make_closing_tx, sign_and_get_sig_string, RevocationStore, derive_blinded_pubkey, Direction, derive_pubkey, @@ -47,10 +48,10 @@ from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKey funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs, ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script, ShortChannelID, map_htlcs_to_ctx_output_idxs) -from .lnutil import FeeUpdate from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx from .lnsweep import create_sweeptx_for_their_revoked_htlc, SweepInfo from .lnhtlc import HTLCManager +from .lnmsg import encode_msg, decode_msg if TYPE_CHECKING: from .lnworker import LNWallet @@ -136,7 +137,6 @@ class Channel(Logger): self.funding_outpoint = state["funding_outpoint"] self.node_id = bfh(state["node_id"]) self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"]) - self.short_channel_id_predicted = self.short_channel_id self.onion_keys = state['onion_keys'] self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp'] self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate) @@ -144,6 +144,7 @@ class Channel(Logger): self.peer_state = peer_states.DISCONNECTED self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]] self._outgoing_channel_update = None # type: Optional[bytes] + self._chan_ann_without_sigs = None # type: Optional[bytes] self.revocation_store = RevocationStore(state["revocation_store"]) def set_onion_key(self, key, value): @@ -158,12 +159,77 @@ class Channel(Logger): def get_data_loss_protect_remote_pcp(self, key): return self.data_loss_protect_remote_pcp.get(key) - def set_remote_update(self, raw): + def get_local_pubkey(self) -> bytes: + if not self.lnworker: + raise Exception('lnworker not set for channel!') + return self.lnworker.node_keypair.pubkey + + def set_remote_update(self, raw: bytes) -> None: self.storage['remote_update'] = raw.hex() - def get_remote_update(self): + def get_remote_update(self) -> Optional[bytes]: return bfh(self.storage.get('remote_update')) if self.storage.get('remote_update') else None + def get_outgoing_gossip_channel_update(self) -> bytes: + if self._outgoing_channel_update is not None: + return self._outgoing_channel_update + if not self.lnworker: + raise Exception('lnworker not set for channel!') + sorted_node_ids = list(sorted([self.node_id, self.get_local_pubkey()])) + channel_flags = b'\x00' if sorted_node_ids[0] == self.get_local_pubkey() else b'\x01' + now = int(time.time()) + htlc_maximum_msat = min(self.config[REMOTE].max_htlc_value_in_flight_msat, 1000 * self.constraints.capacity) + + chan_upd = encode_msg( + "channel_update", + short_channel_id=self.short_channel_id, + channel_flags=channel_flags, + message_flags=b'\x01', + cltv_expiry_delta=lnutil.NBLOCK_OUR_CLTV_EXPIRY_DELTA.to_bytes(2, byteorder="big"), + htlc_minimum_msat=self.config[REMOTE].htlc_minimum_msat.to_bytes(8, byteorder="big"), + htlc_maximum_msat=htlc_maximum_msat.to_bytes(8, byteorder="big"), + fee_base_msat=lnutil.OUR_FEE_BASE_MSAT.to_bytes(4, byteorder="big"), + fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS.to_bytes(4, byteorder="big"), + chain_hash=constants.net.rev_genesis_bytes(), + timestamp=now.to_bytes(4, byteorder="big"), + ) + sighash = sha256d(chan_upd[2 + 64:]) + sig = ecc.ECPrivkey(self.lnworker.node_keypair.privkey).sign(sighash, ecc.sig_string_from_r_and_s) + message_type, payload = decode_msg(chan_upd) + payload['signature'] = sig + chan_upd = encode_msg(message_type, **payload) + + self._outgoing_channel_update = chan_upd + return chan_upd + + def construct_channel_announcement_without_sigs(self) -> bytes: + if self._chan_ann_without_sigs is not None: + return self._chan_ann_without_sigs + if not self.lnworker: + raise Exception('lnworker not set for channel!') + + bitcoin_keys = [self.config[REMOTE].multisig_key.pubkey, + self.config[LOCAL].multisig_key.pubkey] + node_ids = [self.node_id, self.get_local_pubkey()] + sorted_node_ids = list(sorted(node_ids)) + if sorted_node_ids != node_ids: + node_ids = sorted_node_ids + bitcoin_keys.reverse() + + chan_ann = encode_msg("channel_announcement", + len=0, + features=b'', + chain_hash=constants.net.rev_genesis_bytes(), + short_channel_id=self.short_channel_id, + node_id_1=node_ids[0], + node_id_2=node_ids[1], + bitcoin_key_1=bitcoin_keys[0], + bitcoin_key_2=bitcoin_keys[1] + ) + + self._chan_ann_without_sigs = chan_ann + return chan_ann + def set_short_channel_id(self, short_id): self.short_channel_id = short_id self.storage["short_channel_id"] = short_id diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -953,112 +953,25 @@ class Peer(Logger): assert chan.config[LOCAL].funding_locked_received chan.set_state(channel_states.OPEN) self.network.trigger_callback('channel', chan) - self.add_own_channel(chan) + # peer may have sent us a channel update for the incoming direction previously + pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id) + if pending_channel_update: + chan.set_remote_update(pending_channel_update['raw']) self.logger.info(f"CHANNEL OPENING COMPLETED for {scid}") forwarding_enabled = self.network.config.get('lightning_forward_payments', False) if forwarding_enabled: # send channel_update of outgoing edge to peer, # so that channel can be used to to receive payments self.logger.info(f"sending channel update for outgoing edge of {scid}") - chan_upd = self.get_outgoing_gossip_channel_update_for_chan(chan) + chan_upd = chan.get_outgoing_gossip_channel_update() self.transport.send_bytes(chan_upd) - def add_own_channel(self, chan): - # add channel to database - bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey] - sorted_node_ids = list(sorted(self.node_ids)) - if sorted_node_ids != self.node_ids: - bitcoin_keys.reverse() - # note: we inject a channel announcement, and a channel update (for outgoing direction) - # This is atm needed for - # - finding routes - # - the ChanAnn is needed so that we can anchor to it a future ChanUpd - # that the remote sends, even if the channel was not announced - # (from BOLT-07: "MAY create a channel_update to communicate the channel - # parameters to the final node, even though the channel has not yet been announced") - self.channel_db.add_channel_announcement( - { - "short_channel_id": chan.short_channel_id, - "node_id_1": sorted_node_ids[0], - "node_id_2": sorted_node_ids[1], - 'chain_hash': constants.net.rev_genesis_bytes(), - 'len': b'\x00\x00', - 'features': b'', - 'bitcoin_key_1': bitcoin_keys[0], - 'bitcoin_key_2': bitcoin_keys[1] - }, - trusted=True) - # only inject outgoing direction: - chan_upd_bytes = self.get_outgoing_gossip_channel_update_for_chan(chan) - chan_upd_payload = decode_msg(chan_upd_bytes)[1] - self.channel_db.add_channel_update(chan_upd_payload) - # peer may have sent us a channel update for the incoming direction previously - pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id) - if pending_channel_update: - chan.set_remote_update(pending_channel_update['raw']) - # add remote update with a fresh timestamp - if chan.get_remote_update(): - now = int(time.time()) - remote_update_decoded = decode_msg(chan.get_remote_update())[1] - remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big") - self.channel_db.add_channel_update(remote_update_decoded) - - def get_outgoing_gossip_channel_update_for_chan(self, chan: Channel) -> bytes: - if chan._outgoing_channel_update is not None: - return chan._outgoing_channel_update - sorted_node_ids = list(sorted(self.node_ids)) - channel_flags = b'\x00' if sorted_node_ids[0] == privkey_to_pubkey(self.privkey) else b'\x01' - now = int(time.time()) - htlc_maximum_msat = min(chan.config[REMOTE].max_htlc_value_in_flight_msat, 1000 * chan.constraints.capacity) - - chan_upd = encode_msg( - "channel_update", - short_channel_id=chan.short_channel_id, - channel_flags=channel_flags, - message_flags=b'\x01', - cltv_expiry_delta=lnutil.NBLOCK_OUR_CLTV_EXPIRY_DELTA.to_bytes(2, byteorder="big"), - htlc_minimum_msat=chan.config[REMOTE].htlc_minimum_msat.to_bytes(8, byteorder="big"), - htlc_maximum_msat=htlc_maximum_msat.to_bytes(8, byteorder="big"), - fee_base_msat=lnutil.OUR_FEE_BASE_MSAT.to_bytes(4, byteorder="big"), - fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS.to_bytes(4, byteorder="big"), - chain_hash=constants.net.rev_genesis_bytes(), - timestamp=now.to_bytes(4, byteorder="big"), - ) - sighash = sha256d(chan_upd[2 + 64:]) - sig = ecc.ECPrivkey(self.privkey).sign(sighash, sig_string_from_r_and_s) - message_type, payload = decode_msg(chan_upd) - payload['signature'] = sig - chan_upd = encode_msg(message_type, **payload) - - chan._outgoing_channel_update = chan_upd - return chan_upd - def send_announcement_signatures(self, chan: Channel): - - bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey, - chan.config[LOCAL].multisig_key.pubkey] - - sorted_node_ids = list(sorted(self.node_ids)) - if sorted_node_ids != self.node_ids: - node_ids = sorted_node_ids - bitcoin_keys.reverse() - else: - node_ids = self.node_ids - - chan_ann = encode_msg("channel_announcement", - len=0, - #features not set (defaults to zeros) - chain_hash=constants.net.rev_genesis_bytes(), - short_channel_id=chan.short_channel_id, - node_id_1=node_ids[0], - node_id_2=node_ids[1], - bitcoin_key_1=bitcoin_keys[0], - bitcoin_key_2=bitcoin_keys[1] - ) - to_hash = chan_ann[256+2:] - h = sha256d(to_hash) - bitcoin_signature = ecc.ECPrivkey(chan.config[LOCAL].multisig_key.privkey).sign(h, sig_string_from_r_and_s) - node_signature = ecc.ECPrivkey(self.privkey).sign(h, sig_string_from_r_and_s) + chan_ann = chan.construct_channel_announcement_without_sigs() + preimage = chan_ann[256+2:] + msg_hash = sha256d(preimage) + bitcoin_signature = ecc.ECPrivkey(chan.config[LOCAL].multisig_key.privkey).sign(msg_hash, sig_string_from_r_and_s) + node_signature = ecc.ECPrivkey(self.privkey).sign(msg_hash, sig_string_from_r_and_s) self.send_message("announcement_signatures", channel_id=chan.channel_id, short_channel_id=chan.short_channel_id, @@ -1066,7 +979,7 @@ class Peer(Logger): bitcoin_signature=bitcoin_signature ) - return h, node_signature, bitcoin_signature + return msg_hash, node_signature, bitcoin_signature def on_update_fail_htlc(self, payload): channel_id = payload["channel_id"] @@ -1255,7 +1168,7 @@ class Peer(Logger): reason = OnionRoutingFailureMessage(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return - outgoing_chan_upd = self.get_outgoing_gossip_channel_update_for_chan(next_chan)[2:] + outgoing_chan_upd = next_chan.get_outgoing_gossip_channel_update()[2:] outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big") if next_chan.get_state() != channel_states.OPEN: self.logger.info(f"cannot forward htlc. next_chan not OPEN: {next_chan_scid} in state {next_chan.get_state()}") diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py @@ -129,18 +129,20 @@ class LNPathFinder(Logger): self.blacklist.add(short_channel_id) def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes, - payment_amt_msat: int, ignore_costs=False, is_mine=False) -> Tuple[float, int]: + payment_amt_msat: int, ignore_costs=False, is_mine=False, *, + my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Tuple[float, int]: """Heuristic cost of going through a channel. Returns (heuristic_cost, fee_for_edge_msat). """ - channel_info = self.channel_db.get_channel_info(short_channel_id) + channel_info = self.channel_db.get_channel_info(short_channel_id, my_channels=my_channels) if channel_info is None: return float('inf'), 0 - channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node) + channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node, my_channels=my_channels) if channel_policy is None: return float('inf'), 0 # channels that did not publish both policies often return temporary channel failure - if self.channel_db.get_policy_for_node(short_channel_id, end_node) is None and not is_mine: + if self.channel_db.get_policy_for_node(short_channel_id, end_node, my_channels=my_channels) is None \ + and not is_mine: return float('inf'), 0 if channel_policy.is_disabled(): return float('inf'), 0 @@ -164,8 +166,9 @@ class LNPathFinder(Logger): @profiler def find_path_for_payment(self, nodeA: bytes, nodeB: bytes, - invoice_amount_msat: int, - my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]: + invoice_amount_msat: int, *, + my_channels: Dict[ShortChannelID, 'Channel'] = None) \ + -> Optional[Sequence[Tuple[bytes, bytes]]]: """Return a path from nodeA to nodeB. Returns a list of (node_id, short_channel_id) representing a path. @@ -175,8 +178,7 @@ class LNPathFinder(Logger): assert type(nodeA) is bytes assert type(nodeB) is bytes assert type(invoice_amount_msat) is int - if my_channels is None: my_channels = [] - my_channels = {chan.short_channel_id: chan for chan in my_channels} + if my_channels is None: my_channels = {} # FIXME paths cannot be longer than 20 edges (onion packet)... @@ -204,7 +206,8 @@ class LNPathFinder(Logger): end_node=edge_endnode, payment_amt_msat=amount_msat, ignore_costs=(edge_startnode == nodeA), - is_mine=is_mine) + is_mine=is_mine, + my_channels=my_channels) alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost if alt_dist_to_neighbour < distance_from_start[edge_startnode]: distance_from_start[edge_startnode] = alt_dist_to_neighbour @@ -222,11 +225,11 @@ class LNPathFinder(Logger): # so instead of decreasing priorities, we add items again into the queue. # so there are duplicates in the queue, that we discard now: continue - for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode): + for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels): assert isinstance(edge_channel_id, bytes) if edge_channel_id in self.blacklist: continue - channel_info = self.channel_db.get_channel_info(edge_channel_id) + channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels) edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id inspect_edge() else: @@ -241,14 +244,17 @@ class LNPathFinder(Logger): edge_startnode = edge_endnode return path - def create_route_from_path(self, path, from_node_id: bytes) -> LNPaymentRoute: + def create_route_from_path(self, path, from_node_id: bytes, *, + my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute: assert isinstance(from_node_id, bytes) if path is None: raise Exception('cannot create route from None path') route = [] prev_node_id = from_node_id for node_id, short_channel_id in path: - channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id) + channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id, + node_id=prev_node_id, + my_channels=my_channels) if channel_policy is None: raise NoChannelPolicy(short_channel_id) route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id)) diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -942,16 +942,20 @@ class LNWallet(LNWorker): random.shuffle(r_tags) with self.lock: channels = list(self.channels.values()) + scid_to_my_channels = {chan.short_channel_id: chan for chan in channels + if chan.short_channel_id is not None} for private_route in r_tags: if len(private_route) == 0: continue if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH: continue border_node_pubkey = private_route[0][0] - path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels) + path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, + my_channels=scid_to_my_channels) if not path: continue - route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey) + route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey, + my_channels=scid_to_my_channels) # we need to shift the node pubkey by one towards the destination: private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey] private_route_rest = [edge[1:] for edge in private_route] @@ -961,7 +965,9 @@ class LNWallet(LNWorker): short_channel_id = ShortChannelID(short_channel_id) # if we have a routing policy for this edge in the db, that takes precedence, # as it is likely from a previous failure - channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id) + channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id, + node_id=prev_node_id, + my_channels=scid_to_my_channels) if channel_policy: fee_base_msat = channel_policy.fee_base_msat fee_proportional_millionths = channel_policy.fee_proportional_millionths @@ -977,10 +983,12 @@ class LNWallet(LNWorker): break # if could not find route using any hint; try without hint now if route is None: - path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat, channels) + path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat, + my_channels=scid_to_my_channels) if not path: raise NoPathFound() - route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey) + route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey, + my_channels=scid_to_my_channels) if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()): self.logger.info(f"rejecting insane route {route}") raise NoPathFound() @@ -1099,6 +1107,8 @@ class LNWallet(LNWorker): routing_hints = [] with self.lock: channels = list(self.channels.values()) + scid_to_my_channels = {chan.short_channel_id: chan for chan in channels + if chan.short_channel_id is not None} # note: currently we add *all* our channels; but this might be a privacy leak? for chan in channels: # check channel is open @@ -1110,7 +1120,7 @@ class LNWallet(LNWorker): continue chan_id = chan.short_channel_id assert isinstance(chan_id, bytes), chan_id - channel_info = self.channel_db.get_channel_info(chan_id) + channel_info = self.channel_db.get_channel_info(chan_id, my_channels=scid_to_my_channels) # note: as a fallback, if we don't have a channel update for the # incoming direction of our private channel, we fill the invoice with garbage. # the sender should still be able to pay us, but will incur an extra round trip @@ -1120,7 +1130,8 @@ class LNWallet(LNWorker): cltv_expiry_delta = 1 # lnd won't even try with zero missing_info = True if channel_info: - policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id) + policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id, + my_channels=scid_to_my_channels) if policy: fee_base_msat = policy.fee_base_msat fee_proportional_millionths = policy.fee_proportional_millionths diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -18,7 +18,7 @@ from electrum.lnpeer import Peer from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import PaymentFailure, LnLocalFeatures -from electrum.lnchannel import channel_states, peer_states +from electrum.lnchannel import channel_states, peer_states, Channel from electrum.lnrouter import LNPathFinder from electrum.channel_db import ChannelDB from electrum.lnworker import LNWallet, NoPathFound @@ -77,7 +77,7 @@ class MockWallet: return False class MockLNWallet: - def __init__(self, remote_keypair, local_keypair, chan, tx_queue): + def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue): self.remote_keypair = remote_keypair self.node_keypair = local_keypair self.network = MockNetwork(tx_queue) @@ -88,6 +88,8 @@ class MockLNWallet: self.localfeatures = LnLocalFeatures(0) self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT self.pending_payments = defaultdict(asyncio.Future) + chan.lnworker = self + chan.node_id = remote_keypair.pubkey def get_invoice_status(self, key): pass @@ -127,6 +129,7 @@ class MockLNWallet: _pay_to_route = LNWallet._pay_to_route force_close_channel = LNWallet.force_close_channel get_first_timestamp = lambda self: 0 + payment_completed = LNWallet.payment_completed class MockTransport: def __init__(self, name): @@ -264,7 +267,7 @@ class TestPeer(ElectrumTestCase): pay_req = self.prepare_invoice(w2) async def pay(): result = await LNWallet._pay(w1, pay_req) - self.assertEqual(result, True) + self.assertTrue(result) gath.cancel() gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop()) async def f():