commit b9b53e7f76eb054773a9a495593396b6b18af85a
parent f5eb91900ab0a1c3bb776d14f417e566932e596e
Author: SomberNight <somber.night@protonmail.com>
Date: Thu, 30 Apr 2020 21:08:26 +0200
lnworker: fix threading issues for .channels attribute
external code (commands/gui) did not always take lock when iterating lnworker.channels.
instead of exposing lock, let's take a copy internally (as with .peers)
Diffstat:
3 files changed, 43 insertions(+), 44 deletions(-)
diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
@@ -107,7 +107,7 @@ class Peer(Logger):
if not (message_name.startswith("update_") or is_commitment_signed):
return
assert channel_id
- chan = self.lnworker.channels[channel_id] # type: Channel
+ chan = self.channels[channel_id]
chan.hm.store_local_update_raw_msg(raw_msg, is_commitment_signed=is_commitment_signed)
if is_commitment_signed:
# saving now, to ensure replaying updates works (in case of channel reestablishment)
@@ -139,6 +139,8 @@ class Peer(Logger):
@property
def channels(self) -> Dict[bytes, Channel]:
+ # FIXME this iterates over all channels in lnworker,
+ # so if we just want to lookup a channel by channel_id, it's wasteful
return self.lnworker.channels_for_peer(self.pubkey)
def diagnostic_name(self):
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -491,22 +491,26 @@ class LNWallet(LNWorker):
self.enable_htlc_settle.set()
# note: accessing channels (besides simple lookup) needs self.lock!
- self.channels = {}
+ self._channels = {} # type: Dict[bytes, Channel]
channels = self.db.get_dict("channels")
for channel_id, c in channels.items():
- self.channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
+ self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
self.pending_payments = defaultdict(asyncio.Future) # type: Dict[bytes, asyncio.Future[BarePaymentAttemptLog]]
+ @property
+ def channels(self) -> Mapping[bytes, Channel]:
+ """Returns a read-only copy of channels."""
+ with self.lock:
+ return self._channels.copy()
+
@ignore_exceptions
@log_exceptions
async def sync_with_local_watchtower(self):
watchtower = self.network.local_watchtower
if watchtower:
while True:
- with self.lock:
- channels = list(self.channels.values())
- for chan in channels:
+ for chan in self.channels.values():
await self.sync_channel_with_watchtower(chan, watchtower.sweepstore)
await asyncio.sleep(5)
@@ -524,12 +528,10 @@ class LNWallet(LNWorker):
watchtower_url = self.config.get('watchtower_url')
if not watchtower_url:
continue
- with self.lock:
- channels = list(self.channels.values())
try:
async with make_aiohttp_session(proxy=self.network.proxy) as session:
watchtower = myAiohttpClient(session, watchtower_url)
- for chan in channels:
+ for chan in self.channels.values():
await self.sync_channel_with_watchtower(chan, watchtower)
except aiohttp.client_exceptions.ClientConnectorError:
self.logger.info(f'could not contact remote watchtower {watchtower_url}')
@@ -574,9 +576,7 @@ class LNWallet(LNWorker):
# return one item per payment_hash
# note: with AMP we will have several channels per payment
out = defaultdict(list)
- with self.lock:
- channels = list(self.channels.values())
- for chan in channels:
+ for chan in self.channels.values():
d = chan.get_settled_payments()
for k, v in d.items():
out[k].append(v)
@@ -628,9 +628,7 @@ class LNWallet(LNWorker):
def get_onchain_history(self):
out = {}
# add funding events
- with self.lock:
- channels = list(self.channels.values())
- for chan in channels:
+ for chan in self.channels.values():
item = chan.get_funding_height()
if item is None:
continue
@@ -693,8 +691,7 @@ class LNWallet(LNWorker):
def channels_for_peer(self, node_id):
assert type(node_id) is bytes
- with self.lock:
- return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
+ return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
def channel_state_changed(self, chan):
self.save_channel(chan)
@@ -708,9 +705,7 @@ class LNWallet(LNWorker):
util.trigger_callback('channel', chan)
def channel_by_txo(self, txo):
- with self.lock:
- channels = list(self.channels.values())
- for chan in channels:
+ for chan in self.channels.values():
if chan.funding_outpoint.to_str() == txo:
return chan
@@ -762,7 +757,7 @@ class LNWallet(LNWorker):
def add_channel(self, chan):
with self.lock:
- self.channels[chan.channel_id] = chan
+ self._channels[chan.channel_id] = chan
self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
def add_new_channel(self, chan):
@@ -805,10 +800,9 @@ class LNWallet(LNWorker):
success = fut.result()
def get_channel_by_short_id(self, short_channel_id: ShortChannelID) -> Channel:
- with self.lock:
- for chan in self.channels.values():
- if chan.short_channel_id == short_channel_id:
- return chan
+ for chan in self.channels.values():
+ if chan.short_channel_id == short_channel_id:
+ return chan
async def _pay(self, invoice, amount_sat=None, attempts=1) -> bool:
lnaddr = self._check_invoice(invoice, amount_sat)
@@ -981,8 +975,7 @@ class LNWallet(LNWorker):
# if there are multiple hints, we will use the first one that works,
# from a random permutation
random.shuffle(r_tags)
- with self.lock:
- channels = list(self.channels.values())
+ channels = list(self.channels.values())
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None}
for private_route in r_tags:
@@ -1196,8 +1189,7 @@ class LNWallet(LNWorker):
async def _calc_routing_hints_for_invoice(self, amount_sat: Optional[int]):
"""calculate routing hints (BOLT-11 'r' field)"""
routing_hints = []
- with self.lock:
- channels = list(self.channels.values())
+ channels = list(self.channels.values())
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None}
ignore_min_htlc_value = False
@@ -1251,24 +1243,27 @@ class LNWallet(LNWorker):
def get_balance(self):
with self.lock:
- return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0 for chan in self.channels.values()))/1000
+ return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0
+ for chan in self.channels.values())) / 1000
def num_sats_can_send(self) -> Union[Decimal, int]:
with self.lock:
- return Decimal(max(chan.available_to_spend(LOCAL) if chan.is_open() else 0 for chan in self.channels.values()))/1000 if self.channels else 0
+ return Decimal(max(chan.available_to_spend(LOCAL) if chan.is_open() else 0
+ for chan in self.channels.values()))/1000 if self.channels else 0
def num_sats_can_receive(self) -> Union[Decimal, int]:
with self.lock:
- return Decimal(max(chan.available_to_spend(REMOTE) if chan.is_open() else 0 for chan in self.channels.values()))/1000 if self.channels else 0
+ return Decimal(max(chan.available_to_spend(REMOTE) if chan.is_open() else 0
+ for chan in self.channels.values()))/1000 if self.channels else 0
async def close_channel(self, chan_id):
- chan = self.channels[chan_id]
+ chan = self._channels[chan_id]
peer = self._peers[chan.node_id]
return await peer.close_channel(chan_id)
async def force_close_channel(self, chan_id):
# returns txid or raises
- chan = self.channels[chan_id]
+ chan = self._channels[chan_id]
tx = chan.force_close_tx()
await self.network.broadcast_transaction(tx)
chan.set_state(ChannelState.FORCE_CLOSING)
@@ -1276,16 +1271,16 @@ class LNWallet(LNWorker):
async def try_force_closing(self, chan_id):
# fails silently but sets the state, so that we will retry later
- chan = self.channels[chan_id]
+ chan = self._channels[chan_id]
tx = chan.force_close_tx()
chan.set_state(ChannelState.FORCE_CLOSING)
await self.network.try_broadcasting(tx, 'force-close')
def remove_channel(self, chan_id):
- chan = self.channels[chan_id]
+ chan = self._channels[chan_id]
assert chan.get_state() == ChannelState.REDEEMED
with self.lock:
- self.channels.pop(chan_id)
+ self._channels.pop(chan_id)
self.db.get('channels').pop(chan_id.hex())
util.trigger_callback('channels_updated', self.wallet)
@@ -1316,9 +1311,7 @@ class LNWallet(LNWorker):
async def reestablish_peers_and_channels(self):
while True:
await asyncio.sleep(1)
- with self.lock:
- channels = list(self.channels.values())
- for chan in channels:
+ for chan in self.channels.values():
if chan.is_closed():
continue
# reestablish
@@ -1340,7 +1333,7 @@ class LNWallet(LNWorker):
return max(253, feerate_per_kvbyte // 4)
def create_channel_backup(self, channel_id):
- chan = self.channels[channel_id]
+ chan = self._channels[channel_id]
peer_addresses = list(chan.get_peer_addresses())
peer_addr = peer_addresses[0]
return ChannelBackupStorage(
diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
@@ -102,7 +102,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
self.remote_keypair = remote_keypair
self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue)
- self.channels = {chan.channel_id: chan}
+ self._channels = {chan.channel_id: chan}
self.payments = {}
self.logs = defaultdict(list)
self.wallet = MockWallet()
@@ -123,6 +123,10 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
return noop_lock()
@property
+ def channels(self):
+ return self._channels
+
+ @property
def peers(self):
return self._peers
@@ -131,11 +135,11 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
return {self.remote_keypair.pubkey: self.peer}
def channels_for_peer(self, pubkey):
- return self.channels
+ return self._channels
def get_channel_by_short_id(self, short_channel_id):
with self.lock:
- for chan in self.channels.values():
+ for chan in self._channels.values():
if chan.short_channel_id == short_channel_id:
return chan