Skip to content

Commit 6f50b0a

Browse files
committed
Add async run artifact downloads
1 parent 30ff288 commit 6f50b0a

2 files changed

Lines changed: 306 additions & 9 deletions

File tree

cli/polyaxon/_client/run.py

Lines changed: 149 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3687,16 +3687,159 @@ async def get_artifact(
36873687
)
36883688

36893689
@async_client_handler(check_no_op=True, check_offline=True)
3690-
async def download_artifact_for_lineage(self, *args, **kwargs):
3691-
self._raise_sync_only("download_artifact_for_lineage")
3690+
async def download_artifact_for_lineage(
3691+
self,
3692+
lineage: V1RunArtifact,
3693+
force: bool = False,
3694+
path_to: Optional[str] = None,
3695+
):
3696+
if not self.run_uuid:
3697+
return
3698+
3699+
if not self.settings:
3700+
await self.refresh_data()
3701+
await self._use_agent_host()
3702+
3703+
lineage_path = lineage.path or ""
3704+
summary = lineage.summary or {}
3705+
is_event = summary.get("is_event")
3706+
has_step = summary.get("step")
3707+
3708+
if self.run_uuid in lineage_path:
3709+
lineage_path = os.path.relpath(lineage_path, self.run_uuid)
3710+
3711+
if V1ArtifactKind.is_single_file_event(lineage.kind):
3712+
return await self.download_artifact(
3713+
path=lineage_path,
3714+
force=force,
3715+
path_to=path_to,
3716+
)
3717+
3718+
if V1ArtifactKind.is_single_or_multi_file_event(lineage.kind):
3719+
if is_event or has_step:
3720+
if not self.settings:
3721+
await self.refresh_data()
3722+
url = get_proxy_run_url(
3723+
service=STREAMS_V1_LOCATION,
3724+
namespace=self.namespace,
3725+
owner=self.owner,
3726+
project=self.project,
3727+
run_uuid=self.run_uuid,
3728+
subpath="events/{}".format(lineage.kind),
3729+
)
3730+
url = absolute_uri(url=url, host=self.client.config.host)
3731+
params = get_streams_params(
3732+
connection=self.artifacts_store,
3733+
force=force,
3734+
)
3735+
params.update({"names": lineage.name, "pkg_assets": True})
3736+
3737+
# TODO: Update with AsyncPolyaxonStore is done
3738+
return await asyncio.to_thread(
3739+
self.store.download_file,
3740+
url=url,
3741+
path=self.run_uuid,
3742+
use_filepath=False,
3743+
extract_path=path_to,
3744+
path_to=path_to,
3745+
params=params,
3746+
untar=True,
3747+
)
3748+
if V1ArtifactKind.is_file_or_dir(lineage.kind):
3749+
return await self.download_artifacts(
3750+
path=lineage_path,
3751+
path_to=path_to,
3752+
check_path=True,
3753+
)
3754+
return await self.download_artifact(
3755+
path=lineage_path,
3756+
force=force,
3757+
path_to=path_to,
3758+
)
3759+
3760+
if V1ArtifactKind.is_file(lineage.kind):
3761+
return await self.download_artifact(
3762+
path=lineage_path,
3763+
force=force,
3764+
path_to=path_to,
3765+
)
3766+
3767+
if V1ArtifactKind.is_dir(lineage.kind):
3768+
return await self.download_artifacts(path=lineage_path, path_to=path_to)
3769+
3770+
if V1ArtifactKind.is_file_or_dir(lineage.kind):
3771+
return await self.download_artifacts(
3772+
path=lineage_path,
3773+
path_to=path_to,
3774+
check_path=True,
3775+
)
36923776

36933777
@async_client_handler(check_no_op=True, check_offline=True)
3694-
async def download_artifact(self, *args, **kwargs):
3695-
self._raise_sync_only("download_artifact")
3778+
async def download_artifact(
3779+
self,
3780+
path: str,
3781+
force: bool = False,
3782+
path_to: Optional[str] = None,
3783+
):
3784+
if not self.settings:
3785+
await self.refresh_data()
3786+
await self._use_agent_host()
3787+
3788+
url = get_proxy_run_url(
3789+
service=STREAMS_V1_LOCATION,
3790+
namespace=self.namespace,
3791+
owner=self.owner,
3792+
project=self.project,
3793+
run_uuid=self.run_uuid,
3794+
subpath="artifact",
3795+
)
3796+
url = absolute_uri(url=url, host=self.client.config.host)
3797+
params = get_streams_params(connection=self.artifacts_store, force=force)
3798+
return await asyncio.to_thread(
3799+
self.store.download_file,
3800+
url=url,
3801+
path=path,
3802+
path_to=path_to,
3803+
params=params,
3804+
)
36963805

36973806
@async_client_handler(check_no_op=True, check_offline=True)
3698-
async def download_artifacts(self, *args, **kwargs):
3699-
self._raise_sync_only("download_artifacts")
3807+
async def download_artifacts(
3808+
self,
3809+
path: str = "",
3810+
path_to: Optional[str] = None,
3811+
untar: bool = True,
3812+
delete_tar: bool = True,
3813+
extract_path: Optional[str] = None,
3814+
check_path: bool = False,
3815+
):
3816+
if not self.settings:
3817+
await self.refresh_data()
3818+
await self._use_agent_host()
3819+
3820+
url = get_proxy_run_url(
3821+
service=STREAMS_V1_LOCATION,
3822+
namespace=self.namespace,
3823+
owner=self.owner,
3824+
project=self.project,
3825+
run_uuid=self.run_uuid,
3826+
subpath="artifacts",
3827+
)
3828+
url = absolute_uri(url=url, host=self.client.config.host)
3829+
params = get_streams_params(connection=self.artifacts_store)
3830+
if check_path:
3831+
params["check_path"] = True
3832+
3833+
return await asyncio.to_thread(
3834+
self.store.download_file,
3835+
url=url,
3836+
path=path,
3837+
untar=untar,
3838+
path_to=path_to,
3839+
delete_tar=delete_tar and untar,
3840+
extract_path=extract_path,
3841+
params=params,
3842+
)
37003843

37013844
@async_client_handler(check_no_op=True, check_offline=True)
37023845
async def upload_artifact(self, *args, **kwargs):

cli/tests/test_client/test_async_run_client.py

Lines changed: 157 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,163 @@ async def test_get_artifact_awaits_refresh_when_settings_missing():
575575
)
576576

577577

578+
@pytest.mark.asyncio
579+
async def test_download_artifact_methods_call_store_download_file():
580+
patch_settings()
581+
sdk_client = AsyncPolyaxonClientMock()
582+
sdk_client.config = mock.Mock(host="http://api", no_op=None, is_offline=None)
583+
client = make_client(sdk_client)
584+
client._run_data = make_run()
585+
client._store = mock.Mock()
586+
client._store.download_file = mock.Mock(side_effect=["/tmp/file", "/tmp/archive"])
587+
588+
file_path = await client.download_artifact(
589+
"outputs/model.pkl",
590+
force=True,
591+
path_to="/tmp/downloads",
592+
)
593+
archive_path = await client.download_artifacts(
594+
"outputs",
595+
path_to="/tmp/downloads",
596+
untar=False,
597+
extract_path="/tmp/extract",
598+
check_path=True,
599+
)
600+
601+
assert file_path == "/tmp/file"
602+
assert archive_path == "/tmp/archive"
603+
file_kwargs = client._store.download_file.call_args_list[0][1]
604+
archive_kwargs = client._store.download_file.call_args_list[1][1]
605+
assert file_kwargs["path"] == "outputs/model.pkl"
606+
assert file_kwargs["path_to"] == "/tmp/downloads"
607+
assert file_kwargs["params"] == {"force": True}
608+
assert file_kwargs["url"].endswith(
609+
"/streams/v1/test-namespace/test-owner/test-project/runs/{}/artifact".format(
610+
RUN_UUID
611+
)
612+
)
613+
assert archive_kwargs["path"] == "outputs"
614+
assert archive_kwargs["path_to"] == "/tmp/downloads"
615+
assert archive_kwargs["untar"] is False
616+
assert archive_kwargs["delete_tar"] is False
617+
assert archive_kwargs["extract_path"] == "/tmp/extract"
618+
assert archive_kwargs["params"] == {"check_path": True}
619+
assert archive_kwargs["url"].endswith(
620+
"/streams/v1/test-namespace/test-owner/test-project/runs/{}/artifacts".format(
621+
RUN_UUID
622+
)
623+
)
624+
625+
626+
@pytest.mark.asyncio
627+
@pytest.mark.parametrize(
628+
"lineage,expected_method,expected_kwargs",
629+
[
630+
(
631+
V1RunArtifact.model_construct(
632+
kind=V1ArtifactKind.METRIC,
633+
path=f"{RUN_UUID}/events/metric/loss",
634+
),
635+
"download_artifact",
636+
{
637+
"path": "events/metric/loss",
638+
"force": True,
639+
"path_to": "/tmp/downloads",
640+
},
641+
),
642+
(
643+
V1RunArtifact.model_construct(
644+
kind=V1ArtifactKind.MODEL,
645+
path="models/model.pkl",
646+
),
647+
"download_artifacts",
648+
{
649+
"path": "models/model.pkl",
650+
"path_to": "/tmp/downloads",
651+
"check_path": True,
652+
},
653+
),
654+
(
655+
V1RunArtifact.model_construct(
656+
kind=V1ArtifactKind.DIR,
657+
path="outputs",
658+
),
659+
"download_artifacts",
660+
{
661+
"path": "outputs",
662+
"path_to": "/tmp/downloads",
663+
},
664+
),
665+
],
666+
)
667+
async def test_download_artifact_for_lineage_routes_to_download_helpers(
668+
lineage,
669+
expected_method,
670+
expected_kwargs,
671+
):
672+
patch_settings()
673+
sdk_client = AsyncPolyaxonClientMock()
674+
client = make_client(sdk_client)
675+
client._run_data = make_run()
676+
client.download_artifact = AsyncMock(return_value="file")
677+
client.download_artifacts = AsyncMock(return_value="artifacts")
678+
679+
result = await client.download_artifact_for_lineage(
680+
lineage,
681+
force=True,
682+
path_to="/tmp/downloads",
683+
)
684+
685+
if expected_method == "download_artifact":
686+
assert result == "file"
687+
client.download_artifact.assert_called_once_with(**expected_kwargs)
688+
assert client.download_artifacts.call_count == 0
689+
else:
690+
assert result == "artifacts"
691+
client.download_artifacts.assert_called_once_with(**expected_kwargs)
692+
assert client.download_artifact.call_count == 0
693+
694+
695+
@pytest.mark.asyncio
696+
async def test_download_artifact_for_lineage_downloads_event_package():
697+
patch_settings()
698+
sdk_client = AsyncPolyaxonClientMock()
699+
sdk_client.config = mock.Mock(host="http://api", no_op=None, is_offline=None)
700+
client = make_client(sdk_client)
701+
client._run_data = make_run()
702+
client._store = mock.Mock()
703+
client._store.download_file = mock.Mock(return_value="/tmp/events")
704+
lineage = V1RunArtifact.model_construct(
705+
kind=V1ArtifactKind.MODEL,
706+
name="model",
707+
path="events/model",
708+
summary={"is_event": True},
709+
)
710+
711+
result = await client.download_artifact_for_lineage(
712+
lineage,
713+
force=True,
714+
path_to="/tmp/downloads",
715+
)
716+
717+
assert result == "/tmp/events"
718+
kwargs = client._store.download_file.call_args[1]
719+
assert kwargs["path"] == RUN_UUID
720+
assert kwargs["use_filepath"] is False
721+
assert kwargs["extract_path"] == "/tmp/downloads"
722+
assert kwargs["path_to"] == "/tmp/downloads"
723+
assert kwargs["untar"] is True
724+
assert kwargs["params"] == {
725+
"force": True,
726+
"names": "model",
727+
"pkg_assets": True,
728+
}
729+
assert kwargs["url"].endswith(
730+
"/streams/v1/test-namespace/test-owner/test-project/runs/"
731+
"{}/events/model".format(RUN_UUID)
732+
)
733+
734+
578735
@pytest.mark.asyncio
579736
async def test_run_action_methods_await_api_without_async_req():
580737
patch_settings()
@@ -840,9 +997,6 @@ async def test_promote_methods_use_async_project_client_with_injected_client():
840997
("upload_artifact", ("file.txt",)),
841998
("upload_artifacts_dir", ("outputs",)),
842999
("upload_artifacts", (["file.txt"],)),
843-
("download_artifact_for_lineage", (V1RunArtifact.model_construct(),)),
844-
("download_artifact", ("file.txt",)),
845-
("download_artifacts", ()),
8461000
("persist_run", ("/tmp/run",)),
8471001
("push_offline_run", ("/tmp/run",)),
8481002
("get_runs_as_hiplot", ()),

0 commit comments

Comments
 (0)