Skip to content

Commit 9814d7c

Browse files
authored
Add project_id to all models (#46)
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent f996a05 commit 9814d7c

13 files changed

Lines changed: 106 additions & 46 deletions

File tree

alphatrion/experiment/craft_exp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from alphatrion.experiment.base import Experiment
32
from alphatrion.trial.trial import Trial, TrialConfig
43

alphatrion/log/log.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ async def log_metrics(metrics: dict[str, float]):
6868
runtime._metadb.create_metric(
6969
key=key,
7070
value=value,
71+
project_id=runtime._project_id,
7172
trial_id=trial_id,
7273
step=step,
7374
)

alphatrion/metadata/sql.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Experiment,
1010
Metrics,
1111
Model,
12+
Project,
1213
Run,
1314
Trial,
1415
TrialStatus,
@@ -178,7 +179,8 @@ def delete_model(self, model_id: int):
178179

179180
def create_trial(
180181
self,
181-
exp_id: int,
182+
project_id: uuid.UUID,
183+
exp_id: uuid.UUID,
182184
name: str,
183185
description: str | None = None,
184186
meta: dict | None = None,
@@ -187,6 +189,7 @@ def create_trial(
187189
) -> uuid.UUID:
188190
session = self._session()
189191
new_trial = Trial(
192+
project_id=project_id,
190193
experiment_id=exp_id,
191194
name=name,
192195
description=description,
@@ -217,9 +220,10 @@ def update_trial(self, trial_id: uuid.UUID, **kwargs):
217220
session.commit()
218221
session.close()
219222

220-
def create_run(self, trial_id: uuid.UUID) -> uuid.UUID:
223+
def create_run(self, project_id: uuid.UUID, trial_id: uuid.UUID) -> uuid.UUID:
221224
session = self._session()
222225
new_run = Run(
226+
project_id=project_id,
223227
trial_id=trial_id,
224228
)
225229
session.add(new_run)
@@ -228,9 +232,17 @@ def create_run(self, trial_id: uuid.UUID) -> uuid.UUID:
228232
session.close()
229233
return run_id
230234

231-
def create_metric(self, trial_id: uuid.UUID, key: str, value: float, step: int):
235+
def create_metric(
236+
self,
237+
project_id: uuid.UUID,
238+
trial_id: uuid.UUID,
239+
key: str,
240+
value: float,
241+
step: int,
242+
):
232243
session = self._session()
233244
new_metric = Metrics(
245+
project_id=project_id,
234246
trial_id=trial_id,
235247
key=key,
236248
value=value,
@@ -245,3 +257,13 @@ def list_metrics(self, trial_id: uuid.UUID) -> list[Metrics]:
245257
metrics = session.query(Metrics).filter(Metrics.trial_id == trial_id).all()
246258
session.close()
247259
return metrics
260+
261+
def get_project(self, project_id: str) -> Project | None:
262+
session = self._session()
263+
project = (
264+
session.query(Project)
265+
.filter(Project.uuid == project_id, Project.is_del == 0)
266+
.first()
267+
)
268+
session.close()
269+
return project

alphatrion/metadata/sql_models.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,28 @@ class TrialStatus(enum.Enum):
1919
COMPLETED_STATUS = [TrialStatus.FINISHED, TrialStatus.FAILED]
2020

2121

22+
class Project(Base):
23+
__tablename__ = "projects"
24+
25+
uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
26+
name = Column(String, nullable=False)
27+
description = Column(String, nullable=True)
28+
29+
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
30+
updated_at = Column(
31+
DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)
32+
)
33+
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")
34+
35+
2236
# Define the Experiment model for SQLAlchemy
2337
class Experiment(Base):
2438
__tablename__ = "experiments"
2539

2640
uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
2741
name = Column(String, nullable=False)
2842
description = Column(String, nullable=True)
29-
project_id = Column(String, nullable=False)
43+
project_id = Column(UUID(as_uuid=True), nullable=False)
3044
meta = Column(JSON, nullable=True, comment="Additional metadata for the experiment")
3145

3246
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
@@ -40,6 +54,7 @@ class Trial(Base):
4054
__tablename__ = "trials"
4155

4256
uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
57+
project_id = Column(UUID(as_uuid=True), nullable=False)
4358
experiment_id = Column(UUID(as_uuid=True), nullable=False)
4459
name = Column(String, nullable=False)
4560
description = Column(String, nullable=True)
@@ -63,8 +78,8 @@ class Run(Base):
6378
__tablename__ = "runs"
6479

6580
uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
81+
project_id = Column(UUID(as_uuid=True), nullable=False)
6682
trial_id = Column(UUID(as_uuid=True), nullable=False)
67-
# artifact_path = Column(String, nullable=False)
6883

6984
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
7085
updated_at = Column(
@@ -77,10 +92,10 @@ class Model(Base):
7792

7893
uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
7994
name = Column(String, nullable=False, unique=True)
80-
version = Column(String, nullable=False)
8195
description = Column(String, nullable=True)
96+
project_id = Column(UUID(as_uuid=True), nullable=False)
97+
version = Column(String, nullable=False)
8298
meta = Column(JSON, nullable=True, comment="Additional metadata for the model")
83-
project_id = Column(String, nullable=False)
8499

85100
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
86101
updated_at = Column(
@@ -95,6 +110,7 @@ class Metrics(Base):
95110
uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
96111
key = Column(String, nullable=False)
97112
value = Column(Float, nullable=False)
113+
project_id = Column(UUID(as_uuid=True), nullable=False)
98114
trial_id = Column(UUID(as_uuid=True), nullable=False)
99115
step = Column(Integer, nullable=False, default=0)
100116
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))

alphatrion/run/run.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@ def id(self) -> uuid.UUID:
1313
return self._id
1414

1515
def _start(self):
16-
self._id = self._runtime._metadb.create_run(trial_id=self._trial_id)
16+
self._id = self._runtime._metadb.create_run(
17+
project_id=self._runtime._project_id, trial_id=self._trial_id
18+
)

alphatrion/runtime/runtime.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# ruff: noqa: PLW0603
22
import os
3+
import uuid
34

45
from alphatrion import consts
56
from alphatrion.artifact.artifact import Artifact
@@ -9,7 +10,7 @@
910

1011

1112
def init(
12-
project_id: str = "alphatrion",
13+
project_id: uuid.UUID,
1314
artifact_insecure: bool = False,
1415
init_tables: bool = False,
1516
):

alphatrion/trial/trial.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def _start(
240240
params: dict | None = None,
241241
):
242242
self._id = self._runtime._metadb.create_trial(
243+
project_id=self._runtime._project_id,
243244
exp_id=self._exp_id,
244245
name=name,
245246
description=description,

tests/integration/test_artifact.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import tempfile
5+
import uuid
56

67
import pytest
78

@@ -10,14 +11,14 @@
1011

1112
@pytest.fixture
1213
def artifact():
13-
init(project_id="test_project", artifact_insecure=True, init_tables=True)
14+
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
1415
artifact = global_runtime()._artifact
1516

1617
yield artifact
1718

1819

1920
def test_push_with_files(artifact):
20-
init(project_id="test_project", artifact_insecure=True, init_tables=True)
21+
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
2122

2223
with tempfile.TemporaryDirectory() as tmpdir:
2324
os.chdir(tmpdir)
@@ -40,7 +41,7 @@ def test_push_with_files(artifact):
4041

4142

4243
def test_push_with_folder(artifact):
43-
init(project_id="test_project", artifact_insecure=True, init_tables=True)
44+
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
4445

4546
with tempfile.TemporaryDirectory() as tmpdir:
4647
os.chdir(tmpdir)

tests/integration/test_log.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import tempfile
44
import time
5+
import uuid
56
from datetime import datetime, timedelta
67

78
import pytest
@@ -13,7 +14,7 @@
1314

1415
@pytest.mark.asyncio
1516
async def test_log_artifact():
16-
alpha.init(project_id="test_project", artifact_insecure=True, init_tables=True)
17+
alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
1718

1819
async with alpha.CraftExperiment.start(
1920
name="log_artifact_exp",
@@ -65,7 +66,7 @@ async def test_log_artifact():
6566

6667
@pytest.mark.asyncio
6768
async def test_log_params():
68-
alpha.init(project_id="test_project", artifact_insecure=True, init_tables=True)
69+
alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
6970

7071
async with alpha.CraftExperiment.start(name="log_params_exp") as exp:
7172
trial = exp.start_trial(name="first-trial", params={"param1": 0.1})
@@ -92,7 +93,7 @@ async def test_log_params():
9293

9394
@pytest.mark.asyncio
9495
async def test_log_metrics():
95-
alpha.init(project_id="test_project", artifact_insecure=True, init_tables=True)
96+
alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
9697

9798
async with alpha.CraftExperiment.start(name="log_metrics_exp") as exp:
9899
trial = exp.start_trial(name="first-trial", params={"param1": 0.1})
@@ -128,7 +129,7 @@ async def test_log_metrics():
128129

129130
@pytest.mark.asyncio
130131
async def test_log_metrics_with_save_on_max():
131-
alpha.init(project_id="test_project", artifact_insecure=True, init_tables=True)
132+
alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
132133

133134
async with alpha.CraftExperiment.start(
134135
name="log_metrics_with_save_on_max",
@@ -182,7 +183,7 @@ async def test_log_metrics_with_save_on_max():
182183

183184
@pytest.mark.asyncio
184185
async def test_log_metrics_with_save_on_min():
185-
alpha.init(project_id="test_project", artifact_insecure=True, init_tables=True)
186+
alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
186187

187188
async with alpha.CraftExperiment.start(
188189
name="log_metrics_with_save_on_min",
@@ -236,7 +237,7 @@ async def test_log_metrics_with_save_on_min():
236237

237238
@pytest.mark.asyncio
238239
async def test_log_metrics_with_early_stopping():
239-
alpha.init(project_id="test_project", artifact_insecure=True, init_tables=True)
240+
alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
240241

241242
async def fake_work(value: float):
242243
await alpha.log_metrics({"accuracy": value})
@@ -273,7 +274,7 @@ async def fake_sleep(value: float):
273274

274275
@pytest.mark.asyncio
275276
async def test_log_metrics_with_early_stopping_never_triggered():
276-
alpha.init(project_id="test_project", artifact_insecure=True, init_tables=True)
277+
alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
277278

278279
async def fake_work(value: float):
279280
await alpha.log_metrics({"accuracy": value})

tests/unit/artifact/test_artifact.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# Test the Artifact class
22

33

4+
import uuid
5+
46
import pytest
57

68
from alphatrion.runtime.runtime import global_runtime, init
79

810

911
@pytest.fixture
1012
def artifact():
11-
init(project_id="test_project", artifact_insecure=True)
13+
init(project_id=uuid.uuid4(), artifact_insecure=True)
1214
artifact = global_runtime()._artifact
1315
yield artifact
1416

0 commit comments

Comments
 (0)