@@ -69,19 +69,22 @@ def __init__(
6969 include_background : bool = True ,
7070 to_onehot_y : bool = False ,
7171 gamma : float = 2.0 ,
72- alpha : float | None = None ,
72+ alpha : float | Sequence [ float ] | None = None ,
7373 weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
7474 reduction : LossReduction | str = LossReduction .MEAN ,
7575 use_softmax : bool = False ,
7676 ) -> None :
7777 """
7878 Args:
7979 include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
80- If False, `alpha` is invalid when using softmax.
80+ If False, `alpha` is invalid when using softmax unless `alpha` is a sequence (explicit class weights) .
8181 to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
8282 gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
8383 alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
84- The value should be in [0, 1]. Defaults to None.
84+ The value should be in [0, 1].
85+ If a sequence is provided, its length must match the number of classes
86+ (excluding the background class if `include_background=False`).
87+ Defaults to None.
8588 weight: weights to apply to the voxels of each class. If None no weights are applied.
8689 The input can be a single value (same weight for all classes), a sequence of values (the length
8790 of the sequence should be the same as the number of classes. If not ``include_background``,
@@ -109,9 +112,15 @@ def __init__(
109112 self .include_background = include_background
110113 self .to_onehot_y = to_onehot_y
111114 self .gamma = gamma
112- self .alpha = alpha
113115 self .weight = weight
114116 self .use_softmax = use_softmax
117+ self .alpha : float | torch .Tensor | None
118+ if alpha is None :
119+ self .alpha = None
120+ elif isinstance (alpha , (float , int )):
121+ self .alpha = float (alpha )
122+ else :
123+ self .alpha = torch .as_tensor (alpha )
115124 weight = torch .as_tensor (weight ) if weight is not None else None
116125 self .register_buffer ("class_weight" , weight )
117126 self .class_weight : None | torch .Tensor
@@ -155,13 +164,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
155164 loss : torch .Tensor | None = None
156165 input = input .float ()
157166 target = target .float ()
167+ alpha_arg = self .alpha
158168 if self .use_softmax :
159169 if not self .include_background and self .alpha is not None :
160- self .alpha = None
161- warnings .warn ("`include_background=False`, `alpha` ignored when using softmax." )
162- loss = softmax_focal_loss (input , target , self .gamma , self .alpha )
170+ if isinstance (self .alpha , (float , int )):
171+ alpha_arg = None
172+ warnings .warn (
173+ "`include_background=False`, scalar `alpha` ignored when using softmax." , stacklevel = 2
174+ )
175+ loss = softmax_focal_loss (input , target , self .gamma , alpha_arg )
163176 else :
164- loss = sigmoid_focal_loss (input , target , self .gamma , self . alpha )
177+ loss = sigmoid_focal_loss (input , target , self .gamma , alpha_arg )
165178
166179 num_of_classes = target .shape [1 ]
167180 if self .class_weight is not None and num_of_classes != 1 :
@@ -202,7 +215,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
202215
203216
204217def softmax_focal_loss (
205- input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | None = None
218+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | torch . Tensor | None = None
206219) -> torch .Tensor :
207220 """
208221 FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -214,8 +227,22 @@ def softmax_focal_loss(
214227 loss : torch .Tensor = - (1 - input_ls .exp ()).pow (gamma ) * input_ls * target
215228
216229 if alpha is not None :
217- # (1-alpha) for the background class and alpha for the other classes
218- alpha_fac = torch .tensor ([1 - alpha ] + [alpha ] * (target .shape [1 ] - 1 )).to (loss )
230+ if isinstance (alpha , torch .Tensor ):
231+ alpha_t = alpha .to (device = input .device , dtype = input .dtype )
232+ else :
233+ alpha_t = torch .tensor (alpha , device = input .device , dtype = input .dtype )
234+
235+ if alpha_t .ndim == 0 : # scalar
236+ alpha_val = alpha_t .item ()
237+ # (1-alpha) for the background class and alpha for the other classes
238+ alpha_fac = torch .tensor ([1 - alpha_val ] + [alpha_val ] * (target .shape [1 ] - 1 )).to (loss )
239+ else : # tensor (sequence)
240+ if alpha_t .shape [0 ] != target .shape [1 ]:
241+ raise ValueError (
242+ f"The length of alpha ({ alpha_t .shape [0 ]} ) must match the number of classes ({ target .shape [1 ]} )."
243+ )
244+ alpha_fac = alpha_t
245+
219246 broadcast_dims = [- 1 ] + [1 ] * len (target .shape [2 :])
220247 alpha_fac = alpha_fac .view (broadcast_dims )
221248 loss = alpha_fac * loss
@@ -224,7 +251,7 @@ def softmax_focal_loss(
224251
225252
226253def sigmoid_focal_loss (
227- input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | None = None
254+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | torch . Tensor | None = None
228255) -> torch .Tensor :
229256 """
230257 FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -247,8 +274,27 @@ def sigmoid_focal_loss(
247274 loss = (invprobs * gamma ).exp () * loss
248275
249276 if alpha is not None :
250- # alpha if t==1; (1-alpha) if t==0
251- alpha_factor = target * alpha + (1 - target ) * (1 - alpha )
277+ if isinstance (alpha , torch .Tensor ):
278+ alpha_t = alpha .to (device = input .device , dtype = input .dtype )
279+ else :
280+ alpha_t = torch .tensor (alpha , device = input .device , dtype = input .dtype )
281+
282+ if alpha_t .ndim == 0 : # scalar
283+ # alpha if t==1; (1-alpha) if t==0
284+ alpha_factor = target * alpha_t + (1 - target ) * (1 - alpha_t )
285+ else : # tensor (sequence)
286+ if alpha_t .shape [0 ] != target .shape [1 ]:
287+ raise ValueError (
288+ f"The length of alpha ({ alpha_t .shape [0 ]} ) must match the number of classes ({ target .shape [1 ]} )."
289+ )
290+ # Reshape alpha for broadcasting: (1, C, 1, 1...)
291+ broadcast_dims = [- 1 ] + [1 ] * len (target .shape [2 :])
292+ alpha_t = alpha_t .view (broadcast_dims )
293+ # Apply per-class weight only to positive samples
294+ # For positive samples (target==1): multiply by alpha[c]
295+ # For negative samples (target==0): keep weight as 1.0
296+ alpha_factor = torch .where (target == 1 , alpha_t , torch .ones_like (alpha_t ))
297+
252298 loss = alpha_factor * loss
253299
254300 return loss
0 commit comments