electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 680df7d6b60ffcf66f6c47eb73697da1a8613405
parent c79de3ab3c7531ab6d45c8c92b6ada0c620daf09
Author: SomberNight <somber.night@protonmail.com>
Date:   Fri, 16 Mar 2018 23:19:52 +0100

trezor: move the transport-related reimplemented parts into a separate module. disable the bridge transport.

The bridge transport uses requests.post, which uses socket.getaddrinfo under the hood, which on some OSes (MacOS, Windows) in CPython takes a lock. The enumerate method for the bridge transport can block for 10-30 seconds while waiting for this lock.

Diffstat:
Aplugins/trezor/transport.py | 95+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mplugins/trezor/trezor.py | 78++++--------------------------------------------------------------------------
2 files changed, 99 insertions(+), 74 deletions(-)

diff --git a/plugins/trezor/transport.py b/plugins/trezor/transport.py @@ -0,0 +1,95 @@ +from electrum.util import PrintError + + +class TrezorTransport(PrintError): + + @staticmethod + def all_transports(): + """Reimplemented trezorlib.transport.all_transports so that we can + enable/disable specific transports. + """ + try: + # only to detect trezorlib version + from trezorlib.transport import all_transports + except ImportError: + # old trezorlib. compat for trezorlib < 0.9.2 + transports = [] + #try: + # from trezorlib.transport_bridge import BridgeTransport + # transports.append(BridgeTransport) + #except BaseException: + # pass + try: + from trezorlib.transport_hid import HidTransport + transports.append(HidTransport) + except BaseException: + pass + try: + from trezorlib.transport_udp import UdpTransport + transports.append(UdpTransport) + except BaseException: + pass + try: + from trezorlib.transport_webusb import WebUsbTransport + transports.append(WebUsbTransport) + except BaseException: + pass + else: + # new trezorlib. + transports = [] + #try: + # from trezorlib.transport.bridge import BridgeTransport + # transports.append(BridgeTransport) + #except BaseException: + # pass + try: + from trezorlib.transport.hid import HidTransport + transports.append(HidTransport) + except BaseException: + pass + try: + from trezorlib.transport.udp import UdpTransport + transports.append(UdpTransport) + except BaseException: + pass + try: + from trezorlib.transport.webusb import WebUsbTransport + transports.append(WebUsbTransport) + except BaseException: + pass + return transports + return transports + + def enumerate_devices(self): + """Just like trezorlib.transport.enumerate_devices, + but with exception catching, so that transports can fail separately. + """ + devices = [] + for transport in self.all_transports(): + try: + new_devices = transport.enumerate() + except BaseException as e: + self.print_error('enumerate failed for {}. error {}' + .format(transport.__name__, str(e))) + else: + devices.extend(new_devices) + return devices + + def get_transport(self, path=None): + """Reimplemented trezorlib.transport.get_transport, + (1) for old trezorlib + (2) to be able to disable specific transports + (3) to call our own enumerate_devices that catches exceptions + """ + if path is None: + try: + return self.enumerate_devices()[0] + except IndexError: + raise Exception("No TREZOR device found") from None + + def match_prefix(a, b): + return a.startswith(b) or b.startswith(a) + transports = [t for t in self.all_transports() if match_prefix(path, t.PATH_PREFIX)] + if transports: + return transports[0].find_by_path(path) + raise Exception("Unknown path prefix '%s'" % path) diff --git a/plugins/trezor/trezor.py b/plugins/trezor/trezor.py @@ -117,6 +117,7 @@ class TrezorPlugin(HW_PluginBase): return from . import client + from . import transport import trezorlib.ckd_public import trezorlib.messages self.client_class = client.TrezorClient @@ -124,88 +125,17 @@ class TrezorPlugin(HW_PluginBase): self.types = trezorlib.messages self.DEVICE_IDS = ('TREZOR',) + self.transport_handler = transport.TrezorTransport() self.device_manager().register_enumerate_func(self.enumerate) - @staticmethod - def _all_transports(): - """Reimplemented trezorlib.transport.all_transports for old trezorlib. - Remove this when we start to require trezorlib 0.9.2 - """ - try: - from trezorlib.transport import all_transports - except ImportError: - # compat for trezorlib < 0.9.2 - def all_transports(): - transports = [] - try: - from trezorlib.transport_bridge import BridgeTransport - transports.append(BridgeTransport) - except BaseException: - pass - try: - from trezorlib.transport_hid import HidTransport - transports.append(HidTransport) - except BaseException: - pass - try: - from trezorlib.transport_udp import UdpTransport - transports.append(UdpTransport) - except BaseException: - pass - try: - from trezorlib.transport_webusb import WebUsbTransport - transports.append(WebUsbTransport) - except BaseException: - pass - return transports - return all_transports() - - def _enumerate_devices(self): - """Just like trezorlib.transport.enumerate_devices, - but with exception catching, so that transports can fail separately. - """ - devices = [] - for transport in self._all_transports(): - try: - new_devices = transport.enumerate() - except BaseException as e: - self.print_error('enumerate failed for {}. error {}' - .format(transport.__name__, str(e))) - else: - devices.extend(new_devices) - return devices - def enumerate(self): - devices = self._enumerate_devices() + devices = self.transport_handler.enumerate_devices() return [Device(d.get_path(), -1, d.get_path(), 'TREZOR', 0) for d in devices] - def _get_transport(self, path=None): - """Reimplemented trezorlib.transport.get_transport for old trezorlib. - Remove this when we start to require trezorlib 0.9.2 - """ - try: - from trezorlib.transport import get_transport - except ImportError: - # compat for trezorlib < 0.9.2 - def get_transport(path=None, prefix_search=False): - if path is None: - try: - return self._enumerate_devices()[0] - except IndexError: - raise Exception("No TREZOR device found") from None - - def match_prefix(a, b): - return a.startswith(b) or b.startswith(a) - transports = [t for t in self._all_transports() if match_prefix(path, t.PATH_PREFIX)] - if transports: - return transports[0].find_by_path(path) - raise Exception("Unknown path prefix '%s'" % path) - return get_transport(path) - def create_client(self, device, handler): try: self.print_error("connecting to device at", device.path) - transport = self._get_transport(device.path) + transport = self.transport_handler.get_transport(device.path) except BaseException as e: self.print_error("cannot connect at", device.path, str(e)) return None