2323COUNTY_FIPS = "37183"
2424STATE_FIPS = 37
2525N_CLONES = 1
26+ DEFAULT_CATALOG_N_CLONES = 2
2627SEED = 42
2728VERSION = "0.0.0"
2829
@@ -36,6 +37,8 @@ class SeededCase:
3637 expected_district_name : str = DISTRICT_NAME
3738 n_clones : int = N_CLONES
3839 seed : int = SEED
40+ expected_output_count : int | None = None
41+ expected_output_paths : tuple [str , ...] = ()
3942
4043
4144def _sha256 (path : Path ) -> str :
@@ -68,20 +71,25 @@ def _write_dataset(artifact_dir: Path) -> Path:
6871 return dataset_path
6972
7073
71- def _write_weights (artifact_dir : Path , * , n_records : int ) -> Path :
74+ def _write_weights (
75+ artifact_dir : Path ,
76+ * ,
77+ n_records : int ,
78+ n_clones : int = N_CLONES ,
79+ ) -> Path :
7280 import numpy as np
7381
7482 weights_path = artifact_dir / "calibration_weights.npy"
75- np .save (weights_path , np .ones (n_records * N_CLONES , dtype = np .float32 ))
83+ np .save (weights_path , np .ones (n_records * n_clones , dtype = np .float32 ))
7684 return weights_path
7785
7886
79- def _base_geography (* , n_records : int ):
87+ def _base_geography (* , n_records : int , n_clones : int = N_CLONES ):
8088 import numpy as np
8189
8290 from policyengine_us_data .calibration .clone_and_assign import GeographyAssignment
8391
84- total_rows = n_records * N_CLONES
92+ total_rows = n_records * n_clones
8593 block_geoids = np .array (
8694 [f"{ COUNTY_FIPS } { i :06d} { i :04d} " [:15 ] for i in range (total_rows )],
8795 dtype = "U15" ,
@@ -92,21 +100,93 @@ def _base_geography(*, n_records: int):
92100 county_fips = np .full (total_rows , COUNTY_FIPS , dtype = "U5" ),
93101 state_fips = np .full (total_rows , STATE_FIPS , dtype = np .int32 ),
94102 n_records = n_records ,
95- n_clones = N_CLONES ,
103+ n_clones = n_clones ,
104+ )
105+
106+
107+ def _default_catalog_geography (* , n_records : int , n_clones : int ):
108+ import numpy as np
109+
110+ from policyengine_us_data .calibration .calibration_utils import STATE_CODES
111+ from policyengine_us_data .calibration .clone_and_assign import GeographyAssignment
112+
113+ total_rows = n_records * n_clones
114+ state_universe = tuple (sorted (STATE_CODES ))
115+ if total_rows < len (state_universe ):
116+ raise ValueError (
117+ f"Need at least { len (state_universe )} clone rows to cover default states"
118+ )
119+
120+ state_fips = np .array (
121+ [state_universe [index % len (state_universe )] for index in range (total_rows )],
122+ dtype = np .int32 ,
123+ )
124+ county_fips = np .fromiter (
125+ (
126+ "37183"
127+ if state == STATE_FIPS
128+ else "36061"
129+ if state == 36
130+ else f"{ state :02d} 001"
131+ for state in state_fips
132+ ),
133+ dtype = "U5" ,
134+ count = total_rows ,
135+ )
136+ cd_geoid = np .fromiter (
137+ (
138+ DISTRICT_GEOID
139+ if state == STATE_FIPS
140+ else "3601"
141+ if state == 36
142+ else f"{ state } 01"
143+ for state in state_fips
144+ ),
145+ dtype = "U4" ,
146+ count = total_rows ,
147+ )
148+ block_geoid = np .array (
149+ [
150+ f"{ county } { index :06d} { index :04d} " [:15 ]
151+ for index , county in enumerate (county_fips )
152+ ],
153+ dtype = "U15" ,
154+ )
155+ return GeographyAssignment (
156+ block_geoid = block_geoid ,
157+ cd_geoid = cd_geoid ,
158+ county_fips = county_fips ,
159+ state_fips = state_fips ,
160+ n_records = n_records ,
161+ n_clones = n_clones ,
96162 )
97163
98164
99- def _write_saved_geography (artifact_dir : Path , * , n_records : int ) -> Path :
165+ def _write_saved_geography (
166+ artifact_dir : Path ,
167+ * ,
168+ n_records : int ,
169+ n_clones : int = N_CLONES ,
170+ geography = None ,
171+ ) -> Path :
100172 from policyengine_us_data .calibration .clone_and_assign import save_geography
101173
102174 geography_path = artifact_dir / "geography_assignment.npz"
103- save_geography (_base_geography (n_records = n_records ), geography_path )
175+ save_geography (
176+ geography or _base_geography (n_records = n_records , n_clones = n_clones ),
177+ geography_path ,
178+ )
104179 return geography_path
105180
106181
107- def _write_calibration_package (artifact_dir : Path , * , n_records : int ) -> Path :
182+ def _write_calibration_package (
183+ artifact_dir : Path ,
184+ * ,
185+ n_records : int ,
186+ n_clones : int = N_CLONES ,
187+ ) -> Path :
108188 package_path = artifact_dir / "calibration_package.pkl"
109- geography = _base_geography (n_records = n_records )
189+ geography = _base_geography (n_records = n_records , n_clones = n_clones )
110190 payload = {
111191 "block_geoid" : geography .block_geoid ,
112192 "cd_geoid" : geography .cd_geoid ,
@@ -122,9 +202,14 @@ def _write_calibration_package(artifact_dir: Path, *, n_records: int) -> Path:
122202 return package_path
123203
124204
125- def _write_misnamed_package (artifact_dir : Path , * , n_records : int ) -> Path :
205+ def _write_misnamed_package (
206+ artifact_dir : Path ,
207+ * ,
208+ n_records : int ,
209+ n_clones : int = N_CLONES ,
210+ ) -> Path :
126211 wrong_path = artifact_dir / "calibration_package_typo.pkl"
127- geography = _base_geography (n_records = n_records )
212+ geography = _base_geography (n_records = n_records , n_clones = n_clones )
128213 payload = {
129214 "block_geoid" : geography .block_geoid ,
130215 "cd_geoid" : geography .cd_geoid ,
@@ -192,25 +277,71 @@ def seed_case(
192277 _reset_dir (staging_dir )
193278
194279 n_records = _fixture_n_households ()
280+ n_clones = (
281+ DEFAULT_CATALOG_N_CLONES if case_name == "default_catalog_success" else N_CLONES
282+ )
195283 dataset_path = _write_dataset (artifact_dir )
196- weights_path = _write_weights (artifact_dir , n_records = n_records )
284+ weights_path = _write_weights (
285+ artifact_dir ,
286+ n_records = n_records ,
287+ n_clones = n_clones ,
288+ )
197289 db_path = _write_db (artifact_dir )
198290
199291 geography_path = None
200292 package_path = None
293+ expected_output_count = None
294+ expected_output_paths : tuple [str , ...] = ()
201295
202296 if case_name == "saved_geography_success" :
203- geography_path = _write_saved_geography (artifact_dir , n_records = n_records )
297+ geography_path = _write_saved_geography (
298+ artifact_dir ,
299+ n_records = n_records ,
300+ n_clones = n_clones ,
301+ )
302+ _write_run_config (
303+ artifact_dir ,
304+ weights_path = weights_path ,
305+ geography_path = geography_path ,
306+ )
307+ elif case_name == "default_catalog_success" :
308+ from policyengine_us_data .calibration .calibration_utils import STATE_CODES
309+
310+ geography_path = _write_saved_geography (
311+ artifact_dir ,
312+ n_records = n_records ,
313+ n_clones = n_clones ,
314+ geography = _default_catalog_geography (
315+ n_records = n_records ,
316+ n_clones = n_clones ,
317+ ),
318+ )
319+ expected_output_count = len (STATE_CODES ) + 2
320+ expected_output_paths = (
321+ "states/AL.h5" ,
322+ "states/NC.h5" ,
323+ "states/NY.h5" ,
324+ "districts/NC-01.h5" ,
325+ "cities/NYC.h5" ,
326+ )
204327 _write_run_config (
205328 artifact_dir ,
206329 weights_path = weights_path ,
207330 geography_path = geography_path ,
208331 )
209332 elif case_name == "package_fallback_success" :
210- package_path = _write_calibration_package (artifact_dir , n_records = n_records )
333+ package_path = _write_calibration_package (
334+ artifact_dir ,
335+ n_records = n_records ,
336+ n_clones = n_clones ,
337+ )
211338 _write_run_config (artifact_dir , weights_path = weights_path )
212339 elif case_name == "misnamed_package" :
213- _write_misnamed_package (artifact_dir , n_records = n_records )
340+ _write_misnamed_package (
341+ artifact_dir ,
342+ n_records = n_records ,
343+ n_clones = n_clones ,
344+ )
214345 _write_run_config (artifact_dir , weights_path = weights_path )
215346 else :
216347 raise ValueError (f"Unknown H5 test case: { case_name } " )
@@ -223,7 +354,10 @@ def seed_case(
223354 database_path = db_path ,
224355 geography_path = geography_path ,
225356 calibration_package_path = package_path ,
226- n_clones = N_CLONES ,
357+ n_clones = n_clones ,
227358 seed = SEED ,
228359 ).to_wire_dict (),
360+ n_clones = n_clones ,
361+ expected_output_count = expected_output_count ,
362+ expected_output_paths = expected_output_paths ,
229363 )
0 commit comments