electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 72187a43416831306fd13b068b7341a3d0c05003
parent 001bb4ca0983b7f420f156781c4918ffc024ecbf
Author: Janus <ysangkok@gmail.com>
Date:   Fri, 26 Oct 2018 17:05:03 +0200

lnchan: make sign_next_commitment revert state

Diffstat:
Melectrum/lnchan.py | 80+++++++++++++++++++++++++++++++++++++++++++++----------------------------------
Melectrum/tests/test_lnchan.py | 17+++++++++++++++++
2 files changed, 63 insertions(+), 34 deletions(-)

diff --git a/electrum/lnchan.py b/electrum/lnchan.py @@ -27,6 +27,7 @@ import binascii import json from enum import Enum, auto from typing import Optional, Dict, List, Tuple +from copy import deepcopy from .util import bfh, PrintError, bh2u from .bitcoin import TYPE_SCRIPT, TYPE_ADDRESS @@ -79,21 +80,20 @@ class FeeUpdate(defaultdict): return self.rate # implicit return None -class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'locked_in', 'htlc_id'])): +class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])): + """ + This whole class body is so that if you pass a hex-string as payment_hash, + it is decoded to bytes. Bytes can't be saved to disk, so we save hex-strings. + """ __slots__ = () def __new__(cls, *args, **kwargs): if len(args) > 0: args = list(args) if type(args[1]) is str: args[1] = bfh(args[1]) - args[3] = {HTLCOwner(int(x)): y for x,y in args[3].items()} return super().__new__(cls, *args) if type(kwargs['payment_hash']) is str: kwargs['payment_hash'] = bfh(kwargs['payment_hash']) - if 'locked_in' not in kwargs: - kwargs['locked_in'] = {LOCAL: None, REMOTE: None} - else: - kwargs['locked_in'] = {HTLCOwner(int(x)): y for x,y in kwargs['locked_in'].items()} return super().__new__(cls, **kwargs) def decodeAll(d, local): @@ -162,6 +162,7 @@ class Channel(PrintError): 'adds': {}, # Dict[HTLC_ID, UpdateAddHtlc] 'settles': [], # List[HTLC_ID] 'fails': [], # List[HTLC_ID] + 'locked_in': [], # List[HTLC_ID] } self.log = {LOCAL: template(), REMOTE: template()} for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]: @@ -269,7 +270,8 @@ class Channel(PrintError): This docstring was adapted from LND. """ self.print_error("sign_next_commitment") - self.lock_in_htlc_changes(LOCAL) + + old_logs = dict(self.lock_in_htlc_changes(LOCAL)) pending_remote_commitment = self.pending_remote_commitment sig_64 = sign_and_get_sig_string(pending_remote_commitment, self.config[LOCAL], self.config[REMOTE]) @@ -290,29 +292,28 @@ class Channel(PrintError): htlc_sig = ecc.sig_string_from_der_sig(sig[:-1]) htlcsigs.append((pending_remote_commitment.htlc_output_indices[htlc.payment_hash], htlc_sig)) - for pending_fee in self.fee_mgr: - if not self.constraints.is_initiator: - pending_fee[FUNDEE_SIGNED] = True - if self.constraints.is_initiator and pending_fee[FUNDEE_ACKED]: - pending_fee[FUNDER_SIGNED] = True - self.process_new_offchain_ctx(pending_remote_commitment, ours=False) htlcsigs.sort() htlcsigs = [x[1] for x in htlcsigs] + # we can't know if this message arrives. + # since we shouldn't actually throw away + # failed htlcs yet (or mark htlc locked in), + # roll back the changes that were made + self.log = old_logs + return sig_64, htlcsigs def lock_in_htlc_changes(self, subject): for sub in (LOCAL, REMOTE): - for htlc_id in self.log[-sub]['fails']: - adds = self.log[sub]['adds'] - htlc = adds.pop(htlc_id) - self.log[-sub]['fails'].clear() + log = self.log[sub] + yield (sub, deepcopy(log)) + for htlc_id in log['fails']: + log['adds'].pop(htlc_id) + log['fails'].clear() - for htlc in self.log[subject]['adds'].values(): - if htlc.locked_in[subject] is None: - htlc.locked_in[subject] = self.config[subject].ctn + self.log[subject]['locked_in'] |= self.log[subject]['adds'].keys() def receive_new_commitment(self, sig, htlc_sigs): """ @@ -328,7 +329,9 @@ class Channel(PrintError): This docstring is from LND. """ self.print_error("receive_new_commitment") - self.lock_in_htlc_changes(REMOTE) + + for _ in self.lock_in_htlc_changes(REMOTE): pass + assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes pending_local_commitment = self.pending_local_commitment @@ -443,11 +446,20 @@ class Channel(PrintError): def receive_revocation(self, revocation) -> Tuple[int, int]: self.print_error("receive_revocation") + old_logs = dict(self.lock_in_htlc_changes(LOCAL)) + cur_point = self.config[REMOTE].current_per_commitment_point derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True) if cur_point != derived_point: + self.log = old_logs raise Exception('revoked secret not for current point') + for pending_fee in self.fee_mgr: + if not self.constraints.is_initiator: + pending_fee[FUNDEE_SIGNED] = True + if self.constraints.is_initiator and pending_fee[FUNDEE_ACKED]: + pending_fee[FUNDER_SIGNED] = True + # FIXME not sure this is correct... but it seems to work # if there are update_add_htlc msgs between commitment_signed and rev_ack, # this might break @@ -462,11 +474,11 @@ class Channel(PrintError): """ old_amount = htlcsum(self.htlcs(subject, False)) - for htlc_id in self.log[-subject]['settles']: + for htlc_id in self.log[subject]['settles']: adds = self.log[subject]['adds'] htlc = adds.pop(htlc_id) self.settled[subject].append(htlc.amount_msat) - self.log[-subject]['settles'].clear() + self.log[subject]['settles'].clear() return old_amount - htlcsum(self.htlcs(subject, False)) @@ -588,13 +600,12 @@ class Channel(PrintError): only_pending: require the htlc's settlement to be pending (needs additional signatures/acks) """ update_log = self.log[subject] - other_log = self.log[-subject] res = [] for htlc in update_log['adds'].values(): - locked_in = htlc.locked_in[subject] - settled = htlc.htlc_id in other_log['settles'] - failed = htlc.htlc_id in other_log['fails'] - if locked_in is None: + locked_in = htlc.htlc_id in update_log['locked_in'] + settled = htlc.htlc_id in update_log['settles'] + failed = htlc.htlc_id in update_log['fails'] + if not locked_in: continue if only_pending == (settled or failed): continue @@ -608,23 +619,23 @@ class Channel(PrintError): self.print_error("settle_htlc") htlc = self.log[REMOTE]['adds'][htlc_id] assert htlc.payment_hash == sha256(preimage) - self.log[LOCAL]['settles'].append(htlc_id) + self.log[REMOTE]['settles'].append(htlc_id) # not saving preimage because it's already saved in LNWorker.invoices def receive_htlc_settle(self, preimage, htlc_id): self.print_error("receive_htlc_settle") htlc = self.log[LOCAL]['adds'][htlc_id] assert htlc.payment_hash == sha256(preimage) - self.log[REMOTE]['settles'].append(htlc_id) + self.log[LOCAL]['settles'].append(htlc_id) # we don't save the preimage because we don't need to forward it anyway def fail_htlc(self, htlc_id): self.print_error("fail_htlc") - self.log[LOCAL]['fails'].append(htlc_id) + self.log[REMOTE]['fails'].append(htlc_id) def receive_fail_htlc(self, htlc_id): self.print_error("receive_fail_htlc") - self.log[REMOTE]['fails'].append(htlc_id) + self.log[LOCAL]['fails'].append(htlc_id) @property def current_height(self): @@ -654,8 +665,9 @@ class Channel(PrintError): """ removed = [] htlcs = [] - for i in self.log[subject]['adds'].values(): - locked_in = i.locked_in[LOCAL] is not None or i.locked_in[REMOTE] is not None + log = self.log[subject] + for htlc_id, i in log['adds'].items(): + locked_in = htlc_id in log['locked_in'] if locked_in: htlcs.append(i._asdict()) else: diff --git a/electrum/tests/test_lnchan.py b/electrum/tests/test_lnchan.py @@ -396,6 +396,23 @@ class TestChannel(unittest.TestCase): self.alice_channel.add_htlc(new) self.assertIn('Not enough local balance', cm.exception.args[0]) + def test_sign_commitment_is_pure(self): + force_state_transition(self.alice_channel, self.bob_channel) + self.htlc_dict['payment_hash'] = bitcoin.sha256(b'\x02' * 32) + aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict) + before_signing = self.alice_channel.to_save() + self.alice_channel.sign_next_commitment() + after_signing = self.alice_channel.to_save() + try: + self.assertEqual(before_signing, after_signing) + except: + try: + from deepdiff import DeepDiff + from pprint import pformat + except ImportError: + raise + raise Exception(pformat(DeepDiff(before_signing, after_signing))) + class TestAvailableToSpend(unittest.TestCase): def test_DesyncHTLCs(self): alice_channel, bob_channel = create_test_channels()