electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 2fafd01945569cb0ef1cad9320e426874fb40f7d
parent c577df84898c470d45287b39ef484d16cfed4cc4
Author: SomberNight <somber.night@protonmail.com>
Date:   Fri, 19 Oct 2018 21:47:51 +0200

protect against getting robbed through routing fees

Diffstat:
Melectrum/lnonion.py | 10++++++----
Melectrum/lnrouter.py | 105+++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------
Melectrum/lnutil.py | 9++++++---
Melectrum/lnworker.py | 18++++++++++++++++--
4 files changed, 102 insertions(+), 40 deletions(-)

diff --git a/electrum/lnonion.py b/electrum/lnonion.py @@ -33,11 +33,10 @@ from cryptography.hazmat.backends import default_backend from . import ecc from .crypto import sha256, hmac_oneshot from .util import bh2u, profiler, xor_bytes, bfh -from .lnutil import get_ecdh +from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH from .lnrouter import RouteEdge -NUM_MAX_HOPS_IN_PATH = 20 HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04 PER_HOP_FULL_SIZE = 65 # HOPS_DATA_SIZE / 20 NUM_STREAM_BYTES = HOPS_DATA_SIZE + PER_HOP_FULL_SIZE @@ -192,6 +191,9 @@ def calc_hops_data_for_payment(route: List[RouteEdge], amount_msat: int, final_c """Returns the hops_data to be used for constructing an onion packet, and the amount_msat and cltv to be used on our immediate channel. """ + if len(route) > NUM_MAX_HOPS_IN_PAYMENT_PATH: + raise PaymentFailure(f"too long route ({len(route)} hops)") + amt = amount_msat cltv = final_cltv hops_data = [OnionHopsDataSingle(OnionPerHop(b"\x00" * 8, @@ -209,7 +211,7 @@ def calc_hops_data_for_payment(route: List[RouteEdge], amount_msat: int, final_c def generate_filler(key_type: bytes, num_hops: int, hop_size: int, shared_secrets: Sequence[bytes]) -> bytes: - filler_size = (NUM_MAX_HOPS_IN_PATH + 1) * hop_size + filler_size = (NUM_MAX_HOPS_IN_PAYMENT_PATH + 1) * hop_size filler = bytearray(filler_size) for i in range(0, num_hops-1): # -1, as last hop does not obfuscate @@ -219,7 +221,7 @@ def generate_filler(key_type: bytes, num_hops: int, hop_size: int, stream_bytes = generate_cipher_stream(stream_key, filler_size) filler = xor_bytes(filler, stream_bytes) - return filler[(NUM_MAX_HOPS_IN_PATH-num_hops+2)*hop_size:] + return filler[(NUM_MAX_HOPS_IN_PAYMENT_PATH-num_hops+2)*hop_size:] def generate_cipher_stream(stream_key: bytes, num_bytes: int) -> bytes: diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py @@ -39,7 +39,7 @@ from .storage import JsonDB from .lnchannelverifier import LNChannelVerifier, verify_sig_for_channel_update from .crypto import Hash from . import ecc -from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr +from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_HOPS_IN_PAYMENT_PATH class UnknownEvenFeatureBits(Exception): pass @@ -502,10 +502,61 @@ class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes), ('cltv_expiry_delta', int)])): """if you travel through short_channel_id, you will reach node_id""" - def fee_for_edge(self, amount_msat): + def fee_for_edge(self, amount_msat: int) -> int: return self.fee_base_msat \ + (amount_msat * self.fee_proportional_millionths // 1_000_000) + @classmethod + def from_channel_policy(cls, channel_policy: ChannelInfoDirectedPolicy, + short_channel_id: bytes, end_node: bytes) -> 'RouteEdge': + return RouteEdge(end_node, + short_channel_id, + channel_policy.fee_base_msat, + channel_policy.fee_proportional_millionths, + channel_policy.cltv_expiry_delta) + + def is_sane_to_use(self, amount_msat: int) -> bool: + # TODO revise ad-hoc heuristics + # cltv cannot be more than 2 weeks + if self.cltv_expiry_delta > 14 * 144: return False + total_fee = self.fee_for_edge(amount_msat) + # fees below 50 sat are fine + if total_fee > 50_000: + # fee cannot be higher than amt + if total_fee > amount_msat: return False + # fee cannot be higher than 5000 sat + if total_fee > 5_000_000: return False + # unless amt is tiny, fee cannot be more than 10% + if amount_msat > 1_000_000 and total_fee > amount_msat/10: return False + return True + + +def is_route_sane_to_use(route: List[RouteEdge], invoice_amount_msat: int, min_final_cltv_expiry: int) -> bool: + """Run some sanity checks on the whole route, before attempting to use it. + called when we are paying; so e.g. lower cltv is better + """ + if len(route) > NUM_MAX_HOPS_IN_PAYMENT_PATH: + return False + amt = invoice_amount_msat + cltv = min_final_cltv_expiry + for route_edge in reversed(route[1:]): + if not route_edge.is_sane_to_use(amt): return False + amt += route_edge.fee_for_edge(amt) + cltv += route_edge.cltv_expiry_delta + total_fee = amt - invoice_amount_msat + # TODO revise ad-hoc heuristics + # cltv cannot be more than 2 months + if cltv > 60 * 144: return False + # fees below 50 sat are fine + if total_fee > 50_000: + # fee cannot be higher than amt + if total_fee > invoice_amount_msat: return False + # fee cannot be higher than 5000 sat + if total_fee > 5_000_000: return False + # unless amt is tiny, fee cannot be more than 10% + if invoice_amount_msat > 1_000_000 and total_fee > invoice_amount_msat/10: return False + return True + class LNPathFinder(PrintError): @@ -513,11 +564,9 @@ class LNPathFinder(PrintError): self.channel_db = channel_db self.blacklist = set() - def _edge_cost(self, short_channel_id: bytes, start_node: bytes, payment_amt_msat: int, - ignore_cltv=False) -> float: - """Heuristic cost of going through a channel. - direction: 0 or 1. --- 0 means node_id_1 -> node_id_2 - """ + def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes, + payment_amt_msat: int, ignore_cltv=False) -> float: + """Heuristic cost of going through a channel.""" channel_info = self.channel_db.get_channel_info(short_channel_id) # type: ChannelInfo if channel_info is None: return float('inf') @@ -525,41 +574,39 @@ class LNPathFinder(PrintError): channel_policy = channel_info.get_policy_for_node(start_node) if channel_policy is None: return float('inf') if channel_policy.disabled: return float('inf') - cltv_expiry_delta = channel_policy.cltv_expiry_delta - htlc_minimum_msat = channel_policy.htlc_minimum_msat - fee_base_msat = channel_policy.fee_base_msat - fee_proportional_millionths = channel_policy.fee_proportional_millionths - if payment_amt_msat is not None: - if payment_amt_msat < htlc_minimum_msat: - return float('inf') # payment amount too little - if channel_info.capacity_sat is not None and \ - payment_amt_msat // 1000 > channel_info.capacity_sat: - return float('inf') # payment amount too large - if channel_policy.htlc_maximum_msat is not None and \ - payment_amt_msat > channel_policy.htlc_maximum_msat: - return float('inf') # payment amount too large - amt = payment_amt_msat or 50000 * 1000 # guess for typical payment amount - fee_msat = fee_base_msat + amt * fee_proportional_millionths / 1_000_000 + route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node) + if payment_amt_msat < channel_policy.htlc_minimum_msat: + return float('inf') # payment amount too little + if channel_info.capacity_sat is not None and \ + payment_amt_msat // 1000 > channel_info.capacity_sat: + return float('inf') # payment amount too large + if channel_policy.htlc_maximum_msat is not None and \ + payment_amt_msat > channel_policy.htlc_maximum_msat: + return float('inf') # payment amount too large + if not route_edge.is_sane_to_use(payment_amt_msat): + return float('inf') # thanks but no thanks + fee_msat = route_edge.fee_for_edge(payment_amt_msat) # TODO revise # paying 10 more satoshis ~ waiting one more block fee_cost = fee_msat / 1000 / 10 - cltv_cost = cltv_expiry_delta if not ignore_cltv else 0 + cltv_cost = route_edge.cltv_expiry_delta if not ignore_cltv else 0 return cltv_cost + fee_cost + 1 @profiler def find_path_for_payment(self, from_node_id: bytes, to_node_id: bytes, - amount_msat: int=None, my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]: + amount_msat: int, my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]: """Return a path between from_node_id and to_node_id. Returns a list of (node_id, short_channel_id) representing a path. To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1]; i.e. an element reads as, "to get to node_id, travel through short_channel_id" """ - if amount_msat is not None: assert type(amount_msat) is int + assert type(amount_msat) is int if my_channels is None: my_channels = [] unable_channels = set(map(lambda x: x.short_channel_id, filter(lambda x: not x.can_pay(amount_msat), my_channels))) # TODO find multiple paths?? + # FIXME paths cannot be longer than 20 (onion packet)... # run Dijkstra distance_from_start = defaultdict(lambda: float('inf')) @@ -584,7 +631,7 @@ class LNPathFinder(PrintError): node1, node2 = channel_info.node_id_1, channel_info.node_id_2 neighbour = node2 if node1 == cur_node else node1 ignore_cltv_delta_in_edge_cost = cur_node == from_node_id - edge_cost = self._edge_cost(edge_channel_id, cur_node, amount_msat, + edge_cost = self._edge_cost(edge_channel_id, cur_node, neighbour, amount_msat, ignore_cltv=ignore_cltv_delta_in_edge_cost) alt_dist_to_neighbour = distance_from_start[cur_node] + edge_cost if alt_dist_to_neighbour < distance_from_start[neighbour]: @@ -614,10 +661,6 @@ class LNPathFinder(PrintError): channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id) if channel_policy is None: raise Exception(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}') - route.append(RouteEdge(node_id, - short_channel_id, - channel_policy.fee_base_msat, - channel_policy.fee_proportional_millionths, - channel_policy.cltv_expiry_delta)) + route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id)) prev_node_id = node_id return route diff --git a/electrum/lnutil.py b/electrum/lnutil.py @@ -1,7 +1,7 @@ from enum import IntFlag, IntEnum import json from collections import namedtuple -from typing import NamedTuple, List, Tuple, Mapping +from typing import NamedTuple, List, Tuple, Mapping, Optional import re from .util import bfh, bh2u, inv_dict @@ -16,6 +16,7 @@ from .i18n import _ from .lnaddr import lndecode from .keystore import BIP32_KeyStore + HTLC_TIMEOUT_WEIGHT = 663 HTLC_SUCCESS_WEIGHT = 703 @@ -597,8 +598,6 @@ def generate_keypair(ln_keystore: BIP32_KeyStore, key_family: LnKeyFamily, index return Keypair(*ln_keystore.get_keypair([key_family, 0, index], None)) -from typing import Optional - class EncumberedTransaction(NamedTuple("EncumberedTransaction", [('tx', Transaction), ('csv_delay', Optional[int])])): def to_json(self) -> dict: @@ -612,3 +611,7 @@ class EncumberedTransaction(NamedTuple("EncumberedTransaction", [('tx', Transact d2 = dict(d) d2['tx'] = Transaction(d['tx']) return EncumberedTransaction(**d2) + + +NUM_MAX_HOPS_IN_PAYMENT_PATH = 20 + diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -25,10 +25,11 @@ from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr, get_compressed_pubkey_from_bech32, extract_nodeid, PaymentFailure, split_host_port, ConnStringFormatError, generate_keypair, LnKeyFamily, LOCAL, REMOTE, - UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE) + UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE, + NUM_MAX_HOPS_IN_PAYMENT_PATH) from .lnaddr import lndecode from .i18n import _ -from .lnrouter import RouteEdge +from .lnrouter import RouteEdge, is_route_sane_to_use NUM_PEERS_TARGET = 4 PEER_RETRY_INTERVAL = 600 # seconds @@ -253,6 +254,10 @@ class LNWorker(PrintError): if amount_sat is None: raise InvoiceError(_("Missing amount")) amount_msat = int(amount_sat * 1000) + if addr.get_min_final_cltv_expiry() > 60 * 144: + raise InvoiceError("{}\n{}".format( + _("Invoice wants us to risk locking funds for unreasonably long."), + f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}")) route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat) node_id, short_channel_id = route[0].node_id, route[0].short_channel_id peer = self.peers[node_id] @@ -281,6 +286,7 @@ class LNWorker(PrintError): channels = list(self.channels.values()) for private_route in r_tags: if len(private_route) == 0: continue + if len(private_route) > NUM_MAX_HOPS_IN_PAYMENT_PATH: continue border_node_pubkey = private_route[0][0] path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels) if not path: continue @@ -301,6 +307,11 @@ class LNWorker(PrintError): route.append(RouteEdge(node_pubkey, short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta)) prev_node_id = node_pubkey + # test sanity + if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()): + self.print_error(f"rejecting insane route {route}") + route = None + continue break # if could not find route using any hint; try without hint now if route is None: @@ -308,6 +319,9 @@ class LNWorker(PrintError): if not path: raise PaymentFailure(_("No path found")) route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey) + if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()): + self.print_error(f"rejecting insane route {route}") + raise PaymentFailure(_("No path found")) return route def add_invoice(self, amount_sat, message):