Skip to content

Commit 0651704

Browse files
ytl0623pre-commit-ci[bot]ericspod
authored
Generalize TestTimeAugmentation to non-spatial predictions (#8715)
Fixes #8276 ### Description - Added a new argument `apply_inverse_to_pred`. Defaults to `True` to preserve backward compatibility. When set to `False`, it skips the inverse transformation step and aggregates the model predictions directly. - Added a new unit test to simulate a classification task with spatial augmentation, verifying that the aggregation works correctly without spatial inversion. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ytl0623 <david89062388@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 2713c2b commit 0651704

File tree

2 files changed

+65
-16
lines changed

2 files changed

+65
-16
lines changed

monai/data/test_time_augmentation.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from copy import deepcopy
1717
from typing import TYPE_CHECKING, Any
1818

19-
import numpy as np
2019
import torch
2120

2221
from monai.config.type_definitions import NdarrayOrTensor
@@ -68,7 +67,7 @@ class TestTimeAugmentation:
6867
Args:
6968
transform: transform (or composed) to be applied to each realization. At least one transform must be of type
7069
`RandomizableTrait` (i.e. `Randomizable`, `RandomizableTransform`, or `RandomizableTrait`).
71-
. All random transforms must be of type `InvertibleTransform`.
70+
When `apply_inverse_to_pred` is True, all random transforms must be of type `InvertibleTransform`.
7271
batch_size: number of realizations to infer at once.
7372
num_workers: how many subprocesses to use for data.
7473
inferrer_fn: function to use to perform inference.
@@ -92,6 +91,11 @@ class TestTimeAugmentation:
9291
will return the full data. Dimensions will be same size as when passing a single image through
9392
`inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.
9493
progress: whether to display a progress bar.
94+
apply_inverse_to_pred: whether to apply inverse transformations to the predictions.
95+
If the model's prediction is spatial (e.g. segmentation), this should be `True` to map the predictions
96+
back to the original spatial reference.
97+
If the prediction is non-spatial (e.g. classification label or score), this should be `False` to
98+
aggregate the raw predictions directly. Defaults to `True`.
9599
96100
Example:
97101
.. code-block:: python
@@ -125,6 +129,7 @@ def __init__(
125129
post_func: Callable = _identity,
126130
return_full_data: bool = False,
127131
progress: bool = True,
132+
apply_inverse_to_pred: bool = True,
128133
) -> None:
129134
self.transform = transform
130135
self.batch_size = batch_size
@@ -134,6 +139,7 @@ def __init__(
134139
self.image_key = image_key
135140
self.return_full_data = return_full_data
136141
self.progress = progress
142+
self.apply_inverse_to_pred = apply_inverse_to_pred
137143
self._pred_key = CommonKeys.PRED
138144
self.inverter = Invertd(
139145
keys=self._pred_key,
@@ -152,20 +158,23 @@ def __init__(
152158

153159
def _check_transforms(self):
154160
"""Should be at least 1 random transform, and all random transforms should be invertible."""
155-
ts = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
156-
randoms = np.array([isinstance(t, Randomizable) for t in ts])
157-
invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts])
158-
# check at least 1 random
159-
if sum(randoms) == 0:
161+
transforms = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
162+
warns = []
163+
randoms = []
164+
165+
for idx, t in enumerate(transforms):
166+
if isinstance(t, Randomizable):
167+
randoms.append(t)
168+
if self.apply_inverse_to_pred and not isinstance(t, InvertibleTransform):
169+
warns.append(f"Transform #{idx} (type {type(t).__name__}) is random but not invertible.")
170+
171+
if len(randoms) == 0:
172+
warns.append("TTA usually requires at least one `Randomizable` transform in the given transform sequence.")
173+
174+
if len(warns) > 0:
160175
warnings.warn(
161-
"TTA usually has at least a `Randomizable` transform or `Compose` contains `Randomizable` transforms."
176+
"TTA has encountered issues with the given transforms:\n " + "\n ".join(warns), stacklevel=2
162177
)
163-
# check that whenever randoms is True, invertibles is also true
164-
for r, i in zip(randoms, invertibles):
165-
if r and not i:
166-
warnings.warn(
167-
f"Not all applied random transform(s) are invertible. Problematic transform: {type(r).__name__}"
168-
)
169178

170179
def __call__(
171180
self, data: dict[str, Any], num_examples: int = 10
@@ -199,7 +208,10 @@ def __call__(
199208
for b in tqdm(dl) if has_tqdm and self.progress else dl:
200209
# do model forward pass
201210
b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device))
202-
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
211+
if self.apply_inverse_to_pred:
212+
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
213+
else:
214+
outs.extend([i[self._pred_key] for i in decollate_batch(b)])
203215

204216
output: NdarrayOrTensor = stack(outs, 0)
205217

tests/integration/test_testtimeaugmentation.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_test_time_augmentation(self):
104104
# output might be different size, so pad so that they match
105105
train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate)
106106

107-
model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device)
107+
model = UNet(2, 1, 1, channels=(6, 6), strides=(2,)).to(device)
108108
loss_function = DiceLoss(sigmoid=True)
109109
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
110110

@@ -181,6 +181,43 @@ def test_image_no_label(self):
181181
tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, orig_key="image")
182182
tta(self.get_data(1, (20, 20), include_label=False))
183183

184+
def test_non_spatial_output(self):
185+
"""
186+
Test TTA for non-spatial output (e.g., classification scores).
187+
Verifies that setting `apply_inverse_to_pred=False` correctly aggregates
188+
predictions without attempting spatial inversion.
189+
"""
190+
input_size = (20, 20)
191+
data = {"image": np.random.rand(1, *input_size).astype(np.float32)}
192+
193+
transforms = Compose(
194+
[EnsureChannelFirstd("image", channel_dim="no_channel"), RandFlipd("image", prob=1.0, spatial_axis=0)]
195+
)
196+
197+
def mock_classifier(x):
198+
batch_size = x.shape[0]
199+
return torch.tensor([[0.2, 0.8]] * batch_size, dtype=torch.float32, device=x.device)
200+
201+
tt_aug = TestTimeAugmentation(
202+
transform=transforms,
203+
batch_size=2,
204+
num_workers=0,
205+
inferrer_fn=mock_classifier,
206+
device="cpu",
207+
orig_key="image",
208+
apply_inverse_to_pred=False,
209+
return_full_data=False,
210+
)
211+
mode, mean, std, vvc = tt_aug(data, num_examples=4)
212+
213+
self.assertEqual(mean.shape, (2,))
214+
np.testing.assert_allclose(mean, [0.2, 0.8], atol=1e-6)
215+
np.testing.assert_allclose(std, [0.0, 0.0], atol=1e-6)
216+
217+
tt_aug.return_full_data = True
218+
full_output = tt_aug(data, num_examples=4)
219+
self.assertEqual(full_output.shape, (4, 2))
220+
184221

185222
if __name__ == "__main__":
186223
unittest.main()

0 commit comments

Comments
 (0)