Skip to content

Commit b5d9a27

Browse files
authored
Merge pull request #388 from PolicyEngine/feat/district-breakdowns
Orchestrate state results into US national results with congressional district breakdowns
2 parents b4e317d + a5a7545 commit b5d9a27

14 files changed

Lines changed: 1423 additions & 35 deletions

File tree

.github/workflows/pr.yml

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -91,27 +91,27 @@ jobs:
9191
name: Test integration
9292
needs: docker-build # Only run if docker builds succeed
9393
runs-on: ubuntu-latest
94-
94+
9595
steps:
9696
- uses: actions/checkout@v4
97-
97+
9898
- name: Set up Python
9999
uses: actions/setup-python@v5
100100
with:
101101
python-version: '3.13'
102-
102+
103103
- name: Install uv
104104
uses: astral-sh/setup-uv@v3
105105
with:
106106
enable-cache: true
107-
107+
108108
- name: Set up Docker Buildx
109109
uses: docker/setup-buildx-action@v3
110-
110+
111111
- name: Generate API clients
112112
run: |
113113
./scripts/generate-clients.sh
114-
114+
115115
- name: Start services
116116
run: |
117117
docker compose -f deployment/docker-compose.yml up -d
@@ -126,22 +126,30 @@ jobs:
126126
echo "Waiting for services... (attempt $i/30)"
127127
sleep 2
128128
done
129-
130-
- name: Run integration tests
129+
130+
- name: Run integration tests (local services only)
131131
run: |
132132
cd projects/policyengine-apis-integ
133133
uv sync --extra test
134-
# Run tests that don't require GCP credentials
135-
# Modal staging tests run against deployed staging environment
136-
simulation_integ_test_base_url="https://policyengine-staging--policyengine-simulation-gateway-web-app.modal.run" \
137-
uv run pytest tests/ -v -m "not requires_gcp"
138-
134+
# Run tests against local Docker services only (not Modal staging)
135+
uv run pytest tests/ -v -m "not requires_gcp and not beta_only"
136+
139137
- name: Show service logs on failure
140138
if: failure()
141139
run: |
142140
docker compose -f deployment/docker-compose.yml logs
143-
141+
144142
- name: Stop services
145143
if: always()
146144
run: |
147-
docker compose -f deployment/docker-compose.yml down
145+
docker compose -f deployment/docker-compose.yml down
146+
147+
# Deploy to Modal staging and run Modal-specific integration tests
148+
test-modal-integration:
149+
name: Test Modal integration
150+
needs: docker-build
151+
uses: ./.github/workflows/modal-deploy.reusable.yml
152+
with:
153+
environment: beta
154+
modal_environment: staging
155+
secrets: inherit

projects/policyengine-api-simulation/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ dependencies = [
1616
"pydantic-settings (>=2.7.1,<3.0.0)",
1717
"opentelemetry-instrumentation-fastapi (>=0.51b0,<0.52)",
1818
"policyengine-fastapi",
19-
"policyengine==0.9.0",
19+
"policyengine==0.10.1",
2020
"policyengine-uk>=2.22.8",
2121
"policyengine-us>=1.370.2",
2222
"tables>=3.10.2",
2323
"modal>=0.73.0",
24+
"logfire>=3.0.0",
2425
]
2526

2627
[tool.hatch.build.targets.wheel]

projects/policyengine-api-simulation/src/modal/app.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_app_name(us_version: str, uk_version: str) -> str:
4747
.pip_install(
4848
f"policyengine-us=={US_VERSION}",
4949
f"policyengine-uk=={UK_VERSION}",
50-
"policyengine==0.8.1",
50+
"policyengine==0.10.0",
5151
"tables>=3.10.2",
5252
"logfire",
5353
)
@@ -104,3 +104,53 @@ def run_simulation(params: dict) -> dict:
104104
return result
105105
finally:
106106
logfire.force_flush()
107+
108+
109+
@app.function(
110+
image=simulation_image,
111+
cpu=2.0,
112+
memory=4096,
113+
timeout=7200, # 2 hours to wait for all 52 jobs
114+
retries=0,
115+
secrets=[gcp_secret, logfire_secret],
116+
)
117+
def run_national_with_breakdowns(params: dict) -> dict:
118+
"""
119+
Orchestrate parallel simulations and aggregate results.
120+
121+
Spawns:
122+
- 1 national ECPS simulation (region="us")
123+
- State-level simulations (51 states or 10 test states if _test_mode=True)
124+
125+
Each spawned job runs in its own container via run_simulation.
126+
Returns combined national results with congressional district breakdowns.
127+
128+
If _test_mode=True in params, runs only 10 test states instead of all 51.
129+
"""
130+
import logfire
131+
132+
from src.modal.orchestration import run_national_orchestration
133+
from src.modal.utils.state_codes import TEST_STATE_CODES
134+
135+
configure_logfire()
136+
137+
# Check for test mode
138+
test_mode = params.pop("_test_mode", False)
139+
state_codes = TEST_STATE_CODES if test_mode else None
140+
141+
try:
142+
with logfire.span(
143+
"run_national_with_breakdowns",
144+
input_params=params,
145+
test_mode=test_mode,
146+
) as span:
147+
result = run_national_orchestration(params, run_simulation, state_codes)
148+
span.set_attribute(
149+
"total_districts",
150+
len(
151+
result.get("congressional_district_impact", {}).get("districts", [])
152+
),
153+
)
154+
return result
155+
finally:
156+
logfire.force_flush()

projects/policyengine-api-simulation/src/modal/gateway/endpoints.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def get_app_name(country: str, version: Optional[str]) -> tuple[str, str]:
5050
return app_name, resolved_version
5151

5252

53+
NATIONAL_WITH_BREAKDOWNS = "national-with-breakdowns"
54+
NATIONAL_WITH_BREAKDOWNS_TEST = "national-with-breakdowns-test"
55+
56+
5357
@router.post("/simulate/economy/comparison", response_model=JobSubmitResponse)
5458
async def submit_simulation(request: SimulationRequest):
5559
"""
@@ -58,19 +62,59 @@ async def submit_simulation(request: SimulationRequest):
5862
Matches the existing Cloud Run API endpoint path.
5963
Routes to the appropriate app based on country and version params.
6064
Returns immediately with job_id for polling.
65+
66+
Special handling for data="national-with-breakdowns":
67+
- Only supported for country="us"
68+
- Spawns 52 parallel simulations (1 national + 51 states)
69+
- Returns aggregated results with congressional district breakdowns
70+
71+
Special handling for data="national-with-breakdowns-test":
72+
- Only supported for country="us"
73+
- Spawns 11 parallel simulations (1 national + 10 test states)
74+
- Returns aggregated results with congressional district breakdowns
6175
"""
6276
try:
6377
app_name, resolved_version = get_app_name(request.country, request.version)
6478
except ValueError as e:
6579
raise HTTPException(status_code=400, detail=str(e))
6680

67-
logger.info(f"Routing {request.country}:{resolved_version} to app {app_name}")
81+
# Check for national-with-breakdowns special cases
82+
payload = request.model_dump(exclude={"version"})
83+
data_value = payload.get("data")
84+
is_national_breakdowns = data_value in (
85+
NATIONAL_WITH_BREAKDOWNS,
86+
NATIONAL_WITH_BREAKDOWNS_TEST,
87+
)
88+
89+
if is_national_breakdowns:
90+
if request.country.lower() != "us":
91+
raise HTTPException(
92+
status_code=400,
93+
detail="national-with-breakdowns is only supported for country='us'",
94+
)
95+
96+
# Add test_mode flag to payload for orchestration to use
97+
if data_value == NATIONAL_WITH_BREAKDOWNS_TEST:
98+
payload["_test_mode"] = True
99+
logger.info(
100+
f"Routing {request.country}:{resolved_version} to {app_name} "
101+
f"(national-with-breakdowns-test orchestration - 10 states)"
102+
)
103+
else:
104+
logger.info(
105+
f"Routing {request.country}:{resolved_version} to {app_name} "
106+
f"(national-with-breakdowns orchestration - all states)"
107+
)
108+
109+
func_name = "run_national_with_breakdowns"
110+
else:
111+
logger.info(f"Routing {request.country}:{resolved_version} to app {app_name}")
112+
func_name = "run_simulation"
68113

69114
# Get function reference from the target app
70-
sim_func = modal.Function.from_name(app_name, "run_simulation")
115+
sim_func = modal.Function.from_name(app_name, func_name)
71116

72117
# Spawn the job (returns immediately)
73-
payload = request.model_dump(exclude={"version"})
74118
call = sim_func.spawn(payload)
75119

76120
return JobSubmitResponse(
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""
2+
Orchestration logic for national-with-breakdowns simulations.
3+
4+
This module handles spawning 52 parallel simulations (1 national + 51 states)
5+
and aggregating the results into a single response with congressional district breakdowns.
6+
"""
7+
8+
import logfire
9+
from typing import Any, Callable
10+
11+
from src.modal.utils.state_codes import STATE_CODES, TEST_STATE_CODES
12+
13+
# Re-export for backwards compatibility
14+
__all__ = ["STATE_CODES", "TEST_STATE_CODES", "run_national_orchestration"]
15+
16+
17+
def run_national_orchestration(
18+
params: dict,
19+
run_simulation: Callable,
20+
state_codes: list[str] | None = None,
21+
) -> dict:
22+
"""
23+
Orchestrate parallel simulations and aggregate results.
24+
25+
Spawns:
26+
- 1 national ECPS simulation (region="us")
27+
- State-level simulations for each state in state_codes (or all 51 if not specified)
28+
29+
Each spawned job runs in its own container via run_simulation.
30+
31+
Partial failure handling:
32+
- If ALL states fail, the entire request fails
33+
- If SOME states fail, the request succeeds with null values for failed states
34+
35+
Args:
36+
params: Base simulation parameters (reform, baseline, time_period, etc.)
37+
run_simulation: The Modal function to spawn for each simulation
38+
state_codes: Optional list of state codes to run. If None, runs all 51.
39+
40+
Returns:
41+
Aggregated result with national metrics + all congressional district breakdowns
42+
"""
43+
states_to_run = state_codes if state_codes is not None else STATE_CODES
44+
45+
# Prepare base params (remove the special data flag)
46+
base_params = {k: v for k, v in params.items() if k != "data"}
47+
48+
# 1. Spawn national ECPS simulation
49+
logfire.info("Spawning national ECPS simulation")
50+
national_params = {
51+
**base_params,
52+
"region": "us",
53+
# data=None lets policyengine use default ECPS dataset
54+
}
55+
national_call = run_simulation.spawn(national_params)
56+
57+
# 2. Spawn state simulations (each gets its own container)
58+
logfire.info("Spawning state-level simulations", state_count=len(states_to_run))
59+
state_calls: dict[str, Any] = {}
60+
for state_code in states_to_run:
61+
state_params = {
62+
**base_params,
63+
"region": f"state/{state_code.lower()}",
64+
# data=None lets get_default_dataset resolve to states/{CODE}.h5
65+
}
66+
state_calls[state_code] = run_simulation.spawn(state_params)
67+
68+
logfire.info(
69+
"All simulations spawned, waiting for results",
70+
total_jobs=len(states_to_run) + 1,
71+
)
72+
73+
# 3. Wait for national result first
74+
logfire.info("Waiting for national ECPS result")
75+
national_result = national_call.get()
76+
logfire.info("National ECPS simulation complete")
77+
78+
# 4. Wait for all state results and extract district data
79+
all_districts: list[dict] = []
80+
failed_states: list[str] = []
81+
successful_states: list[str] = []
82+
83+
for state_code in states_to_run:
84+
logfire.info("Waiting for state result", state_code=state_code)
85+
call = state_calls[state_code]
86+
87+
try:
88+
state_result = call.get()
89+
90+
# Extract congressional_district_impact.districts from state result
91+
district_impact = state_result.get("congressional_district_impact", {})
92+
districts = district_impact.get("districts", [])
93+
logfire.info(
94+
"State result received",
95+
state_code=state_code,
96+
districts_extracted=len(districts),
97+
)
98+
all_districts.extend(districts)
99+
successful_states.append(state_code)
100+
101+
except Exception as e:
102+
logfire.warn(
103+
"State simulation failed",
104+
state_code=state_code,
105+
error=str(e)[:200],
106+
)
107+
failed_states.append(state_code)
108+
# Add null placeholder for each district in this state
109+
# We don't know how many districts, so we skip adding placeholders
110+
# The response will simply be missing these districts
111+
112+
logfire.info(
113+
"State simulations complete",
114+
successful_count=len(successful_states),
115+
failed_count=len(failed_states),
116+
)
117+
118+
# 5. Check if ALL states failed
119+
if len(failed_states) == len(states_to_run):
120+
raise RuntimeError(
121+
f"All {len(states_to_run)} state simulations failed. "
122+
f"Failed states: {failed_states}"
123+
)
124+
125+
if failed_states:
126+
logfire.warn("Some states failed", failed_states=failed_states)
127+
128+
logfire.info("Total districts collected", total_districts=len(all_districts))
129+
130+
# 6. Merge: national result + aggregated districts + metadata
131+
final_result = national_result.copy()
132+
final_result["congressional_district_impact"] = {
133+
"districts": all_districts,
134+
"failed_states": failed_states if failed_states else None,
135+
"successful_states": successful_states,
136+
}
137+
138+
return final_result
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
"""
22
Utility functions for Modal deployment.
33
"""
4+
5+
from src.modal.utils.state_codes import STATE_CODES, TEST_STATE_CODES
6+
7+
__all__ = ["STATE_CODES", "TEST_STATE_CODES"]

0 commit comments

Comments
 (0)