commit 46d8080c76e79670e8abaaaa0eb2d4d4a74544c1
parent 7d65fe1ba32200ae7e46841b7e0e4b6397bf7a2b
Author: SomberNight <somber.night@protonmail.com>
Date: Mon, 17 Feb 2020 20:38:41 +0100
ln gossip: don't put own channels into db; always pass them to fn calls
Previously we would put fake chan announcement and fake outgoing chan upd
for own channels into db (to make path finding work). See Peer.add_own_channel().
Now, instead of above, we pass a "my_channels" param to the relevant ChannelDB methods.
Diffstat:
6 files changed, 190 insertions(+), 151 deletions(-)
diff --git a/electrum/channel_db.py b/electrum/channel_db.py
@@ -39,9 +39,11 @@ from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enab
from .logging import Logger
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
+from .lnmsg import decode_msg
if TYPE_CHECKING:
from .network import Network
+ from .lnchannel import Channel
class UnknownEvenFeatureBits(Exception): pass
@@ -63,7 +65,7 @@ class ChannelInfo(NamedTuple):
capacity_sat: Optional[int]
@staticmethod
- def from_msg(payload):
+ def from_msg(payload: dict) -> 'ChannelInfo':
features = int.from_bytes(payload['features'], 'big')
validate_features(features)
channel_id = payload['short_channel_id']
@@ -78,6 +80,11 @@ class ChannelInfo(NamedTuple):
capacity_sat = capacity_sat
)
+ @staticmethod
+ def from_raw_msg(raw: bytes) -> 'ChannelInfo':
+ payload_dict = decode_msg(raw)[1]
+ return ChannelInfo.from_msg(payload_dict)
+
class Policy(NamedTuple):
key: bytes
@@ -91,7 +98,7 @@ class Policy(NamedTuple):
timestamp: int
@staticmethod
- def from_msg(payload):
+ def from_msg(payload: dict) -> 'Policy':
return Policy(
key = payload['short_channel_id'] + payload['start_node'],
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"),
@@ -248,11 +255,11 @@ class ChannelDB(SqlDB):
self.ca_verifier = LNChannelVerifier(network, self)
# initialized in load_data
self._channels = {} # type: Dict[bytes, ChannelInfo]
- self._policies = {}
+ self._policies = {} # type: Dict[Tuple[bytes, bytes], Policy] # (node_id, scid) -> Policy
self._nodes = {}
# node_id -> (host, port, ts)
self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]]
- self._channels_for_node = defaultdict(set)
+ self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
self.data_loaded = asyncio.Event()
self.network = network # only for callback
@@ -495,17 +502,6 @@ class ChannelDB(SqlDB):
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
self.update_counts()
- def get_routing_policy_for_channel(self, start_node_id: bytes,
- short_channel_id: bytes) -> Optional[Policy]:
- if not start_node_id or not short_channel_id: return None
- channel_info = self.get_channel_info(short_channel_id)
- if channel_info is not None:
- return self.get_policy_for_node(short_channel_id, start_node_id)
- msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
- if not msg:
- return None
- return Policy.from_msg(msg) # won't actually be written to DB
-
def get_old_policies(self, delta):
now = int(time.time())
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
@@ -587,12 +583,56 @@ class ChannelDB(SqlDB):
out.add(short_channel_id)
self.logger.info(f'semi-orphaned: {len(out)}')
- def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
- return self._policies.get((node_id, short_channel_id))
-
- def get_channel_info(self, channel_id: bytes) -> ChannelInfo:
- return self._channels.get(channel_id)
-
- def get_channels_for_node(self, node_id) -> Set[bytes]:
- """Returns the set of channels that have node_id as one of the endpoints."""
- return self._channels_for_node.get(node_id) or set()
+ def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *,
+ my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']:
+ channel_info = self.get_channel_info(short_channel_id)
+ if channel_info is not None: # publicly announced channel
+ policy = self._policies.get((node_id, short_channel_id))
+ if policy:
+ return policy
+ else: # private channel
+ chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id))
+ if chan_upd_dict:
+ return Policy.from_msg(chan_upd_dict)
+ # check if it's one of our own channels
+ if not my_channels:
+ return
+ chan = my_channels.get(short_channel_id) # type: Optional[Channel]
+ if not chan:
+ return
+ if node_id == chan.node_id: # incoming direction (to us)
+ remote_update_raw = chan.get_remote_update()
+ if not remote_update_raw:
+ return
+ now = int(time.time())
+ remote_update_decoded = decode_msg(remote_update_raw)[1]
+ remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
+ remote_update_decoded['start_node'] = node_id
+ return Policy.from_msg(remote_update_decoded)
+ elif node_id == chan.get_local_pubkey(): # outgoing direction (from us)
+ local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1]
+ local_update_decoded['start_node'] = node_id
+ return Policy.from_msg(local_update_decoded)
+
+ def get_channel_info(self, short_channel_id: bytes, *,
+ my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[ChannelInfo]:
+ ret = self._channels.get(short_channel_id)
+ if ret:
+ return ret
+ # check if it's one of our own channels
+ if not my_channels:
+ return
+ chan = my_channels.get(short_channel_id) # type: Optional[Channel]
+ ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs())
+ return ci._replace(capacity_sat=chan.constraints.capacity)
+
+ def get_channels_for_node(self, node_id: bytes, *,
+ my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]:
+ """Returns the set of short channel IDs where node_id is one of the channel participants."""
+ relevant_channels = self._channels_for_node.get(node_id) or set()
+ relevant_channels = set(relevant_channels) # copy
+ # add our own channels # TODO maybe slow?
+ for chan in (my_channels.values() or []):
+ if node_id in (chan.node_id, chan.get_local_pubkey()):
+ relevant_channels.add(chan.short_channel_id)
+ return relevant_channels
diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py
@@ -32,13 +32,14 @@ import time
import threading
from . import ecc
+from . import constants
from .util import bfh, bh2u
from .bitcoin import redeem_script_to_address
from .crypto import sha256, sha256d
from .transaction import Transaction, PartialTransaction
from .logging import Logger
-
from .lnonion import decode_onion_error
+from . import lnutil
from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints,
get_per_commitment_secret_from_seed, secret_to_pubkey, derive_privkey, make_closing_tx,
sign_and_get_sig_string, RevocationStore, derive_blinded_pubkey, Direction, derive_pubkey,
@@ -47,10 +48,10 @@ from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKey
funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs,
ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script,
ShortChannelID, map_htlcs_to_ctx_output_idxs)
-from .lnutil import FeeUpdate
from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx
from .lnsweep import create_sweeptx_for_their_revoked_htlc, SweepInfo
from .lnhtlc import HTLCManager
+from .lnmsg import encode_msg, decode_msg
if TYPE_CHECKING:
from .lnworker import LNWallet
@@ -136,7 +137,6 @@ class Channel(Logger):
self.funding_outpoint = state["funding_outpoint"]
self.node_id = bfh(state["node_id"])
self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"])
- self.short_channel_id_predicted = self.short_channel_id
self.onion_keys = state['onion_keys']
self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp']
self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate)
@@ -144,6 +144,7 @@ class Channel(Logger):
self.peer_state = peer_states.DISCONNECTED
self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]]
self._outgoing_channel_update = None # type: Optional[bytes]
+ self._chan_ann_without_sigs = None # type: Optional[bytes]
self.revocation_store = RevocationStore(state["revocation_store"])
def set_onion_key(self, key, value):
@@ -158,12 +159,77 @@ class Channel(Logger):
def get_data_loss_protect_remote_pcp(self, key):
return self.data_loss_protect_remote_pcp.get(key)
- def set_remote_update(self, raw):
+ def get_local_pubkey(self) -> bytes:
+ if not self.lnworker:
+ raise Exception('lnworker not set for channel!')
+ return self.lnworker.node_keypair.pubkey
+
+ def set_remote_update(self, raw: bytes) -> None:
self.storage['remote_update'] = raw.hex()
- def get_remote_update(self):
+ def get_remote_update(self) -> Optional[bytes]:
return bfh(self.storage.get('remote_update')) if self.storage.get('remote_update') else None
+ def get_outgoing_gossip_channel_update(self) -> bytes:
+ if self._outgoing_channel_update is not None:
+ return self._outgoing_channel_update
+ if not self.lnworker:
+ raise Exception('lnworker not set for channel!')
+ sorted_node_ids = list(sorted([self.node_id, self.get_local_pubkey()]))
+ channel_flags = b'\x00' if sorted_node_ids[0] == self.get_local_pubkey() else b'\x01'
+ now = int(time.time())
+ htlc_maximum_msat = min(self.config[REMOTE].max_htlc_value_in_flight_msat, 1000 * self.constraints.capacity)
+
+ chan_upd = encode_msg(
+ "channel_update",
+ short_channel_id=self.short_channel_id,
+ channel_flags=channel_flags,
+ message_flags=b'\x01',
+ cltv_expiry_delta=lnutil.NBLOCK_OUR_CLTV_EXPIRY_DELTA.to_bytes(2, byteorder="big"),
+ htlc_minimum_msat=self.config[REMOTE].htlc_minimum_msat.to_bytes(8, byteorder="big"),
+ htlc_maximum_msat=htlc_maximum_msat.to_bytes(8, byteorder="big"),
+ fee_base_msat=lnutil.OUR_FEE_BASE_MSAT.to_bytes(4, byteorder="big"),
+ fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS.to_bytes(4, byteorder="big"),
+ chain_hash=constants.net.rev_genesis_bytes(),
+ timestamp=now.to_bytes(4, byteorder="big"),
+ )
+ sighash = sha256d(chan_upd[2 + 64:])
+ sig = ecc.ECPrivkey(self.lnworker.node_keypair.privkey).sign(sighash, ecc.sig_string_from_r_and_s)
+ message_type, payload = decode_msg(chan_upd)
+ payload['signature'] = sig
+ chan_upd = encode_msg(message_type, **payload)
+
+ self._outgoing_channel_update = chan_upd
+ return chan_upd
+
+ def construct_channel_announcement_without_sigs(self) -> bytes:
+ if self._chan_ann_without_sigs is not None:
+ return self._chan_ann_without_sigs
+ if not self.lnworker:
+ raise Exception('lnworker not set for channel!')
+
+ bitcoin_keys = [self.config[REMOTE].multisig_key.pubkey,
+ self.config[LOCAL].multisig_key.pubkey]
+ node_ids = [self.node_id, self.get_local_pubkey()]
+ sorted_node_ids = list(sorted(node_ids))
+ if sorted_node_ids != node_ids:
+ node_ids = sorted_node_ids
+ bitcoin_keys.reverse()
+
+ chan_ann = encode_msg("channel_announcement",
+ len=0,
+ features=b'',
+ chain_hash=constants.net.rev_genesis_bytes(),
+ short_channel_id=self.short_channel_id,
+ node_id_1=node_ids[0],
+ node_id_2=node_ids[1],
+ bitcoin_key_1=bitcoin_keys[0],
+ bitcoin_key_2=bitcoin_keys[1]
+ )
+
+ self._chan_ann_without_sigs = chan_ann
+ return chan_ann
+
def set_short_channel_id(self, short_id):
self.short_channel_id = short_id
self.storage["short_channel_id"] = short_id
diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
@@ -953,112 +953,25 @@ class Peer(Logger):
assert chan.config[LOCAL].funding_locked_received
chan.set_state(channel_states.OPEN)
self.network.trigger_callback('channel', chan)
- self.add_own_channel(chan)
+ # peer may have sent us a channel update for the incoming direction previously
+ pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
+ if pending_channel_update:
+ chan.set_remote_update(pending_channel_update['raw'])
self.logger.info(f"CHANNEL OPENING COMPLETED for {scid}")
forwarding_enabled = self.network.config.get('lightning_forward_payments', False)
if forwarding_enabled:
# send channel_update of outgoing edge to peer,
# so that channel can be used to to receive payments
self.logger.info(f"sending channel update for outgoing edge of {scid}")
- chan_upd = self.get_outgoing_gossip_channel_update_for_chan(chan)
+ chan_upd = chan.get_outgoing_gossip_channel_update()
self.transport.send_bytes(chan_upd)
- def add_own_channel(self, chan):
- # add channel to database
- bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
- sorted_node_ids = list(sorted(self.node_ids))
- if sorted_node_ids != self.node_ids:
- bitcoin_keys.reverse()
- # note: we inject a channel announcement, and a channel update (for outgoing direction)
- # This is atm needed for
- # - finding routes
- # - the ChanAnn is needed so that we can anchor to it a future ChanUpd
- # that the remote sends, even if the channel was not announced
- # (from BOLT-07: "MAY create a channel_update to communicate the channel
- # parameters to the final node, even though the channel has not yet been announced")
- self.channel_db.add_channel_announcement(
- {
- "short_channel_id": chan.short_channel_id,
- "node_id_1": sorted_node_ids[0],
- "node_id_2": sorted_node_ids[1],
- 'chain_hash': constants.net.rev_genesis_bytes(),
- 'len': b'\x00\x00',
- 'features': b'',
- 'bitcoin_key_1': bitcoin_keys[0],
- 'bitcoin_key_2': bitcoin_keys[1]
- },
- trusted=True)
- # only inject outgoing direction:
- chan_upd_bytes = self.get_outgoing_gossip_channel_update_for_chan(chan)
- chan_upd_payload = decode_msg(chan_upd_bytes)[1]
- self.channel_db.add_channel_update(chan_upd_payload)
- # peer may have sent us a channel update for the incoming direction previously
- pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
- if pending_channel_update:
- chan.set_remote_update(pending_channel_update['raw'])
- # add remote update with a fresh timestamp
- if chan.get_remote_update():
- now = int(time.time())
- remote_update_decoded = decode_msg(chan.get_remote_update())[1]
- remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
- self.channel_db.add_channel_update(remote_update_decoded)
-
- def get_outgoing_gossip_channel_update_for_chan(self, chan: Channel) -> bytes:
- if chan._outgoing_channel_update is not None:
- return chan._outgoing_channel_update
- sorted_node_ids = list(sorted(self.node_ids))
- channel_flags = b'\x00' if sorted_node_ids[0] == privkey_to_pubkey(self.privkey) else b'\x01'
- now = int(time.time())
- htlc_maximum_msat = min(chan.config[REMOTE].max_htlc_value_in_flight_msat, 1000 * chan.constraints.capacity)
-
- chan_upd = encode_msg(
- "channel_update",
- short_channel_id=chan.short_channel_id,
- channel_flags=channel_flags,
- message_flags=b'\x01',
- cltv_expiry_delta=lnutil.NBLOCK_OUR_CLTV_EXPIRY_DELTA.to_bytes(2, byteorder="big"),
- htlc_minimum_msat=chan.config[REMOTE].htlc_minimum_msat.to_bytes(8, byteorder="big"),
- htlc_maximum_msat=htlc_maximum_msat.to_bytes(8, byteorder="big"),
- fee_base_msat=lnutil.OUR_FEE_BASE_MSAT.to_bytes(4, byteorder="big"),
- fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS.to_bytes(4, byteorder="big"),
- chain_hash=constants.net.rev_genesis_bytes(),
- timestamp=now.to_bytes(4, byteorder="big"),
- )
- sighash = sha256d(chan_upd[2 + 64:])
- sig = ecc.ECPrivkey(self.privkey).sign(sighash, sig_string_from_r_and_s)
- message_type, payload = decode_msg(chan_upd)
- payload['signature'] = sig
- chan_upd = encode_msg(message_type, **payload)
-
- chan._outgoing_channel_update = chan_upd
- return chan_upd
-
def send_announcement_signatures(self, chan: Channel):
-
- bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey,
- chan.config[LOCAL].multisig_key.pubkey]
-
- sorted_node_ids = list(sorted(self.node_ids))
- if sorted_node_ids != self.node_ids:
- node_ids = sorted_node_ids
- bitcoin_keys.reverse()
- else:
- node_ids = self.node_ids
-
- chan_ann = encode_msg("channel_announcement",
- len=0,
- #features not set (defaults to zeros)
- chain_hash=constants.net.rev_genesis_bytes(),
- short_channel_id=chan.short_channel_id,
- node_id_1=node_ids[0],
- node_id_2=node_ids[1],
- bitcoin_key_1=bitcoin_keys[0],
- bitcoin_key_2=bitcoin_keys[1]
- )
- to_hash = chan_ann[256+2:]
- h = sha256d(to_hash)
- bitcoin_signature = ecc.ECPrivkey(chan.config[LOCAL].multisig_key.privkey).sign(h, sig_string_from_r_and_s)
- node_signature = ecc.ECPrivkey(self.privkey).sign(h, sig_string_from_r_and_s)
+ chan_ann = chan.construct_channel_announcement_without_sigs()
+ preimage = chan_ann[256+2:]
+ msg_hash = sha256d(preimage)
+ bitcoin_signature = ecc.ECPrivkey(chan.config[LOCAL].multisig_key.privkey).sign(msg_hash, sig_string_from_r_and_s)
+ node_signature = ecc.ECPrivkey(self.privkey).sign(msg_hash, sig_string_from_r_and_s)
self.send_message("announcement_signatures",
channel_id=chan.channel_id,
short_channel_id=chan.short_channel_id,
@@ -1066,7 +979,7 @@ class Peer(Logger):
bitcoin_signature=bitcoin_signature
)
- return h, node_signature, bitcoin_signature
+ return msg_hash, node_signature, bitcoin_signature
def on_update_fail_htlc(self, payload):
channel_id = payload["channel_id"]
@@ -1255,7 +1168,7 @@ class Peer(Logger):
reason = OnionRoutingFailureMessage(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
- outgoing_chan_upd = self.get_outgoing_gossip_channel_update_for_chan(next_chan)[2:]
+ outgoing_chan_upd = next_chan.get_outgoing_gossip_channel_update()[2:]
outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big")
if next_chan.get_state() != channel_states.OPEN:
self.logger.info(f"cannot forward htlc. next_chan not OPEN: {next_chan_scid} in state {next_chan.get_state()}")
diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
@@ -129,18 +129,20 @@ class LNPathFinder(Logger):
self.blacklist.add(short_channel_id)
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
- payment_amt_msat: int, ignore_costs=False, is_mine=False) -> Tuple[float, int]:
+ payment_amt_msat: int, ignore_costs=False, is_mine=False, *,
+ my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Tuple[float, int]:
"""Heuristic cost of going through a channel.
Returns (heuristic_cost, fee_for_edge_msat).
"""
- channel_info = self.channel_db.get_channel_info(short_channel_id)
+ channel_info = self.channel_db.get_channel_info(short_channel_id, my_channels=my_channels)
if channel_info is None:
return float('inf'), 0
- channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node)
+ channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node, my_channels=my_channels)
if channel_policy is None:
return float('inf'), 0
# channels that did not publish both policies often return temporary channel failure
- if self.channel_db.get_policy_for_node(short_channel_id, end_node) is None and not is_mine:
+ if self.channel_db.get_policy_for_node(short_channel_id, end_node, my_channels=my_channels) is None \
+ and not is_mine:
return float('inf'), 0
if channel_policy.is_disabled():
return float('inf'), 0
@@ -164,8 +166,9 @@ class LNPathFinder(Logger):
@profiler
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
- invoice_amount_msat: int,
- my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]:
+ invoice_amount_msat: int, *,
+ my_channels: Dict[ShortChannelID, 'Channel'] = None) \
+ -> Optional[Sequence[Tuple[bytes, bytes]]]:
"""Return a path from nodeA to nodeB.
Returns a list of (node_id, short_channel_id) representing a path.
@@ -175,8 +178,7 @@ class LNPathFinder(Logger):
assert type(nodeA) is bytes
assert type(nodeB) is bytes
assert type(invoice_amount_msat) is int
- if my_channels is None: my_channels = []
- my_channels = {chan.short_channel_id: chan for chan in my_channels}
+ if my_channels is None: my_channels = {}
# FIXME paths cannot be longer than 20 edges (onion packet)...
@@ -204,7 +206,8 @@ class LNPathFinder(Logger):
end_node=edge_endnode,
payment_amt_msat=amount_msat,
ignore_costs=(edge_startnode == nodeA),
- is_mine=is_mine)
+ is_mine=is_mine,
+ my_channels=my_channels)
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
distance_from_start[edge_startnode] = alt_dist_to_neighbour
@@ -222,11 +225,11 @@ class LNPathFinder(Logger):
# so instead of decreasing priorities, we add items again into the queue.
# so there are duplicates in the queue, that we discard now:
continue
- for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
+ for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels):
assert isinstance(edge_channel_id, bytes)
if edge_channel_id in self.blacklist:
continue
- channel_info = self.channel_db.get_channel_info(edge_channel_id)
+ channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels)
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
inspect_edge()
else:
@@ -241,14 +244,17 @@ class LNPathFinder(Logger):
edge_startnode = edge_endnode
return path
- def create_route_from_path(self, path, from_node_id: bytes) -> LNPaymentRoute:
+ def create_route_from_path(self, path, from_node_id: bytes, *,
+ my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute:
assert isinstance(from_node_id, bytes)
if path is None:
raise Exception('cannot create route from None path')
route = []
prev_node_id = from_node_id
for node_id, short_channel_id in path:
- channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
+ channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
+ node_id=prev_node_id,
+ my_channels=my_channels)
if channel_policy is None:
raise NoChannelPolicy(short_channel_id)
route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id))
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -942,16 +942,20 @@ class LNWallet(LNWorker):
random.shuffle(r_tags)
with self.lock:
channels = list(self.channels.values())
+ scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
+ if chan.short_channel_id is not None}
for private_route in r_tags:
if len(private_route) == 0:
continue
if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
continue
border_node_pubkey = private_route[0][0]
- path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels)
+ path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat,
+ my_channels=scid_to_my_channels)
if not path:
continue
- route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
+ route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
+ my_channels=scid_to_my_channels)
# we need to shift the node pubkey by one towards the destination:
private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey]
private_route_rest = [edge[1:] for edge in private_route]
@@ -961,7 +965,9 @@ class LNWallet(LNWorker):
short_channel_id = ShortChannelID(short_channel_id)
# if we have a routing policy for this edge in the db, that takes precedence,
# as it is likely from a previous failure
- channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
+ channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
+ node_id=prev_node_id,
+ my_channels=scid_to_my_channels)
if channel_policy:
fee_base_msat = channel_policy.fee_base_msat
fee_proportional_millionths = channel_policy.fee_proportional_millionths
@@ -977,10 +983,12 @@ class LNWallet(LNWorker):
break
# if could not find route using any hint; try without hint now
if route is None:
- path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat, channels)
+ path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat,
+ my_channels=scid_to_my_channels)
if not path:
raise NoPathFound()
- route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
+ route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
+ my_channels=scid_to_my_channels)
if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
self.logger.info(f"rejecting insane route {route}")
raise NoPathFound()
@@ -1099,6 +1107,8 @@ class LNWallet(LNWorker):
routing_hints = []
with self.lock:
channels = list(self.channels.values())
+ scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
+ if chan.short_channel_id is not None}
# note: currently we add *all* our channels; but this might be a privacy leak?
for chan in channels:
# check channel is open
@@ -1110,7 +1120,7 @@ class LNWallet(LNWorker):
continue
chan_id = chan.short_channel_id
assert isinstance(chan_id, bytes), chan_id
- channel_info = self.channel_db.get_channel_info(chan_id)
+ channel_info = self.channel_db.get_channel_info(chan_id, my_channels=scid_to_my_channels)
# note: as a fallback, if we don't have a channel update for the
# incoming direction of our private channel, we fill the invoice with garbage.
# the sender should still be able to pay us, but will incur an extra round trip
@@ -1120,7 +1130,8 @@ class LNWallet(LNWorker):
cltv_expiry_delta = 1 # lnd won't even try with zero
missing_info = True
if channel_info:
- policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id)
+ policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id,
+ my_channels=scid_to_my_channels)
if policy:
fee_base_msat = policy.fee_base_msat
fee_proportional_millionths = policy.fee_proportional_millionths
diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
@@ -18,7 +18,7 @@ from electrum.lnpeer import Peer
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure, LnLocalFeatures
-from electrum.lnchannel import channel_states, peer_states
+from electrum.lnchannel import channel_states, peer_states, Channel
from electrum.lnrouter import LNPathFinder
from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet, NoPathFound
@@ -77,7 +77,7 @@ class MockWallet:
return False
class MockLNWallet:
- def __init__(self, remote_keypair, local_keypair, chan, tx_queue):
+ def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue):
self.remote_keypair = remote_keypair
self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue)
@@ -88,6 +88,8 @@ class MockLNWallet:
self.localfeatures = LnLocalFeatures(0)
self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT
self.pending_payments = defaultdict(asyncio.Future)
+ chan.lnworker = self
+ chan.node_id = remote_keypair.pubkey
def get_invoice_status(self, key):
pass
@@ -127,6 +129,7 @@ class MockLNWallet:
_pay_to_route = LNWallet._pay_to_route
force_close_channel = LNWallet.force_close_channel
get_first_timestamp = lambda self: 0
+ payment_completed = LNWallet.payment_completed
class MockTransport:
def __init__(self, name):
@@ -264,7 +267,7 @@ class TestPeer(ElectrumTestCase):
pay_req = self.prepare_invoice(w2)
async def pay():
result = await LNWallet._pay(w1, pay_req)
- self.assertEqual(result, True)
+ self.assertTrue(result)
gath.cancel()
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop())
async def f():