Skip to content

Commit b66fb18

Browse files
mattlin1124林旻佑KumoLiuericspod
authored
Fix #8366: Add strict shape validation to sliding_window_inference (#8645)
### Description This PR addresses issue #8366 by implementing strict shape validation in `sliding_window_inference`. Per the feedback from maintainers (@ericspod), implicit guessing (heuristics) for channel-last data has been avoided. Instead, this PR ensures that: 1. The input tensor explicitly matches the expected dimensions based on `roi_size` (e.g., must be 5D for 3D `roi_size`). 2. Validation is skipped if `roi_size` is an integer (broadcasting), preventing regressions in existing 1D/broadcasting tests. 3. A clear `ValueError` is raised if dimensions do not match, guiding users to handle channel-last data upstream using `EnsureChannelFirst` or `EnsureChannelFirstd`. ### Status - [x] Code changes implemented in `monai/inferers/utils.py` - [x] New unit tests added in `tests/inferers/test_sliding_window_inference.py` - [x] No changes to `.gitignore` ### Types of changes - [x] Bug fix (non-breaking change which fixes an issue) --------- Signed-off-by: 林旻佑 <linminyou@linminyoudeMacBook-Air.local> Signed-off-by: Matt Lin <mattlin1124@gmail.com> Co-authored-by: 林旻佑 <linminyou@linminyoudeMacBook-Air.local> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 2147c11 commit b66fb18

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

monai/inferers/utils.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def sliding_window_inference(
7676
7777
Args:
7878
inputs: input image to be processed (assuming NCHW[D])
79-
roi_size: the spatial window size for inferences.
79+
roi_size: the spatial window size for inferences, this must be a single value or a tuple with values
80+
for each spatial dimension (eg. 2 for 2D, 3 for 3D).
8081
When its components have None or non-positives, the corresponding inputs dimension will be used.
8182
if the components of the `roi_size` are non-positive values, the transform will use the
8283
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
@@ -131,11 +132,30 @@ def sliding_window_inference(
131132
kwargs: optional keyword args to be passed to ``predictor``.
132133
133134
Note:
134-
- input must be channel-first and have a batch dim, supports N-D sliding window.
135+
- Inputs must be channel-first and have a batch dim (NCHW / NCDHW).
136+
- If your data is NHWC/NDHWC, please apply `EnsureChannelFirst` / `EnsureChannelFirstd` upstream.
137+
138+
Raises:
139+
ValueError: When the input dimensions do not match the expected dimensions based on ``roi_size``.
135140
136141
"""
137-
buffered = buffer_steps is not None and buffer_steps > 0
138142
num_spatial_dims = len(inputs.shape) - 2
143+
144+
# Only perform strict shape validation if roi_size is a sequence (explicit dimensions).
145+
# If roi_size is an integer, it is broadcast to all dimensions, so we cannot
146+
# infer the expected dimensionality to enforce a strict check here.
147+
if isinstance(roi_size, Sequence):
148+
roi_dims = len(roi_size)
149+
if num_spatial_dims != roi_dims:
150+
raise ValueError(
151+
f"Inputs must have {roi_dims + 2} dimensions for {roi_dims}D roi_size "
152+
f"(Batch, Channel, {', '.join(['Spatial'] * roi_dims)}), "
153+
f"but got inputs shape {inputs.shape}.\n"
154+
"If you have channel-last data (e.g. B, D, H, W, C), please use "
155+
"monai.transforms.EnsureChannelFirst or EnsureChannelFirstd upstream."
156+
)
157+
# -----------------------------------------------------------------
158+
buffered = buffer_steps is not None and buffer_steps > 0
139159
if buffered:
140160
if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims:
141161
raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.")

tests/inferers/test_sliding_window_inference.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,26 @@ def compute_dict(data):
372372
for rr, _ in zip(result_dict, expected_dict):
373373
np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4)
374374

375+
def test_strict_shape_validation(self):
376+
"""Test strict shape validation to ensure inputs match roi_size dimensions."""
377+
device = "cpu"
378+
roi_size = (16, 16, 16)
379+
sw_batch_size = 4
380+
381+
def predictor(data):
382+
return data
383+
384+
# Case 1: Input has fewer dimensions than expected (e.g., missing Batch or Channel)
385+
# 3D roi_size requires 5D input (B, C, D, H, W), giving 4D here.
386+
inputs_4d = torch.randn((1, 16, 16, 16), device=device)
387+
with self.assertRaisesRegex(ValueError, "Inputs must have 5 dimensions"):
388+
sliding_window_inference(inputs_4d, roi_size, sw_batch_size, predictor)
389+
390+
# Case 2: Input is 3D (missing Batch AND Channel)
391+
inputs_3d = torch.randn((16, 16, 16), device=device)
392+
with self.assertRaisesRegex(ValueError, "Inputs must have 5 dimensions"):
393+
sliding_window_inference(inputs_3d, roi_size, sw_batch_size, predictor)
394+
375395

376396
class TestSlidingWindowInferenceCond(unittest.TestCase):
377397
@parameterized.expand(TEST_CASES)

0 commit comments

Comments
 (0)