electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

lnmsg.py (23442B)


      1 import os
      2 import csv
      3 import io
      4 from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional
      5 from collections import OrderedDict
      6 
      7 from .lnutil import OnionFailureCodeMetaFlag
      8 
      9 
     10 class MalformedMsg(Exception): pass
     11 class UnknownMsgFieldType(MalformedMsg): pass
     12 class UnexpectedEndOfStream(MalformedMsg): pass
     13 class FieldEncodingNotMinimal(MalformedMsg): pass
     14 class UnknownMandatoryTLVRecordType(MalformedMsg): pass
     15 class MsgTrailingGarbage(MalformedMsg): pass
     16 class MsgInvalidFieldOrder(MalformedMsg): pass
     17 class UnexpectedFieldSizeForEncoder(MalformedMsg): pass
     18 
     19 
     20 def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int:
     21     cur_pos = fd.tell()
     22     end_pos = fd.seek(0, io.SEEK_END)
     23     fd.seek(cur_pos)
     24     return end_pos - cur_pos
     25 
     26 
     27 def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
     28     # note: it's faster to read n bytes and then check if we read n, than
     29     #       to assert we can read at least n and then read n bytes.
     30     nremaining = _num_remaining_bytes_to_read(fd)
     31     if nremaining < n:
     32         raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left")
     33 
     34 
     35 def write_bigsize_int(i: int) -> bytes:
     36     assert i >= 0, i
     37     if i < 0xfd:
     38         return int.to_bytes(i, length=1, byteorder="big", signed=False)
     39     elif i < 0x1_0000:
     40         return b"\xfd" + int.to_bytes(i, length=2, byteorder="big", signed=False)
     41     elif i < 0x1_0000_0000:
     42         return b"\xfe" + int.to_bytes(i, length=4, byteorder="big", signed=False)
     43     else:
     44         return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False)
     45 
     46 
     47 def read_bigsize_int(fd: io.BytesIO) -> Optional[int]:
     48     try:
     49         first = fd.read(1)[0]
     50     except IndexError:
     51         return None  # end of file
     52     if first < 0xfd:
     53         return first
     54     elif first == 0xfd:
     55         buf = fd.read(2)
     56         if len(buf) != 2:
     57             raise UnexpectedEndOfStream()
     58         val = int.from_bytes(buf, byteorder="big", signed=False)
     59         if not (0xfd <= val < 0x1_0000):
     60             raise FieldEncodingNotMinimal()
     61         return val
     62     elif first == 0xfe:
     63         buf = fd.read(4)
     64         if len(buf) != 4:
     65             raise UnexpectedEndOfStream()
     66         val = int.from_bytes(buf, byteorder="big", signed=False)
     67         if not (0x1_0000 <= val < 0x1_0000_0000):
     68             raise FieldEncodingNotMinimal()
     69         return val
     70     elif first == 0xff:
     71         buf = fd.read(8)
     72         if len(buf) != 8:
     73             raise UnexpectedEndOfStream()
     74         val = int.from_bytes(buf, byteorder="big", signed=False)
     75         if not (0x1_0000_0000 <= val):
     76             raise FieldEncodingNotMinimal()
     77         return val
     78     raise Exception()
     79 
     80 
     81 # TODO: maybe if field_type is not "byte", we could return a list of type_len sized chunks?
     82 #       if field_type is a numeric, we could return a list of ints?
     83 def _read_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str]) -> Union[bytes, int]:
     84     if not fd: raise Exception()
     85     if isinstance(count, int):
     86         assert count >= 0, f"{count!r} must be non-neg int"
     87     elif count == "...":
     88         pass
     89     else:
     90         raise Exception(f"unexpected field count: {count!r}")
     91     if count == 0:
     92         return b""
     93     type_len = None
     94     if field_type == 'byte':
     95         type_len = 1
     96     elif field_type in ('u8', 'u16', 'u32', 'u64'):
     97         if field_type == 'u8':
     98             type_len = 1
     99         elif field_type == 'u16':
    100             type_len = 2
    101         elif field_type == 'u32':
    102             type_len = 4
    103         else:
    104             assert field_type == 'u64'
    105             type_len = 8
    106         assert count == 1, count
    107         buf = fd.read(type_len)
    108         if len(buf) != type_len:
    109             raise UnexpectedEndOfStream()
    110         return int.from_bytes(buf, byteorder="big", signed=False)
    111     elif field_type in ('tu16', 'tu32', 'tu64'):
    112         if field_type == 'tu16':
    113             type_len = 2
    114         elif field_type == 'tu32':
    115             type_len = 4
    116         else:
    117             assert field_type == 'tu64'
    118             type_len = 8
    119         assert count == 1, count
    120         raw = fd.read(type_len)
    121         if len(raw) > 0 and raw[0] == 0x00:
    122             raise FieldEncodingNotMinimal()
    123         return int.from_bytes(raw, byteorder="big", signed=False)
    124     elif field_type == 'varint':
    125         assert count == 1, count
    126         val = read_bigsize_int(fd)
    127         if val is None:
    128             raise UnexpectedEndOfStream()
    129         return val
    130     elif field_type == 'chain_hash':
    131         type_len = 32
    132     elif field_type == 'channel_id':
    133         type_len = 32
    134     elif field_type == 'sha256':
    135         type_len = 32
    136     elif field_type == 'signature':
    137         type_len = 64
    138     elif field_type == 'point':
    139         type_len = 33
    140     elif field_type == 'short_channel_id':
    141         type_len = 8
    142 
    143     if count == "...":
    144         total_len = -1  # read all
    145     else:
    146         if type_len is None:
    147             raise UnknownMsgFieldType(f"unknown field type: {field_type!r}")
    148         total_len = count * type_len
    149 
    150     buf = fd.read(total_len)
    151     if total_len >= 0 and len(buf) != total_len:
    152         raise UnexpectedEndOfStream()
    153     return buf
    154 
    155 
    156 # TODO: maybe for "value" we could accept a list with len "count" of appropriate items
    157 def _write_field(*, fd: io.BytesIO, field_type: str, count: Union[int, str],
    158                  value: Union[bytes, int]) -> None:
    159     if not fd: raise Exception()
    160     if isinstance(count, int):
    161         assert count >= 0, f"{count!r} must be non-neg int"
    162     elif count == "...":
    163         pass
    164     else:
    165         raise Exception(f"unexpected field count: {count!r}")
    166     if count == 0:
    167         return
    168     type_len = None
    169     if field_type == 'byte':
    170         type_len = 1
    171     elif field_type == 'u8':
    172         type_len = 1
    173     elif field_type == 'u16':
    174         type_len = 2
    175     elif field_type == 'u32':
    176         type_len = 4
    177     elif field_type == 'u64':
    178         type_len = 8
    179     elif field_type in ('tu16', 'tu32', 'tu64'):
    180         if field_type == 'tu16':
    181             type_len = 2
    182         elif field_type == 'tu32':
    183             type_len = 4
    184         else:
    185             assert field_type == 'tu64'
    186             type_len = 8
    187         assert count == 1, count
    188         if isinstance(value, int):
    189             value = int.to_bytes(value, length=type_len, byteorder="big", signed=False)
    190         if not isinstance(value, (bytes, bytearray)):
    191             raise Exception(f"can only write bytes into fd. got: {value!r}")
    192         while len(value) > 0 and value[0] == 0x00:
    193             value = value[1:]
    194         nbytes_written = fd.write(value)
    195         if nbytes_written != len(value):
    196             raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
    197         return
    198     elif field_type == 'varint':
    199         assert count == 1, count
    200         if isinstance(value, int):
    201             value = write_bigsize_int(value)
    202         if not isinstance(value, (bytes, bytearray)):
    203             raise Exception(f"can only write bytes into fd. got: {value!r}")
    204         nbytes_written = fd.write(value)
    205         if nbytes_written != len(value):
    206             raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
    207         return
    208     elif field_type == 'chain_hash':
    209         type_len = 32
    210     elif field_type == 'channel_id':
    211         type_len = 32
    212     elif field_type == 'sha256':
    213         type_len = 32
    214     elif field_type == 'signature':
    215         type_len = 64
    216     elif field_type == 'point':
    217         type_len = 33
    218     elif field_type == 'short_channel_id':
    219         type_len = 8
    220     total_len = -1
    221     if count != "...":
    222         if type_len is None:
    223             raise UnknownMsgFieldType(f"unknown field type: {field_type!r}")
    224         total_len = count * type_len
    225         if isinstance(value, int) and (count == 1 or field_type == 'byte'):
    226             value = int.to_bytes(value, length=total_len, byteorder="big", signed=False)
    227     if not isinstance(value, (bytes, bytearray)):
    228         raise Exception(f"can only write bytes into fd. got: {value!r}")
    229     if count != "..." and total_len != len(value):
    230         raise UnexpectedFieldSizeForEncoder(f"expected: {total_len}, got {len(value)}")
    231     nbytes_written = fd.write(value)
    232     if nbytes_written != len(value):
    233         raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
    234 
    235 
    236 def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]:
    237     if not fd: raise Exception()
    238     tlv_type = _read_field(fd=fd, field_type="varint", count=1)
    239     tlv_len = _read_field(fd=fd, field_type="varint", count=1)
    240     tlv_val = _read_field(fd=fd, field_type="byte", count=tlv_len)
    241     return tlv_type, tlv_val
    242 
    243 
    244 def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None:
    245     if not fd: raise Exception()
    246     tlv_len = len(tlv_val)
    247     _write_field(fd=fd, field_type="varint", count=1, value=tlv_type)
    248     _write_field(fd=fd, field_type="varint", count=1, value=tlv_len)
    249     _write_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
    250 
    251 
    252 def _resolve_field_count(field_count_str: str, *, vars_dict: dict, allow_any=False) -> Union[int, str]:
    253     """Returns an evaluated field count, typically an int.
    254     If allow_any is True, the return value can be a str with value=="...".
    255     """
    256     if field_count_str == "":
    257         field_count = 1
    258     elif field_count_str == "...":
    259         if not allow_any:
    260             raise Exception("field count is '...' but allow_any is False")
    261         return field_count_str
    262     else:
    263         try:
    264             field_count = int(field_count_str)
    265         except ValueError:
    266             field_count = vars_dict[field_count_str]
    267             if isinstance(field_count, (bytes, bytearray)):
    268                 field_count = int.from_bytes(field_count, byteorder="big")
    269     assert isinstance(field_count, int)
    270     return field_count
    271 
    272 
    273 def _parse_msgtype_intvalue_for_onion_wire(value: str) -> int:
    274     msg_type_int = 0
    275     for component in value.split("|"):
    276         try:
    277             msg_type_int |= int(component)
    278         except ValueError:
    279             msg_type_int |= OnionFailureCodeMetaFlag[component]
    280     return msg_type_int
    281 
    282 
    283 class LNSerializer:
    284 
    285     def __init__(self, *, for_onion_wire: bool = False):
    286         # TODO msg_type could be 'int' everywhere...
    287         self.msg_scheme_from_type = {}  # type: Dict[bytes, List[Sequence[str]]]
    288         self.msg_type_from_name = {}  # type: Dict[str, bytes]
    289 
    290         self.in_tlv_stream_get_tlv_record_scheme_from_type = {}  # type: Dict[str, Dict[int, List[Sequence[str]]]]
    291         self.in_tlv_stream_get_record_type_from_name = {}  # type: Dict[str, Dict[str, int]]
    292         self.in_tlv_stream_get_record_name_from_type = {}  # type: Dict[str, Dict[int, str]]
    293 
    294         if for_onion_wire:
    295             path = os.path.join(os.path.dirname(__file__), "lnwire", "onion_wire.csv")
    296         else:
    297             path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv")
    298         with open(path, newline='') as f:
    299             csvreader = csv.reader(f)
    300             for row in csvreader:
    301                 #print(f">>> {row!r}")
    302                 if row[0] == "msgtype":
    303                     # msgtype,<msgname>,<value>[,<option>]
    304                     msg_type_name = row[1]
    305                     if for_onion_wire:
    306                         msg_type_int = _parse_msgtype_intvalue_for_onion_wire(str(row[2]))
    307                     else:
    308                         msg_type_int = int(row[2])
    309                     msg_type_bytes = msg_type_int.to_bytes(2, 'big')
    310                     assert msg_type_bytes not in self.msg_scheme_from_type, f"type collision? for {msg_type_name}"
    311                     assert msg_type_name not in self.msg_type_from_name, f"type collision? for {msg_type_name}"
    312                     row[2] = msg_type_int
    313                     self.msg_scheme_from_type[msg_type_bytes] = [tuple(row)]
    314                     self.msg_type_from_name[msg_type_name] = msg_type_bytes
    315                 elif row[0] == "msgdata":
    316                     # msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
    317                     assert msg_type_name == row[1]
    318                     self.msg_scheme_from_type[msg_type_bytes].append(tuple(row))
    319                 elif row[0] == "tlvtype":
    320                     # tlvtype,<tlvstreamname>,<tlvname>,<value>[,<option>]
    321                     tlv_stream_name = row[1]
    322                     tlv_record_name = row[2]
    323                     tlv_record_type = int(row[3])
    324                     row[3] = tlv_record_type
    325                     if tlv_stream_name not in self.in_tlv_stream_get_tlv_record_scheme_from_type:
    326                         self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name] = OrderedDict()
    327                         self.in_tlv_stream_get_record_type_from_name[tlv_stream_name] = {}
    328                         self.in_tlv_stream_get_record_name_from_type[tlv_stream_name] = {}
    329                     assert tlv_record_type not in self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
    330                     assert tlv_record_name not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
    331                     assert tlv_record_type not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
    332                     self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type] = [tuple(row)]
    333                     self.in_tlv_stream_get_record_type_from_name[tlv_stream_name][tlv_record_name] = tlv_record_type
    334                     self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type] = tlv_record_name
    335                     if max(self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name].keys()) > tlv_record_type:
    336                         raise Exception(f"tlv record types must be listed in monotonically increasing order for stream. "
    337                                         f"stream={tlv_stream_name}")
    338                 elif row[0] == "tlvdata":
    339                     # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
    340                     assert tlv_stream_name == row[1]
    341                     assert tlv_record_name == row[2]
    342                     self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row))
    343                 else:
    344                     pass  # TODO
    345 
    346     def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None:
    347         scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
    348         for tlv_record_type, scheme in scheme_map.items():  # note: tlv_record_type is monotonically increasing
    349             tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
    350             if tlv_record_name not in kwargs:
    351                 continue
    352             with io.BytesIO() as tlv_record_fd:
    353                 for row in scheme:
    354                     if row[0] == "tlvtype":
    355                         pass
    356                     elif row[0] == "tlvdata":
    357                         # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
    358                         assert tlv_stream_name == row[1]
    359                         assert tlv_record_name == row[2]
    360                         field_name = row[3]
    361                         field_type = row[4]
    362                         field_count_str = row[5]
    363                         field_count = _resolve_field_count(field_count_str,
    364                                                            vars_dict=kwargs[tlv_record_name],
    365                                                            allow_any=True)
    366                         field_value = kwargs[tlv_record_name][field_name]
    367                         _write_field(fd=tlv_record_fd,
    368                                      field_type=field_type,
    369                                      count=field_count,
    370                                      value=field_value)
    371                     else:
    372                         raise Exception(f"unexpected row in scheme: {row!r}")
    373                 _write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
    374 
    375     def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str, Dict[str, Any]]:
    376         parsed = {}  # type: Dict[str, Dict[str, Any]]
    377         scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
    378         last_seen_tlv_record_type = -1  # type: int
    379         while _num_remaining_bytes_to_read(fd) > 0:
    380             tlv_record_type, tlv_record_val = _read_tlv_record(fd=fd)
    381             if not (tlv_record_type > last_seen_tlv_record_type):
    382                 raise MsgInvalidFieldOrder(f"TLV records must be monotonically increasing by type. "
    383                                            f"cur: {tlv_record_type}. prev: {last_seen_tlv_record_type}")
    384             last_seen_tlv_record_type = tlv_record_type
    385             try:
    386                 scheme = scheme_map[tlv_record_type]
    387             except KeyError:
    388                 if tlv_record_type % 2 == 0:
    389                     # unknown "even" type: hard fail
    390                     raise UnknownMandatoryTLVRecordType(f"{tlv_stream_name}/{tlv_record_type}") from None
    391                 else:
    392                     # unknown "odd" type: skip it
    393                     continue
    394             tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
    395             parsed[tlv_record_name] = {}
    396             with io.BytesIO(tlv_record_val) as tlv_record_fd:
    397                 for row in scheme:
    398                     #print(f"row: {row!r}")
    399                     if row[0] == "tlvtype":
    400                         pass
    401                     elif row[0] == "tlvdata":
    402                         # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
    403                         assert tlv_stream_name == row[1]
    404                         assert tlv_record_name == row[2]
    405                         field_name = row[3]
    406                         field_type = row[4]
    407                         field_count_str = row[5]
    408                         field_count = _resolve_field_count(field_count_str,
    409                                                            vars_dict=parsed[tlv_record_name],
    410                                                            allow_any=True)
    411                         #print(f">> count={field_count}. parsed={parsed}")
    412                         parsed[tlv_record_name][field_name] = _read_field(fd=tlv_record_fd,
    413                                                                           field_type=field_type,
    414                                                                           count=field_count)
    415                     else:
    416                         raise Exception(f"unexpected row in scheme: {row!r}")
    417                 if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
    418                     raise MsgTrailingGarbage(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")
    419         return parsed
    420 
    421     def encode_msg(self, msg_type: str, **kwargs) -> bytes:
    422         """
    423         Encode kwargs into a Lightning message (bytes)
    424         of the type given in the msg_type string
    425         """
    426         #print(f">>> encode_msg. msg_type={msg_type}, payload={kwargs!r}")
    427         msg_type_bytes = self.msg_type_from_name[msg_type]
    428         scheme = self.msg_scheme_from_type[msg_type_bytes]
    429         with io.BytesIO() as fd:
    430             fd.write(msg_type_bytes)
    431             for row in scheme:
    432                 if row[0] == "msgtype":
    433                     pass
    434                 elif row[0] == "msgdata":
    435                     # msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
    436                     field_name = row[2]
    437                     field_type = row[3]
    438                     field_count_str = row[4]
    439                     #print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}")
    440                     field_count = _resolve_field_count(field_count_str, vars_dict=kwargs)
    441                     if field_name == "tlvs":
    442                         tlv_stream_name = field_type
    443                         if tlv_stream_name in kwargs:
    444                             self.write_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name, **(kwargs[tlv_stream_name]))
    445                         continue
    446                     try:
    447                         field_value = kwargs[field_name]
    448                     except KeyError:
    449                         if len(row) > 5:
    450                             break  # optional feature field not present
    451                         else:
    452                             field_value = 0  # default mandatory fields to zero
    453                     #print(f">>> encode_msg. writing field: {field_name}. value={field_value!r}. field_type={field_type!r}. count={field_count!r}")
    454                     _write_field(fd=fd,
    455                                  field_type=field_type,
    456                                  count=field_count,
    457                                  value=field_value)
    458                     #print(f">>> encode_msg. so far: {fd.getvalue().hex()}")
    459                 else:
    460                     raise Exception(f"unexpected row in scheme: {row!r}")
    461             return fd.getvalue()
    462 
    463     def decode_msg(self, data: bytes) -> Tuple[str, dict]:
    464         """
    465         Decode Lightning message by reading the first
    466         two bytes to determine message type.
    467 
    468         Returns message type string and parsed message contents dict
    469         """
    470         #print(f"decode_msg >>> {data.hex()}")
    471         assert len(data) >= 2
    472         msg_type_bytes = data[:2]
    473         msg_type_int = int.from_bytes(msg_type_bytes, byteorder="big", signed=False)
    474         scheme = self.msg_scheme_from_type[msg_type_bytes]
    475         assert scheme[0][2] == msg_type_int
    476         msg_type_name = scheme[0][1]
    477         parsed = {}
    478         with io.BytesIO(data[2:]) as fd:
    479             for row in scheme:
    480                 #print(f"row: {row!r}")
    481                 if row[0] == "msgtype":
    482                     pass
    483                 elif row[0] == "msgdata":
    484                     field_name = row[2]
    485                     field_type = row[3]
    486                     field_count_str = row[4]
    487                     field_count = _resolve_field_count(field_count_str, vars_dict=parsed)
    488                     if field_name == "tlvs":
    489                         tlv_stream_name = field_type
    490                         d = self.read_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name)
    491                         parsed[tlv_stream_name] = d
    492                         continue
    493                     #print(f">> count={field_count}. parsed={parsed}")
    494                     try:
    495                         parsed[field_name] = _read_field(fd=fd,
    496                                                          field_type=field_type,
    497                                                          count=field_count)
    498                     except UnexpectedEndOfStream as e:
    499                         if len(row) > 5:
    500                             break  # optional feature field not present
    501                         else:
    502                             raise
    503                 else:
    504                     raise Exception(f"unexpected row in scheme: {row!r}")
    505         return msg_type_name, parsed
    506 
    507 
    508 _inst = LNSerializer()
    509 encode_msg = _inst.encode_msg
    510 decode_msg = _inst.decode_msg
    511 
    512 
    513 OnionWireSerializer = LNSerializer(for_onion_wire=True)