commit 97393d05aa55a2095b7e6323aa7d4fc5a6014723
parent 029ec5a5ab4546f8ecd03119cde567208952edd0
Author: SomberNight <somber.night@protonmail.com>
Date: Mon, 8 Oct 2018 20:36:46 +0200
use 'r' field in invoice when making payments (routing hints)
Diffstat:
3 files changed, 49 insertions(+), 22 deletions(-)
diff --git a/electrum/lnbase.py b/electrum/lnbase.py
@@ -1018,19 +1018,18 @@ class Peer(PrintError):
await self.receive_commitment(chan)
self.revoke(chan)
- async def pay(self, path, chan, amount_msat, payment_hash, pubkey_in_invoice, min_final_cltv_expiry):
+ async def pay(self, route: List[RouteEdge], chan, amount_msat, payment_hash, min_final_cltv_expiry):
assert chan.get_state() == "OPEN", chan.get_state()
assert amount_msat > 0, "amount_msat is not greater zero"
height = self.network.get_local_height()
- route = self.network.path_finder.create_route_from_path(path, self.lnworker.node_keypair.pubkey)
hops_data = []
- sum_of_deltas = sum(route_edge.channel_policy.cltv_expiry_delta for route_edge in route[1:])
+ sum_of_deltas = sum(route_edge.cltv_expiry_delta for route_edge in route[1:])
total_fee = 0
final_cltv_expiry_without_deltas = (height + min_final_cltv_expiry)
final_cltv_expiry_with_deltas = final_cltv_expiry_without_deltas + sum_of_deltas
for idx, route_edge in enumerate(route[1:]):
hops_data += [OnionHopsDataSingle(OnionPerHop(route_edge.short_channel_id, amount_msat.to_bytes(8, "big"), final_cltv_expiry_without_deltas.to_bytes(4, "big")))]
- total_fee += route_edge.channel_policy.fee_base_msat + ( amount_msat * route_edge.channel_policy.fee_proportional_millionths // 1000000 )
+ total_fee += route_edge.fee_base_msat + ( amount_msat * route_edge.fee_proportional_millionths // 1000000 )
associated_data = payment_hash
secret_key = os.urandom(32)
hops_data += [OnionHopsDataSingle(OnionPerHop(b"\x00"*8, amount_msat.to_bytes(8, "big"), (final_cltv_expiry_without_deltas).to_bytes(4, "big")))]
diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
@@ -28,7 +28,7 @@ import os
import json
import threading
from collections import namedtuple, defaultdict
-from typing import Sequence, Union, Tuple, Optional
+from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple
import binascii
import base64
import asyncio
@@ -478,14 +478,13 @@ class ChannelDB(JsonDB):
direction))
-class RouteEdge:
-
- def __init__(self, node_id: bytes, short_channel_id: bytes,
- channel_policy: ChannelInfoDirectedPolicy):
- # "if you travel through short_channel_id, you will reach node_id"
- self.node_id = node_id
- self.short_channel_id = short_channel_id
- self.channel_policy = channel_policy
+class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
+ ('short_channel_id', bytes),
+ ('fee_base_msat', int),
+ ('fee_proportional_millionths', int),
+ ('cltv_expiry_delta', int)])):
+ """if you travel through short_channel_id, you will reach node_id"""
+ pass
class LNPathFinder(PrintError):
@@ -578,7 +577,7 @@ class LNPathFinder(PrintError):
path.reverse()
return path
- def create_route_from_path(self, path, from_node_id: bytes) -> Sequence[RouteEdge]:
+ def create_route_from_path(self, path, from_node_id: bytes) -> List[RouteEdge]:
assert type(from_node_id) is bytes
if path is None:
raise Exception('cannot create route from None path')
@@ -591,6 +590,10 @@ class LNPathFinder(PrintError):
channel_policy = channel_info.get_policy_for_node(prev_node_id)
if channel_policy is None:
raise Exception('cannot find channel policy for short_channel_id: {}'.format(bh2u(short_channel_id)))
- route.append(RouteEdge(node_id, short_channel_id, channel_policy))
+ route.append(RouteEdge(node_id,
+ short_channel_id,
+ channel_policy.fee_base_msat,
+ channel_policy.fee_proportional_millionths,
+ channel_policy.cltv_expiry_delta))
prev_node_id = node_id
return route
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -27,6 +27,7 @@ from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
from .lnutil import LOCAL, REMOTE
from .lnaddr import lndecode
from .i18n import _
+from .lnrouter import RouteEdge
NUM_PEERS_TARGET = 4
@@ -237,16 +238,12 @@ class LNWorker(PrintError):
def pay(self, invoice, amount_sat=None):
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
payment_hash = addr.paymenthash
- invoice_pubkey = addr.pubkey.serialize()
amount_sat = (addr.amount * COIN) if addr.amount else amount_sat
if amount_sat is None:
raise InvoiceError(_("Missing amount"))
amount_msat = int(amount_sat * 1000)
- # TODO use 'r' field from invoice
- path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat)
- if path is None:
- raise PaymentFailure(_("No path found"))
- node_id, short_channel_id = path[0]
+ route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat)
+ node_id, short_channel_id = route[0].node_id, route[0].short_channel_id
peer = self.peers[node_id]
with self.lock:
channels = list(self.channels.values())
@@ -255,9 +252,37 @@ class LNWorker(PrintError):
break
else:
raise Exception("ChannelDB returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
- coro = peer.pay(path, chan, amount_msat, payment_hash, invoice_pubkey, addr.min_final_cltv_expiry)
+ coro = peer.pay(route, chan, amount_msat, payment_hash, addr.min_final_cltv_expiry)
return addr, peer, asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
+ def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]:
+ invoice_pubkey = decoded_invoice.pubkey.serialize()
+ # use 'r' field from invoice
+ route = None # type: List[RouteEdge]
+ for tag_type, data in decoded_invoice.tags:
+ if tag_type != 'r': continue
+ private_route = data
+ if len(private_route) == 0: continue
+ border_node_pubkey = private_route[0][0]
+ path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat)
+ if path is None: continue
+ route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
+ # we need to shift the node pubkey by one towards the destination:
+ private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey]
+ private_route_rest = [edge[1:] for edge in private_route]
+ for node_pubkey, edge_rest in zip(private_route_nodes, private_route_rest):
+ short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest
+ route.append(RouteEdge(node_pubkey, short_channel_id, fee_base_msat, fee_proportional_millionths,
+ cltv_expiry_delta))
+ break
+ # if could not find route using any hint; try without hint now
+ if route is None:
+ path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat)
+ if path is None:
+ raise PaymentFailure(_("No path found"))
+ route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
+ return route
+
def add_invoice(self, amount_sat, message):
payment_preimage = os.urandom(32)
RHASH = sha256(payment_preimage)