electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

lnrouter.py (18008B)


      1 # -*- coding: utf-8 -*-
      2 #
      3 # Electrum - lightweight Bitcoin client
      4 # Copyright (C) 2018 The Electrum developers
      5 #
      6 # Permission is hereby granted, free of charge, to any person
      7 # obtaining a copy of this software and associated documentation files
      8 # (the "Software"), to deal in the Software without restriction,
      9 # including without limitation the rights to use, copy, modify, merge,
     10 # publish, distribute, sublicense, and/or sell copies of the Software,
     11 # and to permit persons to whom the Software is furnished to do so,
     12 # subject to the following conditions:
     13 #
     14 # The above copyright notice and this permission notice shall be
     15 # included in all copies or substantial portions of the Software.
     16 #
     17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
     18 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
     19 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
     20 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
     21 # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
     22 # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
     23 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
     24 # SOFTWARE.
     25 
     26 import queue
     27 from collections import defaultdict
     28 from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
     29 import time
     30 import attr
     31 
     32 from .util import bh2u, profiler
     33 from .logging import Logger
     34 from .lnutil import (NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, LnFeatures,
     35                      NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE)
     36 from .channel_db import ChannelDB, Policy, NodeInfo
     37 
     38 if TYPE_CHECKING:
     39     from .lnchannel import Channel
     40 
     41 
     42 class NoChannelPolicy(Exception):
     43     def __init__(self, short_channel_id: bytes):
     44         short_channel_id = ShortChannelID.normalize(short_channel_id)
     45         super().__init__(f'cannot find channel policy for short_channel_id: {short_channel_id}')
     46 
     47 
     48 class LNPathInconsistent(Exception): pass
     49 
     50 
     51 def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_proportional_millionths: int) -> int:
     52     return fee_base_msat \
     53            + (forwarded_amount_msat * fee_proportional_millionths // 1_000_000)
     54 
     55 
     56 @attr.s(slots=True)
     57 class PathEdge:
     58     start_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
     59     end_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
     60     short_channel_id = attr.ib(type=ShortChannelID, kw_only=True, repr=lambda val: str(val))
     61 
     62     @property
     63     def node_id(self) -> bytes:
     64         # legacy compat  # TODO rm
     65         return self.end_node
     66 
     67 @attr.s
     68 class RouteEdge(PathEdge):
     69     fee_base_msat = attr.ib(type=int, kw_only=True)
     70     fee_proportional_millionths = attr.ib(type=int, kw_only=True)
     71     cltv_expiry_delta = attr.ib(type=int, kw_only=True)
     72     node_features = attr.ib(type=int, kw_only=True, repr=lambda val: str(int(val)))  # note: for end node!
     73 
     74     def fee_for_edge(self, amount_msat: int) -> int:
     75         return fee_for_edge_msat(forwarded_amount_msat=amount_msat,
     76                                  fee_base_msat=self.fee_base_msat,
     77                                  fee_proportional_millionths=self.fee_proportional_millionths)
     78 
     79     @classmethod
     80     def from_channel_policy(
     81             cls,
     82             *,
     83             channel_policy: 'Policy',
     84             short_channel_id: bytes,
     85             start_node: bytes,
     86             end_node: bytes,
     87             node_info: Optional[NodeInfo],  # for end_node
     88     ) -> 'RouteEdge':
     89         assert isinstance(short_channel_id, bytes)
     90         assert type(start_node) is bytes
     91         assert type(end_node) is bytes
     92         return RouteEdge(
     93             start_node=start_node,
     94             end_node=end_node,
     95             short_channel_id=ShortChannelID.normalize(short_channel_id),
     96             fee_base_msat=channel_policy.fee_base_msat,
     97             fee_proportional_millionths=channel_policy.fee_proportional_millionths,
     98             cltv_expiry_delta=channel_policy.cltv_expiry_delta,
     99             node_features=node_info.features if node_info else 0)
    100 
    101     def is_sane_to_use(self, amount_msat: int) -> bool:
    102         # TODO revise ad-hoc heuristics
    103         # cltv cannot be more than 2 weeks
    104         if self.cltv_expiry_delta > 14 * 144:
    105             return False
    106         total_fee = self.fee_for_edge(amount_msat)
    107         if not is_fee_sane(total_fee, payment_amount_msat=amount_msat):
    108             return False
    109         return True
    110 
    111     def has_feature_varonion(self) -> bool:
    112         features = LnFeatures(self.node_features)
    113         return features.supports(LnFeatures.VAR_ONION_OPT)
    114 
    115     def is_trampoline(self) -> bool:
    116         return False
    117 
    118 @attr.s
    119 class TrampolineEdge(RouteEdge):
    120     invoice_routing_info = attr.ib(type=bytes, default=None)
    121     invoice_features = attr.ib(type=int, default=None)
    122     # this is re-defined from parent just to specify a default value:
    123     short_channel_id = attr.ib(default=ShortChannelID(8), repr=lambda val: str(val))
    124 
    125     def is_trampoline(self):
    126         return True
    127 
    128 
    129 LNPaymentPath = Sequence[PathEdge]
    130 LNPaymentRoute = Sequence[RouteEdge]
    131 
    132 
    133 def is_route_sane_to_use(route: LNPaymentRoute, invoice_amount_msat: int, min_final_cltv_expiry: int) -> bool:
    134     """Run some sanity checks on the whole route, before attempting to use it.
    135     called when we are paying; so e.g. lower cltv is better
    136     """
    137     if len(route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
    138         return False
    139     amt = invoice_amount_msat
    140     cltv = min_final_cltv_expiry
    141     for route_edge in reversed(route[1:]):
    142         if not route_edge.is_sane_to_use(amt): return False
    143         amt += route_edge.fee_for_edge(amt)
    144         cltv += route_edge.cltv_expiry_delta
    145     total_fee = amt - invoice_amount_msat
    146     # TODO revise ad-hoc heuristics
    147     if cltv > NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE:
    148         return False
    149     if not is_fee_sane(total_fee, payment_amount_msat=invoice_amount_msat):
    150         return False
    151     return True
    152 
    153 
    154 def is_fee_sane(fee_msat: int, *, payment_amount_msat: int) -> bool:
    155     # fees <= 5 sat are fine
    156     if fee_msat <= 5_000:
    157         return True
    158     # fees <= 1 % of payment are fine
    159     if 100 * fee_msat <= payment_amount_msat:
    160         return True
    161     return False
    162 
    163 
    164 
    165 class LNPathFinder(Logger):
    166 
    167     def __init__(self, channel_db: ChannelDB):
    168         Logger.__init__(self)
    169         self.channel_db = channel_db
    170 
    171     def _edge_cost(
    172             self,
    173             *,
    174             short_channel_id: bytes,
    175             start_node: bytes,
    176             end_node: bytes,
    177             payment_amt_msat: int,
    178             ignore_costs=False,
    179             is_mine=False,
    180             my_channels: Dict[ShortChannelID, 'Channel'] = None,
    181             private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
    182     ) -> Tuple[float, int]:
    183         """Heuristic cost (distance metric) of going through a channel.
    184         Returns (heuristic_cost, fee_for_edge_msat).
    185         """
    186         if private_route_edges is None:
    187             private_route_edges = {}
    188         channel_info = self.channel_db.get_channel_info(
    189             short_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
    190         if channel_info is None:
    191             return float('inf'), 0
    192         channel_policy = self.channel_db.get_policy_for_node(
    193             short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges)
    194         if channel_policy is None:
    195             return float('inf'), 0
    196         # channels that did not publish both policies often return temporary channel failure
    197         channel_policy_backwards = self.channel_db.get_policy_for_node(
    198             short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges)
    199         if (channel_policy_backwards is None
    200                 and not is_mine
    201                 and short_channel_id not in private_route_edges):
    202             return float('inf'), 0
    203         if channel_policy.is_disabled():
    204             return float('inf'), 0
    205         if payment_amt_msat < channel_policy.htlc_minimum_msat:
    206             return float('inf'), 0  # payment amount too little
    207         if channel_info.capacity_sat is not None and \
    208                 payment_amt_msat // 1000 > channel_info.capacity_sat:
    209             return float('inf'), 0  # payment amount too large
    210         if channel_policy.htlc_maximum_msat is not None and \
    211                 payment_amt_msat > channel_policy.htlc_maximum_msat:
    212             return float('inf'), 0  # payment amount too large
    213         route_edge = private_route_edges.get(short_channel_id, None)
    214         if route_edge is None:
    215             node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
    216             route_edge = RouteEdge.from_channel_policy(
    217                 channel_policy=channel_policy,
    218                 short_channel_id=short_channel_id,
    219                 start_node=start_node,
    220                 end_node=end_node,
    221                 node_info=node_info)
    222         if not route_edge.is_sane_to_use(payment_amt_msat):
    223             return float('inf'), 0  # thanks but no thanks
    224 
    225         # Distance metric notes:  # TODO constants are ad-hoc
    226         # ( somewhat based on https://github.com/lightningnetwork/lnd/pull/1358 )
    227         # - Edges have a base cost. (more edges -> less likely none will fail)
    228         # - The larger the payment amount, and the longer the CLTV,
    229         #   the more irritating it is if the HTLC gets stuck.
    230         # - Paying lower fees is better. :)
    231         base_cost = 500  # one more edge ~ paying 500 msat more fees
    232         if ignore_costs:
    233             return base_cost, 0
    234         fee_msat = route_edge.fee_for_edge(payment_amt_msat)
    235         cltv_cost = route_edge.cltv_expiry_delta * payment_amt_msat * 15 / 1_000_000_000
    236         overall_cost = base_cost + fee_msat + cltv_cost
    237         return overall_cost, fee_msat
    238 
    239     def get_distances(
    240             self,
    241             *,
    242             nodeA: bytes,
    243             nodeB: bytes,
    244             invoice_amount_msat: int,
    245             my_channels: Dict[ShortChannelID, 'Channel'] = None,
    246             blacklist: Set[ShortChannelID] = None,
    247             private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
    248     ) -> Dict[bytes, PathEdge]:
    249         # note: we don't lock self.channel_db, so while the path finding runs,
    250         #       the underlying graph could potentially change... (not good but maybe ~OK?)
    251 
    252         # run Dijkstra
    253         # The search is run in the REVERSE direction, from nodeB to nodeA,
    254         # to properly calculate compound routing fees.
    255         distance_from_start = defaultdict(lambda: float('inf'))
    256         distance_from_start[nodeB] = 0
    257         prev_node = {}  # type: Dict[bytes, PathEdge]
    258         nodes_to_explore = queue.PriorityQueue()
    259         nodes_to_explore.put((0, invoice_amount_msat, nodeB))  # order of fields (in tuple) matters!
    260 
    261         # main loop of search
    262         while nodes_to_explore.qsize() > 0:
    263             dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get()
    264             if edge_endnode == nodeA:
    265                 break
    266             if dist_to_edge_endnode != distance_from_start[edge_endnode]:
    267                 # queue.PriorityQueue does not implement decrease_priority,
    268                 # so instead of decreasing priorities, we add items again into the queue.
    269                 # so there are duplicates in the queue, that we discard now:
    270                 continue
    271             for edge_channel_id in self.channel_db.get_channels_for_node(
    272                     edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges):
    273                 assert isinstance(edge_channel_id, bytes)
    274                 if blacklist and edge_channel_id in blacklist:
    275                     continue
    276                 channel_info = self.channel_db.get_channel_info(
    277                     edge_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
    278                 if channel_info is None:
    279                     continue
    280                 edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
    281                 is_mine = edge_channel_id in my_channels
    282                 if is_mine:
    283                     if edge_startnode == nodeA:  # payment outgoing, on our channel
    284                         if not my_channels[edge_channel_id].can_pay(amount_msat, check_frozen=True):
    285                             continue
    286                     else:  # payment incoming, on our channel. (funny business, cycle weirdness)
    287                         assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode))
    288                         if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True):
    289                             continue
    290                 edge_cost, fee_for_edge_msat = self._edge_cost(
    291                     short_channel_id=edge_channel_id,
    292                     start_node=edge_startnode,
    293                     end_node=edge_endnode,
    294                     payment_amt_msat=amount_msat,
    295                     ignore_costs=(edge_startnode == nodeA),
    296                     is_mine=is_mine,
    297                     my_channels=my_channels,
    298                     private_route_edges=private_route_edges)
    299                 alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
    300                 if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
    301                     distance_from_start[edge_startnode] = alt_dist_to_neighbour
    302                     prev_node[edge_startnode] = PathEdge(
    303                         start_node=edge_startnode,
    304                         end_node=edge_endnode,
    305                         short_channel_id=ShortChannelID(edge_channel_id))
    306                     amount_to_forward_msat = amount_msat + fee_for_edge_msat
    307                     nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
    308 
    309         return prev_node
    310 
    311     @profiler
    312     def find_path_for_payment(
    313             self,
    314             *,
    315             nodeA: bytes,
    316             nodeB: bytes,
    317             invoice_amount_msat: int,
    318             my_channels: Dict[ShortChannelID, 'Channel'] = None,
    319             blacklist: Set[ShortChannelID] = None,
    320             private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
    321     ) -> Optional[LNPaymentPath]:
    322         """Return a path from nodeA to nodeB."""
    323         assert type(nodeA) is bytes
    324         assert type(nodeB) is bytes
    325         assert type(invoice_amount_msat) is int
    326         if my_channels is None:
    327             my_channels = {}
    328 
    329         prev_node = self.get_distances(
    330             nodeA=nodeA,
    331             nodeB=nodeB,
    332             invoice_amount_msat=invoice_amount_msat,
    333             my_channels=my_channels,
    334             blacklist=blacklist,
    335             private_route_edges=private_route_edges)
    336 
    337         if nodeA not in prev_node:
    338             return None  # no path found
    339 
    340         # backtrack from search_end (nodeA) to search_start (nodeB)
    341         # FIXME paths cannot be longer than 20 edges (onion packet)...
    342         edge_startnode = nodeA
    343         path = []
    344         while edge_startnode != nodeB:
    345             edge = prev_node[edge_startnode]
    346             path += [edge]
    347             edge_startnode = edge.node_id
    348         return path
    349 
    350     def create_route_from_path(
    351             self,
    352             path: Optional[LNPaymentPath],
    353             *,
    354             my_channels: Dict[ShortChannelID, 'Channel'] = None,
    355             private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
    356     ) -> LNPaymentRoute:
    357         if path is None:
    358             raise Exception('cannot create route from None path')
    359         if private_route_edges is None:
    360             private_route_edges = {}
    361         route = []
    362         prev_end_node = path[0].start_node
    363         for path_edge in path:
    364             short_channel_id = path_edge.short_channel_id
    365             _endnodes = self.channel_db.get_endnodes_for_chan(short_channel_id, my_channels=my_channels)
    366             if _endnodes and sorted(_endnodes) != sorted([path_edge.start_node, path_edge.end_node]):
    367                 raise LNPathInconsistent("endpoints of edge inconsistent with short_channel_id")
    368             if path_edge.start_node != prev_end_node:
    369                 raise LNPathInconsistent("edges do not chain together")
    370             route_edge = private_route_edges.get(short_channel_id, None)
    371             if route_edge is None:
    372                 channel_policy = self.channel_db.get_policy_for_node(
    373                     short_channel_id=short_channel_id,
    374                     node_id=path_edge.start_node,
    375                     my_channels=my_channels)
    376                 if channel_policy is None:
    377                     raise NoChannelPolicy(short_channel_id)
    378                 node_info = self.channel_db.get_node_info_for_node_id(node_id=path_edge.end_node)
    379                 route_edge = RouteEdge.from_channel_policy(
    380                     channel_policy=channel_policy,
    381                     short_channel_id=short_channel_id,
    382                     start_node=path_edge.start_node,
    383                     end_node=path_edge.end_node,
    384                     node_info=node_info)
    385             route.append(route_edge)
    386             prev_end_node = path_edge.end_node
    387         return route
    388 
    389     def find_route(
    390             self,
    391             *,
    392             nodeA: bytes,
    393             nodeB: bytes,
    394             invoice_amount_msat: int,
    395             path = None,
    396             my_channels: Dict[ShortChannelID, 'Channel'] = None,
    397             blacklist: Set[ShortChannelID] = None,
    398             private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
    399     ) -> Optional[LNPaymentRoute]:
    400         route = None
    401         if not path:
    402             path = self.find_path_for_payment(
    403                 nodeA=nodeA,
    404                 nodeB=nodeB,
    405                 invoice_amount_msat=invoice_amount_msat,
    406                 my_channels=my_channels,
    407                 blacklist=blacklist,
    408                 private_route_edges=private_route_edges)
    409         if path:
    410             route = self.create_route_from_path(
    411                 path, my_channels=my_channels, private_route_edges=private_route_edges)
    412         return route