Skip to content

Commit 8f50580

Browse files
Donglai Weiclaude
andcommitted
Use rescore_region_graph=False for dust merge (skip stale edge re-scoring)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5a16435 commit 8f50580

1 file changed

Lines changed: 11 additions & 46 deletions

File tree

connectomics/decoding/decoders/waterz.py

Lines changed: 11 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -181,56 +181,20 @@ def decode_waterz(
181181
if predictions.shape[0] < 3:
182182
raise ValueError(f"Expected >= 3 affinity channels, got {predictions.shape[0]}.")
183183

184-
# Use first 3 channels (short-range affinities)
185-
affs = predictions[:3]
186-
187-
# Convert float → uint8 on the fly if requested (4x memory savings).
188-
# If already uint8, this is a no-op.
189-
if use_aff_uint8 and affs.dtype != np.uint8:
190-
affs = np.clip(affs, 0, 1)
191-
affs = (affs * 255).astype(np.uint8)
192-
193-
# Detect uint8 affinities — pass through to waterz uint8 path directly.
194-
# For float dtypes, convert to float32. Never convert uint8 to float.
195-
is_uint8 = affs.dtype == np.uint8
196-
if not is_uint8:
197-
affs = affs.astype(np.float32, copy=False)
198-
199-
# Transpose channels to zyx order expected by waterz C++.
200-
# Waterz expects: channel 0=z, 1=y, 2=x.
201-
channel_order = channel_order.lower()
202-
if channel_order == "xyz":
203-
# Model outputs x,y,z → reverse to z,y,x
204-
affs = affs[[2, 1, 0]]
205-
elif channel_order == "zyx":
206-
pass # Already in waterz order
207-
else:
208-
raise ValueError(f"Unknown channel_order '{channel_order}'. Expected 'xyz' or 'zyx'.")
209-
210-
# Ensure C-contiguous for waterz
211-
if not affs.flags["C_CONTIGUOUS"]:
212-
affs = np.ascontiguousarray(affs)
213-
214-
# Scale parameters for uint8: user specifies float [0,1], we map to [0,255].
215-
if is_uint8:
216-
_to_u8 = lambda v: int(round(v * 255)) if isinstance(v, float) and v <= 1.0 else int(v)
217-
else:
218-
_to_u8 = None # unused
219-
220-
# Normalize thresholds to sorted list (waterz requires ascending order)
221-
if isinstance(thresholds, (int, float)):
222-
thresholds_list = [float(thresholds)]
223-
else:
224-
thresholds_list = sorted(float(t) for t in thresholds)
225-
if is_uint8:
226-
thresholds_list = [_to_u8(t) for t in thresholds_list]
184+
from waterz._uint8 import prepare_affinities, scale_aff_threshold, scale_thresholds
185+
186+
# Prepare affinities: dtype normalisation, channel reorder, contiguous
187+
affs, is_uint8 = prepare_affinities(
188+
predictions, channel_order=channel_order, use_aff_uint8=use_aff_uint8,
189+
)
190+
191+
# Scale float [0,1] parameters to [0,255] for uint8
192+
thresholds_list = scale_thresholds(thresholds, is_uint8)
193+
aff_low, aff_high = scale_aff_threshold(aff_threshold, is_uint8)
227194

228195
# Convert shorthand merge function to C++ scoring function string
229196
scoring_function = merge_function_to_scoring(merge_function)
230197

231-
aff_low = _to_u8(aff_threshold[0]) if is_uint8 else float(aff_threshold[0])
232-
aff_high = _to_u8(aff_threshold[1]) if is_uint8 else float(aff_threshold[1])
233-
234198
logger.info(
235199
"Running waterz: %d thresholds=%s, scoring_function=%s, aff_threshold=(%s, %s)%s",
236200
len(thresholds_list),
@@ -256,6 +220,7 @@ def decode_waterz(
256220

257221
do_dust_merge = bool(dust_merge) and dust_merge_size > 0
258222
waterz_kwargs["return_region_graph"] = do_dust_merge
223+
waterz_kwargs["rescore_region_graph"] = False # fast: use cached scores for dust merge
259224

260225
# waterz.waterz() runs watershed + region-graph once, then incrementally
261226
# merges for each threshold. Returns all segmentations (copied).

0 commit comments

Comments
 (0)