diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..8837bece --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,21 @@ +#### What this PR does / why we need it + +#### Which issue(s) this PR fixes + +Fixes # + +#### Special notes for your reviewer + +#### Does this PR introduce a user-facing change? + +```release-note + +``` diff --git a/alphatrion/experiment/base.py b/alphatrion/experiment/base.py index 1be68857..adf04021 100644 --- a/alphatrion/experiment/base.py +++ b/alphatrion/experiment/base.py @@ -11,7 +11,11 @@ def __init__(self, runtime: Runtime): @abstractmethod def create( - self, name: str, description: str | None = None, meta: dict | None = None, labels: dict | None = None + self, + name: str, + description: str | None = None, + meta: dict | None = None, + labels: dict | None = None, ): raise NotImplementedError("Subclasses must implement this method.") diff --git a/alphatrion/experiment/custom_exp.py b/alphatrion/experiment/custom_exp.py index 91760bd1..6611602f 100644 --- a/alphatrion/experiment/custom_exp.py +++ b/alphatrion/experiment/custom_exp.py @@ -10,7 +10,11 @@ def __init__(self, runtime: Runtime): super().__init__(runtime) def create( - self, name: str, description: str | None = None, meta: dict | None = None, labels: dict | None = None + self, + name: str, + description: str | None = None, + meta: dict | None = None, + labels: dict | None = None, ): self._runtime._metadb.create_exp( name=name, diff --git a/alphatrion/metadata/sql.py b/alphatrion/metadata/sql.py index 0e68b64d..d3869097 100644 --- a/alphatrion/metadata/sql.py +++ b/alphatrion/metadata/sql.py @@ -2,7 +2,7 @@ from sqlalchemy.orm import sessionmaker from alphatrion.metadata.base import MetaStore -from alphatrion.metadata.sql_models import Base, Experiment +from alphatrion.metadata.sql_models import Base, Experiment, Model # SQL-like metadata implementation, it could be SQLite, PostgreSQL, MySQL, etc. @@ -38,8 +38,12 @@ def create_exp( # Soft delete the experiment now. In the future, we may implement hard delete. def delete_exp(self, exp_id: int): session = self._session() - exp = session.query(Experiment).filter(Experiment.id == exp_id).first() - if exp and exp.is_del == 0: + exp = ( + session.query(Experiment) + .filter(Experiment.id == exp_id, Experiment.is_del == 0) + .first() + ) + if exp: exp.is_del = 1 session.commit() session.close() @@ -47,7 +51,11 @@ def delete_exp(self, exp_id: int): # We don't support append-only update, the complete fields should be provided. def update_exp(self, exp_id: int, **kwargs): session = self._session() - exp = session.query(Experiment).filter(Experiment.id == exp_id).first() + exp = ( + session.query(Experiment) + .filter(Experiment.id == exp_id, Experiment.is_del == 0) + .first() + ) if exp: for key, value in kwargs.items(): setattr(exp, key, value) @@ -70,10 +78,65 @@ def list_exps(self, project_id: str, page: int, page_size: int) -> list[Experime session = self._session() exps = ( session.query(Experiment) - .filter(Experiment.project_id == project_id) + .filter(Experiment.project_id == project_id, Experiment.is_del == 0) .offset(page * page_size) .limit(page_size) .all() ) session.close() return exps + + def create_model( + self, + name: str, + version: str = "latest", + description: str | None = None, + meta: dict | None = None, + labels: dict | None = None, + ): + session = self._session() + new_model = Model( + name=name, + version=version, + description=description, + meta=meta, + labels=labels, + ) + session.add(new_model) + session.commit() + session.close() + + def update_model(self, model_id: int, **kwargs): + session = self._session() + model = ( + session.query(Model).filter(Model.id == model_id, Model.is_del == 0).first() + ) + if model: + for key, value in kwargs.items(): + setattr(model, key, value) + session.commit() + session.close() + + def get_model(self, model_id: int) -> Model | None: + session = self._session() + model = ( + session.query(Model).filter(Model.id == model_id, Model.is_del == 0).first() + ) + session.close() + return model + + def list_models(self, page: int, page_size: int) -> list[Model]: + session = self._session() + models = session.query(Model).offset(page * page_size).limit(page_size).all() + session.close() + return models + + def delete_model(self, model_id: int): + session = self._session() + model = ( + session.query(Model).filter(Model.id == model_id, Model.is_del == 0).first() + ) + if model: + model.is_del = 1 + session.commit() + session.close() diff --git a/alphatrion/metadata/sql_models.py b/alphatrion/metadata/sql_models.py index 61b73204..2d1d4834 100644 --- a/alphatrion/metadata/sql_models.py +++ b/alphatrion/metadata/sql_models.py @@ -37,3 +37,20 @@ class Experiment(Base): DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC) ) is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted") + + +class Model(Base): + __tablename__ = "models" + + id = Column(Integer, primary_key=True) + name = Column(String, nullable=False, unique=True) + version = Column(String, nullable=False) + description = Column(String, nullable=True) + meta = Column(JSON, nullable=True, comment="Additional metadata for the model") + labels = Column(JSON, nullable=True, comment="Labels for the model") + + created_at = Column(DateTime(timezone=True), default=datetime.now(UTC)) + updated_at = Column( + DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC) + ) + is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted") diff --git a/alphatrion/model/__init__.py b/alphatrion/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/alphatrion/model/model.py b/alphatrion/model/model.py new file mode 100644 index 00000000..4071f7b5 --- /dev/null +++ b/alphatrion/model/model.py @@ -0,0 +1,32 @@ +from alphatrion.runtime.runtime import Runtime + + +class Model: + def __init__(self, runtime: Runtime): + self._runtime = runtime + + def create( + self, + name: str, + description: str | None = None, + meta: dict | None = None, + labels: dict | None = None, + ): + self._runtime._metadb.create_model( + name=name, + description=description, + meta=meta, + labels=labels, + ) + + def update(self, model_id: int, **kwargs): + self._runtime._metadb.update_model(model_id=model_id, **kwargs) + + def get(self, model_id: int): + return self._runtime._metadb.get_model(model_id=model_id) + + def list(self, page: int = 0, page_size: int = 10): + return self._runtime._metadb.list_models(page=page, page_size=page_size) + + def delete(self, model_id: int): + self._runtime._metadb.delete_model(model_id=model_id) diff --git a/tests/experiment/test_model.py b/tests/experiment/test_model.py new file mode 100644 index 00000000..c92cb625 --- /dev/null +++ b/tests/experiment/test_model.py @@ -0,0 +1,35 @@ +import os + +import pytest + +from alphatrion import consts +from alphatrion.model.model import Model +from alphatrion.runtime.runtime import Runtime + + +@pytest.fixture +def model(): + os.environ[consts.METADATA_DB_URL] = "sqlite:///:memory:" + runtime = Runtime(project_id="test_project") + model = Model(runtime=runtime) + yield model + + +def test_model(model): + model.create("test_model", "A test model", {"foo": "bar"}, {"env": "test"}) + model1 = model.get(1) + assert model1 is not None + assert model1.name == "test_model" + assert model1.description == "A test model" + assert model1.meta == {"foo": "bar"} + + model.update(1, labels={"env": "prod"}) + model1 = model.get(1) + assert model1.labels == {"env": "prod"} + + models = model.list() + assert len(models) == 1 + + model.delete(1) + model1 = model.get(1) + assert model1 is None