electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 9f188c087c379078443cfb589a1b5fbd0146dd21
parent 95a217478932b75503732ce4864621b7112629c1
Author: ThomasV <thomasv@electrum.org>
Date:   Tue,  5 Mar 2019 11:22:00 +0100

Flatten the structure of lnrouter, so that DBSession is not used outside of ChannelDB

Diffstat:
Melectrum/lnrouter.py | 162+++++++++++++++++++++++++++++++++++++++----------------------------------------
Melectrum/lnworker.py | 6+++---
2 files changed, 82 insertions(+), 86 deletions(-)

diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py @@ -70,7 +70,6 @@ def validate_features(features : int): Base = declarative_base() session_factory = sessionmaker() -DBSession = scoped_session(session_factory) FLAG_DISABLE = 1 << 1 FLAG_DIRECTION = 1 << 0 @@ -88,16 +87,12 @@ class ChannelInfo(Base): def from_msg(channel_announcement_payload): features = int.from_bytes(channel_announcement_payload['features'], 'big') validate_features(features) - channel_id = channel_announcement_payload['short_channel_id'].hex() node_id_1 = channel_announcement_payload['node_id_1'].hex() node_id_2 = channel_announcement_payload['node_id_2'].hex() assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] - msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex() - capacity_sat = None - return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1, node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, trusted = False) @@ -106,42 +101,6 @@ class ChannelInfo(Base): def msg_payload(self): return bytes.fromhex(self.msg_payload_hex) - def on_channel_update(self, msg: dict, trusted=False): - assert self.short_channel_id == msg['short_channel_id'].hex() - flags = int.from_bytes(msg['channel_flags'], 'big') - direction = flags & FLAG_DIRECTION - if direction == 0: - node_id = self.node1_id - else: - node_id = self.node2_id - new_policy = Policy.from_msg(msg, node_id, self.short_channel_id) - old_policy = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node=node_id).one_or_none() - if not old_policy: - DBSession.add(new_policy) - return - if old_policy.timestamp >= new_policy.timestamp: - return # ignore - if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)): - return # ignore - old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta - old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat - old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat - old_policy.fee_base_msat = new_policy.fee_base_msat - old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths - old_policy.channel_flags = new_policy.channel_flags - old_policy.timestamp = new_policy.timestamp - - def get_policy_for_node(self, node) -> Optional['Policy']: - """ - raises when initiator/non-initiator both unequal node - """ - if node.hex() not in (self.node1_id, self.node2_id): - raise Exception("the given node is not a party in this channel") - n1 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node1_id).one_or_none() - if n1: - return n1 - n2 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none() - return n2 class Policy(Base): __tablename__ = 'policy' @@ -193,9 +152,6 @@ class NodeInfo(Base): timestamp = Column(Integer, nullable=False) alias = Column(String(64), nullable=False) - def get_addresses(self): - return DBSession.query(Address).join(NodeInfo).filter_by(node_id = self.node_id).all() - @staticmethod def from_msg(node_announcement_payload, addresses_already_parsed=False): node_id = node_announcement_payload['node_id'].hex() @@ -281,27 +237,28 @@ class ChannelDB: the lnpeer loop is running from, which will do call in here """ engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True) - DBSession.remove() - DBSession.configure(bind=engine, autoflush=False) + self.DBSession = scoped_session(session_factory) + self.DBSession.remove() + self.DBSession.configure(bind=engine, autoflush=False) Base.metadata.drop_all(engine) Base.metadata.create_all(engine) def update_counts(self): - self.num_channels = DBSession.query(ChannelInfo).count() - self.num_nodes = DBSession.query(NodeInfo).count() + self.num_channels = self.DBSession.query(ChannelInfo).count() + self.num_nodes = self.DBSession.query(NodeInfo).count() def add_recent_peer(self, peer : LNPeerAddr): - addr = DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none() + addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none() if addr is None: addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now()) else: addr.last_connected_date = datetime.datetime.now() - DBSession.add(addr) - DBSession.commit() + self.DBSession.add(addr) + self.DBSession.commit() def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): - unshuffled = DBSession \ + unshuffled = self.DBSession \ .query(NodeInfo) \ .filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \ .limit(200) \ @@ -312,13 +269,13 @@ class ChannelDB: return self.network.run_from_another_thread(self._nodes_get(node_id)) async def _nodes_get(self, node_id): - return DBSession \ + return self.DBSession \ .query(NodeInfo) \ .filter_by(node_id = node_id.hex()) \ .one_or_none() def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: - adr_db = DBSession \ + adr_db = self.DBSession \ .query(Address) \ .filter_by(node_id = node_id.hex()) \ .order_by(Address.last_connected_date.desc()) \ @@ -328,7 +285,7 @@ class ChannelDB: return LNPeerAddr(adr_db.host, adr_db.port, bytes.fromhex(adr_db.node_id)) def get_recent_peers(self): - return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in DBSession \ + return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in self.DBSession \ .query(Address) \ .select_from(NodeInfo) \ .order_by(Address.last_connected_date.desc()) \ @@ -342,21 +299,21 @@ class ChannelDB: condition = or_( ChannelInfo.node1_id == node_id.hex(), ChannelInfo.node2_id == node_id.hex()) - rows = DBSession.query(ChannelInfo).filter(condition).all() + rows = self.DBSession.query(ChannelInfo).filter(condition).all() return [bytes.fromhex(x.short_channel_id) for x in rows] def missing_short_chan_ids(self) -> Set[int]: - expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfo.short_channel_id))) - chan_ids_from_policy = set(x[0] for x in DBSession.query(Policy.short_channel_id).filter(expr).all()) + expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id))) + chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all()) if chan_ids_from_policy: return chan_ids_from_policy # fetch channels for node_ids missing in node_info. that will also give us node_announcement - expr = not_(ChannelInfo.node1_id.in_(DBSession.query(NodeInfo.node_id))) - chan_ids_from_id1 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) + expr = not_(ChannelInfo.node1_id.in_(self.DBSession.query(NodeInfo.node_id))) + chan_ids_from_id1 = set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) if chan_ids_from_id1: return chan_ids_from_id1 - expr = not_(ChannelInfo.node2_id.in_(DBSession.query(NodeInfo.node_id))) - chan_ids_from_id2 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) + expr = not_(ChannelInfo.node2_id.in_(self.DBSession.query(NodeInfo.node_id))) + chan_ids_from_id2 = set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) if chan_ids_from_id2: return chan_ids_from_id2 return set() @@ -366,7 +323,7 @@ class ChannelDB: channel_info = self.get_channel_info(short_id) channel_info.trusted = True channel_info.capacity = capacity - DBSession.commit() + self.DBSession.commit() @profiler def on_channel_announcement(self, msg_payloads, trusted=False): @@ -374,7 +331,7 @@ class ChannelDB: msg_payloads = [msg_payloads] for msg in msg_payloads: short_channel_id = msg['short_channel_id'] - if DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count(): + if self.DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count(): continue if constants.net.rev_genesis_bytes() != msg['chain_hash']: #self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash']))) @@ -384,9 +341,9 @@ class ChannelDB: except UnknownEvenFeatureBits: continue channel_info.trusted = trusted - DBSession.add(channel_info) + self.DBSession.add(channel_info) if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload) - DBSession.commit() + self.DBSession.commit() self.network.trigger_callback('ln_status') self.update_counts() @@ -395,7 +352,7 @@ class ChannelDB: if type(msg_payloads) is dict: msg_payloads = [msg_payloads] short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads] - channel_infos_list = DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() + channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} for msg_payload in msg_payloads: short_channel_id = msg_payload['short_channel_id'] @@ -404,19 +361,19 @@ class ChannelDB: channel_info = channel_infos.get(short_channel_id) if not channel_info: continue - channel_info.on_channel_update(msg_payload, trusted=trusted) - DBSession.commit() + self._update_channel_info(channel_info, msg_payload, trusted=trusted) + self.DBSession.commit() @profiler def on_node_announcement(self, msg_payloads): if type(msg_payloads) is dict: msg_payloads = [msg_payloads] - addresses = DBSession.query(Address).all() + addresses = self.DBSession.query(Address).all() have_addr = {} for addr in addresses: have_addr[(addr.node_id, addr.host, addr.port)] = addr - nodes = DBSession.query(NodeInfo).all() + nodes = self.DBSession.query(NodeInfo).all() timestamps = {} for node in nodes: no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")] @@ -434,7 +391,7 @@ class ChannelDB: continue if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp: continue # ignore - DBSession.add(new_node_info) + self.DBSession.add(new_node_info) for new_addr in addresses: key = (new_addr.node_id, new_addr.host, new_addr.port) old_addr = have_addr.get(key) @@ -444,7 +401,7 @@ class ChannelDB: old_addr.last_connected_date = new_addr.last_connected_date del new_addr else: - DBSession.add(new_addr) + self.DBSession.add(new_addr) have_addr[key] = new_addr # TODO if this message is for a new node, and if we have no associated # channels for this node, we should ignore the message and return here, @@ -453,7 +410,7 @@ class ChannelDB: del nodes, addresses if old_addr: del old_addr - DBSession.commit() + self.DBSession.commit() self.network.trigger_callback('ln_status') self.update_counts() @@ -462,9 +419,10 @@ class ChannelDB: if not start_node_id or not short_channel_id: return None channel_info = self.get_channel_info(short_channel_id) if channel_info is not None: - return channel_info.get_policy_for_node(start_node_id) + return self.get_policy_for_node(channel_info, start_node_id) msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id)) - if not msg: return None + if not msg: + return None return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes): @@ -475,10 +433,10 @@ class ChannelDB: def remove_channel(self, short_channel_id): self.chan_query_for_id(short_channel_id).delete('evaluate') - DBSession.commit() + self.DBSession.commit() def chan_query_for_id(self, short_channel_id) -> Query: - return DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()) + return self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()) def print_graph(self, full_ids=False): # used for debugging. @@ -492,15 +450,15 @@ class ChannelDB: return other if full_ids else other[-4:] self.print_msg('nodes') - for node in DBSession.query(NodeInfo).all(): + for node in self.DBSession.query(NodeInfo).all(): self.print_msg(node) self.print_msg('channels') - for channel_info in DBSession.query(ChannelInfo).all(): + for channel_info in self.DBSession.query(ChannelInfo).all(): node1 = channel_info.node1_id node2 = channel_info.node2_id - direction1 = channel_info.get_policy_for_node(node1) is not None - direction2 = channel_info.get_policy_for_node(node2) is not None + direction1 = self.get_policy_for_node(channel_info, node1) is not None + direction2 = self.get_policy_for_node(channel_info, node2) is not None if direction1 and direction2: direction = 'both' elif direction1: @@ -515,6 +473,44 @@ class ChannelDB: bh2u(node2) if full_ids else bh2u(node2[-4:]), direction)) + def _update_channel_info(self, channel_info, msg: dict, trusted=False): + assert channel_info.short_channel_id == msg['short_channel_id'].hex() + flags = int.from_bytes(msg['channel_flags'], 'big') + direction = flags & FLAG_DIRECTION + node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id + new_policy = Policy.from_msg(msg, node_id, channel_info.short_channel_id) + old_policy = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node=node_id).one_or_none() + if not old_policy: + self.DBSession.add(new_policy) + return + if old_policy.timestamp >= new_policy.timestamp: + return # ignore + if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)): + return # ignore + old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta + old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat + old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat + old_policy.fee_base_msat = new_policy.fee_base_msat + old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths + old_policy.channel_flags = new_policy.channel_flags + old_policy.timestamp = new_policy.timestamp + + def get_policy_for_node(self, node) -> Optional['Policy']: + """ + raises when initiator/non-initiator both unequal node + """ + if node.hex() not in (self.node1_id, self.node2_id): + raise Exception("the given node is not a party in this channel") + n1 = self.DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node1_id).one_or_none() + if n1: + return n1 + n2 = self.DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none() + return n2 + + def get_node_addresses(self, node_info): + return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all() + + class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes), ('short_channel_id', bytes), @@ -596,7 +592,7 @@ class LNPathFinder(PrintError): if channel_info is None: return float('inf'), 0 - channel_policy = channel_info.get_policy_for_node(start_node) + channel_policy = self.channel_db.get_policy_for_node(channel_info, start_node) if channel_policy is None: return float('inf'), 0 if channel_policy.is_disabled(): return float('inf'), 0 route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node) diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -444,7 +444,7 @@ class LNWorker(PrintError): else: if not node_info: raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id)) - addrs = node_info.get_addresses() + addrs = self.channel_db.get_node_addresses(node_info) if len(addrs) == 0: raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id)) host, port = self.choose_preferred_address(addrs) @@ -710,7 +710,7 @@ class LNWorker(PrintError): unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys()) if unconnected_nodes: for node in unconnected_nodes: - addrs = node.get_addresses() + addrs = self.channel_db.get_node_addresses(node) if not addrs: continue host, port = self.choose_preferred_address(addrs) @@ -776,7 +776,7 @@ class LNWorker(PrintError): # try random address for node_id node_info = await self.channel_db._nodes_get(chan.node_id) if not node_info: return - addresses = node_info.get_addresses() + addresses = self.channel_db.get_node_addresses(node_info) if not addresses: return adr_obj = random.choice(addresses) host, port = adr_obj.host, adr_obj.port