Skip to content

Commit d0374ad

Browse files
Donglai Weiclaude
andcommitted
Reuse agglomeration region graph for dust merge instead of rebuilding
The ec4f5cf refactor switched dust merge from reusing the agglomeration's region graph to calling waterz.merge_dust() which rebuilds from scratch with MeanAffinity scoring. This changed results because the dust merge used different edge weights than the agglomeration (p85 histogram quantile). Restore return_region_graph=True and invert the OneMinus uint8 scores back to affinities for waterz.merge_segments(), matching the original behavior. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 49d9fe2 commit d0374ad

2 files changed

Lines changed: 38 additions & 22 deletions

File tree

connectomics/decoding/decoders/waterz.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ def decode_waterz(
147147
min_instance_size: Minimum instance size in voxels. Instances smaller
148148
than this are removed (set to background). Set to 0 to disable.
149149
Default: 0
150-
dust_merge: Enable dust postprocessing. When WaterZ returns a
151-
reusable region graph, dust merging uses it directly via
152-
``waterz.merge_segments``; otherwise it falls back to
153-
``waterz.merge_dust``. When False, the dust merge and dust
154-
removal thresholds below are ignored. Default: True
150+
dust_merge: Enable dust postprocessing. Reuses the agglomeration's
151+
region graph (returned by waterz) and calls
152+
``waterz.merge_segments`` directly — no redundant graph rebuild.
153+
When False, the dust merge and dust removal thresholds below
154+
are ignored. Default: True
155155
dust_merge_size: Size+affinity dust merge (zwatershed-style).
156156
Segments with fewer voxels than this are merged into their
157157
highest-affinity neighbor. Unlike *min_instance_size* which
@@ -296,19 +296,34 @@ def decode_waterz(
296296
waterz_kwargs["fragments"] = fragments.astype(np.uint64, copy=False)
297297

298298
do_dust_merge = bool(dust_merge) and dust_merge_size > 0
299+
waterz_kwargs["return_region_graph"] = do_dust_merge
299300

300301
# waterz.waterz() runs watershed + region-graph once, then incrementally
301302
# merges for each threshold. Returns all segmentations (copied).
302303
seg_list = waterz.waterz(affs, thresholds=thresholds_list, **waterz_kwargs)
303304

304305
# Post-process each result
305306
processed: List[np.ndarray] = []
306-
for seg in seg_list:
307-
# Size+affinity dust merge via buildRegionGraphOnly (fast path)
307+
for waterz_result in seg_list:
308+
if do_dust_merge:
309+
seg, (rg_id, rg_sc) = waterz_result
310+
else:
311+
seg = waterz_result
312+
313+
# Size+affinity dust merge reusing the agglomeration's region graph.
314+
# rg_sc is uint8 sorted ascending (low score = high affinity).
315+
# Invert OneMinus/One255Minus scores to raw affinities in [0, 1].
308316
if do_dust_merge:
309317
seg = seg.astype(np.uint64, copy=False)
310-
waterz.merge_dust(
311-
seg, affs,
318+
rg_affs = (255.0 - rg_sc.astype(np.float32)) / 255.0
319+
id1 = rg_id[:, 0].astype(np.uint64)
320+
id2 = rg_id[:, 1].astype(np.uint64)
321+
ids, cnts = np.unique(seg, return_counts=True)
322+
max_id = int(ids.max()) if len(ids) else 0
323+
counts = np.zeros(max_id + 1, dtype=np.uint64)
324+
counts[ids] = cnts
325+
waterz.merge_segments(
326+
seg, rg_affs, id1, id2, counts,
312327
size_th=dust_merge_size,
313328
weight_th=dust_merge_affinity,
314329
dust_th=dust_remove_size,

tests/unit/test_decode_waterz.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ def waterz(self, affs, thresholds, **kwargs):
2121
seg[:, :, :2] = 1
2222
seg[:, :, 2:] = 2
2323
if kwargs.get("return_region_graph", False):
24-
rg = [{"u": 1, "v": 2, "score": 0.2}]
25-
return [(seg.copy(), rg.copy()) for _ in thresholds]
24+
# rgToArr format: (rg_id (N,2) uint32, rg_sc (N,) uint8)
25+
# score=51 → affinity = (255-51)/255 ≈ 0.8
26+
rg_id = np.array([[1, 2]], dtype=np.uint32)
27+
rg_sc = np.array([51], dtype=np.uint8)
28+
return [(seg.copy(), (rg_id.copy(), rg_sc.copy())) for _ in thresholds]
2629
return [seg.copy() for _ in thresholds]
2730

2831
def merge_dust(self, seg, affs, size_th, weight_th, dust_th):
@@ -131,9 +134,10 @@ def test_decode_waterz_reuses_region_graph_for_dust_when_scores_are_compatible(
131134
]
132135

133136

134-
def test_decode_waterz_falls_back_to_merge_dust_for_incompatible_scores(
137+
def test_decode_waterz_reuses_region_graph_for_any_scoring_function(
135138
monkeypatch,
136139
):
140+
"""Region graph reuse works for any scoring function, not just OneMinus."""
137141
fake_waterz = _FakeWaterzModule()
138142
monkeypatch.setattr(waterz_decoder, "waterz", fake_waterz)
139143
monkeypatch.setattr(waterz_decoder, "WATERZ_AVAILABLE", True)
@@ -159,13 +163,10 @@ def test_decode_waterz_falls_back_to_merge_dust_for_incompatible_scores(
159163
"return_region_graph": True,
160164
}
161165
]
162-
assert fake_waterz.merge_segments_calls == []
163-
assert fake_waterz.merge_dust_calls == [
164-
{
165-
"seg_shape": (4, 4, 4),
166-
"aff_shape": (3, 4, 4, 4),
167-
"size_th": 100,
168-
"weight_th": 0.3,
169-
"dust_th": 50,
170-
}
171-
]
166+
assert fake_waterz.merge_dust_calls == []
167+
assert len(fake_waterz.merge_segments_calls) == 1
168+
call = fake_waterz.merge_segments_calls[0]
169+
assert call["seg_shape"] == (4, 4, 4)
170+
assert call["size_th"] == 100
171+
assert call["weight_th"] == 0.3
172+
assert call["dust_th"] == 50

0 commit comments

Comments
 (0)