Skip to content

Commit a6f4c43

Browse files
committed
Add async run ref logging parity
* Implement async RunClient ref logging helpers for code, data, artifacts, models, files, directories, and tensorboard refs. Offload blocking local hashing and code-reference detection with asyncio.to_thread while awaiting SDK lineage and metadata updates.
1 parent 93d0a43 commit a6f4c43

2 files changed

Lines changed: 298 additions & 15 deletions

File tree

cli/polyaxon/_client/run.py

Lines changed: 149 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3920,32 +3920,167 @@ async def log_progress(self, value: float):
39203920
await self.log_meta(progress=value)
39213921

39223922
@async_client_handler(check_no_op=True)
3923-
async def log_code_ref(self, *args, **kwargs):
3924-
self._raise_sync_only("log_code_ref")
3923+
async def log_code_ref(
3924+
self, code_ref: Optional[Dict] = None, is_input: bool = True
3925+
):
3926+
code_ref = code_ref or await asyncio.to_thread(get_code_reference)
3927+
if code_ref and "commit" in code_ref:
3928+
artifact_run = V1RunArtifact.model_construct(
3929+
name=code_ref.get("commit"),
3930+
kind=V1ArtifactKind.CODEREF,
3931+
summary=code_ref,
3932+
is_input=is_input,
3933+
)
3934+
await self.log_artifact_lineage(body=artifact_run)
39253935

39263936
@async_client_handler(check_no_op=True)
3927-
async def log_data_ref(self, *args, **kwargs):
3928-
self._raise_sync_only("log_data_ref")
3937+
async def log_data_ref(
3938+
self,
3939+
name: str,
3940+
hash: Optional[str] = None,
3941+
path: Optional[str] = None,
3942+
content=None,
3943+
summary: Optional[Dict] = None,
3944+
is_input: bool = True,
3945+
skip_hash_calculation: bool = False,
3946+
):
3947+
return await self.log_artifact_ref(
3948+
path=path,
3949+
hash=hash,
3950+
content=content,
3951+
kind=V1ArtifactKind.DATA,
3952+
name=name,
3953+
summary=summary,
3954+
is_input=is_input,
3955+
skip_hash_calculation=skip_hash_calculation,
3956+
)
39293957

39303958
@async_client_handler(check_no_op=True)
3931-
async def log_artifact_ref(self, *args, **kwargs):
3932-
self._raise_sync_only("log_artifact_ref")
3959+
async def log_artifact_ref(
3960+
self,
3961+
path: str,
3962+
kind: V1ArtifactKind,
3963+
name: Optional[str] = None,
3964+
hash: Optional[str] = None,
3965+
content=None,
3966+
summary: Optional[Dict] = None,
3967+
is_input: bool = False,
3968+
rel_path: Optional[str] = None,
3969+
skip_hash_calculation: bool = False,
3970+
):
3971+
summary = await asyncio.to_thread(
3972+
self._calculate_summary_for_path_or_content,
3973+
hash=hash,
3974+
path=path,
3975+
content=content,
3976+
summary=summary,
3977+
skip_hash_calculation=skip_hash_calculation,
3978+
)
3979+
if path:
3980+
name = name or get_base_filename(path)
3981+
rel_path = self._sanitize_filepath(filepath=path, rel_path=rel_path)
3982+
if name:
3983+
artifact_run = V1RunArtifact.model_construct(
3984+
name=self._sanitize_filename(name),
3985+
kind=kind,
3986+
path=rel_path,
3987+
summary=summary,
3988+
is_input=is_input,
3989+
)
3990+
await self.log_artifact_lineage(body=artifact_run)
3991+
3992+
async def _log_has_model(self):
3993+
if not self._has_meta_key("has_model"):
3994+
await self.log_meta(has_model=True)
39333995

39343996
@async_client_handler(check_no_op=True)
3935-
async def log_model_ref(self, *args, **kwargs):
3936-
self._raise_sync_only("log_model_ref")
3997+
async def log_model_ref(
3998+
self,
3999+
path: str,
4000+
name: Optional[str] = None,
4001+
framework: Optional[str] = None,
4002+
summary: Optional[Dict] = None,
4003+
is_input: bool = False,
4004+
rel_path: Optional[str] = None,
4005+
skip_hash_calculation: bool = False,
4006+
):
4007+
summary = summary or {}
4008+
summary["framework"] = framework
4009+
await self._log_has_model()
4010+
return await self.log_artifact_ref(
4011+
path=path,
4012+
kind=V1ArtifactKind.MODEL,
4013+
name=name,
4014+
summary=summary,
4015+
is_input=is_input,
4016+
rel_path=rel_path,
4017+
skip_hash_calculation=skip_hash_calculation,
4018+
)
39374019

39384020
@async_client_handler(check_no_op=True)
3939-
async def log_file_ref(self, *args, **kwargs):
3940-
self._raise_sync_only("log_file_ref")
4021+
async def log_file_ref(
4022+
self,
4023+
path: str,
4024+
name: Optional[str] = None,
4025+
hash: Optional[str] = None,
4026+
content=None,
4027+
summary: Optional[Dict] = None,
4028+
is_input: bool = False,
4029+
rel_path: Optional[str] = None,
4030+
skip_hash_calculation: bool = False,
4031+
):
4032+
return await self.log_artifact_ref(
4033+
path=path,
4034+
kind=V1ArtifactKind.FILE,
4035+
name=name,
4036+
hash=hash,
4037+
content=content,
4038+
summary=summary,
4039+
is_input=is_input,
4040+
rel_path=rel_path,
4041+
skip_hash_calculation=skip_hash_calculation,
4042+
)
39414043

39424044
@async_client_handler(check_no_op=True)
3943-
async def log_dir_ref(self, *args, **kwargs):
3944-
self._raise_sync_only("log_dir_ref")
4045+
async def log_dir_ref(
4046+
self,
4047+
path: str,
4048+
name: Optional[str] = None,
4049+
hash: Optional[str] = None,
4050+
summary: Optional[Dict] = None,
4051+
is_input: bool = False,
4052+
rel_path: Optional[str] = None,
4053+
skip_hash_calculation: bool = False,
4054+
):
4055+
return await self.log_artifact_ref(
4056+
path=path,
4057+
kind=V1ArtifactKind.DIR,
4058+
name=name or os.path.basename(path),
4059+
hash=hash,
4060+
summary=summary,
4061+
is_input=is_input,
4062+
rel_path=rel_path,
4063+
skip_hash_calculation=skip_hash_calculation,
4064+
)
39454065

39464066
@async_client_handler(check_no_op=True)
3947-
async def log_tensorboard_ref(self, *args, **kwargs):
3948-
self._raise_sync_only("log_tensorboard_ref")
4067+
async def log_tensorboard_ref(
4068+
self,
4069+
path: str,
4070+
name: str = "tensorboard",
4071+
is_input: bool = False,
4072+
rel_path: Optional[str] = None,
4073+
):
4074+
if not self._has_meta_key("has_tensorboard"):
4075+
await self.log_artifact_ref(
4076+
path=path,
4077+
kind=V1ArtifactKind.TENSORBOARD,
4078+
name=name,
4079+
is_input=is_input,
4080+
rel_path=rel_path,
4081+
skip_hash_calculation=True,
4082+
)
4083+
await self.log_meta(has_tensorboard=True)
39494084

39504085
@async_client_handler(check_no_op=True)
39514086
async def log_artifact_lineage(

cli/tests/test_client/test_async_run_client.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from mock import mock
33
import pytest
44

5+
from clipped.utils.hashing import hash_file, hash_value
56
from polyaxon import settings
67
from polyaxon._client.run import AsyncRunClient, RunClient
78
from polyaxon._schemas.lifecycle import (
@@ -147,6 +148,13 @@ def make_run(**kwargs):
147148
return V1Run.model_construct(**data)
148149

149150

151+
def get_logged_lineage_artifact(sdk_client, index=0):
152+
body = sdk_client.runs_v1.create_run_artifacts_lineage.call_args_list[index][1][
153+
"body"
154+
]
155+
return body.artifacts[0]
156+
157+
150158
def test_async_run_client_public_export():
151159
from polyaxon.client import AsyncRunClient as Exported
152160

@@ -581,6 +589,147 @@ async def test_log_artifact_lineage_and_run_edges_await_api():
581589
assert "async_req" not in sdk_client.runs_v1.set_run_edges_lineage.call_args[1]
582590

583591

592+
@pytest.mark.asyncio
593+
async def test_log_code_ref_detects_code_ref_in_thread(monkeypatch):
594+
patch_settings()
595+
monkeypatch.setattr(
596+
"polyaxon._client.run.get_code_reference",
597+
lambda: {"commit": "abc123", "branch": "main"},
598+
)
599+
sdk_client = AsyncPolyaxonClientMock()
600+
sdk_client.runs_v1.create_run_artifacts_lineage = AsyncMock(return_value=None)
601+
client = make_client(sdk_client)
602+
603+
await client.log_code_ref(is_input=False)
604+
605+
artifact = get_logged_lineage_artifact(sdk_client)
606+
assert artifact.name == "abc123"
607+
assert artifact.kind == V1ArtifactKind.CODEREF
608+
assert artifact.summary == {"commit": "abc123", "branch": "main"}
609+
assert artifact.is_input is False
610+
611+
612+
@pytest.mark.asyncio
613+
async def test_log_data_ref_hashes_content_and_awaits_lineage_api():
614+
patch_settings()
615+
sdk_client = AsyncPolyaxonClientMock()
616+
sdk_client.runs_v1.create_run_artifacts_lineage = AsyncMock(return_value=None)
617+
client = make_client(sdk_client)
618+
619+
await client.log_data_ref(
620+
name="dataset",
621+
content={"x": 1},
622+
summary={"rows": 10},
623+
)
624+
625+
artifact = get_logged_lineage_artifact(sdk_client)
626+
assert artifact.name == "dataset"
627+
assert artifact.kind == V1ArtifactKind.DATA
628+
assert artifact.summary == {"rows": 10, "hash": hash_value({"x": 1})}
629+
assert artifact.is_input is True
630+
631+
632+
@pytest.mark.asyncio
633+
async def test_log_artifact_ref_hashes_existing_file(tmp_path):
634+
patch_settings()
635+
asset = tmp_path / "result.json"
636+
asset.write_text("payload")
637+
sdk_client = AsyncPolyaxonClientMock()
638+
sdk_client.runs_v1.create_run_artifacts_lineage = AsyncMock(return_value=None)
639+
client = make_client(sdk_client)
640+
641+
await client.log_artifact_ref(
642+
path=str(asset),
643+
kind=V1ArtifactKind.ARTIFACT,
644+
)
645+
646+
artifact = get_logged_lineage_artifact(sdk_client)
647+
assert artifact.name == "result"
648+
assert artifact.kind == V1ArtifactKind.ARTIFACT
649+
assert artifact.path == str(asset)
650+
assert artifact.summary == {"path": str(asset), "hash": hash_file(str(asset))}
651+
652+
653+
@pytest.mark.asyncio
654+
async def test_log_model_ref_updates_meta_and_awaits_lineage_api():
655+
patch_settings()
656+
sdk_client = AsyncPolyaxonClientMock()
657+
sdk_client.runs_v1.patch_run = AsyncMock(return_value=make_run())
658+
sdk_client.runs_v1.create_run_artifacts_lineage = AsyncMock(return_value=None)
659+
client = make_client(sdk_client)
660+
661+
await client.log_model_ref(
662+
path="models/model.pt",
663+
name="model",
664+
framework="pytorch",
665+
summary={"hash": "hash2"},
666+
rel_path="models/model.pt",
667+
)
668+
669+
assert client.run_data.meta_info == {"has_model": True}
670+
sdk_client.runs_v1.patch_run.assert_called_once()
671+
artifact = get_logged_lineage_artifact(sdk_client)
672+
assert artifact.name == "model"
673+
assert artifact.kind == V1ArtifactKind.MODEL
674+
assert artifact.path == "models/model.pt"
675+
assert artifact.summary == {
676+
"hash": "hash2",
677+
"framework": "pytorch",
678+
"path": "models/model.pt",
679+
}
680+
681+
682+
@pytest.mark.asyncio
683+
async def test_log_file_and_dir_refs_hash_paths_and_await_lineage_api(tmp_path):
684+
patch_settings()
685+
file_path = tmp_path / "file.txt"
686+
file_path.write_text("payload")
687+
dir_path = tmp_path / "outputs"
688+
dir_path.mkdir()
689+
(dir_path / "data.txt").write_text("data")
690+
sdk_client = AsyncPolyaxonClientMock()
691+
sdk_client.runs_v1.create_run_artifacts_lineage = AsyncMock(return_value=None)
692+
client = make_client(sdk_client)
693+
694+
await client.log_file_ref(str(file_path), is_input=True)
695+
await client.log_dir_ref(str(dir_path), name="outputs")
696+
697+
file_artifact = get_logged_lineage_artifact(sdk_client, 0)
698+
dir_artifact = get_logged_lineage_artifact(sdk_client, 1)
699+
assert file_artifact.name == "file"
700+
assert file_artifact.kind == V1ArtifactKind.FILE
701+
assert file_artifact.is_input is True
702+
assert file_artifact.summary == {
703+
"path": str(file_path),
704+
"hash": hash_file(str(file_path)),
705+
}
706+
assert dir_artifact.name == "outputs"
707+
assert dir_artifact.kind == V1ArtifactKind.DIR
708+
assert dir_artifact.summary["path"] == str(dir_path)
709+
assert dir_artifact.summary["hash"]
710+
711+
712+
@pytest.mark.asyncio
713+
async def test_log_tensorboard_ref_logs_once_and_sets_meta():
714+
patch_settings()
715+
sdk_client = AsyncPolyaxonClientMock()
716+
sdk_client.runs_v1.patch_run = AsyncMock(return_value=make_run())
717+
sdk_client.runs_v1.create_run_artifacts_lineage = AsyncMock(return_value=None)
718+
client = make_client(sdk_client)
719+
720+
await client.log_tensorboard_ref("tensorboard", rel_path="tensorboard")
721+
await client.log_tensorboard_ref("tensorboard", rel_path="tensorboard")
722+
723+
assert client.run_data.meta_info == {"has_tensorboard": True}
724+
assert sdk_client.runs_v1.patch_run.call_count == 1
725+
assert sdk_client.runs_v1.create_run_artifacts_lineage.call_count == 1
726+
artifact = get_logged_lineage_artifact(sdk_client)
727+
assert artifact.name == "tensorboard"
728+
assert artifact.kind == V1ArtifactKind.TENSORBOARD
729+
assert artifact.path == "tensorboard"
730+
assert artifact.summary == {"path": "tensorboard"}
731+
732+
584733
@pytest.mark.asyncio
585734
async def test_promote_methods_use_async_project_client_with_injected_client():
586735
patch_settings()
@@ -615,7 +764,6 @@ async def test_promote_methods_use_async_project_client_with_injected_client():
615764
("download_artifacts", ()),
616765
("persist_run", ("/tmp/run",)),
617766
("push_offline_run", ("/tmp/run",)),
618-
("log_file_ref", ("file.txt",)),
619767
("get_runs_as_hiplot", ()),
620768
],
621769
)

0 commit comments

Comments
 (0)