commit 1011245c5e29a128edf4629d024c4cf345190282
parent 95376226e8567adb4b7bd4eea298e623f58c26e4
Author: ThomasV <thomasv@electrum.org>
Date: Mon, 13 May 2019 14:30:02 +0200
LNGossip: sync channel db using query_channel_range
Diffstat:
M | electrum/lnpeer.py | | | 84 | ++++++++++++++++++++++++++++++++++++++++++++++--------------------------------- |
M | electrum/lnrouter.py | | | 57 | +++++++++++++++++++++++++++++++++++++++++++++++++-------- |
M | electrum/lnworker.py | | | 86 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------- |
3 files changed, 168 insertions(+), 59 deletions(-)
diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
@@ -57,9 +57,7 @@ class Peer(Logger):
def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase):
self.initialized = asyncio.Event()
- self.node_anns = []
- self.chan_anns = []
- self.chan_upds = []
+ self.querying_lock = asyncio.Lock()
self.transport = transport
self.pubkey = pubkey
self.lnworker = lnworker
@@ -70,6 +68,7 @@ class Peer(Logger):
self.lnwatcher = lnworker.network.lnwatcher
self.channel_db = lnworker.network.channel_db
self.ping_time = 0
+ self.reply_channel_range = asyncio.Queue()
self.shutdown_received = defaultdict(asyncio.Future)
self.channel_accepted = defaultdict(asyncio.Queue)
self.channel_reestablished = defaultdict(asyncio.Future)
@@ -89,7 +88,7 @@ class Peer(Logger):
def send_message(self, message_name: str, **kwargs):
assert type(message_name) is str
- self.logger.info(f"Sending {message_name.upper()}")
+ self.logger.debug(f"Sending {message_name.upper()}")
self.transport.send_bytes(encode_msg(message_name, **kwargs))
async def initialize(self):
@@ -177,13 +176,13 @@ class Peer(Logger):
self.initialized.set()
def on_node_announcement(self, payload):
- self.node_anns.append(payload)
+ self.channel_db.node_anns.append(payload)
def on_channel_update(self, payload):
- self.chan_upds.append(payload)
+ self.channel_db.chan_upds.append(payload)
def on_channel_announcement(self, payload):
- self.chan_anns.append(payload)
+ self.channel_db.chan_anns.append(payload)
def on_announcement_signatures(self, payload):
channel_id = payload['channel_id']
@@ -207,15 +206,11 @@ class Peer(Logger):
@handle_disconnect
async def main_loop(self):
async with aiorpcx.TaskGroup() as group:
- await group.spawn(self._gossip_loop())
await group.spawn(self._message_loop())
# kill group if the peer times out
await group.spawn(asyncio.wait_for(self.initialized.wait(), 10))
- @log_exceptions
- async def _gossip_loop(self):
- await self.initialized.wait()
- timestamp = self.channel_db.get_last_timestamp()
+ def request_gossip(self, timestamp=0):
if timestamp == 0:
self.logger.info('requesting whole channel graph')
else:
@@ -225,28 +220,47 @@ class Peer(Logger):
chain_hash=constants.net.rev_genesis_bytes(),
first_timestamp=timestamp,
timestamp_range=b'\xff'*4)
- while True:
- await asyncio.sleep(5)
- if self.node_anns:
- self.channel_db.on_node_announcement(self.node_anns)
- self.node_anns = []
- if self.chan_anns:
- self.channel_db.on_channel_announcement(self.chan_anns)
- self.chan_anns = []
- if self.chan_upds:
- self.channel_db.on_channel_update(self.chan_upds)
- self.chan_upds = []
- # todo: enable when db is fixed
- #need_to_get = sorted(self.channel_db.missing_short_chan_ids())
- #if need_to_get and not self.receiving_channels:
- # self.logger.info(f'missing {len(need_to_get)} channels')
- # zlibencoded = zlib.compress(bfh(''.join(need_to_get[0:100])))
- # self.send_message(
- # 'query_short_channel_ids',
- # chain_hash=constants.net.rev_genesis_bytes(),
- # len=1+len(zlibencoded),
- # encoded_short_ids=b'\x01' + zlibencoded)
- # self.receiving_channels = True
+
+ def query_channel_range(self, index, num):
+ self.logger.info(f'query channel range')
+ self.send_message(
+ 'query_channel_range',
+ chain_hash=constants.net.rev_genesis_bytes(),
+ first_blocknum=index,
+ number_of_blocks=num)
+
+ def encode_short_ids(self, ids):
+ return chr(1) + zlib.compress(bfh(''.join(ids)))
+
+ def decode_short_ids(self, encoded):
+ if encoded[0] == 0:
+ decoded = encoded[1:]
+ elif encoded[0] == 1:
+ decoded = zlib.decompress(encoded[1:])
+ else:
+ raise BaseException('zlib')
+ ids = [decoded[i:i+8] for i in range(0, len(decoded), 8)]
+ return ids
+
+ def on_reply_channel_range(self, payload):
+ first = int.from_bytes(payload['first_blocknum'], 'big')
+ num = int.from_bytes(payload['number_of_blocks'], 'big')
+ complete = bool(payload['complete'])
+ encoded = payload['encoded_short_ids']
+ ids = self.decode_short_ids(encoded)
+ self.reply_channel_range.put_nowait((first, num, complete, ids))
+
+ async def query_short_channel_ids(self, ids, compressed=True):
+ await self.querying_lock.acquire()
+ #self.logger.info('querying {} short_channel_ids'.format(len(ids)))
+ s = b''.join(ids)
+ encoded = zlib.compress(s) if compressed else s
+ prefix = b'\x01' if compressed else b'\x00'
+ self.send_message(
+ 'query_short_channel_ids',
+ chain_hash=constants.net.rev_genesis_bytes(),
+ len=1+len(encoded),
+ encoded_short_ids=prefix+encoded)
async def _message_loop(self):
try:
@@ -260,7 +274,7 @@ class Peer(Logger):
self.ping_if_required()
def on_reply_short_channel_ids_end(self, payload):
- self.receiving_channels = False
+ self.querying_lock.release()
def close_and_cleanup(self):
try:
diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py
@@ -223,6 +223,20 @@ class ChannelDB(SqlDB):
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self)
self.update_counts()
+ self.node_anns = []
+ self.chan_anns = []
+ self.chan_upds = []
+
+ def process_gossip(self):
+ if self.node_anns:
+ self.on_node_announcement(self.node_anns)
+ self.node_anns = []
+ if self.chan_anns:
+ self.on_channel_announcement(self.chan_anns)
+ self.chan_anns = []
+ if self.chan_upds:
+ self.on_channel_update(self.chan_upds)
+ self.chan_upds = []
@sql
def update_counts(self):
@@ -232,7 +246,32 @@ class ChannelDB(SqlDB):
self.num_channels = self.DBSession.query(ChannelInfo).count()
self.num_policies = self.DBSession.query(Policy).count()
self.num_nodes = self.DBSession.query(NodeInfo).count()
- self.logger.info(f'update counts {self.num_channels} {self.num_policies}')
+
+ @sql
+ @profiler
+ def purge_unknown_channels(self, channel_ids):
+ ids = [x.hex() for x in channel_ids]
+ missing = self.DBSession \
+ .query(ChannelInfo) \
+ .filter(not_(ChannelInfo.short_channel_id.in_(ids))) \
+ .all()
+ if missing:
+ self.logger.info("deleting {} channels".format(len(missing)))
+ delete_query = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(ids)))
+ self.DBSession.execute(delete_query)
+ self.DBSession.commit()
+
+ @sql
+ @profiler
+ def compare_channels(self, channel_ids):
+ ids = [x.hex() for x in channel_ids]
+ # I need to get the unknown, and also the channels that need refresh
+ known = self.DBSession \
+ .query(ChannelInfo) \
+ .filter(ChannelInfo.short_channel_id.in_(ids)) \
+ .all()
+ known = [bfh(r.short_channel_id) for r in known]
+ return known
@sql
def add_recent_peer(self, peer: LNPeerAddr):
@@ -276,12 +315,14 @@ class ChannelDB(SqlDB):
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r]
@sql
- def missing_short_chan_ids(self) -> Set[int]:
+ def missing_channel_announcements(self) -> Set[int]:
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
- chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
- if chan_ids_from_policy:
- return chan_ids_from_policy
- return set()
+ return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
+
+ @sql
+ def missing_channel_updates(self) -> Set[int]:
+ expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id)))
+ return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
@sql
def add_verified_channel_info(self, short_id, capacity):
@@ -316,8 +357,8 @@ class ChannelDB(SqlDB):
for channel_info in new_channels.values():
self.DBSession.add(channel_info)
self.DBSession.commit()
- #self.logger.info('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
self._update_counts()
+ self.logger.info('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
self.network.trigger_callback('ln_status')
@sql
@@ -370,7 +411,7 @@ class ChannelDB(SqlDB):
self.DBSession.commit()
if new_policies:
self.logger.info(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}')
- self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}')
+ #self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}')
self._update_counts()
@sql
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -133,9 +133,7 @@ class LNWorker(Logger):
self.channel_db = self.network.channel_db
self._last_tried_peer = {} # LNPeerAddr -> unix timestamp
self._add_peers_from_config()
- # wait until we see confirmations
asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.main_loop()), self.network.asyncio_loop)
- self.first_timestamp_requested = None
def _add_peers_from_config(self):
peer_list = self.config.get('lightning_peers', [])
@@ -215,9 +213,24 @@ class LNWorker(Logger):
self.logger.info('got {} ln peers from dns seed'.format(len(peers)))
return peers
+ @staticmethod
+ def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
+ assert len(addr_list) >= 1
+ # choose first one that is an IP
+ for addr_in_db in addr_list:
+ host = addr_in_db.host
+ port = addr_in_db.port
+ if is_ip_address(host):
+ return host, port
+ # otherwise choose one at random
+ # TODO maybe filter out onion if not on tor?
+ choice = random.choice(addr_list)
+ return choice.host, choice.port
class LNGossip(LNWorker):
+ # height of first channel announcements
+ first_block = 497000
def __init__(self, network):
seed = os.urandom(32)
@@ -226,6 +239,61 @@ class LNGossip(LNWorker):
super().__init__(xprv)
self.localfeatures |= LnLocalFeatures.GOSSIP_QUERIES_REQ
+ def start_network(self, network: 'Network'):
+ super().start_network(network)
+ asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.gossip_task()), self.network.asyncio_loop)
+
+ async def gossip_task(self):
+ req_index = self.first_block
+ req_num = self.network.get_local_height() - req_index
+ while len(self.peers) == 0:
+ await asyncio.sleep(1)
+ continue
+ # todo: parallelize over peers
+ peer = list(self.peers.values())[0]
+ await peer.initialized.wait()
+ # send channels_range query. peer will reply with several intervals
+ peer.query_channel_range(req_index, req_num)
+ intervals = []
+ ids = set()
+ # wait until requested range is covered
+ while True:
+ index, num, complete, _ids = await peer.reply_channel_range.get()
+ ids.update(_ids)
+ intervals.append((index, index+num))
+ intervals.sort()
+ while len(intervals) > 1:
+ a,b = intervals[0]
+ c,d = intervals[1]
+ if b == c:
+ intervals = [(a,d)] + intervals[2:]
+ else:
+ break
+ if len(intervals) == 1:
+ a, b = intervals[0]
+ if a <= req_index and b >= req_index + req_num:
+ break
+ self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete))
+ # TODO: filter results by date of last channel update, purge DB
+ #if complete:
+ # self.channel_db.purge_unknown_channels(ids)
+ known = self.channel_db.compare_channels(ids)
+ unknown = list(ids - set(known))
+ total = len(unknown)
+ N = 500
+ while unknown:
+ self.channel_db.process_gossip()
+ await peer.query_short_channel_ids(unknown[0:N])
+ unknown = unknown[N:]
+ self.logger.info(f'Querying channels: {total - len(unknown)}/{total}. Count: {self.channel_db.num_channels}')
+
+ # request gossip fromm current time
+ now = int(time.time())
+ peer.request_gossip(now)
+ while True:
+ await asyncio.sleep(5)
+ self.channel_db.process_gossip()
+
class LNWallet(LNWorker):
@@ -548,20 +616,6 @@ class LNWallet(LNWorker):
def on_channels_updated(self):
self.network.trigger_callback('channels')
- @staticmethod
- def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
- assert len(addr_list) >= 1
- # choose first one that is an IP
- for addr_in_db in addr_list:
- host = addr_in_db.host
- port = addr_in_db.port
- if is_ip_address(host):
- return host, port
- # otherwise choose one at random
- # TODO maybe filter out onion if not on tor?
- choice = random.choice(addr_list)
- return choice.host, choice.port
-
def open_channel(self, connect_contents, local_amt_sat, push_amt_sat, password=None, timeout=20):
node_id, rest = extract_nodeid(connect_contents)
peer = self.peers.get(node_id)