commit d8e9a9a49e38fb9353fb80b3d72debc0ccd711ce
parent b861e2e955c4a790d8e2b4ce262b894a67c3b470
Author: ThomasV <thomasv@electrum.org>
Date: Wed, 6 Mar 2019 09:56:22 +0100
create parent class for sql databases
Diffstat:
3 files changed, 72 insertions(+), 82 deletions(-)
diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
@@ -35,13 +35,11 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
import binascii
import base64
-from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
-from sqlalchemy.pool import StaticPool
-from sqlalchemy.orm import sessionmaker
+from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
-from sqlalchemy.orm import scoped_session
+from .sql_db import SqlDB, sql
from . import constants
from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
@@ -212,50 +210,25 @@ class Address(Base):
last_connected_date = Column(DateTime(), nullable=False)
-class ChannelDB(PrintError):
+
+
+class ChannelDB(SqlDB):
NUM_MAX_RECENT_PEERS = 20
def __init__(self, network: 'Network'):
- self.network = network
+ path = os.path.join(get_headers_dir(network.config), 'channel_db')
+ super().__init__(network, path, Base)
+ print(Base)
self.num_nodes = 0
self.num_channels = 0
- self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3')
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self)
- self.db_requests = queue.Queue()
- threading.Thread(target=self.sql_thread).start()
-
- def sql_thread(self):
- self.sql_thread = threading.currentThread()
- engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
- DBSession = sessionmaker(bind=engine, autoflush=False)
- self.DBSession = DBSession()
- if not os.path.exists(self.path):
- Base.metadata.create_all(engine)
+ self.update_counts()
+
+ @sql
+ def update_counts(self):
self._update_counts()
- while self.network.asyncio_loop.is_running():
- try:
- future, func, args, kwargs = self.db_requests.get(timeout=0.1)
- except queue.Empty:
- continue
- try:
- result = func(self, *args, **kwargs)
- except BaseException as e:
- future.set_exception(e)
- continue
- future.set_result(result)
- # write
- self.DBSession.commit()
- self.print_error("SQL thread terminated")
-
- def sql(func):
- def wrapper(self, *args, **kwargs):
- assert threading.currentThread() != self.sql_thread
- f = concurrent.futures.Future()
- self.db_requests.put((f, func, args, kwargs))
- return f.result(timeout=10)
- return wrapper
def _update_counts(self):
self.num_channels = self.DBSession.query(ChannelInfo).count()
diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
@@ -11,9 +11,14 @@ from collections import defaultdict
import asyncio
from enum import IntEnum, auto
from typing import NamedTuple, Dict
-
import jsonrpclib
+from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
+from sqlalchemy.orm.query import Query
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.sql import not_, or_
+from .sql_db import SqlDB, sql
+
from .util import PrintError, bh2u, bfh, log_exceptions, ignore_exceptions
from . import wallet
from .storage import WalletStorage
@@ -37,14 +42,6 @@ class TxMinedDepth(IntEnum):
FREE = auto()
-from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
-from sqlalchemy.pool import StaticPool
-from sqlalchemy.orm import sessionmaker
-from sqlalchemy.orm.query import Query
-from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.sql import not_, or_
-from sqlalchemy.orm import scoped_session
-
Base = declarative_base()
class SweepTx(Base):
@@ -60,42 +57,11 @@ class ChannelInfo(Base):
outpoint = Column(String(34))
-class SweepStore(PrintError):
- def __init__(self, path, network):
- PrintError.__init__(self)
- self.path = path
- self.network = network
- self.db_requests = queue.Queue()
- threading.Thread(target=self.sql_thread).start()
-
- def sql_thread(self):
- engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)
- DBSession = sessionmaker(bind=engine, autoflush=False)
- self.DBSession = DBSession()
- if not os.path.exists(self.path):
- Base.metadata.create_all(engine)
- while self.network.asyncio_loop.is_running():
- try:
- future, func, args, kwargs = self.db_requests.get(timeout=0.1)
- except queue.Empty:
- continue
- try:
- result = func(self, *args, **kwargs)
- except BaseException as e:
- future.set_exception(e)
- continue
- future.set_result(result)
- # write
- self.DBSession.commit()
- self.print_error("SQL thread terminated")
+class SweepStore(SqlDB):
- def sql(func):
- def wrapper(self, *args, **kwargs):
- f = concurrent.futures.Future()
- self.db_requests.put((f, func, args, kwargs))
- return f.result(timeout=10)
- return wrapper
+ def __init__(self, path, network):
+ super().__init__(network, path, Base)
@sql
def get_sweep_tx(self, funding_outpoint, prev_txid):
diff --git a/electrum/sql_db.py b/electrum/sql_db.py
@@ -0,0 +1,51 @@
+import os
+import concurrent
+import queue
+import threading
+
+from sqlalchemy import create_engine
+from sqlalchemy.pool import StaticPool
+from sqlalchemy.orm import sessionmaker
+
+from .util import PrintError
+
+
+def sql(func):
+ """wrapper for sql methods"""
+ def wrapper(self, *args, **kwargs):
+ assert threading.currentThread() != self.sql_thread
+ f = concurrent.futures.Future()
+ self.db_requests.put((f, func, args, kwargs))
+ return f.result(timeout=10)
+ return wrapper
+
+class SqlDB(PrintError):
+
+ def __init__(self, network, path, base):
+ self.base = base
+ self.network = network
+ self.path = path
+ self.db_requests = queue.Queue()
+ self.sql_thread = threading.Thread(target=self.run_sql)
+ self.sql_thread.start()
+
+ def run_sql(self):
+ engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
+ DBSession = sessionmaker(bind=engine, autoflush=False)
+ self.DBSession = DBSession()
+ if not os.path.exists(self.path):
+ self.base.metadata.create_all(engine)
+ while self.network.asyncio_loop.is_running():
+ try:
+ future, func, args, kwargs = self.db_requests.get(timeout=0.1)
+ except queue.Empty:
+ continue
+ try:
+ result = func(self, *args, **kwargs)
+ except BaseException as e:
+ future.set_exception(e)
+ continue
+ future.set_result(result)
+ # write
+ self.DBSession.commit()
+ self.print_error("SQL thread terminated")