lnmsg.py (23442B)
1 import os 2 import csv 3 import io 4 from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional 5 from collections import OrderedDict 6 7 from .lnutil import OnionFailureCodeMetaFlag 8 9 10 class MalformedMsg(Exception): pass 11 class UnknownMsgFieldType(MalformedMsg): pass 12 class UnexpectedEndOfStream(MalformedMsg): pass 13 class FieldEncodingNotMinimal(MalformedMsg): pass 14 class UnknownMandatoryTLVRecordType(MalformedMsg): pass 15 class MsgTrailingGarbage(MalformedMsg): pass 16 class MsgInvalidFieldOrder(MalformedMsg): pass 17 class UnexpectedFieldSizeForEncoder(MalformedMsg): pass 18 19 20 def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int: 21 cur_pos = fd.tell() 22 end_pos = fd.seek(0, io.SEEK_END) 23 fd.seek(cur_pos) 24 return end_pos - cur_pos 25 26 27 def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None: 28 # note: it's faster to read n bytes and then check if we read n, than 29 # to assert we can read at least n and then read n bytes. 30 nremaining = _num_remaining_bytes_to_read(fd) 31 if nremaining < n: 32 raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left") 33 34 35 def write_bigsize_int(i: int) -> bytes: 36 assert i >= 0, i 37 if i < 0xfd: 38 return int.to_bytes(i, length=1, byteorder="big", signed=False) 39 elif i < 0x1_0000: 40 return b"\xfd" + int.to_bytes(i, length=2, byteorder="big", signed=False) 41 elif i < 0x1_0000_0000: 42 return b"\xfe" + int.to_bytes(i, length=4, byteorder="big", signed=False) 43 else: 44 return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False) 45 46 47 def read_bigsize_int(fd: io.BytesIO) -> Optional[int]: 48 try: 49 first = fd.read(1)[0] 50 except IndexError: 51 return None # end of file 52 if first < 0xfd: 53 return first 54 elif first == 0xfd: 55 buf = fd.read(2) 56 if len(buf) != 2: 57 raise UnexpectedEndOfStream() 58 val = int.from_bytes(buf, byteorder="big", signed=False) 59 if not (0xfd <= val < 0x1_0000): 60 raise FieldEncodingNotMinimal() 61 return val 62 elif first == 0xfe: 63 buf = fd.read(4) 64 if len(buf) != 4: 65 raise UnexpectedEndOfStream() 66 val = int.from_bytes(buf, byteorder="big", signed=False) 67 if not (0x1_0000 <= val < 0x1_0000_0000): 68 raise FieldEncodingNotMinimal() 69 return val 70 elif first == 0xff: 71 buf = fd.read(8) 72 if len(buf) != 8: 73 raise UnexpectedEndOfStream() 74 val = int.from_bytes(buf, byteorder="big", signed=False) 75 if not (0x1_0000_0000 <= val): 76 raise FieldEncodingNotMinimal() 77 return val 78 raise Exception() 79 80 81 # TODO: maybe if field_type is not "byte", we could return a list of type_len sized chunks? 82 # if field_type is a numeric, we could return a list of ints? 83 def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> Union[bytes, int]: 84 if not fd: raise Exception() 85 if isinstance(count, int): 86 assert count >= 0, f"{count!r} must be non-neg int" 87 elif count == "...": 88 pass 89 else: 90 raise Exception(f"unexpected field count: {count!r}") 91 if count == 0: 92 return b"" 93 type_len = None 94 if field_type == 'byte': 95 type_len = 1 96 elif field_type in ('u8', 'u16', 'u32', 'u64'): 97 if field_type == 'u8': 98 type_len = 1 99 elif field_type == 'u16': 100 type_len = 2 101 elif field_type == 'u32': 102 type_len = 4 103 else: 104 assert field_type == 'u64' 105 type_len = 8 106 assert count == 1, count 107 buf = fd.read(type_len) 108 if len(buf) != type_len: 109 raise UnexpectedEndOfStream() 110 return int.from_bytes(buf, byteorder="big", signed=False) 111 elif field_type in ('tu16', 'tu32', 'tu64'): 112 if field_type == 'tu16': 113 type_len = 2 114 elif field_type == 'tu32': 115 type_len = 4 116 else: 117 assert field_type == 'tu64' 118 type_len = 8 119 assert count == 1, count 120 raw = fd.read(type_len) 121 if len(raw) > 0 and raw[0] == 0x00: 122 raise FieldEncodingNotMinimal() 123 return int.from_bytes(raw, byteorder="big", signed=False) 124 elif field_type == 'varint': 125 assert count == 1, count 126 val = read_bigsize_int(fd) 127 if val is None: 128 raise UnexpectedEndOfStream() 129 return val 130 elif field_type == 'chain_hash': 131 type_len = 32 132 elif field_type == 'channel_id': 133 type_len = 32 134 elif field_type == 'sha256': 135 type_len = 32 136 elif field_type == 'signature': 137 type_len = 64 138 elif field_type == 'point': 139 type_len = 33 140 elif field_type == 'short_channel_id': 141 type_len = 8 142 143 if count == "...": 144 total_len = -1 # read all 145 else: 146 if type_len is None: 147 raise UnknownMsgFieldType(f"unknown field type: {field_type!r}") 148 total_len = count * type_len 149 150 buf = fd.read(total_len) 151 if total_len >= 0 and len(buf) != total_len: 152 raise UnexpectedEndOfStream() 153 return buf 154 155 156 # TODO: maybe for "value" we could accept a list with len "count" of appropriate items 157 def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str], 158 value: Union[bytes, int]) -> None: 159 if not fd: raise Exception() 160 if isinstance(count, int): 161 assert count >= 0, f"{count!r} must be non-neg int" 162 elif count == "...": 163 pass 164 else: 165 raise Exception(f"unexpected field count: {count!r}") 166 if count == 0: 167 return 168 type_len = None 169 if field_type == 'byte': 170 type_len = 1 171 elif field_type == 'u8': 172 type_len = 1 173 elif field_type == 'u16': 174 type_len = 2 175 elif field_type == 'u32': 176 type_len = 4 177 elif field_type == 'u64': 178 type_len = 8 179 elif field_type in ('tu16', 'tu32', 'tu64'): 180 if field_type == 'tu16': 181 type_len = 2 182 elif field_type == 'tu32': 183 type_len = 4 184 else: 185 assert field_type == 'tu64' 186 type_len = 8 187 assert count == 1, count 188 if isinstance(value, int): 189 value = int.to_bytes(value, length=type_len, byteorder="big", signed=False) 190 if not isinstance(value, (bytes, bytearray)): 191 raise Exception(f"can only write bytes into fd. got: {value!r}") 192 while len(value) > 0 and value[0] == 0x00: 193 value = value[1:] 194 nbytes_written = fd.write(value) 195 if nbytes_written != len(value): 196 raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") 197 return 198 elif field_type == 'varint': 199 assert count == 1, count 200 if isinstance(value, int): 201 value = write_bigsize_int(value) 202 if not isinstance(value, (bytes, bytearray)): 203 raise Exception(f"can only write bytes into fd. got: {value!r}") 204 nbytes_written = fd.write(value) 205 if nbytes_written != len(value): 206 raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") 207 return 208 elif field_type == 'chain_hash': 209 type_len = 32 210 elif field_type == 'channel_id': 211 type_len = 32 212 elif field_type == 'sha256': 213 type_len = 32 214 elif field_type == 'signature': 215 type_len = 64 216 elif field_type == 'point': 217 type_len = 33 218 elif field_type == 'short_channel_id': 219 type_len = 8 220 total_len = -1 221 if count != "...": 222 if type_len is None: 223 raise UnknownMsgFieldType(f"unknown field type: {field_type!r}") 224 total_len = count * type_len 225 if isinstance(value, int) and (count == 1 or field_type == 'byte'): 226 value = int.to_bytes(value, length=total_len, byteorder="big", signed=False) 227 if not isinstance(value, (bytes, bytearray)): 228 raise Exception(f"can only write bytes into fd. got: {value!r}") 229 if count != "..." and total_len != len(value): 230 raise UnexpectedFieldSizeForEncoder(f"expected: {total_len}, got {len(value)}") 231 nbytes_written = fd.write(value) 232 if nbytes_written != len(value): 233 raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") 234 235 236 def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]: 237 if not fd: raise Exception() 238 tlv_type = _read_field(fd=fd, field_type="varint", count=1) 239 tlv_len = _read_field(fd=fd, field_type="varint", count=1) 240 tlv_val = _read_field(fd=fd, field_type="byte", count=tlv_len) 241 return tlv_type, tlv_val 242 243 244 def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None: 245 if not fd: raise Exception() 246 tlv_len = len(tlv_val) 247 _write_field(fd=fd, field_type="varint", count=1, value=tlv_type) 248 _write_field(fd=fd, field_type="varint", count=1, value=tlv_len) 249 _write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val) 250 251 252 def _resolve_field_count(field_count_str: str, *, vars_dict: dict, allow_any=False) -> Union[int, str]: 253 """Returns an evaluated field count, typically an int. 254 If allow_any is True, the return value can be a str with value=="...". 255 """ 256 if field_count_str == "": 257 field_count = 1 258 elif field_count_str == "...": 259 if not allow_any: 260 raise Exception("field count is '...' but allow_any is False") 261 return field_count_str 262 else: 263 try: 264 field_count = int(field_count_str) 265 except ValueError: 266 field_count = vars_dict[field_count_str] 267 if isinstance(field_count, (bytes, bytearray)): 268 field_count = int.from_bytes(field_count, byteorder="big") 269 assert isinstance(field_count, int) 270 return field_count 271 272 273 def _parse_msgtype_intvalue_for_onion_wire(value: str) -> int: 274 msg_type_int = 0 275 for component in value.split("|"): 276 try: 277 msg_type_int |= int(component) 278 except ValueError: 279 msg_type_int |= OnionFailureCodeMetaFlag[component] 280 return msg_type_int 281 282 283 class LNSerializer: 284 285 def __init__(self, *, for_onion_wire: bool = False): 286 # TODO msg_type could be 'int' everywhere... 287 self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]] 288 self.msg_type_from_name = {} # type: Dict[str, bytes] 289 290 self.in_tlv_stream_get_tlv_record_scheme_from_type = {} # type: Dict[str, Dict[int, List[Sequence[str]]]] 291 self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]] 292 self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]] 293 294 if for_onion_wire: 295 path = os.path.join(os.path.dirname(__file__), "lnwire", "onion_wire.csv") 296 else: 297 path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv") 298 with open(path, newline='') as f: 299 csvreader = csv.reader(f) 300 for row in csvreader: 301 #print(f">>> {row!r}") 302 if row[0] == "msgtype": 303 # msgtype,<msgname>,<value>[,<option>] 304 msg_type_name = row[1] 305 if for_onion_wire: 306 msg_type_int = _parse_msgtype_intvalue_for_onion_wire(str(row[2])) 307 else: 308 msg_type_int = int(row[2]) 309 msg_type_bytes = msg_type_int.to_bytes(2, 'big') 310 assert msg_type_bytes not in self.msg_scheme_from_type, f"type collision? for {msg_type_name}" 311 assert msg_type_name not in self.msg_type_from_name, f"type collision? for {msg_type_name}" 312 row[2] = msg_type_int 313 self.msg_scheme_from_type[msg_type_bytes] = [tuple(row)] 314 self.msg_type_from_name[msg_type_name] = msg_type_bytes 315 elif row[0] == "msgdata": 316 # msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>] 317 assert msg_type_name == row[1] 318 self.msg_scheme_from_type[msg_type_bytes].append(tuple(row)) 319 elif row[0] == "tlvtype": 320 # tlvtype,<tlvstreamname>,<tlvname>,<value>[,<option>] 321 tlv_stream_name = row[1] 322 tlv_record_name = row[2] 323 tlv_record_type = int(row[3]) 324 row[3] = tlv_record_type 325 if tlv_stream_name not in self.in_tlv_stream_get_tlv_record_scheme_from_type: 326 self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name] = OrderedDict() 327 self.in_tlv_stream_get_record_type_from_name[tlv_stream_name] = {} 328 self.in_tlv_stream_get_record_name_from_type[tlv_stream_name] = {} 329 assert tlv_record_type not in self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}" 330 assert tlv_record_name not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}" 331 assert tlv_record_type not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}" 332 self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type] = [tuple(row)] 333 self.in_tlv_stream_get_record_type_from_name[tlv_stream_name][tlv_record_name] = tlv_record_type 334 self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type] = tlv_record_name 335 if max(self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name].keys()) > tlv_record_type: 336 raise Exception(f"tlv record types must be listed in monotonically increasing order for stream. " 337 f"stream={tlv_stream_name}") 338 elif row[0] == "tlvdata": 339 # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>] 340 assert tlv_stream_name == row[1] 341 assert tlv_record_name == row[2] 342 self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row)) 343 else: 344 pass # TODO 345 346 def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None: 347 scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name] 348 for tlv_record_type, scheme in scheme_map.items(): # note: tlv_record_type is monotonically increasing 349 tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type] 350 if tlv_record_name not in kwargs: 351 continue 352 with io.BytesIO() as tlv_record_fd: 353 for row in scheme: 354 if row[0] == "tlvtype": 355 pass 356 elif row[0] == "tlvdata": 357 # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>] 358 assert tlv_stream_name == row[1] 359 assert tlv_record_name == row[2] 360 field_name = row[3] 361 field_type = row[4] 362 field_count_str = row[5] 363 field_count = _resolve_field_count(field_count_str, 364 vars_dict=kwargs[tlv_record_name], 365 allow_any=True) 366 field_value = kwargs[tlv_record_name][field_name] 367 _write_field(fd=tlv_record_fd, 368 field_type=field_type, 369 count=field_count, 370 value=field_value) 371 else: 372 raise Exception(f"unexpected row in scheme: {row!r}") 373 _write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue()) 374 375 def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str, Dict[str, Any]]: 376 parsed = {} # type: Dict[str, Dict[str, Any]] 377 scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name] 378 last_seen_tlv_record_type = -1 # type: int 379 while _num_remaining_bytes_to_read(fd) > 0: 380 tlv_record_type, tlv_record_val = _read_tlv_record(fd=fd) 381 if not (tlv_record_type > last_seen_tlv_record_type): 382 raise MsgInvalidFieldOrder(f"TLV records must be monotonically increasing by type. " 383 f"cur: {tlv_record_type}. prev: {last_seen_tlv_record_type}") 384 last_seen_tlv_record_type = tlv_record_type 385 try: 386 scheme = scheme_map[tlv_record_type] 387 except KeyError: 388 if tlv_record_type % 2 == 0: 389 # unknown "even" type: hard fail 390 raise UnknownMandatoryTLVRecordType(f"{tlv_stream_name}/{tlv_record_type}") from None 391 else: 392 # unknown "odd" type: skip it 393 continue 394 tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type] 395 parsed[tlv_record_name] = {} 396 with io.BytesIO(tlv_record_val) as tlv_record_fd: 397 for row in scheme: 398 #print(f"row: {row!r}") 399 if row[0] == "tlvtype": 400 pass 401 elif row[0] == "tlvdata": 402 # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>] 403 assert tlv_stream_name == row[1] 404 assert tlv_record_name == row[2] 405 field_name = row[3] 406 field_type = row[4] 407 field_count_str = row[5] 408 field_count = _resolve_field_count(field_count_str, 409 vars_dict=parsed[tlv_record_name], 410 allow_any=True) 411 #print(f">> count={field_count}. parsed={parsed}") 412 parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd, 413 field_type=field_type, 414 count=field_count) 415 else: 416 raise Exception(f"unexpected row in scheme: {row!r}") 417 if _num_remaining_bytes_to_read(tlv_record_fd) > 0: 418 raise MsgTrailingGarbage(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage") 419 return parsed 420 421 def encode_msg(self, msg_type: str, **kwargs) -> bytes: 422 """ 423 Encode kwargs into a Lightning message (bytes) 424 of the type given in the msg_type string 425 """ 426 #print(f">>> encode_msg. msg_type={msg_type}, payload={kwargs!r}") 427 msg_type_bytes = self.msg_type_from_name[msg_type] 428 scheme = self.msg_scheme_from_type[msg_type_bytes] 429 with io.BytesIO() as fd: 430 fd.write(msg_type_bytes) 431 for row in scheme: 432 if row[0] == "msgtype": 433 pass 434 elif row[0] == "msgdata": 435 # msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>] 436 field_name = row[2] 437 field_type = row[3] 438 field_count_str = row[4] 439 #print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}") 440 field_count = _resolve_field_count(field_count_str, vars_dict=kwargs) 441 if field_name == "tlvs": 442 tlv_stream_name = field_type 443 if tlv_stream_name in kwargs: 444 self.write_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name, **(kwargs[tlv_stream_name])) 445 continue 446 try: 447 field_value = kwargs[field_name] 448 except KeyError: 449 if len(row) > 5: 450 break # optional feature field not present 451 else: 452 field_value = 0 # default mandatory fields to zero 453 #print(f">>> encode_msg. writing field: {field_name}. value={field_value!r}. field_type={field_type!r}. count={field_count!r}") 454 _write_field(fd=fd, 455 field_type=field_type, 456 count=field_count, 457 value=field_value) 458 #print(f">>> encode_msg. so far: {fd.getvalue().hex()}") 459 else: 460 raise Exception(f"unexpected row in scheme: {row!r}") 461 return fd.getvalue() 462 463 def decode_msg(self, data: bytes) -> Tuple[str, dict]: 464 """ 465 Decode Lightning message by reading the first 466 two bytes to determine message type. 467 468 Returns message type string and parsed message contents dict 469 """ 470 #print(f"decode_msg >>> {data.hex()}") 471 assert len(data) >= 2 472 msg_type_bytes = data[:2] 473 msg_type_int = int.from_bytes(msg_type_bytes, byteorder="big", signed=False) 474 scheme = self.msg_scheme_from_type[msg_type_bytes] 475 assert scheme[0][2] == msg_type_int 476 msg_type_name = scheme[0][1] 477 parsed = {} 478 with io.BytesIO(data[2:]) as fd: 479 for row in scheme: 480 #print(f"row: {row!r}") 481 if row[0] == "msgtype": 482 pass 483 elif row[0] == "msgdata": 484 field_name = row[2] 485 field_type = row[3] 486 field_count_str = row[4] 487 field_count = _resolve_field_count(field_count_str, vars_dict=parsed) 488 if field_name == "tlvs": 489 tlv_stream_name = field_type 490 d = self.read_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name) 491 parsed[tlv_stream_name] = d 492 continue 493 #print(f">> count={field_count}. parsed={parsed}") 494 try: 495 parsed[field_name] = _read_field(fd=fd, 496 field_type=field_type, 497 count=field_count) 498 except UnexpectedEndOfStream as e: 499 if len(row) > 5: 500 break # optional feature field not present 501 else: 502 raise 503 else: 504 raise Exception(f"unexpected row in scheme: {row!r}") 505 return msg_type_name, parsed 506 507 508 _inst = LNSerializer() 509 encode_msg = _inst.encode_msg 510 decode_msg = _inst.decode_msg 511 512 513 OnionWireSerializer = LNSerializer(for_onion_wire=True)