electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit ea0981ebebe978290111c6942ecc52f30cee6604
parent 444610452e4c3344f262680e779b208647d134d5
Author: SomberNight <somber.night@protonmail.com>
Date:   Tue, 17 Mar 2020 20:28:59 +0100

lnutil.UpdateAddHtlc: use attrs instead of old-style namedtuple

Diffstat:
Melectrum/lnchannel.py | 5+++--
Melectrum/lnutil.py | 34+++++++++++++++++-----------------
Melectrum/util.py | 3+++
Melectrum/wallet_db.py | 2+-
4 files changed, 24 insertions(+), 20 deletions(-)

diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py @@ -32,6 +32,7 @@ import time import threading from aiorpcx import NetAddress +import attr from . import ecc from . import constants @@ -434,7 +435,7 @@ class Channel(Logger): assert isinstance(htlc, UpdateAddHtlc) self._check_can_pay(htlc.amount_msat) if htlc.htlc_id is None: - htlc = htlc._replace(htlc_id=self.hm.get_next_htlc_id(LOCAL)) + htlc = attr.evolve(htlc, htlc_id=self.hm.get_next_htlc_id(LOCAL)) with self.db_lock: self.hm.send_htlc(htlc) self.logger.info("add_htlc") @@ -452,7 +453,7 @@ class Channel(Logger): htlc = UpdateAddHtlc(**htlc) assert isinstance(htlc, UpdateAddHtlc) if htlc.htlc_id is None: # used in unit tests - htlc = htlc._replace(htlc_id=self.hm.get_next_htlc_id(REMOTE)) + htlc = attr.evolve(htlc, htlc_id=self.hm.get_next_htlc_id(REMOTE)) if 0 <= self.available_to_spend(REMOTE) < htlc.amount_msat: raise RemoteMisbehaving('Remote dipped below channel reserve.' +\ f' Available at remote: {self.available_to_spend(REMOTE)},' +\ diff --git a/electrum/lnutil.py b/electrum/lnutil.py @@ -878,21 +878,21 @@ def format_short_channel_id(short_channel_id: Optional[bytes]): + 'x' + str(int.from_bytes(short_channel_id[6:], 'big')) -class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id', 'timestamp'])): - # note: typing.NamedTuple cannot be used because we are overriding __new__ - - __slots__ = () - def __new__(cls, *args, **kwargs): - # if you pass a hex-string as payment_hash, it is decoded to bytes. - # Bytes can't be saved to disk, so we save hex-strings. - if len(args) > 0: - args = list(args) - if type(args[1]) is str: - args[1] = bfh(args[1]) - return super().__new__(cls, *args) - if type(kwargs['payment_hash']) is str: - kwargs['payment_hash'] = bfh(kwargs['payment_hash']) - if len(args) < 4 and 'htlc_id' not in kwargs: - kwargs['htlc_id'] = None - return super().__new__(cls, **kwargs) +@attr.s(frozen=True) +class UpdateAddHtlc: + amount_msat = attr.ib(type=int, kw_only=True) + payment_hash = attr.ib(type=bytes, kw_only=True, converter=hex_to_bytes) + cltv_expiry = attr.ib(type=int, kw_only=True) + timestamp = attr.ib(type=int, kw_only=True) + htlc_id = attr.ib(type=int, kw_only=True, default=None) + @classmethod + def from_tuple(cls, amount_msat, payment_hash, cltv_expiry, htlc_id, timestamp) -> 'UpdateAddHtlc': + return cls(amount_msat=amount_msat, + payment_hash=payment_hash, + cltv_expiry=cltv_expiry, + htlc_id=htlc_id, + timestamp=timestamp) + + def to_tuple(self): + return (self.amount_msat, self.payment_hash, self.cltv_expiry, self.htlc_id, self.timestamp) diff --git a/electrum/util.py b/electrum/util.py @@ -277,6 +277,9 @@ class MyEncoder(json.JSONEncoder): def default(self, obj): # note: this does not get called for namedtuples :( https://bugs.python.org/issue30343 from .transaction import Transaction, TxOutput + from .lnutil import UpdateAddHtlc + if isinstance(obj, UpdateAddHtlc): + return obj.to_tuple() if isinstance(obj, Transaction): return obj.serialize() if isinstance(obj, TxOutput): diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py @@ -1079,7 +1079,7 @@ class WalletDB(JsonDB): # note: for performance, "deserialize=False" so that we will deserialize these on-demand v = dict((k, tx_from_any(x, deserialize=False)) for k, x in v.items()) elif key == 'adds': - v = dict((k, UpdateAddHtlc(*x)) for k, x in v.items()) + v = dict((k, UpdateAddHtlc.from_tuple(*x)) for k, x in v.items()) elif key == 'fee_updates': v = dict((k, FeeUpdate(**x)) for k, x in v.items()) elif key == 'tx_fees':