commit c5da22a9dd4459972a85d116fe281b8564c20834
parent c70484455c561b53664ba349cb88587bc349efc9
Author: SomberNight <somber.night@protonmail.com>
Date: Fri, 16 Oct 2020 19:30:42 +0200
network: tighten checks of server responses for type/sanity
Diffstat:
5 files changed, 174 insertions(+), 54 deletions(-)
diff --git a/electrum/interface.py b/electrum/interface.py
@@ -29,7 +29,7 @@ import sys
import traceback
import asyncio
import socket
-from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any
+from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any, Sequence
from collections import defaultdict
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address
import itertools
@@ -46,13 +46,14 @@ import certifi
from .util import (ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy,
is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
- is_real_number)
+ is_int_or_float, is_non_negative_int_or_float)
from . import util
from . import x509
from . import pem
from . import version
from . import blockchain
from .blockchain import Blockchain, HEADER_SIZE
+from . import bitcoin
from . import constants
from .i18n import _
from .logging import Logger
@@ -96,9 +97,14 @@ def assert_integer(val: Any) -> None:
raise RequestCorrupted(f'{val!r} should be an integer')
-def assert_real_number(val: Any, *, as_str: bool = False) -> None:
- if not is_real_number(val, as_str=as_str):
- raise RequestCorrupted(f'{val!r} should be a number')
+def assert_int_or_float(val: Any) -> None:
+ if not is_int_or_float(val):
+ raise RequestCorrupted(f'{val!r} should be int or float')
+
+
+def assert_non_negative_int_or_float(val: Any) -> None:
+ if not is_non_negative_int_or_float(val):
+ raise RequestCorrupted(f'{val!r} should be a non-negative int or float')
def assert_hash256_str(val: Any) -> None:
@@ -656,14 +662,13 @@ class Interface(Logger):
async def request_fee_estimates(self):
from .simple_config import FEE_ETA_TARGETS
- from .bitcoin import COIN
while True:
async with TaskGroup() as group:
fee_tasks = []
for i in FEE_ETA_TARGETS:
- fee_tasks.append((i, await group.spawn(self.session.send_request('blockchain.estimatefee', [i]))))
+ fee_tasks.append((i, await group.spawn(self.get_estimatefee(i))))
for nblock_target, task in fee_tasks:
- fee = int(task.result() * COIN)
+ fee = task.result()
if fee < 0: continue
self.fee_estimates_eta[nblock_target] = fee
self.network.update_fee_estimates()
@@ -983,6 +988,61 @@ class Interface(Logger):
assert_hash256_str(res)
return res
+ async def get_fee_histogram(self) -> Sequence[Tuple[Union[float, int], int]]:
+ # do request
+ res = await self.session.send_request('mempool.get_fee_histogram')
+ # check response
+ assert_list_or_tuple(res)
+ for fee, s in res:
+ assert_non_negative_int_or_float(fee)
+ assert_non_negative_integer(s)
+ return res
+
+ async def get_server_banner(self) -> str:
+ # do request
+ res = await self.session.send_request('server.banner')
+ # check response
+ if not isinstance(res, str):
+ raise RequestCorrupted(f'{res!r} should be a str')
+ return res
+
+ async def get_donation_address(self) -> str:
+ # do request
+ res = await self.session.send_request('server.donation_address')
+ # check response
+ if not res: # ignore empty string
+ return ''
+ if not bitcoin.is_address(res):
+ # note: do not hard-fail -- allow server to use future-type
+ # bitcoin address we do not recognize
+ self.logger.info(f"invalid donation address from server: {repr(res)}")
+ res = ''
+ return res
+
+ async def get_relay_fee(self) -> int:
+ """Returns the min relay feerate in sat/kbyte."""
+ # do request
+ res = await self.session.send_request('blockchain.relayfee')
+ # check response
+ assert_non_negative_int_or_float(res)
+ relayfee = int(res * bitcoin.COIN)
+ relayfee = max(0, relayfee)
+ return relayfee
+
+ async def get_estimatefee(self, num_blocks: int) -> int:
+ """Returns a feerate estimate for getting confirmed within
+ num_blocks blocks, in sat/kbyte.
+ """
+ if not is_non_negative_integer(num_blocks):
+ raise Exception(f"{repr(num_blocks)} is not a num_blocks")
+ # do request
+ res = await self.session.send_request('blockchain.estimatefee', [num_blocks])
+ # check response
+ if res != -1:
+ assert_non_negative_int_or_float(res)
+ res = int(res * bitcoin.COIN)
+ return res
+
def _assert_header_does_not_check_against_any_chain(header: dict) -> None:
chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
diff --git a/electrum/network.py b/electrum/network.py
@@ -418,20 +418,15 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
def is_connecting(self):
return self.connection_status == 'connecting'
- async def _request_server_info(self, interface):
+ async def _request_server_info(self, interface: 'Interface'):
await interface.ready
session = interface.session
async def get_banner():
- self.banner = await session.send_request('server.banner')
+ self.banner = await interface.get_server_banner()
self.notify('banner')
async def get_donation_address():
- addr = await session.send_request('server.donation_address')
- if not bitcoin.is_address(addr):
- if addr: # ignore empty string
- self.logger.info(f"invalid donation address from server: {repr(addr)}")
- addr = ''
- self.donation_address = addr
+ self.donation_address = await interface.get_donation_address()
async def get_server_peers():
server_peers = await session.send_request('server.peers.subscribe')
random.shuffle(server_peers)
@@ -441,12 +436,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
self.server_peers = parse_servers(server_peers)
self.notify('servers')
async def get_relay_fee():
- relayfee = await session.send_request('blockchain.relayfee')
- if relayfee is None:
- self.relay_fee = None
- else:
- relayfee = int(relayfee * COIN)
- self.relay_fee = max(0, relayfee)
+ self.relay_fee = await interface.get_relay_fee()
async with TaskGroup() as group:
await group.spawn(get_banner)
@@ -456,9 +446,8 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
await group.spawn(self._request_fee_estimates(interface))
async def _request_fee_estimates(self, interface):
- session = interface.session
self.config.requested_fee_estimates()
- histogram = await session.send_request('mempool.get_fee_histogram')
+ histogram = await interface.get_fee_histogram()
self.config.mempool_fees = histogram
self.logger.info(f'fee_histogram {histogram}')
self.notify('fee_histogram')
diff --git a/electrum/simple_config.py b/electrum/simple_config.py
@@ -5,7 +5,7 @@ import os
import stat
import ssl
from decimal import Decimal
-from typing import Union, Optional, Dict
+from typing import Union, Optional, Dict, Sequence, Tuple
from numbers import Real
from copy import deepcopy
@@ -65,7 +65,7 @@ class SimpleConfig(Logger):
# a thread-safe way.
self.lock = threading.RLock()
- self.mempool_fees = {} # type: Dict[Union[float, int], int]
+ self.mempool_fees = [] # type: Sequence[Tuple[Union[float, int], int]]
self.fee_estimates = {}
self.fee_estimates_last_updated = {}
self.last_time_fee_estimates_requested = 0 # zero ensures immediate fees
diff --git a/electrum/tests/test_util.py b/electrum/tests/test_util.py
@@ -2,7 +2,9 @@ from decimal import Decimal
from electrum.util import (format_satoshis, format_fee_satoshis, parse_URI,
is_hash256_str, chunks, is_ip_address, list_enabled_bits,
- format_satoshis_plain, is_private_netaddress)
+ format_satoshis_plain, is_private_netaddress, is_hex_str,
+ is_integer, is_non_negative_integer, is_int_or_float,
+ is_non_negative_int_or_float)
from . import ElectrumTestCase
@@ -121,6 +123,89 @@ class TestUtil(ElectrumTestCase):
self.assertFalse(is_hash256_str(None))
self.assertFalse(is_hash256_str(7))
+ def test_is_hex_str(self):
+ self.assertTrue(is_hex_str('09a4'))
+ self.assertTrue(is_hex_str('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
+ self.assertTrue(is_hex_str('00' * 33))
+
+ self.assertFalse(is_hex_str('000'))
+ self.assertFalse(is_hex_str('qweqwe'))
+ self.assertFalse(is_hex_str(None))
+ self.assertFalse(is_hex_str(7))
+
+ def test_is_integer(self):
+ self.assertTrue(is_integer(7))
+ self.assertTrue(is_integer(0))
+ self.assertTrue(is_integer(-1))
+ self.assertTrue(is_integer(-7))
+
+ self.assertFalse(is_integer(Decimal("2.0")))
+ self.assertFalse(is_integer(Decimal(2.0)))
+ self.assertFalse(is_integer(Decimal(2)))
+ self.assertFalse(is_integer(0.72))
+ self.assertFalse(is_integer(2.0))
+ self.assertFalse(is_integer(-2.0))
+ self.assertFalse(is_integer('09a4'))
+ self.assertFalse(is_integer('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
+ self.assertFalse(is_integer('000'))
+ self.assertFalse(is_integer('qweqwe'))
+ self.assertFalse(is_integer(None))
+
+ def test_is_non_negative_integer(self):
+ self.assertTrue(is_non_negative_integer(7))
+ self.assertTrue(is_non_negative_integer(0))
+
+ self.assertFalse(is_non_negative_integer(Decimal("2.0")))
+ self.assertFalse(is_non_negative_integer(Decimal(2.0)))
+ self.assertFalse(is_non_negative_integer(Decimal(2)))
+ self.assertFalse(is_non_negative_integer(0.72))
+ self.assertFalse(is_non_negative_integer(2.0))
+ self.assertFalse(is_non_negative_integer(-2.0))
+ self.assertFalse(is_non_negative_integer(-1))
+ self.assertFalse(is_non_negative_integer(-7))
+ self.assertFalse(is_non_negative_integer('09a4'))
+ self.assertFalse(is_non_negative_integer('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
+ self.assertFalse(is_non_negative_integer('000'))
+ self.assertFalse(is_non_negative_integer('qweqwe'))
+ self.assertFalse(is_non_negative_integer(None))
+
+ def test_is_int_or_float(self):
+ self.assertTrue(is_int_or_float(7))
+ self.assertTrue(is_int_or_float(0))
+ self.assertTrue(is_int_or_float(-1))
+ self.assertTrue(is_int_or_float(-7))
+ self.assertTrue(is_int_or_float(0.72))
+ self.assertTrue(is_int_or_float(2.0))
+ self.assertTrue(is_int_or_float(-2.0))
+
+ self.assertFalse(is_int_or_float(Decimal("2.0")))
+ self.assertFalse(is_int_or_float(Decimal(2.0)))
+ self.assertFalse(is_int_or_float(Decimal(2)))
+ self.assertFalse(is_int_or_float('09a4'))
+ self.assertFalse(is_int_or_float('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
+ self.assertFalse(is_int_or_float('000'))
+ self.assertFalse(is_int_or_float('qweqwe'))
+ self.assertFalse(is_int_or_float(None))
+
+ def test_is_non_negative_int_or_float(self):
+ self.assertTrue(is_non_negative_int_or_float(7))
+ self.assertTrue(is_non_negative_int_or_float(0))
+ self.assertTrue(is_non_negative_int_or_float(0.0))
+ self.assertTrue(is_non_negative_int_or_float(0.72))
+ self.assertTrue(is_non_negative_int_or_float(2.0))
+
+ self.assertFalse(is_non_negative_int_or_float(-1))
+ self.assertFalse(is_non_negative_int_or_float(-7))
+ self.assertFalse(is_non_negative_int_or_float(-2.0))
+ self.assertFalse(is_non_negative_int_or_float(Decimal("2.0")))
+ self.assertFalse(is_non_negative_int_or_float(Decimal(2.0)))
+ self.assertFalse(is_non_negative_int_or_float(Decimal(2)))
+ self.assertFalse(is_non_negative_int_or_float('09a4'))
+ self.assertFalse(is_non_negative_int_or_float('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
+ self.assertFalse(is_non_negative_int_or_float('000'))
+ self.assertFalse(is_non_negative_int_or_float('qweqwe'))
+ self.assertFalse(is_non_negative_int_or_float(None))
+
def test_chunks(self):
self.assertEqual([[1, 2], [3, 4], [5]],
list(chunks([1, 2, 3, 4, 5], 2)))
diff --git a/electrum/util.py b/electrum/util.py
@@ -588,38 +588,24 @@ def is_hex_str(text: Any) -> bool:
return True
-def is_non_negative_integer(val) -> bool:
- try:
- val = int(val)
- if val >= 0:
- return True
- except:
- pass
+def is_integer(val: Any) -> bool:
+ return isinstance(val, int)
+
+
+def is_non_negative_integer(val: Any) -> bool:
+ if is_integer(val):
+ return val >= 0
return False
-def is_integer(val) -> bool:
- try:
- int(val)
- except:
- return False
- else:
- return True
+def is_int_or_float(val: Any) -> bool:
+ return isinstance(val, (int, float))
-def is_real_number(val, *, as_str: bool = False) -> bool:
- if as_str: # only accept str
- if not isinstance(val, str):
- return False
- else: # only accept int/float/etc.
- if isinstance(val, str):
- return False
- try:
- Decimal(val)
- except:
- return False
- else:
- return True
+def is_non_negative_int_or_float(val: Any) -> bool:
+ if is_int_or_float(val):
+ return val >= 0
+ return False
def chunks(items, size: int):