@@ -344,28 +344,30 @@ def __call__(self, data: Mapping) -> dict:
344344 restore_grad_state = torch .is_grad_enabled ()
345345 torch .set_grad_enabled (False )
346346
347- ndas = [d [self .image_key ][i ] for i in range (d [self .image_key ].shape [0 ])]
348- ndas_label = d [self .label_key ] # (H,W,D)
347+ try :
348+ ndas = [d [self .image_key ][i ] for i in range (d [self .image_key ].shape [0 ])]
349+ ndas_label = d [self .label_key ] # (H,W,D)
349350
350- if ndas_label .shape != ndas [0 ].shape :
351- raise ValueError (f"Label shape { ndas_label .shape } is different from image shape { ndas [0 ].shape } " )
351+ if ndas_label .shape != ndas [0 ].shape :
352+ raise ValueError (f"Label shape { ndas_label .shape } is different from image shape { ndas [0 ].shape } " )
352353
353- nda_foregrounds = [get_foreground_label (nda , ndas_label ) for nda in ndas ]
354- nda_foregrounds = [nda if nda .numel () > 0 else MetaTensor ([0.0 ]) for nda in nda_foregrounds ]
354+ nda_foregrounds = [get_foreground_label (nda , ndas_label ) for nda in ndas ]
355+ nda_foregrounds = [nda if nda .numel () > 0 else MetaTensor ([0.0 ]) for nda in nda_foregrounds ]
355356
356- # perform calculation
357- report = deepcopy (self .get_report_format ())
357+ # perform calculation
358+ report = deepcopy (self .get_report_format ())
358359
359- report [ImageStatsKeys .INTENSITY ] = [
360- self .ops [ImageStatsKeys .INTENSITY ].evaluate (nda_f ) for nda_f in nda_foregrounds
361- ]
360+ report [ImageStatsKeys .INTENSITY ] = [
361+ self .ops [ImageStatsKeys .INTENSITY ].evaluate (nda_f ) for nda_f in nda_foregrounds
362+ ]
362363
363- if not verify_report_format (report , self .get_report_format ()):
364- raise RuntimeError (f"report generated by { self .__class__ } differs from the report format." )
364+ if not verify_report_format (report , self .get_report_format ()):
365+ raise RuntimeError (f"report generated by { self .__class__ } differs from the report format." )
365366
366- d [self .stats_name ] = report
367+ d [self .stats_name ] = report
368+ finally :
369+ torch .set_grad_enabled (restore_grad_state )
367370
368- torch .set_grad_enabled (restore_grad_state )
369371 logger .debug (f"Get foreground image stats spent { time .time () - start } " )
370372 return d
371373
@@ -477,78 +479,80 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
477479 restore_grad_state = torch .is_grad_enabled ()
478480 torch .set_grad_enabled (False )
479481
480- if isinstance (image_tensor , (MetaTensor , torch .Tensor )) and isinstance (
481- label_tensor , (MetaTensor , torch .Tensor )
482- ):
483- if label_tensor .device != image_tensor .device :
484- if using_cuda :
485- # Move both tensors to CUDA when mixing devices
486- cuda_device = image_tensor .device if image_tensor .device .type == "cuda" else label_tensor .device
487- image_tensor = cast (MetaTensor , image_tensor .to (cuda_device ))
488- label_tensor = cast (MetaTensor , label_tensor .to (cuda_device ))
489- else :
490- label_tensor = cast (MetaTensor , label_tensor .to (image_tensor .device ))
491-
492- ndas : list [MetaTensor ] = [image_tensor [i ] for i in range (image_tensor .shape [0 ])] # type: ignore
493- ndas_label : MetaTensor = label_tensor .astype (torch .int16 ) # (H,W,D)
494-
495- if ndas_label .shape != ndas [0 ].shape :
496- raise ValueError (f"Label shape { ndas_label .shape } is different from image shape { ndas [0 ].shape } " )
497-
498- nda_foregrounds : list [torch .Tensor ] = [get_foreground_label (nda , ndas_label ) for nda in ndas ]
499- nda_foregrounds = [nda if nda .numel () > 0 else MetaTensor ([0.0 ]) for nda in nda_foregrounds ]
500-
501- unique_label = unique (ndas_label )
502- if isinstance (ndas_label , (MetaTensor , torch .Tensor )):
503- unique_label = unique_label .data .cpu ().numpy () # type: ignore[assignment]
504-
505- unique_label = unique_label .astype (np .int16 ).tolist ()
506-
507- label_substats = [] # each element is one label
508- pixel_sum = 0
509- pixel_arr = []
510- for index in unique_label :
511- start_label = time .time ()
512- label_dict : dict [str , Any ] = {}
513- mask_index = ndas_label == index
514-
515- nda_masks = [nda [mask_index ] for nda in ndas ]
516- label_dict [LabelStatsKeys .IMAGE_INTST ] = [
517- self .ops [LabelStatsKeys .IMAGE_INTST ].evaluate (nda_m ) for nda_m in nda_masks
518- ]
519-
520- pixel_count = sum (mask_index )
521- pixel_arr .append (pixel_count )
522- pixel_sum += pixel_count
523- if self .do_ccp : # apply connected component
524- if using_cuda :
525- # The back end of get_label_ccp is CuPy
526- # which is unable to automatically release CUDA GPU memory held by PyTorch
527- del nda_masks
528- torch .cuda .empty_cache ()
529- shape_list , ncomponents = get_label_ccp (mask_index )
530- label_dict [LabelStatsKeys .LABEL_SHAPE ] = shape_list
531- label_dict [LabelStatsKeys .LABEL_NCOMP ] = ncomponents
532-
533- label_substats .append (label_dict )
534- logger .debug (f" label { index } stats takes { time .time () - start_label } " )
535-
536- for i , _ in enumerate (unique_label ):
537- label_substats [i ].update ({LabelStatsKeys .PIXEL_PCT : float (pixel_arr [i ] / pixel_sum )})
482+ try :
483+ if isinstance (image_tensor , (MetaTensor , torch .Tensor )) and isinstance (
484+ label_tensor , (MetaTensor , torch .Tensor )
485+ ):
486+ if label_tensor .device != image_tensor .device :
487+ if using_cuda :
488+ # Move both tensors to CUDA when mixing devices
489+ cuda_device = image_tensor .device if image_tensor .device .type == "cuda" else label_tensor .device
490+ image_tensor = cast (MetaTensor , image_tensor .to (cuda_device ))
491+ label_tensor = cast (MetaTensor , label_tensor .to (cuda_device ))
492+ else :
493+ label_tensor = cast (MetaTensor , label_tensor .to (image_tensor .device ))
494+
495+ ndas : list [MetaTensor ] = [image_tensor [i ] for i in range (image_tensor .shape [0 ])] # type: ignore
496+ ndas_label : MetaTensor = label_tensor .astype (torch .int16 ) # (H,W,D)
497+
498+ if ndas_label .shape != ndas [0 ].shape :
499+ raise ValueError (f"Label shape { ndas_label .shape } is different from image shape { ndas [0 ].shape } " )
500+
501+ nda_foregrounds : list [torch .Tensor ] = [get_foreground_label (nda , ndas_label ) for nda in ndas ]
502+ nda_foregrounds = [nda if nda .numel () > 0 else MetaTensor ([0.0 ]) for nda in nda_foregrounds ]
503+
504+ unique_label = unique (ndas_label )
505+ if isinstance (ndas_label , (MetaTensor , torch .Tensor )):
506+ unique_label = unique_label .data .cpu ().numpy () # type: ignore[assignment]
507+
508+ unique_label = unique_label .astype (np .int16 ).tolist ()
509+
510+ label_substats = [] # each element is one label
511+ pixel_sum = 0
512+ pixel_arr = []
513+ for index in unique_label :
514+ start_label = time .time ()
515+ label_dict : dict [str , Any ] = {}
516+ mask_index = ndas_label == index
517+
518+ nda_masks = [nda [mask_index ] for nda in ndas ]
519+ label_dict [LabelStatsKeys .IMAGE_INTST ] = [
520+ self .ops [LabelStatsKeys .IMAGE_INTST ].evaluate (nda_m ) for nda_m in nda_masks
521+ ]
538522
539- report = deepcopy (self .get_report_format ())
540- report [LabelStatsKeys .LABEL_UID ] = unique_label
541- report [LabelStatsKeys .IMAGE_INTST ] = [
542- self .ops [LabelStatsKeys .IMAGE_INTST ].evaluate (nda_f ) for nda_f in nda_foregrounds
543- ]
544- report [LabelStatsKeys .LABEL ] = label_substats
523+ pixel_count = sum (mask_index )
524+ pixel_arr .append (pixel_count )
525+ pixel_sum += pixel_count
526+ if self .do_ccp : # apply connected component
527+ if using_cuda :
528+ # The back end of get_label_ccp is CuPy
529+ # which is unable to automatically release CUDA GPU memory held by PyTorch
530+ del nda_masks
531+ torch .cuda .empty_cache ()
532+ shape_list , ncomponents = get_label_ccp (mask_index )
533+ label_dict [LabelStatsKeys .LABEL_SHAPE ] = shape_list
534+ label_dict [LabelStatsKeys .LABEL_NCOMP ] = ncomponents
535+
536+ label_substats .append (label_dict )
537+ logger .debug (f" label { index } stats takes { time .time () - start_label } " )
538+
539+ for i , _ in enumerate (unique_label ):
540+ label_substats [i ].update ({LabelStatsKeys .PIXEL_PCT : float (pixel_arr [i ] / pixel_sum )})
541+
542+ report = deepcopy (self .get_report_format ())
543+ report [LabelStatsKeys .LABEL_UID ] = unique_label
544+ report [LabelStatsKeys .IMAGE_INTST ] = [
545+ self .ops [LabelStatsKeys .IMAGE_INTST ].evaluate (nda_f ) for nda_f in nda_foregrounds
546+ ]
547+ report [LabelStatsKeys .LABEL ] = label_substats
545548
546- if not verify_report_format (report , self .get_report_format ()):
547- raise RuntimeError (f"report generated by { self .__class__ } differs from the report format." )
549+ if not verify_report_format (report , self .get_report_format ()):
550+ raise RuntimeError (f"report generated by { self .__class__ } differs from the report format." )
548551
549- d [self .stats_name ] = report # type: ignore[assignment]
552+ d [self .stats_name ] = report # type: ignore[assignment]
553+ finally :
554+ torch .set_grad_enabled (restore_grad_state )
550555
551- torch .set_grad_enabled (restore_grad_state )
552556 logger .debug (f"Get label stats spent { time .time () - start } " )
553557 return d # type: ignore[return-value]
554558
0 commit comments