Skip to content

Commit b162f64

Browse files
authored
cuda.core.system: Add MIG-related APIs (#1916)
* cuda.core.system: Add MIG-related APIs * Add "need" * Add missing file * Make properties * Fix test * Fix test * Elaborate in the docstring * Address comments in PR * Address comments in PR
1 parent 515b513 commit b162f64

5 files changed

Lines changed: 231 additions & 10 deletions

File tree

cuda_core/cuda/core/system/_device.pyx

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ include "_fan.pxi"
3232
include "_field_values.pxi"
3333
include "_inforom.pxi"
3434
include "_memory.pxi"
35+
include "_mig.pxi"
3536
include "_pci_info.pxi"
3637
include "_performance.pxi"
3738
include "_repair_status.pxi"
@@ -132,12 +133,23 @@ cdef class Device:
132133
board serial identifier.
133134

134135
In the upstream NVML C++ API, the UUID includes a ``gpu-`` or ``mig-``
135-
prefix. That is not included in ``cuda.core.system``.
136+
prefix. If you need a `uuid` without that prefix (for example, to
137+
interact with CUDA), use the `uuid_without_prefix` property.
136138
"""
137-
# NVML UUIDs have a `GPU-` or `MIG-` prefix. We remove that here.
139+
return nvml.device_get_uuid(self._handle)
138140

139-
# TODO: If the user cares about the prefix, we will expose that in the
140-
# future using the MIG-related APIs in NVML.
141+
@property
142+
def uuid_without_prefix(self) -> str:
143+
"""
144+
Retrieves the globally unique immutable UUID associated with this
145+
device, as a 5 part hexadecimal string, that augments the immutable,
146+
board serial identifier.
147+
148+
In the upstream NVML C++ API, the UUID includes a ``gpu-`` or ``mig-``
149+
prefix. This property returns it without the prefix, to match the UUIDs
150+
used in CUDA. If you need the prefix, use the `uuid` property.
151+
"""
152+
# NVML UUIDs have a `gpu-` or `mig-` prefix. We remove that here.
141153
return nvml.device_get_uuid(self._handle)[4:]
142154

143155
@property
@@ -265,7 +277,7 @@ cdef class Device:
265277
# search all the devices for one with a matching UUID.
266278

267279
for cuda_device in CudaDevice.get_all_devices():
268-
if cuda_device.uuid == self.uuid:
280+
if cuda_device.uuid == self.uuid_without_prefix:
269281
return cuda_device
270282

271283
raise RuntimeError("No corresponding CUDA device found for this NVML device.")
@@ -280,6 +292,8 @@ cdef class Device:
280292
int
281293
The number of available devices.
282294
"""
295+
initialize()
296+
283297
return nvml.device_get_count_v2()
284298

285299
@classmethod
@@ -292,6 +306,8 @@ cdef class Device:
292306
Iterator over :obj:`~Device`
293307
An iterator over available devices.
294308
"""
309+
initialize()
310+
295311
for device_id in range(nvml.device_get_count_v2()):
296312
yield cls(index=device_id)
297313

@@ -317,6 +333,18 @@ cdef class Device:
317333
"""
318334
return AddressingMode(nvml.device_get_addressing_mode(self._handle).value)
319335

336+
#########################################################################
337+
# MIG (MULTI-INSTANCE GPU) DEVICES
338+
339+
@property
340+
def mig(self) -> MigInfo:
341+
"""
342+
Get :obj:`~MigInfo` accessor for MIG (Multi-Instance GPU) information.
343+
344+
For Ampere™ or newer fully supported devices.
345+
"""
346+
return MigInfo(self)
347+
320348
#########################################################################
321349
# AFFINITY
322350

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
6+
from typing import Iterable
7+
8+
9+
cdef class MigInfo:
10+
cdef Device _device
11+
12+
def __init__(self, device: Device):
13+
self._device = device
14+
15+
@property
16+
def is_mig_device(self) -> bool:
17+
"""
18+
Whether this device is a MIG (Multi-Instance GPU) device.
19+
20+
A MIG device handle is an NVML abstraction which maps to a MIG compute
21+
instance. These overloaded references can be used (with some
22+
restrictions) interchangeably with a GPU device handle to execute
23+
queries at a per-compute instance granularity.
24+
25+
For Ampere™ or newer fully supported devices.
26+
"""
27+
return bool(nvml.device_is_mig_device_handle(self._device._handle))
28+
29+
@property
30+
def mode(self) -> bool:
31+
"""
32+
Get current MIG mode for the device.
33+
34+
For Ampere™ or newer fully supported devices.
35+
36+
Changing MIG modes may require device unbind or reset. The "pending" MIG
37+
mode refers to the target mode following the next activation trigger.
38+
39+
Returns
40+
-------
41+
bool
42+
`True` if current MIG mode is enabled.
43+
"""
44+
current, _ = nvml.device_get_mig_mode(self._device._handle)
45+
return current == nvml.EnableState.FEATURE_ENABLED
46+
47+
@mode.setter
48+
def mode(self, mode: bool):
49+
"""
50+
Set the MIG mode for the device.
51+
52+
For Ampere™ or newer fully supported devices.
53+
54+
Changing MIG modes may require device unbind or reset. The "pending" MIG
55+
mode refers to the target mode following the next activation trigger.
56+
57+
Parameters
58+
----------
59+
mode: bool
60+
`True` to enable MIG mode, `False` to disable MIG mode.
61+
"""
62+
nvml.device_set_mig_mode(
63+
self._device._handle,
64+
nvml.EnableState.FEATURE_ENABLED if mode else nvml.EnableState.FEATURE_DISABLED
65+
)
66+
67+
@property
68+
def pending_mode(self) -> bool:
69+
"""
70+
Get pending MIG mode for the device.
71+
72+
For Ampere™ or newer fully supported devices.
73+
74+
Changing MIG modes may require device unbind or reset. The "pending" MIG
75+
mode refers to the target mode following the next activation trigger.
76+
77+
If the device is not a MIG device, returns `False`.
78+
79+
Returns
80+
-------
81+
bool
82+
`True` if pending MIG mode is enabled.
83+
"""
84+
_, pending = nvml.device_get_mig_mode(self._device._handle)
85+
return pending == nvml.EnableState.FEATURE_ENABLED
86+
87+
@property
88+
def device_count(self) -> int:
89+
"""
90+
Get the maximum number of MIG devices that can exist under this device.
91+
92+
Returns zero if MIG is not supported or enabled.
93+
94+
For Ampere™ or newer fully supported devices.
95+
96+
Returns
97+
-------
98+
int
99+
The number of MIG devices (compute instances) on this GPU.
100+
"""
101+
return nvml.device_get_max_mig_device_count(self._device._handle)
102+
103+
@property
104+
def parent(self) -> Device:
105+
"""
106+
For MIG devices, get the parent GPU device.
107+
108+
For Ampere™ or newer fully supported devices.
109+
110+
Returns
111+
-------
112+
Device
113+
The parent GPU device for this MIG device.
114+
"""
115+
parent_handle = nvml.device_get_device_handle_from_mig_device_handle(self._device._handle)
116+
parent_device = Device.__new__(Device)
117+
parent_device._handle = parent_handle
118+
return parent_device
119+
120+
def get_device_by_index(self, index: int) -> Device:
121+
"""
122+
Get MIG device for the given index under its parent device.
123+
124+
If the compute instance is destroyed either explicitly or by destroying,
125+
resetting or unbinding the parent GPU instance or the GPU device itself
126+
the MIG device handle would remain invalid and must be requested again
127+
using this API. Handles may be reused and their properties can change in
128+
the process.
129+
130+
For Ampere™ or newer fully supported devices.
131+
132+
Parameters
133+
----------
134+
index: int
135+
The index of the MIG device (compute instance) to retrieve. Must be
136+
between 0 and the value returned by `device_count - 1`.
137+
138+
Returns
139+
-------
140+
Device
141+
The MIG device corresponding to the given index.
142+
"""
143+
mig_device_handle = nvml.device_get_mig_device_handle_by_index(self._device._handle, index)
144+
mig_device = Device.__new__(Device)
145+
mig_device._handle = mig_device_handle
146+
return mig_device
147+
148+
def get_all_devices(self) -> Iterable[Device]:
149+
"""
150+
Get all MIG devices under its parent device.
151+
152+
If the compute instance is destroyed either explicitly or by destroying,
153+
resetting or unbinding the parent GPU instance or the GPU device itself
154+
the MIG device handle would remain invalid and must be requested again
155+
using this API. Handles may be reused and their properties can change in
156+
the process.
157+
158+
For Ampere™ or newer fully supported devices.
159+
160+
Returns
161+
-------
162+
list[Device]
163+
A list of all MIG devices corresponding to this GPU.
164+
"""
165+
for i in range(self.device_count):
166+
yield self.get_device_by_index(i)

cuda_core/docs/source/api_private.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.. SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
.. SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
.. SPDX-License-Identifier: Apache-2.0
33
44
:orphan:
@@ -76,6 +76,7 @@ NVML
7676
system._device.GpuTopologyLevel
7777
system._device.InforomInfo
7878
system._device.MemoryInfo
79+
system._device.MigInfo
7980
system._device.PciInfo
8081
system._device.RepairStatus
8182
system._device.Temperature

cuda_core/tests/system/test_system_device.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_to_cuda_device():
5757
cuda_device = device.to_cuda_device()
5858

5959
assert isinstance(cuda_device, CudaDevice)
60-
assert cuda_device.uuid == device.uuid
60+
assert cuda_device.uuid == device.uuid_without_prefix
6161

6262
# Technically, this test will only work with PCI devices, but are there
6363
# non-PCI devices we need to support?
@@ -227,9 +227,9 @@ def test_device_serial():
227227
assert len(serial) > 0
228228

229229

230-
def test_device_uuid():
230+
def test_device_uuid_without_prefix():
231231
for device in system.Device.get_all_devices():
232-
uuid = device.uuid
232+
uuid = device.uuid_without_prefix
233233
assert isinstance(uuid, str)
234234

235235
# Expands to GPU-8hex-4hex-4hex-4hex-12hex, where 8hex means 8 consecutive
@@ -729,3 +729,29 @@ def test_pstates():
729729
assert isinstance(utilization.percentage, int)
730730
assert isinstance(utilization.inc_threshold, int)
731731
assert isinstance(utilization.dec_threshold, int)
732+
733+
734+
@pytest.mark.skipif(helpers.IS_WSL or helpers.IS_WINDOWS, reason="MIG not supported on WSL or Windows")
735+
def test_mig():
736+
for device in system.Device.get_all_devices():
737+
with unsupported_before(device, None):
738+
mig = device.mig
739+
740+
assert isinstance(mig.is_mig_device, bool)
741+
assert isinstance(mig.mode, bool)
742+
assert isinstance(mig.pending_mode, bool)
743+
744+
device_count = mig.device_count
745+
assert isinstance(device_count, int)
746+
assert device_count >= 0
747+
748+
for mig_device in mig.get_all_devices():
749+
assert isinstance(mig_device, system.Device)
750+
751+
752+
def test_uuid():
753+
for device in system.Device.get_all_devices():
754+
uuid = device.uuid
755+
assert isinstance(uuid, str)
756+
assert uuid.startswith(("GPU-", "MIG-"))
757+
assert uuid == device.uuid

cuda_core/tests/test_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_to_system_device(deinit_cuda):
3838

3939
system_device = device.to_system_device()
4040
assert isinstance(system_device, SystemDevice)
41-
assert system_device.uuid == device.uuid
41+
assert system_device.uuid_without_prefix == device.uuid
4242

4343
# Technically, this test will only work with PCI devices, but are there
4444
# non-PCI devices we need to support?

0 commit comments

Comments
 (0)