Skip to content

Commit 44ceebc

Browse files
efficient clean_crowns (spacial join)
1 parent 10a7ff9 commit 44ceebc

1 file changed

Lines changed: 70 additions & 57 deletions

File tree

detectree2/models/outputs.py

Lines changed: 70 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from shapely.affinity import scale
2222
from shapely.geometry import Polygon, box, shape
2323
from shapely.ops import orient
24+
from tqdm import tqdm
2425

2526
# Type aliases definitions
2627
Feature = Dict[str, Any]
@@ -343,78 +344,90 @@ def calc_iou(shape1, shape2):
343344
return iou
344345

345346

346-
def clean_crowns(crowns: gpd.GeoDataFrame,
347-
iou_threshold: float = 0.7,
348-
confidence: float = 0.2,
349-
area_threshold: float = 2,
350-
field: str = "Confidence_score") -> gpd.GeoDataFrame:
351-
"""Clean overlapping crowns.
352-
353-
Outputs can contain highly overlapping crowns including in the buffer region.
354-
This function removes crowns with a high degree of overlap with others but a
355-
lower Confidence Score.
356-
347+
def clean_crowns(crowns,
348+
iou_threshold= 0.7,
349+
confidence= 0.2,
350+
area_threshold = 2,
351+
field= "Confidence_score") -> gpd.GeoDataFrame:
352+
"""
353+
Clean overlapping crowns by first identifying all candidate overlapping pairs via a spatial join,
354+
then clustering crowns into connected components (where an edge is added if two crowns have IoU
355+
above a threshold), and finally keeping the best crown (by confidence or any given field) in each cluster.
356+
357357
Args:
358358
crowns (gpd.GeoDataFrame): Crowns to be cleaned.
359359
iou_threshold (float, optional): IoU threshold that determines whether crowns are overlapping.
360360
confidence (float, optional): Minimum confidence score for crowns to be retained. Defaults to 0.2. Note that
361361
this should be adjusted to fit "field".
362-
area_threshold (float, optional): Minimum area of crowns to be retained. Defaults to 1m2 (assuming UTM).
362+
area_threshold (float, optional): Minimum area of crowns to be retained. Defaults to 2m2 (assuming UTM).
363363
field (str): Field to used to prioritise selection of crowns. Defaults to "Confidence_score" but this should
364364
be changed to "Area" if using a model that outputs area.
365365
366366
Returns:
367367
gpd.GeoDataFrame: Cleaned crowns.
368368
"""
369-
# Filter any rows with empty or invalid geometry
370-
crowns = crowns[~crowns.is_empty & crowns.is_valid]
369+
# 1. Filter out invalid geometries and tiny artifacts.
370+
crowns = crowns[~crowns.is_empty & crowns.is_valid].copy()
371+
crowns = crowns[crowns.area > area_threshold].copy()
371372

372-
# Filter any rows with polgon of less than 1m2 as these are likely to be artifacts
373-
crowns = crowns[crowns.area > area_threshold]
373+
if confidence:
374+
crowns = crowns[crowns[field] > confidence]
374375

375376
crowns.reset_index(drop=True, inplace=True)
376377

377-
cleaned_crowns = []
378-
print(f"Cleaning {len(crowns)} crowns")
379-
380-
for index, row in crowns.iterrows():
381-
if index % 1000 == 0:
382-
print(f"{index} / {len(crowns)} crowns cleaned")
383-
384-
intersecting_rows = crowns[crowns.intersects(shape(row.geometry))]
385-
386-
if len(intersecting_rows) > 1:
387-
iou_values = intersecting_rows.geometry.map(lambda x: calc_iou(row.geometry, x))
388-
intersecting_rows = intersecting_rows.assign(iou=iou_values)
389-
390-
# Filter rows with IoU over threshold and get the one with the highest confidence score
391-
match = intersecting_rows[intersecting_rows["iou"] > iou_threshold].nlargest(1, field)
392-
393-
if match["iou"].iloc[0] < 1:
394-
continue
395-
396-
else:
397-
match = row.to_frame().T
398-
399-
cleaned_crowns.append(match)
400-
401-
crowns_out = pd.concat(cleaned_crowns, ignore_index=True)
402-
403-
# Drop 'iou' column if it exists
404-
if "iou" in crowns_out.columns:
405-
crowns_out = crowns_out.drop("iou", axis=1)
406-
407-
# Ensuring crowns_out is a GeoDataFrame
408-
if not isinstance(crowns_out, gpd.GeoDataFrame):
409-
crowns_out = gpd.GeoDataFrame(crowns_out, crs=crowns.crs)
410-
else:
411-
crowns_out = crowns_out.set_crs(crowns.crs)
412-
413-
# Filter remaining crowns based on confidence score
414-
if confidence != 0:
415-
crowns_out = crowns_out[crowns_out[field] > confidence]
416-
417-
return crowns_out.reset_index(drop=True)
378+
# 2. Use a spatial join to quickly find all candidate overlapping pairs.
379+
# The join will pair each crown with any crown whose bounding box intersects.
380+
print("clearn_crowns: Performing spatial join...")
381+
join = gpd.sjoin(crowns, crowns, how="inner", predicate="intersects")
382+
# Remove self-joins (where a crown is paired with itself).
383+
join = join[join.index != join.index_right]
384+
385+
# 3. Set up a union-find structure to cluster overlapping crowns.
386+
n = len(crowns)
387+
parent = list(range(n)) # Initially, each crown is its own group.
388+
389+
def find(x):
390+
# Path compression to flatten the tree.
391+
while parent[x] != x:
392+
parent[x] = parent[parent[x]]
393+
x = parent[x]
394+
return x
395+
396+
def union(x, y):
397+
rx, ry = find(x), find(y)
398+
if rx != ry:
399+
parent[ry] = rx
400+
401+
# 4. For each candidate pair, compute IoU and, if it exceeds the threshold, merge the groups.
402+
for idx, row in tqdm(join.iterrows(), total=len(join), desc="clean_crowns: Processing candidate pairs", smoothing=0):
403+
i = row.name # index from left table (crowns)
404+
j = row["index_right"] # index from right table (crowns)
405+
# To avoid duplicate work, skip if i and j are already in the same group.
406+
if find(i) == find(j):
407+
continue
408+
# Compute the IoU for the pair.
409+
iou_val = calc_iou(crowns.at[i, "geometry"], crowns.at[j, "geometry"])
410+
if iou_val > iou_threshold:
411+
union(i, j)
412+
413+
# 5. Group crowns by their union-find root.
414+
groups = {}
415+
for i in range(n):
416+
root = find(i)
417+
groups.setdefault(root, []).append(i)
418+
419+
# 6. In each group, select the crown with the highest "confidence" (or the value in `field`).
420+
selected_indices = []
421+
for comp in groups.values():
422+
group_df = crowns.loc[comp]
423+
best_idx = group_df[field].idxmax()
424+
selected_indices.append(best_idx)
425+
426+
# 7. Assemble the cleaned crowns.
427+
cleaned_crowns = crowns.loc[selected_indices].copy()
428+
429+
430+
return gpd.GeoDataFrame(cleaned_crowns, crs=crowns.crs).reset_index(drop=True)
418431

419432

420433
def post_clean(unclean_df: gpd.GeoDataFrame,

0 commit comments

Comments
 (0)