commit 1d4c113a3524cf012d5a7c40acdc7a61e16b19ed
parent 699368b0b783dd5ade7740a0bffd11ba0fd056f0
Author: Janus <ysangkok@gmail.com>
Date: Wed, 10 Oct 2018 19:52:46 +0200
lnhtlc: remove lookup_htlc, use heterogeneously typed lists
Diffstat:
3 files changed, 28 insertions(+), 66 deletions(-)
diff --git a/electrum/lnbase.py b/electrum/lnbase.py
@@ -1178,7 +1178,6 @@ class Peer(PrintError):
chan = self.channels[update_fulfill_htlc_msg["channel_id"]]
preimage = update_fulfill_htlc_msg["payment_preimage"]
htlc_id = int.from_bytes(update_fulfill_htlc_msg["id"], "big")
- htlc = chan.lookup_htlc(chan.log[LOCAL], htlc_id)
chan.receive_htlc_settle(preimage, htlc_id)
await self.receive_commitment(chan)
self.revoke(chan)
diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py
@@ -23,8 +23,6 @@ from .transaction import Transaction, TxOutput, construct_witness
from .simple_config import SimpleConfig, FEERATE_FALLBACK_STATIC_FEE
-FailHtlc = namedtuple("FailHtlc", ["htlc_id"])
-SettleHtlc = namedtuple("SettleHtlc", ["htlc_id"])
RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_commitment_point"])
class FeeUpdateProgress(Enum):
@@ -100,14 +98,6 @@ def typeWrap(k, v, local):
return v
class HTLCStateMachine(PrintError):
- def lookup_htlc(self, log, htlc_id):
- assert type(htlc_id) is int
- for htlc in log:
- if type(htlc) is not UpdateAddHtlc: continue
- if htlc.htlc_id == htlc_id:
- return htlc
- assert False, self.diagnostic_name() + ": htlc_id {} not found in {}".format(htlc_id, log)
-
def diagnostic_name(self):
return str(self.name)
@@ -146,18 +136,13 @@ class HTLCStateMachine(PrintError):
# any past commitment transaction and use that instead; until then...
self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"])
- self.log = {LOCAL: [], REMOTE: []}
+ template = lambda: {'adds': {}, 'settles': []}
+ self.log = {LOCAL: template(), REMOTE: template()}
for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
if strname not in state: continue
- for typ,y in state[strname]:
- if typ == "UpdateAddHtlc":
- self.log[subject].append(UpdateAddHtlc(*decodeAll(y)))
- elif typ == "SettleHtlc":
- self.log[subject].append(SettleHtlc(*decodeAll(y)))
- elif typ == "FailHtlc":
- self.log[subject].append(FailHtlc(*decodeAll(y)))
- else:
- assert False
+ for y in state[strname]:
+ htlc = UpdateAddHtlc(*decodeAll(y))
+ self.log[subject]['adds'][htlc.htlc_id] = htlc
self.name = name
@@ -197,7 +182,7 @@ class HTLCStateMachine(PrintError):
"""
assert type(htlc) is dict
htlc = UpdateAddHtlc(**htlc, htlc_id=self.local_state.next_htlc_id)
- self.log[LOCAL].append(htlc)
+ self.log[LOCAL]['adds'][htlc.htlc_id] = htlc
self.print_error("add_htlc")
self.local_state=self.local_state._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id
@@ -210,7 +195,7 @@ class HTLCStateMachine(PrintError):
"""
assert type(htlc) is dict
htlc = UpdateAddHtlc(**htlc, htlc_id = self.remote_state.next_htlc_id)
- self.log[REMOTE].append(htlc)
+ self.log[REMOTE]['adds'][htlc.htlc_id] = htlc
self.print_error("receive_htlc")
self.remote_state=self.remote_state._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id
@@ -228,9 +213,8 @@ class HTLCStateMachine(PrintError):
any). The HTLC signatures are sorted according to the BIP 69 order of the
HTLC's on the commitment transaction.
"""
- for htlc in self.log[LOCAL]:
- if not type(htlc) is UpdateAddHtlc: continue
- if htlc.locked_in[LOCAL] is None and FailHtlc(htlc.htlc_id) not in self.log[REMOTE]:
+ for htlc in self.log[LOCAL]['adds'].values():
+ if htlc.locked_in[LOCAL] is None:
htlc.locked_in[LOCAL] = self.local_state.ctn
self.print_error("sign_next_commitment")
@@ -279,9 +263,8 @@ class HTLCStateMachine(PrintError):
"""
self.print_error("receive_new_commitment")
- for htlc in self.log[REMOTE]:
- if not type(htlc) is UpdateAddHtlc: continue
- if htlc.locked_in[REMOTE] is None and FailHtlc(htlc.htlc_id) not in self.log[LOCAL]:
+ for htlc in self.log[REMOTE]['adds'].values():
+ if htlc.locked_in[REMOTE] is None:
htlc.locked_in[REMOTE] = self.remote_state.ctn
assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes
@@ -422,18 +405,15 @@ class HTLCStateMachine(PrintError):
def mark_settled(subject):
"""
- find settled htlcs for subject (LOCAL or REMOTE) and mark them settled, return value of settled htlcs
+ find pending settlements for subject (LOCAL or REMOTE) and mark them settled, return value of settled htlcs
"""
old_amount = self.htlcsum(self.gen_htlc_indices(subject, False))
- removed = []
- for x in self.log[-subject]:
- if type(x) is not SettleHtlc: continue
- htlc = self.lookup_htlc(self.log[subject], x.htlc_id)
+ for htlc_id in self.log[-subject]['settles']:
+ adds = self.log[subject]['adds']
+ htlc = adds.pop(htlc_id)
self.settled[subject].append(htlc.amount_msat)
- self.log[subject].remove(htlc)
- removed.append(x)
- for x in removed: self.log[-subject].remove(x)
+ self.log[-subject]['settles'].clear()
return old_amount - self.htlcsum(self.gen_htlc_indices(subject, False))
@@ -533,12 +513,10 @@ class HTLCStateMachine(PrintError):
update_log = self.log[subject]
other_log = self.log[-subject]
res = []
- for htlc in update_log:
- if type(htlc) is not UpdateAddHtlc:
- continue
+ for htlc in update_log['adds'].values():
locked_in = htlc.locked_in[subject]
- if locked_in is None or only_pending == (SettleHtlc(htlc.htlc_id) in other_log):
+ if locked_in is None or only_pending == (htlc.htlc_id in other_log['settles']):
continue
res.append(htlc)
return res
@@ -558,23 +536,19 @@ class HTLCStateMachine(PrintError):
SettleHTLC attempts to settle an existing outstanding received HTLC.
"""
self.print_error("settle_htlc")
- htlc = self.lookup_htlc(self.log[REMOTE], htlc_id)
+ htlc = self.log[REMOTE]['adds'][htlc_id]
assert htlc.payment_hash == sha256(preimage)
- self.log[LOCAL].append(SettleHtlc(htlc_id))
+ self.log[LOCAL]['settles'].append(htlc_id)
def receive_htlc_settle(self, preimage, htlc_index):
self.print_error("receive_htlc_settle")
- htlc = self.lookup_htlc(self.log[LOCAL], htlc_index)
+ htlc = self.log[LOCAL]['adds'][htlc_index]
assert htlc.payment_hash == sha256(preimage)
- assert len([x for x in self.log[LOCAL] if x.htlc_id == htlc_index and type(x) is UpdateAddHtlc]) == 1, (self.log[LOCAL], htlc_index)
- self.log[REMOTE].append(SettleHtlc(htlc_index))
+ self.log[REMOTE]['settles'].append(htlc_index)
def receive_fail_htlc(self, htlc_id):
self.print_error("receive_fail_htlc")
- htlc = self.lookup_htlc(self.log[LOCAL], htlc_id)
- htlc.locked_in[LOCAL] = None
- htlc.locked_in[REMOTE] = None
- self.log[REMOTE].append(FailHtlc(htlc_id))
+ self.log[LOCAL]['adds'].pop(htlc_id)
@property
def current_height(self):
@@ -604,14 +578,9 @@ class HTLCStateMachine(PrintError):
"""
removed = []
htlcs = []
- for i in self.log[subject]:
- if type(i) is not UpdateAddHtlc:
- htlcs.append(i)
- continue
- settled = SettleHtlc(i.htlc_id) in self.log[-subject]
- failed = FailHtlc(i.htlc_id) in self.log[-subject]
+ for i in self.log[subject]['adds'].values():
locked_in = i.locked_in[LOCAL] is not None or i.locked_in[REMOTE] is not None
- if locked_in or settled or failed:
+ if locked_in:
htlcs.append(i)
else:
removed.append(i.htlc_id)
@@ -634,8 +603,8 @@ class HTLCStateMachine(PrintError):
"funding_outpoint": self.funding_outpoint,
"node_id": self.node_id,
"remote_commitment_to_be_revoked": str(self.remote_commitment_to_be_revoked),
- "remote_log": [(type(x).__name__, x) for x in remote_filtered],
- "local_log": [(type(x).__name__, x) for x in local_filtered],
+ "remote_log": remote_filtered,
+ "local_log": local_filtered,
"onion_keys": {str(k): bh2u(v) for k, v in self.onion_keys.items()},
"settled_local": self.settled[LOCAL],
"settled_remote": self.settled[REMOTE],
@@ -662,12 +631,6 @@ class HTLCStateMachine(PrintError):
return binascii.hexlify(o).decode("ascii")
if isinstance(o, RevocationStore):
return o.serialize()
- if isinstance(o, SettleHtlc):
- return json.dumps(('SettleHtlc', namedtuples_to_dict(o)))
- if isinstance(o, FailHtlc):
- return json.dumps(('FailHtlc', namedtuples_to_dict(o)))
- if isinstance(o, UpdateAddHtlc):
- return json.dumps(('UpdateAddHtlc', namedtuples_to_dict(o)))
return super(MyJsonEncoder, self)
dumped = MyJsonEncoder().encode(serialized_channel)
roundtripped = json.loads(dumped)
diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py
@@ -147,7 +147,7 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc)
self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc)
- self.htlc = self.bob_channel.log[lnutil.REMOTE][0]
+ self.htlc = self.bob_channel.log[lnutil.REMOTE]['adds'][0]
def test_SimpleAddSettleWorkflow(self):
alice_channel, bob_channel = self.alice_channel, self.bob_channel