commit f70e679ababa033ccd99840d6149f9d72fef7f6f
parent f3d1f71e94f3f4e70776352ac1ad889e0bc45e2f
Author: SomberNight <somber.night@protonmail.com>
Date: Mon, 22 Oct 2018 15:35:57 +0200
some more type annotations that needed conditional imports
Diffstat:
8 files changed, 65 insertions(+), 33 deletions(-)
diff --git a/electrum/lnbase.py b/electrum/lnbase.py
@@ -10,7 +10,7 @@ import asyncio
import os
import time
from functools import partial
-from typing import List, Tuple, Dict
+from typing import List, Tuple, Dict, TYPE_CHECKING
import traceback
import sys
@@ -31,10 +31,13 @@ from .lnutil import (Outpoint, LocalConfig, ChannelConfig,
funding_output_script, get_per_commitment_secret_from_seed,
secret_to_pubkey, LNPeerAddr, PaymentFailure, LnLocalFeatures,
LOCAL, REMOTE, HTLCOwner, generate_keypair, LnKeyFamily,
- get_ln_flag_pair_of_bit, privkey_to_pubkey, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_ACCEPTED)
-from .lnutil import LightningPeerConnectionClosed, HandshakeFailed
+ get_ln_flag_pair_of_bit, privkey_to_pubkey, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_ACCEPTED,
+ LightningPeerConnectionClosed, HandshakeFailed, LNPeerAddr)
from .lnrouter import NotFoundChanAnnouncementForUpdate, RouteEdge
-from .lntransport import LNTransport
+from .lntransport import LNTransport, LNTransportBase
+
+if TYPE_CHECKING:
+ from .lnworker import LNWorker
def channel_id_from_funding_tx(funding_txid, funding_index):
@@ -191,7 +194,8 @@ def gen_msg(msg_type: str, **kwargs) -> bytes:
class Peer(PrintError):
- def __init__(self, lnworker, peer_addr, request_initial_sync=False, transport=None):
+ def __init__(self, lnworker: 'LNWorker', peer_addr: LNPeerAddr,
+ request_initial_sync=False, transport: LNTransportBase=None):
self.initialized = asyncio.Future()
self.transport = transport
self.peer_addr = peer_addr
@@ -357,7 +361,7 @@ class Peer(PrintError):
def close_and_cleanup(self):
try:
if self.transport:
- self.transport.writer.close()
+ self.transport.close()
except:
pass
for chan in self.channels.values():
diff --git a/electrum/lnchan.py b/electrum/lnchan.py
@@ -3,7 +3,7 @@ from collections import namedtuple, defaultdict
import binascii
import json
from enum import Enum, auto
-from typing import Optional
+from typing import Optional, Mapping, List
from .util import bfh, PrintError, bh2u
from .bitcoin import Hash, TYPE_SCRIPT, TYPE_ADDRESS
diff --git a/electrum/lnchannelverifier.py b/electrum/lnchannelverifier.py
@@ -25,6 +25,7 @@
import asyncio
import threading
+from typing import TYPE_CHECKING
import aiorpcx
@@ -38,6 +39,10 @@ from .verifier import verify_tx_is_in_block, MerkleVerificationFailure
from .transaction import Transaction
from .interface import GracefulDisconnect
+if TYPE_CHECKING:
+ from .network import Network
+ from .lnrouter import ChannelDB
+
class LNChannelVerifier(NetworkJobOnDefaultServer):
""" Verify channel announcements for the Channel DB """
@@ -46,7 +51,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
# will start throttling us, making it even slower. one option would be to
# spread it over multiple servers.
- def __init__(self, network, channel_db):
+ def __init__(self, network: 'Network', channel_db: 'ChannelDB'):
NetworkJobOnDefaultServer.__init__(self, network)
self.channel_db = channel_db
self.lock = threading.Lock()
@@ -105,7 +110,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
await self.group.spawn(self.verify_channel(block_height, tx_pos, short_channel_id))
#self.print_error('requested short_channel_id', bh2u(short_channel_id))
- async def verify_channel(self, block_height, tx_pos, short_channel_id):
+ async def verify_channel(self, block_height: int, tx_pos: int, short_channel_id: bytes):
# we are verifying channel announcements as they are from untrusted ln peers.
# we use electrum servers to do this. however we don't trust electrum servers either...
try:
diff --git a/electrum/lnonion.py b/electrum/lnonion.py
@@ -24,7 +24,7 @@
# SOFTWARE.
import hashlib
-from typing import Sequence, List, Tuple, NamedTuple
+from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING
from enum import IntEnum, IntFlag
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
@@ -34,7 +34,9 @@ from . import ecc
from .crypto import sha256, hmac_oneshot
from .util import bh2u, profiler, xor_bytes, bfh
from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH, NUM_MAX_EDGES_IN_PAYMENT_PATH
-from .lnrouter import RouteEdge
+
+if TYPE_CHECKING:
+ from .lnrouter import RouteEdge
HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04
@@ -186,7 +188,7 @@ def new_onion_packet(payment_path_pubkeys: Sequence[bytes], session_key: bytes,
hmac=next_hmac)
-def calc_hops_data_for_payment(route: List[RouteEdge], amount_msat: int, final_cltv: int) \
+def calc_hops_data_for_payment(route: List['RouteEdge'], amount_msat: int, final_cltv: int) \
-> Tuple[List[OnionHopsDataSingle], int, int]:
"""Returns the hops_data to be used for constructing an onion packet,
and the amount_msat and cltv to be used on our immediate channel.
diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
@@ -27,8 +27,8 @@ import queue
import os
import json
import threading
-from collections import namedtuple, defaultdict
-from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple
+from collections import defaultdict
+from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING
import binascii
import base64
import asyncio
@@ -41,6 +41,10 @@ from .crypto import Hash
from . import ecc
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH
+if TYPE_CHECKING:
+ from .lnchan import Channel
+ from .network import Network
+
class UnknownEvenFeatureBits(Exception): pass
@@ -272,7 +276,7 @@ class ChannelDB(JsonDB):
NUM_MAX_RECENT_PEERS = 20
- def __init__(self, network):
+ def __init__(self, network: 'Network'):
self.network = network
path = os.path.join(get_headers_dir(network.config), 'channel_db')
@@ -597,7 +601,7 @@ class LNPathFinder(PrintError):
@profiler
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
invoice_amount_msat: int,
- my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]:
+ my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]:
"""Return a path from nodeA to nodeB.
Returns a list of (node_id, short_channel_id) representing a path.
diff --git a/electrum/lntransport.py b/electrum/lntransport.py
@@ -1,10 +1,11 @@
-import hmac
import hashlib
+from asyncio import StreamReader, StreamWriter
+
import cryptography.hazmat.primitives.ciphers.aead as AEAD
-from .crypto import sha256
-from .lnutil import get_ecdh, privkey_to_pubkey
-from .lnutil import LightningPeerConnectionClosed, HandshakeFailed
+from .crypto import sha256, hmac_oneshot
+from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed,
+ HandshakeFailed)
from . import ecc
from .util import bh2u
@@ -49,13 +50,13 @@ def get_bolt8_hkdf(salt, ikm):
Return as two 32 byte fields.
"""
#Extract
- prk = hmac.new(salt, msg=ikm, digestmod=hashlib.sha256).digest()
+ prk = hmac_oneshot(salt, msg=ikm, digest=hashlib.sha256)
assert len(prk) == 32
#Expand
info = b""
T0 = b""
- T1 = hmac.new(prk, T0 + info + b"\x01", digestmod=hashlib.sha256).digest()
- T2 = hmac.new(prk, T1 + info + b"\x02", digestmod=hashlib.sha256).digest()
+ T1 = hmac_oneshot(prk, T0 + info + b"\x01", digest=hashlib.sha256)
+ T2 = hmac_oneshot(prk, T1 + info + b"\x02", digest=hashlib.sha256)
assert len(T1 + T2) == 64
return T1, T2
@@ -76,6 +77,11 @@ def create_ephemeral_key() -> (bytes, bytes):
return privkey.get_secret_bytes(), privkey.get_public_key_bytes()
class LNTransportBase:
+
+ def __init__(self, reader: StreamReader, writer: StreamWriter):
+ self.reader = reader
+ self.writer = writer
+
def send_bytes(self, msg):
l = len(msg).to_bytes(2, 'big')
lc = aead_encrypt(self.sk, self.sn(), b'', l)
@@ -132,11 +138,14 @@ class LNTransportBase:
self.r_ck = ck
self.s_ck = ck
+ def close(self):
+ self.writer.close()
+
+
class LNResponderTransport(LNTransportBase):
- def __init__(self, privkey, reader, writer):
+ def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
+ LNTransportBase.__init__(self, reader, writer)
self.privkey = privkey
- self.reader = reader
- self.writer = writer
async def handshake(self, **kwargs):
hs = HandshakeState(privkey_to_pubkey(self.privkey))
@@ -187,12 +196,12 @@ class LNResponderTransport(LNTransportBase):
return rs
class LNTransport(LNTransportBase):
- def __init__(self, privkey, remote_pubkey, reader, writer):
+ def __init__(self, privkey: bytes, remote_pubkey: bytes,
+ reader: StreamReader, writer: StreamWriter):
+ LNTransportBase.__init__(self, reader, writer)
assert type(privkey) is bytes and len(privkey) == 32
self.privkey = privkey
self.remote_pubkey = remote_pubkey
- self.reader = reader
- self.writer = writer
async def handshake(self):
hs = HandshakeState(self.remote_pubkey)
diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py
@@ -1,5 +1,5 @@
import threading
-from typing import NamedTuple, Iterable
+from typing import NamedTuple, Iterable, TYPE_CHECKING
import os
from collections import defaultdict
import asyncio
@@ -11,6 +11,9 @@ from . import wallet
from .storage import WalletStorage
from .address_synchronizer import AddressSynchronizer
+if TYPE_CHECKING:
+ from .network import Network
+
TX_MINED_STATUS_DEEP, TX_MINED_STATUS_SHALLOW, TX_MINED_STATUS_MEMPOOL, TX_MINED_STATUS_FREE = range(0, 4)
@@ -21,7 +24,7 @@ class LNWatcher(PrintError):
# maybe we should disconnect from server in these cases
verbosity_filter = 'W'
- def __init__(self, network):
+ def __init__(self, network: 'Network'):
self.network = network
self.config = network.config
path = os.path.join(network.config.path, "watcher_db")
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -3,7 +3,7 @@ import os
from decimal import Decimal
import random
import time
-from typing import Optional, Sequence, Tuple, List, Dict
+from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
import threading
import socket
@@ -31,6 +31,11 @@ from .lnaddr import lndecode
from .i18n import _
from .lnrouter import RouteEdge, is_route_sane_to_use
+if TYPE_CHECKING:
+ from .network import Network
+ from .wallet import Abstract_Wallet
+
+
NUM_PEERS_TARGET = 4
PEER_RETRY_INTERVAL = 600 # seconds
PEER_RETRY_INTERVAL_FOR_CHANNELS = 30 # seconds
@@ -45,7 +50,7 @@ FALLBACK_NODE_LIST_MAINNET = (
class LNWorker(PrintError):
- def __init__(self, wallet, network):
+ def __init__(self, wallet: 'Abstract_Wallet', network: 'Network'):
self.wallet = wallet
self.sweep_address = wallet.get_receiving_address()
self.network = network