electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit e7888a50bedd8133fec92e960c8a44bedf8c1311
parent eae8f1a139ae5ebc411c9a33352744ee1897bf1c
Author: ThomasV <thomasv@electrum.org>
Date:   Sun, 17 Mar 2019 11:54:31 +0100

fix sql conflicts in lnrouter

Diffstat:
Melectrum/lnrouter.py | 138++++++++++++++++++++++++++++++++++++++++---------------------------------------
1 file changed, 70 insertions(+), 68 deletions(-)

diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py @@ -23,7 +23,7 @@ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import datetime +import time import random import queue import os @@ -35,7 +35,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK import binascii import base64 -from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean +from sqlalchemy import Column, ForeignKey, Integer, String, Boolean from sqlalchemy.orm.query import Query from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.sql import not_, or_ @@ -81,14 +81,14 @@ class ChannelInfo(Base): trusted = Column(Boolean, nullable=False) @staticmethod - def from_msg(channel_announcement_payload): - features = int.from_bytes(channel_announcement_payload['features'], 'big') + def from_msg(payload): + features = int.from_bytes(payload['features'], 'big') validate_features(features) - channel_id = channel_announcement_payload['short_channel_id'].hex() - node_id_1 = channel_announcement_payload['node_id_1'].hex() - node_id_2 = channel_announcement_payload['node_id_2'].hex() + channel_id = payload['short_channel_id'].hex() + node_id_1 = payload['node_id_1'].hex() + node_id_2 = payload['node_id_2'].hex() assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] - msg_payload_hex = encode_msg('channel_announcement', **channel_announcement_payload).hex() + msg_payload_hex = encode_msg('channel_announcement', **payload).hex() capacity_sat = None return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1, node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, @@ -109,17 +109,17 @@ class Policy(Base): fee_base_msat = Column(Integer, nullable=False) fee_proportional_millionths = Column(Integer, nullable=False) channel_flags = Column(Integer, nullable=False) - timestamp = Column(DateTime, nullable=False) + timestamp = Column(Integer, nullable=False) @staticmethod - def from_msg(channel_update_payload, start_node, short_channel_id): - cltv_expiry_delta = channel_update_payload['cltv_expiry_delta'] - htlc_minimum_msat = channel_update_payload['htlc_minimum_msat'] - fee_base_msat = channel_update_payload['fee_base_msat'] - fee_proportional_millionths = channel_update_payload['fee_proportional_millionths'] - channel_flags = channel_update_payload['channel_flags'] - timestamp = channel_update_payload['timestamp'] - htlc_maximum_msat = channel_update_payload.get('htlc_maximum_msat') # optional + def from_msg(payload, start_node, short_channel_id): + cltv_expiry_delta = payload['cltv_expiry_delta'] + htlc_minimum_msat = payload['htlc_minimum_msat'] + fee_base_msat = payload['fee_base_msat'] + fee_proportional_millionths = payload['fee_proportional_millionths'] + channel_flags = payload['channel_flags'] + timestamp = payload['timestamp'] + htlc_maximum_msat = payload.get('htlc_maximum_msat') # optional cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big") htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big") @@ -127,7 +127,7 @@ class Policy(Base): fee_base_msat = int.from_bytes(fee_base_msat, "big") fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big") channel_flags = int.from_bytes(channel_flags, "big") - timestamp = datetime.datetime.fromtimestamp(int.from_bytes(timestamp, "big")) + timestamp = int.from_bytes(timestamp, "big") return Policy(start_node=start_node, short_channel_id=short_channel_id, @@ -150,17 +150,16 @@ class NodeInfo(Base): alias = Column(String(64), nullable=False) @staticmethod - def from_msg(node_announcement_payload, addresses_already_parsed=False): - node_id = node_announcement_payload['node_id'].hex() - features = int.from_bytes(node_announcement_payload['features'], "big") + def from_msg(payload): + node_id = payload['node_id'].hex() + features = int.from_bytes(payload['features'], "big") validate_features(features) - if not addresses_already_parsed: - addresses = NodeInfo.parse_addresses_field(node_announcement_payload['addresses']) - else: - addresses = node_announcement_payload['addresses'] - alias = node_announcement_payload['alias'].rstrip(b'\x00').hex() - timestamp = datetime.datetime.fromtimestamp(int.from_bytes(node_announcement_payload['timestamp'], "big")) - return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [Address(host=host, port=port, node_id=node_id, last_connected_date=datetime.datetime.now()) for host, port in addresses] + addresses = NodeInfo.parse_addresses_field(payload['addresses']) + alias = payload['alias'].rstrip(b'\x00').hex() + timestamp = int.from_bytes(payload['timestamp'], "big") + now = int(time.time()) + return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [ + Address(host=host, port=port, node_id=node_id, last_connected_date=now) for host, port in addresses] @staticmethod def parse_addresses_field(addresses_field): @@ -207,7 +206,7 @@ class Address(Base): node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) host = Column(String(256), primary_key=True) port = Column(Integer, primary_key=True) - last_connected_date = Column(DateTime(), nullable=False) + last_connected_date = Column(Integer(), nullable=False) @@ -235,12 +234,14 @@ class ChannelDB(SqlDB): @sql def add_recent_peer(self, peer: LNPeerAddr): - addr = self.DBSession.query(Address).filter_by(node_id = peer.pubkey.hex()).one_or_none() - if addr is None: - addr = Address(node_id = peer.pubkey.hex(), host = peer.host, port = peer.port, last_connected_date = datetime.datetime.now()) + now = int(time.time()) + node_id = peer.pubkey.hex() + addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none() + if addr: + addr.last_connected_date = now else: - addr.last_connected_date = datetime.datetime.now() - self.DBSession.add(addr) + addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now) + self.DBSession.add(addr) self.DBSession.commit() @sql @@ -317,25 +318,31 @@ class ChannelDB(SqlDB): self.DBSession.commit() @sql - @profiler + #@profiler def on_channel_announcement(self, msg_payloads, trusted=False): if type(msg_payloads) is dict: msg_payloads = [msg_payloads] + new_channels = {} for msg in msg_payloads: - short_channel_id = msg['short_channel_id'] - if self.DBSession.query(ChannelInfo).filter_by(short_channel_id = bh2u(short_channel_id)).count(): + short_channel_id = bh2u(msg['short_channel_id']) + if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count(): continue if constants.net.rev_genesis_bytes() != msg['chain_hash']: - #self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash']))) + self.print_error("ChanAnn has unexpected chain_hash {}".format(bh2u(msg_payload['chain_hash']))) continue try: channel_info = ChannelInfo.from_msg(msg) except UnknownEvenFeatureBits: + self.print_error("unknown feature bits") continue channel_info.trusted = trusted + new_channels[short_channel_id] = channel_info + if not trusted: + self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload) + for channel_info in new_channels.values(): self.DBSession.add(channel_info) - if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload) self.DBSession.commit() + self.print_error('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads))) self._update_counts() self.network.trigger_callback('ln_status') @@ -379,21 +386,13 @@ class ChannelDB(SqlDB): self.DBSession.commit() @sql - @profiler + #@profiler def on_node_announcement(self, msg_payloads): if type(msg_payloads) is dict: msg_payloads = [msg_payloads] - addresses = self.DBSession.query(Address).all() - have_addr = {} - for addr in addresses: - have_addr[(addr.node_id, addr.host, addr.port)] = addr - - nodes = self.DBSession.query(NodeInfo).all() - timestamps = {} - for node in nodes: - no_millisecs = node.timestamp[:len("0000-00-00 00:00:00")] - timestamps[bfh(node.node_id)] = datetime.datetime.strptime(no_millisecs, "%Y-%m-%d %H:%M:%S") old_addr = None + new_nodes = {} + new_addresses = {} for msg_payload in msg_payloads: pubkey = msg_payload['node_id'] signature = msg_payload['signature'] @@ -401,30 +400,33 @@ class ChannelDB(SqlDB): if not ecc.verify_signature(pubkey, signature, h): continue try: - new_node_info, addresses = NodeInfo.from_msg(msg_payload) + node_info, node_addresses = NodeInfo.from_msg(msg_payload) except UnknownEvenFeatureBits: continue - if timestamps.get(pubkey) and timestamps[pubkey] >= new_node_info.timestamp: - continue # ignore - self.DBSession.add(new_node_info) - for new_addr in addresses: - key = (new_addr.node_id, new_addr.host, new_addr.port) - old_addr = have_addr.get(key) - if old_addr: - # since old_addr is embedded in have_addr, - # it will still live when commmit is called - old_addr.last_connected_date = new_addr.last_connected_date - del new_addr - else: - self.DBSession.add(new_addr) - have_addr[key] = new_addr + node_id = node_info.node_id + node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none() + if node and node.timestamp >= node_info.timestamp: + continue + node = new_nodes.get(node_id) + if node and node.timestamp >= node_info.timestamp: + continue + new_nodes[node_id] = node_info + for addr in node_addresses: + new_addresses[(addr.node_id,addr.host,addr.port)] = addr + + self.print_error("on_node_announcements: %d/%d"%(len(new_nodes), len(msg_payloads))) + for node_info in new_nodes.values(): + self.DBSession.add(node_info) + for new_addr in new_addresses.values(): + old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none() + if old_addr: + old_addr.last_connected_date = new_addr.last_connected_date + else: + self.DBSession.add(new_addr) # TODO if this message is for a new node, and if we have no associated # channels for this node, we should ignore the message and return here, # to mitigate DOS. but race condition: the channels we have for this # node, might be under verification in self.ca_verifier, what then? - del nodes, addresses - if old_addr: - del old_addr self.DBSession.commit() self._update_counts() self.network.trigger_callback('ln_status')