electrum-personal-server

Maximally lightweight electrum server for a single user
git clone https://git.parazyd.org/electrum-personal-server
Log | Files | Refs | README

commit 76cbd8a2c618bc1b2fe72770fad76a2f324b5a8d
parent b38006736d604f06e3feb9e988bf20362e382c69
Author: chris-belcher <chris-belcher@users.noreply.github.com>
Date:   Sat, 28 Dec 2019 14:51:16 +0000

Add address and header sync to protocol tests

Diffstat:
Melectrumpersonalserver/server/common.py | 18+++++++++++++++---
Melectrumpersonalserver/server/electrumprotocol.py | 20++++++--------------
Mtest/test_electrum_protocol.py | 167++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------------
3 files changed, 152 insertions(+), 53 deletions(-)

diff --git a/electrumpersonalserver/server/common.py b/electrumpersonalserver/server/common.py @@ -7,6 +7,7 @@ import os.path import logging import tempfile import platform +import json from configparser import RawConfigParser, NoSectionError, NoOptionError from ipaddress import ip_network, ip_address @@ -126,9 +127,14 @@ def run_electrum_server(rpc, txmonitor, config): except (ConnectionRefusedError, ssl.SSLError): sock.close() sock = None - logger.info('Electrum connected from ' + str(addr[0])) - protocol.set_send_line_fun(lambda l: sock.sendall(l + b'\n')) + + def send_reply_fun(reply): + line = json.dumps(reply) + sock.sendall(line.encode('utf-8') + b'\n') + logger.debug('<= ' + line) + protocol.set_send_reply_fun(send_reply_fun) + sock.settimeout(poll_interval_connected) recv_buffer = bytearray() while True: @@ -144,7 +150,13 @@ def run_electrum_server(rpc, txmonitor, config): line = recv_buffer[:lb].rstrip() recv_buffer = recv_buffer[lb + 1:] lb = recv_buffer.find(b'\n') - protocol.handle_query(line.decode("utf-8")) + line = line.decode("utf-8") + logger.debug("=> " + line) + try: + query = json.loads(line) + except json.decoder.JSONDecodeError as e: + raise IOError(repr(e)) + protocol.handle_query(query) except socket.timeout: on_heartbeat_connected(rpc, txmonitor, protocol) except (IOError, EOFError) as e: diff --git a/electrumpersonalserver/server/electrumprotocol.py b/electrumpersonalserver/server/electrumprotocol.py @@ -123,8 +123,8 @@ class ElectrumProtocol(object): self.are_headers_raw = False self.txid_blockhash_map = {} - def set_send_line_fun(self, send_line_fun): - self.send_line_fun = send_line_fun + def set_send_reply_fun(self, send_reply_fun): + self.send_reply_fun = send_reply_fun def on_blockchain_tip_updated(self, header): if self.subscribed_to_headers: @@ -145,25 +145,17 @@ class ElectrumProtocol(object): def _send_response(self, query, result): response = {"jsonrpc": "2.0", "result": result, "id": query["id"]} - self.send_line_fun(json.dumps(response).encode('utf-8')) - self.logger.debug('<= ' + json.dumps(response)) + self.send_reply_fun(response) def _send_update(self, update): update["jsonrpc"] = "2.0" - self.send_line_fun(json.dumps(update).encode('utf-8')) - self.logger.debug('<= ' + json.dumps(update)) + self.send_reply_fun(update) def _send_error(self, nid, error): payload = {"error": error, "jsonrpc": "2.0", "id": nid} - self.send_line_fun(json.dumps(payload).encode('utf-8')) - self.logger.debug('<= ' + json.dumps(payload)) + self.send_reply_fun(payload) - def handle_query(self, line): - self.logger.debug("=> " + line) - try: - query = json.loads(line) - except json.decoder.JSONDecodeError as e: - raise IOError(e) + def handle_query(self, query): if "method" not in query: raise IOError("Bad client query, no \"method\"") method = query["method"] diff --git a/test/test_electrum_protocol.py b/test/test_electrum_protocol.py @@ -10,22 +10,29 @@ from electrumpersonalserver.server import ( get_block_header, get_current_header, get_block_headers_hex, - JsonRpcError + JsonRpcError, + get_status_electrum ) logger = logging.getLogger('ELECTRUMPERSONALSERVER-TEST') logger.setLevel(logging.DEBUG) +DUMMY_JSONRPC_BLOCKCHAIN_HEIGHT = 100000 + def get_dummy_hash_from_height(height): + if height == 0: + return "00"*32 return str(height) + "a"*(64 - len(str(height))) def get_height_from_dummy_hash(hhash): + if hhash == "00"*32: + return 0 return int(hhash[:hhash.index("a")]) class DummyJsonRpc(object): def __init__(self): self.calls = {} - self.blockchain_height = 100000 + self.blockchain_height = DUMMY_JSONRPC_BLOCKCHAIN_HEIGHT def call(self, method, params): if method not in self.calls: @@ -58,9 +65,15 @@ class DummyJsonRpc(object): + "00000000000000000da", "nTx": 1, } - if height > 0: + if height > 1: header["previousblockhash"] = get_dummy_hash_from_height( height - 1) + elif height == 1: + header["previousblockhash"] = "00"*32 #genesis block + elif height == 0: + pass #no prevblock for genesis + else: + assert 0 if height < self.blockchain_height: header["nextblockhash"] = get_dummy_hash_from_height(height + 1) return header @@ -102,26 +115,29 @@ def test_get_current_header(): assert type(ret[1]) == dict assert len(ret[1]) == 7 -def test_get_block_headers_hex_out_of_bounds(): - rpc = DummyJsonRpc() - ret = get_block_headers_hex(rpc, rpc.blockchain_height + 10, 5) - assert len(ret) == 2 - assert ret[0] == "" - assert ret[1] == 0 - -def test_get_block_headers_hex(): +@pytest.mark.parametrize( + "start_height, count", + [(100, 200), + (DUMMY_JSONRPC_BLOCKCHAIN_HEIGHT + 10, 5), + (DUMMY_JSONRPC_BLOCKCHAIN_HEIGHT - 10, 15), + (0, 250) + ] +) +def test_get_block_headers_hex(start_height, count): rpc = DummyJsonRpc() - count = 200 - ret = get_block_headers_hex(rpc, 100, count) + ret = get_block_headers_hex(rpc, start_height, count) + print("start_height=" + str(start_height) + " count=" + str(count)) assert len(ret) == 2 - assert ret[1] == count - assert len(ret[0]) == count*80*2 #80 bytes per header, 2 chars per byte + available_blocks = -min(0, start_height - DUMMY_JSONRPC_BLOCKCHAIN_HEIGHT + - 1) + expected_count = min(available_blocks, count) + assert len(ret[0]) == expected_count*80*2 #80 bytes/header, 2 chars/byte + assert ret[1] == expected_count @pytest.mark.parametrize( "invalid_json_query", [ - "{\"invalid-json\":}", - "{\"valid-json-no-method\": 5}" + {"valid-json-no-method": 5} ] ) def test_invalid_json_query_line(invalid_json_query): @@ -129,43 +145,122 @@ def test_invalid_json_query_line(invalid_json_query): with pytest.raises(IOError) as e: protocol.handle_query(invalid_json_query) +def create_electrum_protocol_instance(broadcast_method="own-node", + tor_hostport=("127.0.0.1", 9050), + disable_mempool_fee_histogram=False): + protocol = ElectrumProtocol(DummyJsonRpc(), DummyTransactionMonitor(), + logger, broadcast_method, tor_hostport, disable_mempool_fee_histogram) + sent_replies = [] + protocol.set_send_reply_fun(lambda l: sent_replies.append(l)) + assert len(sent_replies) == 0 + return protocol, sent_replies + +def dummy_script_hash_to_history(scrhash): + index = int(scrhash[:scrhash.index("s")]) + tx_count = (index+2) % 5 + height = 500 + return [(index_to_dummy_txid(i), height) for i in range(tx_count)] + +def index_to_dummy_script_hash(index): + return str(index) + "s"*(64 - len(str(index))) + +def index_to_dummy_txid(index): + return str(index) + "t"*(64 - len(str(index))) + +def dummy_txid_to_dummy_tx(txid): + return txid[::-1] * 6 + class DummyTransactionMonitor(object): def __init__(self): self.deterministic_wallets = list(range(5)) self.address_history = list(range(5)) + self.subscribed_addresses = [] + self.history_hashes = {} def get_electrum_history_hash(self, scrhash): - pass + history = dummy_script_hash_to_history(scrhash) + hhash = get_status_electrum(history) + self.history_hashes[scrhash] = history + return hhash def get_electrum_history(self, scrhash): - pass + return self.history_hashes[scrhash] def unsubscribe_all_addresses(self): - pass + self.subscribed_addresses = [] def subscribe_address(self, scrhash): - pass + self.subscribed_addresses.append(scrhash) + return True def get_address_balance(self, scrhash): pass -def create_electrum_protocol_instance(broadcast_method="own-node", - tor_hostport=("127.0.0.01", 9050), - disable_mempool_fee_histogram=False): - protocol = ElectrumProtocol(DummyJsonRpc(), DummyTransactionMonitor(), - logger, broadcast_method, tor_hostport, disable_mempool_fee_histogram) - sent_lines = [] - protocol.set_send_line_fun(lambda l: sent_lines.append(json.loads( - l.decode()))) - return protocol, sent_lines +def test_script_hash_sync(): + protocol, sent_replies = create_electrum_protocol_instance() + scrhash_index = 0 + scrhash = index_to_dummy_script_hash(scrhash_index) + protocol.handle_query({"method": "blockchain.scripthash.subscribe", + "params": [scrhash], "id": 0}) + assert len(sent_replies) == 1 + assert len(protocol.txmonitor.subscribed_addresses) == 1 + assert protocol.txmonitor.subscribed_addresses[0] == scrhash + assert len(sent_replies) == 1 + assert len(sent_replies[0]["result"]) == 64 + history_hash = sent_replies[0]["result"] + + protocol.handle_query({"method": "blockchain.scripthash.get_history", + "params": [scrhash], "id": 0}) + assert len(sent_replies) == 2 + assert get_status_electrum(sent_replies[1]["result"]) == history_hash + + #updated scripthash but actually nothing changed, history_hash unchanged + protocol.on_updated_scripthashes([scrhash]) + assert len(sent_replies) == 3 + assert sent_replies[2]["method"] == "blockchain.scripthash.subscribe" + assert sent_replies[2]["params"][0] == scrhash + assert sent_replies[2]["params"][1] == history_hash + + protocol.on_disconnect() + assert len(protocol.txmonitor.subscribed_addresses) == 0 + +def test_headers_subscribe(): + protocol, sent_replies = create_electrum_protocol_instance() + + protocol.handle_query({"method": "server.version", "params": ["test-code", + 1.4], "id": 0}) #protocol version of 1.4 means only raw headers used + assert len(sent_replies) == 1 + + protocol.handle_query({"method": "blockchain.headers.subscribe", "params": + [], "id": 0}) + assert len(sent_replies) == 2 + assert "height" in sent_replies[1]["result"] + assert sent_replies[1]["result"]["height"] == protocol.rpc.blockchain_height + assert "hex" in sent_replies[1]["result"] + assert len(sent_replies[1]["result"]["hex"]) == 80*2 #80 b/header, 2 b/char + + protocol.rpc.blockchain_height += 1 + new_bestblockhash, header = get_current_header(protocol.rpc, + protocol.are_headers_raw) + protocol.on_blockchain_tip_updated(header) + assert len(sent_replies) == 3 + assert "method" in sent_replies[2] + assert sent_replies[2]["method"] == "blockchain.headers.subscribe" + assert "params" in sent_replies[2] + assert "height" in sent_replies[2]["params"][0] + assert sent_replies[2]["params"][0]["height"]\ + == protocol.rpc.blockchain_height + assert "hex" in sent_replies[2]["params"][0] + assert len(sent_replies[2]["params"][0]["hex"]) == 80*2 #80 b/header, 2 b/c def test_server_ping(): - protocol, sent_lines = create_electrum_protocol_instance() + protocol, sent_replies = create_electrum_protocol_instance() idd = 1 - protocol.handle_query(json.dumps({"method": "server.ping", "id": idd})) - assert len(sent_lines) == 1 - assert sent_lines[0]["result"] == None - assert sent_lines[0]["id"] == idd - + protocol.handle_query({"method": "server.ping", "id": idd}) + assert len(sent_replies) == 1 + assert sent_replies[0]["result"] == None + assert sent_replies[0]["id"] == idd +#test scripthash.subscribe, scripthash.get_history transaction.get +# transaction.get_merkle