Skip to content

Commit c9c7d0e

Browse files
committed
Track peak memory during run.
1 parent d8886ee commit c9c7d0e

1 file changed

Lines changed: 92 additions & 4 deletions

File tree

scripts/conflation/conflate.py

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from config_versioned import Config
4848
from shapely.geometry import box
4949

50+
from openpois.conflation.chunking import extract_centroids_lonlat
5051
from openpois.conflation.match import (
5152
compute_match_scores,
5253
find_and_score_matches_chunked,
@@ -72,6 +73,45 @@
7273
)
7374

7475

76+
# -----------------------------------------------------------------
77+
# Memory instrumentation
78+
# -----------------------------------------------------------------
79+
80+
_RSS_T0 = time.time()
81+
82+
83+
def _read_proc_status() -> dict[str, int]:
84+
"""Return VmRSS and VmHWM in bytes from /proc/self/status (Linux)."""
85+
out: dict[str, int] = {}
86+
try:
87+
with open("/proc/self/status", "r") as f:
88+
for line in f:
89+
if line.startswith(("VmRSS:", "VmHWM:")):
90+
key, val, unit = line.split()
91+
out[key.rstrip(":")] = int(val) * 1024
92+
except FileNotFoundError:
93+
pass
94+
return out
95+
96+
97+
def log_rss(label: str) -> None:
98+
"""Print current RSS and peak RSS at a phase boundary.
99+
100+
Forces a gc.collect() first so the report reflects retained
101+
memory, not pending garbage. Cheap to call (~few ms).
102+
"""
103+
gc.collect()
104+
info = _read_proc_status()
105+
rss = info.get("VmRSS", 0) / 2**30
106+
hwm = info.get("VmHWM", 0) / 2**30
107+
elapsed = time.time() - _RSS_T0
108+
print(
109+
f" [RSS {rss:5.2f} GB | peak {hwm:5.2f} GB | "
110+
f"+{elapsed:6.0f}s] {label}",
111+
flush = True,
112+
)
113+
114+
75115
# -----------------------------------------------------------------
76116
# Configuration
77117
# -----------------------------------------------------------------
@@ -110,6 +150,17 @@
110150
"overture_id", "taxonomy_l0", "taxonomy_l1", "taxonomy_l2",
111151
"overture_name", "brand_name", "confidence", "geometry",
112152
]
153+
# Columns needed downstream by ``build_merge_parts*``. Reloaded into
154+
# memory after the chunked matcher returns so the matching phase can
155+
# run without the full source GeoDataFrames resident.
156+
OSM_MERGE_COLS = [
157+
"osm_id", "name", "brand",
158+
"conf_mean", "conf_lower", "conf_upper", "geometry",
159+
]
160+
OVERTURE_MERGE_COLS = [
161+
"overture_id", "overture_name", "brand_name",
162+
"confidence", "geometry",
163+
]
113164

114165

115166
# -----------------------------------------------------------------
@@ -169,15 +220,19 @@ def _load_gdf(
169220

170221
test_bbox = TEST_BBOX if args.test else None
171222

223+
log_rss("startup")
224+
172225
# -- Load data -------------------------------------------------
173226
osm_gdf = _load_gdf(
174227
OSM_PATH, OSM_MATCH_COLS,
175228
test_bbox = test_bbox, label = "OSM rated",
176229
)
230+
log_rss("after OSM load")
177231
overture_gdf = _load_gdf(
178232
OVERTURE_PATH, OVERTURE_MATCH_COLS,
179233
test_bbox = test_bbox, label = "Overture",
180234
)
235+
log_rss("after Overture load")
181236

182237
# -- Taxonomy assignment ---------------------------------------
183238
print("\nAssigning shared labels ...")
@@ -225,7 +280,8 @@ def _load_gdf(
225280
for col in ["taxonomy_l0", "taxonomy_l1", "taxonomy_l2"]:
226281
if col in overture_gdf.columns:
227282
overture_gdf.drop(columns = col, inplace = True)
228-
gc.collect()
283+
del osm_crosswalk, overture_crosswalk, top_level_matches
284+
log_rss("after taxonomy assignment + tag-col drop")
229285

230286
# -- Matching --------------------------------------------------
231287
# Prepare name/brand arrays once (used by both code paths).
@@ -313,9 +369,23 @@ def _load_gdf(
313369
f"(target ~{CHUNK_TARGET_POIS:,} POIs/chunk, "
314370
f"max {MAX_RADIUS_M}m) ..."
315371
)
372+
373+
# Precompute centroids before freeing the source frames so
374+
# the matching phase never holds the full GeoDataFrames in
375+
# memory. Geometries are reloaded from disk for the merge.
376+
print(" Precomputing centroids ...")
377+
osm_centroids_lonlat = extract_centroids_lonlat(
378+
np.asarray(osm_gdf.geometry.values)
379+
)
380+
overture_centroids_lonlat = extract_centroids_lonlat(
381+
np.asarray(overture_gdf.geometry.values)
382+
)
383+
del osm_gdf, overture_gdf
384+
log_rss("after dropping gdfs (centroids extracted)")
385+
316386
matches, chunk_summary = find_and_score_matches_chunked(
317-
osm_geom = osm_gdf.geometry.values,
318-
overture_geom = overture_gdf.geometry.values,
387+
osm_centroids_lonlat = osm_centroids_lonlat,
388+
overture_centroids_lonlat = overture_centroids_lonlat,
319389
osm_radii_m = osm_radii,
320390
osm_shared_labels = osm_shared_labels,
321391
overture_shared_labels = overture_shared_labels,
@@ -335,12 +405,14 @@ def _load_gdf(
335405
chunk_size = CHUNK_SIZE,
336406
checkpoint_dir = checkpoint_dir,
337407
)
408+
del osm_centroids_lonlat, overture_centroids_lonlat
338409
print(
339410
f" Selected {len(matches):,} one-to-one matches "
340411
f"across {chunk_summary['n_chunks']} chunks "
341412
f"(Overture dedup drops: "
342413
f"{chunk_summary['n_overture_dedup_drops']:,})"
343414
)
415+
log_rss("after chunked matching + dedup")
344416

345417
del osm_names, osm_brands
346418
del overture_names, overture_brands
@@ -370,6 +442,21 @@ def _load_gdf(
370442
n_matches = len(matches)
371443

372444
if chunk_summary is not None:
445+
# Reload only the columns the merge needs — the matching
446+
# phase has already returned the dedup-resolved matches, so
447+
# we no longer need the wide load schema.
448+
print(" Reloading source frames for merge ...")
449+
osm_gdf = _load_gdf(
450+
OSM_PATH, OSM_MERGE_COLS,
451+
test_bbox = test_bbox, label = "OSM (merge cols)",
452+
)
453+
overture_gdf = _load_gdf(
454+
OVERTURE_PATH, OVERTURE_MERGE_COLS,
455+
test_bbox = test_bbox,
456+
label = "Overture (merge cols)",
457+
)
458+
log_rss("after reload for merge")
459+
373460
part_paths = build_merge_parts_chunked(
374461
osm_gdf = osm_gdf,
375462
overture_gdf = overture_gdf,
@@ -395,11 +482,12 @@ def _load_gdf(
395482
del osm_gdf, overture_gdf, matches
396483
del osm_shared_labels, overture_shared_labels, osm_radii
397484
del osm_l0_bits, overture_l0_bits
398-
gc.collect()
485+
log_rss("after merge parts written")
399486

400487
# -- Save ------------------------------------------------------
401488
print("\nSaving conflated dataset ...")
402489
n_total = save_conflated_from_parts(part_paths, OUTPUT_PATH)
490+
log_rss("after final parquet stream")
403491
config.write_self("conflation")
404492

405493
# Clear chunk checkpoints after a successful save.

0 commit comments

Comments
 (0)