Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions alphatrion/artifact/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def __init__(self, runtime: Runtime, insecure: bool = False):
def push(
self,
experiment_name: str,
files: list[str] | None = None,
folder: str | None = None,
paths: str | list[str],
version: str = "latest",
):
"""
Expand All @@ -37,27 +36,24 @@ def push(
:param version: the version (tag) to push the files under
"""

if folder and files:
# Let's be strict here to simplify the implementation.
raise ValueError("Cannot specify both folder and files.")
if paths is None or not paths:
raise ValueError("no files specified to push")

if not folder and not files:
raise ValueError("Either folder or files must be specified.")

url = self._url if self._url.endswith("/") else f"{self._url}/"
target = f"{url}{self._runtime._project_id}/{experiment_name}:{version}"

files_to_push = files
if folder:
if not os.path.isdir(folder):
raise ValueError(f"{folder} is not a valid directory.")

os.chdir(folder)
files_to_push = [f for f in os.listdir(".") if os.path.isfile(f)]
if isinstance(paths, str):
if os.path.isdir(paths):
os.chdir(paths)
files_to_push = [f for f in os.listdir(".") if os.path.isfile(f)]
else:
files_to_push = [paths]
else:
files_to_push = paths

if not files_to_push:
raise ValueError("No files to push.")

url = self._url if self._url.endswith("/") else f"{self._url}/"
target = f"{url}{self._runtime._project_id}/{experiment_name}:{version}"

try:
self._client.push(target, files=files_to_push)
except Exception as e:
Expand Down
21 changes: 16 additions & 5 deletions alphatrion/experiment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,28 @@ def running_time(self) -> int:
def log_artifact(
self,
exp_id: int,
files: list[str] | None = None,
folder: str | None = None,
paths: str | list[str],
version: str = "latest",
):
"""
Log artifacts (files) to the artifact registry.
:param exp_id: the experiment ID
:param paths: list of file paths to log.
Support one or multiple files or a folder.
If a folder is provided, all files in the folder will be logged.
Don't support nested folders currently.
Only files in the first level of the folder will be logged.
:param version: the version (tag) to log the files under
"""

if not paths:
raise ValueError("no files specified to log")

exp = self._runtime._metadb.get_exp(exp_id=exp_id)
if exp is None:
raise ValueError(f"Experiment with id {exp_id} does not exist.")

self._artifact.push(
experiment_name=exp.name, files=files, folder=folder, version=version
)
self._artifact.push(experiment_name=exp.name, paths=paths, version=version)


class RunContext:
Expand Down
9 changes: 5 additions & 4 deletions tests/integration/artifact/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_push_with_files(artifact):
f.write("This is file2.")

artifact.push(
experiment_name="test_experiment", files=[file1, file2], version="v1"
experiment_name="test_experiment", paths=[file1, file2], version="v1"
)

tags = artifact.list_versions("test_experiment")
Expand All @@ -55,7 +55,7 @@ def test_push_with_folder(artifact):
with open(file2, "w") as f:
f.write("This is a new file2.")

artifact.push(experiment_name="test_experiment", folder=tmpdir, version="v1")
artifact.push(experiment_name="test_experiment", paths=tmpdir, version="v1")

tags = artifact.list_versions("test_experiment")
assert "v1" in tags
Expand All @@ -81,14 +81,15 @@ def test_save_checkpoint():
with open(file1, "w") as f:
f.write("This is file1.")

exp.log_artifact(1, files=["file1.txt"], version="v1")
exp.log_artifact(1, paths="file1.txt", version="v1")
versions = exp._artifact.list_versions("context_exp")
assert "v1" in versions

with open("file1.txt", "w") as f:
f.write("This is modified file1.")

exp.log_artifact(1, files=["file1.txt"], version="v2")
# push folder instead
exp.log_artifact(1, paths=["file1.txt"], version="v2")
versions = exp._artifact.list_versions("context_exp")
assert "v2" in versions

Expand Down
19 changes: 5 additions & 14 deletions tests/unit/artifact/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,19 @@ def artifact():
yield artifact


def test_push_with_both_files_and_folder(artifact):
with pytest.raises(ValueError):
artifact.push(
experiment_name="test_experiment",
files=["file1.txt"],
folder="some_folder",
version="v1",
)


def test_push_with_error_folder(artifact):
with pytest.raises(ValueError):
with pytest.raises(RuntimeError):
artifact.push(
experiment_name="test_experiment",
folder="non_existent_folder.txt",
paths="non_existent_folder.txt",
version="v1",
)


def test_push_with_no_files_and_no_folder(artifact):
with pytest.raises(ValueError):
def test_push_with_empty_folder(artifact):
with pytest.raises(RuntimeError):
artifact.push(
experiment_name="test_experiment",
paths="empty_folder",
version="v1",
)
Loading