Skip to content

Commit ec4f5cf

Browse files
Donglai Weiclaude
andcommitted
Add use_uint8 decode param, saved_prediction_path, decode-only config, h5 converter
decode_waterz: - Add use_uint8 parameter: converts float→uint8 on the fly before waterz (4x memory savings, lossless for histogram scoring) - Remove _merge_dust_with_region_graph: dust merge now always uses buildRegionGraphOnly (faster, no return_region_graph overhead) - Output filename appends "uint8" when use_uint8=true, nothing when false Pipeline: - Add test.saved_prediction_path to TestConfig: loads external affinity HDF5 directly, skips model inference entirely (decode-only mode) - Add tutorials/waterz_decoding.yaml: minimal decode-only config Scripts: - Add scripts/convert_h5_to_uint8.py: chunked float32→uint8 HDF5 conversion with RAM-aware chunk sizing Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6ecaeb2 commit ec4f5cf

7 files changed

Lines changed: 287 additions & 98 deletions

File tree

connectomics/config/schema/stages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class TestConfig:
4747

4848
output_path: Optional[str] = None
4949
cache_suffix: str = "_x1_prediction.h5"
50+
# Path to a pre-computed affinity prediction HDF5 file.
51+
# When set, skips model inference entirely — loads and decodes directly.
52+
saved_prediction_path: str = ""
5053

5154

5255
@dataclass

connectomics/decoding/decoders/waterz.py

Lines changed: 41 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -73,76 +73,14 @@ def _merge_function_to_scoring(shorthand: str) -> str:
7373
)
7474

7575

76-
def _can_reuse_region_graph_for_dust(scoring_function: str) -> bool:
77-
"""Return whether waterz region-graph scores can be mapped to affinities."""
78-
return scoring_function.startswith("OneMinus<")
79-
80-
81-
def _build_segment_counts(seg: np.ndarray) -> np.ndarray:
82-
"""Build a dense counts array indexed by segment id."""
83-
ids, cnts = np.unique(seg, return_counts=True)
84-
max_id = int(ids.max()) if len(ids) else 0
85-
counts = np.zeros(max_id + 1, dtype=np.uint64)
86-
counts[ids] = cnts
87-
return counts
88-
89-
90-
def _merge_dust_with_region_graph(
91-
seg: np.ndarray,
92-
region_graph: Sequence[Dict[str, Any]],
93-
*,
94-
scoring_function: str,
95-
size_th: int,
96-
weight_th: float,
97-
dust_th: int,
98-
) -> bool:
99-
"""Reuse waterz's returned region graph for dust postprocessing.
100-
101-
Returns True when the existing region graph could be reused directly.
102-
When False, callers should fall back to ``waterz.merge_dust(...)``.
103-
"""
104-
if not _can_reuse_region_graph_for_dust(scoring_function):
105-
return False
106-
107-
num_edges = len(region_graph)
108-
rg_affs = np.empty(num_edges, dtype=np.float32)
109-
id1 = np.empty(num_edges, dtype=np.uint64)
110-
id2 = np.empty(num_edges, dtype=np.uint64)
111-
112-
for idx, edge in enumerate(region_graph):
113-
# OneMinus<T> stores complement scores, so invert them back to the
114-
# underlying affinity-like merge weight before dust merging.
115-
rg_affs[idx] = 1.0 - float(edge["score"])
116-
id1[idx] = int(edge["u"])
117-
id2[idx] = int(edge["v"])
118-
119-
if num_edges:
120-
np.clip(rg_affs, 0.0, 1.0, out=rg_affs)
121-
order = np.argsort(rg_affs)[::-1]
122-
rg_affs = np.ascontiguousarray(rg_affs[order])
123-
id1 = np.ascontiguousarray(id1[order])
124-
id2 = np.ascontiguousarray(id2[order])
125-
126-
counts = _build_segment_counts(seg)
127-
waterz.merge_segments(
128-
seg,
129-
rg_affs,
130-
id1,
131-
id2,
132-
counts,
133-
size_th,
134-
weight_th,
135-
dust_th,
136-
)
137-
return True
138-
13976

14077
def decode_waterz(
14178
predictions: np.ndarray,
14279
thresholds: Union[float, Sequence[float]] = 0.3,
14380
merge_function: str = "aff50_his256",
14481
aff_threshold: Tuple[float, float] = (0.0001, 0.9999),
14582
channel_order: str = "xyz",
83+
use_uint8: bool = False,
14684
fragments: Optional[np.ndarray] = None,
14785
min_instance_size: int = 0,
14886
dust_merge: bool = True,
@@ -172,12 +110,13 @@ def decode_waterz(
172110
Args:
173111
predictions: Affinity predictions of shape :math:`(C, Z, Y, X)` where
174112
``C >= 3``. The first 3 channels are short-range affinities.
175-
Values should be float32 in [0, 1].
113+
Supports **float32** [0, 1] or **uint8** [0, 255]. When uint8,
114+
the entire pipeline runs in integer arithmetic (4x less memory).
115+
Parameters (thresholds, aff_threshold) can be specified in float
116+
[0, 1] range — they are auto-scaled to [0, 255] for uint8.
176117
thresholds: Agglomeration threshold(s). Regions with merge score below
177-
the threshold are merged. Can be a single float or a list of
178-
floats. When multiple thresholds are provided, either the last
179-
result is returned (default) or all results as a dict (see
180-
*return_all_thresholds*). Default: 0.3
118+
the threshold are merged. Specify in [0, 1] float range
119+
regardless of input dtype (auto-scaled for uint8). Default: 0.3
181120
merge_function: Scoring function for agglomeration. Common options:
182121
183122
- ``"aff50_his256"``: Median affinity via 256-bin histogram (default, recommended)
@@ -197,6 +136,11 @@ def decode_waterz(
197136
(e.g. offsets ``["0-0-1", "0-1-0", "1-0-0"]``), set this to
198137
``"xyz"`` and the channels will be transposed automatically.
199138
Default: ``"xyz"``
139+
use_uint8: Convert float affinities to uint8 before waterz.
140+
Saves 4x memory and runs the entire C++ pipeline in integer
141+
arithmetic. Lossless for ``HistogramQuantileAffinity`` with
142+
256 bins. If input is already uint8, this is a no-op.
143+
Default: False
200144
fragments: Pre-computed over-segmentation (fragment IDs). If provided,
201145
the watershed step is skipped and agglomeration runs directly on
202146
these fragments. Shape :math:`(Z, Y, X)`, dtype uint64.
@@ -283,7 +227,19 @@ def decode_waterz(
283227
raise ValueError(f"Expected >= 3 affinity channels, got {predictions.shape[0]}.")
284228

285229
# Use first 3 channels (short-range affinities)
286-
affs = predictions[:3].astype(np.float32, copy=False)
230+
affs = predictions[:3]
231+
232+
# Convert float → uint8 on the fly if requested (4x memory savings).
233+
# If already uint8, this is a no-op.
234+
if use_uint8 and affs.dtype != np.uint8:
235+
affs = np.clip(affs, 0, 1)
236+
affs = (affs * 255).astype(np.uint8)
237+
238+
# Detect uint8 affinities — pass through to waterz uint8 path directly.
239+
# For float dtypes, convert to float32. Never convert uint8 to float.
240+
is_uint8 = affs.dtype == np.uint8
241+
if not is_uint8:
242+
affs = affs.astype(np.float32, copy=False)
287243

288244
# Transpose channels to zyx order expected by waterz C++.
289245
# Waterz expects: channel 0=z, 1=y, 2=x.
@@ -300,25 +256,34 @@ def decode_waterz(
300256
if not affs.flags["C_CONTIGUOUS"]:
301257
affs = np.ascontiguousarray(affs)
302258

259+
# Scale parameters for uint8: user specifies float [0,1], we map to [0,255].
260+
if is_uint8:
261+
_to_u8 = lambda v: int(round(v * 255)) if isinstance(v, float) and v <= 1.0 else int(v)
262+
else:
263+
_to_u8 = None # unused
264+
303265
# Normalize thresholds to sorted list (waterz requires ascending order)
304266
if isinstance(thresholds, (int, float)):
305267
thresholds_list = [float(thresholds)]
306268
else:
307269
thresholds_list = sorted(float(t) for t in thresholds)
270+
if is_uint8:
271+
thresholds_list = [_to_u8(t) for t in thresholds_list]
308272

309273
# Convert shorthand merge function to C++ scoring function string
310274
scoring_function = _merge_function_to_scoring(merge_function)
311275

312-
aff_low = float(aff_threshold[0])
313-
aff_high = float(aff_threshold[1])
276+
aff_low = _to_u8(aff_threshold[0]) if is_uint8 else float(aff_threshold[0])
277+
aff_high = _to_u8(aff_threshold[1]) if is_uint8 else float(aff_threshold[1])
314278

315279
logger.info(
316-
"Running waterz: %d thresholds=%s, scoring_function=%s, aff_threshold=(%.4f, %.4f)",
280+
"Running waterz: %d thresholds=%s, scoring_function=%s, aff_threshold=(%s, %s)%s",
317281
len(thresholds_list),
318282
thresholds_list,
319283
scoring_function,
320284
aff_low,
321285
aff_high,
286+
" [uint8]" if is_uint8 else "",
322287
)
323288

324289
# Build kwargs for waterz.waterz()
@@ -330,44 +295,24 @@ def decode_waterz(
330295
if fragments is not None:
331296
waterz_kwargs["fragments"] = fragments.astype(np.uint64, copy=False)
332297

333-
# Dust postprocessing can request the current post-agglomeration region
334-
# graph from waterz.waterz(). When the score type is compatible, we reuse
335-
# that graph directly; otherwise we fall back to merge_dust() to rebuild an
336-
# affinity-weighted graph from the final segmentation.
337298
do_dust_merge = bool(dust_merge) and dust_merge_size > 0
338-
waterz_kwargs["return_region_graph"] = do_dust_merge
339299

340300
# waterz.waterz() runs watershed + region-graph once, then incrementally
341301
# merges for each threshold. Returns all segmentations (copied).
342302
seg_list = waterz.waterz(affs, thresholds=thresholds_list, **waterz_kwargs)
343303

344304
# Post-process each result
345305
processed: List[np.ndarray] = []
346-
for waterz_result in seg_list:
347-
if do_dust_merge:
348-
seg, _region_graph = waterz_result
349-
else:
350-
seg = waterz_result
351-
352-
# Size+affinity dust merge (zwatershed-style)
306+
for seg in seg_list:
307+
# Size+affinity dust merge via buildRegionGraphOnly (fast path)
353308
if do_dust_merge:
354309
seg = seg.astype(np.uint64, copy=False)
355-
reused_region_graph = _merge_dust_with_region_graph(
356-
seg,
357-
_region_graph,
358-
scoring_function=scoring_function,
310+
waterz.merge_dust(
311+
seg, affs,
359312
size_th=dust_merge_size,
360313
weight_th=dust_merge_affinity,
361314
dust_th=dust_remove_size,
362315
)
363-
if not reused_region_graph:
364-
waterz.merge_dust(
365-
seg,
366-
affs,
367-
dust_merge_size,
368-
dust_merge_affinity,
369-
dust_remove_size,
370-
)
371316
# Branch merge: resolve false splits via z-slice IOU analysis
372317
if branch_merge:
373318
from .branch_merge import branch_merge as _branch_merge

connectomics/training/lightning/model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,23 @@ def _load_cached_predictions(
574574
self, output_dir_value: Optional[str], filenames: List[str], cache_suffix: str, mode: str
575575
):
576576
"""Attempt to load cached predictions from disk."""
577+
# Check test.saved_prediction_path first (decode-only mode)
578+
saved_path = getattr(getattr(self.cfg, "test", None), "saved_prediction_path", "")
579+
if saved_path and isinstance(saved_path, str) and saved_path.strip():
580+
pred_file = Path(saved_path.strip()).expanduser()
581+
if not pred_file.is_absolute():
582+
pred_file = Path.cwd() / pred_file
583+
if pred_file.exists():
584+
logger.info(f"Loading saved prediction (decode-only): {pred_file}")
585+
pred = read_volume(str(pred_file), dataset="main")
586+
if pred.ndim < 4:
587+
pred = pred[np.newaxis, ...]
588+
return pred, True, cache_suffix
589+
else:
590+
raise FileNotFoundError(
591+
f"saved_prediction_path not found: {pred_file}"
592+
)
593+
577594
explicit_prediction = self._resolve_tta_result_path_override()
578595
if isinstance(explicit_prediction, str) and explicit_prediction.strip():
579596
pred_file = Path(explicit_prediction).expanduser()

connectomics/training/lightning/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,12 @@ def _sanitize_decode_component(text: str) -> str:
485485
"dust_merge_affinity",
486486
"dust_remove_size",
487487
],
488+
"use_uint8": [], # gate only: shows "uint8" when true, omitted when false
489+
}
490+
491+
# Custom labels for boolean gates (instead of showing "true")
492+
gate_labels = {
493+
"use_uint8": "uint8",
488494
}
489495

490496
def _flatten_decode_values(value) -> list[str]:
@@ -501,6 +507,9 @@ def _flatten_decode_values(value) -> list[str]:
501507
for key, nested_value in sorted(value_dict.items()):
502508
if key in gated_keys:
503509
if key in gated_value_groups and nested_value is True:
510+
# Use custom label if defined, otherwise expand children
511+
if key in gate_labels:
512+
result.append(gate_labels[key])
504513
for grouped_key in gated_value_groups[key]:
505514
if grouped_key in value_dict:
506515
result.extend(_flatten_decode_values(value_dict[grouped_key]))
@@ -716,6 +725,12 @@ def resolve_prediction_cache_suffix(
716725
if tta_cfg is not None and bool(getattr(tta_cfg, "enabled", False)):
717726
return tta_cache_suffix(cfg, checkpoint_path=checkpoint_path, output_head=output_head)
718727

728+
# Include head + checkpoint tags in the suffix so different heads don't
729+
# collide (e.g. _x1_head-aff_r1_ckpt-last_prediction.h5).
730+
head = format_output_head_tag(cfg, output_head=output_head)
731+
ckpt = format_checkpoint_name_tag(checkpoint_path)
732+
if head or ckpt:
733+
return f"_x1{head}{ckpt}_prediction.h5"
719734
return configured_suffix
720735

721736

0 commit comments

Comments
 (0)