electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 815079efe025d33bbecc17ba0959f233f0e54376
parent 5b1a5e8786664eba0476453de166243073b62fd9
Author: SomberNight <somber.night@protonmail.com>
Date:   Tue, 17 Apr 2018 20:01:51 +0200

refactor storage of channels, path finding

Diffstat:
Mlib/lnbase.py | 261+++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------
Mlib/tests/test_lnbase.py | 4++--
2 files changed, 170 insertions(+), 95 deletions(-)

diff --git a/lib/lnbase.py b/lib/lnbase.py @@ -419,10 +419,8 @@ class Peer(PrintError): self.localfeatures = (0x08 if request_initial_sync else 0) # view of the network self.nodes = {} # received node announcements - self.channels = {} # received channel announcements - self.channel_u_origin = {} - self.channel_u_final = {} - self.graph_of_payment_channels = defaultdict(set) # node -> short_channel_id + self.channel_db = ChannelDB() + self.path_finder = LNPathFinder(self.channel_db) def diagnostic_name(self): return self.host @@ -541,8 +539,8 @@ class Peer(PrintError): def on_funding_signed(self, payload): sig = payload['signature'] channel_id = payload['channel_id'] - tx = self.channels[channel_id] - self.network.broadcast(tx) + #tx = self.channels[channel_id] # FIXME + #self.network.broadcast(tx) def on_funding_signed(self, payload): self.funding_signed[payload["temporary_channel_id"]].set_result(payload) @@ -588,99 +586,14 @@ class Peer(PrintError): pass def on_channel_update(self, payload): - flags = int.from_bytes(payload['flags'], byteorder="big") - direction = bool(flags & 1) - short_channel_id = payload['short_channel_id'] - if direction == 0: - self.channel_u_origin[short_channel_id] = payload - else: - self.channel_u_final[short_channel_id] = payload - self.print_error('channel update', binascii.hexlify(short_channel_id), flags) + self.channel_db.on_channel_update(payload) def on_channel_announcement(self, payload): - short_channel_id = payload['short_channel_id'] - self.print_error('channel announcement', binascii.hexlify(short_channel_id)) - self.channels[short_channel_id] = payload - self.add_channel_to_graph(payload) - - def add_channel_to_graph(self, payload): - node1 = payload['node_id_1'] - node2 = payload['node_id_2'] - channel_id = payload['short_channel_id'] - self.graph_of_payment_channels[node1].add(channel_id) - self.graph_of_payment_channels[node2].add(channel_id) + self.channel_db.on_channel_announcement(payload) #def open_channel(self, funding_sat, push_msat): # self.send_message(gen_msg('open_channel', funding_satoshis=funding_sat, push_msat=push_msat)) - @profiler - def find_route_for_payment(self, from_node_id, to_node_id, amount_msat=None): - """Return a route between from_node_id and to_node_id. - - Returns a list of (node_id, short_channel_id) representing a path. - To get from node ret[n][0] to ret[n+1][0], use channel ret[n][1] - """ - # TODO find multiple paths?? - - def edge_cost(short_channel_id, direction): - """Heuristic cost of going through a channel. - direction: 0 or 1. --- 0 means node_id_1 -> node_id_2 - """ - channel_updates = self.channel_u_origin if direction == 0 else self.channel_u_final - try: - cltv_expiry_delta = channel_updates[short_channel_id]['cltv_expiry_delta'] - htlc_minimum_msat = channel_updates[short_channel_id]['htlc_minimum_msat'] - fee_base_msat = channel_updates[short_channel_id]['fee_base_msat'] - fee_proportional_millionths = channel_updates[short_channel_id]['fee_proportional_millionths'] - except KeyError: - return float('inf') # can't use this channel - if amount_msat is not None and amount_msat < htlc_minimum_msat: - return float('inf') # can't use this channel - amt = amount_msat or 50000 * 1000 # guess for typical payment amount - fee_msat = fee_base_msat + amt * fee_proportional_millionths / 1000000 - # TODO revise - # paying 10 more satoshis ~ waiting one more block - fee_cost = fee_msat / 1000 / 10 - cltv_cost = cltv_expiry_delta - return cltv_cost + fee_cost + 1 - - # run Dijkstra - distance_from_start = defaultdict(lambda: float('inf')) - distance_from_start[from_node_id] = 0 - prev_node = {} - nodes_to_explore = queue.PriorityQueue() - nodes_to_explore.put((0, from_node_id)) - - while nodes_to_explore.qsize() > 0: - dist_to_cur_node, cur_node = nodes_to_explore.get() - if cur_node == to_node_id: - break - if dist_to_cur_node != distance_from_start[cur_node]: - # queue.PriorityQueue does not implement decrease_priority, - # so instead of decreasing priorities, we add items again into the queue. - # so there are duplicates in the queue, that we discard now: - continue - for edge in self.graph_of_payment_channels[cur_node]: - node1 = self.channels[edge]['node_id_1'] - node2 = self.channels[edge]['node_id_2'] - neighbour, direction = (node1, 1) if node1 != cur_node else (node2, 0) - alt_dist_to_neighbour = distance_from_start[cur_node] + edge_cost(edge, direction) - if alt_dist_to_neighbour < distance_from_start[neighbour]: - distance_from_start[neighbour] = alt_dist_to_neighbour - prev_node[neighbour] = cur_node, edge - nodes_to_explore.put((alt_dist_to_neighbour, neighbour)) - else: - return None # no path found - - # backtrack from end to start - cur_node = to_node_id - path = [(cur_node, None)] - while cur_node != from_node_id: - cur_node, edge_taken = prev_node[cur_node] - path += [(cur_node, edge_taken)] - path.reverse() - return path - @aiosafe async def main_loop(self): self.reader, self.writer = await asyncio.open_connection(self.host, self.port) @@ -792,3 +705,165 @@ class LNWorker: # todo: get utxo from wallet # submit coro to asyncio main loop self.peer.open_channel() + + +class ChannelInfo(PrintError): + + def __init__(self, channel_announcement_payload): + self.channel_id = channel_announcement_payload['short_channel_id'] + self.node_id_1 = channel_announcement_payload['node_id_1'] + self.node_id_2 = channel_announcement_payload['node_id_2'] + + self.capacity_sat = None + self.policy_node1 = None + self.policy_node2 = None + + def set_capacity(self, capacity): + # TODO call this after looking up UTXO for funding txn on chain + self.capacity_sat = capacity + + def on_channel_update(self, msg_payload): + assert self.channel_id == msg_payload['short_channel_id'] + flags = int.from_bytes(msg_payload['flags'], byteorder="big") + direction = bool(flags & 1) + if direction == 0: + self.policy_node1 = ChannelInfoDirectedPolicy(msg_payload) + else: + self.policy_node2 = ChannelInfoDirectedPolicy(msg_payload) + self.print_error('channel update', binascii.hexlify(self.channel_id), flags) + + def get_policy_for_node(self, node_id): + if node_id == self.node_id_1: + return self.policy_node1 + elif node_id == self.node_id_2: + return self.policy_node2 + else: + raise Exception('node_id {} not in channel {}'.format(node_id, self.channel_id)) + + +class ChannelInfoDirectedPolicy: + + def __init__(self, channel_update_payload): + self.cltv_expiry_delta = channel_update_payload['cltv_expiry_delta'] + self.htlc_minimum_msat = channel_update_payload['htlc_minimum_msat'] + self.fee_base_msat = channel_update_payload['fee_base_msat'] + self.fee_proportional_millionths = channel_update_payload['fee_proportional_millionths'] + + +class ChannelDB(PrintError): + + def __init__(self): + self._id_to_channel_info = {} + self._channels_for_node = defaultdict(set) # node -> set(short_channel_id) + + def get_channel_info(self, channel_id): + return self._id_to_channel_info.get(channel_id, None) + + def get_channels_for_node(self, node_id): + """Returns the set of channels that have node_id as one of the endpoints.""" + return self._channels_for_node[node_id] + + def on_channel_announcement(self, msg_payload): + short_channel_id = msg_payload['short_channel_id'] + self.print_error('channel announcement', binascii.hexlify(short_channel_id)) + channel_info = ChannelInfo(msg_payload) + self._id_to_channel_info[short_channel_id] = channel_info + self._channels_for_node[channel_info.node_id_1].add(short_channel_id) + self._channels_for_node[channel_info.node_id_2].add(short_channel_id) + + def on_channel_update(self, msg_payload): + short_channel_id = msg_payload['short_channel_id'] + self._id_to_channel_info[short_channel_id].on_channel_update(msg_payload) + + def remove_channel(self, short_channel_id): + try: + channel_info = self._id_to_channel_info[short_channel_id] + except KeyError: + self.print_error('cannot find channel {}'.format(short_channel_id)) + return + self._id_to_channel_info.pop(short_channel_id, None) + for node in (channel_info.node_id_1, channel_info.node_id_2): + try: + self._channels_for_node[node].remove(short_channel_id) + except KeyError: + pass + + +class LNPathFinder(PrintError): + + def __init__(self, channel_db): + self.channel_db = channel_db + + def _edge_cost(self, short_channel_id, start_node, payment_amt_msat): + """Heuristic cost of going through a channel. + direction: 0 or 1. --- 0 means node_id_1 -> node_id_2 + """ + channel_info = self.channel_db.get_channel_info(short_channel_id) + if channel_info is None: + return float('inf') + + channel_policy = channel_info.get_policy_for_node(start_node) + cltv_expiry_delta = channel_policy.cltv_expiry_delta + htlc_minimum_msat = channel_policy.htlc_minimum_msat + fee_base_msat = channel_policy.fee_base_msat + fee_proportional_millionths = channel_policy.fee_proportional_millionths + if payment_amt_msat is not None: + if payment_amt_msat < htlc_minimum_msat: + return float('inf') # payment amount too little + if channel_info.capacity_sat is not None and \ + payment_amt_msat // 1000 > channel_info.capacity_sat: + return float('inf') # payment amount too large + amt = payment_amt_msat or 50000 * 1000 # guess for typical payment amount + fee_msat = fee_base_msat + amt * fee_proportional_millionths / 1000000 + # TODO revise + # paying 10 more satoshis ~ waiting one more block + fee_cost = fee_msat / 1000 / 10 + cltv_cost = cltv_expiry_delta + return cltv_cost + fee_cost + 1 + + @profiler + def find_path_for_payment(self, from_node_id, to_node_id, amount_msat=None): + """Return a path between from_node_id and to_node_id. + + Returns a list of (node_id, short_channel_id) representing a path. + To get from node ret[n][0] to ret[n+1][0], use channel ret[n][1] + """ + # TODO find multiple paths?? + + # run Dijkstra + distance_from_start = defaultdict(lambda: float('inf')) + distance_from_start[from_node_id] = 0 + prev_node = {} + nodes_to_explore = queue.PriorityQueue() + nodes_to_explore.put((0, from_node_id)) + + while nodes_to_explore.qsize() > 0: + dist_to_cur_node, cur_node = nodes_to_explore.get() + if cur_node == to_node_id: + break + if dist_to_cur_node != distance_from_start[cur_node]: + # queue.PriorityQueue does not implement decrease_priority, + # so instead of decreasing priorities, we add items again into the queue. + # so there are duplicates in the queue, that we discard now: + continue + for edge_channel_id in self.channel_db.get_channels_for_node(cur_node): + channel_info = self.channel_db.get_channel_info(edge_channel_id) + node1, node2 = channel_info.node_id_1, channel_info.node_id_2 + neighbour = node2 if node1 == cur_node else node1 + alt_dist_to_neighbour = distance_from_start[cur_node] \ + + self._edge_cost(edge_channel_id, cur_node, amount_msat) + if alt_dist_to_neighbour < distance_from_start[neighbour]: + distance_from_start[neighbour] = alt_dist_to_neighbour + prev_node[neighbour] = cur_node, edge_channel_id + nodes_to_explore.put((alt_dist_to_neighbour, neighbour)) + else: + return None # no path found + + # backtrack from end to start + cur_node = to_node_id + path = [(cur_node, None)] + while cur_node != from_node_id: + cur_node, edge_taken = prev_node[cur_node] + path += [(cur_node, edge_taken)] + path.reverse() + return path diff --git a/lib/tests/test_lnbase.py b/lib/tests/test_lnbase.py @@ -181,7 +181,7 @@ class Test_LNBase(unittest.TestCase): # local_signature = 30440220549e80b4496803cbc4a1d09d46df50109f546d43fbbf86cd90b174b1484acd5402205f12a4f995cb9bded597eabfee195a285986aa6d93ae5bb72507ebc6a4e2349e output_htlc_success_tx_4 = "020000000001018154ecccf11a5fb56c39654c4deb4d2296f83c69268280b94d021370c94e219704000000000000000001a00f0000000000002200204adb4e2f00643db396dd120d4e7dc17625f5f2c11a40d857accc862d6b7dd80e050047304402207e0410e45454b0978a623f36a10626ef17b27d9ad44e2760f98cfa3efb37924f0220220bd8acd43ecaa916a80bd4f919c495a2c58982ce7c8625153f8596692a801d014730440220549e80b4496803cbc4a1d09d46df50109f546d43fbbf86cd90b174b1484acd5402205f12a4f995cb9bded597eabfee195a285986aa6d93ae5bb72507ebc6a4e2349e012004040404040404040404040404040404040404040404040404040404040404048a76a91414011f7254d96b819c76986c277d115efce6f7b58763ac67210394854aa6eab5b2a8122cc726e9dded053a2184d88256816826d6231c068d4a5b7c8201208763a91418bc1a114ccf9c052d3d23e28d3b0a9d1227434288527c21030d417a46946384f88d5f3337267c5e579765875dc4daca813e21734b140639e752ae677502f801b175ac686800000000" - def test_find_route_for_payment(self): + def test_find_path_for_payment(self): p = Peer('', 0, 'a') p.on_channel_announcement({'node_id_1': 'b', 'node_id_2': 'c', 'short_channel_id': bfh('0000000000000001')}) p.on_channel_announcement({'node_id_1': 'b', 'node_id_2': 'e', 'short_channel_id': bfh('0000000000000002')}) @@ -201,7 +201,7 @@ class Test_LNBase(unittest.TestCase): p.on_channel_update({'short_channel_id': bfh('0000000000000005'), 'flags': b'1', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 999}) p.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'flags': b'0', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 99999999}) p.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'flags': b'1', 'cltv_expiry_delta': 10, 'htlc_minimum_msat': 250, 'fee_base_msat': 100, 'fee_proportional_millionths': 150}) - print(p.find_route_for_payment('a', 'e', 100000)) + print(p.path_finder.find_path_for_payment('a', 'e', 100000)) def test_key_derivation(self): # BOLT3, Appendix E