Skip to content

Commit 03be559

Browse files
committed
JAX performance tuning and diagnostics.
1 parent 15a346e commit 03be559

8 files changed

Lines changed: 1316 additions & 102 deletions

File tree

config.yaml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,13 @@ osm_turnover_model:
138138
group_key: shared_label
139139
group_values: null
140140
min_value_count: 5
141-
n_draws: 250
141+
# NUTS warmup (window adaptation) and retained-sample counts. Warmup should
142+
# generally be >= n_samples for hierarchical models.
143+
n_warmup: 500
144+
n_samples: 500
145+
# Number of independent chains (vmapped in parallel). n_chains > 1 enables
146+
# R-hat and bulk ESS diagnostics at roughly linear wall-time cost on CPU.
147+
n_chains: 4
142148
save_full_model: true
143149

144150
# Directory definitions (used with config.get_dir_path())
@@ -172,6 +178,8 @@ directories:
172178
fitted_params: fitted_params.csv
173179
param_draws: param_draws.csv
174180
predictions: predictions.csv
181+
diagnostics: diagnostics.csv
182+
inference_data: inference_data.nc
175183
fitted_model: fitted_model.pkl
176184
snapshot_foursquare:
177185
versioned: true
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
"""
2+
Benchmark harness for the JAX hierarchical turnover model.
3+
4+
Runs ``RandomByTypeModel`` (and optionally ``ConstantModel``) at several data
5+
sizes and reports:
6+
7+
* Wall time for warmup and sampling (seconds)
8+
* NUTS diagnostics: mean acceptance, divergent count, mean integration steps,
9+
final step size
10+
* Per-parameter effective sample size (via ``blackjax.diagnostics``)
11+
* Log-density at the posterior mean (proxy for MAP)
12+
13+
Results are dumped to JSON so they can be diffed across commits. Run this
14+
BEFORE changing the model code to capture a baseline, then again after each
15+
phase to attribute effects.
16+
17+
Usage:
18+
python scripts/exploratory/bench_random_by_type.py \\
19+
--sizes small medium [large] [real] \\
20+
--out ~/data/openpois/bench/baseline.json \\
21+
[--num-draws 250]
22+
23+
Size presets:
24+
small — n = 10 000, K = 20
25+
medium — n = 1 000 000, K = 91
26+
large — n = 4 200 000, K = 91 (slow; ~matches production scale)
27+
real — reads real osm_observations.csv via config.yaml
28+
"""
29+
30+
from __future__ import annotations
31+
32+
import argparse
33+
import json
34+
import subprocess
35+
import time
36+
from datetime import datetime, timezone
37+
from pathlib import Path
38+
39+
import jax
40+
import jax.numpy as jnp
41+
import jax.random as jrd
42+
import numpy as np
43+
import pandas as pd
44+
45+
from blackjax.diagnostics import effective_sample_size
46+
47+
from openpois.models.jax_core import jax_rng
48+
from openpois.models.model_fitter import ModelFitter
49+
from openpois.models.osm_models import RandomByTypeModel
50+
from openpois.models.setup import prepare_data_for_model
51+
52+
53+
SIZE_PRESETS = {
54+
"small": dict(n = 10_000, k = 20, min_per_group = 5),
55+
"medium": dict(n = 1_000_000, k = 91, min_per_group = 5),
56+
"large": dict(n = 4_200_000, k = 91, min_per_group = 5),
57+
}
58+
59+
60+
def _simulate(
61+
key: jrd.KeyArray,
62+
n: int,
63+
k: int,
64+
min_per_group: int = 5,
65+
true_log_lambda_0: float = -5.3,
66+
true_log_sigma: float = 0.8,
67+
) -> pd.DataFrame:
68+
"""Simulate an observations DataFrame from the RandomByTypeModel likelihood."""
69+
k_eps, k_grp, k_dt, k_y = jrd.split(key, 4)
70+
# Simulate group epsilons from N(0, exp(log_sigma))
71+
eps = np.asarray(
72+
jrd.normal(k_eps, (k,)) * np.exp(true_log_sigma)
73+
)
74+
log_lam = true_log_lambda_0 + eps
75+
76+
# Group assignment with a power-law-ish imbalance so we test uneven fits.
77+
# Weights ~ 1/(i+1); renormalised. Then enforce min_per_group per group.
78+
weights = 1.0 / (np.arange(k) + 1.0)
79+
weights = weights / weights.sum()
80+
# Sample (n - k*min_per_group) according to weights, then add min_per_group per group
81+
assert n > k * min_per_group, "n too small for requested min_per_group"
82+
n_weighted = n - k * min_per_group
83+
g_rand = np.asarray(
84+
jrd.categorical(k_grp, jnp.log(jnp.asarray(weights)), shape = (n_weighted,))
85+
)
86+
g = np.concatenate([
87+
g_rand,
88+
np.repeat(np.arange(k), min_per_group),
89+
]).astype(np.int32)
90+
rng = np.random.default_rng(int(jrd.randint(k_grp, (), 0, 2**31 - 1)))
91+
rng.shuffle(g)
92+
93+
# dt ~ Uniform(0.1, 10)
94+
dt = np.asarray(jrd.uniform(k_dt, (n,), minval = 0.1, maxval = 10.0))
95+
96+
lam_per_obs = np.exp(log_lam[g])
97+
p = 1.0 - np.exp(-lam_per_obs * dt)
98+
y = np.asarray(jrd.bernoulli(k_y, jnp.asarray(p))).astype(np.int32)
99+
100+
# Use string group labels so the category encoding exercises the real code path.
101+
group_names = np.array([f"grp_{i:03d}" for i in range(k)])
102+
return pd.DataFrame({
103+
"tag_years": dt,
104+
"changed": y,
105+
"shared_label": group_names[g],
106+
})
107+
108+
109+
def _load_real_observations() -> pd.DataFrame:
110+
"""Load the real OSM observations via config.yaml."""
111+
from config_versioned import Config
112+
cfg = Config("~/repos/openpois/config.yaml")
113+
path = cfg.get_file_path("osm_data", "osm_observations")
114+
min_value_count = cfg.get(
115+
"osm_turnover_model", "min_value_count", fail_if_none = False
116+
)
117+
group_key = cfg.get(
118+
"osm_turnover_model", "group_key", fail_if_none = False
119+
)
120+
df = pd.read_csv(path)
121+
prepared = prepare_data_for_model(
122+
data = df,
123+
group_key = group_key,
124+
group_values = None,
125+
min_value_count = min_value_count,
126+
t1_col = "last_tag_timestamp",
127+
t2_col = "obs_timestamp",
128+
)
129+
return prepared
130+
131+
132+
def _ess_per_param(param_draws: dict[str, jnp.ndarray]) -> dict[str, float]:
133+
"""Minimum ESS across elements of each pytree leaf."""
134+
out = {}
135+
for name, arr in param_draws.items():
136+
a = jnp.asarray(arr)
137+
if a.ndim == 1:
138+
ess = float(effective_sample_size(a[None, :]))
139+
out[name] = ess
140+
else:
141+
# Multiple elements: report min ESS (worst-case)
142+
flat = a.reshape(a.shape[0], -1).T # (n_elem, n_draws)
143+
esss = np.asarray(
144+
jax.vmap(lambda row: effective_sample_size(row[None, :]))(flat)
145+
)
146+
out[f"{name}__min"] = float(esss.min())
147+
out[f"{name}__median"] = float(np.median(esss))
148+
return out
149+
150+
151+
def _log_density_at_mean(
152+
fitter: ModelFitter,
153+
param_draws: dict[str, jnp.ndarray],
154+
) -> float:
155+
"""Evaluate log-density at the element-wise posterior mean."""
156+
post_mean = {
157+
name: jnp.mean(jnp.asarray(arr), axis = 0)
158+
for name, arr in param_draws.items()
159+
}
160+
return float(fitter.calculate_lp(post_mean))
161+
162+
163+
def _git_sha() -> str:
164+
try:
165+
return subprocess.check_output(
166+
["git", "rev-parse", "--short", "HEAD"],
167+
cwd = str(Path(__file__).resolve().parents[2]),
168+
).decode().strip()
169+
except Exception:
170+
return "unknown"
171+
172+
173+
def _run_one(
174+
tag: str,
175+
df: pd.DataFrame,
176+
num_draws: int,
177+
group_key: str = "shared_label",
178+
) -> dict:
179+
"""Build the model, fit it, collect timings and diagnostics."""
180+
n = len(df)
181+
model = RandomByTypeModel(
182+
dataset = df,
183+
metadata = {
184+
"dt_col": "tag_years",
185+
"group": group_key,
186+
"var_prior": (-1.0, 5.0),
187+
},
188+
)
189+
k = model.group_lookup.shape[0]
190+
print(f"[{tag}] n={n:,} k={k} — building fitter")
191+
192+
fitter = ModelFitter(
193+
event_rate_fun = model.event_rate_fun,
194+
starting_params = model.starting_params,
195+
data = model.data,
196+
target = model.target,
197+
num_warmup = num_draws,
198+
num_samples = num_draws,
199+
param_likelihood = model.param_likelihood,
200+
derive_draws = model.derive_draws,
201+
log_likelihood_fun = model.log_likelihood_fun,
202+
verbose = False,
203+
)
204+
205+
t_fit_start = time.perf_counter()
206+
fitter.fit()
207+
# Ensure the draws are realised on device before we stop the clock.
208+
jax.tree_util.tree_map(lambda x: x.block_until_ready(), fitter.param_draws)
209+
t_fit_end = time.perf_counter()
210+
211+
info = fitter.sampler_info
212+
mean_accept = float(jnp.mean(info.acceptance_rate))
213+
divergences = int(jnp.sum(info.is_divergent))
214+
mean_steps = float(jnp.mean(info.num_integration_steps))
215+
step_size = float(fitter.warmup_params["step_size"])
216+
217+
ess = _ess_per_param(fitter.param_draws)
218+
log_density_at_mean = _log_density_at_mean(fitter, fitter.param_draws)
219+
220+
return {
221+
"tag": tag,
222+
"n": int(n),
223+
"k": int(k),
224+
"num_draws": int(num_draws),
225+
"wall_fit_s": round(t_fit_end - t_fit_start, 3),
226+
"mean_acceptance": round(mean_accept, 4),
227+
"divergences": divergences,
228+
"mean_integration_steps": round(mean_steps, 3),
229+
"final_step_size": round(step_size, 6),
230+
"log_density_at_post_mean": round(log_density_at_mean, 3),
231+
"ess": {k: round(v, 2) for k, v in ess.items()},
232+
}
233+
234+
235+
def main():
236+
parser = argparse.ArgumentParser(description = __doc__)
237+
parser.add_argument(
238+
"--sizes",
239+
nargs = "+",
240+
default = ["small"],
241+
choices = ["small", "medium", "large", "real"],
242+
help = "Which size presets to run.",
243+
)
244+
parser.add_argument(
245+
"--num-draws",
246+
type = int,
247+
default = 250,
248+
help = "Draws for both warmup and sampling (matches current default).",
249+
)
250+
parser.add_argument(
251+
"--out",
252+
type = str,
253+
default = "~/data/openpois/bench/bench_latest.json",
254+
help = "JSON output path.",
255+
)
256+
parser.add_argument("--seed", type = int, default = 0)
257+
args = parser.parse_args()
258+
259+
out_path = Path(args.out).expanduser()
260+
out_path.parent.mkdir(parents = True, exist_ok = True)
261+
262+
rng = jrd.PRNGKey(args.seed) if args.seed else jax_rng()
263+
264+
runs = []
265+
for size in args.sizes:
266+
if size == "real":
267+
df = _load_real_observations()
268+
runs.append(_run_one("real", df, num_draws = args.num_draws))
269+
continue
270+
preset = SIZE_PRESETS[size]
271+
key, rng = jrd.split(rng)
272+
df = _simulate(
273+
key,
274+
n = preset["n"],
275+
k = preset["k"],
276+
min_per_group = preset["min_per_group"],
277+
)
278+
runs.append(_run_one(size, df, num_draws = args.num_draws))
279+
280+
payload = {
281+
"created_at": datetime.now(timezone.utc).isoformat(),
282+
"git_sha": _git_sha(),
283+
"jax_version": jax.__version__,
284+
"platform": jax.default_backend(),
285+
"num_draws": args.num_draws,
286+
"seed": args.seed,
287+
"runs": runs,
288+
}
289+
290+
with open(out_path, "w") as f:
291+
json.dump(payload, f, indent = 2)
292+
print(f"Wrote {out_path}")
293+
for r in runs:
294+
print(
295+
f" {r['tag']:>6s}: n={r['n']:>9,} k={r['k']:>3} "
296+
f"fit={r['wall_fit_s']:>7.2f}s "
297+
f"accept={r['mean_acceptance']:.3f} "
298+
f"div={r['divergences']:>4} "
299+
f"step={r['final_step_size']:.4f}"
300+
)
301+
302+
303+
if __name__ == "__main__":
304+
main()

scripts/models/osm_turnover.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
osm_turnover_model.default_model_type — "constant" or "random_by_type"
2525
(overridable via --model-type)
2626
osm_turnover_model.var_prior — (loc, scale) hyperprior on log_sigma
27-
osm_turnover_model.n_draws — number of posterior draws
27+
osm_turnover_model.n_warmup — NUTS warmup steps (adaptation)
28+
osm_turnover_model.n_samples — posterior draws retained
29+
osm_turnover_model.n_chains — number of NUTS chains (vmapped)
2830
osm_turnover_model.save_full_model — save param_draws and pickled fitter
2931
3032
Prerequisites:
@@ -33,6 +35,8 @@
3335
Output files (in ``model_output`` directory):
3436
fitted_params.csv — posterior summaries per parameter
3537
predictions.csv — P(change) at t = 0.0..10.0 years per group
38+
diagnostics.csv — per-parameter R-hat / bulk-ESS (multi-chain only)
39+
inference_data.nc — ArviZ InferenceData (optional, if arviz installed)
3640
param_draws.csv — posterior draws (if save_full_model = true)
3741
fitted_model.pkl — pickled ModelFitter (if save_full_model = true)
3842
"""
@@ -60,7 +64,19 @@
6064
MIN_VALUE_COUNT = config.get(
6165
"osm_turnover_model", "min_value_count", fail_if_none = False
6266
)
63-
N_DRAWS = config.get("osm_turnover_model", "n_draws")
67+
N_WARMUP = config.get("osm_turnover_model", "n_warmup", fail_if_none = False)
68+
N_SAMPLES = config.get("osm_turnover_model", "n_samples", fail_if_none = False)
69+
N_CHAINS = config.get("osm_turnover_model", "n_chains", fail_if_none = False)
70+
# Back-compat: older configs used `n_draws` for both warmup and sampling.
71+
_LEGACY_N_DRAWS = config.get(
72+
"osm_turnover_model", "n_draws", fail_if_none = False
73+
)
74+
if N_WARMUP is None:
75+
N_WARMUP = _LEGACY_N_DRAWS if _LEGACY_N_DRAWS is not None else 1_000
76+
if N_SAMPLES is None:
77+
N_SAMPLES = _LEGACY_N_DRAWS if _LEGACY_N_DRAWS is not None else 1_000
78+
if N_CHAINS is None:
79+
N_CHAINS = 1
6480
SAVE_FULL_MODEL = config.get("osm_turnover_model", "save_full_model")
6581

6682

@@ -137,8 +153,12 @@ def flatten_param_draws(
137153
starting_params = model.starting_params,
138154
data = model.data,
139155
target = model.target,
140-
num_draws = N_DRAWS,
156+
num_warmup = N_WARMUP,
157+
num_samples = N_SAMPLES,
158+
num_chains = N_CHAINS,
141159
param_likelihood = model.param_likelihood,
160+
derive_draws = model.derive_draws,
161+
log_likelihood_fun = model.log_likelihood_fun,
142162
verbose = True,
143163
)
144164
fitter.fit()
@@ -176,6 +196,15 @@ def flatten_param_draws(
176196
# Save ----------------------------------------------------------------->
177197
config.write(fitted_params, "model_output", "fitted_params")
178198
config.write(predictions, "model_output", "predictions")
199+
if fitter.diagnostics is not None:
200+
config.write(fitter.diagnostics, "model_output", "diagnostics")
201+
try:
202+
idata = fitter.to_inference_data()
203+
idata.to_netcdf(
204+
str(config.get_file_path("model_output", "inference_data"))
205+
)
206+
except ImportError:
207+
print("arviz not installed — skipping inference_data.nc")
179208
if SAVE_FULL_MODEL:
180209
config.write(
181210
flatten_param_draws(fitter.get_parameter_draws()),

0 commit comments

Comments
 (0)