zeromq.py (16384B)
1 #!/usr/bin/env python3 2 # Copyright (C) 2020-2021 Ivan J. <parazyd@dyne.org> 3 # 4 # This file is part of obelisk 5 # 6 # This program is free software: you can redistribute it and/or modify 7 # it under the terms of the GNU Affero General Public License version 3 8 # as published by the Free Software Foundation. 9 # 10 # This program is distributed in the hope that it will be useful, 11 # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 # GNU Affero General Public License for more details. 14 # 15 # You should have received a copy of the GNU Affero General Public License 16 # along with this program. If not, see <http://www.gnu.org/licenses/>. 17 """ZeroMQ implementation for libbitcoin""" 18 import asyncio 19 import functools 20 import struct 21 from binascii import unhexlify 22 from random import randint 23 24 import zmq 25 import zmq.asyncio 26 27 from electrumobelisk.libbitcoin_errors import make_error_code, ErrorCode 28 from electrumobelisk.util import bh2u 29 30 31 def create_random_id(): 32 """Generate a random request ID""" 33 max_uint32 = 4294967295 34 return randint(0, max_uint32) 35 36 37 def pack_block_index(index): 38 """struct.pack given index""" 39 if isinstance(index, str): 40 index = unhexlify(index) 41 assert len(index) == 32 42 return index 43 if isinstance(index, int): 44 return struct.pack("<I", index) 45 46 raise ValueError( 47 f"Unknown index type {type(index)} v:{index}, should be int or bytearray" 48 ) 49 50 51 def to_int(xbytes): 52 """Make little-endian integer from given bytes""" 53 return int.from_bytes(xbytes, byteorder="little") 54 55 56 def checksum(xhash, index): 57 """ 58 This method takes a transaction hash and an index and returns a checksum. 59 60 This checksum is based on 49 bits starting from the 12th byte of the 61 reversed hash. Combined with the last 15 bits of the 4 byte index. 62 """ 63 mask = 0xFFFFFFFFFFFF8000 64 magic_start_position = 12 65 66 hash_bytes = bytes.fromhex(xhash)[::-1] 67 last_20_bytes = hash_bytes[magic_start_position:] 68 69 assert len(hash_bytes) == 32 70 assert index < 2**32 71 72 hash_upper_49_bits = to_int(last_20_bytes) & mask 73 index_lower_15_bits = index & ~mask 74 return hash_upper_49_bits | index_lower_15_bits 75 76 77 def unpack_table(row_fmt, data): 78 """Function to unpack table received from libbitcoin""" 79 # Get the number of rows 80 row_size = struct.calcsize(row_fmt) 81 nrows = len(data) // row_size 82 # Unpack 83 rows = [] 84 for idx in range(nrows): 85 offset = idx * row_size 86 row = struct.unpack_from(row_fmt, data, offset) 87 rows.append(row) 88 return rows 89 90 91 class ClientSettings: 92 """Class implementing ZMQ client settings""" 93 def __init__(self, timeout=10, context=None, loop=None): 94 self._timeout = timeout 95 self._context = context 96 self._loop = loop 97 98 @property 99 def context(self): 100 """ZMQ context property""" 101 if not self._context: 102 ctx = zmq.asyncio.Context() 103 ctx.linger = 500 # in milliseconds 104 self._context = ctx 105 return self._context 106 107 @context.setter 108 def context(self, context): 109 self._context = context 110 111 @property 112 def timeout(self): 113 """Set to None for no timeout""" 114 return self._timeout 115 116 @timeout.setter 117 def timeout(self, timeout): 118 self._timeout = timeout 119 120 121 class Request: 122 """Class implementing a _send_ request. 123 This is either a simple request/response affair or a subscription. 124 """ 125 def __init__(self, socket, command, data): 126 self.id_ = create_random_id() 127 self.socket = socket 128 self.command = command 129 self.data = data 130 self.future = asyncio.Future() 131 self.queue = None 132 133 async def send(self): 134 """Send the ZMQ request""" 135 request = [self.command, struct.pack("<I", self.id_), self.data] 136 await self.socket.send_multipart(request) 137 138 def is_subscription(self): 139 """If the request is a subscription, then the response to this 140 request is a notification. 141 """ 142 return self.queue is not None 143 144 def __str__(self): 145 return "Request(command, ID) {}, {:d}".format(self.command, self.id_) 146 147 148 class InvalidServerResponseException(Exception): 149 """Exception for invalid server responses""" 150 151 152 class Response: 153 """Class implementing a request response""" 154 def __init__(self, frame): 155 if len(frame) != 3: 156 raise InvalidServerResponseException( 157 f"Length of the frame was not 3: {len(frame)}") 158 159 self.command = frame[0] 160 self.request_id = struct.unpack("<I", frame[1])[0] 161 error_code = struct.unpack("<I", frame[2][:4])[0] 162 self.error_code = make_error_code(error_code) 163 self.data = frame[2][4:] 164 165 def is_bound_for_queue(self): 166 return len(self.data) > 0 167 168 def __str__(self): 169 return ( 170 "Response(command, request ID, error code, data):" + 171 f" {self.command}, {self.request_id}, {self.error_code}, {self.data}" 172 ) 173 174 175 class RequestCollection: 176 """RequestCollection carries a list of Requests and matches incoming 177 responses to them. 178 """ 179 def __init__(self, socket, loop): 180 self._socket = socket 181 self._requests = {} 182 self._task = asyncio.ensure_future(self._run(), loop=loop) 183 184 async def _run(self): 185 while True: 186 await self._receive() 187 188 async def stop(self): 189 """Stops listening for incoming responses (or subscription) messages. 190 Returns the number of _responses_ expected but which are now dropped 191 on the floor. 192 """ 193 self._task.cancel() 194 try: 195 await self._task 196 except asyncio.CancelledError: 197 return len(self._requests) 198 199 async def _receive(self): 200 frame = await self._socket.recv_multipart() 201 response = Response(frame) 202 203 if response.request_id in self._requests: 204 self._handle_response(response) 205 else: 206 print( 207 f"DEBUG: RequestCollection unhandled response {response.command}:{response.request_id}" # pylint: disable=C0301 208 ) 209 210 def _handle_response(self, response): 211 request = self._requests[response.request_id] 212 213 if request.is_subscription(): 214 if response.is_bound_for_queue(): 215 # TODO: decode the data into something usable 216 request.queue.put_nowait(response.data) 217 else: 218 request.future.set_result(response) 219 else: 220 self.delete_request(request) 221 request.future.set_result(response) 222 223 def add_request(self, request): 224 # TODO: we should maybe check if the request.id_ is unique 225 self._requests[request.id_] = request 226 227 def delete_request(self, request): 228 del self._requests[request.id_] 229 230 231 class Client: 232 """This class represents a connection to a libbitcoin server.""" 233 def __init__(self, log, endpoints, loop): 234 self.log = log 235 self._endpoints = endpoints 236 self._settings = ClientSettings(loop=loop) 237 self._query_socket = self._create_query_socket() 238 self._block_socket = self._create_block_socket() 239 self._request_collection = RequestCollection(self._query_socket, 240 self._settings._loop) 241 242 async def stop(self): 243 self.log.debug("zmq Client.stop()") 244 self._query_socket.close() 245 self._block_socket.close() 246 return await self._request_collection.stop() 247 248 def _create_block_socket(self): 249 socket = self._settings.context.socket( 250 zmq.SUB, # pylint: disable=E1101 251 io_loop=self._settings._loop, # pylint: disable=W0212 252 ) 253 socket.connect(self._endpoints["block"]) 254 socket.setsockopt_string(zmq.SUBSCRIBE, "") # pylint: disable=E1101 255 return socket 256 257 def _create_query_socket(self): 258 socket = self._settings.context.socket( 259 zmq.DEALER, # pylint: disable=E1101 260 io_loop=self._settings._loop, # pylint: disable=W0212 261 ) 262 socket.connect(self._endpoints["query"]) 263 return socket 264 265 async def _subscription_request(self, command, data): 266 request = await self._request(command, data) 267 request.queue = asyncio.Queue(loop=self._settings._loop) # pylint: disable=W0212 268 error_code, _ = await self._wait_for_response(request) 269 return error_code, request.queue 270 271 async def _simple_request(self, command, data): 272 return await self._wait_for_response(await 273 self._request(command, data)) 274 275 async def _request(self, command, data): 276 """Make a generic request. Both options are byte objects specified 277 like b'blockchain.fetch_block_header' as an example. 278 """ 279 request = Request(self._query_socket, command, data) 280 await request.send() 281 self._request_collection.add_request(request) 282 return request 283 284 async def _wait_for_response(self, request): 285 try: 286 response = await asyncio.wait_for(request.future, 287 self._settings.timeout) 288 except asyncio.TimeoutError: 289 self._request_collection.delete_request(request) 290 return ErrorCode.channel_timeout, None 291 292 assert response.command == request.command 293 assert response.request_id == request.id_ 294 return response.error_code, response.data 295 296 async def fetch_last_height(self): 297 """Fetch the blockchain tip and return integer height""" 298 command = b"blockchain.fetch_last_height" 299 error_code, data = await self._simple_request(command, b"") 300 if error_code: 301 return error_code, None 302 return error_code, struct.unpack("<I", data)[0] 303 304 async def fetch_block_header(self, index): 305 """Fetch a block header by its height or integer index""" 306 command = b"blockchain.fetch_block_header" 307 data = pack_block_index(index) 308 return await self._simple_request(command, data) 309 310 async def fetch_block_transaction_hashes(self, index): 311 """Fetch transaction hashes in a block at height index""" 312 command = b"blockchain.fetch_block_transaction_hashes" 313 data = pack_block_index(index) 314 error_code, data = await self._simple_request(command, data) 315 if error_code: 316 return error_code, None 317 return error_code, unpack_table("32s", data) 318 319 async def fetch_blockchain_transaction(self, txid): 320 """Fetch transaction by txid (not including mempool)""" 321 command = b"blockchain.fetch_transaction2" 322 error_code, data = await self._simple_request( 323 command, 324 bytes.fromhex(txid)[::-1]) 325 if error_code: 326 return error_code, None 327 return error_code, data 328 329 async def fetch_mempool_transaction(self, txid): 330 """Fetch transaction by txid (including mempool)""" 331 command = b"transaction_pool.fetch_transaction2" 332 error_code, data = await self._simple_request( 333 command, 334 bytes.fromhex(txid)[::-1]) 335 if error_code: 336 return error_code, None 337 return error_code, data 338 339 async def subscribe_scripthash(self, scripthash): 340 """Subscribe to scripthash""" 341 command = b"subscribe.key" 342 decoded_address = unhexlify(scripthash) 343 return await self._subscription_request(command, decoded_address) 344 345 async def unsubscribe_scripthash(self, scripthash): 346 """Unsubscribe scripthash""" 347 # TODO: This call should ideally also remove the subscription 348 # request from the RequestCollection. 349 # This call solicits a final call from the server with an 350 # `error::service_stopped` error code. 351 command = b"unsubscribe.key" 352 decoded_address = unhexlify(scripthash) 353 return await self._simple_request(command, decoded_address) 354 355 async def fetch_history4(self, scripthash, height=0): 356 """Fetch history for given scripthash""" 357 command = b"blockchain.fetch_history4" 358 decoded_address = unhexlify(scripthash) 359 error_code, raw_points = await self._simple_request( 360 command, decoded_address + struct.pack("<I", height)) 361 if error_code: 362 return error_code, None 363 364 def make_tuple(row): 365 kind, height, tx_hash, index, value = row 366 return ( 367 kind, 368 { 369 "hash": tx_hash, 370 "index": index 371 }, 372 height, 373 value, 374 checksum(tx_hash[::-1].hex(), index), 375 ) 376 377 rows = unpack_table("<BI32sIQ", raw_points) 378 points = [make_tuple(row) for row in rows] 379 correlated_points = Client.__correlate(points) 380 # self.log.debug("history points: %s", points) 381 # self.log.debug("history correlated: %s", correlated_points) 382 return error_code, self._sort_correlated_points(correlated_points) 383 384 @staticmethod 385 def _sort_correlated_points(points): 386 """Sort by ascending height""" 387 if len(points) < 2: 388 return points 389 return sorted(points, key=lambda x: list(x.values())[0]["height"]) 390 391 async def broadcast_transaction(self, rawtx): 392 """Broadcast given raw transaction""" 393 command = b"transaction_pool.broadcast" 394 return await self._simple_request(command, rawtx) 395 396 async def fetch_balance(self, scripthash): 397 """Fetch balance for given scripthash""" 398 error_code, history = await self.fetch_history4(scripthash) 399 if error_code: 400 return error_code, None 401 402 utxo = Client.__receives_without_spends(history) 403 return error_code, functools.reduce( 404 lambda accumulator, point: accumulator + point["value"], utxo, 0) 405 406 async def fetch_utxo(self, scripthash): 407 """Find UTXO for given scripthash""" 408 error_code, history = await self.fetch_history4(scripthash) 409 if error_code: 410 return error_code, None 411 return error_code, Client.__receives_without_spends(history) 412 413 async def subscribe_to_blocks(self, queue): 414 asyncio.ensure_future(self._listen_for_blocks(queue)) 415 return queue 416 417 async def _listen_for_blocks(self, queue): 418 """Infinite loop for block subscription. 419 Returns raw blocks as they're received. 420 """ 421 while True: 422 frame = await self._block_socket.recv_multipart() 423 seq = struct.unpack("<H", frame[0])[0] 424 height = struct.unpack("<I", frame[1])[0] 425 block_data = frame[2] 426 queue.put_nowait((seq, height, block_data)) 427 428 @staticmethod 429 def __receives_without_spends(history): 430 return (point for point in history if "spent" not in point) 431 432 @staticmethod 433 def __correlate(points): 434 transfers, checksum_to_index = Client.__find_receives(points) 435 transfers = Client.__correlate_spends_to_receives( 436 points, transfers, checksum_to_index) 437 return transfers 438 439 @staticmethod 440 def __correlate_spends_to_receives(points, transfers, checksum_to_index): 441 for point in points: 442 if point[0] == 1: # receive 443 continue 444 445 spent = { 446 "hash": point[1]["hash"], 447 "height": point[2], 448 "index": point[1]["index"], 449 } 450 if point[3] not in checksum_to_index: 451 transfers.append({"spent": spent}) 452 else: 453 transfers[checksum_to_index[point[3]]]["spent"] = spent 454 455 return transfers 456 457 @staticmethod 458 def __find_receives(points): 459 transfers = [] 460 checksum_to_index = {} 461 462 for point in points: 463 if point[0] == 0: # spent 464 continue 465 466 transfers.append({ 467 "received": { 468 "hash": point[1]["hash"], 469 "height": point[2], 470 "index": point[1]["index"], 471 }, 472 "value": point[3], 473 }) 474 475 checksum_to_index[point[4]] = len(transfers) - 1 476 477 return transfers, checksum_to_index