commit 62be1cc367cef5838605ca5d7fd714a7f6333972
parent 5c05c06bf08221d0fa9f9868ef68a35ddb85d1bf
Author: SomberNight <somber.night@protonmail.com>
Date: Wed, 6 May 2020 03:15:20 +0200
small clean-up re "extract preimage from on-chain htlc_tx"
related: #6122
Diffstat:
3 files changed, 49 insertions(+), 40 deletions(-)
diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py
@@ -28,6 +28,7 @@ from typing import (Optional, Dict, List, Tuple, NamedTuple, Set, Callable,
import time
import threading
from abc import ABC, abstractmethod
+import itertools
from aiorpcx import NetAddress
import attr
@@ -37,7 +38,7 @@ from . import constants, util
from .util import bfh, bh2u, chunks, TxMinedInfo, PR_PAID
from .bitcoin import redeem_script_to_address
from .crypto import sha256, sha256d
-from .transaction import Transaction, PartialTransaction
+from .transaction import Transaction, PartialTransaction, TxInput
from .logging import Logger
from .lnonion import decode_onion_error, OnionFailureCode, OnionRoutingFailureMessage
from . import lnutil
@@ -971,46 +972,37 @@ class Channel(AbstractChannel):
failure_message = OnionRoutingFailureMessage.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None
return error_bytes, failure_message
- def extract_preimage_from_htlc_tx(self, tx):
- for _input in tx.inputs():
- witness = _input.witness_elements()
- if len(witness) == 5:
- preimage = witness[3]
- elif len(witness) == 3:
- preimage = witness[1]
- else:
- continue
- payment_hash = sha256(preimage)
- for direction, htlc in self.hm.get_htlcs_in_oldest_unrevoked_ctx(REMOTE):
+ def extract_preimage_from_htlc_txin(self, txin: TxInput) -> None:
+ witness = txin.witness_elements()
+ if len(witness) == 5: # HTLC success tx
+ preimage = witness[3]
+ elif len(witness) == 3: # spending offered HTLC directly from ctx
+ preimage = witness[1]
+ else:
+ return
+ payment_hash = sha256(preimage)
+ for direction, htlc in itertools.chain(self.hm.get_htlcs_in_oldest_unrevoked_ctx(REMOTE),
+ self.hm.get_htlcs_in_latest_ctx(REMOTE)):
+ if htlc.payment_hash == payment_hash:
+ is_sent = direction == RECEIVED
+ break
+ else:
+ for direction, htlc in itertools.chain(self.hm.get_htlcs_in_oldest_unrevoked_ctx(LOCAL),
+ self.hm.get_htlcs_in_latest_ctx(LOCAL)):
if htlc.payment_hash == payment_hash:
- is_sent = direction == RECEIVED
+ is_sent = direction == SENT
break
else:
- for direction, htlc in self.hm.get_htlcs_in_latest_ctx(REMOTE):
- if htlc.payment_hash == payment_hash:
- is_sent = direction == RECEIVED
- break
- else:
- for direction, htlc in self.hm.get_htlcs_in_oldest_unrevoked_ctx(LOCAL):
- if htlc.payment_hash == payment_hash:
- is_sent = direction == SENT
- break
- else:
- for direction, htlc in self.hm.get_htlcs_in_latest_ctx(LOCAL):
- if htlc.payment_hash == payment_hash:
- is_sent = direction == SENT
- break
- else:
- continue
- if self.lnworker.get_preimage(payment_hash) is None:
- self.logger.info(f'found preimage for {payment_hash.hex()} in witness of length {len(witness)}')
- self.lnworker.save_preimage(payment_hash, preimage)
- info = self.lnworker.get_payment_info(payment_hash)
- if info is not None and info.status != PR_PAID:
- if is_sent:
- self.lnworker.payment_sent(self, payment_hash)
- else:
- self.lnworker.payment_received(self, payment_hash)
+ return
+ if self.lnworker.get_preimage(payment_hash) is None:
+ self.logger.info(f'found preimage for {payment_hash.hex()} in witness of length {len(witness)}')
+ self.lnworker.save_preimage(payment_hash, preimage)
+ info = self.lnworker.get_payment_info(payment_hash)
+ if info is not None and info.status != PR_PAID:
+ if is_sent:
+ self.lnworker.payment_sent(self, payment_hash)
+ else:
+ self.lnworker.payment_received(self, payment_hash)
def balance(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None) -> int:
assert type(whose) is HTLCOwner
diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
@@ -13,7 +13,7 @@ from .sql_db import SqlDB, sql
from .wallet_db import WalletDB
from .util import bh2u, bfh, log_exceptions, ignore_exceptions, TxMinedInfo
from .address_synchronizer import AddressSynchronizer, TX_HEIGHT_LOCAL, TX_HEIGHT_UNCONF_PARENT, TX_HEIGHT_UNCONFIRMED
-from .transaction import Transaction
+from .transaction import Transaction, TxOutpoint
if TYPE_CHECKING:
from .network import Network
@@ -387,7 +387,10 @@ class LNWalletWatcher(LNWatcher):
else:
self.logger.info(f'(chan {chan.get_id_for_log()}) outpoint already spent {name}: {prevout}')
keep_watching |= not self.is_deeply_mined(spender_txid)
- chan.extract_preimage_from_htlc_tx(spender_tx)
+ txin_idx = spender_tx.get_input_idx_that_spent_prevout(TxOutpoint.from_str(prevout))
+ assert txin_idx is not None
+ spender_txin = spender_tx.inputs()[txin_idx]
+ chan.extract_preimage_from_htlc_txin(spender_txin)
else:
self.logger.info(f'(chan {chan.get_id_for_log()}) trying to redeem {name}: {prevout}')
await self.try_redeem(prevout, sweep_info, name)
diff --git a/electrum/transaction.py b/electrum/transaction.py
@@ -953,6 +953,20 @@ class Transaction:
else:
raise Exception('output not found', addr)
+ def get_input_idx_that_spent_prevout(self, prevout: TxOutpoint) -> Optional[int]:
+ # build cache if there isn't one yet
+ # note: can become stale and return incorrect data
+ # if the tx is modified later; that's out of scope.
+ if not hasattr(self, '_prevout_to_input_idx'):
+ d = {} # type: Dict[TxOutpoint, int]
+ for i, txin in enumerate(self.inputs()):
+ d[txin.prevout] = i
+ self._prevout_to_input_idx = d
+ idx = self._prevout_to_input_idx.get(prevout)
+ if idx is not None:
+ assert self.inputs()[idx].prevout == prevout
+ return idx
+
def convert_raw_tx_to_hex(raw: Union[str, bytes]) -> str:
"""Sanitizes tx-describing input (hex/base43/base64) into