Skip to content

Commit 7c72b8e

Browse files
committed
Fully working random effects model; update viz for clarity.
1 parent 03be559 commit 7c72b8e

8 files changed

Lines changed: 29 additions & 26 deletions

File tree

.claude/CLAUDE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Style: Black (format-on-save in VSCode). Lint: flake8 + pylint, configured in `p
4949
- [docs/data-sources.md](docs/data-sources.md) — URLs, auth, schema quirks for every source
5050
- [docs/taxonomy-setup.md](docs/taxonomy-setup.md) — crosswalk CSVs, build_taxonomy.py, frontend sync
5151
- [docs/data-versioning.md](docs/data-versioning.md)`versions:` block, path resolution, external references
52+
- [docs/turnover-model-methodology.md](docs/turnover-model-methodology.md) — statistical derivation of the POI turnover model with ZIE extension
5253

5354
## Running to-do
5455

.claude/skills/model-history-pipeline/SKILL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ End-to-end: Geofabrik full-history PBFs → observations table → fitted λ →
3131
```bash
3232
python scripts/models/osm_turnover.py
3333
```
34-
Writes `fitted_params.csv`, `param_draws.csv`, `predictions.csv`, `fitted_model.pkl` to `{date}_by_shared_label` (the unified random-effects model) or `{date}_constant` (single-rate baseline) under `directories.model_output.path`.
34+
Writes `fitted_params.csv`, `param_draws.csv`, `predictions.csv` to `{date}_by_shared_label` (the unified random-effects model) or `{date}_constant` (single-rate baseline) under `directories.model_output.path`.
3535

3636
4. **Apply predictions to the OSM snapshot**`osm_snapshot_rated.parquet`
3737
```bash

.claude/skills/verify-pipeline-run/SKILL.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ Flag >5% drops. Known regression patterns:
3232
fitted_params.csv # λ and σ per group
3333
param_draws.csv # uncertainty bounds
3434
predictions.csv # predictions per POI
35-
fitted_model.pkl # pickled JAX ModelFitter
3635
```
3736

3837
Checks:

config.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ osm_data:
126126
- last_obs_timestamp
127127
- last_tag_timestamp
128128
apply_model:
129-
model_stub: '20260416'
129+
model_stub: '20260421'
130130

131131
# Settings for scripts/models/osm_turnover.py (JAX turnover model)
132132
osm_turnover_model:
@@ -180,7 +180,6 @@ directories:
180180
predictions: predictions.csv
181181
diagnostics: diagnostics.csv
182182
inference_data: inference_data.nc
183-
fitted_model: fitted_model.pkl
184183
snapshot_foursquare:
185184
versioned: true
186185
path: ~/data/openpois/snapshots/foursquare

docs/workflows.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ See :mod:`openpois.osm.format_observations`.
107107
Fits an empirical Bayes JAX model (constant or random-effects by type)
108108
estimating the Poisson change rate λ per group via BlackJAX NUTS. Outputs
109109
``fitted_params.csv`` and ``predictions.csv`` (and optionally
110-
``param_draws.csv`` / ``fitted_model.pkl``).
110+
``param_draws.csv``).
111111

112112
See :mod:`openpois.models.model_fitter`, :mod:`openpois.models.osm_models`,
113113
and :mod:`openpois.models.setup`.

scripts/models/osm_turnover.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,9 @@
3838
diagnostics.csv — per-parameter R-hat / bulk-ESS (multi-chain only)
3939
inference_data.nc — ArviZ InferenceData (optional, if arviz installed)
4040
param_draws.csv — posterior draws (if save_full_model = true)
41-
fitted_model.pkl — pickled ModelFitter (if save_full_model = true)
4241
"""
4342

4443
import argparse
45-
import pickle
4644

4745
import jax.numpy as jnp
4846
import numpy as np
@@ -211,7 +209,3 @@ def flatten_param_draws(
211209
"model_output",
212210
"param_draws",
213211
)
214-
with open(
215-
config.get_file_path("model_output", "fitted_model"), "wb",
216-
) as fh:
217-
pickle.dump(fitter, fh)

scripts/osm_data/data_viz.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
osm_changes_all.png — overall survival curve
2828
osm_changes_all_preds.png — overall curve with constant-model
2929
prediction overlay
30-
osm_changes_by_shared_label.png — per-label facet grid
31-
osm_changes_shared_label_preds_<label>.png — per-label curves with
32-
shared-label model predictions
30+
osm_changes_by_shared_label.png — per-label facet grid (top N)
31+
by_type/osm_changes_<label>.png — per-label curves with
32+
shared-label model predictions,
33+
one file per shared_label with a
34+
fitted prediction
3335
"""
3436

3537
import numpy as np
@@ -68,20 +70,29 @@
6870

6971

7072
def fig_save(
71-
fig: gg.ggplot, stub: str, width: float = 10, height: float = 6, **kwargs
73+
fig: gg.ggplot,
74+
stub: str,
75+
width: float = 10,
76+
height: float = 6,
77+
subdir: str | None = None,
78+
**kwargs,
7279
) -> None:
7380
"""
74-
Save a ggplot figure as a PNG file to VIZ_DIR.
81+
Save a ggplot figure as a PNG file to VIZ_DIR (or a subdirectory of it).
7582
7683
Args:
7784
fig: The ggplot figure to save.
7885
stub: Output filename stem (without extension).
7986
width: Figure width in inches.
8087
height: Figure height in inches.
88+
subdir: Optional subdirectory under VIZ_DIR. Created if it doesn't
89+
exist.
8190
**kwargs: Additional keyword arguments forwarded to fig.save().
8291
"""
92+
out_dir = VIZ_DIR if subdir is None else VIZ_DIR / subdir
93+
out_dir.mkdir(parents = True, exist_ok = True)
8394
fig.save(
84-
filename = VIZ_DIR / f"{stub}.png",
95+
filename = out_dir / f"{stub}.png",
8596
width = width,
8697
height = height,
8798
units = 'in',
@@ -209,12 +220,9 @@ def get_preds_df(version: str) -> pd.DataFrame | None:
209220
fig_save(fig = fig, stub = "osm_changes_by_shared_label")
210221

211222
if preds.get('shared_label') is not None:
212-
top_n_labels = (
213-
to_plot_df["shared_label"].value_counts().head(TOP_N_TYPES).index
214-
)
215-
pred_groups = preds['shared_label']['group_name'].unique().tolist()
216-
keep_preds = list(set(top_n_labels) & set(pred_groups))
217-
for pred_label in keep_preds:
223+
observed_labels = set(to_plot_df["shared_label"].dropna().unique())
224+
pred_groups = set(preds['shared_label']['group_name'].unique())
225+
for pred_label in sorted(pred_groups & observed_labels):
218226
print(f"Plotting shared_label = {pred_label}")
219227
fig = change_plot_create(
220228
observations = to_plot_df.query(
@@ -235,5 +243,6 @@ def get_preds_df(version: str) -> pd.DataFrame | None:
235243
)
236244
fig_save(
237245
fig,
238-
stub = f"osm_changes_shared_label_preds_{pred_label}",
246+
stub = f"osm_changes_{pred_label}",
247+
subdir = "by_type",
239248
)

src/openpois/osm/change_plots.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,10 @@ def change_plot_create(
166166
)
167167
fig = fig + gg.geom_ribbon(
168168
data = p_renamed,
169-
fill = 'red', alpha = 0.25, linetype = 'dashed', color = '#444444'
169+
fill = 'darkred', alpha = 0.25, linetype = 'dashed',
170+
color = 'darkred',
170171
) + gg.geom_line(
171-
data = p_renamed, color = '#444444',
172+
data = p_renamed, color = 'darkred',
172173
mapping = gg.aes(x = 'year', y = 'y')
173174
)
174175
return fig

0 commit comments

Comments
 (0)