electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit f08e5541aeda9524eb456477f93732de7357c792
parent 0a395fefbc885904e0b14ddea593d3cf6c579ce8
Author: ThomasV <thomasv@electrum.org>
Date:   Fri, 20 Sep 2019 17:15:49 +0200

Refactor invoices in lnworker.
 - use InvoiceInfo (NamedTuple) for normal operations,
   because lndecode operations can be very slow.
 - all invoices/requests are stored in wallet
 - invoice expiration detection is performed in wallet
 - CLI commands: list_invoices, add_request, add_lightning_request
 - revert 0062c6d69561991d5918163946344c1b10ed9588 because it forbids self-payments

Diffstat:
Melectrum/commands.py | 36+++++++++++-------------------------
Melectrum/gui/kivy/main_window.py | 18++++++------------
Melectrum/gui/kivy/uix/screens.py | 7+------
Melectrum/gui/qt/main_window.py | 12+++---------
Melectrum/lnpeer.py | 6+++---
Melectrum/lnworker.py | 215++++++++++++++++++++++++++++++++++++++-----------------------------------------
Melectrum/tests/regtest/regtest.sh | 18+++++++++---------
Melectrum/tests/test_lnpeer.py | 22++++++++++++++--------
Melectrum/wallet.py | 120+++++++++++++++++++++++++++++++++++++++++--------------------------------------
9 files changed, 212 insertions(+), 242 deletions(-)

diff --git a/electrum/commands.py b/electrum/commands.py @@ -716,7 +716,7 @@ class Commands: def _format_request(self, out): from .util import get_request_status out['amount_BTC'] = format_satoshis(out.get('amount')) - out['status'] = get_request_status(out) + out['status_str'] = get_request_status(out) return out @command('w') @@ -733,7 +733,7 @@ class Commands: # pass @command('w') - async def listrequests(self, pending=False, expired=False, paid=False, wallet: Abstract_Wallet = None): + async def list_requests(self, pending=False, expired=False, paid=False, wallet: Abstract_Wallet = None): """List the payment requests you made.""" out = wallet.get_sorted_requests() if pending: @@ -760,7 +760,7 @@ class Commands: return wallet.get_unused_address() @command('w') - async def addrequest(self, amount, memo='', expiration=None, force=False, wallet: Abstract_Wallet = None): + async def add_request(self, amount, memo='', expiration=3600, force=False, wallet: Abstract_Wallet = None): """Create a payment request, using the first unused address of the wallet. The address will be considered as used after this operation. If no payment is received, the address will be considered as unused if the payment request is deleted from the wallet.""" @@ -777,6 +777,12 @@ class Commands: out = wallet.get_request(addr) return self._format_request(out) + @command('wn') + async def add_lightning_request(self, amount, memo='', expiration=3600, wallet: Abstract_Wallet = None): + amount_sat = int(satoshis(amount)) + key = await wallet.lnworker._add_request_coro(amount_sat, memo, expiration) + return wallet.get_request(key)['invoice'] + @command('w') async def addtransaction(self, tx, wallet: Abstract_Wallet = None): """ Add a transaction to the wallet history """ @@ -894,13 +900,6 @@ class Commands: async def lnpay(self, invoice, attempts=1, timeout=10, wallet: Abstract_Wallet = None): return await wallet.lnworker._pay(invoice, attempts=attempts) - @command('wn') - async def addinvoice(self, requested_amount, message, expiration=3600, wallet: Abstract_Wallet = None): - # using requested_amount because it is documented in param_descriptions - payment_hash = await wallet.lnworker._add_invoice_coro(satoshis(requested_amount), message, expiration) - invoice, direction, is_paid = wallet.lnworker.invoices[bh2u(payment_hash)] - return invoice - @command('w') async def nodeid(self, wallet: Abstract_Wallet = None): listen_addr = self.config.get('lightning_listen') @@ -925,21 +924,8 @@ class Commands: self.network.path_finder.blacklist.clear() @command('w') - async def lightning_invoices(self, wallet: Abstract_Wallet = None): - from .util import pr_tooltips - out = [] - for payment_hash, (preimage, invoice, is_received, timestamp) in wallet.lnworker.invoices.items(): - status = wallet.lnworker.get_invoice_status(payment_hash) - item = { - 'date':timestamp_to_datetime(timestamp), - 'direction': 'received' if is_received else 'sent', - 'payment_hash':payment_hash, - 'invoice':invoice, - 'preimage':preimage, - 'status':pr_tooltips[status] - } - out.append(item) - return out + async def list_invoices(self, wallet: Abstract_Wallet = None): + return wallet.get_invoices() @command('w') async def lightning_history(self, wallet: Abstract_Wallet = None): diff --git a/electrum/gui/kivy/main_window.py b/electrum/gui/kivy/main_window.py @@ -428,14 +428,11 @@ class ElectrumWindow(App): def show_request(self, is_lightning, key): from .uix.dialogs.request_dialog import RequestDialog - if is_lightning: - request, direction, is_paid = self.wallet.lnworker.invoices.get(key) or (None, None, None) - status = self.wallet.lnworker.get_invoice_status(key) - else: - request = self.wallet.get_request_URI(key) - status, conf = self.wallet.get_request_status(key) - self.request_popup = RequestDialog('Request', request, key) - self.request_popup.set_status(status) + request = self.wallet.get_request(key) + status = request['status'] + data = request['invoice'] if is_lightning else request['URI'] + self.request_popup = RequestDialog('Request', data, key) + self.request_popup.set_status(request['status']) self.request_popup.open() def show_invoice(self, is_lightning, key): @@ -444,10 +441,7 @@ class ElectrumWindow(App): if not invoice: return status = invoice['status'] - if is_lightning: - data = invoice['invoice'] - else: - data = key + data = invoice['invoice'] if is_lightning else key self.invoice_popup = InvoiceDialog('Invoice', data, key) self.invoice_popup.open() diff --git a/electrum/gui/kivy/uix/screens.py b/electrum/gui/kivy/uix/screens.py @@ -304,12 +304,7 @@ class SendScreen(CScreen): return message = self.screen.message if self.screen.is_lightning: - return { - 'type': PR_TYPE_LN, - 'invoice': address, - 'amount': amount, - 'message': message, - } + return self.app.wallet.lnworker.parse_bech32_invoice(address) else: if not bitcoin.is_address(address): self.app.show_error(_('Invalid Bitcoin Address') + ':\n' + address) diff --git a/electrum/gui/qt/main_window.py b/electrum/gui/qt/main_window.py @@ -1073,8 +1073,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger): message = self.receive_message_e.text() expiry = self.config.get('request_expiry', 3600) if is_lightning: - payment_hash = self.wallet.lnworker.add_invoice(amount, message, expiry) - key = bh2u(payment_hash) + key = self.wallet.lnworker.add_request(amount, message, expiry) else: key = self.create_bitcoin_request(amount, message, expiry) self.address_list.update() @@ -1698,12 +1697,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger): message = self.message_e.text() amount = self.amount_e.get_amount() if not self.is_onchain: - return { - 'type': PR_TYPE_LN, - 'invoice': self.payto_e.lightning_invoice, - 'amount': amount, - 'message': message, - } + return self.wallet.lnworker.parse_bech32_invoice(self.payto_e.lightning_invoice) else: outputs = self.read_outputs() if self.check_send_tab_outputs_and_show_errors(outputs): @@ -1733,7 +1727,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger): def do_pay_invoice(self, invoice, preview=False): if invoice['type'] == PR_TYPE_LN: - self.pay_lightning_invoice(self.payto_e.lightning_invoice) + self.pay_lightning_invoice(invoice['invoice']) return elif invoice['type'] == PR_TYPE_ONCHAIN: message = invoice['message'] diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -1402,13 +1402,13 @@ class Peer(Logger): await self.await_local(chan, local_ctn) await self.await_remote(chan, remote_ctn) try: - invoice = self.lnworker.get_invoice(htlc.payment_hash) + info = self.lnworker.get_invoice_info(htlc.payment_hash) preimage = self.lnworker.get_preimage(htlc.payment_hash) except UnknownPaymentHash: reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return - expected_received_msat = int(invoice.amount * bitcoin.COIN * 1000) if invoice.amount is not None else None + expected_received_msat = int(info.amount * 1000) if info.amount is not None else None if expected_received_msat is not None and \ not (expected_received_msat <= htlc.amount_msat <= 2 * expected_received_msat): reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') @@ -1431,7 +1431,7 @@ class Peer(Logger): data=htlc.amount_msat.to_bytes(8, byteorder="big")) await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return - self.network.trigger_callback('htlc_added', htlc, invoice, RECEIVED) + #self.network.trigger_callback('htlc_added', htlc, invoice, RECEIVED) await asyncio.sleep(self.network.config.lightning_settle_delay) await self._fulfill_htlc(chan, htlc.htlc_id, preimage) diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -82,6 +82,16 @@ FALLBACK_NODE_LIST_MAINNET = [ encoder = ChannelJsonEncoder() + +from typing import NamedTuple + +class InvoiceInfo(NamedTuple): + payment_hash: bytes + amount: int + direction: int + status: int + + class LNWorker(Logger): def __init__(self, xprv): @@ -313,7 +323,7 @@ class LNWallet(LNWorker): LNWorker.__init__(self, xprv) self.ln_keystore = keystore.from_xprv(xprv) self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ - self.invoices = self.storage.get('lightning_invoices', {}) # RHASH -> (invoice, direction, is_paid) + self.invoices = self.storage.get('lightning_invoices2', {}) # RHASH -> amount, direction, is_paid self.preimages = self.storage.get('lightning_preimages', {}) # RHASH -> preimage self.sweep_address = wallet.get_receiving_address() self.lock = threading.RLock() @@ -409,16 +419,6 @@ class LNWallet(LNWorker): timestamp = int(time.time()) self.network.trigger_callback('ln_payment_completed', timestamp, direction, htlc, preimage, chan_id) - def get_invoice_status(self, key): - if key not in self.invoices: - return PR_UNKNOWN - invoice, direction, status = self.invoices[key] - lnaddr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) - if status == PR_UNPAID and lnaddr.is_expired(): - return PR_EXPIRED - else: - return status - def get_payments(self): # return one item per payment_hash # note: with AMP we will have several channels per payment @@ -431,6 +431,20 @@ class LNWallet(LNWorker): out[k].append(v) return out + def parse_bech32_invoice(self, invoice): + lnaddr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) + amount = int(lnaddr.amount * COIN) if lnaddr.amount else None + return { + 'type': PR_TYPE_LN, + 'invoice': invoice, + 'amount': amount, + 'message': lnaddr.get_description(), + 'time': lnaddr.date, + 'exp': lnaddr.get_expiry(), + 'pubkey': bh2u(lnaddr.pubkey.serialize()), + 'rhash': lnaddr.paymenthash.hex(), + } + def get_unsettled_payments(self): out = [] for payment_hash, plist in self.get_payments().items(): @@ -455,7 +469,7 @@ class LNWallet(LNWorker): def get_history(self): out = [] - for payment_hash, plist in self.get_payments().items(): + for key, plist in self.get_payments().items(): plist = list(filter(lambda x: x[3] == 'settled', plist)) if len(plist) == 0: continue @@ -464,11 +478,13 @@ class LNWallet(LNWorker): direction = 'sent' if _direction == SENT else 'received' amount_msat = int(_direction) * htlc.amount_msat timestamp = htlc.timestamp - label = self.wallet.get_label(payment_hash) - req = self.get_request(payment_hash) - if req and _direction == SENT: - req_amount_msat = -req['amount']*1000 - fee_msat = req_amount_msat - amount_msat + label = self.wallet.get_label(key) + if _direction == SENT: + try: + inv = self.get_invoice_info(bfh(key)) + fee_msat = inv.amount*1000 - amount_msat if inv.amount else None + except UnknownPaymentHash: + fee_msat = None else: fee_msat = None else: @@ -489,7 +505,7 @@ class LNWallet(LNWorker): 'status': status, 'amount_msat': amount_msat, 'fee_msat': fee_msat, - 'payment_hash': payment_hash + 'payment_hash': key, } out.append(item) # add funding events @@ -831,20 +847,23 @@ class LNWallet(LNWorker): @log_exceptions async def _pay(self, invoice, amount_sat=None, attempts=1): - addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) - key = bh2u(addr.paymenthash) - if key in self.preimages: + lnaddr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) + key = bh2u(lnaddr.paymenthash) + amount = int(lnaddr.amount * COIN) if lnaddr.amount else None + status = self.get_invoice_status(lnaddr.paymenthash) + if status == PR_PAID: raise PaymentFailure(_("This invoice has been paid already")) + info = InvoiceInfo(lnaddr.paymenthash, amount, SENT, PR_UNPAID) + self.save_invoice_info(info) self._check_invoice(invoice, amount_sat) - self.save_invoice(addr.paymenthash, invoice, SENT, PR_INFLIGHT) - self.wallet.set_label(key, addr.get_description()) + self.wallet.set_label(key, lnaddr.get_description()) for i in range(attempts): - route = await self._create_route_from_invoice(decoded_invoice=addr) + route = await self._create_route_from_invoice(decoded_invoice=lnaddr) if not self.get_channel_by_short_id(route[0].short_channel_id): scid = route[0].short_channel_id raise Exception(f"Got route with unknown first channel: {scid}") self.network.trigger_callback('payment_status', key, 'progress', i) - if await self._pay_to_route(route, addr, invoice): + if await self._pay_to_route(route, lnaddr, invoice): return True return False @@ -854,10 +873,12 @@ class LNWallet(LNWorker): if not chan: raise Exception(f"PathFinder returned path with short_channel_id " f"{short_channel_id} that is not in channel list") + self.set_invoice_status(addr.paymenthash, PR_INFLIGHT) peer = self.peers[route[0].node_id] htlc = await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry()) self.network.trigger_callback('htlc_added', htlc, addr, SENT) success = await self.pending_payments[(short_channel_id, htlc.htlc_id)] + self.set_invoice_status(addr.paymenthash, (PR_PAID if success else PR_UNPAID)) return success @staticmethod @@ -933,119 +954,89 @@ class LNWallet(LNWorker): raise PaymentFailure(_("No path found")) return route - def add_invoice(self, amount_sat, message, expiry): - coro = self._add_invoice_coro(amount_sat, message, expiry) + def add_request(self, amount_sat, message, expiry): + coro = self._add_request_coro(amount_sat, message, expiry) fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) try: return fut.result(timeout=5) except concurrent.futures.TimeoutError: - raise Exception(_("add_invoice timed out")) + raise Exception(_("add invoice timed out")) @log_exceptions - async def _add_invoice_coro(self, amount_sat, message, expiry): - payment_preimage = os.urandom(32) - payment_hash = sha256(payment_preimage) - amount_btc = amount_sat/Decimal(COIN) if amount_sat else None + async def _add_request_coro(self, amount_sat, message, expiry): + timestamp = int(time.time()) routing_hints = await self._calc_routing_hints_for_invoice(amount_sat) if not routing_hints: self.logger.info("Warning. No routing hints added to invoice. " "Other clients will likely not be able to send to us.") - invoice = lnencode(LnAddr(payment_hash, amount_btc, - tags=[('d', message), - ('c', MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE), - ('x', expiry)] - + routing_hints), - self.node_keypair.privkey) - self.save_invoice(payment_hash, invoice, RECEIVED, PR_UNPAID) + payment_preimage = os.urandom(32) + payment_hash = sha256(payment_preimage) + info = InvoiceInfo(payment_hash, amount_sat, RECEIVED, PR_UNPAID) + amount_btc = amount_sat/Decimal(COIN) if amount_sat else None + lnaddr = LnAddr(payment_hash, amount_btc, + tags=[('d', message), + ('c', MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE), + ('x', expiry)] + + routing_hints, + date = timestamp) + invoice = lnencode(lnaddr, self.node_keypair.privkey) + key = bh2u(lnaddr.paymenthash) + req = { + 'type': PR_TYPE_LN, + 'amount': amount_sat, + 'time': lnaddr.date, + 'exp': expiry, + 'message': message, + 'rhash': key, + 'invoice': invoice + } self.save_preimage(payment_hash, payment_preimage) - self.wallet.set_label(bh2u(payment_hash), message) - return payment_hash + self.save_invoice_info(info) + self.wallet.add_payment_request(req) + self.wallet.set_label(key, message) + return key def save_preimage(self, payment_hash: bytes, preimage: bytes): assert sha256(preimage) == payment_hash - key = bh2u(payment_hash) - self.preimages[key] = bh2u(preimage) + self.preimages[bh2u(payment_hash)] = bh2u(preimage) self.storage.put('lightning_preimages', self.preimages) self.storage.write() def get_preimage(self, payment_hash: bytes) -> bytes: - try: - preimage = bfh(self.preimages[bh2u(payment_hash)]) - assert sha256(preimage) == payment_hash - return preimage - except KeyError as e: - raise UnknownPaymentHash(payment_hash) from e + return bfh(self.preimages.get(bh2u(payment_hash))) - def save_new_invoice(self, invoice): - addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) - self.save_invoice(addr.paymenthash, invoice, SENT, PR_UNPAID) + def get_invoice_info(self, payment_hash: bytes) -> bytes: + key = payment_hash.hex() + with self.lock: + if key not in self.invoices: + raise UnknownPaymentHash(payment_hash) + amount, direction, status = self.invoices[key] + return InvoiceInfo(payment_hash, amount, direction, status) - def save_invoice(self, payment_hash:bytes, invoice, direction, status): - key = bh2u(payment_hash) + def save_invoice_info(self, info): + key = info.payment_hash.hex() with self.lock: - self.invoices[key] = invoice, direction, status - self.storage.put('lightning_invoices', self.invoices) + self.invoices[key] = info.amount, info.direction, info.status + self.storage.put('lightning_invoices2', self.invoices) self.storage.write() - def set_invoice_status(self, payment_hash, status): - key = bh2u(payment_hash) - if key not in self.invoices: - # if we are forwarding - return - invoice, direction, _ = self.invoices[key] - self.save_invoice(payment_hash, invoice, direction, status) - if direction == RECEIVED and status == PR_PAID: - self.network.trigger_callback('payment_received', self.wallet, key, PR_PAID) - - def get_invoice(self, payment_hash: bytes) -> LnAddr: + def get_invoice_status(self, payment_hash): try: - invoice, direction, is_paid = self.invoices[bh2u(payment_hash)] - return lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) - except KeyError as e: - raise UnknownPaymentHash(payment_hash) from e + info = self.get_invoice_info(payment_hash) + return info.status + except UnknownPaymentHash: + return PR_UNKNOWN - def get_request(self, key): - if key not in self.invoices: + def set_invoice_status(self, payment_hash: bytes, status): + try: + info = self.get_invoice_info(payment_hash) + except UnknownPaymentHash: + # if we are forwarding return - # todo: parse invoices when saving - invoice, direction, is_paid = self.invoices[key] - status = self.get_invoice_status(key) - lnaddr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) - amount_sat = int(lnaddr.amount*COIN) if lnaddr.amount else None - description = lnaddr.get_description() - timestamp = lnaddr.date - return { - 'type': PR_TYPE_LN, - 'status': status, - 'amount': amount_sat, - 'time': timestamp, - 'exp': lnaddr.get_expiry(), - 'message': description, - 'rhash': key, - 'invoice': invoice - } - - @profiler - def get_invoices(self): - # invoices = outgoing - out = [] - with self.lock: - invoice_items = list(self.invoices.items()) - for key, (invoice, direction, status) in invoice_items: - if direction == SENT and status != PR_PAID: - out.append(self.get_request(key)) - return out - - @profiler - def get_requests(self): - # requests = incoming - out = [] - with self.lock: - invoice_items = list(self.invoices.items()) - for key, (invoice, direction, status) in invoice_items: - if direction == RECEIVED and status != PR_PAID: - out.append(self.get_request(key)) - return out + info = info._replace(status=status) + self.save_invoice_info(info) + if info.direction == RECEIVED and info.status == PR_PAID: + self.network.trigger_callback('payment_received', self.wallet, bh2u(payment_hash), PR_PAID) async def _calc_routing_hints_for_invoice(self, amount_sat): """calculate routing hints (BOLT-11 'r' field)""" diff --git a/electrum/tests/regtest/regtest.sh b/electrum/tests/regtest/regtest.sh @@ -114,7 +114,7 @@ if [[ $1 == "open" ]]; then fi if [[ $1 == "alice_pays_carol" ]]; then - request=$($carol addinvoice 0.0001 "blah") + request=$($carol add_lightning_request 0.0001 -m "blah") $alice lnpay $request carol_balance=$($carol list_channels | jq -r '.[0].local_balance') echo "carol balance: $carol_balance" @@ -140,12 +140,12 @@ if [[ $1 == "breach" ]]; then channel=$($alice open_channel $bob_node 0.15) new_blocks 3 wait_until_channel_open alice - request=$($bob addinvoice 0.01 "blah") + request=$($bob add_lightning_request 0.01 -m "blah") echo "alice pays" $alice lnpay $request sleep 2 ctx=$($alice get_channel_ctx $channel | jq '.hex' | tr -d '"') - request=$($bob addinvoice 0.01 "blah2") + request=$($bob add_lightning_request 0.01 -m "blah2") echo "alice pays again" $alice lnpay $request echo "alice broadcasts old ctx" @@ -168,7 +168,7 @@ if [[ $1 == "redeem_htlcs" ]]; then new_blocks 6 sleep 10 # alice pays bob - invoice=$($bob addinvoice 0.05 "test") + invoice=$($bob add_lightning_request 0.05 -m "test") $alice lnpay $invoice --timeout=1 || true sleep 1 settled=$($alice list_channels | jq '.[] | .local_htlcs | .settles | length') @@ -214,7 +214,7 @@ if [[ $1 == "breach_with_unspent_htlc" ]]; then new_blocks 3 wait_until_channel_open alice echo "alice pays bob" - invoice=$($bob addinvoice 0.05 "test") + invoice=$($bob add_lightning_request 0.05 -m "test") $alice lnpay $invoice --timeout=1 || true settled=$($alice list_channels | jq '.[] | .local_htlcs | .settles | length') if [[ "$settled" != "0" ]]; then @@ -246,7 +246,7 @@ if [[ $1 == "breach_with_spent_htlc" ]]; then new_blocks 3 wait_until_channel_open alice echo "alice pays bob" - invoice=$($bob addinvoice 0.05 "test") + invoice=$($bob add_lightning_request 0.05 -m "test") $alice lnpay $invoice --timeout=1 || true ctx=$($alice get_channel_ctx $channel | jq '.hex' | tr -d '"') settled=$($alice list_channels | jq '.[] | .local_htlcs | .settles | length') @@ -310,11 +310,11 @@ if [[ $1 == "watchtower" ]]; then new_blocks 3 wait_until_channel_open alice echo "alice pays bob" - invoice1=$($bob addinvoice 0.05 "invoice1") + invoice1=$($bob add_lightning_request 0.05 -m "invoice1") $alice lnpay $invoice1 - invoice2=$($bob addinvoice 0.05 "invoice2") + invoice2=$($bob add_lightning_request 0.05 -m "invoice2") $alice lnpay $invoice2 - invoice3=$($bob addinvoice 0.05 "invoice3") + invoice3=$($bob add_lightning_request 0.05 -m "invoice3") $alice lnpay $invoice3 fi diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -21,6 +21,7 @@ from electrum.channel_db import ChannelDB from electrum.lnworker import LNWallet from electrum.lnmsg import encode_msg, decode_msg from electrum.logging import console_stderr_handler +from electrum.lnworker import InvoiceInfo, RECEIVED, PR_UNPAID from .test_lnchannel import create_test_channels from . import SequentialTestCase @@ -80,6 +81,7 @@ class MockWallet: pass class MockLNWallet: + storage = MockStorage() def __init__(self, remote_keypair, local_keypair, chan, tx_queue): self.chan = chan self.remote_keypair = remote_keypair @@ -87,7 +89,6 @@ class MockLNWallet: self.network = MockNetwork(tx_queue) self.channels = {self.chan.channel_id: self.chan} self.invoices = {} - self.preimages = {} self.inflight = {} self.wallet = MockWallet() self.localfeatures = LnLocalFeatures(0) @@ -122,7 +123,11 @@ class MockLNWallet: def save_invoice(*args, is_paid=False): pass - get_invoice = LNWallet.get_invoice + preimages = {} + get_invoice_info = LNWallet.get_invoice_info + save_invoice_info = LNWallet.save_invoice_info + set_invoice_status = LNWallet.set_invoice_status + save_preimage = LNWallet.save_preimage get_preimage = LNWallet.get_preimage _create_route_from_invoice = LNWallet._create_route_from_invoice _check_invoice = staticmethod(LNWallet._check_invoice) @@ -207,19 +212,20 @@ class TestPeer(SequentialTestCase): @staticmethod def prepare_invoice(w2 # receiver ): - amount_btc = 100000/Decimal(COIN) + amount_sat = 100000 + amount_btc = amount_sat/Decimal(COIN) payment_preimage = os.urandom(32) RHASH = sha256(payment_preimage) - addr = LnAddr( + info = InvoiceInfo(RHASH, amount_sat, RECEIVED, PR_UNPAID) + w2.save_preimage(RHASH, payment_preimage) + w2.save_invoice_info(info) + lnaddr = LnAddr( RHASH, amount_btc, tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE), ('d', 'coffee') ]) - pay_req = lnencode(addr, w2.node_keypair.privkey) - w2.preimages[bh2u(RHASH)] = bh2u(payment_preimage) - w2.invoices[bh2u(RHASH)] = (pay_req, True, False) - return pay_req + return lnencode(lnaddr, w2.node_keypair.privkey) def test_payment(self): p1, p2, w1, w2, _q1, _q2 = self.prepare_peers() diff --git a/electrum/wallet.py b/electrum/wallet.py @@ -541,16 +541,16 @@ class Abstract_Wallet(AddressSynchronizer): def save_invoice(self, invoice): invoice_type = invoice['type'] if invoice_type == PR_TYPE_LN: - self.lnworker.save_new_invoice(invoice['invoice']) + key = invoice['rhash'] elif invoice_type == PR_TYPE_ONCHAIN: key = bh2u(sha256(repr(invoice))[0:16]) invoice['id'] = key invoice['txid'] = None - self.invoices[key] = invoice - self.storage.put('invoices', self.invoices) - self.storage.write() else: raise Exception('Unsupported invoice type') + self.invoices[key] = invoice + self.storage.put('invoices', self.invoices) + self.storage.write() def clear_invoices(self): self.invoices = {} @@ -560,29 +560,26 @@ class Abstract_Wallet(AddressSynchronizer): def get_invoices(self): out = [self.get_invoice(key) for key in self.invoices.keys()] out = [x for x in out if x and x.get('status') != PR_PAID] - if self.lnworker: - out += self.lnworker.get_invoices() out.sort(key=operator.itemgetter('time')) return out + def check_if_expired(self, item): + if item['status'] == PR_UNPAID and 'exp' in item and item['time'] + item['exp'] < time.time(): + item['status'] = PR_EXPIRED + def get_invoice(self, key): - if key in self.invoices: - item = copy.copy(self.invoices[key]) - request_type = item.get('type') - if request_type is None: - # todo: convert old bip70 invoices - return - # add status - if item.get('txid'): - status = PR_PAID - elif 'exp' in item and item['time'] + item['exp'] < time.time(): - status = PR_EXPIRED - else: - status = PR_UNPAID - item['status'] = status - return item - if self.lnworker: - return self.lnworker.get_request(key) + if key not in self.invoices: + return + item = copy.copy(self.invoices[key]) + request_type = item.get('type') + if request_type == PR_TYPE_ONCHAIN: + item['status'] = PR_PAID if item.get('txid') is not None else PR_UNPAID + elif request_type == PR_TYPE_LN: + item['status'] = self.lnworker.get_invoice_status(bfh(item['rhash'])) + else: + return + self.check_if_expired(item) + return item @profiler def get_full_history(self, fx=None, *, onchain_domain=None, include_lightning=True): @@ -1319,19 +1316,6 @@ class Abstract_Wallet(AddressSynchronizer): return True, conf return False, None - def get_payment_request(self, addr): - r = self.receive_requests.get(addr) - if not r: - return - out = copy.copy(r) - out['type'] = PR_TYPE_ONCHAIN - out['URI'] = self.get_request_URI(addr) - status, conf = self.get_request_status(addr) - out['status'] = status - if conf is not None: - out['confirmations'] = conf - return out - def get_request_URI(self, addr): req = self.receive_requests[addr] message = self.labels.get(addr, '') @@ -1349,11 +1333,10 @@ class Abstract_Wallet(AddressSynchronizer): uri = create_bip21_uri(addr, amount, message, extra_query_params=extra_query_params) return str(uri) - def get_request_status(self, key): - r = self.receive_requests.get(key) + def get_request_status(self, address): + r = self.receive_requests.get(address) if r is None: return PR_UNKNOWN - address = r['address'] amount = r.get('amount', 0) or 0 timestamp = r.get('time', 0) if timestamp and type(timestamp) != int: @@ -1372,14 +1355,23 @@ class Abstract_Wallet(AddressSynchronizer): return status, conf def get_request(self, key): - if key in self.receive_requests: - req = self.get_payment_request(key) - elif self.lnworker: - req = self.lnworker.get_request(key) - else: - req = None + req = self.receive_requests.get(key) if not req: return + req = copy.copy(req) + if req['type'] == PR_TYPE_ONCHAIN: + addr = req['address'] + req['URI'] = self.get_request_URI(addr) + status, conf = self.get_request_status(addr) + req['status'] = status + if conf is not None: + req['confirmations'] = conf + elif req['type'] == PR_TYPE_LN: + req['status'] = self.lnworker.get_invoice_status(bfh(key)) + else: + return + self.check_if_expired(req) + # add URL if we are running a payserver if self.config.get('payserver_port'): host = self.config.get('payserver_host', 'localhost') port = self.config.get('payserver_port') @@ -1405,8 +1397,16 @@ class Abstract_Wallet(AddressSynchronizer): from .bitcoin import TYPE_ADDRESS timestamp = int(time.time()) _id = bh2u(sha256d(addr + "%d"%timestamp))[0:10] - r = {'time':timestamp, 'amount':amount, 'exp':expiration, 'address':addr, 'memo':message, 'id':_id, 'outputs': [(TYPE_ADDRESS, addr, amount)]} - return r + return { + 'type': PR_TYPE_ONCHAIN, + 'time':timestamp, + 'amount':amount, + 'exp':expiration, + 'address':addr, + 'memo':message, + 'id':_id, + 'outputs': [(TYPE_ADDRESS, addr, amount)] + } def sign_payment_request(self, key, alias, alias_addr, password): req = self.receive_requests.get(key) @@ -1419,17 +1419,23 @@ class Abstract_Wallet(AddressSynchronizer): self.storage.put('payment_requests', self.receive_requests) def add_payment_request(self, req): - addr = req['address'] - if not bitcoin.is_address(addr): - raise Exception(_('Invalid Bitcoin address.')) - if not self.is_mine(addr): - raise Exception(_('Address not in wallet.')) - + if req['type'] == PR_TYPE_ONCHAIN: + addr = req['address'] + if not bitcoin.is_address(addr): + raise Exception(_('Invalid Bitcoin address.')) + if not self.is_mine(addr): + raise Exception(_('Address not in wallet.')) + key = addr + message = req['memo'] + elif req['type'] == PR_TYPE_LN: + key = req['rhash'] + message = req['message'] + else: + raise Exception('Unknown request type') amount = req.get('amount') - message = req.get('memo') - self.receive_requests[addr] = req + self.receive_requests[key] = req self.storage.put('payment_requests', self.receive_requests) - self.set_label(addr, message) # should be a default label + self.set_label(key, message) # should be a default label return req def delete_request(self, key): @@ -1457,8 +1463,6 @@ class Abstract_Wallet(AddressSynchronizer): def get_sorted_requests(self): """ sorted by timestamp """ out = [self.get_request(x) for x in self.receive_requests.keys()] - if self.lnworker: - out += self.lnworker.get_requests() out.sort(key=operator.itemgetter('time')) return out