Skip to content

Commit 39be54d

Browse files
committed
nvbug6084457: Fix device architecture handling and NVLink link count query
1 parent d818a75 commit 39be54d

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

cuda_bindings/tests/nvml/test_init.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,18 @@ def test_devices_are_the_same_architecture(all_devices):
2525
# they won't be tested properly. This tests for the (hopefully rare) case
2626
# where a system has devices of different architectures and produces a warning.
2727

28-
all_arches = {nvml.DeviceArch(nvml.device_get_architecture(device)) for device in all_devices}
28+
def get_architecture_name(arch):
29+
try:
30+
arch = nvml.DeviceArch(arch)
31+
return arch.name
32+
except ValueError:
33+
return f"UNKNOWN({arch})"
34+
35+
all_arches = {nvml.device_get_architecture(device) for device in all_devices}
2936

3037
if len(all_arches) > 1:
3138
warnings.warn(
32-
f"System has devices of multiple architectures ({', '.join(x.name for x in all_arches)}). "
39+
f"System has devices of multiple architectures ({', '.join(get_architecture_name(x) for x in all_arches)}). "
3340
f" Some tests may be skipped unexpectedly",
3441
UserWarning,
3542
)

cuda_bindings/tests/nvml/test_nvlink.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ def test_nvlink_get_link_count(all_devices):
2626
# The feature_nvlink_supported detection is not robust, so we
2727
# can't be more specific about how many links we should find.
2828
if value.nvml_return == nvml.Return.SUCCESS:
29-
assert value.value.ui_val <= nvml.NVLINK_MAX_LINKS, f"Unexpected link count {value.value.ui_val}"
29+
assert value.value.ui_val[0] <= nvml.NVLINK_MAX_LINKS, f"Unexpected link count {value.value.ui_val[0]}"

cuda_core/cuda/core/system/_device.pyx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,12 @@ cdef class Device:
165165
"VOLTA"``, and RTX A6000 will report ``DeviceArchitecture.name ==
166166
"AMPERE"``.
167167
"""
168-
return DeviceArch(nvml.device_get_architecture(self._handle))
168+
arch = nvml.device_get_architecture(self._handle)
169+
try:
170+
arch = DeviceArch(arch)
171+
return arch
172+
except ValueError:
173+
return nvml.DeviceArch.UNKNOWN
169174

170175
@property
171176
def name(self) -> str:

0 commit comments

Comments
 (0)