|
7 | 7 |
|
8 | 8 | from __future__ import annotations |
9 | 9 | from typing import Dict, List, Tuple, Optional |
| 10 | +import inspect |
10 | 11 | import warnings |
11 | 12 | import pdb |
12 | 13 |
|
|
18 | 19 | from ..config import Config |
19 | 20 |
|
20 | 21 |
|
| 22 | +def _loss_supports_weight(loss_fn: nn.Module) -> bool: |
| 23 | + """Check if a loss function's forward method accepts a 'weight' keyword argument. |
| 24 | +
|
| 25 | + This is used to conditionally pass per-voxel weight masks only to loss |
| 26 | + functions that support them (e.g., WeightedMSELoss, WeightedMAELoss, |
| 27 | + SmoothL1Loss) while skipping the argument for standard losses that do |
| 28 | + not (e.g., MONAI DiceLoss, BCEWithLogitsLoss). |
| 29 | + """ |
| 30 | + try: |
| 31 | + sig = inspect.signature(loss_fn.forward) |
| 32 | + return "weight" in sig.parameters |
| 33 | + except (ValueError, TypeError): |
| 34 | + return False |
| 35 | + |
| 36 | + |
21 | 37 | class DeepSupervisionHandler: |
22 | 38 | """ |
23 | 39 | Handler for deep supervision and multi-task learning. |
@@ -132,15 +148,15 @@ def compute_multitask_loss( |
132 | 148 | loss_fn = self.loss_functions[loss_idx] |
133 | 149 | weight = self.loss_weights[loss_idx] |
134 | 150 |
|
135 | | - # [D3] Compute foreground-weighted mask for SDT loss |
136 | | - # Weight foreground (SDT > 0) more heavily to prevent background-dominated learning |
137 | | - # REDUCED from 5.0 to 2.0 to prevent numerical explosion |
138 | | - fg_weight = 2.0 |
139 | | - loss_weight_mask = torch.ones_like(task_label) |
140 | | - loss_weight_mask[task_label > 0] = fg_weight |
141 | | - |
142 | | - # [D3] Pass weight mask to loss function (WeightedMSELoss supports this) |
143 | | - loss = loss_fn(task_output, task_label, weight=loss_weight_mask) |
| 151 | + # [D3] Pass foreground-weighted mask to loss functions that support it |
| 152 | + # (e.g., WeightedMSELoss, WeightedMAELoss, SmoothL1Loss) |
| 153 | + if _loss_supports_weight(loss_fn): |
| 154 | + fg_weight = 2.0 |
| 155 | + loss_weight_mask = torch.ones_like(task_label) |
| 156 | + loss_weight_mask[task_label > 0] = fg_weight |
| 157 | + loss = loss_fn(task_output, task_label, weight=loss_weight_mask) |
| 158 | + else: |
| 159 | + loss = loss_fn(task_output, task_label) |
144 | 160 |
|
145 | 161 | # Check for NaN/Inf |
146 | 162 | if self.enable_nan_detection and (torch.isnan(loss) or torch.isinf(loss)): |
@@ -269,16 +285,15 @@ def compute_loss_for_scale( |
269 | 285 | # Clamp outputs to prevent numerical instability at coarser scales |
270 | 286 | output_clamped = torch.clamp(output, min=self.clamp_min, max=self.clamp_max) |
271 | 287 |
|
272 | | - # [D3] Compute foreground-weighted mask for SDT loss |
273 | | - # Weight foreground (SDT > 0) more heavily to prevent background-dominated learning |
274 | | - # REDUCED from 5.0 to 2.0 to prevent numerical explosion |
275 | | - fg_weight = 2.0 |
276 | | - loss_weight_mask = torch.ones_like(target) |
277 | | - loss_weight_mask[target > 0] = fg_weight |
278 | | - |
279 | 288 | for loss_fn, weight in zip(self.loss_functions, self.loss_weights): |
280 | | - # [D3] Pass weight mask to loss function (WeightedMSELoss supports this) |
281 | | - loss = loss_fn(output_clamped, target, weight=loss_weight_mask) |
| 289 | + # [D3] Pass foreground-weighted mask to loss functions that support it |
| 290 | + if _loss_supports_weight(loss_fn): |
| 291 | + fg_weight = 2.0 |
| 292 | + loss_weight_mask = torch.ones_like(target) |
| 293 | + loss_weight_mask[target > 0] = fg_weight |
| 294 | + loss = loss_fn(output_clamped, target, weight=loss_weight_mask) |
| 295 | + else: |
| 296 | + loss = loss_fn(output_clamped, target) |
282 | 297 |
|
283 | 298 | # Check for NaN/Inf (only in training mode) |
284 | 299 | if ( |
@@ -387,16 +402,15 @@ def compute_standard_loss( |
387 | 402 | total_loss, loss_dict = self.compute_multitask_loss(outputs, labels, stage=stage) |
388 | 403 | else: |
389 | 404 | # Standard single-scale loss: apply all losses to all outputs |
390 | | - # [D3] Compute foreground-weighted mask for SDT loss |
391 | | - # Weight foreground (SDT > 0) more heavily to prevent background-dominated learning |
392 | | - # REDUCED from 5.0 to 2.0 to prevent numerical explosion |
393 | | - fg_weight = 2.0 |
394 | | - loss_weight_mask = torch.ones_like(labels) |
395 | | - loss_weight_mask[labels > 0] = fg_weight |
396 | | - |
397 | 405 | for i, (loss_fn, weight) in enumerate(zip(self.loss_functions, self.loss_weights)): |
398 | | - # [D3] Pass weight mask to loss function (WeightedMSELoss supports this) |
399 | | - loss = loss_fn(outputs, labels, weight=loss_weight_mask) |
| 406 | + # [D3] Pass foreground-weighted mask to loss functions that support it |
| 407 | + if _loss_supports_weight(loss_fn): |
| 408 | + fg_weight = 2.0 |
| 409 | + loss_weight_mask = torch.ones_like(labels) |
| 410 | + loss_weight_mask[labels > 0] = fg_weight |
| 411 | + loss = loss_fn(outputs, labels, weight=loss_weight_mask) |
| 412 | + else: |
| 413 | + loss = loss_fn(outputs, labels) |
400 | 414 |
|
401 | 415 | # Check for NaN/Inf (only in training mode) |
402 | 416 | if ( |
|
0 commit comments