electrum

Electrum Bitcoin wallet
git clone https://git.parazyd.org/electrum
Log | Files | Refs | Submodules

commit dbceed26474ba4a1ef7da22cd0fa0e5fde8e6a3c
parent 0a9e7cb04e45981acdd191f2143b993073254af8
Author: ThomasV <thomasv@electrum.org>
Date:   Tue,  4 Feb 2020 13:35:58 +0100

Restructure wallet storage:
 - Perform json deserializations in wallet_db
 - use StoredDict class that keeps tracks of its modifications

Diffstat:
Melectrum/json_db.py | 97+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--
Melectrum/lnchannel.py | 138+++++++++++++++++++++----------------------------------------------------------
Melectrum/lnhtlc.py | 62+++++++++++++++++++-------------------------------------------
Melectrum/lnpeer.py | 21+++++++++++++--------
Melectrum/lnsweep.py | 4++--
Melectrum/lnutil.py | 26++++++++++----------------
Melectrum/lnworker.py | 50++++++++++++++++++++++----------------------------
Melectrum/plugins/labels/labels.py | 2--
Melectrum/tests/test_lnchannel.py | 27+++++++--------------------
Melectrum/tests/test_lnhtlc.py | 22+++++++++++-----------
Melectrum/tests/test_lnutil.py | 14++++++++++----
Melectrum/util.py | 2++
Melectrum/wallet.py | 32+++++++++-----------------------
Melectrum/wallet_db.py | 97++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------------
14 files changed, 303 insertions(+), 291 deletions(-)

diff --git a/electrum/json_db.py b/electrum/json_db.py @@ -45,6 +45,101 @@ def locked(func): return wrapper +class StoredObject: + + db = None + + def __setattr__(self, key, value): + if self.db: + self.db.set_modified(True) + object.__setattr__(self, key, value) + + def set_db(self, db): + self.db = db + + def to_json(self): + d = dict(vars(self)) + d.pop('db', None) + return d + + +_RaiseKeyError = object() # singleton for no-default behavior + +class StoredDict(dict): + + def __init__(self, data, db, path): + self.db = db + self.lock = self.db.lock if self.db else threading.RLock() + self.path = path + # recursively convert dicts to StoredDict + for k, v in list(data.items()): + self.__setitem__(k, v) + + def convert_key(self, key): + # convert int, HTLCOwner to str + return str(int(key)) if isinstance(key, int) else key + + @locked + def __setitem__(self, key, v): + key = self.convert_key(key) + is_new = key not in self + # early return to prevent unnecessary disk writes + if not is_new and self[key] == v: + return + # recursively convert dict to StoredDict. + # _convert_dict is called breadth-first + if isinstance(v, dict): + if self.db: + v = self.db._convert_dict(self.path, key, v) + v = StoredDict(v, self.db, self.path + [key]) + # convert_value is called depth-first + if isinstance(v, dict) or isinstance(v, str): + if self.db: + v = self.db._convert_value(self.path, key, v) + # set parent of StoredObject + if isinstance(v, StoredObject): + v.set_db(self.db) + # set item + dict.__setitem__(self, key, v) + if self.db: + self.db.set_modified(True) + + @locked + def __delitem__(self, key): + key = self.convert_key(key) + dict.__delitem__(self, key) + if self.db: + self.db.set_modified(True) + + @locked + def __getitem__(self, key): + key = self.convert_key(key) + return dict.__getitem__(self, key) + + @locked + def __contains__(self, key): + key = self.convert_key(key) + return dict.__contains__(self, key) + + @locked + def pop(self, key, v=_RaiseKeyError): + key = self.convert_key(key) + if v is _RaiseKeyError: + r = dict.pop(self, key) + else: + r = dict.pop(self, key, v) + if self.db: + self.db.set_modified(True) + return r + + @locked + def get(self, key, default=None): + key = self.convert_key(key) + return dict.get(self, key, default) + + + + class JsonDB(Logger): def __init__(self, data): @@ -65,8 +160,6 @@ class JsonDB(Logger): v = self.data.get(key) if v is None: v = default - else: - v = copy.deepcopy(v) return v @modifier diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py @@ -54,6 +54,7 @@ from .lnhtlc import HTLCManager if TYPE_CHECKING: from .lnworker import LNWallet + from .json_db import StoredDict # lightning channel states @@ -92,17 +93,6 @@ state_transitions = [ (cs.CLOSED, cs.REDEEMED), ] -class ChannelJsonEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, bytes): - return binascii.hexlify(o).decode("ascii") - if isinstance(o, RevocationStore): - return o.serialize() - if isinstance(o, set): - return list(o) - if hasattr(o, 'to_json') and callable(o.to_json): - return o.to_json() - return super().default(o) RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_commitment_point"]) @@ -110,31 +100,9 @@ RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_co class RemoteCtnTooFarInFuture(Exception): pass -def decodeAll(d, local): - for k, v in d.items(): - if k == 'revocation_store': - yield (k, RevocationStore(v)) - elif k.endswith("_basepoint") or k.endswith("_key"): - if local: - yield (k, Keypair(**dict(decodeAll(v, local)))) - else: - yield (k, OnlyPubkeyKeypair(**dict(decodeAll(v, local)))) - elif k in ["node_id", "channel_id", "short_channel_id", "pubkey", "privkey", "current_per_commitment_point", "next_per_commitment_point", "per_commitment_secret_seed", "current_commitment_signature", "current_htlc_signatures"] and v is not None: - yield (k, binascii.unhexlify(v)) - else: - yield (k, v) - def htlcsum(htlcs): return sum([x.amount_msat for x in htlcs]) -# following two functions are used because json -# doesn't store int keys and byte string values -def str_bytes_dict_from_save(x) -> Dict[int, bytes]: - return {int(k): bfh(v) for k,v in x.items()} - -def str_bytes_dict_to_save(x) -> Dict[str, str]: - return {str(k): bh2u(v) for k, v in x.items()} - class Channel(Logger): # note: try to avoid naming ctns/ctxs/etc as "current" and "pending". @@ -149,44 +117,53 @@ class Channel(Logger): except: return super().diagnostic_name() - def __init__(self, state, *, sweep_address=None, name=None, lnworker=None, initial_feerate=None): + def __init__(self, state: 'StoredDict', *, sweep_address=None, name=None, lnworker=None, initial_feerate=None): self.name = name Logger.__init__(self) self.lnworker = lnworker # type: Optional[LNWallet] self.sweep_address = sweep_address - assert 'local_state' not in state - self.db_lock = self.lnworker.wallet.storage.db.lock if self.lnworker else threading.RLock() + self.storage = state + self.db_lock = self.storage.db.lock if self.storage.db else threading.RLock() self.config = {} self.config[LOCAL] = state["local_config"] - if type(self.config[LOCAL]) is not LocalConfig: - conf = dict(decodeAll(self.config[LOCAL], True)) - self.config[LOCAL] = LocalConfig(**conf) - assert type(self.config[LOCAL].htlc_basepoint.privkey) is bytes - self.config[REMOTE] = state["remote_config"] - if type(self.config[REMOTE]) is not RemoteConfig: - conf = dict(decodeAll(self.config[REMOTE], False)) - self.config[REMOTE] = RemoteConfig(**conf) - assert type(self.config[REMOTE].htlc_basepoint.pubkey) is bytes - - self.channel_id = bfh(state["channel_id"]) if type(state["channel_id"]) not in (bytes, type(None)) else state["channel_id"] - self.constraints = ChannelConstraints(**state["constraints"]) if type(state["constraints"]) is not ChannelConstraints else state["constraints"] - self.funding_outpoint = Outpoint(**dict(decodeAll(state["funding_outpoint"], False))) if type(state["funding_outpoint"]) is not Outpoint else state["funding_outpoint"] - self.node_id = bfh(state["node_id"]) if type(state["node_id"]) not in (bytes, type(None)) else state["node_id"] # type: bytes + self.channel_id = bfh(state["channel_id"]) + self.constraints = state["constraints"] + self.funding_outpoint = state["funding_outpoint"] + self.node_id = bfh(state["node_id"]) self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"]) self.short_channel_id_predicted = self.short_channel_id - self.onion_keys = str_bytes_dict_from_save(state.get('onion_keys', {})) - self.data_loss_protect_remote_pcp = str_bytes_dict_from_save(state.get('data_loss_protect_remote_pcp', {})) - self.remote_update = bfh(state.get('remote_update')) if state.get('remote_update') else None - - log = state.get('log') - self.hm = HTLCManager(log=log, initial_feerate=initial_feerate) + self.onion_keys = state['onion_keys'] + self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp'] + self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate) self._state = channel_states[state['state']] self.peer_state = peer_states.DISCONNECTED self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]] self._outgoing_channel_update = None # type: Optional[bytes] self.revocation_store = RevocationStore(state["revocation_store"]) + def set_onion_key(self, key, value): + self.onion_keys[key] = value + + def get_onion_key(self, key): + return self.onion_keys.get(key) + + def set_data_loss_protect_remote_pcp(self, key, value): + self.data_loss_protect_remote_pcp[key] = value + + def get_data_loss_protect_remote_pcp(self, key): + self.data_loss_protect_remote_pcp.get(key) + + def set_remote_update(self, raw): + self.storage['remote_update'] = raw.hex() + + def get_remote_update(self): + return bfh(self.storage.get('remote_update')) if self.storage.get('remote_update') else None + + def set_short_channel_id(self, short_id): + self.short_channel_id = short_id + self.storage["short_channel_id"] = short_id + def get_feerate(self, subject, ctn): return self.hm.get_feerate(subject, ctn) @@ -229,8 +206,10 @@ class Channel(Logger): old_state = self._state if (old_state, state) not in state_transitions: raise Exception(f"Transition not allowed: {old_state.name} -> {state.name}") - self._state = state self.logger.debug(f'Setting channel state: {old_state.name} -> {state.name}') + self._state = state + self.storage['state'] = self._state.name + if self.lnworker: self.lnworker.save_channel(self) self.lnworker.network.trigger_callback('channel', self) @@ -656,51 +635,6 @@ class Channel(Logger): else: self.hm.recv_update_fee(feerate) - def to_save(self): - to_save = { - "local_config": self.config[LOCAL], - "remote_config": self.config[REMOTE], - "channel_id": self.channel_id, - "short_channel_id": self.short_channel_id, - "constraints": self.constraints, - "funding_outpoint": self.funding_outpoint, - "node_id": self.node_id, - "log": self.hm.to_save(), - "revocation_store": self.revocation_store, - "onion_keys": str_bytes_dict_to_save(self.onion_keys), - "state": self._state.name, - "data_loss_protect_remote_pcp": str_bytes_dict_to_save(self.data_loss_protect_remote_pcp), - "remote_update": self.remote_update.hex() if self.remote_update else None - } - return to_save - - def serialize(self): - namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()} - serialized_channel = {} - to_save_ref = self.to_save() - for k, v in to_save_ref.items(): - if isinstance(v, tuple): - serialized_channel[k] = namedtuples_to_dict(v) - else: - serialized_channel[k] = v - dumped = ChannelJsonEncoder().encode(serialized_channel) - roundtripped = json.loads(dumped) - reconstructed = Channel(roundtripped) - to_save_new = reconstructed.to_save() - if to_save_new != to_save_ref: - from pprint import PrettyPrinter - pp = PrettyPrinter(indent=168) - try: - from deepdiff import DeepDiff - except ImportError: - raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(to_save_ref) + "\n" + pp.pformat(to_save_new)) - else: - raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(DeepDiff(to_save_ref, to_save_new))) - return roundtripped - - def __str__(self): - return str(self.serialize()) - def make_commitment(self, subject, this_point, ctn) -> PartialTransaction: assert type(subject) is HTLCOwner feerate = self.get_feerate(subject, ctn) diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py @@ -1,14 +1,17 @@ from copy import deepcopy -from typing import Optional, Sequence, Tuple, List, Dict +from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate from .util import bh2u, bfh +if TYPE_CHECKING: + from .json_db import StoredDict class HTLCManager: - def __init__(self, *, log=None, initial_feerate=None): - if log is None: + def __init__(self, log:'StoredDict', *, initial_feerate=None): + + if len(log) == 0: initial = { 'adds': {}, 'locked_in': {}, @@ -17,33 +20,18 @@ class HTLCManager: 'fee_updates': {}, # "side who initiated fee update" -> action -> list of FeeUpdates 'revack_pending': False, 'next_htlc_id': 0, - 'ctn': -1, # oldest unrevoked ctx of sub + 'ctn': -1, # oldest unrevoked ctx of sub } - log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)} - else: - assert type(log) is dict - log = {(HTLCOwner(int(k)) if k in ("-1", "1") else k): v - for k, v in deepcopy(log).items()} - for sub in (LOCAL, REMOTE): - log[sub]['adds'] = {int(htlc_id): UpdateAddHtlc(*htlc) for htlc_id, htlc in log[sub]['adds'].items()} - coerceHtlcOwner2IntMap = lambda ctns: {HTLCOwner(int(owner)): ctn for owner, ctn in ctns.items()} - # "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn - log[sub]['locked_in'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['locked_in'].items()} - log[sub]['settles'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['settles'].items()} - log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()} - # "side who initiated fee update" -> action -> list of FeeUpdates - log[sub]['fee_updates'] = { int(x): FeeUpdate(**fee_upd) for x,fee_upd in log[sub]['fee_updates'].items() } - - if 'unacked_local_updates2' not in log: + log[LOCAL] = deepcopy(initial) + log[REMOTE] = deepcopy(initial) log['unacked_local_updates2'] = {} - log['unacked_local_updates2'] = {int(ctn): [bfh(msg) for msg in messages] - for ctn, messages in log['unacked_local_updates2'].items()} + # maybe bootstrap fee_updates if initial_feerate was provided if initial_feerate is not None: assert type(initial_feerate) is int for sub in (LOCAL, REMOTE): if not log[sub]['fee_updates']: - log[sub]['fee_updates'][0] = FeeUpdate(initial_feerate, ctn_local=0, ctn_remote=0) + log[sub]['fee_updates'][0] = FeeUpdate(rate=initial_feerate, ctn_local=0, ctn_remote=0) self.log = log def ctn_latest(self, sub: HTLCOwner) -> int: @@ -66,20 +54,6 @@ class HTLCManager: def get_next_htlc_id(self, sub: HTLCOwner) -> int: return self.log[sub]['next_htlc_id'] - def to_save(self): - log = deepcopy(self.log) - for sub in (LOCAL, REMOTE): - # adds - d = {} - for htlc_id, htlc in log[sub]['adds'].items(): - d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:] - log[sub]['adds'] = d - # fee_updates - log[sub]['fee_updates'] = { x:fee_upd.to_json() for x, fee_upd in self.log[sub]['fee_updates'].items() } - log['unacked_local_updates2'] = {ctn: [bh2u(msg) for msg in messages] - for ctn, messages in log['unacked_local_updates2'].items()} - return log - ##### Actions on channel: def channel_open_finished(self): @@ -132,7 +106,7 @@ class HTLCManager: def _new_feeupdate(self, fee_update: FeeUpdate, subject: HTLCOwner) -> None: # overwrite last fee update if not yet committed to by anyone; otherwise append d = self.log[subject]['fee_updates'] - assert type(d) is dict + #assert type(d) is StoredDict n = len(d) last_fee_update = d[n-1] if (last_fee_update.ctn_local is None or last_fee_update.ctn_local > self.ctn_latest(LOCAL)) \ @@ -194,7 +168,7 @@ class HTLCManager: del self.log[REMOTE]['locked_in'][htlc_id] del self.log[REMOTE]['adds'][htlc_id] if self.log[REMOTE]['locked_in']: - self.log[REMOTE]['next_htlc_id'] = max(self.log[REMOTE]['locked_in']) + 1 + self.log[REMOTE]['next_htlc_id'] = max([int(x) for x in self.log[REMOTE]['locked_in'].keys()]) + 1 else: self.log[REMOTE]['next_htlc_id'] = 0 # htlcs removed @@ -217,12 +191,14 @@ class HTLCManager: ctn_idx = self.ctn_latest(REMOTE) else: ctn_idx = self.ctn_latest(REMOTE) + 1 - if ctn_idx not in self.log['unacked_local_updates2']: - self.log['unacked_local_updates2'][ctn_idx] = [] - self.log['unacked_local_updates2'][ctn_idx].append(raw_update_msg) + l = self.log['unacked_local_updates2'].get(ctn_idx, []) + l.append(raw_update_msg.hex()) + self.log['unacked_local_updates2'][ctn_idx] = l def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]: - return self.log['unacked_local_updates2'] + #return self.log['unacked_local_updates2'] + return {int(ctn): [bfh(msg) for msg in messages] + for ctn, messages in self.log['unacked_local_updates2'].items()} ##### Queries re HTLCs: diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py @@ -221,7 +221,7 @@ class Peer(Logger): def maybe_save_remote_update(self, payload): for chan in self.channels.values(): if chan.short_channel_id == payload['short_channel_id']: - chan.remote_update = payload['raw'] + chan.set_remote_update(payload['raw']) self.logger.info("saved remote_update") def on_announcement_signatures(self, payload): @@ -611,9 +611,15 @@ class Peer(Logger): "constraints": constraints, "remote_update": None, "state": channel_states.PREOPENING.name, + 'onion_keys': {}, + 'data_loss_protect_remote_pcp': {}, + "log": {}, "revocation_store": {}, } - return chan_dict + channel_id = chan_dict.get('channel_id') + channels = self.lnworker.storage.db.get_dict('channels') + channels[channel_id] = chan_dict + return channels.get(channel_id) async def on_open_channel(self, payload): # payload['channel_flags'] @@ -684,7 +690,7 @@ class Peer(Logger): signature=sig_64, ) chan.open_with_first_pcp(payload['first_per_commitment_point'], remote_sig) - self.lnworker.save_channel(chan) + self.lnworker.add_channel(chan) self.lnworker.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) def validate_remote_reserve(self, payload_field: bytes, dust_limit: int, funding_sat: int) -> int: @@ -850,7 +856,7 @@ class Peer(Logger): else: if dlp_enabled and should_close_they_are_ahead: self.logger.warning(f"channel_reestablish: remote is ahead of us! luckily DLP is enabled. remote PCP: {bh2u(their_local_pcp)}") - chan.data_loss_protect_remote_pcp[their_next_local_ctn - 1] = their_local_pcp + chan.set_data_loss_protect_remote_pcp(their_next_local_ctn - 1, their_local_pcp) self.lnworker.save_channel(chan) if should_close_they_are_ahead: self.logger.warning(f"channel_reestablish: remote is ahead of us! trying to get them to force-close.") @@ -885,7 +891,6 @@ class Peer(Logger): self.logger.info(f"on_funding_locked. channel: {bh2u(channel_id)}") chan = self.channels.get(channel_id) if not chan: - print(self.channels) raise Exception("Got unknown funding_locked", channel_id) if not chan.config[LOCAL].funding_locked_received: our_next_point = chan.config[REMOTE].next_per_commitment_point @@ -1004,11 +1009,11 @@ class Peer(Logger): # peer may have sent us a channel update for the incoming direction previously pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id) if pending_channel_update: - chan.remote_update = pending_channel_update['raw'] + chan.set_remote_update(pending_channel_update['raw']) # add remote update with a fresh timestamp - if chan.remote_update: + if chan.get_remote_update(): now = int(time.time()) - remote_update_decoded = decode_msg(chan.remote_update)[1] + remote_update_decoded = decode_msg(chan.get_remote_update())[1] remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big") self.channel_db.add_channel_update(remote_update_decoded) diff --git a/electrum/lnsweep.py b/electrum/lnsweep.py @@ -299,8 +299,8 @@ def analyze_ctx(chan: 'Channel', ctx: Transaction): their_pcp = ecc.ECPrivkey(per_commitment_secret).get_public_key_bytes(compressed=True) is_revocation = True #_logger.info(f'tx for revoked: {list(txs.keys())}') - elif ctn in chan.data_loss_protect_remote_pcp: - their_pcp = chan.data_loss_protect_remote_pcp[ctn] + elif chan.get_data_loss_protect_remote_pcp(ctn): + their_pcp = chan.get_data_loss_protect_remote_pcp(ctn) is_revocation = False else: return diff --git a/electrum/lnutil.py b/electrum/lnutil.py @@ -38,12 +38,7 @@ LN_MAX_FUNDING_SAT = pow(2, 24) - 1 def ln_dummy_address(): return redeem_script_to_address('p2wsh', '') - -class StoredObject: - - def to_json(self): - return dict(vars(self)) - +from .json_db import StoredObject @attr.s class OnlyPubkeyKeypair(StoredObject): @@ -180,21 +175,23 @@ class RevocationStore: START_INDEX = 2 ** 48 - 1 def __init__(self, storage): - self.index = storage.get('index', self.START_INDEX) - buckets = storage.get('buckets', {}) - decode = lambda to_decode: ShachainElement(bfh(to_decode[0]), int(to_decode[1])) - self.buckets = dict((int(k), decode(v)) for k, v in buckets.items()) + if len(storage) == 0: + storage['index'] = self.START_INDEX + storage['buckets'] = {} + self.storage = storage + self.buckets = storage['buckets'] def add_next_entry(self, hsh): - new_element = ShachainElement(index=self.index, secret=hsh) - bucket = count_trailing_zeros(self.index) + index = self.storage['index'] + new_element = ShachainElement(index=index, secret=hsh) + bucket = count_trailing_zeros(index) for i in range(0, bucket): this_bucket = self.buckets[i] e = shachain_derive(new_element, this_bucket.index) if e != this_bucket: raise Exception("hash is not derivable: {} {} {}".format(bh2u(e.secret), bh2u(this_bucket.secret), this_bucket.index)) self.buckets[bucket] = new_element - self.index -= 1 + self.storage['index'] = index - 1 def retrieve_secret(self, index: int) -> bytes: assert index <= self.START_INDEX, index @@ -209,9 +206,6 @@ class RevocationStore: return element.secret raise UnableToDeriveSecret() - def serialize(self): - return {"index": self.index, "buckets": dict( (k, [bh2u(v.secret), v.index]) for k, v in self.buckets.items()) } - def __eq__(self, o): return type(o) is RevocationStore and self.serialize() == o.serialize() diff --git a/electrum/lnworker.py b/electrum/lnworker.py @@ -34,13 +34,14 @@ from .bip32 import BIP32Node from .util import bh2u, bfh, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions from .util import ignore_exceptions, make_aiohttp_session from .util import timestamp_to_datetime +from .util import MyEncoder from .logging import Logger from .lntransport import LNTransport, LNResponderTransport from .lnpeer import Peer, LN_P2P_NETWORK_TIMEOUT from .lnaddr import lnencode, LnAddr, lndecode from .ecc import der_sig_from_sig_string from .ecc_fast import is_using_fast_ecc -from .lnchannel import Channel, ChannelJsonEncoder +from .lnchannel import Channel from .lnchannel import channel_states, peer_states from . import lnutil from .lnutil import funding_output_script @@ -106,8 +107,6 @@ FALLBACK_NODE_LIST_MAINNET = [ LNPeerAddr(host='3.124.63.44', port=9735, pubkey=bfh('0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3')), ] -encoder = ChannelJsonEncoder() - from typing import NamedTuple @@ -347,19 +346,20 @@ class LNWallet(LNWorker): LNWorker.__init__(self, xprv) self.ln_keystore = keystore.from_xprv(xprv) self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ - self.payments = self.storage.get('lightning_payments', {}) # RHASH -> amount, direction, is_paid - self.preimages = self.storage.get('lightning_preimages', {}) # RHASH -> preimage + self.payments = self.storage.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid + self.preimages = self.storage.db.get_dict('lightning_preimages') # RHASH -> preimage self.sweep_address = wallet.get_receiving_address() self.lock = threading.RLock() self.logs = defaultdict(list) # type: Dict[str, List[PaymentAttemptLog]] # key is RHASH # note: accessing channels (besides simple lookup) needs self.lock! - self.channels = {} # type: Dict[bytes, Channel] - for x in wallet.storage.get("channels", {}).values(): - c = Channel(x, sweep_address=self.sweep_address, lnworker=self) - self.channels[c.channel_id] = c + self.channels = {} + channels = self.storage.db.get_dict("channels") + for channel_id, c in channels.items(): + self.channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self) + # timestamps of opening and closing transactions - self.channel_timestamps = self.storage.get('lightning_channel_timestamps', {}) + self.channel_timestamps = self.storage.db.get_dict('lightning_channel_timestamps') self.pending_payments = defaultdict(asyncio.Future) @ignore_exceptions @@ -610,17 +610,9 @@ class LNWallet(LNWorker): assert type(chan) is Channel if chan.config[REMOTE].next_per_commitment_point == chan.config[REMOTE].current_per_commitment_point: raise Exception("Tried to save channel with next_point == current_point, this should not happen") - with self.lock: - self.channels[chan.channel_id] = chan - self.save_channels() + self.wallet.storage.write() self.network.trigger_callback('channel', chan) - def save_channels(self): - with self.lock: - dumped = dict( (k.hex(), c.serialize()) for k, c in self.channels.items() ) - self.storage.put("channels", dumped) - self.storage.write() - def save_short_chan_id(self, chan): """ Checks if Funding TX has been mined. If it has, save the short channel ID in chan; @@ -648,8 +640,8 @@ class LNWallet(LNWorker): return block_height, tx_pos = self.lnwatcher.get_txpos(chan.funding_outpoint.txid) assert tx_pos >= 0 - chan.short_channel_id = ShortChannelID.from_components( - block_height, tx_pos, chan.funding_outpoint.output_index) + chan.set_short_channel_id(ShortChannelID.from_components( + block_height, tx_pos, chan.funding_outpoint.output_index)) self.logger.info(f"save_short_channel_id: {chan.short_channel_id}") self.save_channel(chan) @@ -669,7 +661,6 @@ class LNWallet(LNWorker): # save timestamp regardless of state, so that funding tx is returned in get_history self.channel_timestamps[bh2u(chan.channel_id)] = chan.funding_outpoint.txid, funding_height.height, funding_height.timestamp, None, None, None - self.storage.put('lightning_channel_timestamps', self.channel_timestamps) if chan.get_state() == channel_states.OPEN and self.should_channel_be_closed_due_to_expiring_htlcs(chan): self.logger.info(f"force-closing due to expiring htlcs") @@ -714,7 +705,6 @@ class LNWallet(LNWorker): # fixme: this is wasteful self.channel_timestamps[bh2u(chan.channel_id)] = funding_txid, funding_height.height, funding_height.timestamp, closing_txid, closing_height.height, closing_height.timestamp - self.storage.put('lightning_channel_timestamps', self.channel_timestamps) # remove from channel_db if chan.short_channel_id is not None: @@ -836,7 +826,7 @@ class LNWallet(LNWorker): funding_sat=funding_sat, push_msat=push_sat * 1000, temp_channel_id=os.urandom(32)) - self.save_channel(chan) + self.add_channel(chan) self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) self.network.trigger_callback('channels_updated', self.wallet) self.wallet.add_transaction(funding_tx) # save tx as local into the wallet @@ -846,6 +836,10 @@ class LNWallet(LNWorker): await asyncio.wait_for(self.network.broadcast_transaction(funding_tx), LN_P2P_NETWORK_TIMEOUT) return chan, funding_tx + def add_channel(self, chan): + with self.lock: + self.channels[chan.channel_id] = chan + @log_exceptions async def add_peer(self, connect_str: str) -> Peer: node_id, rest = extract_nodeid(connect_str) @@ -1133,7 +1127,6 @@ class LNWallet(LNWorker): def save_preimage(self, payment_hash: bytes, preimage: bytes): assert sha256(preimage) == payment_hash self.preimages[bh2u(payment_hash)] = bh2u(preimage) - self.storage.put('lightning_preimages', self.preimages) self.storage.write() def get_preimage(self, payment_hash: bytes) -> bytes: @@ -1152,7 +1145,6 @@ class LNWallet(LNWorker): assert info.status in [PR_PAID, PR_UNPAID, PR_INFLIGHT] with self.lock: self.payments[key] = info.amount, info.direction, info.status - self.storage.put('lightning_payments', self.payments) self.storage.write() def get_payment_status(self, payment_hash): @@ -1238,7 +1230,6 @@ class LNWallet(LNWorker): del self.payments[payment_hash_hex] except KeyError: return - self.storage.put('lightning_payments', self.payments) self.storage.write() def get_balance(self): @@ -1246,6 +1237,7 @@ class LNWallet(LNWorker): return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0 for chan in self.channels.values()))/1000 def list_channels(self): + encoder = MyEncoder() with self.lock: # we output the funding_outpoint instead of the channel_id because lnd uses channel_point (funding outpoint) to identify channels for channel_id, chan in self.channels.items(): @@ -1283,7 +1275,9 @@ class LNWallet(LNWorker): assert chan.is_closed() with self.lock: self.channels.pop(chan_id) - self.save_channels() + self.channel_timestamps.pop(chan_id.hex()) + self.storage.get('channels').pop(chan_id.hex()) + self.network.trigger_callback('channels_updated', self.wallet) self.network.trigger_callback('wallet_updated', self.wallet) diff --git a/electrum/plugins/labels/labels.py b/electrum/plugins/labels/labels.py @@ -149,8 +149,6 @@ class LabelsPlugin(BasePlugin): wallet.labels[key] = value self.logger.info(f"received {len(response)} labels") - # do not write to disk because we're in a daemon thread - wallet.storage.put('labels', wallet.labels) self.set_nonce(wallet, response["nonce"] + 1) self.on_pulled(wallet) diff --git a/electrum/tests/test_lnchannel.py b/electrum/tests/test_lnchannel.py @@ -35,6 +35,7 @@ from electrum.lnutil import FeeUpdate from electrum.ecc import sig_string_from_der_sig from electrum.logging import console_stderr_handler from electrum.lnchannel import channel_states +from electrum.json_db import StoredDict from . import ElectrumTestCase @@ -45,9 +46,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator, assert local_amount > 0 assert remote_amount > 0 channel_id, _ = lnpeer.channel_id_from_funding_tx(funding_txid, funding_index) - - return { - "channel_id":channel_id, + state = { + "channel_id":channel_id.hex(), "short_channel_id":channel_id[:8], "funding_outpoint":lnpeer.Outpoint(funding_txid, funding_index), "remote_config":lnpeer.RemoteConfig( @@ -63,7 +63,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator, initial_msat=remote_amount, reserve_sat=0, htlc_minimum_msat=1, - next_per_commitment_point=nex, current_per_commitment_point=cur, ), @@ -79,7 +78,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator, max_accepted_htlcs=5, initial_msat=local_amount, reserve_sat=0, - per_commitment_secret_seed=seed, funding_locked_received=True, was_announced=False, @@ -91,11 +89,14 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator, is_initiator=is_initiator, funding_txn_minimum_depth=3, ), - "node_id":other_node_id, + "node_id":other_node_id.hex(), 'onion_keys': {}, + 'data_loss_protect_remote_pcp': {}, 'state': 'PREOPENING', + 'log': {}, 'revocation_store': {}, } + return StoredDict(state, None, []) def bip32(sequence): node = bip32_utils.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence) @@ -317,7 +318,6 @@ class TestChannel(ElectrumTestCase): # Bob revokes his prior commitment given to him by Alice, since he now # has a valid signature for a newer commitment. bobRevocation, _ = bob_channel.revoke_current_commitment() - bob_channel.serialize() self.assertTrue(bob_channel.signature_fits(bob_channel.get_latest_commitment(LOCAL))) # Bob finally sends a signature for Alice's commitment transaction. @@ -341,18 +341,14 @@ class TestChannel(ElectrumTestCase): # her prior commitment transaction. Alice shouldn't have any HTLCs to # forward since she's sending an outgoing HTLC. alice_channel.receive_revocation(bobRevocation) - alice_channel.serialize() self.assertTrue(alice_channel.signature_fits(alice_channel.get_latest_commitment(LOCAL))) - alice_channel.serialize() self.assertEqual(len(alice_channel.get_latest_commitment(LOCAL).outputs()), 2) self.assertEqual(len(alice_channel.get_latest_commitment(REMOTE).outputs()), 3) self.assertEqual(len(alice_channel.force_close_tx().outputs()), 2) self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1) - alice_channel.serialize() - self.assertEqual(alice_channel.get_next_commitment(LOCAL).outputs(), bob_channel.get_latest_commitment(REMOTE).outputs()) @@ -365,14 +361,12 @@ class TestChannel(ElectrumTestCase): self.assertEqual(len(alice_channel.force_close_tx().outputs()), 3) self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1) - alice_channel.serialize() tx1 = str(alice_channel.force_close_tx()) self.assertNotEqual(tx0, tx1) # Alice then generates a revocation for bob. aliceRevocation, _ = alice_channel.revoke_current_commitment() - alice_channel.serialize() tx2 = str(alice_channel.force_close_tx()) # since alice already has the signature for the next one, it doesn't change her force close tx (it was already the newer one) @@ -384,7 +378,6 @@ class TestChannel(ElectrumTestCase): # into both commitment transactions. self.assertTrue(bob_channel.signature_fits(bob_channel.get_latest_commitment(LOCAL))) bob_channel.receive_revocation(aliceRevocation) - bob_channel.serialize() # At this point, both sides should have the proper number of satoshis # sent, and commitment height updated within their local channel @@ -450,20 +443,16 @@ class TestChannel(ElectrumTestCase): self.assertEqual(1, alice_channel.get_oldest_unrevoked_ctn(LOCAL)) self.assertEqual(len(alice_channel.included_htlcs(LOCAL, RECEIVED, ctn=2)), 0) aliceRevocation2, _ = alice_channel.revoke_current_commitment() - alice_channel.serialize() aliceSig2, aliceHtlcSigs2 = alice_channel.sign_next_commitment() self.assertEqual(aliceHtlcSigs2, [], "alice should generate no htlc signatures") self.assertEqual(len(bob_channel.get_latest_commitment(LOCAL).outputs()), 3) bob_channel.receive_revocation(aliceRevocation2) - bob_channel.serialize() bob_channel.receive_new_commitment(aliceSig2, aliceHtlcSigs2) bobRevocation2, (received, sent) = bob_channel.revoke_current_commitment() self.assertEqual(one_bitcoin_in_msat, received) - bob_channel.serialize() alice_channel.receive_revocation(bobRevocation2) - alice_channel.serialize() # At this point, Bob should have 6 BTC settled, with Alice still having # 4 BTC. Alice's channel should show 1 BTC sent and Bob's channel @@ -509,8 +498,6 @@ class TestChannel(ElectrumTestCase): self.assertEqual(bob_channel.total_msat(RECEIVED), one_bitcoin_in_msat, "bob satoshis received incorrect") self.assertEqual(bob_channel.total_msat(SENT), 5 * one_bitcoin_in_msat, "bob satoshis sent incorrect") - alice_channel.serialize() - def alice_to_bob_fee_update(self, fee=111): aoldctx = self.alice_channel.get_next_commitment(REMOTE).outputs() diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py @@ -4,18 +4,18 @@ from typing import NamedTuple from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner, Direction from electrum.lnhtlc import HTLCManager +from electrum.json_db import StoredDict from . import ElectrumTestCase - class H(NamedTuple): owner : str htlc_id : int class TestHTLCManager(ElectrumTestCase): def test_adding_htlcs_race(self): - A = HTLCManager() - B = HTLCManager() + A = HTLCManager(StoredDict({}, None, [])) + B = HTLCManager(StoredDict({}, None, [])) A.channel_open_finished() B.channel_open_finished() ah0, bh0 = H('A', 0), H('B', 0) @@ -61,8 +61,8 @@ class TestHTLCManager(ElectrumTestCase): def test_single_htlc_full_lifecycle(self): def htlc_lifecycle(htlc_success: bool): - A = HTLCManager() - B = HTLCManager() + A = HTLCManager(StoredDict({}, None, [])) + B = HTLCManager(StoredDict({}, None, [])) A.channel_open_finished() B.channel_open_finished() B.recv_htlc(A.send_htlc(H('A', 0))) @@ -134,8 +134,8 @@ class TestHTLCManager(ElectrumTestCase): def test_remove_htlc_while_owing_commitment(self): def htlc_lifecycle(htlc_success: bool): - A = HTLCManager() - B = HTLCManager() + A = HTLCManager(StoredDict({}, None, [])) + B = HTLCManager(StoredDict({}, None, [])) A.channel_open_finished() B.channel_open_finished() ah0 = H('A', 0) @@ -171,8 +171,8 @@ class TestHTLCManager(ElectrumTestCase): htlc_lifecycle(htlc_success=False) def test_adding_htlc_between_send_ctx_and_recv_rev(self): - A = HTLCManager() - B = HTLCManager() + A = HTLCManager(StoredDict({}, None, [])) + B = HTLCManager(StoredDict({}, None, [])) A.channel_open_finished() B.channel_open_finished() A.send_ctx() @@ -217,8 +217,8 @@ class TestHTLCManager(ElectrumTestCase): self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE)) def test_unacked_local_updates(self): - A = HTLCManager() - B = HTLCManager() + A = HTLCManager(StoredDict({}, None, [])) + B = HTLCManager(StoredDict({}, None, [])) A.channel_open_finished() B.channel_open_finished() self.assertEqual({}, A.get_unacked_local_updates()) diff --git a/electrum/tests/test_lnutil.py b/electrum/tests/test_lnutil.py @@ -2,13 +2,14 @@ import unittest import json from electrum import bitcoin +from electrum.json_db import StoredDict from electrum.lnutil import (RevocationStore, get_per_commitment_secret_from_seed, make_offered_htlc, make_received_htlc, make_commitment, make_htlc_tx_witness, make_htlc_tx_output, make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, derive_privkey, derive_pubkey, make_htlc_tx, extract_ctn_from_tx, UnableToDeriveSecret, get_compressed_pubkey_from_bech32, split_host_port, ConnStringFormatError, ScriptHtlc, extract_nodeid, calc_onchain_fees, UpdateAddHtlc) -from electrum.util import bh2u, bfh +from electrum.util import bh2u, bfh, MyEncoder from electrum.transaction import Transaction, PartialTransaction from . import ElectrumTestCase @@ -422,7 +423,7 @@ class TestLNUtil(ElectrumTestCase): ] for test in tests: - receiver = RevocationStore({}) + receiver = RevocationStore(StoredDict({}, None, [])) for insert in test["inserts"]: secret = bytes.fromhex(insert["secret"]) @@ -445,14 +446,19 @@ class TestLNUtil(ElectrumTestCase): def test_shachain_produce_consume(self): seed = bitcoin.sha256(b"shachaintest") - consumer = RevocationStore({}) + consumer = RevocationStore(StoredDict({}, None, [])) for i in range(10000): secret = get_per_commitment_secret_from_seed(seed, RevocationStore.START_INDEX - i) try: consumer.add_next_entry(secret) except Exception as e: raise Exception("iteration " + str(i) + ": " + str(e)) - if i % 1000 == 0: self.assertEqual(consumer.serialize(), RevocationStore(json.loads(json.dumps(consumer.serialize()))).serialize()) + if i % 1000 == 0: + c1 = consumer + s1 = json.dumps(c1.storage, cls=MyEncoder) + c2 = RevocationStore(StoredDict(json.loads(s1), None, [])) + s2 = json.dumps(c2.storage, cls=MyEncoder) + self.assertEqual(s1, s2) def test_commitment_tx_with_all_five_HTLCs_untrimmed_minimum_feerate(self): to_local_msat = 6988000000 diff --git a/electrum/util.py b/electrum/util.py @@ -280,6 +280,8 @@ class MyEncoder(json.JSONEncoder): return obj.isoformat(' ')[:-3] if isinstance(obj, set): return list(obj) + if isinstance(obj, bytes): # for nametuples in lnchannel + return obj.hex() if hasattr(obj, 'to_json') and callable(obj.to_json): return obj.to_json() return super(MyEncoder, self).default(obj) diff --git a/electrum/wallet.py b/electrum/wallet.py @@ -240,21 +240,13 @@ class Abstract_Wallet(AddressSynchronizer, ABC): # saved fields self.use_change = storage.get('use_change', True) self.multiple_change = storage.get('multiple_change', False) - self.labels = storage.get('labels', {}) + self.labels = storage.db.get_dict('labels') self.frozen_addresses = set(storage.get('frozen_addresses', [])) self.frozen_coins = set(storage.get('frozen_coins', [])) # set of txid:vout strings - self.fiat_value = storage.get('fiat_value', {}) - self.receive_requests = storage.get('payment_requests', {}) - self.invoices = storage.get('invoices', {}) - # convert invoices - # TODO invoices being these contextual dicts even internally, - # where certain keys are only present depending on values of other keys... - # it's horrible. we need to change this, at least for the internal representation, - # to something that can be typed. - for invoice_key, invoice in self.invoices.items(): - if invoice.get('type') == PR_TYPE_ONCHAIN: - outputs = [PartialTxOutput.from_legacy_tuple(*output) for output in invoice.get('outputs')] - invoice['outputs'] = outputs + self.fiat_value = storage.db.get_dict('fiat_value') + self.receive_requests = storage.db.get_dict('payment_requests') + self.invoices = storage.db.get_dict('invoices') + self._prepare_onchain_invoice_paid_detection() self.calc_unused_change_addresses() # save wallet type the first time @@ -372,7 +364,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC): changed = True if changed: run_hook('set_label', self, name, text) - self.storage.put('labels', self.labels) return changed def set_fiat_value(self, txid, ccy, text, fx, value_sat): @@ -404,7 +395,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC): if ccy not in self.fiat_value: self.fiat_value[ccy] = {} self.fiat_value[ccy][txid] = text - self.storage.put('fiat_value', self.fiat_value) return reset def get_fiat_value(self, txid, ccy): @@ -625,12 +615,10 @@ class Abstract_Wallet(AddressSynchronizer, ABC): else: raise Exception('Unsupported invoice type') self.invoices[key] = invoice - self.storage.put('invoices', self.invoices) self.storage.write() def clear_invoices(self): self.invoices = {} - self.storage.put('invoices', self.invoices) self.storage.write() def get_invoices(self): @@ -642,7 +630,8 @@ class Abstract_Wallet(AddressSynchronizer, ABC): def get_invoice(self, key): if key not in self.invoices: return - item = copy.copy(self.invoices[key]) + # convert StoredDict to dict + item = dict(self.invoices[key]) request_type = item.get('type') if request_type == PR_TYPE_ONCHAIN: item['status'] = PR_PAID if self.is_onchain_invoice_paid(item) else PR_UNPAID @@ -1553,7 +1542,8 @@ class Abstract_Wallet(AddressSynchronizer, ABC): req = self.receive_requests.get(key) if not req: return - req = copy.copy(req) + # convert StoredDict to dict + req = dict(req) _type = req.get('type') if _type == PR_TYPE_ONCHAIN: addr = req['address'] @@ -1610,7 +1600,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC): req['name'] = pr.pki_data req['sig'] = bh2u(pr.signature) self.receive_requests[key] = req - self.storage.put('payment_requests', self.receive_requests) def add_payment_request(self, req): if req['type'] == PR_TYPE_ONCHAIN: @@ -1628,7 +1617,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC): raise Exception('Unknown request type') amount = req.get('amount') self.receive_requests[key] = req - self.storage.put('payment_requests', self.receive_requests) self.set_label(key, message) # should be a default label return req @@ -1643,7 +1631,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC): """ lightning or on-chain """ if key in self.invoices: self.invoices.pop(key) - self.storage.put('invoices', self.invoices) elif self.lnworker: self.lnworker.delete_payment(key) @@ -1651,7 +1638,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC): if addr not in self.receive_requests: return False self.receive_requests.pop(addr) - self.storage.put('payment_requests', self.receive_requests) return True def get_sorted_requests(self): diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py @@ -29,12 +29,16 @@ import copy import threading from collections import defaultdict from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence +import binascii from . import util, bitcoin -from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, bfh +from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, bfh, PR_TYPE_ONCHAIN from .keystore import bip44_derivation -from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction -from .json_db import JsonDB, locked, modifier +from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput +from .logging import Logger +from .lnutil import LOCAL, REMOTE, FeeUpdate, UpdateAddHtlc, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, RevocationStore +from .lnutil import ChannelConstraints, Outpoint, ShachainElement +from .json_db import StoredDict, JsonDB, locked, modifier # seed_version is now used for the version of the wallet file @@ -44,17 +48,12 @@ FINAL_SEED_VERSION = 24 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format - - class TxFeesValue(NamedTuple): fee: Optional[int] = None is_calculated_by_us: bool = False num_inputs: Optional[int] = None - - - class WalletDB(JsonDB): def __init__(self, raw, *, manual_upgrades: bool): @@ -67,7 +66,6 @@ class WalletDB(JsonDB): self.put('seed_version', FINAL_SEED_VERSION) self._after_upgrade_tasks() - def load_data(self, s): try: self.data = json.loads(s) @@ -833,7 +831,7 @@ class WalletDB(JsonDB): self.tx_fees.pop(txid, None) @locked - def get_data_ref(self, name): + def get_dict(self, name): # Warning: interacts un-intuitively with 'put': certain parts # of 'data' will have pointers saved as separate variables. if name not in self.data: @@ -895,9 +893,9 @@ class WalletDB(JsonDB): def load_addresses(self, wallet_type): """ called from Abstract_Wallet.__init__ """ if wallet_type == 'imported': - self.imported_addresses = self.get_data_ref('addresses') # type: Dict[str, dict] + self.imported_addresses = self.get_dict('addresses') # type: Dict[str, dict] else: - self.get_data_ref('addresses') + self.get_dict('addresses') for name in ['receiving', 'change']: if name not in self.data['addresses']: self.data['addresses'][name] = [] @@ -911,26 +909,20 @@ class WalletDB(JsonDB): @profiler def _load_transactions(self): + self.data = StoredDict(self.data, self, []) # references in self.data # TODO make all these private # txid -> address -> set of (prev_outpoint, value) - self.txi = self.get_data_ref('txi') # type: Dict[str, Dict[str, Set[Tuple[str, int]]]] + self.txi = self.get_dict('txi') # type: Dict[str, Dict[str, Set[Tuple[str, int]]]] # txid -> address -> set of (output_index, value, is_coinbase) - self.txo = self.get_data_ref('txo') # type: Dict[str, Dict[str, Set[Tuple[int, int, bool]]]] - self.transactions = self.get_data_ref('transactions') # type: Dict[str, Transaction] - self.spent_outpoints = self.get_data_ref('spent_outpoints') # txid -> output_index -> next_txid - self.history = self.get_data_ref('addr_history') # address -> list of (txid, height) - self.verified_tx = self.get_data_ref('verified_tx3') # txid -> (height, timestamp, txpos, header_hash) - self.tx_fees = self.get_data_ref('tx_fees') # type: Dict[str, TxFeesValue] + self.txo = self.get_dict('txo') # type: Dict[str, Dict[str, Set[Tuple[int, int, bool]]]] + self.transactions = self.get_dict('transactions') # type: Dict[str, Transaction] + self.spent_outpoints = self.get_dict('spent_outpoints') # txid -> output_index -> next_txid + self.history = self.get_dict('addr_history') # address -> list of (txid, height) + self.verified_tx = self.get_dict('verified_tx3') # txid -> (height, timestamp, txpos, header_hash) + self.tx_fees = self.get_dict('tx_fees') # type: Dict[str, TxFeesValue] # scripthash -> set of (outpoint, value) - self._prevouts_by_scripthash = self.get_data_ref('prevouts_by_scripthash') # type: Dict[str, Set[Tuple[str, int]]] - # convert raw transactions to Transaction objects - for tx_hash, raw_tx in self.transactions.items(): - # note: for performance, "deserialize=False" so that we will deserialize these on-demand - self.transactions[tx_hash] = tx_from_any(raw_tx, deserialize=False) - # convert prevouts_by_scripthash: list to set, list to tuple - for scripthash, lst in self._prevouts_by_scripthash.items(): - self._prevouts_by_scripthash[scripthash] = {(prevout, value) for prevout, value in lst} + self._prevouts_by_scripthash = self.get_dict('prevouts_by_scripthash') # type: Dict[str, Set[Tuple[str, int]]] # remove unreferenced tx for tx_hash in list(self.transactions.keys()): if not self.get_txi_addresses(tx_hash) and not self.get_txo_addresses(tx_hash): @@ -943,9 +935,15 @@ class WalletDB(JsonDB): if spending_txid not in self.transactions: self.logger.info("removing unreferenced spent outpoint") d.pop(prevout_n) - # convert tx_fees tuples to NamedTuples - for tx_hash, tuple_ in self.tx_fees.items(): - self.tx_fees[tx_hash] = TxFeesValue(*tuple_) + # convert invoices + # TODO invoices being these contextual dicts even internally, + # where certain keys are only present depending on values of other keys... + # it's horrible. we need to change this, at least for the internal representation, + # to something that can be typed. + self.invoices = self.get_dict('invoices') + for invoice_key, invoice in self.invoices.items(): + if invoice.get('type') == PR_TYPE_ONCHAIN: + invoice['outputs'] = [PartialTxOutput.from_legacy_tuple(*output) for output in invoice.get('outputs')] @modifier def clear_history(self): @@ -956,3 +954,42 @@ class WalletDB(JsonDB): self.history.clear() self.verified_tx.clear() self.tx_fees.clear() + + def _convert_dict(self, path, key, v): + if key == 'transactions': + # note: for performance, "deserialize=False" so that we will deserialize these on-demand + v = dict((k, tx_from_any(x, deserialize=False)) for k, x in v.items()) + elif key == 'adds': + v = dict((k, UpdateAddHtlc(*x)) for k, x in v.items()) + elif key == 'fee_updates': + v = dict((k, FeeUpdate(**x)) for k, x in v.items()) + elif key == 'tx_fees': + v = dict((k, TxFeesValue(*x)) for k, x in v.items()) + elif key == 'prevouts_by_scripthash': + v = dict((k, {(prevout, value) for (prevout, value) in x}) for k, x in v.items()) + elif key == 'buckets': + v = dict((k, ShachainElement(bfh(x[0]), int(x[1]))) for k, x in v.items()) + return v + + def _convert_value(self, path, key, v): + if key == 'local_config': + v = LocalConfig(**v) + elif key == 'remote_config': + v = RemoteConfig(**v) + elif key == 'constraints': + v = ChannelConstraints(**v) + elif key == 'funding_outpoint': + v = Outpoint(**v) + elif key.endswith("_basepoint") or key.endswith("_key"): + v = Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v) + elif key in [ + "short_channel_id", + "current_per_commitment_point", + "next_per_commitment_point", + "per_commitment_secret_seed", + "current_commitment_signature", + "current_htlc_signatures"]: + v = binascii.unhexlify(v) if v is not None else None + elif len(path) > 2 and path[-2] in ['local_config', 'remote_config'] and key in ["pubkey", "privkey"]: + v = binascii.unhexlify(v) if v is not None else None + return v