@@ -73,76 +73,14 @@ def _merge_function_to_scoring(shorthand: str) -> str:
7373 )
7474
7575
76- def _can_reuse_region_graph_for_dust (scoring_function : str ) -> bool :
77- """Return whether waterz region-graph scores can be mapped to affinities."""
78- return scoring_function .startswith ("OneMinus<" )
79-
80-
81- def _build_segment_counts (seg : np .ndarray ) -> np .ndarray :
82- """Build a dense counts array indexed by segment id."""
83- ids , cnts = np .unique (seg , return_counts = True )
84- max_id = int (ids .max ()) if len (ids ) else 0
85- counts = np .zeros (max_id + 1 , dtype = np .uint64 )
86- counts [ids ] = cnts
87- return counts
88-
89-
90- def _merge_dust_with_region_graph (
91- seg : np .ndarray ,
92- region_graph : Sequence [Dict [str , Any ]],
93- * ,
94- scoring_function : str ,
95- size_th : int ,
96- weight_th : float ,
97- dust_th : int ,
98- ) -> bool :
99- """Reuse waterz's returned region graph for dust postprocessing.
100-
101- Returns True when the existing region graph could be reused directly.
102- When False, callers should fall back to ``waterz.merge_dust(...)``.
103- """
104- if not _can_reuse_region_graph_for_dust (scoring_function ):
105- return False
106-
107- num_edges = len (region_graph )
108- rg_affs = np .empty (num_edges , dtype = np .float32 )
109- id1 = np .empty (num_edges , dtype = np .uint64 )
110- id2 = np .empty (num_edges , dtype = np .uint64 )
111-
112- for idx , edge in enumerate (region_graph ):
113- # OneMinus<T> stores complement scores, so invert them back to the
114- # underlying affinity-like merge weight before dust merging.
115- rg_affs [idx ] = 1.0 - float (edge ["score" ])
116- id1 [idx ] = int (edge ["u" ])
117- id2 [idx ] = int (edge ["v" ])
118-
119- if num_edges :
120- np .clip (rg_affs , 0.0 , 1.0 , out = rg_affs )
121- order = np .argsort (rg_affs )[::- 1 ]
122- rg_affs = np .ascontiguousarray (rg_affs [order ])
123- id1 = np .ascontiguousarray (id1 [order ])
124- id2 = np .ascontiguousarray (id2 [order ])
125-
126- counts = _build_segment_counts (seg )
127- waterz .merge_segments (
128- seg ,
129- rg_affs ,
130- id1 ,
131- id2 ,
132- counts ,
133- size_th ,
134- weight_th ,
135- dust_th ,
136- )
137- return True
138-
13976
14077def decode_waterz (
14178 predictions : np .ndarray ,
14279 thresholds : Union [float , Sequence [float ]] = 0.3 ,
14380 merge_function : str = "aff50_his256" ,
14481 aff_threshold : Tuple [float , float ] = (0.0001 , 0.9999 ),
14582 channel_order : str = "xyz" ,
83+ use_uint8 : bool = False ,
14684 fragments : Optional [np .ndarray ] = None ,
14785 min_instance_size : int = 0 ,
14886 dust_merge : bool = True ,
@@ -172,12 +110,13 @@ def decode_waterz(
172110 Args:
173111 predictions: Affinity predictions of shape :math:`(C, Z, Y, X)` where
174112 ``C >= 3``. The first 3 channels are short-range affinities.
175- Values should be float32 in [0, 1].
113+ Supports **float32** [0, 1] or **uint8** [0, 255]. When uint8,
114+ the entire pipeline runs in integer arithmetic (4x less memory).
115+ Parameters (thresholds, aff_threshold) can be specified in float
116+ [0, 1] range — they are auto-scaled to [0, 255] for uint8.
176117 thresholds: Agglomeration threshold(s). Regions with merge score below
177- the threshold are merged. Can be a single float or a list of
178- floats. When multiple thresholds are provided, either the last
179- result is returned (default) or all results as a dict (see
180- *return_all_thresholds*). Default: 0.3
118+ the threshold are merged. Specify in [0, 1] float range
119+ regardless of input dtype (auto-scaled for uint8). Default: 0.3
181120 merge_function: Scoring function for agglomeration. Common options:
182121
183122 - ``"aff50_his256"``: Median affinity via 256-bin histogram (default, recommended)
@@ -197,6 +136,11 @@ def decode_waterz(
197136 (e.g. offsets ``["0-0-1", "0-1-0", "1-0-0"]``), set this to
198137 ``"xyz"`` and the channels will be transposed automatically.
199138 Default: ``"xyz"``
139+ use_uint8: Convert float affinities to uint8 before waterz.
140+ Saves 4x memory and runs the entire C++ pipeline in integer
141+ arithmetic. Lossless for ``HistogramQuantileAffinity`` with
142+ 256 bins. If input is already uint8, this is a no-op.
143+ Default: False
200144 fragments: Pre-computed over-segmentation (fragment IDs). If provided,
201145 the watershed step is skipped and agglomeration runs directly on
202146 these fragments. Shape :math:`(Z, Y, X)`, dtype uint64.
@@ -283,7 +227,19 @@ def decode_waterz(
283227 raise ValueError (f"Expected >= 3 affinity channels, got { predictions .shape [0 ]} ." )
284228
285229 # Use first 3 channels (short-range affinities)
286- affs = predictions [:3 ].astype (np .float32 , copy = False )
230+ affs = predictions [:3 ]
231+
232+ # Convert float → uint8 on the fly if requested (4x memory savings).
233+ # If already uint8, this is a no-op.
234+ if use_uint8 and affs .dtype != np .uint8 :
235+ affs = np .clip (affs , 0 , 1 )
236+ affs = (affs * 255 ).astype (np .uint8 )
237+
238+ # Detect uint8 affinities — pass through to waterz uint8 path directly.
239+ # For float dtypes, convert to float32. Never convert uint8 to float.
240+ is_uint8 = affs .dtype == np .uint8
241+ if not is_uint8 :
242+ affs = affs .astype (np .float32 , copy = False )
287243
288244 # Transpose channels to zyx order expected by waterz C++.
289245 # Waterz expects: channel 0=z, 1=y, 2=x.
@@ -300,25 +256,34 @@ def decode_waterz(
300256 if not affs .flags ["C_CONTIGUOUS" ]:
301257 affs = np .ascontiguousarray (affs )
302258
259+ # Scale parameters for uint8: user specifies float [0,1], we map to [0,255].
260+ if is_uint8 :
261+ _to_u8 = lambda v : int (round (v * 255 )) if isinstance (v , float ) and v <= 1.0 else int (v )
262+ else :
263+ _to_u8 = None # unused
264+
303265 # Normalize thresholds to sorted list (waterz requires ascending order)
304266 if isinstance (thresholds , (int , float )):
305267 thresholds_list = [float (thresholds )]
306268 else :
307269 thresholds_list = sorted (float (t ) for t in thresholds )
270+ if is_uint8 :
271+ thresholds_list = [_to_u8 (t ) for t in thresholds_list ]
308272
309273 # Convert shorthand merge function to C++ scoring function string
310274 scoring_function = _merge_function_to_scoring (merge_function )
311275
312- aff_low = float (aff_threshold [0 ])
313- aff_high = float (aff_threshold [1 ])
276+ aff_low = _to_u8 ( aff_threshold [ 0 ]) if is_uint8 else float (aff_threshold [0 ])
277+ aff_high = _to_u8 ( aff_threshold [ 1 ]) if is_uint8 else float (aff_threshold [1 ])
314278
315279 logger .info (
316- "Running waterz: %d thresholds=%s, scoring_function=%s, aff_threshold=(%.4f , %.4f) " ,
280+ "Running waterz: %d thresholds=%s, scoring_function=%s, aff_threshold=(%s , %s)%s " ,
317281 len (thresholds_list ),
318282 thresholds_list ,
319283 scoring_function ,
320284 aff_low ,
321285 aff_high ,
286+ " [uint8]" if is_uint8 else "" ,
322287 )
323288
324289 # Build kwargs for waterz.waterz()
@@ -330,44 +295,24 @@ def decode_waterz(
330295 if fragments is not None :
331296 waterz_kwargs ["fragments" ] = fragments .astype (np .uint64 , copy = False )
332297
333- # Dust postprocessing can request the current post-agglomeration region
334- # graph from waterz.waterz(). When the score type is compatible, we reuse
335- # that graph directly; otherwise we fall back to merge_dust() to rebuild an
336- # affinity-weighted graph from the final segmentation.
337298 do_dust_merge = bool (dust_merge ) and dust_merge_size > 0
338- waterz_kwargs ["return_region_graph" ] = do_dust_merge
339299
340300 # waterz.waterz() runs watershed + region-graph once, then incrementally
341301 # merges for each threshold. Returns all segmentations (copied).
342302 seg_list = waterz .waterz (affs , thresholds = thresholds_list , ** waterz_kwargs )
343303
344304 # Post-process each result
345305 processed : List [np .ndarray ] = []
346- for waterz_result in seg_list :
347- if do_dust_merge :
348- seg , _region_graph = waterz_result
349- else :
350- seg = waterz_result
351-
352- # Size+affinity dust merge (zwatershed-style)
306+ for seg in seg_list :
307+ # Size+affinity dust merge via buildRegionGraphOnly (fast path)
353308 if do_dust_merge :
354309 seg = seg .astype (np .uint64 , copy = False )
355- reused_region_graph = _merge_dust_with_region_graph (
356- seg ,
357- _region_graph ,
358- scoring_function = scoring_function ,
310+ waterz .merge_dust (
311+ seg , affs ,
359312 size_th = dust_merge_size ,
360313 weight_th = dust_merge_affinity ,
361314 dust_th = dust_remove_size ,
362315 )
363- if not reused_region_graph :
364- waterz .merge_dust (
365- seg ,
366- affs ,
367- dust_merge_size ,
368- dust_merge_affinity ,
369- dust_remove_size ,
370- )
371316 # Branch merge: resolve false splits via z-slice IOU analysis
372317 if branch_merge :
373318 from .branch_merge import branch_merge as _branch_merge
0 commit comments