commit 2c80996fbf23b612e58e73ff727c2bae1d782e34
parent e7218d798dd1c8556401b1f20d86f9e5474056b6
Author: ThomasV <thomasv@electrum.org>
Date: Fri, 15 Mar 2019 15:47:31 +0100
lnrouter: fix primary key conflict in Policy update
Diffstat:
1 file changed, 22 insertions(+), 22 deletions(-)
diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
@@ -347,6 +347,7 @@ class ChannelDB(SqlDB):
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
+ new_policies = {}
for msg_payload in msg_payloads:
short_channel_id = msg_payload['short_channel_id']
if constants.net.rev_genesis_bytes() != msg_payload['chain_hash']:
@@ -354,7 +355,27 @@ class ChannelDB(SqlDB):
channel_info = channel_infos.get(short_channel_id)
if not channel_info:
continue
- self._update_channel_info(channel_info, msg_payload, trusted=trusted)
+ flags = int.from_bytes(msg_payload['channel_flags'], 'big')
+ direction = flags & FLAG_DIRECTION
+ node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id
+ if not trusted and not verify_sig_for_channel_update(msg_payload, bytes.fromhex(node_id)):
+ continue
+ short_channel_id = channel_info.short_channel_id
+ new_policy = Policy.from_msg(msg_payload, node_id, channel_info.short_channel_id)
+ old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none()
+ if old_policy:
+ if old_policy.timestamp >= new_policy.timestamp:
+ continue
+ self.DBSession.delete(old_policy)
+ p = new_policies.get((short_channel_id, node_id))
+ if p and p.timestamp >= new_policy.timestamp:
+ continue
+ new_policies[(short_channel_id, node_id)] = new_policy
+ # commit pending removals
+ self.DBSession.commit()
+ # add and commit new policies
+ for new_policy in new_policies.values():
+ self.DBSession.add(new_policy)
self.DBSession.commit()
@sql
@@ -468,27 +489,6 @@ class ChannelDB(SqlDB):
bh2u(node2) if full_ids else bh2u(node2[-4:]),
direction))
- def _update_channel_info(self, channel_info, msg: dict, trusted=False):
- assert channel_info.short_channel_id == msg['short_channel_id'].hex()
- flags = int.from_bytes(msg['channel_flags'], 'big')
- direction = flags & FLAG_DIRECTION
- node_id = channel_info.node1_id if direction == 0 else channel_info.node2_id
- new_policy = Policy.from_msg(msg, node_id, channel_info.short_channel_id)
- old_policy = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node=node_id).one_or_none()
- if not old_policy:
- self.DBSession.add(new_policy)
- return
- if old_policy.timestamp >= new_policy.timestamp:
- return # ignore
- if not trusted and not verify_sig_for_channel_update(msg, bytes.fromhex(node_id)):
- return # ignore
- old_policy.cltv_expiry_delta = new_policy.cltv_expiry_delta
- old_policy.htlc_minimum_msat = new_policy.htlc_minimum_msat
- old_policy.htlc_maximum_msat = new_policy.htlc_maximum_msat
- old_policy.fee_base_msat = new_policy.fee_base_msat
- old_policy.fee_proportional_millionths = new_policy.fee_proportional_millionths
- old_policy.channel_flags = new_policy.channel_flags
- old_policy.timestamp = new_policy.timestamp
@sql
def get_policy_for_node(self, channel_info, node) -> Optional['Policy']: