electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 39fa13b93861d324406ac58cf54047dea961ce5d
parent 72187a43416831306fd13b068b7341a3d0c05003
Author: Janus <ysangkok@gmail.com>
Date:   Fri, 26 Oct 2018 18:46:33 +0200

lnchan: use NamedTuple for logs instead of dict with static keys (adds, locked_in, settles, fails)

Diffstat:
Melectrum/lnchan.py | 101++++++++++++++++++++++++++++++++++++++++++++++++++-----------------------------
Melectrum/tests/test_lnchan.py | 6+++++-
2 files changed, 69 insertions(+), 38 deletions(-)

diff --git a/electrum/lnchan.py b/electrum/lnchan.py @@ -26,7 +26,7 @@ from collections import namedtuple, defaultdict import binascii import json from enum import Enum, auto -from typing import Optional, Dict, List, Tuple +from typing import Optional, Dict, List, Tuple, NamedTuple, Set from copy import deepcopy from .util import bfh, PrintError, bh2u @@ -121,6 +121,20 @@ def str_bytes_dict_from_save(x): def str_bytes_dict_to_save(x): return {str(k): bh2u(v) for k, v in x.items()} +class HtlcChanges(NamedTuple): + # ints are htlc ids + adds: Dict[int, UpdateAddHtlc] + settles: Set[int] + fails: Set[int] + locked_in: Set[int] + + @staticmethod + def new(): + """ + Since we can't use default arguments for these types (they would be shared among instances) + """ + return HtlcChanges({}, set(), set(), set()) + class Channel(PrintError): def diagnostic_name(self): if self.name: @@ -158,18 +172,12 @@ class Channel(PrintError): # any past commitment transaction and use that instead; until then... self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"]) - template = lambda: { - 'adds': {}, # Dict[HTLC_ID, UpdateAddHtlc] - 'settles': [], # List[HTLC_ID] - 'fails': [], # List[HTLC_ID] - 'locked_in': [], # List[HTLC_ID] - } - self.log = {LOCAL: template(), REMOTE: template()} + self.log = {LOCAL: HtlcChanges.new(), REMOTE: HtlcChanges.new()} for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]: if strname not in state: continue for y in state[strname]: htlc = UpdateAddHtlc(**y) - self.log[subject]['adds'][htlc.htlc_id] = htlc + self.log[subject].adds[htlc.htlc_id] = htlc self.name = name @@ -185,6 +193,9 @@ class Channel(PrintError): self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])} + for sub in (LOCAL, REMOTE): + self.log[sub].locked_in.update(self.log[sub].adds.keys()) + def set_state(self, state: str): self._state = state @@ -232,7 +243,7 @@ class Channel(PrintError): assert type(htlc) is dict self._check_can_pay(htlc['amount_msat']) htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id) - self.log[LOCAL]['adds'][htlc.htlc_id] = htlc + self.log[LOCAL].adds[htlc.htlc_id] = htlc self.print_error("add_htlc") self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1) return htlc.htlc_id @@ -251,7 +262,7 @@ class Channel(PrintError): raise RemoteMisbehaving('Remote dipped below channel reserve.' +\ f' Available at remote: {self.available_to_spend(REMOTE)},' +\ f' HTLC amount: {htlc.amount_msat}') - adds = self.log[REMOTE]['adds'] + adds = self.log[REMOTE].adds adds[htlc.htlc_id] = htlc self.print_error("receive_htlc") self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1) @@ -309,11 +320,11 @@ class Channel(PrintError): for sub in (LOCAL, REMOTE): log = self.log[sub] yield (sub, deepcopy(log)) - for htlc_id in log['fails']: - log['adds'].pop(htlc_id) - log['fails'].clear() + for htlc_id in log.fails: + log.adds.pop(htlc_id) + log.fails.clear() - self.log[subject]['locked_in'] |= self.log[subject]['adds'].keys() + self.log[subject].locked_in.update(self.log[subject].adds.keys()) def receive_new_commitment(self, sig, htlc_sigs): """ @@ -474,11 +485,11 @@ class Channel(PrintError): """ old_amount = htlcsum(self.htlcs(subject, False)) - for htlc_id in self.log[subject]['settles']: - adds = self.log[subject]['adds'] + for htlc_id in self.log[subject].settles: + adds = self.log[subject].adds htlc = adds.pop(htlc_id) self.settled[subject].append(htlc.amount_msat) - self.log[subject]['settles'].clear() + self.log[subject].settles.clear() return old_amount - htlcsum(self.htlcs(subject, False)) @@ -533,7 +544,7 @@ class Channel(PrintError): pending outgoing HTLCs, is used in the UI. """ return self.balance(subject)\ - - htlcsum(self.log[subject]['adds'].values()) + - htlcsum(self.log[subject].adds.values()) def available_to_spend(self, subject): """ @@ -541,7 +552,7 @@ class Channel(PrintError): not be used in the UI cause it fluctuates (commit fee) """ return self.balance_minus_outgoing_htlcs(subject)\ - - htlcsum(self.log[subject]['adds'].values())\ + - htlcsum(self.log[subject].adds.values())\ - self.config[-subject].reserve_sat * 1000\ - calc_onchain_fees( # TODO should we include a potential new htlc, when we are called from receive_htlc? @@ -601,10 +612,10 @@ class Channel(PrintError): """ update_log = self.log[subject] res = [] - for htlc in update_log['adds'].values(): - locked_in = htlc.htlc_id in update_log['locked_in'] - settled = htlc.htlc_id in update_log['settles'] - failed = htlc.htlc_id in update_log['fails'] + for htlc in update_log.adds.values(): + locked_in = htlc.htlc_id in update_log.locked_in + settled = htlc.htlc_id in update_log.settles + failed = htlc.htlc_id in update_log.fails if not locked_in: continue if only_pending == (settled or failed): @@ -617,25 +628,33 @@ class Channel(PrintError): SettleHTLC attempts to settle an existing outstanding received HTLC. """ self.print_error("settle_htlc") - htlc = self.log[REMOTE]['adds'][htlc_id] + log = self.log[REMOTE] + htlc = log.adds[htlc_id] assert htlc.payment_hash == sha256(preimage) - self.log[REMOTE]['settles'].append(htlc_id) + assert htlc_id not in log.settles + log.settles.add(htlc_id) # not saving preimage because it's already saved in LNWorker.invoices def receive_htlc_settle(self, preimage, htlc_id): self.print_error("receive_htlc_settle") - htlc = self.log[LOCAL]['adds'][htlc_id] + log = self.log[LOCAL] + htlc = log.adds[htlc_id] assert htlc.payment_hash == sha256(preimage) - self.log[LOCAL]['settles'].append(htlc_id) + assert htlc_id not in log.settles + log.settles.add(htlc_id) # we don't save the preimage because we don't need to forward it anyway def fail_htlc(self, htlc_id): self.print_error("fail_htlc") - self.log[REMOTE]['fails'].append(htlc_id) + log = self.log[REMOTE] + assert htlc_id not in log.fails + log.fails.add(htlc_id) def receive_fail_htlc(self, htlc_id): self.print_error("receive_fail_htlc") - self.log[LOCAL]['fails'].append(htlc_id) + log = self.log[LOCAL] + assert htlc_id not in log.fails + log.fails.add(htlc_id) @property def current_height(self): @@ -666,8 +685,8 @@ class Channel(PrintError): removed = [] htlcs = [] log = self.log[subject] - for htlc_id, i in log['adds'].items(): - locked_in = htlc_id in log['locked_in'] + for i in log.adds.values(): + locked_in = i.htlc_id in log.locked_in if locked_in: htlcs.append(i._asdict()) else: @@ -710,18 +729,26 @@ class Channel(PrintError): def serialize(self): namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()} - serialized_channel = {k: namedtuples_to_dict(v) if isinstance(v, tuple) else v for k, v in self.to_save().items()} + serialized_channel = {} + to_save_ref = self.to_save() + for k, v in to_save_ref.items(): + if isinstance(v, tuple): + serialized_channel[k] = namedtuples_to_dict(v) + else: + serialized_channel[k] = v dumped = ChannelJsonEncoder().encode(serialized_channel) roundtripped = json.loads(dumped) reconstructed = Channel(roundtripped) - if reconstructed.to_save() != self.to_save(): - from pprint import pformat + to_save_new = reconstructed.to_save() + if to_save_new != to_save_ref: + from pprint import PrettyPrinter + pp = PrettyPrinter(indent=168) try: from deepdiff import DeepDiff except ImportError: - raise Exception("Channels did not roundtrip serialization without changes:\n" + pformat(reconstructed.to_save()) + "\n" + pformat(self.to_save())) + raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(to_save_ref) + "\n" + pp.pformat(to_save_new)) else: - raise Exception("Channels did not roundtrip serialization without changes:\n" + pformat(DeepDiff(reconstructed.to_save(), self.to_save()))) + raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(DeepDiff(to_save_ref, to_save_new))) return roundtripped def __str__(self): diff --git a/electrum/tests/test_lnchan.py b/electrum/tests/test_lnchan.py @@ -183,7 +183,7 @@ class TestChannel(unittest.TestCase): self.bob_pending_remote_balance = after - self.htlc = self.bob_channel.log[lnutil.REMOTE]['adds'][0] + self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0] def test_SimpleAddSettleWorkflow(self): alice_channel, bob_channel = self.alice_channel, self.bob_channel @@ -217,6 +217,10 @@ class TestChannel(unittest.TestCase): # forward since she's sending an outgoing HTLC. alice_channel.receive_revocation(bobRevocation) + # test serializing with locked_in htlc + self.assertEqual(len(alice_channel.to_save()['local_log']), 1) + alice_channel.serialize() + # Alice then processes bob's signature, and since she just received # the revocation, she expect this signature to cover everything up to # the point where she sent her signature, including the HTLC.