electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 684e69763a5d1091887793e560609efd55d58d0d
parent cd5152a02d0cc3e544d2c0143baf4d3b7405c226
Author: ThomasV <thomasv@electrum.org>
Date:   Fri, 12 Oct 2018 10:50:47 +0200

Merge pull request #4767 from SomberNight/auto_jump_forks

network: auto-switch servers to preferred fork (or longest chain)
Diffstat:
Melectrum/blockchain.py | 24+++++++++++++++++-------
Melectrum/gui/kivy/main_window.py | 4++--
Melectrum/gui/qt/network_dialog.py | 11++++-------
Melectrum/interface.py | 1+
Melectrum/network.py | 120+++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------
Melectrum/verifier.py | 2+-
6 files changed, 103 insertions(+), 59 deletions(-)

diff --git a/electrum/blockchain.py b/electrum/blockchain.py @@ -22,7 +22,7 @@ # SOFTWARE. import os import threading -from typing import Optional +from typing import Optional, Dict from . import util from .bitcoin import Hash, hash_encode, int_to_hex, rev_hex @@ -73,7 +73,7 @@ def hash_header(header: dict) -> str: return hash_encode(Hash(bfh(serialize_header(header)))) -blockchains = {} +blockchains = {} # type: Dict[int, Blockchain] blockchains_lock = threading.Lock() @@ -100,7 +100,7 @@ class Blockchain(util.PrintError): Manages blockchain headers and their verification """ - def __init__(self, config, forkpoint: int, parent_id: int): + def __init__(self, config, forkpoint: int, parent_id: Optional[int]): self.config = config self.forkpoint = forkpoint self.checkpoints = constants.net.CHECKPOINTS @@ -124,22 +124,32 @@ class Blockchain(util.PrintError): children = list(filter(lambda y: y.parent_id==self.forkpoint, chains)) return max([x.forkpoint for x in children]) if children else None - def get_forkpoint(self) -> int: + def get_max_forkpoint(self) -> int: + """Returns the max height where there is a fork + related to this chain. + """ mc = self.get_max_child() return mc if mc is not None else self.forkpoint def get_branch_size(self) -> int: - return self.height() - self.get_forkpoint() + 1 + return self.height() - self.get_max_forkpoint() + 1 def get_name(self) -> str: - return self.get_hash(self.get_forkpoint()).lstrip('00')[0:10] + return self.get_hash(self.get_max_forkpoint()).lstrip('00')[0:10] def check_header(self, header: dict) -> bool: header_hash = hash_header(header) height = header.get('block_height') + return self.check_hash(height, header_hash) + + def check_hash(self, height: int, header_hash: str) -> bool: + """Returns whether the hash of the block at given height + is the given hash. + """ + assert isinstance(header_hash, str) and len(header_hash) == 64, header_hash # hex try: return header_hash == self.get_hash(height) - except MissingHeader: + except Exception: return False def fork(parent, header: dict) -> 'Blockchain': diff --git a/electrum/gui/kivy/main_window.py b/electrum/gui/kivy/main_window.py @@ -120,7 +120,7 @@ class ElectrumWindow(App): with blockchain.blockchains_lock: blockchain_items = list(blockchain.blockchains.items()) for index, b in blockchain_items: if name == b.get_name(): - self.network.run_from_another_thread(self.network.follow_chain(index)) + self.network.run_from_another_thread(self.network.follow_chain_given_id(index)) names = [blockchain.blockchains[b].get_name() for b in chains] if len(names) > 1: cur_chain = self.network.blockchain().get_name() @@ -664,7 +664,7 @@ class ElectrumWindow(App): self.num_nodes = len(self.network.get_interfaces()) self.num_chains = len(self.network.get_blockchains()) chain = self.network.blockchain() - self.blockchain_forkpoint = chain.get_forkpoint() + self.blockchain_forkpoint = chain.get_max_forkpoint() self.blockchain_name = chain.get_name() interface = self.network.interface if interface: diff --git a/electrum/gui/qt/network_dialog.py b/electrum/gui/qt/network_dialog.py @@ -107,7 +107,7 @@ class NodesListWidget(QTreeWidget): b = blockchain.blockchains[k] name = b.get_name() if n_chains >1: - x = QTreeWidgetItem([name + '@%d'%b.get_forkpoint(), '%d'%b.height()]) + x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()]) x.setData(0, Qt.UserRole, 1) x.setData(1, Qt.UserRole, b.forkpoint) else: @@ -364,7 +364,7 @@ class NetworkChoiceLayout(object): chains = self.network.get_blockchains() if len(chains) > 1: chain = self.network.blockchain() - forkpoint = chain.get_forkpoint() + forkpoint = chain.get_max_forkpoint() name = chain.get_name() msg = _('Chain split detected at block {0}').format(forkpoint) + '\n' msg += (_('You are following branch') if auto_connect else _('Your server is on branch'))+ ' ' + name @@ -411,14 +411,11 @@ class NetworkChoiceLayout(object): self.set_server() def follow_branch(self, index): - self.network.run_from_another_thread(self.network.follow_chain(index)) + self.network.run_from_another_thread(self.network.follow_chain_given_id(index)) self.update() def follow_server(self, server): - net_params = self.network.get_parameters() - host, port, protocol = deserialize_server(server) - net_params = net_params._replace(host=host, port=port, protocol=protocol) - self.network.run_from_another_thread(self.network.set_parameters(net_params)) + self.network.run_from_another_thread(self.network.follow_chain_given_server(server)) self.update() def server_changed(self, x): diff --git a/electrum/interface.py b/electrum/interface.py @@ -384,6 +384,7 @@ class Interface(PrintError): self.mark_ready() await self._process_header_at_tip() self.network.trigger_callback('network_updated') + await self.network.switch_unwanted_fork_interface() await self.network.switch_lagging_interface() async def _process_header_at_tip(self): diff --git a/electrum/network.py b/electrum/network.py @@ -32,7 +32,7 @@ import json import sys import ipaddress import asyncio -from typing import NamedTuple, Optional, Sequence, List +from typing import NamedTuple, Optional, Sequence, List, Dict import traceback import dns @@ -172,10 +172,9 @@ class Network(PrintError): self.config = SimpleConfig(config) if isinstance(config, dict) else config self.num_server = 10 if not self.config.get('oneserver') else 0 blockchain.blockchains = blockchain.read_blockchains(self.config) - self.print_error("blockchains", list(blockchain.blockchains.keys())) - self.blockchain_index = config.get('blockchain_index', 0) - if self.blockchain_index not in blockchain.blockchains.keys(): - self.blockchain_index = 0 + self.print_error("blockchains", list(blockchain.blockchains)) + self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Optional[Dict] + self._blockchain_index = 0 # Server for addresses and transactions self.default_server = self.config.get('server', None) # Sanitize default server @@ -213,11 +212,10 @@ class Network(PrintError): # retry times self.server_retry_time = time.time() self.nodes_retry_time = time.time() - # kick off the network. interface is the main server we are currently - # communicating with. interfaces is the set of servers we are connecting - # to or have an ongoing connection with + # the main server we are currently communicating with self.interface = None # type: Interface - self.interfaces = {} + # set of servers we have an ongoing connection with + self.interfaces = {} # type: Dict[str, Interface] self.auto_connect = self.config.get('auto_connect', True) self.connecting = set() self.server_queue = None @@ -227,8 +225,8 @@ class Network(PrintError): #self.asyncio_loop.set_debug(1) self._run_forever = asyncio.Future() self._thread = threading.Thread(target=self.asyncio_loop.run_until_complete, - args=(self._run_forever,), - name='Network') + args=(self._run_forever,), + name='Network') self._thread.start() def run_from_another_thread(self, coro): @@ -523,20 +521,40 @@ class Network(PrintError): async def switch_lagging_interface(self): '''If auto_connect and lagging, switch interface''' - if await self._server_is_lagging() and self.auto_connect: + if self.auto_connect and await self._server_is_lagging(): # switch to one that has the correct header (not height) - header = self.blockchain().read_header(self.get_local_height()) - def filt(x): - a = x[1].tip_header - b = header - assert type(a) is type(b) - return a == b - - with self.interfaces_lock: interfaces_items = list(self.interfaces.items()) - filtered = list(map(lambda x: x[0], filter(filt, interfaces_items))) + best_header = self.blockchain().read_header(self.get_local_height()) + with self.interfaces_lock: interfaces = list(self.interfaces.values()) + filtered = list(filter(lambda iface: iface.tip_header == best_header, interfaces)) if filtered: - choice = random.choice(filtered) - await self.switch_to_interface(choice) + chosen_iface = random.choice(filtered) + await self.switch_to_interface(chosen_iface.server) + + async def switch_unwanted_fork_interface(self): + """If auto_connect and main interface is not on preferred fork, + try to switch to preferred fork. + """ + if not self.auto_connect: + return + with self.interfaces_lock: interfaces = list(self.interfaces.values()) + # try to switch to preferred fork + if self._blockchain_preferred_block: + pref_height = self._blockchain_preferred_block['height'] + pref_hash = self._blockchain_preferred_block['hash'] + filtered = list(filter(lambda iface: iface.blockchain.check_hash(pref_height, pref_hash), + interfaces)) + if filtered: + chosen_iface = random.choice(filtered) + await self.switch_to_interface(chosen_iface.server) + return + # try to switch to longest chain + if self.blockchain().parent_id is None: + return # already on longest chain + filtered = list(filter(lambda iface: iface.blockchain.parent_id is None, + interfaces)) + if filtered: + chosen_iface = random.choice(filtered) + await self.switch_to_interface(chosen_iface.server) async def switch_to_interface(self, server: str): """Switch to server as our main interface. If no connection exists, @@ -704,8 +722,8 @@ class Network(PrintError): def blockchain(self) -> Blockchain: interface = self.interface if interface and interface.blockchain is not None: - self.blockchain_index = interface.blockchain.forkpoint - return blockchain.blockchains[self.blockchain_index] + self._blockchain_index = interface.blockchain.forkpoint + return blockchain.blockchains[self._blockchain_index] def get_blockchains(self): out = {} # blockchain_id -> list(interfaces) @@ -724,24 +742,42 @@ class Network(PrintError): await self.connection_down(interface.server) return ifaces - async def follow_chain(self, chain_id): - bc = blockchain.blockchains.get(chain_id) - if bc: - self.blockchain_index = chain_id - self.config.set_key('blockchain_index', chain_id) - with self.interfaces_lock: interfaces_values = list(self.interfaces.values()) - for iface in interfaces_values: - if iface.blockchain == bc: - await self.switch_to_interface(iface.server) - break - else: - raise Exception('blockchain not found', chain_id) + def _set_preferred_chain(self, chain: Blockchain): + height = chain.get_max_forkpoint() + header_hash = chain.get_hash(height) + self._blockchain_preferred_block = { + 'height': height, + 'hash': header_hash, + } + self.config.set_key('blockchain_preferred_block', self._blockchain_preferred_block) - if self.interface: - net_params = self.get_parameters() - host, port, protocol = deserialize_server(self.interface.server) - net_params = net_params._replace(host=host, port=port, protocol=protocol) - await self.set_parameters(net_params) + async def follow_chain_given_id(self, chain_id: int) -> None: + bc = blockchain.blockchains.get(chain_id) + if not bc: + raise Exception('blockchain {} not found'.format(chain_id)) + self._set_preferred_chain(bc) + # select server on this chain + with self.interfaces_lock: interfaces = list(self.interfaces.values()) + interfaces_on_selected_chain = list(filter(lambda iface: iface.blockchain == bc, interfaces)) + if len(interfaces_on_selected_chain) == 0: return + chosen_iface = random.choice(interfaces_on_selected_chain) + # switch to server (and save to config) + net_params = self.get_parameters() + host, port, protocol = deserialize_server(chosen_iface.server) + net_params = net_params._replace(host=host, port=port, protocol=protocol) + await self.set_parameters(net_params) + + async def follow_chain_given_server(self, server_str: str) -> None: + # note that server_str should correspond to a connected interface + iface = self.interfaces.get(server_str) + if iface is None: + return + self._set_preferred_chain(iface.blockchain) + # switch to server (and save to config) + net_params = self.get_parameters() + host, port, protocol = deserialize_server(server_str) + net_params = net_params._replace(host=host, port=port, protocol=protocol) + await self.set_parameters(net_params) def get_local_height(self): return self.blockchain().height() diff --git a/electrum/verifier.py b/electrum/verifier.py @@ -156,7 +156,7 @@ class SPV(NetworkJobOnDefaultServer): async def _maybe_undo_verifications(self): def undo_verifications(): - height = self.blockchain.get_forkpoint() + height = self.blockchain.get_max_forkpoint() self.print_error("undoing verifications back to height {}".format(height)) tx_hashes = self.wallet.undo_verifications(self.blockchain, height) for tx_hash in tx_hashes: