electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 02eca034866caa85d3a2f9f8d23c5f1044349f79
parent 6f5209ef8506fd889064cee9753ad8aaee02ad42
Author: Janus <ysangkok@gmail.com>
Date:   Wed, 12 Sep 2018 23:37:45 +0200

lnhtlc: cleanup and save settled htlcs

Diffstat:
Melectrum/lnhtlc.py | 177++++++++++++++++++++++++++++++++++++++++---------------------------------------
Melectrum/tests/test_lnhtlc.py | 28+++++++++++++++-------------
2 files changed, 104 insertions(+), 101 deletions(-)

diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py @@ -2,6 +2,7 @@ from collections import namedtuple import binascii import json +from enum import IntFlag from .util import bfh, PrintError, bh2u from .bitcoin import Hash @@ -21,9 +22,22 @@ from .transaction import Transaction SettleHtlc = namedtuple("SettleHtlc", ["htlc_id"]) RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_commitment_point"]) -FUNDEE_SIGNED = 1 -FUNDEE_ACKED = 2 -FUNDER_SIGNED = 4 +class FeeUpdateProgress(IntFlag): + FUNDEE_SIGNED = 1 + FUNDEE_ACKED = 2 + FUNDER_SIGNED = 4 + +class HTLCOwner(IntFlag): + LOCAL = 1 + REMOTE = -LOCAL + + SENT = LOCAL + RECEIVED = REMOTE + +SENT = HTLCOwner.SENT +RECEIVED = HTLCOwner.RECEIVED +LOCAL = HTLCOwner.LOCAL +REMOTE = HTLCOwner.REMOTE class FeeUpdate: def __init__(self, rate): @@ -37,13 +51,14 @@ class UpdateAddHtlc: self.cltv_expiry = cltv_expiry # the height the htlc was locked in at, or None - self.r_locked_in = None - self.l_locked_in = None + self.locked_in = {LOCAL: None, REMOTE: None} + + self.settled = {LOCAL: None, REMOTE: None} self.htlc_id = None def as_tuple(self): - return (self.htlc_id, self.amount_msat, self.payment_hash, self.cltv_expiry, self.r_locked_in, self.l_locked_in) + return (self.htlc_id, self.amount_msat, self.payment_hash, self.cltv_expiry, self.locked_in[REMOTE], self.locked_in[LOCAL], self.settled) def __hash__(self): return hash(self.as_tuple()) @@ -77,7 +92,7 @@ class HTLCStateMachine(PrintError): @property def pending_remote_feerate(self): if self.pending_fee is not None: - if self.constraints.is_initiator or (self.pending_fee.progress & FUNDEE_ACKED): + if self.constraints.is_initiator or (self.pending_fee.progress & FeeUpdateProgress.FUNDEE_ACKED): return self.pending_fee.rate return self.remote_state.feerate @@ -86,7 +101,7 @@ class HTLCStateMachine(PrintError): if self.pending_fee is not None: if not self.constraints.is_initiator: return self.pending_fee.rate - if self.constraints.is_initiator and (self.pending_fee.progress & FUNDEE_ACKED): + if self.constraints.is_initiator and (self.pending_fee.progress & FeeUpdateProgress.FUNDEE_ACKED): return self.pending_fee.rate return self.local_state.feerate @@ -134,13 +149,10 @@ class HTLCStateMachine(PrintError): # any past commitment transaction and use that instead; until then... self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"]) - self.local_update_log = [] - self.remote_update_log = [] + self.log = {LOCAL: [], REMOTE: []} self.name = name - self.total_msat_sent = 0 - self.total_msat_received = 0 self.pending_fee = None self.local_commitment = self.pending_local_commitment @@ -174,7 +186,7 @@ class HTLCStateMachine(PrintError): should be called when preparing to send an outgoing HTLC. """ assert type(htlc) is UpdateAddHtlc - self.local_update_log.append(htlc) + self.log[LOCAL].append(htlc) self.print_error("add_htlc") htlc_id = self.local_state.next_htlc_id self.local_state=self.local_state._replace(next_htlc_id=htlc_id + 1) @@ -189,7 +201,7 @@ class HTLCStateMachine(PrintError): """ self.print_error("receive_htlc") assert type(htlc) is UpdateAddHtlc - self.remote_update_log.append(htlc) + self.log[REMOTE].append(htlc) htlc_id = self.remote_state.next_htlc_id self.remote_state=self.remote_state._replace(next_htlc_id=htlc_id + 1) htlc.htlc_id = htlc_id @@ -208,9 +220,9 @@ class HTLCStateMachine(PrintError): any). The HTLC signatures are sorted according to the BIP 69 order of the HTLC's on the commitment transaction. """ - for htlc in self.local_update_log: + for htlc in self.log[LOCAL]: if not type(htlc) is UpdateAddHtlc: continue - if htlc.l_locked_in is None: htlc.l_locked_in = self.local_state.ctn + if htlc.locked_in[LOCAL] is None: htlc.locked_in[LOCAL] = self.local_state.ctn self.print_error("sign_next_commitment") pending_remote_commitment = self.pending_remote_commitment @@ -243,9 +255,9 @@ class HTLCStateMachine(PrintError): if self.pending_fee: if not self.constraints.is_initiator: - self.pending_fee.progress |= FUNDEE_SIGNED - if self.constraints.is_initiator and (self.pending_fee.progress & FUNDEE_ACKED): - self.pending_fee.progress |= FUNDER_SIGNED + self.pending_fee.progress |= FeeUpdateProgress.FUNDEE_SIGNED + if self.constraints.is_initiator and (self.pending_fee.progress & FeeUpdateProgress.FUNDEE_ACKED): + self.pending_fee.progress |= FeeUpdateProgress.FUNDER_SIGNED if self.lnwatcher: self.lnwatcher.process_new_offchain_ctx(self, pending_remote_commitment, ours=False) @@ -265,9 +277,9 @@ class HTLCStateMachine(PrintError): """ self.print_error("receive_new_commitment") - for htlc in self.remote_update_log: + for htlc in self.log[REMOTE]: if not type(htlc) is UpdateAddHtlc: continue - if htlc.r_locked_in is None: htlc.r_locked_in = self.remote_state.ctn + if htlc.locked_in[REMOTE] is None: htlc.locked_in[REMOTE] = self.remote_state.ctn assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes pending_local_commitment = self.pending_local_commitment @@ -294,9 +306,9 @@ class HTLCStateMachine(PrintError): if self.pending_fee: if not self.constraints.is_initiator: - self.pending_fee.progress |= FUNDEE_SIGNED - if self.constraints.is_initiator and (self.pending_fee.progress & FUNDEE_ACKED): - self.pending_fee.progress |= FUNDER_SIGNED + self.pending_fee.progress |= FeeUpdateProgress.FUNDEE_SIGNED + if self.constraints.is_initiator and (self.pending_fee.progress & FeeUpdateProgress.FUNDEE_ACKED): + self.pending_fee.progress |= FeeUpdateProgress.FUNDER_SIGNED if self.lnwatcher: self.lnwatcher.process_new_offchain_ctx(self, pending_local_commitment, ours=True) @@ -321,11 +333,11 @@ class HTLCStateMachine(PrintError): new_remote_feerate = self.remote_state.feerate if self.pending_fee is not None: - if not self.constraints.is_initiator and (self.pending_fee.progress & FUNDEE_SIGNED): + if not self.constraints.is_initiator and (self.pending_fee.progress & FeeUpdateProgress.FUNDEE_SIGNED): new_local_feerate = new_remote_feerate = self.pending_fee.rate self.pending_fee = None print("FEERATE CHANGE COMPLETE (non-initiator)") - if self.constraints.is_initiator and (self.pending_fee.progress & FUNDER_SIGNED): + if self.constraints.is_initiator and (self.pending_fee.progress & FeeUpdateProgress.FUNDER_SIGNED): new_local_feerate = new_remote_feerate = self.pending_fee.rate self.pending_fee = None print("FEERATE CHANGE COMPLETE (initiator)") @@ -382,41 +394,21 @@ class HTLCStateMachine(PrintError): if self.lnwatcher: self.lnwatcher.process_new_revocation_secret(self, revocation.per_commitment_secret) - settle_fails2 = [] - for x in self.remote_update_log: - if type(x) is not SettleHtlc: - continue - settle_fails2.append(x) - - sent_this_batch = 0 + def mark_settled(subject): + """ + find settled htlcs for subject (LOCAL or REMOTE) and mark them settled, return value of settled htlcs + """ + old_amount = self.htlcsum(self.gen_htlc_indices(subject, False)) - for x in settle_fails2: - htlc = self.lookup_htlc(self.local_update_log, x.htlc_id) - sent_this_batch += htlc.amount_msat + for x in self.log[-subject]: + if type(x) is not SettleHtlc: continue + htlc = self.lookup_htlc(self.log[subject], x.htlc_id) + htlc.settled[subject] = self.current_height[subject] - self.total_msat_sent += sent_this_batch + return old_amount - self.htlcsum(self.gen_htlc_indices(subject, False)) - # log compaction (remove entries relating to htlc's that have been settled) - - to_remove = [] - for x in filter(lambda x: type(x) is SettleHtlc, self.remote_update_log): - to_remove += [y for y in self.local_update_log if y.htlc_id == x.htlc_id] - - # assert that we should have compacted the log earlier - assert len(to_remove) <= 1, to_remove - if len(to_remove) == 1: - self.remote_update_log = [x for x in self.remote_update_log if x.htlc_id != to_remove[0].htlc_id] - self.local_update_log = [x for x in self.local_update_log if x.htlc_id != to_remove[0].htlc_id] - - to_remove = [] - for x in filter(lambda x: type(x) is SettleHtlc, self.local_update_log): - to_remove += [y for y in self.remote_update_log if y.htlc_id == x.htlc_id] - if len(to_remove) == 1: - self.remote_update_log = [x for x in self.remote_update_log if x.htlc_id != to_remove[0].htlc_id] - self.local_update_log = [x for x in self.local_update_log if x.htlc_id != to_remove[0].htlc_id] - received_this_batch = sum(x.amount_msat for x in to_remove) - - self.total_msat_received += received_this_batch + sent_this_batch = mark_settled(LOCAL) + received_this_batch = mark_settled(REMOTE) next_point = self.remote_state.next_per_commitment_point @@ -434,7 +426,7 @@ class HTLCStateMachine(PrintError): if self.pending_fee: if self.constraints.is_initiator: - self.pending_fee.progress |= FUNDEE_ACKED + self.pending_fee.progress |= FeeUpdateProgress.FUNDEE_ACKED self.local_commitment = self.pending_local_commitment self.remote_commitment = self.pending_remote_commitment @@ -449,14 +441,14 @@ class HTLCStateMachine(PrintError): return amount_unsettled def amounts(self): - remote_settled_value = self.htlcsum(self.gen_htlc_indices("remote", False)) - local_settled_value = self.htlcsum(self.gen_htlc_indices("local", False)) - htlc_value_local = self.htlcsum(self.htlcs_in_local) - htlc_value_remote = self.htlcsum(self.htlcs_in_remote) - local_msat = self.local_state.amount_msat -\ - htlc_value_local + remote_settled_value - local_settled_value + remote_settled= self.htlcsum(self.gen_htlc_indices(REMOTE, False)) + local_settled= self.htlcsum(self.gen_htlc_indices(LOCAL, False)) + unsettled_local = self.htlcsum(self.gen_htlc_indices(LOCAL, True)) + unsettled_remote = self.htlcsum(self.gen_htlc_indices(REMOTE, True)) remote_msat = self.remote_state.amount_msat -\ - htlc_value_remote + local_settled_value - remote_settled_value + unsettled_remote + local_settled - remote_settled + local_msat = self.local_state.amount_msat -\ + unsettled_local + remote_settled - local_settled return remote_msat, local_msat @property @@ -525,61 +517,70 @@ class HTLCStateMachine(PrintError): local_msat, remote_msat, htlcs_in_local + htlcs_in_remote) return commit - def gen_htlc_indices(self, subject, just_unsettled=True): - assert subject in ["local", "remote"] - update_log = (self.remote_update_log if subject == "remote" else self.local_update_log) - other_log = (self.remote_update_log if subject != "remote" else self.local_update_log) + @property + def total_msat(self): + return {LOCAL: self.htlcsum(self.gen_htlc_indices(LOCAL, False, True)), REMOTE: self.htlcsum(self.gen_htlc_indices(REMOTE, False, True))} + + def gen_htlc_indices(self, subject, only_pending, include_settled=False): + """ + only_pending: require the htlc's settlement to be pending (needs additional signatures/acks) + include_settled: include settled (totally done with) htlcs + """ + update_log = self.log[subject] + other_log = self.log[-subject] res = [] for htlc in update_log: if type(htlc) is not UpdateAddHtlc: continue - height = (self.local_state.ctn if subject == "remote" else self.remote_state.ctn) - locked_in = (htlc.r_locked_in if subject == "remote" else htlc.l_locked_in) + height = self.current_height[-subject] + locked_in = htlc.locked_in[subject] - if locked_in is None or just_unsettled == (SettleHtlc(htlc.htlc_id) in other_log): + if locked_in is None or only_pending == (SettleHtlc(htlc.htlc_id) in other_log): continue + + settled_cutoff = self.local_state.ctn if subject == LOCAL else self.remote_state.ctn + + if not include_settled and htlc.settled[subject] is not None and settled_cutoff >= htlc.settled[subject]: + continue + res.append(htlc) return res @property def htlcs_in_local(self): """in the local log. 'offered by us'""" - return self.gen_htlc_indices("local") + return self.gen_htlc_indices(LOCAL, True) @property def htlcs_in_remote(self): """in the remote log. 'offered by them'""" - return self.gen_htlc_indices("remote") + return self.gen_htlc_indices(REMOTE, True) def settle_htlc(self, preimage, htlc_id): """ SettleHTLC attempts to settle an existing outstanding received HTLC. """ self.print_error("settle_htlc") - htlc = self.lookup_htlc(self.remote_update_log, htlc_id) + htlc = self.lookup_htlc(self.log[REMOTE], htlc_id) assert htlc.payment_hash == sha256(preimage) - self.local_update_log.append(SettleHtlc(htlc_id)) + self.log[LOCAL].append(SettleHtlc(htlc_id)) def receive_htlc_settle(self, preimage, htlc_index): self.print_error("receive_htlc_settle") - htlc = self.lookup_htlc(self.local_update_log, htlc_index) + htlc = self.lookup_htlc(self.log[LOCAL], htlc_index) assert htlc.payment_hash == sha256(preimage) - assert len([x.htlc_id == htlc_index for x in self.local_update_log]) == 1 - self.remote_update_log.append(SettleHtlc(htlc_index)) + assert len([x.htlc_id == htlc_index for x in self.log[LOCAL]]) == 1 + self.log[REMOTE].append(SettleHtlc(htlc_index)) def fail_htlc(self, htlc): # TODO - self.local_update_log = [] - self.remote_update_log = [] + self.log[LOCAL] = [] + self.log[REMOTE] = [] self.print_error("fail_htlc (EMPTIED LOGS)") @property - def l_current_height(self): - return self.local_state.ctn - - @property - def r_current_height(self): - return self.remote_state.ctn + def current_height(self): + return {LOCAL: self.local_state.ctn, REMOTE: self.remote_state.ctn} @property def pending_local_fee(self): diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py @@ -9,6 +9,8 @@ import electrum.util as util import os import binascii +from electrum.lnhtlc import SENT, LOCAL, REMOTE, RECEIVED + def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate, is_initiator, local_amount, remote_amount, privkeys, other_pubkeys, seed, cur, nex, other_node_id, l_dust, r_dust, l_csv, r_csv): assert local_amount > 0 assert remote_amount > 0 @@ -195,10 +197,10 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase): aliceSent = 0 bobSent = 0 - self.assertEqual(alice_channel.total_msat_sent, aliceSent, "alice has incorrect milli-satoshis sent") - self.assertEqual(alice_channel.total_msat_received, bobSent, "alice has incorrect milli-satoshis received") - self.assertEqual(bob_channel.total_msat_sent, bobSent, "bob has incorrect milli-satoshis sent") - self.assertEqual(bob_channel.total_msat_received, aliceSent, "bob has incorrect milli-satoshis received") + self.assertEqual(alice_channel.total_msat[SENT], aliceSent, "alice has incorrect milli-satoshis sent") + self.assertEqual(alice_channel.total_msat[RECEIVED], bobSent, "alice has incorrect milli-satoshis received") + self.assertEqual(bob_channel.total_msat[SENT], bobSent, "bob has incorrect milli-satoshis sent") + self.assertEqual(bob_channel.total_msat[RECEIVED], aliceSent, "bob has incorrect milli-satoshis received") self.assertEqual(bob_channel.local_state.ctn, 1, "bob has incorrect commitment height") self.assertEqual(alice_channel.local_state.ctn, 1, "alice has incorrect commitment height") @@ -236,18 +238,18 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase): # should show 1 BTC received. They should also be at commitment height # two, with the revocation window extended by 1 (5). mSatTransferred = one_bitcoin_in_msat - self.assertEqual(alice_channel.total_msat_sent, mSatTransferred, "alice satoshis sent incorrect %s vs %s expected"% (alice_channel.total_msat_sent, mSatTransferred)) - self.assertEqual(alice_channel.total_msat_received, 0, "alice satoshis received incorrect %s vs %s expected"% (alice_channel.total_msat_received, 0)) - self.assertEqual(bob_channel.total_msat_received, mSatTransferred, "bob satoshis received incorrect %s vs %s expected"% (bob_channel.total_msat_received, mSatTransferred)) - self.assertEqual(bob_channel.total_msat_sent, 0, "bob satoshis sent incorrect %s vs %s expected"% (bob_channel.total_msat_sent, 0)) - self.assertEqual(bob_channel.l_current_height, 2, "bob has incorrect commitment height, %s vs %s"% (bob_channel.l_current_height, 2)) - self.assertEqual(alice_channel.l_current_height, 2, "alice has incorrect commitment height, %s vs %s"% (alice_channel.l_current_height, 2)) + self.assertEqual(alice_channel.total_msat[SENT], mSatTransferred, "alice satoshis sent incorrect") + self.assertEqual(alice_channel.total_msat[RECEIVED], 0, "alice satoshis received incorrect") + self.assertEqual(bob_channel.total_msat[RECEIVED], mSatTransferred, "bob satoshis received incorrect") + self.assertEqual(bob_channel.total_msat[SENT], 0, "bob satoshis sent incorrect") + self.assertEqual(bob_channel.current_height[LOCAL], 2, "bob has incorrect commitment height") + self.assertEqual(alice_channel.current_height[LOCAL], 2, "alice has incorrect commitment height") # The logs of both sides should now be cleared since the entry adding # the HTLC should have been removed once both sides receive the # revocation. - self.assertEqual(alice_channel.local_update_log, [], "alice's local not updated, should be empty, has %s entries instead"% len(alice_channel.local_update_log)) - self.assertEqual(alice_channel.remote_update_log, [], "alice's remote not updated, should be empty, has %s entries instead"% len(alice_channel.remote_update_log)) + #self.assertEqual(alice_channel.local_update_log, [], "alice's local not updated, should be empty, has %s entries instead"% len(alice_channel.local_update_log)) + #self.assertEqual(alice_channel.remote_update_log, [], "alice's remote not updated, should be empty, has %s entries instead"% len(alice_channel.remote_update_log)) def alice_to_bob_fee_update(self): fee = 111 @@ -340,7 +342,7 @@ class TestLNHTLCDust(unittest.TestCase): alice_channel.receive_htlc_settle(paymentPreimage, aliceHtlcIndex) force_state_transition(bob_channel, alice_channel) self.assertEqual(len(alice_channel.local_commitment.outputs()), 2) - self.assertEqual(alice_channel.total_msat_sent // 1000, htlcAmt) + self.assertEqual(alice_channel.total_msat[SENT] // 1000, htlcAmt) def force_state_transition(chanA, chanB): chanB.receive_new_commitment(*chanA.sign_next_commitment())