Skip to content

Commit 5011b9d

Browse files
committed
Add model fitting using nationwide changeset.
1 parent 20c9458 commit 5011b9d

2 files changed

Lines changed: 113 additions & 30 deletions

File tree

src/openpois/io/geohash_partition.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,20 @@ def write_partitioned_dataset(
7373
"Pass overwrite=True to replace it."
7474
)
7575

76-
gdf = gdf.sort_values(["geohash_prefix", "geohash_sort"]).drop(
77-
columns = ["geohash_sort"]
78-
)
79-
cols = [c for c in gdf.columns if c != "geohash_prefix"]
80-
n_partitions = gdf["geohash_prefix"].nunique()
76+
cols = [c for c in gdf.columns if c not in ("geohash_prefix", "geohash_sort")]
8177
output_dir.mkdir(parents = True, exist_ok = True)
8278

79+
# Iterate without a global sort_values: that would double peak memory on
80+
# multi-GB frames. groupby(sort = False) hands us each partition as a view;
81+
# each small partition is sorted in-place before writing.
82+
groups = gdf.groupby("geohash_prefix", sort = False, observed = True)
83+
n_partitions = len(groups)
8384
print(f"Writing {n_partitions} partitions to {output_dir} ...")
84-
for i, (prefix, group) in enumerate(gdf.groupby("geohash_prefix")):
85+
for i, (prefix, group) in enumerate(groups):
8586
partition_dir = output_dir / f"geohash_prefix={prefix}"
8687
partition_dir.mkdir()
87-
group[cols].to_parquet(partition_dir / "part-0.parquet")
88+
group.sort_values("geohash_sort")[cols].to_parquet(
89+
partition_dir / "part-0.parquet"
90+
)
8891
if (i + 1) % 100 == 0:
8992
print(f" {i + 1}/{n_partitions} partitions written...")

src/openpois/io/osm_history_pbf.py

Lines changed: 103 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
("change", pa.string()),
7676
("id", pa.int64()),
7777
("version", pa.int64()),
78+
("type", pa.string()),
7879
])
7980

8081

@@ -493,6 +494,7 @@ def parse_history_pbf(
493494
for diff_row in _diff_tag_sets(prev_tags, curr_tags):
494495
diff_row["id"] = int(obj.id)
495496
diff_row["version"] = int(obj.version)
497+
diff_row["type"] = kind
496498
changes_buf.append(diff_row)
497499

498500
prev_key = key
@@ -520,23 +522,103 @@ def parse_history_pbf(
520522

521523

522524
# -----------------------------------------------------------------------------
523-
# Parquet concatenation (US + PR)
525+
# Parquet concatenation (US + PR) with cross-extract deduplication
524526
# -----------------------------------------------------------------------------
525527

526528

527-
def _concat_parquets(
528-
inputs: list[Path],
529-
output: Path,
530-
schema: pa.Schema,
531-
) -> Path:
532-
"""Stream-concatenate row groups from ``inputs`` into ``output``."""
533-
output.parent.mkdir(parents=True, exist_ok=True)
534-
with pq.ParquetWriter(output, schema) as writer:
535-
for path in inputs:
536-
reader = pq.ParquetFile(str(path))
537-
for batch in reader.iter_batches():
538-
writer.write_table(pa.Table.from_batches([batch], schema=schema))
539-
return output
529+
def _concat_history(
530+
us_versions_path: Path,
531+
pr_versions_path: Path,
532+
out_versions_path: Path,
533+
us_changes_path: Path,
534+
pr_changes_path: Path,
535+
out_changes_path: Path,
536+
) -> tuple[Path, Path]:
537+
"""
538+
Stream-concatenate US + PR versions/changes Parquets, dropping PR rows for
539+
any element already present in the US file.
540+
541+
Geofabrik's per-state/-territory extracts share near-boundary elements:
542+
the same ``(type, id)`` version can legitimately appear in both the
543+
US-mainland and Puerto Rico extracts. Concatenating naively would produce
544+
duplicate rows per ``(id, version, key)`` in the changes Parquet, which
545+
breaks ``format_observations`` (it calls ``.loc[key, "change"]`` and
546+
expects a scalar, not a Series).
547+
548+
Strategy:
549+
- Stream-copy US versions to the output, collecting the set of
550+
``(type, id)`` seen.
551+
- Load PR versions fully (small — PR is ~70K versions), drop any row whose
552+
``(type, id)`` is already in US, write the remainder.
553+
- Stream-copy US changes to the output.
554+
- Load PR changes fully, drop any row whose ``id`` matches an element
555+
dropped from PR versions, write the remainder.
556+
557+
Dedup is ``(type, id)``-keyed in both tables. OSM element ids are only
558+
unique *within* a type, so an id-only join would incorrectly collapse a
559+
node and a way that share an integer id — see the change-log for the
560+
bug that motivated adding ``type`` to ``osm_changes``.
561+
562+
Args:
563+
us_versions_path: Intermediate US versions Parquet.
564+
pr_versions_path: Intermediate PR versions Parquet.
565+
out_versions_path: Final concatenated versions Parquet.
566+
us_changes_path: Intermediate US changes Parquet.
567+
pr_changes_path: Intermediate PR changes Parquet.
568+
out_changes_path: Final concatenated changes Parquet.
569+
570+
Returns:
571+
Tuple ``(out_versions_path, out_changes_path)``.
572+
"""
573+
out_versions_path.parent.mkdir(parents=True, exist_ok=True)
574+
out_changes_path.parent.mkdir(parents=True, exist_ok=True)
575+
576+
# Pass 1: versions
577+
us_type_ids: set[tuple[str, int]] = set()
578+
with pq.ParquetWriter(out_versions_path, VERSIONS_SCHEMA) as writer:
579+
us_reader = pq.ParquetFile(str(us_versions_path))
580+
for batch in us_reader.iter_batches():
581+
tbl = pa.Table.from_batches([batch], schema=VERSIONS_SCHEMA)
582+
us_type_ids.update(
583+
zip(
584+
tbl.column("type").to_pylist(),
585+
tbl.column("id").to_pylist(),
586+
)
587+
)
588+
writer.write_table(tbl)
589+
pr_v = pq.read_table(str(pr_versions_path), schema=VERSIONS_SCHEMA)
590+
pr_types = pr_v.column("type").to_pylist()
591+
pr_ids = pr_v.column("id").to_pylist()
592+
keep_mask = [
593+
(t, i) not in us_type_ids for t, i in zip(pr_types, pr_ids)
594+
]
595+
pr_dropped_type_ids = {
596+
(t, i) for (t, i), keep in zip(zip(pr_types, pr_ids), keep_mask)
597+
if not keep
598+
}
599+
pr_v_filtered = pr_v.filter(pa.array(keep_mask, type=pa.bool_()))
600+
if pr_v_filtered.num_rows > 0:
601+
writer.write_table(pr_v_filtered)
602+
603+
# Pass 2: changes
604+
with pq.ParquetWriter(out_changes_path, CHANGES_SCHEMA) as writer:
605+
us_reader = pq.ParquetFile(str(us_changes_path))
606+
for batch in us_reader.iter_batches():
607+
writer.write_table(pa.Table.from_batches([batch], schema=CHANGES_SCHEMA))
608+
pr_c = pq.read_table(str(pr_changes_path), schema=CHANGES_SCHEMA)
609+
if pr_dropped_type_ids and pr_c.num_rows > 0:
610+
keep_mask = [
611+
(t, i) not in pr_dropped_type_ids
612+
for t, i in zip(
613+
pr_c.column("type").to_pylist(),
614+
pr_c.column("id").to_pylist(),
615+
)
616+
]
617+
pr_c = pr_c.filter(pa.array(keep_mask, type=pa.bool_()))
618+
if pr_c.num_rows > 0:
619+
writer.write_table(pr_c)
620+
621+
return out_versions_path, out_changes_path
540622

541623

542624
# -----------------------------------------------------------------------------
@@ -691,15 +773,13 @@ def download_osm_history(
691773
"Concatenating US + PR Parquets into"
692774
f" {output_versions_path} / {output_changes_path}..."
693775
)
694-
_concat_parquets(
695-
inputs=[us_versions_path, pr_versions_path],
696-
output=output_versions_path,
697-
schema=VERSIONS_SCHEMA,
698-
)
699-
_concat_parquets(
700-
inputs=[us_changes_path, pr_changes_path],
701-
output=output_changes_path,
702-
schema=CHANGES_SCHEMA,
776+
_concat_history(
777+
us_versions_path=us_versions_path,
778+
pr_versions_path=pr_versions_path,
779+
out_versions_path=output_versions_path,
780+
us_changes_path=us_changes_path,
781+
pr_changes_path=pr_changes_path,
782+
out_changes_path=output_changes_path,
703783
)
704784
print(
705785
f"Saved OSM history to {output_versions_path} and {output_changes_path}"

0 commit comments

Comments
 (0)