Skip to content

Commit 0726a72

Browse files
committed
Successful run of JAX model by shared_label with a zero-inflation term.
1 parent f11cc16 commit 0726a72

20 files changed

Lines changed: 1196 additions & 178 deletions

File tree

.claude/skills/iterate-model-types/SKILL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Rerun just the modeling step (step 3 of the full pipeline) for a fixed source-da
99

1010
## Prerequisites
1111

12-
- `osm_observations.csv` already generated by `scripts/osm_data/format_tabular.py` at the pinned `versions.osm_data`. Each row is one (POI version, shared_label) pair — POIs mapping to multiple taxonomy categories are exploded into multiple rows.
12+
- `osm_observations.parquet` already generated by `scripts/osm_data/format_tabular.py` at the pinned `versions.osm_data`. Each row is one (POI version, shared_label) pair — POIs mapping to multiple taxonomy categories are exploded into multiple rows.
1313
- You know which model variant you want next.
1414

1515
## Steps

.claude/skills/model-history-pipeline/SKILL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ End-to-end: Geofabrik full-history PBFs → observations table → fitted λ →
2121
```
2222
Runs `osmium tags-filter --omit-referenced` then `osmium time-filter`, then pyosmium streams results. Controlled by `download.osm.*` in config.yaml.
2323

24-
2. **Format tabular observations**`osm_observations.csv`
24+
2. **Format tabular observations**`osm_observations.parquet`
2525
```bash
2626
python scripts/osm_data/format_tabular.py
2727
```

config.yaml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Versioned directories (used with config.get_dir_path())
22
versions:
33
osm_data: "20260416"
4-
model_output: "20260421_by_shared_label"
4+
model_output: "20260422_by_shared_label"
55
snapshot_osm: "20260417"
66
snapshot_overture: "20260417"
77
snapshot_foursquare: "20260416"
8-
aws: "20260417"
9-
conflation: "20260417"
8+
aws: "20260422"
9+
conflation: "20260422"
1010

1111
# Settings for downloading data
1212
download:
@@ -126,14 +126,18 @@ osm_data:
126126
- last_obs_timestamp
127127
- last_tag_timestamp
128128
apply_model:
129-
model_stub: '20260421'
129+
model_stub: '20260422'
130130

131131
# Settings for scripts/models/osm_turnover.py (JAX turnover model)
132132
osm_turnover_model:
133133
# Overridable at the CLI via --model-type {constant,random_by_type}.
134134
default_model_type: constant
135135
var_prior: [-1.0, 5.0]
136-
# Column in osm_observations.csv for grouping random effects.
136+
# Tight hyperprior on log_tau (random-effect scale for per-group logit_delta
137+
# in RandomByTypeModel). Tau median ≈ exp(-2) ≈ 0.135 on the logit scale —
138+
# shrinks per-group δ toward the global intercept.
139+
logit_delta_var_prior: [-2.0, 0.5]
140+
# Column in osm_observations.parquet for grouping random effects.
137141
# "shared_label" = shared taxonomy category
138142
group_key: shared_label
139143
group_values: null
@@ -170,7 +174,7 @@ directories:
170174
osm_elements: osm_elements.csv
171175
osm_failed_elements: osm_failed_elements.csv
172176
# Modelling-ready observations (one row per POI version × shared_label)
173-
osm_observations: osm_observations.csv
177+
osm_observations: osm_observations.parquet
174178
model_output:
175179
versioned: true
176180
path: ~/data/openpois/osm_turnover_model
@@ -258,7 +262,7 @@ upload:
258262
s3_prefix_osm: "snapshots/osm"
259263
s3_prefix_conflation: "snapshots/conflated"
260264
latest_url_osm: "https://openpois-public.s3.us-west-2.amazonaws.com/snapshots/osm/20260417/osm_snapshot_partitioned/"
261-
latest_url_conflation: "https://openpois-public.s3.us-west-2.amazonaws.com/snapshots/conflated/20260417/conflated_partitioned/"
265+
latest_url_conflation: "https://openpois-public.s3.us-west-2.amazonaws.com/snapshots/conflated/20260422/conflated_partitioned/"
262266
geohash_precision_partition: 4 # ~39 km x 20 km cells; ~1,000–3,000 cells over CONUS
263267
geohash_precision_sort: 6 # ~0.6 km x 1.2 km; fine-grained sort within each partition
264268
pmtiles:

scripts/exploratory/bench_random_by_type.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
small — n = 10 000, K = 20
2525
medium — n = 1 000 000, K = 91
2626
large — n = 4 200 000, K = 91 (slow; ~matches production scale)
27-
real — reads real osm_observations.csv via config.yaml
27+
real — reads real osm_observations.parquet via config.yaml
2828
"""
2929

3030
from __future__ import annotations
@@ -117,7 +117,7 @@ def _load_real_observations() -> pd.DataFrame:
117117
group_key = cfg.get(
118118
"osm_turnover_model", "group_key", fail_if_none = False
119119
)
120-
df = pd.read_csv(path)
120+
df = pd.read_parquet(path)
121121
prepared = prepare_data_for_model(
122122
data = df,
123123
group_key = group_key,
@@ -199,6 +199,7 @@ def _run_one(
199199
param_likelihood = model.param_likelihood,
200200
derive_draws = model.derive_draws,
201201
log_likelihood_fun = model.log_likelihood_fun,
202+
log_1md_fun = getattr(model, "log_1md_fun", None),
202203
verbose = False,
203204
)
204205

scripts/models/osm_turnover.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Fit an empirical Bayes JAX model for OSM POI tag change rates.
33
4-
Reads ``osm_observations.csv`` (produced by ``osm_data/format_tabular.py``,
4+
Reads ``osm_observations.parquet`` (produced by ``osm_data/format_tabular.py``,
55
one row per (POI version, shared_label)) and fits a Poisson change-rate
66
model using BlackJAX NUTS. The model estimates a per-group change rate λ
77
(events per year). Predictions give the probability that a tag changes
@@ -24,6 +24,10 @@
2424
osm_turnover_model.default_model_type — "constant" or "random_by_type"
2525
(overridable via --model-type)
2626
osm_turnover_model.var_prior — (loc, scale) hyperprior on log_sigma
27+
osm_turnover_model.logit_delta_prior — (loc, scale) prior on logit_delta_0
28+
intercept
29+
osm_turnover_model.logit_delta_var_prior — (loc, scale) tight hyperprior on
30+
log_tau (per-group δ scale)
2731
osm_turnover_model.n_warmup — NUTS warmup steps (adaptation)
2832
osm_turnover_model.n_samples — posterior draws retained
2933
osm_turnover_model.n_chains — number of NUTS chains (vmapped)
@@ -120,13 +124,15 @@ def flatten_param_draws(
120124
config.write_self("model_output")
121125

122126
# Data preparation ------------------------------------------------------>
123-
observations_df = pd.read_csv(OBSERVATIONS_PATH)
127+
observations_df = pd.read_parquet(OBSERVATIONS_PATH)
128+
# t1_col defaults to "last_obs_timestamp" in prepare_data_for_model, so
129+
# tag_years is the inter-observation interval the per-row Bernoulli-on-
130+
# Poisson likelihood requires (methodology §1.2).
124131
obs_sub = prepare_data_for_model(
125132
data = observations_df,
126133
group_key = GROUP_KEY,
127134
group_values = GROUP_VALUES,
128135
min_value_count = MIN_VALUE_COUNT,
129-
t1_col = "last_tag_timestamp",
130136
t2_col = "obs_timestamp",
131137
)
132138

@@ -135,15 +141,26 @@ def flatten_param_draws(
135141
"osm_turnover_model", "default_model_type"
136142
)
137143
print(f"Model type: {model_type}")
144+
metadata = {
145+
"dt_col": "tag_years",
146+
"group": GROUP_KEY,
147+
"var_prior": tuple(
148+
config.get("osm_turnover_model", "var_prior")
149+
),
150+
}
151+
logit_delta_prior = config.get(
152+
"osm_turnover_model", "logit_delta_prior", fail_if_none = False
153+
)
154+
if logit_delta_prior is not None:
155+
metadata["logit_delta_prior"] = tuple(logit_delta_prior)
156+
logit_delta_var_prior = config.get(
157+
"osm_turnover_model", "logit_delta_var_prior", fail_if_none = False
158+
)
159+
if logit_delta_var_prior is not None:
160+
metadata["logit_delta_var_prior"] = tuple(logit_delta_var_prior)
138161
model = get_model_class(model_type)(
139162
dataset = obs_sub,
140-
metadata = {
141-
"dt_col": "tag_years",
142-
"group": GROUP_KEY,
143-
"var_prior": tuple(
144-
config.get("osm_turnover_model", "var_prior")
145-
),
146-
},
163+
metadata = metadata,
147164
)
148165

149166
fitter = ModelFitter(
@@ -157,6 +174,7 @@ def flatten_param_draws(
157174
param_likelihood = model.param_likelihood,
158175
derive_draws = model.derive_draws,
159176
log_likelihood_fun = model.log_likelihood_fun,
177+
log_1md_fun = getattr(model, "log_1md_fun", None),
160178
verbose = True,
161179
)
162180
fitter.fit()
@@ -172,10 +190,23 @@ def flatten_param_draws(
172190
)
173191

174192
# Predictions ----------------------------------------------------------->
193+
# Emit both regimes (methodology §4.2 Step G): the conditional formula
194+
# populates p_mean/p_lower/p_upper (δ-independent, right for rating
195+
# already-observed POIs); the fresh formula populates p_fresh_* (uses δ,
196+
# right for rating a hypothetical freshly tagged POI).
175197
predict_times = jnp.arange(101) / 10.0
176198
predict_data = model.build_predict_data(predict_times)
199+
conditional = fitter.predict(data = predict_data, mode = "conditional")
200+
fresh = (
201+
fitter.predict(data = predict_data, mode = "fresh")
202+
.rename(columns = {
203+
"p_mean": "p_fresh_mean",
204+
"p_lower": "p_fresh_lower",
205+
"p_upper": "p_fresh_upper",
206+
})
207+
)
177208
predictions = (
178-
fitter.predict(data = predict_data)
209+
pd.concat([conditional, fresh], axis = 1)
179210
.assign(t1 = 0.0, units = "years")
180211
)
181212
predictions["t2"] = np.asarray(predict_data["dt"])

scripts/osm_data/data_viz.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Plot OSM tag stability curves from observation data.
33
4-
Reads ``osm_observations.csv`` (one row per (POI version, shared_label)) and
4+
Reads ``osm_observations.parquet`` (one row per (POI version, shared_label)) and
55
computes Kaplan-Meier-style survival estimates showing what fraction of
66
observations remain unchanged over time. Saves two types of PNG figures:
77
@@ -12,7 +12,7 @@
1212
count, shown as separate facets.
1313
1414
Config keys used (config.yaml):
15-
directories.osm_data — directory containing input CSV and viz/ output
15+
directories.osm_data — directory containing input parquet and viz/ output
1616
osm_data.tag_key — tag key whose changes define observation
1717
events (used only in plot titles)
1818
osm_data.timestamp_cols — columns to parse as timestamps (rows with nulls dropped)
@@ -112,9 +112,9 @@ def get_preds_df(version: str) -> pd.DataFrame | None:
112112
return None
113113
return pd.read_csv(preds_fp).assign(
114114
year = pd.col('t2'),
115-
conf_mean = (1.0 - pd.col('p_mean')),
116-
conf_lower = (1.0 - pd.col('p_upper')),
117-
conf_upper = (1.0 - pd.col('p_lower')),
115+
conf_mean = (1.0 - pd.col('p_fresh_mean')),
116+
conf_lower = (1.0 - pd.col('p_fresh_upper')),
117+
conf_upper = (1.0 - pd.col('p_fresh_lower')),
118118
)
119119

120120
return {
@@ -135,7 +135,7 @@ def get_preds_df(version: str) -> pd.DataFrame | None:
135135
# observation timestamp will be missing for these rows
136136
timestamp_cols = config.get("osm_data", "timestamp_cols")
137137
observations_df = (
138-
pd.read_csv(OBSERVATIONS_PATH)
138+
pd.read_parquet(OBSERVATIONS_PATH)
139139
.dropna(subset = timestamp_cols)
140140
)
141141
for timestamp_col in timestamp_cols:

scripts/osm_data/download_history.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,21 @@
111111
chunk_size = CHUNK_SIZE,
112112
verbose = VERBOSE,
113113
)
114+
115+
# -------------------------------------------------------------------------
116+
# Clean up intermediates
117+
# -------------------------------------------------------------------------
118+
finals_ok = all(
119+
p.exists() and p.stat().st_size > 0
120+
for p in (OUTPUT_VERSIONS, OUTPUT_CHANGES)
121+
)
122+
if finals_ok:
123+
intermediates = (
124+
RAW_PBF, FILTERED_PBF, TIME_FILTERED_PBF,
125+
RAW_PR_PBF, FILTERED_PR_PBF, TIME_FILTERED_PR_PBF,
126+
US_VERSIONS, US_CHANGES, PR_VERSIONS, PR_CHANGES,
127+
)
128+
for p in intermediates:
129+
if p.exists():
130+
print(f"Removing intermediate {p} ...")
131+
p.unlink()

scripts/osm_data/format_tabular.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
osm_changes.parquet.
3333
3434
Output file (in osm_data directory):
35-
osm_observations.csv — one row per (POI version, shared_label). Columns:
35+
osm_observations.parquet — one row per (POI version, shared_label). Columns:
3636
id, version, tag_key, last_tag_timestamp, obs_timestamp, changed,
3737
shared_label, plus every filter_keys column for reference.
3838
"""
@@ -86,7 +86,7 @@
8686
print(f"DuckDB wrote {n_written:,} raw observations to {OUT_PATH}")
8787

8888
print("Loading raw observations for shared-label assignment ...")
89-
obs_df = pd.read_csv(OUT_PATH, dtype = {k: str for k in OSM_KEYS})
89+
obs_df = pd.read_parquet(OUT_PATH)
9090

9191
print("Assigning shared taxonomy labels (multi-label, exploded) ...")
9292
labels_per_row, _ = assign_osm_shared_label(
@@ -109,5 +109,5 @@
109109
print("Top shared labels by row count:")
110110
print(obs_df["shared_label"].value_counts().head(15).to_string())
111111

112-
obs_df.to_csv(OUT_PATH, index = False)
112+
obs_df.to_parquet(OUT_PATH, index = False)
113113
print(f"Saved {len(obs_df):,} observations to {OUT_PATH}")

scripts/osm_snapshot/download.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,12 @@
8787
chunk_dir = CHUNK_DIR,
8888
verbose = VERBOSE,
8989
)
90+
91+
# -------------------------------------------------------------------------
92+
# Clean up intermediates
93+
# -------------------------------------------------------------------------
94+
if OUTPUT_PATH.exists() and OUTPUT_PATH.stat().st_size > 0:
95+
for p in (RAW_PBF, FILTERED_PBF, RAW_PR_PBF, FILTERED_PR_PBF):
96+
if p.exists():
97+
print(f"Removing intermediate {p} ...")
98+
p.unlink()

scripts/overture/download.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
Columns: overture_id, overture_name, taxonomy_l0, taxonomy_l1,
3131
taxonomy_l2, brand_name, confidence, geometry, source
3232
"""
33+
import shutil
34+
3335
import pyarrow.parquet as pq
3436
from config_versioned import Config
3537
from openpois.io.boundary import get_us_pr_boundary
@@ -91,3 +93,12 @@
9193
)
9294
n_rows = pq.read_metadata(output_path).num_rows
9395
print(f"Saved {n_rows:,} Overture POIs to {output_path}")
96+
97+
# -------------------------------------------------------------------------
98+
# Clean up intermediates
99+
# -------------------------------------------------------------------------
100+
if OUTPUT_PATH.exists() and OUTPUT_PATH.stat().st_size > 0:
101+
parts_dir = SAVE_DIR / ".parts"
102+
if parts_dir.is_dir():
103+
print(f"Removing intermediate {parts_dir} ...")
104+
shutil.rmtree(parts_dir)

0 commit comments

Comments
 (0)