electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 869a72831781002546758e906da1c157d967186b
parent f08796fe68955e7d2c6381b30a64873e74a3f464
Author: SomberNight <somber.night@protonmail.com>
Date:   Tue, 10 Dec 2019 19:34:44 +0100

wallet: use abstract base classes

Diffstat:
Melectrum/json_db.py | 6+++---
Melectrum/keystore.py | 2++
Melectrum/wallet.py | 141++++++++++++++++++++++++++++++++++++++++++++++++++-----------------------------
3 files changed, 95 insertions(+), 54 deletions(-)

diff --git a/electrum/json_db.py b/electrum/json_db.py @@ -860,11 +860,11 @@ class JsonDB(Logger): self.imported_addresses.pop(addr) @locked - def has_imported_address(self, addr): + def has_imported_address(self, addr: str) -> bool: return addr in self.imported_addresses @locked - def get_imported_addresses(self): + def get_imported_addresses(self) -> Sequence[str]: return list(sorted(self.imported_addresses.keys())) @locked @@ -874,7 +874,7 @@ class JsonDB(Logger): def load_addresses(self, wallet_type): """ called from Abstract_Wallet.__init__ """ if wallet_type == 'imported': - self.imported_addresses = self.get_data_ref('addresses') + self.imported_addresses = self.get_data_ref('addresses') # type: Dict[str, dict] else: self.get_data_ref('addresses') for name in ['receiving', 'change']: diff --git a/electrum/keystore.py b/electrum/keystore.py @@ -624,6 +624,8 @@ class Old_KeyStore(MasterPublicKeyMixin, Deterministic_KeyStore): return public_key.get_public_key_hex(compressed=False) def derive_pubkey(self, for_change, n) -> str: + for_change = int(for_change) + assert for_change in (0, 1) return self.get_pubkey_from_mpk(self.mpk, for_change, n) def _get_private_key_from_stretched_exponent(self, for_change, n, secexp): diff --git a/electrum/wallet.py b/electrum/wallet.py @@ -22,9 +22,9 @@ # SOFTWARE. # Wallet classes: -# - Imported_Wallet: imported address, no keystore -# - Standard_Wallet: one keystore, P2PKH -# - Multisig_Wallet: several keystores, P2SH +# - Imported_Wallet: imported addresses or single keys, 0 or 1 keystore +# - Standard_Wallet: one HD keystore, P2PKH-like scripts +# - Multisig_Wallet: several HD keystores, M-of-N OP_CHECKMULTISIG scripts import os import sys @@ -40,6 +40,8 @@ from collections import defaultdict from numbers import Number from decimal import Decimal from typing import TYPE_CHECKING, List, Optional, Tuple, Union, NamedTuple, Sequence, Dict, Any, Set +from abc import ABC, abstractmethod +import itertools from .i18n import _ from .bip32 import BIP32Node @@ -210,7 +212,7 @@ class TxWalletDetails(NamedTuple): mempool_depth_bytes: Optional[int] -class Abstract_Wallet(AddressSynchronizer): +class Abstract_Wallet(AddressSynchronizer, ABC): """ Wallet classes are created to handle various address generation methods. Completion states (watching-only, single account, no seed, etc) are handled inside classes. @@ -314,8 +316,9 @@ class Abstract_Wallet(AddressSynchronizer): self.test_addresses_sanity() super().load_and_cleanup() + @abstractmethod def load_keystore(self) -> None: - raise NotImplementedError() # implemented by subclasses + pass def diagnostic_name(self): return self.basename() @@ -332,7 +335,7 @@ class Abstract_Wallet(AddressSynchronizer): def basename(self) -> str: return os.path.basename(self.storage.path) - def test_addresses_sanity(self): + def test_addresses_sanity(self) -> None: addrs = self.get_receiving_addresses() if len(addrs) > 0: addr = str(addrs[0]) @@ -350,7 +353,7 @@ class Abstract_Wallet(AddressSynchronizer): self._unused_change_addresses = [addr for addr in addrs if not self.is_used(addr)] return list(self._unused_change_addresses) - def is_deterministic(self): + def is_deterministic(self) -> bool: return self.keystore.is_deterministic() def set_label(self, name, text = None): @@ -417,26 +420,22 @@ class Abstract_Wallet(AddressSynchronizer): return False return self.get_address_index(address)[0] == 1 + @abstractmethod def get_address_index(self, address): - raise NotImplementedError() + pass + @abstractmethod def get_redeem_script(self, address: str) -> Optional[str]: - txin_type = self.get_txin_type(address) - if txin_type in ('p2pkh', 'p2wpkh', 'p2pk'): - return None - if txin_type == 'p2wpkh-p2sh': - pubkey = self.get_public_key(address) - return bitcoin.p2wpkh_nested_script(pubkey) - if txin_type == 'address': - return None - raise UnknownTxinType(f'unexpected txin_type {txin_type}') + pass + @abstractmethod def get_witness_script(self, address: str) -> Optional[str]: - return None + pass + @abstractmethod def get_txin_type(self, address: str) -> str: """Return script type of wallet address.""" - raise NotImplementedError() + pass def export_private_key(self, address, password) -> str: if self.is_watching_only(): @@ -451,17 +450,14 @@ class Abstract_Wallet(AddressSynchronizer): serialized_privkey = bitcoin.serialize_privkey(pk, compressed, txin_type) return serialized_privkey - def get_public_keys(self, address): - return [self.get_public_key(address)] + @abstractmethod + def get_public_keys(self, address: str) -> Sequence[str]: + pass def get_public_keys_with_deriv_info(self, address: str) -> Dict[str, Tuple[KeyStoreWithMPK, Sequence[int]]]: """Returns a map: pubkey_hex -> (keystore, derivation_suffix)""" return {} - def is_found(self): - return True - #return self.history.values() != [[]] * len(self.history) - def get_tx_info(self, tx) -> TxWalletDetails: is_relevant, is_mine, v, fee = self.get_wallet_delta(tx) if fee is None and isinstance(tx, PartialTransaction): @@ -536,11 +532,13 @@ class Abstract_Wallet(AddressSynchronizer): utxos = [utxo for utxo in utxos if not self.is_frozen_coin(utxo)] return utxos + @abstractmethod def get_receiving_addresses(self, *, slice_start=None, slice_stop=None) -> Sequence[str]: - raise NotImplementedError() # implemented by subclasses + pass + @abstractmethod def get_change_addresses(self, *, slice_start=None, slice_stop=None) -> Sequence[str]: - raise NotImplementedError() # implemented by subclasses + pass def dummy_address(self): # first receiving address @@ -1304,8 +1302,9 @@ class Abstract_Wallet(AddressSynchronizer): locktime = get_locktime_for_new_transaction(self.network) return PartialTransaction.from_io(inputs, outputs, locktime=locktime) + @abstractmethod def _add_input_sig_info(self, txin: PartialTxInput, address: str, *, only_der_suffix: bool = True) -> None: - raise NotImplementedError() # implemented by subclasses + pass def _add_txinout_derivation_info(self, txinout: Union[PartialTxInput, PartialTxOutput], address: str, *, only_der_suffix: bool = True) -> None: @@ -1439,10 +1438,10 @@ class Abstract_Wallet(AddressSynchronizer): tx.add_info_from_wallet(self, include_xpubs_and_full_paths=False) return tx - def try_detecting_internal_addresses_corruption(self): + def try_detecting_internal_addresses_corruption(self) -> None: pass - def check_address(self, addr): + def check_address(self, addr: str) -> None: pass def check_returned_address(func): @@ -1479,7 +1478,7 @@ class Abstract_Wallet(AddressSynchronizer): choice = addr return choice - def create_new_address(self, for_change=False): + def create_new_address(self, for_change: bool = False): raise Exception("this wallet cannot generate new addresses") def get_payment_status(self, address, amount): @@ -1650,8 +1649,9 @@ class Abstract_Wallet(AddressSynchronizer): out.sort(key=operator.itemgetter('time')) return out + @abstractmethod def get_fingerprint(self): - raise NotImplementedError() + pass def can_import_privkey(self): return False @@ -1722,15 +1722,23 @@ class Abstract_Wallet(AddressSynchronizer): self.storage.set_keystore_encryption(bool(new_pw) and encrypt_keystore) self.storage.write() + @abstractmethod + def _update_password_for_keystore(self, old_pw: Optional[str], new_pw: Optional[str]) -> None: + pass + def sign_message(self, address, message, password): index = self.get_address_index(address) return self.keystore.sign_message(index, message, password) def decrypt_message(self, pubkey, message, password) -> bytes: - addr = self.pubkeys_to_address(pubkey) + addr = self.pubkeys_to_address([pubkey]) index = self.get_address_index(addr) return self.keystore.decrypt_message(index, message, password) + @abstractmethod + def pubkeys_to_address(self, pubkeys: Sequence[str]) -> Optional[str]: + pass + def txin_value(self, txin: TxInput) -> Optional[int]: if isinstance(txin, PartialTxInput): v = txin.value_sats() @@ -1799,8 +1807,9 @@ class Abstract_Wallet(AddressSynchronizer): # overridden for TrustedCoin wallets return False + @abstractmethod def is_watching_only(self) -> bool: - raise NotImplementedError() + pass def get_keystore(self) -> Optional[KeyStore]: return self.keystore @@ -1808,14 +1817,17 @@ class Abstract_Wallet(AddressSynchronizer): def get_keystores(self) -> Sequence[KeyStore]: return [self.keystore] if self.keystore else [] + @abstractmethod def save_keystore(self): - raise NotImplementedError() + pass + @abstractmethod def has_seed(self) -> bool: - raise NotImplementedError() + pass + @abstractmethod def is_beyond_limit(self, address: str) -> bool: - raise NotImplementedError() + pass class Simple_Wallet(Abstract_Wallet): @@ -1832,6 +1844,27 @@ class Simple_Wallet(Abstract_Wallet): def save_keystore(self): self.storage.put('keystore', self.keystore.dump()) + @abstractmethod + def get_public_key(self, address: str) -> Optional[str]: + pass + + def get_public_keys(self, address: str) -> Sequence[str]: + return [self.get_public_key(address)] + + def get_redeem_script(self, address: str) -> Optional[str]: + txin_type = self.get_txin_type(address) + if txin_type in ('p2pkh', 'p2wpkh', 'p2pk'): + return None + if txin_type == 'p2wpkh-p2sh': + pubkey = self.get_public_key(address) + return bitcoin.p2wpkh_nested_script(pubkey) + if txin_type == 'address': + return None + raise UnknownTxinType(f'unexpected txin_type {txin_type}') + + def get_witness_script(self, address: str) -> Optional[str]: + return None + class Imported_Wallet(Simple_Wallet): # wallet made of imported addresses @@ -2005,10 +2038,13 @@ class Imported_Wallet(Simple_Wallet): raise Exception(f'Unexpected script type: {txin.script_type}. ' f'Imported wallets are not implemented to handle this.') - def pubkeys_to_address(self, pubkey): + def pubkeys_to_address(self, pubkeys): + pubkey = pubkeys[0] for addr in self.db.get_imported_addresses(): if self.db.get_imported_address(addr)['pubkey'] == pubkey: return addr + return None + class Deterministic_Wallet(Abstract_Wallet): @@ -2047,7 +2083,7 @@ class Deterministic_Wallet(Abstract_Wallet): # sample2: a few more randomly selected addresses_rand = addresses_all[10:] addresses_sample2 = random.sample(addresses_rand, min(len(addresses_rand), 10)) - for addr_found in addresses_sample1 + addresses_sample2: + for addr_found in itertools.chain(addresses_sample1, addresses_sample2): self.check_address(addr_found) def check_address(self, addr): @@ -2058,9 +2094,6 @@ class Deterministic_Wallet(Abstract_Wallet): def get_seed(self, password): return self.keystore.get_seed(password) - def add_seed(self, seed, pw): - self.keystore.add_seed(seed, pw) - def change_gap_limit(self, value): '''This method is not called in the code, it is kept for console use''' if value >= self.min_acceptable_gap(): @@ -2093,9 +2126,14 @@ class Deterministic_Wallet(Abstract_Wallet): nmax = max(nmax, n) return nmax + 1 - def derive_address(self, for_change, n): - x = self.derive_pubkeys(for_change, n) - return self.pubkeys_to_address(x) + @abstractmethod + def derive_pubkeys(self, c: int, i: int) -> Sequence[str]: + pass + + def derive_address(self, for_change: int, n: int) -> str: + for_change = int(for_change) + pubkeys = self.derive_pubkeys(for_change, n) + return self.pubkeys_to_address(pubkeys) def get_public_keys_with_deriv_info(self, address: str): der_suffix = self.get_address_index(address) @@ -2117,11 +2155,11 @@ class Deterministic_Wallet(Abstract_Wallet): only_der_suffix=only_der_suffix) txinout.bip32_paths[bfh(pubkey_hex)] = (fp_bytes, der_full) - def create_new_address(self, for_change=False): + def create_new_address(self, for_change: bool = False): assert type(for_change) is bool with self.lock: n = self.db.num_change_addresses() if for_change else self.db.num_receiving_addresses() - address = self.derive_address(for_change, n) + address = self.derive_address(int(for_change), n) self.db.add_change_address(address) if for_change else self.db.add_receiving_address(address) self.add_address(address) if for_change: @@ -2197,8 +2235,8 @@ class Simple_Deterministic_Wallet(Simple_Wallet, Deterministic_Wallet): def get_public_key(self, address): sequence = self.get_address_index(address) - pubkey = self.derive_pubkeys(*sequence) - return pubkey + pubkeys = self.derive_pubkeys(*sequence) + return pubkeys[0] def load_keystore(self): self.keystore = load_keystore(self.storage, 'keystore') @@ -2212,7 +2250,7 @@ class Simple_Deterministic_Wallet(Simple_Wallet, Deterministic_Wallet): return self.keystore.get_master_public_key() def derive_pubkeys(self, c, i): - return self.keystore.derive_pubkey(c, i) + return [self.keystore.derive_pubkey(c, i)] @@ -2222,7 +2260,8 @@ class Simple_Deterministic_Wallet(Simple_Wallet, Deterministic_Wallet): class Standard_Wallet(Simple_Deterministic_Wallet): wallet_type = 'standard' - def pubkeys_to_address(self, pubkey): + def pubkeys_to_address(self, pubkeys): + pubkey = pubkeys[0] return bitcoin.pubkey_to_address(self.txin_type, pubkey)