@@ -318,7 +318,7 @@ def get_surface_distance(
318318
319319 dis = convert_to_dst_type (dis , seg_pred , dtype = lib .float32 )[0 ]
320320 if isinstance (seg_pred , torch .Tensor ):
321- return dis [seg_pred .bool ()] # type: ignore[union-attr]
321+ return dis [seg_pred .bool ()] # type: ignore[union-attr,no-any-return ]
322322 else :
323323 # NumPy array
324324 return dis [seg_pred .astype (bool )] # type: ignore[union-attr,no-any-return]
@@ -352,7 +352,6 @@ def get_edge_surface_distance(
352352 This will return the areas of the edges.
353353 symmetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`.
354354 class_index: The class-index used for context when warning about empty ground truth or prediction.
355- mask: optional boolean mask indicating valid pixels.
356355
357356 Returns:
358357 (edges_pred, edges_gt), (distances_pred_to_gt, [distances_gt_to_pred]), (areas_pred, areas_gt) | tuple()
@@ -365,14 +364,6 @@ def get_edge_surface_distance(
365364 edge_results = get_mask_edges (y_pred , y , crop = True , spacing = edges_spacing , always_return_as_numpy = False )
366365 edges_pred , edges_gt = edge_results [0 ], edge_results [1 ]
367366
368- if mask is not None :
369- if len (edge_results ) > 2 and isinstance (edge_results [2 ], tuple ):
370- slices = edge_results [2 ]
371- mask = mask [slices ]
372- mask = torch .as_tensor (mask , device = edges_pred .device , dtype = torch .bool )
373- edges_pred = edges_pred & mask
374- edges_gt = edges_gt & mask
375-
376367 distances_raw : tuple [torch .Tensor , torch .Tensor ] | tuple [torch .Tensor ]
377368 if symmetric :
378369 distances_raw = (
@@ -382,7 +373,7 @@ def get_edge_surface_distance(
382373 else :
383374 distances_raw = (get_surface_distance (edges_pred , edges_gt , distance_metric , spacing ),) # type: ignore
384375
385- distances_list = [ d if d is not None else edges_pred . new_empty (( 0 ,)) for d in distances_raw ]
376+ distances_list = list ( distances_raw )
386377 distances : tuple [torch .Tensor , torch .Tensor ] | tuple [torch .Tensor ] = (
387378 tuple (distances_list ) if len (distances_list ) == 2 else (distances_list [0 ],) # type: ignore[assignment]
388379 )
0 commit comments