electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 8d7370d897314d8542906aecc6a45cc949651f77
parent 3393ff757e05af9df8fcfc6d56caa8b72c474fda
Author: ThomasV <thomasv@electrum.org>
Date:   Thu,  2 Jul 2020 18:00:21 +0200

Merge pull request #6315 from SomberNight/202007_interface_check_server_response

interface: check server response for some methods
Diffstat:
Melectrum/interface.py | 164+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--
Melectrum/network.py | 48++++++++++++++----------------------------------
Melectrum/synchronizer.py | 7++-----
Melectrum/util.py | 24++++++++++++++++++++++++
4 files changed, 200 insertions(+), 43 deletions(-)

diff --git a/electrum/interface.py b/electrum/interface.py @@ -29,7 +29,7 @@ import sys import traceback import asyncio import socket -from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple +from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any from collections import defaultdict from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address import itertools @@ -44,16 +44,19 @@ from aiorpcx.jsonrpc import JSONRPC, CodeMessageError from aiorpcx.rawsocket import RSClient import certifi -from .util import ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy +from .util import (ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy, + is_integer, is_non_negative_integer, is_hash256_str, is_hex_str, + is_real_number) from . import util from . import x509 from . import pem from . import version from . import blockchain -from .blockchain import Blockchain +from .blockchain import Blockchain, HEADER_SIZE from . import constants from .i18n import _ from .logging import Logger +from .transaction import Transaction if TYPE_CHECKING: from .network import Network @@ -82,6 +85,45 @@ class NetworkTimeout: RELAXED = 20 MOST_RELAXED = 60 + +def assert_non_negative_integer(val: Any) -> None: + if not is_non_negative_integer(val): + raise RequestCorrupted(f'{val!r} should be a non-negative integer') + + +def assert_integer(val: Any) -> None: + if not is_integer(val): + raise RequestCorrupted(f'{val!r} should be an integer') + + +def assert_real_number(val: Any, *, as_str: bool = False) -> None: + if not is_real_number(val, as_str=as_str): + raise RequestCorrupted(f'{val!r} should be a number') + + +def assert_hash256_str(val: Any) -> None: + if not is_hash256_str(val): + raise RequestCorrupted(f'{val!r} should be a hash256 str') + + +def assert_hex_str(val: Any) -> None: + if not is_hex_str(val): + raise RequestCorrupted(f'{val!r} should be a hex str') + + +def assert_dict_contains_field(d: Any, *, field_name: str) -> Any: + if not isinstance(d, dict): + raise RequestCorrupted(f'{d!r} should be a dict') + if field_name not in d: + raise RequestCorrupted(f'required field {field_name!r} missing from dict') + return d[field_name] + + +def assert_list_or_tuple(val: Any) -> None: + if not isinstance(val, (list, tuple)): + raise RequestCorrupted(f'{val!r} should be a list or tuple') + + class NotificationSession(RPCSession): def __init__(self, *args, **kwargs): @@ -187,7 +229,7 @@ class RequestTimedOut(GracefulDisconnect): return _("Network request timed out.") -class RequestCorrupted(GracefulDisconnect): pass +class RequestCorrupted(Exception): pass class ErrorParsingSSLCert(Exception): pass class ErrorGettingSSLCertFromServer(Exception): pass @@ -529,6 +571,8 @@ class Interface(Logger): return blockchain.deserialize_header(bytes.fromhex(res), height) async def request_chunk(self, height: int, tip=None, *, can_return_early=False): + if not is_non_negative_integer(height): + raise Exception(f"{repr(height)} is not a block height") index = height // 2016 if can_return_early and index in self._requested_chunks: return @@ -542,6 +586,16 @@ class Interface(Logger): res = await self.session.send_request('blockchain.block.headers', [index * 2016, size]) finally: self._requested_chunks.discard(index) + assert_dict_contains_field(res, field_name='count') + assert_dict_contains_field(res, field_name='hex') + assert_dict_contains_field(res, field_name='max') + assert_non_negative_integer(res['count']) + assert_non_negative_integer(res['max']) + assert_hex_str(res['hex']) + if len(res['hex']) != HEADER_SIZE * 2 * res['count']: + raise RequestCorrupted('inconsistent chunk hex and count') + if res['count'] != size: + raise RequestCorrupted(f"expected {size} headers but only got {res['count']}") conn = self.blockchain.connect_chunk(index, res['hex']) if not conn: return conn, 0 @@ -819,6 +873,108 @@ class Interface(Logger): self._ipaddr_bucket = do_bucket() return self._ipaddr_bucket + async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict: + if not is_hash256_str(tx_hash): + raise Exception(f"{repr(tx_hash)} is not a txid") + if not is_non_negative_integer(tx_height): + raise Exception(f"{repr(tx_height)} is not a block height") + # do request + res = await self.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height]) + # check response + block_height = assert_dict_contains_field(res, field_name='block_height') + merkle = assert_dict_contains_field(res, field_name='merkle') + pos = assert_dict_contains_field(res, field_name='pos') + # note: tx_height was just a hint to the server, don't enforce the response to match it + assert_non_negative_integer(block_height) + assert_non_negative_integer(pos) + assert_list_or_tuple(merkle) + for item in merkle: + assert_hash256_str(item) + return res + + async def get_transaction(self, tx_hash: str, *, timeout=None) -> str: + if not is_hash256_str(tx_hash): + raise Exception(f"{repr(tx_hash)} is not a txid") + raw = await self.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout) + # validate response + tx = Transaction(raw) + try: + tx.deserialize() # see if raises + except Exception as e: + raise RequestCorrupted(f"cannot deserialize received transaction (txid {tx_hash})") from e + if tx.txid() != tx_hash: + raise RequestCorrupted(f"received tx does not match expected txid {tx_hash} (got {tx.txid()})") + return raw + + async def get_history_for_scripthash(self, sh: str) -> List[dict]: + if not is_hash256_str(sh): + raise Exception(f"{repr(sh)} is not a scripthash") + # do request + res = await self.session.send_request('blockchain.scripthash.get_history', [sh]) + # check response + assert_list_or_tuple(res) + for tx_item in res: + assert_dict_contains_field(tx_item, field_name='height') + assert_dict_contains_field(tx_item, field_name='tx_hash') + assert_integer(tx_item['height']) + assert_hash256_str(tx_item['tx_hash']) + if tx_item['height'] in (-1, 0): + assert_dict_contains_field(tx_item, field_name='fee') + assert_non_negative_integer(tx_item['fee']) + return res + + async def listunspent_for_scripthash(self, sh: str) -> List[dict]: + if not is_hash256_str(sh): + raise Exception(f"{repr(sh)} is not a scripthash") + # do request + res = await self.session.send_request('blockchain.scripthash.listunspent', [sh]) + # check response + assert_list_or_tuple(res) + for utxo_item in res: + assert_dict_contains_field(utxo_item, field_name='tx_pos') + assert_dict_contains_field(utxo_item, field_name='value') + assert_dict_contains_field(utxo_item, field_name='tx_hash') + assert_dict_contains_field(utxo_item, field_name='height') + assert_non_negative_integer(utxo_item['tx_pos']) + assert_non_negative_integer(utxo_item['value']) + assert_non_negative_integer(utxo_item['height']) + assert_hash256_str(utxo_item['tx_hash']) + return res + + async def get_balance_for_scripthash(self, sh: str) -> dict: + if not is_hash256_str(sh): + raise Exception(f"{repr(sh)} is not a scripthash") + # do request + res = await self.session.send_request('blockchain.scripthash.get_balance', [sh]) + # check response + assert_dict_contains_field(res, field_name='confirmed') + assert_dict_contains_field(res, field_name='unconfirmed') + assert_non_negative_integer(res['confirmed']) + assert_non_negative_integer(res['unconfirmed']) + return res + + async def get_txid_from_txpos(self, tx_height: int, tx_pos: int, merkle: bool): + if not is_non_negative_integer(tx_height): + raise Exception(f"{repr(tx_height)} is not a block height") + if not is_non_negative_integer(tx_pos): + raise Exception(f"{repr(tx_pos)} should be non-negative integer") + # do request + res = await self.session.send_request( + 'blockchain.transaction.id_from_pos', + [tx_height, tx_pos, merkle], + ) + # check response + if merkle: + assert_dict_contains_field(res, field_name='tx_hash') + assert_dict_contains_field(res, field_name='merkle') + assert_hash256_str(res['tx_hash']) + assert_list_or_tuple(res['merkle']) + for node_hash in res['merkle']: + assert_hash256_str(node_hash) + else: + assert_hash256_str(res) + return res + def _assert_header_does_not_check_against_any_chain(header: dict) -> None: chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) diff --git a/electrum/network.py b/electrum/network.py @@ -816,7 +816,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): if success_fut.exception(): try: raise success_fut.exception() - except (RequestTimedOut, RequestCorrupted): + except RequestTimedOut: + await iface.close() + await iface.got_disconnected + continue # try again + except RequestCorrupted as e: + # TODO ban server? + iface.logger.exception(f"RequestCorrupted: {e}") await iface.close() await iface.got_disconnected continue # try again @@ -836,11 +842,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): @best_effort_reliable @catch_server_exceptions async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict: - if not is_hash256_str(tx_hash): - raise Exception(f"{repr(tx_hash)} is not a txid") - if not is_non_negative_integer(tx_height): - raise Exception(f"{repr(tx_height)} is not a block height") - return await self.interface.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height]) + return await self.interface.get_merkle_for_transaction(tx_hash=tx_hash, tx_height=tx_height) @best_effort_reliable async def broadcast_transaction(self, tx: 'Transaction', *, timeout=None) -> None: @@ -1012,54 +1014,32 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): @best_effort_reliable @catch_server_exceptions async def request_chunk(self, height: int, tip=None, *, can_return_early=False): - if not is_non_negative_integer(height): - raise Exception(f"{repr(height)} is not a block height") return await self.interface.request_chunk(height, tip=tip, can_return_early=can_return_early) @best_effort_reliable @catch_server_exceptions async def get_transaction(self, tx_hash: str, *, timeout=None) -> str: - if not is_hash256_str(tx_hash): - raise Exception(f"{repr(tx_hash)} is not a txid") - iface = self.interface - raw = await iface.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout) - # validate response - tx = Transaction(raw) - try: - tx.deserialize() # see if raises - except Exception as e: - self.logger.warning(f"cannot deserialize received transaction (txid {tx_hash}). from {str(iface)}") - raise RequestCorrupted() from e # TODO ban server? - if tx.txid() != tx_hash: - self.logger.warning(f"received tx does not match expected txid {tx_hash} (got {tx.txid()}). from {str(iface)}") - raise RequestCorrupted() # TODO ban server? - return raw + return await self.interface.get_transaction(tx_hash=tx_hash, timeout=timeout) @best_effort_reliable @catch_server_exceptions async def get_history_for_scripthash(self, sh: str) -> List[dict]: - if not is_hash256_str(sh): - raise Exception(f"{repr(sh)} is not a scripthash") - return await self.interface.session.send_request('blockchain.scripthash.get_history', [sh]) + return await self.interface.get_history_for_scripthash(sh) @best_effort_reliable @catch_server_exceptions async def listunspent_for_scripthash(self, sh: str) -> List[dict]: - if not is_hash256_str(sh): - raise Exception(f"{repr(sh)} is not a scripthash") - return await self.interface.session.send_request('blockchain.scripthash.listunspent', [sh]) + return await self.interface.listunspent_for_scripthash(sh) @best_effort_reliable @catch_server_exceptions async def get_balance_for_scripthash(self, sh: str) -> dict: - if not is_hash256_str(sh): - raise Exception(f"{repr(sh)} is not a scripthash") - return await self.interface.session.send_request('blockchain.scripthash.get_balance', [sh]) + return await self.interface.get_balance_for_scripthash(sh) @best_effort_reliable + @catch_server_exceptions async def get_txid_from_txpos(self, tx_height, tx_pos, merkle): - command = 'blockchain.transaction.id_from_pos' - return await self.interface.session.send_request(command, [tx_height, tx_pos, merkle]) + return await self.interface.get_txid_from_txpos(tx_height, tx_pos, merkle) def blockchain(self) -> Blockchain: interface = self.interface diff --git a/electrum/synchronizer.py b/electrum/synchronizer.py @@ -168,15 +168,12 @@ class Synchronizer(SynchronizerBase): self.requested_histories.add((addr, status)) h = address_to_scripthash(addr) self._requests_sent += 1 - result = await self.network.get_history_for_scripthash(h) + result = await self.interface.get_history_for_scripthash(h) self._requests_answered += 1 self.logger.info(f"receiving history {addr} {len(result)}") hashes = set(map(lambda item: item['tx_hash'], result)) hist = list(map(lambda item: (item['tx_hash'], item['height']), result)) # tx_fees - for item in result: - if item['height'] in (-1, 0) and 'fee' not in item: - raise Exception("server response to get_history contains unconfirmed tx without fee") tx_fees = [(item['tx_hash'], item.get('fee')) for item in result] tx_fees = dict(filter(lambda x:x[1] is not None, tx_fees)) # Check that txids are unique @@ -214,7 +211,7 @@ class Synchronizer(SynchronizerBase): async def _get_transaction(self, tx_hash, *, allow_server_not_finding_tx=False): self._requests_sent += 1 try: - raw_tx = await self.network.get_transaction(tx_hash) + raw_tx = await self.interface.get_transaction(tx_hash) except UntrustedServerReturnedError as e: # most likely, "No such mempool or blockchain transaction" if allow_server_not_finding_tx: diff --git a/electrum/util.py b/electrum/util.py @@ -582,6 +582,30 @@ def is_non_negative_integer(val) -> bool: return False +def is_integer(val) -> bool: + try: + int(val) + except: + return False + else: + return True + + +def is_real_number(val, *, as_str: bool = False) -> bool: + if as_str: # only accept str + if not isinstance(val, str): + return False + else: # only accept int/float/etc. + if isinstance(val, str): + return False + try: + Decimal(val) + except: + return False + else: + return True + + def chunks(items, size: int): """Break up items, an iterable, into chunks of length size.""" if size < 1: