electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

daemon.py (22371B)


      1 #!/usr/bin/env python
      2 #
      3 # Electrum - lightweight Bitcoin client
      4 # Copyright (C) 2015 Thomas Voegtlin
      5 #
      6 # Permission is hereby granted, free of charge, to any person
      7 # obtaining a copy of this software and associated documentation files
      8 # (the "Software"), to deal in the Software without restriction,
      9 # including without limitation the rights to use, copy, modify, merge,
     10 # publish, distribute, sublicense, and/or sell copies of the Software,
     11 # and to permit persons to whom the Software is furnished to do so,
     12 # subject to the following conditions:
     13 #
     14 # The above copyright notice and this permission notice shall be
     15 # included in all copies or substantial portions of the Software.
     16 #
     17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
     18 # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
     19 # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
     20 # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
     21 # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
     22 # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
     23 # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
     24 # SOFTWARE.
     25 import asyncio
     26 import ast
     27 import os
     28 import time
     29 import traceback
     30 import sys
     31 import threading
     32 from typing import Dict, Optional, Tuple, Iterable, Callable, Union, Sequence, Mapping, TYPE_CHECKING
     33 from base64 import b64decode, b64encode
     34 from collections import defaultdict
     35 import json
     36 
     37 import aiohttp
     38 from aiohttp import web, client_exceptions
     39 from aiorpcx import TaskGroup, timeout_after, TaskTimeout, ignore_after
     40 
     41 from . import util
     42 from .network import Network
     43 from .util import (json_decode, to_bytes, to_string, profiler, standardize_path, constant_time_compare)
     44 from .invoices import PR_PAID, PR_EXPIRED
     45 from .util import log_exceptions, ignore_exceptions, randrange
     46 from .wallet import Wallet, Abstract_Wallet
     47 from .storage import WalletStorage
     48 from .wallet_db import WalletDB
     49 from .commands import known_commands, Commands
     50 from .simple_config import SimpleConfig
     51 from .exchange_rate import FxThread
     52 from .logging import get_logger, Logger
     53 
     54 if TYPE_CHECKING:
     55     from electrum import gui
     56 
     57 
     58 _logger = get_logger(__name__)
     59 
     60 
     61 class DaemonNotRunning(Exception):
     62     pass
     63 
     64 def get_lockfile(config: SimpleConfig):
     65     return os.path.join(config.path, 'daemon')
     66 
     67 
     68 def remove_lockfile(lockfile):
     69     os.unlink(lockfile)
     70 
     71 
     72 def get_file_descriptor(config: SimpleConfig):
     73     '''Tries to create the lockfile, using O_EXCL to
     74     prevent races.  If it succeeds it returns the FD.
     75     Otherwise try and connect to the server specified in the lockfile.
     76     If this succeeds, the server is returned.  Otherwise remove the
     77     lockfile and try again.'''
     78     lockfile = get_lockfile(config)
     79     while True:
     80         try:
     81             return os.open(lockfile, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644)
     82         except OSError:
     83             pass
     84         try:
     85             request(config, 'ping')
     86             return None
     87         except DaemonNotRunning:
     88             # Couldn't connect; remove lockfile and try again.
     89             remove_lockfile(lockfile)
     90 
     91 
     92 
     93 def request(config: SimpleConfig, endpoint, args=(), timeout=60):
     94     lockfile = get_lockfile(config)
     95     while True:
     96         create_time = None
     97         try:
     98             with open(lockfile) as f:
     99                 (host, port), create_time = ast.literal_eval(f.read())
    100         except Exception:
    101             raise DaemonNotRunning()
    102         rpc_user, rpc_password = get_rpc_credentials(config)
    103         server_url = 'http://%s:%d' % (host, port)
    104         auth = aiohttp.BasicAuth(login=rpc_user, password=rpc_password)
    105         loop = asyncio.get_event_loop()
    106         async def request_coroutine():
    107             async with aiohttp.ClientSession(auth=auth) as session:
    108                 c = util.JsonRPCClient(session, server_url)
    109                 return await c.request(endpoint, *args)
    110         try:
    111             fut = asyncio.run_coroutine_threadsafe(request_coroutine(), loop)
    112             return fut.result(timeout=timeout)
    113         except aiohttp.client_exceptions.ClientConnectorError as e:
    114             _logger.info(f"failed to connect to JSON-RPC server {e}")
    115             if not create_time or create_time < time.time() - 1.0:
    116                 raise DaemonNotRunning()
    117         # Sleep a bit and try again; it might have just been started
    118         time.sleep(1.0)
    119 
    120 
    121 def get_rpc_credentials(config: SimpleConfig) -> Tuple[str, str]:
    122     rpc_user = config.get('rpcuser', None)
    123     rpc_password = config.get('rpcpassword', None)
    124     if rpc_user == '':
    125         rpc_user = None
    126     if rpc_password == '':
    127         rpc_password = None
    128     if rpc_user is None or rpc_password is None:
    129         rpc_user = 'user'
    130         bits = 128
    131         nbytes = bits // 8 + (bits % 8 > 0)
    132         pw_int = randrange(pow(2, bits))
    133         pw_b64 = b64encode(
    134             pw_int.to_bytes(nbytes, 'big'), b'-_')
    135         rpc_password = to_string(pw_b64, 'ascii')
    136         config.set_key('rpcuser', rpc_user)
    137         config.set_key('rpcpassword', rpc_password, save=True)
    138     return rpc_user, rpc_password
    139 
    140 
    141 class AuthenticationError(Exception):
    142     pass
    143 
    144 class AuthenticationInvalidOrMissing(AuthenticationError):
    145     pass
    146 
    147 class AuthenticationCredentialsInvalid(AuthenticationError):
    148     pass
    149 
    150 class AuthenticatedServer(Logger):
    151 
    152     def __init__(self, rpc_user, rpc_password):
    153         Logger.__init__(self)
    154         self.rpc_user = rpc_user
    155         self.rpc_password = rpc_password
    156         self.auth_lock = asyncio.Lock()
    157         self._methods = {}  # type: Dict[str, Callable]
    158 
    159     def register_method(self, f):
    160         assert f.__name__ not in self._methods, f"name collision for {f.__name__}"
    161         self._methods[f.__name__] = f
    162 
    163     async def authenticate(self, headers):
    164         if self.rpc_password == '':
    165             # RPC authentication is disabled
    166             return
    167         auth_string = headers.get('Authorization', None)
    168         if auth_string is None:
    169             raise AuthenticationInvalidOrMissing('CredentialsMissing')
    170         basic, _, encoded = auth_string.partition(' ')
    171         if basic != 'Basic':
    172             raise AuthenticationInvalidOrMissing('UnsupportedType')
    173         encoded = to_bytes(encoded, 'utf8')
    174         credentials = to_string(b64decode(encoded), 'utf8')
    175         username, _, password = credentials.partition(':')
    176         if not (constant_time_compare(username, self.rpc_user)
    177                 and constant_time_compare(password, self.rpc_password)):
    178             await asyncio.sleep(0.050)
    179             raise AuthenticationCredentialsInvalid('Invalid Credentials')
    180 
    181     async def handle(self, request):
    182         async with self.auth_lock:
    183             try:
    184                 await self.authenticate(request.headers)
    185             except AuthenticationInvalidOrMissing:
    186                 return web.Response(headers={"WWW-Authenticate": "Basic realm=Electrum"},
    187                                     text='Unauthorized', status=401)
    188             except AuthenticationCredentialsInvalid:
    189                 return web.Response(text='Forbidden', status=403)
    190         try:
    191             request = await request.text()
    192             request = json.loads(request)
    193             method = request['method']
    194             _id = request['id']
    195             params = request.get('params', [])  # type: Union[Sequence, Mapping]
    196             if method not in self._methods:
    197                 raise Exception(f"attempting to use unregistered method: {method}")
    198             f = self._methods[method]
    199         except Exception as e:
    200             self.logger.exception("invalid request")
    201             return web.Response(text='Invalid Request', status=500)
    202         response = {
    203             'id': _id,
    204             'jsonrpc': '2.0',
    205         }
    206         try:
    207             if isinstance(params, dict):
    208                 response['result'] = await f(**params)
    209             else:
    210                 response['result'] = await f(*params)
    211         except BaseException as e:
    212             self.logger.exception("internal error while executing RPC")
    213             response['error'] = {
    214                 'code': 1,
    215                 'message': str(e),
    216             }
    217         return web.json_response(response)
    218 
    219 
    220 class CommandsServer(AuthenticatedServer):
    221 
    222     def __init__(self, daemon, fd):
    223         rpc_user, rpc_password = get_rpc_credentials(daemon.config)
    224         AuthenticatedServer.__init__(self, rpc_user, rpc_password)
    225         self.daemon = daemon
    226         self.fd = fd
    227         self.config = daemon.config
    228         self.host = self.config.get('rpchost', '127.0.0.1')
    229         self.port = self.config.get('rpcport', 0)
    230         self.app = web.Application()
    231         self.app.router.add_post("/", self.handle)
    232         self.register_method(self.ping)
    233         self.register_method(self.gui)
    234         self.cmd_runner = Commands(config=self.config, network=self.daemon.network, daemon=self.daemon)
    235         for cmdname in known_commands:
    236             self.register_method(getattr(self.cmd_runner, cmdname))
    237         self.register_method(self.run_cmdline)
    238 
    239     async def run(self):
    240         self.runner = web.AppRunner(self.app)
    241         await self.runner.setup()
    242         site = web.TCPSite(self.runner, self.host, self.port)
    243         await site.start()
    244         socket = site._server.sockets[0]
    245         os.write(self.fd, bytes(repr((socket.getsockname(), time.time())), 'utf8'))
    246         os.close(self.fd)
    247 
    248     async def ping(self):
    249         return True
    250 
    251     async def gui(self, config_options):
    252         if self.daemon.gui_object:
    253             if hasattr(self.daemon.gui_object, 'new_window'):
    254                 path = self.config.get_wallet_path(use_gui_last_wallet=True)
    255                 self.daemon.gui_object.new_window(path, config_options.get('url'))
    256                 response = "ok"
    257             else:
    258                 response = "error: current GUI does not support multiple windows"
    259         else:
    260             response = "Error: Electrum is running in daemon mode. Please stop the daemon first."
    261         return response
    262 
    263     async def run_cmdline(self, config_options):
    264         cmdname = config_options['cmd']
    265         cmd = known_commands[cmdname]
    266         # arguments passed to function
    267         args = [config_options.get(x) for x in cmd.params]
    268         # decode json arguments
    269         args = [json_decode(i) for i in args]
    270         # options
    271         kwargs = {}
    272         for x in cmd.options:
    273             kwargs[x] = config_options.get(x)
    274         if 'wallet_path' in cmd.options:
    275             kwargs['wallet_path'] = config_options.get('wallet_path')
    276         elif 'wallet' in cmd.options:
    277             kwargs['wallet'] = config_options.get('wallet_path')
    278         func = getattr(self.cmd_runner, cmd.name)
    279         # fixme: not sure how to retrieve message in jsonrpcclient
    280         try:
    281             result = await func(*args, **kwargs)
    282         except Exception as e:
    283             result = {'error':str(e)}
    284         return result
    285 
    286 
    287 class WatchTowerServer(AuthenticatedServer):
    288 
    289     def __init__(self, network, netaddress):
    290         self.addr = netaddress
    291         self.config = network.config
    292         self.network = network
    293         watchtower_user = self.config.get('watchtower_user', '')
    294         watchtower_password = self.config.get('watchtower_password', '')
    295         AuthenticatedServer.__init__(self, watchtower_user, watchtower_password)
    296         self.lnwatcher = network.local_watchtower
    297         self.app = web.Application()
    298         self.app.router.add_post("/", self.handle)
    299         self.register_method(self.get_ctn)
    300         self.register_method(self.add_sweep_tx)
    301 
    302     async def run(self):
    303         self.runner = web.AppRunner(self.app)
    304         await self.runner.setup()
    305         site = web.TCPSite(self.runner, host=str(self.addr.host), port=self.addr.port, ssl_context=self.config.get_ssl_context())
    306         await site.start()
    307 
    308     async def get_ctn(self, *args):
    309         return await self.lnwatcher.sweepstore.get_ctn(*args)
    310 
    311     async def add_sweep_tx(self, *args):
    312         return await self.lnwatcher.sweepstore.add_sweep_tx(*args)
    313 
    314 
    315 class PayServer(Logger):
    316 
    317     def __init__(self, daemon: 'Daemon', netaddress):
    318         Logger.__init__(self)
    319         self.addr = netaddress
    320         self.daemon = daemon
    321         self.config = daemon.config
    322         self.pending = defaultdict(asyncio.Event)
    323         util.register_callback(self.on_payment, ['request_status'])
    324 
    325     @property
    326     def wallet(self):
    327         # FIXME specify wallet somehow?
    328         return list(self.daemon.get_wallets().values())[0]
    329 
    330     async def on_payment(self, evt, wallet, key, status):
    331         if status == PR_PAID:
    332             self.pending[key].set()
    333 
    334     @ignore_exceptions
    335     @log_exceptions
    336     async def run(self):
    337         root = self.config.get('payserver_root', '/r')
    338         app = web.Application()
    339         app.add_routes([web.get('/api/get_invoice', self.get_request)])
    340         app.add_routes([web.get('/api/get_status', self.get_status)])
    341         app.add_routes([web.get('/bip70/{key}.bip70', self.get_bip70_request)])
    342         app.add_routes([web.static(root, os.path.join(os.path.dirname(__file__), 'www'))])
    343         if self.config.get('payserver_allow_create_invoice'):
    344             app.add_routes([web.post('/api/create_invoice', self.create_request)])
    345         runner = web.AppRunner(app)
    346         await runner.setup()
    347         site = web.TCPSite(runner, host=str(self.addr.host), port=self.addr.port, ssl_context=self.config.get_ssl_context())
    348         await site.start()
    349 
    350     async def create_request(self, request):
    351         params = await request.post()
    352         wallet = self.wallet
    353         if 'amount_sat' not in params or not params['amount_sat'].isdigit():
    354             raise web.HTTPUnsupportedMediaType()
    355         amount = int(params['amount_sat'])
    356         message = params['message'] or "donation"
    357         payment_hash = wallet.lnworker.add_request(
    358             amount_sat=amount,
    359             message=message,
    360             expiry=3600)
    361         key = payment_hash.hex()
    362         raise web.HTTPFound(self.root + '/pay?id=' + key)
    363 
    364     async def get_request(self, r):
    365         key = r.query_string
    366         request = self.wallet.get_formatted_request(key)
    367         return web.json_response(request)
    368 
    369     async def get_bip70_request(self, r):
    370         from .paymentrequest import make_request
    371         key = r.match_info['key']
    372         request = self.wallet.get_request(key)
    373         if not request:
    374             return web.HTTPNotFound()
    375         pr = make_request(self.config, request)
    376         return web.Response(body=pr.SerializeToString(), content_type='application/bitcoin-paymentrequest')
    377 
    378     async def get_status(self, request):
    379         ws = web.WebSocketResponse()
    380         await ws.prepare(request)
    381         key = request.query_string
    382         info = self.wallet.get_formatted_request(key)
    383         if not info:
    384             await ws.send_str('unknown invoice')
    385             await ws.close()
    386             return ws
    387         if info.get('status') == PR_PAID:
    388             await ws.send_str(f'paid')
    389             await ws.close()
    390             return ws
    391         if info.get('status') == PR_EXPIRED:
    392             await ws.send_str(f'expired')
    393             await ws.close()
    394             return ws
    395         while True:
    396             try:
    397                 await asyncio.wait_for(self.pending[key].wait(), 1)
    398                 break
    399             except asyncio.TimeoutError:
    400                 # send data on the websocket, to keep it alive
    401                 await ws.send_str('waiting')
    402         await ws.send_str('paid')
    403         await ws.close()
    404         return ws
    405 
    406 
    407 
    408 class Daemon(Logger):
    409 
    410     network: Optional[Network]
    411     gui_object: Optional[Union['gui.qt.ElectrumGui', 'gui.kivy.ElectrumGui']]
    412 
    413     @profiler
    414     def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
    415         Logger.__init__(self)
    416         self.running = False
    417         self.running_lock = threading.Lock()
    418         self.config = config
    419         if fd is None and listen_jsonrpc:
    420             fd = get_file_descriptor(config)
    421             if fd is None:
    422                 raise Exception('failed to lock daemon; already running?')
    423         if 'wallet_path' in config.cmdline_options:
    424             self.logger.warning("Ignoring parameter 'wallet_path' for daemon. "
    425                                 "Use the load_wallet command instead.")
    426         self.asyncio_loop = asyncio.get_event_loop()
    427         self.network = None
    428         if not config.get('offline'):
    429             self.network = Network(config, daemon=self)
    430         self.fx = FxThread(config, self.network)
    431         self.gui_object = None
    432         # path -> wallet;   make sure path is standardized.
    433         self._wallets = {}  # type: Dict[str, Abstract_Wallet]
    434         daemon_jobs = []
    435         # Setup commands server
    436         self.commands_server = None
    437         if listen_jsonrpc:
    438             self.commands_server = CommandsServer(self, fd)
    439             daemon_jobs.append(self.commands_server.run())
    440         # pay server
    441         self.pay_server = None
    442         payserver_address = self.config.get_netaddress('payserver_address')
    443         if not config.get('offline') and payserver_address:
    444             self.pay_server = PayServer(self, payserver_address)
    445             daemon_jobs.append(self.pay_server.run())
    446         # server-side watchtower
    447         self.watchtower = None
    448         watchtower_address = self.config.get_netaddress('watchtower_address')
    449         if not config.get('offline') and watchtower_address:
    450             self.watchtower = WatchTowerServer(self.network, watchtower_address)
    451             daemon_jobs.append(self.watchtower.run)
    452         if self.network:
    453             self.network.start(jobs=[self.fx.run])
    454             # prepare lightning functionality, also load channel db early
    455             if self.config.get('use_gossip', False):
    456                 self.network.start_gossip()
    457 
    458         self.taskgroup = TaskGroup()
    459         asyncio.run_coroutine_threadsafe(self._run(jobs=daemon_jobs), self.asyncio_loop)
    460 
    461     @log_exceptions
    462     async def _run(self, jobs: Iterable = None):
    463         if jobs is None:
    464             jobs = []
    465         self.logger.info("starting taskgroup.")
    466         try:
    467             async with self.taskgroup as group:
    468                 [await group.spawn(job) for job in jobs]
    469                 await group.spawn(asyncio.Event().wait)  # run forever (until cancel)
    470         except asyncio.CancelledError:
    471             raise
    472         except Exception as e:
    473             self.logger.exception("taskgroup died.")
    474         finally:
    475             self.logger.info("taskgroup stopped.")
    476 
    477     def load_wallet(self, path, password, *, manual_upgrades=True) -> Optional[Abstract_Wallet]:
    478         path = standardize_path(path)
    479         # wizard will be launched if we return
    480         if path in self._wallets:
    481             wallet = self._wallets[path]
    482             return wallet
    483         storage = WalletStorage(path)
    484         if not storage.file_exists():
    485             return
    486         if storage.is_encrypted():
    487             if not password:
    488                 return
    489             storage.decrypt(password)
    490         # read data, pass it to db
    491         db = WalletDB(storage.read(), manual_upgrades=manual_upgrades)
    492         if db.requires_split():
    493             return
    494         if db.requires_upgrade():
    495             return
    496         if db.get_action():
    497             return
    498         wallet = Wallet(db, storage, config=self.config)
    499         wallet.start_network(self.network)
    500         self._wallets[path] = wallet
    501         return wallet
    502 
    503     def add_wallet(self, wallet: Abstract_Wallet) -> None:
    504         path = wallet.storage.path
    505         path = standardize_path(path)
    506         self._wallets[path] = wallet
    507 
    508     def get_wallet(self, path: str) -> Optional[Abstract_Wallet]:
    509         path = standardize_path(path)
    510         return self._wallets.get(path)
    511 
    512     def get_wallets(self) -> Dict[str, Abstract_Wallet]:
    513         return dict(self._wallets)  # copy
    514 
    515     def delete_wallet(self, path: str) -> bool:
    516         self.stop_wallet(path)
    517         if os.path.exists(path):
    518             os.unlink(path)
    519             return True
    520         return False
    521 
    522     def stop_wallet(self, path: str) -> bool:
    523         """Returns True iff a wallet was found."""
    524         path = standardize_path(path)
    525         wallet = self._wallets.pop(path, None)
    526         if not wallet:
    527             return False
    528         fut = asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop)
    529         fut.result()
    530         return True
    531 
    532     def run_daemon(self):
    533         self.running = True
    534         try:
    535             while self.is_running():
    536                 time.sleep(0.1)
    537         except KeyboardInterrupt:
    538             self.running = False
    539         self.on_stop()
    540 
    541     def is_running(self):
    542         with self.running_lock:
    543             return self.running and not self.taskgroup.closed()
    544 
    545     def stop(self):
    546         with self.running_lock:
    547             self.running = False
    548 
    549     def on_stop(self):
    550         self.logger.info("on_stop() entered. initiating shutdown")
    551         if self.gui_object:
    552             self.gui_object.stop()
    553 
    554         @log_exceptions
    555         async def stop_async():
    556             self.logger.info("stopping all wallets")
    557             async with TaskGroup() as group:
    558                 for k, wallet in self._wallets.items():
    559                     await group.spawn(wallet.stop())
    560             self.logger.info("stopping network and taskgroup")
    561             async with ignore_after(2):
    562                 async with TaskGroup() as group:
    563                     if self.network:
    564                         await group.spawn(self.network.stop(full_shutdown=True))
    565                     await group.spawn(self.taskgroup.cancel_remaining())
    566 
    567         fut = asyncio.run_coroutine_threadsafe(stop_async(), self.asyncio_loop)
    568         fut.result()
    569         self.logger.info("removing lockfile")
    570         remove_lockfile(get_lockfile(self.config))
    571         self.logger.info("stopped")
    572 
    573     def run_gui(self, config, plugins):
    574         threading.current_thread().setName('GUI')
    575         gui_name = config.get('gui', 'qt')
    576         if gui_name in ['lite', 'classic']:
    577             gui_name = 'qt'
    578         self.logger.info(f'launching GUI: {gui_name}')
    579         try:
    580             gui = __import__('electrum.gui.' + gui_name, fromlist=['electrum'])
    581             self.gui_object = gui.ElectrumGui(config, self, plugins)
    582             self.gui_object.main()
    583         except BaseException as e:
    584             self.logger.error(f'GUI raised exception: {repr(e)}. shutting down.')
    585             raise
    586         finally:
    587             # app will exit now
    588             self.on_stop()