Skip to content

Commit 973662b

Browse files
committed
Unify use_downsample source in DiNTS to avoid inconsistent validation
Signed-off-by: Adrian Caderno <adriancaderno@gmail.com>
1 parent 214ab53 commit 973662b

1 file changed

Lines changed: 7 additions & 1 deletion

File tree

monai/networks/nets/dints.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,12 @@ def __init__(
373373
):
374374
super().__init__()
375375

376+
if hasattr(dints_space, "use_downsample") and dints_space.use_downsample != use_downsample:
377+
raise ValueError(
378+
f"DiNTS.use_downsample ({use_downsample}) must match dints_space.use_downsample "
379+
f"({dints_space.use_downsample})."
380+
)
381+
self.use_downsample = use_downsample
376382
self.dints_space = dints_space
377383
self.filter_nums = dints_space.filter_nums
378384
self.num_blocks = dints_space.num_blocks
@@ -503,7 +509,7 @@ def _check_input_size(self, spatial_shape):
503509
Raises:
504510
ValueError: if any spatial dimension is not divisible by the required factor.
505511
"""
506-
factor = 2 ** (self.num_depths + int(self.dints_space.use_downsample))
512+
factor = 2 ** (self.num_depths + int(self.use_downsample))
507513
wrong_dims = [i + 2 for i, s in enumerate(spatial_shape) if s % factor != 0]
508514
if wrong_dims:
509515
raise ValueError(

0 commit comments

Comments
 (0)