electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit a97e7bae056cb4ef42d82fe4bcfd0cd635045161
parent 387c2a1acda7e4de708e2a3cc0ed2b58c6305736
Author: SomberNight <somber.night@protonmail.com>
Date:   Mon,  2 Mar 2020 16:56:15 +0100

ChannelDB: make gossip sync progress updates cheaper

get_num_channels_partitioned_by_policy_count() was too slow

Diffstat:
Melectrum/channel_db.py | 81++++++++++++++++++++++++++++++++++++++++++++++---------------------------------
Melectrum/lnutil.py | 4+++-
Melectrum/lnworker.py | 2+-
3 files changed, 51 insertions(+), 36 deletions(-)

diff --git a/electrum/channel_db.py b/electrum/channel_db.py @@ -260,13 +260,16 @@ class ChannelDB(SqlDB): # initialized in load_data # note: modify/iterate needs self.lock - self._channels = {} # type: Dict[bytes, ChannelInfo] - self._policies = {} # type: Dict[Tuple[bytes, bytes], Policy] # (node_id, scid) -> Policy + self._channels = {} # type: Dict[ShortChannelID, ChannelInfo] + self._policies = {} # type: Dict[Tuple[bytes, ShortChannelID], Policy] # (node_id, scid) -> Policy self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo # node_id -> (host, port, ts) self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]] self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]] self._recent_peers = [] # type: List[bytes] # list of node_ids + self._chans_with_0_policies = set() # type: Set[ShortChannelID] + self._chans_with_1_policies = set() # type: Set[ShortChannelID] + self._chans_with_2_policies = set() # type: Set[ShortChannelID] self.data_loaded = asyncio.Event() self.network = network # only for callback @@ -357,6 +360,7 @@ class ChannelDB(SqlDB): self._channels[channel_info.short_channel_id] = channel_info self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) + self._update_num_policies_for_chan(channel_info.short_channel_id) if 'raw' in msg: self.save_channel(channel_info.short_channel_id, msg['raw']) @@ -417,6 +421,7 @@ class ChannelDB(SqlDB): policy = Policy.from_msg(payload) with self.lock: self._policies[key] = policy + self._update_num_policies_for_chan(short_channel_id) if 'raw' in payload: self.save_policy(policy.key, payload['raw']) # @@ -523,34 +528,32 @@ class ChannelDB(SqlDB): self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) self.update_counts() - def get_old_policies(self, delta): + def get_old_policies(self, delta) -> Sequence[Tuple[bytes, ShortChannelID]]: with self.lock: _policies = self._policies.copy() now = int(time.time()) return list(k for k, v in _policies.items() if v.timestamp <= now - delta) def prune_old_policies(self, delta): - l = self.get_old_policies(delta) - if l: - for k in l: + old_policies = self.get_old_policies(delta) + if old_policies: + for key in old_policies: + node_id, scid = key with self.lock: - self._policies.pop(k) - self.delete_policy(*k) + self._policies.pop(key) + self.delete_policy(*key) + self._update_num_policies_for_chan(scid) self.update_counts() - self.logger.info(f'Deleting {len(l)} old policies') - - def get_orphaned_channels(self): - with self.lock: - ids = set(x[1] for x in self._policies.keys()) - return list(x for x in self._channels.keys() if x not in ids) + self.logger.info(f'Deleting {len(old_policies)} old policies') def prune_orphaned_channels(self): - l = self.get_orphaned_channels() - if l: - for short_channel_id in l: + with self.lock: + orphaned_chans = self._chans_with_0_policies.copy() + if orphaned_chans: + for short_channel_id in orphaned_chans: self.remove_channel(short_channel_id) self.update_counts() - self.logger.info(f'Deleting {len(l)} orphaned channels') + self.logger.info(f'Deleting {len(orphaned_chans)} orphaned channels') def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes): if not verify_sig_for_channel_update(msg_payload, start_node_id): @@ -560,11 +563,13 @@ class ChannelDB(SqlDB): self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload def remove_channel(self, short_channel_id: ShortChannelID): + # FIXME what about rm-ing policies? with self.lock: channel_info = self._channels.pop(short_channel_id, None) if channel_info: self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id) self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id) + self._update_num_policies_for_chan(short_channel_id) # delete from database self.delete_channel(short_channel_id) @@ -589,7 +594,7 @@ class ChannelDB(SqlDB): c.execute("""SELECT * FROM channel_info""") for short_channel_id, msg in c: ci = ChannelInfo.from_raw_msg(msg) - self._channels[short_channel_id] = ci + self._channels[ShortChannelID.normalize(short_channel_id)] = ci c.execute("""SELECT * FROM node_info""") for node_id, msg in c: node_info, node_addresses = NodeInfo.from_raw_msg(msg) @@ -602,6 +607,7 @@ class ChannelDB(SqlDB): for channel_info in self._channels.values(): self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) + self._update_num_policies_for_chan(channel_info.short_channel_id) self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}') self.update_counts() (nchans_with_0p, nchans_with_1p, nchans_with_2p) = self.get_num_channels_partitioned_by_policy_count() @@ -609,24 +615,31 @@ class ChannelDB(SqlDB): f'0p: {nchans_with_0p}, 1p: {nchans_with_1p}, 2p: {nchans_with_2p}') self.data_loaded.set() - def get_num_channels_partitioned_by_policy_count(self) -> Tuple[int, int, int]: - chans_with_zero_policies = set() - chans_with_one_policies = set() - chans_with_two_policies = set() + def _update_num_policies_for_chan(self, short_channel_id: ShortChannelID) -> None: + channel_info = self.get_channel_info(short_channel_id) + if channel_info is None: + with self.lock: + self._chans_with_0_policies.discard(short_channel_id) + self._chans_with_1_policies.discard(short_channel_id) + self._chans_with_2_policies.discard(short_channel_id) + return + p1 = self.get_policy_for_node(short_channel_id, channel_info.node1_id) + p2 = self.get_policy_for_node(short_channel_id, channel_info.node2_id) with self.lock: - _channels = self._channels.copy() - for short_channel_id, ci in _channels.items(): - p1 = self.get_policy_for_node(short_channel_id, ci.node1_id) - p2 = self.get_policy_for_node(short_channel_id, ci.node2_id) + self._chans_with_0_policies.discard(short_channel_id) + self._chans_with_1_policies.discard(short_channel_id) + self._chans_with_2_policies.discard(short_channel_id) if p1 is not None and p2 is not None: - chans_with_two_policies.add(short_channel_id) + self._chans_with_2_policies.add(short_channel_id) elif p1 is None and p2 is None: - chans_with_zero_policies.add(short_channel_id) + self._chans_with_0_policies.add(short_channel_id) else: - chans_with_one_policies.add(short_channel_id) - nchans_with_0p = len(chans_with_zero_policies) - nchans_with_1p = len(chans_with_one_policies) - nchans_with_2p = len(chans_with_two_policies) + self._chans_with_1_policies.add(short_channel_id) + + def get_num_channels_partitioned_by_policy_count(self) -> Tuple[int, int, int]: + nchans_with_0p = len(self._chans_with_0_policies) + nchans_with_1p = len(self._chans_with_1_policies) + nchans_with_2p = len(self._chans_with_2_policies) return nchans_with_0p, nchans_with_1p, nchans_with_2p def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *, @@ -660,7 +673,7 @@ class ChannelDB(SqlDB): local_update_decoded['start_node'] = node_id return Policy.from_msg(local_update_decoded) - def get_channel_info(self, short_channel_id: bytes, *, + def get_channel_info(self, short_channel_id: ShortChannelID, *, my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[ChannelInfo]: ret = self._channels.get(short_channel_id) if ret: diff --git a/electrum/lnutil.py b/electrum/lnutil.py @@ -825,8 +825,10 @@ class ShortChannelID(bytes): if isinstance(data, ShortChannelID) or data is None: return data if isinstance(data, str): + assert len(data) == 16 return ShortChannelID.fromhex(data) - if isinstance(data, bytes): + if isinstance(data, (bytes, bytearray)): + assert len(data) == 8 return ShortChannelID(data) @property diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -372,8 +372,8 @@ class LNGossip(LNWorker): def get_sync_progress_estimate(self) -> Tuple[Optional[int], Optional[int]]: if self.num_peers() == 0: return None, None - num_db_channels = self.channel_db.num_channels nchans_with_0p, nchans_with_1p, nchans_with_2p = self.channel_db.get_num_channels_partitioned_by_policy_count() + num_db_channels = nchans_with_0p + nchans_with_1p + nchans_with_2p # some channels will never have two policies (only one is in gossip?...) # so if we have at least 1 policy for a channel, we consider that channel "complete" here current_est = num_db_channels - nchans_with_0p