electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit e477a433854dc935ec0d117e07b5e75d9817860f
parent 1fb0c28d0ab316bc07ced9f531787b54db3d57e0
Author: ThomasV <thomasv@electrum.org>
Date:   Mon,  1 Feb 2021 14:17:04 +0100

PaymentInfo: use msat precision

Diffstat:
Melectrum/lnpeer.py | 7++++---
Melectrum/lnworker.py | 35+++++++++++++++++------------------
Melectrum/tests/test_lnpeer.py | 20++++++++++----------
Melectrum/wallet_db.py | 14+++++++++++++-
4 files changed, 44 insertions(+), 32 deletions(-)

diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -1388,7 +1388,7 @@ class Peer(Logger): if payment_secret_from_onion != derive_payment_secret_from_payment_preimage(preimage): reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') return None, reason - expected_received_msat = int(info.amount * 1000) if info.amount is not None else None + expected_received_msat = info.amount_msat if expected_received_msat is not None and \ not (expected_received_msat <= htlc.amount_msat <= 2 * expected_received_msat): reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') @@ -1410,8 +1410,9 @@ class Peer(Logger): reason = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') return None, reason if cltv_from_onion != htlc.cltv_expiry: - reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_CLTV_EXPIRY, - data=htlc.cltv_expiry.to_bytes(4, byteorder="big")) + reason = OnionRoutingFailureMessage( + code=OnionFailureCode.FINAL_INCORRECT_CLTV_EXPIRY, + data=htlc.cltv_expiry.to_bytes(4, byteorder="big")) return None, reason try: amount_from_onion = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"] diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -138,7 +138,7 @@ FALLBACK_NODE_LIST_MAINNET = [ class PaymentInfo(NamedTuple): payment_hash: bytes - amount: Optional[int] # in satoshis # TODO make it msat and rename to amount_msat + amount_msat: Optional[int] direction: int status: int @@ -564,7 +564,7 @@ class LNWallet(LNWorker): self.config = wallet.config self.lnwatcher = None self.lnrater: LNRater = None - self.payments = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid # FIXME amt should be msat + self.payments = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage # note: this sweep_address is only used as fallback; as it might result in address-reuse self.sweep_address = wallet.get_new_sweep_address_for_channel() @@ -687,8 +687,8 @@ class LNWallet(LNWorker): fee_msat = None for chan_id, htlc, _direction in plist: amount_msat += int(_direction) * htlc.amount_msat - if _direction == SENT and info and info.amount: - fee_msat = (fee_msat or 0) - info.amount*1000 - amount_msat + if _direction == SENT and info and info.amount_msat: + fee_msat = (fee_msat or 0) - info.amount_msat - amount_msat timestamp = min([htlc.timestamp for chan_id, htlc, _direction in plist]) return amount_msat, fee_msat, timestamp @@ -948,13 +948,13 @@ class LNWallet(LNWorker): lnaddr = self._check_invoice(invoice, amount_msat=amount_msat) payment_hash = lnaddr.paymenthash key = payment_hash.hex() - amount = int(lnaddr.amount * COIN) + amount_msat = lnaddr.get_amount_msat() status = self.get_payment_status(payment_hash) if status == PR_PAID: raise PaymentFailure(_("This invoice has been paid already")) if status == PR_INFLIGHT: raise PaymentFailure(_("A payment was already initiated for this invoice")) - info = PaymentInfo(lnaddr.paymenthash, amount, SENT, PR_UNPAID) + info = PaymentInfo(payment_hash, amount_msat, SENT, PR_UNPAID) self.save_payment_info(info) self.wallet.set_label(key, lnaddr.get_description()) self.logs[key] = log = [] @@ -1217,16 +1217,16 @@ class LNWallet(LNWorker): raise Exception(_("add invoice timed out")) @log_exceptions - async def create_invoice(self, amount_sat: Optional[int], message, expiry: int): + async def create_invoice(self, amount_msat: Optional[int], message, expiry: int): timestamp = int(time.time()) - routing_hints = await self._calc_routing_hints_for_invoice(amount_sat) + routing_hints = await self._calc_routing_hints_for_invoice(amount_msat) if not routing_hints: self.logger.info("Warning. No routing hints added to invoice. " "Other clients will likely not be able to send to us.") payment_preimage = os.urandom(32) payment_hash = sha256(payment_preimage) - info = PaymentInfo(payment_hash, amount_sat, RECEIVED, PR_UNPAID) - amount_btc = amount_sat/Decimal(COIN) if amount_sat else None + info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID) + amount_btc = amount_msat/Decimal(COIN*1000) if amount_msat else None if expiry == 0: expiry = LN_EXPIRY_NEVER lnaddr = LnAddr(paymenthash=payment_hash, @@ -1244,7 +1244,8 @@ class LNWallet(LNWorker): return lnaddr, invoice async def _add_request_coro(self, amount_sat: Optional[int], message, expiry: int) -> str: - lnaddr, invoice = await self.create_invoice(amount_sat, message, expiry) + amount_msat = amount_sat * 1000 if amount_sat is not None else None + lnaddr, invoice = await self.create_invoice(amount_msat, message, expiry) key = bh2u(lnaddr.paymenthash) req = LNInvoice.from_bech32(invoice) self.wallet.add_payment_request(req) @@ -1265,14 +1266,14 @@ class LNWallet(LNWorker): key = payment_hash.hex() with self.lock: if key in self.payments: - amount, direction, status = self.payments[key] - return PaymentInfo(payment_hash, amount, direction, status) + amount_msat, direction, status = self.payments[key] + return PaymentInfo(payment_hash, amount_msat, direction, status) def save_payment_info(self, info: PaymentInfo) -> None: key = info.payment_hash.hex() assert info.status in SAVED_PR_STATUS with self.lock: - self.payments[key] = info.amount, info.direction, info.status + self.payments[key] = info.amount_msat, info.direction, info.status self.wallet.save_db() def get_payment_status(self, payment_hash): @@ -1355,16 +1356,14 @@ class LNWallet(LNWorker): util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID) util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id) - async def _calc_routing_hints_for_invoice(self, amount_sat: Optional[int]): + async def _calc_routing_hints_for_invoice(self, amount_msat: Optional[int]): """calculate routing hints (BOLT-11 'r' field)""" routing_hints = [] channels = list(self.channels.values()) random.shuffle(channels) # not sure this has any benefit but let's not leak channel order scid_to_my_channels = {chan.short_channel_id: chan for chan in channels if chan.short_channel_id is not None} - if amount_sat: - amount_msat = 1000 * amount_sat - else: + if not amount_msat: # for no amt invoices, check if channel can receive at least 1 msat amount_msat = 1 # note: currently we add *all* our channels; but this might be a privacy leak? diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -373,17 +373,17 @@ class TestPeer(ElectrumTestCase): async def prepare_invoice( w2: MockLNWallet, # receiver *, - amount_sat=100_000, + amount_msat=100_000_000, include_routing_hints=False, ): - amount_btc = amount_sat/Decimal(COIN) + amount_btc = amount_msat/Decimal(COIN*1000) payment_preimage = os.urandom(32) RHASH = sha256(payment_preimage) - info = PaymentInfo(RHASH, amount_sat, RECEIVED, PR_UNPAID) + info = PaymentInfo(RHASH, amount_msat, RECEIVED, PR_UNPAID) w2.save_preimage(RHASH, payment_preimage) w2.save_payment_info(info) if include_routing_hints: - routing_hints = await w2._calc_routing_hints_for_invoice(amount_sat) + routing_hints = await w2._calc_routing_hints_for_invoice(amount_msat) else: routing_hints = [] lnaddr = LnAddr( @@ -541,14 +541,14 @@ class TestPeer(ElectrumTestCase): alice_init_balance_msat = alice_channel.balance(HTLCOwner.LOCAL) bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL) num_payments = 50 - payment_value_sat = 10000 # make it large enough so that there are actually HTLCs on the ctx + payment_value_msat = 10_000_000 # make it large enough so that there are actually HTLCs on the ctx max_htlcs_in_flight = asyncio.Semaphore(5) async def single_payment(pay_req): async with max_htlcs_in_flight: await w1._pay(pay_req) async def many_payments(): async with TaskGroup() as group: - pay_reqs_tasks = [await group.spawn(self.prepare_invoice(w2, amount_sat=payment_value_sat)) + pay_reqs_tasks = [await group.spawn(self.prepare_invoice(w2, amount_msat=payment_value_msat)) for i in range(num_payments)] async with TaskGroup() as group: for pay_req_task in pay_reqs_tasks: @@ -560,10 +560,10 @@ class TestPeer(ElectrumTestCase): await gath with self.assertRaises(concurrent.futures.CancelledError): run(f()) - self.assertEqual(alice_init_balance_msat - num_payments * payment_value_sat * 1000, alice_channel.balance(HTLCOwner.LOCAL)) - self.assertEqual(alice_init_balance_msat - num_payments * payment_value_sat * 1000, bob_channel.balance(HTLCOwner.REMOTE)) - self.assertEqual(bob_init_balance_msat + num_payments * payment_value_sat * 1000, bob_channel.balance(HTLCOwner.LOCAL)) - self.assertEqual(bob_init_balance_msat + num_payments * payment_value_sat * 1000, alice_channel.balance(HTLCOwner.REMOTE)) + self.assertEqual(alice_init_balance_msat - num_payments * payment_value_msat, alice_channel.balance(HTLCOwner.LOCAL)) + self.assertEqual(alice_init_balance_msat - num_payments * payment_value_msat, bob_channel.balance(HTLCOwner.REMOTE)) + self.assertEqual(bob_init_balance_msat + num_payments * payment_value_msat, bob_channel.balance(HTLCOwner.LOCAL)) + self.assertEqual(bob_init_balance_msat + num_payments * payment_value_msat, alice_channel.balance(HTLCOwner.REMOTE)) @needs_test_with_all_chacha20_implementations def test_payment_multihop(self): diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py @@ -52,7 +52,7 @@ if TYPE_CHECKING: OLD_SEED_VERSION = 4 # electrum versions < 2.0 NEW_SEED_VERSION = 11 # electrum versions >= 2.0 -FINAL_SEED_VERSION = 36 # electrum >= 2.7 will set this to prevent +FINAL_SEED_VERSION = 37 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format @@ -184,6 +184,7 @@ class WalletDB(JsonDB): self._convert_version_34() self._convert_version_35() self._convert_version_36() + self._convert_version_37() self.put('seed_version', FINAL_SEED_VERSION) # just to be sure self._after_upgrade_tasks() @@ -740,6 +741,17 @@ class WalletDB(JsonDB): self.data['frozen_coins'] = new_frozen_coins self.data['seed_version'] = 36 + def _convert_version_37(self): + if not self._is_upgrade_method_needed(36, 36): + return + payments = self.data.get('lightning_payments', {}) + for k, v in list(payments.items()): + amount_sat, direction, status = v + amount_msat = amount_sat * 1000 if amount_sat is not None else None + payments[k] = amount_msat, direction, status + self.data['lightning_payments'] = payments + self.data['seed_version'] = 37 + def _convert_imported(self): if not self._is_upgrade_method_needed(0, 13): return