|
| 1 | +"""Public chunk-grid helpers shared by chunked inference and streamed decoding.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from dataclasses import dataclass |
| 6 | +from itertools import product |
| 7 | +from typing import Any, Sequence |
| 8 | + |
| 9 | +from ..data.processing.affinity import ( |
| 10 | + compute_affinity_crop_pad, |
| 11 | + resolve_affinity_channel_groups_from_cfg, |
| 12 | + resolve_affinity_mode_from_cfg, |
| 13 | +) |
| 14 | +from ..utils.channel_slices import resolve_channel_indices |
| 15 | + |
| 16 | + |
| 17 | +@dataclass(frozen=True) |
| 18 | +class ChunkRef: |
| 19 | + index: tuple[int, int, int] |
| 20 | + start: tuple[int, int, int] |
| 21 | + stop: tuple[int, int, int] |
| 22 | + |
| 23 | + @property |
| 24 | + def key(self) -> str: |
| 25 | + z, y, x = self.index |
| 26 | + return f"z{z}_y{y}_x{x}" |
| 27 | + |
| 28 | + @property |
| 29 | + def slices(self) -> tuple[slice, slice, slice]: |
| 30 | + return tuple(slice(self.start[axis], self.stop[axis]) for axis in range(3)) |
| 31 | + |
| 32 | + |
| 33 | +def build_chunk_grid(volume_shape: Sequence[int], chunk_shape: Sequence[int]) -> list[ChunkRef]: |
| 34 | + volume = tuple(int(v) for v in volume_shape) |
| 35 | + chunk = tuple(int(v) for v in chunk_shape) |
| 36 | + counts = tuple((volume[axis] + chunk[axis] - 1) // chunk[axis] for axis in range(3)) |
| 37 | + result: list[ChunkRef] = [] |
| 38 | + for index in product(*(range(count) for count in counts)): |
| 39 | + start = tuple(index[axis] * chunk[axis] for axis in range(3)) |
| 40 | + stop = tuple(min(start[axis] + chunk[axis], volume[axis]) for axis in range(3)) |
| 41 | + result.append(ChunkRef(index=tuple(int(v) for v in index), start=start, stop=stop)) |
| 42 | + return result |
| 43 | + |
| 44 | + |
| 45 | +def normalize_crop_pad(value: Any) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int]]: |
| 46 | + if value in (None, [], ()): |
| 47 | + return ((0, 0), (0, 0), (0, 0)) |
| 48 | + values = [int(v) for v in value] |
| 49 | + if len(values) == 3: |
| 50 | + return tuple((v, v) for v in values) # type: ignore[return-value] |
| 51 | + if len(values) == 6: |
| 52 | + return ((values[0], values[1]), (values[2], values[3]), (values[4], values[5])) |
| 53 | + raise ValueError(f"inference.crop_pad must have length 3 or 6, got {value!r}") |
| 54 | + |
| 55 | + |
| 56 | +def resolve_selected_affinity_offsets(cfg: Any) -> list[tuple[int, int, int]]: |
| 57 | + groups = resolve_affinity_channel_groups_from_cfg(cfg) |
| 58 | + if not groups: |
| 59 | + return [] |
| 60 | + |
| 61 | + label_channels = max(end for (start, end), _offsets in groups) |
| 62 | + channel_offsets: list[tuple[int, int, int] | None] = [None] * label_channels |
| 63 | + for (start, end), offsets in groups: |
| 64 | + for channel, offset in zip(range(start, end), offsets): |
| 65 | + channel_offsets[channel] = offset |
| 66 | + |
| 67 | + select_channel = getattr(getattr(cfg, "inference", None), "select_channel", None) |
| 68 | + if select_channel is not None: |
| 69 | + selected = resolve_channel_indices( |
| 70 | + select_channel, |
| 71 | + num_channels=len(channel_offsets), |
| 72 | + context="inference.select_channel", |
| 73 | + ) |
| 74 | + channel_offsets = [channel_offsets[idx] for idx in selected] |
| 75 | + |
| 76 | + return [offset for offset in channel_offsets if offset is not None] |
| 77 | + |
| 78 | + |
| 79 | +def resolve_global_prediction_crop( |
| 80 | + cfg: Any, |
| 81 | +) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int]]: |
| 82 | + user_crop = normalize_crop_pad(getattr(getattr(cfg, "inference", None), "crop_pad", None)) |
| 83 | + affinity_mode = resolve_affinity_mode_from_cfg(cfg) |
| 84 | + if affinity_mode is None: |
| 85 | + affinity_crop = ((0, 0), (0, 0), (0, 0)) |
| 86 | + else: |
| 87 | + offsets = resolve_selected_affinity_offsets(cfg) |
| 88 | + affinity_crop = ( |
| 89 | + compute_affinity_crop_pad(offsets, affinity_mode=affinity_mode) |
| 90 | + if offsets |
| 91 | + else ((0, 0), (0, 0), (0, 0)) |
| 92 | + ) |
| 93 | + return tuple( |
| 94 | + ( |
| 95 | + int(user_crop[axis][0]) + int(affinity_crop[axis][0]), |
| 96 | + int(user_crop[axis][1]) + int(affinity_crop[axis][1]), |
| 97 | + ) |
| 98 | + for axis in range(3) |
| 99 | + ) # type: ignore[return-value] |
| 100 | + |
| 101 | + |
| 102 | +def validate_chunked_output_format(cfg: Any) -> None: |
| 103 | + save_cfg = getattr(getattr(cfg, "inference", None), "save_prediction", None) |
| 104 | + formats = [str(fmt).lower() for fmt in getattr(save_cfg, "output_formats", ["h5"])] |
| 105 | + unsupported_formats = [fmt for fmt in formats if fmt not in {"h5", "hdf5"}] |
| 106 | + if unsupported_formats: |
| 107 | + raise ValueError( |
| 108 | + "Chunked inference writes a single streamed HDF5 output only; " |
| 109 | + f"unsupported save_prediction.output_formats={unsupported_formats}." |
| 110 | + ) |
| 111 | + |
| 112 | + |
| 113 | +def resolve_chunk_shape(cfg: Any, final_shape: Sequence[int]) -> tuple[int, int, int]: |
| 114 | + chunking_cfg = cfg.inference.chunking |
| 115 | + chunk_size = tuple(int(v) for v in chunking_cfg.chunk_size) |
| 116 | + axes = str(getattr(chunking_cfg, "axes", "all")).lower() |
| 117 | + if axes == "z": |
| 118 | + return (chunk_size[0], int(final_shape[1]), int(final_shape[2])) |
| 119 | + if axes != "all": |
| 120 | + raise ValueError("inference.chunking.axes must be 'all' or 'z'") |
| 121 | + return tuple(min(chunk_size[axis], int(final_shape[axis])) for axis in range(3)) |
| 122 | + |
| 123 | + |
| 124 | +def resolve_h5_spatial_chunks(spatial_shape: Sequence[int]) -> tuple[int, int, int]: |
| 125 | + preferred = (64, 64, 64) |
| 126 | + return tuple(min(int(spatial_shape[axis]), preferred[axis]) for axis in range(3)) |
| 127 | + |
| 128 | + |
| 129 | +def resolve_chunk_output_mode(cfg: Any) -> str: |
| 130 | + chunking_cfg = cfg.inference.chunking |
| 131 | + mode = str(getattr(chunking_cfg, "output_mode", "decoded")).lower() |
| 132 | + if mode not in {"decoded", "raw_prediction"}: |
| 133 | + raise ValueError("inference.chunking.output_mode must be 'decoded' or 'raw_prediction'.") |
| 134 | + return mode |
| 135 | + |
| 136 | + |
| 137 | +__all__ = [ |
| 138 | + "ChunkRef", |
| 139 | + "build_chunk_grid", |
| 140 | + "normalize_crop_pad", |
| 141 | + "resolve_selected_affinity_offsets", |
| 142 | + "resolve_global_prediction_crop", |
| 143 | + "validate_chunked_output_format", |
| 144 | + "resolve_chunk_shape", |
| 145 | + "resolve_h5_spatial_chunks", |
| 146 | + "resolve_chunk_output_mode", |
| 147 | +] |
0 commit comments