Skip to content

Commit fa4ec27

Browse files
committed
fix: fallback to torch.cuda.mem_get_info when NVML memory query is unsupported
nvmlDeviceGetMemoryInfo returns NVML_ERROR_NOT_SUPPORTED on DGX Spark (GB10). Fall back to torch.cuda.mem_get_info which works on all CUDA devices. Signed-off-by: Daniel Bustamante Ospina <dbustamante70@gmail.com>
1 parent 3069f28 commit fa4ec27

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

nemo_rl/utils/nvml.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,17 @@ def get_device_uuid(device_idx: int) -> str:
7878

7979

8080
def get_free_memory_bytes(device_idx: int) -> float:
81-
"""Get the free memory of a CUDA device in bytes using NVML."""
81+
"""Get the free memory of a CUDA device in bytes using NVML, with torch.cuda fallback."""
8282
global_device_idx = device_id_to_physical_device_id(device_idx)
8383
with nvml_context():
8484
try:
8585
handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx)
8686
return pynvml.nvmlDeviceGetMemoryInfo(handle).free
87-
except pynvml.NVMLError as e:
88-
raise RuntimeError(
89-
f"Failed to get free memory for device {device_idx} (global index: {global_device_idx}): {e}"
90-
)
87+
except pynvml.NVMLError:
88+
pass
89+
90+
# Fallback for GPUs where NVML memory query is not supported (e.g. DGX Spark)
91+
import torch
92+
93+
free, _total = torch.cuda.mem_get_info(device_idx)
94+
return free

0 commit comments

Comments
 (0)