electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit 2c80996fbf23b612e58e73ff727c2bae1d782e34
parent e7218d798dd1c8556401b1f20d86f9e5474056b6
Author: ThomasV <thomasv@electrum.org>
Date:   Fri, 15 Mar 2019 15:47:31 +0100

lnrouter: fix primary key conflict in Policy update

Diffstat:
Melectrum/lnrouter.py | 44++++++++++++++++++++++----------------------
1 file changed, 22 insertions(+), 22 deletions(-)

diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py @@ -347,6 +347,7 @@ class ChannelDB(SqlDB): short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads] channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} + new_policies = {} for msg_payload in msg_payloads: short_channel_id = msg_payload['short_channel_id'] if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']: @@ -354,7 +355,27 @@ class ChannelDB(SqlDB): channel_info = channel_infos.get(short_channel_id) if not channel_info: continue - self._update_channel_info(channel_info, msg_payload, trusted=trusted) + flags = int.from_bytes(msg_payload['channel_flags'], 'big') + direction = flags & FLAG_DIRECTION + node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id + if not trusted and not verify_sig_for_channel_update(msg_payload, bytes.fromhex(node_id)): + continue + short_channel_id = channel_info.short_channel_id + new_policy = Policy.from_msg(msg_payload, node_id, channel_info.short_channel_id) + old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none() + if old_policy: + if old_policy.timestamp >= new_policy.timestamp: + continue + self.DBSession.delete(old_policy) + p = new_policies.get((short_channel_id, node_id)) + if p and p.timestamp >= new_policy.timestamp: + continue + new_policies[(short_channel_id, node_id)] = new_policy + # commit pending removals + self.DBSession.commit() + # add and commit new policies + for new_policy in new_policies.values(): + self.DBSession.add(new_policy) self.DBSession.commit() @sql @@ -468,27 +489,6 @@ class ChannelDB(SqlDB): bh2u(node2) if full_ids else bh2u(node2[-4:]), direction)) - def _update_channel_info(self, channel_info, msg: dict, trusted=False): - assert channel_info.short_channel_id == msg['short_channel_id'].hex() - flags = int.from_bytes(msg['channel_flags'], 'big') - direction = flags & FLAG_DIRECTION - node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id - new_policy = Policy.from_msg(msg, node_id, channel_info.short_channel_id) - old_policy = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node=node_id).one_or_none() - if not old_policy: - self.DBSession.add(new_policy) - return - if old_policy.timestamp >= new_policy.timestamp: - return # ignore - if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)): - return # ignore - old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta - old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat - old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat - old_policy.fee_base_msat = new_policy.fee_base_msat - old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths - old_policy.channel_flags = new_policy.channel_flags - old_policy.timestamp = new_policy.timestamp @sql def get_policy_for_node(self, channel_info, node) -> Optional['Policy']: