Skip to content

Commit a705e7a

Browse files
authored
Update parameters (#22)
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 74eb96f commit a705e7a

4 files changed

Lines changed: 40 additions & 41 deletions

File tree

alphatrion/artifact/artifact.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ def __init__(self, runtime: Runtime, insecure: bool = False):
2020
def push(
2121
self,
2222
experiment_name: str,
23-
files: list[str] | None = None,
24-
folder: str | None = None,
23+
paths: str | list[str],
2524
version: str = "latest",
2625
):
2726
"""
@@ -37,27 +36,24 @@ def push(
3736
:param version: the version (tag) to push the files under
3837
"""
3938

40-
if folder and files:
41-
# Let's be strict here to simplify the implementation.
42-
raise ValueError("Cannot specify both folder and files.")
39+
if paths is None or not paths:
40+
raise ValueError("no files specified to push")
4341

44-
if not folder and not files:
45-
raise ValueError("Either folder or files must be specified.")
46-
47-
url = self._url if self._url.endswith("/") else f"{self._url}/"
48-
target = f"{url}{self._runtime._project_id}/{experiment_name}:{version}"
49-
50-
files_to_push = files
51-
if folder:
52-
if not os.path.isdir(folder):
53-
raise ValueError(f"{folder} is not a valid directory.")
54-
55-
os.chdir(folder)
56-
files_to_push = [f for f in os.listdir(".") if os.path.isfile(f)]
42+
if isinstance(paths, str):
43+
if os.path.isdir(paths):
44+
os.chdir(paths)
45+
files_to_push = [f for f in os.listdir(".") if os.path.isfile(f)]
46+
else:
47+
files_to_push = [paths]
48+
else:
49+
files_to_push = paths
5750

5851
if not files_to_push:
5952
raise ValueError("No files to push.")
6053

54+
url = self._url if self._url.endswith("/") else f"{self._url}/"
55+
target = f"{url}{self._runtime._project_id}/{experiment_name}:{version}"
56+
6157
try:
6258
self._client.push(target, files=files_to_push)
6359
except Exception as e:

alphatrion/experiment/base.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,17 +237,28 @@ def running_time(self) -> int:
237237
def log_artifact(
238238
self,
239239
exp_id: int,
240-
files: list[str] | None = None,
241-
folder: str | None = None,
240+
paths: str | list[str],
242241
version: str = "latest",
243242
):
243+
"""
244+
Log artifacts (files) to the artifact registry.
245+
:param exp_id: the experiment ID
246+
:param paths: list of file paths to log.
247+
Support one or multiple files or a folder.
248+
If a folder is provided, all files in the folder will be logged.
249+
Don't support nested folders currently.
250+
Only files in the first level of the folder will be logged.
251+
:param version: the version (tag) to log the files under
252+
"""
253+
254+
if not paths:
255+
raise ValueError("no files specified to log")
256+
244257
exp = self._runtime._metadb.get_exp(exp_id=exp_id)
245258
if exp is None:
246259
raise ValueError(f"Experiment with id {exp_id} does not exist.")
247260

248-
self._artifact.push(
249-
experiment_name=exp.name, files=files, folder=folder, version=version
250-
)
261+
self._artifact.push(experiment_name=exp.name, paths=paths, version=version)
251262

252263

253264
class RunContext:

tests/integration/artifact/test_artifact.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_push_with_files(artifact):
3333
f.write("This is file2.")
3434

3535
artifact.push(
36-
experiment_name="test_experiment", files=[file1, file2], version="v1"
36+
experiment_name="test_experiment", paths=[file1, file2], version="v1"
3737
)
3838

3939
tags = artifact.list_versions("test_experiment")
@@ -55,7 +55,7 @@ def test_push_with_folder(artifact):
5555
with open(file2, "w") as f:
5656
f.write("This is a new file2.")
5757

58-
artifact.push(experiment_name="test_experiment", folder=tmpdir, version="v1")
58+
artifact.push(experiment_name="test_experiment", paths=tmpdir, version="v1")
5959

6060
tags = artifact.list_versions("test_experiment")
6161
assert "v1" in tags
@@ -81,14 +81,15 @@ def test_save_checkpoint():
8181
with open(file1, "w") as f:
8282
f.write("This is file1.")
8383

84-
exp.log_artifact(1, files=["file1.txt"], version="v1")
84+
exp.log_artifact(1, paths="file1.txt", version="v1")
8585
versions = exp._artifact.list_versions("context_exp")
8686
assert "v1" in versions
8787

8888
with open("file1.txt", "w") as f:
8989
f.write("This is modified file1.")
9090

91-
exp.log_artifact(1, files=["file1.txt"], version="v2")
91+
# push folder instead
92+
exp.log_artifact(1, paths=["file1.txt"], version="v2")
9293
versions = exp._artifact.list_versions("context_exp")
9394
assert "v2" in versions
9495

tests/unit/artifact/test_artifact.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,19 @@ def artifact():
1717
yield artifact
1818

1919

20-
def test_push_with_both_files_and_folder(artifact):
21-
with pytest.raises(ValueError):
22-
artifact.push(
23-
experiment_name="test_experiment",
24-
files=["file1.txt"],
25-
folder="some_folder",
26-
version="v1",
27-
)
28-
29-
3020
def test_push_with_error_folder(artifact):
31-
with pytest.raises(ValueError):
21+
with pytest.raises(RuntimeError):
3222
artifact.push(
3323
experiment_name="test_experiment",
34-
folder="non_existent_folder.txt",
24+
paths="non_existent_folder.txt",
3525
version="v1",
3626
)
3727

3828

39-
def test_push_with_no_files_and_no_folder(artifact):
40-
with pytest.raises(ValueError):
29+
def test_push_with_empty_folder(artifact):
30+
with pytest.raises(RuntimeError):
4131
artifact.push(
4232
experiment_name="test_experiment",
33+
paths="empty_folder",
4334
version="v1",
4435
)

0 commit comments

Comments
 (0)