keepkey.py (20387B)
1 from binascii import hexlify, unhexlify 2 import traceback 3 import sys 4 from typing import NamedTuple, Any, Optional, Dict, Union, List, Tuple, TYPE_CHECKING 5 6 from electrum.util import bfh, bh2u, UserCancelled, UserFacingException 7 from electrum.bip32 import BIP32Node 8 from electrum import constants 9 from electrum.i18n import _ 10 from electrum.transaction import Transaction, PartialTransaction, PartialTxInput, PartialTxOutput 11 from electrum.keystore import Hardware_KeyStore 12 from electrum.plugin import Device, runs_in_hwd_thread 13 from electrum.base_wizard import ScriptTypeNotSupported 14 15 from ..hw_wallet import HW_PluginBase 16 from ..hw_wallet.plugin import (is_any_tx_output_on_change_branch, trezor_validate_op_return_output_and_get_data, 17 get_xpubs_and_der_suffixes_from_txinout) 18 19 if TYPE_CHECKING: 20 import usb1 21 from .client import KeepKeyClient 22 23 24 # TREZOR initialization methods 25 TIM_NEW, TIM_RECOVER, TIM_MNEMONIC, TIM_PRIVKEY = range(0, 4) 26 27 28 class KeepKey_KeyStore(Hardware_KeyStore): 29 hw_type = 'keepkey' 30 device = 'KeepKey' 31 32 plugin: 'KeepKeyPlugin' 33 34 def get_client(self, force_pair=True): 35 return self.plugin.get_client(self, force_pair) 36 37 def decrypt_message(self, sequence, message, password): 38 raise UserFacingException(_('Encryption and decryption are not implemented by {}').format(self.device)) 39 40 @runs_in_hwd_thread 41 def sign_message(self, sequence, message, password): 42 client = self.get_client() 43 address_path = self.get_derivation_prefix() + "/%d/%d"%sequence 44 address_n = client.expand_path(address_path) 45 msg_sig = client.sign_message(self.plugin.get_coin_name(), address_n, message) 46 return msg_sig.signature 47 48 @runs_in_hwd_thread 49 def sign_transaction(self, tx, password): 50 if tx.is_complete(): 51 return 52 # previous transactions used as inputs 53 prev_tx = {} 54 for txin in tx.inputs(): 55 tx_hash = txin.prevout.txid.hex() 56 if txin.utxo is None and not txin.is_segwit(): 57 raise UserFacingException(_('Missing previous tx for legacy input.')) 58 prev_tx[tx_hash] = txin.utxo 59 60 self.plugin.sign_transaction(self, tx, prev_tx) 61 62 63 class KeepKeyPlugin(HW_PluginBase): 64 # Derived classes provide: 65 # 66 # class-static variables: client_class, firmware_URL, handler_class, 67 # libraries_available, libraries_URL, minimum_firmware, 68 # wallet_class, ckd_public, types, HidTransport 69 70 firmware_URL = 'https://www.keepkey.com' 71 libraries_URL = 'https://github.com/keepkey/python-keepkey' 72 minimum_firmware = (1, 0, 0) 73 keystore_class = KeepKey_KeyStore 74 SUPPORTED_XTYPES = ('standard', 'p2wpkh-p2sh', 'p2wpkh', 'p2wsh-p2sh', 'p2wsh') 75 76 MAX_LABEL_LEN = 32 77 78 def __init__(self, parent, config, name): 79 HW_PluginBase.__init__(self, parent, config, name) 80 81 try: 82 from . import client 83 import keepkeylib 84 import keepkeylib.ckd_public 85 import keepkeylib.transport_hid 86 import keepkeylib.transport_webusb 87 self.client_class = client.KeepKeyClient 88 self.ckd_public = keepkeylib.ckd_public 89 self.types = keepkeylib.client.types 90 self.DEVICE_IDS = (keepkeylib.transport_hid.DEVICE_IDS + 91 keepkeylib.transport_webusb.DEVICE_IDS) 92 # only "register" hid device id: 93 self.device_manager().register_devices(keepkeylib.transport_hid.DEVICE_IDS, plugin=self) 94 # for webusb transport, use custom enumerate function: 95 self.device_manager().register_enumerate_func(self.enumerate) 96 self.libraries_available = True 97 except ImportError: 98 self.libraries_available = False 99 100 @runs_in_hwd_thread 101 def enumerate(self): 102 from keepkeylib.transport_webusb import WebUsbTransport 103 results = [] 104 for dev in WebUsbTransport.enumerate(): 105 path = self._dev_to_str(dev) 106 results.append(Device(path=path, 107 interface_number=-1, 108 id_=path, 109 product_key=(dev.getVendorID(), dev.getProductID()), 110 usage_page=0, 111 transport_ui_string=f"webusb:{path}")) 112 return results 113 114 @staticmethod 115 def _dev_to_str(dev: "usb1.USBDevice") -> str: 116 return ":".join(str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()) 117 118 @runs_in_hwd_thread 119 def hid_transport(self, pair): 120 from keepkeylib.transport_hid import HidTransport 121 return HidTransport(pair) 122 123 @runs_in_hwd_thread 124 def webusb_transport(self, device): 125 from keepkeylib.transport_webusb import WebUsbTransport 126 for dev in WebUsbTransport.enumerate(): 127 if device.path == self._dev_to_str(dev): 128 return WebUsbTransport(dev) 129 130 @runs_in_hwd_thread 131 def _try_hid(self, device): 132 self.logger.info("Trying to connect over USB...") 133 if device.interface_number == 1: 134 pair = [None, device.path] 135 else: 136 pair = [device.path, None] 137 138 try: 139 return self.hid_transport(pair) 140 except BaseException as e: 141 # see fdb810ba622dc7dbe1259cbafb5b28e19d2ab114 142 # raise 143 self.logger.info(f"cannot connect at {device.path} {e}") 144 return None 145 146 @runs_in_hwd_thread 147 def _try_webusb(self, device): 148 self.logger.info("Trying to connect over WebUSB...") 149 try: 150 return self.webusb_transport(device) 151 except BaseException as e: 152 self.logger.info(f"cannot connect at {device.path} {e}") 153 return None 154 155 @runs_in_hwd_thread 156 def create_client(self, device, handler): 157 if device.product_key[1] == 2: 158 transport = self._try_webusb(device) 159 else: 160 transport = self._try_hid(device) 161 162 if not transport: 163 self.logger.info("cannot connect to device") 164 return 165 166 self.logger.info(f"connected to device at {device.path}") 167 168 client = self.client_class(transport, handler, self) 169 170 # Try a ping for device sanity 171 try: 172 client.ping('t') 173 except BaseException as e: 174 self.logger.info(f"ping failed {e}") 175 return None 176 177 if not client.atleast_version(*self.minimum_firmware): 178 msg = (_('Outdated {} firmware for device labelled {}. Please ' 179 'download the updated firmware from {}') 180 .format(self.device, client.label(), self.firmware_URL)) 181 self.logger.info(msg) 182 if handler: 183 handler.show_error(msg) 184 else: 185 raise UserFacingException(msg) 186 return None 187 188 return client 189 190 @runs_in_hwd_thread 191 def get_client(self, keystore, force_pair=True, *, 192 devices=None, allow_user_interaction=True) -> Optional['KeepKeyClient']: 193 client = super().get_client(keystore, force_pair, 194 devices=devices, 195 allow_user_interaction=allow_user_interaction) 196 # returns the client for a given keystore. can use xpub 197 if client: 198 client.used() 199 return client 200 201 def get_coin_name(self): 202 return "Testnet" if constants.net.TESTNET else "Bitcoin" 203 204 def initialize_device(self, device_id, wizard, handler): 205 # Initialization method 206 msg = _("Choose how you want to initialize your {}.\n\n" 207 "The first two methods are secure as no secret information " 208 "is entered into your computer.\n\n" 209 "For the last two methods you input secrets on your keyboard " 210 "and upload them to your {}, and so you should " 211 "only do those on a computer you know to be trustworthy " 212 "and free of malware." 213 ).format(self.device, self.device) 214 choices = [ 215 # Must be short as QT doesn't word-wrap radio button text 216 (TIM_NEW, _("Let the device generate a completely new seed randomly")), 217 (TIM_RECOVER, _("Recover from a seed you have previously written down")), 218 (TIM_MNEMONIC, _("Upload a BIP39 mnemonic to generate the seed")), 219 (TIM_PRIVKEY, _("Upload a master private key")) 220 ] 221 def f(method): 222 import threading 223 settings = self.request_trezor_init_settings(wizard, method, self.device) 224 t = threading.Thread(target=self._initialize_device_safe, args=(settings, method, device_id, wizard, handler)) 225 t.setDaemon(True) 226 t.start() 227 exit_code = wizard.loop.exec_() 228 if exit_code != 0: 229 # this method (initialize_device) was called with the expectation 230 # of leaving the device in an initialized state when finishing. 231 # signal that this is not the case: 232 raise UserCancelled() 233 wizard.choice_dialog(title=_('Initialize Device'), message=msg, choices=choices, run_next=f) 234 235 def _initialize_device_safe(self, settings, method, device_id, wizard, handler): 236 exit_code = 0 237 try: 238 self._initialize_device(settings, method, device_id, wizard, handler) 239 except UserCancelled: 240 exit_code = 1 241 except BaseException as e: 242 self.logger.exception('') 243 handler.show_error(repr(e)) 244 exit_code = 1 245 finally: 246 wizard.loop.exit(exit_code) 247 248 @runs_in_hwd_thread 249 def _initialize_device(self, settings, method, device_id, wizard, handler): 250 item, label, pin_protection, passphrase_protection = settings 251 252 language = 'english' 253 devmgr = self.device_manager() 254 client = devmgr.client_by_id(device_id) 255 if not client: 256 raise Exception(_("The device was disconnected.")) 257 258 if method == TIM_NEW: 259 strength = 64 * (item + 2) # 128, 192 or 256 260 client.reset_device(True, strength, passphrase_protection, 261 pin_protection, label, language) 262 elif method == TIM_RECOVER: 263 word_count = 6 * (item + 2) # 12, 18 or 24 264 client.step = 0 265 client.recovery_device(word_count, passphrase_protection, 266 pin_protection, label, language) 267 elif method == TIM_MNEMONIC: 268 pin = pin_protection # It's the pin, not a boolean 269 client.load_device_by_mnemonic(str(item), pin, 270 passphrase_protection, 271 label, language) 272 else: 273 pin = pin_protection # It's the pin, not a boolean 274 client.load_device_by_xprv(item, pin, passphrase_protection, 275 label, language) 276 277 def _make_node_path(self, xpub, address_n): 278 bip32node = BIP32Node.from_xkey(xpub) 279 node = self.types.HDNodeType( 280 depth=bip32node.depth, 281 fingerprint=int.from_bytes(bip32node.fingerprint, 'big'), 282 child_num=int.from_bytes(bip32node.child_number, 'big'), 283 chain_code=bip32node.chaincode, 284 public_key=bip32node.eckey.get_public_key_bytes(compressed=True), 285 ) 286 return self.types.HDNodePathType(node=node, address_n=address_n) 287 288 def setup_device(self, device_info, wizard, purpose): 289 device_id = device_info.device.id_ 290 client = self.scan_and_create_client_for_device(device_id=device_id, wizard=wizard) 291 if not device_info.initialized: 292 self.initialize_device(device_id, wizard, client.handler) 293 wizard.run_task_without_blocking_gui( 294 task=lambda: client.get_xpub("m", 'standard')) 295 client.used() 296 return client 297 298 def get_xpub(self, device_id, derivation, xtype, wizard): 299 if xtype not in self.SUPPORTED_XTYPES: 300 raise ScriptTypeNotSupported(_('This type of script is not supported with {}.').format(self.device)) 301 client = self.scan_and_create_client_for_device(device_id=device_id, wizard=wizard) 302 xpub = client.get_xpub(derivation, xtype) 303 client.used() 304 return xpub 305 306 def get_keepkey_input_script_type(self, electrum_txin_type: str): 307 if electrum_txin_type in ('p2wpkh', 'p2wsh'): 308 return self.types.SPENDWITNESS 309 if electrum_txin_type in ('p2wpkh-p2sh', 'p2wsh-p2sh'): 310 return self.types.SPENDP2SHWITNESS 311 if electrum_txin_type in ('p2pkh', ): 312 return self.types.SPENDADDRESS 313 if electrum_txin_type in ('p2sh', ): 314 return self.types.SPENDMULTISIG 315 raise ValueError('unexpected txin type: {}'.format(electrum_txin_type)) 316 317 def get_keepkey_output_script_type(self, electrum_txin_type: str): 318 if electrum_txin_type in ('p2wpkh', 'p2wsh'): 319 return self.types.PAYTOWITNESS 320 if electrum_txin_type in ('p2wpkh-p2sh', 'p2wsh-p2sh'): 321 return self.types.PAYTOP2SHWITNESS 322 if electrum_txin_type in ('p2pkh', ): 323 return self.types.PAYTOADDRESS 324 if electrum_txin_type in ('p2sh', ): 325 return self.types.PAYTOMULTISIG 326 raise ValueError('unexpected txin type: {}'.format(electrum_txin_type)) 327 328 @runs_in_hwd_thread 329 def sign_transaction(self, keystore, tx: PartialTransaction, prev_tx): 330 self.prev_tx = prev_tx 331 client = self.get_client(keystore) 332 inputs = self.tx_inputs(tx, for_sig=True, keystore=keystore) 333 outputs = self.tx_outputs(tx, keystore=keystore) 334 signatures = client.sign_tx(self.get_coin_name(), inputs, outputs, 335 lock_time=tx.locktime, version=tx.version)[0] 336 signatures = [(bh2u(x) + '01') for x in signatures] 337 tx.update_signatures(signatures) 338 339 @runs_in_hwd_thread 340 def show_address(self, wallet, address, keystore=None): 341 if keystore is None: 342 keystore = wallet.get_keystore() 343 if not self.show_address_helper(wallet, address, keystore): 344 return 345 client = self.get_client(keystore) 346 if not client.atleast_version(1, 3): 347 keystore.handler.show_error(_("Your device firmware is too old")) 348 return 349 deriv_suffix = wallet.get_address_index(address) 350 derivation = keystore.get_derivation_prefix() 351 address_path = "%s/%d/%d"%(derivation, *deriv_suffix) 352 address_n = client.expand_path(address_path) 353 script_type = self.get_keepkey_input_script_type(wallet.txin_type) 354 355 # prepare multisig, if available: 356 xpubs = wallet.get_master_public_keys() 357 if len(xpubs) > 1: 358 pubkeys = wallet.get_public_keys(address) 359 # sort xpubs using the order of pubkeys 360 sorted_pairs = sorted(zip(pubkeys, xpubs)) 361 multisig = self._make_multisig( 362 wallet.m, 363 [(xpub, deriv_suffix) for pubkey, xpub in sorted_pairs]) 364 else: 365 multisig = None 366 367 client.get_address(self.get_coin_name(), address_n, True, multisig=multisig, script_type=script_type) 368 369 def tx_inputs(self, tx: Transaction, *, for_sig=False, keystore: 'KeepKey_KeyStore' = None): 370 inputs = [] 371 for txin in tx.inputs(): 372 txinputtype = self.types.TxInputType() 373 if txin.is_coinbase_input(): 374 prev_hash = b"\x00"*32 375 prev_index = 0xffffffff # signed int -1 376 else: 377 if for_sig: 378 assert isinstance(tx, PartialTransaction) 379 assert isinstance(txin, PartialTxInput) 380 assert keystore 381 if len(txin.pubkeys) > 1: 382 xpubs_and_deriv_suffixes = get_xpubs_and_der_suffixes_from_txinout(tx, txin) 383 multisig = self._make_multisig(txin.num_sig, xpubs_and_deriv_suffixes) 384 else: 385 multisig = None 386 script_type = self.get_keepkey_input_script_type(txin.script_type) 387 txinputtype = self.types.TxInputType( 388 script_type=script_type, 389 multisig=multisig) 390 my_pubkey, full_path = keystore.find_my_pubkey_in_txinout(txin) 391 if full_path: 392 txinputtype.address_n.extend(full_path) 393 394 prev_hash = txin.prevout.txid 395 prev_index = txin.prevout.out_idx 396 397 if txin.value_sats() is not None: 398 txinputtype.amount = txin.value_sats() 399 txinputtype.prev_hash = prev_hash 400 txinputtype.prev_index = prev_index 401 402 if txin.script_sig is not None: 403 txinputtype.script_sig = txin.script_sig 404 405 txinputtype.sequence = txin.nsequence 406 407 inputs.append(txinputtype) 408 409 return inputs 410 411 def _make_multisig(self, m, xpubs): 412 if len(xpubs) == 1: 413 return None 414 pubkeys = [self._make_node_path(xpub, deriv) for xpub, deriv in xpubs] 415 return self.types.MultisigRedeemScriptType( 416 pubkeys=pubkeys, 417 signatures=[b''] * len(pubkeys), 418 m=m) 419 420 def tx_outputs(self, tx: PartialTransaction, *, keystore: 'KeepKey_KeyStore'): 421 422 def create_output_by_derivation(): 423 script_type = self.get_keepkey_output_script_type(txout.script_type) 424 if len(txout.pubkeys) > 1: 425 xpubs_and_deriv_suffixes = get_xpubs_and_der_suffixes_from_txinout(tx, txout) 426 multisig = self._make_multisig(txout.num_sig, xpubs_and_deriv_suffixes) 427 else: 428 multisig = None 429 my_pubkey, full_path = keystore.find_my_pubkey_in_txinout(txout) 430 assert full_path 431 txoutputtype = self.types.TxOutputType( 432 multisig=multisig, 433 amount=txout.value, 434 address_n=full_path, 435 script_type=script_type) 436 return txoutputtype 437 438 def create_output_by_address(): 439 txoutputtype = self.types.TxOutputType() 440 txoutputtype.amount = txout.value 441 if address: 442 txoutputtype.script_type = self.types.PAYTOADDRESS 443 txoutputtype.address = address 444 else: 445 txoutputtype.script_type = self.types.PAYTOOPRETURN 446 txoutputtype.op_return_data = trezor_validate_op_return_output_and_get_data(txout) 447 return txoutputtype 448 449 outputs = [] 450 has_change = False 451 any_output_on_change_branch = is_any_tx_output_on_change_branch(tx) 452 453 for txout in tx.outputs(): 454 address = txout.address 455 use_create_by_derivation = False 456 457 if txout.is_mine and not has_change: 458 # prioritise hiding outputs on the 'change' branch from user 459 # because no more than one change address allowed 460 if txout.is_change == any_output_on_change_branch: 461 use_create_by_derivation = True 462 has_change = True 463 464 if use_create_by_derivation: 465 txoutputtype = create_output_by_derivation() 466 else: 467 txoutputtype = create_output_by_address() 468 outputs.append(txoutputtype) 469 470 return outputs 471 472 def electrum_tx_to_txtype(self, tx: Optional[Transaction]): 473 t = self.types.TransactionType() 474 if tx is None: 475 # probably for segwit input and we don't need this prev txn 476 return t 477 tx.deserialize() 478 t.version = tx.version 479 t.lock_time = tx.locktime 480 inputs = self.tx_inputs(tx) 481 t.inputs.extend(inputs) 482 for out in tx.outputs(): 483 o = t.bin_outputs.add() 484 o.amount = out.value 485 o.script_pubkey = out.scriptpubkey 486 return t 487 488 # This function is called from the TREZOR libraries (via tx_api) 489 def get_tx(self, tx_hash): 490 tx = self.prev_tx[tx_hash] 491 return self.electrum_tx_to_txtype(tx)