electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 2ec548dda3664834d4a50d4f323886130285257a
parent 9a803cd1d683b72a2246ed70d76f77562f8d9981
Author: SomberNight <somber.night@protonmail.com>
Date:   Sat,  9 Jan 2021 19:56:05 +0100

ChannelDB: avoid duplicate (host,port) entries in ChannelDB._addresses

before:
node_id -> set of (host, port, ts)
after:
node_id -> NetAddress -> timestamp

Look at e.g. add_recent_peer; we only want to store
the last connection time, not all of them.

Diffstat:
Melectrum/channel_db.py | 52++++++++++++++++++++++++++++------------------------
Melectrum/lnutil.py | 8++++++--
2 files changed, 34 insertions(+), 26 deletions(-)

diff --git a/electrum/channel_db.py b/electrum/channel_db.py @@ -34,6 +34,7 @@ import asyncio import threading from enum import IntEnum +from aiorpcx import NetAddress from .sql_db import SqlDB, sql from . import constants, util @@ -53,14 +54,6 @@ FLAG_DISABLE = 1 << 1 FLAG_DIRECTION = 1 << 0 -class NodeAddress(NamedTuple): - """Holds address information of Lightning nodes - and how up to date this info is.""" - host: str - port: int - timestamp: int - - class ChannelInfo(NamedTuple): short_channel_id: ShortChannelID node1_id: bytes @@ -295,8 +288,8 @@ class ChannelDB(SqlDB): self._channels = {} # type: Dict[ShortChannelID, ChannelInfo] self._policies = {} # type: Dict[Tuple[bytes, ShortChannelID], Policy] # (node_id, scid) -> Policy self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo - # node_id -> (host, port, ts) - self._addresses = defaultdict(set) # type: Dict[bytes, Set[NodeAddress]] + # node_id -> NetAddress -> timestamp + self._addresses = defaultdict(dict) # type: Dict[bytes, Dict[NetAddress, int]] self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]] self._recent_peers = [] # type: List[bytes] # list of node_ids self._chans_with_0_policies = set() # type: Set[ShortChannelID] @@ -321,7 +314,7 @@ class ChannelDB(SqlDB): now = int(time.time()) node_id = peer.pubkey with self.lock: - self._addresses[node_id].add(NodeAddress(peer.host, peer.port, now)) + self._addresses[node_id][peer.net_addr()] = now # list is ordered if node_id in self._recent_peers: self._recent_peers.remove(node_id) @@ -336,12 +329,12 @@ class ChannelDB(SqlDB): def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]: """Returns latest address we successfully connected to, for given node.""" - r = self._addresses.get(node_id) - if not r: + addr_to_ts = self._addresses.get(node_id) + if not addr_to_ts: return None - addr = sorted(list(r), key=lambda x: x.timestamp, reverse=True)[0] + addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0] try: - return LNPeerAddr(addr.host, addr.port, node_id) + return LNPeerAddr(str(addr.host), addr.port, node_id) except ValueError: return None @@ -583,7 +576,8 @@ class ChannelDB(SqlDB): self._db_save_node_info(node_id, msg_payload['raw']) with self.lock: for addr in node_addresses: - self._addresses[node_id].add(NodeAddress(addr.host, addr.port, 0)) + net_addr = NetAddress(addr.host, addr.port) + self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0 self._db_save_node_addresses(node_addresses) self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) @@ -634,8 +628,13 @@ class ChannelDB(SqlDB): # delete from database self._db_delete_channel(short_channel_id) - def get_node_addresses(self, node_id): - return self._addresses.get(node_id) + def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]: + """Returns list of (host, port, timestamp).""" + addr_to_ts = self._addresses.get(node_id) + if not addr_to_ts: + return [] + return [(str(net_addr.host), net_addr.port, ts) + for net_addr, ts in addr_to_ts.items()] @sql @profiler @@ -643,17 +642,19 @@ class ChannelDB(SqlDB): if self.data_loaded.is_set(): return # Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow. - # I believe lnmsg (and lightning.json) will need a rewrite anyway, so instead of tweaking - # load_data() here, that should be done. see #6006 c = self.conn.cursor() c.execute("""SELECT * FROM address""") for x in c: node_id, host, port, timestamp = x - self._addresses[node_id].add(NodeAddress(str(host), int(port), int(timestamp or 0))) + try: + net_addr = NetAddress(host, port) + except Exception: + continue + self._addresses[node_id][net_addr] = int(timestamp or 0) def newest_ts_for_node_id(node_id): newest_ts = 0 - for addr in self._addresses[node_id]: - newest_ts = max(newest_ts, addr.timestamp) + for addr, ts in self._addresses[node_id].items(): + newest_ts = max(newest_ts, ts) return newest_ts sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True) self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS] @@ -791,7 +792,10 @@ class ChannelDB(SqlDB): graph['nodes'].append( nodeinfo._asdict(), ) - graph['nodes'][-1]['addresses'] = [addr._asdict() for addr in self._addresses[pk]] + graph['nodes'][-1]['addresses'] = [ + {'host': str(addr.host), 'port': addr.port, 'timestamp': ts} + for addr, ts in self._addresses[pk].items() + ] # gather channels for cid, channelinfo in self._channels.items(): diff --git a/electrum/lnutil.py b/electrum/lnutil.py @@ -1106,6 +1106,7 @@ def derive_payment_secret_from_payment_preimage(payment_preimage: bytes) -> byte class LNPeerAddr: + # note: while not programmatically enforced, this class is meant to be *immutable* def __init__(self, host: str, port: int, pubkey: bytes): assert isinstance(host, str), repr(host) @@ -1120,7 +1121,7 @@ class LNPeerAddr: self.host = host self.port = port self.pubkey = pubkey - self._net_addr_str = str(net_addr) + self._net_addr = net_addr def __str__(self): return '{}@{}'.format(self.pubkey.hex(), self.net_addr_str()) @@ -1128,8 +1129,11 @@ class LNPeerAddr: def __repr__(self): return f'<LNPeerAddr host={self.host} port={self.port} pubkey={self.pubkey.hex()}>' + def net_addr(self) -> NetAddress: + return self._net_addr + def net_addr_str(self) -> str: - return self._net_addr_str + return str(self._net_addr) def __eq__(self, other): if not isinstance(other, LNPeerAddr):