commit a27b03be6da5a9f4b397830151daf55c682beda6
parent 4fc9f243f7d1c957770589582261389a6a74b9ee
Author: SomberNight <somber.night@protonmail.com>
Date: Mon, 12 Aug 2019 18:37:13 +0200
lnhtlc: local update raw messages must not be deleted before acked
In recv_rev() previously all unacked_local_updates were deleted
as it was assumed that all of them have been acked at that point by
the revoke_and_ack itself. However this is not necessarily the case:
see new test case.
renamed log['unacked_local_updates'] to log['unacked_local_updates2']
to avoid breaking existing wallet files
Diffstat:
4 files changed, 63 insertions(+), 16 deletions(-)
diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py
@@ -1,5 +1,5 @@
from copy import deepcopy
-from typing import Optional, Sequence, Tuple, List
+from typing import Optional, Sequence, Tuple, List, Dict
from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate
from .util import bh2u, bfh
@@ -33,9 +33,10 @@ class HTLCManager:
log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()}
# "side who initiated fee update" -> action -> list of FeeUpdates
log[sub]['fee_updates'] = [FeeUpdate.from_dict(fee_upd) for fee_upd in log[sub]['fee_updates']]
- if 'unacked_local_updates' not in log:
- log['unacked_local_updates'] = []
- log['unacked_local_updates'] = [bfh(upd) for upd in log['unacked_local_updates']]
+ if 'unacked_local_updates2' not in log:
+ log['unacked_local_updates2'] = {}
+ log['unacked_local_updates2'] = {int(ctn): [bfh(msg) for msg in messages]
+ for ctn, messages in log['unacked_local_updates2'].items()}
# maybe bootstrap fee_updates if initial_feerate was provided
if initial_feerate is not None:
assert type(initial_feerate) is int
@@ -74,7 +75,8 @@ class HTLCManager:
log[sub]['adds'] = d
# fee_updates
log[sub]['fee_updates'] = [FeeUpdate.to_dict(fee_upd) for fee_upd in log[sub]['fee_updates']]
- log['unacked_local_updates'] = [bh2u(upd) for upd in log['unacked_local_updates']]
+ log['unacked_local_updates2'] = {ctn: [bh2u(msg) for msg in messages]
+ for ctn, messages in log['unacked_local_updates2'].items()}
return log
##### Actions on channel:
@@ -175,7 +177,7 @@ class HTLCManager:
if fee_update.ctns[LOCAL] is None and fee_update.ctns[REMOTE] <= self.ctn_latest(REMOTE):
fee_update.ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
# no need to keep local update raw msgs anymore, they have just been ACKed.
- self.log['unacked_local_updates'].clear()
+ self.log['unacked_local_updates2'].pop(self.log[REMOTE]['ctn'], None)
def discard_unsigned_remote_updates(self):
"""Discard updates sent by the remote, that the remote itself
@@ -200,11 +202,22 @@ class HTLCManager:
if fee_update.ctns[LOCAL] > self.ctn_latest(LOCAL):
del self.log[REMOTE]['fee_updates'][i]
- def store_local_update_raw_msg(self, raw_update_msg: bytes):
- self.log['unacked_local_updates'].append(raw_update_msg)
+ def store_local_update_raw_msg(self, raw_update_msg: bytes, *, is_commitment_signed: bool) -> None:
+ """We need to be able to replay unacknowledged updates we sent to the remote
+ in case of disconnections. Hence, raw update and commitment_signed messages
+ are stored temporarily (until they are acked)."""
+ # self.log['unacked_local_updates2'][ctn_idx] is a list of raw messages
+ # containing some number of updates and then a single commitment_signed
+ if is_commitment_signed:
+ ctn_idx = self.ctn_latest(REMOTE)
+ else:
+ ctn_idx = self.ctn_latest(REMOTE) + 1
+ if ctn_idx not in self.log['unacked_local_updates2']:
+ self.log['unacked_local_updates2'][ctn_idx] = []
+ self.log['unacked_local_updates2'][ctn_idx].append(raw_update_msg)
- def get_unacked_local_updates(self) -> Sequence[bytes]:
- return self.log['unacked_local_updates']
+ def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]:
+ return self.log['unacked_local_updates2']
##### Queries re HTLCs:
diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
@@ -96,12 +96,13 @@ class Peer(Logger):
self.transport.send_bytes(raw_msg)
def _store_raw_msg_if_local_update(self, raw_msg: bytes, *, message_name: str, channel_id: Optional[bytes]):
- if not (message_name.startswith("update_") or message_name == "commitment_signed"):
+ is_commitment_signed = message_name == "commitment_signed"
+ if not (message_name.startswith("update_") or is_commitment_signed):
return
assert channel_id
chan = self.lnworker.channels[channel_id] # type: Channel
- chan.hm.store_local_update_raw_msg(raw_msg)
- if message_name == "commitment_signed":
+ chan.hm.store_local_update_raw_msg(raw_msg, is_commitment_signed=is_commitment_signed)
+ if is_commitment_signed:
# saving now, to ensure replaying updates works (in case of channel reestablishment)
self.lnworker.save_channel(chan)
@@ -755,8 +756,9 @@ class Peer(Logger):
# Multiple valid ctxs at the same ctn is a major headache for pre-signing spending txns,
# e.g. for watchtowers, hence we must ensure these ctxs coincide.
# We replay the local updates even if they were not yet committed.
- for raw_upd_msg in chan.hm.get_unacked_local_updates():
- self.transport.send_bytes(raw_upd_msg)
+ for ctn, messages in chan.hm.get_unacked_local_updates():
+ for raw_upd_msg in messages:
+ self.transport.send_bytes(raw_upd_msg)
should_close_we_are_ahead = False
should_close_they_are_ahead = False
@@ -831,6 +833,7 @@ class Peer(Logger):
self.lnworker.force_close_channel(chan_id)
return
+ # note: chan.short_channel_id being set implies the funding txn is already at sufficient depth
if their_next_local_ctn == next_local_ctn == 1 and chan.short_channel_id:
self.send_funding_locked(chan)
# checks done
diff --git a/electrum/lntransport.py b/electrum/lntransport.py
@@ -88,7 +88,7 @@ def create_ephemeral_key() -> (bytes, bytes):
class LNTransportBase:
- def send_bytes(self, msg):
+ def send_bytes(self, msg: bytes) -> None:
l = len(msg).to_bytes(2, 'big')
lc = aead_encrypt(self.sk, self.sn(), b'', l)
c = aead_encrypt(self.sk, self.sn(), b'', msg)
diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py
@@ -211,3 +211,34 @@ class TestHTLCManager(unittest.TestCase):
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
+
+ def test_unacked_local_updates(self):
+ A = HTLCManager()
+ B = HTLCManager()
+ A.channel_open_finished()
+ B.channel_open_finished()
+ self.assertEqual({}, A.get_unacked_local_updates())
+
+ ah0 = H('A', 0)
+ B.recv_htlc(A.send_htlc(ah0))
+ A.store_local_update_raw_msg(b"upd_msg0", is_commitment_signed=False)
+ self.assertEqual({1: [b"upd_msg0"]}, A.get_unacked_local_updates())
+
+ ah1 = H('A', 1)
+ B.recv_htlc(A.send_htlc(ah1))
+ A.store_local_update_raw_msg(b"upd_msg1", is_commitment_signed=False)
+ self.assertEqual({1: [b"upd_msg0", b"upd_msg1"]}, A.get_unacked_local_updates())
+
+ A.send_ctx()
+ B.recv_ctx()
+ A.store_local_update_raw_msg(b"ctx1", is_commitment_signed=True)
+ self.assertEqual({1: [b"upd_msg0", b"upd_msg1", b"ctx1"]}, A.get_unacked_local_updates())
+
+ ah2 = H('A', 2)
+ B.recv_htlc(A.send_htlc(ah2))
+ A.store_local_update_raw_msg(b"upd_msg2", is_commitment_signed=False)
+ self.assertEqual({1: [b"upd_msg0", b"upd_msg1", b"ctx1"], 2: [b"upd_msg2"]}, A.get_unacked_local_updates())
+
+ B.send_rev()
+ A.recv_rev()
+ self.assertEqual({2: [b"upd_msg2"]}, A.get_unacked_local_updates())