electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit b9b53e7f76eb054773a9a495593396b6b18af85a
parent f5eb91900ab0a1c3bb776d14f417e566932e596e
Author: SomberNight <somber.night@protonmail.com>
Date:   Thu, 30 Apr 2020 21:08:26 +0200

lnworker: fix threading issues for .channels attribute

external code (commands/gui) did not always take lock when iterating lnworker.channels.
instead of exposing lock, let's take a copy internally (as with .peers)

Diffstat:
Melectrum/lnpeer.py | 4+++-
Melectrum/lnworker.py | 73+++++++++++++++++++++++++++++++++----------------------------------------
Melectrum/tests/test_lnpeer.py | 10+++++++---
3 files changed, 43 insertions(+), 44 deletions(-)

diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -107,7 +107,7 @@ class Peer(Logger): if not (message_name.startswith("update_") or is_commitment_signed): return assert channel_id - chan = self.lnworker.channels[channel_id] # type: Channel + chan = self.channels[channel_id] chan.hm.store_local_update_raw_msg(raw_msg, is_commitment_signed=is_commitment_signed) if is_commitment_signed: # saving now, to ensure replaying updates works (in case of channel reestablishment) @@ -139,6 +139,8 @@ class Peer(Logger): @property def channels(self) -> Dict[bytes, Channel]: + # FIXME this iterates over all channels in lnworker, + # so if we just want to lookup a channel by channel_id, it's wasteful return self.lnworker.channels_for_peer(self.pubkey) def diagnostic_name(self): diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -491,22 +491,26 @@ class LNWallet(LNWorker): self.enable_htlc_settle.set() # note: accessing channels (besides simple lookup) needs self.lock! - self.channels = {} + self._channels = {} # type: Dict[bytes, Channel] channels = self.db.get_dict("channels") for channel_id, c in channels.items(): - self.channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self) + self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self) self.pending_payments = defaultdict(asyncio.Future) # type: Dict[bytes, asyncio.Future[BarePaymentAttemptLog]] + @property + def channels(self) -> Mapping[bytes, Channel]: + """Returns a read-only copy of channels.""" + with self.lock: + return self._channels.copy() + @ignore_exceptions @log_exceptions async def sync_with_local_watchtower(self): watchtower = self.network.local_watchtower if watchtower: while True: - with self.lock: - channels = list(self.channels.values()) - for chan in channels: + for chan in self.channels.values(): await self.sync_channel_with_watchtower(chan, watchtower.sweepstore) await asyncio.sleep(5) @@ -524,12 +528,10 @@ class LNWallet(LNWorker): watchtower_url = self.config.get('watchtower_url') if not watchtower_url: continue - with self.lock: - channels = list(self.channels.values()) try: async with make_aiohttp_session(proxy=self.network.proxy) as session: watchtower = myAiohttpClient(session, watchtower_url) - for chan in channels: + for chan in self.channels.values(): await self.sync_channel_with_watchtower(chan, watchtower) except aiohttp.client_exceptions.ClientConnectorError: self.logger.info(f'could not contact remote watchtower {watchtower_url}') @@ -574,9 +576,7 @@ class LNWallet(LNWorker): # return one item per payment_hash # note: with AMP we will have several channels per payment out = defaultdict(list) - with self.lock: - channels = list(self.channels.values()) - for chan in channels: + for chan in self.channels.values(): d = chan.get_settled_payments() for k, v in d.items(): out[k].append(v) @@ -628,9 +628,7 @@ class LNWallet(LNWorker): def get_onchain_history(self): out = {} # add funding events - with self.lock: - channels = list(self.channels.values()) - for chan in channels: + for chan in self.channels.values(): item = chan.get_funding_height() if item is None: continue @@ -693,8 +691,7 @@ class LNWallet(LNWorker): def channels_for_peer(self, node_id): assert type(node_id) is bytes - with self.lock: - return {x: y for (x, y) in self.channels.items() if y.node_id == node_id} + return {x: y for (x, y) in self.channels.items() if y.node_id == node_id} def channel_state_changed(self, chan): self.save_channel(chan) @@ -708,9 +705,7 @@ class LNWallet(LNWorker): util.trigger_callback('channel', chan) def channel_by_txo(self, txo): - with self.lock: - channels = list(self.channels.values()) - for chan in channels: + for chan in self.channels.values(): if chan.funding_outpoint.to_str() == txo: return chan @@ -762,7 +757,7 @@ class LNWallet(LNWorker): def add_channel(self, chan): with self.lock: - self.channels[chan.channel_id] = chan + self._channels[chan.channel_id] = chan self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) def add_new_channel(self, chan): @@ -805,10 +800,9 @@ class LNWallet(LNWorker): success = fut.result() def get_channel_by_short_id(self, short_channel_id: ShortChannelID) -> Channel: - with self.lock: - for chan in self.channels.values(): - if chan.short_channel_id == short_channel_id: - return chan + for chan in self.channels.values(): + if chan.short_channel_id == short_channel_id: + return chan async def _pay(self, invoice, amount_sat=None, attempts=1) -> bool: lnaddr = self._check_invoice(invoice, amount_sat) @@ -981,8 +975,7 @@ class LNWallet(LNWorker): # if there are multiple hints, we will use the first one that works, # from a random permutation random.shuffle(r_tags) - with self.lock: - channels = list(self.channels.values()) + channels = list(self.channels.values()) scid_to_my_channels = {chan.short_channel_id: chan for chan in channels if chan.short_channel_id is not None} for private_route in r_tags: @@ -1196,8 +1189,7 @@ class LNWallet(LNWorker): async def _calc_routing_hints_for_invoice(self, amount_sat: Optional[int]): """calculate routing hints (BOLT-11 'r' field)""" routing_hints = [] - with self.lock: - channels = list(self.channels.values()) + channels = list(self.channels.values()) scid_to_my_channels = {chan.short_channel_id: chan for chan in channels if chan.short_channel_id is not None} ignore_min_htlc_value = False @@ -1251,24 +1243,27 @@ class LNWallet(LNWorker): def get_balance(self): with self.lock: - return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0 for chan in self.channels.values()))/1000 + return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0 + for chan in self.channels.values())) / 1000 def num_sats_can_send(self) -> Union[Decimal, int]: with self.lock: - return Decimal(max(chan.available_to_spend(LOCAL) if chan.is_open() else 0 for chan in self.channels.values()))/1000 if self.channels else 0 + return Decimal(max(chan.available_to_spend(LOCAL) if chan.is_open() else 0 + for chan in self.channels.values()))/1000 if self.channels else 0 def num_sats_can_receive(self) -> Union[Decimal, int]: with self.lock: - return Decimal(max(chan.available_to_spend(REMOTE) if chan.is_open() else 0 for chan in self.channels.values()))/1000 if self.channels else 0 + return Decimal(max(chan.available_to_spend(REMOTE) if chan.is_open() else 0 + for chan in self.channels.values()))/1000 if self.channels else 0 async def close_channel(self, chan_id): - chan = self.channels[chan_id] + chan = self._channels[chan_id] peer = self._peers[chan.node_id] return await peer.close_channel(chan_id) async def force_close_channel(self, chan_id): # returns txid or raises - chan = self.channels[chan_id] + chan = self._channels[chan_id] tx = chan.force_close_tx() await self.network.broadcast_transaction(tx) chan.set_state(ChannelState.FORCE_CLOSING) @@ -1276,16 +1271,16 @@ class LNWallet(LNWorker): async def try_force_closing(self, chan_id): # fails silently but sets the state, so that we will retry later - chan = self.channels[chan_id] + chan = self._channels[chan_id] tx = chan.force_close_tx() chan.set_state(ChannelState.FORCE_CLOSING) await self.network.try_broadcasting(tx, 'force-close') def remove_channel(self, chan_id): - chan = self.channels[chan_id] + chan = self._channels[chan_id] assert chan.get_state() == ChannelState.REDEEMED with self.lock: - self.channels.pop(chan_id) + self._channels.pop(chan_id) self.db.get('channels').pop(chan_id.hex()) util.trigger_callback('channels_updated', self.wallet) @@ -1316,9 +1311,7 @@ class LNWallet(LNWorker): async def reestablish_peers_and_channels(self): while True: await asyncio.sleep(1) - with self.lock: - channels = list(self.channels.values()) - for chan in channels: + for chan in self.channels.values(): if chan.is_closed(): continue # reestablish @@ -1340,7 +1333,7 @@ class LNWallet(LNWorker): return max(253, feerate_per_kvbyte // 4) def create_channel_backup(self, channel_id): - chan = self.channels[channel_id] + chan = self._channels[channel_id] peer_addresses = list(chan.get_peer_addresses()) peer_addr = peer_addresses[0] return ChannelBackupStorage( diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -102,7 +102,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): self.remote_keypair = remote_keypair self.node_keypair = local_keypair self.network = MockNetwork(tx_queue) - self.channels = {chan.channel_id: chan} + self._channels = {chan.channel_id: chan} self.payments = {} self.logs = defaultdict(list) self.wallet = MockWallet() @@ -123,6 +123,10 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): return noop_lock() @property + def channels(self): + return self._channels + + @property def peers(self): return self._peers @@ -131,11 +135,11 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): return {self.remote_keypair.pubkey: self.peer} def channels_for_peer(self, pubkey): - return self.channels + return self._channels def get_channel_by_short_id(self, short_channel_id): with self.lock: - for chan in self.channels.values(): + for chan in self._channels.values(): if chan.short_channel_id == short_channel_id: return chan