commit 8e8ab775ebf019eabd1640ab4a7d9c820b08a384
parent 821431a23913349ae1c698b7653f42fc91ccb241
Author: SomberNight <somber.night@protonmail.com>
Date: Mon, 13 Apr 2020 15:57:53 +0200
lnchannel: make AbstractChannel inherit ABC
and add some type annotations, clean up method signatures
Diffstat:
6 files changed, 134 insertions(+), 71 deletions(-)
diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py
@@ -27,13 +27,14 @@ from typing import (Optional, Dict, List, Tuple, NamedTuple, Set, Callable,
Iterable, Sequence, TYPE_CHECKING, Iterator, Union)
import time
import threading
+from abc import ABC, abstractmethod
from aiorpcx import NetAddress
import attr
from . import ecc
from . import constants
-from .util import bfh, bh2u, chunks
+from .util import bfh, bh2u, chunks, TxMinedInfo
from .bitcoin import redeem_script_to_address
from .crypto import sha256, sha256d
from .transaction import Transaction, PartialTransaction
@@ -113,7 +114,9 @@ state_transitions = [
del cs # delete as name is ambiguous without context
-RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_commitment_point"])
+class RevokeAndAck(NamedTuple):
+ per_commitment_secret: bytes
+ next_per_commitment_point: bytes
class RemoteCtnTooFarInFuture(Exception): pass
@@ -123,7 +126,16 @@ def htlcsum(htlcs):
return sum([x.amount_msat for x in htlcs])
-class AbstractChannel(Logger):
+class AbstractChannel(Logger, ABC):
+ storage: Union['StoredDict', dict]
+ config: Dict[HTLCOwner, Union[LocalConfig, RemoteConfig]]
+ _sweep_info: Dict[str, Dict[str, 'SweepInfo']]
+ lnworker: Optional['LNWallet']
+ sweep_address: str
+ channel_id: bytes
+ funding_outpoint: Outpoint
+ node_id: bytes
+ _state: channel_states
def set_short_channel_id(self, short_id: ShortChannelID) -> None:
self.short_channel_id = short_id
@@ -168,7 +180,7 @@ class AbstractChannel(Logger):
def is_redeemed(self):
return self.get_state() == channel_states.REDEEMED
- def save_funding_height(self, txid, height, timestamp):
+ def save_funding_height(self, *, txid: str, height: int, timestamp: Optional[int]) -> None:
self.storage['funding_height'] = txid, height, timestamp
def get_funding_height(self):
@@ -177,7 +189,7 @@ class AbstractChannel(Logger):
def delete_funding_height(self):
self.storage.pop('funding_height', None)
- def save_closing_height(self, txid, height, timestamp):
+ def save_closing_height(self, *, txid: str, height: int, timestamp: Optional[int]) -> None:
self.storage['closing_height'] = txid, height, timestamp
def get_closing_height(self):
@@ -197,30 +209,34 @@ class AbstractChannel(Logger):
def sweep_ctx(self, ctx: Transaction) -> Dict[str, SweepInfo]:
txid = ctx.txid()
- if self.sweep_info.get(txid) is None:
+ if self._sweep_info.get(txid) is None:
our_sweep_info = self.create_sweeptxs_for_our_ctx(ctx)
their_sweep_info = self.create_sweeptxs_for_their_ctx(ctx)
if our_sweep_info is not None:
- self.sweep_info[txid] = our_sweep_info
+ self._sweep_info[txid] = our_sweep_info
self.logger.info(f'we force closed')
elif their_sweep_info is not None:
- self.sweep_info[txid] = their_sweep_info
+ self._sweep_info[txid] = their_sweep_info
self.logger.info(f'they force closed.')
else:
- self.sweep_info[txid] = {}
+ self._sweep_info[txid] = {}
self.logger.info(f'not sure who closed.')
- return self.sweep_info[txid]
+ return self._sweep_info[txid]
- # ancestor for Channel and ChannelBackup
- def update_onchain_state(self, funding_txid, funding_height, closing_txid, closing_height, keep_watching):
+ def update_onchain_state(self, *, funding_txid: str, funding_height: TxMinedInfo,
+ closing_txid: str, closing_height: TxMinedInfo, keep_watching: bool) -> None:
# note: state transitions are irreversible, but
# save_funding_height, save_closing_height are reversible
if funding_height.height == TX_HEIGHT_LOCAL:
self.update_unfunded_state()
elif closing_height.height == TX_HEIGHT_LOCAL:
- self.update_funded_state(funding_txid, funding_height)
+ self.update_funded_state(funding_txid=funding_txid, funding_height=funding_height)
else:
- self.update_closed_state(funding_txid, funding_height, closing_txid, closing_height, keep_watching)
+ self.update_closed_state(funding_txid=funding_txid,
+ funding_height=funding_height,
+ closing_txid=closing_txid,
+ closing_height=closing_height,
+ keep_watching=keep_watching)
def update_unfunded_state(self):
self.delete_funding_height()
@@ -249,8 +265,8 @@ class AbstractChannel(Logger):
if now - self.storage.get('init_timestamp', 0) > CHANNEL_OPENING_TIMEOUT:
self.lnworker.remove_channel(self.channel_id)
- def update_funded_state(self, funding_txid, funding_height):
- self.save_funding_height(funding_txid, funding_height.height, funding_height.timestamp)
+ def update_funded_state(self, *, funding_txid: str, funding_height: TxMinedInfo) -> None:
+ self.save_funding_height(txid=funding_txid, height=funding_height.height, timestamp=funding_height.timestamp)
self.delete_closing_height()
if funding_height.conf>0:
self.set_short_channel_id(ShortChannelID.from_components(
@@ -259,9 +275,10 @@ class AbstractChannel(Logger):
if self.is_funding_tx_mined(funding_height):
self.set_state(channel_states.FUNDED)
- def update_closed_state(self, funding_txid, funding_height, closing_txid, closing_height, keep_watching):
- self.save_funding_height(funding_txid, funding_height.height, funding_height.timestamp)
- self.save_closing_height(closing_txid, closing_height.height, closing_height.timestamp)
+ def update_closed_state(self, *, funding_txid: str, funding_height: TxMinedInfo,
+ closing_txid: str, closing_height: TxMinedInfo, keep_watching: bool) -> None:
+ self.save_funding_height(txid=funding_txid, height=funding_height.height, timestamp=funding_height.timestamp)
+ self.save_closing_height(txid=closing_txid, height=closing_height.height, timestamp=closing_height.timestamp)
if self.get_state() < channel_states.CLOSED:
conf = closing_height.conf
if conf > 0:
@@ -273,6 +290,66 @@ class AbstractChannel(Logger):
if self.get_state() == channel_states.CLOSED and not keep_watching:
self.set_state(channel_states.REDEEMED)
+ @abstractmethod
+ def is_initiator(self) -> bool:
+ pass
+
+ @abstractmethod
+ def is_funding_tx_mined(self, funding_height: TxMinedInfo) -> bool:
+ pass
+
+ @abstractmethod
+ def get_funding_address(self) -> str:
+ pass
+
+ @abstractmethod
+ def get_state_for_GUI(self) -> str:
+ pass
+
+ @abstractmethod
+ def get_oldest_unrevoked_ctn(self, subject: HTLCOwner) -> int:
+ pass
+
+ @abstractmethod
+ def included_htlcs(self, subject: HTLCOwner, direction: Direction, ctn: int = None) -> Sequence[UpdateAddHtlc]:
+ pass
+
+ @abstractmethod
+ def funding_txn_minimum_depth(self) -> int:
+ pass
+
+ @abstractmethod
+ def balance(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None) -> int:
+ """This balance (in msat) only considers HTLCs that have been settled by ctn.
+ It disregards reserve, fees, and pending HTLCs (in both directions).
+ """
+ pass
+
+ @abstractmethod
+ def balance_minus_outgoing_htlcs(self, whose: HTLCOwner, *,
+ ctx_owner: HTLCOwner = HTLCOwner.LOCAL,
+ ctn: int = None) -> int:
+ """This balance (in msat), which includes the value of
+ pending outgoing HTLCs, is used in the UI.
+ """
+ pass
+
+ @abstractmethod
+ def is_frozen_for_sending(self) -> bool:
+ """Whether the user has marked this channel as frozen for sending.
+ Frozen channels are not supposed to be used for new outgoing payments.
+ (note that payment-forwarding ignores this option)
+ """
+ pass
+
+ @abstractmethod
+ def is_frozen_for_receiving(self) -> bool:
+ """Whether the user has marked this channel as frozen for receiving.
+ Frozen channels are not supposed to be used for new incoming payments.
+ (note that payment-forwarding ignores this option)
+ """
+ pass
+
class ChannelBackup(AbstractChannel):
"""
@@ -288,7 +365,7 @@ class ChannelBackup(AbstractChannel):
self.name = None
Logger.__init__(self)
self.cb = cb
- self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]]
+ self._sweep_info = {}
self.sweep_address = sweep_address
self.storage = {} # dummy storage
self._state = channel_states.OPENING
@@ -351,7 +428,7 @@ class ChannelBackup(AbstractChannel):
def get_oldest_unrevoked_ctn(self, who):
return -1
- def included_htlcs(self, subject, direction, ctn):
+ def included_htlcs(self, subject, direction, ctn=None):
return []
def funding_txn_minimum_depth(self):
@@ -381,16 +458,16 @@ class Channel(AbstractChannel):
def __init__(self, state: 'StoredDict', *, sweep_address=None, name=None, lnworker=None, initial_feerate=None):
self.name = name
Logger.__init__(self)
- self.lnworker = lnworker # type: Optional[LNWallet]
+ self.lnworker = lnworker
self.sweep_address = sweep_address
self.storage = state
self.db_lock = self.storage.db.lock if self.storage.db else threading.RLock()
- self.config = {} # type: Dict[HTLCOwner, Union[LocalConfig, RemoteConfig]]
+ self.config = {}
self.config[LOCAL] = state["local_config"]
self.config[REMOTE] = state["remote_config"]
self.channel_id = bfh(state["channel_id"])
self.constraints = state["constraints"] # type: ChannelConstraints
- self.funding_outpoint = state["funding_outpoint"] # type: Outpoint
+ self.funding_outpoint = state["funding_outpoint"]
self.node_id = bfh(state["node_id"])
self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"])
self.onion_keys = state['onion_keys'] # type: Dict[int, bytes]
@@ -398,7 +475,7 @@ class Channel(AbstractChannel):
self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate)
self._state = channel_states[state['state']]
self.peer_state = peer_states.DISCONNECTED
- self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]]
+ self._sweep_info = {}
self._outgoing_channel_update = None # type: Optional[bytes]
self._chan_ann_without_sigs = None # type: Optional[bytes]
self.revocation_store = RevocationStore(state["revocation_store"])
@@ -596,10 +673,6 @@ class Channel(AbstractChannel):
return self.can_send_ctx_updates() and not self.is_closing()
def is_frozen_for_sending(self) -> bool:
- """Whether the user has marked this channel as frozen for sending.
- Frozen channels are not supposed to be used for new outgoing payments.
- (note that payment-forwarding ignores this option)
- """
return self.storage.get('frozen_for_sending', False)
def set_frozen_for_sending(self, b: bool) -> None:
@@ -608,10 +681,6 @@ class Channel(AbstractChannel):
self.lnworker.network.trigger_callback('channel', self)
def is_frozen_for_receiving(self) -> bool:
- """Whether the user has marked this channel as frozen for receiving.
- Frozen channels are not supposed to be used for new incoming payments.
- (note that payment-forwarding ignores this option)
- """
return self.storage.get('frozen_for_receiving', False)
def set_frozen_for_receiving(self, b: bool) -> None:
@@ -880,9 +949,6 @@ class Channel(AbstractChannel):
self.lnworker.payment_failed(self, htlc.payment_hash, payment_attempt)
def balance(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None) -> int:
- """This balance (in msat) only considers HTLCs that have been settled by ctn.
- It disregards reserve, fees, and pending HTLCs (in both directions).
- """
assert type(whose) is HTLCOwner
initial = self.config[whose].initial_msat
return self.hm.get_balance_msat(whose=whose,
@@ -891,10 +957,7 @@ class Channel(AbstractChannel):
initial_balance_msat=initial)
def balance_minus_outgoing_htlcs(self, whose: HTLCOwner, *, ctx_owner: HTLCOwner = HTLCOwner.LOCAL,
- ctn: int = None):
- """This balance (in msat), which includes the value of
- pending outgoing HTLCs, is used in the UI.
- """
+ ctn: int = None) -> int:
assert type(whose) is HTLCOwner
if ctn is None:
ctn = self.get_next_ctn(ctx_owner)
@@ -1282,11 +1345,6 @@ class Channel(AbstractChannel):
return total_value_sat > min_value_worth_closing_channel_over_sat
def is_funding_tx_mined(self, funding_height):
- """
- Checks if Funding TX has been mined. If it has, save the short channel ID in chan;
- if it's also deep enough, also save to disk.
- Returns tuple (mined_deep_enough, num_confirmations).
- """
funding_txid = self.funding_outpoint.txid
funding_idx = self.funding_outpoint.output_index
conf = funding_height.conf
diff --git a/electrum/lnsweep.py b/electrum/lnsweep.py
@@ -21,7 +21,7 @@ from .simple_config import SimpleConfig
from .logging import get_logger, Logger
if TYPE_CHECKING:
- from .lnchannel import Channel
+ from .lnchannel import Channel, AbstractChannel
_logger = get_logger(__name__)
@@ -169,7 +169,7 @@ def create_sweeptx_for_their_revoked_htlc(chan: 'Channel', ctx: Transaction, htl
-def create_sweeptxs_for_our_ctx(*, chan: 'Channel', ctx: Transaction,
+def create_sweeptxs_for_our_ctx(*, chan: 'AbstractChannel', ctx: Transaction,
sweep_address: str) -> Optional[Dict[str, SweepInfo]]:
"""Handle the case where we force close unilaterally with our latest ctx.
Construct sweep txns for 'to_local', and for all HTLCs (2 txns each).
diff --git a/electrum/lntransport.py b/electrum/lntransport.py
@@ -89,6 +89,7 @@ def create_ephemeral_key() -> (bytes, bytes):
class LNTransportBase:
reader: StreamReader
writer: StreamWriter
+ privkey: bytes
def name(self) -> str:
raise NotImplementedError()
diff --git a/electrum/lnutil.py b/electrum/lnutil.py
@@ -27,7 +27,7 @@ from .bip32 import BIP32Node, BIP32_PRIME
from .transaction import BCDataStream
if TYPE_CHECKING:
- from .lnchannel import Channel
+ from .lnchannel import Channel, AbstractChannel
from .lnrouter import LNPaymentRoute
from .lnonion import OnionRoutingFailureMessage
@@ -504,8 +504,8 @@ def make_htlc_output_witness_script(is_received_htlc: bool, remote_revocation_pu
payment_hash=payment_hash)
-def get_ordered_channel_configs(chan: 'Channel', for_us: bool) -> Tuple[Union[LocalConfig, RemoteConfig],
- Union[LocalConfig, RemoteConfig]]:
+def get_ordered_channel_configs(chan: 'AbstractChannel', for_us: bool) -> Tuple[Union[LocalConfig, RemoteConfig],
+ Union[LocalConfig, RemoteConfig]]:
conf = chan.config[LOCAL] if for_us else chan.config[REMOTE]
other_conf = chan.config[LOCAL] if not for_us else chan.config[REMOTE]
return conf, other_conf
@@ -781,7 +781,7 @@ def extract_ctn_from_tx(tx: Transaction, txin_index: int, funder_payment_basepoi
obs = ((sequence & 0xffffff) << 24) + (locktime & 0xffffff)
return get_obscured_ctn(obs, funder_payment_basepoint, fundee_payment_basepoint)
-def extract_ctn_from_tx_and_chan(tx: Transaction, chan: 'Channel') -> int:
+def extract_ctn_from_tx_and_chan(tx: Transaction, chan: 'AbstractChannel') -> int:
funder_conf = chan.config[LOCAL] if chan.is_initiator() else chan.config[REMOTE]
fundee_conf = chan.config[LOCAL] if not chan.is_initiator() else chan.config[REMOTE]
return extract_ctn_from_tx(tx, txin_index=0,
diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
@@ -4,20 +4,13 @@
from typing import NamedTuple, Iterable, TYPE_CHECKING
import os
-import queue
-import threading
-import concurrent
-from collections import defaultdict
import asyncio
from enum import IntEnum, auto
from typing import NamedTuple, Dict
from .sql_db import SqlDB, sql
from .wallet_db import WalletDB
-from .util import bh2u, bfh, log_exceptions, ignore_exceptions
-from .lnutil import Outpoint
-from . import wallet
-from .storage import WalletStorage
+from .util import bh2u, bfh, log_exceptions, ignore_exceptions, TxMinedInfo
from .address_synchronizer import AddressSynchronizer, TX_HEIGHT_LOCAL, TX_HEIGHT_UNCONF_PARENT, TX_HEIGHT_UNCONFIRMED
from .transaction import Transaction
@@ -199,17 +192,22 @@ class LNWatcher(AddressSynchronizer):
else:
keep_watching = True
await self.update_channel_state(
- funding_outpoint, funding_txid,
- funding_height, closing_txid,
- closing_height, keep_watching)
+ funding_outpoint=funding_outpoint,
+ funding_txid=funding_txid,
+ funding_height=funding_height,
+ closing_txid=closing_txid,
+ closing_height=closing_height,
+ keep_watching=keep_watching)
if not keep_watching:
await self.unwatch_channel(address, funding_outpoint)
- async def do_breach_remedy(self, funding_outpoint, closing_tx, spenders):
- raise NotImplementedError() # implemented by subclasses
+ async def do_breach_remedy(self, funding_outpoint, closing_tx, spenders) -> bool:
+ raise NotImplementedError() # implemented by subclasses
- async def update_channel_state(self, *args):
- raise NotImplementedError() # implemented by subclasses
+ async def update_channel_state(self, *, funding_outpoint: str, funding_txid: str,
+ funding_height: TxMinedInfo, closing_txid: str,
+ closing_height: TxMinedInfo, keep_watching: bool) -> None:
+ raise NotImplementedError() # implemented by subclasses
def inspect_tx_candidate(self, outpoint, n):
prev_txid, index = outpoint.split(':')
@@ -325,7 +323,7 @@ class WatchTower(LNWatcher):
if funding_outpoint in self.tx_progress:
self.tx_progress[funding_outpoint].all_done.set()
- async def update_channel_state(self, *args):
+ async def update_channel_state(self, *args, **kwargs):
pass
@@ -340,17 +338,23 @@ class LNWalletWatcher(LNWatcher):
@ignore_exceptions
@log_exceptions
- async def update_channel_state(self, funding_outpoint, funding_txid, funding_height, closing_txid, closing_height, keep_watching):
+ async def update_channel_state(self, *, funding_outpoint: str, funding_txid: str,
+ funding_height: TxMinedInfo, closing_txid: str,
+ closing_height: TxMinedInfo, keep_watching: bool) -> None:
chan = self.lnworker.channel_by_txo(funding_outpoint)
if not chan:
return
- chan.update_onchain_state(funding_txid, funding_height, closing_txid, closing_height, keep_watching)
+ chan.update_onchain_state(funding_txid=funding_txid,
+ funding_height=funding_height,
+ closing_txid=closing_txid,
+ closing_height=closing_height,
+ keep_watching=keep_watching)
await self.lnworker.on_channel_update(chan)
async def do_breach_remedy(self, funding_outpoint, closing_tx, spenders):
chan = self.lnworker.channel_by_txo(funding_outpoint)
if not chan:
- return
+ return False
# detect who closed and set sweep_info
sweep_info_dict = chan.sweep_ctx(closing_tx)
keep_watching = False if sweep_info_dict else not self.is_deeply_mined(closing_tx.txid())
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -432,7 +432,7 @@ class LNWallet(LNWorker):
self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage
self.sweep_address = wallet.get_receiving_address()
self.lock = threading.RLock()
- self.logs = defaultdict(list) # (not persisted) type: Dict[str, List[PaymentAttemptLog]] # key is RHASH
+ self.logs = defaultdict(list) # type: Dict[str, List[PaymentAttemptLog]] # key is RHASH # (not persisted)
self.is_routing = set() # (not persisted) keys of invoices that are in PR_ROUTING state
# used in tests
self.enable_htlc_settle = asyncio.Event()