commit ef5a26544981a426945b1d32b22b170272cb0671
parent c0bf9b4509cfdeace824a9054d9c90449ef46ad6
Author: ThomasV <thomasv@electrum.org>
Date: Wed, 27 Jan 2021 19:27:06 +0100
basic_mpp: receive multi-part payments
Diffstat:
5 files changed, 51 insertions(+), 22 deletions(-)
diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py
@@ -969,10 +969,6 @@ class Channel(AbstractChannel):
raise Exception("refusing to revoke as remote sig does not fit")
with self.db_lock:
self.hm.send_rev()
- if self.lnworker:
- received = self.hm.received_in_ctn(new_ctn)
- for htlc in received:
- self.lnworker.payment_received(self, htlc.payment_hash)
last_secret, last_point = self.get_secret_and_point(LOCAL, new_ctn - 1)
next_secret, next_point = self.get_secret_and_point(LOCAL, new_ctn + 1)
return RevokeAndAck(last_secret, next_point)
@@ -1054,7 +1050,7 @@ class Channel(AbstractChannel):
if is_sent:
self.lnworker.payment_sent(self, payment_hash)
else:
- self.lnworker.payment_received(self, payment_hash)
+ self.lnworker.payment_received(payment_hash)
def balance(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None) -> int:
assert type(whose) is HTLCOwner
diff --git a/electrum/lnonion.py b/electrum/lnonion.py
@@ -498,6 +498,7 @@ class OnionFailureCode(IntEnum):
CHANNEL_DISABLED = UPDATE | 20
EXPIRY_TOO_FAR = 21
INVALID_ONION_PAYLOAD = PERM | 22
+ MPP_TIMEOUT = 23
# don't use these elsewhere, the names are ambiguous without context
diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
@@ -1389,10 +1389,6 @@ class Peer(Logger):
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
return None, reason
expected_received_msat = info.amount_msat
- if expected_received_msat is not None and \
- not (expected_received_msat <= htlc.amount_msat <= 2 * expected_received_msat):
- reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
- return None, reason
# Check that our blockchain tip is sufficiently recent so that we have an approx idea of the height.
# We should not release the preimage for an HTLC that its sender could already time out as
# then they might try to force-close and it becomes a race.
@@ -1415,20 +1411,34 @@ class Peer(Logger):
data=htlc.cltv_expiry.to_bytes(4, byteorder="big"))
return None, reason
try:
- amount_from_onion = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"]
+ amt_to_forward = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"]
except:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00')
return None, reason
try:
- amount_from_onion = processed_onion.hop_data.payload["payment_data"]["total_msat"]
+ total_msat = processed_onion.hop_data.payload["payment_data"]["total_msat"]
except:
- pass # fall back to "amt_to_forward"
- if amount_from_onion > htlc.amount_msat:
- reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT,
- data=htlc.amount_msat.to_bytes(8, byteorder="big"))
+ total_msat = amt_to_forward # fall back to "amt_to_forward"
+
+ if amt_to_forward != htlc.amount_msat:
+ reason = OnionRoutingFailureMessage(
+ code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT,
+ data=total_msat.to_bytes(8, byteorder="big"))
return None, reason
- # all good
- return preimage, None
+ if expected_received_msat is None:
+ return preimage, None
+ if not (expected_received_msat <= total_msat <= 2 * expected_received_msat):
+ reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
+ return None, reason
+ accepted, expired = self.lnworker.htlc_received(chan.short_channel_id, htlc, expected_received_msat)
+ if accepted:
+ return preimage, None
+ elif expired:
+ reason = OnionRoutingFailureMessage(code=OnionFailureCode.MPP_TIMEOUT)
+ return None, reason
+ else:
+ # waiting for more htlcs
+ return None, None
def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes):
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
@@ -1669,7 +1679,7 @@ class Peer(Logger):
for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_info) in unfulfilled.items():
if not chan.hm.is_add_htlc_irrevocably_committed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id):
continue
- chan.logger.info(f'found unfulfilled htlc: {htlc_id}')
+ #chan.logger.info(f'found unfulfilled htlc: {htlc_id}')
htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id)
payment_hash = htlc.payment_hash
error_reason = None # type: Optional[OnionRoutingFailureMessage]
@@ -1694,7 +1704,6 @@ class Peer(Logger):
error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_VERSION, data=sha256(onion_packet_bytes))
if self.network.config.get('test_fail_htlcs_with_temp_node_failure'):
error_reason = OnionRoutingFailureMessage(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
-
if not error_reason:
if processed_onion.are_we_final:
preimage, error_reason = self.maybe_fulfill_htlc(
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -86,6 +86,7 @@ SAVED_PR_STATUS = [PR_PAID, PR_UNPAID] # status that are persisted
NUM_PEERS_TARGET = 4
+MPP_EXPIRY = 120
FALLBACK_NODE_LIST_TESTNET = (
@@ -164,7 +165,8 @@ BASE_FEATURES = LnFeatures(0)\
LNWALLET_FEATURES = BASE_FEATURES\
| LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ\
| LnFeatures.OPTION_STATIC_REMOTEKEY_REQ\
- | LnFeatures.GOSSIP_QUERIES_REQ
+ | LnFeatures.GOSSIP_QUERIES_REQ\
+ | LnFeatures.BASIC_MPP_OPT
LNGOSSIP_FEATURES = BASE_FEATURES\
| LnFeatures.GOSSIP_QUERIES_OPT\
@@ -581,6 +583,7 @@ class LNWallet(LNWorker):
self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
self.pending_payments = defaultdict(asyncio.Future) # type: Dict[bytes, asyncio.Future[BarePaymentAttemptLog]]
+ self.pending_htlcs = defaultdict(set) # type: Dict[bytes, set]
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
# detect inflight payments
@@ -1284,6 +1287,24 @@ class LNWallet(LNWorker):
self.payments[key] = info.amount_msat, info.direction, info.status
self.wallet.save_db()
+ def htlc_received(self, short_channel_id, htlc, expected_msat):
+ status = self.get_payment_status(htlc.payment_hash)
+ if status == PR_PAID:
+ return True, None
+ s = self.pending_htlcs[htlc.payment_hash]
+ if (short_channel_id, htlc) not in s:
+ s.add((short_channel_id, htlc))
+ total = sum([htlc.amount_msat for scid, htlc in s])
+ first_timestamp = min([htlc.timestamp for scid, htlc in s])
+ expired = time.time() - first_timestamp > MPP_EXPIRY
+ if total >= expected_msat and not expired:
+ # status must be persisted
+ self.payment_received(htlc.payment_hash)
+ return True, None
+ if expired:
+ return None, True
+ return None, None
+
def get_payment_status(self, payment_hash):
info = self.get_payment_info(payment_hash)
return info.status if info else PR_UNPAID
@@ -1359,10 +1380,10 @@ class LNWallet(LNWorker):
util.trigger_callback('payment_succeeded', self.wallet, key)
util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
- def payment_received(self, chan, payment_hash: bytes):
+ def payment_received(self, payment_hash: bytes):
self.set_payment_status(payment_hash, PR_PAID)
util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID)
- util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
+ #util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
async def _calc_routing_hints_for_invoice(self, amount_msat: Optional[int]):
"""calculate routing hints (BOLT-11 'r' field)"""
diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py
@@ -132,6 +132,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
# used in tests
self.enable_htlc_settle = asyncio.Event()
self.enable_htlc_settle.set()
+ self.pending_htlcs = defaultdict(set)
def get_invoice_status(self, key):
pass
@@ -167,6 +168,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
set_invoice_status = LNWallet.set_invoice_status
set_payment_status = LNWallet.set_payment_status
get_payment_status = LNWallet.get_payment_status
+ htlc_received = LNWallet.htlc_received
await_payment = LNWallet.await_payment
payment_received = LNWallet.payment_received
payment_sent = LNWallet.payment_sent