commit 238f3c949ca2d23cbe4cc4e5a1ccd15371050663
parent 0eab1692d6a64eacdc13f2b3568a1453ce1c3761
Author: ThomasV <thomasv@electrum.org>
Date: Thu, 27 Jun 2019 09:03:34 +0200
get rid of sql_alchemy
Diffstat:
4 files changed, 153 insertions(+), 147 deletions(-)
diff --git a/contrib/requirements/requirements.txt b/contrib/requirements/requirements.txt
@@ -11,4 +11,3 @@ aiohttp_socks
certifi
bitstring
pycryptodomex>=3.7
-sqlalchemy>=1.3.0b3
diff --git a/electrum/channel_db.py b/electrum/channel_db.py
@@ -36,10 +36,6 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
import binascii
import base64
-from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
-from sqlalchemy.orm.query import Query
-from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.sql import not_, or_
from .sql_db import SqlDB, sql
from . import constants
@@ -66,7 +62,6 @@ def validate_features(features : int):
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
raise UnknownEvenFeatureBits()
-Base = declarative_base()
FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0
@@ -193,57 +188,45 @@ class Address(NamedTuple):
port: int
last_connected_date: int
-
-class ChannelInfoBase(Base):
- __tablename__ = 'channel_info'
- short_channel_id = Column(String(64), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
- node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
- node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
- capacity_sat = Column(Integer)
- def to_nametuple(self):
- return ChannelInfo(
- short_channel_id=self.short_channel_id,
- node1_id=self.node1_id,
- node2_id=self.node2_id,
- capacity_sat=self.capacity_sat
- )
-
-class PolicyBase(Base):
- __tablename__ = 'policy'
- key = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
- cltv_expiry_delta = Column(Integer, nullable=False)
- htlc_minimum_msat = Column(Integer, nullable=False)
- htlc_maximum_msat = Column(Integer)
- fee_base_msat = Column(Integer, nullable=False)
- fee_proportional_millionths = Column(Integer, nullable=False)
- channel_flags = Column(Integer, nullable=False)
- timestamp = Column(Integer, nullable=False)
-
- def to_nametuple(self):
- return Policy(
- key=self.key,
- cltv_expiry_delta=self.cltv_expiry_delta,
- htlc_minimum_msat=self.htlc_minimum_msat,
- htlc_maximum_msat=self.htlc_maximum_msat,
- fee_base_msat= self.fee_base_msat,
- fee_proportional_millionths = self.fee_proportional_millionths,
- channel_flags=self.channel_flags,
- timestamp=self.timestamp
- )
-
-class NodeInfoBase(Base):
- __tablename__ = 'node_info'
- node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
- features = Column(Integer, nullable=False)
- timestamp = Column(Integer, nullable=False)
- alias = Column(String(64), nullable=False)
-
-class AddressBase(Base):
- __tablename__ = 'address'
- node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
- host = Column(String(256))
- port = Column(Integer)
- last_connected_date = Column(Integer(), nullable=True)
+create_channel_info = """
+CREATE TABLE IF NOT EXISTS channel_info (
+short_channel_id VARCHAR(64),
+node1_id VARCHAR(66),
+node2_id VARCHAR(66),
+capacity_sat INTEGER,
+PRIMARY KEY(short_channel_id)
+)"""
+
+create_policy = """
+CREATE TABLE IF NOT EXISTS policy (
+key VARCHAR(66),
+cltv_expiry_delta INTEGER NOT NULL,
+htlc_minimum_msat INTEGER NOT NULL,
+htlc_maximum_msat INTEGER,
+fee_base_msat INTEGER NOT NULL,
+fee_proportional_millionths INTEGER NOT NULL,
+channel_flags INTEGER NOT NULL,
+timestamp INTEGER NOT NULL,
+PRIMARY KEY(key)
+)"""
+
+create_address = """
+CREATE TABLE IF NOT EXISTS address (
+node_id VARCHAR(66),
+host STRING(256),
+port INTEGER NOT NULL,
+timestamp INTEGER,
+PRIMARY KEY(node_id, host, port)
+)"""
+
+create_node_info = """
+CREATE TABLE IF NOT EXISTS node_info (
+node_id VARCHAR(66),
+features INTEGER NOT NULL,
+timestamp INTEGER NOT NULL,
+alias STRING(64),
+PRIMARY KEY(node_id)
+)"""
class ChannelDB(SqlDB):
@@ -252,7 +235,7 @@ class ChannelDB(SqlDB):
def __init__(self, network: 'Network'):
path = os.path.join(get_headers_dir(network.config), 'channel_db')
- super().__init__(network, path, Base, commit_interval=100)
+ super().__init__(network, path, commit_interval=100)
self.num_nodes = 0
self.num_channels = 0
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
@@ -276,16 +259,7 @@ class ChannelDB(SqlDB):
now = int(time.time())
node_id = peer.pubkey
self._addresses[node_id].add((peer.host, peer.port, now))
- self.save_address(node_id, peer, now)
-
- @sql
- def save_address(self, node_id, peer, now):
- addr = self.DBSession.query(AddressBase).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
- if addr:
- addr.last_connected_date = now
- else:
- addr = AddressBase(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
- self.DBSession.add(addr)
+ self.save_node_address(node_id, peer, now)
def get_200_randomly_sorted_nodes_not_in(self, node_ids):
unshuffled = set(self._nodes.keys()) - node_ids
@@ -394,17 +368,47 @@ class ChannelDB(SqlDB):
orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False)
assert len(good) == 1
+ def create_database(self):
+ c = self.conn.cursor()
+ c.execute(create_node_info)
+ c.execute(create_address)
+ c.execute(create_policy)
+ c.execute(create_channel_info)
+ self.conn.commit()
+
@sql
def save_policy(self, policy):
- self.DBSession.execute(PolicyBase.__table__.insert().values(policy))
+ c = self.conn.cursor()
+ c.execute("""REPLACE INTO policy (key, cltv_expiry_delta, htlc_minimum_msat, htlc_maximum_msat, fee_base_msat, fee_proportional_millionths, channel_flags, timestamp) VALUES (?,?,?,?,?,?, ?, ?)""", list(policy))
@sql
def delete_policy(self, short_channel_id, node_id):
- self.DBSession.execute(PolicyBase.__table__.delete().values(policy))
+ c = self.conn.cursor()
+ c.execute("""DELETE FROM policy WHERE key=?""", (key,))
@sql
def save_channel(self, channel_info):
- self.DBSession.execute(ChannelInfoBase.__table__.insert().values(channel_info))
+ c = self.conn.cursor()
+ c.execute("REPLACE INTO channel_info (short_channel_id, node1_id, node2_id, capacity_sat) VALUES (?,?,?,?)", list(channel_info))
+
+ @sql
+ def save_node(self, node_info):
+ c = self.conn.cursor()
+ c.execute("REPLACE INTO node_info (node_id, features, timestamp, alias) VALUES (?,?,?,?)", list(node_info))
+
+ @sql
+ def save_node_address(self, node_id, peer, now):
+ c = self.conn.cursor()
+ c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (node_id, peer.host, peer.port, now))
+
+ @sql
+ def save_node_addresses(self, node_id, node_addresses):
+ c = self.conn.cursor()
+ for addr in node_addresses:
+ c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.node_id, addr.host, addr.port))
+ r = c.fetchall()
+ if r == []:
+ c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.node_id, addr.host, addr.port, 0))
def verify_channel_update(self, payload):
short_channel_id = payload['short_channel_id']
@@ -418,7 +422,6 @@ class ChannelDB(SqlDB):
msg_payloads = [msg_payloads]
old_addr = None
new_nodes = {}
- new_addresses = {}
for msg_payload in msg_payloads:
try:
node_info, node_addresses = NodeInfo.from_msg(msg_payload)
@@ -445,17 +448,6 @@ class ChannelDB(SqlDB):
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
self.update_counts()
- @sql
- def save_node_addresses(self, node_if, node_addresses):
- for new_addr in node_addresses:
- old_addr = self.DBSession.query(AddressBase).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
- if not old_addr:
- self.DBSession.execute(AddressBase.__table__.insert().values(new_addr))
-
- @sql
- def save_node(self, node_info):
- self.DBSession.execute(NodeInfoBase.__table__.insert().values(node_info))
-
def get_routing_policy_for_channel(self, start_node_id: bytes,
short_channel_id: bytes) -> Optional[bytes]:
if not start_node_id or not short_channel_id: return None
@@ -506,12 +498,18 @@ class ChannelDB(SqlDB):
@sql
@profiler
def load_data(self):
- for x in self.DBSession.query(AddressBase).all():
- self._addresses[x.node_id].add((str(x.host), int(x.port), int(x.last_connected_date or 0)))
- for x in self.DBSession.query(ChannelInfoBase).all():
- self._channels[x.short_channel_id] = x.to_nametuple()
- for x in self.DBSession.query(PolicyBase).filter_by().all():
- p = x.to_nametuple()
+ c = self.conn.cursor()
+ c.execute("""SELECT * FROM address""")
+ for x in c:
+ node_id, host, port, timestamp = x
+ self._addresses[node_id].add((str(host), int(port), int(timestamp or 0)))
+ c.execute("""SELECT * FROM channel_info""")
+ for x in c:
+ ci = ChannelInfo(*x)
+ self._channels[ci.short_channel_id] = ci
+ c.execute("""SELECT * FROM policy""")
+ for x in c:
+ p = Policy(*x)
self._policies[(p.start_node, p.short_channel_id)] = p
for channel_info in self._channels.values():
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
@@ -13,12 +13,7 @@ from enum import IntEnum, auto
from typing import NamedTuple, Dict
import jsonrpclib
-from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
-from sqlalchemy.orm.query import Query
-from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.sql import not_, or_
from .sql_db import SqlDB, sql
-
from .util import bh2u, bfh, log_exceptions, ignore_exceptions
from . import wallet
from .storage import WalletStorage
@@ -42,80 +37,105 @@ class TxMinedDepth(IntEnum):
FREE = auto()
-Base = declarative_base()
-
-class SweepTx(Base):
- __tablename__ = 'sweep_txs'
- funding_outpoint = Column(String(34), primary_key=True)
- index = Column(Integer(), primary_key=True)
- prevout = Column(String(34))
- tx = Column(String())
-
-class ChannelInfo(Base):
- __tablename__ = 'channel_info'
- outpoint = Column(String(34), primary_key=True)
- address = Column(String(32))
+create_sweep_txs="""
+CREATE TABLE IF NOT EXISTS sweep_txs (
+funding_outpoint VARCHAR(34) NOT NULL,
+"index" INTEGER NOT NULL,
+prevout VARCHAR(34),
+tx VARCHAR,
+PRIMARY KEY(funding_outpoint, "index")
+)"""
+create_channel_info="""
+CREATE TABLE IF NOT EXISTS channel_info (
+outpoint VARCHAR(34) NOT NULL,
+address VARCHAR(32),
+PRIMARY KEY(outpoint)
+)"""
class SweepStore(SqlDB):
def __init__(self, path, network):
- super().__init__(network, path, Base)
+ super().__init__(network, path)
+
+ def create_database(self):
+ c = self.conn.cursor()
+ c.execute(create_channel_info)
+ c.execute(create_sweep_txs)
+ self.conn.commit()
@sql
def get_sweep_tx(self, funding_outpoint, prevout):
- return [Transaction(bh2u(r.tx)) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prevout==prevout).all()]
+ c = self.conn.cursor()
+ c.execute("SELECT tx FROM sweep_txs WHERE funding_outpoint=? AND prevout=?", (funding_outpoint, prevout))
+ return [Transaction(bh2u(r[0])) for r in c.fetchall()]
@sql
def get_tx_by_index(self, funding_outpoint, index):
- r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none()
- return str(r.prevout), bh2u(r.tx)
+ c = self.conn.cursor()
+ c.execute("""SELECT prevout, tx FROM sweep_txs WHERE funding_outpoint=? AND "index"=?""", (funding_outpoint, index))
+ r = c.fetchone()[0]
+ return str(r[0]), bh2u(r[1])
@sql
def list_sweep_tx(self):
- return set(str(r.funding_outpoint) for r in self.DBSession.query(SweepTx).all())
+ c = self.conn.cursor()
+ c.execute("SELECT funding_outpoint FROM sweep_txs")
+ return set([r[0] for r in c.fetchall()])
@sql
def add_sweep_tx(self, funding_outpoint, prevout, tx):
- n = self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
- self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, index=n, prevout=prevout, tx=bfh(tx)))
- self.DBSession.commit()
+ c = self.conn.cursor()
+ c.execute("SELECT count(*) FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
+ n = int(c.fetchone()[0])
+ c.execute("""INSERT INTO sweep_txs (funding_outpoint, "index", prevout, tx) VALUES (?,?,?,?)""", (funding_outpoint, n, prevout, bfh(str(tx))))
+ self.conn.commit()
@sql
def get_num_tx(self, funding_outpoint):
- return int(self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count())
+ c = self.conn.cursor()
+ c.execute("SELECT count(*) FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
+ return int(c.fetchone()[0])
@sql
def remove_sweep_tx(self, funding_outpoint):
- r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all()
- for x in r:
- self.DBSession.delete(x)
- self.DBSession.commit()
+ c = self.conn.cursor()
+ c.execute("DELETE FROM sweep_txs WHERE funding_outpoint=?", (funding_outpoint,))
+ self.conn.commit()
@sql
def add_channel(self, outpoint, address):
- self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint))
- self.DBSession.commit()
+ c = self.conn.cursor()
+ c.execute("INSERT INTO channel_info (address, outpoint) VALUES (?,?)", (address, outpoint))
+ self.conn.commit()
@sql
def remove_channel(self, outpoint):
- v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
- self.DBSession.delete(v)
- self.DBSession.commit()
+ c = self.conn.cursor()
+ c.execute("DELETE FROM channel_info WHERE outpoint=?", (outpoint,))
+ self.conn.commit()
@sql
def has_channel(self, outpoint):
- return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none())
+ c = self.conn.cursor()
+ c.execute("SELECT * FROM channel_info WHERE outpoint=?", (outpoint,))
+ r = c.fetchone()
+ return r is not None
@sql
def get_address(self, outpoint):
- r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
- return str(r.address) if r else None
+ c = self.conn.cursor()
+ c.execute("SELECT address FROM channel_info WHERE outpoint=?", (outpoint,))
+ r = c.fetchone()
+ return r[0] if r else None
@sql
def list_channel_info(self):
- return [(str(r.address), str(r.outpoint)) for r in self.DBSession.query(ChannelInfo).all()]
+ c = self.conn.cursor()
+ c.execute("SELECT address, outpoint FROM channel_info")
+ return [(r[0], r[1]) for r in c.fetchall()]
+
class LNWatcher(AddressSynchronizer):
diff --git a/electrum/sql_db.py b/electrum/sql_db.py
@@ -3,18 +3,11 @@ import concurrent
import queue
import threading
import asyncio
-
-from sqlalchemy import create_engine
-from sqlalchemy.pool import StaticPool
-from sqlalchemy.orm import sessionmaker
+import sqlite3
from .logging import Logger
-# https://stackoverflow.com/questions/26971050/sqlalchemy-sqlite-too-many-sql-variables
-SQLITE_LIMIT_VARIABLE_NUMBER = 999
-
-
def sql(func):
"""wrapper for sql methods"""
def wrapper(self, *args, **kwargs):
@@ -26,9 +19,8 @@ def sql(func):
class SqlDB(Logger):
- def __init__(self, network, path, base, commit_interval=None):
+ def __init__(self, network, path, commit_interval=None):
Logger.__init__(self)
- self.base = base
self.network = network
self.path = path
self.commit_interval = commit_interval
@@ -37,13 +29,10 @@ class SqlDB(Logger):
self.sql_thread.start()
def run_sql(self):
- #return
self.logger.info("SQL thread started")
- engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
- DBSession = sessionmaker(bind=engine, autoflush=False)
- if not os.path.exists(self.path):
- self.base.metadata.create_all(engine)
- self.DBSession = DBSession()
+ self.conn = sqlite3.connect(self.path)
+ self.logger.info("Creating database")
+ self.create_database()
i = 0
while self.network.asyncio_loop.is_running():
try:
@@ -62,7 +51,7 @@ class SqlDB(Logger):
if self.commit_interval:
i = (i + 1) % self.commit_interval
if i == 0:
- self.DBSession.commit()
+ self.conn.commit()
# write
- self.DBSession.commit()
+ self.conn.commit()
self.logger.info("SQL thread terminated")