diff --git a/connectomics/training/deep_supervision.py b/connectomics/training/deep_supervision.py index e9e42c5a..a4551ef5 100644 --- a/connectomics/training/deep_supervision.py +++ b/connectomics/training/deep_supervision.py @@ -7,6 +7,7 @@ from __future__ import annotations from typing import Dict, List, Tuple, Optional +import inspect import warnings import pdb @@ -18,6 +19,21 @@ from ..config import Config +def _loss_supports_weight(loss_fn: nn.Module) -> bool: + """Check if a loss function's forward method accepts a 'weight' keyword argument. + + This is used to conditionally pass per-voxel weight masks only to loss + functions that support them (e.g., WeightedMSELoss, WeightedMAELoss, + SmoothL1Loss) while skipping the argument for standard losses that do + not (e.g., MONAI DiceLoss, BCEWithLogitsLoss). + """ + try: + sig = inspect.signature(loss_fn.forward) + return "weight" in sig.parameters + except (ValueError, TypeError): + return False + + class DeepSupervisionHandler: """ Handler for deep supervision and multi-task learning. @@ -132,15 +148,15 @@ def compute_multitask_loss( loss_fn = self.loss_functions[loss_idx] weight = self.loss_weights[loss_idx] - # [D3] Compute foreground-weighted mask for SDT loss - # Weight foreground (SDT > 0) more heavily to prevent background-dominated learning - # REDUCED from 5.0 to 2.0 to prevent numerical explosion - fg_weight = 2.0 - loss_weight_mask = torch.ones_like(task_label) - loss_weight_mask[task_label > 0] = fg_weight - - # [D3] Pass weight mask to loss function (WeightedMSELoss supports this) - loss = loss_fn(task_output, task_label, weight=loss_weight_mask) + # [D3] Pass foreground-weighted mask to loss functions that support it + # (e.g., WeightedMSELoss, WeightedMAELoss, SmoothL1Loss) + if _loss_supports_weight(loss_fn): + fg_weight = 2.0 + loss_weight_mask = torch.ones_like(task_label) + loss_weight_mask[task_label > 0] = fg_weight + loss = loss_fn(task_output, task_label, weight=loss_weight_mask) + else: + loss = loss_fn(task_output, task_label) # Check for NaN/Inf if self.enable_nan_detection and (torch.isnan(loss) or torch.isinf(loss)): @@ -269,16 +285,15 @@ def compute_loss_for_scale( # Clamp outputs to prevent numerical instability at coarser scales output_clamped = torch.clamp(output, min=self.clamp_min, max=self.clamp_max) - # [D3] Compute foreground-weighted mask for SDT loss - # Weight foreground (SDT > 0) more heavily to prevent background-dominated learning - # REDUCED from 5.0 to 2.0 to prevent numerical explosion - fg_weight = 2.0 - loss_weight_mask = torch.ones_like(target) - loss_weight_mask[target > 0] = fg_weight - for loss_fn, weight in zip(self.loss_functions, self.loss_weights): - # [D3] Pass weight mask to loss function (WeightedMSELoss supports this) - loss = loss_fn(output_clamped, target, weight=loss_weight_mask) + # [D3] Pass foreground-weighted mask to loss functions that support it + if _loss_supports_weight(loss_fn): + fg_weight = 2.0 + loss_weight_mask = torch.ones_like(target) + loss_weight_mask[target > 0] = fg_weight + loss = loss_fn(output_clamped, target, weight=loss_weight_mask) + else: + loss = loss_fn(output_clamped, target) # Check for NaN/Inf (only in training mode) if ( @@ -387,16 +402,15 @@ def compute_standard_loss( total_loss, loss_dict = self.compute_multitask_loss(outputs, labels, stage=stage) else: # Standard single-scale loss: apply all losses to all outputs - # [D3] Compute foreground-weighted mask for SDT loss - # Weight foreground (SDT > 0) more heavily to prevent background-dominated learning - # REDUCED from 5.0 to 2.0 to prevent numerical explosion - fg_weight = 2.0 - loss_weight_mask = torch.ones_like(labels) - loss_weight_mask[labels > 0] = fg_weight - for i, (loss_fn, weight) in enumerate(zip(self.loss_functions, self.loss_weights)): - # [D3] Pass weight mask to loss function (WeightedMSELoss supports this) - loss = loss_fn(outputs, labels, weight=loss_weight_mask) + # [D3] Pass foreground-weighted mask to loss functions that support it + if _loss_supports_weight(loss_fn): + fg_weight = 2.0 + loss_weight_mask = torch.ones_like(labels) + loss_weight_mask[labels > 0] = fg_weight + loss = loss_fn(outputs, labels, weight=loss_weight_mask) + else: + loss = loss_fn(outputs, labels) # Check for NaN/Inf (only in training mode) if (