Skip to content

Commit 192ac61

Browse files
committed
Use a __getattr__ approach
1 parent 801ecb7 commit 192ac61

3 files changed

Lines changed: 33 additions & 11 deletions

File tree

cuda_bindings/cuda/bindings/nvml.pyx

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,15 +1593,6 @@ class FieldId(_FastEnum):
15931593

15941594
MAX = 289
15951595

1596-
# This value changed in CTK 13.3. We need to build a binary that works across
1597-
# all versions, so the only way to support this is to check the version at
1598-
# runtime and set the value accordingly.
1599-
1600-
if tuple(int(x) for x in system_get_nvml_version().split(".")) < (3, 13):
1601-
NVLINK_MAX_LINKS = 18
1602-
else:
1603-
NVLINK_MAX_LINKS = 36
1604-
16051596

16061597
class RUSD(_FastEnum):
16071598
POLL_NONE = (0x0, "Disable RUSD polling on all metric groups")
@@ -28458,3 +28449,17 @@ cpdef str vgpu_type_get_name(unsigned int vgpu_type_id):
2845828449
device_get_virtualization_mode.__doc__ = device_get_virtualization_mode.__doc__.replace("NVML_GPU_VIRTUALIZATION_?", "``NVML_GPU_VIRTUALIZATION_?``")
2845928450
device_set_virtualization_mode.__doc__ = device_set_virtualization_mode.__doc__.replace("NVML_GPU_VIRTUALIZATION_?", "``NVML_GPU_VIRTUALIZATION_?``")
2846028451
GpmMetricId.GPM_METRIC_DRAM_BW_UTIL.__doc__ = "Percentage of DRAM bw used vs theoretical maximum. ``0.0 - 100.0 *\u200d/``."
28452+
28453+
28454+
def __getattr__(name: str):
28455+
# This value changed in CTK 13.3. We need to build a binary that works across
28456+
# all versions, so the only way to support this is to check the version at
28457+
# runtime and set the value accordingly.
28458+
28459+
if name == "NVLINK_MAX_LINKS":
28460+
if tuple(int(x) for x in system_get_nvml_version().split(".")) < (3, 13):
28461+
return 18
28462+
else:
28463+
return 36
28464+
28465+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

cuda_bindings/tests/nvml/test_nvlink.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,12 @@ def test_nvlink_get_link_count(all_devices):
3232
# can't be more specific about how many links we should find.
3333
if value.nvml_return == nvml.Return.SUCCESS:
3434
assert value.value.ui_val[0] <= nvml.NVLINK_MAX_LINKS, f"Unexpected link count {value.value.ui_val[0]}"
35+
36+
37+
def test_nvlink_max_links():
38+
nvml_version = tuple(int(x) for x in nvml.system_get_nvml_version().split("."))
39+
40+
if nvml_version < (13, 3):
41+
assert nvml.NVLINK_MAX_LINKS == 18
42+
else:
43+
assert nvml.NVLINK_MAX_LINKS == 36

cuda_core/cuda/core/system/_nvlink.pxi

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@ if _NVLINK_VERSION_6_0 is not None:
1818
_NVLINK_VERSION_MAPPING[_NVLINK_VERSION_6_0] = (6, 0)
1919

2020

21-
cdef class NvlinkInfo:
21+
cdef class _NvlinkInfoMeta(type):
22+
@property
23+
def max_links(cls):
24+
return nvml.NVLINK_MAX_LINKS
25+
26+
27+
cdef class _NvlinkInfo:
2228
"""
2329
Nvlink information for a device.
2430
"""
@@ -67,4 +73,6 @@ cdef class NvlinkInfo:
6773
nvml.device_get_nvlink_state(self._device._handle, self._link) == nvml.EnableState.FEATURE_ENABLED
6874
)
6975

70-
max_links = nvml.NVLINK_MAX_LINKS
76+
77+
class NvlinkInfo(_NvlinkInfo, metaclass=_NvlinkInfoMeta):
78+
pass

0 commit comments

Comments
 (0)