commit 39fa13b93861d324406ac58cf54047dea961ce5d
parent 72187a43416831306fd13b068b7341a3d0c05003
Author: Janus <ysangkok@gmail.com>
Date: Fri, 26 Oct 2018 18:46:33 +0200
lnchan: use NamedTuple for logs instead of dict with static keys (adds, locked_in, settles, fails)
Diffstat:
2 files changed, 69 insertions(+), 38 deletions(-)
diff --git a/electrum/lnchan.py b/electrum/lnchan.py
@@ -26,7 +26,7 @@ from collections import namedtuple, defaultdict
import binascii
import json
from enum import Enum, auto
-from typing import Optional, Dict, List, Tuple
+from typing import Optional, Dict, List, Tuple, NamedTuple, Set
from copy import deepcopy
from .util import bfh, PrintError, bh2u
@@ -121,6 +121,20 @@ def str_bytes_dict_from_save(x):
def str_bytes_dict_to_save(x):
return {str(k): bh2u(v) for k, v in x.items()}
+class HtlcChanges(NamedTuple):
+ # ints are htlc ids
+ adds: Dict[int, UpdateAddHtlc]
+ settles: Set[int]
+ fails: Set[int]
+ locked_in: Set[int]
+
+ @staticmethod
+ def new():
+ """
+ Since we can't use default arguments for these types (they would be shared among instances)
+ """
+ return HtlcChanges({}, set(), set(), set())
+
class Channel(PrintError):
def diagnostic_name(self):
if self.name:
@@ -158,18 +172,12 @@ class Channel(PrintError):
# any past commitment transaction and use that instead; until then...
self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"])
- template = lambda: {
- 'adds': {}, # Dict[HTLC_ID, UpdateAddHtlc]
- 'settles': [], # List[HTLC_ID]
- 'fails': [], # List[HTLC_ID]
- 'locked_in': [], # List[HTLC_ID]
- }
- self.log = {LOCAL: template(), REMOTE: template()}
+ self.log = {LOCAL: HtlcChanges.new(), REMOTE: HtlcChanges.new()}
for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
if strname not in state: continue
for y in state[strname]:
htlc = UpdateAddHtlc(**y)
- self.log[subject]['adds'][htlc.htlc_id] = htlc
+ self.log[subject].adds[htlc.htlc_id] = htlc
self.name = name
@@ -185,6 +193,9 @@ class Channel(PrintError):
self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])}
+ for sub in (LOCAL, REMOTE):
+ self.log[sub].locked_in.update(self.log[sub].adds.keys())
+
def set_state(self, state: str):
self._state = state
@@ -232,7 +243,7 @@ class Channel(PrintError):
assert type(htlc) is dict
self._check_can_pay(htlc['amount_msat'])
htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id)
- self.log[LOCAL]['adds'][htlc.htlc_id] = htlc
+ self.log[LOCAL].adds[htlc.htlc_id] = htlc
self.print_error("add_htlc")
self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id
@@ -251,7 +262,7 @@ class Channel(PrintError):
raise RemoteMisbehaving('Remote dipped below channel reserve.' +\
f' Available at remote: {self.available_to_spend(REMOTE)},' +\
f' HTLC amount: {htlc.amount_msat}')
- adds = self.log[REMOTE]['adds']
+ adds = self.log[REMOTE].adds
adds[htlc.htlc_id] = htlc
self.print_error("receive_htlc")
self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1)
@@ -309,11 +320,11 @@ class Channel(PrintError):
for sub in (LOCAL, REMOTE):
log = self.log[sub]
yield (sub, deepcopy(log))
- for htlc_id in log['fails']:
- log['adds'].pop(htlc_id)
- log['fails'].clear()
+ for htlc_id in log.fails:
+ log.adds.pop(htlc_id)
+ log.fails.clear()
- self.log[subject]['locked_in'] |= self.log[subject]['adds'].keys()
+ self.log[subject].locked_in.update(self.log[subject].adds.keys())
def receive_new_commitment(self, sig, htlc_sigs):
"""
@@ -474,11 +485,11 @@ class Channel(PrintError):
"""
old_amount = htlcsum(self.htlcs(subject, False))
- for htlc_id in self.log[subject]['settles']:
- adds = self.log[subject]['adds']
+ for htlc_id in self.log[subject].settles:
+ adds = self.log[subject].adds
htlc = adds.pop(htlc_id)
self.settled[subject].append(htlc.amount_msat)
- self.log[subject]['settles'].clear()
+ self.log[subject].settles.clear()
return old_amount - htlcsum(self.htlcs(subject, False))
@@ -533,7 +544,7 @@ class Channel(PrintError):
pending outgoing HTLCs, is used in the UI.
"""
return self.balance(subject)\
- - htlcsum(self.log[subject]['adds'].values())
+ - htlcsum(self.log[subject].adds.values())
def available_to_spend(self, subject):
"""
@@ -541,7 +552,7 @@ class Channel(PrintError):
not be used in the UI cause it fluctuates (commit fee)
"""
return self.balance_minus_outgoing_htlcs(subject)\
- - htlcsum(self.log[subject]['adds'].values())\
+ - htlcsum(self.log[subject].adds.values())\
- self.config[-subject].reserve_sat * 1000\
- calc_onchain_fees(
# TODO should we include a potential new htlc, when we are called from receive_htlc?
@@ -601,10 +612,10 @@ class Channel(PrintError):
"""
update_log = self.log[subject]
res = []
- for htlc in update_log['adds'].values():
- locked_in = htlc.htlc_id in update_log['locked_in']
- settled = htlc.htlc_id in update_log['settles']
- failed = htlc.htlc_id in update_log['fails']
+ for htlc in update_log.adds.values():
+ locked_in = htlc.htlc_id in update_log.locked_in
+ settled = htlc.htlc_id in update_log.settles
+ failed = htlc.htlc_id in update_log.fails
if not locked_in:
continue
if only_pending == (settled or failed):
@@ -617,25 +628,33 @@ class Channel(PrintError):
SettleHTLC attempts to settle an existing outstanding received HTLC.
"""
self.print_error("settle_htlc")
- htlc = self.log[REMOTE]['adds'][htlc_id]
+ log = self.log[REMOTE]
+ htlc = log.adds[htlc_id]
assert htlc.payment_hash == sha256(preimage)
- self.log[REMOTE]['settles'].append(htlc_id)
+ assert htlc_id not in log.settles
+ log.settles.add(htlc_id)
# not saving preimage because it's already saved in LNWorker.invoices
def receive_htlc_settle(self, preimage, htlc_id):
self.print_error("receive_htlc_settle")
- htlc = self.log[LOCAL]['adds'][htlc_id]
+ log = self.log[LOCAL]
+ htlc = log.adds[htlc_id]
assert htlc.payment_hash == sha256(preimage)
- self.log[LOCAL]['settles'].append(htlc_id)
+ assert htlc_id not in log.settles
+ log.settles.add(htlc_id)
# we don't save the preimage because we don't need to forward it anyway
def fail_htlc(self, htlc_id):
self.print_error("fail_htlc")
- self.log[REMOTE]['fails'].append(htlc_id)
+ log = self.log[REMOTE]
+ assert htlc_id not in log.fails
+ log.fails.add(htlc_id)
def receive_fail_htlc(self, htlc_id):
self.print_error("receive_fail_htlc")
- self.log[LOCAL]['fails'].append(htlc_id)
+ log = self.log[LOCAL]
+ assert htlc_id not in log.fails
+ log.fails.add(htlc_id)
@property
def current_height(self):
@@ -666,8 +685,8 @@ class Channel(PrintError):
removed = []
htlcs = []
log = self.log[subject]
- for htlc_id, i in log['adds'].items():
- locked_in = htlc_id in log['locked_in']
+ for i in log.adds.values():
+ locked_in = i.htlc_id in log.locked_in
if locked_in:
htlcs.append(i._asdict())
else:
@@ -710,18 +729,26 @@ class Channel(PrintError):
def serialize(self):
namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()}
- serialized_channel = {k: namedtuples_to_dict(v) if isinstance(v, tuple) else v for k, v in self.to_save().items()}
+ serialized_channel = {}
+ to_save_ref = self.to_save()
+ for k, v in to_save_ref.items():
+ if isinstance(v, tuple):
+ serialized_channel[k] = namedtuples_to_dict(v)
+ else:
+ serialized_channel[k] = v
dumped = ChannelJsonEncoder().encode(serialized_channel)
roundtripped = json.loads(dumped)
reconstructed = Channel(roundtripped)
- if reconstructed.to_save() != self.to_save():
- from pprint import pformat
+ to_save_new = reconstructed.to_save()
+ if to_save_new != to_save_ref:
+ from pprint import PrettyPrinter
+ pp = PrettyPrinter(indent=168)
try:
from deepdiff import DeepDiff
except ImportError:
- raise Exception("Channels did not roundtrip serialization without changes:\n" + pformat(reconstructed.to_save()) + "\n" + pformat(self.to_save()))
+ raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(to_save_ref) + "\n" + pp.pformat(to_save_new))
else:
- raise Exception("Channels did not roundtrip serialization without changes:\n" + pformat(DeepDiff(reconstructed.to_save(), self.to_save())))
+ raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(DeepDiff(to_save_ref, to_save_new)))
return roundtripped
def __str__(self):
diff --git a/electrum/tests/test_lnchan.py b/electrum/tests/test_lnchan.py
@@ -183,7 +183,7 @@ class TestChannel(unittest.TestCase):
self.bob_pending_remote_balance = after
- self.htlc = self.bob_channel.log[lnutil.REMOTE]['adds'][0]
+ self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0]
def test_SimpleAddSettleWorkflow(self):
alice_channel, bob_channel = self.alice_channel, self.bob_channel
@@ -217,6 +217,10 @@ class TestChannel(unittest.TestCase):
# forward since she's sending an outgoing HTLC.
alice_channel.receive_revocation(bobRevocation)
+ # test serializing with locked_in htlc
+ self.assertEqual(len(alice_channel.to_save()['local_log']), 1)
+ alice_channel.serialize()
+
# Alice then processes bob's signature, and since she just received
# the revocation, she expect this signature to cover everything up to
# the point where she sent her signature, including the HTLC.