Skip to content

Commit a8f2ce0

Browse files
authored
Merge pull request #184 from PytorchConnectomics/claude/fix-issue-183-ElNXa
Make loss weight masking conditional on loss function support
2 parents 70df515 + 84d74c7 commit a8f2ce0

1 file changed

Lines changed: 41 additions & 27 deletions

File tree

connectomics/training/deep_supervision.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from __future__ import annotations
99
from typing import Dict, List, Tuple, Optional
10+
import inspect
1011
import warnings
1112
import pdb
1213

@@ -18,6 +19,21 @@
1819
from ..config import Config
1920

2021

22+
def _loss_supports_weight(loss_fn: nn.Module) -> bool:
23+
"""Check if a loss function's forward method accepts a 'weight' keyword argument.
24+
25+
This is used to conditionally pass per-voxel weight masks only to loss
26+
functions that support them (e.g., WeightedMSELoss, WeightedMAELoss,
27+
SmoothL1Loss) while skipping the argument for standard losses that do
28+
not (e.g., MONAI DiceLoss, BCEWithLogitsLoss).
29+
"""
30+
try:
31+
sig = inspect.signature(loss_fn.forward)
32+
return "weight" in sig.parameters
33+
except (ValueError, TypeError):
34+
return False
35+
36+
2137
class DeepSupervisionHandler:
2238
"""
2339
Handler for deep supervision and multi-task learning.
@@ -132,15 +148,15 @@ def compute_multitask_loss(
132148
loss_fn = self.loss_functions[loss_idx]
133149
weight = self.loss_weights[loss_idx]
134150

135-
# [D3] Compute foreground-weighted mask for SDT loss
136-
# Weight foreground (SDT > 0) more heavily to prevent background-dominated learning
137-
# REDUCED from 5.0 to 2.0 to prevent numerical explosion
138-
fg_weight = 2.0
139-
loss_weight_mask = torch.ones_like(task_label)
140-
loss_weight_mask[task_label > 0] = fg_weight
141-
142-
# [D3] Pass weight mask to loss function (WeightedMSELoss supports this)
143-
loss = loss_fn(task_output, task_label, weight=loss_weight_mask)
151+
# [D3] Pass foreground-weighted mask to loss functions that support it
152+
# (e.g., WeightedMSELoss, WeightedMAELoss, SmoothL1Loss)
153+
if _loss_supports_weight(loss_fn):
154+
fg_weight = 2.0
155+
loss_weight_mask = torch.ones_like(task_label)
156+
loss_weight_mask[task_label > 0] = fg_weight
157+
loss = loss_fn(task_output, task_label, weight=loss_weight_mask)
158+
else:
159+
loss = loss_fn(task_output, task_label)
144160

145161
# Check for NaN/Inf
146162
if self.enable_nan_detection and (torch.isnan(loss) or torch.isinf(loss)):
@@ -269,16 +285,15 @@ def compute_loss_for_scale(
269285
# Clamp outputs to prevent numerical instability at coarser scales
270286
output_clamped = torch.clamp(output, min=self.clamp_min, max=self.clamp_max)
271287

272-
# [D3] Compute foreground-weighted mask for SDT loss
273-
# Weight foreground (SDT > 0) more heavily to prevent background-dominated learning
274-
# REDUCED from 5.0 to 2.0 to prevent numerical explosion
275-
fg_weight = 2.0
276-
loss_weight_mask = torch.ones_like(target)
277-
loss_weight_mask[target > 0] = fg_weight
278-
279288
for loss_fn, weight in zip(self.loss_functions, self.loss_weights):
280-
# [D3] Pass weight mask to loss function (WeightedMSELoss supports this)
281-
loss = loss_fn(output_clamped, target, weight=loss_weight_mask)
289+
# [D3] Pass foreground-weighted mask to loss functions that support it
290+
if _loss_supports_weight(loss_fn):
291+
fg_weight = 2.0
292+
loss_weight_mask = torch.ones_like(target)
293+
loss_weight_mask[target > 0] = fg_weight
294+
loss = loss_fn(output_clamped, target, weight=loss_weight_mask)
295+
else:
296+
loss = loss_fn(output_clamped, target)
282297

283298
# Check for NaN/Inf (only in training mode)
284299
if (
@@ -387,16 +402,15 @@ def compute_standard_loss(
387402
total_loss, loss_dict = self.compute_multitask_loss(outputs, labels, stage=stage)
388403
else:
389404
# Standard single-scale loss: apply all losses to all outputs
390-
# [D3] Compute foreground-weighted mask for SDT loss
391-
# Weight foreground (SDT > 0) more heavily to prevent background-dominated learning
392-
# REDUCED from 5.0 to 2.0 to prevent numerical explosion
393-
fg_weight = 2.0
394-
loss_weight_mask = torch.ones_like(labels)
395-
loss_weight_mask[labels > 0] = fg_weight
396-
397405
for i, (loss_fn, weight) in enumerate(zip(self.loss_functions, self.loss_weights)):
398-
# [D3] Pass weight mask to loss function (WeightedMSELoss supports this)
399-
loss = loss_fn(outputs, labels, weight=loss_weight_mask)
406+
# [D3] Pass foreground-weighted mask to loss functions that support it
407+
if _loss_supports_weight(loss_fn):
408+
fg_weight = 2.0
409+
loss_weight_mask = torch.ones_like(labels)
410+
loss_weight_mask[labels > 0] = fg_weight
411+
loss = loss_fn(outputs, labels, weight=loss_weight_mask)
412+
else:
413+
loss = loss_fn(outputs, labels)
400414

401415
# Check for NaN/Inf (only in training mode)
402416
if (

0 commit comments

Comments
 (0)