lntransport.py (9838B)
1 # Copyright (C) 2018 Adam Gibson (waxwing) 2 # Copyright (C) 2018 The Electrum developers 3 # Distributed under the MIT software license, see the accompanying 4 # file LICENCE or http://www.opensource.org/licenses/mit-license.php 5 6 # Derived from https://gist.github.com/AdamISZ/046d05c156aaeb56cc897f85eecb3eb8 7 8 import hashlib 9 import asyncio 10 from asyncio import StreamReader, StreamWriter 11 from typing import Optional 12 13 from .crypto import sha256, hmac_oneshot, chacha20_poly1305_encrypt, chacha20_poly1305_decrypt 14 from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed, 15 HandshakeFailed, LNPeerAddr) 16 from . import ecc 17 from .util import bh2u, MySocksProxy 18 19 20 class HandshakeState(object): 21 prologue = b"lightning" 22 protocol_name = b"Noise_XK_secp256k1_ChaChaPoly_SHA256" 23 handshake_version = b"\x00" 24 25 def __init__(self, responder_pub): 26 self.responder_pub = responder_pub 27 self.h = sha256(self.protocol_name) 28 self.ck = self.h 29 self.update(self.prologue) 30 self.update(self.responder_pub) 31 32 def update(self, data): 33 self.h = sha256(self.h + data) 34 return self.h 35 36 def get_nonce_bytes(n): 37 """BOLT 8 requires the nonce to be 12 bytes, 4 bytes leading 38 zeroes and 8 bytes little endian encoded 64 bit integer. 39 """ 40 return b"\x00"*4 + n.to_bytes(8, 'little') 41 42 def aead_encrypt(key: bytes, nonce: int, associated_data: bytes, data: bytes) -> bytes: 43 nonce_bytes = get_nonce_bytes(nonce) 44 return chacha20_poly1305_encrypt(key=key, 45 nonce=nonce_bytes, 46 associated_data=associated_data, 47 data=data) 48 49 def aead_decrypt(key: bytes, nonce: int, associated_data: bytes, data: bytes) -> bytes: 50 nonce_bytes = get_nonce_bytes(nonce) 51 return chacha20_poly1305_decrypt(key=key, 52 nonce=nonce_bytes, 53 associated_data=associated_data, 54 data=data) 55 56 def get_bolt8_hkdf(salt, ikm): 57 """RFC5869 HKDF instantiated in the specific form 58 used in Lightning BOLT 8: 59 Extract and expand to 64 bytes using HMAC-SHA256, 60 with info field set to a zero length string as per BOLT8 61 Return as two 32 byte fields. 62 """ 63 #Extract 64 prk = hmac_oneshot(salt, msg=ikm, digest=hashlib.sha256) 65 assert len(prk) == 32 66 #Expand 67 info = b"" 68 T0 = b"" 69 T1 = hmac_oneshot(prk, T0 + info + b"\x01", digest=hashlib.sha256) 70 T2 = hmac_oneshot(prk, T1 + info + b"\x02", digest=hashlib.sha256) 71 assert len(T1 + T2) == 64 72 return T1, T2 73 74 def act1_initiator_message(hs, epriv, epub): 75 ss = get_ecdh(epriv, hs.responder_pub) 76 ck2, temp_k1 = get_bolt8_hkdf(hs.ck, ss) 77 hs.ck = ck2 78 c = aead_encrypt(temp_k1, 0, hs.update(epub), b"") 79 #for next step if we do it 80 hs.update(c) 81 msg = hs.handshake_version + epub + c 82 assert len(msg) == 50 83 return msg, temp_k1 84 85 86 def create_ephemeral_key() -> (bytes, bytes): 87 privkey = ecc.ECPrivkey.generate_random_key() 88 return privkey.get_secret_bytes(), privkey.get_public_key_bytes() 89 90 91 class LNTransportBase: 92 reader: StreamReader 93 writer: StreamWriter 94 privkey: bytes 95 96 def name(self) -> str: 97 raise NotImplementedError() 98 99 def send_bytes(self, msg: bytes) -> None: 100 l = len(msg).to_bytes(2, 'big') 101 lc = aead_encrypt(self.sk, self.sn(), b'', l) 102 c = aead_encrypt(self.sk, self.sn(), b'', msg) 103 assert len(lc) == 18 104 assert len(c) == len(msg) + 16 105 self.writer.write(lc+c) 106 107 async def read_messages(self): 108 read_buffer = b'' 109 while True: 110 rn_l, rk_l = self.rn() 111 rn_m, rk_m = self.rn() 112 while True: 113 if len(read_buffer) >= 18: 114 lc = read_buffer[:18] 115 l = aead_decrypt(rk_l, rn_l, b'', lc) 116 length = int.from_bytes(l, 'big') 117 offset = 18 + length + 16 118 if len(read_buffer) >= offset: 119 c = read_buffer[18:offset] 120 read_buffer = read_buffer[offset:] 121 msg = aead_decrypt(rk_m, rn_m, b'', c) 122 yield msg 123 break 124 try: 125 s = await self.reader.read(2**10) 126 except asyncio.CancelledError: 127 raise 128 except Exception: 129 s = None 130 if not s: 131 raise LightningPeerConnectionClosed() 132 read_buffer += s 133 134 def rn(self): 135 o = self._rn, self.rk 136 self._rn += 1 137 if self._rn == 1000: 138 self.r_ck, self.rk = get_bolt8_hkdf(self.r_ck, self.rk) 139 self._rn = 0 140 return o 141 142 def sn(self): 143 o = self._sn 144 self._sn += 1 145 if self._sn == 1000: 146 self.s_ck, self.sk = get_bolt8_hkdf(self.s_ck, self.sk) 147 self._sn = 0 148 return o 149 150 def init_counters(self, ck): 151 # init counters 152 self._sn = 0 153 self._rn = 0 154 self.r_ck = ck 155 self.s_ck = ck 156 157 def close(self): 158 self.writer.close() 159 160 161 class LNResponderTransport(LNTransportBase): 162 """Transport initiated by remote party.""" 163 164 def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter): 165 LNTransportBase.__init__(self) 166 self.reader = reader 167 self.writer = writer 168 self.privkey = privkey 169 170 def name(self): 171 return "responder" 172 173 async def handshake(self, **kwargs): 174 hs = HandshakeState(privkey_to_pubkey(self.privkey)) 175 act1 = b'' 176 while len(act1) < 50: 177 buf = await self.reader.read(50 - len(act1)) 178 if not buf: 179 raise HandshakeFailed('responder disconnected') 180 act1 += buf 181 if len(act1) != 50: 182 raise HandshakeFailed('responder: short act 1 read, length is ' + str(len(act1))) 183 if bytes([act1[0]]) != HandshakeState.handshake_version: 184 raise HandshakeFailed('responder: bad handshake version in act 1') 185 c = act1[-16:] 186 re = act1[1:34] 187 h = hs.update(re) 188 ss = get_ecdh(self.privkey, re) 189 ck, temp_k1 = get_bolt8_hkdf(sha256(HandshakeState.protocol_name), ss) 190 _p = aead_decrypt(temp_k1, 0, h, c) 191 hs.update(c) 192 193 # act 2 194 if 'epriv' not in kwargs: 195 epriv, epub = create_ephemeral_key() 196 else: 197 epriv = kwargs['epriv'] 198 epub = ecc.ECPrivkey(epriv).get_public_key_bytes() 199 hs.ck = ck 200 hs.responder_pub = re 201 202 msg, temp_k2 = act1_initiator_message(hs, epriv, epub) 203 self.writer.write(msg) 204 205 # act 3 206 act3 = b'' 207 while len(act3) < 66: 208 buf = await self.reader.read(66 - len(act3)) 209 if not buf: 210 raise HandshakeFailed('responder disconnected') 211 act3 += buf 212 if len(act3) != 66: 213 raise HandshakeFailed('responder: short act 3 read, length is ' + str(len(act3))) 214 if bytes([act3[0]]) != HandshakeState.handshake_version: 215 raise HandshakeFailed('responder: bad handshake version in act 3') 216 c = act3[1:50] 217 t = act3[-16:] 218 rs = aead_decrypt(temp_k2, 1, hs.h, c) 219 ss = get_ecdh(epriv, rs) 220 ck, temp_k3 = get_bolt8_hkdf(hs.ck, ss) 221 _p = aead_decrypt(temp_k3, 0, hs.update(c), t) 222 self.rk, self.sk = get_bolt8_hkdf(ck, b'') 223 self.init_counters(ck) 224 return rs 225 226 227 class LNTransport(LNTransportBase): 228 """Transport initiated by local party.""" 229 230 def __init__(self, privkey: bytes, peer_addr: LNPeerAddr, *, 231 proxy: Optional[dict]): 232 LNTransportBase.__init__(self) 233 assert type(privkey) is bytes and len(privkey) == 32 234 self.privkey = privkey 235 self.peer_addr = peer_addr 236 self.proxy = MySocksProxy.from_proxy_dict(proxy) 237 238 def name(self): 239 return self.peer_addr.net_addr_str() 240 241 async def handshake(self): 242 if not self.proxy: 243 self.reader, self.writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port) 244 else: 245 self.reader, self.writer = await self.proxy.open_connection(self.peer_addr.host, self.peer_addr.port) 246 hs = HandshakeState(self.peer_addr.pubkey) 247 # Get a new ephemeral key 248 epriv, epub = create_ephemeral_key() 249 250 msg, _temp_k1 = act1_initiator_message(hs, epriv, epub) 251 # act 1 252 self.writer.write(msg) 253 rspns = await self.reader.read(2**10) 254 if len(rspns) != 50: 255 raise HandshakeFailed(f"Lightning handshake act 1 response has bad length, " 256 f"are you sure this is the right pubkey? {self.peer_addr}") 257 hver, alice_epub, tag = rspns[0], rspns[1:34], rspns[34:] 258 if bytes([hver]) != hs.handshake_version: 259 raise HandshakeFailed("unexpected handshake version: {}".format(hver)) 260 # act 2 261 hs.update(alice_epub) 262 ss = get_ecdh(epriv, alice_epub) 263 ck, temp_k2 = get_bolt8_hkdf(hs.ck, ss) 264 hs.ck = ck 265 p = aead_decrypt(temp_k2, 0, hs.h, tag) 266 hs.update(tag) 267 # act 3 268 my_pubkey = privkey_to_pubkey(self.privkey) 269 c = aead_encrypt(temp_k2, 1, hs.h, my_pubkey) 270 hs.update(c) 271 ss = get_ecdh(self.privkey[:32], alice_epub) 272 ck, temp_k3 = get_bolt8_hkdf(hs.ck, ss) 273 hs.ck = ck 274 t = aead_encrypt(temp_k3, 0, hs.h, b'') 275 msg = hs.handshake_version + c + t 276 self.writer.write(msg) 277 self.sk, self.rk = get_bolt8_hkdf(hs.ck, b'') 278 self.init_counters(ck)