Skip to content

Commit d3d0209

Browse files
authored
8800 potentially wrong device using cuda variable status in monaiauto3dseganalyzerpy (#8801)
Fixes #8800. ### Description Fix bug introduced by me on @benediktjohannes' PR#8708 Fixed a device synchronization bug in `LabelStats.__call__()` where the `using_cuda` was being ignored. When image and label tensors were on different devices, the code would: 1. Set `using_cuda` to True if one is on GPU 2. Ignore using_cuda and move tensors to CPU if there was a mismatch Now `using_cuda` is calculated, and when devices don't match, both tensors are moved to CUDA (when either is on CUDA) or CPU otherwise. Replaced `# type: ignore` comments with proper `cast()` for type safety. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: R. Garcia-Dias <rafaelagd@gmail.com>
1 parent 853f702 commit d3d0209

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

monai/auto3dseg/analyzer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from abc import ABC, abstractmethod
1616
from collections.abc import Hashable, Mapping
1717
from copy import deepcopy
18-
from typing import Any
18+
from typing import Any, cast
1919

2020
import numpy as np
2121
import torch
@@ -470,6 +470,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
470470
start = time.time()
471471
image_tensor = d[self.image_key]
472472
label_tensor = d[self.label_key]
473+
# Check if either tensor is on CUDA to determine if we should move both to CUDA for processing
473474
using_cuda = any(
474475
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
475476
)
@@ -480,7 +481,13 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
480481
label_tensor, (MetaTensor, torch.Tensor)
481482
):
482483
if label_tensor.device != image_tensor.device:
483-
label_tensor = label_tensor.to(image_tensor.device) # type: ignore
484+
if using_cuda:
485+
# Move both tensors to CUDA when mixing devices
486+
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
487+
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
488+
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
489+
else:
490+
label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device))
484491

485492
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
486493
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)

tests/apps/test_auto3dseg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def test_label_stats_mixed_device_analyzer(self, input_params):
393393
result = analyzer({"image": image_tensor, "label": label_tensor})
394394
report = result["label_stats"]
395395

396+
# Verify report format and computation succeeded despite mixed/unified devices
396397
assert verify_report_format(report, analyzer.get_report_format())
397398
assert report[LabelStatsKeys.LABEL_UID] == [0, 1]
398399

0 commit comments

Comments
 (0)