electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 53564f249646ed10d1b75c90caac691c8dd077df
parent cdb72509a7a1db6bdc8305582ffe917c3f59d2a7
Author: SomberNight <somber.night@protonmail.com>
Date:   Tue,  3 Mar 2020 02:15:32 +0100

ChannelDB: rm NodeAddress class, just use LNPeerAddr

Diffstat:
Melectrum/channel_db.py | 36+++++++++++++++++-------------------
1 file changed, 17 insertions(+), 19 deletions(-)

diff --git a/electrum/channel_db.py b/electrum/channel_db.py @@ -138,22 +138,28 @@ class NodeInfo(NamedTuple): alias: str @staticmethod - def from_msg(payload) -> Tuple['NodeInfo', Sequence['NodeAddress']]: + def from_msg(payload) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]: node_id = payload['node_id'] features = int.from_bytes(payload['features'], "big") validate_features(features) addresses = NodeInfo.parse_addresses_field(payload['addresses']) + peer_addrs = [] + for host, port in addresses: + try: + peer_addrs.append(LNPeerAddr(host=host, port=port, pubkey=node_id)) + except ValueError: + pass alias = payload['alias'].rstrip(b'\x00') try: alias = alias.decode('utf8') except: alias = '' timestamp = int.from_bytes(payload['timestamp'], "big") - return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [ - NodeAddress(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses] + node_info = NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias) + return node_info, peer_addrs @staticmethod - def from_raw_msg(raw: bytes) -> Tuple['NodeInfo', Sequence['NodeAddress']]: + def from_raw_msg(raw: bytes) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]: payload_dict = decode_msg(raw)[1] return NodeInfo.from_msg(payload_dict) @@ -198,13 +204,6 @@ class NodeInfo(NamedTuple): return addresses -class NodeAddress(NamedTuple): - node_id: bytes - host: str - port: int - last_connected_date: Optional[int] - - class CategorizedChannelUpdates(NamedTuple): orphaned: List # no channel announcement for channel update expired: List # update older than two weeks @@ -295,7 +294,7 @@ class ChannelDB(SqlDB): self._recent_peers.remove(node_id) self._recent_peers.insert(0, node_id) self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS] - self.save_node_address(node_id, peer, now) + self.save_node_address(peer, now) def get_200_randomly_sorted_nodes_not_in(self, node_ids): with self.lock: @@ -473,18 +472,18 @@ class ChannelDB(SqlDB): c.execute("REPLACE INTO node_info (node_id, msg) VALUES (?,?)", [node_id, msg]) @sql - def save_node_address(self, node_id, peer, now): + def save_node_address(self, peer: LNPeerAddr, now): c = self.conn.cursor() - c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (node_id, peer.host, peer.port, now)) + c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (peer.pubkey, peer.host, peer.port, now)) @sql - def save_node_addresses(self, node_id, node_addresses): + def save_node_addresses(self, node_addresses: Sequence[LNPeerAddr]): c = self.conn.cursor() for addr in node_addresses: - c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.node_id, addr.host, addr.port)) + c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.pubkey, addr.host, addr.port)) r = c.fetchall() if r == []: - c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.node_id, addr.host, addr.port, 0)) + c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.pubkey, addr.host, addr.port, 0)) def verify_channel_update(self, payload): short_channel_id = payload['short_channel_id'] @@ -497,7 +496,6 @@ class ChannelDB(SqlDB): def add_node_announcement(self, msg_payloads): if type(msg_payloads) is dict: msg_payloads = [msg_payloads] - old_addr = None new_nodes = {} for msg_payload in msg_payloads: try: @@ -523,7 +521,7 @@ class ChannelDB(SqlDB): with self.lock: for addr in node_addresses: self._addresses[node_id].add((addr.host, addr.port, 0)) - self.save_node_addresses(node_id, node_addresses) + self.save_node_addresses(node_addresses) self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) self.update_counts()