Skip to content

Commit 40df2f6

Browse files
authored
Weights in alpha for FocalLoss (#8665)
Fixes #8601 ### Description Support alpha as a list, tuple, or tensor of floats, in addition to the existing scalar support. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ytl0623 <david89062388@gmail.com>
1 parent 69a0fb1 commit 40df2f6

File tree

2 files changed

+104
-16
lines changed

2 files changed

+104
-16
lines changed

monai/losses/focal_loss.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,22 @@ def __init__(
6969
include_background: bool = True,
7070
to_onehot_y: bool = False,
7171
gamma: float = 2.0,
72-
alpha: float | None = None,
72+
alpha: float | Sequence[float] | None = None,
7373
weight: Sequence[float] | float | int | torch.Tensor | None = None,
7474
reduction: LossReduction | str = LossReduction.MEAN,
7575
use_softmax: bool = False,
7676
) -> None:
7777
"""
7878
Args:
7979
include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
80-
If False, `alpha` is invalid when using softmax.
80+
If False, `alpha` is invalid when using softmax unless `alpha` is a sequence (explicit class weights).
8181
to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
8282
gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
8383
alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
84-
The value should be in [0, 1]. Defaults to None.
84+
The value should be in [0, 1].
85+
If a sequence is provided, its length must match the number of classes
86+
(excluding the background class if `include_background=False`).
87+
Defaults to None.
8588
weight: weights to apply to the voxels of each class. If None no weights are applied.
8689
The input can be a single value (same weight for all classes), a sequence of values (the length
8790
of the sequence should be the same as the number of classes. If not ``include_background``,
@@ -109,9 +112,15 @@ def __init__(
109112
self.include_background = include_background
110113
self.to_onehot_y = to_onehot_y
111114
self.gamma = gamma
112-
self.alpha = alpha
113115
self.weight = weight
114116
self.use_softmax = use_softmax
117+
self.alpha: float | torch.Tensor | None
118+
if alpha is None:
119+
self.alpha = None
120+
elif isinstance(alpha, (float, int)):
121+
self.alpha = float(alpha)
122+
else:
123+
self.alpha = torch.as_tensor(alpha)
115124
weight = torch.as_tensor(weight) if weight is not None else None
116125
self.register_buffer("class_weight", weight)
117126
self.class_weight: None | torch.Tensor
@@ -155,13 +164,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
155164
loss: torch.Tensor | None = None
156165
input = input.float()
157166
target = target.float()
167+
alpha_arg = self.alpha
158168
if self.use_softmax:
159169
if not self.include_background and self.alpha is not None:
160-
self.alpha = None
161-
warnings.warn("`include_background=False`, `alpha` ignored when using softmax.")
162-
loss = softmax_focal_loss(input, target, self.gamma, self.alpha)
170+
if isinstance(self.alpha, (float, int)):
171+
alpha_arg = None
172+
warnings.warn(
173+
"`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2
174+
)
175+
loss = softmax_focal_loss(input, target, self.gamma, alpha_arg)
163176
else:
164-
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)
177+
loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg)
165178

166179
num_of_classes = target.shape[1]
167180
if self.class_weight is not None and num_of_classes != 1:
@@ -202,7 +215,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
202215

203216

204217
def softmax_focal_loss(
205-
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | None = None
218+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
206219
) -> torch.Tensor:
207220
"""
208221
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -214,8 +227,22 @@ def softmax_focal_loss(
214227
loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target
215228

216229
if alpha is not None:
217-
# (1-alpha) for the background class and alpha for the other classes
218-
alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss)
230+
if isinstance(alpha, torch.Tensor):
231+
alpha_t = alpha.to(device=input.device, dtype=input.dtype)
232+
else:
233+
alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)
234+
235+
if alpha_t.ndim == 0: # scalar
236+
alpha_val = alpha_t.item()
237+
# (1-alpha) for the background class and alpha for the other classes
238+
alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss)
239+
else: # tensor (sequence)
240+
if alpha_t.shape[0] != target.shape[1]:
241+
raise ValueError(
242+
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
243+
)
244+
alpha_fac = alpha_t
245+
219246
broadcast_dims = [-1] + [1] * len(target.shape[2:])
220247
alpha_fac = alpha_fac.view(broadcast_dims)
221248
loss = alpha_fac * loss
@@ -224,7 +251,7 @@ def softmax_focal_loss(
224251

225252

226253
def sigmoid_focal_loss(
227-
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | None = None
254+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
228255
) -> torch.Tensor:
229256
"""
230257
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -247,8 +274,27 @@ def sigmoid_focal_loss(
247274
loss = (invprobs * gamma).exp() * loss
248275

249276
if alpha is not None:
250-
# alpha if t==1; (1-alpha) if t==0
251-
alpha_factor = target * alpha + (1 - target) * (1 - alpha)
277+
if isinstance(alpha, torch.Tensor):
278+
alpha_t = alpha.to(device=input.device, dtype=input.dtype)
279+
else:
280+
alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)
281+
282+
if alpha_t.ndim == 0: # scalar
283+
# alpha if t==1; (1-alpha) if t==0
284+
alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t)
285+
else: # tensor (sequence)
286+
if alpha_t.shape[0] != target.shape[1]:
287+
raise ValueError(
288+
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
289+
)
290+
# Reshape alpha for broadcasting: (1, C, 1, 1...)
291+
broadcast_dims = [-1] + [1] * len(target.shape[2:])
292+
alpha_t = alpha_t.view(broadcast_dims)
293+
# Apply per-class weight only to positive samples
294+
# For positive samples (target==1): multiply by alpha[c]
295+
# For negative samples (target==0): keep weight as 1.0
296+
alpha_factor = torch.where(target == 1, alpha_t, torch.ones_like(alpha_t))
297+
252298
loss = alpha_factor * loss
253299

254300
return loss

tests/losses/test_focal_loss.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121

2222
from monai.losses import FocalLoss
2323
from monai.networks import one_hot
24-
from tests.test_utils import test_script_save
24+
from tests.test_utils import TEST_DEVICES, test_script_save
2525

2626
TEST_CASES = []
27-
for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
27+
for case in TEST_DEVICES:
28+
device = case[0]
2829
input_data = {
2930
"input": torch.tensor(
3031
[[[[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]]]], device=device
@@ -77,6 +78,13 @@
7778
TEST_CASES.append([{"to_onehot_y": True, "use_softmax": True}, input_data, 0.16276])
7879
TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8, "use_softmax": True}, input_data, 0.08138])
7980

81+
TEST_ALPHA_BROADCASTING = []
82+
for case in TEST_DEVICES:
83+
device = case[0]
84+
for include_background in [True, False]:
85+
for use_softmax in [True, False]:
86+
TEST_ALPHA_BROADCASTING.append([device, include_background, use_softmax])
87+
8088

8189
class TestFocalLoss(unittest.TestCase):
8290
@parameterized.expand(TEST_CASES)
@@ -374,6 +382,40 @@ def test_script(self):
374382
test_input = torch.ones(2, 2, 8, 8)
375383
test_script_save(loss, test_input, test_input)
376384

385+
@parameterized.expand(TEST_ALPHA_BROADCASTING)
386+
def test_alpha_sequence_broadcasting(self, device, include_background, use_softmax):
387+
"""
388+
Test FocalLoss with alpha as a sequence for proper broadcasting.
389+
"""
390+
num_classes = 3
391+
batch_size = 2
392+
spatial_dims = (4, 4)
393+
394+
logits = torch.randn(batch_size, num_classes, *spatial_dims, device=device)
395+
target = torch.randint(0, num_classes, (batch_size, 1, *spatial_dims), device=device)
396+
397+
if include_background:
398+
alpha_seq = [0.1, 0.5, 2.0]
399+
else:
400+
alpha_seq = [0.5, 2.0]
401+
402+
loss_func = FocalLoss(
403+
to_onehot_y=True,
404+
gamma=2.0,
405+
alpha=alpha_seq,
406+
include_background=include_background,
407+
use_softmax=use_softmax,
408+
reduction="mean",
409+
)
410+
411+
result = loss_func(logits, target)
412+
413+
self.assertTrue(torch.is_tensor(result))
414+
self.assertEqual(result.ndim, 0)
415+
self.assertTrue(
416+
result > 0, f"Loss should be positive. params: dev={device}, bg={include_background}, softmax={use_softmax}"
417+
)
418+
377419

378420
if __name__ == "__main__":
379421
unittest.main()

0 commit comments

Comments
 (0)