commit 72187a43416831306fd13b068b7341a3d0c05003
parent 001bb4ca0983b7f420f156781c4918ffc024ecbf
Author: Janus <ysangkok@gmail.com>
Date: Fri, 26 Oct 2018 17:05:03 +0200
lnchan: make sign_next_commitment revert state
Diffstat:
2 files changed, 63 insertions(+), 34 deletions(-)
diff --git a/electrum/lnchan.py b/electrum/lnchan.py
@@ -27,6 +27,7 @@ import binascii
import json
from enum import Enum, auto
from typing import Optional, Dict, List, Tuple
+from copy import deepcopy
from .util import bfh, PrintError, bh2u
from .bitcoin import TYPE_SCRIPT, TYPE_ADDRESS
@@ -79,21 +80,20 @@ class FeeUpdate(defaultdict):
return self.rate
# implicit return None
-class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'locked_in', 'htlc_id'])):
+class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])):
+ """
+ This whole class body is so that if you pass a hex-string as payment_hash,
+ it is decoded to bytes. Bytes can't be saved to disk, so we save hex-strings.
+ """
__slots__ = ()
def __new__(cls, *args, **kwargs):
if len(args) > 0:
args = list(args)
if type(args[1]) is str:
args[1] = bfh(args[1])
- args[3] = {HTLCOwner(int(x)): y for x,y in args[3].items()}
return super().__new__(cls, *args)
if type(kwargs['payment_hash']) is str:
kwargs['payment_hash'] = bfh(kwargs['payment_hash'])
- if 'locked_in' not in kwargs:
- kwargs['locked_in'] = {LOCAL: None, REMOTE: None}
- else:
- kwargs['locked_in'] = {HTLCOwner(int(x)): y for x,y in kwargs['locked_in'].items()}
return super().__new__(cls, **kwargs)
def decodeAll(d, local):
@@ -162,6 +162,7 @@ class Channel(PrintError):
'adds': {}, # Dict[HTLC_ID, UpdateAddHtlc]
'settles': [], # List[HTLC_ID]
'fails': [], # List[HTLC_ID]
+ 'locked_in': [], # List[HTLC_ID]
}
self.log = {LOCAL: template(), REMOTE: template()}
for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
@@ -269,7 +270,8 @@ class Channel(PrintError):
This docstring was adapted from LND.
"""
self.print_error("sign_next_commitment")
- self.lock_in_htlc_changes(LOCAL)
+
+ old_logs = dict(self.lock_in_htlc_changes(LOCAL))
pending_remote_commitment = self.pending_remote_commitment
sig_64 = sign_and_get_sig_string(pending_remote_commitment, self.config[LOCAL], self.config[REMOTE])
@@ -290,29 +292,28 @@ class Channel(PrintError):
htlc_sig = ecc.sig_string_from_der_sig(sig[:-1])
htlcsigs.append((pending_remote_commitment.htlc_output_indices[htlc.payment_hash], htlc_sig))
- for pending_fee in self.fee_mgr:
- if not self.constraints.is_initiator:
- pending_fee[FUNDEE_SIGNED] = True
- if self.constraints.is_initiator and pending_fee[FUNDEE_ACKED]:
- pending_fee[FUNDER_SIGNED] = True
-
self.process_new_offchain_ctx(pending_remote_commitment, ours=False)
htlcsigs.sort()
htlcsigs = [x[1] for x in htlcsigs]
+ # we can't know if this message arrives.
+ # since we shouldn't actually throw away
+ # failed htlcs yet (or mark htlc locked in),
+ # roll back the changes that were made
+ self.log = old_logs
+
return sig_64, htlcsigs
def lock_in_htlc_changes(self, subject):
for sub in (LOCAL, REMOTE):
- for htlc_id in self.log[-sub]['fails']:
- adds = self.log[sub]['adds']
- htlc = adds.pop(htlc_id)
- self.log[-sub]['fails'].clear()
+ log = self.log[sub]
+ yield (sub, deepcopy(log))
+ for htlc_id in log['fails']:
+ log['adds'].pop(htlc_id)
+ log['fails'].clear()
- for htlc in self.log[subject]['adds'].values():
- if htlc.locked_in[subject] is None:
- htlc.locked_in[subject] = self.config[subject].ctn
+ self.log[subject]['locked_in'] |= self.log[subject]['adds'].keys()
def receive_new_commitment(self, sig, htlc_sigs):
"""
@@ -328,7 +329,9 @@ class Channel(PrintError):
This docstring is from LND.
"""
self.print_error("receive_new_commitment")
- self.lock_in_htlc_changes(REMOTE)
+
+ for _ in self.lock_in_htlc_changes(REMOTE): pass
+
assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes
pending_local_commitment = self.pending_local_commitment
@@ -443,11 +446,20 @@ class Channel(PrintError):
def receive_revocation(self, revocation) -> Tuple[int, int]:
self.print_error("receive_revocation")
+ old_logs = dict(self.lock_in_htlc_changes(LOCAL))
+
cur_point = self.config[REMOTE].current_per_commitment_point
derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True)
if cur_point != derived_point:
+ self.log = old_logs
raise Exception('revoked secret not for current point')
+ for pending_fee in self.fee_mgr:
+ if not self.constraints.is_initiator:
+ pending_fee[FUNDEE_SIGNED] = True
+ if self.constraints.is_initiator and pending_fee[FUNDEE_ACKED]:
+ pending_fee[FUNDER_SIGNED] = True
+
# FIXME not sure this is correct... but it seems to work
# if there are update_add_htlc msgs between commitment_signed and rev_ack,
# this might break
@@ -462,11 +474,11 @@ class Channel(PrintError):
"""
old_amount = htlcsum(self.htlcs(subject, False))
- for htlc_id in self.log[-subject]['settles']:
+ for htlc_id in self.log[subject]['settles']:
adds = self.log[subject]['adds']
htlc = adds.pop(htlc_id)
self.settled[subject].append(htlc.amount_msat)
- self.log[-subject]['settles'].clear()
+ self.log[subject]['settles'].clear()
return old_amount - htlcsum(self.htlcs(subject, False))
@@ -588,13 +600,12 @@ class Channel(PrintError):
only_pending: require the htlc's settlement to be pending (needs additional signatures/acks)
"""
update_log = self.log[subject]
- other_log = self.log[-subject]
res = []
for htlc in update_log['adds'].values():
- locked_in = htlc.locked_in[subject]
- settled = htlc.htlc_id in other_log['settles']
- failed = htlc.htlc_id in other_log['fails']
- if locked_in is None:
+ locked_in = htlc.htlc_id in update_log['locked_in']
+ settled = htlc.htlc_id in update_log['settles']
+ failed = htlc.htlc_id in update_log['fails']
+ if not locked_in:
continue
if only_pending == (settled or failed):
continue
@@ -608,23 +619,23 @@ class Channel(PrintError):
self.print_error("settle_htlc")
htlc = self.log[REMOTE]['adds'][htlc_id]
assert htlc.payment_hash == sha256(preimage)
- self.log[LOCAL]['settles'].append(htlc_id)
+ self.log[REMOTE]['settles'].append(htlc_id)
# not saving preimage because it's already saved in LNWorker.invoices
def receive_htlc_settle(self, preimage, htlc_id):
self.print_error("receive_htlc_settle")
htlc = self.log[LOCAL]['adds'][htlc_id]
assert htlc.payment_hash == sha256(preimage)
- self.log[REMOTE]['settles'].append(htlc_id)
+ self.log[LOCAL]['settles'].append(htlc_id)
# we don't save the preimage because we don't need to forward it anyway
def fail_htlc(self, htlc_id):
self.print_error("fail_htlc")
- self.log[LOCAL]['fails'].append(htlc_id)
+ self.log[REMOTE]['fails'].append(htlc_id)
def receive_fail_htlc(self, htlc_id):
self.print_error("receive_fail_htlc")
- self.log[REMOTE]['fails'].append(htlc_id)
+ self.log[LOCAL]['fails'].append(htlc_id)
@property
def current_height(self):
@@ -654,8 +665,9 @@ class Channel(PrintError):
"""
removed = []
htlcs = []
- for i in self.log[subject]['adds'].values():
- locked_in = i.locked_in[LOCAL] is not None or i.locked_in[REMOTE] is not None
+ log = self.log[subject]
+ for htlc_id, i in log['adds'].items():
+ locked_in = htlc_id in log['locked_in']
if locked_in:
htlcs.append(i._asdict())
else:
diff --git a/electrum/tests/test_lnchan.py b/electrum/tests/test_lnchan.py
@@ -396,6 +396,23 @@ class TestChannel(unittest.TestCase):
self.alice_channel.add_htlc(new)
self.assertIn('Not enough local balance', cm.exception.args[0])
+ def test_sign_commitment_is_pure(self):
+ force_state_transition(self.alice_channel, self.bob_channel)
+ self.htlc_dict['payment_hash'] = bitcoin.sha256(b'\x02' * 32)
+ aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict)
+ before_signing = self.alice_channel.to_save()
+ self.alice_channel.sign_next_commitment()
+ after_signing = self.alice_channel.to_save()
+ try:
+ self.assertEqual(before_signing, after_signing)
+ except:
+ try:
+ from deepdiff import DeepDiff
+ from pprint import pformat
+ except ImportError:
+ raise
+ raise Exception(pformat(DeepDiff(before_signing, after_signing)))
+
class TestAvailableToSpend(unittest.TestCase):
def test_DesyncHTLCs(self):
alice_channel, bob_channel = create_test_channels()