electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 11452722aff3877b7033c2e873f73e906a110a78
parent cb88a3b6e45e1261d4926d6d385f8f779a84a2d9
Author: SomberNight <somber.night@protonmail.com>
Date:   Wed,  1 Jan 2020 07:21:08 +0100

network dns hacks: split from network.py into its own file

Diffstat:
Aelectrum/dns_hacks.py | 100+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Melectrum/network.py | 77+++--------------------------------------------------------------------------
2 files changed, 103 insertions(+), 74 deletions(-)

diff --git a/electrum/dns_hacks.py b/electrum/dns_hacks.py @@ -0,0 +1,100 @@ +# Copyright (C) 2020 The Electrum developers +# Distributed under the MIT software license, see the accompanying +# file LICENCE or http://www.opensource.org/licenses/mit-license.php + +import sys +import socket +import concurrent +from concurrent import futures +import ipaddress +from typing import Optional + +import dns +import dns.resolver + +from .logging import get_logger + + +_logger = get_logger(__name__) + +_dns_threads_executor = None # type: Optional[concurrent.futures.Executor] + + +def configure_dns_depending_on_proxy(is_proxy: bool) -> None: + # Store this somewhere so we can un-monkey-patch: + if not hasattr(socket, "_getaddrinfo"): + socket._getaddrinfo = socket.getaddrinfo + if is_proxy: + # prevent dns leaks, see http://stackoverflow.com/questions/13184205/dns-over-proxy + socket.getaddrinfo = lambda *args: [(socket.AF_INET, socket.SOCK_STREAM, 6, '', (args[0], args[1]))] + else: + if sys.platform == 'win32': + # On Windows, socket.getaddrinfo takes a mutex, and might hold it for up to 10 seconds + # when dns-resolving. To speed it up drastically, we resolve dns ourselves, outside that lock. + # See https://github.com/spesmilo/electrum/issues/4421 + _prepare_windows_dns_hack() + socket.getaddrinfo = _fast_getaddrinfo + else: + socket.getaddrinfo = socket._getaddrinfo + + +def _prepare_windows_dns_hack(): + # enable dns cache + resolver = dns.resolver.get_default_resolver() + if resolver.cache is None: + resolver.cache = dns.resolver.Cache() + # prepare threads + global _dns_threads_executor + if _dns_threads_executor is None: + _dns_threads_executor = concurrent.futures.ThreadPoolExecutor(max_workers=20, + thread_name_prefix='dns_resolver') + + +def _fast_getaddrinfo(host, *args, **kwargs): + def needs_dns_resolving(host): + try: + ipaddress.ip_address(host) + return False # already valid IP + except ValueError: + pass # not an IP + if str(host) in ('localhost', 'localhost.',): + return False + return True + + def resolve_with_dnspython(host): + addrs = [] + expected_errors = (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer, + concurrent.futures.CancelledError, concurrent.futures.TimeoutError) + ipv6_fut = _dns_threads_executor.submit(dns.resolver.query, host, dns.rdatatype.AAAA) + ipv4_fut = _dns_threads_executor.submit(dns.resolver.query, host, dns.rdatatype.A) + # try IPv6 + try: + answers = ipv6_fut.result() + addrs += [str(answer) for answer in answers] + except expected_errors as e: + pass + except BaseException as e: + _logger.info(f'dnspython failed to resolve dns (AAAA) for {repr(host)} with error: {repr(e)}') + # try IPv4 + try: + answers = ipv4_fut.result() + addrs += [str(answer) for answer in answers] + except expected_errors as e: + # dns failed for some reason, e.g. dns.resolver.NXDOMAIN this is normal. + # Simply report back failure; except if we already have some results. + if not addrs: + raise socket.gaierror(11001, 'getaddrinfo failed') from e + except BaseException as e: + # Possibly internal error in dnspython :( see #4483 and #5638 + _logger.info(f'dnspython failed to resolve dns (A) for {repr(host)} with error: {repr(e)}') + if addrs: + return addrs + # Fall back to original socket.getaddrinfo to resolve dns. + return [host] + + addrs = [host] + if needs_dns_resolving(host): + addrs = resolve_with_dnspython(host) + list_of_list_of_socketinfos = [socket._getaddrinfo(addr, *args, **kwargs) for addr in addrs] + list_of_socketinfos = [item for lst in list_of_list_of_socketinfos for item in lst] + return list_of_socketinfos diff --git a/electrum/network.py b/electrum/network.py @@ -31,15 +31,12 @@ import threading import socket import json import sys -import ipaddress import asyncio from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable import traceback import concurrent from concurrent import futures -import dns -import dns.resolver import aiorpcx from aiorpcx import TaskGroup from aiohttp import ClientResponse @@ -53,6 +50,7 @@ from .bitcoin import COIN from . import constants from . import blockchain from . import bitcoin +from . import dns_hacks from .transaction import Transaction from .blockchain import Blockchain, HEADER_SIZE from .interface import (Interface, serialize_server, deserialize_server, @@ -228,8 +226,6 @@ class UntrustedServerReturnedError(NetworkException): return f"<UntrustedServerReturnedError original_exception: {repr(self.original_exception)}>" -_dns_threads_executor = None # type: Optional[concurrent.futures.Executor] - _INSTANCE = None @@ -557,77 +553,10 @@ class Network(Logger): def _set_proxy(self, proxy: Optional[dict]): self.proxy = proxy - # Store these somewhere so we can un-monkey-patch - if not hasattr(socket, "_getaddrinfo"): - socket._getaddrinfo = socket.getaddrinfo - if proxy: - self.logger.info(f'setting proxy {proxy}') - # prevent dns leaks, see http://stackoverflow.com/questions/13184205/dns-over-proxy - socket.getaddrinfo = lambda *args: [(socket.AF_INET, socket.SOCK_STREAM, 6, '', (args[0], args[1]))] - else: - if sys.platform == 'win32': - # On Windows, socket.getaddrinfo takes a mutex, and might hold it for up to 10 seconds - # when dns-resolving. To speed it up drastically, we resolve dns ourselves, outside that lock. - # see #4421 - resolver = dns.resolver.get_default_resolver() - if resolver.cache is None: - resolver.cache = dns.resolver.Cache() - global _dns_threads_executor - if _dns_threads_executor is None: - _dns_threads_executor = concurrent.futures.ThreadPoolExecutor(max_workers=20) - socket.getaddrinfo = self._fast_getaddrinfo - else: - socket.getaddrinfo = socket._getaddrinfo + dns_hacks.configure_dns_depending_on_proxy(bool(proxy)) + self.logger.info(f'setting proxy {proxy}') self.trigger_callback('proxy_set', self.proxy) - @staticmethod - def _fast_getaddrinfo(host, *args, **kwargs): - def needs_dns_resolving(host): - try: - ipaddress.ip_address(host) - return False # already valid IP - except ValueError: - pass # not an IP - if str(host) in ('localhost', 'localhost.',): - return False - return True - def resolve_with_dnspython(host): - addrs = [] - expected_errors = (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer, - concurrent.futures.CancelledError, concurrent.futures.TimeoutError) - ipv6_fut = _dns_threads_executor.submit(dns.resolver.query, host, dns.rdatatype.AAAA) - ipv4_fut = _dns_threads_executor.submit(dns.resolver.query, host, dns.rdatatype.A) - # try IPv6 - try: - answers = ipv6_fut.result() - addrs += [str(answer) for answer in answers] - except expected_errors as e: - pass - except BaseException as e: - _logger.info(f'dnspython failed to resolve dns (AAAA) for {repr(host)} with error: {repr(e)}') - # try IPv4 - try: - answers = ipv4_fut.result() - addrs += [str(answer) for answer in answers] - except expected_errors as e: - # dns failed for some reason, e.g. dns.resolver.NXDOMAIN this is normal. - # Simply report back failure; except if we already have some results. - if not addrs: - raise socket.gaierror(11001, 'getaddrinfo failed') from e - except BaseException as e: - # Possibly internal error in dnspython :( see #4483 and #5638 - _logger.info(f'dnspython failed to resolve dns (A) for {repr(host)} with error: {repr(e)}') - if addrs: - return addrs - # Fall back to original socket.getaddrinfo to resolve dns. - return [host] - addrs = [host] - if needs_dns_resolving(host): - addrs = resolve_with_dnspython(host) - list_of_list_of_socketinfos = [socket._getaddrinfo(addr, *args, **kwargs) for addr in addrs] - list_of_socketinfos = [item for lst in list_of_list_of_socketinfos for item in lst] - return list_of_socketinfos - @log_exceptions async def set_parameters(self, net_params: NetworkParameters): proxy = net_params.proxy