electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 238f3c949ca2d23cbe4cc4e5a1ccd15371050663
parent 0eab1692d6a64eacdc13f2b3568a1453ce1c3761
Author: ThomasV <thomasv@electrum.org>
Date:   Thu, 27 Jun 2019 09:03:34 +0200

get rid of sql_alchemy

Diffstat:
Mcontrib/requirements/requirements.txt | 1-
Melectrum/channel_db.py | 174+++++++++++++++++++++++++++++++++++++++----------------------------------------
Melectrum/lnwatcher.py | 100+++++++++++++++++++++++++++++++++++++++++++++++--------------------------------
Melectrum/sql_db.py | 25+++++++------------------
4 files changed, 153 insertions(+), 147 deletions(-)

diff --git a/contrib/requirements/requirements.txt b/contrib/requirements/requirements.txt @@ -11,4 +11,3 @@ aiohttp_socks certifi bitstring pycryptodomex>=3.7 -sqlalchemy>=1.3.0b3 diff --git a/electrum/channel_db.py b/electrum/channel_db.py @@ -36,10 +36,6 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK import binascii import base64 -from sqlalchemy import Column, ForeignKey, Integer, String, Boolean -from sqlalchemy.orm.query import Query -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.sql import not_, or_ from .sql_db import SqlDB, sql from . import constants @@ -66,7 +62,6 @@ def validate_features(features : int): if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0: raise UnknownEvenFeatureBits() -Base = declarative_base() FLAG_DISABLE = 1 << 1 FLAG_DIRECTION = 1 << 0 @@ -193,57 +188,45 @@ class Address(NamedTuple): port: int last_connected_date: int - -class ChannelInfoBase(Base): - __tablename__ = 'channel_info' - short_channel_id = Column(String(64), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') - node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) - node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) - capacity_sat = Column(Integer) - def to_nametuple(self): - return ChannelInfo( - short_channel_id=self.short_channel_id, - node1_id=self.node1_id, - node2_id=self.node2_id, - capacity_sat=self.capacity_sat - ) - -class PolicyBase(Base): - __tablename__ = 'policy' - key = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') - cltv_expiry_delta = Column(Integer, nullable=False) - htlc_minimum_msat = Column(Integer, nullable=False) - htlc_maximum_msat = Column(Integer) - fee_base_msat = Column(Integer, nullable=False) - fee_proportional_millionths = Column(Integer, nullable=False) - channel_flags = Column(Integer, nullable=False) - timestamp = Column(Integer, nullable=False) - - def to_nametuple(self): - return Policy( - key=self.key, - cltv_expiry_delta=self.cltv_expiry_delta, - htlc_minimum_msat=self.htlc_minimum_msat, - htlc_maximum_msat=self.htlc_maximum_msat, - fee_base_msat= self.fee_base_msat, - fee_proportional_millionths = self.fee_proportional_millionths, - channel_flags=self.channel_flags, - timestamp=self.timestamp - ) - -class NodeInfoBase(Base): - __tablename__ = 'node_info' - node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') - features = Column(Integer, nullable=False) - timestamp = Column(Integer, nullable=False) - alias = Column(String(64), nullable=False) - -class AddressBase(Base): - __tablename__ = 'address' - node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') - host = Column(String(256)) - port = Column(Integer) - last_connected_date = Column(Integer(), nullable=True) +create_channel_info = """ +CREATE TABLE IF NOT EXISTS channel_info ( +short_channel_id VARCHAR(64), +node1_id VARCHAR(66), +node2_id VARCHAR(66), +capacity_sat INTEGER, +PRIMARY KEY(short_channel_id) +)""" + +create_policy = """ +CREATE TABLE IF NOT EXISTS policy ( +key VARCHAR(66), +cltv_expiry_delta INTEGER NOT NULL, +htlc_minimum_msat INTEGER NOT NULL, +htlc_maximum_msat INTEGER, +fee_base_msat INTEGER NOT NULL, +fee_proportional_millionths INTEGER NOT NULL, +channel_flags INTEGER NOT NULL, +timestamp INTEGER NOT NULL, +PRIMARY KEY(key) +)""" + +create_address = """ +CREATE TABLE IF NOT EXISTS address ( +node_id VARCHAR(66), +host STRING(256), +port INTEGER NOT NULL, +timestamp INTEGER, +PRIMARY KEY(node_id, host, port) +)""" + +create_node_info = """ +CREATE TABLE IF NOT EXISTS node_info ( +node_id VARCHAR(66), +features INTEGER NOT NULL, +timestamp INTEGER NOT NULL, +alias STRING(64), +PRIMARY KEY(node_id) +)""" class ChannelDB(SqlDB): @@ -252,7 +235,7 @@ class ChannelDB(SqlDB): def __init__(self, network: 'Network'): path = os.path.join(get_headers_dir(network.config), 'channel_db') - super().__init__(network, path, Base, commit_interval=100) + super().__init__(network, path, commit_interval=100) self.num_nodes = 0 self.num_channels = 0 self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] @@ -276,16 +259,7 @@ class ChannelDB(SqlDB): now = int(time.time()) node_id = peer.pubkey self._addresses[node_id].add((peer.host, peer.port, now)) - self.save_address(node_id, peer, now) - - @sql - def save_address(self, node_id, peer, now): - addr = self.DBSession.query(AddressBase).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none() - if addr: - addr.last_connected_date = now - else: - addr = AddressBase(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now) - self.DBSession.add(addr) + self.save_node_address(node_id, peer, now) def get_200_randomly_sorted_nodes_not_in(self, node_ids): unshuffled = set(self._nodes.keys()) - node_ids @@ -394,17 +368,47 @@ class ChannelDB(SqlDB): orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False) assert len(good) == 1 + def create_database(self): + c = self.conn.cursor() + c.execute(create_node_info) + c.execute(create_address) + c.execute(create_policy) + c.execute(create_channel_info) + self.conn.commit() + @sql def save_policy(self, policy): - self.DBSession.execute(PolicyBase.__table__.insert().values(policy)) + c = self.conn.cursor() + c.execute("""REPLACE INTO policy (key, cltv_expiry_delta, htlc_minimum_msat, htlc_maximum_msat, fee_base_msat, fee_proportional_millionths, channel_flags, timestamp) VALUES (?,?,?,?,?,?, ?, ?)""", list(policy)) @sql def delete_policy(self, short_channel_id, node_id): - self.DBSession.execute(PolicyBase.__table__.delete().values(policy)) + c = self.conn.cursor() + c.execute("""DELETE FROM policy WHERE key=?""", (key,)) @sql def save_channel(self, channel_info): - self.DBSession.execute(ChannelInfoBase.__table__.insert().values(channel_info)) + c = self.conn.cursor() + c.execute("REPLACE INTO channel_info (short_channel_id, node1_id, node2_id, capacity_sat) VALUES (?,?,?,?)", list(channel_info)) + + @sql + def save_node(self, node_info): + c = self.conn.cursor() + c.execute("REPLACE INTO node_info (node_id, features, timestamp, alias) VALUES (?,?,?,?)", list(node_info)) + + @sql + def save_node_address(self, node_id, peer, now): + c = self.conn.cursor() + c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (node_id, peer.host, peer.port, now)) + + @sql + def save_node_addresses(self, node_id, node_addresses): + c = self.conn.cursor() + for addr in node_addresses: + c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.node_id, addr.host, addr.port)) + r = c.fetchall() + if r == []: + c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.node_id, addr.host, addr.port, 0)) def verify_channel_update(self, payload): short_channel_id = payload['short_channel_id'] @@ -418,7 +422,6 @@ class ChannelDB(SqlDB): msg_payloads = [msg_payloads] old_addr = None new_nodes = {} - new_addresses = {} for msg_payload in msg_payloads: try: node_info, node_addresses = NodeInfo.from_msg(msg_payload) @@ -445,17 +448,6 @@ class ChannelDB(SqlDB): self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) self.update_counts() - @sql - def save_node_addresses(self, node_if, node_addresses): - for new_addr in node_addresses: - old_addr = self.DBSession.query(AddressBase).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none() - if not old_addr: - self.DBSession.execute(AddressBase.__table__.insert().values(new_addr)) - - @sql - def save_node(self, node_info): - self.DBSession.execute(NodeInfoBase.__table__.insert().values(node_info)) - def get_routing_policy_for_channel(self, start_node_id: bytes, short_channel_id: bytes) -> Optional[bytes]: if not start_node_id or not short_channel_id: return None @@ -506,12 +498,18 @@ class ChannelDB(SqlDB): @sql @profiler def load_data(self): - for x in self.DBSession.query(AddressBase).all(): - self._addresses[x.node_id].add((str(x.host), int(x.port), int(x.last_connected_date or 0))) - for x in self.DBSession.query(ChannelInfoBase).all(): - self._channels[x.short_channel_id] = x.to_nametuple() - for x in self.DBSession.query(PolicyBase).filter_by().all(): - p = x.to_nametuple() + c = self.conn.cursor() + c.execute("""SELECT * FROM address""") + for x in c: + node_id, host, port, timestamp = x + self._addresses[node_id].add((str(host), int(port), int(timestamp or 0))) + c.execute("""SELECT * FROM channel_info""") + for x in c: + ci = ChannelInfo(*x) + self._channels[ci.short_channel_id] = ci + c.execute("""SELECT * FROM policy""") + for x in c: + p = Policy(*x) self._policies[(p.start_node, p.short_channel_id)] = p for channel_info in self._channels.values(): self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py @@ -13,12 +13,7 @@ from enum import IntEnum, auto from typing import NamedTuple, Dict import jsonrpclib -from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean -from sqlalchemy.orm.query import Query -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.sql import not_, or_ from .sql_db import SqlDB, sql - from .util import bh2u, bfh, log_exceptions, ignore_exceptions from . import wallet from .storage import WalletStorage @@ -42,80 +37,105 @@ class TxMinedDepth(IntEnum): FREE = auto() -Base = declarative_base() - -class SweepTx(Base): - __tablename__ = 'sweep_txs' - funding_outpoint = Column(String(34), primary_key=True) - index = Column(Integer(), primary_key=True) - prevout = Column(String(34)) - tx = Column(String()) - -class ChannelInfo(Base): - __tablename__ = 'channel_info' - outpoint = Column(String(34), primary_key=True) - address = Column(String(32)) +create_sweep_txs=""" +CREATE TABLE IF NOT EXISTS sweep_txs ( +funding_outpoint VARCHAR(34) NOT NULL, +"index" INTEGER NOT NULL, +prevout VARCHAR(34), +tx VARCHAR, +PRIMARY KEY(funding_outpoint, "index") +)""" +create_channel_info=""" +CREATE TABLE IF NOT EXISTS channel_info ( +outpoint VARCHAR(34) NOT NULL, +address VARCHAR(32), +PRIMARY KEY(outpoint) +)""" class SweepStore(SqlDB): def __init__(self, path, network): - super().__init__(network, path, Base) + super().__init__(network, path) + + def create_database(self): + c = self.conn.cursor() + c.execute(create_channel_info) + c.execute(create_sweep_txs) + self.conn.commit() @sql def get_sweep_tx(self, funding_outpoint, prevout): - return [Transaction(bh2u(r.tx)) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prevout==prevout).all()] + c = self.conn.cursor() + c.execute("SELECT tx FROM sweep_txs WHERE funding_outpoint=? AND prevout=?", (funding_outpoint, prevout)) + return [Transaction(bh2u(r[0])) for r in c.fetchall()] @sql def get_tx_by_index(self, funding_outpoint, index): - r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none() - return str(r.prevout), bh2u(r.tx) + c = self.conn.cursor() + c.execute("""SELECT prevout, tx FROM sweep_txs WHERE funding_outpoint=? AND "index"=?""", (funding_outpoint, index)) + r = c.fetchone()[0] + return str(r[0]), bh2u(r[1]) @sql def list_sweep_tx(self): - return set(str(r.funding_outpoint) for r in self.DBSession.query(SweepTx).all()) + c = self.conn.cursor() + c.execute("SELECT funding_outpoint FROM sweep_txs") + return set([r[0] for r in c.fetchall()]) @sql def add_sweep_tx(self, funding_outpoint, prevout, tx): - n = self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count() - self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, index=n, prevout=prevout, tx=bfh(tx))) - self.DBSession.commit() + c = self.conn.cursor() + c.execute("SELECT count(*) FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,)) + n = int(c.fetchone()[0]) + c.execute("""INSERT INTO sweep_txs (funding_outpoint, "index", prevout, tx) VALUES (?,?,?,?)""", (funding_outpoint, n, prevout, bfh(str(tx)))) + self.conn.commit() @sql def get_num_tx(self, funding_outpoint): - return int(self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()) + c = self.conn.cursor() + c.execute("SELECT count(*) FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,)) + return int(c.fetchone()[0]) @sql def remove_sweep_tx(self, funding_outpoint): - r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all() - for x in r: - self.DBSession.delete(x) - self.DBSession.commit() + c = self.conn.cursor() + c.execute("DELETE FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,)) + self.conn.commit() @sql def add_channel(self, outpoint, address): - self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint)) - self.DBSession.commit() + c = self.conn.cursor() + c.execute("INSERT INTO channel_info (address, outpoint) VALUES (?,?)", (address, outpoint)) + self.conn.commit() @sql def remove_channel(self, outpoint): - v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none() - self.DBSession.delete(v) - self.DBSession.commit() + c = self.conn.cursor() + c.execute("DELETE FROM channel_info WHERE outpoint=?", (outpoint,)) + self.conn.commit() @sql def has_channel(self, outpoint): - return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()) + c = self.conn.cursor() + c.execute("SELECT * FROM channel_info WHERE outpoint=?", (outpoint,)) + r = c.fetchone() + return r is not None @sql def get_address(self, outpoint): - r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none() - return str(r.address) if r else None + c = self.conn.cursor() + c.execute("SELECT address FROM channel_info WHERE outpoint=?", (outpoint,)) + r = c.fetchone() + return r[0] if r else None @sql def list_channel_info(self): - return [(str(r.address), str(r.outpoint)) for r in self.DBSession.query(ChannelInfo).all()] + c = self.conn.cursor() + c.execute("SELECT address, outpoint FROM channel_info") + return [(r[0], r[1]) for r in c.fetchall()] + class LNWatcher(AddressSynchronizer): diff --git a/electrum/sql_db.py b/electrum/sql_db.py @@ -3,18 +3,11 @@ import concurrent import queue import threading import asyncio - -from sqlalchemy import create_engine -from sqlalchemy.pool import StaticPool -from sqlalchemy.orm import sessionmaker +import sqlite3 from .logging import Logger -# https://stackoverflow.com/questions/26971050/sqlalchemy-sqlite-too-many-sql-variables -SQLITE_LIMIT_VARIABLE_NUMBER = 999 - - def sql(func): """wrapper for sql methods""" def wrapper(self, *args, **kwargs): @@ -26,9 +19,8 @@ def sql(func): class SqlDB(Logger): - def __init__(self, network, path, base, commit_interval=None): + def __init__(self, network, path, commit_interval=None): Logger.__init__(self) - self.base = base self.network = network self.path = path self.commit_interval = commit_interval @@ -37,13 +29,10 @@ class SqlDB(Logger): self.sql_thread.start() def run_sql(self): - #return self.logger.info("SQL thread started") - engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True) - DBSession = sessionmaker(bind=engine, autoflush=False) - if not os.path.exists(self.path): - self.base.metadata.create_all(engine) - self.DBSession = DBSession() + self.conn = sqlite3.connect(self.path) + self.logger.info("Creating database") + self.create_database() i = 0 while self.network.asyncio_loop.is_running(): try: @@ -62,7 +51,7 @@ class SqlDB(Logger): if self.commit_interval: i = (i + 1) % self.commit_interval if i == 0: - self.DBSession.commit() + self.conn.commit() # write - self.DBSession.commit() + self.conn.commit() self.logger.info("SQL thread terminated")