commit b861e2e955c4a790d8e2b4ce262b894a67c3b470
parent bfdf0a7e8823b8f250df6d7fb8d691dea689b2a5
Author: ThomasV <thomasv@electrum.org>
Date: Tue, 5 Mar 2019 17:28:24 +0100
lnwatcher: save sweepstore in sqlite database
Diffstat:
2 files changed, 124 insertions(+), 43 deletions(-)
diff --git a/electrum/gui/qt/watchtower_window.py b/electrum/gui/qt/watchtower_window.py
@@ -52,9 +52,11 @@ class WatcherList(MyTreeView):
def update(self):
self.model().clear()
self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')})
- for outpoint, sweep_dict in self.parent.lnwatcher.sweepstore.items():
+ sweepstore = self.parent.lnwatcher.sweepstore
+ for outpoint in sweepstore.list_sweep_tx():
+ n = sweepstore.num_sweep_tx(outpoint)
status = self.parent.lnwatcher.get_channel_status(outpoint)
- items = [QStandardItem(e) for e in [outpoint, "%d"%len(sweep_dict), status]]
+ items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]]
self.model().insertRow(self.model().rowCount(), items)
diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
@@ -2,9 +2,11 @@
# Distributed under the MIT software license, see the accompanying
# file LICENCE or http://www.opensource.org/licenses/mit-license.php
-import threading
from typing import NamedTuple, Iterable, TYPE_CHECKING
import os
+import queue
+import threading
+import concurrent
from collections import defaultdict
import asyncio
from enum import IntEnum, auto
@@ -35,27 +37,125 @@ class TxMinedDepth(IntEnum):
FREE = auto()
+from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
+from sqlalchemy.pool import StaticPool
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.orm.query import Query
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.sql import not_, or_
+from sqlalchemy.orm import scoped_session
+
+Base = declarative_base()
+
+class SweepTx(Base):
+ __tablename__ = 'sweep_txs'
+ funding_outpoint = Column(String(34))
+ prev_txid = Column(String(32))
+ tx = Column(String())
+ txid = Column(String(32), primary_key=True) # txid of tx
+
+class ChannelInfo(Base):
+ __tablename__ = 'channel_info'
+ address = Column(String(32), primary_key=True)
+ outpoint = Column(String(34))
+
+
+class SweepStore(PrintError):
+
+ def __init__(self, path, network):
+ PrintError.__init__(self)
+ self.path = path
+ self.network = network
+ self.db_requests = queue.Queue()
+ threading.Thread(target=self.sql_thread).start()
+
+ def sql_thread(self):
+ engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)
+ DBSession = sessionmaker(bind=engine, autoflush=False)
+ self.DBSession = DBSession()
+ if not os.path.exists(self.path):
+ Base.metadata.create_all(engine)
+ while self.network.asyncio_loop.is_running():
+ try:
+ future, func, args, kwargs = self.db_requests.get(timeout=0.1)
+ except queue.Empty:
+ continue
+ try:
+ result = func(self, *args, **kwargs)
+ except BaseException as e:
+ future.set_exception(e)
+ continue
+ future.set_result(result)
+ # write
+ self.DBSession.commit()
+ self.print_error("SQL thread terminated")
+
+ def sql(func):
+ def wrapper(self, *args, **kwargs):
+ f = concurrent.futures.Future()
+ self.db_requests.put((f, func, args, kwargs))
+ return f.result(timeout=10)
+ return wrapper
+
+ @sql
+ def get_sweep_tx(self, funding_outpoint, prev_txid):
+ return [Transaction(r.tx) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prev_txid==prev_txid).all()]
+
+ @sql
+ def list_sweep_tx(self):
+ return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all())
+
+ @sql
+ def add_sweep_tx(self, funding_outpoint, prev_txid, tx):
+ self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, prev_txid=prev_txid, tx=str(tx), txid=tx.txid()))
+ self.DBSession.commit()
+
+ @sql
+ def num_sweep_tx(self, funding_outpoint):
+ return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
+
+ @sql
+ def remove_sweep_tx(self, funding_outpoint):
+ v = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all()
+ self.DBSession.delete(v)
+ self.DBSession.commit()
+
+ @sql
+ def add_channel_info(self, address, outpoint):
+ self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint))
+ self.DBSession.commit()
+
+ @sql
+ def remove_channel_info(self, address):
+ v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
+ self.DBSession.delete(v)
+ self.DBSession.commit()
+
+ @sql
+ def has_channel_info(self, address):
+ return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none())
+
+ @sql
+ def get_channel_info(self, address):
+ r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()
+ return r.outpoint if r else None
+
+ @sql
+ def list_channel_info(self):
+ return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()]
+
+
class LNWatcher(AddressSynchronizer):
verbosity_filter = 'W'
def __init__(self, network: 'Network'):
- path = os.path.join(network.config.path, "watcher_db")
+ path = os.path.join(network.config.path, "watchtower_wallet")
storage = WalletStorage(path)
AddressSynchronizer.__init__(self, storage)
self.config = network.config
self.start_network(network)
self.lock = threading.RLock()
- self.channel_info = storage.get('channel_info', {}) # access with 'lock'
- # [funding_outpoint_str][prev_txid] -> set of Transaction
- # prev_txid is the txid of a tx that is watched for confirmations
- # access with 'lock'
- self.sweepstore = defaultdict(lambda: defaultdict(set))
- for funding_outpoint, ctxs in storage.get('sweepstore', {}).items():
- for txid, set_of_txns in ctxs.items():
- for tx in set_of_txns:
- tx2 = Transaction.from_dict(tx)
- self.sweepstore[funding_outpoint][txid].add(tx2)
-
+ self.sweepstore = SweepStore(os.path.join(network.config.path, "watchtower_db"), network)
self.network.register_callback(self.on_network_update,
['network_updated', 'blockchain_updated', 'verified', 'wallet_updated'])
self.set_remote_watchtower()
@@ -97,34 +197,18 @@ class LNWatcher(AddressSynchronizer):
await asyncio.sleep(5)
await self.watchtower_queue.put((name, args, kwargs))
- def write_to_disk(self):
- # FIXME: json => every update takes linear instead of constant disk write
- with self.lock:
- storage = self.storage
- storage.put('channel_info', self.channel_info)
- # self.sweepstore
- sweepstore = {}
- for funding_outpoint, ctxs in self.sweepstore.items():
- sweepstore[funding_outpoint] = {}
- for prev_txid, set_of_txns in ctxs.items():
- sweepstore[funding_outpoint][prev_txid] = [tx.as_dict() for tx in set_of_txns]
- storage.put('sweepstore', sweepstore)
- storage.write()
@with_watchtower
def watch_channel(self, address, outpoint):
self.add_address(address)
with self.lock:
- if address not in self.channel_info:
- self.channel_info[address] = outpoint
- self.write_to_disk()
+ if not self.sweepstore.has_channel_info(address):
+ self.sweepstore.add_channel_info(address, outpoint)
def unwatch_channel(self, address, funding_outpoint):
self.print_error('unwatching', funding_outpoint)
- with self.lock:
- self.channel_info.pop(address)
- self.sweepstore.pop(funding_outpoint)
- self.write_to_disk()
+ self.sweepstore.remove_sweep_tx(funding_outpoint)
+ self.sweepstore.remove_channel_info(address)
if funding_outpoint in self.tx_progress:
self.tx_progress[funding_outpoint].all_done.set()
@@ -138,9 +222,7 @@ class LNWatcher(AddressSynchronizer):
return
if not self.synchronizer.is_up_to_date():
return
- with self.lock:
- channel_info_items = list(self.channel_info.items())
- for address, outpoint in channel_info_items:
+ for address, outpoint in self.sweepstore.list_channel_info():
await self.check_onchain_situation(address, outpoint)
async def check_onchain_situation(self, address, funding_outpoint):
@@ -192,8 +274,7 @@ class LNWatcher(AddressSynchronizer):
if spender is not None:
continue
prev_txid, prev_n = prevout.split(':')
- with self.lock:
- sweep_txns = self.sweepstore[funding_outpoint][prev_txid]
+ sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prev_txid)
for tx in sweep_txns:
if not await self.broadcast_or_log(funding_outpoint, tx):
self.print_error(tx.name, f'could not publish tx: {str(tx)}, prev_txid: {prev_txid}')
@@ -215,9 +296,7 @@ class LNWatcher(AddressSynchronizer):
@with_watchtower
def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict):
tx = Transaction.from_dict(tx_dict)
- with self.lock:
- self.sweepstore[funding_outpoint][prev_txid].add(tx)
- self.write_to_disk()
+ self.sweepstore.add_sweep_tx(funding_outpoint, prev_txid, tx)
def get_tx_mined_depth(self, txid: str):
if not txid: