11from sqlalchemy import create_engine
22from sqlalchemy .orm import sessionmaker
33
4- from alphatrion .metadata .sql_models import Base , Experiment
54from alphatrion .metadata .base import MetaStore
5+ from alphatrion .metadata .sql_models import Base , Experiment
66
77
88# SQL-like metadata implementation, it could be SQLite, PostgreSQL, MySQL, etc.
99class SQLStore (MetaStore ):
1010 def __init__ (self , db_url : str , init_tables : bool = False ):
11- super ().__init__ ()
12-
1311 self ._engine = create_engine (db_url )
1412 self ._session = sessionmaker (bind = self ._engine )
1513 if init_tables :
1614 # create tables if not exist, will not affect existing tables.
1715 # In production, use migrations instead.
1816 Base .metadata .create_all (self ._engine )
1917
20-
21- def create_exp (self , name : str , project_id : str , description : str | None , meta : dict | None ):
18+ def create_exp (
19+ self ,
20+ name : str ,
21+ project_id : str ,
22+ description : str | None ,
23+ meta : dict | None ,
24+ labels : dict | None = None ,
25+ ):
2226 session = self ._session ()
23- new_exp = Experiment (name = name , description = description , project_id = project_id , meta = meta )
27+ new_exp = Experiment (
28+ name = name ,
29+ description = description ,
30+ project_id = project_id ,
31+ meta = meta ,
32+ labels = labels ,
33+ )
2434 session .add (new_exp )
2535 session .commit ()
2636 session .close ()
@@ -34,6 +44,7 @@ def delete_exp(self, exp_id: int):
3444 session .commit ()
3545 session .close ()
3646
47+ # We don't support append-only update, the complete fields should be provided.
3748 def update_exp (self , exp_id : int , ** kwargs ):
3849 session = self ._session ()
3950 exp = session .query (Experiment ).filter (Experiment .id == exp_id ).first ()
@@ -46,13 +57,23 @@ def update_exp(self, exp_id: int, **kwargs):
4657 # get_exp will ignore the deleted experiments.
4758 def get_exp (self , exp_id : int ) -> Experiment | None :
4859 session = self ._session ()
49- exp = session .query (Experiment ).filter (Experiment .id == exp_id , Experiment .is_del == 0 ).first ()
60+ exp = (
61+ session .query (Experiment )
62+ .filter (Experiment .id == exp_id , Experiment .is_del == 0 )
63+ .first ()
64+ )
5065 session .close ()
5166 return exp
5267
5368 # paginate the experiments in case of too many experiments.
5469 def list_exps (self , project_id : str , page : int , page_size : int ) -> list [Experiment ]:
5570 session = self ._session ()
56- exps = session .query (Experiment ).filter (Experiment .project_id == project_id ).offset (page * page_size ).limit (page_size ).all ()
71+ exps = (
72+ session .query (Experiment )
73+ .filter (Experiment .project_id == project_id )
74+ .offset (page * page_size )
75+ .limit (page_size )
76+ .all ()
77+ )
5778 session .close ()
5879 return exps
0 commit comments