commit 5f1feee33114775f1435ff8cf77bf9e2cb41172c
parent f5cee9ecf6132d42657a7b2e70e4a0b0dab17fb0
Author: Janus <ysangkok@gmail.com>
Date: Tue, 5 Feb 2019 17:56:01 +0100
move lightning message encoding to new lnmsg module
Diffstat:
4 files changed, 164 insertions(+), 153 deletions(-)
diff --git a/electrum/lnbase.py b/electrum/lnbase.py
@@ -37,6 +37,7 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc,
MINIMUM_MAX_HTLC_VALUE_IN_FLIGHT_ACCEPTED, MAXIMUM_HTLC_MINIMUM_MSAT_ACCEPTED,
MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED)
from .lntransport import LNTransport, LNTransportBase
+from .lnmsg import encode_msg, decode_msg
if TYPE_CHECKING:
from .lnworker import LNWorker
@@ -48,153 +49,6 @@ def channel_id_from_funding_tx(funding_txid: str, funding_index: int) -> Tuple[b
i = int.from_bytes(funding_txid_bytes, 'big') ^ funding_index
return i.to_bytes(32, 'big'), funding_txid_bytes
-
-message_types = {}
-
-def handlesingle(x, ma: dict) -> int:
- """
- Evaluate a term of the simple language used
- to specify lightning message field lengths.
-
- If `x` is an integer, it is returned as is,
- otherwise it is treated as a variable and
- looked up in `ma`.
-
- If the value in `ma` was no integer, it is
- assumed big-endian bytes and decoded.
-
- Returns int
- """
- try:
- x = int(x)
- except ValueError:
- x = ma[x]
- try:
- x = int(x)
- except ValueError:
- x = int.from_bytes(x, byteorder='big')
- return x
-
-def calcexp(exp, ma: dict) -> int:
- """
- Evaluate simple mathematical expression given
- in `exp` with variables assigned in the dict `ma`
-
- Returns int
- """
- exp = str(exp)
- if "*" in exp:
- assert "+" not in exp
- result = 1
- for term in exp.split("*"):
- result *= handlesingle(term, ma)
- return result
- return sum(handlesingle(x, ma) for x in exp.split("+"))
-
-def make_handler(k: str, v: dict) -> Callable[[bytes], Tuple[str, dict]]:
- """
- Generate a message handler function (taking bytes)
- for message type `k` with specification `v`
-
- Check lib/lightning.json, `k` could be 'init',
- and `v` could be
-
- { type: 16, payload: { 'gflen': ..., ... }, ... }
-
- Returns function taking bytes
- """
- def handler(data: bytes) -> Tuple[str, dict]:
- nonlocal k, v
- ma = {}
- pos = 0
- for fieldname in v["payload"]:
- poslenMap = v["payload"][fieldname]
- if "feature" in poslenMap and pos == len(data):
- continue
- #print(poslenMap["position"], ma)
- assert pos == calcexp(poslenMap["position"], ma)
- length = poslenMap["length"]
- length = calcexp(length, ma)
- ma[fieldname] = data[pos:pos+length]
- pos += length
- assert pos == len(data), (k, pos, len(data))
- return k, ma
- return handler
-
-path = os.path.join(os.path.dirname(__file__), 'lightning.json')
-with open(path) as f:
- structured = json.loads(f.read(), object_pairs_hook=OrderedDict)
-
-for k in structured:
- v = structured[k]
- # these message types are skipped since their types collide
- # (for example with pong, which also uses type=19)
- # we don't need them yet
- if k in ["final_incorrect_cltv_expiry", "final_incorrect_htlc_amount"]:
- continue
- if len(v["payload"]) == 0:
- continue
- try:
- num = int(v["type"])
- except ValueError:
- #print("skipping", k)
- continue
- byts = num.to_bytes(2, 'big')
- assert byts not in message_types, (byts, message_types[byts].__name__, k)
- names = [x.__name__ for x in message_types.values()]
- assert k + "_handler" not in names, (k, names)
- message_types[byts] = make_handler(k, v)
- message_types[byts].__name__ = k + "_handler"
-
-assert message_types[b"\x00\x10"].__name__ == "init_handler"
-
-def decode_msg(data: bytes) -> Tuple[str, dict]:
- """
- Decode Lightning message by reading the first
- two bytes to determine message type.
-
- Returns message type string and parsed message contents dict
- """
- typ = data[:2]
- k, parsed = message_types[typ](data[2:])
- return k, parsed
-
-def gen_msg(msg_type: str, **kwargs) -> bytes:
- """
- Encode kwargs into a Lightning message (bytes)
- of the type given in the msg_type string
- """
- typ = structured[msg_type]
- data = int(typ["type"]).to_bytes(2, 'big')
- lengths = {}
- for k in typ["payload"]:
- poslenMap = typ["payload"][k]
- if "feature" in poslenMap: continue
- leng = calcexp(poslenMap["length"], lengths)
- try:
- clone = dict(lengths)
- clone.update(kwargs)
- leng = calcexp(poslenMap["length"], clone)
- except KeyError:
- pass
- try:
- param = kwargs[k]
- except KeyError:
- param = 0
- try:
- if not isinstance(param, bytes):
- assert isinstance(param, int), "field {} is neither bytes or int".format(k)
- param = param.to_bytes(leng, 'big')
- except ValueError:
- raise Exception("{} does not fit in {} bytes".format(k, leng))
- lengths[k] = len(param)
- if lengths[k] != leng:
- raise Exception("field {} is {} bytes long, should be {} bytes long".format(k, lengths[k], leng))
- data += param
- return data
-
-
-
class Peer(PrintError):
def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase,
@@ -229,7 +83,7 @@ class Peer(PrintError):
def send_message(self, message_name: str, **kwargs):
assert type(message_name) is str
self.print_error("Sending '%s'"%message_name.upper())
- self.transport.send_bytes(gen_msg(message_name, **kwargs))
+ self.transport.send_bytes(encode_msg(message_name, **kwargs))
async def initialize(self):
if isinstance(self.transport, LNTransport):
@@ -872,7 +726,7 @@ class Peer(PrintError):
else:
node_ids = self.node_ids
- chan_ann = gen_msg("channel_announcement",
+ chan_ann = encode_msg("channel_announcement",
len=0,
#features not set (defaults to zeros)
chain_hash=constants.net.rev_genesis_bytes(),
diff --git a/electrum/lnchannelverifier.py b/electrum/lnchannelverifier.py
@@ -39,6 +39,7 @@ from .verifier import verify_tx_is_in_block, MerkleVerificationFailure
from .transaction import Transaction
from .interface import GracefulDisconnect
from .crypto import sha256d
+from .lnmsg import encode_msg
if TYPE_CHECKING:
from .network import Network
@@ -184,7 +185,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
def verify_sigs_for_channel_announcement(chan_ann: dict) -> bool:
- msg_bytes = lnbase.gen_msg('channel_announcement', **chan_ann)
+ msg_bytes = encode_msg('channel_announcement', **chan_ann)
pre_hash = msg_bytes[2+256:]
h = sha256d(pre_hash)
pubkeys = [chan_ann['node_id_1'], chan_ann['node_id_2'], chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2']]
@@ -196,7 +197,7 @@ def verify_sigs_for_channel_announcement(chan_ann: dict) -> bool:
def verify_sig_for_channel_update(chan_upd: dict, node_id: bytes) -> bool:
- msg_bytes = lnbase.gen_msg('channel_update', **chan_upd)
+ msg_bytes = encode_msg('channel_update', **chan_upd)
pre_hash = msg_bytes[2+64:]
h = sha256d(pre_hash)
sig = chan_upd['signature']
diff --git a/electrum/lnmsg.py b/electrum/lnmsg.py
@@ -0,0 +1,155 @@
+import json
+import os
+from typing import Callable, Tuple
+from collections import OrderedDict
+
+def _eval_length_term(x, ma: dict) -> int:
+ """
+ Evaluate a term of the simple language used
+ to specify lightning message field lengths.
+
+ If `x` is an integer, it is returned as is,
+ otherwise it is treated as a variable and
+ looked up in `ma`.
+
+ If the value in `ma` was no integer, it is
+ assumed big-endian bytes and decoded.
+
+ Returns evaluated result as int
+ """
+ try:
+ x = int(x)
+ except ValueError:
+ x = ma[x]
+ try:
+ x = int(x)
+ except ValueError:
+ x = int.from_bytes(x, byteorder='big')
+ return x
+
+def _eval_exp_with_ctx(exp, ctx: dict) -> int:
+ """
+ Evaluate simple mathematical expression given
+ in `exp` with context (variables assigned)
+ from the dict `ctx`.
+
+ Returns evaluated result as int
+ """
+ exp = str(exp)
+ if "*" in exp:
+ assert "+" not in exp
+ result = 1
+ for term in exp.split("*"):
+ result *= _eval_length_term(term, ctx)
+ return result
+ return sum(_eval_length_term(x, ctx) for x in exp.split("+"))
+
+def _make_handler(k: str, v: dict) -> Callable[[bytes], Tuple[str, dict]]:
+ """
+ Generate a message handler function (taking bytes)
+ for message type `k` with specification `v`
+
+ Check lib/lightning.json, `k` could be 'init',
+ and `v` could be
+
+ { type: 16, payload: { 'gflen': ..., ... }, ... }
+
+ Returns function taking bytes
+ """
+ def handler(data: bytes) -> Tuple[str, dict]:
+ nonlocal k, v
+ ma = {}
+ pos = 0
+ for fieldname in v["payload"]:
+ poslenMap = v["payload"][fieldname]
+ if "feature" in poslenMap and pos == len(data):
+ continue
+ assert pos == _eval_exp_with_ctx(poslenMap["position"], ma)
+ length = poslenMap["length"]
+ length = _eval_exp_with_ctx(length, ma)
+ ma[fieldname] = data[pos:pos+length]
+ pos += length
+ assert pos == len(data), (k, pos, len(data))
+ return k, ma
+ return handler
+
+class LNSerializer:
+ def __init__(self):
+ message_types = {}
+ path = os.path.join(os.path.dirname(__file__), 'lightning.json')
+ with open(path) as f:
+ structured = json.loads(f.read(), object_pairs_hook=OrderedDict)
+
+ for k in structured:
+ v = structured[k]
+ # these message types are skipped since their types collide
+ # (for example with pong, which also uses type=19)
+ # we don't need them yet
+ if k in ["final_incorrect_cltv_expiry", "final_incorrect_htlc_amount"]:
+ continue
+ if len(v["payload"]) == 0:
+ continue
+ try:
+ num = int(v["type"])
+ except ValueError:
+ #print("skipping", k)
+ continue
+ byts = num.to_bytes(2, 'big')
+ assert byts not in message_types, (byts, message_types[byts].__name__, k)
+ names = [x.__name__ for x in message_types.values()]
+ assert k + "_handler" not in names, (k, names)
+ message_types[byts] = _make_handler(k, v)
+ message_types[byts].__name__ = k + "_handler"
+
+ assert message_types[b"\x00\x10"].__name__ == "init_handler"
+ self.structured = structured
+ self.message_types = message_types
+
+ def encode_msg(self, msg_type : str, **kwargs) -> bytes:
+ """
+ Encode kwargs into a Lightning message (bytes)
+ of the type given in the msg_type string
+ """
+ typ = self.structured[msg_type]
+ data = int(typ["type"]).to_bytes(2, 'big')
+ lengths = {}
+ for k in typ["payload"]:
+ poslenMap = typ["payload"][k]
+ if "feature" in poslenMap: continue
+ leng = _eval_exp_with_ctx(poslenMap["length"], lengths)
+ try:
+ clone = dict(lengths)
+ clone.update(kwargs)
+ leng = _eval_exp_with_ctx(poslenMap["length"], clone)
+ except KeyError:
+ pass
+ try:
+ param = kwargs[k]
+ except KeyError:
+ param = 0
+ try:
+ if not isinstance(param, bytes):
+ assert isinstance(param, int), "field {} is neither bytes or int".format(k)
+ param = param.to_bytes(leng, 'big')
+ except ValueError:
+ raise Exception("{} does not fit in {} bytes".format(k, leng))
+ lengths[k] = len(param)
+ if lengths[k] != leng:
+ raise Exception("field {} is {} bytes long, should be {} bytes long".format(k, lengths[k], leng))
+ data += param
+ return data
+
+ def decode_msg(self, data : bytes) -> Tuple[str, dict]:
+ """
+ Decode Lightning message by reading the first
+ two bytes to determine message type.
+
+ Returns message type string and parsed message contents dict
+ """
+ typ = data[:2]
+ k, parsed = self.message_types[typ](data[2:])
+ return k, parsed
+
+_inst = LNSerializer()
+encode_msg = _inst.encode_msg
+decode_msg = _inst.decode_msg
diff --git a/electrum/tests/test_lnbase.py b/electrum/tests/test_lnbase.py
@@ -13,12 +13,13 @@ from electrum.lnaddr import lnencode, LnAddr, lndecode
from electrum.bitcoin import COIN, sha256
from electrum.util import bh2u
-from electrum.lnbase import Peer, decode_msg, gen_msg
+from electrum.lnbase import Peer
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure
from electrum.lnrouter import ChannelDB, LNPathFinder
from electrum.lnworker import LNWorker
+from electrum.lnmsg import encode_msg, decode_msg
from .test_lnchan import create_test_channels
@@ -135,7 +136,7 @@ class NoFeaturesTransport(MockTransport):
decoded = decode_msg(data)
print(decoded)
if decoded[0] == 'init':
- self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
+ self.queue.put_nowait(encode_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
class PutIntoOthersQueueTransport(MockTransport):
def __init__(self):