Skip to content

Commit 867a499

Browse files
ytl0623pre-commit-ci[bot]ericspod
authored
Fix AutoencoderKlMaisi forcing CUDA transfer on CPU inputs (#8736)
Fixes #8735 ### Description This PR fixes a bug in `AutoencoderKlMaisi` where the model would force a transfer to `cuda` even if the input tensors and the model were placed on the `CPU`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ytl0623 <david89062388@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 00de6fb commit 867a499

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

monai/apps/generation/maisi/networks/autoencoderkl_maisi.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ def _concatenate_tensors(self, outputs: list[torch.Tensor], split_size: int, pad
214214
if max(outputs[0].size()) < 500:
215215
x = torch.cat(outputs, dim=self.dim_split + 2)
216216
else:
217+
target_device = outputs[0].device
218+
217219
x = outputs[0].clone().to("cpu", non_blocking=True)
218220
outputs[0] = torch.Tensor(0)
219221
_empty_cuda_cache(self.save_mem)
@@ -225,7 +227,9 @@ def _concatenate_tensors(self, outputs: list[torch.Tensor], split_size: int, pad
225227
if self.print_info:
226228
logger.info(f"MaisiConvolution concat progress: {k + 1}/{len(outputs) - 1}.")
227229

228-
x = x.to("cuda", non_blocking=True)
230+
if target_device.type != "cpu":
231+
x = x.to(target_device, non_blocking=True)
232+
229233
return x
230234

231235
def forward(self, x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)