commit e7888a50bedd8133fec92e960c8a44bedf8c1311
parent eae8f1a139ae5ebc411c9a33352744ee1897bf1c
Author: ThomasV <thomasv@electrum.org>
Date: Sun, 17 Mar 2019 11:54:31 +0100
fix sql conflicts in lnrouter
Diffstat:
M | electrum/lnrouter.py | | | 138 | ++++++++++++++++++++++++++++++++++++++++--------------------------------------- |
1 file changed, 70 insertions(+), 68 deletions(-)
diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
@@ -23,7 +23,7 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
-import datetime
+import time
import random
import queue
import os
@@ -35,7 +35,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
import binascii
import base64
-from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
+from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
@@ -81,14 +81,14 @@ class ChannelInfo(Base):
trusted = Column(Boolean, nullable=False)
@staticmethod
- def from_msg(channel_announcement_payload):
- features = int.from_bytes(channel_announcement_payload['features'], 'big')
+ def from_msg(payload):
+ features = int.from_bytes(payload['features'], 'big')
validate_features(features)
- channel_id = channel_announcement_payload['short_channel_id'].hex()
- node_id_1 = channel_announcement_payload['node_id_1'].hex()
- node_id_2 = channel_announcement_payload['node_id_2'].hex()
+ channel_id = payload['short_channel_id'].hex()
+ node_id_1 = payload['node_id_1'].hex()
+ node_id_2 = payload['node_id_2'].hex()
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
- msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex()
+ msg_payload_hex = encode_msg('channel_announcement', **payload).hex()
capacity_sat = None
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
@@ -109,17 +109,17 @@ class Policy(Base):
fee_base_msat = Column(Integer, nullable=False)
fee_proportional_millionths = Column(Integer, nullable=False)
channel_flags = Column(Integer, nullable=False)
- timestamp = Column(DateTime, nullable=False)
+ timestamp = Column(Integer, nullable=False)
@staticmethod
- def from_msg(channel_update_payload, start_node, short_channel_id):
- cltv_expiry_delta = channel_update_payload['cltv_expiry_delta']
- htlc_minimum_msat = channel_update_payload['htlc_minimum_msat']
- fee_base_msat = channel_update_payload['fee_base_msat']
- fee_proportional_millionths = channel_update_payload['fee_proportional_millionths']
- channel_flags = channel_update_payload['channel_flags']
- timestamp = channel_update_payload['timestamp']
- htlc_maximum_msat = channel_update_payload.get('htlc_maximum_msat') # optional
+ def from_msg(payload, start_node, short_channel_id):
+ cltv_expiry_delta = payload['cltv_expiry_delta']
+ htlc_minimum_msat = payload['htlc_minimum_msat']
+ fee_base_msat = payload['fee_base_msat']
+ fee_proportional_millionths = payload['fee_proportional_millionths']
+ channel_flags = payload['channel_flags']
+ timestamp = payload['timestamp']
+ htlc_maximum_msat = payload.get('htlc_maximum_msat') # optional
cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
@@ -127,7 +127,7 @@ class Policy(Base):
fee_base_msat = int.from_bytes(fee_base_msat, "big")
fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
channel_flags = int.from_bytes(channel_flags, "big")
- timestamp = datetime.datetime.fromtimestamp(int.from_bytes(timestamp, "big"))
+ timestamp = int.from_bytes(timestamp, "big")
return Policy(start_node=start_node,
short_channel_id=short_channel_id,
@@ -150,17 +150,16 @@ class NodeInfo(Base):
alias = Column(String(64), nullable=False)
@staticmethod
- def from_msg(node_announcement_payload, addresses_already_parsed=False):
- node_id = node_announcement_payload['node_id'].hex()
- features = int.from_bytes(node_announcement_payload['features'], "big")
+ def from_msg(payload):
+ node_id = payload['node_id'].hex()
+ features = int.from_bytes(payload['features'], "big")
validate_features(features)
- if not addresses_already_parsed:
- addresses = NodeInfo.parse_addresses_field(node_announcement_payload['addresses'])
- else:
- addresses = node_announcement_payload['addresses']
- alias = node_announcement_payload['alias'].rstrip(b'\x00').hex()
- timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big"))
- return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [Address(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses]
+ addresses = NodeInfo.parse_addresses_field(payload['addresses'])
+ alias = payload['alias'].rstrip(b'\x00').hex()
+ timestamp = int.from_bytes(payload['timestamp'], "big")
+ now = int(time.time())
+ return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [
+ Address(host=host, port=port, node_id=node_id, last_connected_date=now) for host, port in addresses]
@staticmethod
def parse_addresses_field(addresses_field):
@@ -207,7 +206,7 @@ class Address(Base):
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
host = Column(String(256), primary_key=True)
port = Column(Integer, primary_key=True)
- last_connected_date = Column(DateTime(), nullable=False)
+ last_connected_date = Column(Integer(), nullable=False)
@@ -235,12 +234,14 @@ class ChannelDB(SqlDB):
@sql
def add_recent_peer(self, peer: LNPeerAddr):
- addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none()
- if addr is None:
- addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now())
+ now = int(time.time())
+ node_id = peer.pubkey.hex()
+ addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
+ if addr:
+ addr.last_connected_date = now
else:
- addr.last_connected_date = datetime.datetime.now()
- self.DBSession.add(addr)
+ addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
+ self.DBSession.add(addr)
self.DBSession.commit()
@sql
@@ -317,25 +318,31 @@ class ChannelDB(SqlDB):
self.DBSession.commit()
@sql
- @profiler
+ #@profiler
def on_channel_announcement(self, msg_payloads, trusted=False):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
+ new_channels = {}
for msg in msg_payloads:
- short_channel_id = msg['short_channel_id']
- if self.DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count():
+ short_channel_id = bh2u(msg['short_channel_id'])
+ if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count():
continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
- #self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
+ self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash'])))
continue
try:
channel_info = ChannelInfo.from_msg(msg)
except UnknownEvenFeatureBits:
+ self.print_error("unknown feature bits")
continue
channel_info.trusted = trusted
+ new_channels[short_channel_id] = channel_info
+ if not trusted:
+ self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
+ for channel_info in new_channels.values():
self.DBSession.add(channel_info)
- if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
self.DBSession.commit()
+ self.print_error('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
self._update_counts()
self.network.trigger_callback('ln_status')
@@ -379,21 +386,13 @@ class ChannelDB(SqlDB):
self.DBSession.commit()
@sql
- @profiler
+ #@profiler
def on_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
- addresses = self.DBSession.query(Address).all()
- have_addr = {}
- for addr in addresses:
- have_addr[(addr.node_id, addr.host, addr.port)] = addr
-
- nodes = self.DBSession.query(NodeInfo).all()
- timestamps = {}
- for node in nodes:
- no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")]
- timestamps[bfh(node.node_id)] = datetime.datetime.strptime(no_millisecs, "%Y-%m-%d %H:%M:%S")
old_addr = None
+ new_nodes = {}
+ new_addresses = {}
for msg_payload in msg_payloads:
pubkey = msg_payload['node_id']
signature = msg_payload['signature']
@@ -401,30 +400,33 @@ class ChannelDB(SqlDB):
if not ecc.verify_signature(pubkey, signature, h):
continue
try:
- new_node_info, addresses = NodeInfo.from_msg(msg_payload)
+ node_info, node_addresses = NodeInfo.from_msg(msg_payload)
except UnknownEvenFeatureBits:
continue
- if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp:
- continue # ignore
- self.DBSession.add(new_node_info)
- for new_addr in addresses:
- key = (new_addr.node_id, new_addr.host, new_addr.port)
- old_addr = have_addr.get(key)
- if old_addr:
- # since old_addr is embedded in have_addr,
- # it will still live when commmit is called
- old_addr.last_connected_date = new_addr.last_connected_date
- del new_addr
- else:
- self.DBSession.add(new_addr)
- have_addr[key] = new_addr
+ node_id = node_info.node_id
+ node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none()
+ if node and node.timestamp >= node_info.timestamp:
+ continue
+ node = new_nodes.get(node_id)
+ if node and node.timestamp >= node_info.timestamp:
+ continue
+ new_nodes[node_id] = node_info
+ for addr in node_addresses:
+ new_addresses[(addr.node_id,addr.host,addr.port)] = addr
+
+ self.print_error("on_node_announcements: %d/%d"%(len(new_nodes), len(msg_payloads)))
+ for node_info in new_nodes.values():
+ self.DBSession.add(node_info)
+ for new_addr in new_addresses.values():
+ old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
+ if old_addr:
+ old_addr.last_connected_date = new_addr.last_connected_date
+ else:
+ self.DBSession.add(new_addr)
# TODO if this message is for a new node, and if we have no associated
# channels for this node, we should ignore the message and return here,
# to mitigate DOS. but race condition: the channels we have for this
# node, might be under verification in self.ca_verifier, what then?
- del nodes, addresses
- if old_addr:
- del old_addr
self.DBSession.commit()
self._update_counts()
self.network.trigger_callback('ln_status')