electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 2eec7e16004449d70a081e80f782fdc6c90f6b02
parent 9385d2dae371075c73d4919557ab8935339e1350
Author: SomberNight <somber.night@protonmail.com>
Date:   Sun, 21 Jun 2020 11:31:54 +0200

network: smarter switch_unwanted_fork_interface

Previously this function would not switch to a different chain if the
current chain contained the preferred block. This was not the intended
behaviour: if there is a *stronger* chain that *also* contains the
preferred block, we should jump to that.

Note that with this commit there will now always be a preferred block
(defaults to genesis). Previously, it might seem that often there was none,
but actually in practice if the user used the GUI context menu to switch
servers even once, there was one (usually genesis).

Hence, with the old code, if an attacker mined a single header which
then got reorged, auto_connect clients which were connected to the
attacker's server would never switch servers (jump chains) even
without the user explicitly configuring preference for the stale branch.

Diffstat:
Melectrum/blockchain.py | 13+++++++++++++
Melectrum/network.py | 64+++++++++++++++++++++++++++++++---------------------------------
Melectrum/tests/test_blockchain.py | 48++++++++++++++++++++++++++++++++++++++++++++++++
3 files changed, 92 insertions(+), 33 deletions(-)

diff --git a/electrum/blockchain.py b/electrum/blockchain.py @@ -646,6 +646,7 @@ class Blockchain(Logger): def check_header(header: dict) -> Optional[Blockchain]: + """Returns any Blockchain that contains header, or None.""" if type(header) is not dict: return None with blockchains_lock: chains = list(blockchains.values()) @@ -656,8 +657,20 @@ def check_header(header: dict) -> Optional[Blockchain]: def can_connect(header: dict) -> Optional[Blockchain]: + """Returns the Blockchain that has a tip that directly links up + with header, or None. + """ with blockchains_lock: chains = list(blockchains.values()) for b in chains: if b.can_connect(header): return b return None + + +def get_chains_that_contain_header(height: int, header_hash: str) -> Sequence[Blockchain]: + """Returns a list of Blockchains that contain header, best chain first.""" + with blockchains_lock: chains = list(blockchains.values()) + chains = [chain for chain in chains + if chain.check_hash(height=height, header_hash=header_hash)] + chains = sorted(chains, key=lambda x: x.get_chainwork(), reverse=True) + return chains diff --git a/electrum/network.py b/electrum/network.py @@ -32,7 +32,7 @@ import socket import json import sys import asyncio -from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set +from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set, Any import traceback import concurrent from concurrent import futures @@ -276,7 +276,9 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): blockchain.read_blockchains(self.config) blockchain.init_headers_file_for_best_chain() self.logger.info(f"blockchains {list(map(lambda b: b.forkpoint, blockchain.blockchains.values()))}") - self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Optional[Dict] + self._blockchain_preferred_block = self.config.get('blockchain_preferred_block', None) # type: Dict[str, Any] + if self._blockchain_preferred_block is None: + self._set_preferred_chain(None) self._blockchain = blockchain.get_best_chain() self._allowed_protocols = {PREFERRED_NETWORK_PROTOCOL} @@ -624,7 +626,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): await self.switch_to_interface(random.choice(servers)) async def switch_lagging_interface(self): - '''If auto_connect and lagging, switch interface''' + """If auto_connect and lagging, switch interface (only within fork).""" if self.auto_connect and await self._server_is_lagging(): # switch to one that has the correct header (not height) best_header = self.blockchain().header_at_tip() @@ -634,40 +636,32 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): chosen_iface = random.choice(filtered) await self.switch_to_interface(chosen_iface.server) - async def switch_unwanted_fork_interface(self): - """If auto_connect and main interface is not on preferred fork, - try to switch to preferred fork. - """ + async def switch_unwanted_fork_interface(self) -> None: + """If auto_connect, maybe switch to another fork/chain.""" if not self.auto_connect or not self.interface: return with self.interfaces_lock: interfaces = list(self.interfaces.values()) - # try to switch to preferred fork - if self._blockchain_preferred_block: - pref_height = self._blockchain_preferred_block['height'] - pref_hash = self._blockchain_preferred_block['hash'] - if self.interface.blockchain.check_hash(pref_height, pref_hash): - return # already on preferred fork - filtered = list(filter(lambda iface: iface.blockchain.check_hash(pref_height, pref_hash), - interfaces)) + pref_height = self._blockchain_preferred_block['height'] + pref_hash = self._blockchain_preferred_block['hash'] + # shortcut for common case + if pref_height == 0: + return + # maybe try switching chains; starting with most desirable first + matching_chains = blockchain.get_chains_that_contain_header(pref_height, pref_hash) + chains_to_try = list(matching_chains) + [blockchain.get_best_chain()] + for rank, chain in enumerate(chains_to_try): + # check if main interface is already on this fork + if self.interface.blockchain == chain: + return + # switch to another random interface that is on this fork, if any + filtered = [iface for iface in interfaces + if iface.blockchain == chain] if filtered: - self.logger.info("switching to preferred fork") + self.logger.info(f"switching to (more) preferred fork (rank {rank})") chosen_iface = random.choice(filtered) await self.switch_to_interface(chosen_iface.server) return - else: - self.logger.info("tried to switch to preferred fork but no interfaces are on it") - # try to switch to best chain - if self.blockchain().parent is None: - return # already on best chain - filtered = list(filter(lambda iface: iface.blockchain.parent is None, - interfaces)) - if filtered: - self.logger.info("switching to best chain") - chosen_iface = random.choice(filtered) - await self.switch_to_interface(chosen_iface.server) - else: - # FIXME switch to best available? - self.logger.info("tried to switch to best chain but no interfaces are on it") + self.logger.info("tried to switch to (more) preferred fork but no interfaces are on any") async def switch_to_interface(self, server: ServerAddr): """Switch to server as our main interface. If no connection exists, @@ -1083,9 +1077,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): out[chain_id] = r return out - def _set_preferred_chain(self, chain: Blockchain): - height = chain.get_max_forkpoint() - header_hash = chain.get_hash(height) + def _set_preferred_chain(self, chain: Optional[Blockchain]): + if chain: + height = chain.get_max_forkpoint() + header_hash = chain.get_hash(height) + else: + height = 0 + header_hash = constants.net.GENESIS self._blockchain_preferred_block = { 'height': height, 'hash': header_hash, diff --git a/electrum/tests/test_blockchain.py b/electrum/tests/test_blockchain.py @@ -336,6 +336,54 @@ class TestBlockchain(ElectrumTestCase): for b in (chain_u, chain_l, chain_z): self.assertTrue(all([b.can_connect(b.read_header(i), False) for i in range(b.height())])) + def get_chains_that_contain_header_helper(self, header: dict): + height = header['block_height'] + header_hash = hash_header(header) + return blockchain.get_chains_that_contain_header(height, header_hash) + + def test_get_chains_that_contain_header(self): + blockchain.blockchains[constants.net.GENESIS] = chain_u = Blockchain( + config=self.config, forkpoint=0, parent=None, + forkpoint_hash=constants.net.GENESIS, prev_hash=None) + open(chain_u.path(), 'w+').close() + self._append_header(chain_u, self.HEADERS['A']) + self._append_header(chain_u, self.HEADERS['B']) + self._append_header(chain_u, self.HEADERS['C']) + self._append_header(chain_u, self.HEADERS['D']) + self._append_header(chain_u, self.HEADERS['E']) + self._append_header(chain_u, self.HEADERS['F']) + self._append_header(chain_u, self.HEADERS['O']) + self._append_header(chain_u, self.HEADERS['P']) + self._append_header(chain_u, self.HEADERS['Q']) + + chain_l = chain_u.fork(self.HEADERS['G']) + self._append_header(chain_l, self.HEADERS['H']) + self._append_header(chain_l, self.HEADERS['I']) + self._append_header(chain_l, self.HEADERS['J']) + self._append_header(chain_l, self.HEADERS['K']) + self._append_header(chain_l, self.HEADERS['L']) + + chain_z = chain_l.fork(self.HEADERS['M']) + + self.assertEqual([chain_l, chain_z, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['A'])) + self.assertEqual([chain_l, chain_z, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['C'])) + self.assertEqual([chain_l, chain_z, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['F'])) + self.assertEqual([chain_l, chain_z], self.get_chains_that_contain_header_helper(self.HEADERS['G'])) + self.assertEqual([chain_l, chain_z], self.get_chains_that_contain_header_helper(self.HEADERS['I'])) + self.assertEqual([chain_z], self.get_chains_that_contain_header_helper(self.HEADERS['M'])) + self.assertEqual([chain_l], self.get_chains_that_contain_header_helper(self.HEADERS['K'])) + + self._append_header(chain_z, self.HEADERS['N']) + self._append_header(chain_z, self.HEADERS['X']) + self._append_header(chain_z, self.HEADERS['Y']) + self._append_header(chain_z, self.HEADERS['Z']) + + self.assertEqual([chain_z, chain_l, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['A'])) + self.assertEqual([chain_z, chain_l, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['C'])) + self.assertEqual([chain_z, chain_l, chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['F'])) + self.assertEqual([chain_u], self.get_chains_that_contain_header_helper(self.HEADERS['O'])) + self.assertEqual([chain_z, chain_l], self.get_chains_that_contain_header_helper(self.HEADERS['I'])) + class TestVerifyHeader(ElectrumTestCase):