commit 112b0e0544dc25fce718e66ef678cd7b360d405b
parent a7589a97ad04a5dbb9889d06e958a7ea55d8ac8e
Author: ThomasV <thomasv@electrum.org>
Date: Fri, 22 Jun 2018 13:05:40 +0200
Merge pull request #4453 from SomberNight/network_locks
locks in network.py
Diffstat:
M | lib/network.py | | | 203 | ++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------------- |
1 file changed, 135 insertions(+), 68 deletions(-)
diff --git a/lib/network.py b/lib/network.py
@@ -171,7 +171,7 @@ class Network(util.DaemonThread):
util.DaemonThread.__init__(self)
self.config = SimpleConfig(config) if isinstance(config, dict) else config
self.num_server = 10 if not self.config.get('oneserver') else 0
- self.blockchains = blockchain.read_blockchains(self.config)
+ self.blockchains = blockchain.read_blockchains(self.config) # note: needs self.blockchains_lock
self.print_error("blockchains", self.blockchains.keys())
self.blockchain_index = config.get('blockchain_index', 0)
if self.blockchain_index not in self.blockchains.keys():
@@ -187,27 +187,35 @@ class Network(util.DaemonThread):
self.default_server = None
if not self.default_server:
self.default_server = pick_random_server()
- self.lock = threading.Lock()
+
+ # locks: if you need to take multiple ones, acquire them in the order they are defined here!
+ self.interface_lock = threading.RLock() # <- re-entrant
+ self.callback_lock = threading.Lock()
+ self.pending_sends_lock = threading.Lock()
+ self.recent_servers_lock = threading.RLock() # <- re-entrant
+ self.subscribed_addresses_lock = threading.Lock()
+ self.blockchains_lock = threading.Lock()
+
self.pending_sends = []
self.message_id = 0
self.debug = False
self.irc_servers = {} # returned by interface (list from irc)
- self.recent_servers = self.read_recent_servers()
+ self.recent_servers = self.read_recent_servers() # note: needs self.recent_servers_lock
self.banner = ''
self.donation_address = ''
self.relay_fee = None
# callbacks passed with subscriptions
- self.subscriptions = defaultdict(list)
- self.sub_cache = {}
+ self.subscriptions = defaultdict(list) # note: needs self.callback_lock
+ self.sub_cache = {} # note: needs self.interface_lock
# callbacks set by the GUI
- self.callbacks = defaultdict(list)
+ self.callbacks = defaultdict(list) # note: needs self.callback_lock
dir_path = os.path.join( self.config.path, 'certs')
util.make_dir(dir_path)
# subscriptions and requests
- self.subscribed_addresses = set()
+ self.subscribed_addresses = set() # note: needs self.subscribed_addresses_lock
self.h2addr = {}
# Requests from client we've not seen a response to
self.unanswered_requests = {}
@@ -217,8 +225,8 @@ class Network(util.DaemonThread):
# kick off the network. interface is the main server we are currently
# communicating with. interfaces is the set of servers we are connecting
# to or have an ongoing connection with
- self.interface = None
- self.interfaces = {}
+ self.interface = None # note: needs self.interface_lock
+ self.interfaces = {} # note: needs self.interface_lock
self.auto_connect = self.config.get('auto_connect', True)
self.connecting = set()
self.requested_chunks = set()
@@ -226,19 +234,31 @@ class Network(util.DaemonThread):
self.start_network(deserialize_server(self.default_server)[2],
deserialize_proxy(self.config.get('proxy')))
+ def with_interface_lock(func):
+ def func_wrapper(self, *args, **kwargs):
+ with self.interface_lock:
+ return func(self, *args, **kwargs)
+ return func_wrapper
+
+ def with_recent_servers_lock(func):
+ def func_wrapper(self, *args, **kwargs):
+ with self.recent_servers_lock:
+ return func(self, *args, **kwargs)
+ return func_wrapper
+
def register_callback(self, callback, events):
- with self.lock:
+ with self.callback_lock:
for event in events:
self.callbacks[event].append(callback)
def unregister_callback(self, callback):
- with self.lock:
+ with self.callback_lock:
for callbacks in self.callbacks.values():
if callback in callbacks:
callbacks.remove(callback)
def trigger_callback(self, event, *args):
- with self.lock:
+ with self.callback_lock:
callbacks = self.callbacks[event][:]
[callback(event, *args) for callback in callbacks]
@@ -253,6 +273,7 @@ class Network(util.DaemonThread):
except:
return []
+ @with_recent_servers_lock
def save_recent_servers(self):
if not self.config.path:
return
@@ -264,6 +285,7 @@ class Network(util.DaemonThread):
except:
pass
+ @with_interface_lock
def get_server_height(self):
return self.interface.tip if self.interface else 0
@@ -291,11 +313,15 @@ class Network(util.DaemonThread):
def is_up_to_date(self):
return self.unanswered_requests == {}
+ @with_interface_lock
def queue_request(self, method, params, interface=None):
# If you want to queue a request on any interface it must go
# through this function so message ids are properly tracked
if interface is None:
interface = self.interface
+ if interface is None:
+ self.print_error('warning: dropping request', method, params)
+ return
message_id = self.message_id
self.message_id += 1
if self.debug:
@@ -303,7 +329,9 @@ class Network(util.DaemonThread):
interface.queue_request(method, params, message_id)
return message_id
+ @with_interface_lock
def send_subscriptions(self):
+ assert self.interface
self.print_error('sending subscriptions to', self.interface.server, len(self.unanswered_requests), len(self.subscribed_addresses))
self.sub_cache.clear()
# Resend unanswered requests
@@ -317,8 +345,9 @@ class Network(util.DaemonThread):
self.queue_request('server.peers.subscribe', [])
self.request_fee_estimates()
self.queue_request('blockchain.relayfee', [])
- for h in list(self.subscribed_addresses):
- self.queue_request('blockchain.scripthash.subscribe', [h])
+ with self.subscribed_addresses_lock:
+ for h in self.subscribed_addresses:
+ self.queue_request('blockchain.scripthash.subscribe', [h])
def request_fee_estimates(self):
from .simple_config import FEE_ETA_TARGETS
@@ -358,10 +387,12 @@ class Network(util.DaemonThread):
if self.is_connected():
return self.donation_address
+ @with_interface_lock
def get_interfaces(self):
'''The interfaces that are in connected state'''
return list(self.interfaces.keys())
+ @with_recent_servers_lock
def get_servers(self):
out = constants.net.DEFAULT_SERVERS
if self.irc_servers:
@@ -376,6 +407,7 @@ class Network(util.DaemonThread):
out[host] = { protocol:port }
return out
+ @with_interface_lock
def start_interface(self, server):
if (not server in self.interfaces and not server in self.connecting):
if server == self.default_server:
@@ -385,7 +417,8 @@ class Network(util.DaemonThread):
c = Connection(server, self.socket_queue, self.config.path)
def start_random_interface(self):
- exclude_set = self.disconnected_servers.union(set(self.interfaces))
+ with self.interface_lock:
+ exclude_set = self.disconnected_servers.union(set(self.interfaces))
server = pick_random_server(self.get_servers(), self.protocol, exclude_set)
if server:
self.start_interface(server)
@@ -433,15 +466,17 @@ class Network(util.DaemonThread):
else:
socket.getaddrinfo = socket._getaddrinfo
+ @with_interface_lock
def start_network(self, protocol, proxy):
assert not self.interface and not self.interfaces
assert not self.connecting and self.socket_queue.empty()
self.print_error('starting network')
- self.disconnected_servers = set([])
+ self.disconnected_servers = set([]) # note: needs self.interface_lock
self.protocol = protocol
self.set_proxy(proxy)
self.start_interfaces()
+ @with_interface_lock
def stop_network(self):
self.print_error("stopping network")
for interface in list(self.interfaces.values()):
@@ -491,6 +526,7 @@ class Network(util.DaemonThread):
if servers:
self.switch_to_interface(random.choice(servers))
+ @with_interface_lock
def switch_lagging_interface(self):
'''If auto_connect and lagging, switch interface'''
if self.server_is_lagging() and self.auto_connect:
@@ -501,6 +537,7 @@ class Network(util.DaemonThread):
choice = random.choice(filtered)
self.switch_to_interface(choice)
+ @with_interface_lock
def switch_to_interface(self, server):
'''Switch to server as our interface. If no connection exists nor
being opened, start a thread to connect. The actual switch will
@@ -522,6 +559,7 @@ class Network(util.DaemonThread):
self.set_status('connected')
self.notify('updated')
+ @with_interface_lock
def close_interface(self, interface):
if interface:
if interface.server in self.interfaces:
@@ -530,6 +568,7 @@ class Network(util.DaemonThread):
self.interface = None
interface.close()
+ @with_recent_servers_lock
def add_recent_server(self, server):
# list is ordered
if server in self.recent_servers:
@@ -587,7 +626,8 @@ class Network(util.DaemonThread):
for callback in callbacks:
callback(response)
- def get_index(self, method, params):
+ @classmethod
+ def get_index(cls, method, params):
""" hashable index for subscriptions and cache"""
return str(method) + (':' + str(params[0]) if params else '')
@@ -602,12 +642,15 @@ class Network(util.DaemonThread):
# and are placed in the unanswered_requests dictionary
client_req = self.unanswered_requests.pop(message_id, None)
if client_req:
- assert interface == self.interface
+ if interface != self.interface:
+ # we probably changed the current interface
+ # in the meantime; drop this.
+ return
callbacks = [client_req[2]]
else:
# fixme: will only work for subscriptions
k = self.get_index(method, params)
- callbacks = self.subscriptions.get(k, [])
+ callbacks = list(self.subscriptions.get(k, []))
# Copy the request method and params to the response
response['method'] = method
@@ -615,7 +658,8 @@ class Network(util.DaemonThread):
# Only once we've received a response to an addr subscription
# add it to the list; avoids double-sends on reconnection
if method == 'blockchain.scripthash.subscribe':
- self.subscribed_addresses.add(params[0])
+ with self.subscribed_addresses_lock:
+ self.subscribed_addresses.add(params[0])
else:
if not response: # Closed remotely / misbehaving
self.connection_down(interface.server)
@@ -630,27 +674,29 @@ class Network(util.DaemonThread):
elif method == 'blockchain.scripthash.subscribe':
response['params'] = [params[0]] # addr
response['result'] = params[1]
- callbacks = self.subscriptions.get(k, [])
+ callbacks = list(self.subscriptions.get(k, []))
# update cache if it's a subscription
if method.endswith('.subscribe'):
- self.sub_cache[k] = response
+ with self.interface_lock:
+ self.sub_cache[k] = response
# Response is now in canonical form
self.process_response(interface, response, callbacks)
def send(self, messages, callback):
'''Messages is a list of (method, params) tuples'''
messages = list(messages)
- with self.lock:
+ with self.pending_sends_lock:
self.pending_sends.append((messages, callback))
+ @with_interface_lock
def process_pending_sends(self):
# Requests needs connectivity. If we don't have an interface,
# we cannot process them.
if not self.interface:
return
- with self.lock:
+ with self.pending_sends_lock:
sends = self.pending_sends
self.pending_sends = []
@@ -660,10 +706,11 @@ class Network(util.DaemonThread):
if method.endswith('.subscribe'):
k = self.get_index(method, params)
# add callback to list
- l = self.subscriptions.get(k, [])
+ l = list(self.subscriptions.get(k, []))
if callback not in l:
l.append(callback)
- self.subscriptions[k] = l
+ with self.callback_lock:
+ self.subscriptions[k] = l
# check cached response for subscriptions
r = self.sub_cache.get(k)
@@ -679,11 +726,12 @@ class Network(util.DaemonThread):
# Note: we can't unsubscribe from the server, so if we receive
# subsequent notifications process_response() will emit a harmless
# "received unexpected notification" warning
- with self.lock:
+ with self.callback_lock:
for v in self.subscriptions.values():
if callback in v:
v.remove(callback)
+ @with_interface_lock
def connection_down(self, server):
'''A connection to server either went down, or was never made.
We distinguish by whether it is in self.interfaces.'''
@@ -693,9 +741,10 @@ class Network(util.DaemonThread):
if server in self.interfaces:
self.close_interface(self.interfaces[server])
self.notify('interfaces')
- for b in self.blockchains.values():
- if b.catch_up == server:
- b.catch_up = None
+ with self.blockchains_lock:
+ for b in self.blockchains.values():
+ if b.catch_up == server:
+ b.catch_up = None
def new_interface(self, server, socket):
# todo: get tip first, then decide which checkpoint to use.
@@ -706,7 +755,8 @@ class Network(util.DaemonThread):
interface.tip = 0
interface.mode = 'default'
interface.request = None
- self.interfaces[server] = interface
+ with self.interface_lock:
+ self.interfaces[server] = interface
# server.version should be the first message
params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
self.queue_request('server.version', params, interface)
@@ -729,7 +779,9 @@ class Network(util.DaemonThread):
# Send pings and shut down stale interfaces
# must use copy of values
- for interface in list(self.interfaces.values()):
+ with self.interface_lock:
+ interfaces = list(self.interfaces.values())
+ for interface in interfaces:
if interface.has_timed_out():
self.connection_down(interface.server)
elif interface.ping_required():
@@ -737,28 +789,30 @@ class Network(util.DaemonThread):
now = time.time()
# nodes
- if len(self.interfaces) + len(self.connecting) < self.num_server:
- self.start_random_interface()
- if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
- self.print_error('network: retrying connections')
- self.disconnected_servers = set([])
- self.nodes_retry_time = now
+ with self.interface_lock:
+ if len(self.interfaces) + len(self.connecting) < self.num_server:
+ self.start_random_interface()
+ if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
+ self.print_error('network: retrying connections')
+ self.disconnected_servers = set([])
+ self.nodes_retry_time = now
# main interface
- if not self.is_connected():
- if self.auto_connect:
- if not self.is_connecting():
- self.switch_to_random_interface()
- else:
- if self.default_server in self.disconnected_servers:
- if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
- self.disconnected_servers.remove(self.default_server)
- self.server_retry_time = now
+ with self.interface_lock:
+ if not self.is_connected():
+ if self.auto_connect:
+ if not self.is_connecting():
+ self.switch_to_random_interface()
else:
- self.switch_to_interface(self.default_server)
- else:
- if self.config.is_fee_estimates_update_required():
- self.request_fee_estimates()
+ if self.default_server in self.disconnected_servers:
+ if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
+ self.disconnected_servers.remove(self.default_server)
+ self.server_retry_time = now
+ else:
+ self.switch_to_interface(self.default_server)
+ else:
+ if self.config.is_fee_estimates_update_required():
+ self.request_fee_estimates()
def request_chunk(self, interface, index):
if index in self.requested_chunks:
@@ -876,7 +930,8 @@ class Network(util.DaemonThread):
if bh > interface.good:
if not interface.blockchain.check_header(interface.bad_header):
b = interface.blockchain.fork(interface.bad_header)
- self.blockchains[interface.bad] = b
+ with self.blockchains_lock:
+ self.blockchains[interface.bad] = b
interface.blockchain = b
interface.print_error("new chain", b.checkpoint)
interface.mode = 'catch_up'
@@ -928,7 +983,9 @@ class Network(util.DaemonThread):
self.notify('interfaces')
def maintain_requests(self):
- for interface in list(self.interfaces.values()):
+ with self.interface_lock:
+ interfaces = list(self.interfaces.values())
+ for interface in interfaces:
if interface.request and time.time() - interface.request_time > 20:
interface.print_error("blockchain request timed out")
self.connection_down(interface.server)
@@ -940,14 +997,14 @@ class Network(util.DaemonThread):
if not self.interfaces:
time.sleep(0.1)
return
- rin = [i for i in self.interfaces.values()]
- win = [i for i in self.interfaces.values() if i.num_requests()]
+ with self.interface_lock:
+ interfaces = list(self.interfaces.values())
+ rin = [i for i in interfaces]
+ win = [i for i in interfaces if i.num_requests()]
try:
rout, wout, xout = select.select(rin, win, [], 0.1)
except socket.error as e:
- # TODO: py3, get code from e
- code = None
- if code == errno.EINTR:
+ if e.errno == errno.EINTR:
return
raise
assert not xout
@@ -1004,7 +1061,8 @@ class Network(util.DaemonThread):
self.notify('updated')
self.notify('interfaces')
return
- tip = max([x.height() for x in self.blockchains.values()])
+ with self.blockchains_lock:
+ tip = max([x.height() for x in self.blockchains.values()])
if tip >=0:
interface.mode = 'backward'
interface.bad = height
@@ -1016,19 +1074,24 @@ class Network(util.DaemonThread):
chain.catch_up = interface
interface.mode = 'catch_up'
interface.blockchain = chain
- self.print_error("switching to catchup mode", tip, self.blockchains)
+ with self.blockchains_lock:
+ self.print_error("switching to catchup mode", tip, self.blockchains)
self.request_header(interface, 0)
else:
self.print_error("chain already catching up with", chain.catch_up.server)
+ @with_interface_lock
def blockchain(self):
if self.interface and self.interface.blockchain is not None:
self.blockchain_index = self.interface.blockchain.checkpoint
return self.blockchains[self.blockchain_index]
+ @with_interface_lock
def get_blockchains(self):
out = {}
- for k, b in self.blockchains.items():
+ with self.blockchains_lock:
+ blockchain_items = list(self.blockchains.items())
+ for k, b in blockchain_items:
r = list(filter(lambda i: i.blockchain==b, list(self.interfaces.values())))
if r:
out[k] = r
@@ -1039,18 +1102,21 @@ class Network(util.DaemonThread):
if blockchain:
self.blockchain_index = index
self.config.set_key('blockchain_index', index)
- for i in self.interfaces.values():
+ with self.interface_lock:
+ interfaces = list(self.interfaces.values())
+ for i in interfaces:
if i.blockchain == blockchain:
self.switch_to_interface(i.server)
break
else:
raise Exception('blockchain not found', index)
- if self.interface:
- server = self.interface.server
- host, port, protocol, proxy, auto_connect = self.get_parameters()
- host, port, protocol = server.split(':')
- self.set_parameters(host, port, protocol, proxy, auto_connect)
+ with self.interface_lock:
+ if self.interface:
+ server = self.interface.server
+ host, port, protocol, proxy, auto_connect = self.get_parameters()
+ host, port, protocol = server.split(':')
+ self.set_parameters(host, port, protocol, proxy, auto_connect)
def get_local_height(self):
return self.blockchain().height()
@@ -1189,5 +1255,6 @@ class Network(util.DaemonThread):
with open(path, 'w', encoding='utf-8') as f:
f.write(json.dumps(cp, indent=4))
- def max_checkpoint(self):
+ @classmethod
+ def max_checkpoint(cls):
return max(0, len(constants.net.CHECKPOINTS) * 2016 - 1)