electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

lntransport.py (9838B)


      1 # Copyright (C) 2018 Adam Gibson (waxwing)
      2 # Copyright (C) 2018 The Electrum developers
      3 # Distributed under the MIT software license, see the accompanying
      4 # file LICENCE or http://www.opensource.org/licenses/mit-license.php
      5 
      6 # Derived from https://gist.github.com/AdamISZ/046d05c156aaeb56cc897f85eecb3eb8
      7 
      8 import hashlib
      9 import asyncio
     10 from asyncio import StreamReader, StreamWriter
     11 from typing import Optional
     12 
     13 from .crypto import sha256, hmac_oneshot, chacha20_poly1305_encrypt, chacha20_poly1305_decrypt
     14 from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed,
     15                      HandshakeFailed, LNPeerAddr)
     16 from . import ecc
     17 from .util import bh2u, MySocksProxy
     18 
     19 
     20 class HandshakeState(object):
     21     prologue = b"lightning"
     22     protocol_name = b"Noise_XK_secp256k1_ChaChaPoly_SHA256"
     23     handshake_version = b"\x00"
     24 
     25     def __init__(self, responder_pub):
     26         self.responder_pub = responder_pub
     27         self.h = sha256(self.protocol_name)
     28         self.ck = self.h
     29         self.update(self.prologue)
     30         self.update(self.responder_pub)
     31 
     32     def update(self, data):
     33         self.h = sha256(self.h + data)
     34         return self.h
     35 
     36 def get_nonce_bytes(n):
     37     """BOLT 8 requires the nonce to be 12 bytes, 4 bytes leading
     38     zeroes and 8 bytes little endian encoded 64 bit integer.
     39     """
     40     return b"\x00"*4 + n.to_bytes(8, 'little')
     41 
     42 def aead_encrypt(key: bytes, nonce: int, associated_data: bytes, data: bytes) -> bytes:
     43     nonce_bytes = get_nonce_bytes(nonce)
     44     return chacha20_poly1305_encrypt(key=key,
     45                                      nonce=nonce_bytes,
     46                                      associated_data=associated_data,
     47                                      data=data)
     48 
     49 def aead_decrypt(key: bytes, nonce: int, associated_data: bytes, data: bytes) -> bytes:
     50     nonce_bytes = get_nonce_bytes(nonce)
     51     return chacha20_poly1305_decrypt(key=key,
     52                                      nonce=nonce_bytes,
     53                                      associated_data=associated_data,
     54                                      data=data)
     55 
     56 def get_bolt8_hkdf(salt, ikm):
     57     """RFC5869 HKDF instantiated in the specific form
     58     used in Lightning BOLT 8:
     59     Extract and expand to 64 bytes using HMAC-SHA256,
     60     with info field set to a zero length string as per BOLT8
     61     Return as two 32 byte fields.
     62     """
     63     #Extract
     64     prk = hmac_oneshot(salt, msg=ikm, digest=hashlib.sha256)
     65     assert len(prk) == 32
     66     #Expand
     67     info = b""
     68     T0 = b""
     69     T1 = hmac_oneshot(prk, T0 + info + b"\x01", digest=hashlib.sha256)
     70     T2 = hmac_oneshot(prk, T1 + info + b"\x02", digest=hashlib.sha256)
     71     assert len(T1 + T2) == 64
     72     return T1, T2
     73 
     74 def act1_initiator_message(hs, epriv, epub):
     75     ss = get_ecdh(epriv, hs.responder_pub)
     76     ck2, temp_k1 = get_bolt8_hkdf(hs.ck, ss)
     77     hs.ck = ck2
     78     c = aead_encrypt(temp_k1, 0, hs.update(epub), b"")
     79     #for next step if we do it
     80     hs.update(c)
     81     msg = hs.handshake_version + epub + c
     82     assert len(msg) == 50
     83     return msg, temp_k1
     84 
     85 
     86 def create_ephemeral_key() -> (bytes, bytes):
     87     privkey = ecc.ECPrivkey.generate_random_key()
     88     return privkey.get_secret_bytes(), privkey.get_public_key_bytes()
     89 
     90 
     91 class LNTransportBase:
     92     reader: StreamReader
     93     writer: StreamWriter
     94     privkey: bytes
     95 
     96     def name(self) -> str:
     97         raise NotImplementedError()
     98 
     99     def send_bytes(self, msg: bytes) -> None:
    100         l = len(msg).to_bytes(2, 'big')
    101         lc = aead_encrypt(self.sk, self.sn(), b'', l)
    102         c = aead_encrypt(self.sk, self.sn(), b'', msg)
    103         assert len(lc) == 18
    104         assert len(c) == len(msg) + 16
    105         self.writer.write(lc+c)
    106 
    107     async def read_messages(self):
    108         read_buffer = b''
    109         while True:
    110             rn_l, rk_l = self.rn()
    111             rn_m, rk_m = self.rn()
    112             while True:
    113                 if len(read_buffer) >= 18:
    114                     lc = read_buffer[:18]
    115                     l = aead_decrypt(rk_l, rn_l, b'', lc)
    116                     length = int.from_bytes(l, 'big')
    117                     offset = 18 + length + 16
    118                     if len(read_buffer) >= offset:
    119                         c = read_buffer[18:offset]
    120                         read_buffer = read_buffer[offset:]
    121                         msg = aead_decrypt(rk_m, rn_m, b'', c)
    122                         yield msg
    123                         break
    124                 try:
    125                     s = await self.reader.read(2**10)
    126                 except asyncio.CancelledError:
    127                     raise
    128                 except Exception:
    129                     s = None
    130                 if not s:
    131                     raise LightningPeerConnectionClosed()
    132                 read_buffer += s
    133 
    134     def rn(self):
    135         o = self._rn, self.rk
    136         self._rn += 1
    137         if self._rn == 1000:
    138             self.r_ck, self.rk = get_bolt8_hkdf(self.r_ck, self.rk)
    139             self._rn = 0
    140         return o
    141 
    142     def sn(self):
    143         o = self._sn
    144         self._sn += 1
    145         if self._sn == 1000:
    146             self.s_ck, self.sk = get_bolt8_hkdf(self.s_ck, self.sk)
    147             self._sn = 0
    148         return o
    149 
    150     def init_counters(self, ck):
    151         # init counters
    152         self._sn = 0
    153         self._rn = 0
    154         self.r_ck = ck
    155         self.s_ck = ck
    156 
    157     def close(self):
    158         self.writer.close()
    159 
    160 
    161 class LNResponderTransport(LNTransportBase):
    162     """Transport initiated by remote party."""
    163 
    164     def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
    165         LNTransportBase.__init__(self)
    166         self.reader = reader
    167         self.writer = writer
    168         self.privkey = privkey
    169 
    170     def name(self):
    171         return "responder"
    172 
    173     async def handshake(self, **kwargs):
    174         hs = HandshakeState(privkey_to_pubkey(self.privkey))
    175         act1 = b''
    176         while len(act1) < 50:
    177             buf = await self.reader.read(50 - len(act1))
    178             if not buf:
    179                 raise HandshakeFailed('responder disconnected')
    180             act1 += buf
    181         if len(act1) != 50:
    182             raise HandshakeFailed('responder: short act 1 read, length is ' + str(len(act1)))
    183         if bytes([act1[0]]) != HandshakeState.handshake_version:
    184             raise HandshakeFailed('responder: bad handshake version in act 1')
    185         c = act1[-16:]
    186         re = act1[1:34]
    187         h = hs.update(re)
    188         ss = get_ecdh(self.privkey, re)
    189         ck, temp_k1 = get_bolt8_hkdf(sha256(HandshakeState.protocol_name), ss)
    190         _p = aead_decrypt(temp_k1, 0, h, c)
    191         hs.update(c)
    192 
    193         # act 2
    194         if 'epriv' not in kwargs:
    195             epriv, epub = create_ephemeral_key()
    196         else:
    197             epriv = kwargs['epriv']
    198             epub = ecc.ECPrivkey(epriv).get_public_key_bytes()
    199         hs.ck = ck
    200         hs.responder_pub = re
    201 
    202         msg, temp_k2 = act1_initiator_message(hs, epriv, epub)
    203         self.writer.write(msg)
    204 
    205         # act 3
    206         act3 = b''
    207         while len(act3) < 66:
    208             buf = await self.reader.read(66 - len(act3))
    209             if not buf:
    210                 raise HandshakeFailed('responder disconnected')
    211             act3 += buf
    212         if len(act3) != 66:
    213             raise HandshakeFailed('responder: short act 3 read, length is ' + str(len(act3)))
    214         if bytes([act3[0]]) != HandshakeState.handshake_version:
    215             raise HandshakeFailed('responder: bad handshake version in act 3')
    216         c = act3[1:50]
    217         t = act3[-16:]
    218         rs = aead_decrypt(temp_k2, 1, hs.h, c)
    219         ss = get_ecdh(epriv, rs)
    220         ck, temp_k3 = get_bolt8_hkdf(hs.ck, ss)
    221         _p = aead_decrypt(temp_k3, 0, hs.update(c), t)
    222         self.rk, self.sk = get_bolt8_hkdf(ck, b'')
    223         self.init_counters(ck)
    224         return rs
    225 
    226 
    227 class LNTransport(LNTransportBase):
    228     """Transport initiated by local party."""
    229 
    230     def __init__(self, privkey: bytes, peer_addr: LNPeerAddr, *,
    231                  proxy: Optional[dict]):
    232         LNTransportBase.__init__(self)
    233         assert type(privkey) is bytes and len(privkey) == 32
    234         self.privkey = privkey
    235         self.peer_addr = peer_addr
    236         self.proxy = MySocksProxy.from_proxy_dict(proxy)
    237 
    238     def name(self):
    239         return self.peer_addr.net_addr_str()
    240 
    241     async def handshake(self):
    242         if not self.proxy:
    243             self.reader, self.writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port)
    244         else:
    245             self.reader, self.writer = await self.proxy.open_connection(self.peer_addr.host, self.peer_addr.port)
    246         hs = HandshakeState(self.peer_addr.pubkey)
    247         # Get a new ephemeral key
    248         epriv, epub = create_ephemeral_key()
    249 
    250         msg, _temp_k1 = act1_initiator_message(hs, epriv, epub)
    251         # act 1
    252         self.writer.write(msg)
    253         rspns = await self.reader.read(2**10)
    254         if len(rspns) != 50:
    255             raise HandshakeFailed(f"Lightning handshake act 1 response has bad length, "
    256                                   f"are you sure this is the right pubkey? {self.peer_addr}")
    257         hver, alice_epub, tag = rspns[0], rspns[1:34], rspns[34:]
    258         if bytes([hver]) != hs.handshake_version:
    259             raise HandshakeFailed("unexpected handshake version: {}".format(hver))
    260         # act 2
    261         hs.update(alice_epub)
    262         ss = get_ecdh(epriv, alice_epub)
    263         ck, temp_k2 = get_bolt8_hkdf(hs.ck, ss)
    264         hs.ck = ck
    265         p = aead_decrypt(temp_k2, 0, hs.h, tag)
    266         hs.update(tag)
    267         # act 3
    268         my_pubkey = privkey_to_pubkey(self.privkey)
    269         c = aead_encrypt(temp_k2, 1, hs.h, my_pubkey)
    270         hs.update(c)
    271         ss = get_ecdh(self.privkey[:32], alice_epub)
    272         ck, temp_k3 = get_bolt8_hkdf(hs.ck, ss)
    273         hs.ck = ck
    274         t = aead_encrypt(temp_k3, 0, hs.h, b'')
    275         msg = hs.handshake_version + c + t
    276         self.writer.write(msg)
    277         self.sk, self.rk = get_bolt8_hkdf(hs.ck, b'')
    278         self.init_counters(ck)