electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit cf1f2ba4dca51f15f485211a530022165a89c4c4
parent ef2ff11926343016004160215348c54ebe4ffd1e
Author: SomberNight <somber.night@protonmail.com>
Date:   Tue, 14 Apr 2020 16:56:17 +0200

network: replace "server" strings with ServerAddr objects

Diffstat:
Melectrum/daemon.py | 2++
Melectrum/exchange_rate.py | 2+-
Melectrum/gui/qt/network_dialog.py | 40+++++++++++++++++++++-------------------
Melectrum/gui/text.py | 24+++++++++++++++++-------
Melectrum/interface.py | 83+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------
Melectrum/network.py | 124+++++++++++++++++++++++++++++++++++++++++++------------------------------------
6 files changed, 172 insertions(+), 103 deletions(-)

diff --git a/electrum/daemon.py b/electrum/daemon.py @@ -270,6 +270,8 @@ class AuthenticationCredentialsInvalid(AuthenticationError): class Daemon(Logger): + network: Optional[Network] + @profiler def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True): Logger.__init__(self) diff --git a/electrum/exchange_rate.py b/electrum/exchange_rate.py @@ -453,7 +453,7 @@ def get_exchanges_by_ccy(history=True): class FxThread(ThreadJob): - def __init__(self, config: SimpleConfig, network: Network): + def __init__(self, config: SimpleConfig, network: Optional[Network]): ThreadJob.__init__(self) self.config = config self.network = network diff --git a/electrum/gui/qt/network_dialog.py b/electrum/gui/qt/network_dialog.py @@ -36,7 +36,7 @@ from PyQt5.QtGui import QFontMetrics from electrum.i18n import _ from electrum import constants, blockchain, util -from electrum.interface import serialize_server, deserialize_server +from electrum.interface import ServerAddr from electrum.network import Network from electrum.logging import get_logger @@ -72,10 +72,13 @@ class NetworkDialog(QDialog): class NodesListWidget(QTreeWidget): + SERVER_ADDR_ROLE = Qt.UserRole + 100 + CHAIN_ID_ROLE = Qt.UserRole + 101 + IS_SERVER_ROLE = Qt.UserRole + 102 def __init__(self, parent): QTreeWidget.__init__(self) - self.parent = parent + self.parent = parent # type: NetworkChoiceLayout self.setHeaderLabels([_('Connected node'), _('Height')]) self.setContextMenuPolicy(Qt.CustomContextMenu) self.customContextMenuRequested.connect(self.create_menu) @@ -84,13 +87,13 @@ class NodesListWidget(QTreeWidget): item = self.currentItem() if not item: return - is_server = not bool(item.data(0, Qt.UserRole)) + is_server = bool(item.data(0, self.IS_SERVER_ROLE)) menu = QMenu() if is_server: - server = item.data(1, Qt.UserRole) + server = item.data(0, self.SERVER_ADDR_ROLE) # type: ServerAddr menu.addAction(_("Use as server"), lambda: self.parent.follow_server(server)) else: - chain_id = item.data(1, Qt.UserRole) + chain_id = item.data(0, self.CHAIN_ID_ROLE) menu.addAction(_("Follow this branch"), lambda: self.parent.follow_branch(chain_id)) menu.exec_(self.viewport().mapToGlobal(position)) @@ -117,15 +120,15 @@ class NodesListWidget(QTreeWidget): name = b.get_name() if n_chains > 1: x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()]) - x.setData(0, Qt.UserRole, 1) - x.setData(1, Qt.UserRole, b.get_id()) + x.setData(0, self.IS_SERVER_ROLE, 0) + x.setData(0, self.CHAIN_ID_ROLE, b.get_id()) else: x = self for i in interfaces: star = ' *' if i == network.interface else '' item = QTreeWidgetItem([i.host + star, '%d'%i.tip]) - item.setData(0, Qt.UserRole, 0) - item.setData(1, Qt.UserRole, i.server) + item.setData(0, self.IS_SERVER_ROLE, 1) + item.setData(0, self.SERVER_ADDR_ROLE, i.server) x.addChild(item) if n_chains > 1: self.addTopLevelItem(x) @@ -144,11 +147,11 @@ class ServerListWidget(QTreeWidget): HOST = 0 PORT = 1 - SERVER_STR_ROLE = Qt.UserRole + 100 + SERVER_ADDR_ROLE = Qt.UserRole + 100 def __init__(self, parent): QTreeWidget.__init__(self) - self.parent = parent + self.parent = parent # type: NetworkChoiceLayout self.setHeaderLabels([_('Host'), _('Port')]) self.setContextMenuPolicy(Qt.CustomContextMenu) self.customContextMenuRequested.connect(self.create_menu) @@ -158,14 +161,13 @@ class ServerListWidget(QTreeWidget): if not item: return menu = QMenu() - server = item.data(self.Columns.HOST, self.SERVER_STR_ROLE) + server = item.data(self.Columns.HOST, self.SERVER_ADDR_ROLE) menu.addAction(_("Use as server"), lambda: self.set_server(server)) menu.exec_(self.viewport().mapToGlobal(position)) - def set_server(self, s): - host, port, protocol = deserialize_server(s) - self.parent.server_host.setText(host) - self.parent.server_port.setText(port) + def set_server(self, server: ServerAddr): + self.parent.server_host.setText(server.host) + self.parent.server_port.setText(str(server.port)) self.parent.set_server() def keyPressEvent(self, event): @@ -188,8 +190,8 @@ class ServerListWidget(QTreeWidget): port = d.get(protocol) if port: x = QTreeWidgetItem([_host, port]) - server = serialize_server(_host, port, protocol) - x.setData(self.Columns.HOST, self.SERVER_STR_ROLE, server) + server = ServerAddr(_host, port, protocol=protocol) + x.setData(self.Columns.HOST, self.SERVER_ADDR_ROLE, server) self.addTopLevelItem(x) h = self.header() @@ -431,7 +433,7 @@ class NetworkChoiceLayout(object): self.network.run_from_another_thread(self.network.follow_chain_given_id(chain_id)) self.update() - def follow_server(self, server): + def follow_server(self, server: ServerAddr): self.network.run_from_another_thread(self.network.follow_chain_given_server(server)) self.update() diff --git a/electrum/gui/text.py b/electrum/gui/text.py @@ -6,6 +6,7 @@ import locale from decimal import Decimal import getpass import logging +from typing import TYPE_CHECKING import electrum from electrum import util @@ -15,15 +16,21 @@ from electrum.transaction import PartialTxOutput from electrum.wallet import Wallet from electrum.storage import WalletStorage from electrum.network import NetworkParameters, TxBroadcastError, BestEffortRequestFailed -from electrum.interface import deserialize_server +from electrum.interface import ServerAddr from electrum.logging import console_stderr_handler +if TYPE_CHECKING: + from electrum.daemon import Daemon + from electrum.simple_config import SimpleConfig + from electrum.plugin import Plugins + + _ = lambda x:x # i18n class ElectrumGui: - def __init__(self, config, daemon, plugins): + def __init__(self, config: 'SimpleConfig', daemon: 'Daemon', plugins: 'Plugins'): self.config = config self.network = daemon.network @@ -404,21 +411,24 @@ class ElectrumGui: net_params = self.network.get_parameters() host, port, protocol = net_params.host, net_params.port, net_params.protocol proxy_config, auto_connect = net_params.proxy, net_params.auto_connect - srv = 'auto-connect' if auto_connect else self.network.default_server + srv = 'auto-connect' if auto_connect else str(self.network.default_server) out = self.run_dialog('Network', [ {'label':'server', 'type':'str', 'value':srv}, {'label':'proxy', 'type':'str', 'value':self.config.get('proxy', '')}, ], buttons = 1) if out: if out.get('server'): - server = out.get('server') - auto_connect = server == 'auto-connect' + server_str = out.get('server') + auto_connect = server_str == 'auto-connect' if not auto_connect: try: - host, port, protocol = deserialize_server(server) + server_addr = ServerAddr.from_str(server_str) except Exception: - self.show_message("Error:" + server + "\nIn doubt, type \"auto-connect\"") + self.show_message("Error:" + server_str + "\nIn doubt, type \"auto-connect\"") return False + host = server_addr.host + port = str(server_addr.port) + protocol = server_addr.protocol if out.get('server') or out.get('proxy'): proxy = electrum.network.deserialize_proxy(out.get('proxy')) if out.get('proxy') else proxy_config net_params = NetworkParameters(host, port, protocol, proxy, auto_connect) diff --git a/electrum/interface.py b/electrum/interface.py @@ -29,7 +29,7 @@ import sys import traceback import asyncio import socket -from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set +from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple from collections import defaultdict from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address import itertools @@ -198,22 +198,57 @@ class _RSClient(RSClient): raise ConnectError(e) from e -def deserialize_server(server_str: str) -> Tuple[str, str, str]: - # host might be IPv6 address, hence do rsplit: - host, port, protocol = str(server_str).rsplit(':', 2) - if not host: - raise ValueError('host must not be empty') - if host[0] == '[' and host[-1] == ']': # IPv6 - host = host[1:-1] - if protocol not in ('s', 't'): - raise ValueError('invalid network protocol: {}'.format(protocol)) - net_addr = NetAddress(host, port) # this validates host and port - host = str(net_addr.host) # canonical form (if e.g. IPv6 address) - return host, port, protocol +class ServerAddr: + def __init__(self, host: str, port: Union[int, str], *, protocol: str = None): + assert isinstance(host, str), repr(host) + if protocol is None: + protocol = 's' + if not host: + raise ValueError('host must not be empty') + if host[0] == '[' and host[-1] == ']': # IPv6 + host = host[1:-1] + try: + net_addr = NetAddress(host, port) # this validates host and port + except Exception as e: + raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e + if protocol not in ('s', 't'): + raise ValueError(f"invalid network protocol: {protocol}") + self.host = str(net_addr.host) # canonical form (if e.g. IPv6 address) + self.port = int(net_addr.port) + self.protocol = protocol + self._net_addr_str = str(net_addr) + + @classmethod + def from_str(cls, s: str) -> 'ServerAddr': + # host might be IPv6 address, hence do rsplit: + host, port, protocol = str(s).rsplit(':', 2) + return ServerAddr(host=host, port=port, protocol=protocol) -def serialize_server(host: str, port: Union[str, int], protocol: str) -> str: - return str(':'.join([host, str(port), protocol])) + def __str__(self): + return '{}:{}'.format(self.net_addr_str(), self.protocol) + + def to_json(self) -> str: + return str(self) + + def __repr__(self): + return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>' + + def net_addr_str(self) -> str: + return self._net_addr_str + + def __eq__(self, other): + if not isinstance(other, ServerAddr): + return False + return (self.host == other.host + and self.port == other.port + and self.protocol == other.protocol) + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash((self.host, self.port, self.protocol)) def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str: @@ -232,12 +267,10 @@ class Interface(Logger): LOGGING_SHORTCUT = 'i' - def __init__(self, network: 'Network', server: str, proxy: Optional[dict]): + def __init__(self, *, network: 'Network', server: ServerAddr, proxy: Optional[dict]): self.ready = asyncio.Future() self.got_disconnected = asyncio.Future() self.server = server - self.host, self.port, self.protocol = deserialize_server(self.server) - self.port = int(self.port) Logger.__init__(self) assert network.config.path self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host) @@ -259,8 +292,20 @@ class Interface(Logger): self.network.taskgroup.spawn(self.run()), self.network.asyncio_loop) self.taskgroup = SilentTaskGroup() + @property + def host(self): + return self.server.host + + @property + def port(self): + return self.server.port + + @property + def protocol(self): + return self.server.protocol + def diagnostic_name(self): - return str(NetAddress(self.host, self.port)) + return self.server.net_addr_str() def __str__(self): return f"<Interface {self.diagnostic_name()}>" diff --git a/electrum/network.py b/electrum/network.py @@ -32,7 +32,7 @@ import socket import json import sys import asyncio -from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable +from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set import traceback import concurrent from concurrent import futures @@ -44,7 +44,7 @@ from aiohttp import ClientResponse from . import util from .util import (log_exceptions, ignore_exceptions, bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter, - is_hash256_str, is_non_negative_integer) + is_hash256_str, is_non_negative_integer, MyEncoder) from .bitcoin import COIN from . import constants @@ -53,9 +53,9 @@ from . import bitcoin from . import dns_hacks from .transaction import Transaction from .blockchain import Blockchain, HEADER_SIZE -from .interface import (Interface, serialize_server, deserialize_server, +from .interface import (Interface, RequestTimedOut, NetworkTimeout, BUCKET_NAME_OF_ONION_SERVERS, - NetworkException, RequestCorrupted) + NetworkException, RequestCorrupted, ServerAddr) from .version import PROTOCOL_VERSION from .simple_config import SimpleConfig from .i18n import _ @@ -117,18 +117,18 @@ def filter_noonion(servers): return {k: v for k, v in servers.items() if not k.endswith('.onion')} -def filter_protocol(hostmap, protocol='s'): - '''Filters the hostmap for those implementing protocol. - The result is a list in serialized form.''' +def filter_protocol(hostmap, protocol='s') -> Sequence[ServerAddr]: + """Filters the hostmap for those implementing protocol.""" eligible = [] for host, portmap in hostmap.items(): port = portmap.get(protocol) if port: - eligible.append(serialize_server(host, port, protocol)) + eligible.append(ServerAddr(host, port, protocol=protocol)) return eligible -def pick_random_server(hostmap=None, protocol='s', exclude_set=None): +def pick_random_server(hostmap=None, *, protocol='s', + exclude_set: Set[ServerAddr] = None) -> Optional[ServerAddr]: if hostmap is None: hostmap = constants.net.DEFAULT_SERVERS if exclude_set is None: @@ -240,6 +240,14 @@ class Network(Logger): LOGGING_SHORTCUT = 'n' + taskgroup: Optional[TaskGroup] + interface: Optional[Interface] + interfaces: Dict[ServerAddr, Interface] + connecting: Set[ServerAddr] + server_queue: 'Optional[queue.Queue[ServerAddr]]' + disconnected_servers: Set[ServerAddr] + default_server: ServerAddr + def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None): global _INSTANCE assert _INSTANCE is None, "Network is a singleton!" @@ -266,14 +274,15 @@ class Network(Logger): # Sanitize default server if self.default_server: try: - deserialize_server(self.default_server) + self.default_server = ServerAddr.from_str(self.default_server) except: self.logger.warning('failed to parse server-string; falling back to localhost.') - self.default_server = "localhost:50002:s" - if not self.default_server: + self.default_server = ServerAddr.from_str("localhost:50002:s") + else: self.default_server = pick_random_server() + assert isinstance(self.default_server, ServerAddr), f"invalid type for default_server: {self.default_server!r}" - self.taskgroup = None # type: TaskGroup + self.taskgroup = None # locks self.restart_lock = asyncio.Lock() @@ -295,10 +304,10 @@ class Network(Logger): self.server_retry_time = time.time() self.nodes_retry_time = time.time() # the main server we are currently communicating with - self.interface = None # type: Optional[Interface] + self.interface = None self.default_server_changed_event = asyncio.Event() # set of servers we have an ongoing connection with - self.interfaces = {} # type: Dict[str, Interface] + self.interfaces = {} self.auto_connect = self.config.get('auto_connect', True) self.connecting = set() self.server_queue = None @@ -347,14 +356,15 @@ class Network(Logger): return func(self, *args, **kwargs) return func_wrapper - def _read_recent_servers(self): + def _read_recent_servers(self) -> List[ServerAddr]: if not self.config.path: return [] path = os.path.join(self.config.path, "recent_servers") try: with open(path, "r", encoding='utf-8') as f: data = f.read() - return json.loads(data) + servers_list = json.loads(data) + return [ServerAddr.from_str(s) for s in servers_list] except: return [] @@ -363,7 +373,7 @@ class Network(Logger): if not self.config.path: return path = os.path.join(self.config.path, "recent_servers") - s = json.dumps(self.recent_servers, indent=4, sort_keys=True) + s = json.dumps(self.recent_servers, indent=4, sort_keys=True, cls=MyEncoder) try: with open(path, "w", encoding='utf-8') as f: f.write(s) @@ -462,10 +472,10 @@ class Network(Logger): util.trigger_callback(key, self.get_status_value(key)) def get_parameters(self) -> NetworkParameters: - host, port, protocol = deserialize_server(self.default_server) - return NetworkParameters(host=host, - port=port, - protocol=protocol, + server = self.default_server + return NetworkParameters(host=server.host, + port=str(server.port), + protocol=server.protocol, proxy=self.proxy, auto_connect=self.auto_connect, oneserver=self.oneserver) @@ -474,7 +484,7 @@ class Network(Logger): if self.is_connected(): return self.donation_address - def get_interfaces(self) -> List[str]: + def get_interfaces(self) -> List[ServerAddr]: """The list of servers for the connected interfaces.""" with self.interfaces_lock: return list(self.interfaces) @@ -516,21 +526,18 @@ class Network(Logger): # hardcoded servers out.update(constants.net.DEFAULT_SERVERS) # add recent servers - for s in self.recent_servers: - try: - host, port, protocol = deserialize_server(s) - except: - continue - if host in out: - out[host].update({protocol: port}) + for server in self.recent_servers: + port = str(server.port) + if server.host in out: + out[server.host].update({server.protocol: port}) else: - out[host] = {protocol: port} + out[server.host] = {server.protocol: port} # potentially filter out some if self.config.get('noonion'): out = filter_noonion(out) return out - def _start_interface(self, server: str): + def _start_interface(self, server: ServerAddr): if server not in self.interfaces and server not in self.connecting: if server == self.default_server: self.logger.info(f"connecting to {server} as new interface") @@ -538,10 +545,10 @@ class Network(Logger): self.connecting.add(server) self.server_queue.put(server) - def _start_random_interface(self): + def _start_random_interface(self) -> Optional[ServerAddr]: with self.interfaces_lock: exclude_set = self.disconnected_servers | set(self.interfaces) | self.connecting - server = pick_random_server(self.get_servers(), self.protocol, exclude_set) + server = pick_random_server(self.get_servers(), protocol=self.protocol, exclude_set=exclude_set) if server: self._start_interface(server) return server @@ -557,10 +564,9 @@ class Network(Logger): proxy = net_params.proxy proxy_str = serialize_proxy(proxy) host, port, protocol = net_params.host, net_params.port, net_params.protocol - server_str = serialize_server(host, port, protocol) # sanitize parameters try: - deserialize_server(serialize_server(host, port, protocol)) + server = ServerAddr(host, port, protocol=protocol) if proxy: proxy_modes.index(proxy['mode']) + 1 int(proxy['port']) @@ -569,9 +575,9 @@ class Network(Logger): self.config.set_key('auto_connect', net_params.auto_connect, False) self.config.set_key('oneserver', net_params.oneserver, False) self.config.set_key('proxy', proxy_str, False) - self.config.set_key('server', server_str, True) + self.config.set_key('server', str(server), True) # abort if changes were not allowed by config - if self.config.get('server') != server_str \ + if self.config.get('server') != str(server) \ or self.config.get('proxy') != proxy_str \ or self.config.get('oneserver') != net_params.oneserver: return @@ -581,10 +587,10 @@ class Network(Logger): if self.proxy != proxy or self.protocol != protocol or self.oneserver != net_params.oneserver: # Restart the network defaulting to the given server await self._stop() - self.default_server = server_str + self.default_server = server await self._start() - elif self.default_server != server_str: - await self.switch_to_interface(server_str) + elif self.default_server != server: + await self.switch_to_interface(server) else: await self.switch_lagging_interface() @@ -646,7 +652,7 @@ class Network(Logger): # FIXME switch to best available? self.logger.info("tried to switch to best chain but no interfaces are on it") - async def switch_to_interface(self, server: str): + async def switch_to_interface(self, server: ServerAddr): """Switch to server as our main interface. If no connection exists, queue interface to be started. The actual switch will happen when the interface becomes ready. @@ -722,8 +728,8 @@ class Network(Logger): @ignore_exceptions # do not kill main_taskgroup @log_exceptions - async def _run_new_interface(self, server): - interface = Interface(self, server, self.proxy) + async def _run_new_interface(self, server: ServerAddr): + interface = Interface(network=self, server=server, proxy=self.proxy) # note: using longer timeouts here as DNS can sometimes be slow! timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic) try: @@ -1070,23 +1076,26 @@ class Network(Logger): with self.interfaces_lock: interfaces = list(self.interfaces.values()) interfaces_on_selected_chain = list(filter(lambda iface: iface.blockchain == bc, interfaces)) if len(interfaces_on_selected_chain) == 0: return - chosen_iface = random.choice(interfaces_on_selected_chain) + chosen_iface = random.choice(interfaces_on_selected_chain) # type: Interface # switch to server (and save to config) net_params = self.get_parameters() - host, port, protocol = deserialize_server(chosen_iface.server) - net_params = net_params._replace(host=host, port=port, protocol=protocol) + server = chosen_iface.server + net_params = net_params._replace(host=server.host, + port=str(server.port), + protocol=server.protocol) await self.set_parameters(net_params) - async def follow_chain_given_server(self, server_str: str) -> None: + async def follow_chain_given_server(self, server: ServerAddr) -> None: # note that server_str should correspond to a connected interface - iface = self.interfaces.get(server_str) + iface = self.interfaces.get(server) if iface is None: return self._set_preferred_chain(iface.blockchain) # switch to server (and save to config) net_params = self.get_parameters() - host, port, protocol = deserialize_server(server_str) - net_params = net_params._replace(host=host, port=port, protocol=protocol) + net_params = net_params._replace(host=server.host, + port=str(server.port), + protocol=server.protocol) await self.set_parameters(net_params) def get_local_height(self): @@ -1107,7 +1116,7 @@ class Network(Logger): assert not self.connecting and not self.server_queue self.logger.info('starting network') self.disconnected_servers = set([]) - self.protocol = deserialize_server(self.default_server)[2] + self.protocol = self.default_server.protocol self.server_queue = queue.Queue() self._set_proxy(deserialize_proxy(self.config.get('proxy'))) self._set_oneserver(self.config.get('oneserver', False)) @@ -1147,9 +1156,9 @@ class Network(Logger): await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2) except (asyncio.TimeoutError, asyncio.CancelledError) as e: self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}") - self.taskgroup = None # type: TaskGroup - self.interface = None # type: Interface - self.interfaces = {} # type: Dict[str, Interface] + self.taskgroup = None + self.interface = None + self.interfaces = {} self.connecting.clear() self.server_queue = None if not full_shutdown: @@ -1268,8 +1277,8 @@ class Network(Logger): async def send_multiple_requests(self, servers: List[str], method: str, params: Sequence): responses = dict() - async def get_response(server): - interface = Interface(self, server, self.proxy) + async def get_response(server: ServerAddr): + interface = Interface(network=self, server=server, proxy=self.proxy) timeout = self.get_network_timeout_seconds(NetworkTimeout.Urgent) try: await asyncio.wait_for(interface.ready, timeout) @@ -1283,5 +1292,6 @@ class Network(Logger): responses[interface.server] = res async with TaskGroup() as group: for server in servers: + server = ServerAddr.from_str(server) await group.spawn(get_response(server)) return responses