Skip to content

Commit e3361b1

Browse files
authored
Add labels to Experiment (#11)
* Add Roadmap Signed-off-by: kerthcet <kerthcet@gmail.com> * Add labels Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent e2f92fa commit e3361b1

15 files changed

Lines changed: 166 additions & 84 deletions

File tree

Makefile

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,9 @@ launch:
1414

1515
.PHONY: test
1616
test:
17-
$(POETRY) run pytest
17+
$(POETRY) run pytest
18+
19+
.PHONY: format
20+
format:
21+
ruff format .
22+
ruff check --fix .

alphatrion/consts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
METADATA_DB_URL = "METADATA_DB_URL"
1+
METADATA_DB_URL = "METADATA_DB_URL"

alphatrion/experiment/base.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,40 @@
1-
2-
3-
from abc import ABC
1+
from abc import ABC, abstractmethod
42

53
from alphatrion.runtime.runtime import Runtime
64

75

86
class Experiment(ABC):
97
"""Base class for all experiments."""
8+
109
def __init__(self, runtime: Runtime):
1110
self._runtime = runtime
1211

13-
def create(self, name: str, description: str | None = None, meta: dict | None = None):
12+
@abstractmethod
13+
def create(
14+
self, name: str, description: str | None = None, meta: dict | None = None, labels: dict | None = None
15+
):
1416
raise NotImplementedError("Subclasses must implement this method.")
1517

18+
@abstractmethod
1619
def delete(self, exp_id: int):
1720
raise NotImplementedError("Subclasses must implement this method.")
1821

22+
@abstractmethod
1923
def get(self, exp_id: int):
2024
raise NotImplementedError("Subclasses must implement this method.")
2125

26+
@abstractmethod
27+
def update_labels(self, exp_id: int, labels: dict):
28+
raise NotImplementedError("Subclasses must implement this method.")
29+
30+
@abstractmethod
2231
def start(self, exp_id: int):
2332
raise NotImplementedError("Subclasses must implement this method.")
2433

34+
@abstractmethod
2535
def stop(self, exp_id: int, status: str = "finished"):
2636
raise NotImplementedError("Subclasses must implement this method.")
2737

38+
@abstractmethod
2839
def status(self, exp_id: int) -> str:
2940
raise NotImplementedError("Subclasses must implement this method.")

alphatrion/experiment/custom_exp.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,47 @@
1-
from alphatrion.metadata.sql_models import ExperimentStatus
1+
from datetime import datetime
2+
23
from alphatrion.experiment.base import Experiment
4+
from alphatrion.metadata.sql_models import COMPLETED_STATUS, ExperimentStatus
35
from alphatrion.runtime.runtime import Runtime
46

57

68
class CustomExperiment(Experiment):
79
def __init__(self, runtime: Runtime):
810
super().__init__(runtime)
911

10-
def create(self, name: str, description: str | None = None, meta: dict | None = None):
11-
self._runtime._metadb.create_exp(name=name, description=description, project_id=self._runtime._project_id, meta=meta)
12+
def create(
13+
self, name: str, description: str | None = None, meta: dict | None = None, labels: dict | None = None
14+
):
15+
self._runtime._metadb.create_exp(
16+
name=name,
17+
description=description,
18+
project_id=self._runtime._project_id,
19+
meta=meta,
20+
labels=labels,
21+
)
1222

1323
def delete(self, exp_id: int):
1424
self._runtime._metadb.delete_exp(exp_id=exp_id)
1525

1626
def get(self, exp_id: int):
1727
return self._runtime._metadb.get_exp(exp_id=exp_id)
1828

29+
# Please provide all the labels to update, or it will overwrite the existing labels.
30+
def update_labels(self, exp_id: int, labels: dict):
31+
self._runtime._metadb.update_exp(exp_id=exp_id, labels=labels)
32+
1933
# start for experiment usually means update the status to running.
2034
def start(self, exp_id: int):
2135
self._runtime._metadb.update_exp(exp_id=exp_id, status=ExperimentStatus.RUNNING)
2236

2337
# stop for experiment usually means update the status to finished or failed.
2438
def stop(self, exp_id: int, status: ExperimentStatus = ExperimentStatus.FINISHED):
25-
self._runtime._metadb.update_exp(exp_id=exp_id, status=status)
39+
exp = self._runtime._metadb.get_exp(exp_id=exp_id)
40+
if exp is not None and exp.status not in COMPLETED_STATUS:
41+
duration = (datetime.now() - exp.created_at).total_seconds()
42+
self._runtime._metadb.update_exp(
43+
exp_id=exp_id, status=status, duration=duration
44+
)
2645

2746
def status(self, exp_id: int) -> ExperimentStatus:
2847
exp = self._runtime._metadb.get_exp(exp_id=exp_id)

alphatrion/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
load_dotenv()
44

55
if __name__ == "__main__":
6-
print("Hello, AlphaTrion!")
6+
print("Hello, AlphaTrion!")

alphatrion/metadata/base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
1-
from abc import ABC
1+
from abc import ABC, abstractmethod
2+
23

34
class MetaStore(ABC):
45
"""Base class for all metadata storage backends."""
5-
def __init__(self):
6-
pass
76

8-
def create_exp(self, name: str, project_id: str, description: str | None, meta: dict | None):
7+
@abstractmethod
8+
def create_exp(
9+
self, name: str, project_id: str, description: str | None, meta: dict | None
10+
):
911
raise NotImplementedError("Subclasses must implement this method.")
1012

13+
@abstractmethod
1114
def delete_exp(self, exp_id: int):
1215
raise NotImplementedError("Subclasses must implement this method.")
1316

17+
@abstractmethod
1418
def update_exp(self, exp_id: int, **kwargs):
1519
raise NotImplementedError("Subclasses must implement this method.")
1620

21+
@abstractmethod
1722
def get_exp(self, exp_id: int):
1823
raise NotImplementedError("Subclasses must implement this method.")
1924

25+
@abstractmethod
2026
def list_exps(self, project_id: str, page: int, page_size: int):
2127
raise NotImplementedError("Subclasses must implement this method.")

alphatrion/metadata/sql.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,36 @@
11
from sqlalchemy import create_engine
22
from sqlalchemy.orm import sessionmaker
33

4-
from alphatrion.metadata.sql_models import Base, Experiment
54
from alphatrion.metadata.base import MetaStore
5+
from alphatrion.metadata.sql_models import Base, Experiment
66

77

88
# SQL-like metadata implementation, it could be SQLite, PostgreSQL, MySQL, etc.
99
class SQLStore(MetaStore):
1010
def __init__(self, db_url: str, init_tables: bool = False):
11-
super().__init__()
12-
1311
self._engine = create_engine(db_url)
1412
self._session = sessionmaker(bind=self._engine)
1513
if init_tables:
1614
# create tables if not exist, will not affect existing tables.
1715
# In production, use migrations instead.
1816
Base.metadata.create_all(self._engine)
1917

20-
21-
def create_exp(self, name: str, project_id: str, description: str | None, meta: dict | None):
18+
def create_exp(
19+
self,
20+
name: str,
21+
project_id: str,
22+
description: str | None,
23+
meta: dict | None,
24+
labels: dict | None = None,
25+
):
2226
session = self._session()
23-
new_exp = Experiment(name=name, description=description, project_id=project_id, meta=meta)
27+
new_exp = Experiment(
28+
name=name,
29+
description=description,
30+
project_id=project_id,
31+
meta=meta,
32+
labels=labels,
33+
)
2434
session.add(new_exp)
2535
session.commit()
2636
session.close()
@@ -34,6 +44,7 @@ def delete_exp(self, exp_id: int):
3444
session.commit()
3545
session.close()
3646

47+
# We don't support append-only update, the complete fields should be provided.
3748
def update_exp(self, exp_id: int, **kwargs):
3849
session = self._session()
3950
exp = session.query(Experiment).filter(Experiment.id == exp_id).first()
@@ -46,13 +57,23 @@ def update_exp(self, exp_id: int, **kwargs):
4657
# get_exp will ignore the deleted experiments.
4758
def get_exp(self, exp_id: int) -> Experiment | None:
4859
session = self._session()
49-
exp = session.query(Experiment).filter(Experiment.id == exp_id, Experiment.is_del == 0).first()
60+
exp = (
61+
session.query(Experiment)
62+
.filter(Experiment.id == exp_id, Experiment.is_del == 0)
63+
.first()
64+
)
5065
session.close()
5166
return exp
5267

5368
# paginate the experiments in case of too many experiments.
5469
def list_exps(self, project_id: str, page: int, page_size: int) -> list[Experiment]:
5570
session = self._session()
56-
exps = session.query(Experiment).filter(Experiment.project_id == project_id).offset(page * page_size).limit(page_size).all()
71+
exps = (
72+
session.query(Experiment)
73+
.filter(Experiment.project_id == project_id)
74+
.offset(page * page_size)
75+
.limit(page_size)
76+
.all()
77+
)
5778
session.close()
5879
return exps

alphatrion/metadata/sql_models.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from datetime import datetime, timezone
21
import enum
2+
from datetime import UTC, datetime
33

4-
from sqlalchemy import Column, Integer, String, Enum, DateTime, JSON
4+
from sqlalchemy import JSON, Column, DateTime, Enum, Integer, String
55
from sqlalchemy.orm import declarative_base
66

77
Base = declarative_base()
@@ -14,6 +14,9 @@ class ExperimentStatus(enum.Enum):
1414
FAILED = "failed"
1515

1616

17+
COMPLETED_STATUS = [ExperimentStatus.FINISHED, ExperimentStatus.FAILED]
18+
19+
1720
# Define the Experiment model for SQLAlchemy
1821
class Experiment(Base):
1922
__tablename__ = "experiments"
@@ -22,9 +25,15 @@ class Experiment(Base):
2225
name = Column(String, nullable=False, unique=True)
2326
description = Column(String, nullable=True)
2427
project_id = Column(String, nullable=False)
25-
status = Column(Enum(ExperimentStatus), nullable=False, default=ExperimentStatus.PENDING)
28+
status = Column(
29+
Enum(ExperimentStatus), nullable=False, default=ExperimentStatus.PENDING
30+
)
2631
meta = Column(JSON, nullable=True, comment="Additional metadata for the experiment")
32+
labels = Column(JSON, nullable=True, comment="Labels for the experiment")
33+
duration = Column(Integer, default=0, comment="Duration in seconds")
2734

28-
created_at = Column(DateTime(timezone=True), default=datetime.now(timezone.utc))
29-
updated_at = Column(DateTime(timezone=True), default=datetime.now(timezone.utc), onupdate=datetime.now(timezone.utc))
35+
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
36+
updated_at = Column(
37+
DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)
38+
)
3039
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")

poetry.lock

Lines changed: 30 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,31 @@ dependencies = [
1313

1414
[tool.poetry.group.dev.dependencies]
1515
pytest = ">=8.4.2,<9.0.0"
16+
ruff = "^0.12.12"
1617

1718
[build-system]
1819
requires = ["poetry-core>=2.0.0,<3.0.0"]
1920
build-backend = "poetry.core.masonry.api"
21+
22+
# Configuration for ruff linter and formatter
23+
24+
[tool.ruff]
25+
line-length = 88
26+
target-version = "py312"
27+
lint.select = [
28+
"E", # pycodestyle,
29+
"F", # Pyflakes
30+
"UP", # pyupgrade
31+
"B", # flake8-bugbear
32+
"SIM", # flake8-simplify
33+
"I", # isort
34+
"PL", # PyLint,
35+
]
36+
exclude = ["venv", "migrations", "__pycache__"]
37+
38+
[tool.ruff.format]
39+
quote-style = "double"
40+
indent-style = "space"
41+
42+
[tool.ruff.lint.per-file-ignores]
43+
"tests/*" = ["PLR2004"]

0 commit comments

Comments
 (0)