|
8 | 8 | from __future__ import annotations |
9 | 9 |
|
10 | 10 | import logging |
11 | | -from itertools import combinations |
12 | 11 | from typing import Any, Optional |
13 | 12 |
|
14 | 13 | import numpy as np |
|
17 | 16 | from monai.data.utils import dense_patch_slices |
18 | 17 | from monai.inferers.utils import _get_scan_interval, compute_importance_map |
19 | 18 |
|
20 | | -from ..utils.channel_slices import resolve_channel_indices, resolve_channel_range |
| 19 | +from ..utils.channel_slices import resolve_channel_indices |
21 | 20 | from ..utils.model_outputs import ( |
22 | 21 | resolve_output_channels, |
23 | 22 | resolve_output_head, |
|
29 | 28 | is_2d_inference_mode, |
30 | 29 | resolve_inferer_roi_size, |
31 | 30 | ) |
32 | | - |
33 | | -try: |
34 | | - from omegaconf import ListConfig, OmegaConf |
35 | | - |
36 | | - HAS_OMEGACONF = True |
37 | | -except ImportError: |
38 | | - HAS_OMEGACONF = False |
39 | | - ListConfig = list # Fallback |
| 31 | +from .tta_combinations import ( |
| 32 | + _resolve_ensemble_mode_map, |
| 33 | + _resolve_spatial_dims, |
| 34 | + _to_plain_list, |
| 35 | + resolve_tta_augmentation_combinations, |
| 36 | +) |
40 | 37 |
|
41 | 38 | try: |
42 | 39 | from tqdm.auto import tqdm |
|
46 | 43 | logger = logging.getLogger(__name__) |
47 | 44 |
|
48 | 45 |
|
49 | | -def _to_plain_list(config_value) -> list: |
50 | | - """Convert OmegaConf ListConfig (or plain list) to nested plain Python lists.""" |
51 | | - if HAS_OMEGACONF and isinstance(config_value, ListConfig): |
52 | | - return OmegaConf.to_container(config_value, resolve=True) |
53 | | - if isinstance(config_value, (list, tuple)): |
54 | | - return list(config_value) |
55 | | - return [config_value] |
56 | | - |
57 | | - |
58 | | -def _resolve_spatial_dims(ndim: int) -> int: |
59 | | - if ndim == 5: |
60 | | - return 3 |
61 | | - if ndim == 4: |
62 | | - return 2 |
63 | | - raise ValueError(f"Unsupported data dimensions: {ndim}") |
64 | | - |
65 | | - |
66 | | -def _normalize_spatial_axes( |
67 | | - axes: Any, |
68 | | - *, |
69 | | - spatial_dims: int, |
70 | | - context: str, |
71 | | -) -> list[int]: |
72 | | - if isinstance(axes, int): |
73 | | - axes = [axes] |
74 | | - if not isinstance(axes, (list, tuple)): |
75 | | - raise ValueError(f"{context} must be an int or list of ints, got {axes!r}.") |
76 | | - |
77 | | - normalized: list[int] = [] |
78 | | - seen: set[int] = set() |
79 | | - for raw_axis in axes: |
80 | | - axis = int(raw_axis) |
81 | | - if axis < 0 or axis >= spatial_dims: |
82 | | - raise ValueError(f"{context} axis must be in [0, {spatial_dims - 1}], got {axis}.") |
83 | | - if axis in seen: |
84 | | - continue |
85 | | - normalized.append(axis) |
86 | | - seen.add(axis) |
87 | | - return normalized |
88 | | - |
89 | | - |
90 | | -def _resolve_flip_augmentations(tta_cfg, *, spatial_dims: int) -> list[list[int]]: |
91 | | - flip_axes_cfg = getattr(tta_cfg, "flip_axes", None) |
92 | | - if isinstance(flip_axes_cfg, str) and flip_axes_cfg.lower() == "none": |
93 | | - return [[]] |
94 | | - |
95 | | - if flip_axes_cfg == "all" or flip_axes_cfg == []: |
96 | | - spatial_axes = list(range(spatial_dims)) |
97 | | - tta_flip_axes = [[]] |
98 | | - for r in range(1, len(spatial_axes) + 1): |
99 | | - for combo in combinations(spatial_axes, r): |
100 | | - tta_flip_axes.append(list(combo)) |
101 | | - return tta_flip_axes |
102 | | - |
103 | | - if flip_axes_cfg is None: |
104 | | - return [[]] |
105 | | - |
106 | | - tta_flip_axes = [[]] |
107 | | - for raw_axes in _to_plain_list(flip_axes_cfg): |
108 | | - tta_flip_axes.append( |
109 | | - _normalize_spatial_axes( |
110 | | - raw_axes, |
111 | | - spatial_dims=spatial_dims, |
112 | | - context="flip_axes", |
113 | | - ) |
114 | | - ) |
115 | | - return tta_flip_axes |
116 | | - |
117 | | - |
118 | | -def _resolve_rotation_planes(tta_cfg, *, spatial_dims: int) -> list[tuple[int, int]]: |
119 | | - rotation90_axes_cfg = getattr(tta_cfg, "rotation90_axes", None) |
120 | | - if isinstance(rotation90_axes_cfg, str) and rotation90_axes_cfg.lower() == "none": |
121 | | - return [] |
122 | | - |
123 | | - if rotation90_axes_cfg == "all": |
124 | | - if spatial_dims == 3: |
125 | | - return [(0, 1), (0, 2), (1, 2)] |
126 | | - if spatial_dims == 2: |
127 | | - return [(0, 1)] |
128 | | - raise ValueError(f"Unsupported spatial dimensions: {spatial_dims}") |
129 | | - |
130 | | - if rotation90_axes_cfg is None: |
131 | | - return [] |
132 | | - |
133 | | - resolved_planes: list[tuple[int, int]] = [] |
134 | | - for axes in _to_plain_list(rotation90_axes_cfg): |
135 | | - normalized = _normalize_spatial_axes( |
136 | | - axes, |
137 | | - spatial_dims=spatial_dims, |
138 | | - context="rotation90_axes", |
139 | | - ) |
140 | | - if len(normalized) != 2: |
141 | | - raise ValueError( |
142 | | - f"Invalid rotation plane: {axes}. Each plane must contain exactly 2 axes." |
143 | | - ) |
144 | | - plane = (normalized[0], normalized[1]) |
145 | | - if plane not in resolved_planes: |
146 | | - resolved_planes.append(plane) |
147 | | - return resolved_planes |
148 | | - |
149 | | - |
150 | | -def _resolve_rotation_k_values(tta_cfg) -> list[int]: |
151 | | - rotate90_k_cfg = getattr(tta_cfg, "rotate90_k", None) |
152 | | - if rotate90_k_cfg is None: |
153 | | - return [0, 1, 2, 3] |
154 | | - |
155 | | - resolved_values: list[int] = [] |
156 | | - seen: set[int] = set() |
157 | | - for raw_k in _to_plain_list(rotate90_k_cfg): |
158 | | - k = int(raw_k) % 4 |
159 | | - if k in seen: |
160 | | - continue |
161 | | - resolved_values.append(k) |
162 | | - seen.add(k) |
163 | | - return resolved_values or [0] |
164 | | - |
165 | | - |
166 | | -def _augmentation_signature( |
167 | | - *, |
168 | | - spatial_dims: int, |
169 | | - flip_axes: list[int], |
170 | | - rotation_plane: Optional[tuple[int, int]], |
171 | | - k_rotations: int, |
172 | | -) -> tuple[int, ...]: |
173 | | - if spatial_dims == 3: |
174 | | - base = torch.arange(2 * 3 * 5, dtype=torch.int64).reshape(2, 3, 5) |
175 | | - elif spatial_dims == 2: |
176 | | - base = torch.arange(2 * 5, dtype=torch.int64).reshape(2, 5) |
177 | | - else: |
178 | | - raise ValueError(f"Unsupported spatial dimensions: {spatial_dims}") |
179 | | - |
180 | | - if flip_axes: |
181 | | - base = torch.flip(base, dims=flip_axes) |
182 | | - if rotation_plane is not None and k_rotations % 4: |
183 | | - base = torch.rot90(base, k=k_rotations, dims=rotation_plane) |
184 | | - return tuple(int(v) for v in base.reshape(-1).tolist()) |
185 | | - |
186 | | - |
187 | | -def resolve_tta_augmentation_combinations( |
188 | | - tta_cfg, |
189 | | - *, |
190 | | - spatial_dims: int, |
191 | | -) -> list[tuple[list[int], Optional[tuple[int, int]], int]]: |
192 | | - """Return unique spatial TTA combinations for the configured flips/rotations.""" |
193 | | - flip_variants = _resolve_flip_augmentations(tta_cfg, spatial_dims=spatial_dims) |
194 | | - rotation_planes = _resolve_rotation_planes(tta_cfg, spatial_dims=spatial_dims) |
195 | | - |
196 | | - if not rotation_planes: |
197 | | - return [(flip_axes, None, 0) for flip_axes in flip_variants] |
198 | | - |
199 | | - rotation_ks = _resolve_rotation_k_values(tta_cfg) |
200 | | - combinations_out: list[tuple[list[int], Optional[tuple[int, int]], int]] = [] |
201 | | - seen_signatures: set[tuple[int, ...]] = set() |
202 | | - |
203 | | - for flip_axes in flip_variants: |
204 | | - for rotation_plane in rotation_planes: |
205 | | - for k_rotations in rotation_ks: |
206 | | - signature = _augmentation_signature( |
207 | | - spatial_dims=spatial_dims, |
208 | | - flip_axes=flip_axes, |
209 | | - rotation_plane=rotation_plane, |
210 | | - k_rotations=k_rotations, |
211 | | - ) |
212 | | - if signature in seen_signatures: |
213 | | - continue |
214 | | - seen_signatures.add(signature) |
215 | | - combinations_out.append((flip_axes, rotation_plane, k_rotations)) |
216 | | - |
217 | | - return combinations_out |
218 | | - |
219 | | - |
220 | | -def _resolve_ensemble_mode_map( |
221 | | - ensemble_mode: Any, |
222 | | - num_channels: int, |
223 | | -) -> list[str]: |
224 | | - """Resolve ``ensemble_mode`` config to a per-channel mode list. |
225 | | -
|
226 | | - When *ensemble_mode* is a plain string (``"mean"``, ``"min"``, ``"max"``), |
227 | | - every channel gets the same mode. |
228 | | -
|
229 | | - When it is a list of ``[channel_selector, mode]`` pairs — e.g. |
230 | | - ``[["0:3", "min"], ["3:", "mean"]]`` — each channel is assigned the mode |
231 | | - from its matching entry. Channel selectors use the same syntax as loss |
232 | | - ``pred_slice`` / ``target_slice`` (parsed via |
233 | | - :func:`~connectomics.utils.channel_slices.resolve_channel_range`). |
234 | | - """ |
235 | | - if isinstance(ensemble_mode, str): |
236 | | - return [ensemble_mode] * num_channels |
237 | | - |
238 | | - # Convert OmegaConf containers to plain Python. |
239 | | - raw_list = _to_plain_list(ensemble_mode) |
240 | | - if not isinstance(raw_list, list) or not raw_list: |
241 | | - raise ValueError( |
242 | | - f"ensemble_mode must be a string or a list of [channel_selector, mode] pairs, " |
243 | | - f"got {ensemble_mode!r}." |
244 | | - ) |
245 | | - |
246 | | - # If first element is a string (not a list), it's a single mode string |
247 | | - # that was wrapped in a ListConfig — treat it as a plain string. |
248 | | - if isinstance(raw_list[0], str) and len(raw_list) == 1: |
249 | | - return [raw_list[0]] * num_channels |
250 | | - |
251 | | - modes: list[str | None] = [None] * num_channels |
252 | | - for entry in raw_list: |
253 | | - if not isinstance(entry, (list, tuple)) or len(entry) != 2: |
254 | | - raise ValueError( |
255 | | - f"Each ensemble_mode entry must be [channel_selector, mode], got {entry!r}." |
256 | | - ) |
257 | | - selector, mode = entry |
258 | | - if mode not in ("mean", "min", "max"): |
259 | | - raise ValueError( |
260 | | - f"Unknown ensemble mode {mode!r} in per-channel spec. Use 'mean', 'min', or 'max'." |
261 | | - ) |
262 | | - start, stop = resolve_channel_range( |
263 | | - str(selector), |
264 | | - num_channels=num_channels, |
265 | | - context="ensemble_mode channel selector", |
266 | | - ) |
267 | | - for ch in range(start, stop): |
268 | | - modes[ch] = mode |
269 | | - |
270 | | - unset = [i for i, m in enumerate(modes) if m is None] |
271 | | - if unset: |
272 | | - raise ValueError( |
273 | | - f"ensemble_mode does not cover channels {unset}. " |
274 | | - f"Every channel must be assigned a mode." |
275 | | - ) |
276 | | - return modes # type: ignore[return-value] |
277 | | - |
278 | | - |
279 | 46 | class TTAPredictor: |
280 | 47 | """Encapsulates TTA preprocessing and flip ensemble logic.""" |
281 | 48 |
|
|
0 commit comments