Skip to content

Commit 4ec0d4d

Browse files
committed
Fix #5889: restore grad state on exception in FgImageStats and LabelStats
FgImageStats.__call__ and LabelStats.__call__ both saved and disabled torch grad state on entry but only restored it via a plain assignment before return. Any exception raised between the disable and the restore (e.g. shape mismatch, RuntimeError from verify_report_format) left torch.is_grad_enabled() permanently False for the remainder of the process, silently breaking all subsequent gradient computations. Wrapped the computation body in try/finally in both methods so the grad state is guaranteed to be restored regardless of how the function exits. Fixes #5889
1 parent 8d39519 commit 4ec0d4d

1 file changed

Lines changed: 87 additions & 83 deletions

File tree

monai/auto3dseg/analyzer.py

Lines changed: 87 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -344,28 +344,30 @@ def __call__(self, data: Mapping) -> dict:
344344
restore_grad_state = torch.is_grad_enabled()
345345
torch.set_grad_enabled(False)
346346

347-
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
348-
ndas_label = d[self.label_key] # (H,W,D)
347+
try:
348+
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
349+
ndas_label = d[self.label_key] # (H,W,D)
349350

350-
if ndas_label.shape != ndas[0].shape:
351-
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
351+
if ndas_label.shape != ndas[0].shape:
352+
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
352353

353-
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
354-
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
354+
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
355+
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
355356

356-
# perform calculation
357-
report = deepcopy(self.get_report_format())
357+
# perform calculation
358+
report = deepcopy(self.get_report_format())
358359

359-
report[ImageStatsKeys.INTENSITY] = [
360-
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds
361-
]
360+
report[ImageStatsKeys.INTENSITY] = [
361+
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds
362+
]
362363

363-
if not verify_report_format(report, self.get_report_format()):
364-
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
364+
if not verify_report_format(report, self.get_report_format()):
365+
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
365366

366-
d[self.stats_name] = report
367+
d[self.stats_name] = report
368+
finally:
369+
torch.set_grad_enabled(restore_grad_state)
367370

368-
torch.set_grad_enabled(restore_grad_state)
369371
logger.debug(f"Get foreground image stats spent {time.time() - start}")
370372
return d
371373

@@ -477,78 +479,80 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
477479
restore_grad_state = torch.is_grad_enabled()
478480
torch.set_grad_enabled(False)
479481

480-
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
481-
label_tensor, (MetaTensor, torch.Tensor)
482-
):
483-
if label_tensor.device != image_tensor.device:
484-
if using_cuda:
485-
# Move both tensors to CUDA when mixing devices
486-
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
487-
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
488-
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
489-
else:
490-
label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device))
491-
492-
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
493-
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
494-
495-
if ndas_label.shape != ndas[0].shape:
496-
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
497-
498-
nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
499-
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
500-
501-
unique_label = unique(ndas_label)
502-
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
503-
unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment]
504-
505-
unique_label = unique_label.astype(np.int16).tolist()
506-
507-
label_substats = [] # each element is one label
508-
pixel_sum = 0
509-
pixel_arr = []
510-
for index in unique_label:
511-
start_label = time.time()
512-
label_dict: dict[str, Any] = {}
513-
mask_index = ndas_label == index
514-
515-
nda_masks = [nda[mask_index] for nda in ndas]
516-
label_dict[LabelStatsKeys.IMAGE_INTST] = [
517-
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks
518-
]
519-
520-
pixel_count = sum(mask_index)
521-
pixel_arr.append(pixel_count)
522-
pixel_sum += pixel_count
523-
if self.do_ccp: # apply connected component
524-
if using_cuda:
525-
# The back end of get_label_ccp is CuPy
526-
# which is unable to automatically release CUDA GPU memory held by PyTorch
527-
del nda_masks
528-
torch.cuda.empty_cache()
529-
shape_list, ncomponents = get_label_ccp(mask_index)
530-
label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list
531-
label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents
532-
533-
label_substats.append(label_dict)
534-
logger.debug(f" label {index} stats takes {time.time() - start_label}")
535-
536-
for i, _ in enumerate(unique_label):
537-
label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)})
482+
try:
483+
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
484+
label_tensor, (MetaTensor, torch.Tensor)
485+
):
486+
if label_tensor.device != image_tensor.device:
487+
if using_cuda:
488+
# Move both tensors to CUDA when mixing devices
489+
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
490+
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
491+
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
492+
else:
493+
label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device))
494+
495+
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
496+
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
497+
498+
if ndas_label.shape != ndas[0].shape:
499+
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
500+
501+
nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
502+
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
503+
504+
unique_label = unique(ndas_label)
505+
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
506+
unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment]
507+
508+
unique_label = unique_label.astype(np.int16).tolist()
509+
510+
label_substats = [] # each element is one label
511+
pixel_sum = 0
512+
pixel_arr = []
513+
for index in unique_label:
514+
start_label = time.time()
515+
label_dict: dict[str, Any] = {}
516+
mask_index = ndas_label == index
517+
518+
nda_masks = [nda[mask_index] for nda in ndas]
519+
label_dict[LabelStatsKeys.IMAGE_INTST] = [
520+
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks
521+
]
538522

539-
report = deepcopy(self.get_report_format())
540-
report[LabelStatsKeys.LABEL_UID] = unique_label
541-
report[LabelStatsKeys.IMAGE_INTST] = [
542-
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds
543-
]
544-
report[LabelStatsKeys.LABEL] = label_substats
523+
pixel_count = sum(mask_index)
524+
pixel_arr.append(pixel_count)
525+
pixel_sum += pixel_count
526+
if self.do_ccp: # apply connected component
527+
if using_cuda:
528+
# The back end of get_label_ccp is CuPy
529+
# which is unable to automatically release CUDA GPU memory held by PyTorch
530+
del nda_masks
531+
torch.cuda.empty_cache()
532+
shape_list, ncomponents = get_label_ccp(mask_index)
533+
label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list
534+
label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents
535+
536+
label_substats.append(label_dict)
537+
logger.debug(f" label {index} stats takes {time.time() - start_label}")
538+
539+
for i, _ in enumerate(unique_label):
540+
label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)})
541+
542+
report = deepcopy(self.get_report_format())
543+
report[LabelStatsKeys.LABEL_UID] = unique_label
544+
report[LabelStatsKeys.IMAGE_INTST] = [
545+
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds
546+
]
547+
report[LabelStatsKeys.LABEL] = label_substats
545548

546-
if not verify_report_format(report, self.get_report_format()):
547-
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
549+
if not verify_report_format(report, self.get_report_format()):
550+
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
548551

549-
d[self.stats_name] = report # type: ignore[assignment]
552+
d[self.stats_name] = report # type: ignore[assignment]
553+
finally:
554+
torch.set_grad_enabled(restore_grad_state)
550555

551-
torch.set_grad_enabled(restore_grad_state)
552556
logger.debug(f"Get label stats spent {time.time() - start}")
553557
return d # type: ignore[return-value]
554558

0 commit comments

Comments
 (0)