Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 41 additions & 27 deletions connectomics/training/deep_supervision.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import annotations
from typing import Dict, List, Tuple, Optional
import inspect
import warnings
import pdb

Expand All @@ -18,6 +19,21 @@
from ..config import Config


def _loss_supports_weight(loss_fn: nn.Module) -> bool:
"""Check if a loss function's forward method accepts a 'weight' keyword argument.

This is used to conditionally pass per-voxel weight masks only to loss
functions that support them (e.g., WeightedMSELoss, WeightedMAELoss,
SmoothL1Loss) while skipping the argument for standard losses that do
not (e.g., MONAI DiceLoss, BCEWithLogitsLoss).
"""
try:
sig = inspect.signature(loss_fn.forward)
return "weight" in sig.parameters
except (ValueError, TypeError):
return False


class DeepSupervisionHandler:
"""
Handler for deep supervision and multi-task learning.
Expand Down Expand Up @@ -132,15 +148,15 @@ def compute_multitask_loss(
loss_fn = self.loss_functions[loss_idx]
weight = self.loss_weights[loss_idx]

# [D3] Compute foreground-weighted mask for SDT loss
# Weight foreground (SDT > 0) more heavily to prevent background-dominated learning
# REDUCED from 5.0 to 2.0 to prevent numerical explosion
fg_weight = 2.0
loss_weight_mask = torch.ones_like(task_label)
loss_weight_mask[task_label > 0] = fg_weight

# [D3] Pass weight mask to loss function (WeightedMSELoss supports this)
loss = loss_fn(task_output, task_label, weight=loss_weight_mask)
# [D3] Pass foreground-weighted mask to loss functions that support it
# (e.g., WeightedMSELoss, WeightedMAELoss, SmoothL1Loss)
if _loss_supports_weight(loss_fn):
fg_weight = 2.0
loss_weight_mask = torch.ones_like(task_label)
loss_weight_mask[task_label > 0] = fg_weight
loss = loss_fn(task_output, task_label, weight=loss_weight_mask)
else:
loss = loss_fn(task_output, task_label)

# Check for NaN/Inf
if self.enable_nan_detection and (torch.isnan(loss) or torch.isinf(loss)):
Expand Down Expand Up @@ -269,16 +285,15 @@ def compute_loss_for_scale(
# Clamp outputs to prevent numerical instability at coarser scales
output_clamped = torch.clamp(output, min=self.clamp_min, max=self.clamp_max)

# [D3] Compute foreground-weighted mask for SDT loss
# Weight foreground (SDT > 0) more heavily to prevent background-dominated learning
# REDUCED from 5.0 to 2.0 to prevent numerical explosion
fg_weight = 2.0
loss_weight_mask = torch.ones_like(target)
loss_weight_mask[target > 0] = fg_weight

for loss_fn, weight in zip(self.loss_functions, self.loss_weights):
# [D3] Pass weight mask to loss function (WeightedMSELoss supports this)
loss = loss_fn(output_clamped, target, weight=loss_weight_mask)
# [D3] Pass foreground-weighted mask to loss functions that support it
if _loss_supports_weight(loss_fn):
fg_weight = 2.0
loss_weight_mask = torch.ones_like(target)
loss_weight_mask[target > 0] = fg_weight
loss = loss_fn(output_clamped, target, weight=loss_weight_mask)
else:
loss = loss_fn(output_clamped, target)

# Check for NaN/Inf (only in training mode)
if (
Expand Down Expand Up @@ -387,16 +402,15 @@ def compute_standard_loss(
total_loss, loss_dict = self.compute_multitask_loss(outputs, labels, stage=stage)
else:
# Standard single-scale loss: apply all losses to all outputs
# [D3] Compute foreground-weighted mask for SDT loss
# Weight foreground (SDT > 0) more heavily to prevent background-dominated learning
# REDUCED from 5.0 to 2.0 to prevent numerical explosion
fg_weight = 2.0
loss_weight_mask = torch.ones_like(labels)
loss_weight_mask[labels > 0] = fg_weight

for i, (loss_fn, weight) in enumerate(zip(self.loss_functions, self.loss_weights)):
# [D3] Pass weight mask to loss function (WeightedMSELoss supports this)
loss = loss_fn(outputs, labels, weight=loss_weight_mask)
# [D3] Pass foreground-weighted mask to loss functions that support it
if _loss_supports_weight(loss_fn):
fg_weight = 2.0
loss_weight_mask = torch.ones_like(labels)
loss_weight_mask[labels > 0] = fg_weight
loss = loss_fn(outputs, labels, weight=loss_weight_mask)
else:
loss = loss_fn(outputs, labels)

# Check for NaN/Inf (only in training mode)
if (
Expand Down
Loading