electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

keepkey.py (20387B)


      1 from binascii import hexlify, unhexlify
      2 import traceback
      3 import sys
      4 from typing import NamedTuple, Any, Optional, Dict, Union, List, Tuple, TYPE_CHECKING
      5 
      6 from electrum.util import bfh, bh2u, UserCancelled, UserFacingException
      7 from electrum.bip32 import BIP32Node
      8 from electrum import constants
      9 from electrum.i18n import _
     10 from electrum.transaction import Transaction, PartialTransaction, PartialTxInput, PartialTxOutput
     11 from electrum.keystore import Hardware_KeyStore
     12 from electrum.plugin import Device, runs_in_hwd_thread
     13 from electrum.base_wizard import ScriptTypeNotSupported
     14 
     15 from ..hw_wallet import HW_PluginBase
     16 from ..hw_wallet.plugin import (is_any_tx_output_on_change_branch, trezor_validate_op_return_output_and_get_data,
     17                                 get_xpubs_and_der_suffixes_from_txinout)
     18 
     19 if TYPE_CHECKING:
     20     import usb1
     21     from .client import KeepKeyClient
     22 
     23 
     24 # TREZOR initialization methods
     25 TIM_NEW, TIM_RECOVER, TIM_MNEMONIC, TIM_PRIVKEY = range(0, 4)
     26 
     27 
     28 class KeepKey_KeyStore(Hardware_KeyStore):
     29     hw_type = 'keepkey'
     30     device = 'KeepKey'
     31 
     32     plugin: 'KeepKeyPlugin'
     33 
     34     def get_client(self, force_pair=True):
     35         return self.plugin.get_client(self, force_pair)
     36 
     37     def decrypt_message(self, sequence, message, password):
     38         raise UserFacingException(_('Encryption and decryption are not implemented by {}').format(self.device))
     39 
     40     @runs_in_hwd_thread
     41     def sign_message(self, sequence, message, password):
     42         client = self.get_client()
     43         address_path = self.get_derivation_prefix() + "/%d/%d"%sequence
     44         address_n = client.expand_path(address_path)
     45         msg_sig = client.sign_message(self.plugin.get_coin_name(), address_n, message)
     46         return msg_sig.signature
     47 
     48     @runs_in_hwd_thread
     49     def sign_transaction(self, tx, password):
     50         if tx.is_complete():
     51             return
     52         # previous transactions used as inputs
     53         prev_tx = {}
     54         for txin in tx.inputs():
     55             tx_hash = txin.prevout.txid.hex()
     56             if txin.utxo is None and not txin.is_segwit():
     57                 raise UserFacingException(_('Missing previous tx for legacy input.'))
     58             prev_tx[tx_hash] = txin.utxo
     59 
     60         self.plugin.sign_transaction(self, tx, prev_tx)
     61 
     62 
     63 class KeepKeyPlugin(HW_PluginBase):
     64     # Derived classes provide:
     65     #
     66     #  class-static variables: client_class, firmware_URL, handler_class,
     67     #     libraries_available, libraries_URL, minimum_firmware,
     68     #     wallet_class, ckd_public, types, HidTransport
     69 
     70     firmware_URL = 'https://www.keepkey.com'
     71     libraries_URL = 'https://github.com/keepkey/python-keepkey'
     72     minimum_firmware = (1, 0, 0)
     73     keystore_class = KeepKey_KeyStore
     74     SUPPORTED_XTYPES = ('standard', 'p2wpkh-p2sh', 'p2wpkh', 'p2wsh-p2sh', 'p2wsh')
     75 
     76     MAX_LABEL_LEN = 32
     77 
     78     def __init__(self, parent, config, name):
     79         HW_PluginBase.__init__(self, parent, config, name)
     80 
     81         try:
     82             from . import client
     83             import keepkeylib
     84             import keepkeylib.ckd_public
     85             import keepkeylib.transport_hid
     86             import keepkeylib.transport_webusb
     87             self.client_class = client.KeepKeyClient
     88             self.ckd_public = keepkeylib.ckd_public
     89             self.types = keepkeylib.client.types
     90             self.DEVICE_IDS = (keepkeylib.transport_hid.DEVICE_IDS +
     91                                keepkeylib.transport_webusb.DEVICE_IDS)
     92             # only "register" hid device id:
     93             self.device_manager().register_devices(keepkeylib.transport_hid.DEVICE_IDS, plugin=self)
     94             # for webusb transport, use custom enumerate function:
     95             self.device_manager().register_enumerate_func(self.enumerate)
     96             self.libraries_available = True
     97         except ImportError:
     98             self.libraries_available = False
     99 
    100     @runs_in_hwd_thread
    101     def enumerate(self):
    102         from keepkeylib.transport_webusb import WebUsbTransport
    103         results = []
    104         for dev in WebUsbTransport.enumerate():
    105             path = self._dev_to_str(dev)
    106             results.append(Device(path=path,
    107                                   interface_number=-1,
    108                                   id_=path,
    109                                   product_key=(dev.getVendorID(), dev.getProductID()),
    110                                   usage_page=0,
    111                                   transport_ui_string=f"webusb:{path}"))
    112         return results
    113 
    114     @staticmethod
    115     def _dev_to_str(dev: "usb1.USBDevice") -> str:
    116         return ":".join(str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList())
    117 
    118     @runs_in_hwd_thread
    119     def hid_transport(self, pair):
    120         from keepkeylib.transport_hid import HidTransport
    121         return HidTransport(pair)
    122 
    123     @runs_in_hwd_thread
    124     def webusb_transport(self, device):
    125         from keepkeylib.transport_webusb import WebUsbTransport
    126         for dev in WebUsbTransport.enumerate():
    127             if device.path == self._dev_to_str(dev):
    128                 return WebUsbTransport(dev)
    129 
    130     @runs_in_hwd_thread
    131     def _try_hid(self, device):
    132         self.logger.info("Trying to connect over USB...")
    133         if device.interface_number == 1:
    134             pair = [None, device.path]
    135         else:
    136             pair = [device.path, None]
    137 
    138         try:
    139             return self.hid_transport(pair)
    140         except BaseException as e:
    141             # see fdb810ba622dc7dbe1259cbafb5b28e19d2ab114
    142             # raise
    143             self.logger.info(f"cannot connect at {device.path} {e}")
    144             return None
    145 
    146     @runs_in_hwd_thread
    147     def _try_webusb(self, device):
    148         self.logger.info("Trying to connect over WebUSB...")
    149         try:
    150             return self.webusb_transport(device)
    151         except BaseException as e:
    152             self.logger.info(f"cannot connect at {device.path} {e}")
    153             return None
    154 
    155     @runs_in_hwd_thread
    156     def create_client(self, device, handler):
    157         if device.product_key[1] == 2:
    158             transport = self._try_webusb(device)
    159         else:
    160             transport = self._try_hid(device)
    161 
    162         if not transport:
    163             self.logger.info("cannot connect to device")
    164             return
    165 
    166         self.logger.info(f"connected to device at {device.path}")
    167 
    168         client = self.client_class(transport, handler, self)
    169 
    170         # Try a ping for device sanity
    171         try:
    172             client.ping('t')
    173         except BaseException as e:
    174             self.logger.info(f"ping failed {e}")
    175             return None
    176 
    177         if not client.atleast_version(*self.minimum_firmware):
    178             msg = (_('Outdated {} firmware for device labelled {}. Please '
    179                      'download the updated firmware from {}')
    180                    .format(self.device, client.label(), self.firmware_URL))
    181             self.logger.info(msg)
    182             if handler:
    183                 handler.show_error(msg)
    184             else:
    185                 raise UserFacingException(msg)
    186             return None
    187 
    188         return client
    189 
    190     @runs_in_hwd_thread
    191     def get_client(self, keystore, force_pair=True, *,
    192                    devices=None, allow_user_interaction=True) -> Optional['KeepKeyClient']:
    193         client = super().get_client(keystore, force_pair,
    194                                     devices=devices,
    195                                     allow_user_interaction=allow_user_interaction)
    196         # returns the client for a given keystore. can use xpub
    197         if client:
    198             client.used()
    199         return client
    200 
    201     def get_coin_name(self):
    202         return "Testnet" if constants.net.TESTNET else "Bitcoin"
    203 
    204     def initialize_device(self, device_id, wizard, handler):
    205         # Initialization method
    206         msg = _("Choose how you want to initialize your {}.\n\n"
    207                 "The first two methods are secure as no secret information "
    208                 "is entered into your computer.\n\n"
    209                 "For the last two methods you input secrets on your keyboard "
    210                 "and upload them to your {}, and so you should "
    211                 "only do those on a computer you know to be trustworthy "
    212                 "and free of malware."
    213         ).format(self.device, self.device)
    214         choices = [
    215             # Must be short as QT doesn't word-wrap radio button text
    216             (TIM_NEW, _("Let the device generate a completely new seed randomly")),
    217             (TIM_RECOVER, _("Recover from a seed you have previously written down")),
    218             (TIM_MNEMONIC, _("Upload a BIP39 mnemonic to generate the seed")),
    219             (TIM_PRIVKEY, _("Upload a master private key"))
    220         ]
    221         def f(method):
    222             import threading
    223             settings = self.request_trezor_init_settings(wizard, method, self.device)
    224             t = threading.Thread(target=self._initialize_device_safe, args=(settings, method, device_id, wizard, handler))
    225             t.setDaemon(True)
    226             t.start()
    227             exit_code = wizard.loop.exec_()
    228             if exit_code != 0:
    229                 # this method (initialize_device) was called with the expectation
    230                 # of leaving the device in an initialized state when finishing.
    231                 # signal that this is not the case:
    232                 raise UserCancelled()
    233         wizard.choice_dialog(title=_('Initialize Device'), message=msg, choices=choices, run_next=f)
    234 
    235     def _initialize_device_safe(self, settings, method, device_id, wizard, handler):
    236         exit_code = 0
    237         try:
    238             self._initialize_device(settings, method, device_id, wizard, handler)
    239         except UserCancelled:
    240             exit_code = 1
    241         except BaseException as e:
    242             self.logger.exception('')
    243             handler.show_error(repr(e))
    244             exit_code = 1
    245         finally:
    246             wizard.loop.exit(exit_code)
    247 
    248     @runs_in_hwd_thread
    249     def _initialize_device(self, settings, method, device_id, wizard, handler):
    250         item, label, pin_protection, passphrase_protection = settings
    251 
    252         language = 'english'
    253         devmgr = self.device_manager()
    254         client = devmgr.client_by_id(device_id)
    255         if not client:
    256             raise Exception(_("The device was disconnected."))
    257 
    258         if method == TIM_NEW:
    259             strength = 64 * (item + 2)  # 128, 192 or 256
    260             client.reset_device(True, strength, passphrase_protection,
    261                                 pin_protection, label, language)
    262         elif method == TIM_RECOVER:
    263             word_count = 6 * (item + 2)  # 12, 18 or 24
    264             client.step = 0
    265             client.recovery_device(word_count, passphrase_protection,
    266                                        pin_protection, label, language)
    267         elif method == TIM_MNEMONIC:
    268             pin = pin_protection  # It's the pin, not a boolean
    269             client.load_device_by_mnemonic(str(item), pin,
    270                                            passphrase_protection,
    271                                            label, language)
    272         else:
    273             pin = pin_protection  # It's the pin, not a boolean
    274             client.load_device_by_xprv(item, pin, passphrase_protection,
    275                                        label, language)
    276 
    277     def _make_node_path(self, xpub, address_n):
    278         bip32node = BIP32Node.from_xkey(xpub)
    279         node = self.types.HDNodeType(
    280             depth=bip32node.depth,
    281             fingerprint=int.from_bytes(bip32node.fingerprint, 'big'),
    282             child_num=int.from_bytes(bip32node.child_number, 'big'),
    283             chain_code=bip32node.chaincode,
    284             public_key=bip32node.eckey.get_public_key_bytes(compressed=True),
    285         )
    286         return self.types.HDNodePathType(node=node, address_n=address_n)
    287 
    288     def setup_device(self, device_info, wizard, purpose):
    289         device_id = device_info.device.id_
    290         client = self.scan_and_create_client_for_device(device_id=device_id, wizard=wizard)
    291         if not device_info.initialized:
    292             self.initialize_device(device_id, wizard, client.handler)
    293         wizard.run_task_without_blocking_gui(
    294             task=lambda: client.get_xpub("m", 'standard'))
    295         client.used()
    296         return client
    297 
    298     def get_xpub(self, device_id, derivation, xtype, wizard):
    299         if xtype not in self.SUPPORTED_XTYPES:
    300             raise ScriptTypeNotSupported(_('This type of script is not supported with {}.').format(self.device))
    301         client = self.scan_and_create_client_for_device(device_id=device_id, wizard=wizard)
    302         xpub = client.get_xpub(derivation, xtype)
    303         client.used()
    304         return xpub
    305 
    306     def get_keepkey_input_script_type(self, electrum_txin_type: str):
    307         if electrum_txin_type in ('p2wpkh', 'p2wsh'):
    308             return self.types.SPENDWITNESS
    309         if electrum_txin_type in ('p2wpkh-p2sh', 'p2wsh-p2sh'):
    310             return self.types.SPENDP2SHWITNESS
    311         if electrum_txin_type in ('p2pkh', ):
    312             return self.types.SPENDADDRESS
    313         if electrum_txin_type in ('p2sh', ):
    314             return self.types.SPENDMULTISIG
    315         raise ValueError('unexpected txin type: {}'.format(electrum_txin_type))
    316 
    317     def get_keepkey_output_script_type(self, electrum_txin_type: str):
    318         if electrum_txin_type in ('p2wpkh', 'p2wsh'):
    319             return self.types.PAYTOWITNESS
    320         if electrum_txin_type in ('p2wpkh-p2sh', 'p2wsh-p2sh'):
    321             return self.types.PAYTOP2SHWITNESS
    322         if electrum_txin_type in ('p2pkh', ):
    323             return self.types.PAYTOADDRESS
    324         if electrum_txin_type in ('p2sh', ):
    325             return self.types.PAYTOMULTISIG
    326         raise ValueError('unexpected txin type: {}'.format(electrum_txin_type))
    327 
    328     @runs_in_hwd_thread
    329     def sign_transaction(self, keystore, tx: PartialTransaction, prev_tx):
    330         self.prev_tx = prev_tx
    331         client = self.get_client(keystore)
    332         inputs = self.tx_inputs(tx, for_sig=True, keystore=keystore)
    333         outputs = self.tx_outputs(tx, keystore=keystore)
    334         signatures = client.sign_tx(self.get_coin_name(), inputs, outputs,
    335                                     lock_time=tx.locktime, version=tx.version)[0]
    336         signatures = [(bh2u(x) + '01') for x in signatures]
    337         tx.update_signatures(signatures)
    338 
    339     @runs_in_hwd_thread
    340     def show_address(self, wallet, address, keystore=None):
    341         if keystore is None:
    342             keystore = wallet.get_keystore()
    343         if not self.show_address_helper(wallet, address, keystore):
    344             return
    345         client = self.get_client(keystore)
    346         if not client.atleast_version(1, 3):
    347             keystore.handler.show_error(_("Your device firmware is too old"))
    348             return
    349         deriv_suffix = wallet.get_address_index(address)
    350         derivation = keystore.get_derivation_prefix()
    351         address_path = "%s/%d/%d"%(derivation, *deriv_suffix)
    352         address_n = client.expand_path(address_path)
    353         script_type = self.get_keepkey_input_script_type(wallet.txin_type)
    354 
    355         # prepare multisig, if available:
    356         xpubs = wallet.get_master_public_keys()
    357         if len(xpubs) > 1:
    358             pubkeys = wallet.get_public_keys(address)
    359             # sort xpubs using the order of pubkeys
    360             sorted_pairs = sorted(zip(pubkeys, xpubs))
    361             multisig = self._make_multisig(
    362                 wallet.m,
    363                 [(xpub, deriv_suffix) for pubkey, xpub in sorted_pairs])
    364         else:
    365             multisig = None
    366 
    367         client.get_address(self.get_coin_name(), address_n, True, multisig=multisig, script_type=script_type)
    368 
    369     def tx_inputs(self, tx: Transaction, *, for_sig=False, keystore: 'KeepKey_KeyStore' = None):
    370         inputs = []
    371         for txin in tx.inputs():
    372             txinputtype = self.types.TxInputType()
    373             if txin.is_coinbase_input():
    374                 prev_hash = b"\x00"*32
    375                 prev_index = 0xffffffff  # signed int -1
    376             else:
    377                 if for_sig:
    378                     assert isinstance(tx, PartialTransaction)
    379                     assert isinstance(txin, PartialTxInput)
    380                     assert keystore
    381                     if len(txin.pubkeys) > 1:
    382                         xpubs_and_deriv_suffixes = get_xpubs_and_der_suffixes_from_txinout(tx, txin)
    383                         multisig = self._make_multisig(txin.num_sig, xpubs_and_deriv_suffixes)
    384                     else:
    385                         multisig = None
    386                     script_type = self.get_keepkey_input_script_type(txin.script_type)
    387                     txinputtype = self.types.TxInputType(
    388                         script_type=script_type,
    389                         multisig=multisig)
    390                     my_pubkey, full_path = keystore.find_my_pubkey_in_txinout(txin)
    391                     if full_path:
    392                         txinputtype.address_n.extend(full_path)
    393 
    394                 prev_hash = txin.prevout.txid
    395                 prev_index = txin.prevout.out_idx
    396 
    397             if txin.value_sats() is not None:
    398                 txinputtype.amount = txin.value_sats()
    399             txinputtype.prev_hash = prev_hash
    400             txinputtype.prev_index = prev_index
    401 
    402             if txin.script_sig is not None:
    403                 txinputtype.script_sig = txin.script_sig
    404 
    405             txinputtype.sequence = txin.nsequence
    406 
    407             inputs.append(txinputtype)
    408 
    409         return inputs
    410 
    411     def _make_multisig(self, m, xpubs):
    412         if len(xpubs) == 1:
    413             return None
    414         pubkeys = [self._make_node_path(xpub, deriv) for xpub, deriv in xpubs]
    415         return self.types.MultisigRedeemScriptType(
    416             pubkeys=pubkeys,
    417             signatures=[b''] * len(pubkeys),
    418             m=m)
    419 
    420     def tx_outputs(self, tx: PartialTransaction, *, keystore: 'KeepKey_KeyStore'):
    421 
    422         def create_output_by_derivation():
    423             script_type = self.get_keepkey_output_script_type(txout.script_type)
    424             if len(txout.pubkeys) > 1:
    425                 xpubs_and_deriv_suffixes = get_xpubs_and_der_suffixes_from_txinout(tx, txout)
    426                 multisig = self._make_multisig(txout.num_sig, xpubs_and_deriv_suffixes)
    427             else:
    428                 multisig = None
    429             my_pubkey, full_path = keystore.find_my_pubkey_in_txinout(txout)
    430             assert full_path
    431             txoutputtype = self.types.TxOutputType(
    432                 multisig=multisig,
    433                 amount=txout.value,
    434                 address_n=full_path,
    435                 script_type=script_type)
    436             return txoutputtype
    437 
    438         def create_output_by_address():
    439             txoutputtype = self.types.TxOutputType()
    440             txoutputtype.amount = txout.value
    441             if address:
    442                 txoutputtype.script_type = self.types.PAYTOADDRESS
    443                 txoutputtype.address = address
    444             else:
    445                 txoutputtype.script_type = self.types.PAYTOOPRETURN
    446                 txoutputtype.op_return_data = trezor_validate_op_return_output_and_get_data(txout)
    447             return txoutputtype
    448 
    449         outputs = []
    450         has_change = False
    451         any_output_on_change_branch = is_any_tx_output_on_change_branch(tx)
    452 
    453         for txout in tx.outputs():
    454             address = txout.address
    455             use_create_by_derivation = False
    456 
    457             if txout.is_mine and not has_change:
    458                 # prioritise hiding outputs on the 'change' branch from user
    459                 # because no more than one change address allowed
    460                 if txout.is_change == any_output_on_change_branch:
    461                     use_create_by_derivation = True
    462                     has_change = True
    463 
    464             if use_create_by_derivation:
    465                 txoutputtype = create_output_by_derivation()
    466             else:
    467                 txoutputtype = create_output_by_address()
    468             outputs.append(txoutputtype)
    469 
    470         return outputs
    471 
    472     def electrum_tx_to_txtype(self, tx: Optional[Transaction]):
    473         t = self.types.TransactionType()
    474         if tx is None:
    475             # probably for segwit input and we don't need this prev txn
    476             return t
    477         tx.deserialize()
    478         t.version = tx.version
    479         t.lock_time = tx.locktime
    480         inputs = self.tx_inputs(tx)
    481         t.inputs.extend(inputs)
    482         for out in tx.outputs():
    483             o = t.bin_outputs.add()
    484             o.amount = out.value
    485             o.script_pubkey = out.scriptpubkey
    486         return t
    487 
    488     # This function is called from the TREZOR libraries (via tx_api)
    489     def get_tx(self, tx_hash):
    490         tx = self.prev_tx[tx_hash]
    491         return self.electrum_tx_to_txtype(tx)