electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit f2d58d0e3f97975d4dcfcbcacc96d7e206190ef6
parent 180f6d34bec2f3e443488f922591d51c11cab1f6
Author: ThomasV <thomasv@electrum.org>
Date:   Tue, 18 Jun 2019 13:49:31 +0200

optimize channel_db:
 - use python objects mirrored by sql database
 - write sql to file asynchronously
 - the sql decorator is awaited in sweepstore, not in channel_db

Diffstat:
Melectrum/channel_db.py | 534+++++++++++++++++++++++++++++++++++++------------------------------------------
Melectrum/gui/qt/lightning_dialog.py | 9+++++----
Melectrum/lnaddr.py | 11+++++++++--
Melectrum/lnchannel.py | 12+++++-------
Melectrum/lnpeer.py | 87+++++++++++++++++++++++++++++++++++--------------------------------------------
Melectrum/lnrouter.py | 19++++++++++---------
Melectrum/lnwatcher.py | 49++++++++++++++++++++++++++++---------------------
Melectrum/lnworker.py | 131++++++++++++++++++++++++++++++++++++++++---------------------------------------
Melectrum/sql_db.py | 22+++++++++++++++++-----
Melectrum/tests/test_lnpeer.py | 3++-
Melectrum/tests/test_lnrouter.py | 12++++++------
11 files changed, 435 insertions(+), 454 deletions(-)

diff --git a/electrum/channel_db.py b/electrum/channel_db.py @@ -51,6 +51,7 @@ from .crypto import sha256d from . import ecc from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH, NotFoundChanAnnouncementForUpdate) +from .lnverifier import verify_sig_for_channel_update from .lnmsg import encode_msg if TYPE_CHECKING: @@ -70,85 +71,83 @@ Base = declarative_base() FLAG_DISABLE = 1 << 1 FLAG_DIRECTION = 1 << 0 -class ChannelInfo(Base): - __tablename__ = 'channel_info' - short_channel_id = Column(String(64), primary_key=True) - node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) - node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) - capacity_sat = Column(Integer) - msg_payload_hex = Column(String(1024), nullable=False) - trusted = Column(Boolean, nullable=False) +class ChannelInfo(NamedTuple): + short_channel_id: bytes + node1_id: bytes + node2_id: bytes + capacity_sat: int + msg_payload: bytes + trusted: bool @staticmethod def from_msg(payload): features = int.from_bytes(payload['features'], 'big') validate_features(features) - channel_id = payload['short_channel_id'].hex() - node_id_1 = payload['node_id_1'].hex() - node_id_2 = payload['node_id_2'].hex() + channel_id = payload['short_channel_id'] + node_id_1 = payload['node_id_1'] + node_id_2 = payload['node_id_2'] assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] - msg_payload_hex = encode_msg('channel_announcement', **payload).hex() + msg_payload = encode_msg('channel_announcement', **payload) capacity_sat = None - return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1, - node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, - trusted = False) - - @property - def msg_payload(self): - return bytes.fromhex(self.msg_payload_hex) - - -class Policy(Base): - __tablename__ = 'policy' - start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) - short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True) - cltv_expiry_delta = Column(Integer, nullable=False) - htlc_minimum_msat = Column(Integer, nullable=False) - htlc_maximum_msat = Column(Integer) - fee_base_msat = Column(Integer, nullable=False) - fee_proportional_millionths = Column(Integer, nullable=False) - channel_flags = Column(Integer, nullable=False) - timestamp = Column(Integer, nullable=False) + return ChannelInfo( + short_channel_id = channel_id, + node1_id = node_id_1, + node2_id = node_id_2, + capacity_sat = capacity_sat, + msg_payload = msg_payload, + trusted = False) + + + +class Policy(NamedTuple): + key: bytes + cltv_expiry_delta: int + htlc_minimum_msat: int + htlc_maximum_msat: int + fee_base_msat: int + fee_proportional_millionths: int + channel_flags: int + timestamp: int @staticmethod def from_msg(payload): - cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big") - htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big") - htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None - fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big") - fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big") - channel_flags = int.from_bytes(payload['channel_flags'], "big") - timestamp = int.from_bytes(payload['timestamp'], "big") - start_node = payload['start_node'].hex() - short_channel_id = payload['short_channel_id'].hex() - - return Policy(start_node=start_node, - short_channel_id=short_channel_id, - cltv_expiry_delta=cltv_expiry_delta, - htlc_minimum_msat=htlc_minimum_msat, - fee_base_msat=fee_base_msat, - fee_proportional_millionths=fee_proportional_millionths, - channel_flags=channel_flags, - timestamp=timestamp, - htlc_maximum_msat=htlc_maximum_msat) + return Policy( + key = payload['short_channel_id'] + payload['start_node'], + cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"), + htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big"), + htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None, + fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big"), + fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big"), + channel_flags = int.from_bytes(payload['channel_flags'], "big"), + timestamp = int.from_bytes(payload['timestamp'], "big") + ) def is_disabled(self): return self.channel_flags & FLAG_DISABLE -class NodeInfo(Base): - __tablename__ = 'node_info' - node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') - features = Column(Integer, nullable=False) - timestamp = Column(Integer, nullable=False) - alias = Column(String(64), nullable=False) + @property + def short_channel_id(self): + return self.key[0:8] + + @property + def start_node(self): + return self.key[8:] + + + +class NodeInfo(NamedTuple): + node_id: bytes + features: int + timestamp: int + alias: str @staticmethod def from_msg(payload): - node_id = payload['node_id'].hex() + node_id = payload['node_id'] features = int.from_bytes(payload['features'], "big") validate_features(features) addresses = NodeInfo.parse_addresses_field(payload['addresses']) - alias = payload['alias'].rstrip(b'\x00').hex() + alias = payload['alias'].rstrip(b'\x00') timestamp = int.from_bytes(payload['timestamp'], "big") return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [ Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses] @@ -193,110 +192,136 @@ class NodeInfo(Base): break return addresses -class Address(Base): + +class Address(NamedTuple): + node_id: bytes + host: str + port: int + last_connected_date: int + + +class ChannelInfoBase(Base): + __tablename__ = 'channel_info' + short_channel_id = Column(String(64), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') + node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) + node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) + capacity_sat = Column(Integer) + msg_payload = Column(String(1024), nullable=False) + trusted = Column(Boolean, nullable=False) + + def to_nametuple(self): + return ChannelInfo( + short_channel_id=self.short_channel_id, + node1_id=self.node1_id, + node2_id=self.node2_id, + capacity_sat=self.capacity_sat, + msg_payload=self.msg_payload, + trusted=self.trusted + ) + +class PolicyBase(Base): + __tablename__ = 'policy' + key = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') + cltv_expiry_delta = Column(Integer, nullable=False) + htlc_minimum_msat = Column(Integer, nullable=False) + htlc_maximum_msat = Column(Integer) + fee_base_msat = Column(Integer, nullable=False) + fee_proportional_millionths = Column(Integer, nullable=False) + channel_flags = Column(Integer, nullable=False) + timestamp = Column(Integer, nullable=False) + + def to_nametuple(self): + return Policy( + key=self.key, + cltv_expiry_delta=self.cltv_expiry_delta, + htlc_minimum_msat=self.htlc_minimum_msat, + htlc_maximum_msat=self.htlc_maximum_msat, + fee_base_msat= self.fee_base_msat, + fee_proportional_millionths = self.fee_proportional_millionths, + channel_flags=self.channel_flags, + timestamp=self.timestamp + ) + +class NodeInfoBase(Base): + __tablename__ = 'node_info' + node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') + features = Column(Integer, nullable=False) + timestamp = Column(Integer, nullable=False) + alias = Column(String(64), nullable=False) + +class AddressBase(Base): __tablename__ = 'address' - node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) + node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') host = Column(String(256), primary_key=True) port = Column(Integer, primary_key=True) last_connected_date = Column(Integer(), nullable=True) - class ChannelDB(SqlDB): NUM_MAX_RECENT_PEERS = 20 def __init__(self, network: 'Network'): path = os.path.join(get_headers_dir(network.config), 'channel_db') - super().__init__(network, path, Base) + super().__init__(network, path, Base, commit_interval=100) self.num_nodes = 0 self.num_channels = 0 self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self.ca_verifier = LNChannelVerifier(network, self) - self.update_counts() + # initialized in load_data + self._channels = {} + self._policies = {} + self._nodes = {} + self._addresses = defaultdict(set) + self._channels_for_node = defaultdict(set) - @sql def update_counts(self): - self._update_counts() + self.num_channels = len(self._channels) + self.num_policies = len(self._policies) + self.num_nodes = len(self._nodes) - def _update_counts(self): - self.num_channels = self.DBSession.query(ChannelInfo).count() - self.num_policies = self.DBSession.query(Policy).count() - self.num_nodes = self.DBSession.query(NodeInfo).count() + def get_channel_ids(self): + return set(self._channels.keys()) - @sql - def known_ids(self): - known = self.DBSession.query(ChannelInfo.short_channel_id).all() - return set(bfh(r.short_channel_id) for r in known) - - @sql def add_recent_peer(self, peer: LNPeerAddr): now = int(time.time()) - node_id = peer.pubkey.hex() - addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none() + node_id = peer.pubkey + self._addresses[node_id].add((peer.host, peer.port, now)) + self.save_address(node_id, peer, now) + + @sql + def save_address(self, node_id, peer, now): + addr = self.DBSession.query(AddressBase).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none() if addr: addr.last_connected_date = now else: - addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now) + addr = AddressBase(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now) self.DBSession.add(addr) - self.DBSession.commit() - - @sql - def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): - unshuffled = self.DBSession \ - .query(NodeInfo) \ - .filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \ - .limit(200) \ - .all() - return random.sample(unshuffled, len(unshuffled)) - @sql - def nodes_get(self, node_id): - return self.DBSession \ - .query(NodeInfo) \ - .filter_by(node_id = node_id.hex()) \ - .one_or_none() + def get_200_randomly_sorted_nodes_not_in(self, node_ids): + unshuffled = set(self._nodes.keys()) - node_ids + return random.sample(unshuffled, min(200, len(unshuffled))) - @sql def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: - r = self.DBSession.query(Address).filter_by(node_id=node_id.hex()).order_by(Address.last_connected_date.desc()).all() + r = self._addresses.get(node_id) if not r: return None - addr = r[0] - return LNPeerAddr(addr.host, addr.port, bytes.fromhex(addr.node_id)) + addr = sorted(list(r), key=lambda x: x[2])[0] + host, port, timestamp = addr + return LNPeerAddr(host, port, node_id) - @sql def get_recent_peers(self): - r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all() - return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r] - - @sql - def missing_channel_announcements(self) -> Set[int]: - expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id))) - return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all()) + r = [self.get_last_good_address(x) for x in self._addresses.keys()] + r = r[-self.NUM_MAX_RECENT_PEERS:] + return r - @sql - def missing_channel_updates(self) -> Set[int]: - expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id))) - return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) - - @sql - def add_verified_channel_info(self, short_id, capacity): - # called from lnchannelverifier - channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none() - channel_info.trusted = True - channel_info.capacity = capacity - self.DBSession.commit() - - @sql - @profiler - def on_channel_announcement(self, msg_payloads, trusted=True): + def add_channel_announcement(self, msg_payloads, trusted=True): if type(msg_payloads) is dict: msg_payloads = [msg_payloads] - new_channels = {} + added = 0 for msg in msg_payloads: - short_channel_id = bh2u(msg['short_channel_id']) - if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count(): + short_channel_id = msg['short_channel_id'] + if short_channel_id in self._channels: continue if constants.net.rev_genesis_bytes() != msg['chain_hash']: self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash']))) @@ -306,24 +331,24 @@ class ChannelDB(SqlDB): except UnknownEvenFeatureBits: self.logger.info("unknown feature bits") continue - channel_info.trusted = trusted - new_channels[short_channel_id] = channel_info + #channel_info.trusted = trusted + added += 1 + self._channels[short_channel_id] = channel_info + self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) + self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) + self.save_channel(channel_info) if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload) - for channel_info in new_channels.values(): - self.DBSession.add(channel_info) - self.DBSession.commit() - self._update_counts() - self.logger.debug('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads))) - @sql - def get_last_timestamp(self): - return self._get_last_timestamp() + self.update_counts() + self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads))) - def _get_last_timestamp(self): - from sqlalchemy.sql import func - r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one() - return r.max_timestamp or 0 + + #def add_verified_channel_info(self, short_id, capacity): + # # called from lnchannelverifier + # channel_info = self.DBSession.query(ChannelInfoBase).filter_by(short_channel_id = short_id).one_or_none() + # channel_info.trusted = True + # channel_info.capacity = capacity def print_change(self, old_policy, new_policy): # print what changed between policies @@ -340,89 +365,74 @@ class ChannelDB(SqlDB): if old_policy.channel_flags != new_policy.channel_flags: self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}') - @sql - def get_info_for_updates(self, payloads): - short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads] - channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() - channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} - return channel_infos - - @sql - def get_policies_for_updates(self, payloads): - out = {} - for payload in payloads: - short_channel_id = payload['short_channel_id'].hex() - start_node = payload['start_node'].hex() - policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none() - if policy: - out[short_channel_id+start_node] = policy - return out - - @profiler - def filter_channel_updates(self, payloads, max_age=None): + def add_channel_updates(self, payloads, max_age=None, verify=True): orphaned = [] # no channel announcement for channel update expired = [] # update older than two weeks deprecated = [] # update older than database entry - good = {} # good updates + good = [] # good updates to_delete = [] # database entries to delete # filter orphaned and expired first known = [] now = int(time.time()) - channel_infos = self.get_info_for_updates(payloads) for payload in payloads: short_channel_id = payload['short_channel_id'] timestamp = int.from_bytes(payload['timestamp'], "big") if max_age and now - timestamp > max_age: expired.append(short_channel_id) continue - channel_info = channel_infos.get(short_channel_id) + channel_info = self._channels.get(short_channel_id) if not channel_info: orphaned.append(short_channel_id) continue flags = int.from_bytes(payload['channel_flags'], 'big') direction = flags & FLAG_DIRECTION start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id - payload['start_node'] = bfh(start_node) + payload['start_node'] = start_node known.append(payload) # compare updates to existing database entries - old_policies = self.get_policies_for_updates(known) for payload in known: timestamp = int.from_bytes(payload['timestamp'], "big") start_node = payload['start_node'] short_channel_id = payload['short_channel_id'] - key = (short_channel_id+start_node).hex() - old_policy = old_policies.get(key) - if old_policy: - if timestamp <= old_policy.timestamp: - deprecated.append(short_channel_id) - else: - good[key] = payload - to_delete.append(old_policy) - else: - good[key] = payload - good = list(good.values()) + key = (start_node, short_channel_id) + old_policy = self._policies.get(key) + if old_policy and timestamp <= old_policy.timestamp: + deprecated.append(short_channel_id) + continue + good.append(payload) + if verify: + self.verify_channel_update(payload) + policy = Policy.from_msg(payload) + self._policies[key] = policy + self.save_policy(policy) + # + self.update_counts() return orphaned, expired, deprecated, good, to_delete def add_channel_update(self, payload): - orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload]) + orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False) assert len(good) == 1 - self.update_policies(good, to_delete) @sql - @profiler - def update_policies(self, to_add, to_delete): - for policy in to_delete: - self.DBSession.delete(policy) - self.DBSession.commit() - for payload in to_add: - policy = Policy.from_msg(payload) - self.DBSession.add(policy) - self.DBSession.commit() - self._update_counts() + def save_policy(self, policy): + self.DBSession.execute(PolicyBase.__table__.insert().values(policy)) @sql - @profiler - def on_node_announcement(self, msg_payloads): + def delete_policy(self, short_channel_id, node_id): + self.DBSession.execute(PolicyBase.__table__.delete().values(policy)) + + @sql + def save_channel(self, channel_info): + self.DBSession.execute(ChannelInfoBase.__table__.insert().values(channel_info)) + + def verify_channel_update(self, payload): + short_channel_id = payload['short_channel_id'] + if constants.net.rev_genesis_bytes() != payload['chain_hash']: + raise Exception('wrong chain hash') + if not verify_sig_for_channel_update(payload, payload['start_node']): + raise BaseException('verify error') + + def add_node_announcement(self, msg_payloads): if type(msg_payloads) is dict: msg_payloads = [msg_payloads] old_addr = None @@ -435,29 +445,35 @@ class ChannelDB(SqlDB): continue node_id = node_info.node_id # Ignore node if it has no associated channel (DoS protection) - # FIXME this is slow - expr = or_(ChannelInfo.node1_id==node_id, ChannelInfo.node2_id==node_id) - if len(self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).limit(1).all()) == 0: + if node_id not in self._channels_for_node: #self.logger.info('ignoring orphan node_announcement') continue - node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none() + node = self._nodes.get(node_id) if node and node.timestamp >= node_info.timestamp: continue node = new_nodes.get(node_id) if node and node.timestamp >= node_info.timestamp: continue - new_nodes[node_id] = node_info + # save + self._nodes[node_id] = node_info + self.save_node(node_info) for addr in node_addresses: - new_addresses[(addr.node_id,addr.host,addr.port)] = addr + self._addresses[node_id].add((addr.host, addr.port, 0)) + self.save_node_addresses(node_id, node_addresses) + self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) - for node_info in new_nodes.values(): - self.DBSession.add(node_info) - for new_addr in new_addresses.values(): - old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none() + self.update_counts() + + @sql + def save_node_addresses(self, node_if, node_addresses): + for new_addr in node_addresses: + old_addr = self.DBSession.query(AddressBase).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none() if not old_addr: - self.DBSession.add(new_addr) - self.DBSession.commit() - self._update_counts() + self.DBSession.execute(AddressBase.__table__.insert().values(new_addr)) + + @sql + def save_node(self, node_info): + self.DBSession.execute(NodeInfoBase.__table__.insert().values(node_info)) def get_routing_policy_for_channel(self, start_node_id: bytes, short_channel_id: bytes) -> Optional[bytes]: @@ -470,41 +486,28 @@ class ChannelDB(SqlDB): return None return Policy.from_msg(msg) # won't actually be written to DB - @sql - @profiler def get_old_policies(self, delta): - timestamp = int(time.time()) - delta - old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp) - return old_policies.distinct().count() + now = int(time.time()) + return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta) - @sql - @profiler def prune_old_policies(self, delta): - # note: delete queries are order sensitive - timestamp = int(time.time()) - delta - old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp) - delete_old_channels = ChannelInfo.__table__.delete().where(ChannelInfo.short_channel_id.in_(old_policies)) - delete_old_policies = Policy.__table__.delete().where(Policy.timestamp <= timestamp) - self.DBSession.execute(delete_old_channels) - self.DBSession.execute(delete_old_policies) - self.DBSession.commit() - self._update_counts() + l = self.get_old_policies(delta) + for k in l: + self._policies.pop(k) + if l: + self.logger.info(f'Deleting {len(l)} old policies') - @sql - @profiler def get_orphaned_channels(self): - subquery = self.DBSession.query(Policy.short_channel_id) - orphaned = self.DBSession.query(ChannelInfo).filter(not_(ChannelInfo.short_channel_id.in_(subquery))) - return orphaned.count() + ids = set(x[1] for x in self._policies.keys()) + return list(x for x in self._channels.keys() if x not in ids) - @sql - @profiler def prune_orphaned_channels(self): - subquery = self.DBSession.query(Policy.short_channel_id) - delete_orphaned = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(subquery))) - self.DBSession.execute(delete_orphaned) - self.DBSession.commit() - self._update_counts() + l = self.get_orphaned_channels() + for k in l: + self._channels.pop(k) + self.update_counts() + if l: + self.logger.info(f'Deleting {len(l)} orphaned channels') def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes): if not verify_sig_for_channel_update(msg_payload, start_node_id): @@ -513,67 +516,27 @@ class ChannelDB(SqlDB): msg_payload['start_node'] = start_node_id self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload - @sql def remove_channel(self, short_channel_id): - r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none() - if not r: - return - self.DBSession.delete(r) - self.DBSession.commit() - - def print_graph(self, full_ids=False): - # used for debugging. - # FIXME there is a race here - iterables could change size from another thread - def other_node_id(node_id, channel_id): - channel_info = self.get_channel_info(channel_id) - if node_id == channel_info.node1_id: - other = channel_info.node2_id - else: - other = channel_info.node1_id - return other if full_ids else other[-4:] - - print_msg('nodes') - for node in self.DBSession.query(NodeInfo).all(): - print_msg(node) - - print_msg('channels') - for channel_info in self.DBSession.query(ChannelInfo).all(): - short_channel_id = channel_info.short_channel_id - node1 = channel_info.node1_id - node2 = channel_info.node2_id - direction1 = self.get_policy_for_node(channel_info, node1) is not None - direction2 = self.get_policy_for_node(channel_info, node2) is not None - if direction1 and direction2: - direction = 'both' - elif direction1: - direction = 'forward' - elif direction2: - direction = 'backward' - else: - direction = 'none' - print_msg('{}: {}, {}, {}' - .format(bh2u(short_channel_id), - bh2u(node1) if full_ids else bh2u(node1[-4:]), - bh2u(node2) if full_ids else bh2u(node2[-4:]), - direction)) + self._channels.pop(short_channel_id, None) - - @sql - def get_node_addresses(self, node_info): - return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all() + def get_node_addresses(self, node_id): + return self._addresses.get(node_id) @sql @profiler def load_data(self): - r = self.DBSession.query(ChannelInfo).all() - self._channels = dict([(bfh(x.short_channel_id), x) for x in r]) - r = self.DBSession.query(Policy).filter_by().all() - self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r]) - self._channels_for_node = defaultdict(set) + for x in self.DBSession.query(AddressBase).all(): + self._addresses[x.node_id].add((str(x.host), int(x.port), int(x.last_connected_date or 0))) + for x in self.DBSession.query(ChannelInfoBase).all(): + self._channels[x.short_channel_id] = x.to_nametuple() + for x in self.DBSession.query(PolicyBase).filter_by().all(): + p = x.to_nametuple() + self._policies[(p.start_node, p.short_channel_id)] = p for channel_info in self._channels.values(): - self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id)) - self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id)) + self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) + self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}') + self.update_counts() def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']: return self._policies.get((node_id, short_channel_id)) @@ -584,6 +547,3 @@ class ChannelDB(SqlDB): def get_channels_for_node(self, node_id) -> Set[bytes]: """Returns the set of channels that have node_id as one of the endpoints.""" return self._channels_for_node.get(node_id) or set() - - - diff --git a/electrum/gui/qt/lightning_dialog.py b/electrum/gui/qt/lightning_dialog.py @@ -56,10 +56,11 @@ class WatcherList(MyTreeView): return self.model().clear() self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')}) - sweepstore = self.parent.lnwatcher.sweepstore - for outpoint in sweepstore.list_sweep_tx(): - n = sweepstore.get_num_tx(outpoint) - status = self.parent.lnwatcher.get_channel_status(outpoint) + lnwatcher = self.parent.lnwatcher + l = lnwatcher.list_sweep_tx() + for outpoint in l: + n = lnwatcher.get_num_tx(outpoint) + status = lnwatcher.get_channel_status(outpoint) items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]] self.model().insertRow(self.model().rowCount(), items) diff --git a/electrum/lnaddr.py b/electrum/lnaddr.py @@ -258,14 +258,21 @@ class LnAddr(object): def get_min_final_cltv_expiry(self) -> int: return self._min_final_cltv_expiry - def get_description(self): + def get_tag(self, tag): description = '' for k,v in self.tags: - if k == 'd': + if k == tag: description = v break return description + def get_description(self): + return self.get_tag('d') + + def get_expiry(self): + return int(self.get_tag('x') or '3600') + + def lndecode(a, verbose=False, expected_hrp=None): if expected_hrp is None: diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py @@ -163,8 +163,6 @@ class Channel(Logger): self._is_funding_txo_spent = None # "don't know" self._state = None self.set_state('DISCONNECTED') - self.lnwatcher = None - self.local_commitment = None self.remote_commitment = None self.sweep_info = None @@ -453,13 +451,10 @@ class Channel(Logger): return secret, point def process_new_revocation_secret(self, per_commitment_secret: bytes): - if not self.lnwatcher: - return outpoint = self.funding_outpoint.to_str() ctx = self.remote_commitment_to_be_revoked # FIXME can't we just reconstruct it? sweeptxs = create_sweeptxs_for_their_revoked_ctx(self, ctx, per_commitment_secret, self.sweep_address) - for tx in sweeptxs: - self.lnwatcher.add_sweep_tx(outpoint, tx.prevout(0), str(tx)) + return sweeptxs def receive_revocation(self, revocation: RevokeAndAck): self.logger.info("receive_revocation") @@ -477,9 +472,10 @@ class Channel(Logger): # be robust to exceptions raised in lnwatcher try: - self.process_new_revocation_secret(revocation.per_commitment_secret) + sweeptxs = self.process_new_revocation_secret(revocation.per_commitment_secret) except Exception as e: self.logger.info("Could not process revocation secret: {}".format(repr(e))) + sweeptxs = [] ##### start applying fee/htlc changes @@ -505,6 +501,8 @@ class Channel(Logger): self.set_remote_commitment() self.remote_commitment_to_be_revoked = prev_remote_commitment + # return sweep transactions for watchtower + return sweeptxs def balance(self, whose, *, ctx_owner=HTLCOwner.LOCAL, ctn=None): """ diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -42,7 +42,6 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY) from .lntransport import LNTransport, LNTransportBase from .lnmsg import encode_msg, decode_msg -from .lnverifier import verify_sig_for_channel_update from .interface import GracefulDisconnect if TYPE_CHECKING: @@ -242,22 +241,20 @@ class Peer(Logger): # channel announcements for chan_anns_chunk in chunks(chan_anns, 300): self.verify_channel_announcements(chan_anns_chunk) - self.channel_db.on_channel_announcement(chan_anns_chunk) + self.channel_db.add_channel_announcement(chan_anns_chunk) # node announcements for node_anns_chunk in chunks(node_anns, 100): self.verify_node_announcements(node_anns_chunk) - self.channel_db.on_node_announcement(node_anns_chunk) + self.channel_db.add_node_announcement(node_anns_chunk) # channel updates for chan_upds_chunk in chunks(chan_upds, 1000): - orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates(chan_upds_chunk, - max_age=self.network.lngossip.max_age) + orphaned, expired, deprecated, good, to_delete = self.channel_db.add_channel_updates( + chan_upds_chunk, max_age=self.network.lngossip.max_age) if orphaned: self.logger.info(f'adding {len(orphaned)} unknown channel ids') - self.network.lngossip.add_new_ids(orphaned) + await self.network.lngossip.add_new_ids(orphaned) if good: self.logger.debug(f'on_channel_update: {len(good)}/{len(chan_upds_chunk)}') - self.verify_channel_updates(good) - self.channel_db.update_policies(good, to_delete) # refresh gui if chan_anns or node_anns or chan_upds: self.network.lngossip.refresh_gui() @@ -279,14 +276,6 @@ class Peer(Logger): if not ecc.verify_signature(pubkey, signature, h): raise Exception('signature failed') - def verify_channel_updates(self, chan_upds): - for payload in chan_upds: - short_channel_id = payload['short_channel_id'] - if constants.net.rev_genesis_bytes() != payload['chain_hash']: - raise Exception('wrong chain hash') - if not verify_sig_for_channel_update(payload, payload['start_node']): - raise BaseException('verify error') - async def query_gossip(self): try: await asyncio.wait_for(self.initialized.wait(), 10) @@ -298,7 +287,7 @@ class Peer(Logger): except asyncio.TimeoutError as e: raise GracefulDisconnect("query_channel_range timed out") from e self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete)) - self.lnworker.add_new_ids(ids) + await self.lnworker.add_new_ids(ids) while True: todo = self.lnworker.get_ids_to_query() if not todo: @@ -658,7 +647,7 @@ class Peer(Logger): ) chan.open_with_first_pcp(payload['first_per_commitment_point'], remote_sig) self.lnworker.save_channel(chan) - self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) + await self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) self.lnworker.on_channels_updated() while True: try: @@ -862,8 +851,6 @@ class Peer(Logger): bitcoin_key_2=bitcoin_keys[1] ) - print("SENT CHANNEL ANNOUNCEMENT") - def mark_open(self, chan: Channel): assert chan.short_channel_id is not None if chan.get_state() == "OPEN": @@ -872,6 +859,10 @@ class Peer(Logger): assert chan.config[LOCAL].funding_locked_received chan.set_state("OPEN") self.network.trigger_callback('channel', chan) + asyncio.ensure_future(self.add_own_channel(chan)) + self.logger.info("CHANNEL OPENING COMPLETED") + + async def add_own_channel(self, chan): # add channel to database bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey] sorted_node_ids = list(sorted(self.node_ids)) @@ -887,7 +878,7 @@ class Peer(Logger): # that the remote sends, even if the channel was not announced # (from BOLT-07: "MAY create a channel_update to communicate the channel # parameters to the final node, even though the channel has not yet been announced") - self.channel_db.on_channel_announcement( + self.channel_db.add_channel_announcement( { "short_channel_id": chan.short_channel_id, "node_id_1": node_ids[0], @@ -922,8 +913,6 @@ class Peer(Logger): if pending_channel_update: self.channel_db.add_channel_update(pending_channel_update) - self.logger.info("CHANNEL OPENING COMPLETED") - def send_announcement_signatures(self, chan: Channel): bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey, @@ -962,36 +951,34 @@ class Peer(Logger): def on_update_fail_htlc(self, payload): channel_id = payload["channel_id"] htlc_id = int.from_bytes(payload["id"], "big") - key = (channel_id, htlc_id) - try: - route = self.attempted_route[key] - except KeyError: - # the remote might try to fail an htlc after we restarted... - # attempted_route is not persisted, so we will get here then - self.logger.info("UPDATE_FAIL_HTLC. cannot decode! attempted route is MISSING. {}".format(key)) - else: - try: - self._handle_error_code_from_failed_htlc(payload["reason"], route, channel_id, htlc_id) - except Exception: - # exceptions are suppressed as failing to handle an error code - # should not block us from removing the htlc - traceback.print_exc(file=sys.stderr) - # process update_fail_htlc on channel chan = self.channels[channel_id] chan.receive_fail_htlc(htlc_id) local_ctn = chan.get_current_ctn(LOCAL) - asyncio.ensure_future(self._on_update_fail_htlc(chan, htlc_id, local_ctn)) + asyncio.ensure_future(self._handle_error_code_from_failed_htlc(payload, channel_id, htlc_id)) + asyncio.ensure_future(self._on_update_fail_htlc(channel_id, htlc_id, local_ctn)) @log_exceptions - async def _on_update_fail_htlc(self, chan, htlc_id, local_ctn): + async def _on_update_fail_htlc(self, channel_id, htlc_id, local_ctn): + chan = self.channels[channel_id] await self.await_local(chan, local_ctn) self.lnworker.pending_payments[(chan.short_channel_id, htlc_id)].set_result(False) - def _handle_error_code_from_failed_htlc(self, error_reason, route: List['RouteEdge'], channel_id, htlc_id): + @log_exceptions + async def _handle_error_code_from_failed_htlc(self, payload, channel_id, htlc_id): chan = self.channels[channel_id] - failure_msg, sender_idx = decode_onion_error(error_reason, - [x.node_id for x in route], - chan.onion_keys[htlc_id]) + key = (channel_id, htlc_id) + try: + route = self.attempted_route[key] + except KeyError: + # the remote might try to fail an htlc after we restarted... + # attempted_route is not persisted, so we will get here then + self.logger.info("UPDATE_FAIL_HTLC. cannot decode! attempted route is MISSING. {}".format(key)) + return + error_reason = payload["reason"] + failure_msg, sender_idx = decode_onion_error( + error_reason, + [x.node_id for x in route], + chan.onion_keys[htlc_id]) code, data = failure_msg.code, failure_msg.data self.logger.info(f"UPDATE_FAIL_HTLC {repr(code)} {data}") self.logger.info(f"error reported by {bh2u(route[sender_idx].node_id)}") @@ -1009,11 +996,9 @@ class Peer(Logger): channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:] message_type, payload = decode_msg(channel_update) payload['raw'] = channel_update - orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates([payload]) + orphaned, expired, deprecated, good, to_delete = self.channel_db.add_channel_updates([payload]) blacklist = False if good: - self.verify_channel_updates(good) - self.channel_db.update_policies(good, to_delete) self.logger.info("applied channel update on our db") elif orphaned: # maybe it is a private channel (and data in invoice was outdated) @@ -1276,11 +1261,17 @@ class Peer(Logger): self.logger.info("on_revoke_and_ack") channel_id = payload["channel_id"] chan = self.channels[channel_id] - chan.receive_revocation(RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"])) + sweeptxs = chan.receive_revocation(RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"])) self._remote_changed_events[chan.channel_id].set() self._remote_changed_events[chan.channel_id].clear() self.lnworker.save_channel(chan) self.maybe_send_commitment(chan) + asyncio.ensure_future(self._on_revoke_and_ack(chan, sweeptxs)) + + async def _on_revoke_and_ack(self, chan, sweeptxs): + outpoint = chan.funding_outpoint.to_str() + for tx in sweeptxs: + await self.lnwatcher.add_sweep_tx(outpoint, tx.prevout(0), str(tx)) def on_update_fee(self, payload): channel_id = payload["channel_id"] diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py @@ -37,7 +37,7 @@ import binascii import base64 from . import constants -from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits, print_msg, chunks +from .util import bh2u, profiler, get_headers_dir, is_ip_address, list_enabled_bits, print_msg, chunks from .logging import Logger from .storage import JsonDB from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update @@ -169,7 +169,6 @@ class LNPathFinder(Logger): To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1]; i.e. an element reads as, "to get to node_id, travel through short_channel_id" """ - self.channel_db.load_data() assert type(nodeA) is bytes assert type(nodeB) is bytes assert type(invoice_amount_msat) is int @@ -195,11 +194,12 @@ class LNPathFinder(Logger): else: # payment incoming, on our channel. (funny business, cycle weirdness) assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode)) pass # TODO? - edge_cost, fee_for_edge_msat = self._edge_cost(edge_channel_id, - start_node=edge_startnode, - end_node=edge_endnode, - payment_amt_msat=amount_msat, - ignore_costs=(edge_startnode == nodeA)) + edge_cost, fee_for_edge_msat = self._edge_cost( + edge_channel_id, + start_node=edge_startnode, + end_node=edge_endnode, + payment_amt_msat=amount_msat, + ignore_costs=(edge_startnode == nodeA)) alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost if alt_dist_to_neighbour < distance_from_start[edge_startnode]: distance_from_start[edge_startnode] = alt_dist_to_neighbour @@ -219,9 +219,10 @@ class LNPathFinder(Logger): continue for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode): assert type(edge_channel_id) is bytes - if edge_channel_id in self.blacklist: continue + if edge_channel_id in self.blacklist: + continue channel_info = self.channel_db.get_channel_info(edge_channel_id) - edge_startnode = bfh(channel_info.node2_id) if bfh(channel_info.node1_id) == edge_endnode else bfh(channel_info.node1_id) + edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id inspect_edge() else: return None # no path found diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py @@ -70,11 +70,11 @@ class SweepStore(SqlDB): @sql def get_tx_by_index(self, funding_outpoint, index): r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none() - return r.prevout, bh2u(r.tx) + return str(r.prevout), bh2u(r.tx) @sql def list_sweep_tx(self): - return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all()) + return set(str(r.funding_outpoint) for r in self.DBSession.query(SweepTx).all()) @sql def add_sweep_tx(self, funding_outpoint, prevout, tx): @@ -84,7 +84,7 @@ class SweepStore(SqlDB): @sql def get_num_tx(self, funding_outpoint): - return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count() + return int(self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()) @sql def remove_sweep_tx(self, funding_outpoint): @@ -111,11 +111,11 @@ class SweepStore(SqlDB): @sql def get_address(self, outpoint): r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none() - return r.address if r else None + return str(r.address) if r else None @sql def list_channel_info(self): - return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()] + return [(str(r.address), str(r.outpoint)) for r in self.DBSession.query(ChannelInfo).all()] class LNWatcher(AddressSynchronizer): @@ -150,14 +150,21 @@ class LNWatcher(AddressSynchronizer): self.watchtower_queue = asyncio.Queue() def get_num_tx(self, outpoint): - return self.sweepstore.get_num_tx(outpoint) + async def f(): + return await self.sweepstore.get_num_tx(outpoint) + return self.network.run_from_another_thread(f()) + + def list_sweep_tx(self): + async def f(): + return await self.sweepstore.list_sweep_tx() + return self.network.run_from_another_thread(f()) @ignore_exceptions @log_exceptions async def watchtower_task(self): self.logger.info('watchtower task started') # initial check - for address, outpoint in self.sweepstore.list_channel_info(): + for address, outpoint in await self.sweepstore.list_channel_info(): await self.watchtower_queue.put(outpoint) while True: outpoint = await self.watchtower_queue.get() @@ -165,30 +172,30 @@ class LNWatcher(AddressSynchronizer): continue # synchronize with remote try: - local_n = self.sweepstore.get_num_tx(outpoint) + local_n = await self.sweepstore.get_num_tx(outpoint) n = self.watchtower.get_num_tx(outpoint) if n == 0: - address = self.sweepstore.get_address(outpoint) + address = await self.sweepstore.get_address(outpoint) self.watchtower.add_channel(outpoint, address) self.logger.info("sending %d transactions to watchtower"%(local_n - n)) for index in range(n, local_n): - prevout, tx = self.sweepstore.get_tx_by_index(outpoint, index) + prevout, tx = await self.sweepstore.get_tx_by_index(outpoint, index) self.watchtower.add_sweep_tx(outpoint, prevout, tx) except ConnectionRefusedError: self.logger.info('could not reach watchtower, will retry in 5s') await asyncio.sleep(5) await self.watchtower_queue.put(outpoint) - def add_channel(self, outpoint, address): + async def add_channel(self, outpoint, address): self.add_address(address) with self.lock: - if not self.sweepstore.has_channel(outpoint): - self.sweepstore.add_channel(outpoint, address) + if not await self.sweepstore.has_channel(outpoint): + await self.sweepstore.add_channel(outpoint, address) - def unwatch_channel(self, address, funding_outpoint): + async def unwatch_channel(self, address, funding_outpoint): self.logger.info(f'unwatching {funding_outpoint}') - self.sweepstore.remove_sweep_tx(funding_outpoint) - self.sweepstore.remove_channel(funding_outpoint) + await self.sweepstore.remove_sweep_tx(funding_outpoint) + await self.sweepstore.remove_channel(funding_outpoint) if funding_outpoint in self.tx_progress: self.tx_progress[funding_outpoint].all_done.set() @@ -202,7 +209,7 @@ class LNWatcher(AddressSynchronizer): return if not self.synchronizer.is_up_to_date(): return - for address, outpoint in self.sweepstore.list_channel_info(): + for address, outpoint in await self.sweepstore.list_channel_info(): await self.check_onchain_situation(address, outpoint) async def check_onchain_situation(self, address, funding_outpoint): @@ -223,7 +230,7 @@ class LNWatcher(AddressSynchronizer): closing_height, closing_tx) # FIXME sooo many args.. await self.do_breach_remedy(funding_outpoint, spenders) if not keep_watching: - self.unwatch_channel(address, funding_outpoint) + await self.unwatch_channel(address, funding_outpoint) else: #self.logger.info(f'we will keep_watching {funding_outpoint}') pass @@ -260,7 +267,7 @@ class LNWatcher(AddressSynchronizer): for prevout, spender in spenders.items(): if spender is not None: continue - sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prevout) + sweep_txns = await self.sweepstore.get_sweep_tx(funding_outpoint, prevout) for tx in sweep_txns: if not await self.broadcast_or_log(funding_outpoint, tx): self.logger.info(f'{tx.name} could not publish tx: {str(tx)}, prevout: {prevout}') @@ -279,8 +286,8 @@ class LNWatcher(AddressSynchronizer): await self.tx_progress[funding_outpoint].tx_queue.put(tx) return txid - def add_sweep_tx(self, funding_outpoint: str, prevout: str, tx: str): - self.sweepstore.add_sweep_tx(funding_outpoint, prevout, tx) + async def add_sweep_tx(self, funding_outpoint: str, prevout: str, tx: str): + await self.sweepstore.add_sweep_tx(funding_outpoint, prevout, tx) if self.watchtower: self.watchtower_queue.put_nowait(funding_outpoint) diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -108,12 +108,14 @@ class LNWorker(Logger): @log_exceptions async def main_loop(self): + # fixme: only lngossip should do that + await self.channel_db.load_data() while True: await asyncio.sleep(1) now = time.time() if len(self.peers) >= NUM_PEERS_TARGET: continue - peers = self._get_next_peers_to_try() + peers = await self._get_next_peers_to_try() for peer in peers: last_tried = self._last_tried_peer.get(peer, 0) if last_tried + PEER_RETRY_INTERVAL < now: @@ -130,7 +132,8 @@ class LNWorker(Logger): peer = Peer(self, node_id, transport) await self.network.main_taskgroup.spawn(peer.main_loop()) self.peers[node_id] = peer - self.network.lngossip.refresh_gui() + #if self.network.lngossip: + # self.network.lngossip.refresh_gui() return peer def start_network(self, network: 'Network'): @@ -148,7 +151,7 @@ class LNWorker(Logger): self._add_peer(host, int(port), bfh(pubkey)), self.network.asyncio_loop) - def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]: + async def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]: now = time.time() recent_peers = self.channel_db.get_recent_peers() # maintenance for last tried times @@ -158,19 +161,22 @@ class LNWorker(Logger): del self._last_tried_peer[peer] # first try from recent peers for peer in recent_peers: - if peer.pubkey in self.peers: continue - if peer in self._last_tried_peer: continue + if peer.pubkey in self.peers: + continue + if peer in self._last_tried_peer: + continue return [peer] # try random peer from graph unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys()) if unconnected_nodes: - for node in unconnected_nodes: - addrs = self.channel_db.get_node_addresses(node) + for node_id in unconnected_nodes: + addrs = self.channel_db.get_node_addresses(node_id) if not addrs: continue - host, port = self.choose_preferred_address(addrs) - peer = LNPeerAddr(host, port, bytes.fromhex(node.node_id)) - if peer in self._last_tried_peer: continue + host, port, timestamp = self.choose_preferred_address(addrs) + peer = LNPeerAddr(host, port, node_id) + if peer in self._last_tried_peer: + continue #self.logger.info('taking random ln peer from our channel db') return [peer] @@ -223,15 +229,13 @@ class LNWorker(Logger): def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]: assert len(addr_list) >= 1 # choose first one that is an IP - for addr_in_db in addr_list: - host = addr_in_db.host - port = addr_in_db.port + for host, port, timestamp in addr_list: if is_ip_address(host): - return host, port + return host, port, timestamp # otherwise choose one at random # TODO maybe filter out onion if not on tor? choice = random.choice(addr_list) - return choice.host, choice.port + return choice class LNGossip(LNWorker): @@ -260,26 +264,19 @@ class LNGossip(LNWorker): self.network.trigger_callback('ln_status', num_peers, num_nodes, known, unknown) async def maintain_db(self): - n = self.channel_db.get_orphaned_channels() - if n: - self.logger.info(f'Deleting {n} orphaned channels') - self.channel_db.prune_orphaned_channels() - self.refresh_gui() + self.channel_db.prune_orphaned_channels() while True: - n = self.channel_db.get_old_policies(self.max_age) - if n: - self.logger.info(f'Deleting {n} old channels') - self.channel_db.prune_old_policies(self.max_age) - self.refresh_gui() + self.channel_db.prune_old_policies(self.max_age) + self.refresh_gui() await asyncio.sleep(5) - def add_new_ids(self, ids): - known = self.channel_db.known_ids() + async def add_new_ids(self, ids): + known = self.channel_db.get_channel_ids() new = set(ids) - set(known) self.unknown_ids.update(new) def get_ids_to_query(self): - N = 500 + N = 100 l = list(self.unknown_ids) self.unknown_ids = set(l[N:]) return l[0:N] @@ -324,9 +321,10 @@ class LNWallet(LNWorker): self.network.register_callback(self.on_network_update, ['wallet_updated', 'network_updated', 'verified', 'fee']) # thread safe self.network.register_callback(self.on_channel_open, ['channel_open']) self.network.register_callback(self.on_channel_closed, ['channel_closed']) + for chan_id, chan in self.channels.items(): - self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) - chan.lnwatcher = network.lnwatcher + self.network.lnwatcher.add_address(chan.get_funding_address()) + super().start_network(network) for coro in [ self.maybe_listen(), @@ -494,7 +492,7 @@ class LNWallet(LNWorker): chan = self.channel_by_txo(funding_outpoint) if not chan: return - self.logger.debug(f'on_channel_open {funding_outpoint}') + #self.logger.debug(f'on_channel_open {funding_outpoint}') self.channel_timestamps[bh2u(chan.channel_id)] = funding_txid, funding_height.height, funding_height.timestamp, None, None, None self.storage.put('lightning_channel_timestamps', self.channel_timestamps) chan.set_funding_txo_spentness(False) @@ -606,7 +604,8 @@ class LNWallet(LNWorker): self.logger.info('REBROADCASTING CLOSING TX') await self.force_close_channel(chan.channel_id) - async def _open_channel_coroutine(self, peer, local_amount_sat, push_sat, password): + async def _open_channel_coroutine(self, connect_str, local_amount_sat, push_sat, password): + peer = await self.add_peer(connect_str) # peer might just have been connected to await asyncio.wait_for(peer.initialized.wait(), 5) chan = await peer.channel_establishment_flow( @@ -615,24 +614,22 @@ class LNWallet(LNWorker): push_msat=push_sat * 1000, temp_channel_id=os.urandom(32)) self.save_channel(chan) - self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) + self.network.lnwatcher.add_address(chan.get_funding_address()) + await self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) self.on_channels_updated() return chan def on_channels_updated(self): self.network.trigger_callback('channels') - def add_peer(self, connect_str, timeout=20): + async def add_peer(self, connect_str, timeout=20): node_id, rest = extract_nodeid(connect_str) peer = self.peers.get(node_id) if not peer: if rest is not None: host, port = split_host_port(rest) else: - node_info = self.network.channel_db.nodes_get(node_id) - if not node_info: - raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id)) - addrs = self.channel_db.get_node_addresses(node_info) + addrs = self.channel_db.get_node_addresses(node_id) if len(addrs) == 0: raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id)) host, port = self.choose_preferred_address(addrs) @@ -640,18 +637,12 @@ class LNWallet(LNWorker): socket.getaddrinfo(host, int(port)) except socket.gaierror: raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)')) - peer_future = asyncio.run_coroutine_threadsafe( - self._add_peer(host, port, node_id), - self.network.asyncio_loop) - try: - peer = peer_future.result(timeout) - except concurrent.futures.TimeoutError: - raise Exception(_("add_peer timed out")) + # add peer + peer = await self._add_peer(host, port, node_id) return peer def open_channel(self, connect_str, local_amt_sat, push_amt_sat, password=None, timeout=20): - peer = self.add_peer(connect_str, timeout) - coro = self._open_channel_coroutine(peer, local_amt_sat, push_amt_sat, password) + coro = self._open_channel_coroutine(connect_str, local_amt_sat, push_amt_sat, password) fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) try: chan = fut.result(timeout=timeout) @@ -664,6 +655,9 @@ class LNWallet(LNWorker): Can be called from other threads Raises timeout exception if htlc is not fulfilled """ + addr = self._check_invoice(invoice, amount_sat) + self.save_invoice(addr.paymenthash, invoice, SENT, is_paid=False) + self.wallet.set_label(bh2u(addr.paymenthash), addr.get_description()) fut = asyncio.run_coroutine_threadsafe( self._pay(invoice, attempts, amount_sat), self.network.asyncio_loop) @@ -680,8 +674,6 @@ class LNWallet(LNWorker): async def _pay(self, invoice, attempts=1, amount_sat=None): addr = self._check_invoice(invoice, amount_sat) - self.save_invoice(addr.paymenthash, invoice, SENT, is_paid=False) - self.wallet.set_label(bh2u(addr.paymenthash), addr.get_description()) for i in range(attempts): route = await self._create_route_from_invoice(decoded_invoice=addr) if not self.get_channel_by_short_id(route[0].short_channel_id): @@ -691,7 +683,7 @@ class LNWallet(LNWorker): return True return False - async def _pay_to_route(self, route, addr, pay_req): + async def _pay_to_route(self, route, addr, invoice): short_channel_id = route[0].short_channel_id chan = self.get_channel_by_short_id(short_channel_id) if not chan: @@ -713,6 +705,9 @@ class LNWallet(LNWorker): raise InvoiceError("{}\n{}".format( _("Invoice wants us to risk locking funds for unreasonably long."), f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}")) + #now = int(time.time()) + #if addr.date + addr.get_expiry() > now: + # raise InvoiceError(_('Invoice expired')) return addr async def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]: @@ -730,11 +725,14 @@ class LNWallet(LNWorker): with self.lock: channels = list(self.channels.values()) for private_route in r_tags: - if len(private_route) == 0: continue - if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH: continue + if len(private_route) == 0: + continue + if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH: + continue border_node_pubkey = private_route[0][0] path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels) - if not path: continue + if not path: + continue route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey) # we need to shift the node pubkey by one towards the destination: private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey] @@ -770,10 +768,18 @@ class LNWallet(LNWorker): return route def add_invoice(self, amount_sat, message): + coro = self._add_invoice_coro(amount_sat, message) + fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) + try: + return fut.result(timeout=5) + except concurrent.futures.TimeoutError: + raise Exception(_("add_invoice timed out")) + + async def _add_invoice_coro(self, amount_sat, message): payment_preimage = os.urandom(32) payment_hash = sha256(payment_preimage) amount_btc = amount_sat/Decimal(COIN) if amount_sat else None - routing_hints = self._calc_routing_hints_for_invoice(amount_sat) + routing_hints = await self._calc_routing_hints_for_invoice(amount_sat) if not routing_hints: self.logger.info("Warning. No routing hints added to invoice. " "Other clients will likely not be able to send to us.") @@ -847,19 +853,20 @@ class LNWallet(LNWorker): }) return out - def _calc_routing_hints_for_invoice(self, amount_sat): + async def _calc_routing_hints_for_invoice(self, amount_sat): """calculate routing hints (BOLT-11 'r' field)""" - self.channel_db.load_data() routing_hints = [] with self.lock: channels = list(self.channels.values()) # note: currently we add *all* our channels; but this might be a privacy leak? for chan in channels: # check channel is open - if chan.get_state() != "OPEN": continue + if chan.get_state() != "OPEN": + continue # check channel has sufficient balance # FIXME because of on-chain fees of ctx, this check is insufficient - if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat: continue + if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat: + continue chan_id = chan.short_channel_id assert type(chan_id) is bytes, chan_id channel_info = self.channel_db.get_channel_info(chan_id) @@ -949,14 +956,10 @@ class LNWallet(LNWorker): await self._add_peer(peer.host, peer.port, peer.pubkey) return # try random address for node_id - node_info = self.channel_db.nodes_get(chan.node_id) - if not node_info: - return - addresses = self.channel_db.get_node_addresses(node_info) + addresses = self.channel_db.get_node_addresses(chan.node_id) if not addresses: return - adr_obj = random.choice(addresses) - host, port = adr_obj.host, adr_obj.port + host, port, t = random.choice(list(addresses)) peer = LNPeerAddr(host, port, chan.node_id) last_tried = self._last_tried_peer.get(peer, 0) if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now: diff --git a/electrum/sql_db.py b/electrum/sql_db.py @@ -2,6 +2,7 @@ import os import concurrent import queue import threading +import asyncio from sqlalchemy import create_engine from sqlalchemy.pool import StaticPool @@ -18,28 +19,32 @@ def sql(func): """wrapper for sql methods""" def wrapper(self, *args, **kwargs): assert threading.currentThread() != self.sql_thread - f = concurrent.futures.Future() + f = asyncio.Future() self.db_requests.put((f, func, args, kwargs)) - return f.result(timeout=10) + return f return wrapper class SqlDB(Logger): - def __init__(self, network, path, base): + def __init__(self, network, path, base, commit_interval=None): Logger.__init__(self) self.base = base self.network = network self.path = path + self.commit_interval = commit_interval self.db_requests = queue.Queue() self.sql_thread = threading.Thread(target=self.run_sql) self.sql_thread.start() def run_sql(self): + #return + self.logger.info("SQL thread started") engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True) DBSession = sessionmaker(bind=engine, autoflush=False) - self.DBSession = DBSession() if not os.path.exists(self.path): self.base.metadata.create_all(engine) + self.DBSession = DBSession() + i = 0 while self.network.asyncio_loop.is_running(): try: future, func, args, kwargs = self.db_requests.get(timeout=0.1) @@ -50,7 +55,14 @@ class SqlDB(Logger): except BaseException as e: future.set_exception(e) continue - future.set_result(result) + if not future.cancelled(): + future.set_result(result) + # note: in sweepstore session.commit() is called inside + # the sql-decorated methods, so commiting to disk is awaited + if self.commit_interval: + i = (i + 1) % self.commit_interval + if i == 0: + self.DBSession.commit() # write self.DBSession.commit() self.logger.info("SQL thread terminated") diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py @@ -16,7 +16,8 @@ from electrum.lnpeer import Peer from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import PaymentFailure, LnLocalFeatures -from electrum.lnrouter import ChannelDB, LNPathFinder +from electrum.lnrouter import LNPathFinder +from electrum.channel_db import ChannelDB from electrum.lnworker import LNWallet from electrum.lnmsg import encode_msg, decode_msg from electrum.logging import console_stderr_handler diff --git a/electrum/tests/test_lnrouter.py b/electrum/tests/test_lnrouter.py @@ -59,33 +59,33 @@ class Test_LNRouter(TestCaseForTestnet): cdb = fake_network.channel_db path_finder = lnrouter.LNPathFinder(cdb) self.assertEqual(cdb.num_channels, 0) - cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02cccccccccccccccccccccccccccccccc', + cdb.add_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02cccccccccccccccccccccccccccccccc', 'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02cccccccccccccccccccccccccccccccc', 'short_channel_id': bfh('0000000000000001'), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'len': b'\x00\x00', 'features': b''}, trusted=True) self.assertEqual(cdb.num_channels, 1) - cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', + cdb.add_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 'short_channel_id': bfh('0000000000000002'), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'len': b'\x00\x00', 'features': b''}, trusted=True) - cdb.on_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', + cdb.add_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bitcoin_key_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'short_channel_id': bfh('0000000000000003'), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'len': b'\x00\x00', 'features': b''}, trusted=True) - cdb.on_channel_announcement({'node_id_1': b'\x02cccccccccccccccccccccccccccccccc', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd', + cdb.add_channel_announcement({'node_id_1': b'\x02cccccccccccccccccccccccccccccccc', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd', 'bitcoin_key_1': b'\x02cccccccccccccccccccccccccccccccc', 'bitcoin_key_2': b'\x02dddddddddddddddddddddddddddddddd', 'short_channel_id': bfh('0000000000000004'), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'len': b'\x00\x00', 'features': b''}, trusted=True) - cdb.on_channel_announcement({'node_id_1': b'\x02dddddddddddddddddddddddddddddddd', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', + cdb.add_channel_announcement({'node_id_1': b'\x02dddddddddddddddddddddddddddddddd', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 'bitcoin_key_1': b'\x02dddddddddddddddddddddddddddddddd', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 'short_channel_id': bfh('0000000000000005'), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'len': b'\x00\x00', 'features': b''}, trusted=True) - cdb.on_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd', + cdb.add_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd', 'bitcoin_key_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bitcoin_key_2': b'\x02dddddddddddddddddddddddddddddddd', 'short_channel_id': bfh('0000000000000006'), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),