obelisk

Electrum server using libbitcoin as its backend
git clone https://git.parazyd.org/obelisk
Log | Files | Refs | README | LICENSE

zeromq.py (16384B)


      1 #!/usr/bin/env python3
      2 # Copyright (C) 2020-2021 Ivan J. <parazyd@dyne.org>
      3 #
      4 # This file is part of obelisk
      5 #
      6 # This program is free software: you can redistribute it and/or modify
      7 # it under the terms of the GNU Affero General Public License version 3
      8 # as published by the Free Software Foundation.
      9 #
     10 # This program is distributed in the hope that it will be useful,
     11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 # GNU Affero General Public License for more details.
     14 #
     15 # You should have received a copy of the GNU Affero General Public License
     16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 """ZeroMQ implementation for libbitcoin"""
     18 import asyncio
     19 import functools
     20 import struct
     21 from binascii import unhexlify
     22 from random import randint
     23 
     24 import zmq
     25 import zmq.asyncio
     26 
     27 from electrumobelisk.libbitcoin_errors import make_error_code, ErrorCode
     28 from electrumobelisk.util import bh2u
     29 
     30 
     31 def create_random_id():
     32     """Generate a random request ID"""
     33     max_uint32 = 4294967295
     34     return randint(0, max_uint32)
     35 
     36 
     37 def pack_block_index(index):
     38     """struct.pack given index"""
     39     if isinstance(index, str):
     40         index = unhexlify(index)
     41         assert len(index) == 32
     42         return index
     43     if isinstance(index, int):
     44         return struct.pack("<I", index)
     45 
     46     raise ValueError(
     47         f"Unknown index type {type(index)} v:{index}, should be int or bytearray"
     48     )
     49 
     50 
     51 def to_int(xbytes):
     52     """Make little-endian integer from given bytes"""
     53     return int.from_bytes(xbytes, byteorder="little")
     54 
     55 
     56 def checksum(xhash, index):
     57     """
     58     This method takes a transaction hash and an index and returns a checksum.
     59 
     60     This checksum is based on 49 bits starting from the 12th byte of the
     61     reversed hash. Combined with the last 15 bits of the 4 byte index.
     62     """
     63     mask = 0xFFFFFFFFFFFF8000
     64     magic_start_position = 12
     65 
     66     hash_bytes = bytes.fromhex(xhash)[::-1]
     67     last_20_bytes = hash_bytes[magic_start_position:]
     68 
     69     assert len(hash_bytes) == 32
     70     assert index < 2**32
     71 
     72     hash_upper_49_bits = to_int(last_20_bytes) & mask
     73     index_lower_15_bits = index & ~mask
     74     return hash_upper_49_bits | index_lower_15_bits
     75 
     76 
     77 def unpack_table(row_fmt, data):
     78     """Function to unpack table received from libbitcoin"""
     79     # Get the number of rows
     80     row_size = struct.calcsize(row_fmt)
     81     nrows = len(data) // row_size
     82     # Unpack
     83     rows = []
     84     for idx in range(nrows):
     85         offset = idx * row_size
     86         row = struct.unpack_from(row_fmt, data, offset)
     87         rows.append(row)
     88     return rows
     89 
     90 
     91 class ClientSettings:
     92     """Class implementing ZMQ client settings"""
     93     def __init__(self, timeout=10, context=None, loop=None):
     94         self._timeout = timeout
     95         self._context = context
     96         self._loop = loop
     97 
     98     @property
     99     def context(self):
    100         """ZMQ context property"""
    101         if not self._context:
    102             ctx = zmq.asyncio.Context()
    103             ctx.linger = 500  # in milliseconds
    104             self._context = ctx
    105         return self._context
    106 
    107     @context.setter
    108     def context(self, context):
    109         self._context = context
    110 
    111     @property
    112     def timeout(self):
    113         """Set to None for no timeout"""
    114         return self._timeout
    115 
    116     @timeout.setter
    117     def timeout(self, timeout):
    118         self._timeout = timeout
    119 
    120 
    121 class Request:
    122     """Class implementing a _send_ request.
    123     This is either a simple request/response affair or a subscription.
    124     """
    125     def __init__(self, socket, command, data):
    126         self.id_ = create_random_id()
    127         self.socket = socket
    128         self.command = command
    129         self.data = data
    130         self.future = asyncio.Future()
    131         self.queue = None
    132 
    133     async def send(self):
    134         """Send the ZMQ request"""
    135         request = [self.command, struct.pack("<I", self.id_), self.data]
    136         await self.socket.send_multipart(request)
    137 
    138     def is_subscription(self):
    139         """If the request is a subscription, then the response to this
    140         request is a notification.
    141         """
    142         return self.queue is not None
    143 
    144     def __str__(self):
    145         return "Request(command, ID) {}, {:d}".format(self.command, self.id_)
    146 
    147 
    148 class InvalidServerResponseException(Exception):
    149     """Exception for invalid server responses"""
    150 
    151 
    152 class Response:
    153     """Class implementing a request response"""
    154     def __init__(self, frame):
    155         if len(frame) != 3:
    156             raise InvalidServerResponseException(
    157                 f"Length of the frame was not 3: {len(frame)}")
    158 
    159         self.command = frame[0]
    160         self.request_id = struct.unpack("<I", frame[1])[0]
    161         error_code = struct.unpack("<I", frame[2][:4])[0]
    162         self.error_code = make_error_code(error_code)
    163         self.data = frame[2][4:]
    164 
    165     def is_bound_for_queue(self):
    166         return len(self.data) > 0
    167 
    168     def __str__(self):
    169         return (
    170             "Response(command, request ID, error code, data):" +
    171             f" {self.command}, {self.request_id}, {self.error_code}, {self.data}"
    172         )
    173 
    174 
    175 class RequestCollection:
    176     """RequestCollection carries a list of Requests and matches incoming
    177     responses to them.
    178     """
    179     def __init__(self, socket, loop):
    180         self._socket = socket
    181         self._requests = {}
    182         self._task = asyncio.ensure_future(self._run(), loop=loop)
    183 
    184     async def _run(self):
    185         while True:
    186             await self._receive()
    187 
    188     async def stop(self):
    189         """Stops listening for incoming responses (or subscription) messages.
    190         Returns the number of _responses_ expected but which are now dropped
    191         on the floor.
    192         """
    193         self._task.cancel()
    194         try:
    195             await self._task
    196         except asyncio.CancelledError:
    197             return len(self._requests)
    198 
    199     async def _receive(self):
    200         frame = await self._socket.recv_multipart()
    201         response = Response(frame)
    202 
    203         if response.request_id in self._requests:
    204             self._handle_response(response)
    205         else:
    206             print(
    207                 f"DEBUG: RequestCollection unhandled response {response.command}:{response.request_id}"  # pylint: disable=C0301
    208             )
    209 
    210     def _handle_response(self, response):
    211         request = self._requests[response.request_id]
    212 
    213         if request.is_subscription():
    214             if response.is_bound_for_queue():
    215                 # TODO: decode the data into something usable
    216                 request.queue.put_nowait(response.data)
    217             else:
    218                 request.future.set_result(response)
    219         else:
    220             self.delete_request(request)
    221             request.future.set_result(response)
    222 
    223     def add_request(self, request):
    224         # TODO: we should maybe check if the request.id_ is unique
    225         self._requests[request.id_] = request
    226 
    227     def delete_request(self, request):
    228         del self._requests[request.id_]
    229 
    230 
    231 class Client:
    232     """This class represents a connection to a libbitcoin server."""
    233     def __init__(self, log, endpoints, loop):
    234         self.log = log
    235         self._endpoints = endpoints
    236         self._settings = ClientSettings(loop=loop)
    237         self._query_socket = self._create_query_socket()
    238         self._block_socket = self._create_block_socket()
    239         self._request_collection = RequestCollection(self._query_socket,
    240                                                      self._settings._loop)
    241 
    242     async def stop(self):
    243         self.log.debug("zmq Client.stop()")
    244         self._query_socket.close()
    245         self._block_socket.close()
    246         return await self._request_collection.stop()
    247 
    248     def _create_block_socket(self):
    249         socket = self._settings.context.socket(
    250             zmq.SUB,  # pylint: disable=E1101
    251             io_loop=self._settings._loop,  # pylint: disable=W0212
    252         )
    253         socket.connect(self._endpoints["block"])
    254         socket.setsockopt_string(zmq.SUBSCRIBE, "")  # pylint: disable=E1101
    255         return socket
    256 
    257     def _create_query_socket(self):
    258         socket = self._settings.context.socket(
    259             zmq.DEALER,  # pylint: disable=E1101
    260             io_loop=self._settings._loop,  # pylint: disable=W0212
    261         )
    262         socket.connect(self._endpoints["query"])
    263         return socket
    264 
    265     async def _subscription_request(self, command, data):
    266         request = await self._request(command, data)
    267         request.queue = asyncio.Queue(loop=self._settings._loop)  # pylint: disable=W0212
    268         error_code, _ = await self._wait_for_response(request)
    269         return error_code, request.queue
    270 
    271     async def _simple_request(self, command, data):
    272         return await self._wait_for_response(await
    273                                              self._request(command, data))
    274 
    275     async def _request(self, command, data):
    276         """Make a generic request. Both options are byte objects specified
    277         like b'blockchain.fetch_block_header' as an example.
    278         """
    279         request = Request(self._query_socket, command, data)
    280         await request.send()
    281         self._request_collection.add_request(request)
    282         return request
    283 
    284     async def _wait_for_response(self, request):
    285         try:
    286             response = await asyncio.wait_for(request.future,
    287                                               self._settings.timeout)
    288         except asyncio.TimeoutError:
    289             self._request_collection.delete_request(request)
    290             return ErrorCode.channel_timeout, None
    291 
    292         assert response.command == request.command
    293         assert response.request_id == request.id_
    294         return response.error_code, response.data
    295 
    296     async def fetch_last_height(self):
    297         """Fetch the blockchain tip and return integer height"""
    298         command = b"blockchain.fetch_last_height"
    299         error_code, data = await self._simple_request(command, b"")
    300         if error_code:
    301             return error_code, None
    302         return error_code, struct.unpack("<I", data)[0]
    303 
    304     async def fetch_block_header(self, index):
    305         """Fetch a block header by its height or integer index"""
    306         command = b"blockchain.fetch_block_header"
    307         data = pack_block_index(index)
    308         return await self._simple_request(command, data)
    309 
    310     async def fetch_block_transaction_hashes(self, index):
    311         """Fetch transaction hashes in a block at height index"""
    312         command = b"blockchain.fetch_block_transaction_hashes"
    313         data = pack_block_index(index)
    314         error_code, data = await self._simple_request(command, data)
    315         if error_code:
    316             return error_code, None
    317         return error_code, unpack_table("32s", data)
    318 
    319     async def fetch_blockchain_transaction(self, txid):
    320         """Fetch transaction by txid (not including mempool)"""
    321         command = b"blockchain.fetch_transaction2"
    322         error_code, data = await self._simple_request(
    323             command,
    324             bytes.fromhex(txid)[::-1])
    325         if error_code:
    326             return error_code, None
    327         return error_code, data
    328 
    329     async def fetch_mempool_transaction(self, txid):
    330         """Fetch transaction by txid (including mempool)"""
    331         command = b"transaction_pool.fetch_transaction2"
    332         error_code, data = await self._simple_request(
    333             command,
    334             bytes.fromhex(txid)[::-1])
    335         if error_code:
    336             return error_code, None
    337         return error_code, data
    338 
    339     async def subscribe_scripthash(self, scripthash):
    340         """Subscribe to scripthash"""
    341         command = b"subscribe.key"
    342         decoded_address = unhexlify(scripthash)
    343         return await self._subscription_request(command, decoded_address)
    344 
    345     async def unsubscribe_scripthash(self, scripthash):
    346         """Unsubscribe scripthash"""
    347         # TODO: This call should ideally also remove the subscription
    348         # request from the RequestCollection.
    349         # This call solicits a final call from the server with an
    350         # `error::service_stopped` error code.
    351         command = b"unsubscribe.key"
    352         decoded_address = unhexlify(scripthash)
    353         return await self._simple_request(command, decoded_address)
    354 
    355     async def fetch_history4(self, scripthash, height=0):
    356         """Fetch history for given scripthash"""
    357         command = b"blockchain.fetch_history4"
    358         decoded_address = unhexlify(scripthash)
    359         error_code, raw_points = await self._simple_request(
    360             command, decoded_address + struct.pack("<I", height))
    361         if error_code:
    362             return error_code, None
    363 
    364         def make_tuple(row):
    365             kind, height, tx_hash, index, value = row
    366             return (
    367                 kind,
    368                 {
    369                     "hash": tx_hash,
    370                     "index": index
    371                 },
    372                 height,
    373                 value,
    374                 checksum(tx_hash[::-1].hex(), index),
    375             )
    376 
    377         rows = unpack_table("<BI32sIQ", raw_points)
    378         points = [make_tuple(row) for row in rows]
    379         correlated_points = Client.__correlate(points)
    380         # self.log.debug("history points: %s", points)
    381         # self.log.debug("history correlated: %s", correlated_points)
    382         return error_code, self._sort_correlated_points(correlated_points)
    383 
    384     @staticmethod
    385     def _sort_correlated_points(points):
    386         """Sort by ascending height"""
    387         if len(points) < 2:
    388             return points
    389         return sorted(points, key=lambda x: list(x.values())[0]["height"])
    390 
    391     async def broadcast_transaction(self, rawtx):
    392         """Broadcast given raw transaction"""
    393         command = b"transaction_pool.broadcast"
    394         return await self._simple_request(command, rawtx)
    395 
    396     async def fetch_balance(self, scripthash):
    397         """Fetch balance for given scripthash"""
    398         error_code, history = await self.fetch_history4(scripthash)
    399         if error_code:
    400             return error_code, None
    401 
    402         utxo = Client.__receives_without_spends(history)
    403         return error_code, functools.reduce(
    404             lambda accumulator, point: accumulator + point["value"], utxo, 0)
    405 
    406     async def fetch_utxo(self, scripthash):
    407         """Find UTXO for given scripthash"""
    408         error_code, history = await self.fetch_history4(scripthash)
    409         if error_code:
    410             return error_code, None
    411         return error_code, Client.__receives_without_spends(history)
    412 
    413     async def subscribe_to_blocks(self, queue):
    414         asyncio.ensure_future(self._listen_for_blocks(queue))
    415         return queue
    416 
    417     async def _listen_for_blocks(self, queue):
    418         """Infinite loop for block subscription.
    419         Returns raw blocks as they're received.
    420         """
    421         while True:
    422             frame = await self._block_socket.recv_multipart()
    423             seq = struct.unpack("<H", frame[0])[0]
    424             height = struct.unpack("<I", frame[1])[0]
    425             block_data = frame[2]
    426             queue.put_nowait((seq, height, block_data))
    427 
    428     @staticmethod
    429     def __receives_without_spends(history):
    430         return (point for point in history if "spent" not in point)
    431 
    432     @staticmethod
    433     def __correlate(points):
    434         transfers, checksum_to_index = Client.__find_receives(points)
    435         transfers = Client.__correlate_spends_to_receives(
    436             points, transfers, checksum_to_index)
    437         return transfers
    438 
    439     @staticmethod
    440     def __correlate_spends_to_receives(points, transfers, checksum_to_index):
    441         for point in points:
    442             if point[0] == 1:  # receive
    443                 continue
    444 
    445             spent = {
    446                 "hash": point[1]["hash"],
    447                 "height": point[2],
    448                 "index": point[1]["index"],
    449             }
    450             if point[3] not in checksum_to_index:
    451                 transfers.append({"spent": spent})
    452             else:
    453                 transfers[checksum_to_index[point[3]]]["spent"] = spent
    454 
    455         return transfers
    456 
    457     @staticmethod
    458     def __find_receives(points):
    459         transfers = []
    460         checksum_to_index = {}
    461 
    462         for point in points:
    463             if point[0] == 0:  # spent
    464                 continue
    465 
    466             transfers.append({
    467                 "received": {
    468                     "hash": point[1]["hash"],
    469                     "height": point[2],
    470                     "index": point[1]["index"],
    471                 },
    472                 "value": point[3],
    473             })
    474 
    475             checksum_to_index[point[4]] = len(transfers) - 1
    476 
    477         return transfers, checksum_to_index