electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 750d8cfab5bc9b263d9f94b3b59079c8a2cf941b
parent 4445cef033b1c60e05e19ff4dae43d0a026105da
Author: SomberNight <somber.night@protonmail.com>
Date:   Tue,  2 Mar 2021 18:00:31 +0100

lnworker: run create_route_for_payment end-to-end, incl private edges

We pass the private edges to lnrouter, and let it find routes end-to-end.
Previously the edge_cost heuristics didn't apply to the private edges
and we were just randomly picking one of the route hints and use that.
So e.g. cheaper private edges were not preferred, but they are now.

PathEdge now stores both start_node and end_node; not just end_node.

Diffstat:
Melectrum/channel_db.py | 87+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------
Melectrum/lnrouter.py | 206+++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------
Melectrum/lnworker.py | 122+++++++++++++++++++++++++++++++++++++------------------------------------------
Melectrum/tests/test_lnpeer.py | 20+++++++++++++++-----
Melectrum/tests/test_lnrouter.py | 12+++++++-----
Melectrum/trampoline.py | 35++++++++++++++++++++++-------------
6 files changed, 323 insertions(+), 159 deletions(-)

diff --git a/electrum/channel_db.py b/electrum/channel_db.py @@ -48,6 +48,7 @@ from .lnmsg import decode_msg if TYPE_CHECKING: from .network import Network from .lnchannel import Channel + from .lnrouter import RouteEdge FLAG_DISABLE = 1 << 1 @@ -81,6 +82,16 @@ class ChannelInfo(NamedTuple): payload_dict = decode_msg(raw)[1] return ChannelInfo.from_msg(payload_dict) + @staticmethod + def from_route_edge(route_edge: 'RouteEdge') -> 'ChannelInfo': + node1_id, node2_id = sorted([route_edge.start_node, route_edge.end_node]) + return ChannelInfo( + short_channel_id=route_edge.short_channel_id, + node1_id=node1_id, + node2_id=node2_id, + capacity_sat=None, + ) + class Policy(NamedTuple): key: bytes @@ -113,6 +124,20 @@ class Policy(NamedTuple): payload['start_node'] = key[8:] return Policy.from_msg(payload) + @staticmethod + def from_route_edge(route_edge: 'RouteEdge') -> 'Policy': + return Policy( + key=route_edge.short_channel_id + route_edge.start_node, + cltv_expiry_delta=route_edge.cltv_expiry_delta, + htlc_minimum_msat=0, + htlc_maximum_msat=None, + fee_base_msat=route_edge.fee_base_msat, + fee_proportional_millionths=route_edge.fee_proportional_millionths, + channel_flags=0, + message_flags=0, + timestamp=0, + ) + def is_disabled(self): return self.channel_flags & FLAG_DISABLE @@ -216,6 +241,8 @@ class CategorizedChannelUpdates(NamedTuple): def get_mychannel_info(short_channel_id: ShortChannelID, my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]: chan = my_channels.get(short_channel_id) + if not chan: + return ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs()) return ci._replace(capacity_sat=chan.constraints.capacity) @@ -724,8 +751,14 @@ class ChannelDB(SqlDB): nchans_with_2p = len(self._chans_with_2_policies) return nchans_with_0p, nchans_with_1p, nchans_with_2p - def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']: + def get_policy_for_node( + self, + short_channel_id: bytes, + node_id: bytes, + *, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, + ) -> Optional['Policy']: channel_info = self.get_channel_info(short_channel_id) if channel_info is not None: # publicly announced channel policy = self._policies.get((node_id, short_channel_id)) @@ -737,28 +770,56 @@ class ChannelDB(SqlDB): return Policy.from_msg(chan_upd_dict) # check if it's one of our own channels if my_channels: - return get_mychannel_policy(short_channel_id, node_id, my_channels) - - def get_channel_info(self, short_channel_id: ShortChannelID, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[ChannelInfo]: + policy = get_mychannel_policy(short_channel_id, node_id, my_channels) + if policy: + return policy + if private_route_edges: + route_edge = private_route_edges.get(short_channel_id, None) + if route_edge: + return Policy.from_route_edge(route_edge) + + def get_channel_info( + self, + short_channel_id: ShortChannelID, + *, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, + ) -> Optional[ChannelInfo]: ret = self._channels.get(short_channel_id) if ret: return ret # check if it's one of our own channels if my_channels: - return get_mychannel_info(short_channel_id, my_channels) - - def get_channels_for_node(self, node_id: bytes, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]: + channel_info = get_mychannel_info(short_channel_id, my_channels) + if channel_info: + return channel_info + if private_route_edges: + route_edge = private_route_edges.get(short_channel_id) + if route_edge: + return ChannelInfo.from_route_edge(route_edge) + + def get_channels_for_node( + self, + node_id: bytes, + *, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, + ) -> Set[bytes]: """Returns the set of short channel IDs where node_id is one of the channel participants.""" if not self.data_loaded.is_set(): raise Exception("channelDB data not loaded yet!") relevant_channels = self._channels_for_node.get(node_id) or set() relevant_channels = set(relevant_channels) # copy # add our own channels # TODO maybe slow? - for chan in (my_channels.values() or []): - if node_id in (chan.node_id, chan.get_local_pubkey()): - relevant_channels.add(chan.short_channel_id) + if my_channels: + for chan in my_channels.values(): + if node_id in (chan.node_id, chan.get_local_pubkey()): + relevant_channels.add(chan.short_channel_id) + # add private channels # TODO maybe slow? + if private_route_edges: + for route_edge in private_route_edges.values(): + if node_id in (route_edge.start_node, route_edge.end_node): + relevant_channels.add(route_edge.short_channel_id) return relevant_channels def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *, diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py @@ -55,10 +55,14 @@ def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_propor @attr.s(slots=True) class PathEdge: - """if you travel through short_channel_id, you will reach node_id""" - node_id = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex()) + start_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex()) + end_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex()) short_channel_id = attr.ib(type=ShortChannelID, kw_only=True, repr=lambda val: str(val)) + @property + def node_id(self) -> bytes: + # legacy compat # TODO rm + return self.end_node @attr.s class RouteEdge(PathEdge): @@ -73,17 +77,26 @@ class RouteEdge(PathEdge): fee_proportional_millionths=self.fee_proportional_millionths) @classmethod - def from_channel_policy(cls, channel_policy: 'Policy', - short_channel_id: bytes, end_node: bytes, *, - node_info: Optional[NodeInfo]) -> 'RouteEdge': + def from_channel_policy( + cls, + *, + channel_policy: 'Policy', + short_channel_id: bytes, + start_node: bytes, + end_node: bytes, + node_info: Optional[NodeInfo], # for end_node + ) -> 'RouteEdge': assert isinstance(short_channel_id, bytes) + assert type(start_node) is bytes assert type(end_node) is bytes - return RouteEdge(node_id=end_node, - short_channel_id=ShortChannelID.normalize(short_channel_id), - fee_base_msat=channel_policy.fee_base_msat, - fee_proportional_millionths=channel_policy.fee_proportional_millionths, - cltv_expiry_delta=channel_policy.cltv_expiry_delta, - node_features=node_info.features if node_info else 0) + return RouteEdge( + start_node=start_node, + end_node=end_node, + short_channel_id=ShortChannelID.normalize(short_channel_id), + fee_base_msat=channel_policy.fee_base_msat, + fee_proportional_millionths=channel_policy.fee_proportional_millionths, + cltv_expiry_delta=channel_policy.cltv_expiry_delta, + node_features=node_info.features if node_info else 0) def is_sane_to_use(self, amount_msat: int) -> bool: # TODO revise ad-hoc heuristics @@ -155,21 +168,37 @@ class LNPathFinder(Logger): Logger.__init__(self) self.channel_db = channel_db - def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes, - payment_amt_msat: int, ignore_costs=False, is_mine=False, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Tuple[float, int]: + def _edge_cost( + self, + *, + short_channel_id: bytes, + start_node: bytes, + end_node: bytes, + payment_amt_msat: int, + ignore_costs=False, + is_mine=False, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + private_route_edges: Dict[ShortChannelID, RouteEdge] = None, + ) -> Tuple[float, int]: """Heuristic cost (distance metric) of going through a channel. Returns (heuristic_cost, fee_for_edge_msat). """ - channel_info = self.channel_db.get_channel_info(short_channel_id, my_channels=my_channels) + if private_route_edges is None: + private_route_edges = {} + channel_info = self.channel_db.get_channel_info( + short_channel_id, my_channels=my_channels, private_route_edges=private_route_edges) if channel_info is None: return float('inf'), 0 - channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node, my_channels=my_channels) + channel_policy = self.channel_db.get_policy_for_node( + short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges) if channel_policy is None: return float('inf'), 0 # channels that did not publish both policies often return temporary channel failure - if self.channel_db.get_policy_for_node(short_channel_id, end_node, my_channels=my_channels) is None \ - and not is_mine: + channel_policy_backwards = self.channel_db.get_policy_for_node( + short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges) + if (channel_policy_backwards is None + and not is_mine + and short_channel_id not in private_route_edges): return float('inf'), 0 if channel_policy.is_disabled(): return float('inf'), 0 @@ -181,9 +210,15 @@ class LNPathFinder(Logger): if channel_policy.htlc_maximum_msat is not None and \ payment_amt_msat > channel_policy.htlc_maximum_msat: return float('inf'), 0 # payment amount too large - node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node) - route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node, - node_info=node_info) + route_edge = private_route_edges.get(short_channel_id, None) + if route_edge is None: + node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node) + route_edge = RouteEdge.from_channel_policy( + channel_policy=channel_policy, + short_channel_id=short_channel_id, + start_node=start_node, + end_node=end_node, + node_info=node_info) if not route_edge.is_sane_to_use(payment_amt_msat): return float('inf'), 0 # thanks but no thanks @@ -201,9 +236,16 @@ class LNPathFinder(Logger): overall_cost = base_cost + fee_msat + cltv_cost return overall_cost, fee_msat - def get_distances(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None, - blacklist: Set[ShortChannelID] = None) -> Dict[bytes, PathEdge]: + def get_distances( + self, + *, + nodeA: bytes, + nodeB: bytes, + invoice_amount_msat: int, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + blacklist: Set[ShortChannelID] = None, + private_route_edges: Dict[ShortChannelID, RouteEdge] = None, + ) -> Dict[bytes, PathEdge]: # note: we don't lock self.channel_db, so while the path finding runs, # the underlying graph could potentially change... (not good but maybe ~OK?) @@ -226,11 +268,13 @@ class LNPathFinder(Logger): # so instead of decreasing priorities, we add items again into the queue. # so there are duplicates in the queue, that we discard now: continue - for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels): + for edge_channel_id in self.channel_db.get_channels_for_node( + edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges): assert isinstance(edge_channel_id, bytes) if blacklist and edge_channel_id in blacklist: continue - channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels) + channel_info = self.channel_db.get_channel_info( + edge_channel_id, my_channels=my_channels, private_route_edges=private_route_edges) edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id is_mine = edge_channel_id in my_channels if is_mine: @@ -242,29 +286,37 @@ class LNPathFinder(Logger): if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True): continue edge_cost, fee_for_edge_msat = self._edge_cost( - edge_channel_id, + short_channel_id=edge_channel_id, start_node=edge_startnode, end_node=edge_endnode, payment_amt_msat=amount_msat, ignore_costs=(edge_startnode == nodeA), is_mine=is_mine, - my_channels=my_channels) + my_channels=my_channels, + private_route_edges=private_route_edges) alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost if alt_dist_to_neighbour < distance_from_start[edge_startnode]: distance_from_start[edge_startnode] = alt_dist_to_neighbour - prev_node[edge_startnode] = PathEdge(node_id=edge_endnode, - short_channel_id=ShortChannelID(edge_channel_id)) + prev_node[edge_startnode] = PathEdge( + start_node=edge_startnode, + end_node=edge_endnode, + short_channel_id=ShortChannelID(edge_channel_id)) amount_to_forward_msat = amount_msat + fee_for_edge_msat nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode)) return prev_node @profiler - def find_path_for_payment(self, nodeA: bytes, nodeB: bytes, - invoice_amount_msat: int, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None, - blacklist: Set[ShortChannelID] = None) \ - -> Optional[LNPaymentPath]: + def find_path_for_payment( + self, + *, + nodeA: bytes, + nodeB: bytes, + invoice_amount_msat: int, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + blacklist: Set[ShortChannelID] = None, + private_route_edges: Dict[ShortChannelID, RouteEdge] = None, + ) -> Optional[LNPaymentPath]: """Return a path from nodeA to nodeB.""" assert type(nodeA) is bytes assert type(nodeB) is bytes @@ -272,7 +324,13 @@ class LNPathFinder(Logger): if my_channels is None: my_channels = {} - prev_node = self.get_distances(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist) + prev_node = self.get_distances( + nodeA=nodeA, + nodeB=nodeB, + invoice_amount_msat=invoice_amount_msat, + my_channels=my_channels, + blacklist=blacklist, + private_route_edges=private_route_edges) if nodeA not in prev_node: return None # no path found @@ -287,34 +345,66 @@ class LNPathFinder(Logger): edge_startnode = edge.node_id return path - def create_route_from_path(self, path: Optional[LNPaymentPath], from_node_id: bytes, *, - my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute: - assert isinstance(from_node_id, bytes) + def create_route_from_path( + self, + path: Optional[LNPaymentPath], + *, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + private_route_edges: Dict[ShortChannelID, RouteEdge] = None, + ) -> LNPaymentRoute: if path is None: raise Exception('cannot create route from None path') + if private_route_edges is None: + private_route_edges = {} route = [] - prev_node_id = from_node_id - for edge in path: - node_id = edge.node_id - short_channel_id = edge.short_channel_id + prev_end_node = path[0].start_node + for path_edge in path: + short_channel_id = path_edge.short_channel_id _endnodes = self.channel_db.get_endnodes_for_chan(short_channel_id, my_channels=my_channels) - if _endnodes and sorted(_endnodes) != sorted([prev_node_id, node_id]): + if _endnodes and sorted(_endnodes) != sorted([path_edge.start_node, path_edge.end_node]): + raise LNPathInconsistent("endpoints of edge inconsistent with short_channel_id") + if path_edge.start_node != prev_end_node: raise LNPathInconsistent("edges do not chain together") - channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id, - node_id=prev_node_id, - my_channels=my_channels) - if channel_policy is None: - raise NoChannelPolicy(short_channel_id) - node_info = self.channel_db.get_node_info_for_node_id(node_id=node_id) - route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id, - node_info=node_info)) - prev_node_id = node_id + route_edge = private_route_edges.get(short_channel_id, None) + if route_edge is None: + channel_policy = self.channel_db.get_policy_for_node( + short_channel_id=short_channel_id, + node_id=path_edge.start_node, + my_channels=my_channels) + if channel_policy is None: + raise NoChannelPolicy(short_channel_id) + node_info = self.channel_db.get_node_info_for_node_id(node_id=path_edge.end_node) + route_edge = RouteEdge.from_channel_policy( + channel_policy=channel_policy, + short_channel_id=short_channel_id, + start_node=path_edge.start_node, + end_node=path_edge.end_node, + node_info=node_info) + route.append(route_edge) + prev_end_node = path_edge.end_node return route - def find_route(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, *, - path = None, my_channels: Dict[ShortChannelID, 'Channel'] = None, - blacklist: Set[ShortChannelID] = None) -> Optional[LNPaymentRoute]: + def find_route( + self, + *, + nodeA: bytes, + nodeB: bytes, + invoice_amount_msat: int, + path = None, + my_channels: Dict[ShortChannelID, 'Channel'] = None, + blacklist: Set[ShortChannelID] = None, + private_route_edges: Dict[ShortChannelID, RouteEdge] = None, + ) -> Optional[LNPaymentRoute]: + route = None if not path: - path = self.find_path_for_payment(nodeA, nodeB, invoice_amount_msat, my_channels=my_channels, blacklist=blacklist) + path = self.find_path_for_payment( + nodeA=nodeA, + nodeB=nodeB, + invoice_amount_msat=invoice_amount_msat, + my_channels=my_channels, + blacklist=blacklist, + private_route_edges=private_route_edges) if path: - return self.create_route_from_path(path, nodeA, my_channels=my_channels) + route = self.create_route_from_path( + path, my_channels=my_channels, private_route_edges=private_route_edges) + return route diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -1320,6 +1320,7 @@ class LNWallet(LNWorker): amount_msat=amount_msat, bucket_amount_msat=amount_msat, min_cltv_expiry=min_cltv_expiry, + my_pubkey=self.node_keypair.pubkey, invoice_pubkey=invoice_pubkey, invoice_features=invoice_features, node_id=chan.node_id, @@ -1336,7 +1337,8 @@ class LNWallet(LNWorker): continue route = [ RouteEdge( - node_id=chan.node_id, + start_node=self.node_keypair.pubkey, + end_node=chan.node_id, short_channel_id=chan.short_channel_id, fee_base_msat=0, fee_proportional_millionths=0, @@ -1383,6 +1385,7 @@ class LNWallet(LNWorker): amount_msat=amount_msat, bucket_amount_msat=bucket_amount_msat, min_cltv_expiry=min_cltv_expiry, + my_pubkey=self.node_keypair.pubkey, invoice_pubkey=invoice_pubkey, invoice_features=invoice_features, node_id=node_id, @@ -1404,7 +1407,8 @@ class LNWallet(LNWorker): trampoline_fee -= delta_fee route = [ RouteEdge( - node_id=node_id, + start_node=self.node_keypair.pubkey, + end_node=node_id, short_channel_id=chan.short_channel_id, fee_base_msat=0, fee_proportional_millionths=0, @@ -1447,77 +1451,65 @@ class LNWallet(LNWorker): full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]: channels = [outgoing_channel] if outgoing_channel else list(self.channels.values()) - route = None scid_to_my_channels = { chan.short_channel_id: chan for chan in channels if chan.short_channel_id is not None } blacklist = self.network.channel_blacklist.get_current_list() - # first try with routing hints, then without - for private_path in r_tags + [[]]: - private_route = [] - amount_for_node = amount_msat - path = full_path - if len(private_path) > NUM_MAX_EDGES_IN_PAYMENT_PATH: - continue - if len(private_path) == 0: - border_node_pubkey = invoice_pubkey - else: - border_node_pubkey = private_path[0][0] - # we need to shift the node pubkey by one towards the destination: - private_path_nodes = [edge[0] for edge in private_path][1:] + [invoice_pubkey] - private_path_rest = [edge[1:] for edge in private_path] - prev_node_id = border_node_pubkey - for node_pubkey, edge_rest in zip(private_path_nodes, private_path_rest): - short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest - short_channel_id = ShortChannelID(short_channel_id) - # if we have a routing policy for this edge in the db, that takes precedence, - # as it is likely from a previous failure - channel_policy = self.channel_db.get_policy_for_node( + # Collect all private edges from route hints. + # Note: if some route hints are multiple edges long, and these paths cross each other, + # we allow our path finding to cross the paths; i.e. the route hints are not isolated. + private_route_edges = {} # type: Dict[ShortChannelID, RouteEdge] + for private_path in r_tags: + # we need to shift the node pubkey by one towards the destination: + private_path_nodes = [edge[0] for edge in private_path][1:] + [invoice_pubkey] + private_path_rest = [edge[1:] for edge in private_path] + start_node = private_path[0][0] + for end_node, edge_rest in zip(private_path_nodes, private_path_rest): + short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest + short_channel_id = ShortChannelID(short_channel_id) + # if we have a routing policy for this edge in the db, that takes precedence, + # as it is likely from a previous failure + channel_policy = self.channel_db.get_policy_for_node( + short_channel_id=short_channel_id, + node_id=start_node, + my_channels=scid_to_my_channels) + if channel_policy: + fee_base_msat = channel_policy.fee_base_msat + fee_proportional_millionths = channel_policy.fee_proportional_millionths + cltv_expiry_delta = channel_policy.cltv_expiry_delta + node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node) + route_edge = RouteEdge( + start_node=start_node, + end_node=end_node, short_channel_id=short_channel_id, - node_id=prev_node_id, - my_channels=scid_to_my_channels) - if channel_policy: - fee_base_msat = channel_policy.fee_base_msat - fee_proportional_millionths = channel_policy.fee_proportional_millionths - cltv_expiry_delta = channel_policy.cltv_expiry_delta - node_info = self.channel_db.get_node_info_for_node_id(node_id=node_pubkey) - private_route.append( - RouteEdge( - node_id=node_pubkey, - short_channel_id=short_channel_id, - fee_base_msat=fee_base_msat, - fee_proportional_millionths=fee_proportional_millionths, - cltv_expiry_delta=cltv_expiry_delta, - node_features=node_info.features if node_info else 0)) - prev_node_id = node_pubkey - for edge in private_route[::-1]: - amount_for_node += edge.fee_for_edge(amount_for_node) - if full_path: - # user pre-selected path. check that end of given path coincides with private_route: - if [edge.short_channel_id for edge in full_path[-len(private_path):]] != [edge[1] for edge in private_path]: - continue - path = full_path[:-len(private_path)] - if any(edge.short_channel_id in blacklist for edge in private_route): - continue - try: - route = self.network.path_finder.find_route( - self.node_keypair.pubkey, border_node_pubkey, amount_for_node, - path=path, my_channels=scid_to_my_channels, blacklist=blacklist) - except NoChannelPolicy: - continue - if not route: - continue - route = route + private_route - # test sanity - if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry): - self.logger.info(f"rejecting insane route {route}") - continue - break - else: + fee_base_msat=fee_base_msat, + fee_proportional_millionths=fee_proportional_millionths, + cltv_expiry_delta=cltv_expiry_delta, + node_features=node_info.features if node_info else 0) + if route_edge.short_channel_id not in blacklist: + private_route_edges[route_edge.short_channel_id] = route_edge + start_node = end_node + # now find a route, end to end: between us and the recipient + try: + route = self.network.path_finder.find_route( + nodeA=self.node_keypair.pubkey, + nodeB=invoice_pubkey, + invoice_amount_msat=amount_msat, + path=full_path, + my_channels=scid_to_my_channels, + blacklist=blacklist, + private_route_edges=private_route_edges) + except NoChannelPolicy as e: + raise NoPathFound() from e + if not route: + raise NoPathFound() + # test sanity + if not is_route_sane_to_use(route, amount_msat, min_cltv_expiry): + self.logger.info(f"rejecting insane route {route}") raise NoPathFound() assert len(route) > 0 - if route[-1].node_id != invoice_pubkey: + if route[-1].end_node != invoice_pubkey: raise LNPathInconsistent("last node_id != invoice pubkey") # add features from invoice route[-1].node_features |= invoice_features diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -600,17 +600,27 @@ class TestPeer(ElectrumTestCase): peers = graph.all_peers() async def pay(pay_req): with self.subTest(msg="bad path: edges do not chain together"): - path = [PathEdge(node_id=graph.w_c.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id), - PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)] + path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey, + end_node=graph.w_c.node_keypair.pubkey, + short_channel_id=graph.chan_ab.short_channel_id), + PathEdge(start_node=graph.w_b.node_keypair.pubkey, + end_node=graph.w_d.node_keypair.pubkey, + short_channel_id=graph.chan_bd.short_channel_id)] with self.assertRaises(LNPathInconsistent): await graph.w_a.pay_invoice(pay_req, full_path=path) with self.subTest(msg="bad path: last node id differs from invoice pubkey"): - path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id)] + path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey, + end_node=graph.w_b.node_keypair.pubkey, + short_channel_id=graph.chan_ab.short_channel_id)] with self.assertRaises(LNPathInconsistent): await graph.w_a.pay_invoice(pay_req, full_path=path) with self.subTest(msg="good path"): - path = [PathEdge(node_id=graph.w_b.node_keypair.pubkey, short_channel_id=graph.chan_ab.short_channel_id), - PathEdge(node_id=graph.w_d.node_keypair.pubkey, short_channel_id=graph.chan_bd.short_channel_id)] + path = [PathEdge(start_node=graph.w_a.node_keypair.pubkey, + end_node=graph.w_b.node_keypair.pubkey, + short_channel_id=graph.chan_ab.short_channel_id), + PathEdge(start_node=graph.w_b.node_keypair.pubkey, + end_node=graph.w_d.node_keypair.pubkey, + short_channel_id=graph.chan_bd.short_channel_id)] result, log = await graph.w_a.pay_invoice(pay_req, full_path=path) self.assertTrue(result) self.assertEqual( diff --git a/electrum/tests/test_lnrouter.py b/electrum/tests/test_lnrouter.py @@ -83,12 +83,14 @@ class Test_LNRouter(TestCaseForTestnet): cdb.add_channel_update({'short_channel_id': bfh('0000000000000005'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0}) cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 99999999, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0}) cdb.add_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 150, 'chain_hash': BitcoinTestnet.rev_genesis_bytes(), 'timestamp': 0}) - path = path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000) - self.assertEqual([PathEdge(node_id=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', short_channel_id=bfh('0000000000000003')), - PathEdge(node_id=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', short_channel_id=bfh('0000000000000002')), + path = path_finder.find_path_for_payment( + nodeA=b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', + nodeB=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', + invoice_amount_msat=100000) + self.assertEqual([PathEdge(start_node=b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', end_node=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', short_channel_id=bfh('0000000000000003')), + PathEdge(start_node=b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', end_node=b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', short_channel_id=bfh('0000000000000002')), ], path) - start_node = b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa' - route = path_finder.create_route_from_path(path, start_node) + route = path_finder.create_route_from_path(path) self.assertEqual(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', route[0].node_id) self.assertEqual(bfh('0000000000000003'), route[0].short_channel_id) diff --git a/electrum/trampoline.py b/electrum/trampoline.py @@ -61,11 +61,13 @@ def encode_routing_info(r_tags): def create_trampoline_route( + *, amount_msat:int, bucket_amount_msat:int, min_cltv_expiry:int, invoice_pubkey:bytes, invoice_features:int, + my_pubkey: bytes, trampoline_node_id, r_tags, t_tags, trampoline_fee_level, @@ -106,7 +108,8 @@ def create_trampoline_route( # trampoline hop route.append( TrampolineEdge( - node_id=trampoline_node_id, + start_node=my_pubkey, + end_node=trampoline_node_id, fee_base_msat=params['fee_base_msat'], fee_proportional_millionths=params['fee_proportional_millionths'], cltv_expiry_delta=params['cltv_expiry_delta'], @@ -114,7 +117,8 @@ def create_trampoline_route( if trampoline2: route.append( TrampolineEdge( - node_id=trampoline2, + start_node=trampoline_node_id, + end_node=trampoline2, fee_base_msat=params['fee_base_msat'], fee_proportional_millionths=params['fee_proportional_millionths'], cltv_expiry_delta=params['cltv_expiry_delta'], @@ -130,7 +134,8 @@ def create_trampoline_route( if route[-1].node_id != pubkey: route.append( TrampolineEdge( - node_id=pubkey, + start_node=route[-1].node_id, + end_node=pubkey, fee_base_msat=feebase, fee_proportional_millionths=feerate, cltv_expiry_delta=cltv, @@ -138,7 +143,8 @@ def create_trampoline_route( # Fake edge (not part of actual route, needed by calc_hops_data) route.append( TrampolineEdge( - node_id=invoice_pubkey, + start_node=route[-1].end_node, + end_node=invoice_pubkey, fee_base_msat=0, fee_proportional_millionths=0, cltv_expiry_delta=0, @@ -194,6 +200,7 @@ def create_trampoline_route_and_onion( min_cltv_expiry, invoice_pubkey, invoice_features, + my_pubkey: bytes, node_id, r_tags, t_tags, payment_hash, @@ -203,15 +210,17 @@ def create_trampoline_route_and_onion( trampoline2_list): # create route for the trampoline_onion trampoline_route = create_trampoline_route( - amount_msat, - bucket_amount_msat, - min_cltv_expiry, - invoice_pubkey, - invoice_features, - node_id, - r_tags, t_tags, - trampoline_fee_level, - trampoline2_list) + amount_msat=amount_msat, + bucket_amount_msat=bucket_amount_msat, + min_cltv_expiry=min_cltv_expiry, + my_pubkey=my_pubkey, + invoice_pubkey=invoice_pubkey, + invoice_features=invoice_features, + trampoline_node_id=node_id, + r_tags=r_tags, + t_tags=t_tags, + trampoline_fee_level=trampoline_fee_level, + trampoline2_list=trampoline2_list) # compute onion and fees final_cltv = local_height + min_cltv_expiry trampoline_onion, bucket_amount_with_fees, bucket_cltv = create_trampoline_onion(