electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit eb4e6bb0de493877c0c7553928b4408787d79c92
parent f4b3d7627d99777d9c33758fe5874d1360cac1ab
Author: ThomasV <thomasv@electrum.org>
Date:   Thu, 16 May 2019 19:00:44 +0200

improve filter_channel_updates
blacklist channels that do not really get updated

Diffstat:
Melectrum/lnpeer.py | 38++++++++++++++++++++++++--------------
Melectrum/lnrouter.py | 142++++++++++++++++++++++++++++++++++++++++++++++---------------------------------
2 files changed, 106 insertions(+), 74 deletions(-)

diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -241,12 +241,14 @@ class Peer(Logger): self.verify_node_announcements(node_anns) self.channel_db.on_node_announcement(node_anns) # channel updates - good, bad = self.channel_db.filter_channel_updates(chan_upds) - if bad: - self.logger.info(f'adding {len(bad)} unknown channel ids') - self.network.lngossip.add_new_ids(bad) - self.verify_channel_updates(good) - self.channel_db.on_channel_update(good) + orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates(chan_upds, max_age=self.network.lngossip.max_age) + if orphaned: + self.logger.info(f'adding {len(orphaned)} unknown channel ids') + self.network.lngossip.add_new_ids(orphaned) + if good: + self.logger.debug(f'on_channel_update: {len(good)}/{len(chan_upds)}') + self.verify_channel_updates(good) + self.channel_db.update_policies(good, to_delete) # refresh gui if chan_anns or node_anns or chan_upds: self.network.lngossip.refresh_gui() @@ -273,7 +275,7 @@ class Peer(Logger): short_channel_id = payload['short_channel_id'] if constants.net.rev_genesis_bytes() != payload['chain_hash']: raise Exception('wrong chain hash') - if not verify_sig_for_channel_update(payload, payload['node_id']): + if not verify_sig_for_channel_update(payload, payload['start_node']): raise BaseException('verify error') @log_exceptions @@ -990,21 +992,29 @@ class Peer(Logger): OnionFailureCode.EXPIRY_TOO_SOON: 2, OnionFailureCode.CHANNEL_DISABLED: 4, } - offset = failure_codes.get(code) - if offset: + if code in failure_codes: + offset = failure_codes[code] channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:] message_type, payload = decode_msg(channel_update) payload['raw'] = channel_update - try: - self.logger.info(f"trying to apply channel update on our db {payload}") - self.channel_db.add_channel_update(payload) - self.logger.info("successfully applied channel update on our db") - except NotFoundChanAnnouncementForUpdate: + orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates([payload]) + if good: + self.verify_channel_updates(good) + self.channel_db.update_policies(good, to_delete) + self.logger.info("applied channel update on our db") + elif orphaned: # maybe it is a private channel (and data in invoice was outdated) self.logger.info("maybe channel update is for private channel?") start_node_id = route[sender_idx].node_id self.channel_db.add_channel_update_for_private_channel(payload, start_node_id) + elif expired: + blacklist = True + elif deprecated: + self.logger.info(f'channel update is not more recent. blacklisting channel') + blacklist = True else: + blacklist = True + if blacklist: # blacklist channel after reporter node # TODO this should depend on the error (even more granularity) # also, we need finer blacklisting (directed edges; nodes) diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py @@ -114,22 +114,16 @@ class Policy(Base): timestamp = Column(Integer, nullable=False) @staticmethod - def from_msg(payload, start_node, short_channel_id): - cltv_expiry_delta = payload['cltv_expiry_delta'] - htlc_minimum_msat = payload['htlc_minimum_msat'] - fee_base_msat = payload['fee_base_msat'] - fee_proportional_millionths = payload['fee_proportional_millionths'] - channel_flags = payload['channel_flags'] - timestamp = payload['timestamp'] - htlc_maximum_msat = payload.get('htlc_maximum_msat') # optional - - cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big") - htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big") - htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None - fee_base_msat = int.from_bytes(fee_base_msat, "big") - fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big") - channel_flags = int.from_bytes(channel_flags, "big") - timestamp = int.from_bytes(timestamp, "big") + def from_msg(payload): + cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big") + htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big") + htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None + fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big") + fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big") + channel_flags = int.from_bytes(payload['channel_flags'], "big") + timestamp = int.from_bytes(payload['timestamp'], "big") + start_node = payload['start_node'].hex() + short_channel_id = payload['short_channel_id'].hex() return Policy(start_node=start_node, short_channel_id=short_channel_id, @@ -341,71 +335,98 @@ class ChannelDB(SqlDB): r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one() return r.max_timestamp or 0 + def print_change(self, old_policy, new_policy): + # print what changed between policies + if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta: + self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}') + if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat: + self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}') + if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat: + self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}') + if old_policy.fee_base_msat != new_policy.fee_base_msat: + self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}') + if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths: + self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}') + if old_policy.channel_flags != new_policy.channel_flags: + self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}') + @sql - def get_info_for_updates(self, msg_payloads): - short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads] + def get_info_for_updates(self, payloads): + short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads] channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} return channel_infos + @sql + def get_policies_for_updates(self, payloads): + out = {} + for payload in payloads: + short_channel_id = payload['short_channel_id'].hex() + start_node = payload['start_node'].hex() + policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none() + if policy: + out[short_channel_id+start_node] = policy + return out + @profiler - def filter_channel_updates(self, payloads): - # add 'node_id' to payload - channel_infos = self.get_info_for_updates(payloads) + def filter_channel_updates(self, payloads, max_age=None): + orphaned = [] # no channel announcement for channel update + expired = [] # update older than two weeks + deprecated = [] # update older than database entry + good = [] # good updates + to_delete = [] # database entries to delete + # filter orphaned and expired first known = [] - unknown = [] + now = int(time.time()) + channel_infos = self.get_info_for_updates(payloads) for payload in payloads: short_channel_id = payload['short_channel_id'] + timestamp = int.from_bytes(payload['timestamp'], "big") + if max_age and now - timestamp > max_age: + expired.append(short_channel_id) + continue channel_info = channel_infos.get(short_channel_id) if not channel_info: - unknown.append(short_channel_id) + orphaned.append(short_channel_id) continue flags = int.from_bytes(payload['channel_flags'], 'big') direction = flags & FLAG_DIRECTION - node_id = bfh(channel_info.node1_id if direction == 0 else channel_info.node2_id) - payload['node_id'] = node_id + start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id + payload['start_node'] = bfh(start_node) known.append(payload) - return known, unknown + # compare updates to existing database entries + old_policies = self.get_policies_for_updates(known) + for payload in known: + timestamp = int.from_bytes(payload['timestamp'], "big") + start_node = payload['start_node'].hex() + short_channel_id = payload['short_channel_id'].hex() + old_policy = old_policies.get(short_channel_id+start_node) + if old_policy: + if timestamp <= old_policy.timestamp: + deprecated.append(short_channel_id) + else: + good.append(payload) + to_delete.append(old_policy) + else: + good.append(payload) + return orphaned, expired, deprecated, good, to_delete def add_channel_update(self, payload): - # called in tests/test_lnrouter - good, bad = self.filter_channel_updates([payload]) - assert len(bad) == 0 - self.on_channel_update(good) + orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload]) + assert len(good) == 1 + self.update_policies(good, to_delete) @sql @profiler - def on_channel_update(self, msg_payloads): - now = int(time.time()) - if type(msg_payloads) is dict: - msg_payloads = [msg_payloads] - new_policies = {} - for msg_payload in msg_payloads: - short_channel_id = msg_payload['short_channel_id'].hex() - node_id = msg_payload['node_id'].hex() - new_policy = Policy.from_msg(msg_payload, node_id, short_channel_id) - # must not be older than two weeks - if new_policy.timestamp < now - 14*24*3600: - continue - old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none() - if old_policy: - if old_policy.timestamp >= new_policy.timestamp: - continue - self.DBSession.delete(old_policy) - p = new_policies.get((short_channel_id, node_id)) - if p and p.timestamp >= new_policy.timestamp: - continue - new_policies[(short_channel_id, node_id)] = new_policy - # commit pending removals + def update_policies(self, to_add, to_delete): + for policy in to_delete: + self.DBSession.delete(policy) self.DBSession.commit() - # add and commit new policies - for new_policy in new_policies.values(): - self.DBSession.add(new_policy) + for payload in to_add: + policy = Policy.from_msg(payload) + self.DBSession.add(policy) self.DBSession.commit() - if new_policies: - self.logger.debug(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}') - #self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}') - self._update_counts() + self._update_counts() @sql @profiler @@ -454,7 +475,7 @@ class ChannelDB(SqlDB): msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id)) if not msg: return None - return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB + return Policy.from_msg(msg) # won't actually be written to DB @sql @profiler @@ -496,6 +517,7 @@ class ChannelDB(SqlDB): if not verify_sig_for_channel_update(msg_payload, start_node_id): return # ignore short_channel_id = msg_payload['short_channel_id'] + msg_payload['start_node'] = start_node_id self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload @sql