commit 3442e51fac536067ed5bd62091d555ecdcae092e
parent 945e1dc4ee8df0b32a92f698f212e05deb539a30
Author: Janus <ysangkok@gmail.com>
Date: Wed, 20 Feb 2019 21:06:37 +0100
sqlite in lnrouter: remove useless InDB suffix
Diffstat:
1 file changed, 32 insertions(+), 32 deletions(-)
diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
@@ -75,7 +75,7 @@ DBSession = scoped_session(session_factory)
FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0
-class ChannelInfoInDB(Base):
+class ChannelInfo(Base):
__tablename__ = 'channel_info'
short_channel_id = Column(String(64), primary_key=True)
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
@@ -98,7 +98,7 @@ class ChannelInfoInDB(Base):
capacity_sat = None
- return ChannelInfoInDB(short_channel_id = channel_id, node1_id = node_id_1,
+ return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
trusted = False)
@@ -186,7 +186,7 @@ class Policy(Base):
def is_disabled(self):
return self.channel_flags & FLAG_DISABLE
-class NodeInfoInDB(Base):
+class NodeInfo(Base):
__tablename__ = 'node_info'
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
features = Column(Integer, nullable=False)
@@ -194,7 +194,7 @@ class NodeInfoInDB(Base):
alias = Column(String(64), nullable=False)
def get_addresses(self):
- return DBSession.query(AddressInDB).join(NodeInfoInDB).filter_by(node_id = self.node_id).all()
+ return DBSession.query(Address).join(NodeInfo).filter_by(node_id = self.node_id).all()
@staticmethod
def from_msg(node_announcement_payload, addresses_already_parsed=False):
@@ -202,12 +202,12 @@ class NodeInfoInDB(Base):
features = int.from_bytes(node_announcement_payload['features'], "big")
validate_features(features)
if not addresses_already_parsed:
- addresses = NodeInfoInDB.parse_addresses_field(node_announcement_payload['addresses'])
+ addresses = NodeInfo.parse_addresses_field(node_announcement_payload['addresses'])
else:
addresses = node_announcement_payload['addresses']
alias = node_announcement_payload['alias'].rstrip(b'\x00').hex()
timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big"))
- return NodeInfoInDB(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [AddressInDB(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses]
+ return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [Address(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses]
@staticmethod
def parse_addresses_field(addresses_field):
@@ -249,7 +249,7 @@ class NodeInfoInDB(Base):
break
return addresses
-class AddressInDB(Base):
+class Address(Base):
__tablename__ = 'address'
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
host = Column(String(256), primary_key=True)
@@ -288,13 +288,13 @@ class ChannelDB:
Base.metadata.create_all(engine)
def update_counts(self):
- self.num_channels = DBSession.query(ChannelInfoInDB).count()
- self.num_nodes = DBSession.query(NodeInfoInDB).count()
+ self.num_channels = DBSession.query(ChannelInfo).count()
+ self.num_nodes = DBSession.query(NodeInfo).count()
def add_recent_peer(self, peer : LNPeerAddr):
- addr = DBSession.query(AddressInDB).filter_by(node_id = peer.pubkey.hex()).one_or_none()
+ addr = DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none()
if addr is None:
- addr = AddressInDB(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
+ addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
else:
addr.last_connected_date = datetime.datetime.now()
DBSession.add(addr)
@@ -302,8 +302,8 @@ class ChannelDB:
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
unshuffled = DBSession \
- .query(NodeInfoInDB) \
- .filter(not_(NodeInfoInDB.node_id.in_(x.hex() for x in node_ids_bytes))) \
+ .query(NodeInfo) \
+ .filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \
.limit(200) \
.all()
return random.sample(unshuffled, len(unshuffled))
@@ -313,15 +313,15 @@ class ChannelDB:
async def _nodes_get(self, node_id):
return DBSession \
- .query(NodeInfoInDB) \
+ .query(NodeInfo) \
.filter_by(node_id = node_id.hex()) \
.one_or_none()
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
adr_db = DBSession \
- .query(AddressInDB) \
+ .query(Address) \
.filter_by(node_id = node_id.hex()) \
- .order_by(AddressInDB.last_connected_date.desc()) \
+ .order_by(Address.last_connected_date.desc()) \
.one_or_none()
if not adr_db:
return None
@@ -329,9 +329,9 @@ class ChannelDB:
def get_recent_peers(self):
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in DBSession \
- .query(AddressInDB) \
- .select_from(NodeInfoInDB) \
- .order_by(AddressInDB.last_connected_date.desc()) \
+ .query(Address) \
+ .select_from(NodeInfo) \
+ .order_by(Address.last_connected_date.desc()) \
.limit(self.NUM_MAX_RECENT_PEERS)]
def get_channel_info(self, channel_id: bytes):
@@ -340,13 +340,13 @@ class ChannelDB:
def get_channels_for_node(self, node_id):
"""Returns the set of channels that have node_id as one of the endpoints."""
condition = or_(
- ChannelInfoInDB.node1_id == node_id.hex(),
- ChannelInfoInDB.node2_id == node_id.hex())
- rows = DBSession.query(ChannelInfoInDB).filter(condition).all()
+ ChannelInfo.node1_id == node_id.hex(),
+ ChannelInfo.node2_id == node_id.hex())
+ rows = DBSession.query(ChannelInfo).filter(condition).all()
return [bytes.fromhex(x.short_channel_id) for x in rows]
def missing_short_chan_ids(self) -> Set[int]:
- expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfoInDB.short_channel_id)))
+ expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfo.short_channel_id)))
return set(DBSession.query(Policy.short_channel_id).filter(expr).all())
def add_verified_channel_info(self, short_id, capacity):
@@ -362,13 +362,13 @@ class ChannelDB:
msg_payloads = [msg_payloads]
for msg in msg_payloads:
short_channel_id = msg['short_channel_id']
- if DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = bh2u(short_channel_id)).count():
+ if DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count():
continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
continue
try:
- channel_info = ChannelInfoInDB.from_msg(msg)
+ channel_info = ChannelInfo.from_msg(msg)
except UnknownEvenFeatureBits:
continue
channel_info.trusted = trusted
@@ -383,7 +383,7 @@ class ChannelDB:
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
- channel_infos_list = DBSession.query(ChannelInfoInDB).filter(ChannelInfoInDB.short_channel_id.in_(short_channel_ids)).all()
+ channel_infos_list = DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
for msg_payload in msg_payloads:
short_channel_id = msg_payload['short_channel_id']
@@ -397,12 +397,12 @@ class ChannelDB:
def on_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
- addresses = DBSession.query(AddressInDB).all()
+ addresses = DBSession.query(Address).all()
have_addr = {}
for addr in addresses:
have_addr[(addr.node_id, addr.host, addr.port)] = addr
- nodes = DBSession.query(NodeInfoInDB).all()
+ nodes = DBSession.query(NodeInfo).all()
timestamps = {}
for node in nodes:
no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")]
@@ -415,7 +415,7 @@ class ChannelDB:
if not ecc.verify_signature(pubkey, signature, h):
continue
try:
- new_node_info, addresses = NodeInfoInDB.from_msg(msg_payload)
+ new_node_info, addresses = NodeInfo.from_msg(msg_payload)
except UnknownEvenFeatureBits:
continue
if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp:
@@ -464,7 +464,7 @@ class ChannelDB:
DBSession.commit()
def chan_query_for_id(self, short_channel_id) -> Query:
- return DBSession.query(ChannelInfoInDB).filter_by(short_channel_id = short_channel_id.hex())
+ return DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex())
def print_graph(self, full_ids=False):
# used for debugging.
@@ -478,11 +478,11 @@ class ChannelDB:
return other if full_ids else other[-4:]
self.print_msg('nodes')
- for node in DBSession.query(NodeInfoInDB).all():
+ for node in DBSession.query(NodeInfo).all():
self.print_msg(node)
self.print_msg('channels')
- for channel_info in DBSession.query(ChannelInfoInDB).all():
+ for channel_info in DBSession.query(ChannelInfo).all():
node1 = channel_info.node1_id
node2 = channel_info.node2_id
direction1 = channel_info.get_policy_for_node(node1) is not None