electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 112b0e0544dc25fce718e66ef678cd7b360d405b
parent a7589a97ad04a5dbb9889d06e958a7ea55d8ac8e
Author: ThomasV <thomasv@electrum.org>
Date:   Fri, 22 Jun 2018 13:05:40 +0200

Merge pull request #4453 from SomberNight/network_locks

locks in network.py
Diffstat:
Mlib/network.py | 203++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------------
1 file changed, 135 insertions(+), 68 deletions(-)

diff --git a/lib/network.py b/lib/network.py @@ -171,7 +171,7 @@ class Network(util.DaemonThread): util.DaemonThread.__init__(self) self.config = SimpleConfig(config) if isinstance(config, dict) else config self.num_server = 10 if not self.config.get('oneserver') else 0 - self.blockchains = blockchain.read_blockchains(self.config) + self.blockchains = blockchain.read_blockchains(self.config) # note: needs self.blockchains_lock self.print_error("blockchains", self.blockchains.keys()) self.blockchain_index = config.get('blockchain_index', 0) if self.blockchain_index not in self.blockchains.keys(): @@ -187,27 +187,35 @@ class Network(util.DaemonThread): self.default_server = None if not self.default_server: self.default_server = pick_random_server() - self.lock = threading.Lock() + + # locks: if you need to take multiple ones, acquire them in the order they are defined here! + self.interface_lock = threading.RLock() # <- re-entrant + self.callback_lock = threading.Lock() + self.pending_sends_lock = threading.Lock() + self.recent_servers_lock = threading.RLock() # <- re-entrant + self.subscribed_addresses_lock = threading.Lock() + self.blockchains_lock = threading.Lock() + self.pending_sends = [] self.message_id = 0 self.debug = False self.irc_servers = {} # returned by interface (list from irc) - self.recent_servers = self.read_recent_servers() + self.recent_servers = self.read_recent_servers() # note: needs self.recent_servers_lock self.banner = '' self.donation_address = '' self.relay_fee = None # callbacks passed with subscriptions - self.subscriptions = defaultdict(list) - self.sub_cache = {} + self.subscriptions = defaultdict(list) # note: needs self.callback_lock + self.sub_cache = {} # note: needs self.interface_lock # callbacks set by the GUI - self.callbacks = defaultdict(list) + self.callbacks = defaultdict(list) # note: needs self.callback_lock dir_path = os.path.join( self.config.path, 'certs') util.make_dir(dir_path) # subscriptions and requests - self.subscribed_addresses = set() + self.subscribed_addresses = set() # note: needs self.subscribed_addresses_lock self.h2addr = {} # Requests from client we've not seen a response to self.unanswered_requests = {} @@ -217,8 +225,8 @@ class Network(util.DaemonThread): # kick off the network. interface is the main server we are currently # communicating with. interfaces is the set of servers we are connecting # to or have an ongoing connection with - self.interface = None - self.interfaces = {} + self.interface = None # note: needs self.interface_lock + self.interfaces = {} # note: needs self.interface_lock self.auto_connect = self.config.get('auto_connect', True) self.connecting = set() self.requested_chunks = set() @@ -226,19 +234,31 @@ class Network(util.DaemonThread): self.start_network(deserialize_server(self.default_server)[2], deserialize_proxy(self.config.get('proxy'))) + def with_interface_lock(func): + def func_wrapper(self, *args, **kwargs): + with self.interface_lock: + return func(self, *args, **kwargs) + return func_wrapper + + def with_recent_servers_lock(func): + def func_wrapper(self, *args, **kwargs): + with self.recent_servers_lock: + return func(self, *args, **kwargs) + return func_wrapper + def register_callback(self, callback, events): - with self.lock: + with self.callback_lock: for event in events: self.callbacks[event].append(callback) def unregister_callback(self, callback): - with self.lock: + with self.callback_lock: for callbacks in self.callbacks.values(): if callback in callbacks: callbacks.remove(callback) def trigger_callback(self, event, *args): - with self.lock: + with self.callback_lock: callbacks = self.callbacks[event][:] [callback(event, *args) for callback in callbacks] @@ -253,6 +273,7 @@ class Network(util.DaemonThread): except: return [] + @with_recent_servers_lock def save_recent_servers(self): if not self.config.path: return @@ -264,6 +285,7 @@ class Network(util.DaemonThread): except: pass + @with_interface_lock def get_server_height(self): return self.interface.tip if self.interface else 0 @@ -291,11 +313,15 @@ class Network(util.DaemonThread): def is_up_to_date(self): return self.unanswered_requests == {} + @with_interface_lock def queue_request(self, method, params, interface=None): # If you want to queue a request on any interface it must go # through this function so message ids are properly tracked if interface is None: interface = self.interface + if interface is None: + self.print_error('warning: dropping request', method, params) + return message_id = self.message_id self.message_id += 1 if self.debug: @@ -303,7 +329,9 @@ class Network(util.DaemonThread): interface.queue_request(method, params, message_id) return message_id + @with_interface_lock def send_subscriptions(self): + assert self.interface self.print_error('sending subscriptions to', self.interface.server, len(self.unanswered_requests), len(self.subscribed_addresses)) self.sub_cache.clear() # Resend unanswered requests @@ -317,8 +345,9 @@ class Network(util.DaemonThread): self.queue_request('server.peers.subscribe', []) self.request_fee_estimates() self.queue_request('blockchain.relayfee', []) - for h in list(self.subscribed_addresses): - self.queue_request('blockchain.scripthash.subscribe', [h]) + with self.subscribed_addresses_lock: + for h in self.subscribed_addresses: + self.queue_request('blockchain.scripthash.subscribe', [h]) def request_fee_estimates(self): from .simple_config import FEE_ETA_TARGETS @@ -358,10 +387,12 @@ class Network(util.DaemonThread): if self.is_connected(): return self.donation_address + @with_interface_lock def get_interfaces(self): '''The interfaces that are in connected state''' return list(self.interfaces.keys()) + @with_recent_servers_lock def get_servers(self): out = constants.net.DEFAULT_SERVERS if self.irc_servers: @@ -376,6 +407,7 @@ class Network(util.DaemonThread): out[host] = { protocol:port } return out + @with_interface_lock def start_interface(self, server): if (not server in self.interfaces and not server in self.connecting): if server == self.default_server: @@ -385,7 +417,8 @@ class Network(util.DaemonThread): c = Connection(server, self.socket_queue, self.config.path) def start_random_interface(self): - exclude_set = self.disconnected_servers.union(set(self.interfaces)) + with self.interface_lock: + exclude_set = self.disconnected_servers.union(set(self.interfaces)) server = pick_random_server(self.get_servers(), self.protocol, exclude_set) if server: self.start_interface(server) @@ -433,15 +466,17 @@ class Network(util.DaemonThread): else: socket.getaddrinfo = socket._getaddrinfo + @with_interface_lock def start_network(self, protocol, proxy): assert not self.interface and not self.interfaces assert not self.connecting and self.socket_queue.empty() self.print_error('starting network') - self.disconnected_servers = set([]) + self.disconnected_servers = set([]) # note: needs self.interface_lock self.protocol = protocol self.set_proxy(proxy) self.start_interfaces() + @with_interface_lock def stop_network(self): self.print_error("stopping network") for interface in list(self.interfaces.values()): @@ -491,6 +526,7 @@ class Network(util.DaemonThread): if servers: self.switch_to_interface(random.choice(servers)) + @with_interface_lock def switch_lagging_interface(self): '''If auto_connect and lagging, switch interface''' if self.server_is_lagging() and self.auto_connect: @@ -501,6 +537,7 @@ class Network(util.DaemonThread): choice = random.choice(filtered) self.switch_to_interface(choice) + @with_interface_lock def switch_to_interface(self, server): '''Switch to server as our interface. If no connection exists nor being opened, start a thread to connect. The actual switch will @@ -522,6 +559,7 @@ class Network(util.DaemonThread): self.set_status('connected') self.notify('updated') + @with_interface_lock def close_interface(self, interface): if interface: if interface.server in self.interfaces: @@ -530,6 +568,7 @@ class Network(util.DaemonThread): self.interface = None interface.close() + @with_recent_servers_lock def add_recent_server(self, server): # list is ordered if server in self.recent_servers: @@ -587,7 +626,8 @@ class Network(util.DaemonThread): for callback in callbacks: callback(response) - def get_index(self, method, params): + @classmethod + def get_index(cls, method, params): """ hashable index for subscriptions and cache""" return str(method) + (':' + str(params[0]) if params else '') @@ -602,12 +642,15 @@ class Network(util.DaemonThread): # and are placed in the unanswered_requests dictionary client_req = self.unanswered_requests.pop(message_id, None) if client_req: - assert interface == self.interface + if interface != self.interface: + # we probably changed the current interface + # in the meantime; drop this. + return callbacks = [client_req[2]] else: # fixme: will only work for subscriptions k = self.get_index(method, params) - callbacks = self.subscriptions.get(k, []) + callbacks = list(self.subscriptions.get(k, [])) # Copy the request method and params to the response response['method'] = method @@ -615,7 +658,8 @@ class Network(util.DaemonThread): # Only once we've received a response to an addr subscription # add it to the list; avoids double-sends on reconnection if method == 'blockchain.scripthash.subscribe': - self.subscribed_addresses.add(params[0]) + with self.subscribed_addresses_lock: + self.subscribed_addresses.add(params[0]) else: if not response: # Closed remotely / misbehaving self.connection_down(interface.server) @@ -630,27 +674,29 @@ class Network(util.DaemonThread): elif method == 'blockchain.scripthash.subscribe': response['params'] = [params[0]] # addr response['result'] = params[1] - callbacks = self.subscriptions.get(k, []) + callbacks = list(self.subscriptions.get(k, [])) # update cache if it's a subscription if method.endswith('.subscribe'): - self.sub_cache[k] = response + with self.interface_lock: + self.sub_cache[k] = response # Response is now in canonical form self.process_response(interface, response, callbacks) def send(self, messages, callback): '''Messages is a list of (method, params) tuples''' messages = list(messages) - with self.lock: + with self.pending_sends_lock: self.pending_sends.append((messages, callback)) + @with_interface_lock def process_pending_sends(self): # Requests needs connectivity. If we don't have an interface, # we cannot process them. if not self.interface: return - with self.lock: + with self.pending_sends_lock: sends = self.pending_sends self.pending_sends = [] @@ -660,10 +706,11 @@ class Network(util.DaemonThread): if method.endswith('.subscribe'): k = self.get_index(method, params) # add callback to list - l = self.subscriptions.get(k, []) + l = list(self.subscriptions.get(k, [])) if callback not in l: l.append(callback) - self.subscriptions[k] = l + with self.callback_lock: + self.subscriptions[k] = l # check cached response for subscriptions r = self.sub_cache.get(k) @@ -679,11 +726,12 @@ class Network(util.DaemonThread): # Note: we can't unsubscribe from the server, so if we receive # subsequent notifications process_response() will emit a harmless # "received unexpected notification" warning - with self.lock: + with self.callback_lock: for v in self.subscriptions.values(): if callback in v: v.remove(callback) + @with_interface_lock def connection_down(self, server): '''A connection to server either went down, or was never made. We distinguish by whether it is in self.interfaces.''' @@ -693,9 +741,10 @@ class Network(util.DaemonThread): if server in self.interfaces: self.close_interface(self.interfaces[server]) self.notify('interfaces') - for b in self.blockchains.values(): - if b.catch_up == server: - b.catch_up = None + with self.blockchains_lock: + for b in self.blockchains.values(): + if b.catch_up == server: + b.catch_up = None def new_interface(self, server, socket): # todo: get tip first, then decide which checkpoint to use. @@ -706,7 +755,8 @@ class Network(util.DaemonThread): interface.tip = 0 interface.mode = 'default' interface.request = None - self.interfaces[server] = interface + with self.interface_lock: + self.interfaces[server] = interface # server.version should be the first message params = [ELECTRUM_VERSION, PROTOCOL_VERSION] self.queue_request('server.version', params, interface) @@ -729,7 +779,9 @@ class Network(util.DaemonThread): # Send pings and shut down stale interfaces # must use copy of values - for interface in list(self.interfaces.values()): + with self.interface_lock: + interfaces = list(self.interfaces.values()) + for interface in interfaces: if interface.has_timed_out(): self.connection_down(interface.server) elif interface.ping_required(): @@ -737,28 +789,30 @@ class Network(util.DaemonThread): now = time.time() # nodes - if len(self.interfaces) + len(self.connecting) < self.num_server: - self.start_random_interface() - if now - self.nodes_retry_time > NODES_RETRY_INTERVAL: - self.print_error('network: retrying connections') - self.disconnected_servers = set([]) - self.nodes_retry_time = now + with self.interface_lock: + if len(self.interfaces) + len(self.connecting) < self.num_server: + self.start_random_interface() + if now - self.nodes_retry_time > NODES_RETRY_INTERVAL: + self.print_error('network: retrying connections') + self.disconnected_servers = set([]) + self.nodes_retry_time = now # main interface - if not self.is_connected(): - if self.auto_connect: - if not self.is_connecting(): - self.switch_to_random_interface() - else: - if self.default_server in self.disconnected_servers: - if now - self.server_retry_time > SERVER_RETRY_INTERVAL: - self.disconnected_servers.remove(self.default_server) - self.server_retry_time = now + with self.interface_lock: + if not self.is_connected(): + if self.auto_connect: + if not self.is_connecting(): + self.switch_to_random_interface() else: - self.switch_to_interface(self.default_server) - else: - if self.config.is_fee_estimates_update_required(): - self.request_fee_estimates() + if self.default_server in self.disconnected_servers: + if now - self.server_retry_time > SERVER_RETRY_INTERVAL: + self.disconnected_servers.remove(self.default_server) + self.server_retry_time = now + else: + self.switch_to_interface(self.default_server) + else: + if self.config.is_fee_estimates_update_required(): + self.request_fee_estimates() def request_chunk(self, interface, index): if index in self.requested_chunks: @@ -876,7 +930,8 @@ class Network(util.DaemonThread): if bh > interface.good: if not interface.blockchain.check_header(interface.bad_header): b = interface.blockchain.fork(interface.bad_header) - self.blockchains[interface.bad] = b + with self.blockchains_lock: + self.blockchains[interface.bad] = b interface.blockchain = b interface.print_error("new chain", b.checkpoint) interface.mode = 'catch_up' @@ -928,7 +983,9 @@ class Network(util.DaemonThread): self.notify('interfaces') def maintain_requests(self): - for interface in list(self.interfaces.values()): + with self.interface_lock: + interfaces = list(self.interfaces.values()) + for interface in interfaces: if interface.request and time.time() - interface.request_time > 20: interface.print_error("blockchain request timed out") self.connection_down(interface.server) @@ -940,14 +997,14 @@ class Network(util.DaemonThread): if not self.interfaces: time.sleep(0.1) return - rin = [i for i in self.interfaces.values()] - win = [i for i in self.interfaces.values() if i.num_requests()] + with self.interface_lock: + interfaces = list(self.interfaces.values()) + rin = [i for i in interfaces] + win = [i for i in interfaces if i.num_requests()] try: rout, wout, xout = select.select(rin, win, [], 0.1) except socket.error as e: - # TODO: py3, get code from e - code = None - if code == errno.EINTR: + if e.errno == errno.EINTR: return raise assert not xout @@ -1004,7 +1061,8 @@ class Network(util.DaemonThread): self.notify('updated') self.notify('interfaces') return - tip = max([x.height() for x in self.blockchains.values()]) + with self.blockchains_lock: + tip = max([x.height() for x in self.blockchains.values()]) if tip >=0: interface.mode = 'backward' interface.bad = height @@ -1016,19 +1074,24 @@ class Network(util.DaemonThread): chain.catch_up = interface interface.mode = 'catch_up' interface.blockchain = chain - self.print_error("switching to catchup mode", tip, self.blockchains) + with self.blockchains_lock: + self.print_error("switching to catchup mode", tip, self.blockchains) self.request_header(interface, 0) else: self.print_error("chain already catching up with", chain.catch_up.server) + @with_interface_lock def blockchain(self): if self.interface and self.interface.blockchain is not None: self.blockchain_index = self.interface.blockchain.checkpoint return self.blockchains[self.blockchain_index] + @with_interface_lock def get_blockchains(self): out = {} - for k, b in self.blockchains.items(): + with self.blockchains_lock: + blockchain_items = list(self.blockchains.items()) + for k, b in blockchain_items: r = list(filter(lambda i: i.blockchain==b, list(self.interfaces.values()))) if r: out[k] = r @@ -1039,18 +1102,21 @@ class Network(util.DaemonThread): if blockchain: self.blockchain_index = index self.config.set_key('blockchain_index', index) - for i in self.interfaces.values(): + with self.interface_lock: + interfaces = list(self.interfaces.values()) + for i in interfaces: if i.blockchain == blockchain: self.switch_to_interface(i.server) break else: raise Exception('blockchain not found', index) - if self.interface: - server = self.interface.server - host, port, protocol, proxy, auto_connect = self.get_parameters() - host, port, protocol = server.split(':') - self.set_parameters(host, port, protocol, proxy, auto_connect) + with self.interface_lock: + if self.interface: + server = self.interface.server + host, port, protocol, proxy, auto_connect = self.get_parameters() + host, port, protocol = server.split(':') + self.set_parameters(host, port, protocol, proxy, auto_connect) def get_local_height(self): return self.blockchain().height() @@ -1189,5 +1255,6 @@ class Network(util.DaemonThread): with open(path, 'w', encoding='utf-8') as f: f.write(json.dumps(cp, indent=4)) - def max_checkpoint(self): + @classmethod + def max_checkpoint(cls): return max(0, len(constants.net.CHECKPOINTS) * 2016 - 1)