channel_db.py (37455B)
1 # -*- coding: utf-8 -*- 2 # 3 # Electrum - lightweight Bitcoin client 4 # Copyright (C) 2018 The Electrum developers 5 # 6 # Permission is hereby granted, free of charge, to any person 7 # obtaining a copy of this software and associated documentation files 8 # (the "Software"), to deal in the Software without restriction, 9 # including without limitation the rights to use, copy, modify, merge, 10 # publish, distribute, sublicense, and/or sell copies of the Software, 11 # and to permit persons to whom the Software is furnished to do so, 12 # subject to the following conditions: 13 # 14 # The above copyright notice and this permission notice shall be 15 # included in all copies or substantial portions of the Software. 16 # 17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 21 # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 22 # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 23 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 # SOFTWARE. 25 26 import time 27 import random 28 import os 29 from collections import defaultdict 30 from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set 31 import binascii 32 import base64 33 import asyncio 34 import threading 35 from enum import IntEnum 36 37 from aiorpcx import NetAddress 38 39 from .sql_db import SqlDB, sql 40 from . import constants, util 41 from .util import bh2u, profiler, get_headers_dir, is_ip_address, json_normalize 42 from .logging import Logger 43 from .lnutil import (LNPeerAddr, format_short_channel_id, ShortChannelID, 44 validate_features, IncompatibleOrInsaneFeatures) 45 from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update 46 from .lnmsg import decode_msg 47 48 if TYPE_CHECKING: 49 from .network import Network 50 from .lnchannel import Channel 51 from .lnrouter import RouteEdge 52 53 54 FLAG_DISABLE = 1 << 1 55 FLAG_DIRECTION = 1 << 0 56 57 58 class ChannelInfo(NamedTuple): 59 short_channel_id: ShortChannelID 60 node1_id: bytes 61 node2_id: bytes 62 capacity_sat: Optional[int] 63 64 @staticmethod 65 def from_msg(payload: dict) -> 'ChannelInfo': 66 features = int.from_bytes(payload['features'], 'big') 67 validate_features(features) 68 channel_id = payload['short_channel_id'] 69 node_id_1 = payload['node_id_1'] 70 node_id_2 = payload['node_id_2'] 71 assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] 72 capacity_sat = None 73 return ChannelInfo( 74 short_channel_id = ShortChannelID.normalize(channel_id), 75 node1_id = node_id_1, 76 node2_id = node_id_2, 77 capacity_sat = capacity_sat 78 ) 79 80 @staticmethod 81 def from_raw_msg(raw: bytes) -> 'ChannelInfo': 82 payload_dict = decode_msg(raw)[1] 83 return ChannelInfo.from_msg(payload_dict) 84 85 @staticmethod 86 def from_route_edge(route_edge: 'RouteEdge') -> 'ChannelInfo': 87 node1_id, node2_id = sorted([route_edge.start_node, route_edge.end_node]) 88 return ChannelInfo( 89 short_channel_id=route_edge.short_channel_id, 90 node1_id=node1_id, 91 node2_id=node2_id, 92 capacity_sat=None, 93 ) 94 95 96 class Policy(NamedTuple): 97 key: bytes 98 cltv_expiry_delta: int 99 htlc_minimum_msat: int 100 htlc_maximum_msat: Optional[int] 101 fee_base_msat: int 102 fee_proportional_millionths: int 103 channel_flags: int 104 message_flags: int 105 timestamp: int 106 107 @staticmethod 108 def from_msg(payload: dict) -> 'Policy': 109 return Policy( 110 key = payload['short_channel_id'] + payload['start_node'], 111 cltv_expiry_delta = payload['cltv_expiry_delta'], 112 htlc_minimum_msat = payload['htlc_minimum_msat'], 113 htlc_maximum_msat = payload.get('htlc_maximum_msat', None), 114 fee_base_msat = payload['fee_base_msat'], 115 fee_proportional_millionths = payload['fee_proportional_millionths'], 116 message_flags = int.from_bytes(payload['message_flags'], "big"), 117 channel_flags = int.from_bytes(payload['channel_flags'], "big"), 118 timestamp = payload['timestamp'], 119 ) 120 121 @staticmethod 122 def from_raw_msg(key:bytes, raw: bytes) -> 'Policy': 123 payload = decode_msg(raw)[1] 124 payload['start_node'] = key[8:] 125 return Policy.from_msg(payload) 126 127 @staticmethod 128 def from_route_edge(route_edge: 'RouteEdge') -> 'Policy': 129 return Policy( 130 key=route_edge.short_channel_id + route_edge.start_node, 131 cltv_expiry_delta=route_edge.cltv_expiry_delta, 132 htlc_minimum_msat=0, 133 htlc_maximum_msat=None, 134 fee_base_msat=route_edge.fee_base_msat, 135 fee_proportional_millionths=route_edge.fee_proportional_millionths, 136 channel_flags=0, 137 message_flags=0, 138 timestamp=0, 139 ) 140 141 def is_disabled(self): 142 return self.channel_flags & FLAG_DISABLE 143 144 @property 145 def short_channel_id(self) -> ShortChannelID: 146 return ShortChannelID.normalize(self.key[0:8]) 147 148 @property 149 def start_node(self) -> bytes: 150 return self.key[8:] 151 152 153 class NodeInfo(NamedTuple): 154 node_id: bytes 155 features: int 156 timestamp: int 157 alias: str 158 159 @staticmethod 160 def from_msg(payload) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]: 161 node_id = payload['node_id'] 162 features = int.from_bytes(payload['features'], "big") 163 validate_features(features) 164 addresses = NodeInfo.parse_addresses_field(payload['addresses']) 165 peer_addrs = [] 166 for host, port in addresses: 167 try: 168 peer_addrs.append(LNPeerAddr(host=host, port=port, pubkey=node_id)) 169 except ValueError: 170 pass 171 alias = payload['alias'].rstrip(b'\x00') 172 try: 173 alias = alias.decode('utf8') 174 except: 175 alias = '' 176 timestamp = payload['timestamp'] 177 node_info = NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias) 178 return node_info, peer_addrs 179 180 @staticmethod 181 def from_raw_msg(raw: bytes) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]: 182 payload_dict = decode_msg(raw)[1] 183 return NodeInfo.from_msg(payload_dict) 184 185 @staticmethod 186 def parse_addresses_field(addresses_field): 187 buf = addresses_field 188 def read(n): 189 nonlocal buf 190 data, buf = buf[0:n], buf[n:] 191 return data 192 addresses = [] 193 while buf: 194 atype = ord(read(1)) 195 if atype == 0: 196 pass 197 elif atype == 1: # IPv4 198 ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4))) 199 port = int.from_bytes(read(2), 'big') 200 if is_ip_address(ipv4_addr) and port != 0: 201 addresses.append((ipv4_addr, port)) 202 elif atype == 2: # IPv6 203 ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)]) 204 ipv6_addr = ipv6_addr.decode('ascii') 205 port = int.from_bytes(read(2), 'big') 206 if is_ip_address(ipv6_addr) and port != 0: 207 addresses.append((ipv6_addr, port)) 208 elif atype == 3: # onion v2 209 host = base64.b32encode(read(10)) + b'.onion' 210 host = host.decode('ascii').lower() 211 port = int.from_bytes(read(2), 'big') 212 addresses.append((host, port)) 213 elif atype == 4: # onion v3 214 host = base64.b32encode(read(35)) + b'.onion' 215 host = host.decode('ascii').lower() 216 port = int.from_bytes(read(2), 'big') 217 addresses.append((host, port)) 218 else: 219 # unknown address type 220 # we don't know how long it is -> have to escape 221 # if there are other addresses we could have parsed later, they are lost. 222 break 223 return addresses 224 225 226 class UpdateStatus(IntEnum): 227 ORPHANED = 0 228 EXPIRED = 1 229 DEPRECATED = 2 230 UNCHANGED = 3 231 GOOD = 4 232 233 class CategorizedChannelUpdates(NamedTuple): 234 orphaned: List # no channel announcement for channel update 235 expired: List # update older than two weeks 236 deprecated: List # update older than database entry 237 unchanged: List # unchanged policies 238 good: List # good updates 239 240 241 def get_mychannel_info(short_channel_id: ShortChannelID, 242 my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]: 243 chan = my_channels.get(short_channel_id) 244 if not chan: 245 return 246 ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs()) 247 return ci._replace(capacity_sat=chan.constraints.capacity) 248 249 def get_mychannel_policy(short_channel_id: bytes, node_id: bytes, 250 my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[Policy]: 251 chan = my_channels.get(short_channel_id) # type: Optional[Channel] 252 if not chan: 253 return 254 if node_id == chan.node_id: # incoming direction (to us) 255 remote_update_raw = chan.get_remote_update() 256 if not remote_update_raw: 257 return 258 now = int(time.time()) 259 remote_update_decoded = decode_msg(remote_update_raw)[1] 260 remote_update_decoded['timestamp'] = now 261 remote_update_decoded['start_node'] = node_id 262 return Policy.from_msg(remote_update_decoded) 263 elif node_id == chan.get_local_pubkey(): # outgoing direction (from us) 264 local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1] 265 local_update_decoded['start_node'] = node_id 266 return Policy.from_msg(local_update_decoded) 267 268 269 create_channel_info = """ 270 CREATE TABLE IF NOT EXISTS channel_info ( 271 short_channel_id BLOB(8), 272 msg BLOB, 273 PRIMARY KEY(short_channel_id) 274 )""" 275 276 create_policy = """ 277 CREATE TABLE IF NOT EXISTS policy ( 278 key BLOB(41), 279 msg BLOB, 280 PRIMARY KEY(key) 281 )""" 282 283 create_address = """ 284 CREATE TABLE IF NOT EXISTS address ( 285 node_id BLOB(33), 286 host STRING(256), 287 port INTEGER NOT NULL, 288 timestamp INTEGER, 289 PRIMARY KEY(node_id, host, port) 290 )""" 291 292 create_node_info = """ 293 CREATE TABLE IF NOT EXISTS node_info ( 294 node_id BLOB(33), 295 msg BLOB, 296 PRIMARY KEY(node_id) 297 )""" 298 299 300 class ChannelDB(SqlDB): 301 302 NUM_MAX_RECENT_PEERS = 20 303 304 def __init__(self, network: 'Network'): 305 path = os.path.join(get_headers_dir(network.config), 'gossip_db') 306 super().__init__(network.asyncio_loop, path, commit_interval=100) 307 self.lock = threading.RLock() 308 self.num_nodes = 0 309 self.num_channels = 0 310 self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] 311 self.ca_verifier = LNChannelVerifier(network, self) 312 313 # initialized in load_data 314 # note: modify/iterate needs self.lock 315 self._channels = {} # type: Dict[ShortChannelID, ChannelInfo] 316 self._policies = {} # type: Dict[Tuple[bytes, ShortChannelID], Policy] # (node_id, scid) -> Policy 317 self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo 318 # node_id -> NetAddress -> timestamp 319 self._addresses = defaultdict(dict) # type: Dict[bytes, Dict[NetAddress, int]] 320 self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]] 321 self._recent_peers = [] # type: List[bytes] # list of node_ids 322 self._chans_with_0_policies = set() # type: Set[ShortChannelID] 323 self._chans_with_1_policies = set() # type: Set[ShortChannelID] 324 self._chans_with_2_policies = set() # type: Set[ShortChannelID] 325 326 self.data_loaded = asyncio.Event() 327 self.network = network # only for callback 328 329 def update_counts(self): 330 self.num_nodes = len(self._nodes) 331 self.num_channels = len(self._channels) 332 self.num_policies = len(self._policies) 333 util.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies) 334 util.trigger_callback('ln_gossip_sync_progress') 335 336 def get_channel_ids(self): 337 with self.lock: 338 return set(self._channels.keys()) 339 340 def add_recent_peer(self, peer: LNPeerAddr): 341 now = int(time.time()) 342 node_id = peer.pubkey 343 with self.lock: 344 self._addresses[node_id][peer.net_addr()] = now 345 # list is ordered 346 if node_id in self._recent_peers: 347 self._recent_peers.remove(node_id) 348 self._recent_peers.insert(0, node_id) 349 self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS] 350 self._db_save_node_address(peer, now) 351 352 def get_200_randomly_sorted_nodes_not_in(self, node_ids): 353 with self.lock: 354 unshuffled = set(self._nodes.keys()) - node_ids 355 return random.sample(unshuffled, min(200, len(unshuffled))) 356 357 def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]: 358 """Returns latest address we successfully connected to, for given node.""" 359 addr_to_ts = self._addresses.get(node_id) 360 if not addr_to_ts: 361 return None 362 addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0] 363 try: 364 return LNPeerAddr(str(addr.host), addr.port, node_id) 365 except ValueError: 366 return None 367 368 def get_recent_peers(self): 369 if not self.data_loaded.is_set(): 370 raise Exception("channelDB data not loaded yet!") 371 with self.lock: 372 ret = [self.get_last_good_address(node_id) 373 for node_id in self._recent_peers] 374 return ret 375 376 # note: currently channel announcements are trusted by default (trusted=True); 377 # they are not SPV-verified. Verifying them would make the gossip sync 378 # even slower; especially as servers will start throttling us. 379 # It would probably put significant strain on servers if all clients 380 # verified the complete gossip. 381 def add_channel_announcement(self, msg_payloads, *, trusted=True): 382 # note: signatures have already been verified. 383 if type(msg_payloads) is dict: 384 msg_payloads = [msg_payloads] 385 added = 0 386 for msg in msg_payloads: 387 short_channel_id = ShortChannelID(msg['short_channel_id']) 388 if short_channel_id in self._channels: 389 continue 390 if constants.net.rev_genesis_bytes() != msg['chain_hash']: 391 self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash']))) 392 continue 393 try: 394 channel_info = ChannelInfo.from_msg(msg) 395 except IncompatibleOrInsaneFeatures as e: 396 self.logger.info(f"unknown or insane feature bits: {e!r}") 397 continue 398 if trusted: 399 added += 1 400 self.add_verified_channel_info(msg) 401 else: 402 added += self.ca_verifier.add_new_channel_info(short_channel_id, msg) 403 404 self.update_counts() 405 self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads))) 406 407 def add_verified_channel_info(self, msg: dict, *, capacity_sat: int = None) -> None: 408 try: 409 channel_info = ChannelInfo.from_msg(msg) 410 except IncompatibleOrInsaneFeatures: 411 return 412 channel_info = channel_info._replace(capacity_sat=capacity_sat) 413 with self.lock: 414 self._channels[channel_info.short_channel_id] = channel_info 415 self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) 416 self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) 417 self._update_num_policies_for_chan(channel_info.short_channel_id) 418 if 'raw' in msg: 419 self._db_save_channel(channel_info.short_channel_id, msg['raw']) 420 421 def policy_changed(self, old_policy: Policy, new_policy: Policy, verbose: bool) -> bool: 422 changed = False 423 if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta: 424 changed |= True 425 if verbose: 426 self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}') 427 if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat: 428 changed |= True 429 if verbose: 430 self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}') 431 if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat: 432 changed |= True 433 if verbose: 434 self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}') 435 if old_policy.fee_base_msat != new_policy.fee_base_msat: 436 changed |= True 437 if verbose: 438 self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}') 439 if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths: 440 changed |= True 441 if verbose: 442 self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}') 443 if old_policy.channel_flags != new_policy.channel_flags: 444 changed |= True 445 if verbose: 446 self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}') 447 if old_policy.message_flags != new_policy.message_flags: 448 changed |= True 449 if verbose: 450 self.logger.info(f'message_flags: {old_policy.message_flags} -> {new_policy.message_flags}') 451 if not changed and verbose: 452 self.logger.info(f'policy unchanged: {old_policy.timestamp} -> {new_policy.timestamp}') 453 return changed 454 455 def add_channel_update(self, payload, max_age=None, verify=False, verbose=True): 456 now = int(time.time()) 457 short_channel_id = ShortChannelID(payload['short_channel_id']) 458 timestamp = payload['timestamp'] 459 if max_age and now - timestamp > max_age: 460 return UpdateStatus.EXPIRED 461 if timestamp - now > 60: 462 return UpdateStatus.DEPRECATED 463 channel_info = self._channels.get(short_channel_id) 464 if not channel_info: 465 return UpdateStatus.ORPHANED 466 flags = int.from_bytes(payload['channel_flags'], 'big') 467 direction = flags & FLAG_DIRECTION 468 start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id 469 payload['start_node'] = start_node 470 # compare updates to existing database entries 471 timestamp = payload['timestamp'] 472 start_node = payload['start_node'] 473 short_channel_id = ShortChannelID(payload['short_channel_id']) 474 key = (start_node, short_channel_id) 475 old_policy = self._policies.get(key) 476 if old_policy and timestamp <= old_policy.timestamp + 60: 477 return UpdateStatus.DEPRECATED 478 if verify: 479 self.verify_channel_update(payload) 480 policy = Policy.from_msg(payload) 481 with self.lock: 482 self._policies[key] = policy 483 self._update_num_policies_for_chan(short_channel_id) 484 if 'raw' in payload: 485 self._db_save_policy(policy.key, payload['raw']) 486 if old_policy and not self.policy_changed(old_policy, policy, verbose): 487 return UpdateStatus.UNCHANGED 488 else: 489 return UpdateStatus.GOOD 490 491 def add_channel_updates(self, payloads, max_age=None) -> CategorizedChannelUpdates: 492 orphaned = [] 493 expired = [] 494 deprecated = [] 495 unchanged = [] 496 good = [] 497 for payload in payloads: 498 r = self.add_channel_update(payload, max_age=max_age, verbose=False) 499 if r == UpdateStatus.ORPHANED: 500 orphaned.append(payload) 501 elif r == UpdateStatus.EXPIRED: 502 expired.append(payload) 503 elif r == UpdateStatus.DEPRECATED: 504 deprecated.append(payload) 505 elif r == UpdateStatus.UNCHANGED: 506 unchanged.append(payload) 507 elif r == UpdateStatus.GOOD: 508 good.append(payload) 509 self.update_counts() 510 return CategorizedChannelUpdates( 511 orphaned=orphaned, 512 expired=expired, 513 deprecated=deprecated, 514 unchanged=unchanged, 515 good=good) 516 517 518 def create_database(self): 519 c = self.conn.cursor() 520 c.execute(create_node_info) 521 c.execute(create_address) 522 c.execute(create_policy) 523 c.execute(create_channel_info) 524 self.conn.commit() 525 526 @sql 527 def _db_save_policy(self, key: bytes, msg: bytes): 528 # 'msg' is a 'channel_update' message 529 c = self.conn.cursor() 530 c.execute("""REPLACE INTO policy (key, msg) VALUES (?,?)""", [key, msg]) 531 532 @sql 533 def _db_delete_policy(self, node_id: bytes, short_channel_id: ShortChannelID): 534 key = short_channel_id + node_id 535 c = self.conn.cursor() 536 c.execute("""DELETE FROM policy WHERE key=?""", (key,)) 537 538 @sql 539 def _db_save_channel(self, short_channel_id: ShortChannelID, msg: bytes): 540 # 'msg' is a 'channel_announcement' message 541 c = self.conn.cursor() 542 c.execute("REPLACE INTO channel_info (short_channel_id, msg) VALUES (?,?)", [short_channel_id, msg]) 543 544 @sql 545 def _db_delete_channel(self, short_channel_id: ShortChannelID): 546 c = self.conn.cursor() 547 c.execute("""DELETE FROM channel_info WHERE short_channel_id=?""", (short_channel_id,)) 548 549 @sql 550 def _db_save_node_info(self, node_id: bytes, msg: bytes): 551 # 'msg' is a 'node_announcement' message 552 c = self.conn.cursor() 553 c.execute("REPLACE INTO node_info (node_id, msg) VALUES (?,?)", [node_id, msg]) 554 555 @sql 556 def _db_save_node_address(self, peer: LNPeerAddr, timestamp: int): 557 c = self.conn.cursor() 558 c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", 559 (peer.pubkey, peer.host, peer.port, timestamp)) 560 561 @sql 562 def _db_save_node_addresses(self, node_addresses: Sequence[LNPeerAddr]): 563 c = self.conn.cursor() 564 for addr in node_addresses: 565 c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.pubkey, addr.host, addr.port)) 566 r = c.fetchall() 567 if r == []: 568 c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.pubkey, addr.host, addr.port, 0)) 569 570 def verify_channel_update(self, payload): 571 short_channel_id = payload['short_channel_id'] 572 short_channel_id = ShortChannelID(short_channel_id) 573 if constants.net.rev_genesis_bytes() != payload['chain_hash']: 574 raise Exception('wrong chain hash') 575 if not verify_sig_for_channel_update(payload, payload['start_node']): 576 raise Exception(f'failed verifying channel update for {short_channel_id}') 577 578 def add_node_announcement(self, msg_payloads): 579 # note: signatures have already been verified. 580 if type(msg_payloads) is dict: 581 msg_payloads = [msg_payloads] 582 new_nodes = {} 583 for msg_payload in msg_payloads: 584 try: 585 node_info, node_addresses = NodeInfo.from_msg(msg_payload) 586 except IncompatibleOrInsaneFeatures: 587 continue 588 node_id = node_info.node_id 589 # Ignore node if it has no associated channel (DoS protection) 590 if node_id not in self._channels_for_node: 591 #self.logger.info('ignoring orphan node_announcement') 592 continue 593 node = self._nodes.get(node_id) 594 if node and node.timestamp >= node_info.timestamp: 595 continue 596 node = new_nodes.get(node_id) 597 if node and node.timestamp >= node_info.timestamp: 598 continue 599 # save 600 with self.lock: 601 self._nodes[node_id] = node_info 602 if 'raw' in msg_payload: 603 self._db_save_node_info(node_id, msg_payload['raw']) 604 with self.lock: 605 for addr in node_addresses: 606 net_addr = NetAddress(addr.host, addr.port) 607 self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0 608 self._db_save_node_addresses(node_addresses) 609 610 self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) 611 self.update_counts() 612 613 def get_old_policies(self, delta) -> Sequence[Tuple[bytes, ShortChannelID]]: 614 with self.lock: 615 _policies = self._policies.copy() 616 now = int(time.time()) 617 return list(k for k, v in _policies.items() if v.timestamp <= now - delta) 618 619 def prune_old_policies(self, delta): 620 old_policies = self.get_old_policies(delta) 621 if old_policies: 622 for key in old_policies: 623 node_id, scid = key 624 with self.lock: 625 self._policies.pop(key) 626 self._db_delete_policy(*key) 627 self._update_num_policies_for_chan(scid) 628 self.update_counts() 629 self.logger.info(f'Deleting {len(old_policies)} old policies') 630 631 def prune_orphaned_channels(self): 632 with self.lock: 633 orphaned_chans = self._chans_with_0_policies.copy() 634 if orphaned_chans: 635 for short_channel_id in orphaned_chans: 636 self.remove_channel(short_channel_id) 637 self.update_counts() 638 self.logger.info(f'Deleting {len(orphaned_chans)} orphaned channels') 639 640 def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes) -> bool: 641 """Returns True iff the channel update was successfully added and it was different than 642 what we had before (if any). 643 """ 644 if not verify_sig_for_channel_update(msg_payload, start_node_id): 645 return False # ignore 646 short_channel_id = ShortChannelID(msg_payload['short_channel_id']) 647 msg_payload['start_node'] = start_node_id 648 key = (start_node_id, short_channel_id) 649 prev_chanupd = self._channel_updates_for_private_channels.get(key) 650 if prev_chanupd == msg_payload: 651 return False 652 self._channel_updates_for_private_channels[key] = msg_payload 653 return True 654 655 def remove_channel(self, short_channel_id: ShortChannelID): 656 # FIXME what about rm-ing policies? 657 with self.lock: 658 channel_info = self._channels.pop(short_channel_id, None) 659 if channel_info: 660 self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id) 661 self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id) 662 self._update_num_policies_for_chan(short_channel_id) 663 # delete from database 664 self._db_delete_channel(short_channel_id) 665 666 def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]: 667 """Returns list of (host, port, timestamp).""" 668 addr_to_ts = self._addresses.get(node_id) 669 if not addr_to_ts: 670 return [] 671 return [(str(net_addr.host), net_addr.port, ts) 672 for net_addr, ts in addr_to_ts.items()] 673 674 @sql 675 @profiler 676 def load_data(self): 677 if self.data_loaded.is_set(): 678 return 679 # Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow. 680 c = self.conn.cursor() 681 c.execute("""SELECT * FROM address""") 682 for x in c: 683 node_id, host, port, timestamp = x 684 try: 685 net_addr = NetAddress(host, port) 686 except Exception: 687 continue 688 self._addresses[node_id][net_addr] = int(timestamp or 0) 689 def newest_ts_for_node_id(node_id): 690 newest_ts = 0 691 for addr, ts in self._addresses[node_id].items(): 692 newest_ts = max(newest_ts, ts) 693 return newest_ts 694 sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True) 695 self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS] 696 c.execute("""SELECT * FROM channel_info""") 697 for short_channel_id, msg in c: 698 try: 699 ci = ChannelInfo.from_raw_msg(msg) 700 except IncompatibleOrInsaneFeatures: 701 continue 702 self._channels[ShortChannelID.normalize(short_channel_id)] = ci 703 c.execute("""SELECT * FROM node_info""") 704 for node_id, msg in c: 705 try: 706 node_info, node_addresses = NodeInfo.from_raw_msg(msg) 707 except IncompatibleOrInsaneFeatures: 708 continue 709 # don't load node_addresses because they dont have timestamps 710 self._nodes[node_id] = node_info 711 c.execute("""SELECT * FROM policy""") 712 for key, msg in c: 713 p = Policy.from_raw_msg(key, msg) 714 self._policies[(p.start_node, p.short_channel_id)] = p 715 for channel_info in self._channels.values(): 716 self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) 717 self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) 718 self._update_num_policies_for_chan(channel_info.short_channel_id) 719 self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}') 720 self.update_counts() 721 (nchans_with_0p, nchans_with_1p, nchans_with_2p) = self.get_num_channels_partitioned_by_policy_count() 722 self.logger.info(f'num_channels_partitioned_by_policy_count. ' 723 f'0p: {nchans_with_0p}, 1p: {nchans_with_1p}, 2p: {nchans_with_2p}') 724 self.data_loaded.set() 725 util.trigger_callback('gossip_db_loaded') 726 727 def _update_num_policies_for_chan(self, short_channel_id: ShortChannelID) -> None: 728 channel_info = self.get_channel_info(short_channel_id) 729 if channel_info is None: 730 with self.lock: 731 self._chans_with_0_policies.discard(short_channel_id) 732 self._chans_with_1_policies.discard(short_channel_id) 733 self._chans_with_2_policies.discard(short_channel_id) 734 return 735 p1 = self.get_policy_for_node(short_channel_id, channel_info.node1_id) 736 p2 = self.get_policy_for_node(short_channel_id, channel_info.node2_id) 737 with self.lock: 738 self._chans_with_0_policies.discard(short_channel_id) 739 self._chans_with_1_policies.discard(short_channel_id) 740 self._chans_with_2_policies.discard(short_channel_id) 741 if p1 is not None and p2 is not None: 742 self._chans_with_2_policies.add(short_channel_id) 743 elif p1 is None and p2 is None: 744 self._chans_with_0_policies.add(short_channel_id) 745 else: 746 self._chans_with_1_policies.add(short_channel_id) 747 748 def get_num_channels_partitioned_by_policy_count(self) -> Tuple[int, int, int]: 749 nchans_with_0p = len(self._chans_with_0_policies) 750 nchans_with_1p = len(self._chans_with_1_policies) 751 nchans_with_2p = len(self._chans_with_2_policies) 752 return nchans_with_0p, nchans_with_1p, nchans_with_2p 753 754 def get_policy_for_node( 755 self, 756 short_channel_id: bytes, 757 node_id: bytes, 758 *, 759 my_channels: Dict[ShortChannelID, 'Channel'] = None, 760 private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, 761 ) -> Optional['Policy']: 762 channel_info = self.get_channel_info(short_channel_id) 763 if channel_info is not None: # publicly announced channel 764 policy = self._policies.get((node_id, short_channel_id)) 765 if policy: 766 return policy 767 else: # private channel 768 chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id)) 769 if chan_upd_dict: 770 return Policy.from_msg(chan_upd_dict) 771 # check if it's one of our own channels 772 if my_channels: 773 policy = get_mychannel_policy(short_channel_id, node_id, my_channels) 774 if policy: 775 return policy 776 if private_route_edges: 777 route_edge = private_route_edges.get(short_channel_id, None) 778 if route_edge: 779 return Policy.from_route_edge(route_edge) 780 781 def get_channel_info( 782 self, 783 short_channel_id: ShortChannelID, 784 *, 785 my_channels: Dict[ShortChannelID, 'Channel'] = None, 786 private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, 787 ) -> Optional[ChannelInfo]: 788 ret = self._channels.get(short_channel_id) 789 if ret: 790 return ret 791 # check if it's one of our own channels 792 if my_channels: 793 channel_info = get_mychannel_info(short_channel_id, my_channels) 794 if channel_info: 795 return channel_info 796 if private_route_edges: 797 route_edge = private_route_edges.get(short_channel_id) 798 if route_edge: 799 return ChannelInfo.from_route_edge(route_edge) 800 801 def get_channels_for_node( 802 self, 803 node_id: bytes, 804 *, 805 my_channels: Dict[ShortChannelID, 'Channel'] = None, 806 private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, 807 ) -> Set[bytes]: 808 """Returns the set of short channel IDs where node_id is one of the channel participants.""" 809 if not self.data_loaded.is_set(): 810 raise Exception("channelDB data not loaded yet!") 811 relevant_channels = self._channels_for_node.get(node_id) or set() 812 relevant_channels = set(relevant_channels) # copy 813 # add our own channels # TODO maybe slow? 814 if my_channels: 815 for chan in my_channels.values(): 816 if node_id in (chan.node_id, chan.get_local_pubkey()): 817 relevant_channels.add(chan.short_channel_id) 818 # add private channels # TODO maybe slow? 819 if private_route_edges: 820 for route_edge in private_route_edges.values(): 821 if node_id in (route_edge.start_node, route_edge.end_node): 822 relevant_channels.add(route_edge.short_channel_id) 823 return relevant_channels 824 825 def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *, 826 my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[Tuple[bytes, bytes]]: 827 channel_info = self.get_channel_info(short_channel_id) 828 if channel_info is not None: # publicly announced channel 829 return channel_info.node1_id, channel_info.node2_id 830 # check if it's one of our own channels 831 if not my_channels: 832 return 833 chan = my_channels.get(short_channel_id) # type: Optional[Channel] 834 if not chan: 835 return 836 return chan.get_local_pubkey(), chan.node_id 837 838 def get_node_info_for_node_id(self, node_id: bytes) -> Optional['NodeInfo']: 839 return self._nodes.get(node_id) 840 841 def get_node_infos(self) -> Dict[bytes, NodeInfo]: 842 with self.lock: 843 return self._nodes.copy() 844 845 def get_node_policies(self) -> Dict[Tuple[bytes, ShortChannelID], Policy]: 846 with self.lock: 847 return self._policies.copy() 848 849 def to_dict(self) -> dict: 850 """ Generates a graph representation in terms of a dictionary. 851 852 The dictionary contains only native python types and can be encoded 853 to json. 854 """ 855 with self.lock: 856 graph = {'nodes': [], 'channels': []} 857 858 # gather nodes 859 for pk, nodeinfo in self._nodes.items(): 860 # use _asdict() to convert NamedTuples to json encodable dicts 861 graph['nodes'].append( 862 nodeinfo._asdict(), 863 ) 864 graph['nodes'][-1]['addresses'] = [ 865 {'host': str(addr.host), 'port': addr.port, 'timestamp': ts} 866 for addr, ts in self._addresses[pk].items() 867 ] 868 869 # gather channels 870 for cid, channelinfo in self._channels.items(): 871 graph['channels'].append( 872 channelinfo._asdict(), 873 ) 874 policy1 = self._policies.get( 875 (channelinfo.node1_id, channelinfo.short_channel_id)) 876 policy2 = self._policies.get( 877 (channelinfo.node2_id, channelinfo.short_channel_id)) 878 graph['channels'][-1]['policy1'] = policy1._asdict() if policy1 else None 879 graph['channels'][-1]['policy2'] = policy2._asdict() if policy2 else None 880 881 # need to use json_normalize otherwise json encoding in rpc server fails 882 graph = json_normalize(graph) 883 return graph