obelisk

Electrum server using libbitcoin as its backend
git clone https://git.parazyd.org/obelisk
Log | Files | Refs | README | LICENSE

zeromq.py (17187B)


      1 #!/usr/bin/env python3
      2 # Copyright (C) 2020-2021 Ivan J. <parazyd@dyne.org>
      3 #
      4 # This file is part of obelisk
      5 #
      6 # This program is free software: you can redistribute it and/or modify
      7 # it under the terms of the GNU Affero General Public License version 3
      8 # as published by the Free Software Foundation.
      9 #
     10 # This program is distributed in the hope that it will be useful,
     11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 # GNU Affero General Public License for more details.
     14 #
     15 # You should have received a copy of the GNU Affero General Public License
     16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 """ZeroMQ implementation for libbitcoin"""
     18 import asyncio
     19 import functools
     20 import struct
     21 from binascii import unhexlify
     22 from random import randint
     23 
     24 import zmq
     25 import zmq.asyncio
     26 
     27 from obelisk.errors_libbitcoin import make_error_code, ZMQError
     28 from obelisk.util import hash_to_hex_str
     29 
     30 
     31 def create_random_id():
     32     """Generate a random request ID"""
     33     max_uint32 = 4294967295
     34     return randint(0, max_uint32)
     35 
     36 
     37 def pack_block_index(index):
     38     """struct.pack given index"""
     39     if isinstance(index, str):
     40         index = unhexlify(index)
     41         assert len(index) == 32
     42         return index
     43     if isinstance(index, int):
     44         return struct.pack("<I", index)
     45 
     46     raise ValueError(
     47         f"Unknown index type {type(index)} v:{index}, should be int or bytearray"
     48     )
     49 
     50 
     51 def to_int(xbytes):
     52     """Make little-endian integer from given bytes"""
     53     return int.from_bytes(xbytes, byteorder="little")
     54 
     55 
     56 def checksum(xhash, index):
     57     """
     58     This method takes a transaction hash and an index and returns a checksum.
     59 
     60     This checksum is based on 49 bits starting from the 12th byte of the
     61     reversed hash. Combined with the last 15 bits of the 4 byte index.
     62     """
     63     mask = 0xFFFFFFFFFFFF8000
     64     magic_start_position = 12
     65 
     66     hash_bytes = bytes.fromhex(xhash)[::-1]
     67     last_20_bytes = hash_bytes[magic_start_position:]
     68 
     69     assert len(hash_bytes) == 32
     70     assert index < 2**32
     71 
     72     hash_upper_49_bits = to_int(last_20_bytes) & mask
     73     index_lower_15_bits = index & ~mask
     74     return hash_upper_49_bits | index_lower_15_bits
     75 
     76 
     77 def make_tuple(row):
     78     kind, height, tx_hash, index, value = row
     79     return (
     80         kind,
     81         {
     82             "hash": tx_hash,
     83             "index": index
     84         },
     85         height,
     86         value,
     87         checksum(hash_to_hex_str(tx_hash), index),
     88     )
     89 
     90 
     91 def unpack_table(row_fmt, data):
     92     """Function to unpack table received from libbitcoin"""
     93     # Get the number of rows
     94     row_size = struct.calcsize(row_fmt)
     95     nrows = len(data) // row_size
     96     # Unpack
     97     rows = []
     98     for idx in range(nrows):
     99         offset = idx * row_size
    100         row = struct.unpack_from(row_fmt, data, offset)
    101         rows.append(row)
    102     return rows
    103 
    104 
    105 class ClientSettings:
    106     """Class implementing ZMQ client settings"""
    107 
    108     def __init__(self, timeout=10, context=None, loop=None):
    109         self._timeout = timeout
    110         self._context = context
    111         self._loop = loop
    112 
    113     @property
    114     def context(self):
    115         """ZMQ context property"""
    116         if not self._context:
    117             ctx = zmq.asyncio.Context()
    118             ctx.linger = 500  # in milliseconds
    119             self._context = ctx
    120         return self._context
    121 
    122     @context.setter
    123     def context(self, context):
    124         self._context = context  # pragma: no cover
    125 
    126     @property
    127     def timeout(self):
    128         """Set to None for no timeout"""
    129         return self._timeout
    130 
    131     @timeout.setter
    132     def timeout(self, timeout):
    133         self._timeout = timeout  # pragma: no cover
    134 
    135 
    136 class Request:
    137     """Class implementing a _send_ request.
    138     This is either a simple request/response affair or a subscription.
    139     """
    140 
    141     def __init__(self, socket, command, data):
    142         self.id_ = create_random_id()
    143         self.socket = socket
    144         self.command = command
    145         self.data = data
    146         self.future = asyncio.Future()
    147         self.queue = None
    148 
    149     async def send(self):
    150         """Send the ZMQ request"""
    151         request = [self.command, struct.pack("<I", self.id_), self.data]
    152         await self.socket.send_multipart(request)
    153 
    154     def is_subscription(self):
    155         """If the request is a subscription, then the response to this
    156         request is a notification.
    157         """
    158         return self.queue is not None
    159 
    160     def __str__(self):
    161         return "Request(command, ID) {}, {:d}".format(self.command, self.id_)
    162 
    163 
    164 class InvalidServerResponseException(Exception):
    165     """Exception for invalid server responses"""
    166 
    167 
    168 class Response:
    169     """Class implementing a request response"""
    170 
    171     def __init__(self, frame):
    172         if len(frame) != 3:
    173             raise InvalidServerResponseException(
    174                 f"Length of the frame was not 3: {len(frame)}")
    175 
    176         self.command = frame[0]
    177         self.request_id = struct.unpack("<I", frame[1])[0]
    178         error_code = struct.unpack("<I", frame[2][:4])[0]
    179         self.error_code = make_error_code(error_code)
    180         self.data = frame[2][4:]
    181 
    182     def is_bound_for_queue(self):
    183         return len(self.data) > 0
    184 
    185     def __str__(self):
    186         return (
    187             "Response(command, request ID, error code, data):" +
    188             f" {self.command}, {self.request_id}, {self.error_code}, {self.data}"
    189         )
    190 
    191 
    192 class RequestCollection:
    193     """RequestCollection carries a list of Requests and matches incoming
    194     responses to them.
    195     """
    196 
    197     def __init__(self, socket, loop):
    198         self._socket = socket
    199         self._requests = {}
    200         self._task = asyncio.ensure_future(self._run(), loop=loop)
    201 
    202     async def _run(self):
    203         while True:
    204             await self._receive()
    205 
    206     async def stop(self):
    207         """Stops listening for incoming responses (or subscription) messages.
    208         Returns the number of _responses_ expected but which are now dropped
    209         on the floor.
    210         """
    211         self._task.cancel()
    212         try:
    213             await self._task
    214         except asyncio.CancelledError:
    215             return len(self._requests)
    216 
    217     async def _receive(self):
    218         frame = await self._socket.recv_multipart()
    219         response = Response(frame)
    220 
    221         if response.request_id in self._requests:
    222             self._handle_response(response)
    223         else:
    224             print("DEBUG; RequestCollection unhandled response %s:%s" %
    225                   (response.command, response.request_id))
    226 
    227     def _handle_response(self, response):
    228         request = self._requests[response.request_id]
    229 
    230         if request.is_subscription():
    231             if response.is_bound_for_queue():
    232                 # TODO: decode the data into something usable
    233                 request.queue.put_nowait(response.data)
    234             else:
    235                 request.future.set_result(response)
    236         else:
    237             self.delete_request(request)
    238             request.future.set_result(response)
    239 
    240     def add_request(self, request):
    241         # TODO: we should maybe check if the request.id_ is unique
    242         self._requests[request.id_] = request
    243 
    244     def delete_request(self, request):
    245         del self._requests[request.id_]
    246 
    247 
    248 class Client:
    249     """This class represents a connection to a libbitcoin server."""
    250 
    251     def __init__(self, log, endpoints, loop):
    252         self.log = log
    253         self._endpoints = endpoints
    254         self._settings = ClientSettings(loop=loop)
    255         self._query_socket = self._create_query_socket()
    256         self._block_socket = self._create_block_socket()
    257         self._request_collection = RequestCollection(self._query_socket,
    258                                                      self._settings._loop)
    259 
    260     async def stop(self):
    261         self.log.debug("zmq Client.stop()")
    262         self._query_socket.close()
    263         self._block_socket.close()
    264         return await self._request_collection.stop()
    265 
    266     def _create_block_socket(self):
    267         socket = self._settings.context.socket(
    268             zmq.SUB,  # pylint: disable=E1101
    269             io_loop=self._settings._loop,  # pylint: disable=W0212
    270         )
    271         socket.connect(self._endpoints["block"])
    272         socket.setsockopt_string(zmq.SUBSCRIBE, "")  # pylint: disable=E1101
    273         return socket
    274 
    275     def _create_query_socket(self):
    276         socket = self._settings.context.socket(
    277             zmq.DEALER,  # pylint: disable=E1101
    278             io_loop=self._settings._loop,  # pylint: disable=W0212
    279         )
    280         socket.connect(self._endpoints["query"])
    281         return socket
    282 
    283     async def _subscription_request(self, command, data, queue):
    284         request = await self._request(command, data)
    285         request.queue = queue
    286         error_code, _ = await self._wait_for_response(request)
    287         return error_code
    288 
    289     async def _simple_request(self, command, data):
    290         return await self._wait_for_response(await self._request(command, data))
    291 
    292     async def _request(self, command, data):
    293         """Make a generic request. Both options are byte objects specified
    294         like b'blockchain.fetch_block_header' as an example.
    295         """
    296         request = Request(self._query_socket, command, data)
    297         await request.send()
    298         self._request_collection.add_request(request)
    299         return request
    300 
    301     async def _wait_for_response(self, request):
    302         try:
    303             response = await asyncio.wait_for(request.future,
    304                                               self._settings.timeout)
    305         except asyncio.TimeoutError:
    306             self._request_collection.delete_request(request)
    307             return ZMQError.channel_timeout, None
    308 
    309         assert response.command == request.command
    310         assert response.request_id == request.id_
    311         return response.error_code, response.data
    312 
    313     async def server_version(self):
    314         """Get the libbitcoin-server version"""
    315         command = b"server.version"
    316         error_code, data = await self._simple_request(command, b"")
    317         if error_code:
    318             return error_code, None
    319         return error_code, data
    320 
    321     async def fetch_last_height(self):
    322         """Fetch the blockchain tip and return integer height"""
    323         command = b"blockchain.fetch_last_height"
    324         error_code, data = await self._simple_request(command, b"")
    325         if error_code:
    326             return error_code, None
    327         return error_code, struct.unpack("<I", data)[0]
    328 
    329     async def fetch_block_header(self, index):
    330         """Fetch a block header by its height or integer index"""
    331         command = b"blockchain.fetch_block_header"
    332         data = pack_block_index(index)
    333         return await self._simple_request(command, data)
    334 
    335     async def fetch_block_transaction_hashes(self, index):
    336         """Fetch transaction hashes in a block at height index"""
    337         command = b"blockchain.fetch_block_transaction_hashes"
    338         data = pack_block_index(index)
    339         error_code, data = await self._simple_request(command, data)
    340         if error_code:
    341             return error_code, None
    342         return error_code, unpack_table("32s", data)
    343 
    344     async def fetch_blockchain_transaction(self, txid):
    345         """Fetch transaction by txid (not including mempool)"""
    346         command = b"blockchain.fetch_transaction2"
    347         error_code, data = await self._simple_request(command,
    348                                                       bytes.fromhex(txid)[::-1])
    349         if error_code:
    350             return error_code, None
    351         return error_code, data
    352 
    353     async def fetch_mempool_transaction(self, txid):
    354         """Fetch transaction by txid (including mempool)"""
    355         command = b"transaction_pool.fetch_transaction2"
    356         error_code, data = await self._simple_request(command,
    357                                                       bytes.fromhex(txid)[::-1])
    358         if error_code:
    359             return error_code, None
    360         return error_code, data
    361 
    362     async def subscribe_scripthash(self, scripthash, queue):
    363         """Subscribe to scripthash"""
    364         command = b"subscribe.key"
    365         decoded_address = unhexlify(scripthash)
    366         return await self._subscription_request(command, decoded_address, queue)
    367 
    368     async def unsubscribe_scripthash(self, scripthash):
    369         """Unsubscribe scripthash"""
    370         # This call solicits a final call from the server with an
    371         # `error::service_stopped` error code.
    372         command = b"unsubscribe.key"
    373         decoded_address = unhexlify(scripthash)
    374         return await self._simple_request(command, decoded_address)
    375 
    376     async def fetch_history4(self, scripthash, height=0):
    377         """Fetch history for given scripthash"""
    378         command = b"blockchain.fetch_history4"
    379         decoded_address = unhexlify(scripthash)
    380         error_code, raw_points = await self._simple_request(
    381             command, decoded_address + struct.pack("<I", height))
    382         if error_code:
    383             return error_code, None
    384 
    385         rows = unpack_table("<BI32sIQ", raw_points)
    386         points = [make_tuple(row) for row in rows]
    387         correlated_points = Client.__correlate(points)
    388         # self.log.debug("history points: %s", points)
    389         # self.log.debug("history correlated: %s", correlated_points)
    390 
    391         # BUG: In libbitcoin v4 sometimes transactions mess up and double
    392         # https://github.com/libbitcoin/libbitcoin-server/issues/545
    393         #
    394         # The following is not a very efficient solution
    395         correlated = [
    396             i for n, i in enumerate(correlated_points)
    397             if i not in correlated_points[n + 1:]
    398         ]
    399         return error_code, self._sort_correlated_points(correlated)
    400 
    401     @staticmethod
    402     def _sort_correlated_points(points):
    403         """Sort by ascending height"""
    404         if len(points) < 2:
    405             return points
    406         return sorted(points, key=lambda x: list(x.values())[0]["height"])
    407 
    408     async def broadcast_transaction(self, rawtx):
    409         """Broadcast given raw transaction"""
    410         command = b"transaction_pool.broadcast"
    411         return await self._simple_request(command, rawtx)
    412 
    413     async def fetch_balance(self, scripthash):
    414         """Fetch balance for given scripthash"""
    415         error_code, history = await self.fetch_history4(scripthash)
    416         if error_code:
    417             return error_code, None
    418 
    419         utxo = Client.__receives_without_spends(history)
    420 
    421         return error_code, (
    422             # confirmed
    423             functools.reduce(
    424                 lambda accumulator, point: accumulator + point["value"]
    425                 if point["received"]["height"] != 4294967295 else 0,
    426                 utxo,
    427                 0,
    428             ),
    429             # unconfirmed
    430             functools.reduce(
    431                 lambda accumulator, point: accumulator + point["value"]
    432                 if point["received"]["height"] == 4294967295 else 0,
    433                 utxo,
    434                 0,
    435             ),
    436         )
    437 
    438     async def fetch_utxo(self, scripthash):
    439         """Find UTXO for given scripthash"""
    440         error_code, history = await self.fetch_history4(scripthash)
    441         if error_code:
    442             return error_code, None
    443         return error_code, Client.__receives_without_spends(history)
    444 
    445     async def subscribe_to_blocks(self, queue):
    446         asyncio.ensure_future(self._listen_for_blocks(queue))
    447         return queue
    448 
    449     async def _listen_for_blocks(self, queue):
    450         """Infinite loop for block subscription.
    451         Returns raw blocks as they're received.
    452         """
    453         while True:
    454             frame = await self._block_socket.recv_multipart()
    455             seq = struct.unpack("<H", frame[0])[0]
    456             height = struct.unpack("<I", frame[1])[0]
    457             block_data = frame[2]
    458             queue.put_nowait((seq, height, block_data))
    459 
    460     @staticmethod
    461     def __receives_without_spends(history):
    462         return (point for point in history if "spent" not in point)
    463 
    464     @staticmethod
    465     def __correlate(points):
    466         transfers, checksum_to_index = Client.__find_receives(points)
    467         transfers = Client.__correlate_spends_to_receives(
    468             points, transfers, checksum_to_index)
    469         return transfers
    470 
    471     @staticmethod
    472     def __correlate_spends_to_receives(points, transfers, checksum_to_index):
    473         for point in points:
    474             if point[0] == 1:  # receive
    475                 continue
    476 
    477             spent = {
    478                 "hash": point[1]["hash"],
    479                 "height": point[2],
    480                 "index": point[1]["index"],
    481             }
    482             if point[3] not in checksum_to_index:
    483                 transfers.append({"spent": spent})
    484             else:
    485                 transfers[checksum_to_index[point[3]]]["spent"] = spent
    486 
    487         return transfers
    488 
    489     @staticmethod
    490     def __find_receives(points):
    491         transfers = []
    492         checksum_to_index = {}
    493 
    494         for point in points:
    495             if point[0] == 0:  # spent
    496                 continue
    497 
    498             transfers.append({
    499                 "received": {
    500                     "hash": point[1]["hash"],
    501                     "height": point[2],
    502                     "index": point[1]["index"],
    503                 },
    504                 "value": point[3],
    505             })
    506 
    507             checksum_to_index[point[4]] = len(transfers) - 1
    508 
    509         return transfers, checksum_to_index