Skip to content

Commit 3420a85

Browse files
committed
Fix TorchScript compatibility and invalid test input shape for DiNTS
Signed-off-by: Adrian Caderno <adriancaderno@gmail.com>
1 parent c0cc47a commit 3420a85

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

monai/networks/nets/dints.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ def __init__(
490490
def weight_parameters(self):
491491
return [param for name, param in self.named_parameters()]
492492

493+
@torch.jit.unused
493494
def _check_input_size(self, spatial_shape):
494495
"""
495496
Validate that input spatial dimensions satisfy the divisibility requirement.
@@ -521,7 +522,8 @@ def forward(self, x: torch.Tensor):
521522
ValueError: if any spatial dimension of ``x`` is not divisible by
522523
``2 ** (num_depths + int(use_downsample))``.
523524
"""
524-
self._check_input_size(x.shape[2:])
525+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
526+
self._check_input_size(x.shape[2:])
525527
inputs = []
526528
for d in range(self.num_depths):
527529
# allow multi-resolution input

tests/networks/nets/test_dints_network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@
8484
"use_downsample": True,
8585
"spatial_dims": 2,
8686
},
87-
(2, 2, 32, 16),
88-
(2, 2, 32, 16),
87+
(2, 2, 32, 32), # use_downsample=True, num_depths=4 -> factor=32; both dims must be divisible by 32
88+
(2, 2, 32, 32),
8989
]
9090
]
9191
if torch.cuda.is_available():

0 commit comments

Comments
 (0)