commit eb4e6bb0de493877c0c7553928b4408787d79c92
parent f4b3d7627d99777d9c33758fe5874d1360cac1ab
Author: ThomasV <thomasv@electrum.org>
Date: Thu, 16 May 2019 19:00:44 +0200
improve filter_channel_updates
blacklist channels that do not really get updated
Diffstat:
2 files changed, 106 insertions(+), 74 deletions(-)
diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
@@ -241,12 +241,14 @@ class Peer(Logger):
self.verify_node_announcements(node_anns)
self.channel_db.on_node_announcement(node_anns)
# channel updates
- good, bad = self.channel_db.filter_channel_updates(chan_upds)
- if bad:
- self.logger.info(f'adding {len(bad)} unknown channel ids')
- self.network.lngossip.add_new_ids(bad)
- self.verify_channel_updates(good)
- self.channel_db.on_channel_update(good)
+ orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates(chan_upds, max_age=self.network.lngossip.max_age)
+ if orphaned:
+ self.logger.info(f'adding {len(orphaned)} unknown channel ids')
+ self.network.lngossip.add_new_ids(orphaned)
+ if good:
+ self.logger.debug(f'on_channel_update: {len(good)}/{len(chan_upds)}')
+ self.verify_channel_updates(good)
+ self.channel_db.update_policies(good, to_delete)
# refresh gui
if chan_anns or node_anns or chan_upds:
self.network.lngossip.refresh_gui()
@@ -273,7 +275,7 @@ class Peer(Logger):
short_channel_id = payload['short_channel_id']
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
raise Exception('wrong chain hash')
- if not verify_sig_for_channel_update(payload, payload['node_id']):
+ if not verify_sig_for_channel_update(payload, payload['start_node']):
raise BaseException('verify error')
@log_exceptions
@@ -990,21 +992,29 @@ class Peer(Logger):
OnionFailureCode.EXPIRY_TOO_SOON: 2,
OnionFailureCode.CHANNEL_DISABLED: 4,
}
- offset = failure_codes.get(code)
- if offset:
+ if code in failure_codes:
+ offset = failure_codes[code]
channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:]
message_type, payload = decode_msg(channel_update)
payload['raw'] = channel_update
- try:
- self.logger.info(f"trying to apply channel update on our db {payload}")
- self.channel_db.add_channel_update(payload)
- self.logger.info("successfully applied channel update on our db")
- except NotFoundChanAnnouncementForUpdate:
+ orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates([payload])
+ if good:
+ self.verify_channel_updates(good)
+ self.channel_db.update_policies(good, to_delete)
+ self.logger.info("applied channel update on our db")
+ elif orphaned:
# maybe it is a private channel (and data in invoice was outdated)
self.logger.info("maybe channel update is for private channel?")
start_node_id = route[sender_idx].node_id
self.channel_db.add_channel_update_for_private_channel(payload, start_node_id)
+ elif expired:
+ blacklist = True
+ elif deprecated:
+ self.logger.info(f'channel update is not more recent. blacklisting channel')
+ blacklist = True
else:
+ blacklist = True
+ if blacklist:
# blacklist channel after reporter node
# TODO this should depend on the error (even more granularity)
# also, we need finer blacklisting (directed edges; nodes)
diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
@@ -114,22 +114,16 @@ class Policy(Base):
timestamp = Column(Integer, nullable=False)
@staticmethod
- def from_msg(payload, start_node, short_channel_id):
- cltv_expiry_delta = payload['cltv_expiry_delta']
- htlc_minimum_msat = payload['htlc_minimum_msat']
- fee_base_msat = payload['fee_base_msat']
- fee_proportional_millionths = payload['fee_proportional_millionths']
- channel_flags = payload['channel_flags']
- timestamp = payload['timestamp']
- htlc_maximum_msat = payload.get('htlc_maximum_msat') # optional
-
- cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
- htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
- htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None
- fee_base_msat = int.from_bytes(fee_base_msat, "big")
- fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
- channel_flags = int.from_bytes(channel_flags, "big")
- timestamp = int.from_bytes(timestamp, "big")
+ def from_msg(payload):
+ cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big")
+ htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big")
+ htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None
+ fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big")
+ fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big")
+ channel_flags = int.from_bytes(payload['channel_flags'], "big")
+ timestamp = int.from_bytes(payload['timestamp'], "big")
+ start_node = payload['start_node'].hex()
+ short_channel_id = payload['short_channel_id'].hex()
return Policy(start_node=start_node,
short_channel_id=short_channel_id,
@@ -341,71 +335,98 @@ class ChannelDB(SqlDB):
r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one()
return r.max_timestamp or 0
+ def print_change(self, old_policy, new_policy):
+ # print what changed between policies
+ if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
+ self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
+ if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
+ self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
+ if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
+ self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
+ if old_policy.fee_base_msat != new_policy.fee_base_msat:
+ self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
+ if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
+ self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
+ if old_policy.channel_flags != new_policy.channel_flags:
+ self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
+
@sql
- def get_info_for_updates(self, msg_payloads):
- short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
+ def get_info_for_updates(self, payloads):
+ short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads]
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
return channel_infos
+ @sql
+ def get_policies_for_updates(self, payloads):
+ out = {}
+ for payload in payloads:
+ short_channel_id = payload['short_channel_id'].hex()
+ start_node = payload['start_node'].hex()
+ policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none()
+ if policy:
+ out[short_channel_id+start_node] = policy
+ return out
+
@profiler
- def filter_channel_updates(self, payloads):
- # add 'node_id' to payload
- channel_infos = self.get_info_for_updates(payloads)
+ def filter_channel_updates(self, payloads, max_age=None):
+ orphaned = [] # no channel announcement for channel update
+ expired = [] # update older than two weeks
+ deprecated = [] # update older than database entry
+ good = [] # good updates
+ to_delete = [] # database entries to delete
+ # filter orphaned and expired first
known = []
- unknown = []
+ now = int(time.time())
+ channel_infos = self.get_info_for_updates(payloads)
for payload in payloads:
short_channel_id = payload['short_channel_id']
+ timestamp = int.from_bytes(payload['timestamp'], "big")
+ if max_age and now - timestamp > max_age:
+ expired.append(short_channel_id)
+ continue
channel_info = channel_infos.get(short_channel_id)
if not channel_info:
- unknown.append(short_channel_id)
+ orphaned.append(short_channel_id)
continue
flags = int.from_bytes(payload['channel_flags'], 'big')
direction = flags & FLAG_DIRECTION
- node_id = bfh(channel_info.node1_id if direction == 0 else channel_info.node2_id)
- payload['node_id'] = node_id
+ start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
+ payload['start_node'] = bfh(start_node)
known.append(payload)
- return known, unknown
+ # compare updates to existing database entries
+ old_policies = self.get_policies_for_updates(known)
+ for payload in known:
+ timestamp = int.from_bytes(payload['timestamp'], "big")
+ start_node = payload['start_node'].hex()
+ short_channel_id = payload['short_channel_id'].hex()
+ old_policy = old_policies.get(short_channel_id+start_node)
+ if old_policy:
+ if timestamp <= old_policy.timestamp:
+ deprecated.append(short_channel_id)
+ else:
+ good.append(payload)
+ to_delete.append(old_policy)
+ else:
+ good.append(payload)
+ return orphaned, expired, deprecated, good, to_delete
def add_channel_update(self, payload):
- # called in tests/test_lnrouter
- good, bad = self.filter_channel_updates([payload])
- assert len(bad) == 0
- self.on_channel_update(good)
+ orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload])
+ assert len(good) == 1
+ self.update_policies(good, to_delete)
@sql
@profiler
- def on_channel_update(self, msg_payloads):
- now = int(time.time())
- if type(msg_payloads) is dict:
- msg_payloads = [msg_payloads]
- new_policies = {}
- for msg_payload in msg_payloads:
- short_channel_id = msg_payload['short_channel_id'].hex()
- node_id = msg_payload['node_id'].hex()
- new_policy = Policy.from_msg(msg_payload, node_id, short_channel_id)
- # must not be older than two weeks
- if new_policy.timestamp < now - 14*24*3600:
- continue
- old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none()
- if old_policy:
- if old_policy.timestamp >= new_policy.timestamp:
- continue
- self.DBSession.delete(old_policy)
- p = new_policies.get((short_channel_id, node_id))
- if p and p.timestamp >= new_policy.timestamp:
- continue
- new_policies[(short_channel_id, node_id)] = new_policy
- # commit pending removals
+ def update_policies(self, to_add, to_delete):
+ for policy in to_delete:
+ self.DBSession.delete(policy)
self.DBSession.commit()
- # add and commit new policies
- for new_policy in new_policies.values():
- self.DBSession.add(new_policy)
+ for payload in to_add:
+ policy = Policy.from_msg(payload)
+ self.DBSession.add(policy)
self.DBSession.commit()
- if new_policies:
- self.logger.debug(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}')
- #self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}')
- self._update_counts()
+ self._update_counts()
@sql
@profiler
@@ -454,7 +475,7 @@ class ChannelDB(SqlDB):
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
if not msg:
return None
- return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB
+ return Policy.from_msg(msg) # won't actually be written to DB
@sql
@profiler
@@ -496,6 +517,7 @@ class ChannelDB(SqlDB):
if not verify_sig_for_channel_update(msg_payload, start_node_id):
return # ignore
short_channel_id = msg_payload['short_channel_id']
+ msg_payload['start_node'] = start_node_id
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
@sql