Skip to content

Commit 1114907

Browse files
committed
fix: resolve all mypy and CodeRabbit issues
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent c80eeeb commit 1114907

2 files changed

Lines changed: 4 additions & 12 deletions

File tree

monai/metrics/generalized_dice.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
7474
y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
7575
where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
7676
y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
77-
ignore_index: class index to ignore from the metric computation.
77+
Note:
78+
The ignore_index for this computation is taken from self.ignore_index if set during initialization.
7879
7980
Returns:
8081
torch.Tensor: Generalized Dice Score averaged across batch and class

monai/metrics/utils.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def get_surface_distance(
318318

319319
dis = convert_to_dst_type(dis, seg_pred, dtype=lib.float32)[0]
320320
if isinstance(seg_pred, torch.Tensor):
321-
return dis[seg_pred.bool()] # type: ignore[union-attr]
321+
return dis[seg_pred.bool()] # type: ignore[union-attr,no-any-return]
322322
else:
323323
# NumPy array
324324
return dis[seg_pred.astype(bool)] # type: ignore[union-attr,no-any-return]
@@ -352,7 +352,6 @@ def get_edge_surface_distance(
352352
This will return the areas of the edges.
353353
symmetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`.
354354
class_index: The class-index used for context when warning about empty ground truth or prediction.
355-
mask: optional boolean mask indicating valid pixels.
356355
357356
Returns:
358357
(edges_pred, edges_gt), (distances_pred_to_gt, [distances_gt_to_pred]), (areas_pred, areas_gt) | tuple()
@@ -365,14 +364,6 @@ def get_edge_surface_distance(
365364
edge_results = get_mask_edges(y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False)
366365
edges_pred, edges_gt = edge_results[0], edge_results[1]
367366

368-
if mask is not None:
369-
if len(edge_results) > 2 and isinstance(edge_results[2], tuple):
370-
slices = edge_results[2]
371-
mask = mask[slices]
372-
mask = torch.as_tensor(mask, device=edges_pred.device, dtype=torch.bool)
373-
edges_pred = edges_pred & mask
374-
edges_gt = edges_gt & mask
375-
376367
distances_raw: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]
377368
if symmetric:
378369
distances_raw = (
@@ -382,7 +373,7 @@ def get_edge_surface_distance(
382373
else:
383374
distances_raw = (get_surface_distance(edges_pred, edges_gt, distance_metric, spacing),) # type: ignore
384375

385-
distances_list = [d if d is not None else edges_pred.new_empty((0,)) for d in distances_raw]
376+
distances_list = list(distances_raw)
386377
distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor] = (
387378
tuple(distances_list) if len(distances_list) == 2 else (distances_list[0],) # type: ignore[assignment]
388379
)

0 commit comments

Comments
 (0)