Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#### What this PR does / why we need it

#### Which issue(s) this PR fixes
<!--
*Automatically closes linked issue when PR is merged.
Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
_If PR is about `failing-tests or flakes`, please post the related issues/tests in a comment and do not use `Fixes`_*
-->
Fixes #

#### Special notes for your reviewer

#### Does this PR introduce a user-facing change?
<!--
If no, just write "NONE" in the release-note block below.
If yes, a release note is required:
Enter your extended release note in the block below. If the PR requires additional action from users switching to the new release, include the string "action required".
-->
```release-note

```
6 changes: 5 additions & 1 deletion alphatrion/experiment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ def __init__(self, runtime: Runtime):

@abstractmethod
def create(
self, name: str, description: str | None = None, meta: dict | None = None, labels: dict | None = None
self,
name: str,
description: str | None = None,
meta: dict | None = None,
labels: dict | None = None,
):
raise NotImplementedError("Subclasses must implement this method.")

Expand Down
6 changes: 5 additions & 1 deletion alphatrion/experiment/custom_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ def __init__(self, runtime: Runtime):
super().__init__(runtime)

def create(
self, name: str, description: str | None = None, meta: dict | None = None, labels: dict | None = None
self,
name: str,
description: str | None = None,
meta: dict | None = None,
labels: dict | None = None,
):
self._runtime._metadb.create_exp(
name=name,
Expand Down
73 changes: 68 additions & 5 deletions alphatrion/metadata/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sqlalchemy.orm import sessionmaker

from alphatrion.metadata.base import MetaStore
from alphatrion.metadata.sql_models import Base, Experiment
from alphatrion.metadata.sql_models import Base, Experiment, Model


# SQL-like metadata implementation, it could be SQLite, PostgreSQL, MySQL, etc.
Expand Down Expand Up @@ -38,16 +38,24 @@ def create_exp(
# Soft delete the experiment now. In the future, we may implement hard delete.
def delete_exp(self, exp_id: int):
session = self._session()
exp = session.query(Experiment).filter(Experiment.id == exp_id).first()
if exp and exp.is_del == 0:
exp = (
session.query(Experiment)
.filter(Experiment.id == exp_id, Experiment.is_del == 0)
.first()
)
if exp:
exp.is_del = 1
session.commit()
session.close()

# We don't support append-only update, the complete fields should be provided.
def update_exp(self, exp_id: int, **kwargs):
session = self._session()
exp = session.query(Experiment).filter(Experiment.id == exp_id).first()
exp = (
session.query(Experiment)
.filter(Experiment.id == exp_id, Experiment.is_del == 0)
.first()
)
if exp:
for key, value in kwargs.items():
setattr(exp, key, value)
Expand All @@ -70,10 +78,65 @@ def list_exps(self, project_id: str, page: int, page_size: int) -> list[Experime
session = self._session()
exps = (
session.query(Experiment)
.filter(Experiment.project_id == project_id)
.filter(Experiment.project_id == project_id, Experiment.is_del == 0)
.offset(page * page_size)
.limit(page_size)
.all()
)
session.close()
return exps

def create_model(
self,
name: str,
version: str = "latest",
description: str | None = None,
meta: dict | None = None,
labels: dict | None = None,
):
session = self._session()
new_model = Model(
name=name,
version=version,
description=description,
meta=meta,
labels=labels,
)
session.add(new_model)
session.commit()
session.close()

def update_model(self, model_id: int, **kwargs):
session = self._session()
model = (
session.query(Model).filter(Model.id == model_id, Model.is_del == 0).first()
)
if model:
for key, value in kwargs.items():
setattr(model, key, value)
session.commit()
session.close()

def get_model(self, model_id: int) -> Model | None:
session = self._session()
model = (
session.query(Model).filter(Model.id == model_id, Model.is_del == 0).first()
)
session.close()
return model

def list_models(self, page: int, page_size: int) -> list[Model]:
session = self._session()
models = session.query(Model).offset(page * page_size).limit(page_size).all()
session.close()
return models

def delete_model(self, model_id: int):
session = self._session()
model = (
session.query(Model).filter(Model.id == model_id, Model.is_del == 0).first()
)
if model:
model.is_del = 1
session.commit()
session.close()
17 changes: 17 additions & 0 deletions alphatrion/metadata/sql_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,20 @@ class Experiment(Base):
DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)
)
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")


class Model(Base):
__tablename__ = "models"

id = Column(Integer, primary_key=True)
name = Column(String, nullable=False, unique=True)
version = Column(String, nullable=False)
description = Column(String, nullable=True)
meta = Column(JSON, nullable=True, comment="Additional metadata for the model")
labels = Column(JSON, nullable=True, comment="Labels for the model")

created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
updated_at = Column(
DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)
)
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")
Empty file added alphatrion/model/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions alphatrion/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from alphatrion.runtime.runtime import Runtime


class Model:
def __init__(self, runtime: Runtime):
self._runtime = runtime

def create(
self,
name: str,
description: str | None = None,
meta: dict | None = None,
labels: dict | None = None,
):
self._runtime._metadb.create_model(
name=name,
description=description,
meta=meta,
labels=labels,
)

def update(self, model_id: int, **kwargs):
self._runtime._metadb.update_model(model_id=model_id, **kwargs)

def get(self, model_id: int):
return self._runtime._metadb.get_model(model_id=model_id)

def list(self, page: int = 0, page_size: int = 10):
return self._runtime._metadb.list_models(page=page, page_size=page_size)

def delete(self, model_id: int):
self._runtime._metadb.delete_model(model_id=model_id)
35 changes: 35 additions & 0 deletions tests/experiment/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os

import pytest

from alphatrion import consts
from alphatrion.model.model import Model
from alphatrion.runtime.runtime import Runtime


@pytest.fixture
def model():
os.environ[consts.METADATA_DB_URL] = "sqlite:///:memory:"
runtime = Runtime(project_id="test_project")
model = Model(runtime=runtime)
yield model


def test_model(model):
model.create("test_model", "A test model", {"foo": "bar"}, {"env": "test"})
model1 = model.get(1)
assert model1 is not None
assert model1.name == "test_model"
assert model1.description == "A test model"
assert model1.meta == {"foo": "bar"}

model.update(1, labels={"env": "prod"})
model1 = model.get(1)
assert model1.labels == {"env": "prod"}

models = model.list()
assert len(models) == 1

model.delete(1)
model1 = model.get(1)
assert model1 is None
Loading