@@ -181,56 +181,20 @@ def decode_waterz(
181181 if predictions .shape [0 ] < 3 :
182182 raise ValueError (f"Expected >= 3 affinity channels, got { predictions .shape [0 ]} ." )
183183
184- # Use first 3 channels (short-range affinities)
185- affs = predictions [:3 ]
186-
187- # Convert float → uint8 on the fly if requested (4x memory savings).
188- # If already uint8, this is a no-op.
189- if use_aff_uint8 and affs .dtype != np .uint8 :
190- affs = np .clip (affs , 0 , 1 )
191- affs = (affs * 255 ).astype (np .uint8 )
192-
193- # Detect uint8 affinities — pass through to waterz uint8 path directly.
194- # For float dtypes, convert to float32. Never convert uint8 to float.
195- is_uint8 = affs .dtype == np .uint8
196- if not is_uint8 :
197- affs = affs .astype (np .float32 , copy = False )
198-
199- # Transpose channels to zyx order expected by waterz C++.
200- # Waterz expects: channel 0=z, 1=y, 2=x.
201- channel_order = channel_order .lower ()
202- if channel_order == "xyz" :
203- # Model outputs x,y,z → reverse to z,y,x
204- affs = affs [[2 , 1 , 0 ]]
205- elif channel_order == "zyx" :
206- pass # Already in waterz order
207- else :
208- raise ValueError (f"Unknown channel_order '{ channel_order } '. Expected 'xyz' or 'zyx'." )
209-
210- # Ensure C-contiguous for waterz
211- if not affs .flags ["C_CONTIGUOUS" ]:
212- affs = np .ascontiguousarray (affs )
213-
214- # Scale parameters for uint8: user specifies float [0,1], we map to [0,255].
215- if is_uint8 :
216- _to_u8 = lambda v : int (round (v * 255 )) if isinstance (v , float ) and v <= 1.0 else int (v )
217- else :
218- _to_u8 = None # unused
219-
220- # Normalize thresholds to sorted list (waterz requires ascending order)
221- if isinstance (thresholds , (int , float )):
222- thresholds_list = [float (thresholds )]
223- else :
224- thresholds_list = sorted (float (t ) for t in thresholds )
225- if is_uint8 :
226- thresholds_list = [_to_u8 (t ) for t in thresholds_list ]
184+ from waterz ._uint8 import prepare_affinities , scale_aff_threshold , scale_thresholds
185+
186+ # Prepare affinities: dtype normalisation, channel reorder, contiguous
187+ affs , is_uint8 = prepare_affinities (
188+ predictions , channel_order = channel_order , use_aff_uint8 = use_aff_uint8 ,
189+ )
190+
191+ # Scale float [0,1] parameters to [0,255] for uint8
192+ thresholds_list = scale_thresholds (thresholds , is_uint8 )
193+ aff_low , aff_high = scale_aff_threshold (aff_threshold , is_uint8 )
227194
228195 # Convert shorthand merge function to C++ scoring function string
229196 scoring_function = merge_function_to_scoring (merge_function )
230197
231- aff_low = _to_u8 (aff_threshold [0 ]) if is_uint8 else float (aff_threshold [0 ])
232- aff_high = _to_u8 (aff_threshold [1 ]) if is_uint8 else float (aff_threshold [1 ])
233-
234198 logger .info (
235199 "Running waterz: %d thresholds=%s, scoring_function=%s, aff_threshold=(%s, %s)%s" ,
236200 len (thresholds_list ),
@@ -256,6 +220,7 @@ def decode_waterz(
256220
257221 do_dust_merge = bool (dust_merge ) and dust_merge_size > 0
258222 waterz_kwargs ["return_region_graph" ] = do_dust_merge
223+ waterz_kwargs ["rescore_region_graph" ] = False # fast: use cached scores for dust merge
259224
260225 # waterz.waterz() runs watershed + region-graph once, then incrementally
261226 # merges for each threshold. Returns all segmentations (copied).
0 commit comments