electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 180f6d34bec2f3e443488f922591d51c11cab1f6
parent 06b5299b0fe731e4e75555575b7dc3ce90ddd799
Author: ThomasV <thomasv@electrum.org>
Date:   Sat, 22 Jun 2019 09:47:08 +0200

separate channel_db module

Diffstat:
Aelectrum/channel_db.py | 589+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Melectrum/lnrouter.py | 535+------------------------------------------------------------------------------
Melectrum/network.py | 10++++++----
3 files changed, 596 insertions(+), 538 deletions(-)

diff --git a/electrum/channel_db.py b/electrum/channel_db.py @@ -0,0 +1,589 @@ +# -*- coding: utf-8 -*- +# +# Electrum - lightweight Bitcoin client +# Copyright (C) 2018 The Electrum developers +# +# Permission is hereby granted, free of charge, to any person +# obtaining a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from datetime import datetime +import time +import random +import queue +import os +import json +import threading +import concurrent +from collections import defaultdict +from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set +import binascii +import base64 + +from sqlalchemy import Column, ForeignKey, Integer, String, Boolean +from sqlalchemy.orm.query import Query +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.sql import not_, or_ + +from .sql_db import SqlDB, sql +from . import constants +from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits, print_msg, chunks +from .logging import Logger +from .storage import JsonDB +from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update +from .crypto import sha256d +from . import ecc +from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH, + NotFoundChanAnnouncementForUpdate) +from .lnmsg import encode_msg + +if TYPE_CHECKING: + from .lnchannel import Channel + from .network import Network + +class UnknownEvenFeatureBits(Exception): pass + +def validate_features(features : int): + enabled_features = list_enabled_bits(features) + for fbit in enabled_features: + if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0: + raise UnknownEvenFeatureBits() + +Base = declarative_base() + +FLAG_DISABLE = 1 << 1 +FLAG_DIRECTION = 1 << 0 + +class ChannelInfo(Base): + __tablename__ = 'channel_info' + short_channel_id = Column(String(64), primary_key=True) + node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) + node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) + capacity_sat = Column(Integer) + msg_payload_hex = Column(String(1024), nullable=False) + trusted = Column(Boolean, nullable=False) + + @staticmethod + def from_msg(payload): + features = int.from_bytes(payload['features'], 'big') + validate_features(features) + channel_id = payload['short_channel_id'].hex() + node_id_1 = payload['node_id_1'].hex() + node_id_2 = payload['node_id_2'].hex() + assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] + msg_payload_hex = encode_msg('channel_announcement', **payload).hex() + capacity_sat = None + return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1, + node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, + trusted = False) + + @property + def msg_payload(self): + return bytes.fromhex(self.msg_payload_hex) + + +class Policy(Base): + __tablename__ = 'policy' + start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) + short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True) + cltv_expiry_delta = Column(Integer, nullable=False) + htlc_minimum_msat = Column(Integer, nullable=False) + htlc_maximum_msat = Column(Integer) + fee_base_msat = Column(Integer, nullable=False) + fee_proportional_millionths = Column(Integer, nullable=False) + channel_flags = Column(Integer, nullable=False) + timestamp = Column(Integer, nullable=False) + + @staticmethod + def from_msg(payload): + cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big") + htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big") + htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None + fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big") + fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big") + channel_flags = int.from_bytes(payload['channel_flags'], "big") + timestamp = int.from_bytes(payload['timestamp'], "big") + start_node = payload['start_node'].hex() + short_channel_id = payload['short_channel_id'].hex() + + return Policy(start_node=start_node, + short_channel_id=short_channel_id, + cltv_expiry_delta=cltv_expiry_delta, + htlc_minimum_msat=htlc_minimum_msat, + fee_base_msat=fee_base_msat, + fee_proportional_millionths=fee_proportional_millionths, + channel_flags=channel_flags, + timestamp=timestamp, + htlc_maximum_msat=htlc_maximum_msat) + + def is_disabled(self): + return self.channel_flags & FLAG_DISABLE + +class NodeInfo(Base): + __tablename__ = 'node_info' + node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') + features = Column(Integer, nullable=False) + timestamp = Column(Integer, nullable=False) + alias = Column(String(64), nullable=False) + + @staticmethod + def from_msg(payload): + node_id = payload['node_id'].hex() + features = int.from_bytes(payload['features'], "big") + validate_features(features) + addresses = NodeInfo.parse_addresses_field(payload['addresses']) + alias = payload['alias'].rstrip(b'\x00').hex() + timestamp = int.from_bytes(payload['timestamp'], "big") + return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [ + Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses] + + @staticmethod + def parse_addresses_field(addresses_field): + buf = addresses_field + def read(n): + nonlocal buf + data, buf = buf[0:n], buf[n:] + return data + addresses = [] + while buf: + atype = ord(read(1)) + if atype == 0: + pass + elif atype == 1: # IPv4 + ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4))) + port = int.from_bytes(read(2), 'big') + if is_ip_address(ipv4_addr) and port != 0: + addresses.append((ipv4_addr, port)) + elif atype == 2: # IPv6 + ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)]) + ipv6_addr = ipv6_addr.decode('ascii') + port = int.from_bytes(read(2), 'big') + if is_ip_address(ipv6_addr) and port != 0: + addresses.append((ipv6_addr, port)) + elif atype == 3: # onion v2 + host = base64.b32encode(read(10)) + b'.onion' + host = host.decode('ascii').lower() + port = int.from_bytes(read(2), 'big') + addresses.append((host, port)) + elif atype == 4: # onion v3 + host = base64.b32encode(read(35)) + b'.onion' + host = host.decode('ascii').lower() + port = int.from_bytes(read(2), 'big') + addresses.append((host, port)) + else: + # unknown address type + # we don't know how long it is -> have to escape + # if there are other addresses we could have parsed later, they are lost. + break + return addresses + +class Address(Base): + __tablename__ = 'address' + node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) + host = Column(String(256), primary_key=True) + port = Column(Integer, primary_key=True) + last_connected_date = Column(Integer(), nullable=True) + + + +class ChannelDB(SqlDB): + + NUM_MAX_RECENT_PEERS = 20 + + def __init__(self, network: 'Network'): + path = os.path.join(get_headers_dir(network.config), 'channel_db') + super().__init__(network, path, Base) + self.num_nodes = 0 + self.num_channels = 0 + self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] + self.ca_verifier = LNChannelVerifier(network, self) + self.update_counts() + + @sql + def update_counts(self): + self._update_counts() + + def _update_counts(self): + self.num_channels = self.DBSession.query(ChannelInfo).count() + self.num_policies = self.DBSession.query(Policy).count() + self.num_nodes = self.DBSession.query(NodeInfo).count() + + @sql + def known_ids(self): + known = self.DBSession.query(ChannelInfo.short_channel_id).all() + return set(bfh(r.short_channel_id) for r in known) + + @sql + def add_recent_peer(self, peer: LNPeerAddr): + now = int(time.time()) + node_id = peer.pubkey.hex() + addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none() + if addr: + addr.last_connected_date = now + else: + addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now) + self.DBSession.add(addr) + self.DBSession.commit() + + @sql + def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): + unshuffled = self.DBSession \ + .query(NodeInfo) \ + .filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \ + .limit(200) \ + .all() + return random.sample(unshuffled, len(unshuffled)) + + @sql + def nodes_get(self, node_id): + return self.DBSession \ + .query(NodeInfo) \ + .filter_by(node_id = node_id.hex()) \ + .one_or_none() + + @sql + def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: + r = self.DBSession.query(Address).filter_by(node_id=node_id.hex()).order_by(Address.last_connected_date.desc()).all() + if not r: + return None + addr = r[0] + return LNPeerAddr(addr.host, addr.port, bytes.fromhex(addr.node_id)) + + @sql + def get_recent_peers(self): + r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all() + return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r] + + @sql + def missing_channel_announcements(self) -> Set[int]: + expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id))) + return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all()) + + @sql + def missing_channel_updates(self) -> Set[int]: + expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id))) + return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) + + @sql + def add_verified_channel_info(self, short_id, capacity): + # called from lnchannelverifier + channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none() + channel_info.trusted = True + channel_info.capacity = capacity + self.DBSession.commit() + + @sql + @profiler + def on_channel_announcement(self, msg_payloads, trusted=True): + if type(msg_payloads) is dict: + msg_payloads = [msg_payloads] + new_channels = {} + for msg in msg_payloads: + short_channel_id = bh2u(msg['short_channel_id']) + if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count(): + continue + if constants.net.rev_genesis_bytes() != msg['chain_hash']: + self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash']))) + continue + try: + channel_info = ChannelInfo.from_msg(msg) + except UnknownEvenFeatureBits: + self.logger.info("unknown feature bits") + continue + channel_info.trusted = trusted + new_channels[short_channel_id] = channel_info + if not trusted: + self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload) + for channel_info in new_channels.values(): + self.DBSession.add(channel_info) + self.DBSession.commit() + self._update_counts() + self.logger.debug('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads))) + + @sql + def get_last_timestamp(self): + return self._get_last_timestamp() + + def _get_last_timestamp(self): + from sqlalchemy.sql import func + r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one() + return r.max_timestamp or 0 + + def print_change(self, old_policy, new_policy): + # print what changed between policies + if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta: + self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}') + if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat: + self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}') + if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat: + self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}') + if old_policy.fee_base_msat != new_policy.fee_base_msat: + self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}') + if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths: + self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}') + if old_policy.channel_flags != new_policy.channel_flags: + self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}') + + @sql + def get_info_for_updates(self, payloads): + short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads] + channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() + channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} + return channel_infos + + @sql + def get_policies_for_updates(self, payloads): + out = {} + for payload in payloads: + short_channel_id = payload['short_channel_id'].hex() + start_node = payload['start_node'].hex() + policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none() + if policy: + out[short_channel_id+start_node] = policy + return out + + @profiler + def filter_channel_updates(self, payloads, max_age=None): + orphaned = [] # no channel announcement for channel update + expired = [] # update older than two weeks + deprecated = [] # update older than database entry + good = {} # good updates + to_delete = [] # database entries to delete + # filter orphaned and expired first + known = [] + now = int(time.time()) + channel_infos = self.get_info_for_updates(payloads) + for payload in payloads: + short_channel_id = payload['short_channel_id'] + timestamp = int.from_bytes(payload['timestamp'], "big") + if max_age and now - timestamp > max_age: + expired.append(short_channel_id) + continue + channel_info = channel_infos.get(short_channel_id) + if not channel_info: + orphaned.append(short_channel_id) + continue + flags = int.from_bytes(payload['channel_flags'], 'big') + direction = flags & FLAG_DIRECTION + start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id + payload['start_node'] = bfh(start_node) + known.append(payload) + # compare updates to existing database entries + old_policies = self.get_policies_for_updates(known) + for payload in known: + timestamp = int.from_bytes(payload['timestamp'], "big") + start_node = payload['start_node'] + short_channel_id = payload['short_channel_id'] + key = (short_channel_id+start_node).hex() + old_policy = old_policies.get(key) + if old_policy: + if timestamp <= old_policy.timestamp: + deprecated.append(short_channel_id) + else: + good[key] = payload + to_delete.append(old_policy) + else: + good[key] = payload + good = list(good.values()) + return orphaned, expired, deprecated, good, to_delete + + def add_channel_update(self, payload): + orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload]) + assert len(good) == 1 + self.update_policies(good, to_delete) + + @sql + @profiler + def update_policies(self, to_add, to_delete): + for policy in to_delete: + self.DBSession.delete(policy) + self.DBSession.commit() + for payload in to_add: + policy = Policy.from_msg(payload) + self.DBSession.add(policy) + self.DBSession.commit() + self._update_counts() + + @sql + @profiler + def on_node_announcement(self, msg_payloads): + if type(msg_payloads) is dict: + msg_payloads = [msg_payloads] + old_addr = None + new_nodes = {} + new_addresses = {} + for msg_payload in msg_payloads: + try: + node_info, node_addresses = NodeInfo.from_msg(msg_payload) + except UnknownEvenFeatureBits: + continue + node_id = node_info.node_id + # Ignore node if it has no associated channel (DoS protection) + # FIXME this is slow + expr = or_(ChannelInfo.node1_id==node_id, ChannelInfo.node2_id==node_id) + if len(self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).limit(1).all()) == 0: + #self.logger.info('ignoring orphan node_announcement') + continue + node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none() + if node and node.timestamp >= node_info.timestamp: + continue + node = new_nodes.get(node_id) + if node and node.timestamp >= node_info.timestamp: + continue + new_nodes[node_id] = node_info + for addr in node_addresses: + new_addresses[(addr.node_id,addr.host,addr.port)] = addr + self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) + for node_info in new_nodes.values(): + self.DBSession.add(node_info) + for new_addr in new_addresses.values(): + old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none() + if not old_addr: + self.DBSession.add(new_addr) + self.DBSession.commit() + self._update_counts() + + def get_routing_policy_for_channel(self, start_node_id: bytes, + short_channel_id: bytes) -> Optional[bytes]: + if not start_node_id or not short_channel_id: return None + channel_info = self.get_channel_info(short_channel_id) + if channel_info is not None: + return self.get_policy_for_node(short_channel_id, start_node_id) + msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id)) + if not msg: + return None + return Policy.from_msg(msg) # won't actually be written to DB + + @sql + @profiler + def get_old_policies(self, delta): + timestamp = int(time.time()) - delta + old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp) + return old_policies.distinct().count() + + @sql + @profiler + def prune_old_policies(self, delta): + # note: delete queries are order sensitive + timestamp = int(time.time()) - delta + old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp) + delete_old_channels = ChannelInfo.__table__.delete().where(ChannelInfo.short_channel_id.in_(old_policies)) + delete_old_policies = Policy.__table__.delete().where(Policy.timestamp <= timestamp) + self.DBSession.execute(delete_old_channels) + self.DBSession.execute(delete_old_policies) + self.DBSession.commit() + self._update_counts() + + @sql + @profiler + def get_orphaned_channels(self): + subquery = self.DBSession.query(Policy.short_channel_id) + orphaned = self.DBSession.query(ChannelInfo).filter(not_(ChannelInfo.short_channel_id.in_(subquery))) + return orphaned.count() + + @sql + @profiler + def prune_orphaned_channels(self): + subquery = self.DBSession.query(Policy.short_channel_id) + delete_orphaned = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(subquery))) + self.DBSession.execute(delete_orphaned) + self.DBSession.commit() + self._update_counts() + + def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes): + if not verify_sig_for_channel_update(msg_payload, start_node_id): + return # ignore + short_channel_id = msg_payload['short_channel_id'] + msg_payload['start_node'] = start_node_id + self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload + + @sql + def remove_channel(self, short_channel_id): + r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none() + if not r: + return + self.DBSession.delete(r) + self.DBSession.commit() + + def print_graph(self, full_ids=False): + # used for debugging. + # FIXME there is a race here - iterables could change size from another thread + def other_node_id(node_id, channel_id): + channel_info = self.get_channel_info(channel_id) + if node_id == channel_info.node1_id: + other = channel_info.node2_id + else: + other = channel_info.node1_id + return other if full_ids else other[-4:] + + print_msg('nodes') + for node in self.DBSession.query(NodeInfo).all(): + print_msg(node) + + print_msg('channels') + for channel_info in self.DBSession.query(ChannelInfo).all(): + short_channel_id = channel_info.short_channel_id + node1 = channel_info.node1_id + node2 = channel_info.node2_id + direction1 = self.get_policy_for_node(channel_info, node1) is not None + direction2 = self.get_policy_for_node(channel_info, node2) is not None + if direction1 and direction2: + direction = 'both' + elif direction1: + direction = 'forward' + elif direction2: + direction = 'backward' + else: + direction = 'none' + print_msg('{}: {}, {}, {}' + .format(bh2u(short_channel_id), + bh2u(node1) if full_ids else bh2u(node1[-4:]), + bh2u(node2) if full_ids else bh2u(node2[-4:]), + direction)) + + + @sql + def get_node_addresses(self, node_info): + return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all() + + @sql + @profiler + def load_data(self): + r = self.DBSession.query(ChannelInfo).all() + self._channels = dict([(bfh(x.short_channel_id), x) for x in r]) + r = self.DBSession.query(Policy).filter_by().all() + self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r]) + self._channels_for_node = defaultdict(set) + for channel_info in self._channels.values(): + self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id)) + self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id)) + self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}') + + def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']: + return self._policies.get((node_id, short_channel_id)) + + def get_channel_info(self, channel_id: bytes): + return self._channels.get(channel_id) + + def get_channels_for_node(self, node_id) -> Set[bytes]: + """Returns the set of channels that have node_id as one of the endpoints.""" + return self._channels_for_node.get(node_id) or set() + + + diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py @@ -36,12 +36,6 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK import binascii import base64 -from sqlalchemy import Column, ForeignKey, Integer, String, Boolean -from sqlalchemy.orm.query import Query -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.sql import not_, or_ - -from .sql_db import SqlDB, sql from . import constants from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits, print_msg, chunks from .logging import Logger @@ -52,543 +46,16 @@ from . import ecc from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH, NotFoundChanAnnouncementForUpdate) from .lnmsg import encode_msg +from .channel_db import ChannelDB if TYPE_CHECKING: from .lnchannel import Channel from .network import Network -class UnknownEvenFeatureBits(Exception): pass class NoChannelPolicy(Exception): def __init__(self, short_channel_id: bytes): super().__init__(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}') -def validate_features(features : int): - enabled_features = list_enabled_bits(features) - for fbit in enabled_features: - if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0: - raise UnknownEvenFeatureBits() - -Base = declarative_base() - -FLAG_DISABLE = 1 << 1 -FLAG_DIRECTION = 1 << 0 - -class ChannelInfo(Base): - __tablename__ = 'channel_info' - short_channel_id = Column(String(64), primary_key=True) - node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) - node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) - capacity_sat = Column(Integer) - msg_payload_hex = Column(String(1024), nullable=False) - trusted = Column(Boolean, nullable=False) - - @staticmethod - def from_msg(payload): - features = int.from_bytes(payload['features'], 'big') - validate_features(features) - channel_id = payload['short_channel_id'].hex() - node_id_1 = payload['node_id_1'].hex() - node_id_2 = payload['node_id_2'].hex() - assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] - msg_payload_hex = encode_msg('channel_announcement', **payload).hex() - capacity_sat = None - return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1, - node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, - trusted = False) - - @property - def msg_payload(self): - return bytes.fromhex(self.msg_payload_hex) - - -class Policy(Base): - __tablename__ = 'policy' - start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) - short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True) - cltv_expiry_delta = Column(Integer, nullable=False) - htlc_minimum_msat = Column(Integer, nullable=False) - htlc_maximum_msat = Column(Integer) - fee_base_msat = Column(Integer, nullable=False) - fee_proportional_millionths = Column(Integer, nullable=False) - channel_flags = Column(Integer, nullable=False) - timestamp = Column(Integer, nullable=False) - - @staticmethod - def from_msg(payload): - cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big") - htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big") - htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None - fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big") - fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big") - channel_flags = int.from_bytes(payload['channel_flags'], "big") - timestamp = int.from_bytes(payload['timestamp'], "big") - start_node = payload['start_node'].hex() - short_channel_id = payload['short_channel_id'].hex() - - return Policy(start_node=start_node, - short_channel_id=short_channel_id, - cltv_expiry_delta=cltv_expiry_delta, - htlc_minimum_msat=htlc_minimum_msat, - fee_base_msat=fee_base_msat, - fee_proportional_millionths=fee_proportional_millionths, - channel_flags=channel_flags, - timestamp=timestamp, - htlc_maximum_msat=htlc_maximum_msat) - - def is_disabled(self): - return self.channel_flags & FLAG_DISABLE - -class NodeInfo(Base): - __tablename__ = 'node_info' - node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') - features = Column(Integer, nullable=False) - timestamp = Column(Integer, nullable=False) - alias = Column(String(64), nullable=False) - - @staticmethod - def from_msg(payload): - node_id = payload['node_id'].hex() - features = int.from_bytes(payload['features'], "big") - validate_features(features) - addresses = NodeInfo.parse_addresses_field(payload['addresses']) - alias = payload['alias'].rstrip(b'\x00').hex() - timestamp = int.from_bytes(payload['timestamp'], "big") - return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [ - Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses] - - @staticmethod - def parse_addresses_field(addresses_field): - buf = addresses_field - def read(n): - nonlocal buf - data, buf = buf[0:n], buf[n:] - return data - addresses = [] - while buf: - atype = ord(read(1)) - if atype == 0: - pass - elif atype == 1: # IPv4 - ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4))) - port = int.from_bytes(read(2), 'big') - if is_ip_address(ipv4_addr) and port != 0: - addresses.append((ipv4_addr, port)) - elif atype == 2: # IPv6 - ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)]) - ipv6_addr = ipv6_addr.decode('ascii') - port = int.from_bytes(read(2), 'big') - if is_ip_address(ipv6_addr) and port != 0: - addresses.append((ipv6_addr, port)) - elif atype == 3: # onion v2 - host = base64.b32encode(read(10)) + b'.onion' - host = host.decode('ascii').lower() - port = int.from_bytes(read(2), 'big') - addresses.append((host, port)) - elif atype == 4: # onion v3 - host = base64.b32encode(read(35)) + b'.onion' - host = host.decode('ascii').lower() - port = int.from_bytes(read(2), 'big') - addresses.append((host, port)) - else: - # unknown address type - # we don't know how long it is -> have to escape - # if there are other addresses we could have parsed later, they are lost. - break - return addresses - -class Address(Base): - __tablename__ = 'address' - node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) - host = Column(String(256), primary_key=True) - port = Column(Integer, primary_key=True) - last_connected_date = Column(Integer(), nullable=True) - - - -class ChannelDB(SqlDB): - - NUM_MAX_RECENT_PEERS = 20 - - def __init__(self, network: 'Network'): - path = os.path.join(get_headers_dir(network.config), 'channel_db') - super().__init__(network, path, Base) - self.num_nodes = 0 - self.num_channels = 0 - self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] - self.ca_verifier = LNChannelVerifier(network, self) - self.update_counts() - - @sql - def update_counts(self): - self._update_counts() - - def _update_counts(self): - self.num_channels = self.DBSession.query(ChannelInfo).count() - self.num_policies = self.DBSession.query(Policy).count() - self.num_nodes = self.DBSession.query(NodeInfo).count() - - @sql - def known_ids(self): - known = self.DBSession.query(ChannelInfo.short_channel_id).all() - return set(bfh(r.short_channel_id) for r in known) - - @sql - def add_recent_peer(self, peer: LNPeerAddr): - now = int(time.time()) - node_id = peer.pubkey.hex() - addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none() - if addr: - addr.last_connected_date = now - else: - addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now) - self.DBSession.add(addr) - self.DBSession.commit() - - @sql - def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): - unshuffled = self.DBSession \ - .query(NodeInfo) \ - .filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \ - .limit(200) \ - .all() - return random.sample(unshuffled, len(unshuffled)) - - @sql - def nodes_get(self, node_id): - return self.DBSession \ - .query(NodeInfo) \ - .filter_by(node_id = node_id.hex()) \ - .one_or_none() - - @sql - def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: - r = self.DBSession.query(Address).filter_by(node_id=node_id.hex()).order_by(Address.last_connected_date.desc()).all() - if not r: - return None - addr = r[0] - return LNPeerAddr(addr.host, addr.port, bytes.fromhex(addr.node_id)) - - @sql - def get_recent_peers(self): - r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all() - return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r] - - @sql - def missing_channel_announcements(self) -> Set[int]: - expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id))) - return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all()) - - @sql - def missing_channel_updates(self) -> Set[int]: - expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id))) - return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) - - @sql - def add_verified_channel_info(self, short_id, capacity): - # called from lnchannelverifier - channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none() - channel_info.trusted = True - channel_info.capacity = capacity - self.DBSession.commit() - - @sql - @profiler - def on_channel_announcement(self, msg_payloads, trusted=True): - if type(msg_payloads) is dict: - msg_payloads = [msg_payloads] - new_channels = {} - for msg in msg_payloads: - short_channel_id = bh2u(msg['short_channel_id']) - if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count(): - continue - if constants.net.rev_genesis_bytes() != msg['chain_hash']: - self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash']))) - continue - try: - channel_info = ChannelInfo.from_msg(msg) - except UnknownEvenFeatureBits: - self.logger.info("unknown feature bits") - continue - channel_info.trusted = trusted - new_channels[short_channel_id] = channel_info - if not trusted: - self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload) - for channel_info in new_channels.values(): - self.DBSession.add(channel_info) - self.DBSession.commit() - self._update_counts() - self.logger.debug('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads))) - - @sql - def get_last_timestamp(self): - return self._get_last_timestamp() - - def _get_last_timestamp(self): - from sqlalchemy.sql import func - r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one() - return r.max_timestamp or 0 - - def print_change(self, old_policy, new_policy): - # print what changed between policies - if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta: - self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}') - if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat: - self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}') - if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat: - self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}') - if old_policy.fee_base_msat != new_policy.fee_base_msat: - self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}') - if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths: - self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}') - if old_policy.channel_flags != new_policy.channel_flags: - self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}') - - @sql - def get_info_for_updates(self, payloads): - short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads] - channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() - channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} - return channel_infos - - @sql - def get_policies_for_updates(self, payloads): - out = {} - for payload in payloads: - short_channel_id = payload['short_channel_id'].hex() - start_node = payload['start_node'].hex() - policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none() - if policy: - out[short_channel_id+start_node] = policy - return out - - @profiler - def filter_channel_updates(self, payloads, max_age=None): - orphaned = [] # no channel announcement for channel update - expired = [] # update older than two weeks - deprecated = [] # update older than database entry - good = {} # good updates - to_delete = [] # database entries to delete - # filter orphaned and expired first - known = [] - now = int(time.time()) - channel_infos = self.get_info_for_updates(payloads) - for payload in payloads: - short_channel_id = payload['short_channel_id'] - timestamp = int.from_bytes(payload['timestamp'], "big") - if max_age and now - timestamp > max_age: - expired.append(short_channel_id) - continue - channel_info = channel_infos.get(short_channel_id) - if not channel_info: - orphaned.append(short_channel_id) - continue - flags = int.from_bytes(payload['channel_flags'], 'big') - direction = flags & FLAG_DIRECTION - start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id - payload['start_node'] = bfh(start_node) - known.append(payload) - # compare updates to existing database entries - old_policies = self.get_policies_for_updates(known) - for payload in known: - timestamp = int.from_bytes(payload['timestamp'], "big") - start_node = payload['start_node'] - short_channel_id = payload['short_channel_id'] - key = (short_channel_id+start_node).hex() - old_policy = old_policies.get(key) - if old_policy: - if timestamp <= old_policy.timestamp: - deprecated.append(short_channel_id) - else: - good[key] = payload - to_delete.append(old_policy) - else: - good[key] = payload - good = list(good.values()) - return orphaned, expired, deprecated, good, to_delete - - def add_channel_update(self, payload): - orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload]) - assert len(good) == 1 - self.update_policies(good, to_delete) - - @sql - @profiler - def update_policies(self, to_add, to_delete): - for policy in to_delete: - self.DBSession.delete(policy) - self.DBSession.commit() - for payload in to_add: - policy = Policy.from_msg(payload) - self.DBSession.add(policy) - self.DBSession.commit() - self._update_counts() - - @sql - @profiler - def on_node_announcement(self, msg_payloads): - if type(msg_payloads) is dict: - msg_payloads = [msg_payloads] - old_addr = None - new_nodes = {} - new_addresses = {} - for msg_payload in msg_payloads: - try: - node_info, node_addresses = NodeInfo.from_msg(msg_payload) - except UnknownEvenFeatureBits: - continue - node_id = node_info.node_id - # Ignore node if it has no associated channel (DoS protection) - # FIXME this is slow - expr = or_(ChannelInfo.node1_id==node_id, ChannelInfo.node2_id==node_id) - if len(self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).limit(1).all()) == 0: - #self.logger.info('ignoring orphan node_announcement') - continue - node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none() - if node and node.timestamp >= node_info.timestamp: - continue - node = new_nodes.get(node_id) - if node and node.timestamp >= node_info.timestamp: - continue - new_nodes[node_id] = node_info - for addr in node_addresses: - new_addresses[(addr.node_id,addr.host,addr.port)] = addr - self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) - for node_info in new_nodes.values(): - self.DBSession.add(node_info) - for new_addr in new_addresses.values(): - old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none() - if not old_addr: - self.DBSession.add(new_addr) - self.DBSession.commit() - self._update_counts() - - def get_routing_policy_for_channel(self, start_node_id: bytes, - short_channel_id: bytes) -> Optional[bytes]: - if not start_node_id or not short_channel_id: return None - channel_info = self.get_channel_info(short_channel_id) - if channel_info is not None: - return self.get_policy_for_node(short_channel_id, start_node_id) - msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id)) - if not msg: - return None - return Policy.from_msg(msg) # won't actually be written to DB - - @sql - @profiler - def get_old_policies(self, delta): - timestamp = int(time.time()) - delta - old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp) - return old_policies.distinct().count() - - @sql - @profiler - def prune_old_policies(self, delta): - # note: delete queries are order sensitive - timestamp = int(time.time()) - delta - old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp) - delete_old_channels = ChannelInfo.__table__.delete().where(ChannelInfo.short_channel_id.in_(old_policies)) - delete_old_policies = Policy.__table__.delete().where(Policy.timestamp <= timestamp) - self.DBSession.execute(delete_old_channels) - self.DBSession.execute(delete_old_policies) - self.DBSession.commit() - self._update_counts() - - @sql - @profiler - def get_orphaned_channels(self): - subquery = self.DBSession.query(Policy.short_channel_id) - orphaned = self.DBSession.query(ChannelInfo).filter(not_(ChannelInfo.short_channel_id.in_(subquery))) - return orphaned.count() - - @sql - @profiler - def prune_orphaned_channels(self): - subquery = self.DBSession.query(Policy.short_channel_id) - delete_orphaned = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(subquery))) - self.DBSession.execute(delete_orphaned) - self.DBSession.commit() - self._update_counts() - - def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes): - if not verify_sig_for_channel_update(msg_payload, start_node_id): - return # ignore - short_channel_id = msg_payload['short_channel_id'] - msg_payload['start_node'] = start_node_id - self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload - - @sql - def remove_channel(self, short_channel_id): - r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none() - if not r: - return - self.DBSession.delete(r) - self.DBSession.commit() - - def print_graph(self, full_ids=False): - # used for debugging. - # FIXME there is a race here - iterables could change size from another thread - def other_node_id(node_id, channel_id): - channel_info = self.get_channel_info(channel_id) - if node_id == channel_info.node1_id: - other = channel_info.node2_id - else: - other = channel_info.node1_id - return other if full_ids else other[-4:] - - print_msg('nodes') - for node in self.DBSession.query(NodeInfo).all(): - print_msg(node) - - print_msg('channels') - for channel_info in self.DBSession.query(ChannelInfo).all(): - short_channel_id = channel_info.short_channel_id - node1 = channel_info.node1_id - node2 = channel_info.node2_id - direction1 = self.get_policy_for_node(channel_info, node1) is not None - direction2 = self.get_policy_for_node(channel_info, node2) is not None - if direction1 and direction2: - direction = 'both' - elif direction1: - direction = 'forward' - elif direction2: - direction = 'backward' - else: - direction = 'none' - print_msg('{}: {}, {}, {}' - .format(bh2u(short_channel_id), - bh2u(node1) if full_ids else bh2u(node1[-4:]), - bh2u(node2) if full_ids else bh2u(node2[-4:]), - direction)) - - - @sql - def get_node_addresses(self, node_info): - return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all() - - @sql - @profiler - def load_data(self): - r = self.DBSession.query(ChannelInfo).all() - self._channels = dict([(bfh(x.short_channel_id), x) for x in r]) - r = self.DBSession.query(Policy).filter_by().all() - self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r]) - self._channels_for_node = defaultdict(set) - for channel_info in self._channels.values(): - self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id)) - self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id)) - self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}') - - def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']: - return self._policies.get((node_id, short_channel_id)) - - def get_channel_info(self, channel_id: bytes): - return self._channels.get(channel_id) - - def get_channels_for_node(self, node_id) -> Set[bytes]: - """Returns the set of channels that have node_id as one of the endpoints.""" - return self._channels_for_node.get(node_id) or set() - - class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes), ('short_channel_id', bytes), diff --git a/electrum/network.py b/electrum/network.py @@ -297,15 +297,17 @@ class Network(Logger): self._set_status('disconnected') # lightning network - from . import lnwatcher - from . import lnworker - from . import lnrouter if self.config.get('lightning'): - self.channel_db = lnrouter.ChannelDB(self) + from . import lnwatcher + from . import lnworker + from . import lnrouter + from . import channel_db + self.channel_db = channel_db.ChannelDB(self) self.path_finder = lnrouter.LNPathFinder(self.channel_db) self.lnwatcher = lnwatcher.LNWatcher(self) self.lngossip = lnworker.LNGossip(self) else: + self.channel_db = None self.lnwatcher = None self.lngossip = None