Skip to content

Commit 06b92f7

Browse files
committed
Rewrites to reduce peak memory use.
1 parent cae17d5 commit 06b92f7

2 files changed

Lines changed: 398 additions & 283 deletions

File tree

scripts/osm_snapshot/apply_model.py

Lines changed: 142 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,16 @@
2121
model_version — which model version was used
2222
model_group — which group was matched, or "constant"
2323
24-
Saves as a spatially-sorted GeoParquet (Hilbert curve order, zstd compression,
25-
50k-row row groups) for efficient cloud-native range reads from S3.
24+
Streams input row-groups via pyarrow, computes predictions per batch in numpy,
25+
and appends the new columns before writing each batch to the output parquet.
26+
Input row order is preserved — the downstream `format_for_upload.py` step
27+
re-sorts by geohash, so no intermediate spatial ordering is needed here.
2628
"""
2729
from __future__ import annotations
2830

2931
import argparse
30-
import gc
31-
import io
3232
from pathlib import Path
3333

34-
import geopandas as gpd
3534
import numpy as np
3635
import pandas as pd
3736
import pyarrow as pa
@@ -60,6 +59,77 @@
6059
# Base directory containing all versioned model subdirectories
6160
MODEL_BASE = Path(config.get_dir_path("model_output")).parent
6261

62+
BATCH_ROWS = 500_000
63+
ROW_GROUP_SIZE = 50_000
64+
65+
66+
# -----------------------------------------------------------------------------
67+
# Per-batch prediction logic (numpy only)
68+
# -----------------------------------------------------------------------------
69+
70+
def _compute_batch_predictions(
71+
df_lookup: pd.DataFrame,
72+
const_arr: np.ndarray,
73+
by_key_lookups: dict[str, tuple[list[str], np.ndarray]],
74+
constant_version: str,
75+
) -> tuple[dict[str, np.ndarray], np.ndarray]:
76+
"""
77+
Given only the `last_edited` + `FILTER_KEYS` columns of a batch, compute
78+
the 6 prediction columns plus a boolean `matched` mask (True where a
79+
per-key random-effects model was used, False where the constant fallback
80+
applied). Mirrors the logic of the original in-memory implementation.
81+
"""
82+
n = len(df_lookup)
83+
today = pd.Timestamp.now(tz = "UTC")
84+
last_edited = df_lookup["last_edited"]
85+
if last_edited.dt.tz is None:
86+
last_edited = last_edited.dt.tz_localize("UTC")
87+
n_null = last_edited.isna().sum()
88+
if n_null:
89+
raise ValueError(
90+
f"{n_null} rows have a null last_edited timestamp. "
91+
"Remove or impute these rows before applying the model."
92+
)
93+
elapsed_secs = (today - last_edited).dt.total_seconds().to_numpy()
94+
elapsed_years = elapsed_secs / (365.25 * 86_400)
95+
t2_years = np.clip(np.round(elapsed_years * 10) / 10, 0.0, 10.0)
96+
t2_int_arr = np.round(t2_years * 10).astype(int)
97+
98+
p_arr = const_arr[t2_int_arr].copy()
99+
model_version_arr = np.full(n, constant_version, dtype = object)
100+
model_group_arr = np.full(n, "constant", dtype = object)
101+
matched = np.zeros(n, dtype = bool)
102+
103+
for key in FILTER_KEYS:
104+
if key not in by_key_lookups:
105+
continue
106+
groups, group_arr = by_key_lookups[key]
107+
group_to_idx = {g: i for i, g in enumerate(groups)}
108+
group_ids = (
109+
df_lookup[key].map(group_to_idx).fillna(-1).astype(int).to_numpy()
110+
)
111+
eligible = ~matched & (group_ids >= 0)
112+
eli_pos = np.where(eligible)[0]
113+
if len(eli_pos) == 0:
114+
continue
115+
p_arr[eli_pos] = group_arr[group_ids[eli_pos], t2_int_arr[eli_pos]]
116+
model_version_arr[eli_pos] = f"{MODEL_STUB}_by_{key}"
117+
model_group_arr[eli_pos] = df_lookup[key].to_numpy()[eli_pos]
118+
matched[eli_pos] = True
119+
120+
return (
121+
{
122+
"t2_years": t2_years,
123+
"conf_mean": 1.0 - p_arr[:, 0],
124+
"conf_lower": 1.0 - p_arr[:, 2], # 1 - p_upper
125+
"conf_upper": 1.0 - p_arr[:, 1], # 1 - p_lower
126+
"model_version": model_version_arr,
127+
"model_group": model_group_arr,
128+
},
129+
matched,
130+
)
131+
132+
63133
# -----------------------------------------------------------------------------
64134
# Main
65135
# -----------------------------------------------------------------------------
@@ -71,7 +141,7 @@
71141
parser.add_argument(
72142
"--test",
73143
action = "store_true",
74-
help = "Load only the first 10,000 rows of the snapshot for testing.",
144+
help = "Process only the first 10,000 rows of the snapshot for testing.",
75145
)
76146
args = parser.parse_args()
77147

@@ -99,120 +169,74 @@
99169
else:
100170
print(f" No predictions found for {version}; will skip")
101171

102-
# -- Read snapshot ----------------------------------------------------------
172+
# -- Open input, build output schema ----------------------------------------
103173
print(f"\nReading OSM snapshot from {SNAPSHOT_PATH} ...")
104-
if args.test:
105-
# Read only the first row group from disk rather than all 7.8M rows,
106-
# then round-trip through BytesIO to preserve GeoParquet metadata.
107-
print(" (--test mode: loading first 10,000 rows only)")
108-
pf = pq.ParquetFile(SNAPSHOT_PATH)
109-
buf = io.BytesIO()
110-
pq.write_table(pf.read_row_group(0).slice(0, 10_000), buf)
111-
buf.seek(0)
112-
gdf = gpd.read_parquet(buf)
113-
else:
114-
gdf = gpd.read_parquet(SNAPSHOT_PATH)
115-
print(f" {len(gdf):,} POIs loaded")
116-
117-
n = len(gdf)
118-
119-
# -- Compute years since last edit ------------------------------------------
120-
today = pd.Timestamp.now(tz = "UTC")
121-
last_edited = gdf["last_edited"]
122-
if last_edited.dt.tz is None:
123-
last_edited = last_edited.dt.tz_localize("UTC")
124-
n_null = last_edited.isna().sum()
125-
if n_null:
126-
raise ValueError(
127-
f"{n_null} rows have a null last_edited timestamp. "
128-
"Remove or impute these rows before applying the model."
129-
)
130-
elapsed_secs = (today - last_edited).dt.total_seconds().to_numpy()
131-
elapsed_years = elapsed_secs / (365.25 * 86_400)
132-
t2_years = np.clip(np.round(elapsed_years * 10) / 10, 0.0, 10.0)
133-
# t2_int_arr stays in numpy only; never written to gdf
134-
t2_int_arr = np.round(t2_years * 10).astype(int)
135-
gdf["t2_years"] = t2_years
136-
137-
# -- Assign predictions via numpy arrays ------------------------------------
138-
# All matching and lookup work is done in numpy and written to gdf once at
139-
# the end, avoiding repeated pandas indexing overhead across 7.8M rows.
140-
141-
# Initialize from constant model: single vectorized index → shape (n, 3)
142-
p_arr = const_arr[t2_int_arr].copy() # columns: p_mean, p_lower, p_upper
143-
model_version_arr = np.full(n, constant_version, dtype = object)
144-
model_group_arr = np.full(n, "constant", dtype = object)
145-
matched = np.zeros(n, dtype = bool)
146-
147-
for key in FILTER_KEYS:
148-
if key not in by_key_lookups:
149-
continue
150-
groups, group_arr = by_key_lookups[key] # group_arr: (n_groups, 101, 3)
151-
152-
# Map tag values to group indices; NaN and unknown values become -1.
153-
group_to_idx = {g: i for i, g in enumerate(groups)}
154-
group_ids = (
155-
gdf[key].map(group_to_idx).fillna(-1).astype(int).to_numpy()
156-
)
157-
158-
eligible = ~matched & (group_ids >= 0)
159-
eli_pos = np.where(eligible)[0]
160-
if len(eli_pos) == 0:
161-
continue
174+
pf = pq.ParquetFile(SNAPSHOT_PATH)
175+
n_total = pf.metadata.num_rows
176+
print(f" {n_total:,} POIs across {pf.num_row_groups} row groups")
177+
178+
# Preserve the input GeoParquet file-level metadata (contains the `geo`
179+
# block that marks `geometry` as the primary geometry column + its CRS).
180+
# We only append new columns — existing schema + metadata carry through.
181+
input_schema = pf.schema_arrow
182+
new_fields = [
183+
pa.field("t2_years", pa.float64()),
184+
pa.field("conf_mean", pa.float64()),
185+
pa.field("conf_lower", pa.float64()),
186+
pa.field("conf_upper", pa.float64()),
187+
pa.field("model_version", pa.string()),
188+
pa.field("model_group", pa.string()),
189+
]
190+
output_schema = pa.schema(
191+
list(input_schema) + new_fields,
192+
metadata = input_schema.metadata,
193+
)
162194

163-
# Vectorized 2D fancy indexing: group_arr[m_gids, m_t2s] → (m, 3)
164-
p_arr[eli_pos] = group_arr[group_ids[eli_pos], t2_int_arr[eli_pos]]
165-
model_version_arr[eli_pos] = f"{MODEL_STUB}_by_{key}"
166-
model_group_arr[eli_pos] = gdf[key].to_numpy()[eli_pos]
167-
matched[eli_pos] = True
195+
# -- Stream: read batch → append prediction columns → write -----------------
196+
lookup_cols = ["last_edited"] + [k for k in FILTER_KEYS if k in by_key_lookups]
197+
version_counts: dict[str, int] = {
198+
f"{MODEL_STUB}_by_{k}": 0 for k in by_key_lookups
199+
}
200+
version_counts[constant_version] = 0
201+
n_written = 0
168202

169-
print(f" {MODEL_STUB}_by_{key}: matched {len(eli_pos):,} POIs")
170-
171-
n_constant = int((~matched).sum())
172-
print(f" {constant_version}: {n_constant:,} POIs (fallback)")
173-
174-
# -- Assign back to GeoDataFrame --------------------------------------------
175-
# Convert change probability to confidence (1 - p).
176-
# Note: conf_lower = 1 - p_upper and conf_upper = 1 - p_lower.
177-
gdf["conf_mean"] = 1.0 - p_arr[:, 0]
178-
gdf["conf_lower"] = 1.0 - p_arr[:, 2] # 1 - p_upper
179-
gdf["conf_upper"] = 1.0 - p_arr[:, 1] # 1 - p_lower
180-
# Categorical dtype stores integer codes + a small lookup, saving ~90%
181-
# memory vs. object strings for these low-cardinality columns.
182-
gdf["model_version"] = pd.Categorical(model_version_arr)
183-
gdf["model_group"] = pd.Categorical(model_group_arr)
184-
185-
# -- Spatial sort for cloud-native S3 reads ---------------------------------
186-
# Two-pass approach avoids holding two GDF copies in memory at once:
187-
# 1. compute sorted indices, report model breakdown, write unsorted
188-
# GeoParquet to a temp file, drop the GDF
189-
# 2. read the temp back as a pa.Table, reorder via .take(), write final
190-
print("\nComputing Hilbert curve order ...")
191-
sorted_indices = gdf.hilbert_distance().to_numpy().argsort()
203+
OUTPUT_PATH.parent.mkdir(parents = True, exist_ok = True)
204+
print(f"\nWriting to {OUTPUT_PATH} ...", flush = True)
205+
206+
with pq.ParquetWriter(
207+
OUTPUT_PATH, output_schema, compression = "zstd"
208+
) as writer:
209+
if args.test:
210+
print(" (--test mode: first 10,000 rows only)")
211+
batches = [next(pf.iter_batches(batch_size = 10_000))]
212+
else:
213+
batches = pf.iter_batches(batch_size = BATCH_ROWS)
214+
215+
for batch in batches:
216+
tbl = pa.Table.from_batches([batch])
217+
df_lookup = tbl.select(lookup_cols).to_pandas()
218+
preds, matched = _compute_batch_predictions(
219+
df_lookup, const_arr, by_key_lookups, constant_version,
220+
)
221+
222+
for key in by_key_lookups:
223+
version = f"{MODEL_STUB}_by_{key}"
224+
version_counts[version] += int(
225+
(preds["model_version"] == version).sum()
226+
)
227+
version_counts[constant_version] += int(len(df_lookup) - matched.sum())
228+
229+
for field in new_fields:
230+
tbl = tbl.append_column(
231+
field.name,
232+
pa.array(preds[field.name], type = field.type),
233+
)
234+
235+
writer.write_table(tbl, row_group_size = ROW_GROUP_SIZE)
236+
n_written += batch.num_rows
237+
print(f" {n_written:,}/{n_total:,} rows written", flush = True)
192238

193239
print("\nModel version breakdown:")
194-
print(gdf["model_version"].value_counts().to_string())
195-
196-
OUTPUT_PATH.parent.mkdir(parents = True, exist_ok = True)
197-
tmp_path = OUTPUT_PATH.with_suffix(OUTPUT_PATH.suffix + ".unsorted")
198-
print(f"\nWriting unsorted snapshot to {tmp_path} ...")
199-
gdf.to_parquet(tmp_path, compression = "zstd")
200-
n_total = len(gdf)
201-
del gdf
202-
gc.collect()
203-
204-
print(f"Reordering via PyArrow and writing final output to {OUTPUT_PATH} ...")
205-
table = pq.read_table(tmp_path)
206-
sorted_table = table.take(pa.array(sorted_indices))
207-
del table
208-
gc.collect()
209-
pq.write_table(
210-
sorted_table,
211-
OUTPUT_PATH,
212-
compression = "zstd",
213-
row_group_size = 50_000,
214-
)
215-
del sorted_table
216-
gc.collect()
217-
tmp_path.unlink()
218-
print(f"Done. Saved {n_total:,} POIs.")
240+
for version, count in sorted(version_counts.items(), key = lambda kv: -kv[1]):
241+
print(f" {version}: {count:,}")
242+
print(f"\nDone. Saved {n_written:,} POIs.", flush = True)

0 commit comments

Comments
 (0)