22
33import pytest
44
5+ TEST_APP_RELEASE_BUNDLE = {
6+ "app_name" : "policyengine-simulation-py4-10-0" ,
7+ "policyengine_version" : "4.10.0" ,
8+ "us" : {
9+ "model_version" : "1.500.0" ,
10+ "data_version" : "1.110.12" ,
11+ "default_dataset" : "enhanced_cps_2024" ,
12+ "default_dataset_uri" : "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12" ,
13+ "dataset_uris" : {
14+ "enhanced_cps_2024" : "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.110.12" ,
15+ "cps_2023" : "hf://policyengine/policyengine-us-data/cps_2023.h5@1.110.12" ,
16+ "pooled_3_year_cps_2023" : "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.110.12" ,
17+ },
18+ "dataset_aliases" : {
19+ "enhanced_cps" : "enhanced_cps_2024" ,
20+ "enhanced_cps_2024" : "enhanced_cps_2024" ,
21+ "cps" : "cps_2023" ,
22+ "cps_2023" : "cps_2023" ,
23+ "pooled_cps" : "pooled_3_year_cps_2023" ,
24+ "pooled_3_year_cps_2023" : "pooled_3_year_cps_2023" ,
25+ },
26+ },
27+ "uk" : {
28+ "model_version" : "2.66.0" ,
29+ "data_version" : "1.40.3" ,
30+ "default_dataset" : "enhanced_frs_2023_24" ,
31+ "default_dataset_uri" : "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3" ,
32+ "dataset_uris" : {
33+ "enhanced_frs_2023_24" : "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3" ,
34+ "frs_2023_24" : "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.40.3" ,
35+ },
36+ "dataset_aliases" : {
37+ "enhanced_frs" : "enhanced_frs_2023_24" ,
38+ "enhanced_frs_2023_24" : "enhanced_frs_2023_24" ,
39+ "frs" : "frs_2023_24" ,
40+ "frs_2023_24" : "frs_2023_24" ,
41+ },
42+ },
43+ }
44+
45+ TEST_APP_NAMES = (
46+ "policyengine-simulation-py4-10-0" ,
47+ "policyengine-simulation-py3-9-0" ,
48+ )
49+
50+
51+ def resolve_test_dataset_uri (country : str , dataset : str | None ) -> str | None :
52+ if dataset is None :
53+ return None
54+ if "://" in dataset :
55+ return dataset
56+ country_bundle = TEST_APP_RELEASE_BUNDLE [country ]
57+ dataset_name , revision = (
58+ dataset .rsplit ("@" , maxsplit = 1 ) if "@" in dataset else (dataset , None )
59+ )
60+ dataset_name = country_bundle ["dataset_aliases" ].get (dataset_name , dataset_name )
61+ dataset_uri = country_bundle ["dataset_uris" ].get (dataset_name , dataset_name )
62+ if revision is not None and dataset_uri == dataset_name :
63+ return dataset
64+ if revision is not None and dataset_uri .startswith ("hf://" ):
65+ dataset_uri = f"{ dataset_uri .rsplit ('@' , maxsplit = 1 )[0 ]} @{ revision } "
66+ return dataset_uri
67+
568
669class MockDict :
770 """Mock for Modal.Dict to simulate version registry."""
@@ -107,7 +170,11 @@ def mock_modal(monkeypatch):
107170 from src .modal .gateway import endpoints
108171
109172 mock_func = MockFunction ()
110- mock_dicts = {}
173+ mock_dicts = {
174+ "simulation-api-app-release-bundles" : {
175+ app_name : TEST_APP_RELEASE_BUNDLE for app_name in TEST_APP_NAMES
176+ }
177+ }
111178 MockFunctionCall .registry = {}
112179 MockFunctionCall .from_id_errors = {}
113180
@@ -134,6 +201,20 @@ class MockModal:
134201
135202 monkeypatch .setattr (endpoints , "modal" , MockModal )
136203 monkeypatch .setattr (budget_window_state , "modal" , MockModal )
204+ monkeypatch .setattr (
205+ endpoints ,
206+ "with_hf_revision" ,
207+ lambda dataset_uri , revision : (
208+ f"{ dataset_uri .rsplit ('@' , maxsplit = 1 )[0 ]} @{ revision } "
209+ if dataset_uri .startswith ("hf://" )
210+ else dataset_uri
211+ ),
212+ )
213+ monkeypatch .setattr (
214+ endpoints ,
215+ "validate_hf_dataset_uri" ,
216+ lambda dataset_uri : dataset_uri ,
217+ )
137218
138219 return {
139220 "func" : mock_func ,
0 commit comments