Skip to content

Commit 637a218

Browse files
committed
Improve NVML_NVLINK_MAX_LINKS dynamic handling
1 parent a9156b6 commit 637a218

7 files changed

Lines changed: 212 additions & 27 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ repos:
115115
args: [--config-file=cuda_core/pyproject.toml, cuda_core/cuda/core]
116116
additional_dependencies:
117117
- numpy
118+
- types-Deprecated
118119

119120
- repo: https://github.com/rhysd/actionlint
120121
rev: "914e7df21a07ef503a81201c76d2b11c789d3fca" # frozen: v1.7.12

cuda_core/AGENTS.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,95 @@ so that they are documented but don't appear in the main index.
145145
### API stability
146146

147147
Reviews should point out where existing public APIs are broken.
148+
149+
### Deprecation and API lifecycle
150+
151+
`cuda.core` follows SemVer (see `docs/source/:
152+
153+
- **New APIs** may be added at any time (`x.Y.0`).
154+
- **Breaking removals** only happen in **major releases** (`X.0.0`).
155+
- Per the support policy, a deprecation notice must be present for **at least
156+
one minor release** before the API is actually removed.
157+
- Changes should be notated in the code and also in the release notes in the
158+
"Deprecated APIs" section.
159+
160+
**Annotating a new API** — Use the `deprecated.sphinx.versionadded` decorator:
161+
162+
```python
163+
164+
from deprecated.sphinx import versionadded
165+
166+
@versionadded("1.2.0")
167+
def new_feature(...):
168+
"""Short description.
169+
"""
170+
```
171+
172+
Alternatively, if the vagaries of how we implement functions in Cython does not
173+
allow this, you can add the reST `versionadded` directive directly:
174+
175+
```python
176+
def new_feature(...):
177+
"""Short description.
178+
179+
.. versionadded:: 1.2.0
180+
"""
181+
```
182+
183+
**Annotating a changed API** — Use the `deprecated.sphinx.versionchanged` decorator:
184+
185+
```python
186+
187+
from deprecated.sphinx import versionchanged
188+
189+
@versionchanged("1.2.0", reason="The old version was broken because...")
190+
def new_feature(...):
191+
"""Short description.
192+
"""
193+
```
194+
195+
Alternatively, if the vagaries of how we implement functions in Cython does not
196+
allow this, you can add the reST `versionchanged` directive directly:
197+
198+
```python
199+
def new_feature(...):
200+
"""Short description.
201+
202+
.. versionchanged:: 1.2.0
203+
The old version was broken because...
204+
"""
205+
```
206+
207+
**Deprecating an existing API** — use the `@deprecated` decorator from the
208+
`Deprecated` PyPI package (`deprecated.sphinx` import name) and add a
209+
`.. deprecated::` directive in the docstring. The decorator emits a
210+
`DeprecationWarning` at call time; the docstring directive surfaces it in the
211+
generated docs.
212+
213+
```python
214+
from deprecated.sphinx import deprecated
215+
216+
@deprecated(version="1.2.0", reason="Use `new_feature` instead.")
217+
def old_feature(...):
218+
"""Short description.
219+
"""
220+
```
221+
222+
Rules to follow when deprecating:
223+
224+
- The `version=` argument must be the **first** in which the
225+
deprecation appears, not the release in which removal is planned.
226+
- The `reason=` string must name the replacement (if one exists) so users
227+
know what to migrate to.
228+
- Keep the old implementation fully functional — do not change its behavior,
229+
only add the decorator.
230+
- The deprecated API must remain in the codebase for **at least one full minor
231+
release cycle** before it can be removed in a subsequent major release.
232+
233+
**Removing a deprecated API** — removals land in a **major release**. Before
234+
removing, verify that the deprecation has been present since at least the
235+
previous minor release. Remove the decorator, the implementation, and any
236+
`__all__` entry; update `api.rst` and the release notes accordingly.
237+
238+
At some point in the future, we will provide automation for removal of
239+
deprecated APIs.

cuda_core/cuda/core/system/_device.pyi

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ from cuda.core.system.typing import (AddressingMode, AffinityScope, ClockId,
1414
GpuTopologyLevel, InforomObject,
1515
TemperatureThresholds, ThermalController,
1616
ThermalTarget)
17+
from deprecated.sphinx import deprecated, versionadded, versionchanged
1718

1819

1920
class ClockOffsets:
@@ -787,11 +788,23 @@ class MigInfo:
787788
A list of all MIG devices corresponding to this GPU.
788789
"""
789790

790-
class NvlinkInfo:
791+
class _NvlinkInfoMeta(type):
792+
793+
@property
794+
@deprecated(version='1.1.0', reason='Use Device.get_num_nvlinks instead to get the actual number of Nvlinks available on a specific device.')
795+
def max_links(cls):
796+
"""
797+
The statically-defined maximum number of Nvlinks available. Defined in
798+
upstream NVML as ``NVML_NVLINK_MAX_LINKS``.
799+
800+
To find the actual number of Nvlinks available on a device, use
801+
:py:attr:`Device.get_num_nvlinks`.
802+
"""
803+
804+
class _NvlinkInfo:
791805
"""
792806
Nvlink information for a device.
793807
"""
794-
max_links = nvml.NVLINK_MAX_LINKS
795808

796809
def __init__(self, device: Device, link: int):
797810
...
@@ -824,6 +837,9 @@ class NvlinkInfo:
824837
`True` if the Nvlink is active.
825838
"""
826839

840+
class NvlinkInfo(_NvlinkInfo, metaclass=_NvlinkInfoMeta):
841+
...
842+
827843
class PciInfo:
828844
"""
829845
PCI information about a GPU device.
@@ -1719,13 +1735,30 @@ class Device:
17191735
:obj:`~_device.MemoryInfo` object with memory information.
17201736
"""
17211737

1738+
@versionchanged(version='1.1.0', reason='Any link number not supported by this specific device will raise a `ValueError`.')
17221739
def get_nvlink(self, link: int) -> NvlinkInfo:
17231740
"""
17241741
Get :obj:`~NvlinkInfo` about this device.
17251742
17261743
For devices with NVLink support.
17271744
"""
17281745

1746+
@versionadded(version='1.1.0')
1747+
def get_nvlink_count(self) -> int:
1748+
"""
1749+
Get the number of NVLink links on this device.
1750+
1751+
For devices with NVLink support.
1752+
"""
1753+
1754+
@versionadded(version='1.1.0')
1755+
def get_nvlinks(self) -> Iterable[NvlinkInfo]:
1756+
"""
1757+
Get :obj:`~NvlinkInfo` about all NVLink links on this device.
1758+
1759+
For devices with NVLink support.
1760+
"""
1761+
17291762
@property
17301763
def pci_info(self) -> PciInfo:
17311764
"""

cuda_core/cuda/core/system/_device.pyx

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ from multiprocessing import cpu_count
99
from typing import Iterable, TYPE_CHECKING
1010
import warnings
1111

12+
from deprecated.sphinx import deprecated, versionadded, versionchanged
13+
1214
from cuda.bindings import nvml
1315

1416
from ._nvml_context cimport initialize
@@ -884,16 +886,40 @@ cdef class Device:
884886
# NVLINK
885887
# See external class definitions in _nvlink.pxi
886888
889+
@versionchanged(
890+
version="1.1.0",
891+
reason="Any link number not supported by this specific device will raise a `ValueError`."
892+
)
887893
def get_nvlink(self, link: int) -> NvlinkInfo:
888894
"""
889895
Get :obj:`~NvlinkInfo` about this device.
890896

891897
For devices with NVLink support.
892898
"""
893-
if link < 0 or link >= NvlinkInfo.max_links:
894-
raise ValueError(f"Link index {link} is out of range [0, {NvlinkInfo.max_links})")
899+
if link < 0 or link >= self.get_num_nvlinks():
900+
raise ValueError(f"Link index {link} is out of range [0, {self.get_num_nvlinks()})")
895901
return NvlinkInfo(self, link)
896902
903+
@versionadded(version="1.1.0")
904+
def get_nvlink_count(self) -> int:
905+
"""
906+
Get the number of NVLink links on this device.
907+
908+
For devices with NVLink support.
909+
"""
910+
return self.get_field_values([FieldId.DEV_NVLINK_LINK_COUNT])[0].value
911+
912+
@versionadded(version="1.1.0")
913+
def get_nvlinks(self) -> Iterable[NvlinkInfo]:
914+
"""
915+
Get :obj:`~NvlinkInfo` about all NVLink links on this device.
916+
917+
For devices with NVLink support.
918+
"""
919+
num_links = self.get_num_nvlinks()
920+
for link in range(num_links):
921+
yield self.get_nvlink(link)
922+
897923
##########################################################################
898924
# PCI INFO
899925
# See external class definitions in _pci_info.pxi

cuda_core/cuda/core/system/_nvlink.pxi

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,24 @@ if _NVLINK_VERSION_6_0 is not None:
1818
_NVLINK_VERSION_MAPPING[_NVLINK_VERSION_6_0] = (6, 0)
1919

2020

21-
cdef class NvlinkInfo:
21+
class _NvlinkInfoMeta(type):
22+
@property
23+
@deprecated(
24+
version="1.1.0",
25+
reason="Use Device.get_num_nvlinks instead to get the actual number of Nvlinks available on a specific device."
26+
)
27+
def max_links(cls):
28+
"""
29+
The statically-defined maximum number of Nvlinks available. Defined in
30+
upstream NVML as ``NVML_NVLINK_MAX_LINKS``.
31+
32+
To find the actual number of Nvlinks available on a device, use
33+
:py:attr:`Device.get_num_nvlinks`.
34+
"""
35+
return nvml.NVLINK_MAX_LINKS
36+
37+
38+
cdef class _NvlinkInfo:
2239
"""
2340
Nvlink information for a device.
2441
"""
@@ -67,4 +84,6 @@ cdef class NvlinkInfo:
6784
nvml.device_get_nvlink_state(self._device._handle, self._link) == nvml.EnableState.FEATURE_ENABLED
6885
)
6986

70-
max_links = nvml.NVLINK_MAX_LINKS
87+
88+
class NvlinkInfo(_NvlinkInfo, metaclass=_NvlinkInfoMeta):
89+
pass

cuda_core/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ classifiers = [
4949
]
5050
dependencies = [
5151
"cuda-pathfinder >=1.4.2",
52+
"Deprecated",
5253
"numpy",
5354
"backports.strenum; python_version < '3.11'",
5455
]

cuda_core/tests/system/test_system_device.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -763,27 +763,40 @@ def test_compute_running_processes():
763763

764764
def test_nvlink():
765765
for device in system.Device.get_all_devices():
766-
max_links = _device.NvlinkInfo.max_links
767-
assert isinstance(max_links, int)
768-
assert max_links > 0
769-
770-
for link in range(max_links):
771-
with unsupported_before(device, None):
772-
nvlink_info = device.get_nvlink(link)
773-
assert isinstance(nvlink_info, _device.NvlinkInfo)
774-
775-
with unsupported_before(device, None):
776-
state = nvlink_info.state
777-
assert isinstance(state, bool)
778-
779-
if not state:
780-
continue
781-
782-
with unsupported_before(device, None):
783-
version = nvlink_info.version
784-
assert isinstance(version, tuple)
785-
assert len(version) == 2
786-
assert all(isinstance(i, int) for i in version)
766+
with unsupported_before(device, None):
767+
for link in range(device.get_nvlink_count()):
768+
with unsupported_before(device, None):
769+
nvlink_info = device.get_nvlink(link)
770+
assert isinstance(nvlink_info, _device.NvlinkInfo)
771+
772+
with unsupported_before(device, None):
773+
state = nvlink_info.state
774+
assert isinstance(state, bool)
775+
776+
if not state:
777+
continue
778+
779+
with unsupported_before(device, None):
780+
version = nvlink_info.version
781+
assert isinstance(version, tuple)
782+
assert len(version) == 2
783+
assert all(isinstance(i, int) for i in version)
784+
785+
for nvlink_info in device.get_nvlinks():
786+
assert isinstance(nvlink_info, _device.NvlinkInfo)
787+
788+
with unsupported_before(device, None):
789+
state = nvlink_info.state
790+
assert isinstance(state, bool)
791+
792+
if not state:
793+
continue
794+
795+
with unsupported_before(device, None):
796+
version = nvlink_info.version
797+
assert isinstance(version, tuple)
798+
assert len(version) == 2
799+
assert all(isinstance(i, int) for i in version)
787800

788801

789802
def test_utilization():

0 commit comments

Comments
 (0)