Skip to content
This repository was archived by the owner on Jun 14, 2026. It is now read-only.

Commit 2337fdf

Browse files
authored
Wire calibration harness to target providers (#7)
1 parent 1f014d8 commit 2337fdf

4 files changed

Lines changed: 187 additions & 5 deletions

File tree

.github/workflows/site-snapshot.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ jobs:
4141
run: |
4242
uv run --extra dev --with pydantic --with-editable ../microplex pytest -q \
4343
tests/test_package_imports.py \
44+
tests/test_calibration_harness.py \
4445
tests/targets/test_supabase.py \
4546
tests/pipelines/test_check_site_snapshot.py \
4647
tests/pipelines/test_imputation_ablation.py \

src/microplex_us/calibration_harness.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass
6+
from typing import Any
67

78
import numpy as np
89
import pandas as pd
@@ -11,6 +12,8 @@
1112
FilterOperator,
1213
TargetAggregation,
1314
TargetFilter,
15+
TargetProvider,
16+
TargetQuery,
1417
TargetSpec,
1518
)
1619

@@ -19,7 +22,10 @@
1922
TargetLevel,
2023
TargetRegistry,
2124
get_registry,
25+
target_available_in_cps,
2226
target_category,
27+
target_group_name,
28+
target_level,
2329
target_requires_imputation,
2430
)
2531

@@ -59,10 +65,50 @@ def summary(self) -> str:
5965
class CalibrationHarness:
6066
"""Harness for calibration experiments over one entity frame at a time."""
6167

62-
def __init__(self, registry: TargetRegistry | None = None):
63-
self.registry = registry or get_registry()
68+
def __init__(
69+
self,
70+
registry: TargetRegistry | None = None,
71+
*,
72+
target_provider: TargetProvider | None = None,
73+
):
74+
if target_provider is None:
75+
self.registry = registry or get_registry()
76+
self.target_provider = self.registry
77+
else:
78+
self.registry = registry
79+
self.target_provider = target_provider
6480
self._results: dict[str, CalibrationResult] = {}
6581

82+
def select_targets(
83+
self,
84+
*,
85+
categories: list[TargetCategory] | None = None,
86+
levels: list[TargetLevel] | None = None,
87+
groups: list[str] | None = None,
88+
only_available: bool = False,
89+
entity: EntityType | str | None = None,
90+
period: int | str | None = None,
91+
provider_filters: dict[str, Any] | None = None,
92+
) -> list[TargetSpec]:
93+
"""Select canonical targets from the configured provider."""
94+
query = TargetQuery(
95+
period=period,
96+
entity=entity,
97+
provider_filters=dict(provider_filters or {}),
98+
)
99+
targets = self.target_provider.load_target_set(query).targets
100+
return [
101+
target
102+
for target in targets
103+
if _matches_us_target_filters(
104+
target,
105+
categories=categories,
106+
levels=levels,
107+
groups=groups,
108+
only_available=only_available,
109+
)
110+
]
111+
66112
def get_target_vector(
67113
self,
68114
df: pd.DataFrame,
@@ -202,15 +248,19 @@ def run_experiment(
202248
groups: list[str] | None = None,
203249
only_available: bool = False,
204250
entity: EntityType | str | None = None,
251+
period: int | str | None = None,
252+
provider_filters: dict[str, Any] | None = None,
205253
**calibrate_kwargs,
206254
) -> CalibrationResult:
207255
"""Run a calibration experiment over a filtered target subset."""
208-
selected = self.registry.select_targets(
256+
selected = self.select_targets(
209257
categories=categories,
210258
levels=levels,
211259
groups=groups,
212260
only_available=only_available,
213261
entity=entity,
262+
period=period,
263+
provider_filters=provider_filters,
214264
)
215265
selected = [
216266
target
@@ -262,7 +312,7 @@ def print_target_coverage(
262312
print("TARGET COVERAGE ANALYSIS")
263313
print("=" * 70)
264314

265-
all_targets = self.registry.select_targets(entity=entity)
315+
all_targets = self.select_targets(entity=entity)
266316
columns = set(df.columns)
267317

268318
available: list[TargetSpec] = []
@@ -388,6 +438,25 @@ def _weight_stats(weights: np.ndarray) -> dict[str, float]:
388438
}
389439

390440

441+
def _matches_us_target_filters(
442+
target: TargetSpec,
443+
*,
444+
categories: list[TargetCategory] | None = None,
445+
levels: list[TargetLevel] | None = None,
446+
groups: list[str] | None = None,
447+
only_available: bool = False,
448+
) -> bool:
449+
if categories and target_category(target) not in categories:
450+
return False
451+
if levels and target_level(target) not in levels:
452+
return False
453+
if groups and target_group_name(target) not in groups:
454+
return False
455+
if only_available and not target_available_in_cps(target):
456+
return False
457+
return True
458+
459+
391460
def _build_constraint_row(df: pd.DataFrame, spec: TargetSpec) -> np.ndarray:
392461
if spec.aggregation is TargetAggregation.MEAN:
393462
raise NotImplementedError("Mean targets are not supported by this harness")

tests/targets/test_supabase.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from dataclasses import dataclass
66
from typing import Any
77

8+
import pandas as pd
89
import pytest
910
from microplex.core import EntityType
1011
from microplex.targets import FilterOperator, TargetAggregation, TargetQuery
1112

13+
from microplex_us.calibration_harness import CalibrationHarness
1214
from microplex_us.supabase_targets import (
1315
SUPABASE_SUPPORTED_BY_COLUMN_MAP_KEY,
1416
SUPABASE_TARGET_TYPE_KEY,
@@ -393,3 +395,52 @@ def test_load_target_set_filters_rows_with_core_query(
393395

394396
assert [target.name for target in target_set.targets] == ["employment_income"]
395397
assert calls[0]["params"]["period"] == "eq.2024"
398+
399+
400+
def test_calibration_harness_can_use_supabase_target_provider(
401+
provider: SupabaseTargetProvider,
402+
request_queue,
403+
) -> None:
404+
request_queue(
405+
[
406+
{
407+
"id": "target-1",
408+
"variable": "employment_income",
409+
"value": 30,
410+
"target_type": "amount",
411+
"period": 2024,
412+
"source": {"name": "IRS SOI", "institution": "IRS"},
413+
"stratum": {"name": "National", "jurisdiction": "us"},
414+
},
415+
{
416+
"id": "target-2",
417+
"variable": "unknown_cash_income",
418+
"value": 100,
419+
"target_type": "amount",
420+
"period": 2024,
421+
"source": {"name": "Unknown", "institution": "Other"},
422+
"stratum": {"name": "National", "jurisdiction": "us"},
423+
},
424+
]
425+
)
426+
harness = CalibrationHarness(target_provider=provider)
427+
frame = pd.DataFrame(
428+
{
429+
"employment_income": [10.0, 20.0],
430+
"weight": [1.0, 1.0],
431+
}
432+
)
433+
434+
result = harness.run_experiment(
435+
frame,
436+
"supabase_income",
437+
categories=[TargetCategory.INCOME],
438+
only_available=True,
439+
period=2024,
440+
provider_filters={"include_unsupported": False},
441+
entity=EntityType.PERSON,
442+
verbose=False,
443+
)
444+
445+
assert result.targets_used == ["employment_income"]
446+
assert result.errors == {"employment_income": 0.0}

tests/test_calibration_harness.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import numpy as np
44
import pandas as pd
55
from microplex.core import EntityType
6-
from microplex.targets import TargetAggregation, TargetFilter, TargetSpec
6+
from microplex.targets import (
7+
StaticTargetProvider,
8+
TargetAggregation,
9+
TargetFilter,
10+
TargetSet,
11+
TargetSpec,
12+
)
713

814
from microplex_us.calibration_harness import CalibrationHarness
915
from microplex_us.target_registry import (
@@ -101,3 +107,58 @@ def test_run_experiment_filters_to_selected_canonical_targets(self):
101107

102108
assert result.targets_used == ["ca_people", "ca_income"]
103109
np.testing.assert_allclose(result.weights, np.ones(3))
110+
111+
def test_run_experiment_can_use_core_target_provider(self):
112+
targets = _make_registry().get_all_targets() + [
113+
TargetSpec(
114+
name="future_income",
115+
entity=EntityType.PERSON,
116+
value=50.0,
117+
period=2025,
118+
measure="employment_income",
119+
aggregation=TargetAggregation.SUM,
120+
metadata={
121+
"us_category": "income",
122+
"us_level": "national",
123+
"us_group": "future",
124+
"available_in_cps": True,
125+
"requires_imputation": False,
126+
},
127+
)
128+
]
129+
provider = StaticTargetProvider(TargetSet(targets))
130+
harness = CalibrationHarness(target_provider=provider)
131+
df = pd.DataFrame(
132+
{
133+
"state_fips": ["06", "06", "08"],
134+
"employment_income": [10.0, 20.0, 5.0],
135+
"weight": [1.0, 1.0, 1.0],
136+
}
137+
)
138+
139+
result = harness.run_experiment(
140+
df,
141+
"provider_people_only",
142+
groups=["people"],
143+
only_available=True,
144+
period=2024,
145+
entity=EntityType.PERSON,
146+
verbose=False,
147+
)
148+
149+
assert result.targets_used == ["ca_people", "ca_income"]
150+
151+
def test_print_target_coverage_can_use_core_target_provider(self, capsys):
152+
provider = StaticTargetProvider(TargetSet(_make_registry().get_all_targets()))
153+
harness = CalibrationHarness(target_provider=provider)
154+
df = pd.DataFrame(
155+
{
156+
"state_fips": ["06", "06", "08"],
157+
"employment_income": [10.0, 20.0, 5.0],
158+
"weight": [1.0, 1.0, 1.0],
159+
}
160+
)
161+
162+
harness.print_target_coverage(df, entity=EntityType.PERSON)
163+
164+
assert "Available (2 targets)" in capsys.readouterr().out

0 commit comments

Comments
 (0)