|
21 | 21 | from shapely.affinity import scale |
22 | 22 | from shapely.geometry import Polygon, box, shape |
23 | 23 | from shapely.ops import orient |
| 24 | +from tqdm import tqdm |
24 | 25 |
|
25 | 26 | # Type aliases definitions |
26 | 27 | Feature = Dict[str, Any] |
@@ -343,78 +344,90 @@ def calc_iou(shape1, shape2): |
343 | 344 | return iou |
344 | 345 |
|
345 | 346 |
|
346 | | -def clean_crowns(crowns: gpd.GeoDataFrame, |
347 | | - iou_threshold: float = 0.7, |
348 | | - confidence: float = 0.2, |
349 | | - area_threshold: float = 2, |
350 | | - field: str = "Confidence_score") -> gpd.GeoDataFrame: |
351 | | - """Clean overlapping crowns. |
352 | | -
|
353 | | - Outputs can contain highly overlapping crowns including in the buffer region. |
354 | | - This function removes crowns with a high degree of overlap with others but a |
355 | | - lower Confidence Score. |
356 | | -
|
| 347 | +def clean_crowns(crowns, |
| 348 | + iou_threshold= 0.7, |
| 349 | + confidence= 0.2, |
| 350 | + area_threshold = 2, |
| 351 | + field= "Confidence_score") -> gpd.GeoDataFrame: |
| 352 | + """ |
| 353 | + Clean overlapping crowns by first identifying all candidate overlapping pairs via a spatial join, |
| 354 | + then clustering crowns into connected components (where an edge is added if two crowns have IoU |
| 355 | + above a threshold), and finally keeping the best crown (by confidence or any given field) in each cluster. |
| 356 | + |
357 | 357 | Args: |
358 | 358 | crowns (gpd.GeoDataFrame): Crowns to be cleaned. |
359 | 359 | iou_threshold (float, optional): IoU threshold that determines whether crowns are overlapping. |
360 | 360 | confidence (float, optional): Minimum confidence score for crowns to be retained. Defaults to 0.2. Note that |
361 | 361 | this should be adjusted to fit "field". |
362 | | - area_threshold (float, optional): Minimum area of crowns to be retained. Defaults to 1m2 (assuming UTM). |
| 362 | + area_threshold (float, optional): Minimum area of crowns to be retained. Defaults to 2m2 (assuming UTM). |
363 | 363 | field (str): Field to used to prioritise selection of crowns. Defaults to "Confidence_score" but this should |
364 | 364 | be changed to "Area" if using a model that outputs area. |
365 | 365 |
|
366 | 366 | Returns: |
367 | 367 | gpd.GeoDataFrame: Cleaned crowns. |
368 | 368 | """ |
369 | | - # Filter any rows with empty or invalid geometry |
370 | | - crowns = crowns[~crowns.is_empty & crowns.is_valid] |
| 369 | + # 1. Filter out invalid geometries and tiny artifacts. |
| 370 | + crowns = crowns[~crowns.is_empty & crowns.is_valid].copy() |
| 371 | + crowns = crowns[crowns.area > area_threshold].copy() |
371 | 372 |
|
372 | | - # Filter any rows with polgon of less than 1m2 as these are likely to be artifacts |
373 | | - crowns = crowns[crowns.area > area_threshold] |
| 373 | + if confidence: |
| 374 | + crowns = crowns[crowns[field] > confidence] |
374 | 375 |
|
375 | 376 | crowns.reset_index(drop=True, inplace=True) |
376 | 377 |
|
377 | | - cleaned_crowns = [] |
378 | | - print(f"Cleaning {len(crowns)} crowns") |
379 | | - |
380 | | - for index, row in crowns.iterrows(): |
381 | | - if index % 1000 == 0: |
382 | | - print(f"{index} / {len(crowns)} crowns cleaned") |
383 | | - |
384 | | - intersecting_rows = crowns[crowns.intersects(shape(row.geometry))] |
385 | | - |
386 | | - if len(intersecting_rows) > 1: |
387 | | - iou_values = intersecting_rows.geometry.map(lambda x: calc_iou(row.geometry, x)) |
388 | | - intersecting_rows = intersecting_rows.assign(iou=iou_values) |
389 | | - |
390 | | - # Filter rows with IoU over threshold and get the one with the highest confidence score |
391 | | - match = intersecting_rows[intersecting_rows["iou"] > iou_threshold].nlargest(1, field) |
392 | | - |
393 | | - if match["iou"].iloc[0] < 1: |
394 | | - continue |
395 | | - |
396 | | - else: |
397 | | - match = row.to_frame().T |
398 | | - |
399 | | - cleaned_crowns.append(match) |
400 | | - |
401 | | - crowns_out = pd.concat(cleaned_crowns, ignore_index=True) |
402 | | - |
403 | | - # Drop 'iou' column if it exists |
404 | | - if "iou" in crowns_out.columns: |
405 | | - crowns_out = crowns_out.drop("iou", axis=1) |
406 | | - |
407 | | - # Ensuring crowns_out is a GeoDataFrame |
408 | | - if not isinstance(crowns_out, gpd.GeoDataFrame): |
409 | | - crowns_out = gpd.GeoDataFrame(crowns_out, crs=crowns.crs) |
410 | | - else: |
411 | | - crowns_out = crowns_out.set_crs(crowns.crs) |
412 | | - |
413 | | - # Filter remaining crowns based on confidence score |
414 | | - if confidence != 0: |
415 | | - crowns_out = crowns_out[crowns_out[field] > confidence] |
416 | | - |
417 | | - return crowns_out.reset_index(drop=True) |
| 378 | + # 2. Use a spatial join to quickly find all candidate overlapping pairs. |
| 379 | + # The join will pair each crown with any crown whose bounding box intersects. |
| 380 | + print("clearn_crowns: Performing spatial join...") |
| 381 | + join = gpd.sjoin(crowns, crowns, how="inner", predicate="intersects") |
| 382 | + # Remove self-joins (where a crown is paired with itself). |
| 383 | + join = join[join.index != join.index_right] |
| 384 | + |
| 385 | + # 3. Set up a union-find structure to cluster overlapping crowns. |
| 386 | + n = len(crowns) |
| 387 | + parent = list(range(n)) # Initially, each crown is its own group. |
| 388 | + |
| 389 | + def find(x): |
| 390 | + # Path compression to flatten the tree. |
| 391 | + while parent[x] != x: |
| 392 | + parent[x] = parent[parent[x]] |
| 393 | + x = parent[x] |
| 394 | + return x |
| 395 | + |
| 396 | + def union(x, y): |
| 397 | + rx, ry = find(x), find(y) |
| 398 | + if rx != ry: |
| 399 | + parent[ry] = rx |
| 400 | + |
| 401 | + # 4. For each candidate pair, compute IoU and, if it exceeds the threshold, merge the groups. |
| 402 | + for idx, row in tqdm(join.iterrows(), total=len(join), desc="clean_crowns: Processing candidate pairs", smoothing=0): |
| 403 | + i = row.name # index from left table (crowns) |
| 404 | + j = row["index_right"] # index from right table (crowns) |
| 405 | + # To avoid duplicate work, skip if i and j are already in the same group. |
| 406 | + if find(i) == find(j): |
| 407 | + continue |
| 408 | + # Compute the IoU for the pair. |
| 409 | + iou_val = calc_iou(crowns.at[i, "geometry"], crowns.at[j, "geometry"]) |
| 410 | + if iou_val > iou_threshold: |
| 411 | + union(i, j) |
| 412 | + |
| 413 | + # 5. Group crowns by their union-find root. |
| 414 | + groups = {} |
| 415 | + for i in range(n): |
| 416 | + root = find(i) |
| 417 | + groups.setdefault(root, []).append(i) |
| 418 | + |
| 419 | + # 6. In each group, select the crown with the highest "confidence" (or the value in `field`). |
| 420 | + selected_indices = [] |
| 421 | + for comp in groups.values(): |
| 422 | + group_df = crowns.loc[comp] |
| 423 | + best_idx = group_df[field].idxmax() |
| 424 | + selected_indices.append(best_idx) |
| 425 | + |
| 426 | + # 7. Assemble the cleaned crowns. |
| 427 | + cleaned_crowns = crowns.loc[selected_indices].copy() |
| 428 | + |
| 429 | + |
| 430 | + return gpd.GeoDataFrame(cleaned_crowns, crs=crowns.crs).reset_index(drop=True) |
418 | 431 |
|
419 | 432 |
|
420 | 433 | def post_clean(unclean_df: gpd.GeoDataFrame, |
|
0 commit comments