|
1 | 1 | import hashlib |
2 | 2 | from io import BytesIO |
| 3 | +from importlib import metadata |
3 | 4 | from pathlib import Path |
4 | 5 | from unittest.mock import MagicMock, patch |
5 | 6 |
|
| 7 | +import pytest |
6 | 8 | from huggingface_hub import CommitOperationAdd |
| 9 | +from huggingface_hub.errors import EntryNotFoundError |
7 | 10 |
|
8 | | -from policyengine_uk_data.utils.data_upload import upload_files_to_hf |
| 11 | +from policyengine_uk_data.utils.data_upload import ( |
| 12 | + _get_model_package_version, |
| 13 | + load_release_manifest_from_hf, |
| 14 | + upload_files_to_hf, |
| 15 | +) |
9 | 16 | from policyengine_uk_data.utils.release_manifest import ( |
10 | 17 | RELEASE_MANIFEST_SCHEMA_VERSION, |
11 | 18 | build_release_manifest, |
@@ -81,6 +88,89 @@ def test_build_release_manifest_tracks_uk_release_artifacts(tmp_path): |
81 | 88 | assert manifest["artifacts"]["local_authority_weights"]["kind"] == "weights" |
82 | 89 |
|
83 | 90 |
|
| 91 | +def test_build_release_manifest_refreshes_compatible_model_packages_for_draft_retry( |
| 92 | + tmp_path, |
| 93 | +): |
| 94 | + dataset_path = _write_file( |
| 95 | + tmp_path / "enhanced_frs_2023_24.h5", |
| 96 | + b"enhanced-frs", |
| 97 | + ) |
| 98 | + |
| 99 | + manifest = build_release_manifest( |
| 100 | + files_with_repo_paths=[(dataset_path, "enhanced_frs_2023_24.h5")], |
| 101 | + version="1.40.4", |
| 102 | + repo_id="policyengine/policyengine-uk-data-private", |
| 103 | + model_package_version="9.99.9", |
| 104 | + existing_manifest={ |
| 105 | + "schema_version": RELEASE_MANIFEST_SCHEMA_VERSION, |
| 106 | + "data_package": { |
| 107 | + "name": "policyengine-uk-data", |
| 108 | + "version": "1.40.4", |
| 109 | + }, |
| 110 | + "compatible_model_packages": [ |
| 111 | + { |
| 112 | + "name": "policyengine-uk", |
| 113 | + "specifier": "==1.0.0", |
| 114 | + } |
| 115 | + ], |
| 116 | + "default_datasets": {}, |
| 117 | + "created_at": "2026-04-10T12:00:00Z", |
| 118 | + "artifacts": {}, |
| 119 | + }, |
| 120 | + ) |
| 121 | + |
| 122 | + assert manifest["compatible_model_packages"] == [ |
| 123 | + {"name": "policyengine-uk", "specifier": "==9.99.9"} |
| 124 | + ] |
| 125 | + |
| 126 | + |
| 127 | +def test_load_release_manifest_from_hf_raises_non_missing_download_errors(): |
| 128 | + with patch( |
| 129 | + "policyengine_uk_data.utils.data_upload.hf_hub_download", |
| 130 | + side_effect=RuntimeError("boom"), |
| 131 | + ): |
| 132 | + with pytest.raises(RuntimeError, match="boom"): |
| 133 | + load_release_manifest_from_hf(version="1.40.4") |
| 134 | + |
| 135 | + |
| 136 | +def test_load_release_manifest_from_hf_continues_on_missing_entry(tmp_path): |
| 137 | + manifest_path = tmp_path / "release_manifest.json" |
| 138 | + manifest_path.write_text('{"data_package": {"version": "1.40.4"}}') |
| 139 | + |
| 140 | + with patch( |
| 141 | + "policyengine_uk_data.utils.data_upload.hf_hub_download", |
| 142 | + side_effect=[ |
| 143 | + EntryNotFoundError("missing"), |
| 144 | + str(manifest_path), |
| 145 | + ], |
| 146 | + ): |
| 147 | + manifest = load_release_manifest_from_hf(version="1.40.4") |
| 148 | + |
| 149 | + assert manifest["data_package"]["version"] == "1.40.4" |
| 150 | + |
| 151 | + |
| 152 | +def test_get_model_package_version_prefers_imported_checkout(tmp_path): |
| 153 | + package_root = tmp_path / "policyengine_uk" |
| 154 | + package_root.mkdir() |
| 155 | + (package_root / "__init__.py").write_text("") |
| 156 | + pyproject_path = tmp_path / "pyproject.toml" |
| 157 | + pyproject_path.write_text( |
| 158 | + '[project]\nname = "policyengine-uk"\nversion = "2.78.0"\n' |
| 159 | + ) |
| 160 | + fake_spec = MagicMock(origin=str(package_root / "__init__.py")) |
| 161 | + |
| 162 | + with ( |
| 163 | + patch( |
| 164 | + "policyengine_uk_data.utils.data_upload.find_spec", return_value=fake_spec |
| 165 | + ), |
| 166 | + patch( |
| 167 | + "policyengine_uk_data.utils.data_upload.metadata.version", |
| 168 | + side_effect=metadata.PackageNotFoundError, |
| 169 | + ), |
| 170 | + ): |
| 171 | + assert _get_model_package_version() == "2.78.0" |
| 172 | + |
| 173 | + |
84 | 174 | def test_upload_files_to_hf_adds_uk_release_manifest_operations(tmp_path): |
85 | 175 | dataset_path = _write_file( |
86 | 176 | tmp_path / "enhanced_frs_2023_24.h5", |
|
0 commit comments