|
2 | 2 | from mock import mock |
3 | 3 | import pytest |
4 | 4 |
|
| 5 | +from clipped.utils.hashing import hash_file, hash_value |
5 | 6 | from polyaxon import settings |
6 | 7 | from polyaxon._client.run import AsyncRunClient, RunClient |
7 | 8 | from polyaxon._schemas.lifecycle import ( |
@@ -147,6 +148,13 @@ def make_run(**kwargs): |
147 | 148 | return V1Run.model_construct(**data) |
148 | 149 |
|
149 | 150 |
|
| 151 | +def get_logged_lineage_artifact(sdk_client, index=0): |
| 152 | + body = sdk_client.runs_v1.create_run_artifacts_lineage.call_args_list[index][1][ |
| 153 | + "body" |
| 154 | + ] |
| 155 | + return body.artifacts[0] |
| 156 | + |
| 157 | + |
150 | 158 | def test_async_run_client_public_export(): |
151 | 159 | from polyaxon.client import AsyncRunClient as Exported |
152 | 160 |
|
@@ -581,6 +589,147 @@ async def test_log_artifact_lineage_and_run_edges_await_api(): |
581 | 589 | assert "async_req" not in sdk_client.runs_v1.set_run_edges_lineage.call_args[1] |
582 | 590 |
|
583 | 591 |
|
| 592 | +@pytest.mark.asyncio |
| 593 | +async def test_log_code_ref_detects_code_ref_in_thread(monkeypatch): |
| 594 | + patch_settings() |
| 595 | + monkeypatch.setattr( |
| 596 | + "polyaxon._client.run.get_code_reference", |
| 597 | + lambda: {"commit": "abc123", "branch": "main"}, |
| 598 | + ) |
| 599 | + sdk_client = AsyncPolyaxonClientMock() |
| 600 | + sdk_client.runs_v1.create_run_artifacts_lineage = AsyncMock(return_value=None) |
| 601 | + client = make_client(sdk_client) |
| 602 | + |
| 603 | + await client.log_code_ref(is_input=False) |
| 604 | + |
| 605 | + artifact = get_logged_lineage_artifact(sdk_client) |
| 606 | + assert artifact.name == "abc123" |
| 607 | + assert artifact.kind == V1ArtifactKind.CODEREF |
| 608 | + assert artifact.summary == {"commit": "abc123", "branch": "main"} |
| 609 | + assert artifact.is_input is False |
| 610 | + |
| 611 | + |
| 612 | +@pytest.mark.asyncio |
| 613 | +async def test_log_data_ref_hashes_content_and_awaits_lineage_api(): |
| 614 | + patch_settings() |
| 615 | + sdk_client = AsyncPolyaxonClientMock() |
| 616 | + sdk_client.runs_v1.create_run_artifacts_lineage = AsyncMock(return_value=None) |
| 617 | + client = make_client(sdk_client) |
| 618 | + |
| 619 | + await client.log_data_ref( |
| 620 | + name="dataset", |
| 621 | + content={"x": 1}, |
| 622 | + summary={"rows": 10}, |
| 623 | + ) |
| 624 | + |
| 625 | + artifact = get_logged_lineage_artifact(sdk_client) |
| 626 | + assert artifact.name == "dataset" |
| 627 | + assert artifact.kind == V1ArtifactKind.DATA |
| 628 | + assert artifact.summary == {"rows": 10, "hash": hash_value({"x": 1})} |
| 629 | + assert artifact.is_input is True |
| 630 | + |
| 631 | + |
| 632 | +@pytest.mark.asyncio |
| 633 | +async def test_log_artifact_ref_hashes_existing_file(tmp_path): |
| 634 | + patch_settings() |
| 635 | + asset = tmp_path / "result.json" |
| 636 | + asset.write_text("payload") |
| 637 | + sdk_client = AsyncPolyaxonClientMock() |
| 638 | + sdk_client.runs_v1.create_run_artifacts_lineage = AsyncMock(return_value=None) |
| 639 | + client = make_client(sdk_client) |
| 640 | + |
| 641 | + await client.log_artifact_ref( |
| 642 | + path=str(asset), |
| 643 | + kind=V1ArtifactKind.ARTIFACT, |
| 644 | + ) |
| 645 | + |
| 646 | + artifact = get_logged_lineage_artifact(sdk_client) |
| 647 | + assert artifact.name == "result" |
| 648 | + assert artifact.kind == V1ArtifactKind.ARTIFACT |
| 649 | + assert artifact.path == str(asset) |
| 650 | + assert artifact.summary == {"path": str(asset), "hash": hash_file(str(asset))} |
| 651 | + |
| 652 | + |
| 653 | +@pytest.mark.asyncio |
| 654 | +async def test_log_model_ref_updates_meta_and_awaits_lineage_api(): |
| 655 | + patch_settings() |
| 656 | + sdk_client = AsyncPolyaxonClientMock() |
| 657 | + sdk_client.runs_v1.patch_run = AsyncMock(return_value=make_run()) |
| 658 | + sdk_client.runs_v1.create_run_artifacts_lineage = AsyncMock(return_value=None) |
| 659 | + client = make_client(sdk_client) |
| 660 | + |
| 661 | + await client.log_model_ref( |
| 662 | + path="models/model.pt", |
| 663 | + name="model", |
| 664 | + framework="pytorch", |
| 665 | + summary={"hash": "hash2"}, |
| 666 | + rel_path="models/model.pt", |
| 667 | + ) |
| 668 | + |
| 669 | + assert client.run_data.meta_info == {"has_model": True} |
| 670 | + sdk_client.runs_v1.patch_run.assert_called_once() |
| 671 | + artifact = get_logged_lineage_artifact(sdk_client) |
| 672 | + assert artifact.name == "model" |
| 673 | + assert artifact.kind == V1ArtifactKind.MODEL |
| 674 | + assert artifact.path == "models/model.pt" |
| 675 | + assert artifact.summary == { |
| 676 | + "hash": "hash2", |
| 677 | + "framework": "pytorch", |
| 678 | + "path": "models/model.pt", |
| 679 | + } |
| 680 | + |
| 681 | + |
| 682 | +@pytest.mark.asyncio |
| 683 | +async def test_log_file_and_dir_refs_hash_paths_and_await_lineage_api(tmp_path): |
| 684 | + patch_settings() |
| 685 | + file_path = tmp_path / "file.txt" |
| 686 | + file_path.write_text("payload") |
| 687 | + dir_path = tmp_path / "outputs" |
| 688 | + dir_path.mkdir() |
| 689 | + (dir_path / "data.txt").write_text("data") |
| 690 | + sdk_client = AsyncPolyaxonClientMock() |
| 691 | + sdk_client.runs_v1.create_run_artifacts_lineage = AsyncMock(return_value=None) |
| 692 | + client = make_client(sdk_client) |
| 693 | + |
| 694 | + await client.log_file_ref(str(file_path), is_input=True) |
| 695 | + await client.log_dir_ref(str(dir_path), name="outputs") |
| 696 | + |
| 697 | + file_artifact = get_logged_lineage_artifact(sdk_client, 0) |
| 698 | + dir_artifact = get_logged_lineage_artifact(sdk_client, 1) |
| 699 | + assert file_artifact.name == "file" |
| 700 | + assert file_artifact.kind == V1ArtifactKind.FILE |
| 701 | + assert file_artifact.is_input is True |
| 702 | + assert file_artifact.summary == { |
| 703 | + "path": str(file_path), |
| 704 | + "hash": hash_file(str(file_path)), |
| 705 | + } |
| 706 | + assert dir_artifact.name == "outputs" |
| 707 | + assert dir_artifact.kind == V1ArtifactKind.DIR |
| 708 | + assert dir_artifact.summary["path"] == str(dir_path) |
| 709 | + assert dir_artifact.summary["hash"] |
| 710 | + |
| 711 | + |
| 712 | +@pytest.mark.asyncio |
| 713 | +async def test_log_tensorboard_ref_logs_once_and_sets_meta(): |
| 714 | + patch_settings() |
| 715 | + sdk_client = AsyncPolyaxonClientMock() |
| 716 | + sdk_client.runs_v1.patch_run = AsyncMock(return_value=make_run()) |
| 717 | + sdk_client.runs_v1.create_run_artifacts_lineage = AsyncMock(return_value=None) |
| 718 | + client = make_client(sdk_client) |
| 719 | + |
| 720 | + await client.log_tensorboard_ref("tensorboard", rel_path="tensorboard") |
| 721 | + await client.log_tensorboard_ref("tensorboard", rel_path="tensorboard") |
| 722 | + |
| 723 | + assert client.run_data.meta_info == {"has_tensorboard": True} |
| 724 | + assert sdk_client.runs_v1.patch_run.call_count == 1 |
| 725 | + assert sdk_client.runs_v1.create_run_artifacts_lineage.call_count == 1 |
| 726 | + artifact = get_logged_lineage_artifact(sdk_client) |
| 727 | + assert artifact.name == "tensorboard" |
| 728 | + assert artifact.kind == V1ArtifactKind.TENSORBOARD |
| 729 | + assert artifact.path == "tensorboard" |
| 730 | + assert artifact.summary == {"path": "tensorboard"} |
| 731 | + |
| 732 | + |
584 | 733 | @pytest.mark.asyncio |
585 | 734 | async def test_promote_methods_use_async_project_client_with_injected_client(): |
586 | 735 | patch_settings() |
@@ -615,7 +764,6 @@ async def test_promote_methods_use_async_project_client_with_injected_client(): |
615 | 764 | ("download_artifacts", ()), |
616 | 765 | ("persist_run", ("/tmp/run",)), |
617 | 766 | ("push_offline_run", ("/tmp/run",)), |
618 | | - ("log_file_ref", ("file.txt",)), |
619 | 767 | ("get_runs_as_hiplot", ()), |
620 | 768 | ], |
621 | 769 | ) |
|
0 commit comments