commit dbceed26474ba4a1ef7da22cd0fa0e5fde8e6a3c
parent 0a9e7cb04e45981acdd191f2143b993073254af8
Author: ThomasV <thomasv@electrum.org>
Date: Tue, 4 Feb 2020 13:35:58 +0100
Restructure wallet storage:
- Perform json deserializations in wallet_db
- use StoredDict class that keeps tracks of its modifications
Diffstat:
14 files changed, 303 insertions(+), 291 deletions(-)
diff --git a/electrum/json_db.py b/electrum/json_db.py
@@ -45,6 +45,101 @@ def locked(func):
return wrapper
+class StoredObject:
+
+ db = None
+
+ def __setattr__(self, key, value):
+ if self.db:
+ self.db.set_modified(True)
+ object.__setattr__(self, key, value)
+
+ def set_db(self, db):
+ self.db = db
+
+ def to_json(self):
+ d = dict(vars(self))
+ d.pop('db', None)
+ return d
+
+
+_RaiseKeyError = object() # singleton for no-default behavior
+
+class StoredDict(dict):
+
+ def __init__(self, data, db, path):
+ self.db = db
+ self.lock = self.db.lock if self.db else threading.RLock()
+ self.path = path
+ # recursively convert dicts to StoredDict
+ for k, v in list(data.items()):
+ self.__setitem__(k, v)
+
+ def convert_key(self, key):
+ # convert int, HTLCOwner to str
+ return str(int(key)) if isinstance(key, int) else key
+
+ @locked
+ def __setitem__(self, key, v):
+ key = self.convert_key(key)
+ is_new = key not in self
+ # early return to prevent unnecessary disk writes
+ if not is_new and self[key] == v:
+ return
+ # recursively convert dict to StoredDict.
+ # _convert_dict is called breadth-first
+ if isinstance(v, dict):
+ if self.db:
+ v = self.db._convert_dict(self.path, key, v)
+ v = StoredDict(v, self.db, self.path + [key])
+ # convert_value is called depth-first
+ if isinstance(v, dict) or isinstance(v, str):
+ if self.db:
+ v = self.db._convert_value(self.path, key, v)
+ # set parent of StoredObject
+ if isinstance(v, StoredObject):
+ v.set_db(self.db)
+ # set item
+ dict.__setitem__(self, key, v)
+ if self.db:
+ self.db.set_modified(True)
+
+ @locked
+ def __delitem__(self, key):
+ key = self.convert_key(key)
+ dict.__delitem__(self, key)
+ if self.db:
+ self.db.set_modified(True)
+
+ @locked
+ def __getitem__(self, key):
+ key = self.convert_key(key)
+ return dict.__getitem__(self, key)
+
+ @locked
+ def __contains__(self, key):
+ key = self.convert_key(key)
+ return dict.__contains__(self, key)
+
+ @locked
+ def pop(self, key, v=_RaiseKeyError):
+ key = self.convert_key(key)
+ if v is _RaiseKeyError:
+ r = dict.pop(self, key)
+ else:
+ r = dict.pop(self, key, v)
+ if self.db:
+ self.db.set_modified(True)
+ return r
+
+ @locked
+ def get(self, key, default=None):
+ key = self.convert_key(key)
+ return dict.get(self, key, default)
+
+
+
+
class JsonDB(Logger):
def __init__(self, data):
@@ -65,8 +160,6 @@ class JsonDB(Logger):
v = self.data.get(key)
if v is None:
v = default
- else:
- v = copy.deepcopy(v)
return v
@modifier
diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py
@@ -54,6 +54,7 @@ from .lnhtlc import HTLCManager
if TYPE_CHECKING:
from .lnworker import LNWallet
+ from .json_db import StoredDict
# lightning channel states
@@ -92,17 +93,6 @@ state_transitions = [
(cs.CLOSED, cs.REDEEMED),
]
-class ChannelJsonEncoder(json.JSONEncoder):
- def default(self, o):
- if isinstance(o, bytes):
- return binascii.hexlify(o).decode("ascii")
- if isinstance(o, RevocationStore):
- return o.serialize()
- if isinstance(o, set):
- return list(o)
- if hasattr(o, 'to_json') and callable(o.to_json):
- return o.to_json()
- return super().default(o)
RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_commitment_point"])
@@ -110,31 +100,9 @@ RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_co
class RemoteCtnTooFarInFuture(Exception): pass
-def decodeAll(d, local):
- for k, v in d.items():
- if k == 'revocation_store':
- yield (k, RevocationStore(v))
- elif k.endswith("_basepoint") or k.endswith("_key"):
- if local:
- yield (k, Keypair(**dict(decodeAll(v, local))))
- else:
- yield (k, OnlyPubkeyKeypair(**dict(decodeAll(v, local))))
- elif k in ["node_id", "channel_id", "short_channel_id", "pubkey", "privkey", "current_per_commitment_point", "next_per_commitment_point", "per_commitment_secret_seed", "current_commitment_signature", "current_htlc_signatures"] and v is not None:
- yield (k, binascii.unhexlify(v))
- else:
- yield (k, v)
-
def htlcsum(htlcs):
return sum([x.amount_msat for x in htlcs])
-# following two functions are used because json
-# doesn't store int keys and byte string values
-def str_bytes_dict_from_save(x) -> Dict[int, bytes]:
- return {int(k): bfh(v) for k,v in x.items()}
-
-def str_bytes_dict_to_save(x) -> Dict[str, str]:
- return {str(k): bh2u(v) for k, v in x.items()}
-
class Channel(Logger):
# note: try to avoid naming ctns/ctxs/etc as "current" and "pending".
@@ -149,44 +117,53 @@ class Channel(Logger):
except:
return super().diagnostic_name()
- def __init__(self, state, *, sweep_address=None, name=None, lnworker=None, initial_feerate=None):
+ def __init__(self, state: 'StoredDict', *, sweep_address=None, name=None, lnworker=None, initial_feerate=None):
self.name = name
Logger.__init__(self)
self.lnworker = lnworker # type: Optional[LNWallet]
self.sweep_address = sweep_address
- assert 'local_state' not in state
- self.db_lock = self.lnworker.wallet.storage.db.lock if self.lnworker else threading.RLock()
+ self.storage = state
+ self.db_lock = self.storage.db.lock if self.storage.db else threading.RLock()
self.config = {}
self.config[LOCAL] = state["local_config"]
- if type(self.config[LOCAL]) is not LocalConfig:
- conf = dict(decodeAll(self.config[LOCAL], True))
- self.config[LOCAL] = LocalConfig(**conf)
- assert type(self.config[LOCAL].htlc_basepoint.privkey) is bytes
-
self.config[REMOTE] = state["remote_config"]
- if type(self.config[REMOTE]) is not RemoteConfig:
- conf = dict(decodeAll(self.config[REMOTE], False))
- self.config[REMOTE] = RemoteConfig(**conf)
- assert type(self.config[REMOTE].htlc_basepoint.pubkey) is bytes
-
- self.channel_id = bfh(state["channel_id"]) if type(state["channel_id"]) not in (bytes, type(None)) else state["channel_id"]
- self.constraints = ChannelConstraints(**state["constraints"]) if type(state["constraints"]) is not ChannelConstraints else state["constraints"]
- self.funding_outpoint = Outpoint(**dict(decodeAll(state["funding_outpoint"], False))) if type(state["funding_outpoint"]) is not Outpoint else state["funding_outpoint"]
- self.node_id = bfh(state["node_id"]) if type(state["node_id"]) not in (bytes, type(None)) else state["node_id"] # type: bytes
+ self.channel_id = bfh(state["channel_id"])
+ self.constraints = state["constraints"]
+ self.funding_outpoint = state["funding_outpoint"]
+ self.node_id = bfh(state["node_id"])
self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"])
self.short_channel_id_predicted = self.short_channel_id
- self.onion_keys = str_bytes_dict_from_save(state.get('onion_keys', {}))
- self.data_loss_protect_remote_pcp = str_bytes_dict_from_save(state.get('data_loss_protect_remote_pcp', {}))
- self.remote_update = bfh(state.get('remote_update')) if state.get('remote_update') else None
-
- log = state.get('log')
- self.hm = HTLCManager(log=log, initial_feerate=initial_feerate)
+ self.onion_keys = state['onion_keys']
+ self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp']
+ self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate)
self._state = channel_states[state['state']]
self.peer_state = peer_states.DISCONNECTED
self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]]
self._outgoing_channel_update = None # type: Optional[bytes]
self.revocation_store = RevocationStore(state["revocation_store"])
+ def set_onion_key(self, key, value):
+ self.onion_keys[key] = value
+
+ def get_onion_key(self, key):
+ return self.onion_keys.get(key)
+
+ def set_data_loss_protect_remote_pcp(self, key, value):
+ self.data_loss_protect_remote_pcp[key] = value
+
+ def get_data_loss_protect_remote_pcp(self, key):
+ self.data_loss_protect_remote_pcp.get(key)
+
+ def set_remote_update(self, raw):
+ self.storage['remote_update'] = raw.hex()
+
+ def get_remote_update(self):
+ return bfh(self.storage.get('remote_update')) if self.storage.get('remote_update') else None
+
+ def set_short_channel_id(self, short_id):
+ self.short_channel_id = short_id
+ self.storage["short_channel_id"] = short_id
+
def get_feerate(self, subject, ctn):
return self.hm.get_feerate(subject, ctn)
@@ -229,8 +206,10 @@ class Channel(Logger):
old_state = self._state
if (old_state, state) not in state_transitions:
raise Exception(f"Transition not allowed: {old_state.name} -> {state.name}")
- self._state = state
self.logger.debug(f'Setting channel state: {old_state.name} -> {state.name}')
+ self._state = state
+ self.storage['state'] = self._state.name
+
if self.lnworker:
self.lnworker.save_channel(self)
self.lnworker.network.trigger_callback('channel', self)
@@ -656,51 +635,6 @@ class Channel(Logger):
else:
self.hm.recv_update_fee(feerate)
- def to_save(self):
- to_save = {
- "local_config": self.config[LOCAL],
- "remote_config": self.config[REMOTE],
- "channel_id": self.channel_id,
- "short_channel_id": self.short_channel_id,
- "constraints": self.constraints,
- "funding_outpoint": self.funding_outpoint,
- "node_id": self.node_id,
- "log": self.hm.to_save(),
- "revocation_store": self.revocation_store,
- "onion_keys": str_bytes_dict_to_save(self.onion_keys),
- "state": self._state.name,
- "data_loss_protect_remote_pcp": str_bytes_dict_to_save(self.data_loss_protect_remote_pcp),
- "remote_update": self.remote_update.hex() if self.remote_update else None
- }
- return to_save
-
- def serialize(self):
- namedtuples_to_dict = lambda v: {i: j._asdict() if isinstance(j, tuple) else j for i, j in v._asdict().items()}
- serialized_channel = {}
- to_save_ref = self.to_save()
- for k, v in to_save_ref.items():
- if isinstance(v, tuple):
- serialized_channel[k] = namedtuples_to_dict(v)
- else:
- serialized_channel[k] = v
- dumped = ChannelJsonEncoder().encode(serialized_channel)
- roundtripped = json.loads(dumped)
- reconstructed = Channel(roundtripped)
- to_save_new = reconstructed.to_save()
- if to_save_new != to_save_ref:
- from pprint import PrettyPrinter
- pp = PrettyPrinter(indent=168)
- try:
- from deepdiff import DeepDiff
- except ImportError:
- raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(to_save_ref) + "\n" + pp.pformat(to_save_new))
- else:
- raise Exception("Channels did not roundtrip serialization without changes:\n" + pp.pformat(DeepDiff(to_save_ref, to_save_new)))
- return roundtripped
-
- def __str__(self):
- return str(self.serialize())
-
def make_commitment(self, subject, this_point, ctn) -> PartialTransaction:
assert type(subject) is HTLCOwner
feerate = self.get_feerate(subject, ctn)
diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py
@@ -1,14 +1,17 @@
from copy import deepcopy
-from typing import Optional, Sequence, Tuple, List, Dict
+from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate
from .util import bh2u, bfh
+if TYPE_CHECKING:
+ from .json_db import StoredDict
class HTLCManager:
- def __init__(self, *, log=None, initial_feerate=None):
- if log is None:
+ def __init__(self, log:'StoredDict', *, initial_feerate=None):
+
+ if len(log) == 0:
initial = {
'adds': {},
'locked_in': {},
@@ -17,33 +20,18 @@ class HTLCManager:
'fee_updates': {}, # "side who initiated fee update" -> action -> list of FeeUpdates
'revack_pending': False,
'next_htlc_id': 0,
- 'ctn': -1, # oldest unrevoked ctx of sub
+ 'ctn': -1, # oldest unrevoked ctx of sub
}
- log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)}
- else:
- assert type(log) is dict
- log = {(HTLCOwner(int(k)) if k in ("-1", "1") else k): v
- for k, v in deepcopy(log).items()}
- for sub in (LOCAL, REMOTE):
- log[sub]['adds'] = {int(htlc_id): UpdateAddHtlc(*htlc) for htlc_id, htlc in log[sub]['adds'].items()}
- coerceHtlcOwner2IntMap = lambda ctns: {HTLCOwner(int(owner)): ctn for owner, ctn in ctns.items()}
- # "side who offered htlc" -> action -> htlc_id -> whose ctx -> ctn
- log[sub]['locked_in'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['locked_in'].items()}
- log[sub]['settles'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['settles'].items()}
- log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()}
- # "side who initiated fee update" -> action -> list of FeeUpdates
- log[sub]['fee_updates'] = { int(x): FeeUpdate(**fee_upd) for x,fee_upd in log[sub]['fee_updates'].items() }
-
- if 'unacked_local_updates2' not in log:
+ log[LOCAL] = deepcopy(initial)
+ log[REMOTE] = deepcopy(initial)
log['unacked_local_updates2'] = {}
- log['unacked_local_updates2'] = {int(ctn): [bfh(msg) for msg in messages]
- for ctn, messages in log['unacked_local_updates2'].items()}
+
# maybe bootstrap fee_updates if initial_feerate was provided
if initial_feerate is not None:
assert type(initial_feerate) is int
for sub in (LOCAL, REMOTE):
if not log[sub]['fee_updates']:
- log[sub]['fee_updates'][0] = FeeUpdate(initial_feerate, ctn_local=0, ctn_remote=0)
+ log[sub]['fee_updates'][0] = FeeUpdate(rate=initial_feerate, ctn_local=0, ctn_remote=0)
self.log = log
def ctn_latest(self, sub: HTLCOwner) -> int:
@@ -66,20 +54,6 @@ class HTLCManager:
def get_next_htlc_id(self, sub: HTLCOwner) -> int:
return self.log[sub]['next_htlc_id']
- def to_save(self):
- log = deepcopy(self.log)
- for sub in (LOCAL, REMOTE):
- # adds
- d = {}
- for htlc_id, htlc in log[sub]['adds'].items():
- d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:]
- log[sub]['adds'] = d
- # fee_updates
- log[sub]['fee_updates'] = { x:fee_upd.to_json() for x, fee_upd in self.log[sub]['fee_updates'].items() }
- log['unacked_local_updates2'] = {ctn: [bh2u(msg) for msg in messages]
- for ctn, messages in log['unacked_local_updates2'].items()}
- return log
-
##### Actions on channel:
def channel_open_finished(self):
@@ -132,7 +106,7 @@ class HTLCManager:
def _new_feeupdate(self, fee_update: FeeUpdate, subject: HTLCOwner) -> None:
# overwrite last fee update if not yet committed to by anyone; otherwise append
d = self.log[subject]['fee_updates']
- assert type(d) is dict
+ #assert type(d) is StoredDict
n = len(d)
last_fee_update = d[n-1]
if (last_fee_update.ctn_local is None or last_fee_update.ctn_local > self.ctn_latest(LOCAL)) \
@@ -194,7 +168,7 @@ class HTLCManager:
del self.log[REMOTE]['locked_in'][htlc_id]
del self.log[REMOTE]['adds'][htlc_id]
if self.log[REMOTE]['locked_in']:
- self.log[REMOTE]['next_htlc_id'] = max(self.log[REMOTE]['locked_in']) + 1
+ self.log[REMOTE]['next_htlc_id'] = max([int(x) for x in self.log[REMOTE]['locked_in'].keys()]) + 1
else:
self.log[REMOTE]['next_htlc_id'] = 0
# htlcs removed
@@ -217,12 +191,14 @@ class HTLCManager:
ctn_idx = self.ctn_latest(REMOTE)
else:
ctn_idx = self.ctn_latest(REMOTE) + 1
- if ctn_idx not in self.log['unacked_local_updates2']:
- self.log['unacked_local_updates2'][ctn_idx] = []
- self.log['unacked_local_updates2'][ctn_idx].append(raw_update_msg)
+ l = self.log['unacked_local_updates2'].get(ctn_idx, [])
+ l.append(raw_update_msg.hex())
+ self.log['unacked_local_updates2'][ctn_idx] = l
def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]:
- return self.log['unacked_local_updates2']
+ #return self.log['unacked_local_updates2']
+ return {int(ctn): [bfh(msg) for msg in messages]
+ for ctn, messages in self.log['unacked_local_updates2'].items()}
##### Queries re HTLCs:
diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py
@@ -221,7 +221,7 @@ class Peer(Logger):
def maybe_save_remote_update(self, payload):
for chan in self.channels.values():
if chan.short_channel_id == payload['short_channel_id']:
- chan.remote_update = payload['raw']
+ chan.set_remote_update(payload['raw'])
self.logger.info("saved remote_update")
def on_announcement_signatures(self, payload):
@@ -611,9 +611,15 @@ class Peer(Logger):
"constraints": constraints,
"remote_update": None,
"state": channel_states.PREOPENING.name,
+ 'onion_keys': {},
+ 'data_loss_protect_remote_pcp': {},
+ "log": {},
"revocation_store": {},
}
- return chan_dict
+ channel_id = chan_dict.get('channel_id')
+ channels = self.lnworker.storage.db.get_dict('channels')
+ channels[channel_id] = chan_dict
+ return channels.get(channel_id)
async def on_open_channel(self, payload):
# payload['channel_flags']
@@ -684,7 +690,7 @@ class Peer(Logger):
signature=sig_64,
)
chan.open_with_first_pcp(payload['first_per_commitment_point'], remote_sig)
- self.lnworker.save_channel(chan)
+ self.lnworker.add_channel(chan)
self.lnworker.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
def validate_remote_reserve(self, payload_field: bytes, dust_limit: int, funding_sat: int) -> int:
@@ -850,7 +856,7 @@ class Peer(Logger):
else:
if dlp_enabled and should_close_they_are_ahead:
self.logger.warning(f"channel_reestablish: remote is ahead of us! luckily DLP is enabled. remote PCP: {bh2u(their_local_pcp)}")
- chan.data_loss_protect_remote_pcp[their_next_local_ctn - 1] = their_local_pcp
+ chan.set_data_loss_protect_remote_pcp(their_next_local_ctn - 1, their_local_pcp)
self.lnworker.save_channel(chan)
if should_close_they_are_ahead:
self.logger.warning(f"channel_reestablish: remote is ahead of us! trying to get them to force-close.")
@@ -885,7 +891,6 @@ class Peer(Logger):
self.logger.info(f"on_funding_locked. channel: {bh2u(channel_id)}")
chan = self.channels.get(channel_id)
if not chan:
- print(self.channels)
raise Exception("Got unknown funding_locked", channel_id)
if not chan.config[LOCAL].funding_locked_received:
our_next_point = chan.config[REMOTE].next_per_commitment_point
@@ -1004,11 +1009,11 @@ class Peer(Logger):
# peer may have sent us a channel update for the incoming direction previously
pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
if pending_channel_update:
- chan.remote_update = pending_channel_update['raw']
+ chan.set_remote_update(pending_channel_update['raw'])
# add remote update with a fresh timestamp
- if chan.remote_update:
+ if chan.get_remote_update():
now = int(time.time())
- remote_update_decoded = decode_msg(chan.remote_update)[1]
+ remote_update_decoded = decode_msg(chan.get_remote_update())[1]
remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
self.channel_db.add_channel_update(remote_update_decoded)
diff --git a/electrum/lnsweep.py b/electrum/lnsweep.py
@@ -299,8 +299,8 @@ def analyze_ctx(chan: 'Channel', ctx: Transaction):
their_pcp = ecc.ECPrivkey(per_commitment_secret).get_public_key_bytes(compressed=True)
is_revocation = True
#_logger.info(f'tx for revoked: {list(txs.keys())}')
- elif ctn in chan.data_loss_protect_remote_pcp:
- their_pcp = chan.data_loss_protect_remote_pcp[ctn]
+ elif chan.get_data_loss_protect_remote_pcp(ctn):
+ their_pcp = chan.get_data_loss_protect_remote_pcp(ctn)
is_revocation = False
else:
return
diff --git a/electrum/lnutil.py b/electrum/lnutil.py
@@ -38,12 +38,7 @@ LN_MAX_FUNDING_SAT = pow(2, 24) - 1
def ln_dummy_address():
return redeem_script_to_address('p2wsh', '')
-
-class StoredObject:
-
- def to_json(self):
- return dict(vars(self))
-
+from .json_db import StoredObject
@attr.s
class OnlyPubkeyKeypair(StoredObject):
@@ -180,21 +175,23 @@ class RevocationStore:
START_INDEX = 2 ** 48 - 1
def __init__(self, storage):
- self.index = storage.get('index', self.START_INDEX)
- buckets = storage.get('buckets', {})
- decode = lambda to_decode: ShachainElement(bfh(to_decode[0]), int(to_decode[1]))
- self.buckets = dict((int(k), decode(v)) for k, v in buckets.items())
+ if len(storage) == 0:
+ storage['index'] = self.START_INDEX
+ storage['buckets'] = {}
+ self.storage = storage
+ self.buckets = storage['buckets']
def add_next_entry(self, hsh):
- new_element = ShachainElement(index=self.index, secret=hsh)
- bucket = count_trailing_zeros(self.index)
+ index = self.storage['index']
+ new_element = ShachainElement(index=index, secret=hsh)
+ bucket = count_trailing_zeros(index)
for i in range(0, bucket):
this_bucket = self.buckets[i]
e = shachain_derive(new_element, this_bucket.index)
if e != this_bucket:
raise Exception("hash is not derivable: {} {} {}".format(bh2u(e.secret), bh2u(this_bucket.secret), this_bucket.index))
self.buckets[bucket] = new_element
- self.index -= 1
+ self.storage['index'] = index - 1
def retrieve_secret(self, index: int) -> bytes:
assert index <= self.START_INDEX, index
@@ -209,9 +206,6 @@ class RevocationStore:
return element.secret
raise UnableToDeriveSecret()
- def serialize(self):
- return {"index": self.index, "buckets": dict( (k, [bh2u(v.secret), v.index]) for k, v in self.buckets.items()) }
-
def __eq__(self, o):
return type(o) is RevocationStore and self.serialize() == o.serialize()
diff --git a/electrum/lnworker.py b/electrum/lnworker.py
@@ -34,13 +34,14 @@ from .bip32 import BIP32Node
from .util import bh2u, bfh, InvoiceError, resolve_dns_srv, is_ip_address, log_exceptions
from .util import ignore_exceptions, make_aiohttp_session
from .util import timestamp_to_datetime
+from .util import MyEncoder
from .logging import Logger
from .lntransport import LNTransport, LNResponderTransport
from .lnpeer import Peer, LN_P2P_NETWORK_TIMEOUT
from .lnaddr import lnencode, LnAddr, lndecode
from .ecc import der_sig_from_sig_string
from .ecc_fast import is_using_fast_ecc
-from .lnchannel import Channel, ChannelJsonEncoder
+from .lnchannel import Channel
from .lnchannel import channel_states, peer_states
from . import lnutil
from .lnutil import funding_output_script
@@ -106,8 +107,6 @@ FALLBACK_NODE_LIST_MAINNET = [
LNPeerAddr(host='3.124.63.44', port=9735, pubkey=bfh('0242a4ae0c5bef18048fbecf995094b74bfb0f7391418d71ed394784373f41e4f3')),
]
-encoder = ChannelJsonEncoder()
-
from typing import NamedTuple
@@ -347,19 +346,20 @@ class LNWallet(LNWorker):
LNWorker.__init__(self, xprv)
self.ln_keystore = keystore.from_xprv(xprv)
self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ
- self.payments = self.storage.get('lightning_payments', {}) # RHASH -> amount, direction, is_paid
- self.preimages = self.storage.get('lightning_preimages', {}) # RHASH -> preimage
+ self.payments = self.storage.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid
+ self.preimages = self.storage.db.get_dict('lightning_preimages') # RHASH -> preimage
self.sweep_address = wallet.get_receiving_address()
self.lock = threading.RLock()
self.logs = defaultdict(list) # type: Dict[str, List[PaymentAttemptLog]] # key is RHASH
# note: accessing channels (besides simple lookup) needs self.lock!
- self.channels = {} # type: Dict[bytes, Channel]
- for x in wallet.storage.get("channels", {}).values():
- c = Channel(x, sweep_address=self.sweep_address, lnworker=self)
- self.channels[c.channel_id] = c
+ self.channels = {}
+ channels = self.storage.db.get_dict("channels")
+ for channel_id, c in channels.items():
+ self.channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
+
# timestamps of opening and closing transactions
- self.channel_timestamps = self.storage.get('lightning_channel_timestamps', {})
+ self.channel_timestamps = self.storage.db.get_dict('lightning_channel_timestamps')
self.pending_payments = defaultdict(asyncio.Future)
@ignore_exceptions
@@ -610,17 +610,9 @@ class LNWallet(LNWorker):
assert type(chan) is Channel
if chan.config[REMOTE].next_per_commitment_point == chan.config[REMOTE].current_per_commitment_point:
raise Exception("Tried to save channel with next_point == current_point, this should not happen")
- with self.lock:
- self.channels[chan.channel_id] = chan
- self.save_channels()
+ self.wallet.storage.write()
self.network.trigger_callback('channel', chan)
- def save_channels(self):
- with self.lock:
- dumped = dict( (k.hex(), c.serialize()) for k, c in self.channels.items() )
- self.storage.put("channels", dumped)
- self.storage.write()
-
def save_short_chan_id(self, chan):
"""
Checks if Funding TX has been mined. If it has, save the short channel ID in chan;
@@ -648,8 +640,8 @@ class LNWallet(LNWorker):
return
block_height, tx_pos = self.lnwatcher.get_txpos(chan.funding_outpoint.txid)
assert tx_pos >= 0
- chan.short_channel_id = ShortChannelID.from_components(
- block_height, tx_pos, chan.funding_outpoint.output_index)
+ chan.set_short_channel_id(ShortChannelID.from_components(
+ block_height, tx_pos, chan.funding_outpoint.output_index))
self.logger.info(f"save_short_channel_id: {chan.short_channel_id}")
self.save_channel(chan)
@@ -669,7 +661,6 @@ class LNWallet(LNWorker):
# save timestamp regardless of state, so that funding tx is returned in get_history
self.channel_timestamps[bh2u(chan.channel_id)] = chan.funding_outpoint.txid, funding_height.height, funding_height.timestamp, None, None, None
- self.storage.put('lightning_channel_timestamps', self.channel_timestamps)
if chan.get_state() == channel_states.OPEN and self.should_channel_be_closed_due_to_expiring_htlcs(chan):
self.logger.info(f"force-closing due to expiring htlcs")
@@ -714,7 +705,6 @@ class LNWallet(LNWorker):
# fixme: this is wasteful
self.channel_timestamps[bh2u(chan.channel_id)] = funding_txid, funding_height.height, funding_height.timestamp, closing_txid, closing_height.height, closing_height.timestamp
- self.storage.put('lightning_channel_timestamps', self.channel_timestamps)
# remove from channel_db
if chan.short_channel_id is not None:
@@ -836,7 +826,7 @@ class LNWallet(LNWorker):
funding_sat=funding_sat,
push_msat=push_sat * 1000,
temp_channel_id=os.urandom(32))
- self.save_channel(chan)
+ self.add_channel(chan)
self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
self.network.trigger_callback('channels_updated', self.wallet)
self.wallet.add_transaction(funding_tx) # save tx as local into the wallet
@@ -846,6 +836,10 @@ class LNWallet(LNWorker):
await asyncio.wait_for(self.network.broadcast_transaction(funding_tx), LN_P2P_NETWORK_TIMEOUT)
return chan, funding_tx
+ def add_channel(self, chan):
+ with self.lock:
+ self.channels[chan.channel_id] = chan
+
@log_exceptions
async def add_peer(self, connect_str: str) -> Peer:
node_id, rest = extract_nodeid(connect_str)
@@ -1133,7 +1127,6 @@ class LNWallet(LNWorker):
def save_preimage(self, payment_hash: bytes, preimage: bytes):
assert sha256(preimage) == payment_hash
self.preimages[bh2u(payment_hash)] = bh2u(preimage)
- self.storage.put('lightning_preimages', self.preimages)
self.storage.write()
def get_preimage(self, payment_hash: bytes) -> bytes:
@@ -1152,7 +1145,6 @@ class LNWallet(LNWorker):
assert info.status in [PR_PAID, PR_UNPAID, PR_INFLIGHT]
with self.lock:
self.payments[key] = info.amount, info.direction, info.status
- self.storage.put('lightning_payments', self.payments)
self.storage.write()
def get_payment_status(self, payment_hash):
@@ -1238,7 +1230,6 @@ class LNWallet(LNWorker):
del self.payments[payment_hash_hex]
except KeyError:
return
- self.storage.put('lightning_payments', self.payments)
self.storage.write()
def get_balance(self):
@@ -1246,6 +1237,7 @@ class LNWallet(LNWorker):
return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0 for chan in self.channels.values()))/1000
def list_channels(self):
+ encoder = MyEncoder()
with self.lock:
# we output the funding_outpoint instead of the channel_id because lnd uses channel_point (funding outpoint) to identify channels
for channel_id, chan in self.channels.items():
@@ -1283,7 +1275,9 @@ class LNWallet(LNWorker):
assert chan.is_closed()
with self.lock:
self.channels.pop(chan_id)
- self.save_channels()
+ self.channel_timestamps.pop(chan_id.hex())
+ self.storage.get('channels').pop(chan_id.hex())
+
self.network.trigger_callback('channels_updated', self.wallet)
self.network.trigger_callback('wallet_updated', self.wallet)
diff --git a/electrum/plugins/labels/labels.py b/electrum/plugins/labels/labels.py
@@ -149,8 +149,6 @@ class LabelsPlugin(BasePlugin):
wallet.labels[key] = value
self.logger.info(f"received {len(response)} labels")
- # do not write to disk because we're in a daemon thread
- wallet.storage.put('labels', wallet.labels)
self.set_nonce(wallet, response["nonce"] + 1)
self.on_pulled(wallet)
diff --git a/electrum/tests/test_lnchannel.py b/electrum/tests/test_lnchannel.py
@@ -35,6 +35,7 @@ from electrum.lnutil import FeeUpdate
from electrum.ecc import sig_string_from_der_sig
from electrum.logging import console_stderr_handler
from electrum.lnchannel import channel_states
+from electrum.json_db import StoredDict
from . import ElectrumTestCase
@@ -45,9 +46,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
assert local_amount > 0
assert remote_amount > 0
channel_id, _ = lnpeer.channel_id_from_funding_tx(funding_txid, funding_index)
-
- return {
- "channel_id":channel_id,
+ state = {
+ "channel_id":channel_id.hex(),
"short_channel_id":channel_id[:8],
"funding_outpoint":lnpeer.Outpoint(funding_txid, funding_index),
"remote_config":lnpeer.RemoteConfig(
@@ -63,7 +63,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
initial_msat=remote_amount,
reserve_sat=0,
htlc_minimum_msat=1,
-
next_per_commitment_point=nex,
current_per_commitment_point=cur,
),
@@ -79,7 +78,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
max_accepted_htlcs=5,
initial_msat=local_amount,
reserve_sat=0,
-
per_commitment_secret_seed=seed,
funding_locked_received=True,
was_announced=False,
@@ -91,11 +89,14 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
is_initiator=is_initiator,
funding_txn_minimum_depth=3,
),
- "node_id":other_node_id,
+ "node_id":other_node_id.hex(),
'onion_keys': {},
+ 'data_loss_protect_remote_pcp': {},
'state': 'PREOPENING',
+ 'log': {},
'revocation_store': {},
}
+ return StoredDict(state, None, [])
def bip32(sequence):
node = bip32_utils.BIP32Node.from_rootseed(b"9dk", xtype='standard').subkey_at_private_derivation(sequence)
@@ -317,7 +318,6 @@ class TestChannel(ElectrumTestCase):
# Bob revokes his prior commitment given to him by Alice, since he now
# has a valid signature for a newer commitment.
bobRevocation, _ = bob_channel.revoke_current_commitment()
- bob_channel.serialize()
self.assertTrue(bob_channel.signature_fits(bob_channel.get_latest_commitment(LOCAL)))
# Bob finally sends a signature for Alice's commitment transaction.
@@ -341,18 +341,14 @@ class TestChannel(ElectrumTestCase):
# her prior commitment transaction. Alice shouldn't have any HTLCs to
# forward since she's sending an outgoing HTLC.
alice_channel.receive_revocation(bobRevocation)
- alice_channel.serialize()
self.assertTrue(alice_channel.signature_fits(alice_channel.get_latest_commitment(LOCAL)))
- alice_channel.serialize()
self.assertEqual(len(alice_channel.get_latest_commitment(LOCAL).outputs()), 2)
self.assertEqual(len(alice_channel.get_latest_commitment(REMOTE).outputs()), 3)
self.assertEqual(len(alice_channel.force_close_tx().outputs()), 2)
self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1)
- alice_channel.serialize()
-
self.assertEqual(alice_channel.get_next_commitment(LOCAL).outputs(),
bob_channel.get_latest_commitment(REMOTE).outputs())
@@ -365,14 +361,12 @@ class TestChannel(ElectrumTestCase):
self.assertEqual(len(alice_channel.force_close_tx().outputs()), 3)
self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1)
- alice_channel.serialize()
tx1 = str(alice_channel.force_close_tx())
self.assertNotEqual(tx0, tx1)
# Alice then generates a revocation for bob.
aliceRevocation, _ = alice_channel.revoke_current_commitment()
- alice_channel.serialize()
tx2 = str(alice_channel.force_close_tx())
# since alice already has the signature for the next one, it doesn't change her force close tx (it was already the newer one)
@@ -384,7 +378,6 @@ class TestChannel(ElectrumTestCase):
# into both commitment transactions.
self.assertTrue(bob_channel.signature_fits(bob_channel.get_latest_commitment(LOCAL)))
bob_channel.receive_revocation(aliceRevocation)
- bob_channel.serialize()
# At this point, both sides should have the proper number of satoshis
# sent, and commitment height updated within their local channel
@@ -450,20 +443,16 @@ class TestChannel(ElectrumTestCase):
self.assertEqual(1, alice_channel.get_oldest_unrevoked_ctn(LOCAL))
self.assertEqual(len(alice_channel.included_htlcs(LOCAL, RECEIVED, ctn=2)), 0)
aliceRevocation2, _ = alice_channel.revoke_current_commitment()
- alice_channel.serialize()
aliceSig2, aliceHtlcSigs2 = alice_channel.sign_next_commitment()
self.assertEqual(aliceHtlcSigs2, [], "alice should generate no htlc signatures")
self.assertEqual(len(bob_channel.get_latest_commitment(LOCAL).outputs()), 3)
bob_channel.receive_revocation(aliceRevocation2)
- bob_channel.serialize()
bob_channel.receive_new_commitment(aliceSig2, aliceHtlcSigs2)
bobRevocation2, (received, sent) = bob_channel.revoke_current_commitment()
self.assertEqual(one_bitcoin_in_msat, received)
- bob_channel.serialize()
alice_channel.receive_revocation(bobRevocation2)
- alice_channel.serialize()
# At this point, Bob should have 6 BTC settled, with Alice still having
# 4 BTC. Alice's channel should show 1 BTC sent and Bob's channel
@@ -509,8 +498,6 @@ class TestChannel(ElectrumTestCase):
self.assertEqual(bob_channel.total_msat(RECEIVED), one_bitcoin_in_msat, "bob satoshis received incorrect")
self.assertEqual(bob_channel.total_msat(SENT), 5 * one_bitcoin_in_msat, "bob satoshis sent incorrect")
- alice_channel.serialize()
-
def alice_to_bob_fee_update(self, fee=111):
aoldctx = self.alice_channel.get_next_commitment(REMOTE).outputs()
diff --git a/electrum/tests/test_lnhtlc.py b/electrum/tests/test_lnhtlc.py
@@ -4,18 +4,18 @@ from typing import NamedTuple
from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner, Direction
from electrum.lnhtlc import HTLCManager
+from electrum.json_db import StoredDict
from . import ElectrumTestCase
-
class H(NamedTuple):
owner : str
htlc_id : int
class TestHTLCManager(ElectrumTestCase):
def test_adding_htlcs_race(self):
- A = HTLCManager()
- B = HTLCManager()
+ A = HTLCManager(StoredDict({}, None, []))
+ B = HTLCManager(StoredDict({}, None, []))
A.channel_open_finished()
B.channel_open_finished()
ah0, bh0 = H('A', 0), H('B', 0)
@@ -61,8 +61,8 @@ class TestHTLCManager(ElectrumTestCase):
def test_single_htlc_full_lifecycle(self):
def htlc_lifecycle(htlc_success: bool):
- A = HTLCManager()
- B = HTLCManager()
+ A = HTLCManager(StoredDict({}, None, []))
+ B = HTLCManager(StoredDict({}, None, []))
A.channel_open_finished()
B.channel_open_finished()
B.recv_htlc(A.send_htlc(H('A', 0)))
@@ -134,8 +134,8 @@ class TestHTLCManager(ElectrumTestCase):
def test_remove_htlc_while_owing_commitment(self):
def htlc_lifecycle(htlc_success: bool):
- A = HTLCManager()
- B = HTLCManager()
+ A = HTLCManager(StoredDict({}, None, []))
+ B = HTLCManager(StoredDict({}, None, []))
A.channel_open_finished()
B.channel_open_finished()
ah0 = H('A', 0)
@@ -171,8 +171,8 @@ class TestHTLCManager(ElectrumTestCase):
htlc_lifecycle(htlc_success=False)
def test_adding_htlc_between_send_ctx_and_recv_rev(self):
- A = HTLCManager()
- B = HTLCManager()
+ A = HTLCManager(StoredDict({}, None, []))
+ B = HTLCManager(StoredDict({}, None, []))
A.channel_open_finished()
B.channel_open_finished()
A.send_ctx()
@@ -217,8 +217,8 @@ class TestHTLCManager(ElectrumTestCase):
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
def test_unacked_local_updates(self):
- A = HTLCManager()
- B = HTLCManager()
+ A = HTLCManager(StoredDict({}, None, []))
+ B = HTLCManager(StoredDict({}, None, []))
A.channel_open_finished()
B.channel_open_finished()
self.assertEqual({}, A.get_unacked_local_updates())
diff --git a/electrum/tests/test_lnutil.py b/electrum/tests/test_lnutil.py
@@ -2,13 +2,14 @@ import unittest
import json
from electrum import bitcoin
+from electrum.json_db import StoredDict
from electrum.lnutil import (RevocationStore, get_per_commitment_secret_from_seed, make_offered_htlc,
make_received_htlc, make_commitment, make_htlc_tx_witness, make_htlc_tx_output,
make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, derive_privkey,
derive_pubkey, make_htlc_tx, extract_ctn_from_tx, UnableToDeriveSecret,
get_compressed_pubkey_from_bech32, split_host_port, ConnStringFormatError,
ScriptHtlc, extract_nodeid, calc_onchain_fees, UpdateAddHtlc)
-from electrum.util import bh2u, bfh
+from electrum.util import bh2u, bfh, MyEncoder
from electrum.transaction import Transaction, PartialTransaction
from . import ElectrumTestCase
@@ -422,7 +423,7 @@ class TestLNUtil(ElectrumTestCase):
]
for test in tests:
- receiver = RevocationStore({})
+ receiver = RevocationStore(StoredDict({}, None, []))
for insert in test["inserts"]:
secret = bytes.fromhex(insert["secret"])
@@ -445,14 +446,19 @@ class TestLNUtil(ElectrumTestCase):
def test_shachain_produce_consume(self):
seed = bitcoin.sha256(b"shachaintest")
- consumer = RevocationStore({})
+ consumer = RevocationStore(StoredDict({}, None, []))
for i in range(10000):
secret = get_per_commitment_secret_from_seed(seed, RevocationStore.START_INDEX - i)
try:
consumer.add_next_entry(secret)
except Exception as e:
raise Exception("iteration " + str(i) + ": " + str(e))
- if i % 1000 == 0: self.assertEqual(consumer.serialize(), RevocationStore(json.loads(json.dumps(consumer.serialize()))).serialize())
+ if i % 1000 == 0:
+ c1 = consumer
+ s1 = json.dumps(c1.storage, cls=MyEncoder)
+ c2 = RevocationStore(StoredDict(json.loads(s1), None, []))
+ s2 = json.dumps(c2.storage, cls=MyEncoder)
+ self.assertEqual(s1, s2)
def test_commitment_tx_with_all_five_HTLCs_untrimmed_minimum_feerate(self):
to_local_msat = 6988000000
diff --git a/electrum/util.py b/electrum/util.py
@@ -280,6 +280,8 @@ class MyEncoder(json.JSONEncoder):
return obj.isoformat(' ')[:-3]
if isinstance(obj, set):
return list(obj)
+ if isinstance(obj, bytes): # for nametuples in lnchannel
+ return obj.hex()
if hasattr(obj, 'to_json') and callable(obj.to_json):
return obj.to_json()
return super(MyEncoder, self).default(obj)
diff --git a/electrum/wallet.py b/electrum/wallet.py
@@ -240,21 +240,13 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
# saved fields
self.use_change = storage.get('use_change', True)
self.multiple_change = storage.get('multiple_change', False)
- self.labels = storage.get('labels', {})
+ self.labels = storage.db.get_dict('labels')
self.frozen_addresses = set(storage.get('frozen_addresses', []))
self.frozen_coins = set(storage.get('frozen_coins', [])) # set of txid:vout strings
- self.fiat_value = storage.get('fiat_value', {})
- self.receive_requests = storage.get('payment_requests', {})
- self.invoices = storage.get('invoices', {})
- # convert invoices
- # TODO invoices being these contextual dicts even internally,
- # where certain keys are only present depending on values of other keys...
- # it's horrible. we need to change this, at least for the internal representation,
- # to something that can be typed.
- for invoice_key, invoice in self.invoices.items():
- if invoice.get('type') == PR_TYPE_ONCHAIN:
- outputs = [PartialTxOutput.from_legacy_tuple(*output) for output in invoice.get('outputs')]
- invoice['outputs'] = outputs
+ self.fiat_value = storage.db.get_dict('fiat_value')
+ self.receive_requests = storage.db.get_dict('payment_requests')
+ self.invoices = storage.db.get_dict('invoices')
+
self._prepare_onchain_invoice_paid_detection()
self.calc_unused_change_addresses()
# save wallet type the first time
@@ -372,7 +364,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
changed = True
if changed:
run_hook('set_label', self, name, text)
- self.storage.put('labels', self.labels)
return changed
def set_fiat_value(self, txid, ccy, text, fx, value_sat):
@@ -404,7 +395,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
if ccy not in self.fiat_value:
self.fiat_value[ccy] = {}
self.fiat_value[ccy][txid] = text
- self.storage.put('fiat_value', self.fiat_value)
return reset
def get_fiat_value(self, txid, ccy):
@@ -625,12 +615,10 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
else:
raise Exception('Unsupported invoice type')
self.invoices[key] = invoice
- self.storage.put('invoices', self.invoices)
self.storage.write()
def clear_invoices(self):
self.invoices = {}
- self.storage.put('invoices', self.invoices)
self.storage.write()
def get_invoices(self):
@@ -642,7 +630,8 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
def get_invoice(self, key):
if key not in self.invoices:
return
- item = copy.copy(self.invoices[key])
+ # convert StoredDict to dict
+ item = dict(self.invoices[key])
request_type = item.get('type')
if request_type == PR_TYPE_ONCHAIN:
item['status'] = PR_PAID if self.is_onchain_invoice_paid(item) else PR_UNPAID
@@ -1553,7 +1542,8 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
req = self.receive_requests.get(key)
if not req:
return
- req = copy.copy(req)
+ # convert StoredDict to dict
+ req = dict(req)
_type = req.get('type')
if _type == PR_TYPE_ONCHAIN:
addr = req['address']
@@ -1610,7 +1600,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
req['name'] = pr.pki_data
req['sig'] = bh2u(pr.signature)
self.receive_requests[key] = req
- self.storage.put('payment_requests', self.receive_requests)
def add_payment_request(self, req):
if req['type'] == PR_TYPE_ONCHAIN:
@@ -1628,7 +1617,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
raise Exception('Unknown request type')
amount = req.get('amount')
self.receive_requests[key] = req
- self.storage.put('payment_requests', self.receive_requests)
self.set_label(key, message) # should be a default label
return req
@@ -1643,7 +1631,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
""" lightning or on-chain """
if key in self.invoices:
self.invoices.pop(key)
- self.storage.put('invoices', self.invoices)
elif self.lnworker:
self.lnworker.delete_payment(key)
@@ -1651,7 +1638,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
if addr not in self.receive_requests:
return False
self.receive_requests.pop(addr)
- self.storage.put('payment_requests', self.receive_requests)
return True
def get_sorted_requests(self):
diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py
@@ -29,12 +29,16 @@ import copy
import threading
from collections import defaultdict
from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence
+import binascii
from . import util, bitcoin
-from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, bfh
+from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, bfh, PR_TYPE_ONCHAIN
from .keystore import bip44_derivation
-from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction
-from .json_db import JsonDB, locked, modifier
+from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput
+from .logging import Logger
+from .lnutil import LOCAL, REMOTE, FeeUpdate, UpdateAddHtlc, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, RevocationStore
+from .lnutil import ChannelConstraints, Outpoint, ShachainElement
+from .json_db import StoredDict, JsonDB, locked, modifier
# seed_version is now used for the version of the wallet file
@@ -44,17 +48,12 @@ FINAL_SEED_VERSION = 24 # electrum >= 2.7 will set this to prevent
# old versions from overwriting new format
-
-
class TxFeesValue(NamedTuple):
fee: Optional[int] = None
is_calculated_by_us: bool = False
num_inputs: Optional[int] = None
-
-
-
class WalletDB(JsonDB):
def __init__(self, raw, *, manual_upgrades: bool):
@@ -67,7 +66,6 @@ class WalletDB(JsonDB):
self.put('seed_version', FINAL_SEED_VERSION)
self._after_upgrade_tasks()
-
def load_data(self, s):
try:
self.data = json.loads(s)
@@ -833,7 +831,7 @@ class WalletDB(JsonDB):
self.tx_fees.pop(txid, None)
@locked
- def get_data_ref(self, name):
+ def get_dict(self, name):
# Warning: interacts un-intuitively with 'put': certain parts
# of 'data' will have pointers saved as separate variables.
if name not in self.data:
@@ -895,9 +893,9 @@ class WalletDB(JsonDB):
def load_addresses(self, wallet_type):
""" called from Abstract_Wallet.__init__ """
if wallet_type == 'imported':
- self.imported_addresses = self.get_data_ref('addresses') # type: Dict[str, dict]
+ self.imported_addresses = self.get_dict('addresses') # type: Dict[str, dict]
else:
- self.get_data_ref('addresses')
+ self.get_dict('addresses')
for name in ['receiving', 'change']:
if name not in self.data['addresses']:
self.data['addresses'][name] = []
@@ -911,26 +909,20 @@ class WalletDB(JsonDB):
@profiler
def _load_transactions(self):
+ self.data = StoredDict(self.data, self, [])
# references in self.data
# TODO make all these private
# txid -> address -> set of (prev_outpoint, value)
- self.txi = self.get_data_ref('txi') # type: Dict[str, Dict[str, Set[Tuple[str, int]]]]
+ self.txi = self.get_dict('txi') # type: Dict[str, Dict[str, Set[Tuple[str, int]]]]
# txid -> address -> set of (output_index, value, is_coinbase)
- self.txo = self.get_data_ref('txo') # type: Dict[str, Dict[str, Set[Tuple[int, int, bool]]]]
- self.transactions = self.get_data_ref('transactions') # type: Dict[str, Transaction]
- self.spent_outpoints = self.get_data_ref('spent_outpoints') # txid -> output_index -> next_txid
- self.history = self.get_data_ref('addr_history') # address -> list of (txid, height)
- self.verified_tx = self.get_data_ref('verified_tx3') # txid -> (height, timestamp, txpos, header_hash)
- self.tx_fees = self.get_data_ref('tx_fees') # type: Dict[str, TxFeesValue]
+ self.txo = self.get_dict('txo') # type: Dict[str, Dict[str, Set[Tuple[int, int, bool]]]]
+ self.transactions = self.get_dict('transactions') # type: Dict[str, Transaction]
+ self.spent_outpoints = self.get_dict('spent_outpoints') # txid -> output_index -> next_txid
+ self.history = self.get_dict('addr_history') # address -> list of (txid, height)
+ self.verified_tx = self.get_dict('verified_tx3') # txid -> (height, timestamp, txpos, header_hash)
+ self.tx_fees = self.get_dict('tx_fees') # type: Dict[str, TxFeesValue]
# scripthash -> set of (outpoint, value)
- self._prevouts_by_scripthash = self.get_data_ref('prevouts_by_scripthash') # type: Dict[str, Set[Tuple[str, int]]]
- # convert raw transactions to Transaction objects
- for tx_hash, raw_tx in self.transactions.items():
- # note: for performance, "deserialize=False" so that we will deserialize these on-demand
- self.transactions[tx_hash] = tx_from_any(raw_tx, deserialize=False)
- # convert prevouts_by_scripthash: list to set, list to tuple
- for scripthash, lst in self._prevouts_by_scripthash.items():
- self._prevouts_by_scripthash[scripthash] = {(prevout, value) for prevout, value in lst}
+ self._prevouts_by_scripthash = self.get_dict('prevouts_by_scripthash') # type: Dict[str, Set[Tuple[str, int]]]
# remove unreferenced tx
for tx_hash in list(self.transactions.keys()):
if not self.get_txi_addresses(tx_hash) and not self.get_txo_addresses(tx_hash):
@@ -943,9 +935,15 @@ class WalletDB(JsonDB):
if spending_txid not in self.transactions:
self.logger.info("removing unreferenced spent outpoint")
d.pop(prevout_n)
- # convert tx_fees tuples to NamedTuples
- for tx_hash, tuple_ in self.tx_fees.items():
- self.tx_fees[tx_hash] = TxFeesValue(*tuple_)
+ # convert invoices
+ # TODO invoices being these contextual dicts even internally,
+ # where certain keys are only present depending on values of other keys...
+ # it's horrible. we need to change this, at least for the internal representation,
+ # to something that can be typed.
+ self.invoices = self.get_dict('invoices')
+ for invoice_key, invoice in self.invoices.items():
+ if invoice.get('type') == PR_TYPE_ONCHAIN:
+ invoice['outputs'] = [PartialTxOutput.from_legacy_tuple(*output) for output in invoice.get('outputs')]
@modifier
def clear_history(self):
@@ -956,3 +954,42 @@ class WalletDB(JsonDB):
self.history.clear()
self.verified_tx.clear()
self.tx_fees.clear()
+
+ def _convert_dict(self, path, key, v):
+ if key == 'transactions':
+ # note: for performance, "deserialize=False" so that we will deserialize these on-demand
+ v = dict((k, tx_from_any(x, deserialize=False)) for k, x in v.items())
+ elif key == 'adds':
+ v = dict((k, UpdateAddHtlc(*x)) for k, x in v.items())
+ elif key == 'fee_updates':
+ v = dict((k, FeeUpdate(**x)) for k, x in v.items())
+ elif key == 'tx_fees':
+ v = dict((k, TxFeesValue(*x)) for k, x in v.items())
+ elif key == 'prevouts_by_scripthash':
+ v = dict((k, {(prevout, value) for (prevout, value) in x}) for k, x in v.items())
+ elif key == 'buckets':
+ v = dict((k, ShachainElement(bfh(x[0]), int(x[1]))) for k, x in v.items())
+ return v
+
+ def _convert_value(self, path, key, v):
+ if key == 'local_config':
+ v = LocalConfig(**v)
+ elif key == 'remote_config':
+ v = RemoteConfig(**v)
+ elif key == 'constraints':
+ v = ChannelConstraints(**v)
+ elif key == 'funding_outpoint':
+ v = Outpoint(**v)
+ elif key.endswith("_basepoint") or key.endswith("_key"):
+ v = Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v)
+ elif key in [
+ "short_channel_id",
+ "current_per_commitment_point",
+ "next_per_commitment_point",
+ "per_commitment_secret_seed",
+ "current_commitment_signature",
+ "current_htlc_signatures"]:
+ v = binascii.unhexlify(v) if v is not None else None
+ elif len(path) > 2 and path[-2] in ['local_config', 'remote_config'] and key in ["pubkey", "privkey"]:
+ v = binascii.unhexlify(v) if v is not None else None
+ return v