commit fd56fb918982c7afbf96267e03d7cfb2ca7fc858
parent 1ca6f6f306a1f9f376c7519b9b80662aa39768ee
Author: SomberNight <somber.night@protonmail.com>
Date: Sat, 29 Feb 2020 18:32:47 +0100
ChannelDB: add self.lock and make it thread-safe
Diffstat:
1 file changed, 46 insertions(+), 25 deletions(-)
diff --git a/electrum/channel_db.py b/electrum/channel_db.py
@@ -31,6 +31,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
import binascii
import base64
import asyncio
+import threading
from .sql_db import SqlDB, sql
@@ -247,17 +248,21 @@ class ChannelDB(SqlDB):
def __init__(self, network: 'Network'):
path = os.path.join(get_headers_dir(network.config), 'gossip_db')
super().__init__(network, path, commit_interval=100)
+ self.lock = threading.RLock()
self.num_nodes = 0
self.num_channels = 0
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self)
+
# initialized in load_data
+ # note: modify/iterate needs self.lock
self._channels = {} # type: Dict[bytes, ChannelInfo]
self._policies = {} # type: Dict[Tuple[bytes, bytes], Policy] # (node_id, scid) -> Policy
self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo
# node_id -> (host, port, ts)
self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]]
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
+
self.data_loaded = asyncio.Event()
self.network = network # only for callback
@@ -268,16 +273,19 @@ class ChannelDB(SqlDB):
self.network.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies)
def get_channel_ids(self):
- return set(self._channels.keys())
+ with self.lock:
+ return set(self._channels.keys())
def add_recent_peer(self, peer: LNPeerAddr):
now = int(time.time())
node_id = peer.pubkey
- self._addresses[node_id].add((peer.host, peer.port, now))
+ with self.lock:
+ self._addresses[node_id].add((peer.host, peer.port, now))
self.save_node_address(node_id, peer, now)
def get_200_randomly_sorted_nodes_not_in(self, node_ids):
- unshuffled = set(self._nodes.keys()) - node_ids
+ with self.lock:
+ unshuffled = set(self._nodes.keys()) - node_ids
return random.sample(unshuffled, min(200, len(unshuffled)))
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
@@ -296,8 +304,10 @@ class ChannelDB(SqlDB):
# FIXME this does not reliably return "recent" peers...
# Also, the list() cast over the whole dict (thousands of elements),
# is really inefficient.
+ with self.lock:
+ _addresses_keys = list(self._addresses.keys())
r = [self.get_last_good_address(node_id)
- for node_id in list(self._addresses.keys())[-self.NUM_MAX_RECENT_PEERS:]]
+ for node_id in _addresses_keys[-self.NUM_MAX_RECENT_PEERS:]]
return list(reversed(r))
# note: currently channel announcements are trusted by default (trusted=True);
@@ -336,9 +346,10 @@ class ChannelDB(SqlDB):
except UnknownEvenFeatureBits:
return
channel_info = channel_info._replace(capacity_sat=capacity_sat)
- self._channels[channel_info.short_channel_id] = channel_info
- self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
- self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
+ with self.lock:
+ self._channels[channel_info.short_channel_id] = channel_info
+ self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
+ self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
if 'raw' in msg:
self.save_channel(channel_info.short_channel_id, msg['raw'])
@@ -397,7 +408,8 @@ class ChannelDB(SqlDB):
if verify:
self.verify_channel_update(payload)
policy = Policy.from_msg(payload)
- self._policies[key] = policy
+ with self.lock:
+ self._policies[key] = policy
if 'raw' in payload:
self.save_policy(policy.key, payload['raw'])
#
@@ -492,32 +504,38 @@ class ChannelDB(SqlDB):
if node and node.timestamp >= node_info.timestamp:
continue
# save
- self._nodes[node_id] = node_info
+ with self.lock:
+ self._nodes[node_id] = node_info
if 'raw' in msg_payload:
self.save_node_info(node_id, msg_payload['raw'])
- for addr in node_addresses:
- self._addresses[node_id].add((addr.host, addr.port, 0))
+ with self.lock:
+ for addr in node_addresses:
+ self._addresses[node_id].add((addr.host, addr.port, 0))
self.save_node_addresses(node_id, node_addresses)
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
self.update_counts()
def get_old_policies(self, delta):
+ with self.lock:
+ _policies = self._policies.copy()
now = int(time.time())
- return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
+ return list(k for k, v in _policies.items() if v.timestamp <= now - delta)
def prune_old_policies(self, delta):
l = self.get_old_policies(delta)
if l:
for k in l:
- self._policies.pop(k)
+ with self.lock:
+ self._policies.pop(k)
self.delete_policy(*k)
self.update_counts()
self.logger.info(f'Deleting {len(l)} old policies')
def get_orphaned_channels(self):
- ids = set(x[1] for x in self._policies.keys())
- return list(x for x in self._channels.keys() if x not in ids)
+ with self.lock:
+ ids = set(x[1] for x in self._policies.keys())
+ return list(x for x in self._channels.keys() if x not in ids)
def prune_orphaned_channels(self):
l = self.get_orphaned_channels()
@@ -535,10 +553,11 @@ class ChannelDB(SqlDB):
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
def remove_channel(self, short_channel_id: ShortChannelID):
- channel_info = self._channels.pop(short_channel_id, None)
- if channel_info:
- self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
- self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
+ with self.lock:
+ channel_info = self._channels.pop(short_channel_id, None)
+ if channel_info:
+ self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
+ self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
# delete from database
self.delete_channel(short_channel_id)
@@ -571,17 +590,19 @@ class ChannelDB(SqlDB):
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
self.update_counts()
- self.count_incomplete_channels()
+ self.logger.info(f'semi-orphaned channels: {self.get_num_incomplete_channels()}')
self.data_loaded.set()
- def count_incomplete_channels(self):
- out = set()
- for short_channel_id, ci in self._channels.items():
+ def get_num_incomplete_channels(self) -> int:
+ found = set()
+ with self.lock:
+ _channels = self._channels.copy()
+ for short_channel_id, ci in _channels.items():
p1 = self.get_policy_for_node(short_channel_id, ci.node1_id)
p2 = self.get_policy_for_node(short_channel_id, ci.node2_id)
if p1 is None or p2 is not None:
- out.add(short_channel_id)
- self.logger.info(f'semi-orphaned: {len(out)}')
+ found.add(short_channel_id)
+ return len(found)
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']: