diff --git a/alphatrion/artifact/artifact.py b/alphatrion/artifact/artifact.py index 6a3be11f..c128a81f 100644 --- a/alphatrion/artifact/artifact.py +++ b/alphatrion/artifact/artifact.py @@ -20,8 +20,7 @@ def __init__(self, runtime: Runtime, insecure: bool = False): def push( self, experiment_name: str, - files: list[str] | None = None, - folder: str | None = None, + paths: str | list[str], version: str = "latest", ): """ @@ -37,27 +36,24 @@ def push( :param version: the version (tag) to push the files under """ - if folder and files: - # Let's be strict here to simplify the implementation. - raise ValueError("Cannot specify both folder and files.") + if paths is None or not paths: + raise ValueError("no files specified to push") - if not folder and not files: - raise ValueError("Either folder or files must be specified.") - - url = self._url if self._url.endswith("/") else f"{self._url}/" - target = f"{url}{self._runtime._project_id}/{experiment_name}:{version}" - - files_to_push = files - if folder: - if not os.path.isdir(folder): - raise ValueError(f"{folder} is not a valid directory.") - - os.chdir(folder) - files_to_push = [f for f in os.listdir(".") if os.path.isfile(f)] + if isinstance(paths, str): + if os.path.isdir(paths): + os.chdir(paths) + files_to_push = [f for f in os.listdir(".") if os.path.isfile(f)] + else: + files_to_push = [paths] + else: + files_to_push = paths if not files_to_push: raise ValueError("No files to push.") + url = self._url if self._url.endswith("/") else f"{self._url}/" + target = f"{url}{self._runtime._project_id}/{experiment_name}:{version}" + try: self._client.push(target, files=files_to_push) except Exception as e: diff --git a/alphatrion/experiment/base.py b/alphatrion/experiment/base.py index 84110f4d..f9f3091c 100644 --- a/alphatrion/experiment/base.py +++ b/alphatrion/experiment/base.py @@ -237,17 +237,28 @@ def running_time(self) -> int: def log_artifact( self, exp_id: int, - files: list[str] | None = None, - folder: str | None = None, + paths: str | list[str], version: str = "latest", ): + """ + Log artifacts (files) to the artifact registry. + :param exp_id: the experiment ID + :param paths: list of file paths to log. + Support one or multiple files or a folder. + If a folder is provided, all files in the folder will be logged. + Don't support nested folders currently. + Only files in the first level of the folder will be logged. + :param version: the version (tag) to log the files under + """ + + if not paths: + raise ValueError("no files specified to log") + exp = self._runtime._metadb.get_exp(exp_id=exp_id) if exp is None: raise ValueError(f"Experiment with id {exp_id} does not exist.") - self._artifact.push( - experiment_name=exp.name, files=files, folder=folder, version=version - ) + self._artifact.push(experiment_name=exp.name, paths=paths, version=version) class RunContext: diff --git a/tests/integration/artifact/test_artifact.py b/tests/integration/artifact/test_artifact.py index 2a2551e5..e15bac9b 100644 --- a/tests/integration/artifact/test_artifact.py +++ b/tests/integration/artifact/test_artifact.py @@ -33,7 +33,7 @@ def test_push_with_files(artifact): f.write("This is file2.") artifact.push( - experiment_name="test_experiment", files=[file1, file2], version="v1" + experiment_name="test_experiment", paths=[file1, file2], version="v1" ) tags = artifact.list_versions("test_experiment") @@ -55,7 +55,7 @@ def test_push_with_folder(artifact): with open(file2, "w") as f: f.write("This is a new file2.") - artifact.push(experiment_name="test_experiment", folder=tmpdir, version="v1") + artifact.push(experiment_name="test_experiment", paths=tmpdir, version="v1") tags = artifact.list_versions("test_experiment") assert "v1" in tags @@ -81,14 +81,15 @@ def test_save_checkpoint(): with open(file1, "w") as f: f.write("This is file1.") - exp.log_artifact(1, files=["file1.txt"], version="v1") + exp.log_artifact(1, paths="file1.txt", version="v1") versions = exp._artifact.list_versions("context_exp") assert "v1" in versions with open("file1.txt", "w") as f: f.write("This is modified file1.") - exp.log_artifact(1, files=["file1.txt"], version="v2") + # push folder instead + exp.log_artifact(1, paths=["file1.txt"], version="v2") versions = exp._artifact.list_versions("context_exp") assert "v2" in versions diff --git a/tests/unit/artifact/test_artifact.py b/tests/unit/artifact/test_artifact.py index 8a94973a..02a41bae 100644 --- a/tests/unit/artifact/test_artifact.py +++ b/tests/unit/artifact/test_artifact.py @@ -17,28 +17,19 @@ def artifact(): yield artifact -def test_push_with_both_files_and_folder(artifact): - with pytest.raises(ValueError): - artifact.push( - experiment_name="test_experiment", - files=["file1.txt"], - folder="some_folder", - version="v1", - ) - - def test_push_with_error_folder(artifact): - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): artifact.push( experiment_name="test_experiment", - folder="non_existent_folder.txt", + paths="non_existent_folder.txt", version="v1", ) -def test_push_with_no_files_and_no_folder(artifact): - with pytest.raises(ValueError): +def test_push_with_empty_folder(artifact): + with pytest.raises(RuntimeError): artifact.push( experiment_name="test_experiment", + paths="empty_folder", version="v1", )