commit 2ec548dda3664834d4a50d4f323886130285257a
parent 9a803cd1d683b72a2246ed70d76f77562f8d9981
Author: SomberNight <somber.night@protonmail.com>
Date: Sat, 9 Jan 2021 19:56:05 +0100
ChannelDB: avoid duplicate (host,port) entries in ChannelDB._addresses
before:
node_id -> set of (host, port, ts)
after:
node_id -> NetAddress -> timestamp
Look at e.g. add_recent_peer; we only want to store
the last connection time, not all of them.
Diffstat:
2 files changed, 34 insertions(+), 26 deletions(-)
diff --git a/electrum/channel_db.py b/electrum/channel_db.py
@@ -34,6 +34,7 @@ import asyncio
import threading
from enum import IntEnum
+from aiorpcx import NetAddress
from .sql_db import SqlDB, sql
from . import constants, util
@@ -53,14 +54,6 @@ FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0
-class NodeAddress(NamedTuple):
- """Holds address information of Lightning nodes
- and how up to date this info is."""
- host: str
- port: int
- timestamp: int
-
-
class ChannelInfo(NamedTuple):
short_channel_id: ShortChannelID
node1_id: bytes
@@ -295,8 +288,8 @@ class ChannelDB(SqlDB):
self._channels = {} # type: Dict[ShortChannelID, ChannelInfo]
self._policies = {} # type: Dict[Tuple[bytes, ShortChannelID], Policy] # (node_id, scid) -> Policy
self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo
- # node_id -> (host, port, ts)
- self._addresses = defaultdict(set) # type: Dict[bytes, Set[NodeAddress]]
+ # node_id -> NetAddress -> timestamp
+ self._addresses = defaultdict(dict) # type: Dict[bytes, Dict[NetAddress, int]]
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
self._recent_peers = [] # type: List[bytes] # list of node_ids
self._chans_with_0_policies = set() # type: Set[ShortChannelID]
@@ -321,7 +314,7 @@ class ChannelDB(SqlDB):
now = int(time.time())
node_id = peer.pubkey
with self.lock:
- self._addresses[node_id].add(NodeAddress(peer.host, peer.port, now))
+ self._addresses[node_id][peer.net_addr()] = now
# list is ordered
if node_id in self._recent_peers:
self._recent_peers.remove(node_id)
@@ -336,12 +329,12 @@ class ChannelDB(SqlDB):
def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
"""Returns latest address we successfully connected to, for given node."""
- r = self._addresses.get(node_id)
- if not r:
+ addr_to_ts = self._addresses.get(node_id)
+ if not addr_to_ts:
return None
- addr = sorted(list(r), key=lambda x: x.timestamp, reverse=True)[0]
+ addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0]
try:
- return LNPeerAddr(addr.host, addr.port, node_id)
+ return LNPeerAddr(str(addr.host), addr.port, node_id)
except ValueError:
return None
@@ -583,7 +576,8 @@ class ChannelDB(SqlDB):
self._db_save_node_info(node_id, msg_payload['raw'])
with self.lock:
for addr in node_addresses:
- self._addresses[node_id].add(NodeAddress(addr.host, addr.port, 0))
+ net_addr = NetAddress(addr.host, addr.port)
+ self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0
self._db_save_node_addresses(node_addresses)
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
@@ -634,8 +628,13 @@ class ChannelDB(SqlDB):
# delete from database
self._db_delete_channel(short_channel_id)
- def get_node_addresses(self, node_id):
- return self._addresses.get(node_id)
+ def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]:
+ """Returns list of (host, port, timestamp)."""
+ addr_to_ts = self._addresses.get(node_id)
+ if not addr_to_ts:
+ return []
+ return [(str(net_addr.host), net_addr.port, ts)
+ for net_addr, ts in addr_to_ts.items()]
@sql
@profiler
@@ -643,17 +642,19 @@ class ChannelDB(SqlDB):
if self.data_loaded.is_set():
return
# Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow.
- # I believe lnmsg (and lightning.json) will need a rewrite anyway, so instead of tweaking
- # load_data() here, that should be done. see #6006
c = self.conn.cursor()
c.execute("""SELECT * FROM address""")
for x in c:
node_id, host, port, timestamp = x
- self._addresses[node_id].add(NodeAddress(str(host), int(port), int(timestamp or 0)))
+ try:
+ net_addr = NetAddress(host, port)
+ except Exception:
+ continue
+ self._addresses[node_id][net_addr] = int(timestamp or 0)
def newest_ts_for_node_id(node_id):
newest_ts = 0
- for addr in self._addresses[node_id]:
- newest_ts = max(newest_ts, addr.timestamp)
+ for addr, ts in self._addresses[node_id].items():
+ newest_ts = max(newest_ts, ts)
return newest_ts
sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True)
self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
@@ -791,7 +792,10 @@ class ChannelDB(SqlDB):
graph['nodes'].append(
nodeinfo._asdict(),
)
- graph['nodes'][-1]['addresses'] = [addr._asdict() for addr in self._addresses[pk]]
+ graph['nodes'][-1]['addresses'] = [
+ {'host': str(addr.host), 'port': addr.port, 'timestamp': ts}
+ for addr, ts in self._addresses[pk].items()
+ ]
# gather channels
for cid, channelinfo in self._channels.items():
diff --git a/electrum/lnutil.py b/electrum/lnutil.py
@@ -1106,6 +1106,7 @@ def derive_payment_secret_from_payment_preimage(payment_preimage: bytes) -> byte
class LNPeerAddr:
+ # note: while not programmatically enforced, this class is meant to be *immutable*
def __init__(self, host: str, port: int, pubkey: bytes):
assert isinstance(host, str), repr(host)
@@ -1120,7 +1121,7 @@ class LNPeerAddr:
self.host = host
self.port = port
self.pubkey = pubkey
- self._net_addr_str = str(net_addr)
+ self._net_addr = net_addr
def __str__(self):
return '{}@{}'.format(self.pubkey.hex(), self.net_addr_str())
@@ -1128,8 +1129,11 @@ class LNPeerAddr:
def __repr__(self):
return f'<LNPeerAddr host={self.host} port={self.port} pubkey={self.pubkey.hex()}>'
+ def net_addr(self) -> NetAddress:
+ return self._net_addr
+
def net_addr_str(self) -> str:
- return self._net_addr_str
+ return str(self._net_addr)
def __eq__(self, other):
if not isinstance(other, LNPeerAddr):