electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit a27b03be6da5a9f4b397830151daf55c682beda6
parent 4fc9f243f7d1c957770589582261389a6a74b9ee
Author: SomberNight <somber.night@protonmail.com>
Date:   Mon, 12 Aug 2019 18:37:13 +0200

lnhtlc: local update raw messages must not be deleted before acked

In recv_rev() previously all unacked_local_updates were deleted
as it was assumed that all of them have been acked at that point by
the revoke_and_ack itself. However this is not necessarily the case:
see new test case.

renamed log['unacked_local_updates'] to log['unacked_local_updates2']
to avoid breaking existing wallet files

Diffstat:
Melectrum/lnhtlc.py | 33+++++++++++++++++++++++----------
Melectrum/lnpeer.py | 13++++++++-----
Melectrum/lntransport.py | 2+-
Melectrum/tests/test_lnhtlc.py | 31+++++++++++++++++++++++++++++++
4 files changed, 63 insertions(+), 16 deletions(-)

diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Optional, Sequence, Tuple, List +from typing import Optional, Sequence, Tuple, List, Dict from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate from .util import bh2u, bfh @@ -33,9 +33,10 @@ class HTLCManager: log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()} # "side who initiated fee update" -> action -> list of FeeUpdates log[sub]['fee_updates'] = [FeeUpdate.from_dict(fee_upd) for fee_upd in log[sub]['fee_updates']] - if 'unacked_local_updates' not in log: - log['unacked_local_updates'] = [] - log['unacked_local_updates'] = [bfh(upd) for upd in log['unacked_local_updates']] + if 'unacked_local_updates2' not in log: + log['unacked_local_updates2'] = {} + log['unacked_local_updates2'] = {int(ctn): [bfh(msg) for msg in messages] + for ctn, messages in log['unacked_local_updates2'].items()} # maybe bootstrap fee_updates if initial_feerate was provided if initial_feerate is not None: assert type(initial_feerate) is int @@ -74,7 +75,8 @@ class HTLCManager: log[sub]['adds'] = d # fee_updates log[sub]['fee_updates'] = [FeeUpdate.to_dict(fee_upd) for fee_upd in log[sub]['fee_updates']] - log['unacked_local_updates'] = [bh2u(upd) for upd in log['unacked_local_updates']] + log['unacked_local_updates2'] = {ctn: [bh2u(msg) for msg in messages] + for ctn, messages in log['unacked_local_updates2'].items()} return log ##### Actions on channel: @@ -175,7 +177,7 @@ class HTLCManager: if fee_update.ctns[LOCAL] is None and fee_update.ctns[REMOTE] <= self.ctn_latest(REMOTE): fee_update.ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 # no need to keep local update raw msgs anymore, they have just been ACKed. - self.log['unacked_local_updates'].clear() + self.log['unacked_local_updates2'].pop(self.log[REMOTE]['ctn'], None) def discard_unsigned_remote_updates(self): """Discard updates sent by the remote, that the remote itself @@ -200,11 +202,22 @@ class HTLCManager: if fee_update.ctns[LOCAL] > self.ctn_latest(LOCAL): del self.log[REMOTE]['fee_updates'][i] - def store_local_update_raw_msg(self, raw_update_msg: bytes): - self.log['unacked_local_updates'].append(raw_update_msg) + def store_local_update_raw_msg(self, raw_update_msg: bytes, *, is_commitment_signed: bool) -> None: + """We need to be able to replay unacknowledged updates we sent to the remote + in case of disconnections. Hence, raw update and commitment_signed messages + are stored temporarily (until they are acked).""" + # self.log['unacked_local_updates2'][ctn_idx] is a list of raw messages + # containing some number of updates and then a single commitment_signed + if is_commitment_signed: + ctn_idx = self.ctn_latest(REMOTE) + else: + ctn_idx = self.ctn_latest(REMOTE) + 1 + if ctn_idx not in self.log['unacked_local_updates2']: + self.log['unacked_local_updates2'][ctn_idx] = [] + self.log['unacked_local_updates2'][ctn_idx].append(raw_update_msg) - def get_unacked_local_updates(self) -> Sequence[bytes]: - return self.log['unacked_local_updates'] + def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]: + return self.log['unacked_local_updates2'] ##### Queries re HTLCs: diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -96,12 +96,13 @@ class Peer(Logger): self.transport.send_bytes(raw_msg) def _store_raw_msg_if_local_update(self, raw_msg: bytes, *, message_name: str, channel_id: Optional[bytes]): - if not (message_name.startswith("update_") or message_name == "commitment_signed"): + is_commitment_signed = message_name == "commitment_signed" + if not (message_name.startswith("update_") or is_commitment_signed): return assert channel_id chan = self.lnworker.channels[channel_id] # type: Channel - chan.hm.store_local_update_raw_msg(raw_msg) - if message_name == "commitment_signed": + chan.hm.store_local_update_raw_msg(raw_msg, is_commitment_signed=is_commitment_signed) + if is_commitment_signed: # saving now, to ensure replaying updates works (in case of channel reestablishment) self.lnworker.save_channel(chan) @@ -755,8 +756,9 @@ class Peer(Logger): # Multiple valid ctxs at the same ctn is a major headache for pre-signing spending txns, # e.g. for watchtowers, hence we must ensure these ctxs coincide. # We replay the local updates even if they were not yet committed. - for raw_upd_msg in chan.hm.get_unacked_local_updates(): - self.transport.send_bytes(raw_upd_msg) + for ctn, messages in chan.hm.get_unacked_local_updates(): + for raw_upd_msg in messages: + self.transport.send_bytes(raw_upd_msg) should_close_we_are_ahead = False should_close_they_are_ahead = False @@ -831,6 +833,7 @@ class Peer(Logger): self.lnworker.force_close_channel(chan_id) return + # note: chan.short_channel_id being set implies the funding txn is already at sufficient depth if their_next_local_ctn == next_local_ctn == 1 and chan.short_channel_id: self.send_funding_locked(chan) # checks done diff --git a/electrum/lntransport.py b/electrum/lntransport.py @@ -88,7 +88,7 @@ def create_ephemeral_key() -> (bytes, bytes): class LNTransportBase: - def send_bytes(self, msg): + def send_bytes(self, msg: bytes) -> None: l = len(msg).to_bytes(2, 'big') lc = aead_encrypt(self.sk, self.sn(), b'', l) c = aead_encrypt(self.sk, self.sn(), b'', msg) diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py @@ -211,3 +211,34 @@ class TestHTLCManager(unittest.TestCase): self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE)) self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL)) self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE)) + + def test_unacked_local_updates(self): + A = HTLCManager() + B = HTLCManager() + A.channel_open_finished() + B.channel_open_finished() + self.assertEqual({}, A.get_unacked_local_updates()) + + ah0 = H('A', 0) + B.recv_htlc(A.send_htlc(ah0)) + A.store_local_update_raw_msg(b"upd_msg0", is_commitment_signed=False) + self.assertEqual({1: [b"upd_msg0"]}, A.get_unacked_local_updates()) + + ah1 = H('A', 1) + B.recv_htlc(A.send_htlc(ah1)) + A.store_local_update_raw_msg(b"upd_msg1", is_commitment_signed=False) + self.assertEqual({1: [b"upd_msg0", b"upd_msg1"]}, A.get_unacked_local_updates()) + + A.send_ctx() + B.recv_ctx() + A.store_local_update_raw_msg(b"ctx1", is_commitment_signed=True) + self.assertEqual({1: [b"upd_msg0", b"upd_msg1", b"ctx1"]}, A.get_unacked_local_updates()) + + ah2 = H('A', 2) + B.recv_htlc(A.send_htlc(ah2)) + A.store_local_update_raw_msg(b"upd_msg2", is_commitment_signed=False) + self.assertEqual({1: [b"upd_msg0", b"upd_msg1", b"ctx1"], 2: [b"upd_msg2"]}, A.get_unacked_local_updates()) + + B.send_rev() + A.recv_rev() + self.assertEqual({2: [b"upd_msg2"]}, A.get_unacked_local_updates())