Skip to content

Commit db8ae12

Browse files
committed
fix: fallback to torch.cuda.mem_get_info when NVML memory query is unsupported
nvmlDeviceGetMemoryInfo returns NVML_ERROR_NOT_SUPPORTED on DGX Spark (GB10). Log the error and fall back to torch.cuda.mem_get_info which works on all CUDA devices. Signed-off-by: Daniel Bustamante Ospina <dbustamante70@gmail.com>
1 parent fa4ec27 commit db8ae12

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

nemo_rl/utils/nvml.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import contextlib
15+
import logging
1516
import os
1617
from typing import Generator
1718

1819
import pynvml
1920

21+
logger = logging.getLogger(__name__)
22+
2023

2124
@contextlib.contextmanager
2225
def nvml_context() -> Generator[None, None, None]:
@@ -84,10 +87,9 @@ def get_free_memory_bytes(device_idx: int) -> float:
8487
try:
8588
handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx)
8689
return pynvml.nvmlDeviceGetMemoryInfo(handle).free
87-
except pynvml.NVMLError:
88-
pass
90+
except pynvml.NVMLError as e:
91+
logger.warning("NVML memory query failed for device %d: %s. Falling back to torch.cuda.mem_get_info.", device_idx, e)
8992

90-
# Fallback for GPUs where NVML memory query is not supported (e.g. DGX Spark)
9193
import torch
9294

9395
free, _total = torch.cuda.mem_get_info(device_idx)

0 commit comments

Comments
 (0)