commit 8d7370d897314d8542906aecc6a45cc949651f77
parent 3393ff757e05af9df8fcfc6d56caa8b72c474fda
Author: ThomasV <thomasv@electrum.org>
Date: Thu, 2 Jul 2020 18:00:21 +0200
Merge pull request #6315 from SomberNight/202007_interface_check_server_response
interface: check server response for some methods
Diffstat:
4 files changed, 200 insertions(+), 43 deletions(-)
diff --git a/electrum/interface.py b/electrum/interface.py
@@ -29,7 +29,7 @@ import sys
import traceback
import asyncio
import socket
-from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple
+from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any
from collections import defaultdict
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address
import itertools
@@ -44,16 +44,19 @@ from aiorpcx.jsonrpc import JSONRPC, CodeMessageError
from aiorpcx.rawsocket import RSClient
import certifi
-from .util import ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy
+from .util import (ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy,
+ is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
+ is_real_number)
from . import util
from . import x509
from . import pem
from . import version
from . import blockchain
-from .blockchain import Blockchain
+from .blockchain import Blockchain, HEADER_SIZE
from . import constants
from .i18n import _
from .logging import Logger
+from .transaction import Transaction
if TYPE_CHECKING:
from .network import Network
@@ -82,6 +85,45 @@ class NetworkTimeout:
RELAXED = 20
MOST_RELAXED = 60
+
+def assert_non_negative_integer(val: Any) -> None:
+ if not is_non_negative_integer(val):
+ raise RequestCorrupted(f'{val!r} should be a non-negative integer')
+
+
+def assert_integer(val: Any) -> None:
+ if not is_integer(val):
+ raise RequestCorrupted(f'{val!r} should be an integer')
+
+
+def assert_real_number(val: Any, *, as_str: bool = False) -> None:
+ if not is_real_number(val, as_str=as_str):
+ raise RequestCorrupted(f'{val!r} should be a number')
+
+
+def assert_hash256_str(val: Any) -> None:
+ if not is_hash256_str(val):
+ raise RequestCorrupted(f'{val!r} should be a hash256 str')
+
+
+def assert_hex_str(val: Any) -> None:
+ if not is_hex_str(val):
+ raise RequestCorrupted(f'{val!r} should be a hex str')
+
+
+def assert_dict_contains_field(d: Any, *, field_name: str) -> Any:
+ if not isinstance(d, dict):
+ raise RequestCorrupted(f'{d!r} should be a dict')
+ if field_name not in d:
+ raise RequestCorrupted(f'required field {field_name!r} missing from dict')
+ return d[field_name]
+
+
+def assert_list_or_tuple(val: Any) -> None:
+ if not isinstance(val, (list, tuple)):
+ raise RequestCorrupted(f'{val!r} should be a list or tuple')
+
+
class NotificationSession(RPCSession):
def __init__(self, *args, **kwargs):
@@ -187,7 +229,7 @@ class RequestTimedOut(GracefulDisconnect):
return _("Network request timed out.")
-class RequestCorrupted(GracefulDisconnect): pass
+class RequestCorrupted(Exception): pass
class ErrorParsingSSLCert(Exception): pass
class ErrorGettingSSLCertFromServer(Exception): pass
@@ -529,6 +571,8 @@ class Interface(Logger):
return blockchain.deserialize_header(bytes.fromhex(res), height)
async def request_chunk(self, height: int, tip=None, *, can_return_early=False):
+ if not is_non_negative_integer(height):
+ raise Exception(f"{repr(height)} is not a block height")
index = height // 2016
if can_return_early and index in self._requested_chunks:
return
@@ -542,6 +586,16 @@ class Interface(Logger):
res = await self.session.send_request('blockchain.block.headers', [index * 2016, size])
finally:
self._requested_chunks.discard(index)
+ assert_dict_contains_field(res, field_name='count')
+ assert_dict_contains_field(res, field_name='hex')
+ assert_dict_contains_field(res, field_name='max')
+ assert_non_negative_integer(res['count'])
+ assert_non_negative_integer(res['max'])
+ assert_hex_str(res['hex'])
+ if len(res['hex']) != HEADER_SIZE * 2 * res['count']:
+ raise RequestCorrupted('inconsistent chunk hex and count')
+ if res['count'] != size:
+ raise RequestCorrupted(f"expected {size} headers but only got {res['count']}")
conn = self.blockchain.connect_chunk(index, res['hex'])
if not conn:
return conn, 0
@@ -819,6 +873,108 @@ class Interface(Logger):
self._ipaddr_bucket = do_bucket()
return self._ipaddr_bucket
+ async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict:
+ if not is_hash256_str(tx_hash):
+ raise Exception(f"{repr(tx_hash)} is not a txid")
+ if not is_non_negative_integer(tx_height):
+ raise Exception(f"{repr(tx_height)} is not a block height")
+ # do request
+ res = await self.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height])
+ # check response
+ block_height = assert_dict_contains_field(res, field_name='block_height')
+ merkle = assert_dict_contains_field(res, field_name='merkle')
+ pos = assert_dict_contains_field(res, field_name='pos')
+ # note: tx_height was just a hint to the server, don't enforce the response to match it
+ assert_non_negative_integer(block_height)
+ assert_non_negative_integer(pos)
+ assert_list_or_tuple(merkle)
+ for item in merkle:
+ assert_hash256_str(item)
+ return res
+
+ async def get_transaction(self, tx_hash: str, *, timeout=None) -> str:
+ if not is_hash256_str(tx_hash):
+ raise Exception(f"{repr(tx_hash)} is not a txid")
+ raw = await self.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout)
+ # validate response
+ tx = Transaction(raw)
+ try:
+ tx.deserialize() # see if raises
+ except Exception as e:
+ raise RequestCorrupted(f"cannot deserialize received transaction (txid {tx_hash})") from e
+ if tx.txid() != tx_hash:
+ raise RequestCorrupted(f"received tx does not match expected txid {tx_hash} (got {tx.txid()})")
+ return raw
+
+ async def get_history_for_scripthash(self, sh: str) -> List[dict]:
+ if not is_hash256_str(sh):
+ raise Exception(f"{repr(sh)} is not a scripthash")
+ # do request
+ res = await self.session.send_request('blockchain.scripthash.get_history', [sh])
+ # check response
+ assert_list_or_tuple(res)
+ for tx_item in res:
+ assert_dict_contains_field(tx_item, field_name='height')
+ assert_dict_contains_field(tx_item, field_name='tx_hash')
+ assert_integer(tx_item['height'])
+ assert_hash256_str(tx_item['tx_hash'])
+ if tx_item['height'] in (-1, 0):
+ assert_dict_contains_field(tx_item, field_name='fee')
+ assert_non_negative_integer(tx_item['fee'])
+ return res
+
+ async def listunspent_for_scripthash(self, sh: str) -> List[dict]:
+ if not is_hash256_str(sh):
+ raise Exception(f"{repr(sh)} is not a scripthash")
+ # do request
+ res = await self.session.send_request('blockchain.scripthash.listunspent', [sh])
+ # check response
+ assert_list_or_tuple(res)
+ for utxo_item in res:
+ assert_dict_contains_field(utxo_item, field_name='tx_pos')
+ assert_dict_contains_field(utxo_item, field_name='value')
+ assert_dict_contains_field(utxo_item, field_name='tx_hash')
+ assert_dict_contains_field(utxo_item, field_name='height')
+ assert_non_negative_integer(utxo_item['tx_pos'])
+ assert_non_negative_integer(utxo_item['value'])
+ assert_non_negative_integer(utxo_item['height'])
+ assert_hash256_str(utxo_item['tx_hash'])
+ return res
+
+ async def get_balance_for_scripthash(self, sh: str) -> dict:
+ if not is_hash256_str(sh):
+ raise Exception(f"{repr(sh)} is not a scripthash")
+ # do request
+ res = await self.session.send_request('blockchain.scripthash.get_balance', [sh])
+ # check response
+ assert_dict_contains_field(res, field_name='confirmed')
+ assert_dict_contains_field(res, field_name='unconfirmed')
+ assert_non_negative_integer(res['confirmed'])
+ assert_non_negative_integer(res['unconfirmed'])
+ return res
+
+ async def get_txid_from_txpos(self, tx_height: int, tx_pos: int, merkle: bool):
+ if not is_non_negative_integer(tx_height):
+ raise Exception(f"{repr(tx_height)} is not a block height")
+ if not is_non_negative_integer(tx_pos):
+ raise Exception(f"{repr(tx_pos)} should be non-negative integer")
+ # do request
+ res = await self.session.send_request(
+ 'blockchain.transaction.id_from_pos',
+ [tx_height, tx_pos, merkle],
+ )
+ # check response
+ if merkle:
+ assert_dict_contains_field(res, field_name='tx_hash')
+ assert_dict_contains_field(res, field_name='merkle')
+ assert_hash256_str(res['tx_hash'])
+ assert_list_or_tuple(res['merkle'])
+ for node_hash in res['merkle']:
+ assert_hash256_str(node_hash)
+ else:
+ assert_hash256_str(res)
+ return res
+
def _assert_header_does_not_check_against_any_chain(header: dict) -> None:
chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
diff --git a/electrum/network.py b/electrum/network.py
@@ -816,7 +816,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
if success_fut.exception():
try:
raise success_fut.exception()
- except (RequestTimedOut, RequestCorrupted):
+ except RequestTimedOut:
+ await iface.close()
+ await iface.got_disconnected
+ continue # try again
+ except RequestCorrupted as e:
+ # TODO ban server?
+ iface.logger.exception(f"RequestCorrupted: {e}")
await iface.close()
await iface.got_disconnected
continue # try again
@@ -836,11 +842,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
@best_effort_reliable
@catch_server_exceptions
async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict:
- if not is_hash256_str(tx_hash):
- raise Exception(f"{repr(tx_hash)} is not a txid")
- if not is_non_negative_integer(tx_height):
- raise Exception(f"{repr(tx_height)} is not a block height")
- return await self.interface.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height])
+ return await self.interface.get_merkle_for_transaction(tx_hash=tx_hash, tx_height=tx_height)
@best_effort_reliable
async def broadcast_transaction(self, tx: 'Transaction', *, timeout=None) -> None:
@@ -1012,54 +1014,32 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
@best_effort_reliable
@catch_server_exceptions
async def request_chunk(self, height: int, tip=None, *, can_return_early=False):
- if not is_non_negative_integer(height):
- raise Exception(f"{repr(height)} is not a block height")
return await self.interface.request_chunk(height, tip=tip, can_return_early=can_return_early)
@best_effort_reliable
@catch_server_exceptions
async def get_transaction(self, tx_hash: str, *, timeout=None) -> str:
- if not is_hash256_str(tx_hash):
- raise Exception(f"{repr(tx_hash)} is not a txid")
- iface = self.interface
- raw = await iface.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout)
- # validate response
- tx = Transaction(raw)
- try:
- tx.deserialize() # see if raises
- except Exception as e:
- self.logger.warning(f"cannot deserialize received transaction (txid {tx_hash}). from {str(iface)}")
- raise RequestCorrupted() from e # TODO ban server?
- if tx.txid() != tx_hash:
- self.logger.warning(f"received tx does not match expected txid {tx_hash} (got {tx.txid()}). from {str(iface)}")
- raise RequestCorrupted() # TODO ban server?
- return raw
+ return await self.interface.get_transaction(tx_hash=tx_hash, timeout=timeout)
@best_effort_reliable
@catch_server_exceptions
async def get_history_for_scripthash(self, sh: str) -> List[dict]:
- if not is_hash256_str(sh):
- raise Exception(f"{repr(sh)} is not a scripthash")
- return await self.interface.session.send_request('blockchain.scripthash.get_history', [sh])
+ return await self.interface.get_history_for_scripthash(sh)
@best_effort_reliable
@catch_server_exceptions
async def listunspent_for_scripthash(self, sh: str) -> List[dict]:
- if not is_hash256_str(sh):
- raise Exception(f"{repr(sh)} is not a scripthash")
- return await self.interface.session.send_request('blockchain.scripthash.listunspent', [sh])
+ return await self.interface.listunspent_for_scripthash(sh)
@best_effort_reliable
@catch_server_exceptions
async def get_balance_for_scripthash(self, sh: str) -> dict:
- if not is_hash256_str(sh):
- raise Exception(f"{repr(sh)} is not a scripthash")
- return await self.interface.session.send_request('blockchain.scripthash.get_balance', [sh])
+ return await self.interface.get_balance_for_scripthash(sh)
@best_effort_reliable
+ @catch_server_exceptions
async def get_txid_from_txpos(self, tx_height, tx_pos, merkle):
- command = 'blockchain.transaction.id_from_pos'
- return await self.interface.session.send_request(command, [tx_height, tx_pos, merkle])
+ return await self.interface.get_txid_from_txpos(tx_height, tx_pos, merkle)
def blockchain(self) -> Blockchain:
interface = self.interface
diff --git a/electrum/synchronizer.py b/electrum/synchronizer.py
@@ -168,15 +168,12 @@ class Synchronizer(SynchronizerBase):
self.requested_histories.add((addr, status))
h = address_to_scripthash(addr)
self._requests_sent += 1
- result = await self.network.get_history_for_scripthash(h)
+ result = await self.interface.get_history_for_scripthash(h)
self._requests_answered += 1
self.logger.info(f"receiving history {addr} {len(result)}")
hashes = set(map(lambda item: item['tx_hash'], result))
hist = list(map(lambda item: (item['tx_hash'], item['height']), result))
# tx_fees
- for item in result:
- if item['height'] in (-1, 0) and 'fee' not in item:
- raise Exception("server response to get_history contains unconfirmed tx without fee")
tx_fees = [(item['tx_hash'], item.get('fee')) for item in result]
tx_fees = dict(filter(lambda x:x[1] is not None, tx_fees))
# Check that txids are unique
@@ -214,7 +211,7 @@ class Synchronizer(SynchronizerBase):
async def _get_transaction(self, tx_hash, *, allow_server_not_finding_tx=False):
self._requests_sent += 1
try:
- raw_tx = await self.network.get_transaction(tx_hash)
+ raw_tx = await self.interface.get_transaction(tx_hash)
except UntrustedServerReturnedError as e:
# most likely, "No such mempool or blockchain transaction"
if allow_server_not_finding_tx:
diff --git a/electrum/util.py b/electrum/util.py
@@ -582,6 +582,30 @@ def is_non_negative_integer(val) -> bool:
return False
+def is_integer(val) -> bool:
+ try:
+ int(val)
+ except:
+ return False
+ else:
+ return True
+
+
+def is_real_number(val, *, as_str: bool = False) -> bool:
+ if as_str: # only accept str
+ if not isinstance(val, str):
+ return False
+ else: # only accept int/float/etc.
+ if isinstance(val, str):
+ return False
+ try:
+ Decimal(val)
+ except:
+ return False
+ else:
+ return True
+
+
def chunks(items, size: int):
"""Break up items, an iterable, into chunks of length size."""
if size < 1: