Skip to content

Make loss weight masking conditional on loss function support#184

Merged
donglaiw merged 1 commit intomasterfrom
claude/fix-issue-183-ElNXa
Feb 5, 2026
Merged

Make loss weight masking conditional on loss function support#184
donglaiw merged 1 commit intomasterfrom
claude/fix-issue-183-ElNXa

Conversation

@donglaiw
Copy link
Copy Markdown
Collaborator

@donglaiw donglaiw commented Feb 5, 2026

Summary

This PR refactors the deep supervision loss computation to conditionally apply per-voxel weight masks only to loss functions that explicitly support them, rather than unconditionally passing weight masks to all loss functions.

Key Changes

  • Added _loss_supports_weight() utility function that uses introspection to check if a loss function's forward method accepts a weight parameter
  • Updated three loss computation methods to conditionally apply foreground-weighted masks:
    • compute_multitask_loss()
    • compute_loss_for_scale()
    • compute_standard_loss()
  • Loss functions that support weight masks (e.g., WeightedMSELoss, WeightedMAELoss, SmoothL1Loss) receive the mask
  • Loss functions that don't support weight masks (e.g., MONAI DiceLoss, BCEWithLogitsLoss) are called without the weight argument

Implementation Details

  • The _loss_supports_weight() function uses inspect.signature() to safely check the loss function's forward method signature
  • Gracefully handles inspection errors by returning False as a fallback
  • Foreground weighting (2.0x for SDT > 0) is only computed when needed, reducing unnecessary tensor operations
  • This approach maintains backward compatibility while enabling use of diverse loss functions in the same training pipeline

https://claude.ai/code/session_01XRs3k4q3869VksCMeHsMTN

The deep supervision handler was unconditionally passing weight=loss_weight_mask
to all loss functions, but only custom losses (WeightedMSELoss, WeightedMAELoss,
SmoothL1Loss) accept this parameter. Standard losses like MONAI's DiceLoss and
PyTorch's BCEWithLogitsLoss do not, causing a TypeError during validation.

Add _loss_supports_weight() helper that uses inspect.signature to check if a
loss function's forward method accepts a 'weight' parameter, and conditionally
pass it only when supported.

https://claude.ai/code/session_01XRs3k4q3869VksCMeHsMTN
@donglaiw donglaiw merged commit a8f2ce0 into master Feb 5, 2026
1 of 5 checks passed
@donglaiw donglaiw deleted the claude/fix-issue-183-ElNXa branch March 5, 2026 14:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants