commit f6449ea78a20d6ef3d62d7dc00de34ec05bbfc10
parent 0c8ef25aea1b5e0bab605b950f77d93279861c9b
Author: parazyd <parazyd@dyne.org>
Date: Mon, 19 Apr 2021 17:59:10 +0200
protocol: Do task cleanup when peer disconnects.
Diffstat:
3 files changed, 94 insertions(+), 63 deletions(-)
diff --git a/obelisk/protocol.py b/obelisk/protocol.py
@@ -65,13 +65,9 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
self.endpoints = endpoints
self.server_cfg = server_cfg
self.loop = asyncio.get_event_loop()
- self.chain_tip = 0
- # Consider renaming bx to something else
self.bx = Client(log, endpoints, self.loop)
self.block_queue = None
- # TODO: Clean up on client disconnect
- self.tasks = []
- self.sh_subscriptions = {}
+ self.peers = {}
if chain == "mainnet": # pragma: no cover
self.genesis = "000000000019d6689c085ae165831e934ff763ae46a2a6c172b3f1b60a8ce26f"
@@ -112,28 +108,32 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
self.log.debug("ElectrumProtocol.stop()")
self.stopped = True
if self.bx:
- # unsub_pool = []
- # for i in self.sh_subscriptions: # pragma: no cover
- # self.log.debug("bx.unsubscribe %s", i)
- # unsub_pool.append(self.bx.unsubscribe_scripthash(i))
- # await asyncio.gather(*unsub_pool, return_exceptions=True)
+ for i in self.peers:
+ await self._peer_cleanup(i)
await self.bx.stop()
- # idxs = []
- # for task in self.tasks:
- # idxs.append(self.tasks.index(task))
- # task.cancel()
- # for i in idxs:
- # del self.tasks[i]
+ async def _peer_cleanup(self, peer):
+ """Cleanup tasks and data for peer"""
+ self.log.debug("Cleaning up data for %s", peer)
+ for i in self.peers[peer]["tasks"]:
+ i.cancel()
+ for i in self.peers[peer]["sh"]:
+ self.peers[peer]["sh"][i]["task"].cancel()
+
+ @staticmethod
+ def _get_peer(writer):
+ peer_t = writer._transport.get_extra_info("peername") # pylint: disable=W0212
+ return f"{peer_t[0]}:{peer_t[1]}"
async def recv(self, reader, writer):
"""Loop ran upon a connection which acts as a JSON-RPC handler"""
recv_buf = bytearray()
+ self.peers[self._get_peer(writer)] = {"tasks": [], "sh": {}}
+
while not self.stopped:
data = await reader.read(4096)
if not data or len(data) == 0:
- self.log.debug("Received EOF, disconnect")
- # TODO: cancel asyncio tasks for this client here?
+ await self._peer_cleanup(self._get_peer(writer))
return
recv_buf.extend(data)
lb = recv_buf.find(b"\n")
@@ -181,12 +181,7 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
async def handle_query(self, writer, query): # pylint: disable=R0915,R0912,R0911
"""Electrum protocol method handler mapper"""
- if "method" not in query:
- self.log.debug("No 'method' in query: %s", query)
- return await self._send_reply(writer, JsonRPCError.invalidrequest(),
- None)
- if "id" not in query:
- self.log.debug("No 'id' in query: %s", query)
+ if "method" not in query or "id" not in query:
return await self._send_reply(writer, JsonRPCError.invalidrequest(),
None)
@@ -304,13 +299,11 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
self.block_queue = asyncio.Queue()
await self.bx.subscribe_to_blocks(self.block_queue)
while True:
- # item = (seq, height, block_data)
item = await self.block_queue.get()
if len(item) != 3:
self.log.debug("error: item from block queue len != 3")
continue
- self.chain_tip = item[1]
header = block_to_header(item[2])
params = [{"height": item[1], "hex": safe_hexlify(header)}]
await self._send_notification(writer,
@@ -331,8 +324,8 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
self.log.debug("Got error: %s", repr(_ec))
return JsonRPCError.internalerror()
- self.chain_tip = height
- self.tasks.append(asyncio.create_task(self.header_notifier(writer)))
+ self.peers[self._get_peer(writer)]["tasks"].append(
+ asyncio.create_task(self.header_notifier(writer)))
ret = {"height": height, "hex": safe_hexlify(tip_header)}
return {"result": ret}
@@ -428,32 +421,56 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
return {"result": ret}
+ async def scripthash_renewer(self, scripthash, queue):
+ while True:
+ try:
+ self.log.debug("scriphash renewer: %s", scripthash)
+ _ec = await self.bx.subscribe_scripthash(scripthash, queue)
+ if _ec and _ec != 0:
+ self.log.error("bx.subscribe_scripthash failed: %s",
+ repr(_ec))
+ await asyncio.sleep(60)
+ except asyncio.CancelledError:
+ self.log.debug("%s renewer cancelled", scripthash)
+ break
+
async def scripthash_notifier(self, writer, scripthash):
# TODO: Mempool
+ # TODO: This is still flaky and not always notified. Investigate.
+ self.log.debug("notifier")
method = "blockchain.scripthash.subscribe"
- while True:
- _ec, sh_queue = await self.bx.subscribe_scripthash(scripthash)
- if _ec and _ec != 0:
- self.log.error("bx.subscribe_scripthash failed: %s", repr(_ec))
- return
-
- item = await sh_queue.get()
- _ec, height, txid = struct.unpack("<HI32s", item)
+ queue = asyncio.Queue()
+ renew_task = asyncio.create_task(
+ self.scripthash_renewer(scripthash, queue))
- if (_ec == ErrorCode.service_stopped.value and height == 0 and
- not self.stopped):
- # Subscription expired
- continue
-
- self.sh_subscriptions[scripthash]["status"].append(
- (hash_to_hex_str(txid), height))
-
- params = [
- scripthash,
- ElectrumProtocol.__scripthash_status_encode(
- self.sh_subscriptions[scripthash]["status"]),
- ]
- await self._send_notification(writer, method, params)
+ while True:
+ try:
+ item = await queue.get()
+ _ec, height, txid = struct.unpack("<HI32s", item)
+
+ if (_ec == ErrorCode.service_stopped.value and height == 0 and
+ not self.stopped):
+ self.log.debug("subscription expired: %s", scripthash)
+ # Subscription expired
+ continue
+
+ self.peers[self._get_peer(writer)]["sh"]["status"].append(
+ (hash_to_hex_str(txid), height))
+
+ self.log.debug("shnotifier: Got _ec: %d", _ec)
+ self.log.debug("shnotifier: Got height: %d", height)
+ self.log.debug("shnotifier: Got txid: %s",
+ hash_to_hex_str(txid))
+
+ params = [
+ scripthash,
+ ElectrumProtocol.__scripthash_status_encode(self.peers[
+ self._get_peer(writer)]["sh"]["scripthash"]["status"]),
+ ]
+ await self._send_notification(writer, method, params)
+ except asyncio.CancelledError:
+ break
+ renew_task.cancel()
async def scripthash_subscribe(self, writer, query): # pylint: disable=W0613
"""Method: blockchain.scripthash.subscribe
@@ -470,16 +487,17 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
if _ec and _ec != 0:
return JsonRPCError.internalerror()
- if len(history) < 1:
- return {"result": None}
-
# TODO: Check how history4 acts for mempool/unconfirmed
status = ElectrumProtocol.__scripthash_status_from_history(history)
- self.sh_subscriptions[scripthash] = {"status": status}
+ self.peers[self._get_peer(writer)]["sh"][scripthash] = {
+ "status": status
+ }
task = asyncio.create_task(self.scripthash_notifier(writer, scripthash))
- self.sh_subscriptions[scripthash]["task"] = task
+ self.peers[self._get_peer(writer)]["sh"][scripthash]["task"] = task
+ if len(history) < 1:
+ return {"result": None}
return {"result": ElectrumProtocol.__scripthash_status_encode(status)}
@staticmethod
@@ -517,10 +535,11 @@ class ElectrumProtocol(asyncio.Protocol): # pylint: disable=R0904,R0902
if not is_hash256_str(scripthash):
return JsonRPCError.invalidparams()
- if scripthash in self.sh_subscriptions:
- self.sh_subscriptions[scripthash]["task"].cancel()
+ if scripthash in self.peers[self._get_peer(writer)]["sh"]:
+ self.peers[self._get_peer(
+ writer)]["sh"][scripthash]["task"].cancel()
# await self.bx.unsubscribe_scripthash(scripthash)
- del self.sh_subscriptions[scripthash]
+ del self.peers[self._get_peer(writer)]["sh"][scripthash]
return {"result": True}
return {"result": False}
diff --git a/obelisk/zeromq.py b/obelisk/zeromq.py
@@ -266,11 +266,11 @@ class Client:
socket.connect(self._endpoints["query"])
return socket
- async def _subscription_request(self, command, data):
+ async def _subscription_request(self, command, data, queue):
request = await self._request(command, data)
- request.queue = asyncio.Queue()
+ request.queue = queue
error_code, _ = await self._wait_for_response(request)
- return error_code, request.queue
+ return error_code
async def _simple_request(self, command, data):
return await self._wait_for_response(await self._request(command, data))
@@ -345,11 +345,11 @@ class Client:
return error_code, None
return error_code, data
- async def subscribe_scripthash(self, scripthash):
+ async def subscribe_scripthash(self, scripthash, queue):
"""Subscribe to scripthash"""
command = b"subscribe.key"
decoded_address = unhexlify(scripthash)
- return await self._subscription_request(command, decoded_address)
+ return await self._subscription_request(command, decoded_address, queue)
async def unsubscribe_scripthash(self, scripthash):
"""Unsubscribe scripthash"""
diff --git a/tests/test_electrum_protocol.py b/tests/test_electrum_protocol.py
@@ -399,11 +399,21 @@ async def test_send_reply(protocol, writer, method):
assert_equal(writer.mock, expect)
+class MockTransport:
+
+ def __init__(self):
+ self.peername = ("foo", 42)
+
+ def get_extra_info(self, param):
+ return self.peername
+
+
class MockWriter(asyncio.StreamWriter): # pragma: no cover
"""Mock class for StreamWriter"""
def __init__(self):
self.mock = None
+ self._transport = MockTransport()
def write(self, data):
self.mock = data
@@ -455,6 +465,8 @@ async def main():
protocol = ElectrumProtocol(log, "testnet", libbitcoin, {})
writer = MockWriter()
+ protocol.peers[protocol._get_peer(writer)] = {"tasks": [], "sh": {}}
+
for func in orchestration:
try:
await orchestration[func](protocol, writer, func)