Skip to content

Commit 225eb36

Browse files
MaxGhenisclaude
andcommitted
Add per-column zero-rate breakdown + embedding-PRDC validation script
ScaleUpResult now includes zero_rate_per_column: for every column, the real zero-rate, synthetic zero-rate, and absolute difference. Lets the stage-1 doc identify which specific columns drive each method's overall zero-rate MAE — the pilot/stage-1 result showed every method drives disabled_ssdi to 0, but aggregate MAE of 0.18+ implies many other columns also diverge. scripts/embedding_prdc_compare.py: one-off validation script that fits a 16-dim autoencoder on the holdout, encodes real and synthetic to latent space, and reports PRDC both in the raw 50-dim feature space and in the learned 16-dim embedding. Settles whether the stage-1 ordering (ZI-QRF > ZI-QDNN > ZI-MAF) is a metric artifact from PRDC-in-high-dimensions or a genuine method difference. Usage: uv run python scripts/embedding_prdc_compare.py --n-rows 40000 Tests still pass (7/7). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 6763237 commit 225eb36

2 files changed

Lines changed: 289 additions & 0 deletions

File tree

scripts/embedding_prdc_compare.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
"""Compare raw-feature PRDC vs learned-embedding PRDC on the stage-1 methods.
2+
3+
The scale-up-protocol doc flagged that PRDC in ~50 dimensions may be
4+
degenerate (curse of dimensionality: k-NN distances concentrate and the
5+
metric becomes noise-dominated). This script settles the question.
6+
7+
Procedure:
8+
9+
1. Fit each of (ZI-QRF, ZI-MAF, ZI-QDNN) on 40k x 50 real ECPS.
10+
2. Generate synthetic records from each.
11+
3. Train a 16-dim autoencoder on the holdout's raw features only.
12+
4. Compute PRDC in the raw 50-dim feature space (unchanged from stage 1).
13+
5. Compute PRDC in the 16-dim learned latent space.
14+
6. Report both side-by-side. If the ordering changes, the stage-1
15+
finding was metric-driven not method-driven; if it's preserved, the
16+
finding is robust.
17+
18+
Usage:
19+
uv run python scripts/embedding_prdc_compare.py \
20+
--output artifacts/embedding_prdc_compare.json
21+
22+
Runs in ~5 minutes on 40 k rows x 50 cols (driven by ZI-MAF fit time).
23+
"""
24+
25+
from __future__ import annotations
26+
27+
import argparse
28+
import json
29+
import logging
30+
import time
31+
from pathlib import Path
32+
33+
import numpy as np
34+
import pandas as pd
35+
import torch
36+
import torch.nn as nn
37+
from prdc import compute_prdc
38+
from sklearn.preprocessing import StandardScaler
39+
40+
from microplex.eval.benchmark import ZIMAFMethod, ZIQDNNMethod, ZIQRFMethod
41+
from microplex_us.bakeoff import (
42+
DEFAULT_CONDITION_COLS,
43+
DEFAULT_TARGET_COLS,
44+
ScaleUpRunner,
45+
ScaleUpStageConfig,
46+
stage1_config,
47+
)
48+
49+
LOGGER = logging.getLogger(__name__)
50+
51+
52+
class Autoencoder(nn.Module):
53+
"""Tiny autoencoder for dimensionality reduction on tabular features."""
54+
55+
def __init__(self, n_features: int, latent_dim: int = 16, hidden: int = 64) -> None:
56+
super().__init__()
57+
self.encoder = nn.Sequential(
58+
nn.Linear(n_features, hidden),
59+
nn.ReLU(),
60+
nn.Linear(hidden, hidden),
61+
nn.ReLU(),
62+
nn.Linear(hidden, latent_dim),
63+
)
64+
self.decoder = nn.Sequential(
65+
nn.Linear(latent_dim, hidden),
66+
nn.ReLU(),
67+
nn.Linear(hidden, hidden),
68+
nn.ReLU(),
69+
nn.Linear(hidden, n_features),
70+
)
71+
72+
def forward(self, x: torch.Tensor) -> torch.Tensor:
73+
return self.decoder(self.encoder(x))
74+
75+
def encode(self, x: torch.Tensor) -> torch.Tensor:
76+
return self.encoder(x)
77+
78+
79+
def fit_autoencoder(
80+
x: np.ndarray, latent_dim: int = 16, epochs: int = 200, lr: float = 1e-3
81+
) -> Autoencoder:
82+
"""Fit an autoencoder on standardized features."""
83+
n_features = x.shape[1]
84+
model = Autoencoder(n_features=n_features, latent_dim=latent_dim)
85+
x_t = torch.tensor(x, dtype=torch.float32)
86+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
87+
batch_size = 256
88+
ds = torch.utils.data.TensorDataset(x_t)
89+
g = torch.Generator()
90+
g.manual_seed(42)
91+
loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, generator=g)
92+
93+
model.train()
94+
for epoch in range(epochs):
95+
total = 0.0
96+
for (batch,) in loader:
97+
optimizer.zero_grad()
98+
recon = model(batch)
99+
loss = ((recon - batch) ** 2).mean()
100+
loss.backward()
101+
optimizer.step()
102+
total += loss.item() * len(batch)
103+
if (epoch + 1) % 50 == 0:
104+
LOGGER.info(" AE epoch %d loss=%.4f", epoch + 1, total / len(x))
105+
model.eval()
106+
return model
107+
108+
109+
def encode(model: Autoencoder, x: np.ndarray) -> np.ndarray:
110+
with torch.no_grad():
111+
return model.encode(torch.tensor(x, dtype=torch.float32)).numpy()
112+
113+
114+
def compute_prdc_both_spaces(
115+
real: pd.DataFrame,
116+
synthetic: pd.DataFrame,
117+
encoder: Autoencoder,
118+
scaler: StandardScaler,
119+
k: int = 5,
120+
max_samples: int = 15_000,
121+
seed: int = 42,
122+
) -> dict:
123+
"""Return {raw: ..., embed: ...} PRDC tuples."""
124+
rng = np.random.default_rng(seed)
125+
cols = [c for c in real.columns if c in synthetic.columns]
126+
r = real[cols].to_numpy(dtype=np.float64)
127+
s = synthetic[cols].to_numpy(dtype=np.float64)
128+
if len(r) > max_samples:
129+
r = r[rng.choice(len(r), size=max_samples, replace=False)]
130+
if len(s) > max_samples:
131+
s = s[rng.choice(len(s), size=max_samples, replace=False)]
132+
133+
raw_r = scaler.transform(r)
134+
raw_s = scaler.transform(s)
135+
raw_metrics = compute_prdc(raw_r, raw_s, nearest_k=k)
136+
137+
emb_r = encode(encoder, raw_r.astype(np.float32))
138+
emb_s = encode(encoder, raw_s.astype(np.float32))
139+
emb_metrics = compute_prdc(emb_r, emb_s, nearest_k=k)
140+
141+
return {
142+
"raw": {k: float(v) for k, v in raw_metrics.items()},
143+
"embed": {k: float(v) for k, v in emb_metrics.items()},
144+
}
145+
146+
147+
def build_method(name: str):
148+
registry = {
149+
"ZI-QRF": ZIQRFMethod,
150+
"ZI-MAF": ZIMAFMethod,
151+
"ZI-QDNN": ZIQDNNMethod,
152+
}
153+
return registry[name]()
154+
155+
156+
def main(argv: list[str] | None = None) -> int:
157+
parser = argparse.ArgumentParser(description=__doc__)
158+
parser.add_argument("--n-rows", type=int, default=40_000)
159+
parser.add_argument(
160+
"--methods", nargs="+", default=["ZI-QRF", "ZI-MAF", "ZI-QDNN"]
161+
)
162+
parser.add_argument(
163+
"--output",
164+
type=Path,
165+
default=Path("artifacts/embedding_prdc_compare.json"),
166+
)
167+
parser.add_argument("--seed", type=int, default=42)
168+
parser.add_argument("--latent-dim", type=int, default=16)
169+
parser.add_argument("--ae-epochs", type=int, default=200)
170+
args = parser.parse_args(argv)
171+
172+
logging.basicConfig(
173+
level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
174+
)
175+
176+
base = stage1_config()
177+
cfg = ScaleUpStageConfig(
178+
stage="embedding_prdc",
179+
n_rows=args.n_rows,
180+
methods=tuple(args.methods),
181+
condition_cols=DEFAULT_CONDITION_COLS,
182+
target_cols=DEFAULT_TARGET_COLS,
183+
holdout_frac=0.2,
184+
seed=args.seed,
185+
k=5,
186+
data_path=base.data_path,
187+
year=base.year,
188+
rare_cell_checks=(),
189+
prdc_max_samples=15_000,
190+
)
191+
192+
runner = ScaleUpRunner(cfg)
193+
df = runner.load_frame()
194+
train, holdout = runner.split(df)
195+
LOGGER.info(
196+
"loaded: train=%d holdout=%d cols=%d", len(train), len(holdout), len(df.columns)
197+
)
198+
199+
scaler = StandardScaler().fit(holdout.to_numpy(dtype=np.float64))
200+
201+
LOGGER.info("fitting autoencoder on holdout...")
202+
t0 = time.time()
203+
encoder = fit_autoencoder(
204+
scaler.transform(holdout.to_numpy(dtype=np.float64)).astype(np.float32),
205+
latent_dim=args.latent_dim,
206+
epochs=args.ae_epochs,
207+
)
208+
LOGGER.info(" autoencoder fit=%.1fs", time.time() - t0)
209+
210+
results = []
211+
for method_name in args.methods:
212+
LOGGER.info("== %s ==", method_name)
213+
method = build_method(method_name)
214+
t0 = time.time()
215+
method.fit(sources={"ecps": train.copy()}, shared_cols=list(DEFAULT_CONDITION_COLS))
216+
fit_s = time.time() - t0
217+
218+
t0 = time.time()
219+
synth = method.generate(len(train), seed=args.seed)
220+
gen_s = time.time() - t0
221+
222+
metrics = compute_prdc_both_spaces(
223+
holdout, synth, encoder, scaler, k=5, seed=args.seed
224+
)
225+
LOGGER.info(
226+
" raw: prec=%.3f dens=%.3f cov=%.3f",
227+
metrics["raw"]["precision"],
228+
metrics["raw"]["density"],
229+
metrics["raw"]["coverage"],
230+
)
231+
LOGGER.info(
232+
" embed: prec=%.3f dens=%.3f cov=%.3f (fit=%.1fs gen=%.1fs)",
233+
metrics["embed"]["precision"],
234+
metrics["embed"]["density"],
235+
metrics["embed"]["coverage"],
236+
fit_s,
237+
gen_s,
238+
)
239+
results.append(
240+
{
241+
"method": method_name,
242+
"fit_wall_seconds": fit_s,
243+
"generate_wall_seconds": gen_s,
244+
**metrics,
245+
}
246+
)
247+
248+
args.output.parent.mkdir(parents=True, exist_ok=True)
249+
args.output.write_text(json.dumps(results, indent=2, default=str))
250+
251+
print()
252+
print("== Raw-feature PRDC (50-dim) ==")
253+
for r in sorted(results, key=lambda x: -x["raw"]["coverage"]):
254+
print(
255+
f" {r['method']:8s}: cov={r['raw']['coverage']:.3f} "
256+
f"prec={r['raw']['precision']:.3f} dens={r['raw']['density']:.3f}"
257+
)
258+
print()
259+
print(f"== Learned-embedding PRDC ({args.latent_dim}-dim) ==")
260+
for r in sorted(results, key=lambda x: -x["embed"]["coverage"]):
261+
print(
262+
f" {r['method']:8s}: cov={r['embed']['coverage']:.3f} "
263+
f"prec={r['embed']['precision']:.3f} dens={r['embed']['density']:.3f}"
264+
)
265+
return 0
266+
267+
268+
if __name__ == "__main__":
269+
raise SystemExit(main())

src/microplex_us/bakeoff/scale_up.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ class ScaleUpResult:
202202
coverage: float
203203
rare_cell_ratios: dict[str, float]
204204
zero_rate_mae: float
205+
zero_rate_per_column: dict[str, dict[str, float]] = field(default_factory=dict)
205206
notes: str = ""
206207

207208
def to_dict(self) -> dict[str, Any]:
@@ -407,6 +408,23 @@ def _compute_zero_rate_mae(real: pd.DataFrame, synthetic: pd.DataFrame) -> float
407408
return float(np.mean(errs)) if errs else 0.0
408409

409410

411+
def _compute_zero_rate_per_column(
412+
real: pd.DataFrame, synthetic: pd.DataFrame
413+
) -> dict[str, dict[str, float]]:
414+
"""Per-column {real_zero_rate, synth_zero_rate, abs_diff} breakdown."""
415+
cols = [c for c in real.columns if c in synthetic.columns]
416+
out: dict[str, dict[str, float]] = {}
417+
for c in cols:
418+
r_zero = float((real[c] == 0).mean())
419+
s_zero = float((synthetic[c] == 0).mean())
420+
out[c] = {
421+
"real": r_zero,
422+
"synth": s_zero,
423+
"abs_diff": abs(r_zero - s_zero),
424+
}
425+
return out
426+
427+
410428
def _compute_prdc(
411429
real: pd.DataFrame,
412430
synthetic: pd.DataFrame,
@@ -614,6 +632,7 @@ def run(
614632
holdout, synthetic, self.config.rare_cell_checks
615633
)
616634
zero_mae = _compute_zero_rate_mae(holdout, synthetic)
635+
zero_per_col = _compute_zero_rate_per_column(holdout, synthetic)
617636

618637
result = ScaleUpResult(
619638
stage=self.config.stage,
@@ -630,6 +649,7 @@ def run(
630649
coverage=coverage,
631650
rare_cell_ratios=rare,
632651
zero_rate_mae=zero_mae,
652+
zero_rate_per_column=zero_per_col,
633653
notes="",
634654
)
635655
results.append(result)

0 commit comments

Comments
 (0)