commit 680df7d6b60ffcf66f6c47eb73697da1a8613405
parent c79de3ab3c7531ab6d45c8c92b6ada0c620daf09
Author: SomberNight <somber.night@protonmail.com>
Date: Fri, 16 Mar 2018 23:19:52 +0100
trezor: move the transport-related reimplemented parts into a separate module. disable the bridge transport.
The bridge transport uses requests.post, which uses socket.getaddrinfo under the hood, which on some OSes (MacOS, Windows) in CPython takes a lock. The enumerate method for the bridge transport can block for 10-30 seconds while waiting for this lock.
Diffstat:
2 files changed, 99 insertions(+), 74 deletions(-)
diff --git a/plugins/trezor/transport.py b/plugins/trezor/transport.py
@@ -0,0 +1,95 @@
+from electrum.util import PrintError
+
+
+class TrezorTransport(PrintError):
+
+ @staticmethod
+ def all_transports():
+ """Reimplemented trezorlib.transport.all_transports so that we can
+ enable/disable specific transports.
+ """
+ try:
+ # only to detect trezorlib version
+ from trezorlib.transport import all_transports
+ except ImportError:
+ # old trezorlib. compat for trezorlib < 0.9.2
+ transports = []
+ #try:
+ # from trezorlib.transport_bridge import BridgeTransport
+ # transports.append(BridgeTransport)
+ #except BaseException:
+ # pass
+ try:
+ from trezorlib.transport_hid import HidTransport
+ transports.append(HidTransport)
+ except BaseException:
+ pass
+ try:
+ from trezorlib.transport_udp import UdpTransport
+ transports.append(UdpTransport)
+ except BaseException:
+ pass
+ try:
+ from trezorlib.transport_webusb import WebUsbTransport
+ transports.append(WebUsbTransport)
+ except BaseException:
+ pass
+ else:
+ # new trezorlib.
+ transports = []
+ #try:
+ # from trezorlib.transport.bridge import BridgeTransport
+ # transports.append(BridgeTransport)
+ #except BaseException:
+ # pass
+ try:
+ from trezorlib.transport.hid import HidTransport
+ transports.append(HidTransport)
+ except BaseException:
+ pass
+ try:
+ from trezorlib.transport.udp import UdpTransport
+ transports.append(UdpTransport)
+ except BaseException:
+ pass
+ try:
+ from trezorlib.transport.webusb import WebUsbTransport
+ transports.append(WebUsbTransport)
+ except BaseException:
+ pass
+ return transports
+ return transports
+
+ def enumerate_devices(self):
+ """Just like trezorlib.transport.enumerate_devices,
+ but with exception catching, so that transports can fail separately.
+ """
+ devices = []
+ for transport in self.all_transports():
+ try:
+ new_devices = transport.enumerate()
+ except BaseException as e:
+ self.print_error('enumerate failed for {}. error {}'
+ .format(transport.__name__, str(e)))
+ else:
+ devices.extend(new_devices)
+ return devices
+
+ def get_transport(self, path=None):
+ """Reimplemented trezorlib.transport.get_transport,
+ (1) for old trezorlib
+ (2) to be able to disable specific transports
+ (3) to call our own enumerate_devices that catches exceptions
+ """
+ if path is None:
+ try:
+ return self.enumerate_devices()[0]
+ except IndexError:
+ raise Exception("No TREZOR device found") from None
+
+ def match_prefix(a, b):
+ return a.startswith(b) or b.startswith(a)
+ transports = [t for t in self.all_transports() if match_prefix(path, t.PATH_PREFIX)]
+ if transports:
+ return transports[0].find_by_path(path)
+ raise Exception("Unknown path prefix '%s'" % path)
diff --git a/plugins/trezor/trezor.py b/plugins/trezor/trezor.py
@@ -117,6 +117,7 @@ class TrezorPlugin(HW_PluginBase):
return
from . import client
+ from . import transport
import trezorlib.ckd_public
import trezorlib.messages
self.client_class = client.TrezorClient
@@ -124,88 +125,17 @@ class TrezorPlugin(HW_PluginBase):
self.types = trezorlib.messages
self.DEVICE_IDS = ('TREZOR',)
+ self.transport_handler = transport.TrezorTransport()
self.device_manager().register_enumerate_func(self.enumerate)
- @staticmethod
- def _all_transports():
- """Reimplemented trezorlib.transport.all_transports for old trezorlib.
- Remove this when we start to require trezorlib 0.9.2
- """
- try:
- from trezorlib.transport import all_transports
- except ImportError:
- # compat for trezorlib < 0.9.2
- def all_transports():
- transports = []
- try:
- from trezorlib.transport_bridge import BridgeTransport
- transports.append(BridgeTransport)
- except BaseException:
- pass
- try:
- from trezorlib.transport_hid import HidTransport
- transports.append(HidTransport)
- except BaseException:
- pass
- try:
- from trezorlib.transport_udp import UdpTransport
- transports.append(UdpTransport)
- except BaseException:
- pass
- try:
- from trezorlib.transport_webusb import WebUsbTransport
- transports.append(WebUsbTransport)
- except BaseException:
- pass
- return transports
- return all_transports()
-
- def _enumerate_devices(self):
- """Just like trezorlib.transport.enumerate_devices,
- but with exception catching, so that transports can fail separately.
- """
- devices = []
- for transport in self._all_transports():
- try:
- new_devices = transport.enumerate()
- except BaseException as e:
- self.print_error('enumerate failed for {}. error {}'
- .format(transport.__name__, str(e)))
- else:
- devices.extend(new_devices)
- return devices
-
def enumerate(self):
- devices = self._enumerate_devices()
+ devices = self.transport_handler.enumerate_devices()
return [Device(d.get_path(), -1, d.get_path(), 'TREZOR', 0) for d in devices]
- def _get_transport(self, path=None):
- """Reimplemented trezorlib.transport.get_transport for old trezorlib.
- Remove this when we start to require trezorlib 0.9.2
- """
- try:
- from trezorlib.transport import get_transport
- except ImportError:
- # compat for trezorlib < 0.9.2
- def get_transport(path=None, prefix_search=False):
- if path is None:
- try:
- return self._enumerate_devices()[0]
- except IndexError:
- raise Exception("No TREZOR device found") from None
-
- def match_prefix(a, b):
- return a.startswith(b) or b.startswith(a)
- transports = [t for t in self._all_transports() if match_prefix(path, t.PATH_PREFIX)]
- if transports:
- return transports[0].find_by_path(path)
- raise Exception("Unknown path prefix '%s'" % path)
- return get_transport(path)
-
def create_client(self, device, handler):
try:
self.print_error("connecting to device at", device.path)
- transport = self._get_transport(device.path)
+ transport = self.transport_handler.get_transport(device.path)
except BaseException as e:
self.print_error("cannot connect at", device.path, str(e))
return None