Skip to content

Commit 150d359

Browse files
author
Donglai Wei
committed
Split TTA and prediction crop helpers
1 parent f6b83d3 commit 150d359

5 files changed

Lines changed: 555 additions & 518 deletions

File tree

connectomics/inference/tta.py

Lines changed: 7 additions & 240 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from __future__ import annotations
99

1010
import logging
11-
from itertools import combinations
1211
from typing import Any, Optional
1312

1413
import numpy as np
@@ -17,7 +16,7 @@
1716
from monai.data.utils import dense_patch_slices
1817
from monai.inferers.utils import _get_scan_interval, compute_importance_map
1918

20-
from ..utils.channel_slices import resolve_channel_indices, resolve_channel_range
19+
from ..utils.channel_slices import resolve_channel_indices
2120
from ..utils.model_outputs import (
2221
resolve_output_channels,
2322
resolve_output_head,
@@ -29,14 +28,12 @@
2928
is_2d_inference_mode,
3029
resolve_inferer_roi_size,
3130
)
32-
33-
try:
34-
from omegaconf import ListConfig, OmegaConf
35-
36-
HAS_OMEGACONF = True
37-
except ImportError:
38-
HAS_OMEGACONF = False
39-
ListConfig = list # Fallback
31+
from .tta_combinations import (
32+
_resolve_ensemble_mode_map,
33+
_resolve_spatial_dims,
34+
_to_plain_list,
35+
resolve_tta_augmentation_combinations,
36+
)
4037

4138
try:
4239
from tqdm.auto import tqdm
@@ -46,236 +43,6 @@
4643
logger = logging.getLogger(__name__)
4744

4845

49-
def _to_plain_list(config_value) -> list:
50-
"""Convert OmegaConf ListConfig (or plain list) to nested plain Python lists."""
51-
if HAS_OMEGACONF and isinstance(config_value, ListConfig):
52-
return OmegaConf.to_container(config_value, resolve=True)
53-
if isinstance(config_value, (list, tuple)):
54-
return list(config_value)
55-
return [config_value]
56-
57-
58-
def _resolve_spatial_dims(ndim: int) -> int:
59-
if ndim == 5:
60-
return 3
61-
if ndim == 4:
62-
return 2
63-
raise ValueError(f"Unsupported data dimensions: {ndim}")
64-
65-
66-
def _normalize_spatial_axes(
67-
axes: Any,
68-
*,
69-
spatial_dims: int,
70-
context: str,
71-
) -> list[int]:
72-
if isinstance(axes, int):
73-
axes = [axes]
74-
if not isinstance(axes, (list, tuple)):
75-
raise ValueError(f"{context} must be an int or list of ints, got {axes!r}.")
76-
77-
normalized: list[int] = []
78-
seen: set[int] = set()
79-
for raw_axis in axes:
80-
axis = int(raw_axis)
81-
if axis < 0 or axis >= spatial_dims:
82-
raise ValueError(f"{context} axis must be in [0, {spatial_dims - 1}], got {axis}.")
83-
if axis in seen:
84-
continue
85-
normalized.append(axis)
86-
seen.add(axis)
87-
return normalized
88-
89-
90-
def _resolve_flip_augmentations(tta_cfg, *, spatial_dims: int) -> list[list[int]]:
91-
flip_axes_cfg = getattr(tta_cfg, "flip_axes", None)
92-
if isinstance(flip_axes_cfg, str) and flip_axes_cfg.lower() == "none":
93-
return [[]]
94-
95-
if flip_axes_cfg == "all" or flip_axes_cfg == []:
96-
spatial_axes = list(range(spatial_dims))
97-
tta_flip_axes = [[]]
98-
for r in range(1, len(spatial_axes) + 1):
99-
for combo in combinations(spatial_axes, r):
100-
tta_flip_axes.append(list(combo))
101-
return tta_flip_axes
102-
103-
if flip_axes_cfg is None:
104-
return [[]]
105-
106-
tta_flip_axes = [[]]
107-
for raw_axes in _to_plain_list(flip_axes_cfg):
108-
tta_flip_axes.append(
109-
_normalize_spatial_axes(
110-
raw_axes,
111-
spatial_dims=spatial_dims,
112-
context="flip_axes",
113-
)
114-
)
115-
return tta_flip_axes
116-
117-
118-
def _resolve_rotation_planes(tta_cfg, *, spatial_dims: int) -> list[tuple[int, int]]:
119-
rotation90_axes_cfg = getattr(tta_cfg, "rotation90_axes", None)
120-
if isinstance(rotation90_axes_cfg, str) and rotation90_axes_cfg.lower() == "none":
121-
return []
122-
123-
if rotation90_axes_cfg == "all":
124-
if spatial_dims == 3:
125-
return [(0, 1), (0, 2), (1, 2)]
126-
if spatial_dims == 2:
127-
return [(0, 1)]
128-
raise ValueError(f"Unsupported spatial dimensions: {spatial_dims}")
129-
130-
if rotation90_axes_cfg is None:
131-
return []
132-
133-
resolved_planes: list[tuple[int, int]] = []
134-
for axes in _to_plain_list(rotation90_axes_cfg):
135-
normalized = _normalize_spatial_axes(
136-
axes,
137-
spatial_dims=spatial_dims,
138-
context="rotation90_axes",
139-
)
140-
if len(normalized) != 2:
141-
raise ValueError(
142-
f"Invalid rotation plane: {axes}. Each plane must contain exactly 2 axes."
143-
)
144-
plane = (normalized[0], normalized[1])
145-
if plane not in resolved_planes:
146-
resolved_planes.append(plane)
147-
return resolved_planes
148-
149-
150-
def _resolve_rotation_k_values(tta_cfg) -> list[int]:
151-
rotate90_k_cfg = getattr(tta_cfg, "rotate90_k", None)
152-
if rotate90_k_cfg is None:
153-
return [0, 1, 2, 3]
154-
155-
resolved_values: list[int] = []
156-
seen: set[int] = set()
157-
for raw_k in _to_plain_list(rotate90_k_cfg):
158-
k = int(raw_k) % 4
159-
if k in seen:
160-
continue
161-
resolved_values.append(k)
162-
seen.add(k)
163-
return resolved_values or [0]
164-
165-
166-
def _augmentation_signature(
167-
*,
168-
spatial_dims: int,
169-
flip_axes: list[int],
170-
rotation_plane: Optional[tuple[int, int]],
171-
k_rotations: int,
172-
) -> tuple[int, ...]:
173-
if spatial_dims == 3:
174-
base = torch.arange(2 * 3 * 5, dtype=torch.int64).reshape(2, 3, 5)
175-
elif spatial_dims == 2:
176-
base = torch.arange(2 * 5, dtype=torch.int64).reshape(2, 5)
177-
else:
178-
raise ValueError(f"Unsupported spatial dimensions: {spatial_dims}")
179-
180-
if flip_axes:
181-
base = torch.flip(base, dims=flip_axes)
182-
if rotation_plane is not None and k_rotations % 4:
183-
base = torch.rot90(base, k=k_rotations, dims=rotation_plane)
184-
return tuple(int(v) for v in base.reshape(-1).tolist())
185-
186-
187-
def resolve_tta_augmentation_combinations(
188-
tta_cfg,
189-
*,
190-
spatial_dims: int,
191-
) -> list[tuple[list[int], Optional[tuple[int, int]], int]]:
192-
"""Return unique spatial TTA combinations for the configured flips/rotations."""
193-
flip_variants = _resolve_flip_augmentations(tta_cfg, spatial_dims=spatial_dims)
194-
rotation_planes = _resolve_rotation_planes(tta_cfg, spatial_dims=spatial_dims)
195-
196-
if not rotation_planes:
197-
return [(flip_axes, None, 0) for flip_axes in flip_variants]
198-
199-
rotation_ks = _resolve_rotation_k_values(tta_cfg)
200-
combinations_out: list[tuple[list[int], Optional[tuple[int, int]], int]] = []
201-
seen_signatures: set[tuple[int, ...]] = set()
202-
203-
for flip_axes in flip_variants:
204-
for rotation_plane in rotation_planes:
205-
for k_rotations in rotation_ks:
206-
signature = _augmentation_signature(
207-
spatial_dims=spatial_dims,
208-
flip_axes=flip_axes,
209-
rotation_plane=rotation_plane,
210-
k_rotations=k_rotations,
211-
)
212-
if signature in seen_signatures:
213-
continue
214-
seen_signatures.add(signature)
215-
combinations_out.append((flip_axes, rotation_plane, k_rotations))
216-
217-
return combinations_out
218-
219-
220-
def _resolve_ensemble_mode_map(
221-
ensemble_mode: Any,
222-
num_channels: int,
223-
) -> list[str]:
224-
"""Resolve ``ensemble_mode`` config to a per-channel mode list.
225-
226-
When *ensemble_mode* is a plain string (``"mean"``, ``"min"``, ``"max"``),
227-
every channel gets the same mode.
228-
229-
When it is a list of ``[channel_selector, mode]`` pairs — e.g.
230-
``[["0:3", "min"], ["3:", "mean"]]`` — each channel is assigned the mode
231-
from its matching entry. Channel selectors use the same syntax as loss
232-
``pred_slice`` / ``target_slice`` (parsed via
233-
:func:`~connectomics.utils.channel_slices.resolve_channel_range`).
234-
"""
235-
if isinstance(ensemble_mode, str):
236-
return [ensemble_mode] * num_channels
237-
238-
# Convert OmegaConf containers to plain Python.
239-
raw_list = _to_plain_list(ensemble_mode)
240-
if not isinstance(raw_list, list) or not raw_list:
241-
raise ValueError(
242-
f"ensemble_mode must be a string or a list of [channel_selector, mode] pairs, "
243-
f"got {ensemble_mode!r}."
244-
)
245-
246-
# If first element is a string (not a list), it's a single mode string
247-
# that was wrapped in a ListConfig — treat it as a plain string.
248-
if isinstance(raw_list[0], str) and len(raw_list) == 1:
249-
return [raw_list[0]] * num_channels
250-
251-
modes: list[str | None] = [None] * num_channels
252-
for entry in raw_list:
253-
if not isinstance(entry, (list, tuple)) or len(entry) != 2:
254-
raise ValueError(
255-
f"Each ensemble_mode entry must be [channel_selector, mode], got {entry!r}."
256-
)
257-
selector, mode = entry
258-
if mode not in ("mean", "min", "max"):
259-
raise ValueError(
260-
f"Unknown ensemble mode {mode!r} in per-channel spec. Use 'mean', 'min', or 'max'."
261-
)
262-
start, stop = resolve_channel_range(
263-
str(selector),
264-
num_channels=num_channels,
265-
context="ensemble_mode channel selector",
266-
)
267-
for ch in range(start, stop):
268-
modes[ch] = mode
269-
270-
unset = [i for i, m in enumerate(modes) if m is None]
271-
if unset:
272-
raise ValueError(
273-
f"ensemble_mode does not cover channels {unset}. "
274-
f"Every channel must be assigned a mode."
275-
)
276-
return modes # type: ignore[return-value]
277-
278-
27946
class TTAPredictor:
28047
"""Encapsulates TTA preprocessing and flip ensemble logic."""
28148

0 commit comments

Comments
 (0)