electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 3abe30e9d8a383fe88bd9edd23b3be749e9c1c03
parent c155293166a315a3479bfc9af66b05cb8ef22696
Author: ThomasV <thomasv@electrum.org>
Date:   Tue, 12 Mar 2019 18:33:36 +0100

basic watchtower synchronization

Diffstat:
Melectrum/daemon.py | 3++-
Melectrum/gui/qt/watchtower_window.py | 2+-
Melectrum/lnchannel.py | 2+-
Melectrum/lnpeer.py | 2+-
Melectrum/lnwatcher.py | 82++++++++++++++++++++++++++++++++++++++++++++-----------------------------------
Melectrum/lnworker.py | 4++--
6 files changed, 53 insertions(+), 42 deletions(-)

diff --git a/electrum/daemon.py b/electrum/daemon.py @@ -135,7 +135,8 @@ class WatchTower(DaemonThread): port = self.config.get('watchtower_port', 12345) server = SimpleJSONRPCServer((host, port), logRequests=True) server.register_function(self.lnwatcher.add_sweep_tx, 'add_sweep_tx') - server.register_function(self.lnwatcher.watch_channel, 'watch_channel') + server.register_function(self.lnwatcher.add_channel, 'add_channel') + server.register_function(self.lnwatcher.get_num_tx, 'get_num_tx') server.timeout = 0.1 while self.is_running(): server.handle_request() diff --git a/electrum/gui/qt/watchtower_window.py b/electrum/gui/qt/watchtower_window.py @@ -54,7 +54,7 @@ class WatcherList(MyTreeView): self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')}) sweepstore = self.parent.lnwatcher.sweepstore for outpoint in sweepstore.list_sweep_tx(): - n = sweepstore.num_sweep_tx(outpoint) + n = sweepstore.get_num_tx(outpoint) status = self.parent.lnwatcher.get_channel_status(outpoint) items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]] self.model().insertRow(self.model().rowCount(), items) diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py @@ -464,7 +464,7 @@ class Channel(PrintError): sweeptxs = create_sweeptxs_for_their_just_revoked_ctx(self, ctx, per_commitment_secret, self.sweep_address) for prev_txid, tx in sweeptxs.items(): if tx is not None: - self.lnwatcher.add_sweep_tx(outpoint, prev_txid, tx.as_dict()) + self.lnwatcher.add_sweep_tx(outpoint, prev_txid, str(tx)) def receive_revocation(self, revocation: RevokeAndAck): self.print_error("receive_revocation") diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -490,7 +490,7 @@ class Peer(PrintError): ) chan.open_with_first_pcp(payload['first_per_commitment_point'], remote_sig) self.lnworker.save_channel(chan) - self.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str()) + self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) self.lnworker.on_channels_updated() while True: try: diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py @@ -46,15 +46,15 @@ Base = declarative_base() class SweepTx(Base): __tablename__ = 'sweep_txs' - funding_outpoint = Column(String(34)) + funding_outpoint = Column(String(34), primary_key=True) + index = Column(Integer(), primary_key=True) prev_txid = Column(String(32)) tx = Column(String()) - txid = Column(String(32), primary_key=True) # txid of tx class ChannelInfo(Base): __tablename__ = 'channel_info' - address = Column(String(32), primary_key=True) - outpoint = Column(String(34)) + outpoint = Column(String(34), primary_key=True) + address = Column(String(32)) @@ -68,16 +68,22 @@ class SweepStore(SqlDB): return [Transaction(bh2u(r.tx)) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prev_txid==prev_txid).all()] @sql + def get_tx_by_index(self, funding_outpoint, index): + r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none() + return r.prev_txid, bh2u(r.tx) + + @sql def list_sweep_tx(self): return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all()) @sql def add_sweep_tx(self, funding_outpoint, prev_txid, tx): - self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, prev_txid=prev_txid, tx=bfh(str(tx)), txid=tx.txid())) + n = self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count() + self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, index=n, prev_txid=prev_txid, tx=bfh(tx))) self.DBSession.commit() @sql - def num_sweep_tx(self, funding_outpoint): + def get_num_tx(self, funding_outpoint): return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count() @sql @@ -87,24 +93,24 @@ class SweepStore(SqlDB): self.DBSession.commit() @sql - def add_channel_info(self, address, outpoint): + def add_channel(self, outpoint, address): self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint)) self.DBSession.commit() @sql - def remove_channel_info(self, address): - v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none() + def remove_channel(self, outpoint): + v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none() self.DBSession.delete(v) self.DBSession.commit() @sql - def has_channel_info(self, address): - return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()) + def has_channel(self, outpoint): + return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()) @sql - def get_channel_info(self, address): - r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none() - return r.outpoint if r else None + def get_address(self, outpoint): + r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none() + return r.address if r else None @sql def list_channel_info(self): @@ -139,42 +145,46 @@ class LNWatcher(AddressSynchronizer): self.watchtower = jsonrpclib.Server(watchtower_url) if watchtower_url else None self.watchtower_queue = asyncio.Queue() - def with_watchtower(func): - def wrapper(self, *args, **kwargs): - if self.watchtower: - self.watchtower_queue.put_nowait((func.__name__, args, kwargs)) - return func(self, *args, **kwargs) - return wrapper + def get_num_tx(self, outpoint): + return self.sweepstore.get_num_tx(outpoint) @ignore_exceptions @log_exceptions async def watchtower_task(self): self.print_error('watchtower task started') + # initial check + for address, outpoint in self.sweepstore.list_channel_info(): + await self.watchtower_queue.put(outpoint) while True: - name, args, kwargs = await self.watchtower_queue.get() + outpoint = await self.watchtower_queue.get() if self.watchtower is None: continue - func = getattr(self.watchtower, name) + # synchronize with remote try: - r = func(*args, **kwargs) - self.print_error("watchtower answer", r) - except: - self.print_error('could not reach watchtower, will retry in 5s', name, args) + local_n = self.sweepstore.get_num_tx(outpoint) + n = self.watchtower.get_num_tx(outpoint) + if n == 0: + address = self.sweepstore.get_address(outpoint) + self.watchtower.add_channel(outpoint, address) + self.print_error("sending %d transactions to watchtower"%(local_n - n)) + for index in range(n, local_n): + prev_txid, tx = self.sweepstore.get_tx_by_index(outpoint, index) + self.watchtower.add_sweep_tx(outpoint, prev_txid, tx) + except ConnectionRefusedError: + self.print_error('could not reach watchtower, will retry in 5s') await asyncio.sleep(5) - await self.watchtower_queue.put((name, args, kwargs)) - + await self.watchtower_queue.put(outpoint) - @with_watchtower - def watch_channel(self, address, outpoint): + def add_channel(self, outpoint, address): self.add_address(address) with self.lock: - if not self.sweepstore.has_channel_info(address): - self.sweepstore.add_channel_info(address, outpoint) + if not self.sweepstore.has_channel(outpoint): + self.sweepstore.add_channel(outpoint, address) def unwatch_channel(self, address, funding_outpoint): self.print_error('unwatching', funding_outpoint) self.sweepstore.remove_sweep_tx(funding_outpoint) - self.sweepstore.remove_channel_info(address) + self.sweepstore.remove_channel_info(funding_outpoint) if funding_outpoint in self.tx_progress: self.tx_progress[funding_outpoint].all_done.set() @@ -259,10 +269,10 @@ class LNWatcher(AddressSynchronizer): await self.tx_progress[funding_outpoint].tx_queue.put(tx) return txid - @with_watchtower - def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict): - tx = Transaction.from_dict(tx_dict) + def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx: str): self.sweepstore.add_sweep_tx(funding_outpoint, prev_txid, tx) + if self.watchtower: + self.watchtower_queue.put_nowait(funding_outpoint) def get_tx_mined_depth(self, txid: str): if not txid: diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -92,7 +92,7 @@ class LNWorker(PrintError): self.config = network.config self.channel_db = self.network.channel_db for chan_id, chan in self.channels.items(): - self.network.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str()) + self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) chan.lnwatcher = network.lnwatcher self._last_tried_peer = {} # LNPeerAddr -> unix timestamp self._add_peers_from_config() @@ -425,7 +425,7 @@ class LNWorker(PrintError): push_msat=push_sat * 1000, temp_channel_id=os.urandom(32)) self.save_channel(chan) - self.network.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str()) + self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) self.on_channels_updated() return chan