@@ -322,16 +322,61 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
322322 y_pred = torch .sigmoid (y_pred )
323323 y_pred = y_pred > 0.5
324324
325+ # Vectorized computation (replaces nested loops for better performance)
326+ batch_size = y_pred .shape [0 ]
327+ device = y_pred .device
328+
329+ # Convert to boolean for computation
330+ if y_pred .shape [1 ] == 1 and n_pred_ch > 1 :
331+ # Single-channel class indices: convert to one-hot
332+ y_pred_bool = torch .zeros (batch_size , n_pred_ch , * y_pred .shape [2 :], dtype = torch .bool , device = device )
333+ y_bool = torch .zeros (batch_size , n_pred_ch , * y .shape [2 :], dtype = torch .bool , device = device )
334+
335+ for c in range (n_pred_ch ):
336+ y_pred_bool [:, c ] = (y_pred [:, 0 ] == c )
337+ y_bool [:, c ] = (y [:, 0 ] == c )
338+ else :
339+ # One-hot format: cast to bool
340+ y_pred_bool = y_pred .bool ()
341+ if y .shape [1 ] == 1 and y_pred .shape [1 ] > 1 :
342+ # Expand y to match y_pred channels
343+ y_bool = (y == 1 ).expand (batch_size , n_pred_ch , * y .shape [2 :])
344+ else :
345+ y_bool = y .bool ()
346+
347+ # Flatten spatial dimensions for vectorized computation: (batch, channels, -1)
348+ y_pred_flat = y_pred_bool .reshape (batch_size , n_pred_ch , - 1 ).float ()
349+ y_flat = y_bool .reshape (batch_size , n_pred_ch , - 1 ).float ()
350+
351+ # Compute Dice per (batch, channel) vectorized: all reductions at once
352+ intersection = torch .sum (y_pred_flat * y_flat , dim = - 1 ) # (batch, n_pred_ch)
353+ pred_sum = torch .sum (y_pred_flat , dim = - 1 ) # (batch, n_pred_ch)
354+ y_sum = torch .sum (y_flat , dim = - 1 ) # (batch, n_pred_ch)
355+
356+ # Dice formula: 2 * intersection / (pred_sum + y_sum)
357+ union = pred_sum + y_sum
358+ dice = (2.0 * intersection ) / union # (batch, n_pred_ch)
359+
360+ # Handle empty ground truth cases
361+ if self .ignore_empty :
362+ # Set NaN where ground truth is empty
363+ dice = torch .where (y_sum > 0 , dice , torch .tensor (float ("nan" ), device = device , dtype = dice .dtype ))
364+ else :
365+ # Set 1.0 if both empty, 0.0 if only pred is non-empty
366+ empty_mask = y_sum == 0
367+ dice = torch .where (
368+ empty_mask ,
369+ torch .where (pred_sum == 0 , torch .tensor (1.0 , device = device , dtype = dice .dtype ),
370+ torch .tensor (0.0 , device = device , dtype = dice .dtype )),
371+ dice
372+ )
373+
374+ # Select channels: exclude background if requested
325375 first_ch = 0 if self .include_background else 1
326- data = []
327- for b in range (y_pred .shape [0 ]):
328- c_list = []
329- for c in range (first_ch , n_pred_ch ) if n_pred_ch > 1 else [1 ]:
330- x_pred = (y_pred [b , 0 ] == c ) if (y_pred .shape [1 ] == 1 ) else y_pred [b , c ].bool ()
331- x = (y [b , 0 ] == c ) if (y .shape [1 ] == 1 ) else y [b , c ]
332- c_list .append (self .compute_channel (x_pred , x ))
333- data .append (torch .stack (c_list ))
334- data = torch .stack (data , dim = 0 ).contiguous () # type: ignore
376+ if n_pred_ch > 1 :
377+ data = dice [:, first_ch :] # (batch, num_classes_selected)
378+ else :
379+ data = dice # (batch, 1)
335380
336381 f , not_nans = do_metric_reduction (data , self .reduction ) # type: ignore
337382 return (f , not_nans ) if self .get_not_nans else f
0 commit comments