Skip to content

Commit fc612cb

Browse files
committed
Fix torch.export compatibility in loss functions
Replace torch.arange().tolist() with list(range()) in dice, focal, and tversky losses to avoid graph breaks under torch.export tracing. Move class_weight validation to __init__ in hausdorff_loss. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 45c20ef commit fc612cb

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

monai/losses/dice.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def __init__(
123123
self.smooth_dr = float(smooth_dr)
124124
self.batch = batch
125125
weight = torch.as_tensor(weight) if weight is not None else None
126+
if weight is not None and weight.min() < 0:
127+
raise ValueError("the value/values of the `weight` should be no less than 0.")
126128
self.register_buffer("class_weight", weight)
127129
self.class_weight: None | torch.Tensor
128130
self.soft_label = soft_label
@@ -181,7 +183,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
181183
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
182184

183185
# reducing only spatial dimensions (not batch nor channels)
184-
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
186+
reduce_axis: list[int] = list(range(2, len(input.shape)))
185187
if self.batch:
186188
# reducing spatial dimensions and batch
187189
reduce_axis = [0] + reduce_axis
@@ -208,9 +210,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
208210
If `include_background=False`, the weight should not include
209211
the background category class 0."""
210212
)
211-
if self.class_weight.min() < 0:
212-
raise ValueError("the value/values of the `weight` should be no less than 0.")
213-
# apply class_weight to loss
213+
# apply class_weight to loss (weight values validated in __init__)
214214
f = f * self.class_weight.to(f)
215215

216216
if self.reduction == LossReduction.MEAN.value:
@@ -431,7 +431,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
431431
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
432432

433433
# reducing only spatial dimensions (not batch nor channels)
434-
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
434+
reduce_axis: list[int] = list(range(2, len(input.shape)))
435435
if self.batch:
436436
reduce_axis = [0] + reduce_axis
437437

monai/losses/focal_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def __init__(
122122
else:
123123
self.alpha = torch.as_tensor(alpha)
124124
weight = torch.as_tensor(weight) if weight is not None else None
125+
if weight is not None and weight.min() < 0:
126+
raise ValueError("the value/values of the `weight` should be no less than 0.")
125127
self.register_buffer("class_weight", weight)
126128
self.class_weight: None | torch.Tensor
127129

@@ -188,9 +190,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
188190
If `include_background=False`, the weight should not include
189191
the background category class 0."""
190192
)
191-
if self.class_weight.min() < 0:
192-
raise ValueError("the value/values of the `weight` should be no less than 0.")
193-
# apply class_weight to loss
193+
# apply class_weight to loss (weight values validated in __init__)
194194
self.class_weight = self.class_weight.to(loss)
195195
broadcast_dims = [-1] + [1] * len(target.shape[2:])
196196
self.class_weight = self.class_weight.view(broadcast_dims)

monai/losses/hausdorff_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
190190
distance = pred_dt**self.alpha + target_dt**self.alpha
191191

192192
running_f = pred_error * distance.to(device)
193-
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
193+
reduce_axis: list[int] = list(range(2, len(input.shape)))
194194
if self.batch:
195195
# reducing spatial dimensions and batch
196196
reduce_axis = [0] + reduce_axis

monai/losses/tversky.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
143143
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
144144

145145
# reducing only spatial dimensions (not batch nor channels)
146-
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
146+
reduce_axis: list[int] = list(range(2, len(input.shape)))
147147
if self.batch:
148148
# reducing spatial dimensions and batch
149149
reduce_axis = [0] + reduce_axis

0 commit comments

Comments
 (0)