Skip to content

Commit 3628e0c

Browse files
authored
Merge pull request #745 from PolicyEngine/codex/h5-migration-pr1-contracts-partitioning
Add local H5 partitioning seam
2 parents 28fa8e2 + fa0346b commit 3628e0c

10 files changed

Lines changed: 283 additions & 61 deletions

File tree

changelog.d/745.changed.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Split the new `calibration.local_h5` contracts into themed request, input,
2+
validation, and result modules; extract test-only fixtures into dedicated
3+
fixture helpers; and tighten the new request boundary so construction logic
4+
stays outside the value objects.

modal_app/local_area.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
modal run modal_app/local_area.py --branch=main --num-workers=8
1212
"""
1313

14-
import heapq
1514
import json
1615
import os
1716
import subprocess
@@ -30,6 +29,9 @@
3029

3130
from modal_app.images import cpu_image as image # noqa: E402
3231
from modal_app.resilience import reconcile_run_dir_fingerprint # noqa: E402
32+
from policyengine_us_data.calibration.local_h5.partitioning import ( # noqa: E402
33+
partition_weighted_work_items,
34+
)
3335

3436
app = modal.App("policyengine-us-data-local-area")
3537

@@ -309,25 +311,12 @@ def partition_work(
309311
num_workers: int,
310312
completed: set,
311313
) -> List[List[Dict]]:
312-
"""Partition work items across N workers using LPT scheduling."""
313-
remaining = [
314-
item for item in work_items if f"{item['type']}:{item['id']}" not in completed
315-
]
316-
remaining.sort(key=lambda x: -x["weight"])
317-
318-
n_workers = min(num_workers, len(remaining))
319-
if n_workers == 0:
320-
return []
321-
322-
heap = [(0, i) for i in range(n_workers)]
323-
chunks = [[] for _ in range(n_workers)]
324-
325-
for item in remaining:
326-
load, idx = heapq.heappop(heap)
327-
chunks[idx].append(item)
328-
heapq.heappush(heap, (load + item["weight"], idx))
329-
330-
return [c for c in chunks if c]
314+
"""Compatibility wrapper over the extracted pure partitioning seam."""
315+
return partition_weighted_work_items(
316+
work_items=work_items,
317+
num_workers=num_workers,
318+
completed=completed,
319+
)
331320

332321

333322
def get_completed_from_volume(run_dir: Path) -> set:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Internal package for the incremental local H5 migration.
2+
3+
Modules in this package should land only when they become active runtime
4+
seams rather than speculative placeholders. The first migration slice
5+
introduces only ``partitioning.py``.
6+
"""
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Pure helpers for assigning weighted work items to worker chunks."""
2+
3+
from __future__ import annotations
4+
5+
import heapq
6+
from collections.abc import Mapping, Sequence
7+
from typing import Any
8+
9+
10+
def work_item_key(item: Mapping[str, Any]) -> str:
11+
"""Return the stable completion key used by the current H5 workers."""
12+
13+
return f"{item['type']}:{item['id']}"
14+
15+
16+
def partition_weighted_work_items(
17+
work_items: Sequence[Mapping[str, Any]],
18+
num_workers: int,
19+
completed: set[str] | None = None,
20+
) -> list[list[Mapping[str, Any]]]:
21+
"""Partition work items across workers using longest-processing-time first."""
22+
23+
if num_workers <= 0:
24+
return []
25+
26+
completed = completed or set()
27+
remaining = [item for item in work_items if work_item_key(item) not in completed]
28+
remaining.sort(key=lambda item: -item["weight"])
29+
30+
n_workers = min(num_workers, len(remaining))
31+
if n_workers == 0:
32+
return []
33+
34+
heap: list[tuple[int | float, int]] = [(0, idx) for idx in range(n_workers)]
35+
chunks: list[list[Mapping[str, Any]]] = [[] for _ in range(n_workers)]
36+
37+
for item in remaining:
38+
load, idx = heapq.heappop(heap)
39+
chunks[idx].append(item)
40+
heapq.heappush(heap, (load + item["weight"], idx))
41+
42+
return [chunk for chunk in chunks if chunk]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Shared test helpers for unit calibration tests.
2+
3+
Fixture modules in this package hold setup code and reusable test-only
4+
helpers so individual test files stay focused on assertions.
5+
"""
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Fixture helpers for ``test_local_h5_partitioning.py``."""
2+
3+
from __future__ import annotations
4+
5+
import importlib.util
6+
import sys
7+
from pathlib import Path
8+
9+
__test__ = False
10+
11+
12+
def _load_partitioning_module():
13+
"""Load the pure partitioning module directly from disk."""
14+
15+
repo_root = Path(__file__).resolve().parents[4]
16+
module_path = (
17+
repo_root
18+
/ "policyengine_us_data"
19+
/ "calibration"
20+
/ "local_h5"
21+
/ "partitioning.py"
22+
)
23+
spec = importlib.util.spec_from_file_location(
24+
"local_h5_partitioning",
25+
module_path,
26+
)
27+
module = importlib.util.module_from_spec(spec)
28+
assert spec is not None
29+
assert spec.loader is not None
30+
sys.modules[spec.name] = module
31+
spec.loader.exec_module(module)
32+
return module
33+
34+
35+
def flatten_chunks(chunks):
36+
"""Flatten worker chunks into a single item list for assertions."""
37+
38+
return [item for chunk in chunks for item in chunk]
39+
40+
41+
def load_partitioning_exports():
42+
"""Load the partitioning module and return its public exports."""
43+
44+
module = _load_partitioning_module()
45+
return {
46+
"module": module,
47+
"flatten_chunks": flatten_chunks,
48+
"partition_weighted_work_items": module.partition_weighted_work_items,
49+
"work_item_key": module.work_item_key,
50+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from tests.unit.calibration.fixtures.test_local_h5_partitioning import (
2+
load_partitioning_exports,
3+
)
4+
5+
6+
partitioning = load_partitioning_exports()
7+
flatten_chunks = partitioning["flatten_chunks"]
8+
partition_weighted_work_items = partitioning["partition_weighted_work_items"]
9+
work_item_key = partitioning["work_item_key"]
10+
11+
12+
def test_work_item_key_uses_existing_completion_shape():
13+
item = {"type": "district", "id": "CA-12", "weight": 1}
14+
assert work_item_key(item) == "district:CA-12"
15+
16+
17+
def test_partition_filters_completed_items():
18+
work_items = [
19+
{"type": "state", "id": "CA", "weight": 3},
20+
{"type": "district", "id": "CA-12", "weight": 1},
21+
{"type": "city", "id": "NYC", "weight": 2},
22+
]
23+
24+
chunks = partition_weighted_work_items(
25+
work_items,
26+
num_workers=2,
27+
completed={"district:CA-12"},
28+
)
29+
30+
flattened = flatten_chunks(chunks)
31+
assert all(item["id"] != "CA-12" for item in flattened)
32+
assert {item["id"] for item in flattened} == {"CA", "NYC"}
33+
34+
35+
def test_partition_returns_empty_for_zero_workers_or_zero_remaining():
36+
work_items = [{"type": "state", "id": "CA", "weight": 1}]
37+
38+
assert partition_weighted_work_items(work_items, num_workers=0) == []
39+
assert (
40+
partition_weighted_work_items(
41+
work_items,
42+
num_workers=3,
43+
completed={"state:CA"},
44+
)
45+
== []
46+
)
47+
48+
49+
def test_partition_uses_no_more_workers_than_remaining_items():
50+
work_items = [
51+
{"type": "state", "id": "CA", "weight": 5},
52+
{"type": "state", "id": "NY", "weight": 4},
53+
]
54+
55+
chunks = partition_weighted_work_items(work_items, num_workers=10)
56+
57+
assert len(chunks) == 2
58+
assert all(len(chunk) == 1 for chunk in chunks)
59+
60+
61+
def test_partition_is_weight_balancing_and_deterministic_for_equal_weights():
62+
work_items = [
63+
{"type": "district", "id": "A", "weight": 5},
64+
{"type": "district", "id": "B", "weight": 5},
65+
{"type": "district", "id": "C", "weight": 2},
66+
{"type": "district", "id": "D", "weight": 2},
67+
]
68+
69+
chunks = partition_weighted_work_items(work_items, num_workers=2)
70+
71+
ids_by_chunk = [[item["id"] for item in chunk] for chunk in chunks]
72+
loads = [sum(item["weight"] for item in chunk) for chunk in chunks]
73+
74+
assert ids_by_chunk == [["A", "C"], ["B", "D"]]
75+
assert loads == [7, 7]

tests/unit/fixtures/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Shared test helpers for top-level unit tests."""
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Fixture helpers for `test_modal_local_area.py`."""
2+
3+
import importlib
4+
import sys
5+
from contextlib import contextmanager
6+
from types import ModuleType, SimpleNamespace
7+
8+
__test__ = False
9+
10+
11+
@contextmanager
12+
def _patched_module_registry(overrides: dict[str, ModuleType]):
13+
"""Temporarily replace selected `sys.modules` entries for one import."""
14+
15+
sentinel = object()
16+
previous = {
17+
name: sys.modules.get(name, sentinel)
18+
for name in [*overrides.keys(), "modal_app.local_area"]
19+
}
20+
21+
try:
22+
for name, module in overrides.items():
23+
sys.modules[name] = module
24+
sys.modules.pop("modal_app.local_area", None)
25+
yield
26+
finally:
27+
for name, module in previous.items():
28+
if module is sentinel:
29+
sys.modules.pop(name, None)
30+
else:
31+
sys.modules[name] = module
32+
33+
34+
def load_local_area_module():
35+
"""Import `modal_app.local_area` with scoped fake Modal dependencies."""
36+
37+
fake_modal = ModuleType("modal")
38+
fake_policyengine = ModuleType("policyengine_us_data")
39+
fake_calibration = ModuleType("policyengine_us_data.calibration")
40+
fake_local_h5 = ModuleType("policyengine_us_data.calibration.local_h5")
41+
fake_partitioning = ModuleType(
42+
"policyengine_us_data.calibration.local_h5.partitioning"
43+
)
44+
fake_policyengine.__path__ = []
45+
fake_calibration.__path__ = []
46+
fake_local_h5.__path__ = []
47+
48+
class _FakeApp:
49+
def __init__(self, *args, **kwargs):
50+
pass
51+
52+
def function(self, *args, **kwargs):
53+
def decorator(func):
54+
return func
55+
56+
return decorator
57+
58+
def local_entrypoint(self, *args, **kwargs):
59+
def decorator(func):
60+
return func
61+
62+
return decorator
63+
64+
fake_modal.App = _FakeApp
65+
fake_modal.Secret = SimpleNamespace(from_name=lambda *args, **kwargs: object())
66+
fake_modal.Volume = SimpleNamespace(from_name=lambda *args, **kwargs: object())
67+
68+
fake_images = ModuleType("modal_app.images")
69+
fake_images.cpu_image = object()
70+
71+
fake_resilience = ModuleType("modal_app.resilience")
72+
fake_resilience.reconcile_run_dir_fingerprint = lambda *args, **kwargs: None
73+
fake_partitioning.partition_weighted_work_items = lambda *args, **kwargs: []
74+
75+
with _patched_module_registry(
76+
{
77+
"modal": fake_modal,
78+
"modal_app.images": fake_images,
79+
"modal_app.resilience": fake_resilience,
80+
"policyengine_us_data": fake_policyengine,
81+
"policyengine_us_data.calibration": fake_calibration,
82+
"policyengine_us_data.calibration.local_h5": fake_local_h5,
83+
"policyengine_us_data.calibration.local_h5.partitioning": (
84+
fake_partitioning
85+
),
86+
}
87+
):
88+
return importlib.import_module("modal_app.local_area")

tests/unit/test_modal_local_area.py

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,8 @@
1-
import importlib
2-
import sys
3-
from types import ModuleType, SimpleNamespace
4-
5-
6-
def _load_local_area_module():
7-
fake_modal = ModuleType("modal")
8-
9-
class _FakeApp:
10-
def __init__(self, *args, **kwargs):
11-
pass
12-
13-
def function(self, *args, **kwargs):
14-
def decorator(func):
15-
return func
16-
17-
return decorator
18-
19-
def local_entrypoint(self, *args, **kwargs):
20-
def decorator(func):
21-
return func
22-
23-
return decorator
24-
25-
fake_modal.App = _FakeApp
26-
fake_modal.Secret = SimpleNamespace(from_name=lambda *args, **kwargs: object())
27-
fake_modal.Volume = SimpleNamespace(from_name=lambda *args, **kwargs: object())
28-
29-
fake_images = ModuleType("modal_app.images")
30-
fake_images.cpu_image = object()
31-
32-
fake_resilience = ModuleType("modal_app.resilience")
33-
fake_resilience.reconcile_run_dir_fingerprint = lambda *args, **kwargs: None
34-
35-
sys.modules["modal"] = fake_modal
36-
sys.modules["modal_app.images"] = fake_images
37-
sys.modules["modal_app.resilience"] = fake_resilience
38-
sys.modules.pop("modal_app.local_area", None)
39-
return importlib.import_module("modal_app.local_area")
1+
from tests.unit.fixtures.test_modal_local_area import load_local_area_module
402

413

424
def test_build_promote_national_publish_script_imports_version_manifest_helpers():
43-
local_area = _load_local_area_module()
5+
local_area = load_local_area_module()
446

457
script = local_area._build_promote_national_publish_script(
468
version="1.73.0",
@@ -55,7 +17,7 @@ def test_build_promote_national_publish_script_imports_version_manifest_helpers(
5517

5618

5719
def test_build_promote_publish_script_finalizes_complete_release():
58-
local_area = _load_local_area_module()
20+
local_area = load_local_area_module()
5921

6022
script = local_area._build_promote_publish_script(
6123
version="1.73.0",

0 commit comments

Comments
 (0)