electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit dd7c4b3bab7232366bfd6be9e222be7ae3b357ba
parent d94e40d2be2774aeec63a71add57d6071f94f3d4
Author: Janus <ysangkok@gmail.com>
Date:   Fri,  1 Feb 2019 20:59:59 +0100

sqlite in lnrouter

Diffstat:
Melectrum/gui/qt/channels_list.py | 10+++-------
Melectrum/lnpeer.py | 33++++++++++++++++++---------------
Melectrum/lnrouter.py | 691+++++++++++++++++++++++++++++++++++++++----------------------------------------
Melectrum/lnverifier.py | 28+++++++++++++---------------
Melectrum/lnworker.py | 61++++++++++++++++++++++++++++++++-----------------------------
Melectrum/network.py | 1-
Melectrum/tests/test_lnpeer.py | 8++++----
Melectrum/tests/test_lnrouter.py | 14++++++++++----
8 files changed, 424 insertions(+), 422 deletions(-)

diff --git a/electrum/gui/qt/channels_list.py b/electrum/gui/qt/channels_list.py @@ -105,25 +105,21 @@ class ChannelsList(MyTreeView): def update_status(self): channel_db = self.parent.network.channel_db - num_nodes = len(channel_db.nodes) - num_channels = len(channel_db) num_peers = len(self.parent.wallet.lnworker.peers) - msg = _('{} peers, {} nodes, {} channels.').format(num_peers, num_nodes, num_channels) + msg = _('{} peers, {} nodes, {} channels.').format(num_peers, channel_db.num_nodes, channel_db.num_channels) self.status.setText(msg) def statistics_dialog(self): channel_db = self.parent.network.channel_db - num_nodes = len(channel_db.nodes) - num_channels = len(channel_db) capacity = self.parent.format_amount(channel_db.capacity()) + ' '+ self.parent.base_unit() d = WindowModalDialog(self.parent, _('Lightning Network Statistics')) d.setMinimumWidth(400) vbox = QVBoxLayout(d) h = QGridLayout() h.addWidget(QLabel(_('Nodes') + ':'), 0, 0) - h.addWidget(QLabel('{}'.format(num_nodes)), 0, 1) + h.addWidget(QLabel('{}'.format(channel_db.num_nodes)), 0, 1) h.addWidget(QLabel(_('Channels') + ':'), 1, 0) - h.addWidget(QLabel('{}'.format(num_channels)), 1, 1) + h.addWidget(QLabel('{}'.format(channel_db.num_channels)), 1, 1) h.addWidget(QLabel(_('Capacity') + ':'), 2, 0) h.addWidget(QLabel(capacity), 2, 1) vbox.addLayout(h) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -55,6 +55,10 @@ class Peer(PrintError): def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase, request_initial_sync=False): self.initialized = asyncio.Event() + self.node_anns = [] + self.chan_anns = [] + self.chan_upds = [] + self.last_chan_db_upd = time.time() self.transport = transport self.pubkey = pubkey self.lnworker = lnworker @@ -152,10 +156,6 @@ class Peer(PrintError): if channel_id not in self.funding_created: raise Exception("Got unknown funding_created") self.funding_created[channel_id].put_nowait(payload) - def on_node_announcement(self, payload): - self.channel_db.on_node_announcement(payload) - self.network.trigger_callback('ln_status') - def on_init(self, payload): if self.initialized.is_set(): self.print_error("ALREADY INITIALIZED BUT RECEIVED INIT") @@ -175,20 +175,14 @@ class Peer(PrintError): self.send_message('gossip_timestamp_filter', chain_hash=constants.net.rev_genesis_bytes(), first_timestamp=first_timestamp, timestamp_range=b"\xff"*4) self.initialized.set() + def on_node_announcement(self, payload): + self.node_anns.append(payload) + def on_channel_update(self, payload): - try: - self.channel_db.on_channel_update(payload) - except NotFoundChanAnnouncementForUpdate: - # If it's for a direct channel with this peer, save it for later, as it might be - # for our own channel (and we might not yet know the short channel id for that) - short_channel_id = payload['short_channel_id'] - self.print_error("not found channel announce for channel update in db", bh2u(short_channel_id)) - self.orphan_channel_updates[short_channel_id] = payload - while len(self.orphan_channel_updates) > 10: - self.orphan_channel_updates.popitem(last=False) + self.chan_upds.append(payload) def on_channel_announcement(self, payload): - self.channel_db.on_channel_announcement(payload) + self.chan_anns.append(payload) def on_announcement_signatures(self, payload): channel_id = payload['channel_id'] @@ -230,6 +224,15 @@ class Peer(PrintError): # loop async for msg in self.transport.read_messages(): self.process_message(msg) + await asyncio.sleep(.01) + if time.time() - self.last_chan_db_upd > 5: + self.last_chan_db_upd = time.time() + self.channel_db.on_node_announcement(self.node_anns) + self.node_anns = [] + self.channel_db.on_channel_announcement(self.chan_anns) + self.chan_anns = [] + self.channel_db.on_channel_update(self.chan_upds) + self.chan_upds = [] self.ping_if_required() def close_and_cleanup(self): diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py @@ -23,6 +23,8 @@ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import datetime +import random import queue import os import json @@ -33,6 +35,14 @@ import binascii import base64 import asyncio +from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm.query import Query +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.sql import not_, or_ +from sqlalchemy.orm import scoped_session + from . import constants from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits from .storage import JsonDB @@ -41,112 +51,113 @@ from .crypto import sha256d from . import ecc from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH, NotFoundChanAnnouncementForUpdate) +from .lnmsg import encode_msg if TYPE_CHECKING: from .lnchannel import Channel from .network import Network - class UnknownEvenFeatureBits(Exception): pass - - - -class ChannelInfo(PrintError): - - def __init__(self, channel_announcement_payload): - self.features_len = channel_announcement_payload['len'] - self.features = channel_announcement_payload['features'] - enabled_features = list_enabled_bits(int.from_bytes(self.features, "big")) - for fbit in enabled_features: - if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0: - raise UnknownEvenFeatureBits() - - self.channel_id = channel_announcement_payload['short_channel_id'] - self.node_id_1 = channel_announcement_payload['node_id_1'] - self.node_id_2 = channel_announcement_payload['node_id_2'] - assert type(self.node_id_1) is bytes - assert type(self.node_id_2) is bytes - assert list(sorted([self.node_id_1, self.node_id_2])) == [self.node_id_1, self.node_id_2] - - self.bitcoin_key_1 = channel_announcement_payload['bitcoin_key_1'] - self.bitcoin_key_2 = channel_announcement_payload['bitcoin_key_2'] - - # this field does not get persisted - self.msg_payload = channel_announcement_payload - - self.capacity_sat = None - self.policy_node1 = None - self.policy_node2 = None - - def to_json(self) -> dict: - d = {} - d['short_channel_id'] = bh2u(self.channel_id) - d['node_id_1'] = bh2u(self.node_id_1) - d['node_id_2'] = bh2u(self.node_id_2) - d['len'] = bh2u(self.features_len) - d['features'] = bh2u(self.features) - d['bitcoin_key_1'] = bh2u(self.bitcoin_key_1) - d['bitcoin_key_2'] = bh2u(self.bitcoin_key_2) - d['policy_node1'] = self.policy_node1 - d['policy_node2'] = self.policy_node2 - d['capacity_sat'] = self.capacity_sat - return d - - @classmethod - def from_json(cls, d: dict): - d2 = {} - d2['short_channel_id'] = bfh(d['short_channel_id']) - d2['node_id_1'] = bfh(d['node_id_1']) - d2['node_id_2'] = bfh(d['node_id_2']) - d2['len'] = bfh(d['len']) - d2['features'] = bfh(d['features']) - d2['bitcoin_key_1'] = bfh(d['bitcoin_key_1']) - d2['bitcoin_key_2'] = bfh(d['bitcoin_key_2']) - ci = ChannelInfo(d2) - ci.capacity_sat = d['capacity_sat'] - ci.policy_node1 = ChannelInfoDirectedPolicy.from_json(d['policy_node1']) - ci.policy_node2 = ChannelInfoDirectedPolicy.from_json(d['policy_node2']) - return ci - - def set_capacity(self, capacity): - self.capacity_sat = capacity - - def on_channel_update(self, msg_payload, trusted=False): - assert self.channel_id == msg_payload['short_channel_id'] - flags = int.from_bytes(msg_payload['channel_flags'], 'big') - direction = flags & ChannelInfoDirectedPolicy.FLAG_DIRECTION - new_policy = ChannelInfoDirectedPolicy(msg_payload) +class NoChannelPolicy(Exception): + def __init__(self, short_channel_id: bytes): + super().__init__(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}') + +def validate_features(features : int): + enabled_features = list_enabled_bits(features) + for fbit in enabled_features: + if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0: + raise UnknownEvenFeatureBits() + +Base = declarative_base() +session_factory = sessionmaker() +DBSession = scoped_session(session_factory) +engine = None + +FLAG_DISABLE = 1 << 1 +FLAG_DIRECTION = 1 << 0 + +class ChannelInfoInDB(Base): + __tablename__ = 'channel_info' + short_channel_id = Column(String(64), primary_key=True) + node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) + node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) + capacity_sat = Column(Integer) + msg_payload_hex = Column(String(1024), nullable=False) + trusted = Column(Boolean, nullable=False) + + @staticmethod + def from_msg(channel_announcement_payload): + features = int.from_bytes(channel_announcement_payload['features'], 'big') + validate_features(features) + + channel_id = channel_announcement_payload['short_channel_id'].hex() + node_id_1 = channel_announcement_payload['node_id_1'].hex() + node_id_2 = channel_announcement_payload['node_id_2'].hex() + assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] + + msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex() + + capacity_sat = None + + return ChannelInfoInDB(short_channel_id = channel_id, node1_id = node_id_1, + node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, + trusted = False) + + @property + def msg_payload(self): + return bytes.fromhex(self.msg_payload_hex) + + def on_channel_update(self, msg: dict, trusted=False): + assert self.short_channel_id == msg['short_channel_id'].hex() + flags = int.from_bytes(msg['channel_flags'], 'big') + direction = flags & FLAG_DIRECTION if direction == 0: - old_policy = self.policy_node1 - node_id = self.node_id_1 + node_id = self.node1_id else: - old_policy = self.policy_node2 - node_id = self.node_id_2 - if old_policy and old_policy.timestamp >= new_policy.timestamp: + node_id = self.node2_id + new_policy = Policy.from_msg(msg, node_id, self.short_channel_id) + old_policy = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node=node_id).one_or_none() + if not old_policy: + DBSession.add(new_policy) + return + if old_policy.timestamp >= new_policy.timestamp: return # ignore - if not trusted and not verify_sig_for_channel_update(msg_payload, node_id): + if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)): return # ignore - # save new policy - if direction == 0: - self.policy_node1 = new_policy - else: - self.policy_node2 = new_policy - - def get_policy_for_node(self, node_id: bytes) -> Optional['ChannelInfoDirectedPolicy']: - if node_id == self.node_id_1: - return self.policy_node1 - elif node_id == self.node_id_2: - return self.policy_node2 - else: - raise Exception('node_id {} not in channel {}'.format(node_id, self.channel_id)) - - -class ChannelInfoDirectedPolicy: - - FLAG_DIRECTION = 1 << 0 - FLAG_DISABLE = 1 << 1 - - def __init__(self, channel_update_payload): + old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta + old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat + old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat + old_policy.fee_base_msat = new_policy.fee_base_msat + old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths + old_policy.channel_flags = new_policy.channel_flags + old_policy.timestamp = new_policy.timestamp + + def get_policy_for_node(self, node) -> Optional['Policy']: + """ + raises when initiator/non-initiator both unequal node + """ + if node.hex() not in (self.node1_id, self.node2_id): + raise Exception("the given node is not a party in this channel") + n1 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node1_id).one_or_none() + if n1: + return n1 + n2 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none() + return n2 + +class Policy(Base): + __tablename__ = 'policy' + start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) + short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True) + cltv_expiry_delta = Column(Integer, nullable=False) + htlc_minimum_msat = Column(Integer, nullable=False) + htlc_maximum_msat = Column(Integer) + fee_base_msat = Column(Integer, nullable=False) + fee_proportional_millionths = Column(Integer, nullable=False) + channel_flags = Column(Integer, nullable=False) + timestamp = Column(DateTime, nullable=False) + + @staticmethod + def from_msg(channel_update_payload, start_node, short_channel_id): cltv_expiry_delta = channel_update_payload['cltv_expiry_delta'] htlc_minimum_msat = channel_update_payload['htlc_minimum_msat'] fee_base_msat = channel_update_payload['fee_base_msat'] @@ -155,61 +166,52 @@ class ChannelInfoDirectedPolicy: timestamp = channel_update_payload['timestamp'] htlc_maximum_msat = channel_update_payload.get('htlc_maximum_msat') # optional - self.cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big") - self.htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big") - self.htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None - self.fee_base_msat = int.from_bytes(fee_base_msat, "big") - self.fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big") - self.channel_flags = int.from_bytes(channel_flags, "big") - self.timestamp = int.from_bytes(timestamp, "big") - - self.disabled = self.channel_flags & self.FLAG_DISABLE - - def to_json(self) -> dict: - d = {} - d['cltv_expiry_delta'] = self.cltv_expiry_delta - d['htlc_minimum_msat'] = self.htlc_minimum_msat - d['fee_base_msat'] = self.fee_base_msat - d['fee_proportional_millionths'] = self.fee_proportional_millionths - d['channel_flags'] = self.channel_flags - d['timestamp'] = self.timestamp - if self.htlc_maximum_msat: - d['htlc_maximum_msat'] = self.htlc_maximum_msat - return d - - @classmethod - def from_json(cls, d: dict): - if d is None: return None - d2 = {} - d2['cltv_expiry_delta'] = d['cltv_expiry_delta'].to_bytes(2, "big") - d2['htlc_minimum_msat'] = d['htlc_minimum_msat'].to_bytes(8, "big") - d2['htlc_maximum_msat'] = d['htlc_maximum_msat'].to_bytes(8, "big") if d.get('htlc_maximum_msat') else None - d2['fee_base_msat'] = d['fee_base_msat'].to_bytes(4, "big") - d2['fee_proportional_millionths'] = d['fee_proportional_millionths'].to_bytes(4, "big") - d2['channel_flags'] = d['channel_flags'].to_bytes(1, "big") - d2['timestamp'] = d['timestamp'].to_bytes(4, "big") - return ChannelInfoDirectedPolicy(d2) - - -class NodeInfo(PrintError): - - def __init__(self, node_announcement_payload, addresses_already_parsed=False): - self.pubkey = node_announcement_payload['node_id'] - self.features_len = node_announcement_payload['flen'] - self.features = node_announcement_payload['features'] - enabled_features = list_enabled_bits(int.from_bytes(self.features, "big")) - for fbit in enabled_features: - if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0: - raise UnknownEvenFeatureBits() + cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big") + htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big") + htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None + fee_base_msat = int.from_bytes(fee_base_msat, "big") + fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big") + channel_flags = int.from_bytes(channel_flags, "big") + timestamp = datetime.datetime.fromtimestamp(int.from_bytes(timestamp, "big")) + + return Policy(start_node=start_node, + short_channel_id=short_channel_id, + cltv_expiry_delta=cltv_expiry_delta, + htlc_minimum_msat=htlc_minimum_msat, + fee_base_msat=fee_base_msat, + fee_proportional_millionths=fee_proportional_millionths, + channel_flags=channel_flags, + timestamp=timestamp, + htlc_maximum_msat=htlc_maximum_msat) + + def is_disabled(self): + return self.channel_flags & FLAG_DISABLE + +class NodeInfoInDB(Base): + __tablename__ = 'node_info' + node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') + features = Column(Integer, nullable=False) + timestamp = Column(Integer, nullable=False) + alias = Column(String(64), nullable=False) + + def get_addresses(self): + return DBSession.query(AddressInDB).join(NodeInfoInDB).filter_by(node_id = self.node_id).all() + + @staticmethod + def from_msg(node_announcement_payload, addresses_already_parsed=False): + node_id = node_announcement_payload['node_id'].hex() + features = int.from_bytes(node_announcement_payload['features'], "big") + validate_features(features) if not addresses_already_parsed: - self.addresses = self.parse_addresses_field(node_announcement_payload['addresses']) + addresses = NodeInfoInDB.parse_addresses_field(node_announcement_payload['addresses']) else: - self.addresses = node_announcement_payload['addresses'] - self.alias = node_announcement_payload['alias'].rstrip(b'\x00') - self.timestamp = int.from_bytes(node_announcement_payload['timestamp'], "big") + addresses = node_announcement_payload['addresses'] + alias = node_announcement_payload['alias'].rstrip(b'\x00').hex() + timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big")) + return NodeInfoInDB(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [AddressInDB(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses] - @classmethod - def parse_addresses_field(cls, addresses_field): + @staticmethod + def parse_addresses_field(addresses_field): buf = addresses_field def read(n): nonlocal buf @@ -248,243 +250,233 @@ class NodeInfo(PrintError): break return addresses - def to_json(self) -> dict: - d = {} - d['node_id'] = bh2u(self.pubkey) - d['flen'] = bh2u(self.features_len) - d['features'] = bh2u(self.features) - d['addresses'] = self.addresses - d['alias'] = bh2u(self.alias) - d['timestamp'] = self.timestamp - return d - - @classmethod - def from_json(cls, d: dict): - if d is None: return None - d2 = {} - d2['node_id'] = bfh(d['node_id']) - d2['flen'] = bfh(d['flen']) - d2['features'] = bfh(d['features']) - d2['addresses'] = d['addresses'] - d2['alias'] = bfh(d['alias']) - d2['timestamp'] = d['timestamp'].to_bytes(4, "big") - return NodeInfo(d2, addresses_already_parsed=True) - +class AddressInDB(Base): + __tablename__ = 'address' + node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) + host = Column(String(256), primary_key=True) + port = Column(Integer, primary_key=True) + last_connected_date = Column(DateTime(), nullable=False) -class ChannelDB(JsonDB): +class ChannelDB: NUM_MAX_RECENT_PEERS = 20 def __init__(self, network: 'Network'): + global engine self.network = network - path = os.path.join(get_headers_dir(network.config), 'channel_db') - JsonDB.__init__(self, path) + self.num_nodes = 0 + self.num_channels = 0 + + self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3') + engine = create_engine('sqlite:///' + self.path)#, echo=True) + DBSession.remove() + DBSession.configure(bind=engine, autoflush=False) + + Base.metadata.drop_all(engine) + Base.metadata.create_all(engine) self.lock = threading.RLock() - self._id_to_channel_info = {} # type: Dict[bytes, ChannelInfo] - self._channels_for_node = defaultdict(set) # node -> set(short_channel_id) - self.nodes = {} # node_id -> NodeInfo - self._recent_peers = [] - self._last_good_address = {} # node_id -> LNPeerAddr # (intentionally not persisted) - self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], ChannelInfoDirectedPolicy] + self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self.ca_verifier = LNChannelVerifier(network, self) - self.load_data() - - def load_data(self): - if os.path.exists(self.path): - with open(self.path, "r", encoding='utf-8') as f: - raw = f.read() - self.data = json.loads(raw) - # channels - channel_infos = self.get('channel_infos', {}) - for short_channel_id, channel_info_d in channel_infos.items(): - channel_info = ChannelInfo.from_json(channel_info_d) - short_channel_id = bfh(short_channel_id) - self.add_verified_channel_info(short_channel_id, channel_info) - # nodes - node_infos = self.get('node_infos', {}) - for node_id, node_info_d in node_infos.items(): - node_info = NodeInfo.from_json(node_info_d) - node_id = bfh(node_id) - self.nodes[node_id] = node_info - # recent peers - recent_peers = self.get('recent_peers', {}) - for host, port, pubkey in recent_peers: - peer = LNPeerAddr(str(host), int(port), bfh(pubkey)) - self._recent_peers.append(peer) - # last good address - last_good_addr = self.get('last_good_address', {}) - for node_id, host_and_port in last_good_addr.items(): - host, port = host_and_port - self._last_good_address[bfh(node_id)] = LNPeerAddr(str(host), int(port), bfh(node_id)) - - def save_data(self): - with self.lock: - # channels - channel_infos = {} - for short_channel_id, channel_info in self._id_to_channel_info.items(): - channel_infos[bh2u(short_channel_id)] = channel_info - self.put('channel_infos', channel_infos) - # nodes - node_infos = {} - for node_id, node_info in self.nodes.items(): - node_infos[bh2u(node_id)] = node_info - self.put('node_infos', node_infos) - # recent peers - recent_peers = [] - for peer in self._recent_peers: - recent_peers.append( - [str(peer.host), int(peer.port), bh2u(peer.pubkey)]) - self.put('recent_peers', recent_peers) - # last good address - last_good_addr = {} - for node_id, peer in self._last_good_address.items(): - last_good_addr[bh2u(node_id)] = [str(peer.host), int(peer.port)] - self.put('last_good_address', last_good_addr) - self.write() - - def __len__(self): - # number of channels - return len(self._id_to_channel_info) - - def capacity(self): - # capacity of the network - return sum(c.capacity_sat for c in self._id_to_channel_info.values() if c.capacity_sat is not None) - - def get_channel_info(self, channel_id: bytes) -> Optional[ChannelInfo]: - return self._id_to_channel_info.get(channel_id, None) + def update_counts(self): + self.num_channels = DBSession.query(ChannelInfoInDB).count() + self.num_nodes = DBSession.query(NodeInfoInDB).count() + + def add_recent_peer(self, peer : LNPeerAddr): + addr = DBSession.query(AddressInDB).filter_by(node_id = peer.pubkey.hex()).one_or_none() + if addr is None: + addr = AddressInDB(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now()) + else: + addr.last_connected_date = datetime.datetime.now() + DBSession.add(addr) + DBSession.commit() + + def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): + unshuffled = DBSession \ + .query(NodeInfoInDB) \ + .filter(not_(NodeInfoInDB.node_id.in_(x.hex() for x in node_ids_bytes))) \ + .limit(200) \ + .all() + return random.sample(unshuffled, len(unshuffled)) + + def nodes_get(self, node_id): + return self.network.run_from_another_thread(self._nodes_get(node_id)) + + async def _nodes_get(self, node_id): + return DBSession \ + .query(NodeInfoInDB) \ + .filter_by(node_id = node_id.hex()) \ + .one_or_none() + + def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: + adr_db = DBSession \ + .query(AddressInDB) \ + .filter_by(node_id = node_id.hex()) \ + .order_by(AddressInDB.last_connected_date.desc()) \ + .one_or_none() + if not adr_db: + return None + return LNPeerAddr(adr_db.host, adr_db.port, bytes.fromhex(adr_db.node_id)) + + def get_recent_peers(self): + return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in DBSession \ + .query(AddressInDB) \ + .select_from(NodeInfoInDB) \ + .order_by(AddressInDB.last_connected_date.desc()) \ + .limit(self.NUM_MAX_RECENT_PEERS)] + + def get_channel_info(self, channel_id: bytes): + return self.chan_query_for_id(channel_id).one_or_none() def get_channels_for_node(self, node_id): """Returns the set of channels that have node_id as one of the endpoints.""" - return self._channels_for_node[node_id] + condition = or_( + ChannelInfoInDB.node1_id == node_id.hex(), + ChannelInfoInDB.node2_id == node_id.hex()) + rows = DBSession.query(ChannelInfoInDB).filter(condition).all() + return [bytes.fromhex(x.short_channel_id) for x in rows] + + def add_verified_channel_info(self, short_id, capacity): + # called from lnchannelverifier + channel_info = self.get_channel_info(short_id) + channel_info.trusted = True + channel_info.capacity = capacity + DBSession.commit() - def add_verified_channel_info(self, short_channel_id: bytes, channel_info: ChannelInfo): - with self.lock: - self._id_to_channel_info[short_channel_id] = channel_info - self._channels_for_node[channel_info.node_id_1].add(short_channel_id) - self._channels_for_node[channel_info.node_id_2].add(short_channel_id) + @profiler + def on_channel_announcement(self, msg_payloads, trusted=False): + if type(msg_payloads) is dict: + msg_payloads = [msg_payloads] + for msg in msg_payloads: + short_channel_id = msg['short_channel_id'] + if DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = bh2u(short_channel_id)).count(): + continue + if constants.net.rev_genesis_bytes() != msg['chain_hash']: + #self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash']))) + continue + try: + channel_info = ChannelInfoInDB.from_msg(msg) + except UnknownEvenFeatureBits: + continue + channel_info.trusted = trusted + DBSession.add(channel_info) + if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload) + DBSession.commit() self.network.trigger_callback('ln_status') + self.update_counts() - def get_recent_peers(self): - with self.lock: - return list(self._recent_peers) - - def add_recent_peer(self, peer: LNPeerAddr): - with self.lock: - # list is ordered - if peer in self._recent_peers: - self._recent_peers.remove(peer) - self._recent_peers.insert(0, peer) - self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS] - self._last_good_address[peer.pubkey] = peer - - def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]: - return self._last_good_address.get(node_id, None) - - def on_channel_announcement(self, msg_payload, trusted=False): - short_channel_id = msg_payload['short_channel_id'] - if short_channel_id in self._id_to_channel_info: - return - if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']: - #self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash']))) - return - try: - channel_info = ChannelInfo(msg_payload) - except UnknownEvenFeatureBits: - return - if trusted: - self.add_verified_channel_info(short_channel_id, channel_info) - else: - self.ca_verifier.add_new_channel_info(channel_info) + @profiler + def on_channel_update(self, msg_payloads, trusted=False): + if type(msg_payloads) is dict: + msg_payloads = [msg_payloads] + short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads] + channel_infos_list = DBSession.query(ChannelInfoInDB).filter(ChannelInfoInDB.short_channel_id.in_(short_channel_ids)).all() + channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} + for msg_payload in msg_payloads: + short_channel_id = msg_payload['short_channel_id'] + if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']: + continue + channel_info = channel_infos.get(short_channel_id) + channel_info.on_channel_update(msg_payload, trusted=trusted) + DBSession.commit() - def on_channel_update(self, msg_payload, trusted=False): - short_channel_id = msg_payload['short_channel_id'] - if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']: - return - # try finding channel in pending db - channel_info = self.ca_verifier.get_pending_channel_info(short_channel_id) - if channel_info is None: - # try finding channel in verified db - channel_info = self._id_to_channel_info.get(short_channel_id, None) - if channel_info is None: - self.print_error("could not find", short_channel_id) - raise NotFoundChanAnnouncementForUpdate() - channel_info.on_channel_update(msg_payload, trusted=trusted) - - def on_node_announcement(self, msg_payload): - pubkey = msg_payload['node_id'] - signature = msg_payload['signature'] - h = sha256d(msg_payload['raw'][66:]) - if not ecc.verify_signature(pubkey, signature, h): - return - old_node_info = self.nodes.get(pubkey, None) - try: - new_node_info = NodeInfo(msg_payload) - except UnknownEvenFeatureBits: - return - # TODO if this message is for a new node, and if we have no associated - # channels for this node, we should ignore the message and return here, - # to mitigate DOS. but race condition: the channels we have for this - # node, might be under verification in self.ca_verifier, what then? - if old_node_info and old_node_info.timestamp >= new_node_info.timestamp: - return # ignore - self.nodes[pubkey] = new_node_info + @profiler + def on_node_announcement(self, msg_payloads): + if type(msg_payloads) is dict: + msg_payloads = [msg_payloads] + addresses = DBSession.query(AddressInDB).all() + have_addr = {} + for addr in addresses: + have_addr[(addr.node_id, addr.host, addr.port)] = addr + + nodes = DBSession.query(NodeInfoInDB).all() + timestamps = {} + for node in nodes: + no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")] + timestamps[bfh(node.node_id)] = datetime.datetime.strptime(no_millisecs, "%Y-%m-%d %H:%M:%S") + old_addr = None + for msg_payload in msg_payloads: + pubkey = msg_payload['node_id'] + signature = msg_payload['signature'] + h = sha256d(msg_payload['raw'][66:]) + if not ecc.verify_signature(pubkey, signature, h): + continue + try: + new_node_info, addresses = NodeInfoInDB.from_msg(msg_payload) + except UnknownEvenFeatureBits: + continue + if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp: + continue # ignore + DBSession.add(new_node_info) + for new_addr in addresses: + key = (new_addr.node_id, new_addr.host, new_addr.port) + old_addr = have_addr.get(key) + if old_addr: + # since old_addr is embedded in have_addr, + # it will still live when commmit is called + old_addr.last_connected_date = new_addr.last_connected_date + del new_addr + else: + DBSession.add(new_addr) + have_addr[key] = new_addr + # TODO if this message is for a new node, and if we have no associated + # channels for this node, we should ignore the message and return here, + # to mitigate DOS. but race condition: the channels we have for this + # node, might be under verification in self.ca_verifier, what then? + del nodes, addresses + if old_addr: + del old_addr + DBSession.commit() + self.network.trigger_callback('ln_status') + self.update_counts() def get_routing_policy_for_channel(self, start_node_id: bytes, - short_channel_id: bytes) -> Optional[ChannelInfoDirectedPolicy]: + short_channel_id: bytes) -> Optional[bytes]: if not start_node_id or not short_channel_id: return None channel_info = self.get_channel_info(short_channel_id) if channel_info is not None: return channel_info.get_policy_for_node(start_node_id) - return self._channel_updates_for_private_channels.get((start_node_id, short_channel_id)) + msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id)) + if not msg: return None + return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes): if not verify_sig_for_channel_update(msg_payload, start_node_id): return # ignore short_channel_id = msg_payload['short_channel_id'] - policy = ChannelInfoDirectedPolicy(msg_payload) - self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = policy + self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload def remove_channel(self, short_channel_id): - try: - channel_info = self._id_to_channel_info[short_channel_id] - except KeyError: - self.print_error(f'remove_channel: cannot find channel {bh2u(short_channel_id)}') - return - self._id_to_channel_info.pop(short_channel_id, None) - for node in (channel_info.node_id_1, channel_info.node_id_2): - try: - self._channels_for_node[node].remove(short_channel_id) - except KeyError: - pass + self.chan_query_for_id(short_channel_id).delete('evaluate') + DBSession.commit() + + def chan_query_for_id(self, short_channel_id) -> Query: + return DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = short_channel_id.hex()) def print_graph(self, full_ids=False): # used for debugging. # FIXME there is a race here - iterables could change size from another thread def other_node_id(node_id, channel_id): - channel_info = self._id_to_channel_info[channel_id] - if node_id == channel_info.node_id_1: - other = channel_info.node_id_2 + channel_info = self.get_channel_info(channel_id) + if node_id == channel_info.node1_id: + other = channel_info.node2_id else: - other = channel_info.node_id_1 + other = channel_info.node1_id return other if full_ids else other[-4:] - self.print_msg('node: {(channel, other_node), ...}') - for node_id, short_channel_ids in list(self._channels_for_node.items()): - short_channel_ids = {(bh2u(cid), bh2u(other_node_id(node_id, cid))) - for cid in short_channel_ids} - node_id = bh2u(node_id) if full_ids else bh2u(node_id[-4:]) - self.print_msg('{}: {}'.format(node_id, short_channel_ids)) - - self.print_msg('channel: node1, node2, direction') - for short_channel_id, channel_info in list(self._id_to_channel_info.items()): - node1 = channel_info.node_id_1 - node2 = channel_info.node_id_2 + self.print_msg('nodes') + for node in DBSession.query(NodeInfoInDB).all(): + self.print_msg(node) + + self.print_msg('channels') + for channel_info in DBSession.query(ChannelInfoInDB).all(): + node1 = channel_info.node1_id + node2 = channel_info.node2_id direction1 = channel_info.get_policy_for_node(node1) is not None direction2 = channel_info.get_policy_for_node(node2) is not None if direction1 and direction2: @@ -514,8 +506,10 @@ class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes), + (amount_msat * self.fee_proportional_millionths // 1_000_000) @classmethod - def from_channel_policy(cls, channel_policy: ChannelInfoDirectedPolicy, + def from_channel_policy(cls, channel_policy: 'Policy', short_channel_id: bytes, end_node: bytes) -> 'RouteEdge': + assert type(short_channel_id) is bytes + assert type(end_node) is bytes return RouteEdge(end_node, short_channel_id, channel_policy.fee_base_msat, @@ -582,7 +576,7 @@ class LNPathFinder(PrintError): channel_policy = channel_info.get_policy_for_node(start_node) if channel_policy is None: return float('inf'), 0 - if channel_policy.disabled: return float('inf'), 0 + if channel_policy.is_disabled(): return float('inf'), 0 route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node) if payment_amt_msat < channel_policy.htlc_minimum_msat: return float('inf'), 0 # payment amount too little @@ -611,6 +605,8 @@ class LNPathFinder(PrintError): To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1]; i.e. an element reads as, "to get to node_id, travel through short_channel_id" """ + assert type(nodeA) is bytes + assert type(nodeB) is bytes assert type(invoice_amount_msat) is int if my_channels is None: my_channels = [] my_channels = {chan.short_channel_id: chan for chan in my_channels} @@ -657,9 +653,10 @@ class LNPathFinder(PrintError): # so there are duplicates in the queue, that we discard now: continue for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode): + assert type(edge_channel_id) is bytes if edge_channel_id in self.blacklist: continue channel_info = self.channel_db.get_channel_info(edge_channel_id) - edge_startnode = channel_info.node_id_2 if channel_info.node_id_1 == edge_endnode else channel_info.node_id_1 + edge_startnode = bfh(channel_info.node2_id) if bfh(channel_info.node1_id) == edge_endnode else bfh(channel_info.node1_id) inspect_edge() else: return None # no path found @@ -682,7 +679,7 @@ class LNPathFinder(PrintError): for node_id, short_channel_id in path: channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id) if channel_policy is None: - raise Exception(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}') + raise NoChannelPolicy(short_channel_id) route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id)) prev_node_id = node_id return route diff --git a/electrum/lnverifier.py b/electrum/lnverifier.py @@ -38,7 +38,7 @@ from .verifier import verify_tx_is_in_block, MerkleVerificationFailure from .transaction import Transaction from .interface import GracefulDisconnect from .crypto import sha256d -from .lnmsg import encode_msg +from .lnmsg import decode_msg, encode_msg if TYPE_CHECKING: from .network import Network @@ -56,7 +56,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): NetworkJobOnDefaultServer.__init__(self, network) self.channel_db = channel_db self.lock = threading.Lock() - self.unverified_channel_info = {} # short_channel_id -> channel_info + self.unverified_channel_info = {} # short_channel_id -> msg_payload # channel announcements that seem to be invalid: self.blacklist = set() # short_channel_id @@ -65,19 +65,16 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): self.started_verifying_channel = set() # short_channel_id # TODO make async; and rm self.lock completely - def add_new_channel_info(self, channel_info): - short_channel_id = channel_info.channel_id + def add_new_channel_info(self, short_channel_id_hex, msg_payload): + short_channel_id = bfh(short_channel_id_hex) if short_channel_id in self.unverified_channel_info: return if short_channel_id in self.blacklist: return - if not verify_sigs_for_channel_announcement(channel_info.msg_payload): + if not verify_sigs_for_channel_announcement(msg_payload): return with self.lock: - self.unverified_channel_info[short_channel_id] = channel_info - - def get_pending_channel_info(self, short_channel_id): - return self.unverified_channel_info.get(short_channel_id, None) + self.unverified_channel_info[short_channel_id] = msg_payload async def _start_tasks(self): async with self.group as group: @@ -151,8 +148,9 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): self.print_error(f"received tx does not match expected txid ({tx_hash} != {tx.txid()})") return # check funding output - channel_info = self.unverified_channel_info[short_channel_id] - chan_ann = channel_info.msg_payload + msg_payload = self.unverified_channel_info[short_channel_id] + msg_type, chan_ann = decode_msg(msg_payload) + assert msg_type == 'channel_announcement' redeem_script = funding_output_script_from_keys(chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2']) expected_address = bitcoin.redeem_script_to_address('p2wsh', redeem_script) output_idx = invert_short_channel_id(short_channel_id)[2] @@ -167,8 +165,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): self._remove_channel_from_unverified_db(short_channel_id) return # put channel into channel DB - channel_info.set_capacity(actual_output.value) - self.channel_db.add_verified_channel_info(short_channel_id, channel_info) + self.channel_db.add_verified_channel_info(short_channel_id, actual_output.value) self._remove_channel_from_unverified_db(short_channel_id) def _remove_channel_from_unverified_db(self, short_channel_id: bytes): @@ -183,8 +180,9 @@ class LNChannelVerifier(NetworkJobOnDefaultServer): self.unverified_channel_info.pop(short_channel_id, None) -def verify_sigs_for_channel_announcement(chan_ann: dict) -> bool: - msg_bytes = encode_msg('channel_announcement', **chan_ann) +def verify_sigs_for_channel_announcement(msg_bytes: bytes) -> bool: + msg_type, chan_ann = decode_msg(msg_bytes) + assert msg_type == 'channel_announcement' pre_hash = msg_bytes[2+256:] h = sha256d(pre_hash) pubkeys = [chan_ann['node_id_1'], chan_ann['node_id_2'], chan_ann['bitcoin_key_1'], chan_ann['bitcoin_key_2']] diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -56,7 +56,8 @@ GRAPH_DOWNLOAD_SECONDS = 600 FALLBACK_NODE_LIST_TESTNET = ( LNPeerAddr('ecdsa.net', 9735, bfh('038370f0e7a03eded3e1d41dc081084a87f0afa1c5b22090b4f3abb391eb15d8ff')), - LNPeerAddr('180.181.208.42', 9735, bfh('038863cf8ab91046230f561cd5b386cbff8309fa02e3f0c3ed161a3aeb64a643b9')), + LNPeerAddr('148.251.87.112', 9735, bfh('021a8bd8d8f1f2e208992a2eb755cdc74d44e66b6a0c924d3a3cce949123b9ce40')), # janus test server + LNPeerAddr('122.199.61.90', 9735, bfh('038863cf8ab91046230f561cd5b386cbff8309fa02e3f0c3ed161a3aeb64a643b9')), # popular node https://1ml.com/testnet/node/038863cf8ab91046230f561cd5b386cbff8309fa02e3f0c3ed161a3aeb64a643b9 ) FALLBACK_NODE_LIST_MAINNET = ( LNPeerAddr('104.198.32.198', 9735, bfh('02f6725f9c1c40333b67faea92fd211c183050f28df32cac3f9d69685fe9665432')), # Blockstream @@ -420,26 +421,33 @@ class LNWorker(PrintError): @staticmethod def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]: + assert len(addr_list) >= 1 # choose first one that is an IP - for host, port in addr_list: + for addr_in_db in addr_list: + host = addr_in_db.host + port = addr_in_db.port if is_ip_address(host): return host, port # otherwise choose one at random # TODO maybe filter out onion if not on tor? - return random.choice(addr_list) + choice = random.choice(addr_list) + return choice.host, choice.port def open_channel(self, connect_contents, local_amt_sat, push_amt_sat, password=None, timeout=5): node_id, rest = extract_nodeid(connect_contents) peer = self.peers.get(node_id) if not peer: - all_nodes = self.network.channel_db.nodes - node_info = all_nodes.get(node_id, None) + nodes_get = self.network.channel_db.nodes_get + node_info = nodes_get(node_id) if rest is not None: host, port = split_host_port(rest) - elif node_info and len(node_info.addresses) > 0: - host, port = self.choose_preferred_address(node_info.addresses) else: - raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id)) + if not node_info: + raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id)) + addrs = node_info.get_addresses() + if len(addrs) == 0: + raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id)) + host, port = self.choose_preferred_address(addrs) try: socket.getaddrinfo(host, int(port)) except socket.gaierror: @@ -457,7 +465,7 @@ class LNWorker(PrintError): This is not merged with _pay so that we can run the test with one thread only. """ - addr, peer, coro = self._pay(invoice, amount_sat) + addr, peer, coro = self.network.run_from_another_thread(self._pay(invoice, amount_sat)) fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) return addr, peer, fut @@ -467,9 +475,9 @@ class LNWorker(PrintError): if chan.short_channel_id == short_channel_id: return chan - def _pay(self, invoice, amount_sat=None): + async def _pay(self, invoice, amount_sat=None, same_thread=False): addr = self._check_invoice(invoice, amount_sat) - route = self._create_route_from_invoice(decoded_invoice=addr) + route = await self._create_route_from_invoice(decoded_invoice=addr) peer = self.peers[route[0].node_id] if not self.get_channel_by_short_id(route[0].short_channel_id): assert False, 'Found route with short channel ID we don\'t have: ' + repr(route[0].short_channel_id) @@ -498,7 +506,7 @@ class LNWorker(PrintError): f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}")) return addr - def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]: + async def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]: amount_msat = int(decoded_invoice.amount * COIN * 1000) invoice_pubkey = decoded_invoice.pubkey.serialize() # use 'r' field from invoice @@ -699,20 +707,14 @@ class LNWorker(PrintError): if peer in self._last_tried_peer: continue return [peer] # try random peer from graph - all_nodes = self.channel_db.nodes - if all_nodes: - #self.print_error('trying to get ln peers from channel db') - node_ids = list(all_nodes) - max_tries = min(200, len(all_nodes)) - for i in range(max_tries): - node_id = random.choice(node_ids) - node = all_nodes.get(node_id) - if node is None: continue - addresses = node.addresses - if not addresses: continue - host, port = self.choose_preferred_address(addresses) - peer = LNPeerAddr(host, port, node_id) - if peer.pubkey in self.peers: continue + unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys()) + if unconnected_nodes: + for node in unconnected_nodes: + addrs = node.get_addresses() + if not addrs: + continue + host, port = self.choose_preferred_address(addrs) + peer = LNPeerAddr(host, port, bytes.fromhex(node.node_id)) if peer in self._last_tried_peer: continue self.print_error('taking random ln peer from our channel db') return [peer] @@ -772,11 +774,12 @@ class LNWorker(PrintError): await self.add_peer(peer.host, peer.port, peer.pubkey) return # try random address for node_id - node_info = self.channel_db.nodes.get(chan.node_id, None) + node_info = await self.channel_db._nodes_get(chan.node_id) if not node_info: return - addresses = node_info.addresses + addresses = node_info.get_addresses() if not addresses: return - host, port = random.choice(addresses) + adr_obj = random.choice(addresses) + host, port = adr_obj.host, adr_obj.port peer = LNPeerAddr(host, port, chan.node_id) last_tried = self._last_tried_peer.get(peer, 0) if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now: diff --git a/electrum/network.py b/electrum/network.py @@ -1181,7 +1181,6 @@ class Network(Logger): def stop(self): assert self._loop_thread != threading.current_thread(), 'must not be called from network thread' fut = asyncio.run_coroutine_threadsafe(self._stop(full_shutdown=True), self.asyncio_loop) - self.channel_db.save_data() try: fut.result(timeout=2) except (asyncio.TimeoutError, asyncio.CancelledError): pass diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -226,7 +226,7 @@ class TestPeer(unittest.TestCase): fut = self.prepare_ln_message_future(w2) async def pay(): - addr, peer, coro = LNWorker._pay(w1, pay_req) + addr, peer, coro = await LNWorker._pay(w1, pay_req, same_thread=True) await coro print("HTLC ADDED") self.assertEqual(await fut, 'Payment received') @@ -240,14 +240,14 @@ class TestPeer(unittest.TestCase): pay_req = self.prepare_invoice(w2) addr = w1._check_invoice(pay_req) - route = w1._create_route_from_invoice(decoded_invoice=addr) + route = run(w1._create_route_from_invoice(decoded_invoice=addr)) run(w1.force_close_channel(self.alice_channel.channel_id)) # check if a tx (commitment transaction) was broadcasted: assert q1.qsize() == 1 with self.assertRaises(PaymentFailure) as e: - w1._create_route_from_invoice(decoded_invoice=addr) + run(w1._create_route_from_invoice(decoded_invoice=addr)) self.assertEqual(str(e.exception), 'No path found') peer = w1.peers[route[0].node_id] @@ -257,4 +257,4 @@ class TestPeer(unittest.TestCase): run(asyncio.gather(w1._pay_to_route(route, addr, pay_req), p1._main_loop(), p2._main_loop())) def run(coro): - asyncio.get_event_loop().run_until_complete(coro) + return asyncio.get_event_loop().run_until_complete(coro) diff --git a/electrum/tests/test_lnrouter.py b/electrum/tests/test_lnrouter.py @@ -45,15 +45,17 @@ class Test_LNRouter(TestCaseForTestnet): asyncio_loop = asyncio.get_event_loop() trigger_callback = lambda *args: None register_callback = lambda *args: None - async def add_job(self, *args): return None + interface = None fake_network.channel_db = lnrouter.ChannelDB(fake_network()) cdb = fake_network.channel_db path_finder = lnrouter.LNPathFinder(cdb) + self.assertEqual(cdb.num_channels, 0) cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02cccccccccccccccccccccccccccccccc', 'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02cccccccccccccccccccccccccccccccc', 'short_channel_id': bfh('0000000000000001'), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'len': b'\x00\x00', 'features': b''}, trusted=True) + self.assertEqual(cdb.num_channels, 1) cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 'short_channel_id': bfh('0000000000000002'), @@ -92,12 +94,16 @@ class Test_LNRouter(TestCaseForTestnet): cdb.on_channel_update({'short_channel_id': bfh('0000000000000005'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(999), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True) cdb.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x00', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(99999999), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True) cdb.on_channel_update({'short_channel_id': bfh('0000000000000006'), 'message_flags': b'\x00', 'channel_flags': b'\x01', 'cltv_expiry_delta': o(10), 'htlc_minimum_msat': o(250), 'fee_base_msat': o(100), 'fee_proportional_millionths': o(150), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'timestamp': b'\x00\x00\x00\x00'}, trusted=True) - self.assertNotEqual(None, path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000)) + path = path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000) self.assertEqual([(b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', b'\x00\x00\x00\x00\x00\x00\x00\x03'), (b'\x02cccccccccccccccccccccccccccccccc', b'\x00\x00\x00\x00\x00\x00\x00\x01'), (b'\x02dddddddddddddddddddddddddddddddd', b'\x00\x00\x00\x00\x00\x00\x00\x04'), - (b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', b'\x00\x00\x00\x00\x00\x00\x00\x05')], - path_finder.find_path_for_payment(b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 100000)) + (b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', b'\x00\x00\x00\x00\x00\x00\x00\x05') + ], path) + start_node = b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb' + route = path_finder.create_route_from_path(path, start_node) + self.assertEqual(route[0].node_id, start_node) + self.assertEqual(route[0].short_channel_id, bfh('0000000000000003'))