electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 54257cbcca5a36e0041eb4dabe04e783e94ade4b
parent fa5302bcfb695490c56ace8208fdebc7b1f04953
Author: ThomasV <thomasv@electrum.org>
Date:   Thu, 15 Aug 2019 13:17:16 +0200

Rewrite JsonRPC requests using asyncio.
 - commands are async
 - the asyncio loop is started and stopped from the main script
 - the daemon's main loop runs in the main thread
 - use jsonrpcserver and jsonrpcclient instead of jsonrpclib

Diffstat:
Melectrum/commands.py | 189++++++++++++++++++++++++++++++++++++++++---------------------------------------
Melectrum/daemon.py | 193+++++++++++++++++++++++++++++++++++++++++++++++--------------------------------
Delectrum/jsonrpc.py | 126-------------------------------------------------------------------------------
Melectrum/lnworker.py | 24+++++++++++++-----------
Melectrum/simple_config.py | 2+-
Melectrum/tests/regtest/regtest.sh | 11+++++++++--
Melectrum/tests/test_commands.py | 51+++++++++++++++++++++++++++++++++++----------------
Melectrum/tests/test_lnpeer.py | 3+++
Mrun_electrum | 82++++++++++++++++++++++++++++++++++++++++++++++++-------------------------------
9 files changed, 322 insertions(+), 359 deletions(-)

diff --git a/electrum/commands.py b/electrum/commands.py @@ -31,6 +31,7 @@ import json import ast import base64 import operator +import asyncio from functools import wraps from decimal import Decimal from typing import Optional, TYPE_CHECKING @@ -112,8 +113,8 @@ class Commands: self._callback = callback self.lnworker = self.wallet.lnworker if self.wallet else None - def _run(self, method, args, password_getter, **kwargs): - """This wrapper is called from the Qt python console.""" + def _run(self, method, args, password_getter=None, **kwargs): + """This wrapper is called from unit tests and the Qt python console.""" cmd = known_commands[method] password = kwargs.get('password', None) if (cmd.requires_password and self.wallet.has_password() @@ -125,19 +126,22 @@ class Commands: f = getattr(self, method) if cmd.requires_password: kwargs['password'] = password - result = f(*args, **kwargs) + + coro = f(*args, **kwargs) + fut = asyncio.run_coroutine_threadsafe(coro, asyncio.get_event_loop()) + result = fut.result() if self._callback: self._callback() return result @command('') - def commands(self): + async def commands(self): """List of commands""" return ' '.join(sorted(known_commands.keys())) @command('') - def create(self, passphrase=None, password=None, encrypt_file=True, seed_type=None): + async def create(self, passphrase=None, password=None, encrypt_file=True, seed_type=None): """Create a new wallet. If you want to be prompted for an argument, type '?' or ':' (concealed) """ @@ -153,7 +157,7 @@ class Commands: } @command('') - def restore(self, text, passphrase=None, password=None, encrypt_file=True): + async def restore(self, text, passphrase=None, password=None, encrypt_file=True): """Restore a wallet from text. Text can be a seed phrase, a master public key, a master private key, a list of bitcoin addresses or bitcoin private keys. @@ -171,7 +175,7 @@ class Commands: } @command('wp') - def password(self, password=None, new_password=None): + async def password(self, password=None, new_password=None): """Change wallet password. """ if self.wallet.storage.is_encrypted_with_hw_device() and new_password: raise Exception("Can't change the password of a wallet encrypted with a hw device.") @@ -181,12 +185,12 @@ class Commands: return {'password':self.wallet.has_password()} @command('w') - def get(self, key): + async def get(self, key): """Return item from wallet storage""" return self.wallet.storage.get(key) @command('') - def getconfig(self, key): + async def getconfig(self, key): """Return a configuration variable. """ return self.config.get(key) @@ -201,29 +205,29 @@ class Commands: return value @command('') - def setconfig(self, key, value): + async def setconfig(self, key, value): """Set a configuration variable. 'value' may be a string or a Python expression.""" value = self._setconfig_normalize_value(key, value) self.config.set_key(key, value) return True @command('') - def make_seed(self, nbits=132, language=None, seed_type=None): + async def make_seed(self, nbits=132, language=None, seed_type=None): """Create a seed""" from .mnemonic import Mnemonic s = Mnemonic(language).make_seed(seed_type, num_bits=nbits) return s @command('n') - def getaddresshistory(self, address): + async def getaddresshistory(self, address): """Return the transaction history of any address. Note: This is a walletless server query, results are not checked by SPV. """ sh = bitcoin.address_to_scripthash(address) - return self.network.run_from_another_thread(self.network.get_history_for_scripthash(sh)) + return await self.network.get_history_for_scripthash(sh) @command('w') - def listunspent(self): + async def listunspent(self): """List unspent outputs. Returns the list of unspent transaction outputs in your wallet.""" l = copy.deepcopy(self.wallet.get_utxos()) @@ -233,15 +237,15 @@ class Commands: return l @command('n') - def getaddressunspent(self, address): + async def getaddressunspent(self, address): """Returns the UTXO list of any address. Note: This is a walletless server query, results are not checked by SPV. """ sh = bitcoin.address_to_scripthash(address) - return self.network.run_from_another_thread(self.network.listunspent_for_scripthash(sh)) + return await self.network.listunspent_for_scripthash(sh) @command('') - def serialize(self, jsontx): + async def serialize(self, jsontx): """Create a transaction from json inputs. Inputs must have a redeemPubkey. Outputs must be a list of {'address':address, 'value':satoshi_amount}. @@ -271,7 +275,7 @@ class Commands: return tx.as_dict() @command('wp') - def signtransaction(self, tx, privkey=None, password=None): + async def signtransaction(self, tx, privkey=None, password=None): """Sign a transaction. The wallet keys will be used unless a private key is provided.""" tx = Transaction(tx) if privkey: @@ -283,20 +287,20 @@ class Commands: return tx.as_dict() @command('') - def deserialize(self, tx): + async def deserialize(self, tx): """Deserialize a serialized transaction""" tx = Transaction(tx) return tx.deserialize(force_full_parse=True) @command('n') - def broadcast(self, tx): + async def broadcast(self, tx): """Broadcast a transaction to the network. """ tx = Transaction(tx) - self.network.run_from_another_thread(self.network.broadcast_transaction(tx)) + await self.network.broadcast_transaction(tx) return tx.txid() @command('') - def createmultisig(self, num, pubkeys): + async def createmultisig(self, num, pubkeys): """Create multisig address""" assert isinstance(pubkeys, list), (type(num), type(pubkeys)) redeem_script = multisig_script(pubkeys, num) @@ -304,17 +308,17 @@ class Commands: return {'address':address, 'redeemScript':redeem_script} @command('w') - def freeze(self, address): + async def freeze(self, address): """Freeze address. Freeze the funds at one of your wallet\'s addresses""" return self.wallet.set_frozen_state_of_addresses([address], True) @command('w') - def unfreeze(self, address): + async def unfreeze(self, address): """Unfreeze address. Unfreeze the funds at one of your wallet\'s address""" return self.wallet.set_frozen_state_of_addresses([address], False) @command('wp') - def getprivatekeys(self, address, password=None): + async def getprivatekeys(self, address, password=None): """Get private keys of addresses. You may pass a single wallet address, or a list of wallet addresses.""" if isinstance(address, str): address = address.strip() @@ -324,27 +328,27 @@ class Commands: return [self.wallet.export_private_key(address, password)[0] for address in domain] @command('w') - def ismine(self, address): + async def ismine(self, address): """Check if address is in wallet. Return true if and only address is in wallet""" return self.wallet.is_mine(address) @command('') - def dumpprivkeys(self): + async def dumpprivkeys(self): """Deprecated.""" return "This command is deprecated. Use a pipe instead: 'electrum listaddresses | electrum getprivatekeys - '" @command('') - def validateaddress(self, address): + async def validateaddress(self, address): """Check that an address is valid. """ return is_address(address) @command('w') - def getpubkeys(self, address): + async def getpubkeys(self, address): """Return the public keys for a wallet address. """ return self.wallet.get_public_keys(address) @command('w') - def getbalance(self): + async def getbalance(self): """Return the balance of your wallet. """ c, u, x = self.wallet.get_balance() l = self.lnworker.get_balance() if self.lnworker else None @@ -358,45 +362,45 @@ class Commands: return out @command('n') - def getaddressbalance(self, address): + async def getaddressbalance(self, address): """Return the balance of any address. Note: This is a walletless server query, results are not checked by SPV. """ sh = bitcoin.address_to_scripthash(address) - out = self.network.run_from_another_thread(self.network.get_balance_for_scripthash(sh)) + out = await self.network.get_balance_for_scripthash(sh) out["confirmed"] = str(Decimal(out["confirmed"])/COIN) out["unconfirmed"] = str(Decimal(out["unconfirmed"])/COIN) return out @command('n') - def getmerkle(self, txid, height): + async def getmerkle(self, txid, height): """Get Merkle branch of a transaction included in a block. Electrum uses this to verify transactions (Simple Payment Verification).""" - return self.network.run_from_another_thread(self.network.get_merkle_for_transaction(txid, int(height))) + return await self.network.get_merkle_for_transaction(txid, int(height)) @command('n') - def getservers(self): + async def getservers(self): """Return the list of available servers""" return self.network.get_servers() @command('') - def version(self): + async def version(self): """Return the version of Electrum.""" from .version import ELECTRUM_VERSION return ELECTRUM_VERSION @command('w') - def getmpk(self): + async def getmpk(self): """Get master public key. Return your wallet\'s master public key""" return self.wallet.get_master_public_key() @command('wp') - def getmasterprivate(self, password=None): + async def getmasterprivate(self, password=None): """Get master private key. Return your wallet\'s master private key""" return str(self.wallet.keystore.get_master_private_key(password)) @command('') - def convert_xkey(self, xkey, xtype): + async def convert_xkey(self, xkey, xtype): """Convert xtype of a master key. e.g. xpub -> ypub""" try: node = BIP32Node.from_xkey(xkey) @@ -405,13 +409,13 @@ class Commands: return node._replace(xtype=xtype).to_xkey() @command('wp') - def getseed(self, password=None): + async def getseed(self, password=None): """Get seed phrase. Print the generation seed of your wallet.""" s = self.wallet.get_seed(password) return s @command('wp') - def importprivkey(self, privkey, password=None): + async def importprivkey(self, privkey, password=None): """Import a private key.""" if not self.wallet.can_import_privkey(): return "Error: This type of wallet cannot import private keys. Try to create a new wallet with that key." @@ -431,7 +435,7 @@ class Commands: return out['address'] @command('n') - def sweep(self, privkey, destination, fee=None, nocheck=False, imax=100): + async def sweep(self, privkey, destination, fee=None, nocheck=False, imax=100): """Sweep private keys. Returns a transaction that spends UTXOs from privkey to a destination address. The transaction is not broadcasted.""" @@ -444,14 +448,14 @@ class Commands: return tx.as_dict() if tx else None @command('wp') - def signmessage(self, address, message, password=None): + async def signmessage(self, address, message, password=None): """Sign a message with a key. Use quotes if your message contains whitespaces""" sig = self.wallet.sign_message(address, message, password) return base64.b64encode(sig).decode('ascii') @command('') - def verifymessage(self, address, signature, message): + async def verifymessage(self, address, signature, message): """Verify a signature.""" sig = base64.b64decode(signature) message = util.to_bytes(message) @@ -480,7 +484,7 @@ class Commands: return tx @command('wp') - def payto(self, destination, amount, fee=None, from_addr=None, change_addr=None, nocheck=False, unsigned=False, rbf=None, password=None, locktime=None): + async def payto(self, destination, amount, fee=None, from_addr=None, change_addr=None, nocheck=False, unsigned=False, rbf=None, password=None, locktime=None): """Create a transaction. """ tx_fee = satoshis(fee) domain = from_addr.split(',') if from_addr else None @@ -488,7 +492,7 @@ class Commands: return tx.as_dict() @command('wp') - def paytomany(self, outputs, fee=None, from_addr=None, change_addr=None, nocheck=False, unsigned=False, rbf=None, password=None, locktime=None): + async def paytomany(self, outputs, fee=None, from_addr=None, change_addr=None, nocheck=False, unsigned=False, rbf=None, password=None, locktime=None): """Create a multi-output transaction. """ tx_fee = satoshis(fee) domain = from_addr.split(',') if from_addr else None @@ -496,7 +500,7 @@ class Commands: return tx.as_dict() @command('w') - def onchain_history(self, year=None, show_addresses=False, show_fiat=False, show_fees=False): + async def onchain_history(self, year=None, show_addresses=False, show_fiat=False, show_fees=False): """Wallet onchain history. Returns the transaction history of your wallet.""" kwargs = { 'show_addresses': show_addresses, @@ -515,29 +519,29 @@ class Commands: return json_encode(self.wallet.get_detailed_history(**kwargs)) @command('w') - def lightning_history(self, show_fiat=False): + async def lightning_history(self, show_fiat=False): """ lightning history """ lightning_history = self.wallet.lnworker.get_history() if self.wallet.lnworker else [] return json_encode(lightning_history) @command('w') - def setlabel(self, key, label): + async def setlabel(self, key, label): """Assign a label to an item. Item may be a bitcoin address or a transaction ID""" self.wallet.set_label(key, label) @command('w') - def listcontacts(self): + async def listcontacts(self): """Show your list of contacts""" return self.wallet.contacts @command('w') - def getalias(self, key): + async def getalias(self, key): """Retrieve alias. Lookup in your list of contacts, and for an OpenAlias DNS record.""" return self.wallet.contacts.resolve(key) @command('w') - def searchcontacts(self, query): + async def searchcontacts(self, query): """Search through contacts, return matching entries. """ results = {} for key, value in self.wallet.contacts.items(): @@ -546,7 +550,7 @@ class Commands: return results @command('w') - def listaddresses(self, receiving=False, change=False, labels=False, frozen=False, unused=False, funded=False, balance=False): + async def listaddresses(self, receiving=False, change=False, labels=False, frozen=False, unused=False, funded=False, balance=False): """List wallet addresses. Returns the list of all addresses in your wallet. Use optional arguments to filter the results.""" out = [] for addr in self.wallet.get_addresses(): @@ -571,13 +575,13 @@ class Commands: return out @command('n') - def gettransaction(self, txid): + async def gettransaction(self, txid): """Retrieve a transaction. """ tx = None if self.wallet: tx = self.wallet.db.get_transaction(txid) if tx is None: - raw = self.network.run_from_another_thread(self.network.get_transaction(txid)) + raw = await self.network.get_transaction(txid) if raw: tx = Transaction(raw) else: @@ -585,7 +589,7 @@ class Commands: return tx.as_dict() @command('') - def encrypt(self, pubkey, message) -> str: + async def encrypt(self, pubkey, message) -> str: """Encrypt a message with a public key. Use quotes if the message contains whitespaces.""" if not is_hex_str(pubkey): raise Exception(f"pubkey must be a hex string instead of {repr(pubkey)}") @@ -598,7 +602,7 @@ class Commands: return encrypted.decode('utf-8') @command('wp') - def decrypt(self, pubkey, encrypted, password=None) -> str: + async def decrypt(self, pubkey, encrypted, password=None) -> str: """Decrypt a message encrypted with a public key.""" if not is_hex_str(pubkey): raise Exception(f"pubkey must be a hex string instead of {repr(pubkey)}") @@ -619,7 +623,7 @@ class Commands: return out @command('w') - def getrequest(self, key): + async def getrequest(self, key): """Return a payment request""" r = self.wallet.get_payment_request(key, self.config) if not r: @@ -627,12 +631,12 @@ class Commands: return self._format_request(r) #@command('w') - #def ackrequest(self, serialized): + #async def ackrequest(self, serialized): # """<Not implemented>""" # pass @command('w') - def listrequests(self, pending=False, expired=False, paid=False): + async def listrequests(self, pending=False, expired=False, paid=False): """List the payment requests you made.""" out = self.wallet.get_sorted_requests(self.config) if pending: @@ -648,18 +652,18 @@ class Commands: return list(map(self._format_request, out)) @command('w') - def createnewaddress(self): + async def createnewaddress(self): """Create a new receiving address, beyond the gap limit of the wallet""" return self.wallet.create_new_address(False) @command('w') - def getunusedaddress(self): + async def getunusedaddress(self): """Returns the first unused address of the wallet, or None if all addresses are used. An address is considered as used if it has received a transaction, or if it is used in a payment request.""" return self.wallet.get_unused_address() @command('w') - def addrequest(self, amount, memo='', expiration=None, force=False): + async def addrequest(self, amount, memo='', expiration=None, force=False): """Create a payment request, using the first unused address of the wallet. The address will be considered as used after this operation. If no payment is received, the address will be considered as unused if the payment request is deleted from the wallet.""" @@ -677,7 +681,7 @@ class Commands: return self._format_request(out) @command('w') - def addtransaction(self, tx): + async def addtransaction(self, tx): """ Add a transaction to the wallet history """ tx = Transaction(tx) if not self.wallet.add_transaction(tx.txid(), tx): @@ -686,7 +690,7 @@ class Commands: return tx.txid() @command('wp') - def signrequest(self, address, password=None): + async def signrequest(self, address, password=None): "Sign payment request with an OpenAlias" alias = self.config.get('alias') if not alias: @@ -695,31 +699,31 @@ class Commands: self.wallet.sign_payment_request(address, alias, alias_addr, password) @command('w') - def rmrequest(self, address): + async def rmrequest(self, address): """Remove a payment request""" return self.wallet.remove_payment_request(address, self.config) @command('w') - def clearrequests(self): + async def clearrequests(self): """Remove all payment requests""" for k in list(self.wallet.receive_requests.keys()): self.wallet.remove_payment_request(k, self.config) @command('n') - def notify(self, address: str, URL: str): + async def notify(self, address: str, URL: str): """Watch an address. Every time the address changes, a http POST is sent to the URL.""" if not hasattr(self, "_notifier"): self._notifier = Notifier(self.network) - self.network.run_from_another_thread(self._notifier.start_watching_queue.put((address, URL))) + await self._notifier.start_watching_queue.put((address, URL)) return True @command('wn') - def is_synchronized(self): + async def is_synchronized(self): """ return wallet synchronization status """ return self.wallet.is_up_to_date() @command('n') - def getfeerate(self, fee_method=None, fee_level=None): + async def getfeerate(self, fee_method=None, fee_level=None): """Return current suggested fee rate (in sat/kvByte), according to config settings or supplied parameters. """ @@ -738,7 +742,7 @@ class Commands: return self.config.fee_per_kb(dyn=dyn, mempool=mempool, fee_level=fee_level) @command('w') - def removelocaltx(self, txid): + async def removelocaltx(self, txid): """Remove a 'local' transaction from the wallet, and its dependent transactions. """ @@ -755,7 +759,7 @@ class Commands: self.wallet.storage.write() @command('wn') - def get_tx_status(self, txid): + async def get_tx_status(self, txid): """Returns some information regarding the tx. For now, only confirmations. The transaction must be related to the wallet. """ @@ -768,58 +772,57 @@ class Commands: } @command('') - def help(self): + async def help(self): # for the python console return sorted(known_commands.keys()) # lightning network commands @command('wn') - def add_peer(self, connection_string, timeout=20): - coro = self.lnworker.add_peer(connection_string) - self.network.run_from_another_thread(coro, timeout=timeout) + async def add_peer(self, connection_string, timeout=20): + await self.lnworker.add_peer(connection_string) return True @command('wpn') - def open_channel(self, connection_string, amount, channel_push=0, password=None): - chan = self.lnworker.open_channel(connection_string, satoshis(amount), satoshis(channel_push), password) + async def open_channel(self, connection_string, amount, channel_push=0, password=None): + chan = await self.lnworker._open_channel_coroutine(connection_string, satoshis(amount), satoshis(channel_push), password) return chan.funding_outpoint.to_str() @command('wn') - def lnpay(self, invoice, attempts=1, timeout=10): - return self.lnworker.pay(invoice, attempts=attempts, timeout=timeout) + async def lnpay(self, invoice, attempts=1, timeout=10): + return await self.lnworker._pay(invoice, attempts=attempts) @command('wn') - def addinvoice(self, requested_amount, message): + async def addinvoice(self, requested_amount, message): # using requested_amount because it is documented in param_descriptions - payment_hash = self.lnworker.add_invoice(satoshis(requested_amount), message) + payment_hash = await self.lnworker._add_invoice_coro(satoshis(requested_amount), message) invoice, direction, is_paid = self.lnworker.invoices[bh2u(payment_hash)] return invoice @command('w') - def nodeid(self): + async def nodeid(self): listen_addr = self.config.get('lightning_listen') return bh2u(self.lnworker.node_keypair.pubkey) + (('@' + listen_addr) if listen_addr else '') @command('w') - def list_channels(self): + async def list_channels(self): return list(self.lnworker.list_channels()) @command('wn') - def dumpgraph(self): + async def dumpgraph(self): return list(map(bh2u, self.lnworker.channel_db.nodes.keys())) @command('n') - def inject_fees(self, fees): + async def inject_fees(self, fees): import ast self.network.config.fee_estimates = ast.literal_eval(fees) self.network.notify('fee') @command('n') - def clear_ln_blacklist(self): + async def clear_ln_blacklist(self): self.network.path_finder.blacklist.clear() @command('w') - def lightning_invoices(self): + async def lightning_invoices(self): from .util import pr_tooltips out = [] for payment_hash, (preimage, invoice, is_received, timestamp) in self.lnworker.invoices.items(): @@ -836,18 +839,18 @@ class Commands: return out @command('w') - def lightning_history(self): + async def lightning_history(self): return self.lnworker.get_history() @command('wn') - def close_channel(self, channel_point, force=False): + async def close_channel(self, channel_point, force=False): txid, index = channel_point.split(':') chan_id, _ = channel_id_from_funding_tx(txid, int(index)) coro = self.lnworker.force_close_channel(chan_id) if force else self.lnworker.close_channel(chan_id) - return self.network.run_from_another_thread(coro) + return await coro @command('wn') - def get_channel_ctx(self, channel_point): + async def get_channel_ctx(self, channel_point): """ return the current commitment transaction of a channel """ txid, index = channel_point.split(':') chan_id, _ = channel_id_from_funding_tx(txid, int(index)) diff --git a/electrum/daemon.py b/electrum/daemon.py @@ -30,18 +30,17 @@ import traceback import sys import threading from typing import Dict, Optional, Tuple - +import aiohttp from aiohttp import web -from jsonrpcserver import async_dispatch as dispatch -from jsonrpcserver.methods import Methods +from base64 import b64decode -import jsonrpclib +import jsonrpcclient +import jsonrpcserver +from jsonrpcclient.clients.aiohttp_client import AiohttpClient -from .jsonrpc import PasswordProtectedJSONRPCServer from .version import ELECTRUM_VERSION from .network import Network -from .util import (json_decode, DaemonThread, to_string, - create_and_start_event_loop, profiler, standardize_path) +from .util import (json_decode, to_bytes, to_string, profiler, standardize_path, constant_time_compare) from .wallet import Wallet, Abstract_Wallet from .storage import WalletStorage from .commands import known_commands, Commands @@ -84,29 +83,31 @@ def get_file_descriptor(config: SimpleConfig): remove_lockfile(lockfile) -def request(config: SimpleConfig, endpoint, *args, **kwargs): + +def request(config: SimpleConfig, endpoint, args=(), timeout=60): lockfile = get_lockfile(config) while True: create_time = None try: with open(lockfile) as f: (host, port), create_time = ast.literal_eval(f.read()) - rpc_user, rpc_password = get_rpc_credentials(config) - if rpc_password == '': - # authentication disabled - server_url = 'http://%s:%d' % (host, port) - else: - server_url = 'http://%s:%s@%s:%d' % ( - rpc_user, rpc_password, host, port) except Exception: raise DaemonNotRunning() - server = jsonrpclib.Server(server_url) + rpc_user, rpc_password = get_rpc_credentials(config) + server_url = 'http://%s:%d' % (host, port) + auth = aiohttp.BasicAuth(login=rpc_user, password=rpc_password) + loop = asyncio.get_event_loop() + async def request_coroutine(): + async with aiohttp.ClientSession(auth=auth, loop=loop) as session: + server = AiohttpClient(session, server_url) + f = getattr(server, endpoint) + response = await f(*args) + return response.data.result try: - # run request - f = getattr(server, endpoint) - return f(*args, **kwargs) - except ConnectionRefusedError: - _logger.info(f"failed to connect to JSON-RPC server") + fut = asyncio.run_coroutine_threadsafe(request_coroutine(), loop) + return fut.result(timeout=timeout) + except aiohttp.client_exceptions.ClientConnectorError as e: + _logger.info(f"failed to connect to JSON-RPC server {e}") if not create_time or create_time < time.time() - 1.0: raise DaemonNotRunning() # Sleep a bit and try again; it might have just been started @@ -141,14 +142,14 @@ class WatchTowerServer(Logger): self.lnwatcher = network.local_watchtower self.app = web.Application() self.app.router.add_post("/", self.handle) - self.methods = Methods() + self.methods = jsonrpcserver.methods.Methods() self.methods.add(self.get_ctn) self.methods.add(self.add_sweep_tx) async def handle(self, request): request = await request.text() self.logger.info(f'{request}') - response = await dispatch(request, methods=self.methods) + response = await jsonrpcserver.async_dispatch(request, methods=self.methods) if response.wanted: return web.json_response(response.deserialized(), status=response.http_status) else: @@ -168,70 +169,98 @@ class WatchTowerServer(Logger): async def add_sweep_tx(self, *args): return await self.lnwatcher.sweepstore.add_sweep_tx(*args) +class AuthenticationError(Exception): + pass -class Daemon(DaemonThread): +class Daemon(Logger): @profiler def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True): - DaemonThread.__init__(self) + Logger.__init__(self) + self.running = False + self.running_lock = threading.Lock() self.config = config if fd is None and listen_jsonrpc: fd = get_file_descriptor(config) if fd is None: raise Exception('failed to lock daemon; already running?') - self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop() + self.asyncio_loop = asyncio.get_event_loop() if config.get('offline'): self.network = None else: self.network = Network(config) - self.network._loop_thread = self._loop_thread self.fx = FxThread(config, self.network) - self.gui = None + self.gui_object = None # path -> wallet; make sure path is standardized. self.wallets = {} # type: Dict[str, Abstract_Wallet] + jobs = [self.fx.run] # Setup JSONRPC server - self.server = None if listen_jsonrpc: - self.init_server(config, fd) + jobs.append(self.start_jsonrpc(config, fd)) # server-side watchtower self.watchtower = WatchTowerServer(self.network) if self.config.get('watchtower_host') else None - jobs = [self.fx.run] if self.watchtower: jobs.append(self.watchtower.run) if self.network: self.network.start(jobs) - self.start() - def init_server(self, config: SimpleConfig, fd): - host = config.get('rpchost', '127.0.0.1') - port = config.get('rpcport', 0) - rpc_user, rpc_password = get_rpc_credentials(config) - try: - server = PasswordProtectedJSONRPCServer( - (host, port), logRequests=False, - rpc_user=rpc_user, rpc_password=rpc_password) - except Exception as e: - self.logger.error(f'cannot initialize RPC server on host {host}: {repr(e)}') - self.server = None - os.close(fd) + def authenticate(self, headers): + if self.rpc_password == '': + # RPC authentication is disabled return - os.write(fd, bytes(repr((server.socket.getsockname(), time.time())), 'utf8')) - os.close(fd) - self.server = server - server.timeout = 0.1 - server.register_function(self.ping, 'ping') - server.register_function(self.run_gui, 'gui') - server.register_function(self.run_daemon, 'daemon') + auth_string = headers.get('Authorization', None) + if auth_string is None: + raise AuthenticationError('CredentialsMissing') + basic, _, encoded = auth_string.partition(' ') + if basic != 'Basic': + raise AuthenticationError('UnsupportedType') + encoded = to_bytes(encoded, 'utf8') + credentials = to_string(b64decode(encoded), 'utf8') + username, _, password = credentials.partition(':') + if not (constant_time_compare(username, self.rpc_user) + and constant_time_compare(password, self.rpc_password)): + time.sleep(0.050) + raise AuthenticationError('Invalid Credentials') + + async def handle(self, request): + try: + self.authenticate(request.headers) + except AuthenticationError: + return web.Response(text='Forbidden', status='403') + request = await request.text() + self.logger.info(f'request: {request}') + response = await jsonrpcserver.async_dispatch(request, methods=self.methods) + if response.wanted: + return web.json_response(response.deserialized(), status=response.http_status) + else: + return web.Response() + + async def start_jsonrpc(self, config: SimpleConfig, fd): + self.app = web.Application() + self.app.router.add_post("/", self.handle) + self.rpc_user, self.rpc_password = get_rpc_credentials(config) + self.methods = jsonrpcserver.methods.Methods() + self.methods.add(self.ping) + self.methods.add(self.gui) + self.methods.add(self.daemon) self.cmd_runner = Commands(self.config, None, self.network) for cmdname in known_commands: - server.register_function(getattr(self.cmd_runner, cmdname), cmdname) - server.register_function(self.run_cmdline, 'run_cmdline') + self.methods.add(getattr(self.cmd_runner, cmdname)) + self.methods.add(self.run_cmdline) + self.host = config.get('rpchost', '127.0.0.1') + self.port = config.get('rpcport', 0) + self.runner = web.AppRunner(self.app) + await self.runner.setup() + site = web.TCPSite(self.runner, self.host, self.port) + await site.start() + socket = site._server.sockets[0] + os.write(fd, bytes(repr((socket.getsockname(), time.time())), 'utf8')) + os.close(fd) - def ping(self): + async def ping(self): return True - def run_daemon(self, config_options): - asyncio.set_event_loop(self.asyncio_loop) + async def daemon(self, config_options): config = SimpleConfig(config_options) sub = config.get('subcommand') assert sub in [None, 'start', 'stop', 'status', 'load_wallet', 'close_wallet'] @@ -279,13 +308,13 @@ class Daemon(DaemonThread): response = "Daemon stopped" return response - def run_gui(self, config_options): + async def gui(self, config_options): config = SimpleConfig(config_options) - if self.gui: - if hasattr(self.gui, 'new_window'): + if self.gui_object: + if hasattr(self.gui_object, 'new_window'): config.open_last_wallet() path = config.get_wallet_path() - self.gui.new_window(path, config.get('url')) + self.gui_object.new_window(path, config.get('url')) response = "ok" else: response = "error: current GUI does not support multiple windows" @@ -339,8 +368,7 @@ class Daemon(DaemonThread): if not wallet: return wallet.stop_threads() - def run_cmdline(self, config_options): - asyncio.set_event_loop(self.asyncio_loop) + async def run_cmdline(self, config_options): password = config_options.get('password') new_password = config_options.get('new_password') config = SimpleConfig(config_options) @@ -368,41 +396,50 @@ class Daemon(DaemonThread): cmd_runner = Commands(config, wallet, self.network) func = getattr(cmd_runner, cmd.name) try: - result = func(*args, **kwargs) + result = await func(*args, **kwargs) except TypeError as e: raise Exception("Wrapping TypeError to prevent JSONRPC-Pelix from hiding traceback") from e return result - def run(self): - while self.is_running(): - self.server.handle_request() if self.server else time.sleep(0.1) + def run_daemon(self): + self.running = True + try: + while self.is_running(): + time.sleep(0.1) + except KeyboardInterrupt: + self.running = False + self.on_stop() + + def is_running(self): + with self.running_lock: + return self.running + + def stop(self): + with self.running_lock: + self.running = False + + def on_stop(self): + if self.gui_object: + self.gui_object.stop() # stop network/wallets for k, wallet in self.wallets.items(): wallet.stop_threads() if self.network: self.logger.info("shutting down network") self.network.stop() - # stop event loop - self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1) - self._loop_thread.join(timeout=1) - self.on_stop() - - def stop(self): - if self.gui: - self.gui.stop() self.logger.info("stopping, removing lockfile") remove_lockfile(get_lockfile(self.config)) - DaemonThread.stop(self) - def init_gui(self, config, plugins): + def run_gui(self, config, plugins): threading.current_thread().setName('GUI') gui_name = config.get('gui', 'qt') if gui_name in ['lite', 'classic']: gui_name = 'qt' gui = __import__('electrum.gui.' + gui_name, fromlist=['electrum']) - self.gui = gui.ElectrumGui(config, self, plugins) + self.gui_object = gui.ElectrumGui(config, self, plugins) try: - self.gui.main() + self.gui_object.main() except BaseException as e: self.logger.exception('') # app will exit now + self.on_stop() diff --git a/electrum/jsonrpc.py b/electrum/jsonrpc.py @@ -1,126 +0,0 @@ - -#!/usr/bin/env python3 -# -# Electrum - lightweight Bitcoin client -# Copyright (C) 2018 Thomas Voegtlin -# -# Permission is hereby granted, free of charge, to any person -# obtaining a copy of this software and associated documentation files -# (the "Software"), to deal in the Software without restriction, -# including without limitation the rights to use, copy, modify, merge, -# publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, -# subject to the following conditions: -# -# The above copyright notice and this permission notice shall be -# included in all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS -# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN -# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from base64 import b64decode -import time - -from jsonrpclib.SimpleJSONRPCServer import SimpleJSONRPCServer, SimpleJSONRPCRequestHandler - -from . import util -from .logging import Logger - - -class RPCAuthCredentialsInvalid(Exception): - def __str__(self): - return 'Authentication failed (bad credentials)' - - -class RPCAuthCredentialsMissing(Exception): - def __str__(self): - return 'Authentication failed (missing credentials)' - - -class RPCAuthUnsupportedType(Exception): - def __str__(self): - return 'Authentication failed (only basic auth is supported)' - - -# based on http://acooke.org/cute/BasicHTTPA0.html by andrew cooke -class AuthenticatedJSONRPCServer(SimpleJSONRPCServer, Logger): - - def __init__(self, *args, **kargs): - Logger.__init__(self) - - class VerifyingRequestHandler(SimpleJSONRPCRequestHandler): - def parse_request(myself): - # first, call the original implementation which returns - # True if all OK so far - if SimpleJSONRPCRequestHandler.parse_request(myself): - # Do not authenticate OPTIONS-requests - if myself.command.strip() == 'OPTIONS': - return True - try: - self.authenticate(myself.headers) - return True - except (RPCAuthCredentialsInvalid, RPCAuthCredentialsMissing, - RPCAuthUnsupportedType) as e: - myself.send_error(401, repr(e)) - except BaseException as e: - self.logger.exception('') - myself.send_error(500, repr(e)) - return False - SimpleJSONRPCServer.__init__( - self, requestHandler=VerifyingRequestHandler, *args, **kargs) - - def authenticate(self, headers): - raise Exception('undefined') - - -class PasswordProtectedJSONRPCServer(AuthenticatedJSONRPCServer): - - def __init__(self, *args, rpc_user, rpc_password, **kargs): - self.rpc_user = rpc_user - self.rpc_password = rpc_password - AuthenticatedJSONRPCServer.__init__(self, *args, **kargs) - - def authenticate(self, headers): - if self.rpc_password == '': - # RPC authentication is disabled - return - - auth_string = headers.get('Authorization', None) - if auth_string is None: - raise RPCAuthCredentialsMissing() - - (basic, _, encoded) = auth_string.partition(' ') - if basic != 'Basic': - raise RPCAuthUnsupportedType() - - encoded = util.to_bytes(encoded, 'utf8') - credentials = util.to_string(b64decode(encoded), 'utf8') - (username, _, password) = credentials.partition(':') - if not (util.constant_time_compare(username, self.rpc_user) - and util.constant_time_compare(password, self.rpc_password)): - time.sleep(0.050) - raise RPCAuthCredentialsInvalid() - - -class AccountsJSONRPCServer(AuthenticatedJSONRPCServer): - """ user accounts """ - - def __init__(self, *args, **kargs): - self.users = {} - AuthenticatedJSONRPCServer.__init__(self, *args, **kargs) - self.register_function(self.add_user, 'add_user') - - def authenticate(self, headers): - # todo: verify signature - return - - def add_user(self, pubkey): - user_id = len(self.users) - self.users[user_id] = pubkey - return user_id diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -700,6 +700,7 @@ class LNWallet(LNWorker): self.logger.info('REBROADCASTING CLOSING TX') await self.force_close_channel(chan.channel_id) + @log_exceptions async def _open_channel_coroutine(self, connect_str, local_amount_sat, push_sat, password): peer = await self.add_peer(connect_str) # peer might just have been connected to @@ -717,6 +718,7 @@ class LNWallet(LNWorker): def on_channels_updated(self): self.network.trigger_callback('channels') + @log_exceptions async def add_peer(self, connect_str: str) -> Peer: node_id, rest = extract_nodeid(connect_str) peer = self.peers.get(node_id) @@ -750,16 +752,8 @@ class LNWallet(LNWorker): Can be called from other threads Raises exception after timeout """ - addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) - status = self.get_invoice_status(bh2u(addr.paymenthash)) - if status == PR_PAID: - raise PaymentFailure(_("This invoice has been paid already")) - self._check_invoice(invoice, amount_sat) - self.save_invoice(addr.paymenthash, invoice, SENT, is_paid=False) - self.wallet.set_label(bh2u(addr.paymenthash), addr.get_description()) - fut = asyncio.run_coroutine_threadsafe( - self._pay(invoice, attempts, amount_sat), - self.network.asyncio_loop) + coro = self._pay(invoice, attempts, amount_sat) + fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) try: return fut.result(timeout=timeout) except concurrent.futures.TimeoutError: @@ -771,8 +765,15 @@ class LNWallet(LNWorker): if chan.short_channel_id == short_channel_id: return chan + @log_exceptions async def _pay(self, invoice, attempts=1, amount_sat=None): - addr = self._check_invoice(invoice, amount_sat) + addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) + status = self.get_invoice_status(bh2u(addr.paymenthash)) + if status == PR_PAID: + raise PaymentFailure(_("This invoice has been paid already")) + self._check_invoice(invoice, amount_sat) + self.save_invoice(addr.paymenthash, invoice, SENT, is_paid=False) + self.wallet.set_label(bh2u(addr.paymenthash), addr.get_description()) for i in range(attempts): route = await self._create_route_from_invoice(decoded_invoice=addr) if not self.get_channel_by_short_id(route[0].short_channel_id): @@ -875,6 +876,7 @@ class LNWallet(LNWorker): except concurrent.futures.TimeoutError: raise Exception(_("add_invoice timed out")) + @log_exceptions async def _add_invoice_coro(self, amount_sat, message): payment_preimage = os.urandom(32) payment_hash = sha256(payment_preimage) diff --git a/electrum/simple_config.py b/electrum/simple_config.py @@ -31,7 +31,7 @@ FEERATE_STATIC_VALUES = [1000, 2000, 5000, 10000, 20000, 30000, FEERATE_REGTEST_HARDCODED = 180000 # for eclair compat -config = None +config = {} _logger = get_logger(__name__) diff --git a/electrum/tests/regtest/regtest.sh b/electrum/tests/regtest/regtest.sh @@ -56,10 +56,11 @@ fi # start daemons. Bob is started first because he is listening if [[ $1 == "start" ]]; then $bob daemon -s 127.0.0.1:51001:t start - $bob daemon load_wallet $alice daemon -s 127.0.0.1:51001:t start - $alice daemon load_wallet $carol daemon -s 127.0.0.1:51001:t start + sleep 1 # time to accept commands + $bob daemon load_wallet + $alice daemon load_wallet $carol daemon load_wallet sleep 10 # give time to synchronize fi @@ -130,6 +131,7 @@ fi if [[ $1 == "redeem_htlcs" ]]; then $bob daemon stop ELECTRUM_DEBUG_LIGHTNING_SETTLE_DELAY=10 $bob daemon -s 127.0.0.1:51001:t start + sleep 1 $bob daemon load_wallet sleep 1 # alice opens channel @@ -177,6 +179,7 @@ fi if [[ $1 == "breach_with_unspent_htlc" ]]; then $bob daemon stop ELECTRUM_DEBUG_LIGHTNING_SETTLE_DELAY=3 $bob daemon -s 127.0.0.1:51001:t start + sleep 1 $bob daemon load_wallet wait_until_funded echo "alice opens channel" @@ -233,6 +236,7 @@ fi if [[ $1 == "breach_with_spent_htlc" ]]; then $bob daemon stop ELECTRUM_DEBUG_LIGHTNING_SETTLE_DELAY=3 $bob daemon -s 127.0.0.1:51001:t start + sleep 1 $bob daemon load_wallet wait_until_funded echo "alice opens channel" @@ -274,6 +278,7 @@ if [[ $1 == "breach_with_spent_htlc" ]]; then $alice daemon stop cp /tmp/alice/regtest/wallets/toxic_wallet /tmp/alice/regtest/wallets/default_wallet $alice daemon -s 127.0.0.1:51001:t start + sleep 1 $alice daemon load_wallet # wait until alice has spent both ctx outputs while [[ $($bitcoin_cli gettxout $ctx_id 0) ]]; do @@ -287,6 +292,7 @@ if [[ $1 == "breach_with_spent_htlc" ]]; then new_blocks 1 echo "bob comes back" $bob daemon -s 127.0.0.1:51001:t start + sleep 1 $bob daemon load_wallet while [[ $($bitcoin_cli getmempoolinfo | jq '.size') != "1" ]]; do echo "waiting for bob's transaction" @@ -312,6 +318,7 @@ if [[ $1 == "watchtower" ]]; then $carol setconfig watchtower_port 12345 $carol daemon -s 127.0.0.1:51001:t start $alice daemon -s 127.0.0.1:51001:t start + sleep 1 $alice daemon load_wallet echo "waiting until alice funded" wait_until_funded diff --git a/electrum/tests/test_commands.py b/electrum/tests/test_commands.py @@ -2,6 +2,7 @@ import unittest from unittest import mock from decimal import Decimal +from electrum.util import create_and_start_event_loop from electrum.commands import Commands, eval_bool from electrum import storage from electrum.wallet import restore_wallet_from_text @@ -11,6 +12,15 @@ from . import TestCaseForTestnet class TestCommands(unittest.TestCase): + def setUp(self): + super().setUp() + self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop() + + def tearDown(self): + super().tearDown() + self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1) + self._loop_thread.join(timeout=1) + def test_setconfig_non_auth_number(self): self.assertEqual(7777, Commands._setconfig_normalize_value('rpcport', "7777")) self.assertEqual(7777, Commands._setconfig_normalize_value('rpcport', '7777')) @@ -54,7 +64,7 @@ class TestCommands(unittest.TestCase): } for xkey1, xtype1 in xpubs: for xkey2, xtype2 in xpubs: - self.assertEqual(xkey2, cmds.convert_xkey(xkey1, xtype2)) + self.assertEqual(xkey2, cmds._run('convert_xkey', (xkey1, xtype2))) xprvs = { ("xprv9yD9r6PJmTgqpGCUf8FUkkAhNTxv4rryiFWkqb5mYQPw8aMDXUzuyJ3tgv5vUqYkdK1E6Q5jKxPss4HkMBYV4q8AfG8t7rxgyS4xQX4ndAm", "standard"), @@ -63,7 +73,7 @@ class TestCommands(unittest.TestCase): } for xkey1, xtype1 in xprvs: for xkey2, xtype2 in xprvs: - self.assertEqual(xkey2, cmds.convert_xkey(xkey1, xtype2)) + self.assertEqual(xkey2, cmds._run('convert_xkey', (xkey1, xtype2))) @mock.patch.object(storage.WalletStorage, '_write') def test_encrypt_decrypt(self, mock_write): @@ -72,8 +82,8 @@ class TestCommands(unittest.TestCase): cmds = Commands(config=None, wallet=wallet, network=None) cleartext = "asdasd this is the message" pubkey = "021f110909ded653828a254515b58498a6bafc96799fb0851554463ed44ca7d9da" - ciphertext = cmds.encrypt(pubkey, cleartext) - self.assertEqual(cleartext, cmds.decrypt(pubkey, ciphertext)) + ciphertext = cmds._run('encrypt', (pubkey, cleartext)) + self.assertEqual(cleartext, cmds._run('decrypt', (pubkey, ciphertext))) @mock.patch.object(storage.WalletStorage, '_write') def test_export_private_key_imported(self, mock_write): @@ -82,16 +92,16 @@ class TestCommands(unittest.TestCase): cmds = Commands(config=None, wallet=wallet, network=None) # single address tests with self.assertRaises(Exception): - cmds.getprivatekeys("asdasd") # invalid addr, though might raise "not in wallet" + cmds._run('getprivatekeys', ("asdasd",)) # invalid addr, though might raise "not in wallet" with self.assertRaises(Exception): - cmds.getprivatekeys("bc1qgfam82qk7uwh5j2xxmcd8cmklpe0zackyj6r23") # not in wallet + cmds._run('getprivatekeys', ("bc1qgfam82qk7uwh5j2xxmcd8cmklpe0zackyj6r23",)) # not in wallet self.assertEqual("p2wpkh:L4jkdiXszG26SUYvwwJhzGwg37H2nLhrbip7u6crmgNeJysv5FHL", - cmds.getprivatekeys("bc1q2ccr34wzep58d4239tl3x3734ttle92a8srmuw")) + cmds._run('getprivatekeys', ("bc1q2ccr34wzep58d4239tl3x3734ttle92a8srmuw",))) # list of addresses tests with self.assertRaises(Exception): - cmds.getprivatekeys(['bc1q2ccr34wzep58d4239tl3x3734ttle92a8srmuw', 'asd']) + cmds._run('getprivatekeys', (['bc1q2ccr34wzep58d4239tl3x3734ttle92a8srmuw', 'asd'], )) self.assertEqual(['p2wpkh:L4jkdiXszG26SUYvwwJhzGwg37H2nLhrbip7u6crmgNeJysv5FHL', 'p2wpkh:L4rYY5QpfN6wJEF4SEKDpcGhTPnCe9zcGs6hiSnhpprZqVywFifN'], - cmds.getprivatekeys(['bc1q2ccr34wzep58d4239tl3x3734ttle92a8srmuw', 'bc1q9pzjpjq4nqx5ycnywekcmycqz0wjp2nq604y2n'])) + cmds._run('getprivatekeys', (['bc1q2ccr34wzep58d4239tl3x3734ttle92a8srmuw', 'bc1q9pzjpjq4nqx5ycnywekcmycqz0wjp2nq604y2n'], ))) @mock.patch.object(storage.WalletStorage, '_write') def test_export_private_key_deterministic(self, mock_write): @@ -101,20 +111,29 @@ class TestCommands(unittest.TestCase): cmds = Commands(config=None, wallet=wallet, network=None) # single address tests with self.assertRaises(Exception): - cmds.getprivatekeys("asdasd") # invalid addr, though might raise "not in wallet" + cmds._run('getprivatekeys', ("asdasd",)) # invalid addr, though might raise "not in wallet" with self.assertRaises(Exception): - cmds.getprivatekeys("bc1qgfam82qk7uwh5j2xxmcd8cmklpe0zackyj6r23") # not in wallet + cmds._run('getprivatekeys', ("bc1qgfam82qk7uwh5j2xxmcd8cmklpe0zackyj6r23",)) # not in wallet self.assertEqual("p2wpkh:L15oxP24NMNAXxq5r2aom24pHPtt3Fet8ZutgL155Bad93GSubM2", - cmds.getprivatekeys("bc1q3g5tmkmlvxryhh843v4dz026avatc0zzr6h3af")) + cmds._run('getprivatekeys', ("bc1q3g5tmkmlvxryhh843v4dz026avatc0zzr6h3af",))) # list of addresses tests with self.assertRaises(Exception): - cmds.getprivatekeys(['bc1q3g5tmkmlvxryhh843v4dz026avatc0zzr6h3af', 'asd']) + cmds._run('getprivatekeys', (['bc1q3g5tmkmlvxryhh843v4dz026avatc0zzr6h3af', 'asd'],)) self.assertEqual(['p2wpkh:L15oxP24NMNAXxq5r2aom24pHPtt3Fet8ZutgL155Bad93GSubM2', 'p2wpkh:L4rYY5QpfN6wJEF4SEKDpcGhTPnCe9zcGs6hiSnhpprZqVywFifN'], - cmds.getprivatekeys(['bc1q3g5tmkmlvxryhh843v4dz026avatc0zzr6h3af', 'bc1q9pzjpjq4nqx5ycnywekcmycqz0wjp2nq604y2n'])) + cmds._run('getprivatekeys', (['bc1q3g5tmkmlvxryhh843v4dz026avatc0zzr6h3af', 'bc1q9pzjpjq4nqx5ycnywekcmycqz0wjp2nq604y2n'], ))) class TestCommandsTestnet(TestCaseForTestnet): + def setUp(self): + super().setUp() + self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop() + + def tearDown(self): + super().tearDown() + self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1) + self._loop_thread.join(timeout=1) + def test_convert_xkey(self): cmds = Commands(config=None, wallet=None, network=None) xpubs = { @@ -124,7 +143,7 @@ class TestCommandsTestnet(TestCaseForTestnet): } for xkey1, xtype1 in xpubs: for xkey2, xtype2 in xpubs: - self.assertEqual(xkey2, cmds.convert_xkey(xkey1, xtype2)) + self.assertEqual(xkey2, cmds._run('convert_xkey', (xkey1, xtype2))) xprvs = { ("tprv8c83gxdVUcznP8fMx2iNUBbaQgQC7MUbBUDG3c6YU9xgt7Dn5pfcgHUeNZTAvuYmNgVHjyTzYzGWwJr7GvKCm2FkPaaJipyipbfJeB3tdPW", "standard"), @@ -133,4 +152,4 @@ class TestCommandsTestnet(TestCaseForTestnet): } for xkey1, xtype1 in xprvs: for xkey2, xtype2 in xprvs: - self.assertEqual(xkey2, cmds.convert_xkey(xkey1, xtype2)) + self.assertEqual(xkey2, cmds._run('convert_xkey', (xkey1, xtype2))) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -93,6 +93,9 @@ class MockLNWallet: self.localfeatures = LnLocalFeatures(0) self.pending_payments = defaultdict(asyncio.Future) + def get_invoice_status(self, key): + pass + @property def lock(self): return noop_lock() diff --git a/run_electrum b/run_electrum @@ -26,7 +26,7 @@ import os import sys import warnings - +import asyncio MIN_PYTHON_VERSION = "3.6.1" # FIXME duplicated from setup.py _min_python_version_tuple = tuple(map(int, (MIN_PYTHON_VERSION.split(".")))) @@ -90,6 +90,7 @@ from electrum.util import InvalidPassword from electrum.commands import get_parser, known_commands, Commands, config_variables from electrum import daemon from electrum import keystore +from electrum.util import create_and_start_event_loop _logger = get_logger(__name__) @@ -222,7 +223,7 @@ def get_password_for_hw_device_encrypted_storage(plugins): return password -def run_offline_command(config, config_options, plugins): +async def run_offline_command(config, config_options, plugins): cmdname = config.get('cmd') cmd = known_commands[cmdname] password = config_options.get('password') @@ -256,7 +257,7 @@ def run_offline_command(config, config_options, plugins): kwargs[x] = (config_options.get(x) if x in ['password', 'new_password'] else config.get(x)) cmd_runner = Commands(config, wallet, None) func = getattr(cmd_runner, cmd.name) - result = func(*args, **kwargs) + result = await func(*args, **kwargs) # save wallet if wallet: wallet.storage.write() @@ -267,6 +268,11 @@ def init_plugins(config, gui_name): from electrum.plugin import Plugins return Plugins(config, gui_name) +def sys_exit(i): + # stop event loop and exit + loop.call_soon_threadsafe(stop_loop.set_result, 1) + loop_thread.join(timeout=1) + sys.exit(i) if __name__ == '__main__': # The hook will only be used in the Qt GUI right now @@ -345,6 +351,7 @@ if __name__ == '__main__': config = SimpleConfig(config_options) cmdname = config.get('cmd') + subcommand = config.get('subcommand') if config.get('testnet'): constants.set_testnet() @@ -355,19 +362,38 @@ if __name__ == '__main__': elif config.get('lightning') and not config.get('reckless'): raise Exception('lightning branch not available on mainnet') + if cmdname == 'daemon' and subcommand == 'start': + # fork before creating the asyncio event loop + pid = os.fork() + if pid: + print_stderr("starting daemon (PID %d)" % pid) + sys.exit(0) + else: + # redirect standard file descriptors + sys.stdout.flush() + sys.stderr.flush() + si = open(os.devnull, 'r') + so = open(os.devnull, 'w') + se = open(os.devnull, 'w') + os.dup2(si.fileno(), sys.stdin.fileno()) + os.dup2(so.fileno(), sys.stdout.fileno()) + os.dup2(se.fileno(), sys.stderr.fileno()) + + loop, stop_loop, loop_thread = create_and_start_event_loop() + if cmdname == 'gui': configure_logging(config) fd = daemon.get_file_descriptor(config) if fd is not None: plugins = init_plugins(config, config.get('gui', 'qt')) d = daemon.Daemon(config, fd) - d.init_gui(config, plugins) - sys.exit(0) + d.run_gui(config, plugins) + sys_exit(0) else: - result = daemon.request(config, 'gui', config_options) + result = daemon.request(config, 'gui', (config_options,)) elif cmdname == 'daemon': - subcommand = config.get('subcommand') + if subcommand in ['load_wallet']: init_daemon(config_options) @@ -375,20 +401,6 @@ if __name__ == '__main__': configure_logging(config) fd = daemon.get_file_descriptor(config) if fd is not None: - if subcommand == 'start': - pid = os.fork() - if pid: - print_stderr("starting daemon (PID %d)" % pid) - sys.exit(0) - # redirect standard file descriptors - sys.stdout.flush() - sys.stderr.flush() - si = open(os.devnull, 'r') - so = open(os.devnull, 'w') - se = open(os.devnull, 'w') - os.dup2(si.fileno(), sys.stdin.fileno()) - os.dup2(so.fileno(), sys.stdout.fileno()) - os.dup2(se.fileno(), sys.stderr.fileno()) # run daemon init_plugins(config, 'cmdline') d = daemon.Daemon(config, fd) @@ -400,36 +412,42 @@ if __name__ == '__main__': if not os.path.exists(path): print("Requests directory not configured.") print("You can configure it using https://github.com/spesmilo/electrum-merchant") - sys.exit(1) - d.join() - sys.exit(0) + sys_exit(1) + d.run_daemon() + sys_exit(0) else: - result = daemon.request(config, 'daemon', config_options) + result = daemon.request(config, 'daemon', (config_options,)) else: try: - result = daemon.request(config, 'daemon', config_options) + result = daemon.request(config, 'daemon', (config_options,)) except daemon.DaemonNotRunning: print_msg("Daemon not running") - sys.exit(1) + sys_exit(1) else: # command line try: init_cmdline(config_options, True) - result = daemon.request(config, 'run_cmdline', config_options) + timeout = config_options.get('timeout', 60) + if timeout: timeout = int(timeout) + result = daemon.request(config, 'run_cmdline', (config_options,), timeout) except daemon.DaemonNotRunning: cmd = known_commands[cmdname] if cmd.requires_network: print_msg("Daemon not running; try 'electrum daemon start'") - sys.exit(1) + sys_exit(1) else: init_cmdline(config_options, False) plugins = init_plugins(config, 'cmdline') - result = run_offline_command(config, config_options, plugins) - # print result + coro = run_offline_command(config, config_options, plugins) + fut = asyncio.run_coroutine_threadsafe(coro, loop) + result = fut.result(10) + except Exception as e: + print_stderr(e) + sys_exit(1) if isinstance(result, str): print_msg(result) elif type(result) is dict and result.get('error'): print_stderr(result.get('error')) elif result is not None: print_msg(json_encode(result)) - sys.exit(0) + sys_exit(0)