|
| 1 | +""" |
| 2 | +Benchmark harness for the JAX hierarchical turnover model. |
| 3 | +
|
| 4 | +Runs ``RandomByTypeModel`` (and optionally ``ConstantModel``) at several data |
| 5 | +sizes and reports: |
| 6 | +
|
| 7 | +* Wall time for warmup and sampling (seconds) |
| 8 | +* NUTS diagnostics: mean acceptance, divergent count, mean integration steps, |
| 9 | + final step size |
| 10 | +* Per-parameter effective sample size (via ``blackjax.diagnostics``) |
| 11 | +* Log-density at the posterior mean (proxy for MAP) |
| 12 | +
|
| 13 | +Results are dumped to JSON so they can be diffed across commits. Run this |
| 14 | +BEFORE changing the model code to capture a baseline, then again after each |
| 15 | +phase to attribute effects. |
| 16 | +
|
| 17 | +Usage: |
| 18 | + python scripts/exploratory/bench_random_by_type.py \\ |
| 19 | + --sizes small medium [large] [real] \\ |
| 20 | + --out ~/data/openpois/bench/baseline.json \\ |
| 21 | + [--num-draws 250] |
| 22 | +
|
| 23 | +Size presets: |
| 24 | + small — n = 10 000, K = 20 |
| 25 | + medium — n = 1 000 000, K = 91 |
| 26 | + large — n = 4 200 000, K = 91 (slow; ~matches production scale) |
| 27 | + real — reads real osm_observations.csv via config.yaml |
| 28 | +""" |
| 29 | + |
| 30 | +from __future__ import annotations |
| 31 | + |
| 32 | +import argparse |
| 33 | +import json |
| 34 | +import subprocess |
| 35 | +import time |
| 36 | +from datetime import datetime, timezone |
| 37 | +from pathlib import Path |
| 38 | + |
| 39 | +import jax |
| 40 | +import jax.numpy as jnp |
| 41 | +import jax.random as jrd |
| 42 | +import numpy as np |
| 43 | +import pandas as pd |
| 44 | + |
| 45 | +from blackjax.diagnostics import effective_sample_size |
| 46 | + |
| 47 | +from openpois.models.jax_core import jax_rng |
| 48 | +from openpois.models.model_fitter import ModelFitter |
| 49 | +from openpois.models.osm_models import RandomByTypeModel |
| 50 | +from openpois.models.setup import prepare_data_for_model |
| 51 | + |
| 52 | + |
| 53 | +SIZE_PRESETS = { |
| 54 | + "small": dict(n = 10_000, k = 20, min_per_group = 5), |
| 55 | + "medium": dict(n = 1_000_000, k = 91, min_per_group = 5), |
| 56 | + "large": dict(n = 4_200_000, k = 91, min_per_group = 5), |
| 57 | +} |
| 58 | + |
| 59 | + |
| 60 | +def _simulate( |
| 61 | + key: jrd.KeyArray, |
| 62 | + n: int, |
| 63 | + k: int, |
| 64 | + min_per_group: int = 5, |
| 65 | + true_log_lambda_0: float = -5.3, |
| 66 | + true_log_sigma: float = 0.8, |
| 67 | +) -> pd.DataFrame: |
| 68 | + """Simulate an observations DataFrame from the RandomByTypeModel likelihood.""" |
| 69 | + k_eps, k_grp, k_dt, k_y = jrd.split(key, 4) |
| 70 | + # Simulate group epsilons from N(0, exp(log_sigma)) |
| 71 | + eps = np.asarray( |
| 72 | + jrd.normal(k_eps, (k,)) * np.exp(true_log_sigma) |
| 73 | + ) |
| 74 | + log_lam = true_log_lambda_0 + eps |
| 75 | + |
| 76 | + # Group assignment with a power-law-ish imbalance so we test uneven fits. |
| 77 | + # Weights ~ 1/(i+1); renormalised. Then enforce min_per_group per group. |
| 78 | + weights = 1.0 / (np.arange(k) + 1.0) |
| 79 | + weights = weights / weights.sum() |
| 80 | + # Sample (n - k*min_per_group) according to weights, then add min_per_group per group |
| 81 | + assert n > k * min_per_group, "n too small for requested min_per_group" |
| 82 | + n_weighted = n - k * min_per_group |
| 83 | + g_rand = np.asarray( |
| 84 | + jrd.categorical(k_grp, jnp.log(jnp.asarray(weights)), shape = (n_weighted,)) |
| 85 | + ) |
| 86 | + g = np.concatenate([ |
| 87 | + g_rand, |
| 88 | + np.repeat(np.arange(k), min_per_group), |
| 89 | + ]).astype(np.int32) |
| 90 | + rng = np.random.default_rng(int(jrd.randint(k_grp, (), 0, 2**31 - 1))) |
| 91 | + rng.shuffle(g) |
| 92 | + |
| 93 | + # dt ~ Uniform(0.1, 10) |
| 94 | + dt = np.asarray(jrd.uniform(k_dt, (n,), minval = 0.1, maxval = 10.0)) |
| 95 | + |
| 96 | + lam_per_obs = np.exp(log_lam[g]) |
| 97 | + p = 1.0 - np.exp(-lam_per_obs * dt) |
| 98 | + y = np.asarray(jrd.bernoulli(k_y, jnp.asarray(p))).astype(np.int32) |
| 99 | + |
| 100 | + # Use string group labels so the category encoding exercises the real code path. |
| 101 | + group_names = np.array([f"grp_{i:03d}" for i in range(k)]) |
| 102 | + return pd.DataFrame({ |
| 103 | + "tag_years": dt, |
| 104 | + "changed": y, |
| 105 | + "shared_label": group_names[g], |
| 106 | + }) |
| 107 | + |
| 108 | + |
| 109 | +def _load_real_observations() -> pd.DataFrame: |
| 110 | + """Load the real OSM observations via config.yaml.""" |
| 111 | + from config_versioned import Config |
| 112 | + cfg = Config("~/repos/openpois/config.yaml") |
| 113 | + path = cfg.get_file_path("osm_data", "osm_observations") |
| 114 | + min_value_count = cfg.get( |
| 115 | + "osm_turnover_model", "min_value_count", fail_if_none = False |
| 116 | + ) |
| 117 | + group_key = cfg.get( |
| 118 | + "osm_turnover_model", "group_key", fail_if_none = False |
| 119 | + ) |
| 120 | + df = pd.read_csv(path) |
| 121 | + prepared = prepare_data_for_model( |
| 122 | + data = df, |
| 123 | + group_key = group_key, |
| 124 | + group_values = None, |
| 125 | + min_value_count = min_value_count, |
| 126 | + t1_col = "last_tag_timestamp", |
| 127 | + t2_col = "obs_timestamp", |
| 128 | + ) |
| 129 | + return prepared |
| 130 | + |
| 131 | + |
| 132 | +def _ess_per_param(param_draws: dict[str, jnp.ndarray]) -> dict[str, float]: |
| 133 | + """Minimum ESS across elements of each pytree leaf.""" |
| 134 | + out = {} |
| 135 | + for name, arr in param_draws.items(): |
| 136 | + a = jnp.asarray(arr) |
| 137 | + if a.ndim == 1: |
| 138 | + ess = float(effective_sample_size(a[None, :])) |
| 139 | + out[name] = ess |
| 140 | + else: |
| 141 | + # Multiple elements: report min ESS (worst-case) |
| 142 | + flat = a.reshape(a.shape[0], -1).T # (n_elem, n_draws) |
| 143 | + esss = np.asarray( |
| 144 | + jax.vmap(lambda row: effective_sample_size(row[None, :]))(flat) |
| 145 | + ) |
| 146 | + out[f"{name}__min"] = float(esss.min()) |
| 147 | + out[f"{name}__median"] = float(np.median(esss)) |
| 148 | + return out |
| 149 | + |
| 150 | + |
| 151 | +def _log_density_at_mean( |
| 152 | + fitter: ModelFitter, |
| 153 | + param_draws: dict[str, jnp.ndarray], |
| 154 | +) -> float: |
| 155 | + """Evaluate log-density at the element-wise posterior mean.""" |
| 156 | + post_mean = { |
| 157 | + name: jnp.mean(jnp.asarray(arr), axis = 0) |
| 158 | + for name, arr in param_draws.items() |
| 159 | + } |
| 160 | + return float(fitter.calculate_lp(post_mean)) |
| 161 | + |
| 162 | + |
| 163 | +def _git_sha() -> str: |
| 164 | + try: |
| 165 | + return subprocess.check_output( |
| 166 | + ["git", "rev-parse", "--short", "HEAD"], |
| 167 | + cwd = str(Path(__file__).resolve().parents[2]), |
| 168 | + ).decode().strip() |
| 169 | + except Exception: |
| 170 | + return "unknown" |
| 171 | + |
| 172 | + |
| 173 | +def _run_one( |
| 174 | + tag: str, |
| 175 | + df: pd.DataFrame, |
| 176 | + num_draws: int, |
| 177 | + group_key: str = "shared_label", |
| 178 | +) -> dict: |
| 179 | + """Build the model, fit it, collect timings and diagnostics.""" |
| 180 | + n = len(df) |
| 181 | + model = RandomByTypeModel( |
| 182 | + dataset = df, |
| 183 | + metadata = { |
| 184 | + "dt_col": "tag_years", |
| 185 | + "group": group_key, |
| 186 | + "var_prior": (-1.0, 5.0), |
| 187 | + }, |
| 188 | + ) |
| 189 | + k = model.group_lookup.shape[0] |
| 190 | + print(f"[{tag}] n={n:,} k={k} — building fitter") |
| 191 | + |
| 192 | + fitter = ModelFitter( |
| 193 | + event_rate_fun = model.event_rate_fun, |
| 194 | + starting_params = model.starting_params, |
| 195 | + data = model.data, |
| 196 | + target = model.target, |
| 197 | + num_warmup = num_draws, |
| 198 | + num_samples = num_draws, |
| 199 | + param_likelihood = model.param_likelihood, |
| 200 | + derive_draws = model.derive_draws, |
| 201 | + log_likelihood_fun = model.log_likelihood_fun, |
| 202 | + verbose = False, |
| 203 | + ) |
| 204 | + |
| 205 | + t_fit_start = time.perf_counter() |
| 206 | + fitter.fit() |
| 207 | + # Ensure the draws are realised on device before we stop the clock. |
| 208 | + jax.tree_util.tree_map(lambda x: x.block_until_ready(), fitter.param_draws) |
| 209 | + t_fit_end = time.perf_counter() |
| 210 | + |
| 211 | + info = fitter.sampler_info |
| 212 | + mean_accept = float(jnp.mean(info.acceptance_rate)) |
| 213 | + divergences = int(jnp.sum(info.is_divergent)) |
| 214 | + mean_steps = float(jnp.mean(info.num_integration_steps)) |
| 215 | + step_size = float(fitter.warmup_params["step_size"]) |
| 216 | + |
| 217 | + ess = _ess_per_param(fitter.param_draws) |
| 218 | + log_density_at_mean = _log_density_at_mean(fitter, fitter.param_draws) |
| 219 | + |
| 220 | + return { |
| 221 | + "tag": tag, |
| 222 | + "n": int(n), |
| 223 | + "k": int(k), |
| 224 | + "num_draws": int(num_draws), |
| 225 | + "wall_fit_s": round(t_fit_end - t_fit_start, 3), |
| 226 | + "mean_acceptance": round(mean_accept, 4), |
| 227 | + "divergences": divergences, |
| 228 | + "mean_integration_steps": round(mean_steps, 3), |
| 229 | + "final_step_size": round(step_size, 6), |
| 230 | + "log_density_at_post_mean": round(log_density_at_mean, 3), |
| 231 | + "ess": {k: round(v, 2) for k, v in ess.items()}, |
| 232 | + } |
| 233 | + |
| 234 | + |
| 235 | +def main(): |
| 236 | + parser = argparse.ArgumentParser(description = __doc__) |
| 237 | + parser.add_argument( |
| 238 | + "--sizes", |
| 239 | + nargs = "+", |
| 240 | + default = ["small"], |
| 241 | + choices = ["small", "medium", "large", "real"], |
| 242 | + help = "Which size presets to run.", |
| 243 | + ) |
| 244 | + parser.add_argument( |
| 245 | + "--num-draws", |
| 246 | + type = int, |
| 247 | + default = 250, |
| 248 | + help = "Draws for both warmup and sampling (matches current default).", |
| 249 | + ) |
| 250 | + parser.add_argument( |
| 251 | + "--out", |
| 252 | + type = str, |
| 253 | + default = "~/data/openpois/bench/bench_latest.json", |
| 254 | + help = "JSON output path.", |
| 255 | + ) |
| 256 | + parser.add_argument("--seed", type = int, default = 0) |
| 257 | + args = parser.parse_args() |
| 258 | + |
| 259 | + out_path = Path(args.out).expanduser() |
| 260 | + out_path.parent.mkdir(parents = True, exist_ok = True) |
| 261 | + |
| 262 | + rng = jrd.PRNGKey(args.seed) if args.seed else jax_rng() |
| 263 | + |
| 264 | + runs = [] |
| 265 | + for size in args.sizes: |
| 266 | + if size == "real": |
| 267 | + df = _load_real_observations() |
| 268 | + runs.append(_run_one("real", df, num_draws = args.num_draws)) |
| 269 | + continue |
| 270 | + preset = SIZE_PRESETS[size] |
| 271 | + key, rng = jrd.split(rng) |
| 272 | + df = _simulate( |
| 273 | + key, |
| 274 | + n = preset["n"], |
| 275 | + k = preset["k"], |
| 276 | + min_per_group = preset["min_per_group"], |
| 277 | + ) |
| 278 | + runs.append(_run_one(size, df, num_draws = args.num_draws)) |
| 279 | + |
| 280 | + payload = { |
| 281 | + "created_at": datetime.now(timezone.utc).isoformat(), |
| 282 | + "git_sha": _git_sha(), |
| 283 | + "jax_version": jax.__version__, |
| 284 | + "platform": jax.default_backend(), |
| 285 | + "num_draws": args.num_draws, |
| 286 | + "seed": args.seed, |
| 287 | + "runs": runs, |
| 288 | + } |
| 289 | + |
| 290 | + with open(out_path, "w") as f: |
| 291 | + json.dump(payload, f, indent = 2) |
| 292 | + print(f"Wrote {out_path}") |
| 293 | + for r in runs: |
| 294 | + print( |
| 295 | + f" {r['tag']:>6s}: n={r['n']:>9,} k={r['k']:>3} " |
| 296 | + f"fit={r['wall_fit_s']:>7.2f}s " |
| 297 | + f"accept={r['mean_acceptance']:.3f} " |
| 298 | + f"div={r['divergences']:>4} " |
| 299 | + f"step={r['final_step_size']:.4f}" |
| 300 | + ) |
| 301 | + |
| 302 | + |
| 303 | +if __name__ == "__main__": |
| 304 | + main() |
0 commit comments