commit 43c5657cb60c044beb0db839e33ea53d590c2471
parent cf84068fdb0bcdf92bc6946fd6990bd3fd9da722
Author: ThomasV <thomasv@electrum.org>
Date: Sat, 15 Jul 2017 17:20:06 +0200
blockchain: parent pointer and recursive methods
Diffstat:
2 files changed, 59 insertions(+), 52 deletions(-)
diff --git a/lib/blockchain.py b/lib/blockchain.py
@@ -61,6 +61,29 @@ def hash_header(header):
return hash_encode(Hash(serialize_header(header).decode('hex')))
+blockchains = {}
+
+def read_blockchains(config):
+ blockchains[0] = Blockchain(config, 'blockchain_headers')
+ # fixme: sort
+ for x in os.listdir(config.path):
+ if x.startswith('fork_'):
+ b = Blockchain(config, x)
+ blockchains[b.checkpoint] = b
+ return blockchains
+
+def get_blockchain(header):
+ if type(header) is not dict:
+ return False
+ header_hash = hash_header(header)
+ height = header.get('block_height')
+ for b in blockchains.values():
+ if header_hash == b.get_hash(height):
+ return b
+ return False
+
+
+
class Blockchain(util.PrintError):
'''Manages blockchain headers and their verification'''
@@ -70,26 +93,28 @@ class Blockchain(util.PrintError):
self.filename = filename
self.catch_up = None # interface catching up
self.is_saved = True
- self.checkpoint = int(filename[16:]) if filename.startswith('blockchain_fork_') else 0
self.headers = []
+ if filename == 'blockchain_headers':
+ self.parent = None
+ self.checkpoint = 0
+ elif filename.startswith('fork_'):
+ self.parent = blockchains[int(filename.split('_')[1])]
+ self.checkpoint = int(filename.split('_')[2])
+ else:
+ raise BaseException('')
self.set_local_height()
- def fork(parent, fork_point):
- self = Blockchain(parent.config, parent.filename)
+ def fork(parent, checkpoint):
+ filename = 'fork_%d_%d'%(parent.checkpoint, checkpoint)
+ self = Blockchain(parent.config, filename)
self.is_saved = False
- if parent.is_saved:
- self.checkpoint = fork_point
- else:
- if fork_point > parent.checkpoint:
- self.headers = parent.headers[0: fork_point - parent.checkpoint]
- else:
- self.headers = []
+ self.parent = parent
+ self.checkpoint = checkpoint
return self
def height(self):
- if self.headers:
- return self.checkpoint + len(self.headers) - 1
- return self.local_height
+ local = self.local_height if self.is_saved else len(self.headers) - 1
+ return self.checkpoint + local
def verify_header(self, header, prev_header, bits, target):
prev_hash = hash_header(prev_header)
@@ -139,20 +164,15 @@ class Blockchain(util.PrintError):
self.set_local_height()
def save(self):
- import shutil
- self.print_error("save fork")
- height = self.checkpoint
- filename = "blockchain_fork_%d"%height
- new_path = os.path.join(util.get_headers_dir(self.config), filename)
- shutil.copy(self.path(), new_path)
- with open(new_path, 'rb+') as f:
- f.seek((height) * 80)
- f.truncate()
- self.filename = filename
- self.is_saved = True
+ # recursively save parents if they have not been saved
+ if self.parent and not self.parent.is_saved():
+ self.parent.save()
+ open(self.path(), 'w+').close()
for h in self.headers:
self.write_header(h)
self.headers = []
+ self.is_saved = True
+ self.print_error("saved", self.filename)
def save_header(self, header):
height = header.get('block_height')
@@ -165,12 +185,12 @@ class Blockchain(util.PrintError):
self.write_header(header)
def write_header(self, header):
- height = header.get('block_height')
+ delta = header.get('block_height') - self.checkpoint
data = serialize_header(header).decode('hex')
assert len(data) == 80
filename = self.path()
with open(filename, 'rb+') as f:
- f.seek(height * 80)
+ f.seek(delta * 80)
f.truncate()
h = f.write(data)
self.set_local_height()
@@ -180,25 +200,26 @@ class Blockchain(util.PrintError):
name = self.path()
if os.path.exists(name):
h = os.path.getsize(name)/80 - 1
- if self.local_height != h:
- self.local_height = h
+ self.local_height = h
def read_header(self, height):
- if not self.is_saved and height >= self.checkpoint:
- i = height - self.checkpoint
- if i >= len(self.headers):
+ if height < self.checkpoint:
+ return self.parent.read_header(height)
+ delta = height - self.checkpoint
+ if not self.is_saved:
+ if delta >= len(self.headers):
return None
- header = self.headers[i]
+ header = self.headers[delta]
assert header.get('block_height') == height
return header
name = self.path()
if os.path.exists(name):
f = open(name, 'rb')
- f.seek(height * 80)
+ f.seek(delta * 80)
h = f.read(80)
f.close()
if len(h) == 80:
- h = deserialize_header(h, height)
+ h = deserialize_header(h, delta)
return h
def get_hash(self, height):
diff --git a/lib/network.py b/lib/network.py
@@ -40,7 +40,7 @@ import util
import bitcoin
from bitcoin import *
from interface import Connection, Interface
-from blockchain import Blockchain
+from blockchain import read_blockchains, get_blockchain
from version import ELECTRUM_VERSION, PROTOCOL_VERSION
DEFAULT_PORTS = {'t':'50001', 's':'50002'}
@@ -206,11 +206,7 @@ class Network(util.DaemonThread):
util.DaemonThread.__init__(self)
self.config = SimpleConfig(config) if type(config) == type({}) else config
self.num_server = 10 if not self.config.get('oneserver') else 0
- self.blockchains = { 0:Blockchain(self.config, 'blockchain_headers') }
- for x in os.listdir(self.config.path):
- if x.startswith('blockchain_fork_'):
- b = Blockchain(self.config, x)
- self.blockchains[b.checkpoint] = b
+ self.blockchains = read_blockchains(self.config)
self.print_error("blockchains", self.blockchains.keys())
self.blockchain_index = config.get('blockchain_index', 0)
if self.blockchain_index not in self.blockchains.keys():
@@ -706,18 +702,8 @@ class Network(util.DaemonThread):
def get_checkpoint(self):
return max(self.blockchains.keys())
- def get_blockchain(self, header):
- from blockchain import hash_header
- if type(header) is not dict:
- return False
- header_hash = hash_header(header)
- height = header.get('block_height')
- for b in self.blockchains.values():
- if header_hash == b.get_hash(height):
- return b
- return False
-
def new_interface(self, server, socket):
+ # todo: get tip first, then decide which checkpoint to use.
self.add_recent_server(server)
interface = Interface(server, socket)
interface.blockchain = None
@@ -830,7 +816,7 @@ class Network(util.DaemonThread):
def on_header(self, interface, header):
height = header.get('block_height')
if interface.mode == 'checkpoint':
- b = self.get_blockchain(header)
+ b = get_blockchain(header)
if b:
interface.mode = 'default'
interface.blockchain = b