electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit c5da22a9dd4459972a85d116fe281b8564c20834
parent c70484455c561b53664ba349cb88587bc349efc9
Author: SomberNight <somber.night@protonmail.com>
Date:   Fri, 16 Oct 2020 19:30:42 +0200

network: tighten checks of server responses for type/sanity

Diffstat:
Melectrum/interface.py | 76++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------
Melectrum/network.py | 21+++++----------------
Melectrum/simple_config.py | 4++--
Melectrum/tests/test_util.py | 87++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
Melectrum/util.py | 40+++++++++++++---------------------------
5 files changed, 174 insertions(+), 54 deletions(-)

diff --git a/electrum/interface.py b/electrum/interface.py @@ -29,7 +29,7 @@ import sys import traceback import asyncio import socket -from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any +from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any, Sequence from collections import defaultdict from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address import itertools @@ -46,13 +46,14 @@ import certifi from .util import (ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy, is_integer, is_non_negative_integer, is_hash256_str, is_hex_str, - is_real_number) + is_int_or_float, is_non_negative_int_or_float) from . import util from . import x509 from . import pem from . import version from . import blockchain from .blockchain import Blockchain, HEADER_SIZE +from . import bitcoin from . import constants from .i18n import _ from .logging import Logger @@ -96,9 +97,14 @@ def assert_integer(val: Any) -> None: raise RequestCorrupted(f'{val!r} should be an integer') -def assert_real_number(val: Any, *, as_str: bool = False) -> None: - if not is_real_number(val, as_str=as_str): - raise RequestCorrupted(f'{val!r} should be a number') +def assert_int_or_float(val: Any) -> None: + if not is_int_or_float(val): + raise RequestCorrupted(f'{val!r} should be int or float') + + +def assert_non_negative_int_or_float(val: Any) -> None: + if not is_non_negative_int_or_float(val): + raise RequestCorrupted(f'{val!r} should be a non-negative int or float') def assert_hash256_str(val: Any) -> None: @@ -656,14 +662,13 @@ class Interface(Logger): async def request_fee_estimates(self): from .simple_config import FEE_ETA_TARGETS - from .bitcoin import COIN while True: async with TaskGroup() as group: fee_tasks = [] for i in FEE_ETA_TARGETS: - fee_tasks.append((i, await group.spawn(self.session.send_request('blockchain.estimatefee', [i])))) + fee_tasks.append((i, await group.spawn(self.get_estimatefee(i)))) for nblock_target, task in fee_tasks: - fee = int(task.result() * COIN) + fee = task.result() if fee < 0: continue self.fee_estimates_eta[nblock_target] = fee self.network.update_fee_estimates() @@ -983,6 +988,61 @@ class Interface(Logger): assert_hash256_str(res) return res + async def get_fee_histogram(self) -> Sequence[Tuple[Union[float, int], int]]: + # do request + res = await self.session.send_request('mempool.get_fee_histogram') + # check response + assert_list_or_tuple(res) + for fee, s in res: + assert_non_negative_int_or_float(fee) + assert_non_negative_integer(s) + return res + + async def get_server_banner(self) -> str: + # do request + res = await self.session.send_request('server.banner') + # check response + if not isinstance(res, str): + raise RequestCorrupted(f'{res!r} should be a str') + return res + + async def get_donation_address(self) -> str: + # do request + res = await self.session.send_request('server.donation_address') + # check response + if not res: # ignore empty string + return '' + if not bitcoin.is_address(res): + # note: do not hard-fail -- allow server to use future-type + # bitcoin address we do not recognize + self.logger.info(f"invalid donation address from server: {repr(res)}") + res = '' + return res + + async def get_relay_fee(self) -> int: + """Returns the min relay feerate in sat/kbyte.""" + # do request + res = await self.session.send_request('blockchain.relayfee') + # check response + assert_non_negative_int_or_float(res) + relayfee = int(res * bitcoin.COIN) + relayfee = max(0, relayfee) + return relayfee + + async def get_estimatefee(self, num_blocks: int) -> int: + """Returns a feerate estimate for getting confirmed within + num_blocks blocks, in sat/kbyte. + """ + if not is_non_negative_integer(num_blocks): + raise Exception(f"{repr(num_blocks)} is not a num_blocks") + # do request + res = await self.session.send_request('blockchain.estimatefee', [num_blocks]) + # check response + if res != -1: + assert_non_negative_int_or_float(res) + res = int(res * bitcoin.COIN) + return res + def _assert_header_does_not_check_against_any_chain(header: dict) -> None: chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) diff --git a/electrum/network.py b/electrum/network.py @@ -418,20 +418,15 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): def is_connecting(self): return self.connection_status == 'connecting' - async def _request_server_info(self, interface): + async def _request_server_info(self, interface: 'Interface'): await interface.ready session = interface.session async def get_banner(): - self.banner = await session.send_request('server.banner') + self.banner = await interface.get_server_banner() self.notify('banner') async def get_donation_address(): - addr = await session.send_request('server.donation_address') - if not bitcoin.is_address(addr): - if addr: # ignore empty string - self.logger.info(f"invalid donation address from server: {repr(addr)}") - addr = '' - self.donation_address = addr + self.donation_address = await interface.get_donation_address() async def get_server_peers(): server_peers = await session.send_request('server.peers.subscribe') random.shuffle(server_peers) @@ -441,12 +436,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): self.server_peers = parse_servers(server_peers) self.notify('servers') async def get_relay_fee(): - relayfee = await session.send_request('blockchain.relayfee') - if relayfee is None: - self.relay_fee = None - else: - relayfee = int(relayfee * COIN) - self.relay_fee = max(0, relayfee) + self.relay_fee = await interface.get_relay_fee() async with TaskGroup() as group: await group.spawn(get_banner) @@ -456,9 +446,8 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): await group.spawn(self._request_fee_estimates(interface)) async def _request_fee_estimates(self, interface): - session = interface.session self.config.requested_fee_estimates() - histogram = await session.send_request('mempool.get_fee_histogram') + histogram = await interface.get_fee_histogram() self.config.mempool_fees = histogram self.logger.info(f'fee_histogram {histogram}') self.notify('fee_histogram') diff --git a/electrum/simple_config.py b/electrum/simple_config.py @@ -5,7 +5,7 @@ import os import stat import ssl from decimal import Decimal -from typing import Union, Optional, Dict +from typing import Union, Optional, Dict, Sequence, Tuple from numbers import Real from copy import deepcopy @@ -65,7 +65,7 @@ class SimpleConfig(Logger): # a thread-safe way. self.lock = threading.RLock() - self.mempool_fees = {} # type: Dict[Union[float, int], int] + self.mempool_fees = [] # type: Sequence[Tuple[Union[float, int], int]] self.fee_estimates = {} self.fee_estimates_last_updated = {} self.last_time_fee_estimates_requested = 0 # zero ensures immediate fees diff --git a/electrum/tests/test_util.py b/electrum/tests/test_util.py @@ -2,7 +2,9 @@ from decimal import Decimal from electrum.util import (format_satoshis, format_fee_satoshis, parse_URI, is_hash256_str, chunks, is_ip_address, list_enabled_bits, - format_satoshis_plain, is_private_netaddress) + format_satoshis_plain, is_private_netaddress, is_hex_str, + is_integer, is_non_negative_integer, is_int_or_float, + is_non_negative_int_or_float) from . import ElectrumTestCase @@ -121,6 +123,89 @@ class TestUtil(ElectrumTestCase): self.assertFalse(is_hash256_str(None)) self.assertFalse(is_hash256_str(7)) + def test_is_hex_str(self): + self.assertTrue(is_hex_str('09a4')) + self.assertTrue(is_hex_str('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA')) + self.assertTrue(is_hex_str('00' * 33)) + + self.assertFalse(is_hex_str('000')) + self.assertFalse(is_hex_str('qweqwe')) + self.assertFalse(is_hex_str(None)) + self.assertFalse(is_hex_str(7)) + + def test_is_integer(self): + self.assertTrue(is_integer(7)) + self.assertTrue(is_integer(0)) + self.assertTrue(is_integer(-1)) + self.assertTrue(is_integer(-7)) + + self.assertFalse(is_integer(Decimal("2.0"))) + self.assertFalse(is_integer(Decimal(2.0))) + self.assertFalse(is_integer(Decimal(2))) + self.assertFalse(is_integer(0.72)) + self.assertFalse(is_integer(2.0)) + self.assertFalse(is_integer(-2.0)) + self.assertFalse(is_integer('09a4')) + self.assertFalse(is_integer('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA')) + self.assertFalse(is_integer('000')) + self.assertFalse(is_integer('qweqwe')) + self.assertFalse(is_integer(None)) + + def test_is_non_negative_integer(self): + self.assertTrue(is_non_negative_integer(7)) + self.assertTrue(is_non_negative_integer(0)) + + self.assertFalse(is_non_negative_integer(Decimal("2.0"))) + self.assertFalse(is_non_negative_integer(Decimal(2.0))) + self.assertFalse(is_non_negative_integer(Decimal(2))) + self.assertFalse(is_non_negative_integer(0.72)) + self.assertFalse(is_non_negative_integer(2.0)) + self.assertFalse(is_non_negative_integer(-2.0)) + self.assertFalse(is_non_negative_integer(-1)) + self.assertFalse(is_non_negative_integer(-7)) + self.assertFalse(is_non_negative_integer('09a4')) + self.assertFalse(is_non_negative_integer('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA')) + self.assertFalse(is_non_negative_integer('000')) + self.assertFalse(is_non_negative_integer('qweqwe')) + self.assertFalse(is_non_negative_integer(None)) + + def test_is_int_or_float(self): + self.assertTrue(is_int_or_float(7)) + self.assertTrue(is_int_or_float(0)) + self.assertTrue(is_int_or_float(-1)) + self.assertTrue(is_int_or_float(-7)) + self.assertTrue(is_int_or_float(0.72)) + self.assertTrue(is_int_or_float(2.0)) + self.assertTrue(is_int_or_float(-2.0)) + + self.assertFalse(is_int_or_float(Decimal("2.0"))) + self.assertFalse(is_int_or_float(Decimal(2.0))) + self.assertFalse(is_int_or_float(Decimal(2))) + self.assertFalse(is_int_or_float('09a4')) + self.assertFalse(is_int_or_float('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA')) + self.assertFalse(is_int_or_float('000')) + self.assertFalse(is_int_or_float('qweqwe')) + self.assertFalse(is_int_or_float(None)) + + def test_is_non_negative_int_or_float(self): + self.assertTrue(is_non_negative_int_or_float(7)) + self.assertTrue(is_non_negative_int_or_float(0)) + self.assertTrue(is_non_negative_int_or_float(0.0)) + self.assertTrue(is_non_negative_int_or_float(0.72)) + self.assertTrue(is_non_negative_int_or_float(2.0)) + + self.assertFalse(is_non_negative_int_or_float(-1)) + self.assertFalse(is_non_negative_int_or_float(-7)) + self.assertFalse(is_non_negative_int_or_float(-2.0)) + self.assertFalse(is_non_negative_int_or_float(Decimal("2.0"))) + self.assertFalse(is_non_negative_int_or_float(Decimal(2.0))) + self.assertFalse(is_non_negative_int_or_float(Decimal(2))) + self.assertFalse(is_non_negative_int_or_float('09a4')) + self.assertFalse(is_non_negative_int_or_float('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA')) + self.assertFalse(is_non_negative_int_or_float('000')) + self.assertFalse(is_non_negative_int_or_float('qweqwe')) + self.assertFalse(is_non_negative_int_or_float(None)) + def test_chunks(self): self.assertEqual([[1, 2], [3, 4], [5]], list(chunks([1, 2, 3, 4, 5], 2))) diff --git a/electrum/util.py b/electrum/util.py @@ -588,38 +588,24 @@ def is_hex_str(text: Any) -> bool: return True -def is_non_negative_integer(val) -> bool: - try: - val = int(val) - if val >= 0: - return True - except: - pass +def is_integer(val: Any) -> bool: + return isinstance(val, int) + + +def is_non_negative_integer(val: Any) -> bool: + if is_integer(val): + return val >= 0 return False -def is_integer(val) -> bool: - try: - int(val) - except: - return False - else: - return True +def is_int_or_float(val: Any) -> bool: + return isinstance(val, (int, float)) -def is_real_number(val, *, as_str: bool = False) -> bool: - if as_str: # only accept str - if not isinstance(val, str): - return False - else: # only accept int/float/etc. - if isinstance(val, str): - return False - try: - Decimal(val) - except: - return False - else: - return True +def is_non_negative_int_or_float(val: Any) -> bool: + if is_int_or_float(val): + return val >= 0 + return False def chunks(items, size: int):