lnrouter.py (18008B)
1 # -*- coding: utf-8 -*- 2 # 3 # Electrum - lightweight Bitcoin client 4 # Copyright (C) 2018 The Electrum developers 5 # 6 # Permission is hereby granted, free of charge, to any person 7 # obtaining a copy of this software and associated documentation files 8 # (the "Software"), to deal in the Software without restriction, 9 # including without limitation the rights to use, copy, modify, merge, 10 # publish, distribute, sublicense, and/or sell copies of the Software, 11 # and to permit persons to whom the Software is furnished to do so, 12 # subject to the following conditions: 13 # 14 # The above copyright notice and this permission notice shall be 15 # included in all copies or substantial portions of the Software. 16 # 17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 21 # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 22 # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 23 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 # SOFTWARE. 25 26 import queue 27 from collections import defaultdict 28 from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set 29 import time 30 import attr 31 32 from .util import bh2u, profiler 33 from .logging import Logger 34 from .lnutil import (NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, LnFeatures, 35 NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE) 36 from .channel_db import ChannelDB, Policy, NodeInfo 37 38 if TYPE_CHECKING: 39 from .lnchannel import Channel 40 41 42 class NoChannelPolicy(Exception): 43 def __init__(self, short_channel_id: bytes): 44 short_channel_id = ShortChannelID.normalize(short_channel_id) 45 super().__init__(f'cannot find channel policy for short_channel_id: {short_channel_id}') 46 47 48 class LNPathInconsistent(Exception): pass 49 50 51 def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_proportional_millionths: int) -> int: 52 return fee_base_msat \ 53 + (forwarded_amount_msat * fee_proportional_millionths // 1_000_000) 54 55 56 @attr.s(slots=True) 57 class PathEdge: 58 start_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex()) 59 end_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex()) 60 short_channel_id = attr.ib(type=ShortChannelID, kw_only=True, repr=lambda val: str(val)) 61 62 @property 63 def node_id(self) -> bytes: 64 # legacy compat # TODO rm 65 return self.end_node 66 67 @attr.s 68 class RouteEdge(PathEdge): 69 fee_base_msat = attr.ib(type=int, kw_only=True) 70 fee_proportional_millionths = attr.ib(type=int, kw_only=True) 71 cltv_expiry_delta = attr.ib(type=int, kw_only=True) 72 node_features = attr.ib(type=int, kw_only=True, repr=lambda val: str(int(val))) # note: for end node! 73 74 def fee_for_edge(self, amount_msat: int) -> int: 75 return fee_for_edge_msat(forwarded_amount_msat=amount_msat, 76 fee_base_msat=self.fee_base_msat, 77 fee_proportional_millionths=self.fee_proportional_millionths) 78 79 @classmethod 80 def from_channel_policy( 81 cls, 82 *, 83 channel_policy: 'Policy', 84 short_channel_id: bytes, 85 start_node: bytes, 86 end_node: bytes, 87 node_info: Optional[NodeInfo], # for end_node 88 ) -> 'RouteEdge': 89 assert isinstance(short_channel_id, bytes) 90 assert type(start_node) is bytes 91 assert type(end_node) is bytes 92 return RouteEdge( 93 start_node=start_node, 94 end_node=end_node, 95 short_channel_id=ShortChannelID.normalize(short_channel_id), 96 fee_base_msat=channel_policy.fee_base_msat, 97 fee_proportional_millionths=channel_policy.fee_proportional_millionths, 98 cltv_expiry_delta=channel_policy.cltv_expiry_delta, 99 node_features=node_info.features if node_info else 0) 100 101 def is_sane_to_use(self, amount_msat: int) -> bool: 102 # TODO revise ad-hoc heuristics 103 # cltv cannot be more than 2 weeks 104 if self.cltv_expiry_delta > 14 * 144: 105 return False 106 total_fee = self.fee_for_edge(amount_msat) 107 if not is_fee_sane(total_fee, payment_amount_msat=amount_msat): 108 return False 109 return True 110 111 def has_feature_varonion(self) -> bool: 112 features = LnFeatures(self.node_features) 113 return features.supports(LnFeatures.VAR_ONION_OPT) 114 115 def is_trampoline(self) -> bool: 116 return False 117 118 @attr.s 119 class TrampolineEdge(RouteEdge): 120 invoice_routing_info = attr.ib(type=bytes, default=None) 121 invoice_features = attr.ib(type=int, default=None) 122 # this is re-defined from parent just to specify a default value: 123 short_channel_id = attr.ib(default=ShortChannelID(8), repr=lambda val: str(val)) 124 125 def is_trampoline(self): 126 return True 127 128 129 LNPaymentPath = Sequence[PathEdge] 130 LNPaymentRoute = Sequence[RouteEdge] 131 132 133 def is_route_sane_to_use(route: LNPaymentRoute, invoice_amount_msat: int, min_final_cltv_expiry: int) -> bool: 134 """Run some sanity checks on the whole route, before attempting to use it. 135 called when we are paying; so e.g. lower cltv is better 136 """ 137 if len(route) > NUM_MAX_EDGES_IN_PAYMENT_PATH: 138 return False 139 amt = invoice_amount_msat 140 cltv = min_final_cltv_expiry 141 for route_edge in reversed(route[1:]): 142 if not route_edge.is_sane_to_use(amt): return False 143 amt += route_edge.fee_for_edge(amt) 144 cltv += route_edge.cltv_expiry_delta 145 total_fee = amt - invoice_amount_msat 146 # TODO revise ad-hoc heuristics 147 if cltv > NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE: 148 return False 149 if not is_fee_sane(total_fee, payment_amount_msat=invoice_amount_msat): 150 return False 151 return True 152 153 154 def is_fee_sane(fee_msat: int, *, payment_amount_msat: int) -> bool: 155 # fees <= 5 sat are fine 156 if fee_msat <= 5_000: 157 return True 158 # fees <= 1 % of payment are fine 159 if 100 * fee_msat <= payment_amount_msat: 160 return True 161 return False 162 163 164 165 class LNPathFinder(Logger): 166 167 def __init__(self, channel_db: ChannelDB): 168 Logger.__init__(self) 169 self.channel_db = channel_db 170 171 def _edge_cost( 172 self, 173 *, 174 short_channel_id: bytes, 175 start_node: bytes, 176 end_node: bytes, 177 payment_amt_msat: int, 178 ignore_costs=False, 179 is_mine=False, 180 my_channels: Dict[ShortChannelID, 'Channel'] = None, 181 private_route_edges: Dict[ShortChannelID, RouteEdge] = None, 182 ) -> Tuple[float, int]: 183 """Heuristic cost (distance metric) of going through a channel. 184 Returns (heuristic_cost, fee_for_edge_msat). 185 """ 186 if private_route_edges is None: 187 private_route_edges = {} 188 channel_info = self.channel_db.get_channel_info( 189 short_channel_id, my_channels=my_channels, private_route_edges=private_route_edges) 190 if channel_info is None: 191 return float('inf'), 0 192 channel_policy = self.channel_db.get_policy_for_node( 193 short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges) 194 if channel_policy is None: 195 return float('inf'), 0 196 # channels that did not publish both policies often return temporary channel failure 197 channel_policy_backwards = self.channel_db.get_policy_for_node( 198 short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges) 199 if (channel_policy_backwards is None 200 and not is_mine 201 and short_channel_id not in private_route_edges): 202 return float('inf'), 0 203 if channel_policy.is_disabled(): 204 return float('inf'), 0 205 if payment_amt_msat < channel_policy.htlc_minimum_msat: 206 return float('inf'), 0 # payment amount too little 207 if channel_info.capacity_sat is not None and \ 208 payment_amt_msat // 1000 > channel_info.capacity_sat: 209 return float('inf'), 0 # payment amount too large 210 if channel_policy.htlc_maximum_msat is not None and \ 211 payment_amt_msat > channel_policy.htlc_maximum_msat: 212 return float('inf'), 0 # payment amount too large 213 route_edge = private_route_edges.get(short_channel_id, None) 214 if route_edge is None: 215 node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node) 216 route_edge = RouteEdge.from_channel_policy( 217 channel_policy=channel_policy, 218 short_channel_id=short_channel_id, 219 start_node=start_node, 220 end_node=end_node, 221 node_info=node_info) 222 if not route_edge.is_sane_to_use(payment_amt_msat): 223 return float('inf'), 0 # thanks but no thanks 224 225 # Distance metric notes: # TODO constants are ad-hoc 226 # ( somewhat based on https://github.com/lightningnetwork/lnd/pull/1358 ) 227 # - Edges have a base cost. (more edges -> less likely none will fail) 228 # - The larger the payment amount, and the longer the CLTV, 229 # the more irritating it is if the HTLC gets stuck. 230 # - Paying lower fees is better. :) 231 base_cost = 500 # one more edge ~ paying 500 msat more fees 232 if ignore_costs: 233 return base_cost, 0 234 fee_msat = route_edge.fee_for_edge(payment_amt_msat) 235 cltv_cost = route_edge.cltv_expiry_delta * payment_amt_msat * 15 / 1_000_000_000 236 overall_cost = base_cost + fee_msat + cltv_cost 237 return overall_cost, fee_msat 238 239 def get_distances( 240 self, 241 *, 242 nodeA: bytes, 243 nodeB: bytes, 244 invoice_amount_msat: int, 245 my_channels: Dict[ShortChannelID, 'Channel'] = None, 246 blacklist: Set[ShortChannelID] = None, 247 private_route_edges: Dict[ShortChannelID, RouteEdge] = None, 248 ) -> Dict[bytes, PathEdge]: 249 # note: we don't lock self.channel_db, so while the path finding runs, 250 # the underlying graph could potentially change... (not good but maybe ~OK?) 251 252 # run Dijkstra 253 # The search is run in the REVERSE direction, from nodeB to nodeA, 254 # to properly calculate compound routing fees. 255 distance_from_start = defaultdict(lambda: float('inf')) 256 distance_from_start[nodeB] = 0 257 prev_node = {} # type: Dict[bytes, PathEdge] 258 nodes_to_explore = queue.PriorityQueue() 259 nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters! 260 261 # main loop of search 262 while nodes_to_explore.qsize() > 0: 263 dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get() 264 if edge_endnode == nodeA: 265 break 266 if dist_to_edge_endnode != distance_from_start[edge_endnode]: 267 # queue.PriorityQueue does not implement decrease_priority, 268 # so instead of decreasing priorities, we add items again into the queue. 269 # so there are duplicates in the queue, that we discard now: 270 continue 271 for edge_channel_id in self.channel_db.get_channels_for_node( 272 edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges): 273 assert isinstance(edge_channel_id, bytes) 274 if blacklist and edge_channel_id in blacklist: 275 continue 276 channel_info = self.channel_db.get_channel_info( 277 edge_channel_id, my_channels=my_channels, private_route_edges=private_route_edges) 278 if channel_info is None: 279 continue 280 edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id 281 is_mine = edge_channel_id in my_channels 282 if is_mine: 283 if edge_startnode == nodeA: # payment outgoing, on our channel 284 if not my_channels[edge_channel_id].can_pay(amount_msat, check_frozen=True): 285 continue 286 else: # payment incoming, on our channel. (funny business, cycle weirdness) 287 assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode)) 288 if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True): 289 continue 290 edge_cost, fee_for_edge_msat = self._edge_cost( 291 short_channel_id=edge_channel_id, 292 start_node=edge_startnode, 293 end_node=edge_endnode, 294 payment_amt_msat=amount_msat, 295 ignore_costs=(edge_startnode == nodeA), 296 is_mine=is_mine, 297 my_channels=my_channels, 298 private_route_edges=private_route_edges) 299 alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost 300 if alt_dist_to_neighbour < distance_from_start[edge_startnode]: 301 distance_from_start[edge_startnode] = alt_dist_to_neighbour 302 prev_node[edge_startnode] = PathEdge( 303 start_node=edge_startnode, 304 end_node=edge_endnode, 305 short_channel_id=ShortChannelID(edge_channel_id)) 306 amount_to_forward_msat = amount_msat + fee_for_edge_msat 307 nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode)) 308 309 return prev_node 310 311 @profiler 312 def find_path_for_payment( 313 self, 314 *, 315 nodeA: bytes, 316 nodeB: bytes, 317 invoice_amount_msat: int, 318 my_channels: Dict[ShortChannelID, 'Channel'] = None, 319 blacklist: Set[ShortChannelID] = None, 320 private_route_edges: Dict[ShortChannelID, RouteEdge] = None, 321 ) -> Optional[LNPaymentPath]: 322 """Return a path from nodeA to nodeB.""" 323 assert type(nodeA) is bytes 324 assert type(nodeB) is bytes 325 assert type(invoice_amount_msat) is int 326 if my_channels is None: 327 my_channels = {} 328 329 prev_node = self.get_distances( 330 nodeA=nodeA, 331 nodeB=nodeB, 332 invoice_amount_msat=invoice_amount_msat, 333 my_channels=my_channels, 334 blacklist=blacklist, 335 private_route_edges=private_route_edges) 336 337 if nodeA not in prev_node: 338 return None # no path found 339 340 # backtrack from search_end (nodeA) to search_start (nodeB) 341 # FIXME paths cannot be longer than 20 edges (onion packet)... 342 edge_startnode = nodeA 343 path = [] 344 while edge_startnode != nodeB: 345 edge = prev_node[edge_startnode] 346 path += [edge] 347 edge_startnode = edge.node_id 348 return path 349 350 def create_route_from_path( 351 self, 352 path: Optional[LNPaymentPath], 353 *, 354 my_channels: Dict[ShortChannelID, 'Channel'] = None, 355 private_route_edges: Dict[ShortChannelID, RouteEdge] = None, 356 ) -> LNPaymentRoute: 357 if path is None: 358 raise Exception('cannot create route from None path') 359 if private_route_edges is None: 360 private_route_edges = {} 361 route = [] 362 prev_end_node = path[0].start_node 363 for path_edge in path: 364 short_channel_id = path_edge.short_channel_id 365 _endnodes = self.channel_db.get_endnodes_for_chan(short_channel_id, my_channels=my_channels) 366 if _endnodes and sorted(_endnodes) != sorted([path_edge.start_node, path_edge.end_node]): 367 raise LNPathInconsistent("endpoints of edge inconsistent with short_channel_id") 368 if path_edge.start_node != prev_end_node: 369 raise LNPathInconsistent("edges do not chain together") 370 route_edge = private_route_edges.get(short_channel_id, None) 371 if route_edge is None: 372 channel_policy = self.channel_db.get_policy_for_node( 373 short_channel_id=short_channel_id, 374 node_id=path_edge.start_node, 375 my_channels=my_channels) 376 if channel_policy is None: 377 raise NoChannelPolicy(short_channel_id) 378 node_info = self.channel_db.get_node_info_for_node_id(node_id=path_edge.end_node) 379 route_edge = RouteEdge.from_channel_policy( 380 channel_policy=channel_policy, 381 short_channel_id=short_channel_id, 382 start_node=path_edge.start_node, 383 end_node=path_edge.end_node, 384 node_info=node_info) 385 route.append(route_edge) 386 prev_end_node = path_edge.end_node 387 return route 388 389 def find_route( 390 self, 391 *, 392 nodeA: bytes, 393 nodeB: bytes, 394 invoice_amount_msat: int, 395 path = None, 396 my_channels: Dict[ShortChannelID, 'Channel'] = None, 397 blacklist: Set[ShortChannelID] = None, 398 private_route_edges: Dict[ShortChannelID, RouteEdge] = None, 399 ) -> Optional[LNPaymentRoute]: 400 route = None 401 if not path: 402 path = self.find_path_for_payment( 403 nodeA=nodeA, 404 nodeB=nodeB, 405 invoice_amount_msat=invoice_amount_msat, 406 my_channels=my_channels, 407 blacklist=blacklist, 408 private_route_edges=private_route_edges) 409 if path: 410 route = self.create_route_from_path( 411 path, my_channels=my_channels, private_route_edges=private_route_edges) 412 return route