commit cd5453e4779c773a6b295a6483db6bdc3ab6a777
parent 150e27608b2d5eaef438f32bc16adc3805b5e331
Author: ThomasV <thomasv@electrum.org>
Date: Wed, 10 Oct 2018 20:46:17 +0200
Merge pull request #4753 from SomberNight/synchronizer_rewrite
restructure synchronizer
Diffstat:
9 files changed, 249 insertions(+), 178 deletions(-)
diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py
@@ -56,11 +56,9 @@ class AddressSynchronizer(PrintError):
def __init__(self, storage):
self.storage = storage
self.network = None
- # verifier (SPV) and synchronizer are started in start_threads
- self.synchronizer = None
- self.verifier = None
- self.sync_restart_lock = asyncio.Lock()
- self.group = None
+ # verifier (SPV) and synchronizer are started in start_network
+ self.synchronizer = None # type: Synchronizer
+ self.verifier = None # type: SPV
# locks: if you need to take multiple ones, acquire them in the order they are defined here!
self.lock = threading.RLock()
self.transaction_lock = threading.RLock()
@@ -143,45 +141,20 @@ class AddressSynchronizer(PrintError):
# add it in case it was previously unconfirmed
self.add_unverified_tx(tx_hash, tx_height)
- @aiosafe
- async def on_default_server_changed(self, event):
- async with self.sync_restart_lock:
- self.stop_threads(write_to_disk=False)
- await self._start_threads()
-
def start_network(self, network):
self.network = network
if self.network is not None:
- self.network.register_callback(self.on_default_server_changed, ['default_server_changed'])
- asyncio.run_coroutine_threadsafe(self._start_threads(), network.asyncio_loop)
-
- async def _start_threads(self):
- interface = self.network.interface
- if interface is None:
- return # we should get called again soon
-
- self.verifier = SPV(self.network, self)
- self.synchronizer = synchronizer = Synchronizer(self)
- assert self.group is None, 'group already exists'
- self.group = SilentTaskGroup()
-
- async def job():
- async with self.group as group:
- await group.spawn(self.verifier.main(group))
- await group.spawn(self.synchronizer.send_subscriptions(group))
- await group.spawn(self.synchronizer.handle_status(group))
- await group.spawn(self.synchronizer.main())
- # we are being cancelled now
- interface.session.unsubscribe(synchronizer.status_queue)
- await interface.group.spawn(job)
+ self.synchronizer = Synchronizer(self)
+ self.verifier = SPV(self.network, self)
def stop_threads(self, write_to_disk=True):
if self.network:
- self.synchronizer = None
- self.verifier = None
- if self.group:
- asyncio.run_coroutine_threadsafe(self.group.cancel_remaining(), self.network.asyncio_loop)
- self.group = None
+ if self.synchronizer:
+ asyncio.run_coroutine_threadsafe(self.synchronizer.stop(), self.network.asyncio_loop)
+ self.synchronizer = None
+ if self.verifier:
+ asyncio.run_coroutine_threadsafe(self.verifier.stop(), self.network.asyncio_loop)
+ self.verifier = None
self.storage.put('stored_height', self.get_local_height())
if write_to_disk:
self.save_transactions()
diff --git a/electrum/commands.py b/electrum/commands.py
@@ -40,7 +40,7 @@ from .bitcoin import is_address, hash_160, COIN, TYPE_ADDRESS
from .i18n import _
from .transaction import Transaction, multisig_script, TxOutput
from .paymentrequest import PR_PAID, PR_UNPAID, PR_UNKNOWN, PR_EXPIRED
-from .plugin import run_hook
+from .synchronizer import Notifier
known_commands = {}
@@ -635,21 +635,11 @@ class Commands:
self.wallet.remove_payment_request(k, self.config)
@command('n')
- def notify(self, address, URL):
+ def notify(self, address: str, URL: str):
"""Watch an address. Every time the address changes, a http POST is sent to the URL."""
- raise NotImplementedError() # TODO this method is currently broken
- def callback(x):
- import urllib.request
- headers = {'content-type':'application/json'}
- data = {'address':address, 'status':x.get('result')}
- serialized_data = util.to_bytes(json.dumps(data))
- try:
- req = urllib.request.Request(URL, serialized_data, headers)
- response_stream = urllib.request.urlopen(req, timeout=5)
- util.print_error('Got Response for %s' % address)
- except BaseException as e:
- util.print_error(str(e))
- self.network.subscribe_to_addresses([address], callback)
+ if not hasattr(self, "_notifier"):
+ self._notifier = Notifier(self.network)
+ self.network.run_from_another_thread(self._notifier.start_watching_queue.put((address, URL)))
return True
@command('wn')
diff --git a/electrum/interface.py b/electrum/interface.py
@@ -28,7 +28,7 @@ import ssl
import sys
import traceback
import asyncio
-from typing import Tuple, Union
+from typing import Tuple, Union, List
from collections import defaultdict
import aiorpcx
@@ -57,7 +57,7 @@ class NotificationSession(ClientSession):
# will catch the exception, count errors, and at some point disconnect
if isinstance(request, Notification):
params, result = request.args[:-1], request.args[-1]
- key = self.get_index(request.method, params)
+ key = self.get_hashable_key_for_rpc_call(request.method, params)
if key in self.subscriptions:
self.cache[key] = result
for queue in self.subscriptions[key]:
@@ -78,10 +78,10 @@ class NotificationSession(ClientSession):
except asyncio.TimeoutError as e:
raise RequestTimedOut('request timed out: {}'.format(args)) from e
- async def subscribe(self, method, params, queue):
+ async def subscribe(self, method: str, params: List, queue: asyncio.Queue):
# note: until the cache is written for the first time,
# each 'subscribe' call might make a request on the network.
- key = self.get_index(method, params)
+ key = self.get_hashable_key_for_rpc_call(method, params)
self.subscriptions[key].append(queue)
if key in self.cache:
result = self.cache[key]
@@ -99,7 +99,7 @@ class NotificationSession(ClientSession):
v.remove(queue)
@classmethod
- def get_index(cls, method, params):
+ def get_hashable_key_for_rpc_call(cls, method, params):
"""Hashable index for subscriptions and cache"""
return str(method) + repr(params)
@@ -141,7 +141,7 @@ class Interface(PrintError):
self._requested_chunks = set()
self.network = network
self._set_proxy(proxy)
- self.session = None
+ self.session = None # type: NotificationSession
self.tip_header = None
self.tip = 0
diff --git a/electrum/network.py b/electrum/network.py
@@ -847,3 +847,54 @@ class Network(PrintError):
await self.interface.group.spawn(self._request_fee_estimates, self.interface)
await asyncio.sleep(0.1)
+
+
+class NetworkJobOnDefaultServer(PrintError):
+ """An abstract base class for a job that runs on the main network
+ interface. Every time the main interface changes, the job is
+ restarted, and some of its internals are reset.
+ """
+ def __init__(self, network: Network):
+ asyncio.set_event_loop(network.asyncio_loop)
+ self.network = network
+ self.interface = None # type: Interface
+ self._restart_lock = asyncio.Lock()
+ self._reset()
+ asyncio.run_coroutine_threadsafe(self._restart(), network.asyncio_loop)
+ network.register_callback(self._restart, ['default_server_changed'])
+
+ def _reset(self):
+ """Initialise fields. Called every time the underlying
+ server connection changes.
+ """
+ self.group = SilentTaskGroup()
+
+ async def _start(self, interface):
+ self.interface = interface
+ await interface.group.spawn(self._start_tasks)
+
+ async def _start_tasks(self):
+ """Start tasks in self.group. Called every time the underlying
+ server connection changes.
+ """
+ raise NotImplementedError() # implemented by subclasses
+
+ async def stop(self):
+ await self.group.cancel_remaining()
+
+ @aiosafe
+ async def _restart(self, *args):
+ interface = self.network.interface
+ if interface is None:
+ return # we should get called again soon
+
+ async with self._restart_lock:
+ await self.stop()
+ self._reset()
+ await self._start(interface)
+
+ @property
+ def session(self):
+ s = self.interface.session
+ assert s is not None
+ return s
diff --git a/electrum/synchronizer.py b/electrum/synchronizer.py
@@ -24,12 +24,15 @@
# SOFTWARE.
import asyncio
import hashlib
+from typing import Dict, List
+from collections import defaultdict
from aiorpcx import TaskGroup, run_in_thread
from .transaction import Transaction
-from .util import bh2u, PrintError
+from .util import bh2u, make_aiohttp_session
from .bitcoin import address_to_scripthash
+from .network import NetworkJobOnDefaultServer
def history_status(h):
@@ -41,7 +44,68 @@ def history_status(h):
return bh2u(hashlib.sha256(status.encode('ascii')).digest())
-class Synchronizer(PrintError):
+class SynchronizerBase(NetworkJobOnDefaultServer):
+ """Subscribe over the network to a set of addresses, and monitor their statuses.
+ Every time a status changes, run a coroutine provided by the subclass.
+ """
+ def __init__(self, network):
+ NetworkJobOnDefaultServer.__init__(self, network)
+ self.asyncio_loop = network.asyncio_loop
+
+ def _reset(self):
+ super()._reset()
+ self.requested_addrs = set()
+ self.scripthash_to_address = {}
+ self._processed_some_notifications = False # so that we don't miss them
+ # Queues
+ self.add_queue = asyncio.Queue()
+ self.status_queue = asyncio.Queue()
+
+ async def _start_tasks(self):
+ try:
+ async with self.group as group:
+ await group.spawn(self.send_subscriptions())
+ await group.spawn(self.handle_status())
+ await group.spawn(self.main())
+ finally:
+ # we are being cancelled now
+ self.session.unsubscribe(self.status_queue)
+
+ def add(self, addr):
+ asyncio.run_coroutine_threadsafe(self._add_address(addr), self.asyncio_loop)
+
+ async def _add_address(self, addr):
+ if addr in self.requested_addrs: return
+ self.requested_addrs.add(addr)
+ await self.add_queue.put(addr)
+
+ async def _on_address_status(self, addr, status):
+ """Handle the change of the status of an address."""
+ raise NotImplementedError() # implemented by subclasses
+
+ async def send_subscriptions(self):
+ async def subscribe_to_address(addr):
+ h = address_to_scripthash(addr)
+ self.scripthash_to_address[h] = addr
+ await self.session.subscribe('blockchain.scripthash.subscribe', [h], self.status_queue)
+ self.requested_addrs.remove(addr)
+
+ while True:
+ addr = await self.add_queue.get()
+ await self.group.spawn(subscribe_to_address, addr)
+
+ async def handle_status(self):
+ while True:
+ h, status = await self.status_queue.get()
+ addr = self.scripthash_to_address[h]
+ await self.group.spawn(self._on_address_status, addr, status)
+ self._processed_some_notifications = True
+
+ async def main(self):
+ raise NotImplementedError() # implemented by subclasses
+
+
+class Synchronizer(SynchronizerBase):
'''The synchronizer keeps the wallet up-to-date with its set of
addresses and their transactions. It subscribes over the network
to wallet addresses, gets the wallet to generate new addresses
@@ -51,16 +115,12 @@ class Synchronizer(PrintError):
'''
def __init__(self, wallet):
self.wallet = wallet
- self.network = wallet.network
- self.asyncio_loop = wallet.network.asyncio_loop
+ SynchronizerBase.__init__(self, wallet.network)
+
+ def _reset(self):
+ super()._reset()
self.requested_tx = {}
self.requested_histories = {}
- self.requested_addrs = set()
- self.scripthash_to_address = {}
- self._processed_some_notifications = False # so that we don't miss them
- # Queues
- self.add_queue = asyncio.Queue()
- self.status_queue = asyncio.Queue()
def diagnostic_name(self):
return '{}:{}'.format(self.__class__.__name__, self.wallet.diagnostic_name())
@@ -70,14 +130,6 @@ class Synchronizer(PrintError):
and not self.requested_histories
and not self.requested_tx)
- def add(self, addr):
- asyncio.run_coroutine_threadsafe(self._add(addr), self.asyncio_loop)
-
- async def _add(self, addr):
- if addr in self.requested_addrs: return
- self.requested_addrs.add(addr)
- await self.add_queue.put(addr)
-
async def _on_address_status(self, addr, status):
history = self.wallet.history.get(addr, [])
if history_status(history) == status:
@@ -144,30 +196,6 @@ class Synchronizer(PrintError):
# callbacks
self.wallet.network.trigger_callback('new_transaction', self.wallet, tx)
- async def send_subscriptions(self, group: TaskGroup):
- async def subscribe_to_address(addr):
- h = address_to_scripthash(addr)
- self.scripthash_to_address[h] = addr
- await self.session.subscribe('blockchain.scripthash.subscribe', [h], self.status_queue)
- self.requested_addrs.remove(addr)
-
- while True:
- addr = await self.add_queue.get()
- await group.spawn(subscribe_to_address, addr)
-
- async def handle_status(self, group: TaskGroup):
- while True:
- h, status = await self.status_queue.get()
- addr = self.scripthash_to_address[h]
- await group.spawn(self._on_address_status, addr, status)
- self._processed_some_notifications = True
-
- @property
- def session(self):
- s = self.wallet.network.interface.session
- assert s is not None
- return s
-
async def main(self):
self.wallet.set_up_to_date(False)
# request missing txns, if any
@@ -178,7 +206,7 @@ class Synchronizer(PrintError):
await self._request_missing_txs(history)
# add addresses to bootstrap
for addr in self.wallet.get_addresses():
- await self._add(addr)
+ await self._add_address(addr)
# main loop
while True:
await asyncio.sleep(0.1)
@@ -189,3 +217,37 @@ class Synchronizer(PrintError):
self._processed_some_notifications = False
self.wallet.set_up_to_date(up_to_date)
self.wallet.network.trigger_callback('wallet_updated', self.wallet)
+
+
+class Notifier(SynchronizerBase):
+ """Watch addresses. Every time the status of an address changes,
+ an HTTP POST is sent to the corresponding URL.
+ """
+ def __init__(self, network):
+ SynchronizerBase.__init__(self, network)
+ self.watched_addresses = defaultdict(list) # type: Dict[str, List[str]]
+ self.start_watching_queue = asyncio.Queue()
+
+ async def main(self):
+ # resend existing subscriptions if we were restarted
+ for addr in self.watched_addresses:
+ await self._add_address(addr)
+ # main loop
+ while True:
+ addr, url = await self.start_watching_queue.get()
+ self.watched_addresses[addr].append(url)
+ await self._add_address(addr)
+
+ async def _on_address_status(self, addr, status):
+ self.print_error('new status for addr {}'.format(addr))
+ headers = {'content-type': 'application/json'}
+ data = {'address': addr, 'status': status}
+ for url in self.watched_addresses[addr]:
+ try:
+ async with make_aiohttp_session(proxy=self.network.proxy, headers=headers) as session:
+ async with session.post(url, json=data, headers=headers) as resp:
+ await resp.text()
+ except Exception as e:
+ self.print_error(str(e))
+ else:
+ self.print_error('Got Response for {}'.format(addr))
diff --git a/electrum/util.py b/electrum/util.py
@@ -879,7 +879,12 @@ VerifiedTxInfo = NamedTuple("VerifiedTxInfo", [("height", int),
("txpos", int),
("header_hash", str)])
-def make_aiohttp_session(proxy):
+
+def make_aiohttp_session(proxy: dict, headers=None, timeout=None):
+ if headers is None:
+ headers = {'User-Agent': 'Electrum'}
+ if timeout is None:
+ timeout = aiohttp.ClientTimeout(total=10)
if proxy:
connector = SocksConnector(
socks_ver=SocksVer.SOCKS5 if proxy['mode'] == 'socks5' else SocksVer.SOCKS4,
@@ -889,9 +894,9 @@ def make_aiohttp_session(proxy):
password=proxy.get('password', None),
rdns=True
)
- return aiohttp.ClientSession(headers={'User-Agent' : 'Electrum'}, timeout=aiohttp.ClientTimeout(total=10), connector=connector)
+ return aiohttp.ClientSession(headers=headers, timeout=timeout, connector=connector)
else:
- return aiohttp.ClientSession(headers={'User-Agent' : 'Electrum'}, timeout=aiohttp.ClientTimeout(total=10))
+ return aiohttp.ClientSession(headers=headers, timeout=timeout)
class SilentTaskGroup(TaskGroup):
diff --git a/electrum/verifier.py b/electrum/verifier.py
@@ -25,14 +25,14 @@ import asyncio
from typing import Sequence, Optional
import aiorpcx
-from aiorpcx import TaskGroup
-from .util import PrintError, bh2u, VerifiedTxInfo
+from .util import bh2u, VerifiedTxInfo
from .bitcoin import Hash, hash_decode, hash_encode
from .transaction import Transaction
from .blockchain import hash_header
from .interface import GracefulDisconnect
from . import constants
+from .network import NetworkJobOnDefaultServer
class MerkleVerificationFailure(Exception): pass
@@ -41,26 +41,33 @@ class MerkleRootMismatch(MerkleVerificationFailure): pass
class InnerNodeOfSpvProofIsValidTx(MerkleVerificationFailure): pass
-class SPV(PrintError):
+class SPV(NetworkJobOnDefaultServer):
""" Simple Payment Verification """
def __init__(self, network, wallet):
+ NetworkJobOnDefaultServer.__init__(self, network)
self.wallet = wallet
- self.network = network
+
+ def _reset(self):
+ super()._reset()
self.merkle_roots = {} # txid -> merkle root (once it has been verified)
self.requested_merkle = set() # txid set of pending requests
+ async def _start_tasks(self):
+ async with self.group as group:
+ await group.spawn(self.main)
+
def diagnostic_name(self):
return '{}:{}'.format(self.__class__.__name__, self.wallet.diagnostic_name())
- async def main(self, group: TaskGroup):
+ async def main(self):
self.blockchain = self.network.blockchain()
while True:
await self._maybe_undo_verifications()
- await self._request_proofs(group)
+ await self._request_proofs()
await asyncio.sleep(0.1)
- async def _request_proofs(self, group: TaskGroup):
+ async def _request_proofs(self):
local_height = self.blockchain.height()
unverified = self.wallet.get_unverified_txs()
@@ -75,12 +82,12 @@ class SPV(PrintError):
header = self.blockchain.read_header(tx_height)
if header is None:
if tx_height < constants.net.max_checkpoint():
- await group.spawn(self.network.request_chunk(tx_height, None, can_return_early=True))
+ await self.group.spawn(self.network.request_chunk(tx_height, None, can_return_early=True))
continue
# request now
self.print_error('requested merkle', tx_hash)
self.requested_merkle.add(tx_hash)
- await group.spawn(self._request_and_verify_single_proof, tx_hash, tx_height)
+ await self.group.spawn(self._request_and_verify_single_proof, tx_hash, tx_height)
async def _request_and_verify_single_proof(self, tx_hash, tx_height):
try:
diff --git a/electrum/websockets.py b/electrum/websockets.py
@@ -22,44 +22,49 @@
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
-import queue
-import threading, os, json
+import threading
+import os
+import json
from collections import defaultdict
+import asyncio
+from typing import Dict, List
+import traceback
+import sys
+
try:
from SimpleWebSocketServer import WebSocket, SimpleSSLWebSocketServer
except ImportError:
- import sys
sys.exit("install SimpleWebSocketServer")
-from . import util
+from .util import PrintError
from . import bitcoin
+from .synchronizer import SynchronizerBase
+
+request_queue = asyncio.Queue()
-request_queue = queue.Queue()
-class ElectrumWebSocket(WebSocket):
+class ElectrumWebSocket(WebSocket, PrintError):
def handleMessage(self):
assert self.data[0:3] == 'id:'
- util.print_error("message received", self.data)
+ self.print_error("message received", self.data)
request_id = self.data[3:]
- request_queue.put((self, request_id))
+ asyncio.run_coroutine_threadsafe(
+ request_queue.put((self, request_id)), asyncio.get_event_loop())
def handleConnected(self):
- util.print_error("connected", self.address)
+ self.print_error("connected", self.address)
def handleClose(self):
- util.print_error("closed", self.address)
+ self.print_error("closed", self.address)
-
-class WsClientThread(util.DaemonThread):
+class BalanceMonitor(SynchronizerBase):
def __init__(self, config, network):
- util.DaemonThread.__init__(self)
- self.network = network
+ SynchronizerBase.__init__(self, network)
self.config = config
- self.response_queue = queue.Queue()
- self.subscriptions = defaultdict(list)
+ self.expected_payments = defaultdict(list) # type: Dict[str, List[WebSocket, int]]
def make_request(self, request_id):
# read json file
@@ -72,69 +77,47 @@ class WsClientThread(util.DaemonThread):
amount = d.get('amount')
return addr, amount
- def reading_thread(self):
- while self.is_running():
- try:
- ws, request_id = request_queue.get()
- except queue.Empty:
- continue
+ async def main(self):
+ # resend existing subscriptions if we were restarted
+ for addr in self.expected_payments:
+ await self._add_address(addr)
+ # main loop
+ while True:
+ ws, request_id = await request_queue.get()
try:
addr, amount = self.make_request(request_id)
- except:
+ except Exception:
+ traceback.print_exc(file=sys.stderr)
continue
- l = self.subscriptions.get(addr, [])
- l.append((ws, amount))
- self.subscriptions[addr] = l
- self.network.subscribe_to_addresses([addr], self.response_queue.put)
-
- def run(self):
- threading.Thread(target=self.reading_thread).start()
- while self.is_running():
- try:
- r = self.response_queue.get(timeout=0.1)
- except queue.Empty:
- continue
- util.print_error('response', r)
- method = r.get('method')
- result = r.get('result')
- if result is None:
- continue
- if method == 'blockchain.scripthash.subscribe':
- addr = r.get('params')[0]
- scripthash = bitcoin.address_to_scripthash(addr)
- self.network.get_balance_for_scripthash(
- scripthash, self.response_queue.put)
- elif method == 'blockchain.scripthash.get_balance':
- scripthash = r.get('params')[0]
- addr = self.network.h2addr.get(scripthash, None)
- if addr is None:
- util.print_error(
- "can't find address for scripthash: %s" % scripthash)
- l = self.subscriptions.get(addr, [])
- for ws, amount in l:
- if not ws.closed:
- if sum(result.values()) >=amount:
- ws.sendMessage('paid')
+ self.expected_payments[addr].append((ws, amount))
+ await self._add_address(addr)
+ async def _on_address_status(self, addr, status):
+ self.print_error('new status for addr {}'.format(addr))
+ sh = bitcoin.address_to_scripthash(addr)
+ balance = await self.network.get_balance_for_scripthash(sh)
+ for ws, amount in self.expected_payments[addr]:
+ if not ws.closed:
+ if sum(balance.values()) >= amount:
+ ws.sendMessage('paid')
class WebSocketServer(threading.Thread):
- def __init__(self, config, ns):
+ def __init__(self, config, network):
threading.Thread.__init__(self)
self.config = config
- self.net_server = ns
+ self.network = network
+ asyncio.set_event_loop(network.asyncio_loop)
self.daemon = True
+ self.balance_monitor = BalanceMonitor(self.config, self.network)
+ self.start()
def run(self):
- t = WsClientThread(self.config, self.net_server)
- t.start()
-
+ asyncio.set_event_loop(self.network.asyncio_loop)
host = self.config.get('websocket_server')
port = self.config.get('websocket_port', 9999)
certfile = self.config.get('ssl_chain')
keyfile = self.config.get('ssl_privkey')
self.server = SimpleSSLWebSocketServer(host, port, ElectrumWebSocket, certfile, keyfile)
self.server.serveforever()
-
-
diff --git a/run_electrum b/run_electrum
@@ -438,7 +438,7 @@ if __name__ == '__main__':
d = daemon.Daemon(config, fd)
if config.get('websocket_server'):
from electrum import websockets
- websockets.WebSocketServer(config, d.network).start()
+ websockets.WebSocketServer(config, d.network)
if config.get('requests_dir'):
path = os.path.join(config.get('requests_dir'), 'index.html')
if not os.path.exists(path):