electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 652d10aa5f810d6db00ce4870ea7c779f9ff35a4
parent ef661050c8da36b7b2b73bec47dbaeddb38301f5
Author: ThomasV <thomasv@electrum.org>
Date:   Tue,  9 Mar 2021 09:55:55 +0100

Remove LNBackups object: no longer needed since LNWorker is instantiated by default.

Diffstat:
Melectrum/gui/kivy/uix/dialogs/lightning_channels.py | 16++++++++--------
Melectrum/gui/qt/channels_list.py | 9++++-----
Melectrum/lnpeer.py | 13+++++++------
Melectrum/lnworker.py | 86++++++++++++++++++++++++-------------------------------------------------------
Melectrum/wallet.py | 7+------
5 files changed, 46 insertions(+), 85 deletions(-)

diff --git a/electrum/gui/kivy/uix/dialogs/lightning_channels.py b/electrum/gui/kivy/uix/dialogs/lightning_channels.py @@ -410,11 +410,12 @@ Builder.load_string(r''' class ChannelBackupPopup(Popup, Logger): - def __init__(self, chan: AbstractChannel, app: 'ElectrumWindow', **kwargs): + def __init__(self, chan: AbstractChannel, channels_list, **kwargs): Popup.__init__(self, **kwargs) Logger.__init__(self) self.chan = chan - self.app = app + self.channels_list = channels_list + self.app = channels_list.app self.short_id = format_short_channel_id(chan.short_channel_id) self.state = chan.get_state_for_GUI() self.title = _('Channel Backup') @@ -427,7 +428,7 @@ class ChannelBackupPopup(Popup, Logger): if not b: return loop = self.app.wallet.network.asyncio_loop - coro = asyncio.run_coroutine_threadsafe(self.app.wallet.lnbackups.request_force_close(self.chan.channel_id), loop) + coro = asyncio.run_coroutine_threadsafe(self.app.wallet.lnworker.request_force_close_from_backup(self.chan.channel_id), loop) try: coro.result(5) self.app.show_info(_('Channel closed')) @@ -442,7 +443,7 @@ class ChannelBackupPopup(Popup, Logger): def _remove_backup(self, b): if not b: return - self.app.wallet.lnbackups.remove_channel_backup(self.chan.channel_id) + self.app.wallet.lnworker.remove_channel_backup(self.chan.channel_id) self.dismiss() @@ -550,9 +551,9 @@ class LightningChannelsDialog(Factory.Popup): def show_item(self, obj): chan = obj._chan if chan.is_backup(): - p = ChannelBackupPopup(chan, self.app) + p = ChannelBackupPopup(chan, self) else: - p = ChannelDetailsPopup(chan, self.app) + p = ChannelDetailsPopup(chan, self) p.open() def format_fields(self, chan): @@ -587,8 +588,7 @@ class LightningChannelsDialog(Factory.Popup): return lnworker = self.app.wallet.lnworker channels = list(lnworker.channels.values()) if lnworker else [] - lnbackups = self.app.wallet.lnbackups - backups = list(lnbackups.channel_backups.values()) + backups = list(lnworker.channel_backups.values()) if lnworker else [] for i in channels + backups: item = Factory.LightningChannelItem() item.screen = self diff --git a/electrum/gui/qt/channels_list.py b/electrum/gui/qt/channels_list.py @@ -65,7 +65,6 @@ class ChannelsList(MyTreeView): self.update_single_row.connect(self.do_update_single_row) self.network = self.parent.network self.lnworker = self.parent.wallet.lnworker - self.lnbackups = self.parent.wallet.lnbackups self.setSortingEnabled(True) def format_fields(self, chan: AbstractChannel) -> Dict['ChannelsList.Columns', str]: @@ -136,7 +135,7 @@ class ChannelsList(MyTreeView): def remove_channel_backup(self, channel_id): if self.main_window.question(_('Remove channel backup?')): - self.lnbackups.remove_channel_backup(channel_id) + self.lnworker.remove_channel_backup(channel_id) def export_channel_backup(self, channel_id): msg = ' '.join([ @@ -150,7 +149,7 @@ class ChannelsList(MyTreeView): def request_force_close(self, channel_id): def task(): - coro = self.lnbackups.request_force_close(channel_id) + coro = self.lnworker.request_force_close_from_backup(channel_id) return self.network.run_from_another_thread(coro) def on_success(b): self.main_window.show_message('success') @@ -185,7 +184,7 @@ class ChannelsList(MyTreeView): if not item: return channel_id = idx.sibling(idx.row(), self.Columns.NODE_ALIAS).data(ROLE_CHANNEL_ID) - if channel_id in self.lnbackups.channel_backups: + if channel_id in self.lnworker.channel_backups: menu.addAction(_("Request force-close"), lambda: self.request_force_close(channel_id)) menu.addAction(_("Delete"), lambda: self.remove_channel_backup(channel_id)) menu.exec_(self.viewport().mapToGlobal(position)) @@ -253,7 +252,7 @@ class ChannelsList(MyTreeView): if wallet != self.parent.wallet: return channels = list(wallet.lnworker.channels.values()) if wallet.lnworker else [] - backups = list(wallet.lnbackups.channel_backups.values()) + backups = list(wallet.lnworker.channel_backups.values()) if wallet.lnworker else [] if wallet.lnworker: self.update_can_send(wallet.lnworker) self.model().clear() diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -52,7 +52,7 @@ from .json_db import StoredDict from .invoices import PR_PAID if TYPE_CHECKING: - from .lnworker import LNGossip, LNWallet, LNBackups + from .lnworker import LNGossip, LNWallet from .lnrouter import LNPaymentRoute from .transaction import PartialTransaction @@ -65,10 +65,12 @@ class Peer(Logger): def __init__( self, - lnworker: Union['LNGossip', 'LNWallet', 'LNBackups'], + lnworker: Union['LNGossip', 'LNWallet'], pubkey: bytes, - transport: LNTransportBase - ): + transport: LNTransportBase, + *, is_channel_backup= False): + + self.is_channel_backup = is_channel_backup self._sent_init = False # type: bool self._received_init = False # type: bool self.initialized = asyncio.Future() @@ -171,8 +173,7 @@ class Peer(Logger): def process_message(self, message): message_type, payload = decode_msg(message) # only process INIT if we are a backup - from .lnworker import LNBackups - if isinstance(self.lnworker, LNBackups) and message_type != 'init': + if self.is_channel_backup is True and message_type != 'init': return if message_type in self.ordered_messages: chan_id = payload.get('channel_id') or payload["temporary_channel_id"] diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -599,6 +599,11 @@ class LNWallet(LNWorker): for channel_id, c in random_shuffled_copy(channels.items()): self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self) + self._channel_backups = {} # type: Dict[bytes, Channel] + channel_backups = self.db.get_dict("channel_backups") + for channel_id, cb in random_shuffled_copy(channel_backups.items()): + self._channel_backups[bfh(channel_id)] = ChannelBackup(cb, sweep_address=self.sweep_address, lnworker=self) + self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] self.sent_htlcs_routes = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat self.sent_buckets = dict() # payment_secret -> (amount_sent, amount_failed) @@ -618,6 +623,12 @@ class LNWallet(LNWorker): with self.lock: return self._channels.copy() + @property + def channel_backups(self) -> Mapping[bytes, Channel]: + """Returns a read-only copy of channels.""" + with self.lock: + return self._channel_backups.copy() + def get_channel_by_id(self, channel_id: bytes) -> Optional[Channel]: return self._channels.get(channel_id, None) @@ -680,6 +691,8 @@ class LNWallet(LNWorker): for chan in self.channels.values(): self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) + for cb in self.channel_backups.values(): + self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address()) for coro in [ self.maybe_listen(), @@ -843,7 +856,8 @@ class LNWallet(LNWorker): if chan.node_id == node_id} def channel_state_changed(self, chan: Channel): - self.save_channel(chan) + if type(chan) is Channel: + self.save_channel(chan) util.trigger_callback('channel', self.wallet, chan) def save_channel(self, chan: Channel): @@ -857,8 +871,14 @@ class LNWallet(LNWorker): for chan in self.channels.values(): if chan.funding_outpoint.to_str() == txo: return chan + for chan in self.channel_backups.values(): + if chan.funding_outpoint.to_str() == txo: + return chan async def on_channel_update(self, chan: Channel): + if type(chan) is ChannelBackup: + util.trigger_callback('channel', self.wallet, chan) + return if chan.get_state() == ChannelState.OPEN and chan.should_be_closed_due_to_expiring_htlcs(self.network.get_local_height()): self.logger.info(f"force-closing due to expiring htlcs") @@ -1940,61 +1960,6 @@ class LNWallet(LNWorker): peer = await self.add_peer(connect_str) await peer.trigger_force_close(channel_id) - -class LNBackups(Logger): - - lnwatcher: Optional['LNWalletWatcher'] - - def __init__(self, wallet: 'Abstract_Wallet'): - Logger.__init__(self) - self.features = LNWALLET_FEATURES - self.lock = threading.RLock() - self.wallet = wallet - self.db = wallet.db - self.lnwatcher = None - self.channel_backups = {} - for channel_id, cb in random_shuffled_copy(self.db.get_dict("channel_backups").items()): - self.channel_backups[bfh(channel_id)] = ChannelBackup(cb, sweep_address=self.sweep_address, lnworker=self) - - @property - def sweep_address(self) -> str: - # TODO possible address-reuse - return self.wallet.get_new_sweep_address_for_channel() - - def channel_state_changed(self, chan): - util.trigger_callback('channel', self.wallet, chan) - - def peer_closed(self, chan): - pass - - async def on_channel_update(self, chan): - util.trigger_callback('channel', self.wallet, chan) - - def channel_by_txo(self, txo): - with self.lock: - channel_backups = list(self.channel_backups.values()) - for chan in channel_backups: - if chan.funding_outpoint.to_str() == txo: - return chan - - def on_peer_successfully_established(self, peer: Peer) -> None: - pass - - def channels_for_peer(self, node_id): - return {} - - def start_network(self, network: 'Network'): - assert network - self.lnwatcher = LNWalletWatcher(self, network) - self.lnwatcher.start_network(network) - self.network = network - for cb in self.channel_backups.values(): - self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address()) - - def stop(self): - self.lnwatcher.stop() - self.lnwatcher = None - def import_channel_backup(self, data): assert data.startswith('channel_backup:') encrypted = data[15:] @@ -2015,19 +1980,20 @@ class LNBackups(Logger): d = self.db.get_dict("channel_backups") if channel_id.hex() not in d: raise Exception('Channel not found') - d.pop(channel_id.hex()) - self.channel_backups.pop(channel_id) + with self.lock: + d.pop(channel_id.hex()) + self._channel_backups.pop(channel_id) self.wallet.save_db() util.trigger_callback('channels_updated', self.wallet) @log_exceptions - async def request_force_close(self, channel_id: bytes): + async def request_force_close_from_backup(self, channel_id: bytes): cb = self.channel_backups[channel_id].cb # TODO also try network addresses from gossip db (as it might have changed) peer_addr = LNPeerAddr(cb.host, cb.port, cb.node_id) transport = LNTransport(cb.privkey, peer_addr, proxy=self.network.proxy) - peer = Peer(self, cb.node_id, transport) + peer = Peer(self, cb.node_id, transport, is_channel_backup=True) async with TaskGroup() as group: await group.spawn(peer._message_loop()) await group.spawn(peer.trigger_force_close(channel_id)) diff --git a/electrum/wallet.py b/electrum/wallet.py @@ -80,7 +80,7 @@ from .contacts import Contacts from .interface import NetworkException from .mnemonic import Mnemonic from .logging import get_logger -from .lnworker import LNWallet, LNBackups +from .lnworker import LNWallet from .paymentrequest import PaymentRequest from .util import read_json_file, write_json_file, UserFacingException @@ -273,7 +273,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC): txin_type: str wallet_type: str lnworker: Optional['LNWallet'] - lnbackups: Optional['LNBackups'] def __init__(self, db: WalletDB, storage: Optional[WalletStorage], *, config: SimpleConfig): if not db.is_ready_to_be_used_by_wallet(): @@ -310,8 +309,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC): self._coin_price_cache = {} self.lnworker = None - # a wallet may have channel backups, regardless of lnworker activation - self.lnbackups = LNBackups(self) def save_db(self): if self.storage: @@ -364,7 +361,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC): if self.lnworker: self.lnworker.stop() self.lnworker = None - self.lnbackups.stop() self.save_db() def set_up_to_date(self, b): @@ -383,7 +379,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC): # only start gossiping when we already have channels if self.db.get('channels'): self.network.start_gossip() - self.lnbackups.start_network(network) def load_and_cleanup(self): self.load_keystore()