commit 0ea87278fb18a535c5ba35eb542ea7dfd5672f18
parent 6211e656a8cbdc9b4ab31b127e94bce6523f92d3
Author: Janus <ysangkok@gmail.com>
Date: Fri, 2 Nov 2018 19:16:42 +0100
move force_close_channel to lnbase, test it, add FORCE_CLOSING state
Diffstat:
4 files changed, 122 insertions(+), 47 deletions(-)
diff --git a/electrum/lnbase.py b/electrum/lnbase.py
@@ -19,7 +19,7 @@ import aiorpcx
from .crypto import sha256, sha256d
from . import bitcoin
from . import ecc
-from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string
+from .ecc import sig_string_from_r_and_s, get_r_and_s_from_sig_string, der_sig_from_sig_string
from . import constants
from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabled_bits, ignore_exceptions
from .transaction import Transaction, TxOutput
@@ -1158,6 +1158,25 @@ class Peer(PrintError):
self.print_error('Channel closed', txid)
return txid
+ async def force_close_channel(self, chan_id):
+ chan = self.channels[chan_id]
+ # local_commitment always gives back the next expected local_commitment,
+ # but in this case, we want the current one. So substract one ctn number
+ old_local_state = chan.config[LOCAL]
+ chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1)
+ tx = chan.pending_local_commitment
+ chan.config[LOCAL] = old_local_state
+ tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)})
+ remote_sig = chan.config[LOCAL].current_commitment_signature
+ remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01"
+ none_idx = tx._inputs[0]["signatures"].index(None)
+ tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
+ assert tx.is_complete()
+ # TODO persist FORCE_CLOSING state to disk
+ chan.set_state('FORCE_CLOSING')
+ self.lnworker.save_channel(chan)
+ return await self.network.broadcast_transaction(tx)
+
@log_exceptions
async def on_shutdown(self, payload):
# length of scripts allowed in BOLT-02
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -11,6 +11,7 @@ from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
import threading
import socket
import json
+from decimal import Decimal
import dns.resolver
import dns.exception
@@ -267,18 +268,13 @@ class LNWorker(PrintError):
return addr, peer, fut
def _pay(self, invoice, amount_sat=None):
- addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
- payment_hash = addr.paymenthash
- amount_sat = (addr.amount * COIN) if addr.amount else amount_sat
- if amount_sat is None:
- raise InvoiceError(_("Missing amount"))
- amount_msat = int(amount_sat * 1000)
- if addr.get_min_final_cltv_expiry() > 60 * 144:
- raise InvoiceError("{}\n{}".format(
- _("Invoice wants us to risk locking funds for unreasonably long."),
- f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
- route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat)
- node_id, short_channel_id = route[0].node_id, route[0].short_channel_id
+ addr = self._check_invoice(invoice, amount_sat)
+ route = self._create_route_from_invoice(decoded_invoice=addr)
+ peer = self.peers[route[0].node_id]
+ return addr, peer, self._pay_to_route(route, addr)
+
+ async def _pay_to_route(self, route, addr):
+ short_channel_id = route[0].short_channel_id
with self.lock:
channels = list(self.channels.values())
for chan in channels:
@@ -286,11 +282,24 @@ class LNWorker(PrintError):
break
else:
raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
- peer = self.peers[node_id]
- coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry())
- return addr, peer, coro
+ peer = self.peers[route[0].node_id]
+ return await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry())
- def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]:
+ @staticmethod
+ def _check_invoice(invoice, amount_sat=None):
+ addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
+ if amount_sat:
+ addr.amount = Decimal(amount_sat) / COIN
+ if addr.amount is None:
+ raise InvoiceError(_("Missing amount"))
+ if addr.get_min_final_cltv_expiry() > 60 * 144:
+ raise InvoiceError("{}\n{}".format(
+ _("Invoice wants us to risk locking funds for unreasonably long."),
+ f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
+ return addr
+
+ def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]:
+ amount_msat = int(decoded_invoice.amount * COIN * 1000)
invoice_pubkey = decoded_invoice.pubkey.serialize()
# use 'r' field from invoice
route = None # type: List[RouteEdge]
@@ -441,19 +450,8 @@ class LNWorker(PrintError):
async def force_close_channel(self, chan_id):
chan = self.channels[chan_id]
- # local_commitment always gives back the next expected local_commitment,
- # but in this case, we want the current one. So substract one ctn number
- old_local_state = chan.config[LOCAL]
- chan.config[LOCAL]=chan.config[LOCAL]._replace(ctn=chan.config[LOCAL].ctn - 1)
- tx = chan.pending_local_commitment
- chan.config[LOCAL] = old_local_state
- tx.sign({bh2u(chan.config[LOCAL].multisig_key.pubkey): (chan.config[LOCAL].multisig_key.privkey, True)})
- remote_sig = chan.config[LOCAL].current_commitment_signature
- remote_sig = der_sig_from_sig_string(remote_sig) + b"\x01"
- none_idx = tx._inputs[0]["signatures"].index(None)
- tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
- assert tx.is_complete()
- return await self.network.broadcast_transaction(tx)
+ peer = self.peers[chan.node_id]
+ return await peer.force_close_channel(chan_id)
def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
now = time.time()
diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py
@@ -16,6 +16,7 @@ from electrum.util import bh2u
from electrum.lnbase import Peer, decode_msg, gen_msg
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
+from electrum.lnutil import PaymentFailure
from electrum.lnrouter import ChannelDB, LNPathFinder
from electrum.lnworker import LNWorker
@@ -33,7 +34,7 @@ def noop_lock():
yield
class MockNetwork:
- def __init__(self):
+ def __init__(self, tx_queue):
self.callbacks = defaultdict(list)
self.lnwatcher = None
user_config = {}
@@ -43,6 +44,7 @@ class MockNetwork:
self.channel_db = ChannelDB(self)
self.interface = None
self.path_finder = LNPathFinder(self.channel_db)
+ self.tx_queue = tx_queue
@property
def callback_lock(self):
@@ -55,12 +57,16 @@ class MockNetwork:
def get_local_height(self):
return 0
+ async def broadcast_transaction(self, tx):
+ if self.tx_queue:
+ await self.tx_queue.put(tx)
+
class MockLNWorker:
- def __init__(self, remote_keypair, local_keypair, chan):
+ def __init__(self, remote_keypair, local_keypair, chan, tx_queue):
self.chan = chan
self.remote_keypair = remote_keypair
self.node_keypair = local_keypair
- self.network = MockNetwork()
+ self.network = MockNetwork(tx_queue)
self.channels = {self.chan.channel_id: self.chan}
self.invoices = {}
@@ -76,10 +82,12 @@ class MockLNWorker:
return self.channels
def save_channel(self, chan):
- pass
+ print("Ignoring channel save")
get_invoice = LNWorker.get_invoice
_create_route_from_invoice = LNWorker._create_route_from_invoice
+ _check_invoice = staticmethod(LNWorker._check_invoice)
+ _pay_to_route = LNWorker._pay_to_route
class MockTransport:
def __init__(self):
@@ -120,18 +128,19 @@ class TestPeer(unittest.TestCase):
self.alice_channel, self.bob_channel = create_test_channels()
def test_require_data_loss_protect(self):
- mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel)
+ mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
mock_transport = NoFeaturesTransport()
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
mock_lnworker.peer = p1
with self.assertRaises(LightningPeerConnectionClosed):
- asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
+ run(asyncio.wait_for(p1._main_loop(), 1))
- def test_payment(self):
+ def prepare_peers(self):
k1, k2 = keypair(), keypair()
t1, t2 = transport_pair()
- w1 = MockLNWorker(k1, k2, self.alice_channel)
- w2 = MockLNWorker(k2, k1, self.bob_channel)
+ q1, q2 = asyncio.Queue(), asyncio.Queue()
+ w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1)
+ w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2)
p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
request_initial_sync=False, transport=t1)
p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
@@ -145,6 +154,11 @@ class TestPeer(unittest.TestCase):
# this populates the channel graph:
p1.mark_open(self.alice_channel)
p2.mark_open(self.bob_channel)
+ return p1, p2, w1, w2, q1, q2
+
+ @staticmethod
+ def prepare_invoice(w2 # receiver
+ ):
amount_btc = 100000/Decimal(COIN)
payment_preimage = os.urandom(32)
RHASH = sha256(payment_preimage)
@@ -156,13 +170,23 @@ class TestPeer(unittest.TestCase):
])
pay_req = lnencode(addr, w2.node_keypair.privkey)
w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req)
- l = asyncio.get_event_loop()
- async def pay():
- fut = asyncio.Future()
- def evt_set(event, _lnworker, msg):
- fut.set_result(msg)
- w2.network.register_callback(evt_set, ['ln_message'])
+ return pay_req
+
+ @staticmethod
+ def prepare_ln_message_future(w2 # receiver
+ ):
+ fut = asyncio.Future()
+ def evt_set(event, _lnworker, msg):
+ fut.set_result(msg)
+ w2.network.register_callback(evt_set, ['ln_message'])
+ return fut
+
+ def test_payment(self):
+ p1, p2, w1, w2, _q1, _q2 = self.prepare_peers()
+ pay_req = self.prepare_invoice(w2)
+ fut = self.prepare_ln_message_future(w2)
+ async def pay():
addr, peer, coro = LNWorker._pay(w1, pay_req)
await coro
print("HTLC ADDED")
@@ -170,4 +194,28 @@ class TestPeer(unittest.TestCase):
gath.cancel()
gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop())
with self.assertRaises(asyncio.CancelledError):
- l.run_until_complete(gath)
+ run(gath)
+
+ def test_channel_usage_after_closing(self):
+ p1, p2, w1, w2, q1, q2 = self.prepare_peers()
+ pay_req = self.prepare_invoice(w2)
+
+ addr = w1._check_invoice(pay_req)
+ route = w1._create_route_from_invoice(decoded_invoice=addr)
+
+ run(p1.force_close_channel(self.alice_channel.channel_id))
+ # check if a tx (commitment transaction) was broadcasted:
+ assert q1.qsize() == 1
+
+ with self.assertRaises(PaymentFailure) as e:
+ w1._create_route_from_invoice(decoded_invoice=addr)
+ self.assertEqual(str(e.exception), 'No path found')
+
+ peer = w1.peers[route[0].node_id]
+ # AssertionError is ok since we shouldn't use old routes, and the
+ # route finding should fail when channel is closed
+ with self.assertRaises(AssertionError):
+ run(asyncio.gather(w1._pay_to_route(route, addr), p1._main_loop(), p2._main_loop()))
+
+def run(coro):
+ asyncio.get_event_loop().run_until_complete(coro)
diff --git a/electrum/tests/test_lnchan.py b/electrum/tests/test_lnchan.py
@@ -29,6 +29,7 @@ from electrum import lnchan
from electrum import lnutil
from electrum import bip32 as bip32_utils
from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED
+from electrum.ecc import sig_string_from_der_sig
one_bitcoin_in_msat = bitcoin.COIN * 1000
@@ -81,7 +82,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
per_commitment_secret_seed=seed,
funding_locked_received=True,
was_announced=False,
- current_commitment_signature=None,
+ # just a random signature
+ current_commitment_signature=sig_string_from_der_sig(bytes.fromhex('3046022100c66e112e22b91b96b795a6dd5f4b004f3acccd9a2a31bf104840f256855b7aa3022100e711b868b62d87c7edd95a2370e496b9cb6a38aff13c9f64f9ff2f3b2a0052dd')),
current_htlc_signatures=None,
),
"constraints":lnbase.ChannelConstraints(
@@ -185,6 +187,14 @@ class TestChannel(unittest.TestCase):
self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0]
+ def test_concurrent_reversed_payment(self):
+ self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02')
+ self.htlc_dict['amount_msat'] += 1000
+ bob_idx = self.bob_channel.add_htlc(self.htlc_dict)
+ alice_idx = self.alice_channel.receive_htlc(self.htlc_dict)
+ self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment())
+ self.assertEquals(len(self.alice_channel.pending_remote_commitment.outputs()), 3)
+
def test_SimpleAddSettleWorkflow(self):
alice_channel, bob_channel = self.alice_channel, self.bob_channel
htlc = self.htlc