commit 9f188c087c379078443cfb589a1b5fbd0146dd21
parent 95a217478932b75503732ce4864621b7112629c1
Author: ThomasV <thomasv@electrum.org>
Date: Tue, 5 Mar 2019 11:22:00 +0100
Flatten the structure of lnrouter, so that DBSession is not used outside of ChannelDB
Diffstat:
2 files changed, 82 insertions(+), 86 deletions(-)
diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
@@ -70,7 +70,6 @@ def validate_features(features : int):
Base = declarative_base()
session_factory = sessionmaker()
-DBSession = scoped_session(session_factory)
FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0
@@ -88,16 +87,12 @@ class ChannelInfo(Base):
def from_msg(channel_announcement_payload):
features = int.from_bytes(channel_announcement_payload['features'], 'big')
validate_features(features)
-
channel_id = channel_announcement_payload['short_channel_id'].hex()
node_id_1 = channel_announcement_payload['node_id_1'].hex()
node_id_2 = channel_announcement_payload['node_id_2'].hex()
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
-
msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex()
-
capacity_sat = None
-
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
trusted = False)
@@ -106,42 +101,6 @@ class ChannelInfo(Base):
def msg_payload(self):
return bytes.fromhex(self.msg_payload_hex)
- def on_channel_update(self, msg: dict, trusted=False):
- assert self.short_channel_id == msg['short_channel_id'].hex()
- flags = int.from_bytes(msg['channel_flags'], 'big')
- direction = flags & FLAG_DIRECTION
- if direction == 0:
- node_id = self.node1_id
- else:
- node_id = self.node2_id
- new_policy = Policy.from_msg(msg, node_id, self.short_channel_id)
- old_policy = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node=node_id).one_or_none()
- if not old_policy:
- DBSession.add(new_policy)
- return
- if old_policy.timestamp >= new_policy.timestamp:
- return # ignore
- if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)):
- return # ignore
- old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta
- old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat
- old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat
- old_policy.fee_base_msat = new_policy.fee_base_msat
- old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths
- old_policy.channel_flags = new_policy.channel_flags
- old_policy.timestamp = new_policy.timestamp
-
- def get_policy_for_node(self, node) -> Optional['Policy']:
- """
- raises when initiator/non-initiator both unequal node
- """
- if node.hex() not in (self.node1_id, self.node2_id):
- raise Exception("the given node is not a party in this channel")
- n1 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node1_id).one_or_none()
- if n1:
- return n1
- n2 = DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none()
- return n2
class Policy(Base):
__tablename__ = 'policy'
@@ -193,9 +152,6 @@ class NodeInfo(Base):
timestamp = Column(Integer, nullable=False)
alias = Column(String(64), nullable=False)
- def get_addresses(self):
- return DBSession.query(Address).join(NodeInfo).filter_by(node_id = self.node_id).all()
-
@staticmethod
def from_msg(node_announcement_payload, addresses_already_parsed=False):
node_id = node_announcement_payload['node_id'].hex()
@@ -281,27 +237,28 @@ class ChannelDB:
the lnpeer loop is running from, which will do call in here
"""
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
- DBSession.remove()
- DBSession.configure(bind=engine, autoflush=False)
+ self.DBSession = scoped_session(session_factory)
+ self.DBSession.remove()
+ self.DBSession.configure(bind=engine, autoflush=False)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
def update_counts(self):
- self.num_channels = DBSession.query(ChannelInfo).count()
- self.num_nodes = DBSession.query(NodeInfo).count()
+ self.num_channels = self.DBSession.query(ChannelInfo).count()
+ self.num_nodes = self.DBSession.query(NodeInfo).count()
def add_recent_peer(self, peer : LNPeerAddr):
- addr = DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none()
+ addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none()
if addr is None:
addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
else:
addr.last_connected_date = datetime.datetime.now()
- DBSession.add(addr)
- DBSession.commit()
+ self.DBSession.add(addr)
+ self.DBSession.commit()
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
- unshuffled = DBSession \
+ unshuffled = self.DBSession \
.query(NodeInfo) \
.filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \
.limit(200) \
@@ -312,13 +269,13 @@ class ChannelDB:
return self.network.run_from_another_thread(self._nodes_get(node_id))
async def _nodes_get(self, node_id):
- return DBSession \
+ return self.DBSession \
.query(NodeInfo) \
.filter_by(node_id = node_id.hex()) \
.one_or_none()
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
- adr_db = DBSession \
+ adr_db = self.DBSession \
.query(Address) \
.filter_by(node_id = node_id.hex()) \
.order_by(Address.last_connected_date.desc()) \
@@ -328,7 +285,7 @@ class ChannelDB:
return LNPeerAddr(adr_db.host, adr_db.port, bytes.fromhex(adr_db.node_id))
def get_recent_peers(self):
- return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in DBSession \
+ return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in self.DBSession \
.query(Address) \
.select_from(NodeInfo) \
.order_by(Address.last_connected_date.desc()) \
@@ -342,21 +299,21 @@ class ChannelDB:
condition = or_(
ChannelInfo.node1_id == node_id.hex(),
ChannelInfo.node2_id == node_id.hex())
- rows = DBSession.query(ChannelInfo).filter(condition).all()
+ rows = self.DBSession.query(ChannelInfo).filter(condition).all()
return [bytes.fromhex(x.short_channel_id) for x in rows]
def missing_short_chan_ids(self) -> Set[int]:
- expr = not_(Policy.short_channel_id.in_(DBSession.query(ChannelInfo.short_channel_id)))
- chan_ids_from_policy = set(x[0] for x in DBSession.query(Policy.short_channel_id).filter(expr).all())
+ expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
+ chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
if chan_ids_from_policy:
return chan_ids_from_policy
# fetch channels for node_ids missing in node_info. that will also give us node_announcement
- expr = not_(ChannelInfo.node1_id.in_(DBSession.query(NodeInfo.node_id)))
- chan_ids_from_id1 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
+ expr = not_(ChannelInfo.node1_id.in_(self.DBSession.query(NodeInfo.node_id)))
+ chan_ids_from_id1 = set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
if chan_ids_from_id1:
return chan_ids_from_id1
- expr = not_(ChannelInfo.node2_id.in_(DBSession.query(NodeInfo.node_id)))
- chan_ids_from_id2 = set(x[0] for x in DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
+ expr = not_(ChannelInfo.node2_id.in_(self.DBSession.query(NodeInfo.node_id)))
+ chan_ids_from_id2 = set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
if chan_ids_from_id2:
return chan_ids_from_id2
return set()
@@ -366,7 +323,7 @@ class ChannelDB:
channel_info = self.get_channel_info(short_id)
channel_info.trusted = True
channel_info.capacity = capacity
- DBSession.commit()
+ self.DBSession.commit()
@profiler
def on_channel_announcement(self, msg_payloads, trusted=False):
@@ -374,7 +331,7 @@ class ChannelDB:
msg_payloads = [msg_payloads]
for msg in msg_payloads:
short_channel_id = msg['short_channel_id']
- if DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count():
+ if self.DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count():
continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
#self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
@@ -384,9 +341,9 @@ class ChannelDB:
except UnknownEvenFeatureBits:
continue
channel_info.trusted = trusted
- DBSession.add(channel_info)
+ self.DBSession.add(channel_info)
if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
- DBSession.commit()
+ self.DBSession.commit()
self.network.trigger_callback('ln_status')
self.update_counts()
@@ -395,7 +352,7 @@ class ChannelDB:
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
- channel_infos_list = DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
+ channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
for msg_payload in msg_payloads:
short_channel_id = msg_payload['short_channel_id']
@@ -404,19 +361,19 @@ class ChannelDB:
channel_info = channel_infos.get(short_channel_id)
if not channel_info:
continue
- channel_info.on_channel_update(msg_payload, trusted=trusted)
- DBSession.commit()
+ self._update_channel_info(channel_info, msg_payload, trusted=trusted)
+ self.DBSession.commit()
@profiler
def on_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
- addresses = DBSession.query(Address).all()
+ addresses = self.DBSession.query(Address).all()
have_addr = {}
for addr in addresses:
have_addr[(addr.node_id, addr.host, addr.port)] = addr
- nodes = DBSession.query(NodeInfo).all()
+ nodes = self.DBSession.query(NodeInfo).all()
timestamps = {}
for node in nodes:
no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")]
@@ -434,7 +391,7 @@ class ChannelDB:
continue
if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp:
continue # ignore
- DBSession.add(new_node_info)
+ self.DBSession.add(new_node_info)
for new_addr in addresses:
key = (new_addr.node_id, new_addr.host, new_addr.port)
old_addr = have_addr.get(key)
@@ -444,7 +401,7 @@ class ChannelDB:
old_addr.last_connected_date = new_addr.last_connected_date
del new_addr
else:
- DBSession.add(new_addr)
+ self.DBSession.add(new_addr)
have_addr[key] = new_addr
# TODO if this message is for a new node, and if we have no associated
# channels for this node, we should ignore the message and return here,
@@ -453,7 +410,7 @@ class ChannelDB:
del nodes, addresses
if old_addr:
del old_addr
- DBSession.commit()
+ self.DBSession.commit()
self.network.trigger_callback('ln_status')
self.update_counts()
@@ -462,9 +419,10 @@ class ChannelDB:
if not start_node_id or not short_channel_id: return None
channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None:
- return channel_info.get_policy_for_node(start_node_id)
+ return self.get_policy_for_node(channel_info, start_node_id)
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
- if not msg: return None
+ if not msg:
+ return None
return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
@@ -475,10 +433,10 @@ class ChannelDB:
def remove_channel(self, short_channel_id):
self.chan_query_for_id(short_channel_id).delete('evaluate')
- DBSession.commit()
+ self.DBSession.commit()
def chan_query_for_id(self, short_channel_id) -> Query:
- return DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex())
+ return self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex())
def print_graph(self, full_ids=False):
# used for debugging.
@@ -492,15 +450,15 @@ class ChannelDB:
return other if full_ids else other[-4:]
self.print_msg('nodes')
- for node in DBSession.query(NodeInfo).all():
+ for node in self.DBSession.query(NodeInfo).all():
self.print_msg(node)
self.print_msg('channels')
- for channel_info in DBSession.query(ChannelInfo).all():
+ for channel_info in self.DBSession.query(ChannelInfo).all():
node1 = channel_info.node1_id
node2 = channel_info.node2_id
- direction1 = channel_info.get_policy_for_node(node1) is not None
- direction2 = channel_info.get_policy_for_node(node2) is not None
+ direction1 = self.get_policy_for_node(channel_info, node1) is not None
+ direction2 = self.get_policy_for_node(channel_info, node2) is not None
if direction1 and direction2:
direction = 'both'
elif direction1:
@@ -515,6 +473,44 @@ class ChannelDB:
bh2u(node2) if full_ids else bh2u(node2[-4:]),
direction))
+ def _update_channel_info(self, channel_info, msg: dict, trusted=False):
+ assert channel_info.short_channel_id == msg['short_channel_id'].hex()
+ flags = int.from_bytes(msg['channel_flags'], 'big')
+ direction = flags & FLAG_DIRECTION
+ node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id
+ new_policy = Policy.from_msg(msg, node_id, channel_info.short_channel_id)
+ old_policy = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node=node_id).one_or_none()
+ if not old_policy:
+ self.DBSession.add(new_policy)
+ return
+ if old_policy.timestamp >= new_policy.timestamp:
+ return # ignore
+ if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)):
+ return # ignore
+ old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta
+ old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat
+ old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat
+ old_policy.fee_base_msat = new_policy.fee_base_msat
+ old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths
+ old_policy.channel_flags = new_policy.channel_flags
+ old_policy.timestamp = new_policy.timestamp
+
+ def get_policy_for_node(self, node) -> Optional['Policy']:
+ """
+ raises when initiator/non-initiator both unequal node
+ """
+ if node.hex() not in (self.node1_id, self.node2_id):
+ raise Exception("the given node is not a party in this channel")
+ n1 = self.DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node1_id).one_or_none()
+ if n1:
+ return n1
+ n2 = self.DBSession.query(Policy).filter_by(short_channel_id = self.short_channel_id, start_node = self.node2_id).one_or_none()
+ return n2
+
+ def get_node_addresses(self, node_info):
+ return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()
+
+
class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
('short_channel_id', bytes),
@@ -596,7 +592,7 @@ class LNPathFinder(PrintError):
if channel_info is None:
return float('inf'), 0
- channel_policy = channel_info.get_policy_for_node(start_node)
+ channel_policy = self.channel_db.get_policy_for_node(channel_info, start_node)
if channel_policy is None: return float('inf'), 0
if channel_policy.is_disabled(): return float('inf'), 0
route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node)
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -444,7 +444,7 @@ class LNWorker(PrintError):
else:
if not node_info:
raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id))
- addrs = node_info.get_addresses()
+ addrs = self.channel_db.get_node_addresses(node_info)
if len(addrs) == 0:
raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id))
host, port = self.choose_preferred_address(addrs)
@@ -710,7 +710,7 @@ class LNWorker(PrintError):
unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys())
if unconnected_nodes:
for node in unconnected_nodes:
- addrs = node.get_addresses()
+ addrs = self.channel_db.get_node_addresses(node)
if not addrs:
continue
host, port = self.choose_preferred_address(addrs)
@@ -776,7 +776,7 @@ class LNWorker(PrintError):
# try random address for node_id
node_info = await self.channel_db._nodes_get(chan.node_id)
if not node_info: return
- addresses = node_info.get_addresses()
+ addresses = self.channel_db.get_node_addresses(node_info)
if not addresses: return
adr_obj = random.choice(addresses)
host, port = adr_obj.host, adr_obj.port