Skip to content

Commit 3d1ab93

Browse files
MaxGhenisclaude
andcommitted
Add calibrate-on-synthesizer experiment script
Tests whether MicrocalibrateAdapter on top of a weak synthesizer recovers weighted aggregate accuracy. Stage-1 PRDC measured un-weighted coverage; the actual production pipeline is synthesize -> calibrate, so a method that produces biased samples may still produce accurate WEIGHTED aggregates after calibration. Procedure for each method: 1. Fit synthesizer on train, generate synthetic with unit weights. 2. Rescale initial weights so synth totals match holdout-scale (moves gradient descent's starting point close to the target). 3. Build per-target-column sum LinearConstraints with holdout totals. 4. Run MicrocalibrateAdapter. 5. Report pre- and post-calibration relative error per target. Usage: uv run python scripts/calibrate_on_synthesizer.py --n-rows 20000 Interpretation: - If post-cal error converges to near-zero across methods, choice of synthesizer matters less than PRDC alone suggested. The weights carry the accuracy signal. - If ZI-MAF / ZI-QDNN can't be calibrated (gradient descent diverges or leaves huge residuals), the PRDC verdict stands and the synthesizer choice is load-bearing. Output: artifacts/calibrate_on_synthesizer.json with per-target pre/post errors, calibration wall time, weight distribution summary. Not run tonight — deferred to Max's morning after the ZI-MAF tuning job completes (both would contend for CPU otherwise). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent cef213b commit 3d1ab93

1 file changed

Lines changed: 266 additions & 0 deletions

File tree

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
"""Measure whether `microcalibrate` on top of a synthesizer rescues weak synthesis.
2+
3+
Stage-1 PRDC coverage compared synthesizers with uniform unit weights. The
4+
actual production pipeline is synthesize → calibrate. If calibration can
5+
pull a weak synthesizer's weighted aggregates onto the real targets, the
6+
choice of synthesizer matters less than PRDC alone would suggest.
7+
8+
Procedure:
9+
10+
1. Load enhanced_cps_2024 (`ScaleUpRunner.load_frame`), split 80/20.
11+
2. For each method (ZI-QRF / ZI-MAF / ZI-QDNN):
12+
a. Fit method, generate synthetic records with uniform weights.
13+
b. Compute holdout aggregates for each target column
14+
(total, count-of-nonzero).
15+
c. Build `LinearConstraint`s that require the weighted synthetic
16+
aggregates to match the holdout aggregates.
17+
d. Run `MicrocalibrateAdapter.fit_transform`.
18+
e. Report per-target relative error pre- and post-calibration.
19+
20+
Usage:
21+
uv run python scripts/calibrate_on_synthesizer.py --n-rows 20000
22+
23+
~10 minutes on a 48 GB M3 for 20k × 50 × 3 methods.
24+
"""
25+
26+
from __future__ import annotations
27+
28+
import argparse
29+
import json
30+
import logging
31+
import time
32+
from pathlib import Path
33+
34+
import numpy as np
35+
import pandas as pd
36+
from microplex.calibration import LinearConstraint
37+
from microplex.eval.benchmark import ZIMAFMethod, ZIQDNNMethod, ZIQRFMethod
38+
39+
from microplex_us.bakeoff import (
40+
DEFAULT_CONDITION_COLS,
41+
DEFAULT_TARGET_COLS,
42+
ScaleUpRunner,
43+
ScaleUpStageConfig,
44+
stage1_config,
45+
)
46+
from microplex_us.calibration import (
47+
MicrocalibrateAdapter,
48+
MicrocalibrateAdapterConfig,
49+
)
50+
51+
LOGGER = logging.getLogger(__name__)
52+
53+
METHOD_REGISTRY = {
54+
"ZI-QRF": ZIQRFMethod,
55+
"ZI-MAF": ZIMAFMethod,
56+
"ZI-QDNN": ZIQDNNMethod,
57+
}
58+
59+
60+
def build_target_constraints(
61+
holdout: pd.DataFrame,
62+
synthetic: pd.DataFrame,
63+
target_cols: tuple[str, ...],
64+
) -> tuple[LinearConstraint, ...]:
65+
"""One total-sum constraint per target column.
66+
67+
Target = sum of `holdout[col]`; coefficients = `synthetic[col].values`.
68+
After calibration, `(weights * coefficients).sum()` should match target.
69+
"""
70+
constraints: list[LinearConstraint] = []
71+
for col in target_cols:
72+
if col not in synthetic.columns or col not in holdout.columns:
73+
continue
74+
target = float(holdout[col].sum())
75+
coefs = synthetic[col].to_numpy(dtype=float)
76+
constraints.append(
77+
LinearConstraint(
78+
name=f"sum_{col}",
79+
coefficients=coefs,
80+
target=target,
81+
)
82+
)
83+
return tuple(constraints)
84+
85+
86+
def evaluate_aggregates(
87+
holdout: pd.DataFrame,
88+
synthetic: pd.DataFrame,
89+
weights: np.ndarray,
90+
target_cols: tuple[str, ...],
91+
) -> dict[str, dict[str, float]]:
92+
"""Per-target: real total, weighted-synth total, relative error."""
93+
out: dict[str, dict[str, float]] = {}
94+
for col in target_cols:
95+
if col not in synthetic.columns or col not in holdout.columns:
96+
continue
97+
real_total = float(holdout[col].sum())
98+
synth_weighted = float((synthetic[col].to_numpy(dtype=float) * weights).sum())
99+
rel_err = abs(synth_weighted - real_total) / max(abs(real_total), 1.0)
100+
out[col] = {
101+
"real_total": real_total,
102+
"weighted_synth_total": synth_weighted,
103+
"relative_error": rel_err,
104+
}
105+
return out
106+
107+
108+
def main(argv: list[str] | None = None) -> int:
109+
parser = argparse.ArgumentParser(description=__doc__)
110+
parser.add_argument("--n-rows", type=int, default=20_000)
111+
parser.add_argument(
112+
"--methods", nargs="+", default=["ZI-QRF", "ZI-MAF", "ZI-QDNN"]
113+
)
114+
parser.add_argument("--calibration-epochs", type=int, default=100)
115+
parser.add_argument(
116+
"--output",
117+
type=Path,
118+
default=Path("artifacts/calibrate_on_synthesizer.json"),
119+
)
120+
parser.add_argument("--seed", type=int, default=42)
121+
args = parser.parse_args(argv)
122+
123+
logging.basicConfig(
124+
level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
125+
)
126+
127+
base = stage1_config()
128+
cfg = ScaleUpStageConfig(
129+
stage="calibrate_on_synth",
130+
n_rows=args.n_rows,
131+
methods=tuple(args.methods),
132+
condition_cols=DEFAULT_CONDITION_COLS,
133+
target_cols=DEFAULT_TARGET_COLS,
134+
holdout_frac=0.2,
135+
seed=args.seed,
136+
k=5,
137+
data_path=base.data_path,
138+
year=base.year,
139+
rare_cell_checks=(),
140+
prdc_max_samples=15_000,
141+
)
142+
runner = ScaleUpRunner(cfg)
143+
df = runner.load_frame()
144+
train, holdout = runner.split(df)
145+
LOGGER.info(
146+
"loaded %d rows; train=%d holdout=%d", len(df), len(train), len(holdout)
147+
)
148+
149+
results = []
150+
for method_name in args.methods:
151+
LOGGER.info("== %s ==", method_name)
152+
if method_name not in METHOD_REGISTRY:
153+
LOGGER.warning("unknown method %r, skipping", method_name)
154+
continue
155+
method = METHOD_REGISTRY[method_name]()
156+
t0 = time.time()
157+
method.fit(sources={"ecps": train.copy()}, shared_cols=list(DEFAULT_CONDITION_COLS))
158+
fit_s = time.time() - t0
159+
160+
t0 = time.time()
161+
synthetic = method.generate(len(train), seed=args.seed)
162+
gen_s = time.time() - t0
163+
LOGGER.info(" fit=%.1fs gen=%.1fs n_synth=%d", fit_s, gen_s, len(synthetic))
164+
165+
constraints = build_target_constraints(
166+
holdout, synthetic, DEFAULT_TARGET_COLS
167+
)
168+
LOGGER.info(" %d calibration constraints", len(constraints))
169+
170+
synthetic = synthetic.copy()
171+
synthetic["weight"] = 1.0
172+
173+
# Rescale initial weights so synth totals sum to holdout-scale before
174+
# calibration. Otherwise gradient descent has to travel a long way.
175+
for col in DEFAULT_TARGET_COLS:
176+
if col not in holdout.columns or col not in synthetic.columns:
177+
continue
178+
r_sum = float(holdout[col].sum())
179+
s_sum = float(synthetic[col].sum())
180+
if r_sum > 0 and s_sum > 0:
181+
synthetic["weight"] = synthetic["weight"] * (r_sum / s_sum)
182+
break
183+
184+
pre_weights = synthetic["weight"].to_numpy(dtype=float)
185+
pre = evaluate_aggregates(holdout, synthetic, pre_weights, DEFAULT_TARGET_COLS)
186+
187+
adapter = MicrocalibrateAdapter(
188+
MicrocalibrateAdapterConfig(
189+
epochs=args.calibration_epochs,
190+
learning_rate=1e-3,
191+
noise_level=0.0,
192+
seed=args.seed,
193+
)
194+
)
195+
t0 = time.time()
196+
calibrated = adapter.fit_transform(
197+
synthetic,
198+
marginal_targets={},
199+
weight_col="weight",
200+
linear_constraints=constraints,
201+
)
202+
cal_s = time.time() - t0
203+
204+
post_weights = calibrated["weight"].to_numpy(dtype=float)
205+
post = evaluate_aggregates(
206+
holdout, calibrated, post_weights, DEFAULT_TARGET_COLS
207+
)
208+
validation = adapter.validate()
209+
210+
pre_mean_err = float(
211+
np.mean([v["relative_error"] for v in pre.values()])
212+
)
213+
post_mean_err = float(
214+
np.mean([v["relative_error"] for v in post.values()])
215+
)
216+
LOGGER.info(
217+
" pre-cal mean rel err = %.4f; post-cal mean rel err = %.4f; cal=%.1fs",
218+
pre_mean_err,
219+
post_mean_err,
220+
cal_s,
221+
)
222+
223+
results.append(
224+
{
225+
"method": method_name,
226+
"n_train": int(len(train)),
227+
"n_holdout": int(len(holdout)),
228+
"n_synthetic": int(len(synthetic)),
229+
"n_constraints": int(len(constraints)),
230+
"fit_wall_seconds": fit_s,
231+
"generate_wall_seconds": gen_s,
232+
"calibration_wall_seconds": cal_s,
233+
"pre_cal_mean_rel_err": pre_mean_err,
234+
"post_cal_mean_rel_err": post_mean_err,
235+
"calibration_max_error": validation["max_error"],
236+
"calibration_converged": validation["converged"],
237+
"pre_cal_per_target": pre,
238+
"post_cal_per_target": post,
239+
"calibrated_weights_summary": {
240+
"min": float(post_weights.min()),
241+
"max": float(post_weights.max()),
242+
"mean": float(post_weights.mean()),
243+
"std": float(post_weights.std()),
244+
"zero_fraction": float((post_weights == 0).mean()),
245+
},
246+
}
247+
)
248+
249+
args.output.parent.mkdir(parents=True, exist_ok=True)
250+
args.output.write_text(json.dumps(results, indent=2, default=str))
251+
252+
print()
253+
print("== Pre / post mean-relative-error per method ==")
254+
for r in sorted(results, key=lambda x: x["post_cal_mean_rel_err"]):
255+
print(
256+
f" {r['method']:8s}: pre={r['pre_cal_mean_rel_err']:.4f} "
257+
f"post={r['post_cal_mean_rel_err']:.4f} "
258+
f"max={r['calibration_max_error']:.4f} "
259+
f"cal={r['calibration_wall_seconds']:.1f}s"
260+
)
261+
262+
return 0
263+
264+
265+
if __name__ == "__main__":
266+
raise SystemExit(main())

0 commit comments

Comments
 (0)