diff --git a/src/segger/geometry/query.py b/src/segger/geometry/query.py index dd4a0c3..0f73e2e 100644 --- a/src/segger/geometry/query.py +++ b/src/segger/geometry/query.py @@ -18,6 +18,17 @@ logger = logging.getLogger(__name__) + +def _empty_match_frame() -> cudf.DataFrame: + """Empty match frame with the canonical output schema.""" + return cudf.DataFrame( + { + "index_query": cudf.Series([], dtype="int64"), + "index_match": cudf.Series([], dtype="int64"), + } + ) + + def _points_in_polygons_contains( points: cuspatial.GeoSeries, polygons: cuspatial.GeoSeries, @@ -59,32 +70,50 @@ def _points_in_polygons_contains( with_bounds=False ) - # Perform spatial join in batches - batch_idx = np.linspace(0, len(polygons), (batches or 1) + 1, dtype=int) - results = [] - for start_idx, end_idx in zip(batch_idx, batch_idx[1:]): - - # Get polygons for this batch - batch_polygons = polygons.iloc[start_idx:end_idx] - bboxes = cuspatial.polygon_bounding_boxes(batch_polygons) - poly_quad_pairs = cuspatial.join_quadtree_and_bounding_boxes( - quadtree=quadtree, - bounding_boxes=bboxes, - **kwargs - ) - # Run spatial join - result = cuspatial.quadtree_point_in_polygon( - poly_quad_pairs, - quadtree, - point_indices, - points, - batch_polygons, - ) - # Adjust polygon indices back to global indices - result['polygon_index'] += start_idx - results.append(result) + # Perform spatial join in batches. If a large join hits GPU OOM, retry with + # progressively finer batching so huge inputs (e.g. MERSCOPE) don't crash. + batch_count = batches or 1 + while True: + try: + batch_idx = np.linspace(0, len(polygons), batch_count + 1, dtype=int) + results = [] + for start_idx, end_idx in zip(batch_idx, batch_idx[1:]): + if end_idx <= start_idx: + continue + # Get polygons for this batch + batch_polygons = polygons.iloc[start_idx:end_idx] + bboxes = cuspatial.polygon_bounding_boxes(batch_polygons) + poly_quad_pairs = cuspatial.join_quadtree_and_bounding_boxes( + quadtree=quadtree, + bounding_boxes=bboxes, + **kwargs + ) + # Run spatial join + result = cuspatial.quadtree_point_in_polygon( + poly_quad_pairs, + quadtree, + point_indices, + points, + batch_polygons, + ) + # Adjust polygon indices back to global indices + result['polygon_index'] += start_idx + results.append(result) + break + except Exception as exc: + message = str(exc).lower() + is_oom = isinstance(exc, MemoryError) or ( + "out_of_memory" in message + or "out of memory" in message + or ("cuda error" in message and "memory" in message) + ) + if not is_oom or batch_count >= 256: + raise + batch_count *= 2 # Concatenate all batch results + if not results: + return _empty_match_frame() result = cudf.concat(results, ignore_index=True) result = result.rename( {'point_index': 'index_query', 'polygon_index': 'index_match'},