electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 8e8ab775ebf019eabd1640ab4a7d9c820b08a384
parent 821431a23913349ae1c698b7653f42fc91ccb241
Author: SomberNight <somber.night@protonmail.com>
Date:   Mon, 13 Apr 2020 15:57:53 +0200

lnchannel: make AbstractChannel inherit ABC

and add some type annotations, clean up method signatures

Diffstat:
Melectrum/lnchannel.py | 148+++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------
Melectrum/lnsweep.py | 4++--
Melectrum/lntransport.py | 1+
Melectrum/lnutil.py | 8++++----
Melectrum/lnwatcher.py | 42+++++++++++++++++++++++-------------------
Melectrum/lnworker.py | 2+-
6 files changed, 134 insertions(+), 71 deletions(-)

diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py @@ -27,13 +27,14 @@ from typing import (Optional, Dict, List, Tuple, NamedTuple, Set, Callable, Iterable, Sequence, TYPE_CHECKING, Iterator, Union) import time import threading +from abc import ABC, abstractmethod from aiorpcx import NetAddress import attr from . import ecc from . import constants -from .util import bfh, bh2u, chunks +from .util import bfh, bh2u, chunks, TxMinedInfo from .bitcoin import redeem_script_to_address from .crypto import sha256, sha256d from .transaction import Transaction, PartialTransaction @@ -113,7 +114,9 @@ state_transitions = [ del cs # delete as name is ambiguous without context -RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_commitment_point"]) +class RevokeAndAck(NamedTuple): + per_commitment_secret: bytes + next_per_commitment_point: bytes class RemoteCtnTooFarInFuture(Exception): pass @@ -123,7 +126,16 @@ def htlcsum(htlcs): return sum([x.amount_msat for x in htlcs]) -class AbstractChannel(Logger): +class AbstractChannel(Logger, ABC): + storage: Union['StoredDict', dict] + config: Dict[HTLCOwner, Union[LocalConfig, RemoteConfig]] + _sweep_info: Dict[str, Dict[str, 'SweepInfo']] + lnworker: Optional['LNWallet'] + sweep_address: str + channel_id: bytes + funding_outpoint: Outpoint + node_id: bytes + _state: channel_states def set_short_channel_id(self, short_id: ShortChannelID) -> None: self.short_channel_id = short_id @@ -168,7 +180,7 @@ class AbstractChannel(Logger): def is_redeemed(self): return self.get_state() == channel_states.REDEEMED - def save_funding_height(self, txid, height, timestamp): + def save_funding_height(self, *, txid: str, height: int, timestamp: Optional[int]) -> None: self.storage['funding_height'] = txid, height, timestamp def get_funding_height(self): @@ -177,7 +189,7 @@ class AbstractChannel(Logger): def delete_funding_height(self): self.storage.pop('funding_height', None) - def save_closing_height(self, txid, height, timestamp): + def save_closing_height(self, *, txid: str, height: int, timestamp: Optional[int]) -> None: self.storage['closing_height'] = txid, height, timestamp def get_closing_height(self): @@ -197,30 +209,34 @@ class AbstractChannel(Logger): def sweep_ctx(self, ctx: Transaction) -> Dict[str, SweepInfo]: txid = ctx.txid() - if self.sweep_info.get(txid) is None: + if self._sweep_info.get(txid) is None: our_sweep_info = self.create_sweeptxs_for_our_ctx(ctx) their_sweep_info = self.create_sweeptxs_for_their_ctx(ctx) if our_sweep_info is not None: - self.sweep_info[txid] = our_sweep_info + self._sweep_info[txid] = our_sweep_info self.logger.info(f'we force closed') elif their_sweep_info is not None: - self.sweep_info[txid] = their_sweep_info + self._sweep_info[txid] = their_sweep_info self.logger.info(f'they force closed.') else: - self.sweep_info[txid] = {} + self._sweep_info[txid] = {} self.logger.info(f'not sure who closed.') - return self.sweep_info[txid] + return self._sweep_info[txid] - # ancestor for Channel and ChannelBackup - def update_onchain_state(self, funding_txid, funding_height, closing_txid, closing_height, keep_watching): + def update_onchain_state(self, *, funding_txid: str, funding_height: TxMinedInfo, + closing_txid: str, closing_height: TxMinedInfo, keep_watching: bool) -> None: # note: state transitions are irreversible, but # save_funding_height, save_closing_height are reversible if funding_height.height == TX_HEIGHT_LOCAL: self.update_unfunded_state() elif closing_height.height == TX_HEIGHT_LOCAL: - self.update_funded_state(funding_txid, funding_height) + self.update_funded_state(funding_txid=funding_txid, funding_height=funding_height) else: - self.update_closed_state(funding_txid, funding_height, closing_txid, closing_height, keep_watching) + self.update_closed_state(funding_txid=funding_txid, + funding_height=funding_height, + closing_txid=closing_txid, + closing_height=closing_height, + keep_watching=keep_watching) def update_unfunded_state(self): self.delete_funding_height() @@ -249,8 +265,8 @@ class AbstractChannel(Logger): if now - self.storage.get('init_timestamp', 0) > CHANNEL_OPENING_TIMEOUT: self.lnworker.remove_channel(self.channel_id) - def update_funded_state(self, funding_txid, funding_height): - self.save_funding_height(funding_txid, funding_height.height, funding_height.timestamp) + def update_funded_state(self, *, funding_txid: str, funding_height: TxMinedInfo) -> None: + self.save_funding_height(txid=funding_txid, height=funding_height.height, timestamp=funding_height.timestamp) self.delete_closing_height() if funding_height.conf>0: self.set_short_channel_id(ShortChannelID.from_components( @@ -259,9 +275,10 @@ class AbstractChannel(Logger): if self.is_funding_tx_mined(funding_height): self.set_state(channel_states.FUNDED) - def update_closed_state(self, funding_txid, funding_height, closing_txid, closing_height, keep_watching): - self.save_funding_height(funding_txid, funding_height.height, funding_height.timestamp) - self.save_closing_height(closing_txid, closing_height.height, closing_height.timestamp) + def update_closed_state(self, *, funding_txid: str, funding_height: TxMinedInfo, + closing_txid: str, closing_height: TxMinedInfo, keep_watching: bool) -> None: + self.save_funding_height(txid=funding_txid, height=funding_height.height, timestamp=funding_height.timestamp) + self.save_closing_height(txid=closing_txid, height=closing_height.height, timestamp=closing_height.timestamp) if self.get_state() < channel_states.CLOSED: conf = closing_height.conf if conf > 0: @@ -273,6 +290,66 @@ class AbstractChannel(Logger): if self.get_state() == channel_states.CLOSED and not keep_watching: self.set_state(channel_states.REDEEMED) + @abstractmethod + def is_initiator(self) -> bool: + pass + + @abstractmethod + def is_funding_tx_mined(self, funding_height: TxMinedInfo) -> bool: + pass + + @abstractmethod + def get_funding_address(self) -> str: + pass + + @abstractmethod + def get_state_for_GUI(self) -> str: + pass + + @abstractmethod + def get_oldest_unrevoked_ctn(self, subject: HTLCOwner) -> int: + pass + + @abstractmethod + def included_htlcs(self, subject: HTLCOwner, direction: Direction, ctn: int = None) -> Sequence[UpdateAddHtlc]: + pass + + @abstractmethod + def funding_txn_minimum_depth(self) -> int: + pass + + @abstractmethod + def balance(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None) -> int: + """This balance (in msat) only considers HTLCs that have been settled by ctn. + It disregards reserve, fees, and pending HTLCs (in both directions). + """ + pass + + @abstractmethod + def balance_minus_outgoing_htlcs(self, whose: HTLCOwner, *, + ctx_owner: HTLCOwner = HTLCOwner.LOCAL, + ctn: int = None) -> int: + """This balance (in msat), which includes the value of + pending outgoing HTLCs, is used in the UI. + """ + pass + + @abstractmethod + def is_frozen_for_sending(self) -> bool: + """Whether the user has marked this channel as frozen for sending. + Frozen channels are not supposed to be used for new outgoing payments. + (note that payment-forwarding ignores this option) + """ + pass + + @abstractmethod + def is_frozen_for_receiving(self) -> bool: + """Whether the user has marked this channel as frozen for receiving. + Frozen channels are not supposed to be used for new incoming payments. + (note that payment-forwarding ignores this option) + """ + pass + class ChannelBackup(AbstractChannel): """ @@ -288,7 +365,7 @@ class ChannelBackup(AbstractChannel): self.name = None Logger.__init__(self) self.cb = cb - self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]] + self._sweep_info = {} self.sweep_address = sweep_address self.storage = {} # dummy storage self._state = channel_states.OPENING @@ -351,7 +428,7 @@ class ChannelBackup(AbstractChannel): def get_oldest_unrevoked_ctn(self, who): return -1 - def included_htlcs(self, subject, direction, ctn): + def included_htlcs(self, subject, direction, ctn=None): return [] def funding_txn_minimum_depth(self): @@ -381,16 +458,16 @@ class Channel(AbstractChannel): def __init__(self, state: 'StoredDict', *, sweep_address=None, name=None, lnworker=None, initial_feerate=None): self.name = name Logger.__init__(self) - self.lnworker = lnworker # type: Optional[LNWallet] + self.lnworker = lnworker self.sweep_address = sweep_address self.storage = state self.db_lock = self.storage.db.lock if self.storage.db else threading.RLock() - self.config = {} # type: Dict[HTLCOwner, Union[LocalConfig, RemoteConfig]] + self.config = {} self.config[LOCAL] = state["local_config"] self.config[REMOTE] = state["remote_config"] self.channel_id = bfh(state["channel_id"]) self.constraints = state["constraints"] # type: ChannelConstraints - self.funding_outpoint = state["funding_outpoint"] # type: Outpoint + self.funding_outpoint = state["funding_outpoint"] self.node_id = bfh(state["node_id"]) self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"]) self.onion_keys = state['onion_keys'] # type: Dict[int, bytes] @@ -398,7 +475,7 @@ class Channel(AbstractChannel): self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate) self._state = channel_states[state['state']] self.peer_state = peer_states.DISCONNECTED - self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]] + self._sweep_info = {} self._outgoing_channel_update = None # type: Optional[bytes] self._chan_ann_without_sigs = None # type: Optional[bytes] self.revocation_store = RevocationStore(state["revocation_store"]) @@ -596,10 +673,6 @@ class Channel(AbstractChannel): return self.can_send_ctx_updates() and not self.is_closing() def is_frozen_for_sending(self) -> bool: - """Whether the user has marked this channel as frozen for sending. - Frozen channels are not supposed to be used for new outgoing payments. - (note that payment-forwarding ignores this option) - """ return self.storage.get('frozen_for_sending', False) def set_frozen_for_sending(self, b: bool) -> None: @@ -608,10 +681,6 @@ class Channel(AbstractChannel): self.lnworker.network.trigger_callback('channel', self) def is_frozen_for_receiving(self) -> bool: - """Whether the user has marked this channel as frozen for receiving. - Frozen channels are not supposed to be used for new incoming payments. - (note that payment-forwarding ignores this option) - """ return self.storage.get('frozen_for_receiving', False) def set_frozen_for_receiving(self, b: bool) -> None: @@ -880,9 +949,6 @@ class Channel(AbstractChannel): self.lnworker.payment_failed(self, htlc.payment_hash, payment_attempt) def balance(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None) -> int: - """This balance (in msat) only considers HTLCs that have been settled by ctn. - It disregards reserve, fees, and pending HTLCs (in both directions). - """ assert type(whose) is HTLCOwner initial = self.config[whose].initial_msat return self.hm.get_balance_msat(whose=whose, @@ -891,10 +957,7 @@ class Channel(AbstractChannel): initial_balance_msat=initial) def balance_minus_outgoing_htlcs(self, whose: HTLCOwner, *, ctx_owner: HTLCOwner = HTLCOwner.LOCAL, - ctn: int = None): - """This balance (in msat), which includes the value of - pending outgoing HTLCs, is used in the UI. - """ + ctn: int = None) -> int: assert type(whose) is HTLCOwner if ctn is None: ctn = self.get_next_ctn(ctx_owner) @@ -1282,11 +1345,6 @@ class Channel(AbstractChannel): return total_value_sat > min_value_worth_closing_channel_over_sat def is_funding_tx_mined(self, funding_height): - """ - Checks if Funding TX has been mined. If it has, save the short channel ID in chan; - if it's also deep enough, also save to disk. - Returns tuple (mined_deep_enough, num_confirmations). - """ funding_txid = self.funding_outpoint.txid funding_idx = self.funding_outpoint.output_index conf = funding_height.conf diff --git a/electrum/lnsweep.py b/electrum/lnsweep.py @@ -21,7 +21,7 @@ from .simple_config import SimpleConfig from .logging import get_logger, Logger if TYPE_CHECKING: - from .lnchannel import Channel + from .lnchannel import Channel, AbstractChannel _logger = get_logger(__name__) @@ -169,7 +169,7 @@ def create_sweeptx_for_their_revoked_htlc(chan: 'Channel', ctx: Transaction, htl -def create_sweeptxs_for_our_ctx(*, chan: 'Channel', ctx: Transaction, +def create_sweeptxs_for_our_ctx(*, chan: 'AbstractChannel', ctx: Transaction, sweep_address: str) -> Optional[Dict[str, SweepInfo]]: """Handle the case where we force close unilaterally with our latest ctx. Construct sweep txns for 'to_local', and for all HTLCs (2 txns each). diff --git a/electrum/lntransport.py b/electrum/lntransport.py @@ -89,6 +89,7 @@ def create_ephemeral_key() -> (bytes, bytes): class LNTransportBase: reader: StreamReader writer: StreamWriter + privkey: bytes def name(self) -> str: raise NotImplementedError() diff --git a/electrum/lnutil.py b/electrum/lnutil.py @@ -27,7 +27,7 @@ from .bip32 import BIP32Node, BIP32_PRIME from .transaction import BCDataStream if TYPE_CHECKING: - from .lnchannel import Channel + from .lnchannel import Channel, AbstractChannel from .lnrouter import LNPaymentRoute from .lnonion import OnionRoutingFailureMessage @@ -504,8 +504,8 @@ def make_htlc_output_witness_script(is_received_htlc: bool, remote_revocation_pu payment_hash=payment_hash) -def get_ordered_channel_configs(chan: 'Channel', for_us: bool) -> Tuple[Union[LocalConfig, RemoteConfig], - Union[LocalConfig, RemoteConfig]]: +def get_ordered_channel_configs(chan: 'AbstractChannel', for_us: bool) -> Tuple[Union[LocalConfig, RemoteConfig], + Union[LocalConfig, RemoteConfig]]: conf = chan.config[LOCAL] if for_us else chan.config[REMOTE] other_conf = chan.config[LOCAL] if not for_us else chan.config[REMOTE] return conf, other_conf @@ -781,7 +781,7 @@ def extract_ctn_from_tx(tx: Transaction, txin_index: int, funder_payment_basepoi obs = ((sequence & 0xffffff) << 24) + (locktime & 0xffffff) return get_obscured_ctn(obs, funder_payment_basepoint, fundee_payment_basepoint) -def extract_ctn_from_tx_and_chan(tx: Transaction, chan: 'Channel') -> int: +def extract_ctn_from_tx_and_chan(tx: Transaction, chan: 'AbstractChannel') -> int: funder_conf = chan.config[LOCAL] if chan.is_initiator() else chan.config[REMOTE] fundee_conf = chan.config[LOCAL] if not chan.is_initiator() else chan.config[REMOTE] return extract_ctn_from_tx(tx, txin_index=0, diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py @@ -4,20 +4,13 @@ from typing import NamedTuple, Iterable, TYPE_CHECKING import os -import queue -import threading -import concurrent -from collections import defaultdict import asyncio from enum import IntEnum, auto from typing import NamedTuple, Dict from .sql_db import SqlDB, sql from .wallet_db import WalletDB -from .util import bh2u, bfh, log_exceptions, ignore_exceptions -from .lnutil import Outpoint -from . import wallet -from .storage import WalletStorage +from .util import bh2u, bfh, log_exceptions, ignore_exceptions, TxMinedInfo from .address_synchronizer import AddressSynchronizer, TX_HEIGHT_LOCAL, TX_HEIGHT_UNCONF_PARENT, TX_HEIGHT_UNCONFIRMED from .transaction import Transaction @@ -199,17 +192,22 @@ class LNWatcher(AddressSynchronizer): else: keep_watching = True await self.update_channel_state( - funding_outpoint, funding_txid, - funding_height, closing_txid, - closing_height, keep_watching) + funding_outpoint=funding_outpoint, + funding_txid=funding_txid, + funding_height=funding_height, + closing_txid=closing_txid, + closing_height=closing_height, + keep_watching=keep_watching) if not keep_watching: await self.unwatch_channel(address, funding_outpoint) - async def do_breach_remedy(self, funding_outpoint, closing_tx, spenders): - raise NotImplementedError() # implemented by subclasses + async def do_breach_remedy(self, funding_outpoint, closing_tx, spenders) -> bool: + raise NotImplementedError() # implemented by subclasses - async def update_channel_state(self, *args): - raise NotImplementedError() # implemented by subclasses + async def update_channel_state(self, *, funding_outpoint: str, funding_txid: str, + funding_height: TxMinedInfo, closing_txid: str, + closing_height: TxMinedInfo, keep_watching: bool) -> None: + raise NotImplementedError() # implemented by subclasses def inspect_tx_candidate(self, outpoint, n): prev_txid, index = outpoint.split(':') @@ -325,7 +323,7 @@ class WatchTower(LNWatcher): if funding_outpoint in self.tx_progress: self.tx_progress[funding_outpoint].all_done.set() - async def update_channel_state(self, *args): + async def update_channel_state(self, *args, **kwargs): pass @@ -340,17 +338,23 @@ class LNWalletWatcher(LNWatcher): @ignore_exceptions @log_exceptions - async def update_channel_state(self, funding_outpoint, funding_txid, funding_height, closing_txid, closing_height, keep_watching): + async def update_channel_state(self, *, funding_outpoint: str, funding_txid: str, + funding_height: TxMinedInfo, closing_txid: str, + closing_height: TxMinedInfo, keep_watching: bool) -> None: chan = self.lnworker.channel_by_txo(funding_outpoint) if not chan: return - chan.update_onchain_state(funding_txid, funding_height, closing_txid, closing_height, keep_watching) + chan.update_onchain_state(funding_txid=funding_txid, + funding_height=funding_height, + closing_txid=closing_txid, + closing_height=closing_height, + keep_watching=keep_watching) await self.lnworker.on_channel_update(chan) async def do_breach_remedy(self, funding_outpoint, closing_tx, spenders): chan = self.lnworker.channel_by_txo(funding_outpoint) if not chan: - return + return False # detect who closed and set sweep_info sweep_info_dict = chan.sweep_ctx(closing_tx) keep_watching = False if sweep_info_dict else not self.is_deeply_mined(closing_tx.txid()) diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -432,7 +432,7 @@ class LNWallet(LNWorker): self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage self.sweep_address = wallet.get_receiving_address() self.lock = threading.RLock() - self.logs = defaultdict(list) # (not persisted) type: Dict[str, List[PaymentAttemptLog]] # key is RHASH + self.logs = defaultdict(list) # type: Dict[str, List[PaymentAttemptLog]] # key is RHASH # (not persisted) self.is_routing = set() # (not persisted) keys of invoices that are in PR_ROUTING state # used in tests self.enable_htlc_settle = asyncio.Event()