electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit e56e8495059d526b6151839dcf2a7caddb5a8d18
parent ef88bb1c286db6c7ca29754a750c499cb90dbb82
Author: Janus <ysangkok@gmail.com>
Date:   Mon, 21 Jan 2019 21:27:27 +0100

lnchan refactor

- replace undoing logic with new HTLCManager class
- separate SENT/RECEIVED
- move UpdateAddHtlc to lnutil

Diffstat:
Melectrum/gui/qt/channel_details.py | 22+++++++++++-----------
Melectrum/gui/qt/channels_list.py | 5+++--
Melectrum/lnbase.py | 20++++++++++----------
Melectrum/lnchan.py | 465++++++++++++++++++++++++++++++-------------------------------------------------
Aelectrum/lnhtlc.py | 159+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Melectrum/lnsweep.py | 20+++++++++-----------
Melectrum/lnutil.py | 42+++++++++++++++++++++++++++++++-----------
Melectrum/lnworker.py | 30++++++++++++++++++------------
Melectrum/tests/test_lnchan.py | 231++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------
Aelectrum/tests/test_lnhtlc.py | 95+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Melectrum/tests/test_lnutil.py | 13++++++-------
11 files changed, 705 insertions(+), 397 deletions(-)

diff --git a/electrum/gui/qt/channel_details.py b/electrum/gui/qt/channel_details.py @@ -5,9 +5,9 @@ import PyQt5.QtWidgets as QtWidgets import PyQt5.QtCore as QtCore from electrum.i18n import _ -from electrum.lnchan import UpdateAddHtlc, HTLCOwner from electrum.util import bh2u, format_time -from electrum.lnutil import format_short_channel_id, SENT, RECEIVED +from electrum.lnutil import format_short_channel_id, LOCAL, REMOTE, UpdateAddHtlc, Direction +from electrum.lnchan import htlcsum from electrum.lnaddr import LnAddr, lndecode from electrum.bitcoin import COIN @@ -30,8 +30,8 @@ class LinkedLabel(QtWidgets.QLabel): self.linkActivated.connect(on_clicked) class ChannelDetailsDialog(QtWidgets.QDialog): - def make_htlc_item(self, i: UpdateAddHtlc, direction: HTLCOwner) -> HTLCItem: - it = HTLCItem(_('Sent HTLC with ID {}' if SENT == direction else 'Received HTLC with ID {}').format(i.htlc_id)) + def make_htlc_item(self, i: UpdateAddHtlc, direction: Direction) -> HTLCItem: + it = HTLCItem(_('Sent HTLC with ID {}' if Direction.SENT == direction else 'Received HTLC with ID {}').format(i.htlc_id)) it.appendRow([HTLCItem(_('Amount')),HTLCItem(self.format(i.amount_msat))]) it.appendRow([HTLCItem(_('CLTV expiry')),HTLCItem(str(i.cltv_expiry))]) it.appendRow([HTLCItem(_('Payment hash')),HTLCItem(bh2u(i.payment_hash))]) @@ -45,7 +45,7 @@ class ChannelDetailsDialog(QtWidgets.QDialog): invoice.appendRow([HTLCItem(_('Date')), HTLCItem(format_time(lnaddr.date))]) it.appendRow([invoice]) - def make_inflight(self, lnaddr, i: UpdateAddHtlc, direction: HTLCOwner) -> HTLCItem: + def make_inflight(self, lnaddr, i: UpdateAddHtlc, direction: Direction) -> HTLCItem: it = self.make_htlc_item(i, direction) self.append_lnaddr(it, lnaddr) return it @@ -99,23 +99,23 @@ class ChannelDetailsDialog(QtWidgets.QDialog): dest_mapping = self.keyname_rows[to] dest_mapping[payment_hash] = len(dest_mapping) - ln_payment_completed = QtCore.pyqtSignal(str, float, HTLCOwner, UpdateAddHtlc, bytes, bytes) - htlc_added = QtCore.pyqtSignal(str, UpdateAddHtlc, LnAddr, HTLCOwner) + ln_payment_completed = QtCore.pyqtSignal(str, float, Direction, UpdateAddHtlc, bytes, bytes) + htlc_added = QtCore.pyqtSignal(str, UpdateAddHtlc, LnAddr, Direction) - @QtCore.pyqtSlot(str, UpdateAddHtlc, LnAddr, HTLCOwner) + @QtCore.pyqtSlot(str, UpdateAddHtlc, LnAddr, Direction) def do_htlc_added(self, evtname, htlc, lnaddr, direction): mapping = self.keyname_rows['inflight'] mapping[htlc.payment_hash] = len(mapping) self.folders['inflight'].appendRow(self.make_inflight(lnaddr, htlc, direction)) - @QtCore.pyqtSlot(str, float, HTLCOwner, UpdateAddHtlc, bytes, bytes) + @QtCore.pyqtSlot(str, float, Direction, UpdateAddHtlc, bytes, bytes) def do_ln_payment_completed(self, evtname, date, direction, htlc, preimage, chan_id): self.move('inflight', 'settled', htlc.payment_hash) self.update_sent_received() def update_sent_received(self): - self.sent_label.setText(str(sum(self.chan.settled[SENT]))) - self.received_label.setText(str(sum(self.chan.settled[RECEIVED]))) + self.sent_label.setText(str(htlcsum(self.hm.settled_htlcs_by(LOCAL)))) + self.received_label.setText(str(htlcsum(self.hm.settled_htlcs_by(REMOTE)))) @QtCore.pyqtSlot(str) def show_tx(self, link_text: str): diff --git a/electrum/gui/qt/channels_list.py b/electrum/gui/qt/channels_list.py @@ -30,8 +30,9 @@ class ChannelsList(MyTreeView): for subject in (REMOTE, LOCAL): bal_minus_htlcs = chan.balance_minus_outgoing_htlcs(subject)//1000 label = self.parent.format_amount(bal_minus_htlcs) - bal_other = chan.balance(-subject)//1000 - bal_minus_htlcs_other = chan.balance_minus_outgoing_htlcs(-subject)//1000 + other = subject.inverted() + bal_other = chan.balance(other)//1000 + bal_minus_htlcs_other = chan.balance_minus_outgoing_htlcs(other)//1000 if bal_other != bal_minus_htlcs_other: label += ' (+' + self.parent.format_amount(bal_other - bal_minus_htlcs_other) + ')' labels[subject] = label diff --git a/electrum/lnbase.py b/electrum/lnbase.py @@ -25,8 +25,8 @@ from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabl from .transaction import Transaction, TxOutput from .lnonion import (new_onion_packet, decode_onion_error, OnionFailureCode, calc_hops_data_for_payment, process_onion_packet, OnionPacket, construct_onion_error, OnionRoutingFailureMessage) -from .lnchan import Channel, RevokeAndAck, htlcsum, UpdateAddHtlc -from .lnutil import (Outpoint, LocalConfig, RECEIVED, +from .lnchan import Channel, RevokeAndAck, htlcsum +from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, RemoteConfig, OnlyPubkeyKeypair, ChannelConstraints, RevocationStore, funding_output_script, get_per_commitment_secret_from_seed, secret_to_pubkey, LNPeerAddr, PaymentFailure, LnLocalFeatures, @@ -397,20 +397,20 @@ class Peer(PrintError): htlc_basepoint=keypair_generator(LnKeyFamily.HTLC_BASE), delayed_basepoint=keypair_generator(LnKeyFamily.DELAY_BASE), revocation_basepoint=keypair_generator(LnKeyFamily.REVOCATION_BASE), - to_self_delay=143, + to_self_delay=9, dust_limit_sat=546, max_htlc_value_in_flight_msat=0xffffffffffffffff, max_accepted_htlcs=5, initial_msat=initial_msat, ctn=-1, next_htlc_id=0, - amount_msat=initial_msat, reserve_sat=546, per_commitment_secret_seed=keypair_generator(LnKeyFamily.REVOCATION_ROOT).privkey, funding_locked_received=False, was_announced=False, current_commitment_signature=None, current_htlc_signatures=[], + got_sig_for_next=False, ) return local_config @@ -472,7 +472,6 @@ class Peer(PrintError): max_accepted_htlcs=int.from_bytes(payload["max_accepted_htlcs"], 'big'), initial_msat=push_msat, ctn = -1, - amount_msat=push_msat, next_htlc_id = 0, reserve_sat = remote_reserve_sat, @@ -517,9 +516,11 @@ class Peer(PrintError): # broadcast funding tx await self.network.broadcast_transaction(funding_tx) chan.remote_commitment_to_be_revoked = chan.pending_commitment(REMOTE) - chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0) - chan.config[LOCAL] = chan.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig) + chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0, current_per_commitment_point=remote_per_commitment_point, next_per_commitment_point=None) + chan.config[LOCAL] = chan.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig, got_sig_for_next=False) chan.set_state('OPENING') + chan.set_remote_commitment() + chan.set_local_commitment(chan.current_commitment(LOCAL)) return chan async def on_open_channel(self, payload): @@ -579,7 +580,6 @@ class Peer(PrintError): max_accepted_htlcs=int.from_bytes(payload['max_accepted_htlcs'], 'big'), initial_msat=remote_balance_sat, ctn = -1, - amount_msat=remote_balance_sat, next_htlc_id = 0, reserve_sat = remote_reserve_sat, @@ -605,7 +605,7 @@ class Peer(PrintError): ) chan.set_state('OPENING') chan.remote_commitment_to_be_revoked = chan.pending_commitment(REMOTE) - chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0) + chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0, current_per_commitment_point=payload['first_per_commitment_point'], next_per_commitment_point=None) chan.config[LOCAL] = chan.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig) self.lnworker.save_channel(chan) self.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str()) @@ -732,7 +732,7 @@ class Peer(PrintError): if not chan.config[LOCAL].funding_locked_received: our_next_point = chan.config[REMOTE].next_per_commitment_point their_next_point = payload["next_per_commitment_point"] - new_remote_state = chan.config[REMOTE]._replace(next_per_commitment_point=their_next_point, current_per_commitment_point=our_next_point) + new_remote_state = chan.config[REMOTE]._replace(next_per_commitment_point=their_next_point) new_local_state = chan.config[LOCAL]._replace(funding_locked_received = True) chan.config[REMOTE]=new_remote_state chan.config[LOCAL]=new_local_state diff --git a/electrum/lnchan.py b/electrum/lnchan.py @@ -27,24 +27,25 @@ import binascii import json from enum import Enum, auto from typing import Optional, Dict, List, Tuple, NamedTuple, Set, Callable, Iterable, Sequence -from copy import deepcopy +from . import ecc from .util import bfh, PrintError, bh2u from .bitcoin import TYPE_SCRIPT, TYPE_ADDRESS from .bitcoin import redeem_script_to_address from .crypto import sha256, sha256d -from . import ecc -from .lnutil import Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints, RevocationStore -from .lnutil import get_per_commitment_secret_from_seed -from .lnutil import secret_to_pubkey, derive_privkey, derive_pubkey, derive_blinded_pubkey -from .lnutil import sign_and_get_sig_string -from .lnutil import make_htlc_tx_with_open_channel, make_commitment, make_received_htlc, make_offered_htlc -from .lnutil import HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT -from .lnutil import funding_output_script, LOCAL, REMOTE, HTLCOwner, make_closing_tx, make_commitment_outputs -from .lnutil import ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script +from .simple_config import get_config from .transaction import Transaction + +from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints, + get_per_commitment_secret_from_seed, secret_to_pubkey, derive_privkey, make_closing_tx, + sign_and_get_sig_string, RevocationStore, derive_blinded_pubkey, Direction, derive_pubkey, + make_htlc_tx_with_open_channel, make_commitment, make_received_htlc, make_offered_htlc, + HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT, extract_ctn_from_tx_and_chan, UpdateAddHtlc, + funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs, + ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script) from .lnsweep import create_sweeptxs_for_their_just_revoked_ctx from .lnsweep import create_sweeptxs_for_our_latest_ctx, create_sweeptxs_for_their_latest_ctx +from .lnhtlc import HTLCManager class ChannelJsonEncoder(json.JSONEncoder): @@ -83,22 +84,6 @@ class FeeUpdate(defaultdict): return self.rate # implicit return None -class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])): - """ - This whole class body is so that if you pass a hex-string as payment_hash, - it is decoded to bytes. Bytes can't be saved to disk, so we save hex-strings. - """ - __slots__ = () - def __new__(cls, *args, **kwargs): - if len(args) > 0: - args = list(args) - if type(args[1]) is str: - args[1] = bfh(args[1]) - return super().__new__(cls, *args) - if type(kwargs['payment_hash']) is str: - kwargs['payment_hash'] = bfh(kwargs['payment_hash']) - return super().__new__(cls, **kwargs) - def decodeAll(d, local): for k, v in d.items(): if k == 'revocation_store': @@ -124,20 +109,6 @@ def str_bytes_dict_from_save(x): def str_bytes_dict_to_save(x): return {str(k): bh2u(v) for k, v in x.items()} -class HtlcChanges(NamedTuple): - # ints are htlc ids - adds: Dict[int, UpdateAddHtlc] - settles: Set[int] - fails: Set[int] - locked_in: Set[int] - - @staticmethod - def new(): - """ - Since we can't use default arguments for these types (they would be shared among instances) - """ - return HtlcChanges({}, set(), set(), set()) - class Channel(PrintError): def diagnostic_name(self): if self.name: @@ -147,7 +118,7 @@ class Channel(PrintError): except: return super().diagnostic_name() - def __init__(self, state, sweep_address = None, name = None, payment_completed : Optional[Callable[[HTLCOwner, UpdateAddHtlc, bytes], None]] = None): + def __init__(self, state, sweep_address = None, name = None, payment_completed : Optional[Callable[[Direction, UpdateAddHtlc, bytes], None]] = None): self.preimages = {} if not payment_completed: payment_completed = lambda this, x, y, z: None @@ -179,13 +150,9 @@ class Channel(PrintError): # we should not persist txns in this format. we should persist htlcs, and be able to derive # any past commitment transaction and use that instead; until then... self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"]) + self.remote_commitment_to_be_revoked.deserialize(True) - self.log = {LOCAL: HtlcChanges.new(), REMOTE: HtlcChanges.new()} - for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]: - if strname not in state: continue - for y in state[strname]: - htlc = UpdateAddHtlc(**y) - self.log[subject].adds[htlc.htlc_id] = htlc + self.hm = HTLCManager(state.get('log')) self.name = name @@ -200,23 +167,18 @@ class Channel(PrintError): self.lnwatcher = None - self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])} - - for sub in (LOCAL, REMOTE): - self.log[sub].locked_in.update(self.log[sub].adds.keys()) - - self.set_local_commitment(self.current_commitment(LOCAL)) - self.set_remote_commitment(self.current_commitment(REMOTE)) + self.local_commitment = None + self.remote_commitment = None def set_local_commitment(self, ctx): + ctn = extract_ctn_from_tx_and_chan(ctx, self) + assert self.signature_fits(ctx), (self.log[LOCAL]) self.local_commitment = ctx if self.sweep_address is not None: self.local_sweeptxs = create_sweeptxs_for_our_latest_ctx(self, self.local_commitment, self.sweep_address) - self.assert_signature_fits(ctx) - - def set_remote_commitment(self, ctx): - self.remote_commitment = ctx + def set_remote_commitment(self): + self.remote_commitment = self.current_commitment(REMOTE) if self.sweep_address is not None: self.remote_sweeptxs = create_sweeptxs_for_their_latest_ctx(self, self.remote_commitment, self.sweep_address) @@ -233,9 +195,9 @@ class Channel(PrintError): raise PaymentFailure('Channel not open') if self.available_to_spend(LOCAL) < amount_msat: raise PaymentFailure(f'Not enough local balance. Have: {self.available_to_spend(LOCAL)}, Need: {amount_msat}') - if len(self.htlcs(LOCAL, only_pending=True)) + 1 > self.config[REMOTE].max_accepted_htlcs: + if len(self.hm.htlcs(LOCAL)) + 1 > self.config[REMOTE].max_accepted_htlcs: raise PaymentFailure('Too many HTLCs already in channel') - current_htlc_sum = htlcsum(self.htlcs(LOCAL, only_pending=True)) + current_htlc_sum = htlcsum(self.hm.htlcs_by_direction(LOCAL, SENT)) + htlcsum(self.hm.htlcs_by_direction(LOCAL, RECEIVED)) if current_htlc_sum + amount_msat > self.config[REMOTE].max_htlc_value_in_flight_msat: raise PaymentFailure(f'HTLC value sum (sum of pending htlcs: {current_htlc_sum/1000} sat plus new htlc: {amount_msat/1000} sat) would exceed max allowed: {self.config[REMOTE].max_htlc_value_in_flight_msat/1000} sat') if amount_msat <= 0: # FIXME htlc_minimum_msat @@ -269,7 +231,7 @@ class Channel(PrintError): assert type(htlc) is dict self._check_can_pay(htlc['amount_msat']) htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id) - self.log[LOCAL].adds[htlc.htlc_id] = htlc + self.hm.send_htlc(htlc) self.print_error("add_htlc") self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1) return htlc.htlc_id @@ -288,8 +250,7 @@ class Channel(PrintError): raise RemoteMisbehaving('Remote dipped below channel reserve.' +\ f' Available at remote: {self.available_to_spend(REMOTE)},' +\ f' HTLC amount: {htlc.amount_msat}') - adds = self.log[REMOTE].adds - adds[htlc.htlc_id] = htlc + self.hm.recv_htlc(htlc) self.print_error("receive_htlc") self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1) return htlc.htlc_id @@ -308,7 +269,7 @@ class Channel(PrintError): """ self.print_error("sign_next_commitment") - old_logs = dict(self.lock_in_htlc_changes(LOCAL)) + self.hm.send_ctx() pending_remote_commitment = self.pending_commitment(REMOTE) sig_64 = sign_and_get_sig_string(pending_remote_commitment, self.config[LOCAL], self.config[REMOTE]) @@ -321,7 +282,8 @@ class Channel(PrintError): for_us = False htlcsigs = [] - for we_receive, htlcs in zip([True, False], [self.included_htlcs(REMOTE, REMOTE), self.included_htlcs(REMOTE, LOCAL)]): + # they sent => we receive + for we_receive, htlcs in zip([True, False], [self.included_htlcs(REMOTE, SENT, ctn=self.config[REMOTE].ctn+1), self.included_htlcs(REMOTE, RECEIVED, ctn=self.config[REMOTE].ctn+1)]): for htlc in htlcs: _script, htlc_tx = make_htlc_tx_with_open_channel(chan=self, pcp=self.config[REMOTE].next_per_commitment_point, @@ -337,26 +299,11 @@ class Channel(PrintError): htlcsigs.sort() htlcsigs = [x[1] for x in htlcsigs] - self.remote_commitment = self.pending_commitment(REMOTE) - - # we can't know if this message arrives. - # since we shouldn't actually throw away - # failed htlcs yet (or mark htlc locked in), - # roll back the changes that were made - self.log = old_logs + # TODO should add remote_commitment here and handle + # both valid ctx'es in lnwatcher at the same time... return sig_64, htlcsigs - def lock_in_htlc_changes(self, subject): - for sub in (LOCAL, REMOTE): - log = self.log[sub] - yield (sub, deepcopy(log)) - for htlc_id in log.fails: - log.adds.pop(htlc_id) - log.fails.clear() - - self.log[subject].locked_in.update(self.log[subject].adds.keys()) - def receive_new_commitment(self, sig, htlc_sigs): """ ReceiveNewCommitment process a signature for a new commitment state sent by @@ -372,7 +319,7 @@ class Channel(PrintError): """ self.print_error("receive_new_commitment") - for _ in self.lock_in_htlc_changes(REMOTE): pass + self.hm.recv_ctx() assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes @@ -385,16 +332,18 @@ class Channel(PrintError): htlc_sigs_string = b''.join(htlc_sigs) htlc_sigs = htlc_sigs[:] # copy cause we will delete now - for htlcs, we_receive in [(self.included_htlcs(LOCAL, REMOTE), True), (self.included_htlcs(LOCAL, LOCAL), False)]: + ctn = self.config[LOCAL].ctn+1 + for htlcs, we_receive in [(self.included_htlcs(LOCAL, SENT, ctn=ctn), False), (self.included_htlcs(LOCAL, RECEIVED, ctn=ctn), True)]: for htlc in htlcs: - idx = self.verify_htlc(htlc, htlc_sigs, we_receive) + idx = self.verify_htlc(htlc, htlc_sigs, we_receive, pending_local_commitment) del htlc_sigs[idx] if len(htlc_sigs) != 0: # all sigs should have been popped above raise Exception('failed verifying HTLC signatures: invalid amount of correct signatures') self.config[LOCAL]=self.config[LOCAL]._replace( current_commitment_signature=sig, - current_htlc_signatures=htlc_sigs_string) + current_htlc_signatures=htlc_sigs_string, + got_sig_for_next=True) if self.pending_fee is not None: if not self.constraints.is_initiator: @@ -402,15 +351,15 @@ class Channel(PrintError): if self.constraints.is_initiator and self.pending_fee[FUNDEE_ACKED]: self.pending_fee[FUNDER_SIGNED] = True - self.set_local_commitment(self.pending_commitment(LOCAL)) + self.set_local_commitment(pending_local_commitment) - def verify_htlc(self, htlc: UpdateAddHtlc, htlc_sigs: Sequence[bytes], we_receive: bool) -> int: - _, this_point, _ = self.points() + def verify_htlc(self, htlc: UpdateAddHtlc, htlc_sigs: Sequence[bytes], we_receive: bool, ctx) -> int: + _, this_point, _, _ = self.points() _script, htlc_tx = make_htlc_tx_with_open_channel(chan=self, pcp=this_point, for_us=True, we_receive=we_receive, - commit=self.pending_commitment(LOCAL), + commit=ctx, htlc=htlc) pre_hash = sha256d(bfh(htlc_tx.serialize_preimage(0))) remote_htlc_pubkey = derive_pubkey(self.config[REMOTE].htlc_basepoint.pubkey, this_point) @@ -418,19 +367,19 @@ class Channel(PrintError): if ecc.verify_signature(remote_htlc_pubkey, sig, pre_hash): return idx else: - raise Exception(f'failed verifying HTLC signatures: {htlc}') + raise Exception(f'failed verifying HTLC signatures: {htlc}, sigs: {len(htlc_sigs)}, we_receive: {we_receive}') - def get_remote_htlc_sig_for_htlc(self, htlc: UpdateAddHtlc, we_receive: bool) -> bytes: + def get_remote_htlc_sig_for_htlc(self, htlc: UpdateAddHtlc, we_receive: bool, ctx) -> bytes: data = self.config[LOCAL].current_htlc_signatures htlc_sigs = [data[i:i + 64] for i in range(0, len(data), 64)] - idx = self.verify_htlc(htlc, htlc_sigs, we_receive=we_receive) + idx = self.verify_htlc(htlc, htlc_sigs, we_receive=we_receive, ctx=ctx) remote_htlc_sig = ecc.der_sig_from_sig_string(htlc_sigs[idx]) + b'\x01' return remote_htlc_sig def revoke_current_commitment(self): self.print_error("revoke_current_commitment") - last_secret, this_point, next_point = self.points() + last_secret, this_point, next_point, _ = self.points() new_feerate = self.constraints.feerate @@ -444,16 +393,18 @@ class Channel(PrintError): self.pending_fee = None print("FEERATE CHANGE COMPLETE (initiator)") - self.config[LOCAL]=self.config[LOCAL]._replace( - ctn=self.config[LOCAL].ctn + 1, - ) + assert self.config[LOCAL].got_sig_for_next self.constraints=self.constraints._replace( feerate=new_feerate ) - - # since we should not revoke our latest commitment tx, - # we do not update self.local_commitment here, - # it should instead be updated when we receive a new sig + self.set_local_commitment(self.pending_commitment(LOCAL)) + ctx = self.pending_commitment(LOCAL) + self.hm.send_rev() + self.config[LOCAL]=self.config[LOCAL]._replace( + ctn=self.config[LOCAL].ctn + 1, + got_sig_for_next=False, + ) + assert self.signature_fits(ctx) return RevokeAndAck(last_secret, next_point), "current htlcs" @@ -466,7 +417,8 @@ class Channel(PrintError): this_point = secret_to_pubkey(int.from_bytes(this_secret, 'big')) next_secret = get_per_commitment_secret_from_seed(self.config[LOCAL].per_commitment_secret_seed, RevocationStore.START_INDEX - next_small_num) next_point = secret_to_pubkey(int.from_bytes(next_secret, 'big')) - return last_secret, this_point, next_point + last_point = secret_to_pubkey(int.from_bytes(last_secret, 'big')) + return last_secret, this_point, next_point, last_point def process_new_revocation_secret(self, per_commitment_secret: bytes): if not self.lnwatcher: @@ -481,12 +433,9 @@ class Channel(PrintError): def receive_revocation(self, revocation) -> Tuple[int, int]: self.print_error("receive_revocation") - old_logs = dict(self.lock_in_htlc_changes(LOCAL)) - cur_point = self.config[REMOTE].current_per_commitment_point derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True) if cur_point != derived_point: - self.log = old_logs raise Exception('revoked secret not for current point') # FIXME not sure this is correct... but it seems to work @@ -505,51 +454,36 @@ class Channel(PrintError): if self.constraints.is_initiator and self.pending_fee[FUNDEE_ACKED]: self.pending_fee[FUNDER_SIGNED] = True - def mark_settled(subject): - """ - find pending settlements for subject (LOCAL or REMOTE) and mark them settled, return value of settled htlcs - """ - old_amount = htlcsum(self.htlcs(subject, False)) - - for htlc_id in self.log[subject].settles: - adds = self.log[subject].adds - htlc = adds.pop(htlc_id) - self.settled[subject].append(htlc.amount_msat) - if subject == LOCAL: - preimage = self.preimages.pop(htlc_id) - else: - preimage = None - self.payment_completed(self, subject, htlc, preimage) - self.log[subject].settles.clear() - - return old_amount - htlcsum(self.htlcs(subject, False)) - - sent_this_batch = mark_settled(LOCAL) - received_this_batch = mark_settled(REMOTE) + received = self.hm.received_in_ctn(self.config[REMOTE].ctn + 1) + sent = self.hm.sent_in_ctn(self.config[REMOTE].ctn + 1) + for htlc in received: + self.payment_completed(self, RECEIVED, htlc, None) + for htlc in sent: + preimage = self.preimages.pop(htlc.htlc_id) + self.payment_completed(self, SENT, htlc, preimage) + received_this_batch = htlcsum(received) + sent_this_batch = htlcsum(sent) next_point = self.config[REMOTE].next_per_commitment_point - print("RECEIVED", received_this_batch) - print("SENT", sent_this_batch) + self.hm.recv_rev() + self.config[REMOTE]=self.config[REMOTE]._replace( ctn=self.config[REMOTE].ctn + 1, current_per_commitment_point=next_point, next_per_commitment_point=revocation.next_per_commitment_point, - amount_msat=self.config[REMOTE].amount_msat + (sent_this_batch - received_this_batch) - ) - self.config[LOCAL]=self.config[LOCAL]._replace( - amount_msat = self.config[LOCAL].amount_msat + (received_this_batch - sent_this_batch) ) if self.pending_fee is not None: if self.constraints.is_initiator: self.pending_fee[FUNDEE_ACKED] = True - self.set_remote_commitment(self.pending_commitment(REMOTE)) + self.set_remote_commitment() self.remote_commitment_to_be_revoked = prev_remote_commitment + return received_this_batch, sent_this_batch - def balance(self, subject): + def balance(self, subject, ctn=None): """ This balance in mSAT is not including reserve and fees. So a node cannot actually use it's whole balance. @@ -560,12 +494,15 @@ class Channel(PrintError): commited to later when the respective commitment transaction as been revoked. """ + assert type(subject) is HTLCOwner initial = self.config[subject].initial_msat - initial -= sum(self.settled[subject]) - initial += sum(self.settled[-subject]) + for direction, htlc in self.hm.settled_htlcs(subject, ctn): + if direction == SENT: + initial -= htlc.amount_msat + else: + initial += htlc.amount_msat - assert initial == self.config[subject].amount_msat return initial def balance_minus_outgoing_htlcs(self, subject): @@ -573,48 +510,46 @@ class Channel(PrintError): This balance in mSAT, which includes the value of pending outgoing HTLCs, is used in the UI. """ - return self.balance(subject)\ - - htlcsum(self.log[subject].adds.values()) + assert type(subject) is HTLCOwner + ctn = self.hm.log[subject]['ctn'] + 1 + return self.balance(subject, ctn)\ + - htlcsum(self.hm.htlcs_by_direction(subject, SENT, ctn)) def available_to_spend(self, subject): """ This balance in mSAT, while technically correct, can not be used in the UI cause it fluctuates (commit fee) """ + assert type(subject) is HTLCOwner return self.balance_minus_outgoing_htlcs(subject)\ - - htlcsum(self.log[subject].adds.values())\ - self.config[-subject].reserve_sat * 1000\ - calc_onchain_fees( # TODO should we include a potential new htlc, when we are called from receive_htlc? - len(list(self.included_htlcs(subject, LOCAL)) + list(self.included_htlcs(subject, REMOTE))), + len(self.included_htlcs(subject, SENT) + self.included_htlcs(subject, RECEIVED)), self.pending_feerate(subject), - True, # for_us self.constraints.is_initiator, )[subject] - def amounts(self): - remote_settled= htlcsum(self.htlcs(REMOTE, False)) - local_settled= htlcsum(self.htlcs(LOCAL, False)) - unsettled_local = htlcsum(self.htlcs(LOCAL, True)) - unsettled_remote = htlcsum(self.htlcs(REMOTE, True)) - remote_msat = self.config[REMOTE].amount_msat -\ - unsettled_remote + local_settled - remote_settled - local_msat = self.config[LOCAL].amount_msat -\ - unsettled_local + remote_settled - local_settled - return remote_msat, local_msat - - def included_htlcs(self, subject, htlc_initiator, only_pending=True): + def included_htlcs(self, subject, direction, ctn=None): """ return filter of non-dust htlcs for subjects commitment transaction, initiated by given party """ + assert type(subject) is HTLCOwner + assert type(direction) is Direction + if ctn is None: + ctn = self.config[subject].ctn feerate = self.pending_feerate(subject) conf = self.config[subject] - weight = HTLC_SUCCESS_WEIGHT if subject != htlc_initiator else HTLC_TIMEOUT_WEIGHT - htlcs = self.htlcs(htlc_initiator, only_pending=only_pending) + if (subject, direction) in [(REMOTE, RECEIVED), (LOCAL, SENT)]: + weight = HTLC_SUCCESS_WEIGHT + else: + weight = HTLC_TIMEOUT_WEIGHT + htlcs = self.hm.htlcs_by_direction(subject, direction, ctn=ctn) fee_for_htlc = lambda htlc: htlc.amount_msat // 1000 - (weight * feerate // 1000) - return filter(lambda htlc: fee_for_htlc(htlc) >= conf.dust_limit_sat, htlcs) + return list(filter(lambda htlc: fee_for_htlc(htlc) >= conf.dust_limit_sat, htlcs)) def pending_feerate(self, subject): + assert type(subject) is HTLCOwner candidate = self.constraints.feerate if self.pending_fee is not None: x = self.pending_fee.pending_feerate(subject) @@ -623,81 +558,53 @@ class Channel(PrintError): return candidate def pending_commitment(self, subject): + assert type(subject) is HTLCOwner this_point = self.config[REMOTE].next_per_commitment_point if subject == REMOTE else self.points()[1] - return self.make_commitment(subject, this_point) + ctn = self.config[subject].ctn + 1 + feerate = self.pending_feerate(subject) + return self.make_commitment(subject, this_point, ctn, feerate, True) def current_commitment(self, subject): - old_local_state = self.config[subject] - self.config[subject]=self.config[subject]._replace(ctn=self.config[subject].ctn - 1) - r = self.pending_commitment(subject) - self.config[subject] = old_local_state - return r - - def total_msat(self, sub): - return sum(self.settled[sub]) + assert type(subject) is HTLCOwner + this_point = self.config[REMOTE].current_per_commitment_point if subject == REMOTE else self.points()[3] + ctn = self.config[subject].ctn + feerate = self.constraints.feerate + return self.make_commitment(subject, this_point, ctn, feerate, False) - def htlcs(self, subject, only_pending): - """ - only_pending: require the htlc's settlement to be pending (needs additional signatures/acks) - - sets returned with True and False are disjunct - - only_pending true: - skipped if settled or failed - <=> - included if not settled and not failed - only_pending false: - skipped if not (settled or failed) - <=> - included if not not (settled or failed) - included if settled or failed - """ - update_log = self.log[subject] - res = [] - for htlc in update_log.adds.values(): - locked_in = htlc.htlc_id in update_log.locked_in - settled = htlc.htlc_id in update_log.settles - failed = htlc.htlc_id in update_log.fails - if not locked_in: - continue - if only_pending == (settled or failed): - continue - res.append(htlc) - return res + def total_msat(self, direction): + assert type(direction) is Direction + sub = LOCAL if direction == SENT else REMOTE + return htlcsum(self.hm.settled_htlcs_by(sub, self.config[sub].ctn)) def settle_htlc(self, preimage, htlc_id): """ SettleHTLC attempts to settle an existing outstanding received HTLC. """ self.print_error("settle_htlc") - log = self.log[REMOTE] - htlc = log.adds[htlc_id] + log = self.hm.log[REMOTE] + htlc = log['adds'][htlc_id] assert htlc.payment_hash == sha256(preimage) - assert htlc_id not in log.settles - log.settles.add(htlc_id) + assert htlc_id not in log['settles'] + self.hm.send_settle(htlc_id) # not saving preimage because it's already saved in LNWorker.invoices def receive_htlc_settle(self, preimage, htlc_id): self.print_error("receive_htlc_settle") - log = self.log[LOCAL] - htlc = log.adds[htlc_id] + log = self.hm.log[LOCAL] + htlc = log['adds'][htlc_id] assert htlc.payment_hash == sha256(preimage) - assert htlc_id not in log.settles + assert htlc_id not in log['settles'] + self.hm.recv_settle(htlc_id) self.preimages[htlc_id] = preimage - log.settles.add(htlc_id) # we don't save the preimage because we don't need to forward it anyway def fail_htlc(self, htlc_id): self.print_error("fail_htlc") - log = self.log[REMOTE] - assert htlc_id not in log.fails - log.fails.add(htlc_id) + self.hm.send_fail(htlc_id) def receive_fail_htlc(self, htlc_id): self.print_error("receive_fail_htlc") - log = self.log[LOCAL] - assert htlc_id not in log.fails - log.fails.add(htlc_id) + self.hm.recv_fail(htlc_id) @property def current_height(self): @@ -713,29 +620,7 @@ class Channel(PrintError): raise Exception("a fee update is already in progress") self.pending_fee = FeeUpdate(self, rate=feerate) - def remove_uncommitted_htlcs_from_log(self, subject): - """ - returns - - the htlcs with uncommited (not locked in) htlcs removed - - a list of htlc_ids that were removed - """ - removed = [] - htlcs = [] - log = self.log[subject] - for i in log.adds.values(): - locked_in = i.htlc_id in log.locked_in - if locked_in: - htlcs.append(i._asdict()) - else: - removed.append(i.htlc_id) - return htlcs, removed - def to_save(self): - # need to forget about uncommited htlcs - # since we must assume they don't know about it, - # if it was not acked - remote_filtered, remote_removed = self.remove_uncommitted_htlcs_from_log(REMOTE) - local_filtered, local_removed = self.remove_uncommitted_htlcs_from_log(LOCAL) to_save = { "local_config": self.config[LOCAL], "remote_config": self.config[REMOTE], @@ -745,24 +630,10 @@ class Channel(PrintError): "funding_outpoint": self.funding_outpoint, "node_id": self.node_id, "remote_commitment_to_be_revoked": str(self.remote_commitment_to_be_revoked), - "remote_log": remote_filtered, - "local_log": local_filtered, + "log": self.hm.to_save(), "onion_keys": str_bytes_dict_to_save(self.onion_keys), - "settled_local": self.settled[LOCAL], - "settled_remote": self.settled[REMOTE], "force_closed": self.get_state() == 'FORCE_CLOSING', } - - # htlcs number must be monotonically increasing, - # so we have to decrease the counter - if len(remote_removed) != 0: - assert min(remote_removed) < to_save['remote_config'].next_htlc_id - to_save['remote_config'] = to_save['remote_config']._replace(next_htlc_id = min(remote_removed)) - - if len(local_removed) != 0: - assert min(local_removed) < to_save['local_config'].next_htlc_id - to_save['local_config'] = to_save['local_config']._replace(next_htlc_id = min(local_removed)) - return to_save def serialize(self): @@ -792,33 +663,49 @@ class Channel(PrintError): def __str__(self): return str(self.serialize()) - def make_commitment(self, subject, this_point) -> Transaction: - remote_msat, local_msat = self.amounts() - assert local_msat >= 0, local_msat - assert remote_msat >= 0, remote_msat + def make_commitment(self, subject, this_point, ctn, feerate, pending) -> Transaction: + #if subject == REMOTE and not pending: + # ctn -= 1 + assert type(subject) is HTLCOwner + other = REMOTE if LOCAL == subject else LOCAL + remote_msat, local_msat = self.balance(other, ctn), self.balance(subject, ctn) + received_htlcs = self.hm.htlcs_by_direction(subject, SENT if subject == LOCAL else RECEIVED, ctn) + sent_htlcs = self.hm.htlcs_by_direction(subject, RECEIVED if subject == LOCAL else SENT, ctn) + if subject != LOCAL: + remote_msat -= htlcsum(received_htlcs) + local_msat -= htlcsum(sent_htlcs) + else: + remote_msat -= htlcsum(sent_htlcs) + local_msat -= htlcsum(received_htlcs) + assert remote_msat >= 0 + assert local_msat >= 0 + # same htlcs as before, but now without dust. + received_htlcs = self.included_htlcs(subject, SENT if subject == LOCAL else RECEIVED, ctn) + sent_htlcs = self.included_htlcs(subject, RECEIVED if subject == LOCAL else SENT, ctn) + this_config = self.config[subject] other_config = self.config[-subject] other_htlc_pubkey = derive_pubkey(other_config.htlc_basepoint.pubkey, this_point) this_htlc_pubkey = derive_pubkey(this_config.htlc_basepoint.pubkey, this_point) other_revocation_pubkey = derive_blinded_pubkey(other_config.revocation_basepoint.pubkey, this_point) htlcs = [] # type: List[ScriptHtlc] - def append_htlc(htlc: UpdateAddHtlc, is_received_htlc: bool): - htlcs.append(ScriptHtlc(make_htlc_output_witness_script( - is_received_htlc=is_received_htlc, - remote_revocation_pubkey=other_revocation_pubkey, - remote_htlc_pubkey=other_htlc_pubkey, - local_htlc_pubkey=this_htlc_pubkey, - payment_hash=htlc.payment_hash, - cltv_expiry=htlc.cltv_expiry), htlc)) - for htlc in self.included_htlcs(subject, -subject): - append_htlc(htlc, is_received_htlc=True) - for htlc in self.included_htlcs(subject, subject): - append_htlc(htlc, is_received_htlc=False) - if subject != LOCAL: - remote_msat, local_msat = local_msat, remote_msat + for is_received_htlc, htlc_list in zip((subject != LOCAL, subject == LOCAL), (received_htlcs, sent_htlcs)): + for htlc in htlc_list: + htlcs.append(ScriptHtlc(make_htlc_output_witness_script( + is_received_htlc=is_received_htlc, + remote_revocation_pubkey=other_revocation_pubkey, + remote_htlc_pubkey=other_htlc_pubkey, + local_htlc_pubkey=this_htlc_pubkey, + payment_hash=htlc.payment_hash, + cltv_expiry=htlc.cltv_expiry), htlc)) + onchain_fees = calc_onchain_fees( + len(htlcs), + feerate, + self.constraints.is_initiator == (subject == LOCAL), + ) payment_pubkey = derive_pubkey(other_config.payment_basepoint.pubkey, this_point) return make_commitment( - self.config[subject].ctn + 1, + ctn, this_config.multisig_key.pubkey, other_config.multisig_key.pubkey, payment_pubkey, @@ -832,12 +719,7 @@ class Channel(PrintError): local_msat, remote_msat, this_config.dust_limit_sat, - calc_onchain_fees( - len(htlcs), - self.pending_feerate(subject), - subject == LOCAL, - self.constraints.is_initiator, - ), + onchain_fees, htlcs=htlcs) def get_local_index(self): @@ -850,8 +732,8 @@ class Channel(PrintError): LOCAL: fee_sat * 1000 if self.constraints.is_initiator else 0, REMOTE: fee_sat * 1000 if not self.constraints.is_initiator else 0, }, - self.config[LOCAL].amount_msat, - self.config[REMOTE].amount_msat, + self.balance(LOCAL), + self.balance(REMOTE), (TYPE_SCRIPT, bh2u(local_script)), (TYPE_SCRIPT, bh2u(remote_script)), [], self.config[LOCAL].dust_limit_sat) @@ -867,38 +749,39 @@ class Channel(PrintError): sig = ecc.sig_string_from_der_sig(der_sig[:-1]) return sig, closing_tx - def assert_signature_fits(self, tx): + def signature_fits(self, tx): remote_sig = self.config[LOCAL].current_commitment_signature - if remote_sig: # only None in test - preimage_hex = tx.serialize_preimage(0) - pre_hash = sha256d(bfh(preimage_hex)) - if not ecc.verify_signature(self.config[REMOTE].multisig_key.pubkey, remote_sig, pre_hash): - self.print_error("WARNING: commitment signature inconsistency, cannot force close") + preimage_hex = tx.serialize_preimage(0) + pre_hash = sha256d(bfh(preimage_hex)) + assert remote_sig + res = ecc.verify_signature(self.config[REMOTE].multisig_key.pubkey, remote_sig, pre_hash) + return res def force_close_tx(self): tx = self.local_commitment + assert self.signature_fits(tx) tx = Transaction(str(tx)) tx.deserialize(True) - self.assert_signature_fits(tx) tx.sign({bh2u(self.config[LOCAL].multisig_key.pubkey): (self.config[LOCAL].multisig_key.privkey, True)}) remote_sig = self.config[LOCAL].current_commitment_signature - if remote_sig: # only None in test - remote_sig = ecc.der_sig_from_sig_string(remote_sig) + b"\x01" - sigs = tx._inputs[0]["signatures"] - none_idx = sigs.index(None) - tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig)) - assert tx.is_complete() + remote_sig = ecc.der_sig_from_sig_string(remote_sig) + b"\x01" + sigs = tx._inputs[0]["signatures"] + none_idx = sigs.index(None) + tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig)) + assert tx.is_complete() return tx def included_htlcs_in_their_latest_ctxs(self, htlc_initiator) -> Dict[int, List[UpdateAddHtlc]]: """ A map from commitment number to list of HTLCs in their latest two commitment transactions. The oldest might have been revoked. """ - old_htlcs = list(self.included_htlcs(REMOTE, htlc_initiator, only_pending=False)) + assert type(htlc_initiator) is HTLCOwner + direction = RECEIVED if htlc_initiator == LOCAL else SENT + old_ctn = self.config[REMOTE].ctn + old_htlcs = self.included_htlcs(REMOTE, direction, ctn=old_ctn) - old_logs = dict(self.lock_in_htlc_changes(LOCAL)) - new_htlcs = list(self.included_htlcs(REMOTE, htlc_initiator)) - self.log = old_logs + new_ctn = self.config[REMOTE].ctn+1 + new_htlcs = self.included_htlcs(REMOTE, direction, ctn=new_ctn) - return {self.config[REMOTE].ctn: old_htlcs, - self.config[REMOTE].ctn+1: new_htlcs, } + return {old_ctn: old_htlcs, + new_ctn: new_htlcs, } diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py @@ -0,0 +1,159 @@ +from copy import deepcopy +from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction +from .util import bh2u + +class HTLCManager: + def __init__(self, log=None): + self.expect_sig = {SENT: False, RECEIVED: False} + if log is None: + initial = {'ctn': 0, 'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}} + log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)} + else: + assert type(log) is dict + log = {HTLCOwner(int(x)): y for x, y in deepcopy(log).items()} + for sub in (LOCAL, REMOTE): + log[sub]['adds'] = {int(x): UpdateAddHtlc(*y) for x, y in log[sub]['adds'].items()} + coerceHtlcOwner2IntMap = lambda x: {HTLCOwner(int(y)): z for y, z in x.items()} + log[sub]['locked_in'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['locked_in'].items()} + log[sub]['settles'] = {int(x): y for x, y in log[sub]['settles'].items()} + log[sub]['fails'] = {int(x): y for x, y in log[sub]['fails'].items()} + self.log = log + + def to_save(self): + x = deepcopy(self.log) + for sub in (LOCAL, REMOTE): + d = {} + for htlc_id, htlc in x[sub]['adds'].items(): + d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:] + x[sub]['adds'] = d + return x + + def send_htlc(self, htlc): + htlc_id = htlc.htlc_id + adds = self.log[LOCAL]['adds'] + assert type(adds) is not str + adds[htlc_id] = htlc + self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.log[REMOTE]['ctn']+1} + self.expect_sig[SENT] = True + return htlc + + def recv_htlc(self, htlc): + htlc_id = htlc.htlc_id + self.log[REMOTE]['htlc_id'] = htlc_id + self.log[REMOTE]['adds'][htlc_id] = htlc + l = self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.log[LOCAL]['ctn']+1, REMOTE: None} + self.expect_sig[RECEIVED] = True + + def send_ctx(self): + next_ctn = self.log[REMOTE]['ctn'] + 1 + + for locked_in in self.log[REMOTE]['locked_in'].values(): + if locked_in[REMOTE] is None: + locked_in[REMOTE] = next_ctn + + self.expect_sig[SENT] = False + + #return Sig(self.pending_htlcs(REMOTE), next_ctn) + + def recv_ctx(self): + next_ctn = self.log[LOCAL]['ctn'] + 1 + + for locked_in in self.log[LOCAL]['locked_in'].values(): + if locked_in[LOCAL] is None: + locked_in[LOCAL] = next_ctn + + self.expect_sig[SENT] = False + + def send_rev(self): + self.log[LOCAL]['ctn'] += 1 + + def recv_rev(self): + self.log[REMOTE]['ctn'] += 1 + did_set_htlc_height = False + for htlc_id, ctnheights in self.log[LOCAL]['locked_in'].items(): + if ctnheights[LOCAL] is None: + did_set_htlc_height = True + assert ctnheights[REMOTE] == self.log[REMOTE]['ctn'] + ctnheights[LOCAL] = ctnheights[REMOTE] + return did_set_htlc_height + + def htlcs_by_direction(self, subject, direction, ctn=None): + """ + direction is relative to subject! + """ + assert type(subject) is HTLCOwner + assert type(direction) is Direction + if ctn is None: + ctn = self.log[subject]['ctn'] + l = [] + if direction == SENT and subject == LOCAL: + party = LOCAL + elif direction == RECEIVED and subject == REMOTE: + party = LOCAL + else: + party = REMOTE + for htlc_id, ctnheights in self.log[party]['locked_in'].items(): + htlc_height = ctnheights[subject] + if htlc_height is None: + include = not self.expect_sig[RECEIVED if party == LOCAL else SENT] and ctnheights[-subject] <= ctn + else: + include = htlc_height <= ctn + if include: + settles = self.log[party]['settles'] + if htlc_id not in settles or settles[htlc_id] > ctn: + fails = self.log[party]['fails'] + if htlc_id not in fails or fails[htlc_id] > ctn: + l.append(self.log[party]['adds'][htlc_id]) + return l + + def htlcs(self, subject, ctn=None): + assert type(subject) is HTLCOwner + if ctn is None: + ctn = self.log[subject]['ctn'] + l = [] + l += [(SENT, x) for x in self.htlcs_by_direction(subject, SENT, ctn)] + l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn)] + return l + + def current_htlcs(self, subject): + assert type(subject) is HTLCOwner + ctn = self.log[subject]['ctn'] + return self.htlcs(subject, ctn) + + def pending_htlcs(self, subject): + assert type(subject) is HTLCOwner + ctn = self.log[subject]['ctn'] + 1 + return self.htlcs(subject, ctn) + + def send_settle(self, htlc_id): + self.log[REMOTE]['settles'][htlc_id] = self.log[REMOTE]['ctn'] + 1 + + def recv_settle(self, htlc_id): + self.log[LOCAL]['settles'][htlc_id] = self.log[LOCAL]['ctn'] + 1 + + def settled_htlcs_by(self, subject, ctn=None): + assert type(subject) is HTLCOwner + if ctn is None: + ctn = self.log[subject]['ctn'] + return [self.log[subject]['adds'][htlc_id] for htlc_id, height in self.log[subject]['settles'].items() if height <= ctn] + + def settled_htlcs(self, subject, ctn=None): + assert type(subject) is HTLCOwner + if ctn is None: + ctn = self.log[subject]['ctn'] + sent = [(SENT, x) for x in self.settled_htlcs_by(subject, ctn)] + other = subject.inverted() + received = [(RECEIVED, x) for x in self.settled_htlcs_by(other, ctn)] + return sent + received + + def received_in_ctn(self, ctn): + return [self.log[REMOTE]['adds'][htlc_id] for htlc_id, height in self.log[REMOTE]['settles'].items() if height == ctn] + + def sent_in_ctn(self, ctn): + return [self.log[LOCAL]['adds'][htlc_id] for htlc_id, height in self.log[LOCAL]['settles'].items() if height == ctn] + + def send_fail(self, htlc_id): + self.log[REMOTE]['fails'][htlc_id] = self.log[REMOTE]['ctn'] + 1 + + def recv_fail(self, htlc_id): + self.log[LOCAL]['fails'][htlc_id] = self.log[LOCAL]['ctn'] + 1 diff --git a/electrum/lnsweep.py b/electrum/lnsweep.py @@ -9,15 +9,15 @@ from .bitcoin import TYPE_ADDRESS, redeem_script_to_address, dust_threshold from . import ecc from .lnutil import (make_commitment_output_to_remote_address, make_commitment_output_to_local_witness_script, derive_privkey, derive_pubkey, derive_blinded_pubkey, derive_blinded_privkey, - make_htlc_tx_witness, make_htlc_tx_with_open_channel, + make_htlc_tx_witness, make_htlc_tx_with_open_channel, UpdateAddHtlc, LOCAL, REMOTE, make_htlc_output_witness_script, UnknownPaymentHash, get_ordered_channel_configs, privkey_to_pubkey, get_per_commitment_secret_from_seed, - RevocationStore, extract_ctn_from_tx_and_chan, UnableToDeriveSecret) + RevocationStore, extract_ctn_from_tx_and_chan, UnableToDeriveSecret, SENT, RECEIVED) from .transaction import Transaction, TxOutput, construct_witness from .simple_config import SimpleConfig, FEERATE_FALLBACK_STATIC_FEE if TYPE_CHECKING: - from .lnchan import Channel, UpdateAddHtlc + from .lnchan import Channel def maybe_create_sweeptx_for_their_ctx_to_remote(ctx: Transaction, sweep_address: str, @@ -106,7 +106,7 @@ def create_sweeptxs_for_their_just_revoked_ctx(chan: 'Channel', ctx: Transaction ctn = extract_ctn_from_tx_and_chan(ctx, chan) assert ctn == chan.config[REMOTE].ctn # received HTLCs, in their ctx - received_htlcs = chan.included_htlcs(REMOTE, LOCAL, False) + received_htlcs = chan.included_htlcs(REMOTE, RECEIVED, ctn) for htlc in received_htlcs: direct_sweep_tx, secondstage_sweep_tx, htlc_tx = create_sweeptx_for_htlc(htlc, is_received_htlc=True) if direct_sweep_tx: @@ -114,7 +114,7 @@ def create_sweeptxs_for_their_just_revoked_ctx(chan: 'Channel', ctx: Transaction if secondstage_sweep_tx: txs[htlc_tx.txid()] = secondstage_sweep_tx # offered HTLCs, in their ctx - offered_htlcs = chan.included_htlcs(REMOTE, REMOTE, False) + offered_htlcs = chan.included_htlcs(REMOTE, SENT, ctn) for htlc in offered_htlcs: direct_sweep_tx, secondstage_sweep_tx, htlc_tx = create_sweeptx_for_htlc(htlc, is_received_htlc=False) if direct_sweep_tx: @@ -181,16 +181,14 @@ def create_sweeptxs_for_our_latest_ctx(chan: 'Channel', ctx: Transaction, is_revocation=False) return htlc_tx, to_wallet_tx # offered HTLCs, in our ctx --> "timeout" - # TODO consider carefully if "included_htlcs" is what we need here - offered_htlcs = list(chan.included_htlcs(LOCAL, LOCAL)) # type: List[UpdateAddHtlc] + # received HTLCs, in our ctx --> "success" + offered_htlcs = chan.included_htlcs(LOCAL, SENT, ctn) # type: List[UpdateAddHtlc] + received_htlcs = chan.included_htlcs(LOCAL, RECEIVED, ctn) # type: List[UpdateAddHtlc] for htlc in offered_htlcs: htlc_tx, to_wallet_tx = create_txns_for_htlc(htlc, is_received_htlc=False) if htlc_tx and to_wallet_tx: txs[to_wallet_tx.prevout(0)] = to_wallet_tx txs[htlc_tx.prevout(0)] = htlc_tx - # received HTLCs, in our ctx --> "success" - # TODO consider carefully if "included_htlcs" is what we need here - received_htlcs = list(chan.included_htlcs(LOCAL, REMOTE)) # type: List[UpdateAddHtlc] for htlc in received_htlcs: htlc_tx, to_wallet_tx = create_txns_for_htlc(htlc, is_received_htlc=True) if htlc_tx and to_wallet_tx: @@ -332,7 +330,7 @@ def create_htlctx_that_spends_from_our_ctx(chan: 'Channel', our_pcp: bytes, htlc=htlc, name=f'our_ctx_htlc_tx_{bh2u(htlc.payment_hash)}', cltv_expiry=0 if is_received_htlc else htlc.cltv_expiry) - remote_htlc_sig = chan.get_remote_htlc_sig_for_htlc(htlc, we_receive=is_received_htlc) + remote_htlc_sig = chan.get_remote_htlc_sig_for_htlc(htlc, we_receive=is_received_htlc, ctx=ctx) local_htlc_sig = bfh(htlc_tx.sign_txin(0, local_htlc_privkey)) txin = htlc_tx.inputs()[0] witness_program = bfh(Transaction.get_preimage_script(txin)) diff --git a/electrum/lnutil.py b/electrum/lnutil.py @@ -21,7 +21,7 @@ from .lnaddr import lndecode from .keystore import BIP32_KeyStore if TYPE_CHECKING: - from .lnchan import Channel, UpdateAddHtlc + from .lnchan import Channel HTLC_TIMEOUT_WEIGHT = 663 @@ -35,7 +35,6 @@ OnlyPubkeyKeypair = namedtuple("OnlyPubkeyKeypair", ["pubkey"]) class LocalConfig(NamedTuple): # shared channel config fields (DUPLICATED code!!) ctn: int - amount_msat: int next_htlc_id: int payment_basepoint: 'Keypair' multisig_key: 'Keypair' @@ -54,12 +53,12 @@ class LocalConfig(NamedTuple): was_announced: bool current_commitment_signature: Optional[bytes] current_htlc_signatures: List[bytes] + got_sig_for_next: bool class RemoteConfig(NamedTuple): # shared channel config fields (DUPLICATED code!!) ctn: int - amount_msat: int next_htlc_id: int payment_basepoint: 'Keypair' multisig_key: 'Keypair' @@ -364,7 +363,7 @@ def make_htlc_tx_with_open_channel(chan: 'Channel', pcp: bytes, for_us: bool, # FIXME handle htlc_address collision # also: https://github.com/lightningnetwork/lightning-rfc/issues/448 prevout_idx = commit.get_output_idx_from_address(htlc_address) - assert prevout_idx is not None + assert prevout_idx is not None, (htlc_address, commit.outputs(), extract_ctn_from_tx_and_chan(commit, chan)) htlc_tx_inputs = make_htlc_tx_inputs( commit.txid(), prevout_idx, amount_msat=amount_msat, @@ -395,11 +394,16 @@ class HTLCOwner(IntFlag): LOCAL = 1 REMOTE = -LOCAL - SENT = LOCAL - RECEIVED = REMOTE + def inverted(self): + return HTLCOwner(-self) + +class Direction(IntFlag): + SENT = 3 + RECEIVED = 4 + +SENT = Direction.SENT +RECEIVED = Direction.RECEIVED -SENT = HTLCOwner.SENT -RECEIVED = HTLCOwner.RECEIVED LOCAL = HTLCOwner.LOCAL REMOTE = HTLCOwner.REMOTE @@ -420,8 +424,7 @@ def make_commitment_outputs(fees_per_participant: Mapping[HTLCOwner, int], local c_outputs_filtered = list(filter(lambda x: x.value >= dust_limit_sat, non_htlc_outputs + htlc_outputs)) return htlc_outputs, c_outputs_filtered -def calc_onchain_fees(num_htlcs, feerate, for_us, we_are_initiator): - we_pay_fee = for_us == we_are_initiator +def calc_onchain_fees(num_htlcs, feerate, we_pay_fee): overall_weight = 500 + 172 * num_htlcs + 224 fee = feerate * overall_weight fee = fee // 1000 * 1000 @@ -451,7 +454,7 @@ def make_commitment(ctn, local_funding_pubkey, remote_funding_pubkey, htlc_outputs, c_outputs_filtered = make_commitment_outputs(fees_per_participant, local_amount, remote_amount, (bitcoin.TYPE_ADDRESS, local_address), (bitcoin.TYPE_ADDRESS, remote_address), htlcs, dust_limit_sat) - assert sum(x.value for x in c_outputs_filtered) <= funding_sat + assert sum(x.value for x in c_outputs_filtered) <= funding_sat, (c_outputs_filtered, funding_sat) # create commitment tx tx = Transaction.from_io(c_inputs, c_outputs_filtered, locktime=locktime, version=2) @@ -649,3 +652,20 @@ def format_short_channel_id(short_channel_id: Optional[bytes]): return str(int.from_bytes(short_channel_id[:3], 'big')) \ + 'x' + str(int.from_bytes(short_channel_id[3:6], 'big')) \ + 'x' + str(int.from_bytes(short_channel_id[6:], 'big')) + +class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])): + """ + This whole class body is so that if you pass a hex-string as payment_hash, + it is decoded to bytes. Bytes can't be saved to disk, so we save hex-strings. + """ + __slots__ = () + def __new__(cls, *args, **kwargs): + if len(args) > 0: + args = list(args) + if type(args[1]) is str: + args[1] = bfh(args[1]) + return super().__new__(cls, *args) + if type(kwargs['payment_hash']) is str: + kwargs['payment_hash'] = bfh(kwargs['payment_hash']) + return super().__new__(cls, **kwargs) + diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -29,13 +29,14 @@ from .lntransport import LNResponderTransport from .lnbase import Peer from .lnaddr import lnencode, LnAddr, lndecode from .ecc import der_sig_from_sig_string -from .lnchan import Channel, ChannelJsonEncoder, UpdateAddHtlc +from .lnchan import Channel, ChannelJsonEncoder from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr, get_compressed_pubkey_from_bech32, extract_nodeid, PaymentFailure, split_host_port, ConnStringFormatError, generate_keypair, LnKeyFamily, LOCAL, REMOTE, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE, - NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner) + NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner, + UpdateAddHtlc, Direction) from .i18n import _ from .lnrouter import RouteEdge, is_route_sane_to_use from .address_synchronizer import TX_HEIGHT_LOCAL @@ -66,7 +67,7 @@ class LNWorker(PrintError): def __init__(self, wallet: 'Abstract_Wallet', network: 'Network'): self.wallet = wallet # invoices we are currently trying to pay (might be pending HTLCs on a commitment transaction) - self.paying = self.wallet.storage.get('lightning_payments_inflight', {}) # type: Dict[bytes, Tuple[str, Optional[int], bytes]] + self.paying = self.wallet.storage.get('lightning_payments_inflight', {}) # type: Dict[bytes, Tuple[str, Optional[int], str]] self.sweep_address = wallet.get_receiving_address() self.network = network self.channel_db = self.network.channel_db @@ -75,12 +76,15 @@ class LNWorker(PrintError): self.node_keypair = generate_keypair(self.ln_keystore, LnKeyFamily.NODE_KEY, 0) self.config = network.config self.peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer + self.invoices = wallet.storage.get('lightning_invoices', {}) # type: Dict[str, Tuple[str,str]] # RHASH -> (preimage, invoice) self.channels = {} # type: Dict[bytes, Channel] for x in wallet.storage.get("channels", []): c = Channel(x, sweep_address=self.sweep_address, payment_completed=self.payment_completed) - self.channels[c.channel_id] = c c.lnwatcher = network.lnwatcher - self.invoices = wallet.storage.get('lightning_invoices', {}) # type: Dict[str, Tuple[str,str]] # RHASH -> (preimage, invoice) + c.get_preimage_and_invoice = self.get_invoice + self.channels[c.channel_id] = c + c.set_remote_commitment() + c.set_local_commitment(c.current_commitment(LOCAL)) for chan_id, chan in self.channels.items(): self.network.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str()) self._last_tried_peer = {} # LNPeerAddr -> unix timestamp @@ -116,6 +120,7 @@ class LNWorker(PrintError): self.print_error('saved lightning gossip timestamp') def payment_completed(self, chan, direction, htlc, preimage): + assert type(direction) is Direction chan_id = chan.channel_id if direction == SENT: assert htlc.payment_hash not in self.invoices @@ -166,6 +171,7 @@ class LNWorker(PrintError): unsettled = [] inflight = [] for date, direction, htlc, hex_preimage, hex_chan_id in completed: + direction = Direction(direction) if chan_id is not None: if bfh(hex_chan_id) != chan_id: continue @@ -175,12 +181,12 @@ class LNWorker(PrintError): else: preimage = bfh(hex_preimage) # FIXME use fromisoformat when minimum Python is 3.7 - settled.append((datetime.fromtimestamp(date, timezone.utc), HTLCOwner(direction), htlcobj, preimage)) + settled.append((datetime.fromtimestamp(date, timezone.utc), direction, htlcobj, preimage)) for preimage, pay_req in invoices.values(): addr = lndecode(pay_req, expected_hrp=constants.net.SEGWIT_HRP) unsettled.append((addr, bfh(preimage), pay_req)) for pay_req, amount_sat, this_chan_id in self.paying.values(): - if chan_id is not None and this_chan_id != chan_id: + if chan_id is not None and bfh(this_chan_id) != chan_id: continue addr = lndecode(pay_req, expected_hrp=constants.net.SEGWIT_HRP) if amount_sat is not None: @@ -194,7 +200,7 @@ class LNWorker(PrintError): def find_htlc_for_addr(self, addr, whitelist=None): channels = [y for x,y in self.channels.items() if x in whitelist or whitelist is None] for chan in channels: - for htlc in chan.log[LOCAL].adds.values(): + for htlc in chan.hm.log[LOCAL]['adds'].values(): if htlc.payment_hash == addr.paymenthash: return htlc @@ -319,7 +325,7 @@ class LNWorker(PrintError): self.print_error('they force closed', funding_outpoint) encumbered_sweeptxs = chan.remote_sweeptxs else: - self.print_error('not sure who closed', funding_outpoint) + self.print_error('not sure who closed', funding_outpoint, txid) return # sweep for prevout, spender in spenders.items(): @@ -456,7 +462,7 @@ class LNWorker(PrintError): break else: assert False, 'Found route with short channel ID we don\'t have: ' + repr(route[0].short_channel_id) - self.paying[bh2u(addr.paymenthash)] = (invoice, amount_sat, chan_id) + self.paying[bh2u(addr.paymenthash)] = (invoice, amount_sat, bh2u(chan_id)) self.wallet.storage.put('lightning_payments_inflight', self.paying) self.wallet.storage.write() return addr, peer, self._pay_to_route(route, addr) @@ -623,8 +629,8 @@ class LNWorker(PrintError): # we output the funding_outpoint instead of the channel_id because lnd uses channel_point (funding outpoint) to identify channels for channel_id, chan in self.channels.items(): yield { - 'local_htlcs': json.loads(encoder.encode(chan.log[LOCAL ]._asdict())), - 'remote_htlcs': json.loads(encoder.encode(chan.log[REMOTE]._asdict())), + 'local_htlcs': json.loads(encoder.encode(chan.hm.log[LOCAL ])), + 'remote_htlcs': json.loads(encoder.encode(chan.hm.log[REMOTE])), 'channel_id': bh2u(chan.short_channel_id), 'channel_point': chan.funding_outpoint.to_str(), 'state': chan.get_state(), diff --git a/electrum/tests/test_lnchan.py b/electrum/tests/test_lnchan.py @@ -22,6 +22,7 @@ import unittest import os import binascii +from pprint import pformat from electrum import bitcoin from electrum import lnbase @@ -30,6 +31,7 @@ from electrum import lnutil from electrum import bip32 as bip32_utils from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED from electrum.ecc import sig_string_from_der_sig +from electrum.util import set_verbosity one_bitcoin_in_msat = bitcoin.COIN * 1000 @@ -54,9 +56,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate max_htlc_value_in_flight_msat=one_bitcoin_in_msat * 5, max_accepted_htlcs=5, initial_msat=remote_amount, - ctn = 0, + ctn = -1, next_htlc_id = 0, - amount_msat=remote_amount, reserve_sat=0, next_per_commitment_point=nex, @@ -76,7 +77,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate initial_msat=local_amount, ctn = 0, next_htlc_id = 0, - amount_msat=local_amount, reserve_sat=0, per_commitment_secret_seed=seed, @@ -84,6 +84,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate was_announced=False, current_commitment_signature=None, current_htlc_signatures=None, + got_sig_for_next=False, ), "constraints":lnbase.ChannelConstraints( capacity=funding_sat, @@ -105,7 +106,7 @@ def bip32(sequence): return k def create_test_channels(feerate=6000, local=None, remote=None): - funding_txid = binascii.hexlify(os.urandom(32)).decode("ascii") + funding_txid = binascii.hexlify(b"\x01"*32).decode("ascii") funding_index = 0 funding_sat = ((local + remote) // 1000) if local is not None and remote is not None else (bitcoin.COIN * 10) local_amount = local if local is not None else (funding_sat * 1000 // 2) @@ -117,23 +118,52 @@ def create_test_channels(feerate=6000, local=None, remote=None): alice_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in alice_privkeys] bob_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in bob_privkeys] - alice_seed = os.urandom(32) - bob_seed = os.urandom(32) + alice_seed = b"\x01" * 32 + bob_seed = b"\x02" * 32 - alice_cur = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX), "big")) - alice_next = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX - 1), "big")) - bob_cur = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX), "big")) - bob_next = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX - 1), "big")) + alice_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX), "big")) + bob_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX), "big")) alice, bob = \ lnchan.Channel( - create_channel_state(funding_txid, funding_index, funding_sat, feerate, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, bob_cur, bob_next, b"\x02"*33, l_dust=200, r_dust=1300, l_csv=5, r_csv=4), name="alice"), \ + create_channel_state(funding_txid, funding_index, funding_sat, feerate, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None, bob_first, b"\x02"*33, l_dust=200, r_dust=1300, l_csv=5, r_csv=4), name="alice"), \ lnchan.Channel( - create_channel_state(funding_txid, funding_index, funding_sat, feerate, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, alice_cur, alice_next, b"\x01"*33, l_dust=1300, r_dust=200, l_csv=4, r_csv=5), name="bob") + create_channel_state(funding_txid, funding_index, funding_sat, feerate, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, None, alice_first, b"\x01"*33, l_dust=1300, r_dust=200, l_csv=4, r_csv=5), name="bob") alice.set_state('OPEN') bob.set_state('OPEN') + a_out = alice.current_commitment(LOCAL).outputs() + b_out = bob.pending_commitment(REMOTE).outputs() + assert a_out == b_out, "\n" + pformat((a_out, b_out)) + + sig_from_bob, a_htlc_sigs = bob.sign_next_commitment() + sig_from_alice, b_htlc_sigs = alice.sign_next_commitment() + + assert len(a_htlc_sigs) == 0 + assert len(b_htlc_sigs) == 0 + + alice.config[LOCAL] = alice.config[LOCAL]._replace(current_commitment_signature=sig_from_bob) + bob.config[LOCAL] = bob.config[LOCAL]._replace(current_commitment_signature=sig_from_alice) + + alice.set_local_commitment(alice.current_commitment(LOCAL)) + bob.set_local_commitment(bob.current_commitment(LOCAL)) + + alice_second = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX - 1), "big")) + bob_second = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX - 1), "big")) + + alice.config[REMOTE] = alice.config[REMOTE]._replace(next_per_commitment_point=bob_second, current_per_commitment_point=bob_first) + bob.config[REMOTE] = bob.config[REMOTE]._replace(next_per_commitment_point=alice_second, current_per_commitment_point=alice_first) + + alice.set_remote_commitment() + bob.set_remote_commitment() + + alice.remote_commitment_to_be_revoked = alice.remote_commitment + bob.remote_commitment_to_be_revoked = bob.remote_commitment + + alice.config[REMOTE] = alice.config[REMOTE]._replace(ctn=0) + bob.config[REMOTE] = bob.config[REMOTE]._replace(ctn=0) + return alice, bob class TestFee(unittest.TestCase): @@ -141,11 +171,13 @@ class TestFee(unittest.TestCase): test https://github.com/lightningnetwork/lightning-rfc/blob/e0c436bd7a3ed6a028e1cb472908224658a14eca/03-transactions.md#requirements-2 """ - def test_SimpleAddSettleWorkflow(self): + def test_fee(self): alice_channel, bob_channel = create_test_channels(253, 10000000000, 5000000000) self.assertIn(9999817, [x[2] for x in alice_channel.local_commitment.outputs()]) class TestChannel(unittest.TestCase): + maxDiff = 999 + def assertOutputExistsByValue(self, tx, amt_sat): for typ, scr, val in tx.outputs(): if val == amt_sat: @@ -153,6 +185,10 @@ class TestChannel(unittest.TestCase): else: self.assertFalse() + @staticmethod + def setUpClass(): + set_verbosity(True) + def setUp(self): # Create a test channel which will be used for the duration of this # unittest. The channel will be funded evenly with Alice having 5 BTC, @@ -171,12 +207,15 @@ class TestChannel(unittest.TestCase): # update log. Then Alice sends this wire message over to Bob who adds # this htlc to his remote state update log. self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict) + self.assertNotEqual(self.alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED, 1), set()) before = self.bob_channel.balance_minus_outgoing_htlcs(REMOTE) beforeLocal = self.bob_channel.balance_minus_outgoing_htlcs(LOCAL) self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc_dict) + self.assertEqual(1, self.bob_channel.hm.log[LOCAL]['ctn'] + 1) + self.assertNotEqual(self.bob_channel.hm.htlcs_by_direction(LOCAL, RECEIVED, 1), set()) after = self.bob_channel.balance_minus_outgoing_htlcs(REMOTE) afterLocal = self.bob_channel.balance_minus_outgoing_htlcs(LOCAL) @@ -185,7 +224,7 @@ class TestChannel(unittest.TestCase): self.bob_pending_remote_balance = after - self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0] + self.htlc = self.bob_channel.hm.log[REMOTE]['adds'][0] def test_concurrent_reversed_payment(self): self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02') @@ -193,32 +232,65 @@ class TestChannel(unittest.TestCase): bob_idx = self.bob_channel.add_htlc(self.htlc_dict) alice_idx = self.alice_channel.receive_htlc(self.htlc_dict) self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment()) - self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 3) + self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 4) def test_SimpleAddSettleWorkflow(self): alice_channel, bob_channel = self.alice_channel, self.bob_channel htlc = self.htlc + alice_out = alice_channel.current_commitment(LOCAL).outputs() + short_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 42] + long_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 62] + self.assertLess(alice_out[long_idx].value, 5 * 10**8, alice_out) + self.assertEqual(alice_out[short_idx].value, 5 * 10**8, alice_out) + + alice_out = alice_channel.current_commitment(REMOTE).outputs() + short_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 42] + long_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 62] + self.assertLess(alice_out[short_idx].value, 5 * 10**8) + self.assertEqual(alice_out[long_idx].value, 5 * 10**8) + + def com(): + return alice_channel.local_commitment + + self.assertTrue(alice_channel.signature_fits(com())) + + self.assertNotEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 1), []) self.assertEqual({0: [], 1: [htlc]}, alice_channel.included_htlcs_in_their_latest_ctxs(LOCAL)) - self.assertEqual({0: [], 1: []}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE)) + self.assertNotEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), []) + self.assertEqual({0: [], 1: [htlc]}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE)) self.assertEqual({0: [], 1: []}, alice_channel.included_htlcs_in_their_latest_ctxs(REMOTE)) self.assertEqual({0: [], 1: []}, bob_channel.included_htlcs_in_their_latest_ctxs(LOCAL)) - # this wouldn't work since we put None in the remote_sig - # alice_channel.force_close_tx() + from electrum.lnutil import extract_ctn_from_tx_and_chan + tx0 = str(alice_channel.force_close_tx()) + self.assertEqual(alice_channel.config[LOCAL].ctn, 0) + self.assertEqual(extract_ctn_from_tx_and_chan(alice_channel.force_close_tx(), alice_channel), 0) + self.assertTrue(alice_channel.signature_fits(alice_channel.current_commitment(LOCAL))) # Next alice commits this change by sending a signature message. Since # we expect the messages to be ordered, Bob will receive the HTLC we # just sent before he receives this signature, so the signature will # cover the HTLC. aliceSig, aliceHtlcSigs = alice_channel.sign_next_commitment() - self.assertEqual(len(aliceHtlcSigs), 1, "alice should generate one htlc signature") + self.assertTrue(alice_channel.signature_fits(com())) + self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com())) + + self.assertEqual(next(iter(alice_channel.hm.pending_htlcs(REMOTE)))[0], RECEIVED) + self.assertEqual(alice_channel.hm.pending_htlcs(REMOTE), bob_channel.hm.pending_htlcs(LOCAL)) + self.assertEqual(alice_channel.pending_commitment(REMOTE).outputs(), bob_channel.pending_commitment(LOCAL).outputs()) + # Bob receives this signature message, and checks that this covers the # state he has in his remote log. This includes the HTLC just sent # from Alice. + self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL))) bob_channel.receive_new_commitment(aliceSig, aliceHtlcSigs) + self.assertTrue(bob_channel.signature_fits(bob_channel.pending_commitment(LOCAL))) + + self.assertEqual(bob_channel.config[REMOTE].ctn, 0) + self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), [htlc]) self.assertEqual({0: [], 1: [htlc]}, alice_channel.included_htlcs_in_their_latest_ctxs(LOCAL)) self.assertEqual({0: [], 1: [htlc]}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE)) @@ -228,31 +300,68 @@ class TestChannel(unittest.TestCase): # Bob revokes his prior commitment given to him by Alice, since he now # has a valid signature for a newer commitment. bobRevocation, _ = bob_channel.revoke_current_commitment() + bob_channel.serialize() + self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL))) - # Bob finally send a signature for Alice's commitment transaction. + # Bob finally sends a signature for Alice's commitment transaction. # This signature will cover the HTLC, since Bob will first send the # revocation just created. The revocation also acks every received - # HTLC up to the point where Alice sent here signature. + # HTLC up to the point where Alice sent her signature. bobSig, bobHtlcSigs = bob_channel.sign_next_commitment() + self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL))) + + self.assertEqual(len(bobHtlcSigs), 1) + + self.assertTrue(alice_channel.signature_fits(com())) + self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com())) + + self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 3) # Alice then processes this revocation, sending her own revocation for # her prior commitment transaction. Alice shouldn't have any HTLCs to # forward since she's sending an outgoing HTLC. alice_channel.receive_revocation(bobRevocation) + alice_channel.serialize() + self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs()) + + self.assertTrue(alice_channel.signature_fits(com())) + self.assertTrue(alice_channel.signature_fits(alice_channel.current_commitment(LOCAL))) + alice_channel.serialize() + self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com())) - # test serializing with locked_in htlc - self.assertEqual(len(alice_channel.to_save()['local_log']), 1) + self.assertEqual(len(alice_channel.current_commitment(LOCAL).outputs()), 2) + self.assertEqual(len(alice_channel.current_commitment(REMOTE).outputs()), 3) + self.assertEqual(len(com().outputs()), 2) + self.assertEqual(len(alice_channel.force_close_tx().outputs()), 2) + + self.assertEqual(alice_channel.hm.log.keys(), set([LOCAL, REMOTE])) + self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1) alice_channel.serialize() + self.assertEqual(alice_channel.pending_commitment(LOCAL).outputs(), + bob_channel.pending_commitment(REMOTE).outputs()) + # Alice then processes bob's signature, and since she just received # the revocation, she expect this signature to cover everything up to # the point where she sent her signature, including the HTLC. alice_channel.receive_new_commitment(bobSig, bobHtlcSigs) + self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs()) + + self.assertEqual(len(alice_channel.current_commitment(REMOTE).outputs()), 3) + self.assertEqual(len(com().outputs()), 3) + self.assertEqual(len(alice_channel.force_close_tx().outputs()), 3) + + self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1) + alice_channel.serialize() tx1 = str(alice_channel.force_close_tx()) + self.assertNotEqual(tx0, tx1) # Alice then generates a revocation for bob. + self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs()) aliceRevocation, _ = alice_channel.revoke_current_commitment() + alice_channel.serialize() + #self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs()) tx2 = str(alice_channel.force_close_tx()) # since alice already has the signature for the next one, it doesn't change her force close tx (it was already the newer one) @@ -262,7 +371,9 @@ class TestChannel(unittest.TestCase): # is fully locked in within both commitment transactions. Bob should # also be able to forward an HTLC now that the HTLC has been locked # into both commitment transactions. + self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL))) bob_channel.receive_revocation(aliceRevocation) + bob_channel.serialize() # At this point, both sides should have the proper number of satoshis # sent, and commitment height updated within their local channel @@ -279,16 +390,19 @@ class TestChannel(unittest.TestCase): # Both commitment transactions should have three outputs, and one of # them should be exactly the amount of the HTLC. - self.assertEqual(len(alice_channel.local_commitment.outputs()), 3, "alice should have three commitment outputs, instead have %s"% len(alice_channel.local_commitment.outputs())) - self.assertEqual(len(bob_channel.local_commitment.outputs()), 3, "bob should have three commitment outputs, instead have %s"% len(bob_channel.local_commitment.outputs())) - self.assertOutputExistsByValue(alice_channel.local_commitment, htlc.amount_msat // 1000) - self.assertOutputExistsByValue(bob_channel.local_commitment, htlc.amount_msat // 1000) + alice_ctx = alice_channel.pending_commitment(LOCAL) + bob_ctx = bob_channel.pending_commitment(LOCAL) + self.assertEqual(len(alice_ctx.outputs()), 3, "alice should have three commitment outputs, instead have %s"% len(alice_ctx.outputs())) + self.assertEqual(len(bob_ctx.outputs()), 3, "bob should have three commitment outputs, instead have %s"% len(bob_ctx.outputs())) + self.assertOutputExistsByValue(alice_ctx, htlc.amount_msat // 1000) + self.assertOutputExistsByValue(bob_ctx, htlc.amount_msat // 1000) # Now we'll repeat a similar exchange, this time with Bob settling the # HTLC once he learns of the preimage. preimage = self.paymentPreimage bob_channel.settle_htlc(preimage, self.bobHtlcIndex) + #self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs()) alice_channel.receive_htlc_settle(preimage, self.aliceHtlcIndex) tx3 = str(alice_channel.force_close_tx()) @@ -296,28 +410,43 @@ class TestChannel(unittest.TestCase): self.assertEqual(tx2, tx3) bobSig2, bobHtlcSigs2 = bob_channel.sign_next_commitment() + self.assertEqual(len(bobHtlcSigs2), 0) + self.assertEqual(alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED), [htlc]) + self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, alice_channel.config[REMOTE].ctn), [htlc]) self.assertEqual({1: [htlc], 2: []}, alice_channel.included_htlcs_in_their_latest_ctxs(LOCAL)) self.assertEqual({1: [htlc], 2: []}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE)) self.assertEqual({1: [], 2: []}, alice_channel.included_htlcs_in_their_latest_ctxs(REMOTE)) self.assertEqual({1: [], 2: []}, bob_channel.included_htlcs_in_their_latest_ctxs(LOCAL)) + alice_ctx_bob_version = bob_channel.pending_commitment(REMOTE).outputs() + alice_ctx_alice_version = alice_channel.pending_commitment(LOCAL).outputs() + self.assertEqual(alice_ctx_alice_version, alice_ctx_bob_version) + alice_channel.receive_new_commitment(bobSig2, bobHtlcSigs2) tx4 = str(alice_channel.force_close_tx()) self.assertNotEqual(tx3, tx4) + self.assertEqual(alice_channel.balance(LOCAL), 500000000000) + self.assertEqual(1, alice_channel.config[LOCAL].ctn) + self.assertEqual(len(alice_channel.included_htlcs(LOCAL, RECEIVED, ctn=2)), 0) aliceRevocation2, _ = alice_channel.revoke_current_commitment() + alice_channel.serialize() aliceSig2, aliceHtlcSigs2 = alice_channel.sign_next_commitment() self.assertEqual(aliceHtlcSigs2, [], "alice should generate no htlc signatures") - + self.assertEqual(len(bob_channel.current_commitment(LOCAL).outputs()), 3) + self.assertEqual(len(bob_channel.pending_commitment(LOCAL).outputs()), 2) received, sent = bob_channel.receive_revocation(aliceRevocation2) + bob_channel.serialize() self.assertEqual(received, one_bitcoin_in_msat) bob_channel.receive_new_commitment(aliceSig2, aliceHtlcSigs2) bobRevocation2, _ = bob_channel.revoke_current_commitment() + bob_channel.serialize() alice_channel.receive_revocation(bobRevocation2) + alice_channel.serialize() # At this point, Bob should have 6 BTC settled, with Alice still having # 4 BTC. Alice's channel should show 1 BTC sent and Bob's channel @@ -331,15 +460,15 @@ class TestChannel(unittest.TestCase): self.assertEqual(bob_channel.current_height[LOCAL], 2, "bob has incorrect commitment height") self.assertEqual(alice_channel.current_height[LOCAL], 2, "alice has incorrect commitment height") - # The logs of both sides should now be cleared since the entry adding - # the HTLC should have been removed once both sides receive the - # revocation. - #self.assertEqual(alice_channel.local_update_log, [], "alice's local not updated, should be empty, has %s entries instead"% len(alice_channel.local_update_log)) - #self.assertEqual(alice_channel.remote_update_log, [], "alice's remote not updated, should be empty, has %s entries instead"% len(alice_channel.remote_update_log)) self.assertEqual(self.bob_pending_remote_balance, self.alice_channel.balance(LOCAL)) alice_channel.update_fee(100000, True) + alice_outputs = alice_channel.pending_commitment(REMOTE).outputs() + old_outputs = bob_channel.pending_commitment(LOCAL).outputs() bob_channel.update_fee(100000, False) + new_outputs = bob_channel.pending_commitment(LOCAL).outputs() + self.assertNotEqual(old_outputs, new_outputs) + self.assertEqual(alice_outputs, new_outputs) tx5 = str(alice_channel.force_close_tx()) # sending a fee update does not change her force close tx @@ -353,10 +482,17 @@ class TestChannel(unittest.TestCase): self.htlc_dict['amount_msat'] *= 5 bob_index = bob_channel.add_htlc(self.htlc_dict) alice_index = alice_channel.receive_htlc(self.htlc_dict) - force_state_transition(alice_channel, bob_channel) + + bob_channel.pending_commitment(REMOTE) + alice_channel.pending_commitment(LOCAL) + + alice_channel.pending_commitment(REMOTE) + bob_channel.pending_commitment(LOCAL) + + force_state_transition(bob_channel, alice_channel) alice_channel.settle_htlc(self.paymentPreimage, alice_index) bob_channel.receive_htlc_settle(self.paymentPreimage, bob_index) - force_state_transition(alice_channel, bob_channel) + force_state_transition(bob_channel, alice_channel) self.assertEqual(alice_channel.total_msat(SENT), one_bitcoin_in_msat, "alice satoshis sent incorrect") self.assertEqual(alice_channel.total_msat(RECEIVED), 5 * one_bitcoin_in_msat, "alice satoshis received incorrect") self.assertEqual(bob_channel.total_msat(RECEIVED), one_bitcoin_in_msat, "bob satoshis received incorrect") @@ -366,8 +502,15 @@ class TestChannel(unittest.TestCase): def alice_to_bob_fee_update(self, fee=111): + aoldctx = self.alice_channel.pending_commitment(REMOTE).outputs() self.alice_channel.update_fee(fee, True) + anewctx = self.alice_channel.pending_commitment(REMOTE).outputs() + self.assertNotEqual(aoldctx, anewctx) + boldctx = self.bob_channel.pending_commitment(LOCAL).outputs() self.bob_channel.update_fee(fee, False) + bnewctx = self.bob_channel.pending_commitment(LOCAL).outputs() + self.assertNotEqual(boldctx, bnewctx) + self.assertEqual(anewctx, bnewctx) return fee def test_UpdateFeeSenderCommits(self): @@ -444,7 +587,7 @@ class TestChannel(unittest.TestCase): # value 2 BTC, which should make Alice's balance negative (since she # has to pay a commitment fee). new = dict(self.htlc_dict) - new['amount_msat'] *= 2 + new['amount_msat'] *= 2.5 new['payment_hash'] = bitcoin.sha256(32 * b'\x04') with self.assertRaises(lnutil.PaymentFailure) as cm: self.alice_channel.add_htlc(new) @@ -462,7 +605,6 @@ class TestChannel(unittest.TestCase): except: try: from deepdiff import DeepDiff - from pprint import pformat except ImportError: raise raise Exception(pformat(DeepDiff(before_signing, after_signing))) @@ -549,9 +691,9 @@ class TestChanReserve(unittest.TestCase): force_state_transition(self.alice_channel, self.bob_channel) aliceSelfBalance = self.alice_channel.balance(LOCAL)\ - - lnchan.htlcsum(self.alice_channel.htlcs(LOCAL, True)) + - lnchan.htlcsum(self.alice_channel.hm.htlcs_by_direction(LOCAL, SENT)) bobBalance = self.bob_channel.balance(REMOTE)\ - - lnchan.htlcsum(self.alice_channel.htlcs(REMOTE, True)) + - lnchan.htlcsum(self.alice_channel.hm.htlcs_by_direction(REMOTE, SENT)) self.assertEqual(aliceSelfBalance, one_bitcoin_in_msat*4.5) self.assertEqual(bobBalance, one_bitcoin_in_msat*5) # Now let Bob try to add an HTLC. This should fail, since it will @@ -647,17 +789,22 @@ class TestDust(unittest.TestCase): 'cltv_expiry' : 5, # also in create_test_channels } + old_values = [x.value for x in bob_channel.current_commitment(LOCAL).outputs() ] aliceHtlcIndex = alice_channel.add_htlc(htlc) bobHtlcIndex = bob_channel.receive_htlc(htlc) force_state_transition(alice_channel, bob_channel) - self.assertEqual(len(alice_channel.local_commitment.outputs()), 3) - self.assertEqual(len(bob_channel.local_commitment.outputs()), 2) + alice_ctx = alice_channel.current_commitment(LOCAL) + bob_ctx = bob_channel.current_commitment(LOCAL) + new_values = [x.value for x in bob_ctx.outputs() ] + self.assertNotEqual(old_values, new_values) + self.assertEqual(len(alice_ctx.outputs()), 3) + self.assertEqual(len(bob_ctx.outputs()), 2) default_fee = calc_static_fee(0) self.assertEqual(bob_channel.pending_local_fee(), default_fee + htlcAmt) bob_channel.settle_htlc(paymentPreimage, bobHtlcIndex) alice_channel.receive_htlc_settle(paymentPreimage, aliceHtlcIndex) force_state_transition(bob_channel, alice_channel) - self.assertEqual(len(alice_channel.local_commitment.outputs()), 2) + self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 2) self.assertEqual(alice_channel.total_msat(SENT) // 1000, htlcAmt) def force_state_transition(chanA, chanB): diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py @@ -0,0 +1,95 @@ +import unittest +from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner +from electrum.lnhtlc import HTLCManager +from typing import NamedTuple + +class H(NamedTuple): + owner : str + htlc_id : int + +class TestHTLCManager(unittest.TestCase): + def test_race(self): + A = HTLCManager() + B = HTLCManager() + ah0, bh0 = H('A', 0), H('B', 0) + B.recv_htlc(A.send_htlc(ah0)) + self.assertTrue(B.expect_sig[RECEIVED]) + self.assertTrue(A.expect_sig[SENT]) + self.assertFalse(B.expect_sig[SENT]) + self.assertFalse(A.expect_sig[RECEIVED]) + self.assertEqual(B.log[REMOTE]['locked_in'][0][LOCAL], 1) + A.recv_htlc(B.send_htlc(bh0)) + self.assertTrue(B.expect_sig[RECEIVED]) + self.assertTrue(A.expect_sig[SENT]) + self.assertTrue(A.expect_sig[SENT]) + self.assertTrue(B.expect_sig[RECEIVED]) + self.assertEqual(B.current_htlcs(LOCAL), []) + self.assertEqual(A.current_htlcs(LOCAL), []) + self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0)]) + self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0)]) + A.send_ctx() + B.recv_ctx() + B.send_ctx() + A.recv_ctx() + self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1]) + self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1]) + B.send_rev() + A.recv_rev() + A.send_rev() + B.recv_rev() + self.assertEqual(B.current_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1]) + self.assertEqual(A.current_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1]) + + def test_no_race(self): + A = HTLCManager() + B = HTLCManager() + B.recv_htlc(A.send_htlc(H('A', 0))) + self.assertEqual(len(B.pending_htlcs(REMOTE)), 1) + A.send_ctx() + B.recv_ctx() + B.send_rev() + A.recv_rev() + B.send_ctx() + A.recv_ctx() + A.send_rev() + B.recv_rev() + self.assertEqual(len(A.current_htlcs(LOCAL)), 1) + self.assertEqual(len(B.current_htlcs(LOCAL)), 1) + B.send_settle(0) + A.recv_settle(0) + self.assertEqual(A.htlcs_by_direction(REMOTE, RECEIVED), [H('A', 0)]) + self.assertNotEqual(A.current_htlcs(LOCAL), []) + self.assertNotEqual(B.current_htlcs(REMOTE), []) + self.assertEqual(A.pending_htlcs(LOCAL), []) + self.assertEqual(B.pending_htlcs(REMOTE), []) + B.send_ctx() + A.recv_ctx() + A.send_rev() + B.recv_rev() + A.send_ctx() + B.recv_ctx() + B.send_rev() + A.recv_rev() + self.assertEqual(B.current_htlcs(LOCAL), []) + self.assertEqual(A.current_htlcs(LOCAL), []) + self.assertEqual(A.current_htlcs(REMOTE), []) + self.assertEqual(B.current_htlcs(REMOTE), []) + self.assertEqual(len(A.settled_htlcs(LOCAL)), 1) + self.assertEqual(len(A.sent_in_ctn(2)), 1) + self.assertEqual(len(B.received_in_ctn(2)), 1) + + def test_settle_while_owing(self): + A = HTLCManager() + B = HTLCManager() + B.recv_htlc(A.send_htlc(H('A', 0))) + A.send_ctx() + B.recv_ctx() + B.send_rev() + A.recv_rev() + B.send_settle(0) + A.recv_settle(0) + self.assertEqual(B.pending_htlcs(REMOTE), []) + B.send_ctx() + A.recv_ctx() + A.send_rev() + B.recv_rev() diff --git a/electrum/tests/test_lnutil.py b/electrum/tests/test_lnutil.py @@ -6,8 +6,7 @@ from electrum.lnutil import (RevocationStore, get_per_commitment_secret_from_see make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, derive_privkey, derive_pubkey, make_htlc_tx, extract_ctn_from_tx, UnableToDeriveSecret, get_compressed_pubkey_from_bech32, split_host_port, ConnStringFormatError, - ScriptHtlc, extract_nodeid, calc_onchain_fees) -from electrum import lnchan + ScriptHtlc, extract_nodeid, calc_onchain_fees, UpdateAddHtlc) from electrum.util import bh2u, bfh from electrum.transaction import Transaction @@ -496,7 +495,7 @@ class TestLNUtil(unittest.TestCase): (1, 2000 * 1000), (3, 3000 * 1000), (4, 4000 * 1000)]: - htlc_obj[num] = lnchan.UpdateAddHtlc(amount_msat=msat, payment_hash=bitcoin.sha256(htlc_payment_preimage[num]), cltv_expiry=None, htlc_id=None) + htlc_obj[num] = UpdateAddHtlc(amount_msat=msat, payment_hash=bitcoin.sha256(htlc_payment_preimage[num]), cltv_expiry=None, htlc_id=None) htlcs = [ScriptHtlc(htlc[x], htlc_obj[x]) for x in range(5)] our_commit_tx = make_commitment( @@ -506,7 +505,7 @@ class TestLNUtil(unittest.TestCase): local_revocation_pubkey, local_delayedpubkey, local_delay, funding_tx_id, funding_output_index, funding_amount_satoshi, to_local_msat, to_remote_msat, local_dust_limit_satoshi, - calc_onchain_fees(len(htlcs), local_feerate_per_kw, True, we_are_initiator=True), htlcs=htlcs) + calc_onchain_fees(len(htlcs), local_feerate_per_kw, True), htlcs=htlcs) self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey) self.assertEqual(str(our_commit_tx), output_commit_tx) @@ -584,7 +583,7 @@ class TestLNUtil(unittest.TestCase): local_revocation_pubkey, local_delayedpubkey, local_delay, funding_tx_id, funding_output_index, funding_amount_satoshi, to_local_msat, to_remote_msat, local_dust_limit_satoshi, - calc_onchain_fees(0, local_feerate_per_kw, True, we_are_initiator=True), htlcs=[]) + calc_onchain_fees(0, local_feerate_per_kw, True), htlcs=[]) self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey) self.assertEqual(str(our_commit_tx), output_commit_tx) @@ -603,7 +602,7 @@ class TestLNUtil(unittest.TestCase): local_revocation_pubkey, local_delayedpubkey, local_delay, funding_tx_id, funding_output_index, funding_amount_satoshi, to_local_msat, to_remote_msat, local_dust_limit_satoshi, - calc_onchain_fees(0, local_feerate_per_kw, True, we_are_initiator=True), htlcs=[]) + calc_onchain_fees(0, local_feerate_per_kw, True), htlcs=[]) self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey) self.assertEqual(str(our_commit_tx), output_commit_tx) @@ -661,7 +660,7 @@ class TestLNUtil(unittest.TestCase): local_revocation_pubkey, local_delayedpubkey, local_delay, funding_tx_id, funding_output_index, funding_amount_satoshi, to_local_msat, to_remote_msat, local_dust_limit_satoshi, - calc_onchain_fees(0, local_feerate_per_kw, True, we_are_initiator=True), htlcs=[]) + calc_onchain_fees(0, local_feerate_per_kw, True), htlcs=[]) self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey) ref_commit_tx_str = '02000000000101bef67e4e2fb9ddeeb3461973cd4c62abb35050b1add772995b820b584a488489000000000038b02b8002c0c62d0000000000160014ccf1af2f2aabee14bb40fa3851ab2301de84311054a56a00000000002200204adb4e2f00643db396dd120d4e7dc17625f5f2c11a40d857accc862d6b7dd80e0400473044022051b75c73198c6deee1a875871c3961832909acd297c6b908d59e3319e5185a46022055c419379c5051a78d00dbbce11b5b664a0c22815fbcc6fcef6b1937c383693901483045022100f51d2e566a70ba740fc5d8c0f07b9b93d2ed741c3c0860c613173de7d39e7968022041376d520e9c0e1ad52248ddf4b22e12be8763007df977253ef45a4ca3bdb7c001475221023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb21030e9f7b623d2ccc7c9bd44d66d5ce21ce504c0acf6385a132cec6d3c39fa711c152ae3e195220' self.assertEqual(str(our_commit_tx), ref_commit_tx_str)