electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

channel_db.py (37455B)


      1 # -*- coding: utf-8 -*-
      2 #
      3 # Electrum - lightweight Bitcoin client
      4 # Copyright (C) 2018 The Electrum developers
      5 #
      6 # Permission is hereby granted, free of charge, to any person
      7 # obtaining a copy of this software and associated documentation files
      8 # (the "Software"), to deal in the Software without restriction,
      9 # including without limitation the rights to use, copy, modify, merge,
     10 # publish, distribute, sublicense, and/or sell copies of the Software,
     11 # and to permit persons to whom the Software is furnished to do so,
     12 # subject to the following conditions:
     13 #
     14 # The above copyright notice and this permission notice shall be
     15 # included in all copies or substantial portions of the Software.
     16 #
     17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
     18 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
     19 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
     20 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
     21 # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
     22 # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
     23 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
     24 # SOFTWARE.
     25 
     26 import time
     27 import random
     28 import os
     29 from collections import defaultdict
     30 from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
     31 import binascii
     32 import base64
     33 import asyncio
     34 import threading
     35 from enum import IntEnum
     36 
     37 from aiorpcx import NetAddress
     38 
     39 from .sql_db import SqlDB, sql
     40 from . import constants, util
     41 from .util import bh2u, profiler, get_headers_dir, is_ip_address, json_normalize
     42 from .logging import Logger
     43 from .lnutil import (LNPeerAddr, format_short_channel_id, ShortChannelID,
     44                      validate_features, IncompatibleOrInsaneFeatures)
     45 from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
     46 from .lnmsg import decode_msg
     47 
     48 if TYPE_CHECKING:
     49     from .network import Network
     50     from .lnchannel import Channel
     51     from .lnrouter import RouteEdge
     52 
     53 
     54 FLAG_DISABLE   = 1 << 1
     55 FLAG_DIRECTION = 1 << 0
     56 
     57 
     58 class ChannelInfo(NamedTuple):
     59     short_channel_id: ShortChannelID
     60     node1_id: bytes
     61     node2_id: bytes
     62     capacity_sat: Optional[int]
     63 
     64     @staticmethod
     65     def from_msg(payload: dict) -> 'ChannelInfo':
     66         features = int.from_bytes(payload['features'], 'big')
     67         validate_features(features)
     68         channel_id = payload['short_channel_id']
     69         node_id_1 = payload['node_id_1']
     70         node_id_2 = payload['node_id_2']
     71         assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
     72         capacity_sat = None
     73         return ChannelInfo(
     74             short_channel_id = ShortChannelID.normalize(channel_id),
     75             node1_id = node_id_1,
     76             node2_id = node_id_2,
     77             capacity_sat = capacity_sat
     78         )
     79 
     80     @staticmethod
     81     def from_raw_msg(raw: bytes) -> 'ChannelInfo':
     82         payload_dict = decode_msg(raw)[1]
     83         return ChannelInfo.from_msg(payload_dict)
     84 
     85     @staticmethod
     86     def from_route_edge(route_edge: 'RouteEdge') -> 'ChannelInfo':
     87         node1_id, node2_id = sorted([route_edge.start_node, route_edge.end_node])
     88         return ChannelInfo(
     89             short_channel_id=route_edge.short_channel_id,
     90             node1_id=node1_id,
     91             node2_id=node2_id,
     92             capacity_sat=None,
     93         )
     94 
     95 
     96 class Policy(NamedTuple):
     97     key: bytes
     98     cltv_expiry_delta: int
     99     htlc_minimum_msat: int
    100     htlc_maximum_msat: Optional[int]
    101     fee_base_msat: int
    102     fee_proportional_millionths: int
    103     channel_flags: int
    104     message_flags: int
    105     timestamp: int
    106 
    107     @staticmethod
    108     def from_msg(payload: dict) -> 'Policy':
    109         return Policy(
    110             key                         = payload['short_channel_id'] + payload['start_node'],
    111             cltv_expiry_delta           = payload['cltv_expiry_delta'],
    112             htlc_minimum_msat           = payload['htlc_minimum_msat'],
    113             htlc_maximum_msat           = payload.get('htlc_maximum_msat', None),
    114             fee_base_msat               = payload['fee_base_msat'],
    115             fee_proportional_millionths = payload['fee_proportional_millionths'],
    116             message_flags               = int.from_bytes(payload['message_flags'], "big"),
    117             channel_flags               = int.from_bytes(payload['channel_flags'], "big"),
    118             timestamp                   = payload['timestamp'],
    119         )
    120 
    121     @staticmethod
    122     def from_raw_msg(key:bytes, raw: bytes) -> 'Policy':
    123         payload = decode_msg(raw)[1]
    124         payload['start_node'] = key[8:]
    125         return Policy.from_msg(payload)
    126 
    127     @staticmethod
    128     def from_route_edge(route_edge: 'RouteEdge') -> 'Policy':
    129         return Policy(
    130             key=route_edge.short_channel_id + route_edge.start_node,
    131             cltv_expiry_delta=route_edge.cltv_expiry_delta,
    132             htlc_minimum_msat=0,
    133             htlc_maximum_msat=None,
    134             fee_base_msat=route_edge.fee_base_msat,
    135             fee_proportional_millionths=route_edge.fee_proportional_millionths,
    136             channel_flags=0,
    137             message_flags=0,
    138             timestamp=0,
    139         )
    140 
    141     def is_disabled(self):
    142         return self.channel_flags & FLAG_DISABLE
    143 
    144     @property
    145     def short_channel_id(self) -> ShortChannelID:
    146         return ShortChannelID.normalize(self.key[0:8])
    147 
    148     @property
    149     def start_node(self) -> bytes:
    150         return self.key[8:]
    151 
    152 
    153 class NodeInfo(NamedTuple):
    154     node_id: bytes
    155     features: int
    156     timestamp: int
    157     alias: str
    158 
    159     @staticmethod
    160     def from_msg(payload) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]:
    161         node_id = payload['node_id']
    162         features = int.from_bytes(payload['features'], "big")
    163         validate_features(features)
    164         addresses = NodeInfo.parse_addresses_field(payload['addresses'])
    165         peer_addrs = []
    166         for host, port in addresses:
    167             try:
    168                 peer_addrs.append(LNPeerAddr(host=host, port=port, pubkey=node_id))
    169             except ValueError:
    170                 pass
    171         alias = payload['alias'].rstrip(b'\x00')
    172         try:
    173             alias = alias.decode('utf8')
    174         except:
    175             alias = ''
    176         timestamp = payload['timestamp']
    177         node_info = NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias)
    178         return node_info, peer_addrs
    179 
    180     @staticmethod
    181     def from_raw_msg(raw: bytes) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]:
    182         payload_dict = decode_msg(raw)[1]
    183         return NodeInfo.from_msg(payload_dict)
    184 
    185     @staticmethod
    186     def parse_addresses_field(addresses_field):
    187         buf = addresses_field
    188         def read(n):
    189             nonlocal buf
    190             data, buf = buf[0:n], buf[n:]
    191             return data
    192         addresses = []
    193         while buf:
    194             atype = ord(read(1))
    195             if atype == 0:
    196                 pass
    197             elif atype == 1:  # IPv4
    198                 ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
    199                 port = int.from_bytes(read(2), 'big')
    200                 if is_ip_address(ipv4_addr) and port != 0:
    201                     addresses.append((ipv4_addr, port))
    202             elif atype == 2:  # IPv6
    203                 ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)])
    204                 ipv6_addr = ipv6_addr.decode('ascii')
    205                 port = int.from_bytes(read(2), 'big')
    206                 if is_ip_address(ipv6_addr) and port != 0:
    207                     addresses.append((ipv6_addr, port))
    208             elif atype == 3:  # onion v2
    209                 host = base64.b32encode(read(10)) + b'.onion'
    210                 host = host.decode('ascii').lower()
    211                 port = int.from_bytes(read(2), 'big')
    212                 addresses.append((host, port))
    213             elif atype == 4:  # onion v3
    214                 host = base64.b32encode(read(35)) + b'.onion'
    215                 host = host.decode('ascii').lower()
    216                 port = int.from_bytes(read(2), 'big')
    217                 addresses.append((host, port))
    218             else:
    219                 # unknown address type
    220                 # we don't know how long it is -> have to escape
    221                 # if there are other addresses we could have parsed later, they are lost.
    222                 break
    223         return addresses
    224 
    225 
    226 class UpdateStatus(IntEnum):
    227     ORPHANED   = 0
    228     EXPIRED    = 1
    229     DEPRECATED = 2
    230     UNCHANGED  = 3
    231     GOOD       = 4
    232 
    233 class CategorizedChannelUpdates(NamedTuple):
    234     orphaned: List    # no channel announcement for channel update
    235     expired: List     # update older than two weeks
    236     deprecated: List  # update older than database entry
    237     unchanged: List   # unchanged policies
    238     good: List        # good updates
    239 
    240 
    241 def get_mychannel_info(short_channel_id: ShortChannelID,
    242                        my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]:
    243     chan = my_channels.get(short_channel_id)
    244     if not chan:
    245         return
    246     ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs())
    247     return ci._replace(capacity_sat=chan.constraints.capacity)
    248 
    249 def get_mychannel_policy(short_channel_id: bytes, node_id: bytes,
    250                          my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[Policy]:
    251     chan = my_channels.get(short_channel_id)  # type: Optional[Channel]
    252     if not chan:
    253         return
    254     if node_id == chan.node_id:  # incoming direction (to us)
    255         remote_update_raw = chan.get_remote_update()
    256         if not remote_update_raw:
    257             return
    258         now = int(time.time())
    259         remote_update_decoded = decode_msg(remote_update_raw)[1]
    260         remote_update_decoded['timestamp'] = now
    261         remote_update_decoded['start_node'] = node_id
    262         return Policy.from_msg(remote_update_decoded)
    263     elif node_id == chan.get_local_pubkey():  # outgoing direction (from us)
    264         local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1]
    265         local_update_decoded['start_node'] = node_id
    266         return Policy.from_msg(local_update_decoded)
    267 
    268 
    269 create_channel_info = """
    270 CREATE TABLE IF NOT EXISTS channel_info (
    271 short_channel_id BLOB(8),
    272 msg BLOB,
    273 PRIMARY KEY(short_channel_id)
    274 )"""
    275 
    276 create_policy = """
    277 CREATE TABLE IF NOT EXISTS policy (
    278 key BLOB(41),
    279 msg BLOB,
    280 PRIMARY KEY(key)
    281 )"""
    282 
    283 create_address = """
    284 CREATE TABLE IF NOT EXISTS address (
    285 node_id BLOB(33),
    286 host STRING(256),
    287 port INTEGER NOT NULL,
    288 timestamp INTEGER,
    289 PRIMARY KEY(node_id, host, port)
    290 )"""
    291 
    292 create_node_info = """
    293 CREATE TABLE IF NOT EXISTS node_info (
    294 node_id BLOB(33),
    295 msg BLOB,
    296 PRIMARY KEY(node_id)
    297 )"""
    298 
    299 
    300 class ChannelDB(SqlDB):
    301 
    302     NUM_MAX_RECENT_PEERS = 20
    303 
    304     def __init__(self, network: 'Network'):
    305         path = os.path.join(get_headers_dir(network.config), 'gossip_db')
    306         super().__init__(network.asyncio_loop, path, commit_interval=100)
    307         self.lock = threading.RLock()
    308         self.num_nodes = 0
    309         self.num_channels = 0
    310         self._channel_updates_for_private_channels = {}  # type: Dict[Tuple[bytes, bytes], dict]
    311         self.ca_verifier = LNChannelVerifier(network, self)
    312 
    313         # initialized in load_data
    314         # note: modify/iterate needs self.lock
    315         self._channels = {}  # type: Dict[ShortChannelID, ChannelInfo]
    316         self._policies = {}  # type: Dict[Tuple[bytes, ShortChannelID], Policy]  # (node_id, scid) -> Policy
    317         self._nodes = {}  # type: Dict[bytes, NodeInfo]  # node_id -> NodeInfo
    318         # node_id -> NetAddress -> timestamp
    319         self._addresses = defaultdict(dict)  # type: Dict[bytes, Dict[NetAddress, int]]
    320         self._channels_for_node = defaultdict(set)  # type: Dict[bytes, Set[ShortChannelID]]
    321         self._recent_peers = []  # type: List[bytes]  # list of node_ids
    322         self._chans_with_0_policies = set()  # type: Set[ShortChannelID]
    323         self._chans_with_1_policies = set()  # type: Set[ShortChannelID]
    324         self._chans_with_2_policies = set()  # type: Set[ShortChannelID]
    325 
    326         self.data_loaded = asyncio.Event()
    327         self.network = network # only for callback
    328 
    329     def update_counts(self):
    330         self.num_nodes = len(self._nodes)
    331         self.num_channels = len(self._channels)
    332         self.num_policies = len(self._policies)
    333         util.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies)
    334         util.trigger_callback('ln_gossip_sync_progress')
    335 
    336     def get_channel_ids(self):
    337         with self.lock:
    338             return set(self._channels.keys())
    339 
    340     def add_recent_peer(self, peer: LNPeerAddr):
    341         now = int(time.time())
    342         node_id = peer.pubkey
    343         with self.lock:
    344             self._addresses[node_id][peer.net_addr()] = now
    345             # list is ordered
    346             if node_id in self._recent_peers:
    347                 self._recent_peers.remove(node_id)
    348             self._recent_peers.insert(0, node_id)
    349             self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS]
    350         self._db_save_node_address(peer, now)
    351 
    352     def get_200_randomly_sorted_nodes_not_in(self, node_ids):
    353         with self.lock:
    354             unshuffled = set(self._nodes.keys()) - node_ids
    355         return random.sample(unshuffled, min(200, len(unshuffled)))
    356 
    357     def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
    358         """Returns latest address we successfully connected to, for given node."""
    359         addr_to_ts = self._addresses.get(node_id)
    360         if not addr_to_ts:
    361             return None
    362         addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0]
    363         try:
    364             return LNPeerAddr(str(addr.host), addr.port, node_id)
    365         except ValueError:
    366             return None
    367 
    368     def get_recent_peers(self):
    369         if not self.data_loaded.is_set():
    370             raise Exception("channelDB data not loaded yet!")
    371         with self.lock:
    372             ret = [self.get_last_good_address(node_id)
    373                    for node_id in self._recent_peers]
    374             return ret
    375 
    376     # note: currently channel announcements are trusted by default (trusted=True);
    377     #       they are not SPV-verified. Verifying them would make the gossip sync
    378     #       even slower; especially as servers will start throttling us.
    379     #       It would probably put significant strain on servers if all clients
    380     #       verified the complete gossip.
    381     def add_channel_announcement(self, msg_payloads, *, trusted=True):
    382         # note: signatures have already been verified.
    383         if type(msg_payloads) is dict:
    384             msg_payloads = [msg_payloads]
    385         added = 0
    386         for msg in msg_payloads:
    387             short_channel_id = ShortChannelID(msg['short_channel_id'])
    388             if short_channel_id in self._channels:
    389                 continue
    390             if constants.net.rev_genesis_bytes() != msg['chain_hash']:
    391                 self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash'])))
    392                 continue
    393             try:
    394                 channel_info = ChannelInfo.from_msg(msg)
    395             except IncompatibleOrInsaneFeatures as e:
    396                 self.logger.info(f"unknown or insane feature bits: {e!r}")
    397                 continue
    398             if trusted:
    399                 added += 1
    400                 self.add_verified_channel_info(msg)
    401             else:
    402                 added += self.ca_verifier.add_new_channel_info(short_channel_id, msg)
    403 
    404         self.update_counts()
    405         self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads)))
    406 
    407     def add_verified_channel_info(self, msg: dict, *, capacity_sat: int = None) -> None:
    408         try:
    409             channel_info = ChannelInfo.from_msg(msg)
    410         except IncompatibleOrInsaneFeatures:
    411             return
    412         channel_info = channel_info._replace(capacity_sat=capacity_sat)
    413         with self.lock:
    414             self._channels[channel_info.short_channel_id] = channel_info
    415             self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
    416             self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
    417         self._update_num_policies_for_chan(channel_info.short_channel_id)
    418         if 'raw' in msg:
    419             self._db_save_channel(channel_info.short_channel_id, msg['raw'])
    420 
    421     def policy_changed(self, old_policy: Policy, new_policy: Policy, verbose: bool) -> bool:
    422         changed = False
    423         if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
    424             changed |= True
    425             if verbose:
    426                 self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
    427         if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
    428             changed |= True
    429             if verbose:
    430                 self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
    431         if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
    432             changed |= True
    433             if verbose:
    434                 self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
    435         if old_policy.fee_base_msat != new_policy.fee_base_msat:
    436             changed |= True
    437             if verbose:
    438                 self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
    439         if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
    440             changed |= True
    441             if verbose:
    442                 self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
    443         if old_policy.channel_flags != new_policy.channel_flags:
    444             changed |= True
    445             if verbose:
    446                 self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
    447         if old_policy.message_flags != new_policy.message_flags:
    448             changed |= True
    449             if verbose:
    450                 self.logger.info(f'message_flags: {old_policy.message_flags} -> {new_policy.message_flags}')
    451         if not changed and verbose:
    452             self.logger.info(f'policy unchanged: {old_policy.timestamp} -> {new_policy.timestamp}')
    453         return changed
    454 
    455     def add_channel_update(self, payload, max_age=None, verify=False, verbose=True):
    456         now = int(time.time())
    457         short_channel_id = ShortChannelID(payload['short_channel_id'])
    458         timestamp = payload['timestamp']
    459         if max_age and now - timestamp > max_age:
    460             return UpdateStatus.EXPIRED
    461         if timestamp - now > 60:
    462             return UpdateStatus.DEPRECATED
    463         channel_info = self._channels.get(short_channel_id)
    464         if not channel_info:
    465             return UpdateStatus.ORPHANED
    466         flags = int.from_bytes(payload['channel_flags'], 'big')
    467         direction = flags & FLAG_DIRECTION
    468         start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
    469         payload['start_node'] = start_node
    470         # compare updates to existing database entries
    471         timestamp = payload['timestamp']
    472         start_node = payload['start_node']
    473         short_channel_id = ShortChannelID(payload['short_channel_id'])
    474         key = (start_node, short_channel_id)
    475         old_policy = self._policies.get(key)
    476         if old_policy and timestamp <= old_policy.timestamp + 60:
    477             return UpdateStatus.DEPRECATED
    478         if verify:
    479             self.verify_channel_update(payload)
    480         policy = Policy.from_msg(payload)
    481         with self.lock:
    482             self._policies[key] = policy
    483         self._update_num_policies_for_chan(short_channel_id)
    484         if 'raw' in payload:
    485             self._db_save_policy(policy.key, payload['raw'])
    486         if old_policy and not self.policy_changed(old_policy, policy, verbose):
    487             return UpdateStatus.UNCHANGED
    488         else:
    489             return UpdateStatus.GOOD
    490 
    491     def add_channel_updates(self, payloads, max_age=None) -> CategorizedChannelUpdates:
    492         orphaned = []
    493         expired = []
    494         deprecated = []
    495         unchanged = []
    496         good = []
    497         for payload in payloads:
    498             r = self.add_channel_update(payload, max_age=max_age, verbose=False)
    499             if r == UpdateStatus.ORPHANED:
    500                 orphaned.append(payload)
    501             elif r == UpdateStatus.EXPIRED:
    502                 expired.append(payload)
    503             elif r == UpdateStatus.DEPRECATED:
    504                 deprecated.append(payload)
    505             elif r == UpdateStatus.UNCHANGED:
    506                 unchanged.append(payload)
    507             elif r == UpdateStatus.GOOD:
    508                 good.append(payload)
    509         self.update_counts()
    510         return CategorizedChannelUpdates(
    511             orphaned=orphaned,
    512             expired=expired,
    513             deprecated=deprecated,
    514             unchanged=unchanged,
    515             good=good)
    516 
    517 
    518     def create_database(self):
    519         c = self.conn.cursor()
    520         c.execute(create_node_info)
    521         c.execute(create_address)
    522         c.execute(create_policy)
    523         c.execute(create_channel_info)
    524         self.conn.commit()
    525 
    526     @sql
    527     def _db_save_policy(self, key: bytes, msg: bytes):
    528         # 'msg' is a 'channel_update' message
    529         c = self.conn.cursor()
    530         c.execute("""REPLACE INTO policy (key, msg) VALUES (?,?)""", [key, msg])
    531 
    532     @sql
    533     def _db_delete_policy(self, node_id: bytes, short_channel_id: ShortChannelID):
    534         key = short_channel_id + node_id
    535         c = self.conn.cursor()
    536         c.execute("""DELETE FROM policy WHERE key=?""", (key,))
    537 
    538     @sql
    539     def _db_save_channel(self, short_channel_id: ShortChannelID, msg: bytes):
    540         # 'msg' is a 'channel_announcement' message
    541         c = self.conn.cursor()
    542         c.execute("REPLACE INTO channel_info (short_channel_id, msg) VALUES (?,?)", [short_channel_id, msg])
    543 
    544     @sql
    545     def _db_delete_channel(self, short_channel_id: ShortChannelID):
    546         c = self.conn.cursor()
    547         c.execute("""DELETE FROM channel_info WHERE short_channel_id=?""", (short_channel_id,))
    548 
    549     @sql
    550     def _db_save_node_info(self, node_id: bytes, msg: bytes):
    551         # 'msg' is a 'node_announcement' message
    552         c = self.conn.cursor()
    553         c.execute("REPLACE INTO node_info (node_id, msg) VALUES (?,?)", [node_id, msg])
    554 
    555     @sql
    556     def _db_save_node_address(self, peer: LNPeerAddr, timestamp: int):
    557         c = self.conn.cursor()
    558         c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)",
    559                   (peer.pubkey, peer.host, peer.port, timestamp))
    560 
    561     @sql
    562     def _db_save_node_addresses(self, node_addresses: Sequence[LNPeerAddr]):
    563         c = self.conn.cursor()
    564         for addr in node_addresses:
    565             c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.pubkey, addr.host, addr.port))
    566             r = c.fetchall()
    567             if r == []:
    568                 c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.pubkey, addr.host, addr.port, 0))
    569 
    570     def verify_channel_update(self, payload):
    571         short_channel_id = payload['short_channel_id']
    572         short_channel_id = ShortChannelID(short_channel_id)
    573         if constants.net.rev_genesis_bytes() != payload['chain_hash']:
    574             raise Exception('wrong chain hash')
    575         if not verify_sig_for_channel_update(payload, payload['start_node']):
    576             raise Exception(f'failed verifying channel update for {short_channel_id}')
    577 
    578     def add_node_announcement(self, msg_payloads):
    579         # note: signatures have already been verified.
    580         if type(msg_payloads) is dict:
    581             msg_payloads = [msg_payloads]
    582         new_nodes = {}
    583         for msg_payload in msg_payloads:
    584             try:
    585                 node_info, node_addresses = NodeInfo.from_msg(msg_payload)
    586             except IncompatibleOrInsaneFeatures:
    587                 continue
    588             node_id = node_info.node_id
    589             # Ignore node if it has no associated channel (DoS protection)
    590             if node_id not in self._channels_for_node:
    591                 #self.logger.info('ignoring orphan node_announcement')
    592                 continue
    593             node = self._nodes.get(node_id)
    594             if node and node.timestamp >= node_info.timestamp:
    595                 continue
    596             node = new_nodes.get(node_id)
    597             if node and node.timestamp >= node_info.timestamp:
    598                 continue
    599             # save
    600             with self.lock:
    601                 self._nodes[node_id] = node_info
    602             if 'raw' in msg_payload:
    603                 self._db_save_node_info(node_id, msg_payload['raw'])
    604             with self.lock:
    605                 for addr in node_addresses:
    606                     net_addr = NetAddress(addr.host, addr.port)
    607                     self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0
    608             self._db_save_node_addresses(node_addresses)
    609 
    610         self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
    611         self.update_counts()
    612 
    613     def get_old_policies(self, delta) -> Sequence[Tuple[bytes, ShortChannelID]]:
    614         with self.lock:
    615             _policies = self._policies.copy()
    616         now = int(time.time())
    617         return list(k for k, v in _policies.items() if v.timestamp <= now - delta)
    618 
    619     def prune_old_policies(self, delta):
    620         old_policies = self.get_old_policies(delta)
    621         if old_policies:
    622             for key in old_policies:
    623                 node_id, scid = key
    624                 with self.lock:
    625                     self._policies.pop(key)
    626                 self._db_delete_policy(*key)
    627                 self._update_num_policies_for_chan(scid)
    628             self.update_counts()
    629             self.logger.info(f'Deleting {len(old_policies)} old policies')
    630 
    631     def prune_orphaned_channels(self):
    632         with self.lock:
    633             orphaned_chans = self._chans_with_0_policies.copy()
    634         if orphaned_chans:
    635             for short_channel_id in orphaned_chans:
    636                 self.remove_channel(short_channel_id)
    637             self.update_counts()
    638             self.logger.info(f'Deleting {len(orphaned_chans)} orphaned channels')
    639 
    640     def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes) -> bool:
    641         """Returns True iff the channel update was successfully added and it was different than
    642         what we had before (if any).
    643         """
    644         if not verify_sig_for_channel_update(msg_payload, start_node_id):
    645             return False  # ignore
    646         short_channel_id = ShortChannelID(msg_payload['short_channel_id'])
    647         msg_payload['start_node'] = start_node_id
    648         key = (start_node_id, short_channel_id)
    649         prev_chanupd = self._channel_updates_for_private_channels.get(key)
    650         if prev_chanupd == msg_payload:
    651             return False
    652         self._channel_updates_for_private_channels[key] = msg_payload
    653         return True
    654 
    655     def remove_channel(self, short_channel_id: ShortChannelID):
    656         # FIXME what about rm-ing policies?
    657         with self.lock:
    658             channel_info = self._channels.pop(short_channel_id, None)
    659             if channel_info:
    660                 self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
    661                 self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
    662         self._update_num_policies_for_chan(short_channel_id)
    663         # delete from database
    664         self._db_delete_channel(short_channel_id)
    665 
    666     def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]:
    667         """Returns list of (host, port, timestamp)."""
    668         addr_to_ts = self._addresses.get(node_id)
    669         if not addr_to_ts:
    670             return []
    671         return [(str(net_addr.host), net_addr.port, ts)
    672                 for net_addr, ts in addr_to_ts.items()]
    673 
    674     @sql
    675     @profiler
    676     def load_data(self):
    677         if self.data_loaded.is_set():
    678             return
    679         # Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow.
    680         c = self.conn.cursor()
    681         c.execute("""SELECT * FROM address""")
    682         for x in c:
    683             node_id, host, port, timestamp = x
    684             try:
    685                 net_addr = NetAddress(host, port)
    686             except Exception:
    687                 continue
    688             self._addresses[node_id][net_addr] = int(timestamp or 0)
    689         def newest_ts_for_node_id(node_id):
    690             newest_ts = 0
    691             for addr, ts in self._addresses[node_id].items():
    692                 newest_ts = max(newest_ts, ts)
    693             return newest_ts
    694         sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True)
    695         self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
    696         c.execute("""SELECT * FROM channel_info""")
    697         for short_channel_id, msg in c:
    698             try:
    699                 ci = ChannelInfo.from_raw_msg(msg)
    700             except IncompatibleOrInsaneFeatures:
    701                 continue
    702             self._channels[ShortChannelID.normalize(short_channel_id)] = ci
    703         c.execute("""SELECT * FROM node_info""")
    704         for node_id, msg in c:
    705             try:
    706                 node_info, node_addresses = NodeInfo.from_raw_msg(msg)
    707             except IncompatibleOrInsaneFeatures:
    708                 continue
    709             # don't load node_addresses because they dont have timestamps
    710             self._nodes[node_id] = node_info
    711         c.execute("""SELECT * FROM policy""")
    712         for key, msg in c:
    713             p = Policy.from_raw_msg(key, msg)
    714             self._policies[(p.start_node, p.short_channel_id)] = p
    715         for channel_info in self._channels.values():
    716             self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
    717             self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
    718             self._update_num_policies_for_chan(channel_info.short_channel_id)
    719         self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
    720         self.update_counts()
    721         (nchans_with_0p, nchans_with_1p, nchans_with_2p) = self.get_num_channels_partitioned_by_policy_count()
    722         self.logger.info(f'num_channels_partitioned_by_policy_count. '
    723                          f'0p: {nchans_with_0p}, 1p: {nchans_with_1p}, 2p: {nchans_with_2p}')
    724         self.data_loaded.set()
    725         util.trigger_callback('gossip_db_loaded')
    726 
    727     def _update_num_policies_for_chan(self, short_channel_id: ShortChannelID) -> None:
    728         channel_info = self.get_channel_info(short_channel_id)
    729         if channel_info is None:
    730             with self.lock:
    731                 self._chans_with_0_policies.discard(short_channel_id)
    732                 self._chans_with_1_policies.discard(short_channel_id)
    733                 self._chans_with_2_policies.discard(short_channel_id)
    734             return
    735         p1 = self.get_policy_for_node(short_channel_id, channel_info.node1_id)
    736         p2 = self.get_policy_for_node(short_channel_id, channel_info.node2_id)
    737         with self.lock:
    738             self._chans_with_0_policies.discard(short_channel_id)
    739             self._chans_with_1_policies.discard(short_channel_id)
    740             self._chans_with_2_policies.discard(short_channel_id)
    741             if p1 is not None and p2 is not None:
    742                 self._chans_with_2_policies.add(short_channel_id)
    743             elif p1 is None and p2 is None:
    744                 self._chans_with_0_policies.add(short_channel_id)
    745             else:
    746                 self._chans_with_1_policies.add(short_channel_id)
    747 
    748     def get_num_channels_partitioned_by_policy_count(self) -> Tuple[int, int, int]:
    749         nchans_with_0p = len(self._chans_with_0_policies)
    750         nchans_with_1p = len(self._chans_with_1_policies)
    751         nchans_with_2p = len(self._chans_with_2_policies)
    752         return nchans_with_0p, nchans_with_1p, nchans_with_2p
    753 
    754     def get_policy_for_node(
    755             self,
    756             short_channel_id: bytes,
    757             node_id: bytes,
    758             *,
    759             my_channels: Dict[ShortChannelID, 'Channel'] = None,
    760             private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
    761     ) -> Optional['Policy']:
    762         channel_info = self.get_channel_info(short_channel_id)
    763         if channel_info is not None:  # publicly announced channel
    764             policy = self._policies.get((node_id, short_channel_id))
    765             if policy:
    766                 return policy
    767         else:  # private channel
    768             chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id))
    769             if chan_upd_dict:
    770                 return Policy.from_msg(chan_upd_dict)
    771         # check if it's one of our own channels
    772         if my_channels:
    773             policy = get_mychannel_policy(short_channel_id, node_id, my_channels)
    774             if policy:
    775                 return policy
    776         if private_route_edges:
    777             route_edge = private_route_edges.get(short_channel_id, None)
    778             if route_edge:
    779                 return Policy.from_route_edge(route_edge)
    780 
    781     def get_channel_info(
    782             self,
    783             short_channel_id: ShortChannelID,
    784             *,
    785             my_channels: Dict[ShortChannelID, 'Channel'] = None,
    786             private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
    787     ) -> Optional[ChannelInfo]:
    788         ret = self._channels.get(short_channel_id)
    789         if ret:
    790             return ret
    791         # check if it's one of our own channels
    792         if my_channels:
    793             channel_info = get_mychannel_info(short_channel_id, my_channels)
    794             if channel_info:
    795                 return channel_info
    796         if private_route_edges:
    797             route_edge = private_route_edges.get(short_channel_id)
    798             if route_edge:
    799                 return ChannelInfo.from_route_edge(route_edge)
    800 
    801     def get_channels_for_node(
    802             self,
    803             node_id: bytes,
    804             *,
    805             my_channels: Dict[ShortChannelID, 'Channel'] = None,
    806             private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
    807     ) -> Set[bytes]:
    808         """Returns the set of short channel IDs where node_id is one of the channel participants."""
    809         if not self.data_loaded.is_set():
    810             raise Exception("channelDB data not loaded yet!")
    811         relevant_channels = self._channels_for_node.get(node_id) or set()
    812         relevant_channels = set(relevant_channels)  # copy
    813         # add our own channels  # TODO maybe slow?
    814         if my_channels:
    815             for chan in my_channels.values():
    816                 if node_id in (chan.node_id, chan.get_local_pubkey()):
    817                     relevant_channels.add(chan.short_channel_id)
    818         # add private channels  # TODO maybe slow?
    819         if private_route_edges:
    820             for route_edge in private_route_edges.values():
    821                 if node_id in (route_edge.start_node, route_edge.end_node):
    822                     relevant_channels.add(route_edge.short_channel_id)
    823         return relevant_channels
    824 
    825     def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *,
    826                               my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[Tuple[bytes, bytes]]:
    827         channel_info = self.get_channel_info(short_channel_id)
    828         if channel_info is not None:  # publicly announced channel
    829             return channel_info.node1_id, channel_info.node2_id
    830         # check if it's one of our own channels
    831         if not my_channels:
    832             return
    833         chan = my_channels.get(short_channel_id)  # type: Optional[Channel]
    834         if not chan:
    835             return
    836         return chan.get_local_pubkey(), chan.node_id
    837 
    838     def get_node_info_for_node_id(self, node_id: bytes) -> Optional['NodeInfo']:
    839         return self._nodes.get(node_id)
    840 
    841     def get_node_infos(self) -> Dict[bytes, NodeInfo]:
    842         with self.lock:
    843             return self._nodes.copy()
    844 
    845     def get_node_policies(self) -> Dict[Tuple[bytes, ShortChannelID], Policy]:
    846         with self.lock:
    847             return self._policies.copy()
    848 
    849     def to_dict(self) -> dict:
    850         """ Generates a graph representation in terms of a dictionary.
    851 
    852         The dictionary contains only native python types and can be encoded
    853         to json.
    854         """
    855         with self.lock:
    856             graph = {'nodes': [], 'channels': []}
    857 
    858             # gather nodes
    859             for pk, nodeinfo in self._nodes.items():
    860                 # use _asdict() to convert NamedTuples to json encodable dicts
    861                 graph['nodes'].append(
    862                     nodeinfo._asdict(),
    863                 )
    864                 graph['nodes'][-1]['addresses'] = [
    865                     {'host': str(addr.host), 'port': addr.port, 'timestamp': ts}
    866                     for addr, ts in self._addresses[pk].items()
    867                 ]
    868 
    869             # gather channels
    870             for cid, channelinfo in self._channels.items():
    871                 graph['channels'].append(
    872                     channelinfo._asdict(),
    873                 )
    874                 policy1 = self._policies.get(
    875                     (channelinfo.node1_id, channelinfo.short_channel_id))
    876                 policy2 = self._policies.get(
    877                     (channelinfo.node2_id, channelinfo.short_channel_id))
    878                 graph['channels'][-1]['policy1'] = policy1._asdict() if policy1 else None
    879                 graph['channels'][-1]['policy2'] = policy2._asdict() if policy2 else None
    880 
    881         # need to use json_normalize otherwise json encoding in rpc server fails
    882         graph = json_normalize(graph)
    883         return graph