electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 7431aac5cdbdec55295ceedcf38642fc1b3506ea
parent 4ccfa39fddb5c67a8c8104b3db8607c0071d5a0a
Author: SomberNight <somber.night@protonmail.com>
Date:   Sat, 27 Jul 2019 00:59:51 +0200

lnhtlc: (fix) was locking in too many updates during commit/revoke

Diffstat:
Melectrum/lnchannel.py | 10+++++++++-
Melectrum/lnhtlc.py | 106++++++++++++++++++++++++++++++++++---------------------------------------------
Melectrum/lnpeer.py | 10+++++-----
Melectrum/tests/test_lnchannel.py | 17++++++++++-------
Melectrum/tests/test_lnhtlc.py | 153++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------------
5 files changed, 171 insertions(+), 125 deletions(-)

diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py @@ -113,6 +113,10 @@ def str_bytes_dict_to_save(x): return {str(k): bh2u(v) for k, v in x.items()} class Channel(Logger): + # note: try to avoid naming ctns/ctxs/etc as "current" and "pending". + # they are ambiguous. Use "oldest_unrevoked" or "latest" or "next". + # TODO enforce this ^ + def diagnostic_name(self): if self.name: return str(self.name) @@ -154,7 +158,9 @@ class Channel(Logger): self.remote_commitment_to_be_revoked.deserialize(True) log = state.get('log') - self.hm = HTLCManager(self.config[LOCAL].ctn, self.config[REMOTE].ctn, log) + self.hm = HTLCManager(local_ctn=self.config[LOCAL].ctn, + remote_ctn=self.config[REMOTE].ctn, + log=log) self.name = name Logger.__init__(self) @@ -209,6 +215,7 @@ class Channel(Logger): return self.force_closed or self.get_state() in ['CLOSED', 'CLOSING'] def _check_can_pay(self, amount_msat: int) -> None: + # TODO check if this method uses correct ctns (should use "latest" + 1) if self.is_closed(): raise PaymentFailure('Channel closed') if self.get_state() != 'OPEN': @@ -525,6 +532,7 @@ class Channel(Logger): not be used in the UI cause it fluctuates (commit fee) """ # FIXME whose balance? whose ctx? + # FIXME confusing/mixing ctns (should probably use latest_ctn + 1; not oldest_unrevoked + 1) assert type(subject) is HTLCOwner return self.balance_minus_outgoing_htlcs(subject, ctx_owner=subject)\ - self.config[-subject].reserve_sat * 1000\ diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py @@ -1,46 +1,45 @@ from copy import deepcopy -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple, List -from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction +from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate from .util import bh2u class HTLCManager: - def __init__(self, local_ctn=0, remote_ctn=0, log=None): + def __init__(self, *, local_ctn=0, remote_ctn=0, log=None): # self.ctn[sub] is the ctn for the oldest unrevoked ctx of sub self.ctn = {LOCAL:local_ctn, REMOTE: remote_ctn} - # ctx_pending[sub] is True iff sub has sent commitment_signed but did not receive revoke_and_ack + # ctx_pending[sub] is True iff sub has received commitment_signed but did not send revoke_and_ack (sub has multiple unrevoked ctxs) self.ctx_pending = {LOCAL:False, REMOTE: False} # FIXME does this need to be persisted? - # expect_sig[SENT/RECEIVED] is True iff HTLCs have been sent/received but the corresponding commitment_signed has not been received/sent - self.expect_sig = {SENT: False, RECEIVED: False} if log is None: initial = {'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}} log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)} else: assert type(log) is dict - log = {HTLCOwner(int(x)): y for x, y in deepcopy(log).items()} + log = {HTLCOwner(int(sub)): action for sub, action in deepcopy(log).items()} for sub in (LOCAL, REMOTE): log[sub]['adds'] = {int(x): UpdateAddHtlc(*y) for x, y in log[sub]['adds'].items()} - coerceHtlcOwner2IntMap = lambda x: {HTLCOwner(int(y)): z for y, z in x.items()} + coerceHtlcOwner2IntMap = lambda ctns: {HTLCOwner(int(owner)): ctn for owner, ctn in ctns.items()} # "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn - log[sub]['locked_in'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['locked_in'].items()} - log[sub]['settles'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['settles'].items()} - log[sub]['fails'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['fails'].items()} + log[sub]['locked_in'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['locked_in'].items()} + log[sub]['settles'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['settles'].items()} + log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()} self.log = log - def ctn_latest(self, sub): + def ctn_latest(self, sub: HTLCOwner) -> int: """Return the ctn for the latest (newest that has a valid sig) ctx of sub""" return self.ctn[sub] + int(self.ctx_pending[sub]) def to_save(self): - x = deepcopy(self.log) + log = deepcopy(self.log) for sub in (LOCAL, REMOTE): + # adds d = {} - for htlc_id, htlc in x[sub]['adds'].items(): + for htlc_id, htlc in log[sub]['adds'].items(): d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:] - x[sub]['adds'] = d - return x + log[sub]['adds'] = d + return log def channel_open_finished(self): self.ctn = {LOCAL: 0, REMOTE: 0} @@ -48,53 +47,55 @@ class HTLCManager: def send_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc: htlc_id = htlc.htlc_id - adds = self.log[LOCAL]['adds'] - assert type(adds) is not str - adds[htlc_id] = htlc + self.log[LOCAL]['adds'][htlc_id] = htlc self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE)+1} - self.expect_sig[SENT] = True return htlc def recv_htlc(self, htlc: UpdateAddHtlc) -> None: htlc_id = htlc.htlc_id self.log[REMOTE]['adds'][htlc_id] = htlc - l = self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL)+1, REMOTE: None} - self.expect_sig[RECEIVED] = True + self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL)+1, REMOTE: None} + + def send_settle(self, htlc_id: int) -> None: + self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1} + + def recv_settle(self, htlc_id: int) -> None: + self.log[LOCAL]['settles'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None} + + def send_fail(self, htlc_id: int) -> None: + self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1} + + def recv_fail(self, htlc_id: int) -> None: + self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None} def send_ctx(self) -> None: assert self.ctn_latest(REMOTE) == self.ctn[REMOTE], (self.ctn_latest(REMOTE), self.ctn[REMOTE]) self.ctx_pending[REMOTE] = True - for locked_in in self.log[REMOTE]['locked_in'].values(): - if locked_in[REMOTE] is None: - locked_in[REMOTE] = self.ctn_latest(REMOTE) - self.expect_sig[SENT] = False def recv_ctx(self) -> None: assert self.ctn_latest(LOCAL) == self.ctn[LOCAL], (self.ctn_latest(LOCAL), self.ctn[LOCAL]) self.ctx_pending[LOCAL] = True - for locked_in in self.log[LOCAL]['locked_in'].values(): - if locked_in[LOCAL] is None: - locked_in[LOCAL] = self.ctn_latest(LOCAL) - self.expect_sig[RECEIVED] = False def send_rev(self) -> None: self.ctn[LOCAL] += 1 self.ctx_pending[LOCAL] = False + for ctns in self.log[REMOTE]['locked_in'].values(): + if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL): + ctns[REMOTE] = self.ctn_latest(REMOTE) + 1 for log_action in ('settles', 'fails'): - for htlc_id, ctns in self.log[LOCAL][log_action].items(): - if ctns[REMOTE] is None: + for ctns in self.log[LOCAL][log_action].values(): + if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL): ctns[REMOTE] = self.ctn_latest(REMOTE) + 1 def recv_rev(self) -> None: self.ctn[REMOTE] += 1 self.ctx_pending[REMOTE] = False - for htlc_id, ctns in self.log[LOCAL]['locked_in'].items(): - if ctns[LOCAL] is None: - #assert ctns[REMOTE] == self.ctn[REMOTE] # FIXME I don't think this assert is correct + for ctns in self.log[LOCAL]['locked_in'].values(): + if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE): ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 for log_action in ('settles', 'fails'): - for htlc_id, ctns in self.log[REMOTE][log_action].items(): - if ctns[LOCAL] is None: + for ctns in self.log[REMOTE][log_action].values(): + if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE): ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction, @@ -113,13 +114,7 @@ class HTLCManager: # party is the proposer of the HTLCs party = subject if direction == SENT else subject.inverted() for htlc_id, ctns in self.log[party]['locked_in'].items(): - htlc_height = ctns[subject] - if htlc_height is None: - expect_sig = self.expect_sig[RECEIVED if party != LOCAL else SENT] - include = not expect_sig and ctns[-subject] <= ctn - else: - include = htlc_height <= ctn - if include: + if ctns[subject] is not None and ctns[subject] <= ctn: settles = self.log[party]['settles'] fails = self.log[party]['fails'] not_settled = htlc_id not in settles or settles[htlc_id][subject] is None or settles[htlc_id][subject] > ctn @@ -138,23 +133,20 @@ class HTLCManager: l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn)] return l - def current_htlcs(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: - """Return the list of HTLCs in subject's oldest unrevoked ctx.""" + def get_htlcs_in_oldest_unrevoked_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: assert type(subject) is HTLCOwner ctn = self.ctn[subject] return self.htlcs(subject, ctn) - def pending_htlcs(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: - """Return the list of HTLCs in subject's next ctx (one after oldest unrevoked).""" + def get_htlcs_in_latest_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: assert type(subject) is HTLCOwner - ctn = self.ctn[subject] + 1 + ctn = self.ctn_latest(subject) return self.htlcs(subject, ctn) - def send_settle(self, htlc_id: int) -> None: - self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1} - - def recv_settle(self, htlc_id: int) -> None: - self.log[LOCAL]['settles'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None} + def get_htlcs_in_next_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: + assert type(subject) is HTLCOwner + ctn = self.ctn_latest(subject) + 1 + return self.htlcs(subject, ctn) def all_settled_htlcs_ever_by_direction(self, subject: HTLCOwner, direction: Direction, ctn: int = None) -> Sequence[UpdateAddHtlc]: @@ -194,9 +186,3 @@ class HTLCManager: return [self.log[LOCAL]['adds'][htlc_id] for htlc_id, ctns in self.log[LOCAL]['settles'].items() if ctns[LOCAL] == ctn] - - def send_fail(self, htlc_id: int) -> None: - self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1} - - def recv_fail(self, htlc_id: int) -> None: - self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None} diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -1024,12 +1024,12 @@ class Peer(Logger): def maybe_send_commitment(self, chan: Channel): ctn_to_sign = chan.get_current_ctn(REMOTE) + 1 # if there are no changes, we will not (and must not) send a new commitment - pending, current = chan.hm.pending_htlcs(REMOTE), chan.hm.current_htlcs(REMOTE) - if (pending == current + next_htlcs, latest_htlcs = chan.hm.get_htlcs_in_next_ctx(REMOTE), chan.hm.get_htlcs_in_latest_ctx(REMOTE) + if (next_htlcs == latest_htlcs and chan.pending_feerate(REMOTE) == chan.constraints.feerate) \ or ctn_to_sign == self.sent_commitment_for_ctn_last[chan]: return - self.logger.info(f'send_commitment. old number htlcs: {len(current)}, new number htlcs: {len(pending)}') + self.logger.info(f'send_commitment. old number htlcs: {len(latest_htlcs)}, new number htlcs: {len(next_htlcs)}') sig_64, htlc_sigs = chan.sign_next_commitment() self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs)) self.sent_commitment_for_ctn_last[chan] = ctn_to_sign @@ -1087,8 +1087,8 @@ class Peer(Logger): channel_id = payload['channel_id'] chan = self.channels[channel_id] # make sure there were changes to the ctx, otherwise the remote peer is misbehaving - if (chan.hm.pending_htlcs(LOCAL) == chan.hm.current_htlcs(LOCAL) - and chan.pending_feerate(LOCAL) == chan.constraints.feerate): + if (chan.hm.get_htlcs_in_next_ctx(LOCAL) == chan.hm.get_htlcs_in_latest_ctx(LOCAL) + and chan.pending_feerate(LOCAL) == chan.constraints.feerate): raise RemoteMisbehaving('received commitment_signed without pending changes') # make sure ctn is new ctn_to_recv = chan.get_current_ctn(LOCAL) + 1 diff --git a/electrum/tests/test_lnchannel.py b/electrum/tests/test_lnchannel.py @@ -226,6 +226,8 @@ class TestChannel(unittest.TestCase): self.bob_channel.add_htlc(self.htlc_dict) self.alice_channel.receive_htlc(self.htlc_dict) self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment()) + self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 3) + self.alice_channel.revoke_current_commitment() self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 4) def test_SimpleAddSettleWorkflow(self): @@ -279,8 +281,8 @@ class TestChannel(unittest.TestCase): self.assertTrue(alice_channel.signature_fits(com())) self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com())) - self.assertEqual(next(iter(alice_channel.hm.pending_htlcs(REMOTE)))[0], RECEIVED) - self.assertEqual(alice_channel.hm.pending_htlcs(REMOTE), bob_channel.hm.pending_htlcs(LOCAL)) + self.assertEqual(next(iter(alice_channel.hm.get_htlcs_in_next_ctx(REMOTE)))[0], RECEIVED) + self.assertEqual(alice_channel.hm.get_htlcs_in_next_ctx(REMOTE), bob_channel.hm.get_htlcs_in_next_ctx(LOCAL)) self.assertEqual(alice_channel.pending_commitment(REMOTE).outputs(), bob_channel.pending_commitment(LOCAL).outputs()) # Bob receives this signature message, and checks that this covers the @@ -291,14 +293,11 @@ class TestChannel(unittest.TestCase): self.assertTrue(bob_channel.signature_fits(bob_channel.pending_commitment(LOCAL))) self.assertEqual(bob_channel.config[REMOTE].ctn, 0) - self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), [htlc]) + self.assertEqual(bob_channel.included_htlcs(LOCAL, RECEIVED, 1), [htlc])# self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 0), []) self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 1), [htlc]) - self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 0), []) - self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), [htlc]) - self.assertEqual(alice_channel.included_htlcs(REMOTE, SENT, 0), []) self.assertEqual(alice_channel.included_htlcs(REMOTE, SENT, 1), []) @@ -323,7 +322,11 @@ class TestChannel(unittest.TestCase): self.assertTrue(alice_channel.signature_fits(com())) self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com())) - self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 3) + # so far: Alice added htlc, Alice signed. + self.assertEqual(len(alice_channel.current_commitment(LOCAL).outputs()), 2) + self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 2) + self.assertEqual(len(alice_channel.current_commitment(REMOTE).outputs()), 2) # oldest unrevoked + self.assertEqual(len(alice_channel.pending_commitment(REMOTE).outputs()), 3) # latest # Alice then processes this revocation, sending her own revocation for # her prior commitment transaction. Alice shouldn't have any HTLCs to diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py @@ -14,42 +14,54 @@ class TestHTLCManager(unittest.TestCase): B = HTLCManager() ah0, bh0 = H('A', 0), H('B', 0) B.recv_htlc(A.send_htlc(ah0)) - self.assertTrue(B.expect_sig[RECEIVED]) - self.assertTrue(A.expect_sig[SENT]) - self.assertFalse(B.expect_sig[SENT]) - self.assertFalse(A.expect_sig[RECEIVED]) self.assertEqual(B.log[REMOTE]['locked_in'][0][LOCAL], 1) A.recv_htlc(B.send_htlc(bh0)) - self.assertTrue(B.expect_sig[RECEIVED]) - self.assertTrue(A.expect_sig[SENT]) - self.assertTrue(A.expect_sig[SENT]) - self.assertTrue(B.expect_sig[RECEIVED]) - self.assertEqual(B.current_htlcs(LOCAL), []) - self.assertEqual(A.current_htlcs(LOCAL), []) - self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0)]) - self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0)]) + self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), []) + self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), []) + self.assertEqual(B.get_htlcs_in_next_ctx(LOCAL), [(RECEIVED, ah0)]) + self.assertEqual(A.get_htlcs_in_next_ctx(LOCAL), [(RECEIVED, bh0)]) A.send_ctx() B.recv_ctx() B.send_ctx() A.recv_ctx() - self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1]) - self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1]) + self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), []) + self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), []) + self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0)]) + self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0)]) B.send_rev() A.recv_rev() A.send_rev() B.recv_rev() - self.assertEqual(B.current_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1]) - self.assertEqual(A.current_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1]) + self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, ah0)]) + self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, bh0)]) + self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0)]) + self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0)]) + A.send_ctx() + B.recv_ctx() + B.send_ctx() + A.recv_ctx() + self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, ah0)]) + self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, bh0)]) + self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1]) + self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1]) + B.send_rev() + A.recv_rev() + A.send_rev() + B.recv_rev() + self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1]) + self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1]) + self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1]) + self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1]) def test_single_htlc_full_lifecycle(self): def htlc_lifecycle(htlc_success: bool): A = HTLCManager() B = HTLCManager() B.recv_htlc(A.send_htlc(H('A', 0))) - self.assertEqual(len(B.pending_htlcs(REMOTE)), 0) - self.assertEqual(len(A.pending_htlcs(REMOTE)), 1) - self.assertEqual(len(B.pending_htlcs(LOCAL)), 1) - self.assertEqual(len(A.pending_htlcs(LOCAL)), 0) + self.assertEqual(len(B.get_htlcs_in_next_ctx(REMOTE)), 0) + self.assertEqual(len(A.get_htlcs_in_next_ctx(REMOTE)), 1) + self.assertEqual(len(B.get_htlcs_in_next_ctx(LOCAL)), 1) + self.assertEqual(len(A.get_htlcs_in_next_ctx(LOCAL)), 0) A.send_ctx() B.recv_ctx() B.send_rev() @@ -58,8 +70,8 @@ class TestHTLCManager(unittest.TestCase): A.recv_ctx() A.send_rev() B.recv_rev() - self.assertEqual(len(A.current_htlcs(LOCAL)), 1) - self.assertEqual(len(B.current_htlcs(LOCAL)), 1) + self.assertEqual(len(A.get_htlcs_in_latest_ctx(LOCAL)), 1) + self.assertEqual(len(B.get_htlcs_in_latest_ctx(LOCAL)), 1) if htlc_success: B.send_settle(0) A.recv_settle(0) @@ -67,47 +79,47 @@ class TestHTLCManager(unittest.TestCase): B.send_fail(0) A.recv_fail(0) self.assertEqual(A.htlcs_by_direction(REMOTE, RECEIVED), [H('A', 0)]) - self.assertNotEqual(A.current_htlcs(LOCAL), []) - self.assertNotEqual(B.current_htlcs(REMOTE), []) + self.assertNotEqual(A.get_htlcs_in_latest_ctx(LOCAL), []) + self.assertNotEqual(B.get_htlcs_in_latest_ctx(REMOTE), []) - self.assertEqual(A.pending_htlcs(LOCAL), []) - self.assertNotEqual(A.pending_htlcs(REMOTE), []) - self.assertEqual(A.pending_htlcs(REMOTE), A.current_htlcs(REMOTE)) + self.assertEqual(A.get_htlcs_in_next_ctx(LOCAL), []) + self.assertNotEqual(A.get_htlcs_in_next_ctx(REMOTE), []) + self.assertEqual(A.get_htlcs_in_next_ctx(REMOTE), A.get_htlcs_in_latest_ctx(REMOTE)) - self.assertEqual(B.pending_htlcs(REMOTE), []) + self.assertEqual(B.get_htlcs_in_next_ctx(REMOTE), []) B.send_ctx() A.recv_ctx() A.send_rev() # here pending_htlcs(REMOTE) should become empty - self.assertEqual(A.pending_htlcs(REMOTE), []) + self.assertEqual(A.get_htlcs_in_next_ctx(REMOTE), []) B.recv_rev() A.send_ctx() B.recv_ctx() B.send_rev() A.recv_rev() - self.assertEqual(B.current_htlcs(LOCAL), []) - self.assertEqual(A.current_htlcs(LOCAL), []) - self.assertEqual(A.current_htlcs(REMOTE), []) - self.assertEqual(B.current_htlcs(REMOTE), []) + self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), []) + self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), []) + self.assertEqual(A.get_htlcs_in_latest_ctx(REMOTE), []) + self.assertEqual(B.get_htlcs_in_latest_ctx(REMOTE), []) self.assertEqual(len(A.all_settled_htlcs_ever(LOCAL)), int(htlc_success)) self.assertEqual(len(A.sent_in_ctn(2)), int(htlc_success)) self.assertEqual(len(B.received_in_ctn(2)), int(htlc_success)) A.recv_htlc(B.send_htlc(H('B', 0))) - self.assertEqual(A.pending_htlcs(REMOTE), []) - self.assertNotEqual(A.pending_htlcs(LOCAL), []) - self.assertNotEqual(B.pending_htlcs(REMOTE), []) - self.assertEqual(B.pending_htlcs(LOCAL), []) + self.assertEqual(A.get_htlcs_in_next_ctx(REMOTE), []) + self.assertNotEqual(A.get_htlcs_in_next_ctx(LOCAL), []) + self.assertNotEqual(B.get_htlcs_in_next_ctx(REMOTE), []) + self.assertEqual(B.get_htlcs_in_next_ctx(LOCAL), []) B.send_ctx() A.recv_ctx() A.send_rev() B.recv_rev() - self.assertNotEqual(A.pending_htlcs(REMOTE), A.current_htlcs(REMOTE)) - self.assertEqual(A.pending_htlcs(LOCAL), A.current_htlcs(LOCAL)) - self.assertEqual(B.pending_htlcs(REMOTE), B.current_htlcs(REMOTE)) - self.assertNotEqual(B.pending_htlcs(LOCAL), B.pending_htlcs(REMOTE)) + self.assertNotEqual(A.get_htlcs_in_next_ctx(REMOTE), A.get_htlcs_in_latest_ctx(REMOTE)) + self.assertEqual(A.get_htlcs_in_next_ctx(LOCAL), A.get_htlcs_in_latest_ctx(LOCAL)) + self.assertEqual(B.get_htlcs_in_next_ctx(REMOTE), B.get_htlcs_in_latest_ctx(REMOTE)) + self.assertNotEqual(B.get_htlcs_in_next_ctx(LOCAL), B.get_htlcs_in_next_ctx(REMOTE)) htlc_lifecycle(htlc_success=True) htlc_lifecycle(htlc_success=False) @@ -116,7 +128,8 @@ class TestHTLCManager(unittest.TestCase): def htlc_lifecycle(htlc_success: bool): A = HTLCManager() B = HTLCManager() - B.recv_htlc(A.send_htlc(H('A', 0))) + ah0 = H('A', 0) + B.recv_htlc(A.send_htlc(ah0)) A.send_ctx() B.recv_ctx() B.send_rev() @@ -127,11 +140,22 @@ class TestHTLCManager(unittest.TestCase): else: B.send_fail(0) A.recv_fail(0) - self.assertEqual(B.pending_htlcs(REMOTE), []) + self.assertEqual([], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_oldest_unrevoked_ctx(REMOTE)) + self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE)) + self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE)) B.send_ctx() A.recv_ctx() A.send_rev() B.recv_rev() + self.assertEqual([], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_oldest_unrevoked_ctx(REMOTE)) + self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE)) + self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL)) + self.assertEqual([], A.get_htlcs_in_next_ctx(REMOTE)) htlc_lifecycle(htlc_success=True) htlc_lifecycle(htlc_success=False) @@ -144,13 +168,38 @@ class TestHTLCManager(unittest.TestCase): B.send_rev() ah0 = H('A', 0) B.recv_htlc(A.send_htlc(ah0)) - self.assertEqual([], A.current_htlcs(LOCAL)) - self.assertEqual([], A.current_htlcs(REMOTE)) - self.assertEqual([], A.pending_htlcs(LOCAL)) - self.assertEqual([], A.pending_htlcs(REMOTE)) + self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL)) + self.assertEqual([], A.get_htlcs_in_latest_ctx(REMOTE)) + self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE)) A.recv_rev() - self.assertEqual([], A.current_htlcs(LOCAL)) - self.assertEqual([], A.current_htlcs(REMOTE)) - self.assertEqual([(Direction.SENT, ah0)], A.pending_htlcs(LOCAL)) - self.assertEqual([(Direction.RECEIVED, ah0)], A.pending_htlcs(REMOTE)) - + self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL)) + self.assertEqual([], A.get_htlcs_in_latest_ctx(REMOTE)) + self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE)) + A.send_ctx() + B.recv_ctx() + self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE)) + self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE)) + B.send_rev() + A.recv_rev() + self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE)) + self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE)) + B.send_ctx() + A.recv_ctx() + self.assertEqual([], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL)) + self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_latest_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE)) + self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE)) + A.send_rev() + B.recv_rev() + self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL)) + self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_latest_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE)) + self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL)) + self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))