sql_db.py (2392B)
1 import os 2 import concurrent 3 import queue 4 import threading 5 import asyncio 6 import sqlite3 7 8 from .logging import Logger 9 from .util import test_read_write_permissions 10 11 12 def sql(func): 13 """wrapper for sql methods""" 14 def wrapper(self: 'SqlDB', *args, **kwargs): 15 assert threading.currentThread() != self.sql_thread 16 f = asyncio.Future() 17 self.db_requests.put((f, func, args, kwargs)) 18 return f 19 return wrapper 20 21 22 class SqlDB(Logger): 23 24 def __init__(self, asyncio_loop: asyncio.BaseEventLoop, path, commit_interval=None): 25 Logger.__init__(self) 26 self.asyncio_loop = asyncio_loop 27 self.stopping = False 28 self.stopped_event = asyncio.Event() 29 self.path = path 30 test_read_write_permissions(path) 31 self.commit_interval = commit_interval 32 self.db_requests = queue.Queue() 33 self.sql_thread = threading.Thread(target=self.run_sql) 34 self.sql_thread.start() 35 36 def stop(self): 37 self.stopping = True 38 39 def filesize(self): 40 return os.stat(self.path).st_size 41 42 def run_sql(self): 43 self.logger.info("SQL thread started") 44 self.conn = sqlite3.connect(self.path) 45 self.logger.info("Creating database") 46 self.create_database() 47 i = 0 48 while not self.stopping and self.asyncio_loop.is_running(): 49 try: 50 future, func, args, kwargs = self.db_requests.get(timeout=0.1) 51 except queue.Empty: 52 continue 53 try: 54 result = func(self, *args, **kwargs) 55 except BaseException as e: 56 self.asyncio_loop.call_soon_threadsafe(future.set_exception, e) 57 continue 58 if not future.cancelled(): 59 self.asyncio_loop.call_soon_threadsafe(future.set_result, result) 60 # note: in sweepstore session.commit() is called inside 61 # the sql-decorated methods, so commiting to disk is awaited 62 if self.commit_interval: 63 i = (i + 1) % self.commit_interval 64 if i == 0: 65 self.conn.commit() 66 # write 67 self.conn.commit() 68 self.conn.close() 69 70 self.logger.info("SQL thread terminated") 71 self.asyncio_loop.call_soon_threadsafe(self.stopped_event.set) 72 73 def create_database(self): 74 raise NotImplementedError()