commit 3c019c2f9c4d2fdefe52d84444632b54d421e140
parent ce88b36e81a533810b08ccf8796120951265da8a
Author: SomberNight <somber.night@protonmail.com>
Date: Tue, 9 Mar 2021 17:52:36 +0100
daemon/wallet/network: make stop() methods async
Diffstat:
15 files changed, 123 insertions(+), 70 deletions(-)
diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py
@@ -28,6 +28,8 @@ import itertools
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence, List
+from aiorpcx import TaskGroup
+
from . import bitcoin, util
from .bitcoin import COINBASE_MATURITY
from .util import profiler, bfh, TxMinedInfo, UnrelatedTransactionException
@@ -197,16 +199,19 @@ class AddressSynchronizer(Logger):
def on_blockchain_updated(self, event, *args):
self._get_addr_balance_cache = {} # invalidate cache
- def stop(self):
+ async def stop(self):
if self.network:
- if self.synchronizer:
- asyncio.run_coroutine_threadsafe(self.synchronizer.stop(), self.network.asyncio_loop)
+ try:
+ async with TaskGroup() as group:
+ if self.synchronizer:
+ await group.spawn(self.synchronizer.stop())
+ if self.verifier:
+ await group.spawn(self.verifier.stop())
+ finally: # even if we get cancelled
self.synchronizer = None
- if self.verifier:
- asyncio.run_coroutine_threadsafe(self.verifier.stop(), self.network.asyncio_loop)
self.verifier = None
- util.unregister_callback(self.on_blockchain_updated)
- self.db.put('stored_height', self.get_local_height())
+ util.unregister_callback(self.on_blockchain_updated)
+ self.db.put('stored_height', self.get_local_height())
def add_address(self, address):
if not self.db.get_addr_history(address):
diff --git a/electrum/daemon.py b/electrum/daemon.py
@@ -29,7 +29,7 @@ import time
import traceback
import sys
import threading
-from typing import Dict, Optional, Tuple, Iterable, Callable, Union, Sequence, Mapping
+from typing import Dict, Optional, Tuple, Iterable, Callable, Union, Sequence, Mapping, TYPE_CHECKING
from base64 import b64decode, b64encode
from collections import defaultdict
import concurrent
@@ -38,7 +38,7 @@ import json
import aiohttp
from aiohttp import web, client_exceptions
-from aiorpcx import TaskGroup
+from aiorpcx import TaskGroup, timeout_after, TaskTimeout
from . import util
from .network import Network
@@ -53,6 +53,9 @@ from .simple_config import SimpleConfig
from .exchange_rate import FxThread
from .logging import get_logger, Logger
+if TYPE_CHECKING:
+ from electrum import gui
+
_logger = get_logger(__name__)
@@ -407,6 +410,7 @@ class PayServer(Logger):
class Daemon(Logger):
network: Optional[Network]
+ gui_object: Optional[Union['gui.qt.ElectrumGui', 'gui.kivy.ElectrumGui']]
@profiler
def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
@@ -523,7 +527,8 @@ class Daemon(Logger):
wallet = self._wallets.pop(path, None)
if not wallet:
return False
- wallet.stop()
+ fut = asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop)
+ fut.result()
return True
def run_daemon(self):
@@ -544,20 +549,28 @@ class Daemon(Logger):
self.running = False
def on_stop(self):
+ self.logger.info("on_stop() entered. initiating shutdown")
if self.gui_object:
self.gui_object.stop()
- # stop network/wallets
- for k, wallet in self._wallets.items():
- wallet.stop()
- if self.network:
- self.logger.info("shutting down network")
- self.network.stop()
- self.logger.info("stopping taskgroup")
- fut = asyncio.run_coroutine_threadsafe(self.taskgroup.cancel_remaining(), self.asyncio_loop)
- try:
- fut.result(timeout=2)
- except (concurrent.futures.TimeoutError, concurrent.futures.CancelledError, asyncio.CancelledError):
- pass
+
+ @log_exceptions
+ async def stop_async():
+ self.logger.info("stopping all wallets")
+ async with TaskGroup() as group:
+ for k, wallet in self._wallets.items():
+ await group.spawn(wallet.stop())
+ self.logger.info("stopping network and taskgroup")
+ try:
+ async with timeout_after(2):
+ async with TaskGroup() as group:
+ if self.network:
+ await group.spawn(self.network.stop(full_shutdown=True))
+ await group.spawn(self.taskgroup.cancel_remaining())
+ except TaskTimeout:
+ pass
+
+ fut = asyncio.run_coroutine_threadsafe(stop_async(), self.asyncio_loop)
+ fut.result()
self.logger.info("removing lockfile")
remove_lockfile(get_lockfile(self.config))
self.logger.info("stopped")
diff --git a/electrum/gui/__init__.py b/electrum/gui/__init__.py
@@ -3,3 +3,9 @@
# The Wallet object is instantiated by the GUI
# Notifications about network events are sent to the GUI by using network.register_callback()
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from . import qt
+ from . import kivy
diff --git a/electrum/gui/kivy/main_window.py b/electrum/gui/kivy/main_window.py
@@ -190,7 +190,8 @@ class ElectrumWindow(App, Logger):
if self.use_gossip:
self.network.start_gossip()
else:
- self.network.stop_gossip()
+ self.network.run_from_another_thread(
+ self.network.stop_gossip())
android_backups = BooleanProperty(False)
def on_android_backups(self, instance, x):
diff --git a/electrum/gui/qt/settings_dialog.py b/electrum/gui/qt/settings_dialog.py
@@ -141,7 +141,8 @@ channels graph and compute payment path locally, instead of using trampoline pay
if use_gossip:
self.window.network.start_gossip()
else:
- self.window.network.stop_gossip()
+ self.window.network.run_from_another_thread(
+ self.window.network.stop_gossip())
util.trigger_callback('ln_gossip_sync_progress')
# FIXME: update all wallet windows
util.trigger_callback('channels_updated', self.wallet)
diff --git a/electrum/interface.py b/electrum/interface.py
@@ -695,7 +695,7 @@ class Interface(Logger):
# We give up after a while and just abort the connection.
# Note: specifically if the server is running Fulcrum, waiting seems hopeless,
# the connection must be aborted (see https://github.com/cculianu/Fulcrum/issues/76)
- force_after = 2 # seconds
+ force_after = 1 # seconds
if self.session:
await self.session.close(force_after=force_after)
# monitor_connection will cancel tasks
diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
@@ -147,8 +147,8 @@ class LNWatcher(AddressSynchronizer):
# status gets populated when we run
self.channel_status = {}
- def stop(self):
- super().stop()
+ async def stop(self):
+ await super().stop()
util.unregister_callback(self.on_network_update)
def get_channel_status(self, outpoint):
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -311,11 +311,11 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
self._add_peers_from_config()
asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
- def stop(self):
+ async def stop(self):
if self.listen_server:
- self.network.asyncio_loop.call_soon_threadsafe(self.listen_server.close)
- asyncio.run_coroutine_threadsafe(self.taskgroup.cancel_remaining(), self.network.asyncio_loop)
+ self.listen_server.close()
util.unregister_callback(self.on_proxy_changed)
+ await self.taskgroup.cancel_remaining()
def _add_peers_from_config(self):
peer_list = self.config.get('lightning_peers', [])
@@ -704,9 +704,9 @@ class LNWallet(LNWorker):
tg_coro = self.taskgroup.spawn(coro)
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
- def stop(self):
- super().stop()
- self.lnwatcher.stop()
+ async def stop(self):
+ await super().stop()
+ await self.lnwatcher.stop()
self.lnwatcher = None
def peer_closed(self, peer):
diff --git a/electrum/network.py b/electrum/network.py
@@ -252,6 +252,11 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
default_server: ServerAddr
_recent_servers: List[ServerAddr]
+ channel_blacklist: 'ChannelBlackList'
+ channel_db: Optional['ChannelDB'] = None
+ lngossip: Optional['LNGossip'] = None
+ local_watchtower: Optional['WatchTower'] = None
+
def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None):
global _INSTANCE
assert _INSTANCE is None, "Network is a singleton!"
@@ -344,9 +349,6 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
# lightning network
self.channel_blacklist = ChannelBlackList()
- self.channel_db = None # type: Optional[ChannelDB]
- self.lngossip = None # type: Optional[LNGossip]
- self.local_watchtower = None # type: Optional[WatchTower]
if self.config.get('run_local_watchtower', False):
from . import lnwatcher
self.local_watchtower = lnwatcher.WatchTower(self)
@@ -373,11 +375,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
self.lngossip = lnworker.LNGossip()
self.lngossip.start_network(self)
- def stop_gossip(self):
+ async def stop_gossip(self, *, full_shutdown: bool = False):
if self.lngossip:
- self.lngossip.stop()
+ await self.lngossip.stop()
self.lngossip = None
self.channel_db.stop()
+ if full_shutdown:
+ await self.channel_db.stopped_event.wait()
self.channel_db = None
def run_from_another_thread(self, coro, *, timeout=None):
@@ -623,7 +627,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
self.auto_connect = net_params.auto_connect
if self.proxy != proxy or self.oneserver != net_params.oneserver:
# Restart the network defaulting to the given server
- await self._stop()
+ await self.stop(full_shutdown=False)
self.default_server = server
await self._start()
elif self.default_server != server:
@@ -1217,13 +1221,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
asyncio.run_coroutine_threadsafe(self._start(), self.asyncio_loop)
@log_exceptions
- async def _stop(self, full_shutdown=False):
+ async def stop(self, *, full_shutdown: bool = True):
self.logger.info("stopping network")
try:
# note: cancel_remaining ~cannot be cancelled, it suppresses CancelledError
- await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2)
+ await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=1)
except (asyncio.TimeoutError, asyncio.CancelledError) as e:
- self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}")
+ self.logger.info(f"exc during taskgroup cancellation: {repr(e)}")
self.taskgroup = None
self.interface = None
self.interfaces = {}
@@ -1231,13 +1235,8 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
self._closing_ifaces.clear()
if not full_shutdown:
util.trigger_callback('network_updated')
-
- def stop(self):
- assert self._loop_thread != threading.current_thread(), 'must not be called from network thread'
- fut = asyncio.run_coroutine_threadsafe(self._stop(full_shutdown=True), self.asyncio_loop)
- try:
- fut.result(timeout=2)
- except (concurrent.futures.TimeoutError, concurrent.futures.CancelledError): pass
+ if full_shutdown:
+ await self.stop_gossip(full_shutdown=full_shutdown)
async def _ensure_there_is_a_main_interface(self):
if self.is_connected():
diff --git a/electrum/sql_db.py b/electrum/sql_db.py
@@ -25,6 +25,7 @@ class SqlDB(Logger):
Logger.__init__(self)
self.asyncio_loop = asyncio_loop
self.stopping = False
+ self.stopped_event = asyncio.Event()
self.path = path
test_read_write_permissions(path)
self.commit_interval = commit_interval
@@ -65,6 +66,8 @@ class SqlDB(Logger):
# write
self.conn.commit()
self.conn.close()
+
+ self.asyncio_loop.call_soon_threadsafe(self.stopped_event.set)
self.logger.info("SQL thread terminated")
def create_database(self):
diff --git a/electrum/tests/test_storage_upgrade.py b/electrum/tests/test_storage_upgrade.py
@@ -3,10 +3,12 @@ import tempfile
import os
import json
from typing import Optional
+import asyncio
from electrum.wallet_db import WalletDB
from electrum.wallet import Wallet
from electrum import constants
+from electrum import util
from .test_wallet import WalletTestCase
@@ -15,6 +17,15 @@ from .test_wallet import WalletTestCase
# TODO hw wallet with client version 2.6.x (single-, and multiacc)
class TestStorageUpgrade(WalletTestCase):
+ def setUp(self):
+ super().setUp()
+ self.asyncio_loop, self._stop_loop, self._loop_thread = util.create_and_start_event_loop()
+
+ def tearDown(self):
+ super().tearDown()
+ self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
+ self._loop_thread.join(timeout=1)
+
def testnet_wallet(func):
# note: it's ok to modify global network constants in subclasses of SequentialTestCase
def wrapper(self, *args, **kwargs):
@@ -281,7 +292,7 @@ class TestStorageUpgrade(WalletTestCase):
# to simulate ks.opportunistically_fill_in_missing_info_from_device():
ks._root_fingerprint = "deadbeef"
ks.is_requesting_to_be_rewritten_to_wallet_file = True
- wallet.stop()
+ asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop).result()
def test_upgrade_from_client_2_9_3_importedkeys_keystore_changes(self):
# see #6401
@@ -292,7 +303,7 @@ class TestStorageUpgrade(WalletTestCase):
["p2wpkh:L1cgMEnShp73r9iCukoPE3MogLeueNYRD9JVsfT1zVHyPBR3KqBY"],
password=None
)
- wallet.stop()
+ asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop).result()
@testnet_wallet
def test_upgrade_from_client_3_3_8_xpub_with_realistic_history(self):
diff --git a/electrum/tests/test_wallet.py b/electrum/tests/test_wallet.py
@@ -5,8 +5,9 @@ import os
import json
from decimal import Decimal
import time
-
from io import StringIO
+import asyncio
+
from electrum.storage import WalletStorage
from electrum.wallet_db import FINAL_SEED_VERSION
from electrum.wallet import (Abstract_Wallet, Standard_Wallet, create_new_wallet,
@@ -16,6 +17,7 @@ from electrum.util import TxMinedInfo, InvalidPassword
from electrum.bitcoin import COIN
from electrum.wallet_db import WalletDB
from electrum.simple_config import SimpleConfig
+from electrum import util
from . import ElectrumTestCase
@@ -237,6 +239,15 @@ class TestCreateRestoreWallet(WalletTestCase):
class TestWalletPassword(WalletTestCase):
+ def setUp(self):
+ super().setUp()
+ self.asyncio_loop, self._stop_loop, self._loop_thread = util.create_and_start_event_loop()
+
+ def tearDown(self):
+ super().tearDown()
+ self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
+ self._loop_thread.join(timeout=1)
+
def test_update_password_of_imported_wallet(self):
wallet_str = '{"addr_history":{"1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr":[],"15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA":[],"1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6":[]},"addresses":{"change":[],"receiving":["1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr","1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6","15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA"]},"keystore":{"keypairs":{"0344b1588589958b0bcab03435061539e9bcf54677c104904044e4f8901f4ebdf5":"L2sED74axVXC4H8szBJ4rQJrkfem7UMc6usLCPUoEWxDCFGUaGUM","0389508c13999d08ffae0f434a085f4185922d64765c0bff2f66e36ad7f745cc5f":"L3Gi6EQLvYw8gEEUckmqawkevfj9s8hxoQDFveQJGZHTfyWnbk1U","04575f52b82f159fa649d2a4c353eb7435f30206f0a6cb9674fbd659f45082c37d559ffd19bea9c0d3b7dcc07a7b79f4cffb76026d5d4dff35341efe99056e22d2":"5JyVyXU1LiRXATvRTQvR9Kp8Rx1X84j2x49iGkjSsXipydtByUq"},"type":"imported"},"pruned_txo":{},"seed_version":13,"stored_height":-1,"transactions":{},"tx_fees":{},"txi":{},"txo":{},"use_encryption":false,"verified_tx3":{},"wallet_type":"standard","winpos-qt":[100,100,840,405]}'
db = WalletDB(wallet_str, manual_upgrades=False)
@@ -273,7 +284,7 @@ class TestWalletPassword(WalletTestCase):
db = WalletDB(wallet_str, manual_upgrades=False)
storage = WalletStorage(self.wallet_path)
wallet = Wallet(db, storage, config=self.config)
- wallet.stop()
+ asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop).result()
storage = WalletStorage(self.wallet_path)
# if storage.is_encrypted():
diff --git a/electrum/util.py b/electrum/util.py
@@ -1205,11 +1205,9 @@ class NetworkJobOnDefaultServer(Logger, ABC):
if taskgroup != self.taskgroup:
raise asyncio.CancelledError()
- async def stop(self):
- unregister_callback(self._restart)
- await self._stop()
-
- async def _stop(self):
+ async def stop(self, *, full_shutdown: bool = True):
+ if full_shutdown:
+ unregister_callback(self._restart)
await self.taskgroup.cancel_remaining()
@log_exceptions
@@ -1219,7 +1217,7 @@ class NetworkJobOnDefaultServer(Logger, ABC):
return # we should get called again soon
async with self._restart_lock:
- await self._stop()
+ await self.stop(full_shutdown=False)
self._reset()
await self._start(interface)
diff --git a/electrum/wallet.py b/electrum/wallet.py
@@ -46,7 +46,7 @@ import itertools
import threading
import enum
-from aiorpcx import TaskGroup
+from aiorpcx import TaskGroup, timeout_after, TaskTimeout
from .i18n import _
from .bip32 import BIP32Node, convert_bip32_intpath_to_strpath, convert_bip32_path_to_list_of_uint32
@@ -353,15 +353,21 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
ln_xprv = node.to_xprv()
self.db.put('lightning_privkey2', ln_xprv)
- def stop(self):
- super().stop()
- if any([ks.is_requesting_to_be_rewritten_to_wallet_file for ks in self.get_keystores()]):
- self.save_keystore()
- if self.network:
- if self.lnworker:
- self.lnworker.stop()
- self.lnworker = None
- self.save_db()
+ async def stop(self):
+ """Stop all networking and save DB to disk."""
+ try:
+ async with timeout_after(5):
+ await super().stop()
+ if self.network:
+ if self.lnworker:
+ await self.lnworker.stop()
+ self.lnworker = None
+ except TaskTimeout:
+ pass
+ finally: # even if we get cancelled
+ if any([ks.is_requesting_to_be_rewritten_to_wallet_file for ks in self.get_keystores()]):
+ self.save_keystore()
+ self.save_db()
def set_up_to_date(self, b):
super().set_up_to_date(b)
diff --git a/run_electrum b/run_electrum
@@ -345,7 +345,6 @@ def main():
print_stderr('unknown command:', uri)
sys.exit(1)
- # singleton
config = SimpleConfig(config_options)
if config.get('testnet'):