Skip to content

Commit 323f1c8

Browse files
authored
Merge pull request #986 from PolicyEngine/codex/h5-migration-pr9-coordinator
Coordinate local H5 workers with catalog requests
2 parents a5e675f + bf0cb76 commit 323f1c8

24 files changed

Lines changed: 2443 additions & 205 deletions

changelog.d/985.added

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Coordinate local H5 publishing through typed catalog requests and normalize worker responses.

docs/engineering/stages/build_outputs.md

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ a postprocessor.
3434

3535
## Worker Chunk Execution
3636

37+
The Modal coordinator builds canonical typed area requests before spawning
38+
workers. Regional publish reads the target congressional district universe from
39+
the staged target database through `TargetUniverseReader`, reads only
40+
coordinator-needed CD and county geography fields through
41+
`CalibrationGeographyLoader.load_index()`, then asks `USAreaCatalog` to define
42+
the regional release shape: every configured state, every target congressional
43+
district, and the explicitly supported city outputs such as NYC. The
44+
coordinator wraps those requests in `WeightedAreaRequest`, partitions them with
45+
`partition_weighted_area_requests()`, and sends workers typed
46+
`--requests-json` payloads. Completion is measured against the explicit request
47+
keys, not just a raw file count, so stale or unrelated H5 files cannot satisfy a
48+
missing expected area.
49+
3750
`LocalH5WorkerService` is the reusable Stage 4 boundary for executing one
3851
prepared local-H5 worker chunk. It consumes a `WorkerSession`, typed
3952
`AreaBuildRequest` objects, and a `WorkerExecutionConfig`, then returns a
@@ -51,11 +64,18 @@ service. It may parse legacy `--work-items` and typed `--requests-json`, prepare
5164
the worker session, and print the legacy coordinator JSON shape, but it should
5265
not regain build-loop, write-loop, or validation-loop logic.
5366

67+
The legacy `--work-items` input path remains compatibility-only while older
68+
tests and explicit override callers are retired. New coordinator work should
69+
prefer typed `AreaBuildRequest` objects and typed worker payloads.
70+
5471
For now, `WorkerResult.to_legacy_dict()` preserves the existing coordinator
5572
contract with `completed`, `failed`, `errors`, `validation_rows`, and
5673
`validation_summary`. New code should prefer the structured `results` and
57-
`issues` fields. Validation exceptions remain visible in legacy `errors` so the
58-
current coordinator does not drop them before it migrates to structured results.
74+
`issues` fields. Validation exceptions remain visible as structured `issues`
75+
with `severity="validation"` by default; they enter legacy `errors` only when a
76+
worker is configured to fail on validation errors or the area result is otherwise
77+
failed. This keeps nonfatal validation diagnostics visible without treating them
78+
as coordinator-fatal worker failures.
5979
Removing the legacy shape and moving the coordinator off worker subprocess JSON
6080
is a later migration step.
6181

modal_app/fixtures/h5_cases.py

Lines changed: 150 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
COUNTY_FIPS = "37183"
2424
STATE_FIPS = 37
2525
N_CLONES = 1
26+
DEFAULT_CATALOG_N_CLONES = 2
2627
SEED = 42
2728
VERSION = "0.0.0"
2829

@@ -36,6 +37,8 @@ class SeededCase:
3637
expected_district_name: str = DISTRICT_NAME
3738
n_clones: int = N_CLONES
3839
seed: int = SEED
40+
expected_output_count: int | None = None
41+
expected_output_paths: tuple[str, ...] = ()
3942

4043

4144
def _sha256(path: Path) -> str:
@@ -68,20 +71,25 @@ def _write_dataset(artifact_dir: Path) -> Path:
6871
return dataset_path
6972

7073

71-
def _write_weights(artifact_dir: Path, *, n_records: int) -> Path:
74+
def _write_weights(
75+
artifact_dir: Path,
76+
*,
77+
n_records: int,
78+
n_clones: int = N_CLONES,
79+
) -> Path:
7280
import numpy as np
7381

7482
weights_path = artifact_dir / "calibration_weights.npy"
75-
np.save(weights_path, np.ones(n_records * N_CLONES, dtype=np.float32))
83+
np.save(weights_path, np.ones(n_records * n_clones, dtype=np.float32))
7684
return weights_path
7785

7886

79-
def _base_geography(*, n_records: int):
87+
def _base_geography(*, n_records: int, n_clones: int = N_CLONES):
8088
import numpy as np
8189

8290
from policyengine_us_data.calibration.clone_and_assign import GeographyAssignment
8391

84-
total_rows = n_records * N_CLONES
92+
total_rows = n_records * n_clones
8593
block_geoids = np.array(
8694
[f"{COUNTY_FIPS}{i:06d}{i:04d}"[:15] for i in range(total_rows)],
8795
dtype="U15",
@@ -92,21 +100,93 @@ def _base_geography(*, n_records: int):
92100
county_fips=np.full(total_rows, COUNTY_FIPS, dtype="U5"),
93101
state_fips=np.full(total_rows, STATE_FIPS, dtype=np.int32),
94102
n_records=n_records,
95-
n_clones=N_CLONES,
103+
n_clones=n_clones,
104+
)
105+
106+
107+
def _default_catalog_geography(*, n_records: int, n_clones: int):
108+
import numpy as np
109+
110+
from policyengine_us_data.calibration.calibration_utils import STATE_CODES
111+
from policyengine_us_data.calibration.clone_and_assign import GeographyAssignment
112+
113+
total_rows = n_records * n_clones
114+
state_universe = tuple(sorted(STATE_CODES))
115+
if total_rows < len(state_universe):
116+
raise ValueError(
117+
f"Need at least {len(state_universe)} clone rows to cover default states"
118+
)
119+
120+
state_fips = np.array(
121+
[state_universe[index % len(state_universe)] for index in range(total_rows)],
122+
dtype=np.int32,
123+
)
124+
county_fips = np.fromiter(
125+
(
126+
"37183"
127+
if state == STATE_FIPS
128+
else "36061"
129+
if state == 36
130+
else f"{state:02d}001"
131+
for state in state_fips
132+
),
133+
dtype="U5",
134+
count=total_rows,
135+
)
136+
cd_geoid = np.fromiter(
137+
(
138+
DISTRICT_GEOID
139+
if state == STATE_FIPS
140+
else "3601"
141+
if state == 36
142+
else f"{state}01"
143+
for state in state_fips
144+
),
145+
dtype="U4",
146+
count=total_rows,
147+
)
148+
block_geoid = np.array(
149+
[
150+
f"{county}{index:06d}{index:04d}"[:15]
151+
for index, county in enumerate(county_fips)
152+
],
153+
dtype="U15",
154+
)
155+
return GeographyAssignment(
156+
block_geoid=block_geoid,
157+
cd_geoid=cd_geoid,
158+
county_fips=county_fips,
159+
state_fips=state_fips,
160+
n_records=n_records,
161+
n_clones=n_clones,
96162
)
97163

98164

99-
def _write_saved_geography(artifact_dir: Path, *, n_records: int) -> Path:
165+
def _write_saved_geography(
166+
artifact_dir: Path,
167+
*,
168+
n_records: int,
169+
n_clones: int = N_CLONES,
170+
geography=None,
171+
) -> Path:
100172
from policyengine_us_data.calibration.clone_and_assign import save_geography
101173

102174
geography_path = artifact_dir / "geography_assignment.npz"
103-
save_geography(_base_geography(n_records=n_records), geography_path)
175+
save_geography(
176+
geography or _base_geography(n_records=n_records, n_clones=n_clones),
177+
geography_path,
178+
)
104179
return geography_path
105180

106181

107-
def _write_calibration_package(artifact_dir: Path, *, n_records: int) -> Path:
182+
def _write_calibration_package(
183+
artifact_dir: Path,
184+
*,
185+
n_records: int,
186+
n_clones: int = N_CLONES,
187+
) -> Path:
108188
package_path = artifact_dir / "calibration_package.pkl"
109-
geography = _base_geography(n_records=n_records)
189+
geography = _base_geography(n_records=n_records, n_clones=n_clones)
110190
payload = {
111191
"block_geoid": geography.block_geoid,
112192
"cd_geoid": geography.cd_geoid,
@@ -122,9 +202,14 @@ def _write_calibration_package(artifact_dir: Path, *, n_records: int) -> Path:
122202
return package_path
123203

124204

125-
def _write_misnamed_package(artifact_dir: Path, *, n_records: int) -> Path:
205+
def _write_misnamed_package(
206+
artifact_dir: Path,
207+
*,
208+
n_records: int,
209+
n_clones: int = N_CLONES,
210+
) -> Path:
126211
wrong_path = artifact_dir / "calibration_package_typo.pkl"
127-
geography = _base_geography(n_records=n_records)
212+
geography = _base_geography(n_records=n_records, n_clones=n_clones)
128213
payload = {
129214
"block_geoid": geography.block_geoid,
130215
"cd_geoid": geography.cd_geoid,
@@ -192,25 +277,71 @@ def seed_case(
192277
_reset_dir(staging_dir)
193278

194279
n_records = _fixture_n_households()
280+
n_clones = (
281+
DEFAULT_CATALOG_N_CLONES if case_name == "default_catalog_success" else N_CLONES
282+
)
195283
dataset_path = _write_dataset(artifact_dir)
196-
weights_path = _write_weights(artifact_dir, n_records=n_records)
284+
weights_path = _write_weights(
285+
artifact_dir,
286+
n_records=n_records,
287+
n_clones=n_clones,
288+
)
197289
db_path = _write_db(artifact_dir)
198290

199291
geography_path = None
200292
package_path = None
293+
expected_output_count = None
294+
expected_output_paths: tuple[str, ...] = ()
201295

202296
if case_name == "saved_geography_success":
203-
geography_path = _write_saved_geography(artifact_dir, n_records=n_records)
297+
geography_path = _write_saved_geography(
298+
artifact_dir,
299+
n_records=n_records,
300+
n_clones=n_clones,
301+
)
302+
_write_run_config(
303+
artifact_dir,
304+
weights_path=weights_path,
305+
geography_path=geography_path,
306+
)
307+
elif case_name == "default_catalog_success":
308+
from policyengine_us_data.calibration.calibration_utils import STATE_CODES
309+
310+
geography_path = _write_saved_geography(
311+
artifact_dir,
312+
n_records=n_records,
313+
n_clones=n_clones,
314+
geography=_default_catalog_geography(
315+
n_records=n_records,
316+
n_clones=n_clones,
317+
),
318+
)
319+
expected_output_count = len(STATE_CODES) + 2
320+
expected_output_paths = (
321+
"states/AL.h5",
322+
"states/NC.h5",
323+
"states/NY.h5",
324+
"districts/NC-01.h5",
325+
"cities/NYC.h5",
326+
)
204327
_write_run_config(
205328
artifact_dir,
206329
weights_path=weights_path,
207330
geography_path=geography_path,
208331
)
209332
elif case_name == "package_fallback_success":
210-
package_path = _write_calibration_package(artifact_dir, n_records=n_records)
333+
package_path = _write_calibration_package(
334+
artifact_dir,
335+
n_records=n_records,
336+
n_clones=n_clones,
337+
)
211338
_write_run_config(artifact_dir, weights_path=weights_path)
212339
elif case_name == "misnamed_package":
213-
_write_misnamed_package(artifact_dir, n_records=n_records)
340+
_write_misnamed_package(
341+
artifact_dir,
342+
n_records=n_records,
343+
n_clones=n_clones,
344+
)
214345
_write_run_config(artifact_dir, weights_path=weights_path)
215346
else:
216347
raise ValueError(f"Unknown H5 test case: {case_name}")
@@ -223,7 +354,10 @@ def seed_case(
223354
database_path=db_path,
224355
geography_path=geography_path,
225356
calibration_package_path=package_path,
226-
n_clones=N_CLONES,
357+
n_clones=n_clones,
227358
seed=SEED,
228359
).to_wire_dict(),
360+
n_clones=n_clones,
361+
expected_output_count=expected_output_count,
362+
expected_output_paths=expected_output_paths,
229363
)

modal_app/h5_test_harness.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def seed_h5_case(run_id: str, case_name: str) -> dict:
117117
"expected_district_name": seeded.expected_district_name,
118118
"n_clones": seeded.n_clones,
119119
"seed": seeded.seed,
120+
"expected_output_count": seeded.expected_output_count,
121+
"expected_output_paths": list(seeded.expected_output_paths),
120122
}
121123

122124

0 commit comments

Comments
 (0)