commit e56e8495059d526b6151839dcf2a7caddb5a8d18
parent ef88bb1c286db6c7ca29754a750c499cb90dbb82
Author: Janus <ysangkok@gmail.com>
Date: Mon, 21 Jan 2019 21:27:27 +0100
lnchan refactor
- replace undoing logic with new HTLCManager class
- separate SENT/RECEIVED
- move UpdateAddHtlc to lnutil
Diffstat:
11 files changed, 705 insertions(+), 397 deletions(-)
diff --git a/electrum/gui/qt/channel_details.py b/electrum/gui/qt/channel_details.py
@@ -5,9 +5,9 @@ import PyQt5.QtWidgets as QtWidgets
import PyQt5.QtCore as QtCore
from electrum.i18n import _
-from electrum.lnchan import UpdateAddHtlc, HTLCOwner
from electrum.util import bh2u, format_time
-from electrum.lnutil import format_short_channel_id, SENT, RECEIVED
+from electrum.lnutil import format_short_channel_id, LOCAL, REMOTE, UpdateAddHtlc, Direction
+from electrum.lnchan import htlcsum
from electrum.lnaddr import LnAddr, lndecode
from electrum.bitcoin import COIN
@@ -30,8 +30,8 @@ class LinkedLabel(QtWidgets.QLabel):
self.linkActivated.connect(on_clicked)
class ChannelDetailsDialog(QtWidgets.QDialog):
- def make_htlc_item(self, i: UpdateAddHtlc, direction: HTLCOwner) -> HTLCItem:
- it = HTLCItem(_('Sent HTLC with ID {}' if SENT == direction else 'Received HTLC with ID {}').format(i.htlc_id))
+ def make_htlc_item(self, i: UpdateAddHtlc, direction: Direction) -> HTLCItem:
+ it = HTLCItem(_('Sent HTLC with ID {}' if Direction.SENT == direction else 'Received HTLC with ID {}').format(i.htlc_id))
it.appendRow([HTLCItem(_('Amount')),HTLCItem(self.format(i.amount_msat))])
it.appendRow([HTLCItem(_('CLTV expiry')),HTLCItem(str(i.cltv_expiry))])
it.appendRow([HTLCItem(_('Payment hash')),HTLCItem(bh2u(i.payment_hash))])
@@ -45,7 +45,7 @@ class ChannelDetailsDialog(QtWidgets.QDialog):
invoice.appendRow([HTLCItem(_('Date')), HTLCItem(format_time(lnaddr.date))])
it.appendRow([invoice])
- def make_inflight(self, lnaddr, i: UpdateAddHtlc, direction: HTLCOwner) -> HTLCItem:
+ def make_inflight(self, lnaddr, i: UpdateAddHtlc, direction: Direction) -> HTLCItem:
it = self.make_htlc_item(i, direction)
self.append_lnaddr(it, lnaddr)
return it
@@ -99,23 +99,23 @@ class ChannelDetailsDialog(QtWidgets.QDialog):
dest_mapping = self.keyname_rows[to]
dest_mapping[payment_hash] = len(dest_mapping)
- ln_payment_completed = QtCore.pyqtSignal(str, float, HTLCOwner, UpdateAddHtlc, bytes, bytes)
- htlc_added = QtCore.pyqtSignal(str, UpdateAddHtlc, LnAddr, HTLCOwner)
+ ln_payment_completed = QtCore.pyqtSignal(str, float, Direction, UpdateAddHtlc, bytes, bytes)
+ htlc_added = QtCore.pyqtSignal(str, UpdateAddHtlc, LnAddr, Direction)
- @QtCore.pyqtSlot(str, UpdateAddHtlc, LnAddr, HTLCOwner)
+ @QtCore.pyqtSlot(str, UpdateAddHtlc, LnAddr, Direction)
def do_htlc_added(self, evtname, htlc, lnaddr, direction):
mapping = self.keyname_rows['inflight']
mapping[htlc.payment_hash] = len(mapping)
self.folders['inflight'].appendRow(self.make_inflight(lnaddr, htlc, direction))
- @QtCore.pyqtSlot(str, float, HTLCOwner, UpdateAddHtlc, bytes, bytes)
+ @QtCore.pyqtSlot(str, float, Direction, UpdateAddHtlc, bytes, bytes)
def do_ln_payment_completed(self, evtname, date, direction, htlc, preimage, chan_id):
self.move('inflight', 'settled', htlc.payment_hash)
self.update_sent_received()
def update_sent_received(self):
- self.sent_label.setText(str(sum(self.chan.settled[SENT])))
- self.received_label.setText(str(sum(self.chan.settled[RECEIVED])))
+ self.sent_label.setText(str(htlcsum(self.hm.settled_htlcs_by(LOCAL))))
+ self.received_label.setText(str(htlcsum(self.hm.settled_htlcs_by(REMOTE))))
@QtCore.pyqtSlot(str)
def show_tx(self, link_text: str):
diff --git a/electrum/gui/qt/channels_list.py b/electrum/gui/qt/channels_list.py
@@ -30,8 +30,9 @@ class ChannelsList(MyTreeView):
for subject in (REMOTE, LOCAL):
bal_minus_htlcs = chan.balance_minus_outgoing_htlcs(subject)//1000
label = self.parent.format_amount(bal_minus_htlcs)
- bal_other = chan.balance(-subject)//1000
- bal_minus_htlcs_other = chan.balance_minus_outgoing_htlcs(-subject)//1000
+ other = subject.inverted()
+ bal_other = chan.balance(other)//1000
+ bal_minus_htlcs_other = chan.balance_minus_outgoing_htlcs(other)//1000
if bal_other != bal_minus_htlcs_other:
label += ' (+' + self.parent.format_amount(bal_other - bal_minus_htlcs_other) + ')'
labels[subject] = label
diff --git a/electrum/lnbase.py b/electrum/lnbase.py
@@ -25,8 +25,8 @@ from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabl
from .transaction import Transaction, TxOutput
from .lnonion import (new_onion_packet, decode_onion_error, OnionFailureCode, calc_hops_data_for_payment,
process_onion_packet, OnionPacket, construct_onion_error, OnionRoutingFailureMessage)
-from .lnchan import Channel, RevokeAndAck, htlcsum, UpdateAddHtlc
-from .lnutil import (Outpoint, LocalConfig, RECEIVED,
+from .lnchan import Channel, RevokeAndAck, htlcsum
+from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc,
RemoteConfig, OnlyPubkeyKeypair, ChannelConstraints, RevocationStore,
funding_output_script, get_per_commitment_secret_from_seed,
secret_to_pubkey, LNPeerAddr, PaymentFailure, LnLocalFeatures,
@@ -397,20 +397,20 @@ class Peer(PrintError):
htlc_basepoint=keypair_generator(LnKeyFamily.HTLC_BASE),
delayed_basepoint=keypair_generator(LnKeyFamily.DELAY_BASE),
revocation_basepoint=keypair_generator(LnKeyFamily.REVOCATION_BASE),
- to_self_delay=143,
+ to_self_delay=9,
dust_limit_sat=546,
max_htlc_value_in_flight_msat=0xffffffffffffffff,
max_accepted_htlcs=5,
initial_msat=initial_msat,
ctn=-1,
next_htlc_id=0,
- amount_msat=initial_msat,
reserve_sat=546,
per_commitment_secret_seed=keypair_generator(LnKeyFamily.REVOCATION_ROOT).privkey,
funding_locked_received=False,
was_announced=False,
current_commitment_signature=None,
current_htlc_signatures=[],
+ got_sig_for_next=False,
)
return local_config
@@ -472,7 +472,6 @@ class Peer(PrintError):
max_accepted_htlcs=int.from_bytes(payload["max_accepted_htlcs"], 'big'),
initial_msat=push_msat,
ctn = -1,
- amount_msat=push_msat,
next_htlc_id = 0,
reserve_sat = remote_reserve_sat,
@@ -517,9 +516,11 @@ class Peer(PrintError):
# broadcast funding tx
await self.network.broadcast_transaction(funding_tx)
chan.remote_commitment_to_be_revoked = chan.pending_commitment(REMOTE)
- chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0)
- chan.config[LOCAL] = chan.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig)
+ chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0, current_per_commitment_point=remote_per_commitment_point, next_per_commitment_point=None)
+ chan.config[LOCAL] = chan.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig, got_sig_for_next=False)
chan.set_state('OPENING')
+ chan.set_remote_commitment()
+ chan.set_local_commitment(chan.current_commitment(LOCAL))
return chan
async def on_open_channel(self, payload):
@@ -579,7 +580,6 @@ class Peer(PrintError):
max_accepted_htlcs=int.from_bytes(payload['max_accepted_htlcs'], 'big'),
initial_msat=remote_balance_sat,
ctn = -1,
- amount_msat=remote_balance_sat,
next_htlc_id = 0,
reserve_sat = remote_reserve_sat,
@@ -605,7 +605,7 @@ class Peer(PrintError):
)
chan.set_state('OPENING')
chan.remote_commitment_to_be_revoked = chan.pending_commitment(REMOTE)
- chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0)
+ chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0, current_per_commitment_point=payload['first_per_commitment_point'], next_per_commitment_point=None)
chan.config[LOCAL] = chan.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig)
self.lnworker.save_channel(chan)
self.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str())
@@ -732,7 +732,7 @@ class Peer(PrintError):
if not chan.config[LOCAL].funding_locked_received:
our_next_point = chan.config[REMOTE].next_per_commitment_point
their_next_point = payload["next_per_commitment_point"]
- new_remote_state = chan.config[REMOTE]._replace(next_per_commitment_point=their_next_point, current_per_commitment_point=our_next_point)
+ new_remote_state = chan.config[REMOTE]._replace(next_per_commitment_point=their_next_point)
new_local_state = chan.config[LOCAL]._replace(funding_locked_received = True)
chan.config[REMOTE]=new_remote_state
chan.config[LOCAL]=new_local_state
diff --git a/electrum/lnchan.py b/electrum/lnchan.py
@@ -27,24 +27,25 @@ import binascii
import json
from enum import Enum, auto
from typing import Optional, Dict, List, Tuple, NamedTuple, Set, Callable, Iterable, Sequence
-from copy import deepcopy
+from . import ecc
from .util import bfh, PrintError, bh2u
from .bitcoin import TYPE_SCRIPT, TYPE_ADDRESS
from .bitcoin import redeem_script_to_address
from .crypto import sha256, sha256d
-from . import ecc
-from .lnutil import Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints, RevocationStore
-from .lnutil import get_per_commitment_secret_from_seed
-from .lnutil import secret_to_pubkey, derive_privkey, derive_pubkey, derive_blinded_pubkey
-from .lnutil import sign_and_get_sig_string
-from .lnutil import make_htlc_tx_with_open_channel, make_commitment, make_received_htlc, make_offered_htlc
-from .lnutil import HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT
-from .lnutil import funding_output_script, LOCAL, REMOTE, HTLCOwner, make_closing_tx, make_commitment_outputs
-from .lnutil import ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script
+from .simple_config import get_config
from .transaction import Transaction
+
+from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints,
+ get_per_commitment_secret_from_seed, secret_to_pubkey, derive_privkey, make_closing_tx,
+ sign_and_get_sig_string, RevocationStore, derive_blinded_pubkey, Direction, derive_pubkey,
+ make_htlc_tx_with_open_channel, make_commitment, make_received_htlc, make_offered_htlc,
+ HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT, extract_ctn_from_tx_and_chan, UpdateAddHtlc,
+ funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs,
+ ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script)
from .lnsweep import create_sweeptxs_for_their_just_revoked_ctx
from .lnsweep import create_sweeptxs_for_our_latest_ctx, create_sweeptxs_for_their_latest_ctx
+from .lnhtlc import HTLCManager
class ChannelJsonEncoder(json.JSONEncoder):
@@ -83,22 +84,6 @@ class FeeUpdate(defaultdict):
return self.rate
# implicit return None
-class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])):
- """
- This whole class body is so that if you pass a hex-string as payment_hash,
- it is decoded to bytes. Bytes can't be saved to disk, so we save hex-strings.
- """
- __slots__ = ()
- def __new__(cls, *args, **kwargs):
- if len(args) > 0:
- args = list(args)
- if type(args[1]) is str:
- args[1] = bfh(args[1])
- return super().__new__(cls, *args)
- if type(kwargs['payment_hash']) is str:
- kwargs['payment_hash'] = bfh(kwargs['payment_hash'])
- return super().__new__(cls, **kwargs)
-
def decodeAll(d, local):
for k, v in d.items():
if k == 'revocation_store':
@@ -124,20 +109,6 @@ def str_bytes_dict_from_save(x):
def str_bytes_dict_to_save(x):
return {str(k): bh2u(v) for k, v in x.items()}
-class HtlcChanges(NamedTuple):
- # ints are htlc ids
- adds: Dict[int, UpdateAddHtlc]
- settles: Set[int]
- fails: Set[int]
- locked_in: Set[int]
-
- @staticmethod
- def new():
- """
- Since we can't use default arguments for these types (they would be shared among instances)
- """
- return HtlcChanges({}, set(), set(), set())
-
class Channel(PrintError):
def diagnostic_name(self):
if self.name:
@@ -147,7 +118,7 @@ class Channel(PrintError):
except:
return super().diagnostic_name()
- def __init__(self, state, sweep_address = None, name = None, payment_completed : Optional[Callable[[HTLCOwner, UpdateAddHtlc, bytes], None]] = None):
+ def __init__(self, state, sweep_address = None, name = None, payment_completed : Optional[Callable[[Direction, UpdateAddHtlc, bytes], None]] = None):
self.preimages = {}
if not payment_completed:
payment_completed = lambda this, x, y, z: None
@@ -179,13 +150,9 @@ class Channel(PrintError):
# we should not persist txns in this format. we should persist htlcs, and be able to derive
# any past commitment transaction and use that instead; until then...
self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"])
+ self.remote_commitment_to_be_revoked.deserialize(True)
- self.log = {LOCAL: HtlcChanges.new(), REMOTE: HtlcChanges.new()}
- for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
- if strname not in state: continue
- for y in state[strname]:
- htlc = UpdateAddHtlc(**y)
- self.log[subject].adds[htlc.htlc_id] = htlc
+ self.hm = HTLCManager(state.get('log'))
self.name = name
@@ -200,23 +167,18 @@ class Channel(PrintError):
self.lnwatcher = None
- self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])}
-
- for sub in (LOCAL, REMOTE):
- self.log[sub].locked_in.update(self.log[sub].adds.keys())
-
- self.set_local_commitment(self.current_commitment(LOCAL))
- self.set_remote_commitment(self.current_commitment(REMOTE))
+ self.local_commitment = None
+ self.remote_commitment = None
def set_local_commitment(self, ctx):
+ ctn = extract_ctn_from_tx_and_chan(ctx, self)
+ assert self.signature_fits(ctx), (self.log[LOCAL])
self.local_commitment = ctx
if self.sweep_address is not None:
self.local_sweeptxs = create_sweeptxs_for_our_latest_ctx(self, self.local_commitment, self.sweep_address)
- self.assert_signature_fits(ctx)
-
- def set_remote_commitment(self, ctx):
- self.remote_commitment = ctx
+ def set_remote_commitment(self):
+ self.remote_commitment = self.current_commitment(REMOTE)
if self.sweep_address is not None:
self.remote_sweeptxs = create_sweeptxs_for_their_latest_ctx(self, self.remote_commitment, self.sweep_address)
@@ -233,9 +195,9 @@ class Channel(PrintError):
raise PaymentFailure('Channel not open')
if self.available_to_spend(LOCAL) < amount_msat:
raise PaymentFailure(f'Not enough local balance. Have: {self.available_to_spend(LOCAL)}, Need: {amount_msat}')
- if len(self.htlcs(LOCAL, only_pending=True)) + 1 > self.config[REMOTE].max_accepted_htlcs:
+ if len(self.hm.htlcs(LOCAL)) + 1 > self.config[REMOTE].max_accepted_htlcs:
raise PaymentFailure('Too many HTLCs already in channel')
- current_htlc_sum = htlcsum(self.htlcs(LOCAL, only_pending=True))
+ current_htlc_sum = htlcsum(self.hm.htlcs_by_direction(LOCAL, SENT)) + htlcsum(self.hm.htlcs_by_direction(LOCAL, RECEIVED))
if current_htlc_sum + amount_msat > self.config[REMOTE].max_htlc_value_in_flight_msat:
raise PaymentFailure(f'HTLC value sum (sum of pending htlcs: {current_htlc_sum/1000} sat plus new htlc: {amount_msat/1000} sat) would exceed max allowed: {self.config[REMOTE].max_htlc_value_in_flight_msat/1000} sat')
if amount_msat <= 0: # FIXME htlc_minimum_msat
@@ -269,7 +231,7 @@ class Channel(PrintError):
assert type(htlc) is dict
self._check_can_pay(htlc['amount_msat'])
htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id)
- self.log[LOCAL].adds[htlc.htlc_id] = htlc
+ self.hm.send_htlc(htlc)
self.print_error("add_htlc")
self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id
@@ -288,8 +250,7 @@ class Channel(PrintError):
raise RemoteMisbehaving('Remote dipped below channel reserve.' +\
f' Available at remote: {self.available_to_spend(REMOTE)},' +\
f' HTLC amount: {htlc.amount_msat}')
- adds = self.log[REMOTE].adds
- adds[htlc.htlc_id] = htlc
+ self.hm.recv_htlc(htlc)
self.print_error("receive_htlc")
self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id
@@ -308,7 +269,7 @@ class Channel(PrintError):
"""
self.print_error("sign_next_commitment")
- old_logs = dict(self.lock_in_htlc_changes(LOCAL))
+ self.hm.send_ctx()
pending_remote_commitment = self.pending_commitment(REMOTE)
sig_64 = sign_and_get_sig_string(pending_remote_commitment, self.config[LOCAL], self.config[REMOTE])
@@ -321,7 +282,8 @@ class Channel(PrintError):
for_us = False
htlcsigs = []
- for we_receive, htlcs in zip([True, False], [self.included_htlcs(REMOTE, REMOTE), self.included_htlcs(REMOTE, LOCAL)]):
+ # they sent => we receive
+ for we_receive, htlcs in zip([True, False], [self.included_htlcs(REMOTE, SENT, ctn=self.config[REMOTE].ctn+1), self.included_htlcs(REMOTE, RECEIVED, ctn=self.config[REMOTE].ctn+1)]):
for htlc in htlcs:
_script, htlc_tx = make_htlc_tx_with_open_channel(chan=self,
pcp=self.config[REMOTE].next_per_commitment_point,
@@ -337,26 +299,11 @@ class Channel(PrintError):
htlcsigs.sort()
htlcsigs = [x[1] for x in htlcsigs]
- self.remote_commitment = self.pending_commitment(REMOTE)
-
- # we can't know if this message arrives.
- # since we shouldn't actually throw away
- # failed htlcs yet (or mark htlc locked in),
- # roll back the changes that were made
- self.log = old_logs
+ # TODO should add remote_commitment here and handle
+ # both valid ctx'es in lnwatcher at the same time...
return sig_64, htlcsigs
- def lock_in_htlc_changes(self, subject):
- for sub in (LOCAL, REMOTE):
- log = self.log[sub]
- yield (sub, deepcopy(log))
- for htlc_id in log.fails:
- log.adds.pop(htlc_id)
- log.fails.clear()
-
- self.log[subject].locked_in.update(self.log[subject].adds.keys())
-
def receive_new_commitment(self, sig, htlc_sigs):
"""
ReceiveNewCommitment process a signature for a new commitment state sent by
@@ -372,7 +319,7 @@ class Channel(PrintError):
"""
self.print_error("receive_new_commitment")
- for _ in self.lock_in_htlc_changes(REMOTE): pass
+ self.hm.recv_ctx()
assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes
@@ -385,16 +332,18 @@ class Channel(PrintError):
htlc_sigs_string = b''.join(htlc_sigs)
htlc_sigs = htlc_sigs[:] # copy cause we will delete now
- for htlcs, we_receive in [(self.included_htlcs(LOCAL, REMOTE), True), (self.included_htlcs(LOCAL, LOCAL), False)]:
+ ctn = self.config[LOCAL].ctn+1
+ for htlcs, we_receive in [(self.included_htlcs(LOCAL, SENT, ctn=ctn), False), (self.included_htlcs(LOCAL, RECEIVED, ctn=ctn), True)]:
for htlc in htlcs:
- idx = self.verify_htlc(htlc, htlc_sigs, we_receive)
+ idx = self.verify_htlc(htlc, htlc_sigs, we_receive, pending_local_commitment)
del htlc_sigs[idx]
if len(htlc_sigs) != 0: # all sigs should have been popped above
raise Exception('failed verifying HTLC signatures: invalid amount of correct signatures')
self.config[LOCAL]=self.config[LOCAL]._replace(
current_commitment_signature=sig,
- current_htlc_signatures=htlc_sigs_string)
+ current_htlc_signatures=htlc_sigs_string,
+ got_sig_for_next=True)
if self.pending_fee is not None:
if not self.constraints.is_initiator:
@@ -402,15 +351,15 @@ class Channel(PrintError):
if self.constraints.is_initiator and self.pending_fee[FUNDEE_ACKED]:
self.pending_fee[FUNDER_SIGNED] = True
- self.set_local_commitment(self.pending_commitment(LOCAL))
+ self.set_local_commitment(pending_local_commitment)
- def verify_htlc(self, htlc: UpdateAddHtlc, htlc_sigs: Sequence[bytes], we_receive: bool) -> int:
- _, this_point, _ = self.points()
+ def verify_htlc(self, htlc: UpdateAddHtlc, htlc_sigs: Sequence[bytes], we_receive: bool, ctx) -> int:
+ _, this_point, _, _ = self.points()
_script, htlc_tx = make_htlc_tx_with_open_channel(chan=self,
pcp=this_point,
for_us=True,
we_receive=we_receive,
- commit=self.pending_commitment(LOCAL),
+ commit=ctx,
htlc=htlc)
pre_hash = sha256d(bfh(htlc_tx.serialize_preimage(0)))
remote_htlc_pubkey = derive_pubkey(self.config[REMOTE].htlc_basepoint.pubkey, this_point)
@@ -418,19 +367,19 @@ class Channel(PrintError):
if ecc.verify_signature(remote_htlc_pubkey, sig, pre_hash):
return idx
else:
- raise Exception(f'failed verifying HTLC signatures: {htlc}')
+ raise Exception(f'failed verifying HTLC signatures: {htlc}, sigs: {len(htlc_sigs)}, we_receive: {we_receive}')
- def get_remote_htlc_sig_for_htlc(self, htlc: UpdateAddHtlc, we_receive: bool) -> bytes:
+ def get_remote_htlc_sig_for_htlc(self, htlc: UpdateAddHtlc, we_receive: bool, ctx) -> bytes:
data = self.config[LOCAL].current_htlc_signatures
htlc_sigs = [data[i:i + 64] for i in range(0, len(data), 64)]
- idx = self.verify_htlc(htlc, htlc_sigs, we_receive=we_receive)
+ idx = self.verify_htlc(htlc, htlc_sigs, we_receive=we_receive, ctx=ctx)
remote_htlc_sig = ecc.der_sig_from_sig_string(htlc_sigs[idx]) + b'\x01'
return remote_htlc_sig
def revoke_current_commitment(self):
self.print_error("revoke_current_commitment")
- last_secret, this_point, next_point = self.points()
+ last_secret, this_point, next_point, _ = self.points()
new_feerate = self.constraints.feerate
@@ -444,16 +393,18 @@ class Channel(PrintError):
self.pending_fee = None
print("FEERATE CHANGE COMPLETE (initiator)")
- self.config[LOCAL]=self.config[LOCAL]._replace(
- ctn=self.config[LOCAL].ctn + 1,
- )
+ assert self.config[LOCAL].got_sig_for_next
self.constraints=self.constraints._replace(
feerate=new_feerate
)
-
- # since we should not revoke our latest commitment tx,
- # we do not update self.local_commitment here,
- # it should instead be updated when we receive a new sig
+ self.set_local_commitment(self.pending_commitment(LOCAL))
+ ctx = self.pending_commitment(LOCAL)
+ self.hm.send_rev()
+ self.config[LOCAL]=self.config[LOCAL]._replace(
+ ctn=self.config[LOCAL].ctn + 1,
+ got_sig_for_next=False,
+ )
+ assert self.signature_fits(ctx)
return RevokeAndAck(last_secret, next_point), "current htlcs"
@@ -466,7 +417,8 @@ class Channel(PrintError):
this_point = secret_to_pubkey(int.from_bytes(this_secret, 'big'))
next_secret = get_per_commitment_secret_from_seed(self.config[LOCAL].per_commitment_secret_seed, RevocationStore.START_INDEX - next_small_num)
next_point = secret_to_pubkey(int.from_bytes(next_secret, 'big'))
- return last_secret, this_point, next_point
+ last_point = secret_to_pubkey(int.from_bytes(last_secret, 'big'))
+ return last_secret, this_point, next_point, last_point
def process_new_revocation_secret(self, per_commitment_secret: bytes):
if not self.lnwatcher:
@@ -481,12 +433,9 @@ class Channel(PrintError):
def receive_revocation(self, revocation) -> Tuple[int, int]:
self.print_error("receive_revocation")
- old_logs = dict(self.lock_in_htlc_changes(LOCAL))
-
cur_point = self.config[REMOTE].current_per_commitment_point
derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True)
if cur_point != derived_point:
- self.log = old_logs
raise Exception('revoked secret not for current point')
# FIXME not sure this is correct... but it seems to work
@@ -505,51 +454,36 @@ class Channel(PrintError):
if self.constraints.is_initiator and self.pending_fee[FUNDEE_ACKED]:
self.pending_fee[FUNDER_SIGNED] = True
- def mark_settled(subject):
- """
- find pending settlements for subject (LOCAL or REMOTE) and mark them settled, return value of settled htlcs
- """
- old_amount = htlcsum(self.htlcs(subject, False))
-
- for htlc_id in self.log[subject].settles:
- adds = self.log[subject].adds
- htlc = adds.pop(htlc_id)
- self.settled[subject].append(htlc.amount_msat)
- if subject == LOCAL:
- preimage = self.preimages.pop(htlc_id)
- else:
- preimage = None
- self.payment_completed(self, subject, htlc, preimage)
- self.log[subject].settles.clear()
-
- return old_amount - htlcsum(self.htlcs(subject, False))
-
- sent_this_batch = mark_settled(LOCAL)
- received_this_batch = mark_settled(REMOTE)
+ received = self.hm.received_in_ctn(self.config[REMOTE].ctn + 1)
+ sent = self.hm.sent_in_ctn(self.config[REMOTE].ctn + 1)
+ for htlc in received:
+ self.payment_completed(self, RECEIVED, htlc, None)
+ for htlc in sent:
+ preimage = self.preimages.pop(htlc.htlc_id)
+ self.payment_completed(self, SENT, htlc, preimage)
+ received_this_batch = htlcsum(received)
+ sent_this_batch = htlcsum(sent)
next_point = self.config[REMOTE].next_per_commitment_point
- print("RECEIVED", received_this_batch)
- print("SENT", sent_this_batch)
+ self.hm.recv_rev()
+
self.config[REMOTE]=self.config[REMOTE]._replace(
ctn=self.config[REMOTE].ctn + 1,
current_per_commitment_point=next_point,
next_per_commitment_point=revocation.next_per_commitment_point,
- amount_msat=self.config[REMOTE].amount_msat + (sent_this_batch - received_this_batch)
- )
- self.config[LOCAL]=self.config[LOCAL]._replace(
- amount_msat = self.config[LOCAL].amount_msat + (received_this_batch - sent_this_batch)
)
if self.pending_fee is not None:
if self.constraints.is_initiator:
self.pending_fee[FUNDEE_ACKED] = True
- self.set_remote_commitment(self.pending_commitment(REMOTE))
+ self.set_remote_commitment()
self.remote_commitment_to_be_revoked = prev_remote_commitment
+
return received_this_batch, sent_this_batch
- def balance(self, subject):
+ def balance(self, subject, ctn=None):
"""
This balance in mSAT is not including reserve and fees.
So a node cannot actually use it's whole balance.
@@ -560,12 +494,15 @@ class Channel(PrintError):
commited to later when the respective commitment
transaction as been revoked.
"""
+ assert type(subject) is HTLCOwner
initial = self.config[subject].initial_msat
- initial -= sum(self.settled[subject])
- initial += sum(self.settled[-subject])
+ for direction, htlc in self.hm.settled_htlcs(subject, ctn):
+ if direction == SENT:
+ initial -= htlc.amount_msat
+ else:
+ initial += htlc.amount_msat
- assert initial == self.config[subject].amount_msat
return initial
def balance_minus_outgoing_htlcs(self, subject):
@@ -573,48 +510,46 @@ class Channel(PrintError):
This balance in mSAT, which includes the value of
pending outgoing HTLCs, is used in the UI.
"""
- return self.balance(subject)\
- - htlcsum(self.log[subject].adds.values())
+ assert type(subject) is HTLCOwner
+ ctn = self.hm.log[subject]['ctn'] + 1
+ return self.balance(subject, ctn)\
+ - htlcsum(self.hm.htlcs_by_direction(subject, SENT, ctn))
def available_to_spend(self, subject):
"""
This balance in mSAT, while technically correct, can
not be used in the UI cause it fluctuates (commit fee)
"""
+ assert type(subject) is HTLCOwner
return self.balance_minus_outgoing_htlcs(subject)\
- - htlcsum(self.log[subject].adds.values())\
- self.config[-subject].reserve_sat * 1000\
- calc_onchain_fees(
# TODO should we include a potential new htlc, when we are called from receive_htlc?
- len(list(self.included_htlcs(subject, LOCAL)) + list(self.included_htlcs(subject, REMOTE))),
+ len(self.included_htlcs(subject, SENT) + self.included_htlcs(subject, RECEIVED)),
self.pending_feerate(subject),
- True, # for_us
self.constraints.is_initiator,
)[subject]
- def amounts(self):
- remote_settled= htlcsum(self.htlcs(REMOTE, False))
- local_settled= htlcsum(self.htlcs(LOCAL, False))
- unsettled_local = htlcsum(self.htlcs(LOCAL, True))
- unsettled_remote = htlcsum(self.htlcs(REMOTE, True))
- remote_msat = self.config[REMOTE].amount_msat -\
- unsettled_remote + local_settled - remote_settled
- local_msat = self.config[LOCAL].amount_msat -\
- unsettled_local + remote_settled - local_settled
- return remote_msat, local_msat
-
- def included_htlcs(self, subject, htlc_initiator, only_pending=True):
+ def included_htlcs(self, subject, direction, ctn=None):
"""
return filter of non-dust htlcs for subjects commitment transaction, initiated by given party
"""
+ assert type(subject) is HTLCOwner
+ assert type(direction) is Direction
+ if ctn is None:
+ ctn = self.config[subject].ctn
feerate = self.pending_feerate(subject)
conf = self.config[subject]
- weight = HTLC_SUCCESS_WEIGHT if subject != htlc_initiator else HTLC_TIMEOUT_WEIGHT
- htlcs = self.htlcs(htlc_initiator, only_pending=only_pending)
+ if (subject, direction) in [(REMOTE, RECEIVED), (LOCAL, SENT)]:
+ weight = HTLC_SUCCESS_WEIGHT
+ else:
+ weight = HTLC_TIMEOUT_WEIGHT
+ htlcs = self.hm.htlcs_by_direction(subject, direction, ctn=ctn)
fee_for_htlc = lambda htlc: htlc.amount_msat // 1000 - (weight * feerate // 1000)
- return filter(lambda htlc: fee_for_htlc(htlc) >= conf.dust_limit_sat, htlcs)
+ return list(filter(lambda htlc: fee_for_htlc(htlc) >= conf.dust_limit_sat, htlcs))
def pending_feerate(self, subject):
+ assert type(subject) is HTLCOwner
candidate = self.constraints.feerate
if self.pending_fee is not None:
x = self.pending_fee.pending_feerate(subject)
@@ -623,81 +558,53 @@ class Channel(PrintError):
return candidate
def pending_commitment(self, subject):
+ assert type(subject) is HTLCOwner
this_point = self.config[REMOTE].next_per_commitment_point if subject == REMOTE else self.points()[1]
- return self.make_commitment(subject, this_point)
+ ctn = self.config[subject].ctn + 1
+ feerate = self.pending_feerate(subject)
+ return self.make_commitment(subject, this_point, ctn, feerate, True)
def current_commitment(self, subject):
- old_local_state = self.config[subject]
- self.config[subject]=self.config[subject]._replace(ctn=self.config[subject].ctn - 1)
- r = self.pending_commitment(subject)
- self.config[subject] = old_local_state
- return r
-
- def total_msat(self, sub):
- return sum(self.settled[sub])
+ assert type(subject) is HTLCOwner
+ this_point = self.config[REMOTE].current_per_commitment_point if subject == REMOTE else self.points()[3]
+ ctn = self.config[subject].ctn
+ feerate = self.constraints.feerate
+ return self.make_commitment(subject, this_point, ctn, feerate, False)
- def htlcs(self, subject, only_pending):
- """
- only_pending: require the htlc's settlement to be pending (needs additional signatures/acks)
-
- sets returned with True and False are disjunct
-
- only_pending true:
- skipped if settled or failed
- <=>
- included if not settled and not failed
- only_pending false:
- skipped if not (settled or failed)
- <=>
- included if not not (settled or failed)
- included if settled or failed
- """
- update_log = self.log[subject]
- res = []
- for htlc in update_log.adds.values():
- locked_in = htlc.htlc_id in update_log.locked_in
- settled = htlc.htlc_id in update_log.settles
- failed = htlc.htlc_id in update_log.fails
- if not locked_in:
- continue
- if only_pending == (settled or failed):
- continue
- res.append(htlc)
- return res
+ def total_msat(self, direction):
+ assert type(direction) is Direction
+ sub = LOCAL if direction == SENT else REMOTE
+ return htlcsum(self.hm.settled_htlcs_by(sub, self.config[sub].ctn))
def settle_htlc(self, preimage, htlc_id):
"""
SettleHTLC attempts to settle an existing outstanding received HTLC.
"""
self.print_error("settle_htlc")
- log = self.log[REMOTE]
- htlc = log.adds[htlc_id]
+ log = self.hm.log[REMOTE]
+ htlc = log['adds'][htlc_id]
assert htlc.payment_hash == sha256(preimage)
- assert htlc_id not in log.settles
- log.settles.add(htlc_id)
+ assert htlc_id not in log['settles']
+ self.hm.send_settle(htlc_id)
# not saving preimage because it's already saved in LNWorker.invoices
def receive_htlc_settle(self, preimage, htlc_id):
self.print_error("receive_htlc_settle")
- log = self.log[LOCAL]
- htlc = log.adds[htlc_id]
+ log = self.hm.log[LOCAL]
+ htlc = log['adds'][htlc_id]
assert htlc.payment_hash == sha256(preimage)
- assert htlc_id not in log.settles
+ assert htlc_id not in log['settles']
+ self.hm.recv_settle(htlc_id)
self.preimages[htlc_id] = preimage
- log.settles.add(htlc_id)
# we don't save the preimage because we don't need to forward it anyway
def fail_htlc(self, htlc_id):
self.print_error("fail_htlc")
- log = self.log[REMOTE]
- assert htlc_id not in log.fails
- log.fails.add(htlc_id)
+ self.hm.send_fail(htlc_id)
def receive_fail_htlc(self, htlc_id):
self.print_error("receive_fail_htlc")
- log = self.log[LOCAL]
- assert htlc_id not in log.fails
- log.fails.add(htlc_id)
+ self.hm.recv_fail(htlc_id)
@property
def current_height(self):
@@ -713,29 +620,7 @@ class Channel(PrintError):
raise Exception("a fee update is already in progress")
self.pending_fee = FeeUpdate(self, rate=feerate)
- def remove_uncommitted_htlcs_from_log(self, subject):
- """
- returns
- - the htlcs with uncommited (not locked in) htlcs removed
- - a list of htlc_ids that were removed
- """
- removed = []
- htlcs = []
- log = self.log[subject]
- for i in log.adds.values():
- locked_in = i.htlc_id in log.locked_in
- if locked_in:
- htlcs.append(i._asdict())
- else:
- removed.append(i.htlc_id)
- return htlcs, removed
-
def to_save(self):
- # need to forget about uncommited htlcs
- # since we must assume they don't know about it,
- # if it was not acked
- remote_filtered, remote_removed = self.remove_uncommitted_htlcs_from_log(REMOTE)
- local_filtered, local_removed = self.remove_uncommitted_htlcs_from_log(LOCAL)
to_save = {
"local_config": self.config[LOCAL],
"remote_config": self.config[REMOTE],
@@ -745,24 +630,10 @@ class Channel(PrintError):
"funding_outpoint": self.funding_outpoint,
"node_id": self.node_id,
"remote_commitment_to_be_revoked": str(self.remote_commitment_to_be_revoked),
- "remote_log": remote_filtered,
- "local_log": local_filtered,
+ "log": self.hm.to_save(),
"onion_keys": str_bytes_dict_to_save(self.onion_keys),
- "settled_local": self.settled[LOCAL],
- "settled_remote": self.settled[REMOTE],
"force_closed": self.get_state() == 'FORCE_CLOSING',
}
-
- # htlcs number must be monotonically increasing,
- # so we have to decrease the counter
- if len(remote_removed) != 0:
- assert min(remote_removed) < to_save['remote_config'].next_htlc_id
- to_save['remote_config'] = to_save['remote_config']._replace(next_htlc_id = min(remote_removed))
-
- if len(local_removed) != 0:
- assert min(local_removed) < to_save['local_config'].next_htlc_id
- to_save['local_config'] = to_save['local_config']._replace(next_htlc_id = min(local_removed))
-
return to_save
def serialize(self):
@@ -792,33 +663,49 @@ class Channel(PrintError):
def __str__(self):
return str(self.serialize())
- def make_commitment(self, subject, this_point) -> Transaction:
- remote_msat, local_msat = self.amounts()
- assert local_msat >= 0, local_msat
- assert remote_msat >= 0, remote_msat
+ def make_commitment(self, subject, this_point, ctn, feerate, pending) -> Transaction:
+ #if subject == REMOTE and not pending:
+ # ctn -= 1
+ assert type(subject) is HTLCOwner
+ other = REMOTE if LOCAL == subject else LOCAL
+ remote_msat, local_msat = self.balance(other, ctn), self.balance(subject, ctn)
+ received_htlcs = self.hm.htlcs_by_direction(subject, SENT if subject == LOCAL else RECEIVED, ctn)
+ sent_htlcs = self.hm.htlcs_by_direction(subject, RECEIVED if subject == LOCAL else SENT, ctn)
+ if subject != LOCAL:
+ remote_msat -= htlcsum(received_htlcs)
+ local_msat -= htlcsum(sent_htlcs)
+ else:
+ remote_msat -= htlcsum(sent_htlcs)
+ local_msat -= htlcsum(received_htlcs)
+ assert remote_msat >= 0
+ assert local_msat >= 0
+ # same htlcs as before, but now without dust.
+ received_htlcs = self.included_htlcs(subject, SENT if subject == LOCAL else RECEIVED, ctn)
+ sent_htlcs = self.included_htlcs(subject, RECEIVED if subject == LOCAL else SENT, ctn)
+
this_config = self.config[subject]
other_config = self.config[-subject]
other_htlc_pubkey = derive_pubkey(other_config.htlc_basepoint.pubkey, this_point)
this_htlc_pubkey = derive_pubkey(this_config.htlc_basepoint.pubkey, this_point)
other_revocation_pubkey = derive_blinded_pubkey(other_config.revocation_basepoint.pubkey, this_point)
htlcs = [] # type: List[ScriptHtlc]
- def append_htlc(htlc: UpdateAddHtlc, is_received_htlc: bool):
- htlcs.append(ScriptHtlc(make_htlc_output_witness_script(
- is_received_htlc=is_received_htlc,
- remote_revocation_pubkey=other_revocation_pubkey,
- remote_htlc_pubkey=other_htlc_pubkey,
- local_htlc_pubkey=this_htlc_pubkey,
- payment_hash=htlc.payment_hash,
- cltv_expiry=htlc.cltv_expiry), htlc))
- for htlc in self.included_htlcs(subject, -subject):
- append_htlc(htlc, is_received_htlc=True)
- for htlc in self.included_htlcs(subject, subject):
- append_htlc(htlc, is_received_htlc=False)
- if subject != LOCAL:
- remote_msat, local_msat = local_msat, remote_msat
+ for is_received_htlc, htlc_list in zip((subject != LOCAL, subject == LOCAL), (received_htlcs, sent_htlcs)):
+ for htlc in htlc_list:
+ htlcs.append(ScriptHtlc(make_htlc_output_witness_script(
+ is_received_htlc=is_received_htlc,
+ remote_revocation_pubkey=other_revocation_pubkey,
+ remote_htlc_pubkey=other_htlc_pubkey,
+ local_htlc_pubkey=this_htlc_pubkey,
+ payment_hash=htlc.payment_hash,
+ cltv_expiry=htlc.cltv_expiry), htlc))
+ onchain_fees = calc_onchain_fees(
+ len(htlcs),
+ feerate,
+ self.constraints.is_initiator == (subject == LOCAL),
+ )
payment_pubkey = derive_pubkey(other_config.payment_basepoint.pubkey, this_point)
return make_commitment(
- self.config[subject].ctn + 1,
+ ctn,
this_config.multisig_key.pubkey,
other_config.multisig_key.pubkey,
payment_pubkey,
@@ -832,12 +719,7 @@ class Channel(PrintError):
local_msat,
remote_msat,
this_config.dust_limit_sat,
- calc_onchain_fees(
- len(htlcs),
- self.pending_feerate(subject),
- subject == LOCAL,
- self.constraints.is_initiator,
- ),
+ onchain_fees,
htlcs=htlcs)
def get_local_index(self):
@@ -850,8 +732,8 @@ class Channel(PrintError):
LOCAL: fee_sat * 1000 if self.constraints.is_initiator else 0,
REMOTE: fee_sat * 1000 if not self.constraints.is_initiator else 0,
},
- self.config[LOCAL].amount_msat,
- self.config[REMOTE].amount_msat,
+ self.balance(LOCAL),
+ self.balance(REMOTE),
(TYPE_SCRIPT, bh2u(local_script)),
(TYPE_SCRIPT, bh2u(remote_script)),
[], self.config[LOCAL].dust_limit_sat)
@@ -867,38 +749,39 @@ class Channel(PrintError):
sig = ecc.sig_string_from_der_sig(der_sig[:-1])
return sig, closing_tx
- def assert_signature_fits(self, tx):
+ def signature_fits(self, tx):
remote_sig = self.config[LOCAL].current_commitment_signature
- if remote_sig: # only None in test
- preimage_hex = tx.serialize_preimage(0)
- pre_hash = sha256d(bfh(preimage_hex))
- if not ecc.verify_signature(self.config[REMOTE].multisig_key.pubkey, remote_sig, pre_hash):
- self.print_error("WARNING: commitment signature inconsistency, cannot force close")
+ preimage_hex = tx.serialize_preimage(0)
+ pre_hash = sha256d(bfh(preimage_hex))
+ assert remote_sig
+ res = ecc.verify_signature(self.config[REMOTE].multisig_key.pubkey, remote_sig, pre_hash)
+ return res
def force_close_tx(self):
tx = self.local_commitment
+ assert self.signature_fits(tx)
tx = Transaction(str(tx))
tx.deserialize(True)
- self.assert_signature_fits(tx)
tx.sign({bh2u(self.config[LOCAL].multisig_key.pubkey): (self.config[LOCAL].multisig_key.privkey, True)})
remote_sig = self.config[LOCAL].current_commitment_signature
- if remote_sig: # only None in test
- remote_sig = ecc.der_sig_from_sig_string(remote_sig) + b"\x01"
- sigs = tx._inputs[0]["signatures"]
- none_idx = sigs.index(None)
- tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
- assert tx.is_complete()
+ remote_sig = ecc.der_sig_from_sig_string(remote_sig) + b"\x01"
+ sigs = tx._inputs[0]["signatures"]
+ none_idx = sigs.index(None)
+ tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
+ assert tx.is_complete()
return tx
def included_htlcs_in_their_latest_ctxs(self, htlc_initiator) -> Dict[int, List[UpdateAddHtlc]]:
""" A map from commitment number to list of HTLCs in
their latest two commitment transactions.
The oldest might have been revoked. """
- old_htlcs = list(self.included_htlcs(REMOTE, htlc_initiator, only_pending=False))
+ assert type(htlc_initiator) is HTLCOwner
+ direction = RECEIVED if htlc_initiator == LOCAL else SENT
+ old_ctn = self.config[REMOTE].ctn
+ old_htlcs = self.included_htlcs(REMOTE, direction, ctn=old_ctn)
- old_logs = dict(self.lock_in_htlc_changes(LOCAL))
- new_htlcs = list(self.included_htlcs(REMOTE, htlc_initiator))
- self.log = old_logs
+ new_ctn = self.config[REMOTE].ctn+1
+ new_htlcs = self.included_htlcs(REMOTE, direction, ctn=new_ctn)
- return {self.config[REMOTE].ctn: old_htlcs,
- self.config[REMOTE].ctn+1: new_htlcs, }
+ return {old_ctn: old_htlcs,
+ new_ctn: new_htlcs, }
diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py
@@ -0,0 +1,159 @@
+from copy import deepcopy
+from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction
+from .util import bh2u
+
+class HTLCManager:
+ def __init__(self, log=None):
+ self.expect_sig = {SENT: False, RECEIVED: False}
+ if log is None:
+ initial = {'ctn': 0, 'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}}
+ log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)}
+ else:
+ assert type(log) is dict
+ log = {HTLCOwner(int(x)): y for x, y in deepcopy(log).items()}
+ for sub in (LOCAL, REMOTE):
+ log[sub]['adds'] = {int(x): UpdateAddHtlc(*y) for x, y in log[sub]['adds'].items()}
+ coerceHtlcOwner2IntMap = lambda x: {HTLCOwner(int(y)): z for y, z in x.items()}
+ log[sub]['locked_in'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['locked_in'].items()}
+ log[sub]['settles'] = {int(x): y for x, y in log[sub]['settles'].items()}
+ log[sub]['fails'] = {int(x): y for x, y in log[sub]['fails'].items()}
+ self.log = log
+
+ def to_save(self):
+ x = deepcopy(self.log)
+ for sub in (LOCAL, REMOTE):
+ d = {}
+ for htlc_id, htlc in x[sub]['adds'].items():
+ d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:]
+ x[sub]['adds'] = d
+ return x
+
+ def send_htlc(self, htlc):
+ htlc_id = htlc.htlc_id
+ adds = self.log[LOCAL]['adds']
+ assert type(adds) is not str
+ adds[htlc_id] = htlc
+ self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.log[REMOTE]['ctn']+1}
+ self.expect_sig[SENT] = True
+ return htlc
+
+ def recv_htlc(self, htlc):
+ htlc_id = htlc.htlc_id
+ self.log[REMOTE]['htlc_id'] = htlc_id
+ self.log[REMOTE]['adds'][htlc_id] = htlc
+ l = self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.log[LOCAL]['ctn']+1, REMOTE: None}
+ self.expect_sig[RECEIVED] = True
+
+ def send_ctx(self):
+ next_ctn = self.log[REMOTE]['ctn'] + 1
+
+ for locked_in in self.log[REMOTE]['locked_in'].values():
+ if locked_in[REMOTE] is None:
+ locked_in[REMOTE] = next_ctn
+
+ self.expect_sig[SENT] = False
+
+ #return Sig(self.pending_htlcs(REMOTE), next_ctn)
+
+ def recv_ctx(self):
+ next_ctn = self.log[LOCAL]['ctn'] + 1
+
+ for locked_in in self.log[LOCAL]['locked_in'].values():
+ if locked_in[LOCAL] is None:
+ locked_in[LOCAL] = next_ctn
+
+ self.expect_sig[SENT] = False
+
+ def send_rev(self):
+ self.log[LOCAL]['ctn'] += 1
+
+ def recv_rev(self):
+ self.log[REMOTE]['ctn'] += 1
+ did_set_htlc_height = False
+ for htlc_id, ctnheights in self.log[LOCAL]['locked_in'].items():
+ if ctnheights[LOCAL] is None:
+ did_set_htlc_height = True
+ assert ctnheights[REMOTE] == self.log[REMOTE]['ctn']
+ ctnheights[LOCAL] = ctnheights[REMOTE]
+ return did_set_htlc_height
+
+ def htlcs_by_direction(self, subject, direction, ctn=None):
+ """
+ direction is relative to subject!
+ """
+ assert type(subject) is HTLCOwner
+ assert type(direction) is Direction
+ if ctn is None:
+ ctn = self.log[subject]['ctn']
+ l = []
+ if direction == SENT and subject == LOCAL:
+ party = LOCAL
+ elif direction == RECEIVED and subject == REMOTE:
+ party = LOCAL
+ else:
+ party = REMOTE
+ for htlc_id, ctnheights in self.log[party]['locked_in'].items():
+ htlc_height = ctnheights[subject]
+ if htlc_height is None:
+ include = not self.expect_sig[RECEIVED if party == LOCAL else SENT] and ctnheights[-subject] <= ctn
+ else:
+ include = htlc_height <= ctn
+ if include:
+ settles = self.log[party]['settles']
+ if htlc_id not in settles or settles[htlc_id] > ctn:
+ fails = self.log[party]['fails']
+ if htlc_id not in fails or fails[htlc_id] > ctn:
+ l.append(self.log[party]['adds'][htlc_id])
+ return l
+
+ def htlcs(self, subject, ctn=None):
+ assert type(subject) is HTLCOwner
+ if ctn is None:
+ ctn = self.log[subject]['ctn']
+ l = []
+ l += [(SENT, x) for x in self.htlcs_by_direction(subject, SENT, ctn)]
+ l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn)]
+ return l
+
+ def current_htlcs(self, subject):
+ assert type(subject) is HTLCOwner
+ ctn = self.log[subject]['ctn']
+ return self.htlcs(subject, ctn)
+
+ def pending_htlcs(self, subject):
+ assert type(subject) is HTLCOwner
+ ctn = self.log[subject]['ctn'] + 1
+ return self.htlcs(subject, ctn)
+
+ def send_settle(self, htlc_id):
+ self.log[REMOTE]['settles'][htlc_id] = self.log[REMOTE]['ctn'] + 1
+
+ def recv_settle(self, htlc_id):
+ self.log[LOCAL]['settles'][htlc_id] = self.log[LOCAL]['ctn'] + 1
+
+ def settled_htlcs_by(self, subject, ctn=None):
+ assert type(subject) is HTLCOwner
+ if ctn is None:
+ ctn = self.log[subject]['ctn']
+ return [self.log[subject]['adds'][htlc_id] for htlc_id, height in self.log[subject]['settles'].items() if height <= ctn]
+
+ def settled_htlcs(self, subject, ctn=None):
+ assert type(subject) is HTLCOwner
+ if ctn is None:
+ ctn = self.log[subject]['ctn']
+ sent = [(SENT, x) for x in self.settled_htlcs_by(subject, ctn)]
+ other = subject.inverted()
+ received = [(RECEIVED, x) for x in self.settled_htlcs_by(other, ctn)]
+ return sent + received
+
+ def received_in_ctn(self, ctn):
+ return [self.log[REMOTE]['adds'][htlc_id] for htlc_id, height in self.log[REMOTE]['settles'].items() if height == ctn]
+
+ def sent_in_ctn(self, ctn):
+ return [self.log[LOCAL]['adds'][htlc_id] for htlc_id, height in self.log[LOCAL]['settles'].items() if height == ctn]
+
+ def send_fail(self, htlc_id):
+ self.log[REMOTE]['fails'][htlc_id] = self.log[REMOTE]['ctn'] + 1
+
+ def recv_fail(self, htlc_id):
+ self.log[LOCAL]['fails'][htlc_id] = self.log[LOCAL]['ctn'] + 1
diff --git a/electrum/lnsweep.py b/electrum/lnsweep.py
@@ -9,15 +9,15 @@ from .bitcoin import TYPE_ADDRESS, redeem_script_to_address, dust_threshold
from . import ecc
from .lnutil import (make_commitment_output_to_remote_address, make_commitment_output_to_local_witness_script,
derive_privkey, derive_pubkey, derive_blinded_pubkey, derive_blinded_privkey,
- make_htlc_tx_witness, make_htlc_tx_with_open_channel,
+ make_htlc_tx_witness, make_htlc_tx_with_open_channel, UpdateAddHtlc,
LOCAL, REMOTE, make_htlc_output_witness_script, UnknownPaymentHash,
get_ordered_channel_configs, privkey_to_pubkey, get_per_commitment_secret_from_seed,
- RevocationStore, extract_ctn_from_tx_and_chan, UnableToDeriveSecret)
+ RevocationStore, extract_ctn_from_tx_and_chan, UnableToDeriveSecret, SENT, RECEIVED)
from .transaction import Transaction, TxOutput, construct_witness
from .simple_config import SimpleConfig, FEERATE_FALLBACK_STATIC_FEE
if TYPE_CHECKING:
- from .lnchan import Channel, UpdateAddHtlc
+ from .lnchan import Channel
def maybe_create_sweeptx_for_their_ctx_to_remote(ctx: Transaction, sweep_address: str,
@@ -106,7 +106,7 @@ def create_sweeptxs_for_their_just_revoked_ctx(chan: 'Channel', ctx: Transaction
ctn = extract_ctn_from_tx_and_chan(ctx, chan)
assert ctn == chan.config[REMOTE].ctn
# received HTLCs, in their ctx
- received_htlcs = chan.included_htlcs(REMOTE, LOCAL, False)
+ received_htlcs = chan.included_htlcs(REMOTE, RECEIVED, ctn)
for htlc in received_htlcs:
direct_sweep_tx, secondstage_sweep_tx, htlc_tx = create_sweeptx_for_htlc(htlc, is_received_htlc=True)
if direct_sweep_tx:
@@ -114,7 +114,7 @@ def create_sweeptxs_for_their_just_revoked_ctx(chan: 'Channel', ctx: Transaction
if secondstage_sweep_tx:
txs[htlc_tx.txid()] = secondstage_sweep_tx
# offered HTLCs, in their ctx
- offered_htlcs = chan.included_htlcs(REMOTE, REMOTE, False)
+ offered_htlcs = chan.included_htlcs(REMOTE, SENT, ctn)
for htlc in offered_htlcs:
direct_sweep_tx, secondstage_sweep_tx, htlc_tx = create_sweeptx_for_htlc(htlc, is_received_htlc=False)
if direct_sweep_tx:
@@ -181,16 +181,14 @@ def create_sweeptxs_for_our_latest_ctx(chan: 'Channel', ctx: Transaction,
is_revocation=False)
return htlc_tx, to_wallet_tx
# offered HTLCs, in our ctx --> "timeout"
- # TODO consider carefully if "included_htlcs" is what we need here
- offered_htlcs = list(chan.included_htlcs(LOCAL, LOCAL)) # type: List[UpdateAddHtlc]
+ # received HTLCs, in our ctx --> "success"
+ offered_htlcs = chan.included_htlcs(LOCAL, SENT, ctn) # type: List[UpdateAddHtlc]
+ received_htlcs = chan.included_htlcs(LOCAL, RECEIVED, ctn) # type: List[UpdateAddHtlc]
for htlc in offered_htlcs:
htlc_tx, to_wallet_tx = create_txns_for_htlc(htlc, is_received_htlc=False)
if htlc_tx and to_wallet_tx:
txs[to_wallet_tx.prevout(0)] = to_wallet_tx
txs[htlc_tx.prevout(0)] = htlc_tx
- # received HTLCs, in our ctx --> "success"
- # TODO consider carefully if "included_htlcs" is what we need here
- received_htlcs = list(chan.included_htlcs(LOCAL, REMOTE)) # type: List[UpdateAddHtlc]
for htlc in received_htlcs:
htlc_tx, to_wallet_tx = create_txns_for_htlc(htlc, is_received_htlc=True)
if htlc_tx and to_wallet_tx:
@@ -332,7 +330,7 @@ def create_htlctx_that_spends_from_our_ctx(chan: 'Channel', our_pcp: bytes,
htlc=htlc,
name=f'our_ctx_htlc_tx_{bh2u(htlc.payment_hash)}',
cltv_expiry=0 if is_received_htlc else htlc.cltv_expiry)
- remote_htlc_sig = chan.get_remote_htlc_sig_for_htlc(htlc, we_receive=is_received_htlc)
+ remote_htlc_sig = chan.get_remote_htlc_sig_for_htlc(htlc, we_receive=is_received_htlc, ctx=ctx)
local_htlc_sig = bfh(htlc_tx.sign_txin(0, local_htlc_privkey))
txin = htlc_tx.inputs()[0]
witness_program = bfh(Transaction.get_preimage_script(txin))
diff --git a/electrum/lnutil.py b/electrum/lnutil.py
@@ -21,7 +21,7 @@ from .lnaddr import lndecode
from .keystore import BIP32_KeyStore
if TYPE_CHECKING:
- from .lnchan import Channel, UpdateAddHtlc
+ from .lnchan import Channel
HTLC_TIMEOUT_WEIGHT = 663
@@ -35,7 +35,6 @@ OnlyPubkeyKeypair = namedtuple("OnlyPubkeyKeypair", ["pubkey"])
class LocalConfig(NamedTuple):
# shared channel config fields (DUPLICATED code!!)
ctn: int
- amount_msat: int
next_htlc_id: int
payment_basepoint: 'Keypair'
multisig_key: 'Keypair'
@@ -54,12 +53,12 @@ class LocalConfig(NamedTuple):
was_announced: bool
current_commitment_signature: Optional[bytes]
current_htlc_signatures: List[bytes]
+ got_sig_for_next: bool
class RemoteConfig(NamedTuple):
# shared channel config fields (DUPLICATED code!!)
ctn: int
- amount_msat: int
next_htlc_id: int
payment_basepoint: 'Keypair'
multisig_key: 'Keypair'
@@ -364,7 +363,7 @@ def make_htlc_tx_with_open_channel(chan: 'Channel', pcp: bytes, for_us: bool,
# FIXME handle htlc_address collision
# also: https://github.com/lightningnetwork/lightning-rfc/issues/448
prevout_idx = commit.get_output_idx_from_address(htlc_address)
- assert prevout_idx is not None
+ assert prevout_idx is not None, (htlc_address, commit.outputs(), extract_ctn_from_tx_and_chan(commit, chan))
htlc_tx_inputs = make_htlc_tx_inputs(
commit.txid(), prevout_idx,
amount_msat=amount_msat,
@@ -395,11 +394,16 @@ class HTLCOwner(IntFlag):
LOCAL = 1
REMOTE = -LOCAL
- SENT = LOCAL
- RECEIVED = REMOTE
+ def inverted(self):
+ return HTLCOwner(-self)
+
+class Direction(IntFlag):
+ SENT = 3
+ RECEIVED = 4
+
+SENT = Direction.SENT
+RECEIVED = Direction.RECEIVED
-SENT = HTLCOwner.SENT
-RECEIVED = HTLCOwner.RECEIVED
LOCAL = HTLCOwner.LOCAL
REMOTE = HTLCOwner.REMOTE
@@ -420,8 +424,7 @@ def make_commitment_outputs(fees_per_participant: Mapping[HTLCOwner, int], local
c_outputs_filtered = list(filter(lambda x: x.value >= dust_limit_sat, non_htlc_outputs + htlc_outputs))
return htlc_outputs, c_outputs_filtered
-def calc_onchain_fees(num_htlcs, feerate, for_us, we_are_initiator):
- we_pay_fee = for_us == we_are_initiator
+def calc_onchain_fees(num_htlcs, feerate, we_pay_fee):
overall_weight = 500 + 172 * num_htlcs + 224
fee = feerate * overall_weight
fee = fee // 1000 * 1000
@@ -451,7 +454,7 @@ def make_commitment(ctn, local_funding_pubkey, remote_funding_pubkey,
htlc_outputs, c_outputs_filtered = make_commitment_outputs(fees_per_participant, local_amount, remote_amount,
(bitcoin.TYPE_ADDRESS, local_address), (bitcoin.TYPE_ADDRESS, remote_address), htlcs, dust_limit_sat)
- assert sum(x.value for x in c_outputs_filtered) <= funding_sat
+ assert sum(x.value for x in c_outputs_filtered) <= funding_sat, (c_outputs_filtered, funding_sat)
# create commitment tx
tx = Transaction.from_io(c_inputs, c_outputs_filtered, locktime=locktime, version=2)
@@ -649,3 +652,20 @@ def format_short_channel_id(short_channel_id: Optional[bytes]):
return str(int.from_bytes(short_channel_id[:3], 'big')) \
+ 'x' + str(int.from_bytes(short_channel_id[3:6], 'big')) \
+ 'x' + str(int.from_bytes(short_channel_id[6:], 'big'))
+
+class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])):
+ """
+ This whole class body is so that if you pass a hex-string as payment_hash,
+ it is decoded to bytes. Bytes can't be saved to disk, so we save hex-strings.
+ """
+ __slots__ = ()
+ def __new__(cls, *args, **kwargs):
+ if len(args) > 0:
+ args = list(args)
+ if type(args[1]) is str:
+ args[1] = bfh(args[1])
+ return super().__new__(cls, *args)
+ if type(kwargs['payment_hash']) is str:
+ kwargs['payment_hash'] = bfh(kwargs['payment_hash'])
+ return super().__new__(cls, **kwargs)
+
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -29,13 +29,14 @@ from .lntransport import LNResponderTransport
from .lnbase import Peer
from .lnaddr import lnencode, LnAddr, lndecode
from .ecc import der_sig_from_sig_string
-from .lnchan import Channel, ChannelJsonEncoder, UpdateAddHtlc
+from .lnchan import Channel, ChannelJsonEncoder
from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
get_compressed_pubkey_from_bech32, extract_nodeid,
PaymentFailure, split_host_port, ConnStringFormatError,
generate_keypair, LnKeyFamily, LOCAL, REMOTE,
UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE,
- NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner)
+ NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner,
+ UpdateAddHtlc, Direction)
from .i18n import _
from .lnrouter import RouteEdge, is_route_sane_to_use
from .address_synchronizer import TX_HEIGHT_LOCAL
@@ -66,7 +67,7 @@ class LNWorker(PrintError):
def __init__(self, wallet: 'Abstract_Wallet', network: 'Network'):
self.wallet = wallet
# invoices we are currently trying to pay (might be pending HTLCs on a commitment transaction)
- self.paying = self.wallet.storage.get('lightning_payments_inflight', {}) # type: Dict[bytes, Tuple[str, Optional[int], bytes]]
+ self.paying = self.wallet.storage.get('lightning_payments_inflight', {}) # type: Dict[bytes, Tuple[str, Optional[int], str]]
self.sweep_address = wallet.get_receiving_address()
self.network = network
self.channel_db = self.network.channel_db
@@ -75,12 +76,15 @@ class LNWorker(PrintError):
self.node_keypair = generate_keypair(self.ln_keystore, LnKeyFamily.NODE_KEY, 0)
self.config = network.config
self.peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer
+ self.invoices = wallet.storage.get('lightning_invoices', {}) # type: Dict[str, Tuple[str,str]] # RHASH -> (preimage, invoice)
self.channels = {} # type: Dict[bytes, Channel]
for x in wallet.storage.get("channels", []):
c = Channel(x, sweep_address=self.sweep_address, payment_completed=self.payment_completed)
- self.channels[c.channel_id] = c
c.lnwatcher = network.lnwatcher
- self.invoices = wallet.storage.get('lightning_invoices', {}) # type: Dict[str, Tuple[str,str]] # RHASH -> (preimage, invoice)
+ c.get_preimage_and_invoice = self.get_invoice
+ self.channels[c.channel_id] = c
+ c.set_remote_commitment()
+ c.set_local_commitment(c.current_commitment(LOCAL))
for chan_id, chan in self.channels.items():
self.network.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str())
self._last_tried_peer = {} # LNPeerAddr -> unix timestamp
@@ -116,6 +120,7 @@ class LNWorker(PrintError):
self.print_error('saved lightning gossip timestamp')
def payment_completed(self, chan, direction, htlc, preimage):
+ assert type(direction) is Direction
chan_id = chan.channel_id
if direction == SENT:
assert htlc.payment_hash not in self.invoices
@@ -166,6 +171,7 @@ class LNWorker(PrintError):
unsettled = []
inflight = []
for date, direction, htlc, hex_preimage, hex_chan_id in completed:
+ direction = Direction(direction)
if chan_id is not None:
if bfh(hex_chan_id) != chan_id:
continue
@@ -175,12 +181,12 @@ class LNWorker(PrintError):
else:
preimage = bfh(hex_preimage)
# FIXME use fromisoformat when minimum Python is 3.7
- settled.append((datetime.fromtimestamp(date, timezone.utc), HTLCOwner(direction), htlcobj, preimage))
+ settled.append((datetime.fromtimestamp(date, timezone.utc), direction, htlcobj, preimage))
for preimage, pay_req in invoices.values():
addr = lndecode(pay_req, expected_hrp=constants.net.SEGWIT_HRP)
unsettled.append((addr, bfh(preimage), pay_req))
for pay_req, amount_sat, this_chan_id in self.paying.values():
- if chan_id is not None and this_chan_id != chan_id:
+ if chan_id is not None and bfh(this_chan_id) != chan_id:
continue
addr = lndecode(pay_req, expected_hrp=constants.net.SEGWIT_HRP)
if amount_sat is not None:
@@ -194,7 +200,7 @@ class LNWorker(PrintError):
def find_htlc_for_addr(self, addr, whitelist=None):
channels = [y for x,y in self.channels.items() if x in whitelist or whitelist is None]
for chan in channels:
- for htlc in chan.log[LOCAL].adds.values():
+ for htlc in chan.hm.log[LOCAL]['adds'].values():
if htlc.payment_hash == addr.paymenthash:
return htlc
@@ -319,7 +325,7 @@ class LNWorker(PrintError):
self.print_error('they force closed', funding_outpoint)
encumbered_sweeptxs = chan.remote_sweeptxs
else:
- self.print_error('not sure who closed', funding_outpoint)
+ self.print_error('not sure who closed', funding_outpoint, txid)
return
# sweep
for prevout, spender in spenders.items():
@@ -456,7 +462,7 @@ class LNWorker(PrintError):
break
else:
assert False, 'Found route with short channel ID we don\'t have: ' + repr(route[0].short_channel_id)
- self.paying[bh2u(addr.paymenthash)] = (invoice, amount_sat, chan_id)
+ self.paying[bh2u(addr.paymenthash)] = (invoice, amount_sat, bh2u(chan_id))
self.wallet.storage.put('lightning_payments_inflight', self.paying)
self.wallet.storage.write()
return addr, peer, self._pay_to_route(route, addr)
@@ -623,8 +629,8 @@ class LNWorker(PrintError):
# we output the funding_outpoint instead of the channel_id because lnd uses channel_point (funding outpoint) to identify channels
for channel_id, chan in self.channels.items():
yield {
- 'local_htlcs': json.loads(encoder.encode(chan.log[LOCAL ]._asdict())),
- 'remote_htlcs': json.loads(encoder.encode(chan.log[REMOTE]._asdict())),
+ 'local_htlcs': json.loads(encoder.encode(chan.hm.log[LOCAL ])),
+ 'remote_htlcs': json.loads(encoder.encode(chan.hm.log[REMOTE])),
'channel_id': bh2u(chan.short_channel_id),
'channel_point': chan.funding_outpoint.to_str(),
'state': chan.get_state(),
diff --git a/electrum/tests/test_lnchan.py b/electrum/tests/test_lnchan.py
@@ -22,6 +22,7 @@
import unittest
import os
import binascii
+from pprint import pformat
from electrum import bitcoin
from electrum import lnbase
@@ -30,6 +31,7 @@ from electrum import lnutil
from electrum import bip32 as bip32_utils
from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED
from electrum.ecc import sig_string_from_der_sig
+from electrum.util import set_verbosity
one_bitcoin_in_msat = bitcoin.COIN * 1000
@@ -54,9 +56,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
max_htlc_value_in_flight_msat=one_bitcoin_in_msat * 5,
max_accepted_htlcs=5,
initial_msat=remote_amount,
- ctn = 0,
+ ctn = -1,
next_htlc_id = 0,
- amount_msat=remote_amount,
reserve_sat=0,
next_per_commitment_point=nex,
@@ -76,7 +77,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
initial_msat=local_amount,
ctn = 0,
next_htlc_id = 0,
- amount_msat=local_amount,
reserve_sat=0,
per_commitment_secret_seed=seed,
@@ -84,6 +84,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
was_announced=False,
current_commitment_signature=None,
current_htlc_signatures=None,
+ got_sig_for_next=False,
),
"constraints":lnbase.ChannelConstraints(
capacity=funding_sat,
@@ -105,7 +106,7 @@ def bip32(sequence):
return k
def create_test_channels(feerate=6000, local=None, remote=None):
- funding_txid = binascii.hexlify(os.urandom(32)).decode("ascii")
+ funding_txid = binascii.hexlify(b"\x01"*32).decode("ascii")
funding_index = 0
funding_sat = ((local + remote) // 1000) if local is not None and remote is not None else (bitcoin.COIN * 10)
local_amount = local if local is not None else (funding_sat * 1000 // 2)
@@ -117,23 +118,52 @@ def create_test_channels(feerate=6000, local=None, remote=None):
alice_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in alice_privkeys]
bob_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in bob_privkeys]
- alice_seed = os.urandom(32)
- bob_seed = os.urandom(32)
+ alice_seed = b"\x01" * 32
+ bob_seed = b"\x02" * 32
- alice_cur = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX), "big"))
- alice_next = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX - 1), "big"))
- bob_cur = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX), "big"))
- bob_next = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX - 1), "big"))
+ alice_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX), "big"))
+ bob_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX), "big"))
alice, bob = \
lnchan.Channel(
- create_channel_state(funding_txid, funding_index, funding_sat, feerate, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, bob_cur, bob_next, b"\x02"*33, l_dust=200, r_dust=1300, l_csv=5, r_csv=4), name="alice"), \
+ create_channel_state(funding_txid, funding_index, funding_sat, feerate, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None, bob_first, b"\x02"*33, l_dust=200, r_dust=1300, l_csv=5, r_csv=4), name="alice"), \
lnchan.Channel(
- create_channel_state(funding_txid, funding_index, funding_sat, feerate, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, alice_cur, alice_next, b"\x01"*33, l_dust=1300, r_dust=200, l_csv=4, r_csv=5), name="bob")
+ create_channel_state(funding_txid, funding_index, funding_sat, feerate, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, None, alice_first, b"\x01"*33, l_dust=1300, r_dust=200, l_csv=4, r_csv=5), name="bob")
alice.set_state('OPEN')
bob.set_state('OPEN')
+ a_out = alice.current_commitment(LOCAL).outputs()
+ b_out = bob.pending_commitment(REMOTE).outputs()
+ assert a_out == b_out, "\n" + pformat((a_out, b_out))
+
+ sig_from_bob, a_htlc_sigs = bob.sign_next_commitment()
+ sig_from_alice, b_htlc_sigs = alice.sign_next_commitment()
+
+ assert len(a_htlc_sigs) == 0
+ assert len(b_htlc_sigs) == 0
+
+ alice.config[LOCAL] = alice.config[LOCAL]._replace(current_commitment_signature=sig_from_bob)
+ bob.config[LOCAL] = bob.config[LOCAL]._replace(current_commitment_signature=sig_from_alice)
+
+ alice.set_local_commitment(alice.current_commitment(LOCAL))
+ bob.set_local_commitment(bob.current_commitment(LOCAL))
+
+ alice_second = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX - 1), "big"))
+ bob_second = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX - 1), "big"))
+
+ alice.config[REMOTE] = alice.config[REMOTE]._replace(next_per_commitment_point=bob_second, current_per_commitment_point=bob_first)
+ bob.config[REMOTE] = bob.config[REMOTE]._replace(next_per_commitment_point=alice_second, current_per_commitment_point=alice_first)
+
+ alice.set_remote_commitment()
+ bob.set_remote_commitment()
+
+ alice.remote_commitment_to_be_revoked = alice.remote_commitment
+ bob.remote_commitment_to_be_revoked = bob.remote_commitment
+
+ alice.config[REMOTE] = alice.config[REMOTE]._replace(ctn=0)
+ bob.config[REMOTE] = bob.config[REMOTE]._replace(ctn=0)
+
return alice, bob
class TestFee(unittest.TestCase):
@@ -141,11 +171,13 @@ class TestFee(unittest.TestCase):
test
https://github.com/lightningnetwork/lightning-rfc/blob/e0c436bd7a3ed6a028e1cb472908224658a14eca/03-transactions.md#requirements-2
"""
- def test_SimpleAddSettleWorkflow(self):
+ def test_fee(self):
alice_channel, bob_channel = create_test_channels(253, 10000000000, 5000000000)
self.assertIn(9999817, [x[2] for x in alice_channel.local_commitment.outputs()])
class TestChannel(unittest.TestCase):
+ maxDiff = 999
+
def assertOutputExistsByValue(self, tx, amt_sat):
for typ, scr, val in tx.outputs():
if val == amt_sat:
@@ -153,6 +185,10 @@ class TestChannel(unittest.TestCase):
else:
self.assertFalse()
+ @staticmethod
+ def setUpClass():
+ set_verbosity(True)
+
def setUp(self):
# Create a test channel which will be used for the duration of this
# unittest. The channel will be funded evenly with Alice having 5 BTC,
@@ -171,12 +207,15 @@ class TestChannel(unittest.TestCase):
# update log. Then Alice sends this wire message over to Bob who adds
# this htlc to his remote state update log.
self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict)
+ self.assertNotEqual(self.alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED, 1), set())
before = self.bob_channel.balance_minus_outgoing_htlcs(REMOTE)
beforeLocal = self.bob_channel.balance_minus_outgoing_htlcs(LOCAL)
self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc_dict)
+ self.assertEqual(1, self.bob_channel.hm.log[LOCAL]['ctn'] + 1)
+ self.assertNotEqual(self.bob_channel.hm.htlcs_by_direction(LOCAL, RECEIVED, 1), set())
after = self.bob_channel.balance_minus_outgoing_htlcs(REMOTE)
afterLocal = self.bob_channel.balance_minus_outgoing_htlcs(LOCAL)
@@ -185,7 +224,7 @@ class TestChannel(unittest.TestCase):
self.bob_pending_remote_balance = after
- self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0]
+ self.htlc = self.bob_channel.hm.log[REMOTE]['adds'][0]
def test_concurrent_reversed_payment(self):
self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02')
@@ -193,32 +232,65 @@ class TestChannel(unittest.TestCase):
bob_idx = self.bob_channel.add_htlc(self.htlc_dict)
alice_idx = self.alice_channel.receive_htlc(self.htlc_dict)
self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment())
- self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 3)
+ self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 4)
def test_SimpleAddSettleWorkflow(self):
alice_channel, bob_channel = self.alice_channel, self.bob_channel
htlc = self.htlc
+ alice_out = alice_channel.current_commitment(LOCAL).outputs()
+ short_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 42]
+ long_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 62]
+ self.assertLess(alice_out[long_idx].value, 5 * 10**8, alice_out)
+ self.assertEqual(alice_out[short_idx].value, 5 * 10**8, alice_out)
+
+ alice_out = alice_channel.current_commitment(REMOTE).outputs()
+ short_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 42]
+ long_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 62]
+ self.assertLess(alice_out[short_idx].value, 5 * 10**8)
+ self.assertEqual(alice_out[long_idx].value, 5 * 10**8)
+
+ def com():
+ return alice_channel.local_commitment
+
+ self.assertTrue(alice_channel.signature_fits(com()))
+
+ self.assertNotEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 1), [])
self.assertEqual({0: [], 1: [htlc]}, alice_channel.included_htlcs_in_their_latest_ctxs(LOCAL))
- self.assertEqual({0: [], 1: []}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE))
+ self.assertNotEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), [])
+ self.assertEqual({0: [], 1: [htlc]}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE))
self.assertEqual({0: [], 1: []}, alice_channel.included_htlcs_in_their_latest_ctxs(REMOTE))
self.assertEqual({0: [], 1: []}, bob_channel.included_htlcs_in_their_latest_ctxs(LOCAL))
- # this wouldn't work since we put None in the remote_sig
- # alice_channel.force_close_tx()
+ from electrum.lnutil import extract_ctn_from_tx_and_chan
+ tx0 = str(alice_channel.force_close_tx())
+ self.assertEqual(alice_channel.config[LOCAL].ctn, 0)
+ self.assertEqual(extract_ctn_from_tx_and_chan(alice_channel.force_close_tx(), alice_channel), 0)
+ self.assertTrue(alice_channel.signature_fits(alice_channel.current_commitment(LOCAL)))
# Next alice commits this change by sending a signature message. Since
# we expect the messages to be ordered, Bob will receive the HTLC we
# just sent before he receives this signature, so the signature will
# cover the HTLC.
aliceSig, aliceHtlcSigs = alice_channel.sign_next_commitment()
-
self.assertEqual(len(aliceHtlcSigs), 1, "alice should generate one htlc signature")
+ self.assertTrue(alice_channel.signature_fits(com()))
+ self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
+
+ self.assertEqual(next(iter(alice_channel.hm.pending_htlcs(REMOTE)))[0], RECEIVED)
+ self.assertEqual(alice_channel.hm.pending_htlcs(REMOTE), bob_channel.hm.pending_htlcs(LOCAL))
+ self.assertEqual(alice_channel.pending_commitment(REMOTE).outputs(), bob_channel.pending_commitment(LOCAL).outputs())
+
# Bob receives this signature message, and checks that this covers the
# state he has in his remote log. This includes the HTLC just sent
# from Alice.
+ self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL)))
bob_channel.receive_new_commitment(aliceSig, aliceHtlcSigs)
+ self.assertTrue(bob_channel.signature_fits(bob_channel.pending_commitment(LOCAL)))
+
+ self.assertEqual(bob_channel.config[REMOTE].ctn, 0)
+ self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), [htlc])
self.assertEqual({0: [], 1: [htlc]}, alice_channel.included_htlcs_in_their_latest_ctxs(LOCAL))
self.assertEqual({0: [], 1: [htlc]}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE))
@@ -228,31 +300,68 @@ class TestChannel(unittest.TestCase):
# Bob revokes his prior commitment given to him by Alice, since he now
# has a valid signature for a newer commitment.
bobRevocation, _ = bob_channel.revoke_current_commitment()
+ bob_channel.serialize()
+ self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL)))
- # Bob finally send a signature for Alice's commitment transaction.
+ # Bob finally sends a signature for Alice's commitment transaction.
# This signature will cover the HTLC, since Bob will first send the
# revocation just created. The revocation also acks every received
- # HTLC up to the point where Alice sent here signature.
+ # HTLC up to the point where Alice sent her signature.
bobSig, bobHtlcSigs = bob_channel.sign_next_commitment()
+ self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL)))
+
+ self.assertEqual(len(bobHtlcSigs), 1)
+
+ self.assertTrue(alice_channel.signature_fits(com()))
+ self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
+
+ self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 3)
# Alice then processes this revocation, sending her own revocation for
# her prior commitment transaction. Alice shouldn't have any HTLCs to
# forward since she's sending an outgoing HTLC.
alice_channel.receive_revocation(bobRevocation)
+ alice_channel.serialize()
+ self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
+
+ self.assertTrue(alice_channel.signature_fits(com()))
+ self.assertTrue(alice_channel.signature_fits(alice_channel.current_commitment(LOCAL)))
+ alice_channel.serialize()
+ self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
- # test serializing with locked_in htlc
- self.assertEqual(len(alice_channel.to_save()['local_log']), 1)
+ self.assertEqual(len(alice_channel.current_commitment(LOCAL).outputs()), 2)
+ self.assertEqual(len(alice_channel.current_commitment(REMOTE).outputs()), 3)
+ self.assertEqual(len(com().outputs()), 2)
+ self.assertEqual(len(alice_channel.force_close_tx().outputs()), 2)
+
+ self.assertEqual(alice_channel.hm.log.keys(), set([LOCAL, REMOTE]))
+ self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1)
alice_channel.serialize()
+ self.assertEqual(alice_channel.pending_commitment(LOCAL).outputs(),
+ bob_channel.pending_commitment(REMOTE).outputs())
+
# Alice then processes bob's signature, and since she just received
# the revocation, she expect this signature to cover everything up to
# the point where she sent her signature, including the HTLC.
alice_channel.receive_new_commitment(bobSig, bobHtlcSigs)
+ self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
+
+ self.assertEqual(len(alice_channel.current_commitment(REMOTE).outputs()), 3)
+ self.assertEqual(len(com().outputs()), 3)
+ self.assertEqual(len(alice_channel.force_close_tx().outputs()), 3)
+
+ self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1)
+ alice_channel.serialize()
tx1 = str(alice_channel.force_close_tx())
+ self.assertNotEqual(tx0, tx1)
# Alice then generates a revocation for bob.
+ self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
aliceRevocation, _ = alice_channel.revoke_current_commitment()
+ alice_channel.serialize()
+ #self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
tx2 = str(alice_channel.force_close_tx())
# since alice already has the signature for the next one, it doesn't change her force close tx (it was already the newer one)
@@ -262,7 +371,9 @@ class TestChannel(unittest.TestCase):
# is fully locked in within both commitment transactions. Bob should
# also be able to forward an HTLC now that the HTLC has been locked
# into both commitment transactions.
+ self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL)))
bob_channel.receive_revocation(aliceRevocation)
+ bob_channel.serialize()
# At this point, both sides should have the proper number of satoshis
# sent, and commitment height updated within their local channel
@@ -279,16 +390,19 @@ class TestChannel(unittest.TestCase):
# Both commitment transactions should have three outputs, and one of
# them should be exactly the amount of the HTLC.
- self.assertEqual(len(alice_channel.local_commitment.outputs()), 3, "alice should have three commitment outputs, instead have %s"% len(alice_channel.local_commitment.outputs()))
- self.assertEqual(len(bob_channel.local_commitment.outputs()), 3, "bob should have three commitment outputs, instead have %s"% len(bob_channel.local_commitment.outputs()))
- self.assertOutputExistsByValue(alice_channel.local_commitment, htlc.amount_msat // 1000)
- self.assertOutputExistsByValue(bob_channel.local_commitment, htlc.amount_msat // 1000)
+ alice_ctx = alice_channel.pending_commitment(LOCAL)
+ bob_ctx = bob_channel.pending_commitment(LOCAL)
+ self.assertEqual(len(alice_ctx.outputs()), 3, "alice should have three commitment outputs, instead have %s"% len(alice_ctx.outputs()))
+ self.assertEqual(len(bob_ctx.outputs()), 3, "bob should have three commitment outputs, instead have %s"% len(bob_ctx.outputs()))
+ self.assertOutputExistsByValue(alice_ctx, htlc.amount_msat // 1000)
+ self.assertOutputExistsByValue(bob_ctx, htlc.amount_msat // 1000)
# Now we'll repeat a similar exchange, this time with Bob settling the
# HTLC once he learns of the preimage.
preimage = self.paymentPreimage
bob_channel.settle_htlc(preimage, self.bobHtlcIndex)
+ #self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
alice_channel.receive_htlc_settle(preimage, self.aliceHtlcIndex)
tx3 = str(alice_channel.force_close_tx())
@@ -296,28 +410,43 @@ class TestChannel(unittest.TestCase):
self.assertEqual(tx2, tx3)
bobSig2, bobHtlcSigs2 = bob_channel.sign_next_commitment()
+ self.assertEqual(len(bobHtlcSigs2), 0)
+ self.assertEqual(alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED), [htlc])
+ self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, alice_channel.config[REMOTE].ctn), [htlc])
self.assertEqual({1: [htlc], 2: []}, alice_channel.included_htlcs_in_their_latest_ctxs(LOCAL))
self.assertEqual({1: [htlc], 2: []}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE))
self.assertEqual({1: [], 2: []}, alice_channel.included_htlcs_in_their_latest_ctxs(REMOTE))
self.assertEqual({1: [], 2: []}, bob_channel.included_htlcs_in_their_latest_ctxs(LOCAL))
+ alice_ctx_bob_version = bob_channel.pending_commitment(REMOTE).outputs()
+ alice_ctx_alice_version = alice_channel.pending_commitment(LOCAL).outputs()
+ self.assertEqual(alice_ctx_alice_version, alice_ctx_bob_version)
+
alice_channel.receive_new_commitment(bobSig2, bobHtlcSigs2)
tx4 = str(alice_channel.force_close_tx())
self.assertNotEqual(tx3, tx4)
+ self.assertEqual(alice_channel.balance(LOCAL), 500000000000)
+ self.assertEqual(1, alice_channel.config[LOCAL].ctn)
+ self.assertEqual(len(alice_channel.included_htlcs(LOCAL, RECEIVED, ctn=2)), 0)
aliceRevocation2, _ = alice_channel.revoke_current_commitment()
+ alice_channel.serialize()
aliceSig2, aliceHtlcSigs2 = alice_channel.sign_next_commitment()
self.assertEqual(aliceHtlcSigs2, [], "alice should generate no htlc signatures")
-
+ self.assertEqual(len(bob_channel.current_commitment(LOCAL).outputs()), 3)
+ self.assertEqual(len(bob_channel.pending_commitment(LOCAL).outputs()), 2)
received, sent = bob_channel.receive_revocation(aliceRevocation2)
+ bob_channel.serialize()
self.assertEqual(received, one_bitcoin_in_msat)
bob_channel.receive_new_commitment(aliceSig2, aliceHtlcSigs2)
bobRevocation2, _ = bob_channel.revoke_current_commitment()
+ bob_channel.serialize()
alice_channel.receive_revocation(bobRevocation2)
+ alice_channel.serialize()
# At this point, Bob should have 6 BTC settled, with Alice still having
# 4 BTC. Alice's channel should show 1 BTC sent and Bob's channel
@@ -331,15 +460,15 @@ class TestChannel(unittest.TestCase):
self.assertEqual(bob_channel.current_height[LOCAL], 2, "bob has incorrect commitment height")
self.assertEqual(alice_channel.current_height[LOCAL], 2, "alice has incorrect commitment height")
- # The logs of both sides should now be cleared since the entry adding
- # the HTLC should have been removed once both sides receive the
- # revocation.
- #self.assertEqual(alice_channel.local_update_log, [], "alice's local not updated, should be empty, has %s entries instead"% len(alice_channel.local_update_log))
- #self.assertEqual(alice_channel.remote_update_log, [], "alice's remote not updated, should be empty, has %s entries instead"% len(alice_channel.remote_update_log))
self.assertEqual(self.bob_pending_remote_balance, self.alice_channel.balance(LOCAL))
alice_channel.update_fee(100000, True)
+ alice_outputs = alice_channel.pending_commitment(REMOTE).outputs()
+ old_outputs = bob_channel.pending_commitment(LOCAL).outputs()
bob_channel.update_fee(100000, False)
+ new_outputs = bob_channel.pending_commitment(LOCAL).outputs()
+ self.assertNotEqual(old_outputs, new_outputs)
+ self.assertEqual(alice_outputs, new_outputs)
tx5 = str(alice_channel.force_close_tx())
# sending a fee update does not change her force close tx
@@ -353,10 +482,17 @@ class TestChannel(unittest.TestCase):
self.htlc_dict['amount_msat'] *= 5
bob_index = bob_channel.add_htlc(self.htlc_dict)
alice_index = alice_channel.receive_htlc(self.htlc_dict)
- force_state_transition(alice_channel, bob_channel)
+
+ bob_channel.pending_commitment(REMOTE)
+ alice_channel.pending_commitment(LOCAL)
+
+ alice_channel.pending_commitment(REMOTE)
+ bob_channel.pending_commitment(LOCAL)
+
+ force_state_transition(bob_channel, alice_channel)
alice_channel.settle_htlc(self.paymentPreimage, alice_index)
bob_channel.receive_htlc_settle(self.paymentPreimage, bob_index)
- force_state_transition(alice_channel, bob_channel)
+ force_state_transition(bob_channel, alice_channel)
self.assertEqual(alice_channel.total_msat(SENT), one_bitcoin_in_msat, "alice satoshis sent incorrect")
self.assertEqual(alice_channel.total_msat(RECEIVED), 5 * one_bitcoin_in_msat, "alice satoshis received incorrect")
self.assertEqual(bob_channel.total_msat(RECEIVED), one_bitcoin_in_msat, "bob satoshis received incorrect")
@@ -366,8 +502,15 @@ class TestChannel(unittest.TestCase):
def alice_to_bob_fee_update(self, fee=111):
+ aoldctx = self.alice_channel.pending_commitment(REMOTE).outputs()
self.alice_channel.update_fee(fee, True)
+ anewctx = self.alice_channel.pending_commitment(REMOTE).outputs()
+ self.assertNotEqual(aoldctx, anewctx)
+ boldctx = self.bob_channel.pending_commitment(LOCAL).outputs()
self.bob_channel.update_fee(fee, False)
+ bnewctx = self.bob_channel.pending_commitment(LOCAL).outputs()
+ self.assertNotEqual(boldctx, bnewctx)
+ self.assertEqual(anewctx, bnewctx)
return fee
def test_UpdateFeeSenderCommits(self):
@@ -444,7 +587,7 @@ class TestChannel(unittest.TestCase):
# value 2 BTC, which should make Alice's balance negative (since she
# has to pay a commitment fee).
new = dict(self.htlc_dict)
- new['amount_msat'] *= 2
+ new['amount_msat'] *= 2.5
new['payment_hash'] = bitcoin.sha256(32 * b'\x04')
with self.assertRaises(lnutil.PaymentFailure) as cm:
self.alice_channel.add_htlc(new)
@@ -462,7 +605,6 @@ class TestChannel(unittest.TestCase):
except:
try:
from deepdiff import DeepDiff
- from pprint import pformat
except ImportError:
raise
raise Exception(pformat(DeepDiff(before_signing, after_signing)))
@@ -549,9 +691,9 @@ class TestChanReserve(unittest.TestCase):
force_state_transition(self.alice_channel, self.bob_channel)
aliceSelfBalance = self.alice_channel.balance(LOCAL)\
- - lnchan.htlcsum(self.alice_channel.htlcs(LOCAL, True))
+ - lnchan.htlcsum(self.alice_channel.hm.htlcs_by_direction(LOCAL, SENT))
bobBalance = self.bob_channel.balance(REMOTE)\
- - lnchan.htlcsum(self.alice_channel.htlcs(REMOTE, True))
+ - lnchan.htlcsum(self.alice_channel.hm.htlcs_by_direction(REMOTE, SENT))
self.assertEqual(aliceSelfBalance, one_bitcoin_in_msat*4.5)
self.assertEqual(bobBalance, one_bitcoin_in_msat*5)
# Now let Bob try to add an HTLC. This should fail, since it will
@@ -647,17 +789,22 @@ class TestDust(unittest.TestCase):
'cltv_expiry' : 5, # also in create_test_channels
}
+ old_values = [x.value for x in bob_channel.current_commitment(LOCAL).outputs() ]
aliceHtlcIndex = alice_channel.add_htlc(htlc)
bobHtlcIndex = bob_channel.receive_htlc(htlc)
force_state_transition(alice_channel, bob_channel)
- self.assertEqual(len(alice_channel.local_commitment.outputs()), 3)
- self.assertEqual(len(bob_channel.local_commitment.outputs()), 2)
+ alice_ctx = alice_channel.current_commitment(LOCAL)
+ bob_ctx = bob_channel.current_commitment(LOCAL)
+ new_values = [x.value for x in bob_ctx.outputs() ]
+ self.assertNotEqual(old_values, new_values)
+ self.assertEqual(len(alice_ctx.outputs()), 3)
+ self.assertEqual(len(bob_ctx.outputs()), 2)
default_fee = calc_static_fee(0)
self.assertEqual(bob_channel.pending_local_fee(), default_fee + htlcAmt)
bob_channel.settle_htlc(paymentPreimage, bobHtlcIndex)
alice_channel.receive_htlc_settle(paymentPreimage, aliceHtlcIndex)
force_state_transition(bob_channel, alice_channel)
- self.assertEqual(len(alice_channel.local_commitment.outputs()), 2)
+ self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 2)
self.assertEqual(alice_channel.total_msat(SENT) // 1000, htlcAmt)
def force_state_transition(chanA, chanB):
diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py
@@ -0,0 +1,95 @@
+import unittest
+from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner
+from electrum.lnhtlc import HTLCManager
+from typing import NamedTuple
+
+class H(NamedTuple):
+ owner : str
+ htlc_id : int
+
+class TestHTLCManager(unittest.TestCase):
+ def test_race(self):
+ A = HTLCManager()
+ B = HTLCManager()
+ ah0, bh0 = H('A', 0), H('B', 0)
+ B.recv_htlc(A.send_htlc(ah0))
+ self.assertTrue(B.expect_sig[RECEIVED])
+ self.assertTrue(A.expect_sig[SENT])
+ self.assertFalse(B.expect_sig[SENT])
+ self.assertFalse(A.expect_sig[RECEIVED])
+ self.assertEqual(B.log[REMOTE]['locked_in'][0][LOCAL], 1)
+ A.recv_htlc(B.send_htlc(bh0))
+ self.assertTrue(B.expect_sig[RECEIVED])
+ self.assertTrue(A.expect_sig[SENT])
+ self.assertTrue(A.expect_sig[SENT])
+ self.assertTrue(B.expect_sig[RECEIVED])
+ self.assertEqual(B.current_htlcs(LOCAL), [])
+ self.assertEqual(A.current_htlcs(LOCAL), [])
+ self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0)])
+ self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0)])
+ A.send_ctx()
+ B.recv_ctx()
+ B.send_ctx()
+ A.recv_ctx()
+ self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
+ self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
+ B.send_rev()
+ A.recv_rev()
+ A.send_rev()
+ B.recv_rev()
+ self.assertEqual(B.current_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
+ self.assertEqual(A.current_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
+
+ def test_no_race(self):
+ A = HTLCManager()
+ B = HTLCManager()
+ B.recv_htlc(A.send_htlc(H('A', 0)))
+ self.assertEqual(len(B.pending_htlcs(REMOTE)), 1)
+ A.send_ctx()
+ B.recv_ctx()
+ B.send_rev()
+ A.recv_rev()
+ B.send_ctx()
+ A.recv_ctx()
+ A.send_rev()
+ B.recv_rev()
+ self.assertEqual(len(A.current_htlcs(LOCAL)), 1)
+ self.assertEqual(len(B.current_htlcs(LOCAL)), 1)
+ B.send_settle(0)
+ A.recv_settle(0)
+ self.assertEqual(A.htlcs_by_direction(REMOTE, RECEIVED), [H('A', 0)])
+ self.assertNotEqual(A.current_htlcs(LOCAL), [])
+ self.assertNotEqual(B.current_htlcs(REMOTE), [])
+ self.assertEqual(A.pending_htlcs(LOCAL), [])
+ self.assertEqual(B.pending_htlcs(REMOTE), [])
+ B.send_ctx()
+ A.recv_ctx()
+ A.send_rev()
+ B.recv_rev()
+ A.send_ctx()
+ B.recv_ctx()
+ B.send_rev()
+ A.recv_rev()
+ self.assertEqual(B.current_htlcs(LOCAL), [])
+ self.assertEqual(A.current_htlcs(LOCAL), [])
+ self.assertEqual(A.current_htlcs(REMOTE), [])
+ self.assertEqual(B.current_htlcs(REMOTE), [])
+ self.assertEqual(len(A.settled_htlcs(LOCAL)), 1)
+ self.assertEqual(len(A.sent_in_ctn(2)), 1)
+ self.assertEqual(len(B.received_in_ctn(2)), 1)
+
+ def test_settle_while_owing(self):
+ A = HTLCManager()
+ B = HTLCManager()
+ B.recv_htlc(A.send_htlc(H('A', 0)))
+ A.send_ctx()
+ B.recv_ctx()
+ B.send_rev()
+ A.recv_rev()
+ B.send_settle(0)
+ A.recv_settle(0)
+ self.assertEqual(B.pending_htlcs(REMOTE), [])
+ B.send_ctx()
+ A.recv_ctx()
+ A.send_rev()
+ B.recv_rev()
diff --git a/electrum/tests/test_lnutil.py b/electrum/tests/test_lnutil.py
@@ -6,8 +6,7 @@ from electrum.lnutil import (RevocationStore, get_per_commitment_secret_from_see
make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, derive_privkey,
derive_pubkey, make_htlc_tx, extract_ctn_from_tx, UnableToDeriveSecret,
get_compressed_pubkey_from_bech32, split_host_port, ConnStringFormatError,
- ScriptHtlc, extract_nodeid, calc_onchain_fees)
-from electrum import lnchan
+ ScriptHtlc, extract_nodeid, calc_onchain_fees, UpdateAddHtlc)
from electrum.util import bh2u, bfh
from electrum.transaction import Transaction
@@ -496,7 +495,7 @@ class TestLNUtil(unittest.TestCase):
(1, 2000 * 1000),
(3, 3000 * 1000),
(4, 4000 * 1000)]:
- htlc_obj[num] = lnchan.UpdateAddHtlc(amount_msat=msat, payment_hash=bitcoin.sha256(htlc_payment_preimage[num]), cltv_expiry=None, htlc_id=None)
+ htlc_obj[num] = UpdateAddHtlc(amount_msat=msat, payment_hash=bitcoin.sha256(htlc_payment_preimage[num]), cltv_expiry=None, htlc_id=None)
htlcs = [ScriptHtlc(htlc[x], htlc_obj[x]) for x in range(5)]
our_commit_tx = make_commitment(
@@ -506,7 +505,7 @@ class TestLNUtil(unittest.TestCase):
local_revocation_pubkey, local_delayedpubkey, local_delay,
funding_tx_id, funding_output_index, funding_amount_satoshi,
to_local_msat, to_remote_msat, local_dust_limit_satoshi,
- calc_onchain_fees(len(htlcs), local_feerate_per_kw, True, we_are_initiator=True), htlcs=htlcs)
+ calc_onchain_fees(len(htlcs), local_feerate_per_kw, True), htlcs=htlcs)
self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey)
self.assertEqual(str(our_commit_tx), output_commit_tx)
@@ -584,7 +583,7 @@ class TestLNUtil(unittest.TestCase):
local_revocation_pubkey, local_delayedpubkey, local_delay,
funding_tx_id, funding_output_index, funding_amount_satoshi,
to_local_msat, to_remote_msat, local_dust_limit_satoshi,
- calc_onchain_fees(0, local_feerate_per_kw, True, we_are_initiator=True), htlcs=[])
+ calc_onchain_fees(0, local_feerate_per_kw, True), htlcs=[])
self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey)
self.assertEqual(str(our_commit_tx), output_commit_tx)
@@ -603,7 +602,7 @@ class TestLNUtil(unittest.TestCase):
local_revocation_pubkey, local_delayedpubkey, local_delay,
funding_tx_id, funding_output_index, funding_amount_satoshi,
to_local_msat, to_remote_msat, local_dust_limit_satoshi,
- calc_onchain_fees(0, local_feerate_per_kw, True, we_are_initiator=True), htlcs=[])
+ calc_onchain_fees(0, local_feerate_per_kw, True), htlcs=[])
self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey)
self.assertEqual(str(our_commit_tx), output_commit_tx)
@@ -661,7 +660,7 @@ class TestLNUtil(unittest.TestCase):
local_revocation_pubkey, local_delayedpubkey, local_delay,
funding_tx_id, funding_output_index, funding_amount_satoshi,
to_local_msat, to_remote_msat, local_dust_limit_satoshi,
- calc_onchain_fees(0, local_feerate_per_kw, True, we_are_initiator=True), htlcs=[])
+ calc_onchain_fees(0, local_feerate_per_kw, True), htlcs=[])
self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey)
ref_commit_tx_str = '02000000000101bef67e4e2fb9ddeeb3461973cd4c62abb35050b1add772995b820b584a488489000000000038b02b8002c0c62d0000000000160014ccf1af2f2aabee14bb40fa3851ab2301de84311054a56a00000000002200204adb4e2f00643db396dd120d4e7dc17625f5f2c11a40d857accc862d6b7dd80e0400473044022051b75c73198c6deee1a875871c3961832909acd297c6b908d59e3319e5185a46022055c419379c5051a78d00dbbce11b5b664a0c22815fbcc6fcef6b1937c383693901483045022100f51d2e566a70ba740fc5d8c0f07b9b93d2ed741c3c0860c613173de7d39e7968022041376d520e9c0e1ad52248ddf4b22e12be8763007df977253ef45a4ca3bdb7c001475221023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb21030e9f7b623d2ccc7c9bd44d66d5ce21ce504c0acf6385a132cec6d3c39fa711c152ae3e195220'
self.assertEqual(str(our_commit_tx), ref_commit_tx_str)