|
21 | 21 | model_version — which model version was used |
22 | 22 | model_group — which group was matched, or "constant" |
23 | 23 |
|
24 | | -Saves as a spatially-sorted GeoParquet (Hilbert curve order, zstd compression, |
25 | | -50k-row row groups) for efficient cloud-native range reads from S3. |
| 24 | +Streams input row-groups via pyarrow, computes predictions per batch in numpy, |
| 25 | +and appends the new columns before writing each batch to the output parquet. |
| 26 | +Input row order is preserved — the downstream `format_for_upload.py` step |
| 27 | +re-sorts by geohash, so no intermediate spatial ordering is needed here. |
26 | 28 | """ |
27 | 29 | from __future__ import annotations |
28 | 30 |
|
29 | 31 | import argparse |
30 | | -import gc |
31 | | -import io |
32 | 32 | from pathlib import Path |
33 | 33 |
|
34 | | -import geopandas as gpd |
35 | 34 | import numpy as np |
36 | 35 | import pandas as pd |
37 | 36 | import pyarrow as pa |
|
60 | 59 | # Base directory containing all versioned model subdirectories |
61 | 60 | MODEL_BASE = Path(config.get_dir_path("model_output")).parent |
62 | 61 |
|
| 62 | +BATCH_ROWS = 500_000 |
| 63 | +ROW_GROUP_SIZE = 50_000 |
| 64 | + |
| 65 | + |
| 66 | +# ----------------------------------------------------------------------------- |
| 67 | +# Per-batch prediction logic (numpy only) |
| 68 | +# ----------------------------------------------------------------------------- |
| 69 | + |
| 70 | +def _compute_batch_predictions( |
| 71 | + df_lookup: pd.DataFrame, |
| 72 | + const_arr: np.ndarray, |
| 73 | + by_key_lookups: dict[str, tuple[list[str], np.ndarray]], |
| 74 | + constant_version: str, |
| 75 | +) -> tuple[dict[str, np.ndarray], np.ndarray]: |
| 76 | + """ |
| 77 | + Given only the `last_edited` + `FILTER_KEYS` columns of a batch, compute |
| 78 | + the 6 prediction columns plus a boolean `matched` mask (True where a |
| 79 | + per-key random-effects model was used, False where the constant fallback |
| 80 | + applied). Mirrors the logic of the original in-memory implementation. |
| 81 | + """ |
| 82 | + n = len(df_lookup) |
| 83 | + today = pd.Timestamp.now(tz = "UTC") |
| 84 | + last_edited = df_lookup["last_edited"] |
| 85 | + if last_edited.dt.tz is None: |
| 86 | + last_edited = last_edited.dt.tz_localize("UTC") |
| 87 | + n_null = last_edited.isna().sum() |
| 88 | + if n_null: |
| 89 | + raise ValueError( |
| 90 | + f"{n_null} rows have a null last_edited timestamp. " |
| 91 | + "Remove or impute these rows before applying the model." |
| 92 | + ) |
| 93 | + elapsed_secs = (today - last_edited).dt.total_seconds().to_numpy() |
| 94 | + elapsed_years = elapsed_secs / (365.25 * 86_400) |
| 95 | + t2_years = np.clip(np.round(elapsed_years * 10) / 10, 0.0, 10.0) |
| 96 | + t2_int_arr = np.round(t2_years * 10).astype(int) |
| 97 | + |
| 98 | + p_arr = const_arr[t2_int_arr].copy() |
| 99 | + model_version_arr = np.full(n, constant_version, dtype = object) |
| 100 | + model_group_arr = np.full(n, "constant", dtype = object) |
| 101 | + matched = np.zeros(n, dtype = bool) |
| 102 | + |
| 103 | + for key in FILTER_KEYS: |
| 104 | + if key not in by_key_lookups: |
| 105 | + continue |
| 106 | + groups, group_arr = by_key_lookups[key] |
| 107 | + group_to_idx = {g: i for i, g in enumerate(groups)} |
| 108 | + group_ids = ( |
| 109 | + df_lookup[key].map(group_to_idx).fillna(-1).astype(int).to_numpy() |
| 110 | + ) |
| 111 | + eligible = ~matched & (group_ids >= 0) |
| 112 | + eli_pos = np.where(eligible)[0] |
| 113 | + if len(eli_pos) == 0: |
| 114 | + continue |
| 115 | + p_arr[eli_pos] = group_arr[group_ids[eli_pos], t2_int_arr[eli_pos]] |
| 116 | + model_version_arr[eli_pos] = f"{MODEL_STUB}_by_{key}" |
| 117 | + model_group_arr[eli_pos] = df_lookup[key].to_numpy()[eli_pos] |
| 118 | + matched[eli_pos] = True |
| 119 | + |
| 120 | + return ( |
| 121 | + { |
| 122 | + "t2_years": t2_years, |
| 123 | + "conf_mean": 1.0 - p_arr[:, 0], |
| 124 | + "conf_lower": 1.0 - p_arr[:, 2], # 1 - p_upper |
| 125 | + "conf_upper": 1.0 - p_arr[:, 1], # 1 - p_lower |
| 126 | + "model_version": model_version_arr, |
| 127 | + "model_group": model_group_arr, |
| 128 | + }, |
| 129 | + matched, |
| 130 | + ) |
| 131 | + |
| 132 | + |
63 | 133 | # ----------------------------------------------------------------------------- |
64 | 134 | # Main |
65 | 135 | # ----------------------------------------------------------------------------- |
|
71 | 141 | parser.add_argument( |
72 | 142 | "--test", |
73 | 143 | action = "store_true", |
74 | | - help = "Load only the first 10,000 rows of the snapshot for testing.", |
| 144 | + help = "Process only the first 10,000 rows of the snapshot for testing.", |
75 | 145 | ) |
76 | 146 | args = parser.parse_args() |
77 | 147 |
|
|
99 | 169 | else: |
100 | 170 | print(f" No predictions found for {version}; will skip") |
101 | 171 |
|
102 | | - # -- Read snapshot ---------------------------------------------------------- |
| 172 | + # -- Open input, build output schema ---------------------------------------- |
103 | 173 | print(f"\nReading OSM snapshot from {SNAPSHOT_PATH} ...") |
104 | | - if args.test: |
105 | | - # Read only the first row group from disk rather than all 7.8M rows, |
106 | | - # then round-trip through BytesIO to preserve GeoParquet metadata. |
107 | | - print(" (--test mode: loading first 10,000 rows only)") |
108 | | - pf = pq.ParquetFile(SNAPSHOT_PATH) |
109 | | - buf = io.BytesIO() |
110 | | - pq.write_table(pf.read_row_group(0).slice(0, 10_000), buf) |
111 | | - buf.seek(0) |
112 | | - gdf = gpd.read_parquet(buf) |
113 | | - else: |
114 | | - gdf = gpd.read_parquet(SNAPSHOT_PATH) |
115 | | - print(f" {len(gdf):,} POIs loaded") |
116 | | - |
117 | | - n = len(gdf) |
118 | | - |
119 | | - # -- Compute years since last edit ------------------------------------------ |
120 | | - today = pd.Timestamp.now(tz = "UTC") |
121 | | - last_edited = gdf["last_edited"] |
122 | | - if last_edited.dt.tz is None: |
123 | | - last_edited = last_edited.dt.tz_localize("UTC") |
124 | | - n_null = last_edited.isna().sum() |
125 | | - if n_null: |
126 | | - raise ValueError( |
127 | | - f"{n_null} rows have a null last_edited timestamp. " |
128 | | - "Remove or impute these rows before applying the model." |
129 | | - ) |
130 | | - elapsed_secs = (today - last_edited).dt.total_seconds().to_numpy() |
131 | | - elapsed_years = elapsed_secs / (365.25 * 86_400) |
132 | | - t2_years = np.clip(np.round(elapsed_years * 10) / 10, 0.0, 10.0) |
133 | | - # t2_int_arr stays in numpy only; never written to gdf |
134 | | - t2_int_arr = np.round(t2_years * 10).astype(int) |
135 | | - gdf["t2_years"] = t2_years |
136 | | - |
137 | | - # -- Assign predictions via numpy arrays ------------------------------------ |
138 | | - # All matching and lookup work is done in numpy and written to gdf once at |
139 | | - # the end, avoiding repeated pandas indexing overhead across 7.8M rows. |
140 | | - |
141 | | - # Initialize from constant model: single vectorized index → shape (n, 3) |
142 | | - p_arr = const_arr[t2_int_arr].copy() # columns: p_mean, p_lower, p_upper |
143 | | - model_version_arr = np.full(n, constant_version, dtype = object) |
144 | | - model_group_arr = np.full(n, "constant", dtype = object) |
145 | | - matched = np.zeros(n, dtype = bool) |
146 | | - |
147 | | - for key in FILTER_KEYS: |
148 | | - if key not in by_key_lookups: |
149 | | - continue |
150 | | - groups, group_arr = by_key_lookups[key] # group_arr: (n_groups, 101, 3) |
151 | | - |
152 | | - # Map tag values to group indices; NaN and unknown values become -1. |
153 | | - group_to_idx = {g: i for i, g in enumerate(groups)} |
154 | | - group_ids = ( |
155 | | - gdf[key].map(group_to_idx).fillna(-1).astype(int).to_numpy() |
156 | | - ) |
157 | | - |
158 | | - eligible = ~matched & (group_ids >= 0) |
159 | | - eli_pos = np.where(eligible)[0] |
160 | | - if len(eli_pos) == 0: |
161 | | - continue |
| 174 | + pf = pq.ParquetFile(SNAPSHOT_PATH) |
| 175 | + n_total = pf.metadata.num_rows |
| 176 | + print(f" {n_total:,} POIs across {pf.num_row_groups} row groups") |
| 177 | + |
| 178 | + # Preserve the input GeoParquet file-level metadata (contains the `geo` |
| 179 | + # block that marks `geometry` as the primary geometry column + its CRS). |
| 180 | + # We only append new columns — existing schema + metadata carry through. |
| 181 | + input_schema = pf.schema_arrow |
| 182 | + new_fields = [ |
| 183 | + pa.field("t2_years", pa.float64()), |
| 184 | + pa.field("conf_mean", pa.float64()), |
| 185 | + pa.field("conf_lower", pa.float64()), |
| 186 | + pa.field("conf_upper", pa.float64()), |
| 187 | + pa.field("model_version", pa.string()), |
| 188 | + pa.field("model_group", pa.string()), |
| 189 | + ] |
| 190 | + output_schema = pa.schema( |
| 191 | + list(input_schema) + new_fields, |
| 192 | + metadata = input_schema.metadata, |
| 193 | + ) |
162 | 194 |
|
163 | | - # Vectorized 2D fancy indexing: group_arr[m_gids, m_t2s] → (m, 3) |
164 | | - p_arr[eli_pos] = group_arr[group_ids[eli_pos], t2_int_arr[eli_pos]] |
165 | | - model_version_arr[eli_pos] = f"{MODEL_STUB}_by_{key}" |
166 | | - model_group_arr[eli_pos] = gdf[key].to_numpy()[eli_pos] |
167 | | - matched[eli_pos] = True |
| 195 | + # -- Stream: read batch → append prediction columns → write ----------------- |
| 196 | + lookup_cols = ["last_edited"] + [k for k in FILTER_KEYS if k in by_key_lookups] |
| 197 | + version_counts: dict[str, int] = { |
| 198 | + f"{MODEL_STUB}_by_{k}": 0 for k in by_key_lookups |
| 199 | + } |
| 200 | + version_counts[constant_version] = 0 |
| 201 | + n_written = 0 |
168 | 202 |
|
169 | | - print(f" {MODEL_STUB}_by_{key}: matched {len(eli_pos):,} POIs") |
170 | | - |
171 | | - n_constant = int((~matched).sum()) |
172 | | - print(f" {constant_version}: {n_constant:,} POIs (fallback)") |
173 | | - |
174 | | - # -- Assign back to GeoDataFrame -------------------------------------------- |
175 | | - # Convert change probability to confidence (1 - p). |
176 | | - # Note: conf_lower = 1 - p_upper and conf_upper = 1 - p_lower. |
177 | | - gdf["conf_mean"] = 1.0 - p_arr[:, 0] |
178 | | - gdf["conf_lower"] = 1.0 - p_arr[:, 2] # 1 - p_upper |
179 | | - gdf["conf_upper"] = 1.0 - p_arr[:, 1] # 1 - p_lower |
180 | | - # Categorical dtype stores integer codes + a small lookup, saving ~90% |
181 | | - # memory vs. object strings for these low-cardinality columns. |
182 | | - gdf["model_version"] = pd.Categorical(model_version_arr) |
183 | | - gdf["model_group"] = pd.Categorical(model_group_arr) |
184 | | - |
185 | | - # -- Spatial sort for cloud-native S3 reads --------------------------------- |
186 | | - # Two-pass approach avoids holding two GDF copies in memory at once: |
187 | | - # 1. compute sorted indices, report model breakdown, write unsorted |
188 | | - # GeoParquet to a temp file, drop the GDF |
189 | | - # 2. read the temp back as a pa.Table, reorder via .take(), write final |
190 | | - print("\nComputing Hilbert curve order ...") |
191 | | - sorted_indices = gdf.hilbert_distance().to_numpy().argsort() |
| 203 | + OUTPUT_PATH.parent.mkdir(parents = True, exist_ok = True) |
| 204 | + print(f"\nWriting to {OUTPUT_PATH} ...", flush = True) |
| 205 | + |
| 206 | + with pq.ParquetWriter( |
| 207 | + OUTPUT_PATH, output_schema, compression = "zstd" |
| 208 | + ) as writer: |
| 209 | + if args.test: |
| 210 | + print(" (--test mode: first 10,000 rows only)") |
| 211 | + batches = [next(pf.iter_batches(batch_size = 10_000))] |
| 212 | + else: |
| 213 | + batches = pf.iter_batches(batch_size = BATCH_ROWS) |
| 214 | + |
| 215 | + for batch in batches: |
| 216 | + tbl = pa.Table.from_batches([batch]) |
| 217 | + df_lookup = tbl.select(lookup_cols).to_pandas() |
| 218 | + preds, matched = _compute_batch_predictions( |
| 219 | + df_lookup, const_arr, by_key_lookups, constant_version, |
| 220 | + ) |
| 221 | + |
| 222 | + for key in by_key_lookups: |
| 223 | + version = f"{MODEL_STUB}_by_{key}" |
| 224 | + version_counts[version] += int( |
| 225 | + (preds["model_version"] == version).sum() |
| 226 | + ) |
| 227 | + version_counts[constant_version] += int(len(df_lookup) - matched.sum()) |
| 228 | + |
| 229 | + for field in new_fields: |
| 230 | + tbl = tbl.append_column( |
| 231 | + field.name, |
| 232 | + pa.array(preds[field.name], type = field.type), |
| 233 | + ) |
| 234 | + |
| 235 | + writer.write_table(tbl, row_group_size = ROW_GROUP_SIZE) |
| 236 | + n_written += batch.num_rows |
| 237 | + print(f" {n_written:,}/{n_total:,} rows written", flush = True) |
192 | 238 |
|
193 | 239 | print("\nModel version breakdown:") |
194 | | - print(gdf["model_version"].value_counts().to_string()) |
195 | | - |
196 | | - OUTPUT_PATH.parent.mkdir(parents = True, exist_ok = True) |
197 | | - tmp_path = OUTPUT_PATH.with_suffix(OUTPUT_PATH.suffix + ".unsorted") |
198 | | - print(f"\nWriting unsorted snapshot to {tmp_path} ...") |
199 | | - gdf.to_parquet(tmp_path, compression = "zstd") |
200 | | - n_total = len(gdf) |
201 | | - del gdf |
202 | | - gc.collect() |
203 | | - |
204 | | - print(f"Reordering via PyArrow and writing final output to {OUTPUT_PATH} ...") |
205 | | - table = pq.read_table(tmp_path) |
206 | | - sorted_table = table.take(pa.array(sorted_indices)) |
207 | | - del table |
208 | | - gc.collect() |
209 | | - pq.write_table( |
210 | | - sorted_table, |
211 | | - OUTPUT_PATH, |
212 | | - compression = "zstd", |
213 | | - row_group_size = 50_000, |
214 | | - ) |
215 | | - del sorted_table |
216 | | - gc.collect() |
217 | | - tmp_path.unlink() |
218 | | - print(f"Done. Saved {n_total:,} POIs.") |
| 240 | + for version, count in sorted(version_counts.items(), key = lambda kv: -kv[1]): |
| 241 | + print(f" {version}: {count:,}") |
| 242 | + print(f"\nDone. Saved {n_written:,} POIs.", flush = True) |
0 commit comments