Skip to content

Commit 1149b50

Browse files
shubham-61969pre-commit-ci[bot]ericspod
authored
ENH: support additional dtypes in pad_nd (#8672)
Prefer the PyTorch padding backend when supported and safely fall back to NumPy on error. Add unit tests to validate backend selection and ensure output dtype is preserved. Fixes #7842 ### Description This pull request relaxes dtype restrictions in `pad_nd` and prefers the PyTorch padding backend when supported, with a safe fallback to NumPy on error. This enables support for additional dtypes (e.g. bool) that are already handled correctly by recent PyTorch versions. Unit tests are added to validate backend selection and ensure dtype preservation. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. --------- Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent edb58f1 commit 1149b50

File tree

2 files changed

+122
-9
lines changed

2 files changed

+122
-9
lines changed

monai/transforms/croppad/functional.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,27 @@ def pad_nd(
9191
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
9292
kwargs: other arguments for the `np.pad` or `torch.pad` function.
9393
note that `np.pad` treats channel dimension as the first dimension.
94+
Raises:
95+
ValueError: If `value` is provided when `mode` is not ``"constant"``.
9496
"""
97+
if mode != "constant" and "value" in kwargs:
98+
raise ValueError("'value' argument is only valid when mode='constant'")
9599
if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}:
96100
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
97101
try:
98102
_pad = _np_pad
99-
if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and img.dtype not in {
100-
torch.int16,
101-
torch.int64,
102-
torch.bool,
103-
torch.uint8,
104-
}:
103+
if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"}:
104+
# Try PyTorch pad for these modes; fallback to NumPy on error.
105105
_pad = _pt_pad
106106
return _pad(img, pad_width=to_pad, mode=mode, **kwargs)
107+
except NotImplementedError:
108+
# PyTorch does not support this combination, fall back to NumPy
109+
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
107110
except (ValueError, TypeError, RuntimeError) as err:
108-
if isinstance(err, NotImplementedError) or any(
109-
k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value")
110-
):
111+
# PyTorch may raise generic errors for unsupported modes/dtypes or kwargs.
112+
# Since there are no stable exception types for these cases, we fall back
113+
# to NumPy by matching known error message patterns.
114+
if any(k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value")):
111115
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
112116
raise ValueError(
113117
f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device if isinstance(img, torch.Tensor) else None}"
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
"""
12+
Tests for pad_nd dtype support and backend selection.
13+
Validates PyTorch padding preference and NumPy fallback behavior.
14+
"""
15+
from __future__ import annotations
16+
17+
import unittest
18+
from unittest.mock import Mock, patch
19+
20+
import torch
21+
from parameterized.parameterized import parameterized
22+
23+
import monai.transforms.croppad.functional as F
24+
from monai.transforms.croppad.functional import pad_nd
25+
26+
DTYPES = [torch.bool, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, torch.float32]
27+
MODES_DTYPES = [
28+
("constant", torch.bool),
29+
("constant", torch.int8),
30+
("constant", torch.float32),
31+
("reflect", torch.bool),
32+
("reflect", torch.int8),
33+
("reflect", torch.float32),
34+
("replicate", torch.bool),
35+
("replicate", torch.int8),
36+
("replicate", torch.float32),
37+
]
38+
39+
40+
class TestPadNdDtypes(unittest.TestCase):
41+
def test_pad_uses_pt_for_bool(self):
42+
"""Test that pad_nd uses PyTorch backend for bool dtype in constant mode."""
43+
img = torch.ones((1, 4, 4), dtype=torch.bool)
44+
to_pad = [(0, 0), (1, 1), (2, 2)]
45+
with (
46+
patch.object(F, "_pt_pad", wraps=F._pt_pad) as mock_pt,
47+
patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np,
48+
):
49+
out = pad_nd(img, to_pad, mode="constant", value=0)
50+
51+
self.assertTrue(mock_pt.called)
52+
self.assertFalse(mock_np.called)
53+
self.assertEqual(out.dtype, img.dtype)
54+
self.assertEqual(out.shape, (1, 6, 8))
55+
56+
def test_pad_falls_back_to_np_if_pt_raises(self):
57+
"""Test that pad_nd falls back to NumPy when PyTorch raises NotImplementedError."""
58+
img = torch.ones((1, 4, 4), dtype=torch.bool)
59+
to_pad = [(0, 0), (1, 1), (2, 2)]
60+
with (
61+
patch.object(F, "_pt_pad", new=Mock(side_effect=NotImplementedError("no"))) as mock_pt,
62+
patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np,
63+
):
64+
out = pad_nd(img, to_pad, mode="constant", value=0)
65+
66+
self.assertTrue(mock_pt.called)
67+
self.assertTrue(mock_np.called)
68+
self.assertEqual(out.dtype, img.dtype)
69+
self.assertEqual(out.shape, (1, 6, 8))
70+
71+
@parameterized.expand(DTYPES)
72+
def test_pad_dtype_no_error_and_dtype_preserved(self, dtype):
73+
"""Test that pad_nd handles various dtypes without error and preserves dtype.
74+
Args:
75+
dtype: Input dtype under test.
76+
"""
77+
img = torch.ones((1, 4, 4), dtype=dtype)
78+
to_pad = [(0, 0), (1, 1), (2, 2)]
79+
out = pad_nd(img, to_pad, mode="constant", value=0)
80+
81+
self.assertEqual(out.shape, (1, 6, 8))
82+
self.assertEqual(out.dtype, img.dtype)
83+
84+
@parameterized.expand(MODES_DTYPES)
85+
def test_pad_multiple_modes_dtype_preserved(self, mode, dtype):
86+
"""Test that pad_nd preserves dtype across multiple padding modes.
87+
Args:
88+
mode: Padding mode under test.
89+
dtype: Input dtype under test.
90+
"""
91+
img = torch.ones((1, 4, 4), dtype=dtype)
92+
to_pad = [(0, 0), (1, 1), (2, 2)]
93+
94+
kwargs = {"value": 0} if mode == "constant" else {}
95+
out = pad_nd(img, to_pad, mode=mode, **kwargs)
96+
97+
self.assertEqual(out.shape, (1, 6, 8))
98+
self.assertEqual(out.dtype, img.dtype)
99+
100+
def test_value_with_non_constant_mode_raises(self):
101+
"""Test that pad_nd raises ValueError when 'value' is provided with non-constant mode."""
102+
img = torch.ones((1, 4, 4))
103+
to_pad = [(0, 0), (1, 1), (2, 2)]
104+
with self.assertRaises(ValueError):
105+
pad_nd(img, to_pad, mode="reflect", value=0)
106+
107+
108+
if __name__ == "__main__":
109+
unittest.main()

0 commit comments

Comments
 (0)