Skip to content

Commit ea45bbf

Browse files
mdboomrwgk
andauthored
Fix test_cuda_device_order on some multi-GPU systems (#1590)
* Fix test_cuda_device_order * Fix test * Update cuda_bindings/tests/nvml/test_cuda.py Co-authored-by: Ralf W. Grosse-Kunstleve <rwgkio@gmail.com> --------- Co-authored-by: Ralf W. Grosse-Kunstleve <rwgkio@gmail.com>
1 parent 1d7c89e commit ea45bbf

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

cuda_bindings/tests/nvml/test_cuda.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33

4+
import os
5+
46
import cuda.bindings.driver as cuda
57
from cuda.bindings import nvml
68

@@ -54,4 +56,12 @@ def test_cuda_device_order():
5456
cuda_devices = get_cuda_device_names()
5557
nvml_devices = get_nvml_device_names()
5658

57-
assert cuda_devices == nvml_devices, "CUDA and NVML device lists do not match"
59+
if "CUDA_VISIBLE_DEVICES" not in os.environ:
60+
# If that environment variable isn't set, the device lists should match exactly
61+
assert cuda_devices == nvml_devices, "CUDA and NVML device lists do not match"
62+
else:
63+
# If the environment variable is set, there may possibly be fewer CUDA devices,
64+
# and each of them should still be found in NVML devices.
65+
assert len(cuda_devices) <= len(nvml_devices)
66+
for cuda_device in cuda_devices:
67+
assert cuda_device in nvml_devices, f"CUDA device {cuda_device} not found in NVML device list"

0 commit comments

Comments
 (0)