Skip to content

Commit a75fd88

Browse files
committed
Fix marching squares case table and add tests (#964)
Corrected edge assignments in the marching squares lookup table. Previous table had wrong edges for most cases (e.g. case 1 used top edge instead of bottom edge). Added comprehensive test suite covering correctness, NaN handling, edge cases, backend equivalence, GeoDataFrame output, accessor integration, and closed-ring detection.
1 parent 8b57f6f commit a75fd88

File tree

2 files changed

+499
-85
lines changed

2 files changed

+499
-85
lines changed

xrspatial/contour.py

Lines changed: 120 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,20 @@
3737
# Empty tuple means no contour passes through the quad.
3838
_MS_TABLE = (
3939
(), # 0000 - all below
40-
((3, 0),), # 0001 - bottom-left above
41-
((0, 1),), # 0010 - bottom-right above
42-
((3, 1),), # 0011 - bottom row above
43-
((1, 2),), # 0100 - top-right above
44-
((3, 0), (1, 2)), # 0101 - saddle: BL and TR above
45-
((0, 2),), # 0110 - right column above
46-
((3, 2),), # 0111 - only top-left below
47-
((2, 3),), # 1000 - top-left above
48-
((2, 0),), # 1001 - left column above
49-
((2, 3), (0, 1)), # 1010 - saddle: TL and BR above
50-
((2, 1),), # 1011 - only top-right below
51-
((1, 3),), # 1100 - top row above
52-
((1, 0),), # 1101 - only bottom-right below
53-
((0, 3),), # 1110 - only bottom-left below
40+
((3, 2),), # 0001 - bl above: left-bottom
41+
((2, 1),), # 0010 - br above: bottom-right
42+
((3, 1),), # 0011 - bl+br above: left-right
43+
((0, 1),), # 0100 - tr above: top-right
44+
((2, 3), (0, 1)), # 0101 - saddle: bl+tr above (default: separated)
45+
((0, 2),), # 0110 - tr+br above: top-bottom
46+
((0, 3),), # 0111 - only tl below: top-left
47+
((0, 3),), # 1000 - tl above: top-left
48+
((0, 2),), # 1001 - tl+bl above: top-bottom
49+
((0, 3), (1, 2)), # 1010 - saddle: tl+br above (default: separated)
50+
((0, 1),), # 1011 - only tr below: top-right
51+
((1, 3),), # 1100 - tl+tr above: right-left
52+
((1, 2),), # 1101 - only br below: right-bottom
53+
((2, 3),), # 1110 - only bl below: bottom-left
5454
(), # 1111 - all above
5555
)
5656

@@ -113,65 +113,76 @@ def _marching_squares_kernel(data, level, seg_rows, seg_cols, seg_count):
113113
continue
114114

115115
# Saddle disambiguation: use center value.
116-
if idx == 5 or idx == 10:
116+
# Default: above-level corners stay separated.
117+
# Flipped (center >= level): above-level corners connect.
118+
if idx == 5:
117119
center = (tl + tr + bl + br) * 0.25
118120
if center >= level:
119-
# Connect like-valued corners.
120-
if idx == 5:
121-
# BL+TR above, center above -> connect them
122-
idx = 5 # keep as-is: two separate segments
123-
else:
124-
# TL+BR above, center above -> connect them
125-
idx = 10
126-
# If center < level, keep the default table entry.
121+
idx = 55 # flipped saddle
122+
elif idx == 10:
123+
center = (tl + tr + bl + br) * 0.25
124+
if center >= level:
125+
idx = 100 # flipped saddle
127126

128127
# Emit segments for this case.
129-
# Inline the lookup and interpolation for Numba compatibility.
130-
if idx == 1:
131-
_emit_seg(r, c, tl, tr, bl, br, level, 3, 0,
128+
# Edge numbering: 0=top, 1=right, 2=bottom, 3=left.
129+
# Each edge is crossed where one corner is above and the
130+
# other is below the contour level.
131+
if idx == 1: # bl above: left-bottom
132+
_emit_seg(r, c, tl, tr, bl, br, level, 3, 2,
132133
seg_rows, seg_cols, seg_count)
133-
elif idx == 2:
134-
_emit_seg(r, c, tl, tr, bl, br, level, 0, 1,
134+
elif idx == 2: # br above: bottom-right
135+
_emit_seg(r, c, tl, tr, bl, br, level, 2, 1,
135136
seg_rows, seg_cols, seg_count)
136-
elif idx == 3:
137+
elif idx == 3: # bl+br above: left-right
137138
_emit_seg(r, c, tl, tr, bl, br, level, 3, 1,
138139
seg_rows, seg_cols, seg_count)
139-
elif idx == 4:
140-
_emit_seg(r, c, tl, tr, bl, br, level, 1, 2,
140+
elif idx == 4: # tr above: top-right
141+
_emit_seg(r, c, tl, tr, bl, br, level, 0, 1,
142+
seg_rows, seg_cols, seg_count)
143+
elif idx == 5: # saddle bl+tr (separated)
144+
_emit_seg(r, c, tl, tr, bl, br, level, 2, 3,
145+
seg_rows, seg_cols, seg_count)
146+
_emit_seg(r, c, tl, tr, bl, br, level, 0, 1,
141147
seg_rows, seg_cols, seg_count)
142-
elif idx == 5:
148+
elif idx == 55: # saddle bl+tr (connected via center)
143149
_emit_seg(r, c, tl, tr, bl, br, level, 3, 0,
144150
seg_rows, seg_cols, seg_count)
145-
_emit_seg(r, c, tl, tr, bl, br, level, 1, 2,
151+
_emit_seg(r, c, tl, tr, bl, br, level, 2, 1,
146152
seg_rows, seg_cols, seg_count)
147-
elif idx == 6:
153+
elif idx == 6: # tr+br above: top-bottom
148154
_emit_seg(r, c, tl, tr, bl, br, level, 0, 2,
149155
seg_rows, seg_cols, seg_count)
150-
elif idx == 7:
151-
_emit_seg(r, c, tl, tr, bl, br, level, 3, 2,
156+
elif idx == 7: # only tl below: top-left
157+
_emit_seg(r, c, tl, tr, bl, br, level, 0, 3,
152158
seg_rows, seg_cols, seg_count)
153-
elif idx == 8:
154-
_emit_seg(r, c, tl, tr, bl, br, level, 2, 3,
159+
elif idx == 8: # tl above: top-left
160+
_emit_seg(r, c, tl, tr, bl, br, level, 0, 3,
155161
seg_rows, seg_cols, seg_count)
156-
elif idx == 9:
157-
_emit_seg(r, c, tl, tr, bl, br, level, 2, 0,
162+
elif idx == 9: # tl+bl above: top-bottom
163+
_emit_seg(r, c, tl, tr, bl, br, level, 0, 2,
158164
seg_rows, seg_cols, seg_count)
159-
elif idx == 10:
160-
_emit_seg(r, c, tl, tr, bl, br, level, 2, 3,
165+
elif idx == 10: # saddle tl+br (separated)
166+
_emit_seg(r, c, tl, tr, bl, br, level, 0, 3,
161167
seg_rows, seg_cols, seg_count)
168+
_emit_seg(r, c, tl, tr, bl, br, level, 1, 2,
169+
seg_rows, seg_cols, seg_count)
170+
elif idx == 100: # saddle tl+br (connected via center)
162171
_emit_seg(r, c, tl, tr, bl, br, level, 0, 1,
163172
seg_rows, seg_cols, seg_count)
164-
elif idx == 11:
165-
_emit_seg(r, c, tl, tr, bl, br, level, 2, 1,
173+
_emit_seg(r, c, tl, tr, bl, br, level, 3, 2,
174+
seg_rows, seg_cols, seg_count)
175+
elif idx == 11: # only tr below: top-right
176+
_emit_seg(r, c, tl, tr, bl, br, level, 0, 1,
166177
seg_rows, seg_cols, seg_count)
167-
elif idx == 12:
178+
elif idx == 12: # tl+tr above: right-left
168179
_emit_seg(r, c, tl, tr, bl, br, level, 1, 3,
169180
seg_rows, seg_cols, seg_count)
170-
elif idx == 13:
171-
_emit_seg(r, c, tl, tr, bl, br, level, 1, 0,
181+
elif idx == 13: # only br below: right-bottom
182+
_emit_seg(r, c, tl, tr, bl, br, level, 1, 2,
172183
seg_rows, seg_cols, seg_count)
173-
elif idx == 14:
174-
_emit_seg(r, c, tl, tr, bl, br, level, 0, 3,
184+
elif idx == 14: # only bl below: bottom-left
185+
_emit_seg(r, c, tl, tr, bl, br, level, 2, 3,
175186
seg_rows, seg_cols, seg_count)
176187

177188

@@ -269,6 +280,28 @@ def _stitch_segments(seg_rows, seg_cols, n_segs):
269280
_extend_line(line_r, line_c, 0, rows, cols, used, endpoint_map,
270281
DECIMALS)
271282

283+
# Check if the polyline forms a closed ring.
284+
start_key = (round(line_r[0], DECIMALS), round(line_c[0], DECIMALS))
285+
end_key = (round(line_r[-1], DECIMALS), round(line_c[-1], DECIMALS))
286+
if start_key == end_key and len(line_r) > 2:
287+
# Already closed, ensure exact closure.
288+
line_r[-1] = line_r[0]
289+
line_c[-1] = line_c[0]
290+
elif len(line_r) > 2:
291+
# Check if an unused segment connects end back to start.
292+
end_candidates = endpoint_map.get(end_key, [])
293+
for seg_idx, end_idx in end_candidates:
294+
if used[seg_idx]:
295+
continue
296+
other = 1 - end_idx
297+
other_key = (round(rows[seg_idx, other], DECIMALS),
298+
round(cols[seg_idx, other], DECIMALS))
299+
if other_key == start_key:
300+
used[seg_idx] = True
301+
line_r.append(line_r[0])
302+
line_c.append(line_c[0])
303+
break
304+
272305
coords = np.column_stack([line_r, line_c])
273306
lines.append(coords)
274307

@@ -349,30 +382,38 @@ def _contours_cupy(data, levels):
349382

350383

351384
def _contours_dask(data, levels):
352-
"""Dask backend: process each chunk independently, then merge.
385+
"""Dask backend: process each chunk with 1-cell overlap, then merge.
353386
354-
Each chunk is processed with a 1-cell overlap so that quads spanning
355-
chunk boundaries are handled by both neighbors. Duplicate segments
356-
at boundaries are removed during stitching.
387+
Uses ``dask.array.overlap.overlap`` to give each chunk a 1-cell halo
388+
so that 2x2 quads at chunk boundaries are processed by both neighbors.
389+
Duplicate segments are removed during the merge/stitch step.
357390
"""
358391
if da is None:
359392
raise ImportError("Dask is required for chunked contour extraction")
360393

361-
# Compute chunks independently.
362-
chunks = data.to_delayed().ravel()
363-
chunk_slices = _get_chunk_slices(data.chunks)
394+
padded = da.overlap.overlap(data, depth={0: 1, 1: 1}, boundary=np.nan)
395+
orig_row_chunks = data.chunks[0]
396+
orig_col_chunks = data.chunks[1]
397+
padded_blocks = padded.to_delayed()
364398

365399
all_results = []
366-
for chunk_delayed, (r_off, c_off) in zip(chunks, chunk_slices):
367-
result = dask.delayed(_process_chunk_numpy)(
368-
chunk_delayed, levels, r_off, c_off
369-
)
370-
all_results.append(result)
400+
r_off = 0
401+
for ri, rsize in enumerate(orig_row_chunks):
402+
c_off = 0
403+
for ci, csize in enumerate(orig_col_chunks):
404+
chunk = padded_blocks[ri, ci]
405+
# Padded chunk has 1-cell halo on each side (NaN at edges).
406+
# Global coordinate of the padded chunk's (0,0) is
407+
# (r_off - 1, c_off - 1).
408+
result = dask.delayed(_process_chunk_numpy)(
409+
chunk, levels, r_off - 1, c_off - 1
410+
)
411+
all_results.append(result)
412+
c_off += csize
413+
r_off += rsize
371414

372-
# Compute all chunks and merge.
373415
chunk_results = dask.compute(*all_results)
374416

375-
# Flatten and deduplicate.
376417
merged = []
377418
for chunk_lines in chunk_results:
378419
merged.extend(chunk_lines)
@@ -381,19 +422,27 @@ def _contours_dask(data, levels):
381422

382423

383424
def _contours_dask_cupy(data, levels):
384-
"""Dask+CuPy backend: transfer each chunk to CPU independently."""
425+
"""Dask+CuPy backend: overlap chunks, transfer each to CPU."""
385426
if da is None:
386427
raise ImportError("Dask is required for chunked contour extraction")
387428

388-
chunks = data.to_delayed().ravel()
389-
chunk_slices = _get_chunk_slices(data.chunks)
429+
padded = da.overlap.overlap(data, depth={0: 1, 1: 1}, boundary=np.nan)
430+
orig_row_chunks = data.chunks[0]
431+
orig_col_chunks = data.chunks[1]
432+
padded_blocks = padded.to_delayed()
390433

391434
all_results = []
392-
for chunk_delayed, (r_off, c_off) in zip(chunks, chunk_slices):
393-
result = dask.delayed(_process_chunk_cupy)(
394-
chunk_delayed, levels, r_off, c_off
395-
)
396-
all_results.append(result)
435+
r_off = 0
436+
for ri, rsize in enumerate(orig_row_chunks):
437+
c_off = 0
438+
for ci, csize in enumerate(orig_col_chunks):
439+
chunk = padded_blocks[ri, ci]
440+
result = dask.delayed(_process_chunk_cupy)(
441+
chunk, levels, r_off - 1, c_off - 1
442+
)
443+
all_results.append(result)
444+
c_off += csize
445+
r_off += rsize
397446

398447
chunk_results = dask.compute(*all_results)
399448

@@ -404,20 +453,6 @@ def _contours_dask_cupy(data, levels):
404453
return _deduplicate_by_level(merged)
405454

406455

407-
def _get_chunk_slices(chunks_tuple):
408-
"""Compute (row_offset, col_offset) for each chunk."""
409-
row_chunks, col_chunks = chunks_tuple
410-
slices = []
411-
r_off = 0
412-
for rsize in row_chunks:
413-
c_off = 0
414-
for csize in col_chunks:
415-
slices.append((r_off, c_off))
416-
c_off += csize
417-
r_off += rsize
418-
return slices
419-
420-
421456
def _process_chunk_numpy(chunk_data, levels, r_offset, c_offset):
422457
"""Process a single numpy chunk, offsetting coordinates to global space."""
423458
chunk_data = np.asarray(chunk_data)

0 commit comments

Comments
 (0)