electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit b524460fdf264aa0a4f321974c948e139b717eff
parent ea0981ebebe978290111c6942ecc52f30cee6604
Author: SomberNight <somber.night@protonmail.com>
Date:   Tue, 17 Mar 2020 18:02:51 +0100

lnpeer: implement basic handling of "update_fail_malformed_htlc"

Diffstat:
Melectrum/lnchannel.py | 8+++++---
Melectrum/lnhtlc.py | 44++++++++++++++++++++++++++++++++------------
Melectrum/lnonion.py | 7++++++-
Melectrum/lnpeer.py | 124+++++++++++++++++++++++++++++++++++++++++++++++++------------------------------
Melectrum/lnworker.py | 5++++-
5 files changed, 124 insertions(+), 64 deletions(-)

diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py @@ -41,7 +41,7 @@ from .bitcoin import redeem_script_to_address from .crypto import sha256, sha256d from .transaction import Transaction, PartialTransaction from .logging import Logger -from .lnonion import decode_onion_error +from .lnonion import decode_onion_error, OnionFailureCode, OnionRoutingFailureMessage from . import lnutil from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints, get_per_commitment_secret_from_seed, secret_to_pubkey, derive_privkey, make_closing_tx, @@ -59,6 +59,7 @@ from .lnmsg import encode_msg, decode_msg if TYPE_CHECKING: from .lnworker import LNWallet from .json_db import StoredDict + from .lnrouter import RouteEdge # lightning channel states @@ -769,7 +770,8 @@ class Channel(Logger): htlc = log['adds'][htlc_id] return htlc.payment_hash - def decode_onion_error(self, reason, route, htlc_id): + def decode_onion_error(self, reason: bytes, route: Sequence['RouteEdge'], + htlc_id: int) -> Tuple[OnionRoutingFailureMessage, int]: failure_msg, sender_idx = decode_onion_error( reason, [x.node_id for x in route], @@ -791,7 +793,7 @@ class Channel(Logger): with self.db_lock: self.hm.send_fail(htlc_id) - def receive_fail_htlc(self, htlc_id, reason): + def receive_fail_htlc(self, htlc_id: int, reason: bytes): self.logger.info("receive_fail_htlc") with self.db_lock: self.hm.recv_fail(htlc_id) diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py @@ -88,16 +88,28 @@ class HTLCManager: self._maybe_active_htlc_ids[REMOTE].add(htlc_id) def send_settle(self, htlc_id: int) -> None: - self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1} + next_ctn = self.ctn_latest(REMOTE) + 1 + if not self.is_htlc_active_at_ctn(ctx_owner=REMOTE, ctn=next_ctn, htlc_proposer=REMOTE, htlc_id=htlc_id): + raise Exception(f"(local) cannot remove htlc that is not there...") + self.log[REMOTE]['settles'][htlc_id] = {LOCAL: None, REMOTE: next_ctn} def recv_settle(self, htlc_id: int) -> None: - self.log[LOCAL]['settles'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None} + next_ctn = self.ctn_latest(LOCAL) + 1 + if not self.is_htlc_active_at_ctn(ctx_owner=LOCAL, ctn=next_ctn, htlc_proposer=LOCAL, htlc_id=htlc_id): + raise Exception(f"(remote) cannot remove htlc that is not there...") + self.log[LOCAL]['settles'][htlc_id] = {LOCAL: next_ctn, REMOTE: None} def send_fail(self, htlc_id: int) -> None: - self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1} + next_ctn = self.ctn_latest(REMOTE) + 1 + if not self.is_htlc_active_at_ctn(ctx_owner=REMOTE, ctn=next_ctn, htlc_proposer=REMOTE, htlc_id=htlc_id): + raise Exception(f"(local) cannot remove htlc that is not there...") + self.log[REMOTE]['fails'][htlc_id] = {LOCAL: None, REMOTE: next_ctn} def recv_fail(self, htlc_id: int) -> None: - self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None} + next_ctn = self.ctn_latest(LOCAL) + 1 + if not self.is_htlc_active_at_ctn(ctx_owner=LOCAL, ctn=next_ctn, htlc_proposer=LOCAL, htlc_id=htlc_id): + raise Exception(f"(remote) cannot remove htlc that is not there...") + self.log[LOCAL]['fails'][htlc_id] = {LOCAL: next_ctn, REMOTE: None} def send_update_fee(self, feerate: int) -> None: fee_update = FeeUpdate(rate=feerate, @@ -249,6 +261,20 @@ class HTLCManager: ##### Queries re HTLCs: + def is_htlc_active_at_ctn(self, *, ctx_owner: HTLCOwner, ctn: int, + htlc_proposer: HTLCOwner, htlc_id: int) -> bool: + if htlc_id >= self.get_next_htlc_id(htlc_proposer): + return False + settles = self.log[htlc_proposer]['settles'] + fails = self.log[htlc_proposer]['fails'] + ctns = self.log[htlc_proposer]['locked_in'][htlc_id] + if ctns[ctx_owner] is not None and ctns[ctx_owner] <= ctn: + not_settled = htlc_id not in settles or settles[htlc_id][ctx_owner] is None or settles[htlc_id][ctx_owner] > ctn + not_failed = htlc_id not in fails or fails[htlc_id][ctx_owner] is None or fails[htlc_id][ctx_owner] > ctn + if not_settled and not_failed: + return True + return False + def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction, ctn: int = None) -> Dict[int, UpdateAddHtlc]: """Return the dict of received or sent (depending on direction) HTLCs @@ -264,19 +290,13 @@ class HTLCManager: # subject's ctx # party is the proposer of the HTLCs party = subject if direction == SENT else subject.inverted() - settles = self.log[party]['settles'] - fails = self.log[party]['fails'] if ctn >= self.ctn_oldest_unrevoked(subject): considered_htlc_ids = self._maybe_active_htlc_ids[party] else: # ctn is too old; need to consider full log (slow...) considered_htlc_ids = self.log[party]['locked_in'] for htlc_id in considered_htlc_ids: - ctns = self.log[party]['locked_in'][htlc_id] - if ctns[subject] is not None and ctns[subject] <= ctn: - not_settled = htlc_id not in settles or settles[htlc_id][subject] is None or settles[htlc_id][subject] > ctn - not_failed = htlc_id not in fails or fails[htlc_id][subject] is None or fails[htlc_id][subject] > ctn - if not_settled and not_failed: - d[htlc_id] = self.log[party]['adds'][htlc_id] + if self.is_htlc_active_at_ctn(ctx_owner=subject, ctn=ctn, htlc_proposer=party, htlc_id=htlc_id): + d[htlc_id] = self.log[party]['adds'][htlc_id] return d def htlcs(self, subject: HTLCOwner, ctn: int = None) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: diff --git a/electrum/lnonion.py b/electrum/lnonion.py @@ -45,6 +45,7 @@ PER_HOP_HMAC_SIZE = 32 class UnsupportedOnionPacketVersion(Exception): pass class InvalidOnionMac(Exception): pass +class InvalidOnionPubkey(Exception): pass class OnionPerHop: @@ -109,6 +110,8 @@ class OnionPacket: self.public_key = public_key self.hops_data = hops_data # also called RoutingInfo in bolt-04 self.hmac = hmac + if not ecc.ECPubkey.is_pubkey_bytes(public_key): + raise InvalidOnionPubkey() def to_bytes(self) -> bytes: ret = bytes([self.version]) @@ -243,6 +246,8 @@ class ProcessedOnionPacket(NamedTuple): # TODO replay protection def process_onion_packet(onion_packet: OnionPacket, associated_data: bytes, our_onion_private_key: bytes) -> ProcessedOnionPacket: + if not ecc.ECPubkey.is_pubkey_bytes(onion_packet.public_key): + raise InvalidOnionPubkey() shared_secret = get_ecdh(our_onion_private_key, onion_packet.public_key) # check message integrity @@ -322,7 +327,7 @@ def construct_onion_error(reason: OnionRoutingFailureMessage, def _decode_onion_error(error_packet: bytes, payment_path_pubkeys: Sequence[bytes], - session_key: bytes) -> (bytes, int): + session_key: bytes) -> Tuple[bytes, int]: """Returns the decoded error bytes, and the index of the sender of the error.""" num_hops = len(payment_path_pubkeys) hop_shared_secrets = get_shared_secrets_along_route(payment_path_pubkeys, session_key) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -30,7 +30,8 @@ from .transaction import Transaction, TxOutput, PartialTxOutput, match_script_ag from .logging import Logger from .lnonion import (new_onion_packet, decode_onion_error, OnionFailureCode, calc_hops_data_for_payment, process_onion_packet, OnionPacket, construct_onion_error, OnionRoutingFailureMessage, - ProcessedOnionPacket) + ProcessedOnionPacket, UnsupportedOnionPacketVersion, InvalidOnionMac, InvalidOnionPubkey, + OnionFailureCodeMetaFlag) from .lnchannel import Channel, RevokeAndAck, htlcsum, RemoteCtnTooFarInFuture, channel_states, peer_states from . import lnutil from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, @@ -229,7 +230,7 @@ class Peer(Logger): chan.set_remote_update(payload['raw']) self.logger.info("saved remote_update") - def on_announcement_signatures(self, chan, payload): + def on_announcement_signatures(self, chan: Channel, payload): if chan.config[LOCAL].was_announced: h, local_node_sig, local_bitcoin_sig = self.send_announcement_signatures(chan) else: @@ -900,7 +901,7 @@ class Peer(Logger): if chan.config[LOCAL].funding_locked_received and chan.short_channel_id: self.mark_open(chan) - def on_funding_locked(self, chan, payload): + def on_funding_locked(self, chan: Channel, payload): self.logger.info(f"on_funding_locked. channel: {bh2u(chan.channel_id)}") if not chan.config[LOCAL].funding_locked_received: their_next_point = payload["next_per_commitment_point"] @@ -926,7 +927,7 @@ class Peer(Logger): asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) @log_exceptions - async def handle_announcements(self, chan): + async def handle_announcements(self, chan: Channel): h, local_node_sig, local_bitcoin_sig = self.send_announcement_signatures(chan) announcement_signatures_msg = await self.announcement_signatures[chan.channel_id].get() remote_node_sig = announcement_signatures_msg["node_signature"] @@ -1002,11 +1003,11 @@ class Peer(Logger): ) return msg_hash, node_signature, bitcoin_signature - def on_update_fail_htlc(self, chan, payload): + def on_update_fail_htlc(self, chan: Channel, payload): htlc_id = int.from_bytes(payload["id"], "big") reason = payload["reason"] self.logger.info(f"on_update_fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") - chan.receive_fail_htlc(htlc_id, reason) + chan.receive_fail_htlc(htlc_id, reason) # TODO handle exc and maybe fail channel (e.g. bad htlc_id) self.maybe_send_commitment(chan) def maybe_send_commitment(self, chan: Channel): @@ -1057,7 +1058,7 @@ class Peer(Logger): next_per_commitment_point=rev.next_per_commitment_point) self.maybe_send_commitment(chan) - def on_commitment_signed(self, chan, payload): + def on_commitment_signed(self, chan: Channel, payload): if chan.peer_state == peer_states.BAD: return self.logger.info(f'on_commitment_signed. chan {chan.short_channel_id}. ctn: {chan.get_next_ctn(LOCAL)}.') @@ -1075,19 +1076,29 @@ class Peer(Logger): chan.receive_new_commitment(payload["signature"], htlc_sigs) self.send_revoke_and_ack(chan) - def on_update_fulfill_htlc(self, chan, payload): + def on_update_fulfill_htlc(self, chan: Channel, payload): preimage = payload["payment_preimage"] payment_hash = sha256(preimage) htlc_id = int.from_bytes(payload["id"], "big") self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") - chan.receive_htlc_settle(preimage, htlc_id) + chan.receive_htlc_settle(preimage, htlc_id) # TODO handle exc and maybe fail channel (e.g. bad htlc_id) self.lnworker.save_preimage(payment_hash, preimage) self.maybe_send_commitment(chan) - def on_update_fail_malformed_htlc(self, chan, payload): - self.logger.info(f"on_update_fail_malformed_htlc. error {payload['data'].decode('ascii')}") + def on_update_fail_malformed_htlc(self, chan: Channel, payload): + htlc_id = payload["id"] + failure_code = payload["failure_code"] + self.logger.info(f"on_update_fail_malformed_htlc. chan {chan.get_id_for_log()}. " + f"htlc_id {htlc_id}. failure_code={failure_code}") + if failure_code & OnionFailureCodeMetaFlag.BADONION == 0: + asyncio.ensure_future(self.lnworker.try_force_closing(chan.channel_id)) + raise RemoteMisbehaving(f"received update_fail_malformed_htlc with unexpected failure code: {failure_code}") + reason = b'' # TODO somehow propagate "failure_code" ? + chan.receive_fail_htlc(htlc_id, reason) # TODO handle exc and maybe fail channel (e.g. bad htlc_id) + self.maybe_send_commitment(chan) + # TODO when forwarding, we need to propagate this "update_fail_malformed_htlc" downstream - def on_update_add_htlc(self, chan, payload): + def on_update_add_htlc(self, chan: Channel, payload): payment_hash = payload["payment_hash"] htlc_id = int.from_bytes(payload["id"], 'big') self.logger.info(f"on_update_add_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") @@ -1211,19 +1222,27 @@ class Peer(Logger): id=htlc_id, payment_preimage=preimage) - def fail_htlc(self, chan: Channel, htlc_id: int, onion_packet: OnionPacket, + def fail_htlc(self, chan: Channel, htlc_id: int, onion_packet: Optional[OnionPacket], reason: OnionRoutingFailureMessage): self.logger.info(f"fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}. reason: {reason}") assert chan.can_send_ctx_updates(), f"cannot send updates: {chan.short_channel_id}" chan.fail_htlc(htlc_id) - error_packet = construct_onion_error(reason, onion_packet, our_onion_private_key=self.privkey) - self.send_message("update_fail_htlc", - channel_id=chan.channel_id, - id=htlc_id, - len=len(error_packet), - reason=error_packet) - - def on_revoke_and_ack(self, chan, payload): + if onion_packet: + error_packet = construct_onion_error(reason, onion_packet, our_onion_private_key=self.privkey) + self.send_message("update_fail_htlc", + channel_id=chan.channel_id, + id=htlc_id, + len=len(error_packet), + reason=error_packet) + else: + assert len(reason.data) == 32, f"unexpected reason when sending 'update_fail_malformed_htlc': {reason!r}" + self.send_message("update_fail_malformed_htlc", + channel_id=chan.channel_id, + id=htlc_id, + sha256_of_onion=reason.data, + failure_code=reason.code) + + def on_revoke_and_ack(self, chan: Channel, payload): if chan.peer_state == peer_states.BAD: return self.logger.info(f'on_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(REMOTE)}') @@ -1232,7 +1251,7 @@ class Peer(Logger): self.lnworker.save_channel(chan) self.maybe_send_commitment(chan) - def on_update_fee(self, chan, payload): + def on_update_fee(self, chan: Channel, payload): feerate = int.from_bytes(payload["feerate_per_kw"], "big") chan.update_fee(feerate, False) @@ -1282,7 +1301,7 @@ class Peer(Logger): return txid @log_exceptions - async def on_shutdown(self, chan, payload): + async def on_shutdown(self, chan: Channel, payload): their_scriptpubkey = payload['scriptpubkey'] # BOLT-02 restrict the scriptpubkey to some templates: if not (match_script_against_template(their_scriptpubkey, transaction.SCRIPTPUBKEY_TEMPLATE_WITNESS_V0) @@ -1404,33 +1423,44 @@ class Peer(Logger): if chan.get_oldest_unrevoked_ctn(REMOTE) <= remote_ctn: continue chan.logger.info(f'found unfulfilled htlc: {htlc_id}') - onion_packet = OnionPacket.from_bytes(bytes.fromhex(onion_packet_hex)) htlc = chan.hm.log[REMOTE]['adds'][htlc_id] payment_hash = htlc.payment_hash - processed_onion = process_onion_packet(onion_packet, associated_data=payment_hash, our_onion_private_key=self.privkey) - preimage, error = None, None - if processed_onion.are_we_final: - preimage, error = self.maybe_fulfill_htlc( - chan=chan, - htlc=htlc, - onion_packet=onion_packet, - processed_onion=processed_onion) - elif not forwarded: - error = self.maybe_forward_htlc( - chan=chan, - htlc=htlc, - onion_packet=onion_packet, - processed_onion=processed_onion) - if not error: - unfulfilled[htlc_id] = local_ctn, remote_ctn, onion_packet_hex, True + error = None # type: Optional[OnionRoutingFailureMessage] + preimage = None + onion_packet_bytes = bytes.fromhex(onion_packet_hex) + onion_packet = None + try: + onion_packet = OnionPacket.from_bytes(onion_packet_bytes) + processed_onion = process_onion_packet(onion_packet, associated_data=payment_hash, our_onion_private_key=self.privkey) + except UnsupportedOnionPacketVersion: + error = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_VERSION, data=sha256(onion_packet_bytes)) + except InvalidOnionPubkey: + error = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_KEY, data=sha256(onion_packet_bytes)) + except InvalidOnionMac: + error = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_HMAC, data=sha256(onion_packet_bytes)) else: - f = self.lnworker.pending_payments[payment_hash] - if f.done(): - success, preimage, error = f.result() - if preimage: - await self.lnworker.enable_htlc_settle.wait() - self.fulfill_htlc(chan, htlc.htlc_id, preimage) - done.add(htlc_id) + if processed_onion.are_we_final: + preimage, error = self.maybe_fulfill_htlc( + chan=chan, + htlc=htlc, + onion_packet=onion_packet, + processed_onion=processed_onion) + elif not forwarded: + error = self.maybe_forward_htlc( + chan=chan, + htlc=htlc, + onion_packet=onion_packet, + processed_onion=processed_onion) + if not error: + unfulfilled[htlc_id] = local_ctn, remote_ctn, onion_packet_hex, True + else: + f = self.lnworker.pending_payments[payment_hash] + if f.done(): + success, preimage, error = f.result() + if preimage: + await self.lnworker.enable_htlc_settle.wait() + self.fulfill_htlc(chan, htlc.htlc_id, preimage) + done.add(htlc_id) if error: self.fail_htlc(chan, htlc.htlc_id, onion_packet, error) done.add(htlc_id) diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -958,6 +958,9 @@ class LNWallet(LNWorker): if success: failure_log = None else: + # TODO this blacklisting is fragile, consider (who to ban/penalize?): + # - we might not be able to decode "reason" (coming from update_fail_htlc). + # - handle update_fail_malformed_htlc case, where there is (kinda) no "reason" failure_msg, sender_idx = chan.decode_onion_error(reason, route, htlc.htlc_id) blacklist = self.handle_error_code_from_failed_htlc(failure_msg, sender_idx, route, peer) if blacklist: @@ -1216,7 +1219,7 @@ class LNWallet(LNWorker): info = info._replace(status=status) self.save_payment_info(info) - def payment_failed(self, chan, payment_hash: bytes, reason): + def payment_failed(self, chan, payment_hash: bytes, reason: bytes): self.set_payment_status(payment_hash, PR_UNPAID) key = payment_hash.hex() f = self.pending_payments.get(payment_hash)