commit 7431aac5cdbdec55295ceedcf38642fc1b3506ea
parent 4ccfa39fddb5c67a8c8104b3db8607c0071d5a0a
Author: SomberNight <somber.night@protonmail.com>
Date: Sat, 27 Jul 2019 00:59:51 +0200
lnhtlc: (fix) was locking in too many updates during commit/revoke
Diffstat:
5 files changed, 171 insertions(+), 125 deletions(-)
diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py
@@ -113,6 +113,10 @@ def str_bytes_dict_to_save(x):
return {str(k): bh2u(v) for k, v in x.items()}
class Channel(Logger):
+ # note: try to avoid naming ctns/ctxs/etc as "current" and "pending".
+ # they are ambiguous. Use "oldest_unrevoked" or "latest" or "next".
+ # TODO enforce this ^
+
def diagnostic_name(self):
if self.name:
return str(self.name)
@@ -154,7 +158,9 @@ class Channel(Logger):
self.remote_commitment_to_be_revoked.deserialize(True)
log = state.get('log')
- self.hm = HTLCManager(self.config[LOCAL].ctn, self.config[REMOTE].ctn, log)
+ self.hm = HTLCManager(local_ctn=self.config[LOCAL].ctn,
+ remote_ctn=self.config[REMOTE].ctn,
+ log=log)
self.name = name
Logger.__init__(self)
@@ -209,6 +215,7 @@ class Channel(Logger):
return self.force_closed or self.get_state() in ['CLOSED', 'CLOSING']
def _check_can_pay(self, amount_msat: int) -> None:
+ # TODO check if this method uses correct ctns (should use "latest" + 1)
if self.is_closed():
raise PaymentFailure('Channel closed')
if self.get_state() != 'OPEN':
@@ -525,6 +532,7 @@ class Channel(Logger):
not be used in the UI cause it fluctuates (commit fee)
"""
# FIXME whose balance? whose ctx?
+ # FIXME confusing/mixing ctns (should probably use latest_ctn + 1; not oldest_unrevoked + 1)
assert type(subject) is HTLCOwner
return self.balance_minus_outgoing_htlcs(subject, ctx_owner=subject)\
- self.config[-subject].reserve_sat * 1000\
diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py
@@ -1,46 +1,45 @@
from copy import deepcopy
-from typing import Optional, Sequence, Tuple
+from typing import Optional, Sequence, Tuple, List
-from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction
+from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate
from .util import bh2u
class HTLCManager:
- def __init__(self, local_ctn=0, remote_ctn=0, log=None):
+ def __init__(self, *, local_ctn=0, remote_ctn=0, log=None):
# self.ctn[sub] is the ctn for the oldest unrevoked ctx of sub
self.ctn = {LOCAL:local_ctn, REMOTE: remote_ctn}
- # ctx_pending[sub] is True iff sub has sent commitment_signed but did not receive revoke_and_ack
+ # ctx_pending[sub] is True iff sub has received commitment_signed but did not send revoke_and_ack (sub has multiple unrevoked ctxs)
self.ctx_pending = {LOCAL:False, REMOTE: False} # FIXME does this need to be persisted?
- # expect_sig[SENT/RECEIVED] is True iff HTLCs have been sent/received but the corresponding commitment_signed has not been received/sent
- self.expect_sig = {SENT: False, RECEIVED: False}
if log is None:
initial = {'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}}
log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)}
else:
assert type(log) is dict
- log = {HTLCOwner(int(x)): y for x, y in deepcopy(log).items()}
+ log = {HTLCOwner(int(sub)): action for sub, action in deepcopy(log).items()}
for sub in (LOCAL, REMOTE):
log[sub]['adds'] = {int(x): UpdateAddHtlc(*y) for x, y in log[sub]['adds'].items()}
- coerceHtlcOwner2IntMap = lambda x: {HTLCOwner(int(y)): z for y, z in x.items()}
+ coerceHtlcOwner2IntMap = lambda ctns: {HTLCOwner(int(owner)): ctn for owner, ctn in ctns.items()}
# "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn
- log[sub]['locked_in'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['locked_in'].items()}
- log[sub]['settles'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['settles'].items()}
- log[sub]['fails'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['fails'].items()}
+ log[sub]['locked_in'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['locked_in'].items()}
+ log[sub]['settles'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['settles'].items()}
+ log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()}
self.log = log
- def ctn_latest(self, sub):
+ def ctn_latest(self, sub: HTLCOwner) -> int:
"""Return the ctn for the latest (newest that has a valid sig) ctx of sub"""
return self.ctn[sub] + int(self.ctx_pending[sub])
def to_save(self):
- x = deepcopy(self.log)
+ log = deepcopy(self.log)
for sub in (LOCAL, REMOTE):
+ # adds
d = {}
- for htlc_id, htlc in x[sub]['adds'].items():
+ for htlc_id, htlc in log[sub]['adds'].items():
d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:]
- x[sub]['adds'] = d
- return x
+ log[sub]['adds'] = d
+ return log
def channel_open_finished(self):
self.ctn = {LOCAL: 0, REMOTE: 0}
@@ -48,53 +47,55 @@ class HTLCManager:
def send_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc:
htlc_id = htlc.htlc_id
- adds = self.log[LOCAL]['adds']
- assert type(adds) is not str
- adds[htlc_id] = htlc
+ self.log[LOCAL]['adds'][htlc_id] = htlc
self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE)+1}
- self.expect_sig[SENT] = True
return htlc
def recv_htlc(self, htlc: UpdateAddHtlc) -> None:
htlc_id = htlc.htlc_id
self.log[REMOTE]['adds'][htlc_id] = htlc
- l = self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL)+1, REMOTE: None}
- self.expect_sig[RECEIVED] = True
+ self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL)+1, REMOTE: None}
+
+ def send_settle(self, htlc_id: int) -> None:
+ self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1}
+
+ def recv_settle(self, htlc_id: int) -> None:
+ self.log[LOCAL]['settles'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}
+
+ def send_fail(self, htlc_id: int) -> None:
+ self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1}
+
+ def recv_fail(self, htlc_id: int) -> None:
+ self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}
def send_ctx(self) -> None:
assert self.ctn_latest(REMOTE) == self.ctn[REMOTE], (self.ctn_latest(REMOTE), self.ctn[REMOTE])
self.ctx_pending[REMOTE] = True
- for locked_in in self.log[REMOTE]['locked_in'].values():
- if locked_in[REMOTE] is None:
- locked_in[REMOTE] = self.ctn_latest(REMOTE)
- self.expect_sig[SENT] = False
def recv_ctx(self) -> None:
assert self.ctn_latest(LOCAL) == self.ctn[LOCAL], (self.ctn_latest(LOCAL), self.ctn[LOCAL])
self.ctx_pending[LOCAL] = True
- for locked_in in self.log[LOCAL]['locked_in'].values():
- if locked_in[LOCAL] is None:
- locked_in[LOCAL] = self.ctn_latest(LOCAL)
- self.expect_sig[RECEIVED] = False
def send_rev(self) -> None:
self.ctn[LOCAL] += 1
self.ctx_pending[LOCAL] = False
+ for ctns in self.log[REMOTE]['locked_in'].values():
+ if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL):
+ ctns[REMOTE] = self.ctn_latest(REMOTE) + 1
for log_action in ('settles', 'fails'):
- for htlc_id, ctns in self.log[LOCAL][log_action].items():
- if ctns[REMOTE] is None:
+ for ctns in self.log[LOCAL][log_action].values():
+ if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL):
ctns[REMOTE] = self.ctn_latest(REMOTE) + 1
def recv_rev(self) -> None:
self.ctn[REMOTE] += 1
self.ctx_pending[REMOTE] = False
- for htlc_id, ctns in self.log[LOCAL]['locked_in'].items():
- if ctns[LOCAL] is None:
- #assert ctns[REMOTE] == self.ctn[REMOTE] # FIXME I don't think this assert is correct
+ for ctns in self.log[LOCAL]['locked_in'].values():
+ if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE):
ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
for log_action in ('settles', 'fails'):
- for htlc_id, ctns in self.log[REMOTE][log_action].items():
- if ctns[LOCAL] is None:
+ for ctns in self.log[REMOTE][log_action].values():
+ if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE):
ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,
@@ -113,13 +114,7 @@ class HTLCManager:
# party is the proposer of the HTLCs
party = subject if direction == SENT else subject.inverted()
for htlc_id, ctns in self.log[party]['locked_in'].items():
- htlc_height = ctns[subject]
- if htlc_height is None:
- expect_sig = self.expect_sig[RECEIVED if party != LOCAL else SENT]
- include = not expect_sig and ctns[-subject] <= ctn
- else:
- include = htlc_height <= ctn
- if include:
+ if ctns[subject] is not None and ctns[subject] <= ctn:
settles = self.log[party]['settles']
fails = self.log[party]['fails']
not_settled = htlc_id not in settles or settles[htlc_id][subject] is None or settles[htlc_id][subject] > ctn
@@ -138,23 +133,20 @@ class HTLCManager:
l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn)]
return l
- def current_htlcs(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
- """Return the list of HTLCs in subject's oldest unrevoked ctx."""
+ def get_htlcs_in_oldest_unrevoked_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
assert type(subject) is HTLCOwner
ctn = self.ctn[subject]
return self.htlcs(subject, ctn)
- def pending_htlcs(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
- """Return the list of HTLCs in subject's next ctx (one after oldest unrevoked)."""
+ def get_htlcs_in_latest_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
assert type(subject) is HTLCOwner
- ctn = self.ctn[subject] + 1
+ ctn = self.ctn_latest(subject)
return self.htlcs(subject, ctn)
- def send_settle(self, htlc_id: int) -> None:
- self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1}
-
- def recv_settle(self, htlc_id: int) -> None:
- self.log[LOCAL]['settles'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}
+ def get_htlcs_in_next_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
+ assert type(subject) is HTLCOwner
+ ctn = self.ctn_latest(subject) + 1
+ return self.htlcs(subject, ctn)
def all_settled_htlcs_ever_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Sequence[UpdateAddHtlc]:
@@ -194,9 +186,3 @@ class HTLCManager:
return [self.log[LOCAL]['adds'][htlc_id]
for htlc_id, ctns in self.log[LOCAL]['settles'].items()
if ctns[LOCAL] == ctn]
-
- def send_fail(self, htlc_id: int) -> None:
- self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1}
-
- def recv_fail(self, htlc_id: int) -> None:
- self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}
diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
@@ -1024,12 +1024,12 @@ class Peer(Logger):
def maybe_send_commitment(self, chan: Channel):
ctn_to_sign = chan.get_current_ctn(REMOTE) + 1
# if there are no changes, we will not (and must not) send a new commitment
- pending, current = chan.hm.pending_htlcs(REMOTE), chan.hm.current_htlcs(REMOTE)
- if (pending == current
+ next_htlcs, latest_htlcs = chan.hm.get_htlcs_in_next_ctx(REMOTE), chan.hm.get_htlcs_in_latest_ctx(REMOTE)
+ if (next_htlcs == latest_htlcs
and chan.pending_feerate(REMOTE) == chan.constraints.feerate) \
or ctn_to_sign == self.sent_commitment_for_ctn_last[chan]:
return
- self.logger.info(f'send_commitment. old number htlcs: {len(current)}, new number htlcs: {len(pending)}')
+ self.logger.info(f'send_commitment. old number htlcs: {len(latest_htlcs)}, new number htlcs: {len(next_htlcs)}')
sig_64, htlc_sigs = chan.sign_next_commitment()
self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs))
self.sent_commitment_for_ctn_last[chan] = ctn_to_sign
@@ -1087,8 +1087,8 @@ class Peer(Logger):
channel_id = payload['channel_id']
chan = self.channels[channel_id]
# make sure there were changes to the ctx, otherwise the remote peer is misbehaving
- if (chan.hm.pending_htlcs(LOCAL) == chan.hm.current_htlcs(LOCAL)
- and chan.pending_feerate(LOCAL) == chan.constraints.feerate):
+ if (chan.hm.get_htlcs_in_next_ctx(LOCAL) == chan.hm.get_htlcs_in_latest_ctx(LOCAL)
+ and chan.pending_feerate(LOCAL) == chan.constraints.feerate):
raise RemoteMisbehaving('received commitment_signed without pending changes')
# make sure ctn is new
ctn_to_recv = chan.get_current_ctn(LOCAL) + 1
diff --git a/electrum/tests/test_lnchannel.py b/electrum/tests/test_lnchannel.py
@@ -226,6 +226,8 @@ class TestChannel(unittest.TestCase):
self.bob_channel.add_htlc(self.htlc_dict)
self.alice_channel.receive_htlc(self.htlc_dict)
self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment())
+ self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 3)
+ self.alice_channel.revoke_current_commitment()
self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 4)
def test_SimpleAddSettleWorkflow(self):
@@ -279,8 +281,8 @@ class TestChannel(unittest.TestCase):
self.assertTrue(alice_channel.signature_fits(com()))
self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
- self.assertEqual(next(iter(alice_channel.hm.pending_htlcs(REMOTE)))[0], RECEIVED)
- self.assertEqual(alice_channel.hm.pending_htlcs(REMOTE), bob_channel.hm.pending_htlcs(LOCAL))
+ self.assertEqual(next(iter(alice_channel.hm.get_htlcs_in_next_ctx(REMOTE)))[0], RECEIVED)
+ self.assertEqual(alice_channel.hm.get_htlcs_in_next_ctx(REMOTE), bob_channel.hm.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual(alice_channel.pending_commitment(REMOTE).outputs(), bob_channel.pending_commitment(LOCAL).outputs())
# Bob receives this signature message, and checks that this covers the
@@ -291,14 +293,11 @@ class TestChannel(unittest.TestCase):
self.assertTrue(bob_channel.signature_fits(bob_channel.pending_commitment(LOCAL)))
self.assertEqual(bob_channel.config[REMOTE].ctn, 0)
- self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), [htlc])
+ self.assertEqual(bob_channel.included_htlcs(LOCAL, RECEIVED, 1), [htlc])#
self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 0), [])
self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 1), [htlc])
- self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 0), [])
- self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), [htlc])
-
self.assertEqual(alice_channel.included_htlcs(REMOTE, SENT, 0), [])
self.assertEqual(alice_channel.included_htlcs(REMOTE, SENT, 1), [])
@@ -323,7 +322,11 @@ class TestChannel(unittest.TestCase):
self.assertTrue(alice_channel.signature_fits(com()))
self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
- self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 3)
+ # so far: Alice added htlc, Alice signed.
+ self.assertEqual(len(alice_channel.current_commitment(LOCAL).outputs()), 2)
+ self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 2)
+ self.assertEqual(len(alice_channel.current_commitment(REMOTE).outputs()), 2) # oldest unrevoked
+ self.assertEqual(len(alice_channel.pending_commitment(REMOTE).outputs()), 3) # latest
# Alice then processes this revocation, sending her own revocation for
# her prior commitment transaction. Alice shouldn't have any HTLCs to
diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py
@@ -14,42 +14,54 @@ class TestHTLCManager(unittest.TestCase):
B = HTLCManager()
ah0, bh0 = H('A', 0), H('B', 0)
B.recv_htlc(A.send_htlc(ah0))
- self.assertTrue(B.expect_sig[RECEIVED])
- self.assertTrue(A.expect_sig[SENT])
- self.assertFalse(B.expect_sig[SENT])
- self.assertFalse(A.expect_sig[RECEIVED])
self.assertEqual(B.log[REMOTE]['locked_in'][0][LOCAL], 1)
A.recv_htlc(B.send_htlc(bh0))
- self.assertTrue(B.expect_sig[RECEIVED])
- self.assertTrue(A.expect_sig[SENT])
- self.assertTrue(A.expect_sig[SENT])
- self.assertTrue(B.expect_sig[RECEIVED])
- self.assertEqual(B.current_htlcs(LOCAL), [])
- self.assertEqual(A.current_htlcs(LOCAL), [])
- self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0)])
- self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0)])
+ self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [])
+ self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [])
+ self.assertEqual(B.get_htlcs_in_next_ctx(LOCAL), [(RECEIVED, ah0)])
+ self.assertEqual(A.get_htlcs_in_next_ctx(LOCAL), [(RECEIVED, bh0)])
A.send_ctx()
B.recv_ctx()
B.send_ctx()
A.recv_ctx()
- self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
- self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
+ self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [])
+ self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [])
+ self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0)])
+ self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0)])
B.send_rev()
A.recv_rev()
A.send_rev()
B.recv_rev()
- self.assertEqual(B.current_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
- self.assertEqual(A.current_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
+ self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, ah0)])
+ self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, bh0)])
+ self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0)])
+ self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0)])
+ A.send_ctx()
+ B.recv_ctx()
+ B.send_ctx()
+ A.recv_ctx()
+ self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, ah0)])
+ self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, bh0)])
+ self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
+ self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
+ B.send_rev()
+ A.recv_rev()
+ A.send_rev()
+ B.recv_rev()
+ self.assertEqual(B.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
+ self.assertEqual(A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
+ self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
+ self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
def test_single_htlc_full_lifecycle(self):
def htlc_lifecycle(htlc_success: bool):
A = HTLCManager()
B = HTLCManager()
B.recv_htlc(A.send_htlc(H('A', 0)))
- self.assertEqual(len(B.pending_htlcs(REMOTE)), 0)
- self.assertEqual(len(A.pending_htlcs(REMOTE)), 1)
- self.assertEqual(len(B.pending_htlcs(LOCAL)), 1)
- self.assertEqual(len(A.pending_htlcs(LOCAL)), 0)
+ self.assertEqual(len(B.get_htlcs_in_next_ctx(REMOTE)), 0)
+ self.assertEqual(len(A.get_htlcs_in_next_ctx(REMOTE)), 1)
+ self.assertEqual(len(B.get_htlcs_in_next_ctx(LOCAL)), 1)
+ self.assertEqual(len(A.get_htlcs_in_next_ctx(LOCAL)), 0)
A.send_ctx()
B.recv_ctx()
B.send_rev()
@@ -58,8 +70,8 @@ class TestHTLCManager(unittest.TestCase):
A.recv_ctx()
A.send_rev()
B.recv_rev()
- self.assertEqual(len(A.current_htlcs(LOCAL)), 1)
- self.assertEqual(len(B.current_htlcs(LOCAL)), 1)
+ self.assertEqual(len(A.get_htlcs_in_latest_ctx(LOCAL)), 1)
+ self.assertEqual(len(B.get_htlcs_in_latest_ctx(LOCAL)), 1)
if htlc_success:
B.send_settle(0)
A.recv_settle(0)
@@ -67,47 +79,47 @@ class TestHTLCManager(unittest.TestCase):
B.send_fail(0)
A.recv_fail(0)
self.assertEqual(A.htlcs_by_direction(REMOTE, RECEIVED), [H('A', 0)])
- self.assertNotEqual(A.current_htlcs(LOCAL), [])
- self.assertNotEqual(B.current_htlcs(REMOTE), [])
+ self.assertNotEqual(A.get_htlcs_in_latest_ctx(LOCAL), [])
+ self.assertNotEqual(B.get_htlcs_in_latest_ctx(REMOTE), [])
- self.assertEqual(A.pending_htlcs(LOCAL), [])
- self.assertNotEqual(A.pending_htlcs(REMOTE), [])
- self.assertEqual(A.pending_htlcs(REMOTE), A.current_htlcs(REMOTE))
+ self.assertEqual(A.get_htlcs_in_next_ctx(LOCAL), [])
+ self.assertNotEqual(A.get_htlcs_in_next_ctx(REMOTE), [])
+ self.assertEqual(A.get_htlcs_in_next_ctx(REMOTE), A.get_htlcs_in_latest_ctx(REMOTE))
- self.assertEqual(B.pending_htlcs(REMOTE), [])
+ self.assertEqual(B.get_htlcs_in_next_ctx(REMOTE), [])
B.send_ctx()
A.recv_ctx()
A.send_rev() # here pending_htlcs(REMOTE) should become empty
- self.assertEqual(A.pending_htlcs(REMOTE), [])
+ self.assertEqual(A.get_htlcs_in_next_ctx(REMOTE), [])
B.recv_rev()
A.send_ctx()
B.recv_ctx()
B.send_rev()
A.recv_rev()
- self.assertEqual(B.current_htlcs(LOCAL), [])
- self.assertEqual(A.current_htlcs(LOCAL), [])
- self.assertEqual(A.current_htlcs(REMOTE), [])
- self.assertEqual(B.current_htlcs(REMOTE), [])
+ self.assertEqual(B.get_htlcs_in_latest_ctx(LOCAL), [])
+ self.assertEqual(A.get_htlcs_in_latest_ctx(LOCAL), [])
+ self.assertEqual(A.get_htlcs_in_latest_ctx(REMOTE), [])
+ self.assertEqual(B.get_htlcs_in_latest_ctx(REMOTE), [])
self.assertEqual(len(A.all_settled_htlcs_ever(LOCAL)), int(htlc_success))
self.assertEqual(len(A.sent_in_ctn(2)), int(htlc_success))
self.assertEqual(len(B.received_in_ctn(2)), int(htlc_success))
A.recv_htlc(B.send_htlc(H('B', 0)))
- self.assertEqual(A.pending_htlcs(REMOTE), [])
- self.assertNotEqual(A.pending_htlcs(LOCAL), [])
- self.assertNotEqual(B.pending_htlcs(REMOTE), [])
- self.assertEqual(B.pending_htlcs(LOCAL), [])
+ self.assertEqual(A.get_htlcs_in_next_ctx(REMOTE), [])
+ self.assertNotEqual(A.get_htlcs_in_next_ctx(LOCAL), [])
+ self.assertNotEqual(B.get_htlcs_in_next_ctx(REMOTE), [])
+ self.assertEqual(B.get_htlcs_in_next_ctx(LOCAL), [])
B.send_ctx()
A.recv_ctx()
A.send_rev()
B.recv_rev()
- self.assertNotEqual(A.pending_htlcs(REMOTE), A.current_htlcs(REMOTE))
- self.assertEqual(A.pending_htlcs(LOCAL), A.current_htlcs(LOCAL))
- self.assertEqual(B.pending_htlcs(REMOTE), B.current_htlcs(REMOTE))
- self.assertNotEqual(B.pending_htlcs(LOCAL), B.pending_htlcs(REMOTE))
+ self.assertNotEqual(A.get_htlcs_in_next_ctx(REMOTE), A.get_htlcs_in_latest_ctx(REMOTE))
+ self.assertEqual(A.get_htlcs_in_next_ctx(LOCAL), A.get_htlcs_in_latest_ctx(LOCAL))
+ self.assertEqual(B.get_htlcs_in_next_ctx(REMOTE), B.get_htlcs_in_latest_ctx(REMOTE))
+ self.assertNotEqual(B.get_htlcs_in_next_ctx(LOCAL), B.get_htlcs_in_next_ctx(REMOTE))
htlc_lifecycle(htlc_success=True)
htlc_lifecycle(htlc_success=False)
@@ -116,7 +128,8 @@ class TestHTLCManager(unittest.TestCase):
def htlc_lifecycle(htlc_success: bool):
A = HTLCManager()
B = HTLCManager()
- B.recv_htlc(A.send_htlc(H('A', 0)))
+ ah0 = H('A', 0)
+ B.recv_htlc(A.send_htlc(ah0))
A.send_ctx()
B.recv_ctx()
B.send_rev()
@@ -127,11 +140,22 @@ class TestHTLCManager(unittest.TestCase):
else:
B.send_fail(0)
A.recv_fail(0)
- self.assertEqual(B.pending_htlcs(REMOTE), [])
+ self.assertEqual([], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_oldest_unrevoked_ctx(REMOTE))
+ self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
+ self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
B.send_ctx()
A.recv_ctx()
A.send_rev()
B.recv_rev()
+ self.assertEqual([], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_oldest_unrevoked_ctx(REMOTE))
+ self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
+ self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
+ self.assertEqual([], A.get_htlcs_in_next_ctx(REMOTE))
htlc_lifecycle(htlc_success=True)
htlc_lifecycle(htlc_success=False)
@@ -144,13 +168,38 @@ class TestHTLCManager(unittest.TestCase):
B.send_rev()
ah0 = H('A', 0)
B.recv_htlc(A.send_htlc(ah0))
- self.assertEqual([], A.current_htlcs(LOCAL))
- self.assertEqual([], A.current_htlcs(REMOTE))
- self.assertEqual([], A.pending_htlcs(LOCAL))
- self.assertEqual([], A.pending_htlcs(REMOTE))
+ self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
+ self.assertEqual([], A.get_htlcs_in_latest_ctx(REMOTE))
+ self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
A.recv_rev()
- self.assertEqual([], A.current_htlcs(LOCAL))
- self.assertEqual([], A.current_htlcs(REMOTE))
- self.assertEqual([(Direction.SENT, ah0)], A.pending_htlcs(LOCAL))
- self.assertEqual([(Direction.RECEIVED, ah0)], A.pending_htlcs(REMOTE))
-
+ self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
+ self.assertEqual([], A.get_htlcs_in_latest_ctx(REMOTE))
+ self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
+ A.send_ctx()
+ B.recv_ctx()
+ self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
+ self.assertEqual([], A.get_htlcs_in_next_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
+ B.send_rev()
+ A.recv_rev()
+ self.assertEqual([], A.get_htlcs_in_latest_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
+ self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
+ B.send_ctx()
+ A.recv_ctx()
+ self.assertEqual([], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL))
+ self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_latest_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
+ self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
+ A.send_rev()
+ B.recv_rev()
+ self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_oldest_unrevoked_ctx(LOCAL))
+ self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_latest_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
+ self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL))
+ self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))