diff --git a/Makefile b/Makefile index e22ecac2..9eb445f3 100644 --- a/Makefile +++ b/Makefile @@ -34,3 +34,5 @@ test-integration: lint until docker exec postgres pg_isready -U at_user; do sleep 1; done; \ $(POETRY) run pytest tests/integration; \ ' +.PHONY: test-all +test-all: test test-integration \ No newline at end of file diff --git a/alphatrion/__init__.py b/alphatrion/__init__.py index e69de29b..0f7e1309 100644 --- a/alphatrion/__init__.py +++ b/alphatrion/__init__.py @@ -0,0 +1,3 @@ +from alphatrion.experiment.craft_exp import CraftExperiment as CraftExperiment +from alphatrion.observe.observe import log_artifact as log_artifact +from alphatrion.runtime.runtime import init as init diff --git a/alphatrion/artifact/artifact.py b/alphatrion/artifact/artifact.py index c128a81f..05a6cbc3 100644 --- a/alphatrion/artifact/artifact.py +++ b/alphatrion/artifact/artifact.py @@ -3,14 +3,13 @@ import oras.client from alphatrion import consts -from alphatrion.runtime.runtime import Runtime SUCCESS_CODE = 201 class Artifact: - def __init__(self, runtime: Runtime, insecure: bool = False): - self._runtime = runtime + def __init__(self, project_id: str, insecure: bool = False): + self._project_id = project_id self._url = os.environ.get(consts.ARTIFACT_REGISTRY_URL) self._url = self._url.replace("https://", "").replace("http://", "") self._client = oras.client.OrasClient( @@ -52,16 +51,17 @@ def push( raise ValueError("No files to push.") url = self._url if self._url.endswith("/") else f"{self._url}/" - target = f"{url}{self._runtime._project_id}/{experiment_name}:{version}" + target = f"{url}{self._project_id}/{experiment_name}:{version}" try: self._client.push(target, files=files_to_push) except Exception as e: raise RuntimeError("Failed to push artifacts") from e + # TODO: should we store it in the metadb instead? def list_versions(self, experiment_name: str) -> list[str]: url = self._url if self._url.endswith("/") else f"{self._url}/" - target = f"{url}{self._runtime._project_id}/{experiment_name}" + target = f"{url}{self._project_id}/{experiment_name}" try: tags = self._client.get_tags(target) return tags @@ -70,7 +70,7 @@ def list_versions(self, experiment_name: str) -> list[str]: def delete(self, experiment_name: str, versions: str | list[str]): url = self._url if self._url.endswith("/") else f"{self._url}/" - target = f"{url}{self._runtime._project_id}/{experiment_name}" + target = f"{url}{self._project_id}/{experiment_name}" try: self._client.delete_tags(target, tags=versions) diff --git a/alphatrion/experiment/base.py b/alphatrion/experiment/base.py index f9f3091c..cb4b9aa9 100644 --- a/alphatrion/experiment/base.py +++ b/alphatrion/experiment/base.py @@ -3,9 +3,8 @@ from pydantic import BaseModel, Field, field_validator -from alphatrion.artifact.artifact import Artifact from alphatrion.metadata.sql_models import COMPLETED_STATUS, ExperimentStatus -from alphatrion.runtime.runtime import Runtime +from alphatrion.runtime.runtime import global_runtime class CheckpointConfig(BaseModel): @@ -75,9 +74,7 @@ class Experiment: def __init__( self, - runtime: Runtime, config: ExperimentConfig | None = None, - artifact_insecure: bool = False, ): """ :param runtime: the Runtime instance @@ -87,9 +84,8 @@ def __init__( artifact registry. Default is False. """ - self._runtime = runtime - self._artifact = Artifact(runtime, insecure=artifact_insecure) self._config = config or ExperimentConfig() + self._runtime = global_runtime() self._steps = 0 self._best_metric_value = None @@ -100,13 +96,11 @@ def __init__( @classmethod def run( cls, - project_id: str, config: ExperimentConfig | None = None, name: str | None = None, description: str | None = None, meta: dict | None = None, labels: dict | None = None, - artifact_insecure: bool = False, ): """ :param project_id: the project ID to run the experiment under @@ -121,12 +115,7 @@ def run( :return: a context manager that yields an Experiment instance """ - runtime = Runtime(project_id=project_id) - exp = Experiment( - runtime=runtime, - config=config, - artifact_insecure=artifact_insecure, - ) + exp = Experiment(config=config) return RunContext( exp, name=name, description=description, meta=meta, labels=labels ) @@ -234,32 +223,6 @@ def running_time(self) -> int: return 0 return int((datetime.now(UTC) - self._start_at).total_seconds()) - def log_artifact( - self, - exp_id: int, - paths: str | list[str], - version: str = "latest", - ): - """ - Log artifacts (files) to the artifact registry. - :param exp_id: the experiment ID - :param paths: list of file paths to log. - Support one or multiple files or a folder. - If a folder is provided, all files in the folder will be logged. - Don't support nested folders currently. - Only files in the first level of the folder will be logged. - :param version: the version (tag) to log the files under - """ - - if not paths: - raise ValueError("no files specified to log") - - exp = self._runtime._metadb.get_exp(exp_id=exp_id) - if exp is None: - raise ValueError(f"Experiment with id {exp_id} does not exist.") - - self._artifact.push(experiment_name=exp.name, paths=paths, version=version) - class RunContext: """A context manager for running experiments.""" diff --git a/alphatrion/experiment/craft_exp.py b/alphatrion/experiment/craft_exp.py index df577783..b8d578c9 100644 --- a/alphatrion/experiment/craft_exp.py +++ b/alphatrion/experiment/craft_exp.py @@ -1,5 +1,4 @@ -from alphatrion.experiment.base import Experiment -from alphatrion.runtime.runtime import Runtime +from alphatrion.experiment.base import Experiment, ExperimentConfig class CraftExperiment(Experiment): @@ -10,7 +9,7 @@ class CraftExperiment(Experiment): Opposite to other experiment classes, you need to call all these methods yourself. """ - def __init__(self, runtime: Runtime): - super().__init__(runtime) - # Disable checkpointing by default for CraftExperiment + def __init__(self, config: ExperimentConfig | None = None): + super().__init__(config=config) + # Disable auto-checkpointing by default for CraftExperiment self._config.checkpoint.enabled = False diff --git a/alphatrion/observe/__init__.py b/alphatrion/observe/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/alphatrion/observe/observe.py b/alphatrion/observe/observe.py new file mode 100644 index 00000000..5d05fd57 --- /dev/null +++ b/alphatrion/observe/observe.py @@ -0,0 +1,40 @@ +from alphatrion.runtime.runtime import global_runtime + + +def log_artifact( + exp_id: int, + paths: str | list[str], + version: str = "latest", +): + """ + Log artifacts (files) to the artifact registry. + + :param exp_id: the experiment ID + :param paths: list of file paths to log. + Support one or multiple files or a folder. + If a folder is provided, all files in the folder will be logged. + Don't support nested folders currently. + Only files in the first level of the folder will be logged. + :param version: the version (tag) to log the files + """ + + if not paths: + raise ValueError("no files specified to log") + + runtime = global_runtime() + if runtime is None: + raise RuntimeError("Runtime is not initialized. Please call init() first.") + + exp = runtime._metadb.get_exp(exp_id=exp_id) + if exp is None: + raise ValueError(f"Experiment with id {exp_id} does not exist.") + + runtime._artifact.push(experiment_name=exp.name, paths=paths, version=version) + + +# def log_params(exp_id: int, params: dict): +# runtime = global_runtime() +# if runtime is None: +# raise RuntimeError("Runtime is not initialized. Please call init() first.") + +# runtime._metadb.log_params(exp_id=exp_id, params=params) diff --git a/alphatrion/runtime/runtime.py b/alphatrion/runtime/runtime.py index 8f7297be..51571914 100644 --- a/alphatrion/runtime/runtime.py +++ b/alphatrion/runtime/runtime.py @@ -1,10 +1,34 @@ +# ruff: noqa: PLW0603 import os from alphatrion import consts +from alphatrion.artifact.artifact import Artifact from alphatrion.metadata.sql import SQLStore +__RUNTIME__ = None + +def init(project_id: str, artifact_insecure: bool = False): + """ + Initialize the AlphaTrion runtime environment. + + :param project_id: the project ID to initialize the environment for + :param artifact_insecure: whether to use insecure connection to the + artifact registry + """ + global __RUNTIME__ + __RUNTIME__ = Runtime(project_id=project_id, artifact_insecure=artifact_insecure) + + +def global_runtime(): + return __RUNTIME__ + + +# Runtime contains all kinds of clients, e.g., metadb client, artifact client, etc. class Runtime: - def __init__(self, project_id: str): + def __init__(self, project_id: str, artifact_insecure: bool = False): self._project_id = project_id self._metadb = SQLStore(os.getenv(consts.METADATA_DB_URL), init_tables=True) + self._artifact = Artifact( + project_id=self._project_id, insecure=artifact_insecure + ) diff --git a/tests/integration/artifact/test_artifact.py b/tests/integration/artifact/test_artifact.py index e15bac9b..09ac7f86 100644 --- a/tests/integration/artifact/test_artifact.py +++ b/tests/integration/artifact/test_artifact.py @@ -5,23 +5,22 @@ import pytest -from alphatrion.artifact.artifact import Artifact from alphatrion.experiment.base import Experiment -from alphatrion.runtime.runtime import Runtime +from alphatrion.observe.observe import log_artifact +from alphatrion.runtime.runtime import global_runtime, init @pytest.fixture def artifact(): - # We use a local registry for testing, it doesn't mean - # it will always successfully with cloud registries. - # We may need e2e tests for that. - runtime = Runtime(project_id="test_project") - artifact = Artifact(runtime=runtime, insecure=True) + init(project_id="test_project", artifact_insecure=True) + artifact = global_runtime()._artifact + yield artifact def test_push_with_files(artifact): - # Create a temporary directory with some files + init(project_id="test_project", artifact_insecure=True) + with tempfile.TemporaryDirectory() as tmpdir: os.chdir(tmpdir) @@ -45,6 +44,8 @@ def test_push_with_files(artifact): def test_push_with_folder(artifact): + init(project_id="test_project", artifact_insecure=True) + with tempfile.TemporaryDirectory() as tmpdir: os.chdir(tmpdir) @@ -66,13 +67,13 @@ def test_push_with_folder(artifact): def test_save_checkpoint(): + init(project_id="test_project", artifact_insecure=True) + with Experiment.run( - project_id="test_project", name="context_exp", description="Context manager test", meta={"key": "value"}, labels={"type": "unit"}, - artifact_insecure=True, ) as exp: with tempfile.TemporaryDirectory() as tmpdir: os.chdir(tmpdir) @@ -81,18 +82,21 @@ def test_save_checkpoint(): with open(file1, "w") as f: f.write("This is file1.") - exp.log_artifact(1, paths="file1.txt", version="v1") - versions = exp._artifact.list_versions("context_exp") + log_artifact(1, paths="file1.txt", version="v1") + versions = exp._runtime._artifact.list_versions("context_exp") assert "v1" in versions with open("file1.txt", "w") as f: f.write("This is modified file1.") # push folder instead - exp.log_artifact(1, paths=["file1.txt"], version="v2") - versions = exp._artifact.list_versions("context_exp") + log_artifact(1, paths=["file1.txt"], version="v2") + versions = exp._runtime._artifact.list_versions("context_exp") assert "v2" in versions - exp._artifact.delete(experiment_name="context_exp", versions=["v1", "v2"]) - versions = exp._artifact.list_versions("context_exp") + exp._runtime._artifact.delete( + experiment_name="context_exp", + versions=["v1", "v2"], + ) + versions = exp._runtime._artifact.list_versions("context_exp") assert len(versions) == 0 diff --git a/tests/integration/test_sdk.py b/tests/integration/test_sdk.py new file mode 100644 index 00000000..f01c16dd --- /dev/null +++ b/tests/integration/test_sdk.py @@ -0,0 +1,26 @@ +import os +import tempfile + +import alphatrion as at + + +def test_sdk(): + at.init(project_id="test_project", artifact_insecure=True) + + with at.CraftExperiment.run( + name="craft_exp", + description="test description", + meta={"key": "value"}, + labels={"type": "unit"}, + ) as exp: + with tempfile.TemporaryDirectory() as tmpdir: + os.chdir(tmpdir) + + file = "file.txt" + with open(file, "w") as f: + f.write("Hello, AlphaTrion!") + + at.log_artifact(2, paths=file, version="v1") + + versions = exp._runtime._artifact.list_versions("craft_exp") + assert "v1" in versions diff --git a/tests/unit/artifact/test_artifact.py b/tests/unit/artifact/test_artifact.py index 02a41bae..489995bf 100644 --- a/tests/unit/artifact/test_artifact.py +++ b/tests/unit/artifact/test_artifact.py @@ -3,17 +3,13 @@ import pytest -from alphatrion.artifact.artifact import Artifact -from alphatrion.runtime.runtime import Runtime +from alphatrion.runtime.runtime import global_runtime, init @pytest.fixture def artifact(): - # We use a local registry for testing, it doesn't mean - # it will always successfully with cloud registries. - # We may need e2e tests for that. - runtime = Runtime(project_id="test_project") - artifact = Artifact(runtime=runtime, insecure=True) + init(project_id="test_project", artifact_insecure=True) + artifact = global_runtime()._artifact yield artifact diff --git a/tests/unit/experiment/test_base_exp.py b/tests/unit/experiment/test_base_exp.py index 852cdacd..61ff822e 100644 --- a/tests/unit/experiment/test_base_exp.py +++ b/tests/unit/experiment/test_base_exp.py @@ -2,13 +2,13 @@ from alphatrion.experiment.craft_exp import Experiment from alphatrion.metadata.sql_models import ExperimentStatus -from alphatrion.runtime.runtime import Runtime +from alphatrion.runtime.runtime import init @pytest.fixture def exp(): - runtime = Runtime(project_id="test_project") - exp = Experiment(runtime=runtime) + init(project_id="test_project", artifact_insecure=True) + exp = Experiment() yield exp diff --git a/tests/unit/experiment/test_craft_exp.py b/tests/unit/experiment/test_craft_exp.py index 6261bbcf..a23981c6 100644 --- a/tests/unit/experiment/test_craft_exp.py +++ b/tests/unit/experiment/test_craft_exp.py @@ -2,11 +2,13 @@ from alphatrion.experiment.craft_exp import CraftExperiment from alphatrion.metadata.sql_models import ExperimentStatus +from alphatrion.runtime.runtime import init def test_craft_experiment(): + init(project_id="test_project", artifact_insecure=True) + with CraftExperiment.run( - project_id="test_project", name="context_exp", description="Context manager test", meta={"key": "value"},