1- from typing import Optional , List
1+ from typing import Optional , List , Dict , Any
22from datetime import datetime , timezone
33import json
44import os
@@ -94,8 +94,10 @@ def detect_db_version(data_dir: str, max_version: Optional[int] = None) -> Optio
9494class PeeweeStorage (AbstractStorage ):
9595 sid = "peewee"
9696
97- def __init__ (self , testing ) :
97+ def __init__ (self , testing : bool = True , filepath : str = None ) -> None :
9898 data_dir = get_data_dir ("aw-server" )
99+
100+ # TODO: Won't work with custom filepath
99101 current_db_version = detect_db_version (data_dir , max_version = LATEST_VERSION )
100102
101103 if current_db_version is not None and current_db_version < LATEST_VERSION :
@@ -104,26 +106,27 @@ def __init__(self, testing):
104106 logger .info ("Creating database file for new version {}" .format (LATEST_VERSION ))
105107 logger .warning ("ActivityWatch does not currently support database migrations, new database file will be empty" )
106108
107- filename = 'peewee-sqlite' + ('-testing' if testing else '' ) + ".v{}" .format (LATEST_VERSION ) + '.db'
108- filepath = os .path .join (data_dir , filename )
109+ if not filepath :
110+ filename = 'peewee-sqlite' + ('-testing' if testing else '' ) + ".v{}" .format (LATEST_VERSION ) + '.db'
111+ filepath = os .path .join (data_dir , filename )
109112 self .db = _db
110113 self .db .init (filepath )
111114 logger .info ("Using database file: {}" .format (filepath ))
112115
113116 # db.connect()
114117
115- self .bucket_keys = {}
118+ self .bucket_keys = {} # type: Dict[str, int]
116119 if not BucketModel .table_exists ():
117120 BucketModel .create_table ()
118121 if not EventModel .table_exists ():
119122 EventModel .create_table ()
120123 self .update_bucket_keys ()
121124
122- def update_bucket_keys (self ):
125+ def update_bucket_keys (self ) -> None :
123126 buckets = BucketModel .select ()
124127 self .bucket_keys = {bucket .id : bucket .key for bucket in buckets }
125128
126- def buckets (self ):
129+ def buckets (self ) -> Dict [ str , Dict [ str , Any ]] :
127130 buckets = {bucket .id : bucket .json () for bucket in BucketModel .select ()}
128131 return buckets
129132
@@ -214,7 +217,7 @@ def get_events(self, bucket_id: str, limit: int,
214217 return [Event (** e ) for e in list (map (EventModel .json , q .execute ()))]
215218
216219 def get_eventcount (self , bucket_id : str ,
217- starttime : Optional [datetime ] = None , endtime : Optional [datetime ] = None ):
220+ starttime : Optional [datetime ] = None , endtime : Optional [datetime ] = None ):
218221 q = EventModel .select () \
219222 .where (EventModel .bucket == self .bucket_keys [bucket_id ])
220223 if starttime :
0 commit comments