Skip to content

Commit e09b1f4

Browse files
authored
Add SDK api (#23)
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent a705e7a commit e09b1f4

13 files changed

Lines changed: 138 additions & 79 deletions

File tree

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,5 @@ test-integration: lint
3434
until docker exec postgres pg_isready -U at_user; do sleep 1; done; \
3535
$(POETRY) run pytest tests/integration; \
3636
'
37+
.PHONY: test-all
38+
test-all: test test-integration

alphatrion/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from alphatrion.experiment.craft_exp import CraftExperiment as CraftExperiment
2+
from alphatrion.observe.observe import log_artifact as log_artifact
3+
from alphatrion.runtime.runtime import init as init

alphatrion/artifact/artifact.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
import oras.client
44

55
from alphatrion import consts
6-
from alphatrion.runtime.runtime import Runtime
76

87
SUCCESS_CODE = 201
98

109

1110
class Artifact:
12-
def __init__(self, runtime: Runtime, insecure: bool = False):
13-
self._runtime = runtime
11+
def __init__(self, project_id: str, insecure: bool = False):
12+
self._project_id = project_id
1413
self._url = os.environ.get(consts.ARTIFACT_REGISTRY_URL)
1514
self._url = self._url.replace("https://", "").replace("http://", "")
1615
self._client = oras.client.OrasClient(
@@ -52,16 +51,17 @@ def push(
5251
raise ValueError("No files to push.")
5352

5453
url = self._url if self._url.endswith("/") else f"{self._url}/"
55-
target = f"{url}{self._runtime._project_id}/{experiment_name}:{version}"
54+
target = f"{url}{self._project_id}/{experiment_name}:{version}"
5655

5756
try:
5857
self._client.push(target, files=files_to_push)
5958
except Exception as e:
6059
raise RuntimeError("Failed to push artifacts") from e
6160

61+
# TODO: should we store it in the metadb instead?
6262
def list_versions(self, experiment_name: str) -> list[str]:
6363
url = self._url if self._url.endswith("/") else f"{self._url}/"
64-
target = f"{url}{self._runtime._project_id}/{experiment_name}"
64+
target = f"{url}{self._project_id}/{experiment_name}"
6565
try:
6666
tags = self._client.get_tags(target)
6767
return tags
@@ -70,7 +70,7 @@ def list_versions(self, experiment_name: str) -> list[str]:
7070

7171
def delete(self, experiment_name: str, versions: str | list[str]):
7272
url = self._url if self._url.endswith("/") else f"{self._url}/"
73-
target = f"{url}{self._runtime._project_id}/{experiment_name}"
73+
target = f"{url}{self._project_id}/{experiment_name}"
7474

7575
try:
7676
self._client.delete_tags(target, tags=versions)

alphatrion/experiment/base.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
from pydantic import BaseModel, Field, field_validator
55

6-
from alphatrion.artifact.artifact import Artifact
76
from alphatrion.metadata.sql_models import COMPLETED_STATUS, ExperimentStatus
8-
from alphatrion.runtime.runtime import Runtime
7+
from alphatrion.runtime.runtime import global_runtime
98

109

1110
class CheckpointConfig(BaseModel):
@@ -75,9 +74,7 @@ class Experiment:
7574

7675
def __init__(
7776
self,
78-
runtime: Runtime,
7977
config: ExperimentConfig | None = None,
80-
artifact_insecure: bool = False,
8178
):
8279
"""
8380
:param runtime: the Runtime instance
@@ -87,9 +84,8 @@ def __init__(
8784
artifact registry. Default is False.
8885
"""
8986

90-
self._runtime = runtime
91-
self._artifact = Artifact(runtime, insecure=artifact_insecure)
9287
self._config = config or ExperimentConfig()
88+
self._runtime = global_runtime()
9389

9490
self._steps = 0
9591
self._best_metric_value = None
@@ -100,13 +96,11 @@ def __init__(
10096
@classmethod
10197
def run(
10298
cls,
103-
project_id: str,
10499
config: ExperimentConfig | None = None,
105100
name: str | None = None,
106101
description: str | None = None,
107102
meta: dict | None = None,
108103
labels: dict | None = None,
109-
artifact_insecure: bool = False,
110104
):
111105
"""
112106
:param project_id: the project ID to run the experiment under
@@ -121,12 +115,7 @@ def run(
121115
:return: a context manager that yields an Experiment instance
122116
"""
123117

124-
runtime = Runtime(project_id=project_id)
125-
exp = Experiment(
126-
runtime=runtime,
127-
config=config,
128-
artifact_insecure=artifact_insecure,
129-
)
118+
exp = Experiment(config=config)
130119
return RunContext(
131120
exp, name=name, description=description, meta=meta, labels=labels
132121
)
@@ -234,32 +223,6 @@ def running_time(self) -> int:
234223
return 0
235224
return int((datetime.now(UTC) - self._start_at).total_seconds())
236225

237-
def log_artifact(
238-
self,
239-
exp_id: int,
240-
paths: str | list[str],
241-
version: str = "latest",
242-
):
243-
"""
244-
Log artifacts (files) to the artifact registry.
245-
:param exp_id: the experiment ID
246-
:param paths: list of file paths to log.
247-
Support one or multiple files or a folder.
248-
If a folder is provided, all files in the folder will be logged.
249-
Don't support nested folders currently.
250-
Only files in the first level of the folder will be logged.
251-
:param version: the version (tag) to log the files under
252-
"""
253-
254-
if not paths:
255-
raise ValueError("no files specified to log")
256-
257-
exp = self._runtime._metadb.get_exp(exp_id=exp_id)
258-
if exp is None:
259-
raise ValueError(f"Experiment with id {exp_id} does not exist.")
260-
261-
self._artifact.push(experiment_name=exp.name, paths=paths, version=version)
262-
263226

264227
class RunContext:
265228
"""A context manager for running experiments."""

alphatrion/experiment/craft_exp.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from alphatrion.experiment.base import Experiment
2-
from alphatrion.runtime.runtime import Runtime
1+
from alphatrion.experiment.base import Experiment, ExperimentConfig
32

43

54
class CraftExperiment(Experiment):
@@ -10,7 +9,7 @@ class CraftExperiment(Experiment):
109
Opposite to other experiment classes, you need to call all these methods yourself.
1110
"""
1211

13-
def __init__(self, runtime: Runtime):
14-
super().__init__(runtime)
15-
# Disable checkpointing by default for CraftExperiment
12+
def __init__(self, config: ExperimentConfig | None = None):
13+
super().__init__(config=config)
14+
# Disable auto-checkpointing by default for CraftExperiment
1615
self._config.checkpoint.enabled = False

alphatrion/observe/__init__.py

Whitespace-only changes.

alphatrion/observe/observe.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from alphatrion.runtime.runtime import global_runtime
2+
3+
4+
def log_artifact(
5+
exp_id: int,
6+
paths: str | list[str],
7+
version: str = "latest",
8+
):
9+
"""
10+
Log artifacts (files) to the artifact registry.
11+
12+
:param exp_id: the experiment ID
13+
:param paths: list of file paths to log.
14+
Support one or multiple files or a folder.
15+
If a folder is provided, all files in the folder will be logged.
16+
Don't support nested folders currently.
17+
Only files in the first level of the folder will be logged.
18+
:param version: the version (tag) to log the files
19+
"""
20+
21+
if not paths:
22+
raise ValueError("no files specified to log")
23+
24+
runtime = global_runtime()
25+
if runtime is None:
26+
raise RuntimeError("Runtime is not initialized. Please call init() first.")
27+
28+
exp = runtime._metadb.get_exp(exp_id=exp_id)
29+
if exp is None:
30+
raise ValueError(f"Experiment with id {exp_id} does not exist.")
31+
32+
runtime._artifact.push(experiment_name=exp.name, paths=paths, version=version)
33+
34+
35+
# def log_params(exp_id: int, params: dict):
36+
# runtime = global_runtime()
37+
# if runtime is None:
38+
# raise RuntimeError("Runtime is not initialized. Please call init() first.")
39+
40+
# runtime._metadb.log_params(exp_id=exp_id, params=params)

alphatrion/runtime/runtime.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,34 @@
1+
# ruff: noqa: PLW0603
12
import os
23

34
from alphatrion import consts
5+
from alphatrion.artifact.artifact import Artifact
46
from alphatrion.metadata.sql import SQLStore
57

8+
__RUNTIME__ = None
69

10+
11+
def init(project_id: str, artifact_insecure: bool = False):
12+
"""
13+
Initialize the AlphaTrion runtime environment.
14+
15+
:param project_id: the project ID to initialize the environment for
16+
:param artifact_insecure: whether to use insecure connection to the
17+
artifact registry
18+
"""
19+
global __RUNTIME__
20+
__RUNTIME__ = Runtime(project_id=project_id, artifact_insecure=artifact_insecure)
21+
22+
23+
def global_runtime():
24+
return __RUNTIME__
25+
26+
27+
# Runtime contains all kinds of clients, e.g., metadb client, artifact client, etc.
728
class Runtime:
8-
def __init__(self, project_id: str):
29+
def __init__(self, project_id: str, artifact_insecure: bool = False):
930
self._project_id = project_id
1031
self._metadb = SQLStore(os.getenv(consts.METADATA_DB_URL), init_tables=True)
32+
self._artifact = Artifact(
33+
project_id=self._project_id, insecure=artifact_insecure
34+
)

tests/integration/artifact/test_artifact.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,22 @@
55

66
import pytest
77

8-
from alphatrion.artifact.artifact import Artifact
98
from alphatrion.experiment.base import Experiment
10-
from alphatrion.runtime.runtime import Runtime
9+
from alphatrion.observe.observe import log_artifact
10+
from alphatrion.runtime.runtime import global_runtime, init
1111

1212

1313
@pytest.fixture
1414
def artifact():
15-
# We use a local registry for testing, it doesn't mean
16-
# it will always successfully with cloud registries.
17-
# We may need e2e tests for that.
18-
runtime = Runtime(project_id="test_project")
19-
artifact = Artifact(runtime=runtime, insecure=True)
15+
init(project_id="test_project", artifact_insecure=True)
16+
artifact = global_runtime()._artifact
17+
2018
yield artifact
2119

2220

2321
def test_push_with_files(artifact):
24-
# Create a temporary directory with some files
22+
init(project_id="test_project", artifact_insecure=True)
23+
2524
with tempfile.TemporaryDirectory() as tmpdir:
2625
os.chdir(tmpdir)
2726

@@ -45,6 +44,8 @@ def test_push_with_files(artifact):
4544

4645

4746
def test_push_with_folder(artifact):
47+
init(project_id="test_project", artifact_insecure=True)
48+
4849
with tempfile.TemporaryDirectory() as tmpdir:
4950
os.chdir(tmpdir)
5051

@@ -66,13 +67,13 @@ def test_push_with_folder(artifact):
6667

6768

6869
def test_save_checkpoint():
70+
init(project_id="test_project", artifact_insecure=True)
71+
6972
with Experiment.run(
70-
project_id="test_project",
7173
name="context_exp",
7274
description="Context manager test",
7375
meta={"key": "value"},
7476
labels={"type": "unit"},
75-
artifact_insecure=True,
7677
) as exp:
7778
with tempfile.TemporaryDirectory() as tmpdir:
7879
os.chdir(tmpdir)
@@ -81,18 +82,21 @@ def test_save_checkpoint():
8182
with open(file1, "w") as f:
8283
f.write("This is file1.")
8384

84-
exp.log_artifact(1, paths="file1.txt", version="v1")
85-
versions = exp._artifact.list_versions("context_exp")
85+
log_artifact(1, paths="file1.txt", version="v1")
86+
versions = exp._runtime._artifact.list_versions("context_exp")
8687
assert "v1" in versions
8788

8889
with open("file1.txt", "w") as f:
8990
f.write("This is modified file1.")
9091

9192
# push folder instead
92-
exp.log_artifact(1, paths=["file1.txt"], version="v2")
93-
versions = exp._artifact.list_versions("context_exp")
93+
log_artifact(1, paths=["file1.txt"], version="v2")
94+
versions = exp._runtime._artifact.list_versions("context_exp")
9495
assert "v2" in versions
9596

96-
exp._artifact.delete(experiment_name="context_exp", versions=["v1", "v2"])
97-
versions = exp._artifact.list_versions("context_exp")
97+
exp._runtime._artifact.delete(
98+
experiment_name="context_exp",
99+
versions=["v1", "v2"],
100+
)
101+
versions = exp._runtime._artifact.list_versions("context_exp")
98102
assert len(versions) == 0

tests/integration/test_sdk.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
import tempfile
3+
4+
import alphatrion as at
5+
6+
7+
def test_sdk():
8+
at.init(project_id="test_project", artifact_insecure=True)
9+
10+
with at.CraftExperiment.run(
11+
name="craft_exp",
12+
description="test description",
13+
meta={"key": "value"},
14+
labels={"type": "unit"},
15+
) as exp:
16+
with tempfile.TemporaryDirectory() as tmpdir:
17+
os.chdir(tmpdir)
18+
19+
file = "file.txt"
20+
with open(file, "w") as f:
21+
f.write("Hello, AlphaTrion!")
22+
23+
at.log_artifact(2, paths=file, version="v1")
24+
25+
versions = exp._runtime._artifact.list_versions("craft_exp")
26+
assert "v1" in versions

0 commit comments

Comments
 (0)