Skip to content

Commit bd6c331

Browse files
authored
Merge pull request #21 from henryspatialanalysis/feature/jax
Rebuild POI change models in JAX
2 parents 697ce26 + 0726a72 commit bd6c331

38 files changed

Lines changed: 4319 additions & 1215 deletions

.claude/CLAUDE.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Style: Black (format-on-save in VSCode). Lint: flake8 + pylint, configured in `p
3939
|---|---|
4040
| [src/openpois/io/](../src/openpois/io/) | I/O adapters: OSM history/snapshot, Overture, Foursquare, Census boundary |
4141
| [src/openpois/osm/](../src/openpois/osm/) | OSM-specific transforms: `format_observations`, `change_plots` |
42-
| [src/openpois/models/](../src/openpois/models/) | PyTorch empirical Bayes: `EventRate`, `ModelFitter`, model registry |
42+
| [src/openpois/models/](../src/openpois/models/) | JAX/BlackJAX empirical Bayes: `ModelFitter`, model registry |
4343
| [src/openpois/conflation/](../src/openpois/conflation/) | OSM×Overture matching: `taxonomy`, `match`, `merge` |
4444
| [scripts/](../scripts/) | End-to-end pipelines using config.yaml — not installed, reference only |
4545
| [site/](../site/) | Vue 3 + Vite frontend |
@@ -49,6 +49,7 @@ Style: Black (format-on-save in VSCode). Lint: flake8 + pylint, configured in `p
4949
- [docs/data-sources.md](docs/data-sources.md) — URLs, auth, schema quirks for every source
5050
- [docs/taxonomy-setup.md](docs/taxonomy-setup.md) — crosswalk CSVs, build_taxonomy.py, frontend sync
5151
- [docs/data-versioning.md](docs/data-versioning.md)`versions:` block, path resolution, external references
52+
- [docs/turnover-model-methodology.md](docs/turnover-model-methodology.md) — statistical derivation of the POI turnover model with ZIE extension
5253

5354
## Running to-do
5455

.claude/docs/turnover-model-methodology.md

Lines changed: 391 additions & 0 deletions
Large diffs are not rendered by default.

.claude/skills/iterate-model-types/SKILL.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,21 @@ Rerun just the modeling step (step 3 of the full pipeline) for a fixed source-da
99

1010
## Prerequisites
1111

12-
- `osm_observations_{tag_key}.csv` already generated by `scripts/osm_data/format_tabular.py` at the pinned `versions.osm_data`.
12+
- `osm_observations.parquet` already generated by `scripts/osm_data/format_tabular.py` at the pinned `versions.osm_data`. Each row is one (POI version, shared_label) pair — POIs mapping to multiple taxonomy categories are exploded into multiple rows.
1313
- You know which model variant you want next.
1414

1515
## Steps
1616

1717
1. **Pin `versions.osm_data`** in `config.yaml`. Do *not* change it — that's the whole point.
1818

1919
2. **Bump `versions.model_output`** using the convention:
20-
- `{date}_by_{group_key}` for grouped models (e.g., `20260416_by_leisure`, `20260416_by_amenity`)
20+
- `{date}_by_shared_label` for the unified random-effects model (one intercept per taxonomy category)
21+
- `{date}_by_{group_key}` for ad-hoc groupings on other columns
2122
- `{date}_constant` for the single-rate baseline
2223

2324
3. **Edit `osm_turnover_model`** in `config.yaml`:
24-
- `model_type`: one of `constant`, `random_by_type`, `pseudo_varying` (registry at [src/openpois/models/osm_models.py](../../../src/openpois/models/osm_models.py))
25-
- `group_key`: column to group by (e.g., `leisure_last_value`, `shop`, `amenity`). Null for constant.
25+
- `model_type`: one of `constant`, `random_by_type` (registry at [src/openpois/models/osm_models.py](../../../src/openpois/models/osm_models.py))
26+
- `group_key`: column to group by. Default `shared_label` (unified taxonomy). Null for constant. Other observation columns (raw OSM keys like `shop`, `amenity`, ...) are still accepted.
2627
- `group_values`: restrict to specific values, or null for all
2728
- `min_value_count`: drop groups below this count
2829

@@ -35,7 +36,7 @@ Rerun just the modeling step (step 3 of the full pipeline) for a fixed source-da
3536
```bash
3637
python scripts/osm_snapshot/apply_model.py
3738
```
38-
`apply_model.py` picks up every `{stub}_by_*` directory at the stub date and falls back to `{stub}_constant` for unmatched groups.
39+
`apply_model.py` loads a single `{stub}_by_shared_label` random-effects model (if present) and falls back to `{stub}_constant` for rows with no matching taxonomy label.
3940

4041
## Comparing variants
4142

@@ -46,5 +47,5 @@ Rerun just the modeling step (step 3 of the full pipeline) for a fixed source-da
4647
## Pitfalls
4748

4849
- Forgetting to change `versions.model_output` overwrites the previous variant's outputs.
49-
- `group_key` must exist as a column in the observations CSV; run `format_tabular.py` first if adding a new tag key to `osm_data.tag_key`.
50+
- `group_key` must exist as a column in the observations CSV. `shared_label` is populated by `format_tabular.py` from the conflation taxonomy crosswalk; if you change the crosswalk, rerun `format_tabular.py`.
5051
- `min_value_count` filters groups silently — check `fitted_params.csv` row count vs. expected group count.

.claude/skills/model-history-pipeline/SKILL.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,23 @@ End-to-end: Geofabrik full-history PBFs → observations table → fitted λ →
2121
```
2222
Runs `osmium tags-filter --omit-referenced` then `osmium time-filter`, then pyosmium streams results. Controlled by `download.osm.*` in config.yaml.
2323

24-
2. **Format tabular observations**`osm_observations_{tag_key}.csv`
24+
2. **Format tabular observations**`osm_observations.parquet`
2525
```bash
2626
python scripts/osm_data/format_tabular.py
2727
```
28-
Uses `osm_data.tag_key` (e.g., `name`) to produce one observation row per version with change/deletion flags.
28+
Uses `osm_data.tag_key` (e.g., `name`) to flag change/deletion per POI version, then assigns shared taxonomy labels from the conflation crosswalk and explodes rows per label. One row = (POI version, shared_label). Rows with no matching taxonomy category are dropped.
2929

3030
3. **Pick a modeling config and fit λ** — see [skills/iterate-model-types](../iterate-model-types/SKILL.md) for choosing `model_type` / `group_key`.
3131
```bash
3232
python scripts/models/osm_turnover.py
3333
```
34-
Writes `fitted_params.csv`, `param_draws.csv`, `predictions.csv`, `fitted_model.pt` to `{date}_by_{group_key}` (or `{date}_constant`) under `directories.model_output.path`.
34+
Writes `fitted_params.csv`, `param_draws.csv`, `predictions.csv` to `{date}_by_shared_label` (the unified random-effects model) or `{date}_constant` (single-rate baseline) under `directories.model_output.path`.
3535

3636
4. **Apply predictions to the OSM snapshot**`osm_snapshot_rated.parquet`
3737
```bash
3838
python scripts/osm_snapshot/apply_model.py
3939
```
40-
Reads the `osm_data.apply_model.model_stub` date, loads all `{stub}_by_*` dirs (plus a `{stub}_constant` fallback), and rates every POI in `osm_snapshot.parquet`.
40+
Reads the `osm_data.apply_model.model_stub` date, loads the `{stub}_by_shared_label` random-effects model (if present), falls back to `{stub}_constant` for rows with no matching taxonomy label, and rates every POI in `osm_snapshot.parquet`.
4141

4242
## Verification
4343

@@ -51,5 +51,5 @@ Hand off to [skills/verify-pipeline-run](../verify-pipeline-run/SKILL.md) — in
5151

5252
- Entry: [src/openpois/io/osm_history_pbf.py](../../../src/openpois/io/osm_history_pbf.py) (`download_osm_history`)
5353
- Entry: [src/openpois/osm/format_observations.py](../../../src/openpois/osm/format_observations.py)
54-
- Entry: [src/openpois/models/](../../../src/openpois/models/)`ModelFitter`, `EventRate`, `pytorch_setup`
54+
- Entry: [src/openpois/models/](../../../src/openpois/models/)`ModelFitter` (JAX/BlackJAX), model classes
5555
- Registry: [src/openpois/models/osm_models.py](../../../src/openpois/models/osm_models.py)`MODEL_REGISTRY`, `get_model_class`

.claude/skills/verify-pipeline-run/SKILL.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ Flag >5% drops. Known regression patterns:
3232
fitted_params.csv # λ and σ per group
3333
param_draws.csv # uncertainty bounds
3434
predictions.csv # predictions per POI
35-
fitted_model.pt # torch state
3635
```
3736

3837
Checks:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ The core workflow models how long POI tags remain stable over time using histori
7070
```bash
7171
python exploratory/osm_data/download.py # Download OSM history for a bounding box
7272
python exploratory/osm_data/format_tabular.py # Format into observation records
73-
python exploratory/models/pytorch_simple.py # Fit Poisson change-rate model
73+
python scripts/models/osm_turnover.py # Fit Poisson change-rate model (JAX)
7474
```
7575

7676
---

config.yaml

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Versioned directories (used with config.get_dir_path())
22
versions:
33
osm_data: "20260416"
4-
model_output: "20260416_by_leisure"
4+
model_output: "20260422_by_shared_label"
55
snapshot_osm: "20260417"
66
snapshot_overture: "20260417"
77
snapshot_foursquare: "20260416"
8-
aws: "20260417"
9-
conflation: "20260417"
8+
aws: "20260422"
9+
conflation: "20260422"
1010

1111
# Settings for downloading data
1212
download:
@@ -126,17 +126,29 @@ osm_data:
126126
- last_obs_timestamp
127127
- last_tag_timestamp
128128
apply_model:
129-
model_stub: '20260416'
129+
model_stub: '20260422'
130130

131-
# Settings for exploratory/models/pytorch_simple.py
131+
# Settings for scripts/models/osm_turnover.py (JAX turnover model)
132132
osm_turnover_model:
133-
model_type: random_by_type
134-
tag_key: name
133+
# Overridable at the CLI via --model-type {constant,random_by_type}.
134+
default_model_type: constant
135135
var_prior: [-1.0, 5.0]
136-
group_key: leisure_last_value
136+
# Tight hyperprior on log_tau (random-effect scale for per-group logit_delta
137+
# in RandomByTypeModel). Tau median ≈ exp(-2) ≈ 0.135 on the logit scale —
138+
# shrinks per-group δ toward the global intercept.
139+
logit_delta_var_prior: [-2.0, 0.5]
140+
# Column in osm_observations.parquet for grouping random effects.
141+
# "shared_label" = shared taxonomy category
142+
group_key: shared_label
137143
group_values: null
138144
min_value_count: 5
139-
n_draws: 250
145+
# NUTS warmup (window adaptation) and retained-sample counts. Warmup should
146+
# generally be >= n_samples for hierarchical models.
147+
n_warmup: 500
148+
n_samples: 500
149+
# Number of independent chains (vmapped in parallel). n_chains > 1 enables
150+
# R-hat and bulk ESS diagnostics at roughly linear wall-time cost on CPU.
151+
n_chains: 4
140152
save_full_model: true
141153

142154
# Directory definitions (used with config.get_dir_path())
@@ -161,14 +173,17 @@ directories:
161173
# Legacy Overpass-based pipeline (still used by scripts/osm_data/download.py)
162174
osm_elements: osm_elements.csv
163175
osm_failed_elements: osm_failed_elements.csv
176+
# Modelling-ready observations (one row per POI version × shared_label)
177+
osm_observations: osm_observations.parquet
164178
model_output:
165179
versioned: true
166180
path: ~/data/openpois/osm_turnover_model
167181
files:
168182
fitted_params: fitted_params.csv
169183
param_draws: param_draws.csv
170184
predictions: predictions.csv
171-
fitted_model: fitted_model.pt
185+
diagnostics: diagnostics.csv
186+
inference_data: inference_data.nc
172187
snapshot_foursquare:
173188
versioned: true
174189
path: ~/data/openpois/snapshots/foursquare
@@ -247,7 +262,7 @@ upload:
247262
s3_prefix_osm: "snapshots/osm"
248263
s3_prefix_conflation: "snapshots/conflated"
249264
latest_url_osm: "https://openpois-public.s3.us-west-2.amazonaws.com/snapshots/osm/20260417/osm_snapshot_partitioned/"
250-
latest_url_conflation: "https://openpois-public.s3.us-west-2.amazonaws.com/snapshots/conflated/20260417/conflated_partitioned/"
265+
latest_url_conflation: "https://openpois-public.s3.us-west-2.amazonaws.com/snapshots/conflated/20260422/conflated_partitioned/"
251266
geohash_precision_partition: 4 # ~39 km x 20 km cells; ~1,000–3,000 cells over CONUS
252267
geohash_precision_sort: 6 # ~0.6 km x 1.2 km; fine-grained sort within each partition
253268
pmtiles:

docs/api.rst

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -136,38 +136,48 @@ AWS credentials via environment variables or ``~/.aws/credentials``.
136136
models
137137
------
138138

139-
openpois.models.event_rate
140-
~~~~~~~~~~~~~~~~~~~~~~~~~~
139+
openpois.models.jax_core
140+
~~~~~~~~~~~~~~~~~~~~~~~~
141141

142-
Representation of a Poisson event rate (λ) used by the change-rate model.
143-
Wraps a constant or time-varying λ tensor and computes the probability that
144-
at least one change event occurs within a given time interval via numerical
145-
or closed-form integration.
142+
JAX/BlackJAX helpers: a PRNG factory, a jitted Markov-chain scan, a NUTS
143+
sampler with window adaptation, and a vmap-based predictive-draw utility.
146144

147-
.. automodule:: openpois.models.event_rate
145+
.. automodule:: openpois.models.jax_core
148146
:members:
149147
:undoc-members:
150148
:show-inheritance:
151149

152150
openpois.models.model_fitter
153151
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
154152

155-
L-BFGS optimizer wrapper for POI change-rate models. Fits model parameters
156-
using PyTorch and the ``torchmin`` optimizer, generates posterior parameter
157-
draws for uncertainty quantification, and produces prediction tables of
158-
change probability versus time.
153+
BlackJAX NUTS fitter for POI change-rate models. Takes an ``event_rate_fun``
154+
plus starting parameters as a pytree, runs window-adapted NUTS to draw from
155+
the posterior, and produces posterior summaries and predictive distributions
156+
of change probability versus time.
159157

160158
.. automodule:: openpois.models.model_fitter
161159
:members:
162160
:undoc-members:
163161
:show-inheritance:
164162

163+
openpois.models.osm_models
164+
~~~~~~~~~~~~~~~~~~~~~~~~~~
165+
166+
JAX model classes for OSM turnover. ``ConstantModel`` and
167+
``RandomByTypeModel`` package their own data, priors, and event-rate
168+
functions to hand to ``ModelFitter``. Selectable via ``get_model_class``.
169+
170+
.. automodule:: openpois.models.osm_models
171+
:members:
172+
:undoc-members:
173+
:show-inheritance:
174+
165175
openpois.models.setup
166176
~~~~~~~~~~~~~~~~~~~~~
167177

168-
Environment setup utilities for PyTorch model runs. Selects GPU or CPU
169-
device, configures ``torch_continuum`` optimisation level, and prepares
170-
filtered and grouped observation data for model fitting.
178+
Data-preparation helpers. ``prepare_data_for_model`` filters and groups
179+
observation records and computes the ``tag_years`` elapsed column used as
180+
the per-observation interval length in fitting.
171181

172182
.. automodule:: openpois.models.setup
173183
:members:

docs/conf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
# Mock heavy imports so Sphinx can parse docstrings without installing the
2929
# full conda environment in CI.
3030
autodoc_mock_imports = [
31-
"torch",
31+
"jax",
32+
"jaxlib",
33+
"blackjax",
3234
"numpy",
3335
"pandas",
3436
"geopandas",

docs/workflows.rst

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ See :mod:`openpois.io.osm_history`.
9292
python scripts/osm_data/format_tabular.py
9393
9494
Converts raw version histories into one-row-per-observation records, each
95-
flagged for whether the configured tag changed. Output:
96-
``osm_observations_{tag_key}.csv``.
95+
flagged for whether the configured ``osm_data.tag_key`` changed, then
96+
assigns a shared taxonomy label and explodes rows for POIs mapping to
97+
multiple labels. Output: ``osm_observations.csv``.
9798

9899
See :mod:`openpois.osm.format_observations`.
99100

@@ -103,13 +104,13 @@ See :mod:`openpois.osm.format_observations`.
103104
104105
python scripts/models/osm_turnover.py
105106
106-
Fits an empirical Bayes PyTorch model (constant or random-effects by type)
107-
estimating the Poisson change rate λ per group. Outputs ``fitted_params.csv``
108-
and ``predictions.csv`` (and optionally ``param_draws.csv`` /
109-
``fitted_model.pt``).
107+
Fits an empirical Bayes JAX model (constant or random-effects by type)
108+
estimating the Poisson change rate λ per group via BlackJAX NUTS. Outputs
109+
``fitted_params.csv`` and ``predictions.csv`` (and optionally
110+
``param_draws.csv``).
110111

111-
See :mod:`openpois.models.model_fitter`, :mod:`openpois.models.setup`, and
112-
:mod:`openpois.models.event_rate`.
112+
See :mod:`openpois.models.model_fitter`, :mod:`openpois.models.osm_models`,
113+
and :mod:`openpois.models.setup`.
113114

114115
**Step 4 — Visualise stability curves** *(optional)*
115116

0 commit comments

Comments
 (0)