commit aafbe74a28a48a50369ab81133b341941f1077d1
parent 1f6646fa256516e72b5840d7fe8c746b9a38e1aa
Author: ThomasV <thomasv@electrum.org>
Date: Tue, 29 May 2018 11:30:38 +0200
fix channel_reestablish
Diffstat:
5 files changed, 21 insertions(+), 50 deletions(-)
diff --git a/gui/qt/lightning_channels_list.py b/gui/qt/lightning_channels_list.py
@@ -45,7 +45,7 @@ class LightningChannelsList(QtWidgets.QWidget):
push_amt = int(push_amt_inp.text())
assert local_amt >= 200000
assert local_amt >= push_amt
- obj = self.lnworker.open_channel_from_other_thread(node_id, local_amt, push_amt, self.update_rows.emit, password)
+ obj = self.lnworker.open_channel(node_id, local_amt, push_amt, password)
@QtCore.pyqtSlot(dict)
def do_update_single_row(self, new):
diff --git a/lib/lnbase.py b/lib/lnbase.py
@@ -568,8 +568,7 @@ def is_synced(network):
class Peer(PrintError):
- def __init__(self, host, port, pubkey, privkey, network, channel_db, path_finder, channel_state, handle_channel_reestablish, request_initial_sync=False):
- self.handle_channel_reestablish = handle_channel_reestablish
+ def __init__(self, host, port, pubkey, privkey, network, channel_db, path_finder, channel_state, channels, request_initial_sync=False):
self.update_add_htlc_event = asyncio.Event()
self.channel_update_event = asyncio.Event()
self.host = host
@@ -594,7 +593,6 @@ class Peer(PrintError):
self.local_funding_locked = defaultdict(asyncio.Future)
self.remote_funding_locked = defaultdict(asyncio.Future)
self.revoke_and_ack = defaultdict(asyncio.Future)
- self.channel_reestablish = defaultdict(asyncio.Future)
self.update_fulfill_htlc = defaultdict(asyncio.Future)
self.commitment_signed = defaultdict(asyncio.Future)
self.initialized = asyncio.Future()
@@ -602,6 +600,7 @@ class Peer(PrintError):
self.unfulfilled_htlcs = []
self.channel_state = channel_state
self.nodes = {}
+ self.channels = channels
def diagnostic_name(self):
return self.host
@@ -714,13 +713,6 @@ class Peer(PrintError):
l = int.from_bytes(payload['num_pong_bytes'], 'big')
self.send_message(gen_msg('pong', byteslen=l))
- def on_channel_reestablish(self, payload):
- chan_id = int.from_bytes(payload["channel_id"], 'big')
- if chan_id in self.channel_reestablish:
- self.channel_reestablish[chan_id].set_result(payload)
- else:
- asyncio.run_coroutine_threadsafe(self.handle_channel_reestablish(chan_id, payload), self.network.asyncio_loop).result()
-
def on_accept_channel(self, payload):
temp_chan_id = payload["temporary_channel_id"]
if temp_chan_id not in self.channel_accepted: raise Exception("Got unknown accept_channel")
@@ -795,6 +787,8 @@ class Peer(PrintError):
self.process_message(msg)
# initialized
self.initialized.set_result(msg)
+ # reestablish channels
+ [await self.reestablish_channel(c) for c in self.channels]
# loop
while True:
self.ping_if_required()
@@ -963,33 +957,29 @@ class Peer(PrintError):
async def reestablish_channel(self, chan):
assert chan.channel_id not in self.channel_state
-
- await self.initialized
self.send_message(gen_msg("channel_reestablish",
channel_id=chan.channel_id,
next_local_commitment_number=chan.local_state.ctn+1,
next_remote_revocation_number=chan.remote_state.ctn
))
- channel_reestablish_msg = await self.channel_reestablish[chan.channel_id]
- print(channel_reestablish_msg)
- # {
- # 'channel_id': b'\xfa\xce\x0b\x8cjZ6\x03\xd2\x99k\x12\x86\xc7\xed\xe5\xec\x80\x85F\xf2\x1bzn\xa1\xd30I\xf9_V\xfa',
- # 'next_local_commitment_number': b'\x00\x00\x00\x00\x00\x00\x00\x01',
- # 'next_remote_revocation_number': b'\x00\x00\x00\x00\x00\x00\x00\x00',
- # 'your_last_per_commitment_secret': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
- # 'my_current_per_commitment_point': b'\x03\x18\xb9\x1b\x99\xd4\xc3\xf1\x92\x0f\xfe\xe4c\x9e\xae\xa4\xf1\xdeX\xcf4\xa9[\xd1\tAh\x80\x88\x01b*['
- # }
+
+ def on_channel_reestablish(self, payload):
+ chan_id = int.from_bytes(payload["channel_id"], 'big')
+ for chan in self.channels:
+ if chan.channel_id == chan_id:
+ break
+ else:
+ print("Warning: received unknown channel_reestablish", chan_id, list(self.channels))
+ return
+ channel_reestablish_msg = payload
remote_ctn = int.from_bytes(channel_reestablish_msg["next_local_commitment_number"], 'big')
if remote_ctn != chan.remote_state.ctn + 1:
raise Exception("expected remote ctn {}, got {}".format(chan.remote_state.ctn + 1, remote_ctn))
-
local_ctn = int.from_bytes(channel_reestablish_msg["next_remote_revocation_number"], 'big')
if local_ctn != chan.local_state.ctn:
raise Exception("expected local ctn {}, got {}".format(chan.local_state.ctn, local_ctn))
-
if channel_reestablish_msg["my_current_per_commitment_point"] != chan.remote_state.last_per_commitment_point:
raise Exception("Remote PCP mismatch")
-
self.channel_state[chan.channel_id] = "OPEN"
async def funding_locked(self, chan):
@@ -1009,9 +999,7 @@ class Peer(PrintError):
finally:
del self.remote_funding_locked[channel_id]
self.print_error('Done waiting for remote_funding_locked', remote_funding_locked_msg)
-
self.channel_state[chan.channel_id] = "OPEN"
-
return chan._replace(short_channel_id=short_channel_id, remote_state=chan.remote_state._replace(next_per_commitment_point=remote_funding_locked_msg["next_per_commitment_point"]))
def on_update_fail_htlc(self, payload):
diff --git a/lib/lnrouter.py b/lib/lnrouter.py
@@ -120,7 +120,7 @@ class ChannelDB(PrintError):
try:
channel_info = self._id_to_channel_info[short_channel_id]
except KeyError:
- print("could not find", short_channel_id)
+ self.print_error("could not find", short_channel_id)
else:
channel_info.on_channel_update(msg_payload)
diff --git a/lib/lnworker.py b/lib/lnworker.py
@@ -98,7 +98,6 @@ class LNWorker:
self.nodes = {} # received node announcements
self.channel_db = lnrouter.ChannelDB()
self.path_finder = lnrouter.LNPathFinder(self.channel_db)
-
self.channels = [reconstruct_namedtuples(x) for x in wallet.storage.get("channels", {})]
peer_list = network.config.get('lightning_peers', node_list)
self.channel_state = {}
@@ -109,15 +108,11 @@ class LNWorker:
self.on_network_update('updated') # shortcut (don't block) if funding tx locked and verified
def add_peer(self, host, port, pubkey):
- peer = Peer(host, int(port), binascii.unhexlify(pubkey), self.privkey,
- self.network, self.channel_db, self.path_finder, self.channel_state, self.handle_channel_reestablish)
+ node_id = bfh(pubkey)
+ channels = list(filter(lambda x: x.node_id == node_id, self.channels))
+ peer = Peer(host, int(port), node_id, self.privkey, self.network, self.channel_db, self.path_finder, self.channel_state, channels)
self.network.futures.append(asyncio.run_coroutine_threadsafe(peer.main_loop(), asyncio.get_event_loop()))
- self.peers[bfh(pubkey)] = peer
-
- async def handle_channel_reestablish(self, chan_id, payload):
- chans = [x for x in self.channels if x.channel_id == chan_id ]
- chan = chans[0]
- await self.peers[chan.node_id].reestablish_channel(chan)
+ self.peers[node_id] = peer
def save_channel(self, openchannel):
self.channels = [openchannel] # TODO multiple channels
@@ -179,17 +174,6 @@ class LNWorker:
def list_channels(self):
return serialize_channels(self.channels)
- def reestablish_channels(self):
- coro = self._reestablish_channels_coroutine()
- return asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop).result()
-
- # not aiosafe because we call .result() which will propagate an exception
- async def _reestablish_channels_coroutine(self):
- if self.channels is None or len(self.channels) < 1:
- raise Exception("Can't reestablish: No channel saved")
- peer = self.peers[self.channels[0].node_id]
- await peer.reestablish_channel(self.channels[0])
-
# not aiosafe because we call .result() which will propagate an exception
async def _pay_coroutine(self, invoice):
openchannel = self.channels[0]
@@ -216,7 +200,6 @@ class LNWorker:
openchannel = await peer.receive_commitment_revoke_ack(openchannel, expected_received_msat, payment_preimage)
self.save_channel(openchannel)
-
def subscribe_payment_received_from_other_thread(self, emit_function):
pass
diff --git a/lib/tests/test_lnbase.py b/lib/tests/test_lnbase.py
@@ -256,7 +256,7 @@ class Test_LNBase(unittest.TestCase):
def test_find_path_for_payment(self):
channel_db = lnrouter.ChannelDB()
path_finder = lnrouter.LNPathFinder(channel_db)
- p = Peer('', 0, 'a', bitcoin.sha256('privkeyseed'), None, channel_db, path_finder, {}, lambda x, y: None)
+ p = Peer('', 0, 'a', bitcoin.sha256('privkeyseed'), None, channel_db, path_finder, {}, [])
p.on_channel_announcement({'node_id_1': b'b', 'node_id_2': b'c', 'short_channel_id': bfh('0000000000000001')})
p.on_channel_announcement({'node_id_1': b'b', 'node_id_2': b'e', 'short_channel_id': bfh('0000000000000002')})
p.on_channel_announcement({'node_id_1': b'a', 'node_id_2': b'b', 'short_channel_id': bfh('0000000000000003')})