|
| 1 | +"""Compare raw-feature PRDC vs learned-embedding PRDC on the stage-1 methods. |
| 2 | +
|
| 3 | +The scale-up-protocol doc flagged that PRDC in ~50 dimensions may be |
| 4 | +degenerate (curse of dimensionality: k-NN distances concentrate and the |
| 5 | +metric becomes noise-dominated). This script settles the question. |
| 6 | +
|
| 7 | +Procedure: |
| 8 | +
|
| 9 | +1. Fit each of (ZI-QRF, ZI-MAF, ZI-QDNN) on 40k x 50 real ECPS. |
| 10 | +2. Generate synthetic records from each. |
| 11 | +3. Train a 16-dim autoencoder on the holdout's raw features only. |
| 12 | +4. Compute PRDC in the raw 50-dim feature space (unchanged from stage 1). |
| 13 | +5. Compute PRDC in the 16-dim learned latent space. |
| 14 | +6. Report both side-by-side. If the ordering changes, the stage-1 |
| 15 | + finding was metric-driven not method-driven; if it's preserved, the |
| 16 | + finding is robust. |
| 17 | +
|
| 18 | +Usage: |
| 19 | + uv run python scripts/embedding_prdc_compare.py \ |
| 20 | + --output artifacts/embedding_prdc_compare.json |
| 21 | +
|
| 22 | +Runs in ~5 minutes on 40 k rows x 50 cols (driven by ZI-MAF fit time). |
| 23 | +""" |
| 24 | + |
| 25 | +from __future__ import annotations |
| 26 | + |
| 27 | +import argparse |
| 28 | +import json |
| 29 | +import logging |
| 30 | +import time |
| 31 | +from pathlib import Path |
| 32 | + |
| 33 | +import numpy as np |
| 34 | +import pandas as pd |
| 35 | +import torch |
| 36 | +import torch.nn as nn |
| 37 | +from prdc import compute_prdc |
| 38 | +from sklearn.preprocessing import StandardScaler |
| 39 | + |
| 40 | +from microplex.eval.benchmark import ZIMAFMethod, ZIQDNNMethod, ZIQRFMethod |
| 41 | +from microplex_us.bakeoff import ( |
| 42 | + DEFAULT_CONDITION_COLS, |
| 43 | + DEFAULT_TARGET_COLS, |
| 44 | + ScaleUpRunner, |
| 45 | + ScaleUpStageConfig, |
| 46 | + stage1_config, |
| 47 | +) |
| 48 | + |
| 49 | +LOGGER = logging.getLogger(__name__) |
| 50 | + |
| 51 | + |
| 52 | +class Autoencoder(nn.Module): |
| 53 | + """Tiny autoencoder for dimensionality reduction on tabular features.""" |
| 54 | + |
| 55 | + def __init__(self, n_features: int, latent_dim: int = 16, hidden: int = 64) -> None: |
| 56 | + super().__init__() |
| 57 | + self.encoder = nn.Sequential( |
| 58 | + nn.Linear(n_features, hidden), |
| 59 | + nn.ReLU(), |
| 60 | + nn.Linear(hidden, hidden), |
| 61 | + nn.ReLU(), |
| 62 | + nn.Linear(hidden, latent_dim), |
| 63 | + ) |
| 64 | + self.decoder = nn.Sequential( |
| 65 | + nn.Linear(latent_dim, hidden), |
| 66 | + nn.ReLU(), |
| 67 | + nn.Linear(hidden, hidden), |
| 68 | + nn.ReLU(), |
| 69 | + nn.Linear(hidden, n_features), |
| 70 | + ) |
| 71 | + |
| 72 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 73 | + return self.decoder(self.encoder(x)) |
| 74 | + |
| 75 | + def encode(self, x: torch.Tensor) -> torch.Tensor: |
| 76 | + return self.encoder(x) |
| 77 | + |
| 78 | + |
| 79 | +def fit_autoencoder( |
| 80 | + x: np.ndarray, latent_dim: int = 16, epochs: int = 200, lr: float = 1e-3 |
| 81 | +) -> Autoencoder: |
| 82 | + """Fit an autoencoder on standardized features.""" |
| 83 | + n_features = x.shape[1] |
| 84 | + model = Autoencoder(n_features=n_features, latent_dim=latent_dim) |
| 85 | + x_t = torch.tensor(x, dtype=torch.float32) |
| 86 | + optimizer = torch.optim.Adam(model.parameters(), lr=lr) |
| 87 | + batch_size = 256 |
| 88 | + ds = torch.utils.data.TensorDataset(x_t) |
| 89 | + g = torch.Generator() |
| 90 | + g.manual_seed(42) |
| 91 | + loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, generator=g) |
| 92 | + |
| 93 | + model.train() |
| 94 | + for epoch in range(epochs): |
| 95 | + total = 0.0 |
| 96 | + for (batch,) in loader: |
| 97 | + optimizer.zero_grad() |
| 98 | + recon = model(batch) |
| 99 | + loss = ((recon - batch) ** 2).mean() |
| 100 | + loss.backward() |
| 101 | + optimizer.step() |
| 102 | + total += loss.item() * len(batch) |
| 103 | + if (epoch + 1) % 50 == 0: |
| 104 | + LOGGER.info(" AE epoch %d loss=%.4f", epoch + 1, total / len(x)) |
| 105 | + model.eval() |
| 106 | + return model |
| 107 | + |
| 108 | + |
| 109 | +def encode(model: Autoencoder, x: np.ndarray) -> np.ndarray: |
| 110 | + with torch.no_grad(): |
| 111 | + return model.encode(torch.tensor(x, dtype=torch.float32)).numpy() |
| 112 | + |
| 113 | + |
| 114 | +def compute_prdc_both_spaces( |
| 115 | + real: pd.DataFrame, |
| 116 | + synthetic: pd.DataFrame, |
| 117 | + encoder: Autoencoder, |
| 118 | + scaler: StandardScaler, |
| 119 | + k: int = 5, |
| 120 | + max_samples: int = 15_000, |
| 121 | + seed: int = 42, |
| 122 | +) -> dict: |
| 123 | + """Return {raw: ..., embed: ...} PRDC tuples.""" |
| 124 | + rng = np.random.default_rng(seed) |
| 125 | + cols = [c for c in real.columns if c in synthetic.columns] |
| 126 | + r = real[cols].to_numpy(dtype=np.float64) |
| 127 | + s = synthetic[cols].to_numpy(dtype=np.float64) |
| 128 | + if len(r) > max_samples: |
| 129 | + r = r[rng.choice(len(r), size=max_samples, replace=False)] |
| 130 | + if len(s) > max_samples: |
| 131 | + s = s[rng.choice(len(s), size=max_samples, replace=False)] |
| 132 | + |
| 133 | + raw_r = scaler.transform(r) |
| 134 | + raw_s = scaler.transform(s) |
| 135 | + raw_metrics = compute_prdc(raw_r, raw_s, nearest_k=k) |
| 136 | + |
| 137 | + emb_r = encode(encoder, raw_r.astype(np.float32)) |
| 138 | + emb_s = encode(encoder, raw_s.astype(np.float32)) |
| 139 | + emb_metrics = compute_prdc(emb_r, emb_s, nearest_k=k) |
| 140 | + |
| 141 | + return { |
| 142 | + "raw": {k: float(v) for k, v in raw_metrics.items()}, |
| 143 | + "embed": {k: float(v) for k, v in emb_metrics.items()}, |
| 144 | + } |
| 145 | + |
| 146 | + |
| 147 | +def build_method(name: str): |
| 148 | + registry = { |
| 149 | + "ZI-QRF": ZIQRFMethod, |
| 150 | + "ZI-MAF": ZIMAFMethod, |
| 151 | + "ZI-QDNN": ZIQDNNMethod, |
| 152 | + } |
| 153 | + return registry[name]() |
| 154 | + |
| 155 | + |
| 156 | +def main(argv: list[str] | None = None) -> int: |
| 157 | + parser = argparse.ArgumentParser(description=__doc__) |
| 158 | + parser.add_argument("--n-rows", type=int, default=40_000) |
| 159 | + parser.add_argument( |
| 160 | + "--methods", nargs="+", default=["ZI-QRF", "ZI-MAF", "ZI-QDNN"] |
| 161 | + ) |
| 162 | + parser.add_argument( |
| 163 | + "--output", |
| 164 | + type=Path, |
| 165 | + default=Path("artifacts/embedding_prdc_compare.json"), |
| 166 | + ) |
| 167 | + parser.add_argument("--seed", type=int, default=42) |
| 168 | + parser.add_argument("--latent-dim", type=int, default=16) |
| 169 | + parser.add_argument("--ae-epochs", type=int, default=200) |
| 170 | + args = parser.parse_args(argv) |
| 171 | + |
| 172 | + logging.basicConfig( |
| 173 | + level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s" |
| 174 | + ) |
| 175 | + |
| 176 | + base = stage1_config() |
| 177 | + cfg = ScaleUpStageConfig( |
| 178 | + stage="embedding_prdc", |
| 179 | + n_rows=args.n_rows, |
| 180 | + methods=tuple(args.methods), |
| 181 | + condition_cols=DEFAULT_CONDITION_COLS, |
| 182 | + target_cols=DEFAULT_TARGET_COLS, |
| 183 | + holdout_frac=0.2, |
| 184 | + seed=args.seed, |
| 185 | + k=5, |
| 186 | + data_path=base.data_path, |
| 187 | + year=base.year, |
| 188 | + rare_cell_checks=(), |
| 189 | + prdc_max_samples=15_000, |
| 190 | + ) |
| 191 | + |
| 192 | + runner = ScaleUpRunner(cfg) |
| 193 | + df = runner.load_frame() |
| 194 | + train, holdout = runner.split(df) |
| 195 | + LOGGER.info( |
| 196 | + "loaded: train=%d holdout=%d cols=%d", len(train), len(holdout), len(df.columns) |
| 197 | + ) |
| 198 | + |
| 199 | + scaler = StandardScaler().fit(holdout.to_numpy(dtype=np.float64)) |
| 200 | + |
| 201 | + LOGGER.info("fitting autoencoder on holdout...") |
| 202 | + t0 = time.time() |
| 203 | + encoder = fit_autoencoder( |
| 204 | + scaler.transform(holdout.to_numpy(dtype=np.float64)).astype(np.float32), |
| 205 | + latent_dim=args.latent_dim, |
| 206 | + epochs=args.ae_epochs, |
| 207 | + ) |
| 208 | + LOGGER.info(" autoencoder fit=%.1fs", time.time() - t0) |
| 209 | + |
| 210 | + results = [] |
| 211 | + for method_name in args.methods: |
| 212 | + LOGGER.info("== %s ==", method_name) |
| 213 | + method = build_method(method_name) |
| 214 | + t0 = time.time() |
| 215 | + method.fit(sources={"ecps": train.copy()}, shared_cols=list(DEFAULT_CONDITION_COLS)) |
| 216 | + fit_s = time.time() - t0 |
| 217 | + |
| 218 | + t0 = time.time() |
| 219 | + synth = method.generate(len(train), seed=args.seed) |
| 220 | + gen_s = time.time() - t0 |
| 221 | + |
| 222 | + metrics = compute_prdc_both_spaces( |
| 223 | + holdout, synth, encoder, scaler, k=5, seed=args.seed |
| 224 | + ) |
| 225 | + LOGGER.info( |
| 226 | + " raw: prec=%.3f dens=%.3f cov=%.3f", |
| 227 | + metrics["raw"]["precision"], |
| 228 | + metrics["raw"]["density"], |
| 229 | + metrics["raw"]["coverage"], |
| 230 | + ) |
| 231 | + LOGGER.info( |
| 232 | + " embed: prec=%.3f dens=%.3f cov=%.3f (fit=%.1fs gen=%.1fs)", |
| 233 | + metrics["embed"]["precision"], |
| 234 | + metrics["embed"]["density"], |
| 235 | + metrics["embed"]["coverage"], |
| 236 | + fit_s, |
| 237 | + gen_s, |
| 238 | + ) |
| 239 | + results.append( |
| 240 | + { |
| 241 | + "method": method_name, |
| 242 | + "fit_wall_seconds": fit_s, |
| 243 | + "generate_wall_seconds": gen_s, |
| 244 | + **metrics, |
| 245 | + } |
| 246 | + ) |
| 247 | + |
| 248 | + args.output.parent.mkdir(parents=True, exist_ok=True) |
| 249 | + args.output.write_text(json.dumps(results, indent=2, default=str)) |
| 250 | + |
| 251 | + print() |
| 252 | + print("== Raw-feature PRDC (50-dim) ==") |
| 253 | + for r in sorted(results, key=lambda x: -x["raw"]["coverage"]): |
| 254 | + print( |
| 255 | + f" {r['method']:8s}: cov={r['raw']['coverage']:.3f} " |
| 256 | + f"prec={r['raw']['precision']:.3f} dens={r['raw']['density']:.3f}" |
| 257 | + ) |
| 258 | + print() |
| 259 | + print(f"== Learned-embedding PRDC ({args.latent_dim}-dim) ==") |
| 260 | + for r in sorted(results, key=lambda x: -x["embed"]["coverage"]): |
| 261 | + print( |
| 262 | + f" {r['method']:8s}: cov={r['embed']['coverage']:.3f} " |
| 263 | + f"prec={r['embed']['precision']:.3f} dens={r['embed']['density']:.3f}" |
| 264 | + ) |
| 265 | + return 0 |
| 266 | + |
| 267 | + |
| 268 | +if __name__ == "__main__": |
| 269 | + raise SystemExit(main()) |
0 commit comments