Skip to content

Commit 2d20100

Browse files
committed
cuda.core.system: Add MIG-related APIs
1 parent 82e6bb8 commit 2d20100

File tree

2 files changed

+205
-4
lines changed

2 files changed

+205
-4
lines changed

cuda_core/cuda/core/system/_device.pyx

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ include "_fan.pxi"
3232
include "_field_values.pxi"
3333
include "_inforom.pxi"
3434
include "_memory.pxi"
35+
include "_mig.pxi"
3536
include "_pci_info.pxi"
3637
include "_performance.pxi"
3738
include "_repair_status.pxi"
@@ -132,14 +133,20 @@ cdef class Device:
132133
board serial identifier.
133134

134135
In the upstream NVML C++ API, the UUID includes a ``gpu-`` or ``mig-``
135-
prefix. That is not included in ``cuda.core.system``.
136+
prefix. If you want that prefix, use the `uuid_with_prefix` property.
136137
"""
137138
# NVML UUIDs have a `GPU-` or `MIG-` prefix. We remove that here.
138-
139-
# TODO: If the user cares about the prefix, we will expose that in the
140-
# future using the MIG-related APIs in NVML.
141139
return nvml.device_get_uuid(self._handle)[4:]
142140

141+
@property
142+
def uuid_with_prefix(self) -> str:
143+
"""
144+
Retrieves the globally unique immutable UUID associated with this
145+
device, as a 5 part hexadecimal string, that augments the immutable,
146+
board serial identifier.
147+
"""
148+
return nvml.device_get_uuid(self._handle)
149+
143150
@property
144151
def pci_bus_id(self) -> str:
145152
"""
@@ -280,6 +287,8 @@ cdef class Device:
280287
int
281288
The number of available devices.
282289
"""
290+
initialize()
291+
283292
return nvml.device_get_count_v2()
284293

285294
@classmethod
@@ -292,6 +301,8 @@ cdef class Device:
292301
Iterator of Device
293302
An iterator over available devices.
294303
"""
304+
initialize()
305+
295306
for device_id in range(nvml.device_get_count_v2()):
296307
yield cls(index=device_id)
297308

@@ -317,6 +328,18 @@ cdef class Device:
317328
"""
318329
return AddressingMode(nvml.device_get_addressing_mode(self._handle).value)
319330

331+
#########################################################################
332+
# MIG (MULTI-INSTANCE GPU) DEVICES
333+
334+
@property
335+
def mig(self) -> MigInfo:
336+
"""
337+
Accessor for MIG (Multi-Instance GPU) information.
338+
339+
For Ampere™ or newer fully supported devices.
340+
"""
341+
return MigInfo(self)
342+
320343
#########################################################################
321344
# AFFINITY
322345

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
6+
from typing import Iterable
7+
8+
9+
InforomObject = nvml.InforomObject
10+
11+
12+
cdef class MigInfo:
13+
cdef Device _device
14+
15+
def __init__(self, device: Device):
16+
self._device = device
17+
18+
@property
19+
def is_mig_device(self) -> bool:
20+
"""
21+
Whether this device is a MIG (Multi-Instance GPU) device.
22+
23+
A MIG device handle is an NVML abstraction which maps to a MIG compute
24+
instance. These overloaded references can be used (with some
25+
restrictions) interchangeably with a GPU device handle to execute
26+
queries at a per-compute instance granularity.
27+
28+
For Ampere™ or newer fully supported devices.
29+
"""
30+
return nvml.device_is_mig_device_handle(self._device._handle)
31+
32+
@property
33+
def mode(self) -> bool:
34+
"""
35+
Get current MIG mode for the device.
36+
37+
For Ampere™ or newer fully supported devices.
38+
39+
Changing MIG modes may require device unbind or reset. The "pending" MIG
40+
mode refers to the target mode following the next activation trigger.
41+
42+
If the device is not a MIG device, returns `False`.
43+
44+
Returns
45+
-------
46+
bool
47+
`True` if current MIG mode is enabled.
48+
"""
49+
if not self.is_mig_device:
50+
return False
51+
52+
current, _ = nvml.device_get_mig_mode(self._device._handle)
53+
return current == nvml.EnableState.FEATURE_ENABLED
54+
55+
@mode.setter
56+
def mode(self, mode: bool):
57+
"""
58+
Set the MIG mode for the device.
59+
60+
For Ampere™ or newer fully supported devices.
61+
62+
Changing MIG modes may require device unbind or reset. The "pending" MIG
63+
mode refers to the target mode following the next activation trigger.
64+
65+
Parameters
66+
----------
67+
mode: bool
68+
`True` to enable MIG mode, `False` to disable MIG mode.
69+
"""
70+
if not self.is_mig_device:
71+
raise ValueError("Device is not a MIG device")
72+
73+
nvml.device_set_mig_mode(
74+
self._device._handle,
75+
nvml.EnableState.FEATURE_ENABLED if mode else nvml.EnableState.FEATURE_DISABLED
76+
)
77+
78+
@property
79+
def pending_mode(self) -> bool:
80+
"""
81+
Get pending MIG mode for the device.
82+
83+
For Ampere™ or newer fully supported devices.
84+
85+
Changing MIG modes may require device unbind or reset. The "pending" MIG
86+
mode refers to the target mode following the next activation trigger.
87+
88+
If the device is not a MIG device, returns `False`.
89+
90+
Returns
91+
-------
92+
bool
93+
`True` if pending MIG mode is enabled.
94+
"""
95+
if not self.is_mig_device:
96+
return False
97+
98+
_, pending = nvml.device_get_mig_mode(self._device._handle)
99+
return pending == nvml.EnableState.FEATURE_ENABLED
100+
101+
def get_device_count(self) -> int:
102+
"""
103+
Get the maximum number of MIG devices that can exist under this device.
104+
105+
Returns zero if MIG is not supported or enabled.
106+
107+
For Ampere™ or newer fully supported devices.
108+
109+
Returns
110+
-------
111+
int
112+
The number of MIG devices (compute instances) on this GPU.
113+
"""
114+
return nvml.device_get_max_mig_device_count(self._device._handle)
115+
116+
def get_parent_device(self) -> Device:
117+
"""
118+
For MIG devices, get the parent GPU device.
119+
120+
For Ampere™ or newer fully supported devices.
121+
122+
Returns
123+
-------
124+
Device
125+
The parent GPU device for this MIG device.
126+
"""
127+
parent_handle = nvml.device_get_handle_from_mig_device_handle(self._handle)
128+
parent_device = Device.__new__(Device)
129+
parent_device._handle = parent_handle
130+
return parent_device
131+
132+
def get_device_by_index(self, index: int) -> Device:
133+
"""
134+
Get MIG device for the given index under its parent device.
135+
136+
If the compute instance is destroyed either explicitly or by destroying,
137+
resetting or unbinding the parent GPU instance or the GPU device itself
138+
the MIG device handle would remain invalid and must be requested again
139+
using this API. Handles may be reused and their properties can change in
140+
the process.
141+
142+
For Ampere™ or newer fully supported devices.
143+
144+
Parameters
145+
----------
146+
index: int
147+
The index of the MIG device (compute instance) to retrieve. Must be
148+
between 0 and the value returned by `get_device_count() - 1`.
149+
150+
Returns
151+
-------
152+
Device
153+
The MIG device corresponding to the given index.
154+
"""
155+
mig_device_handle = nvml.device_get_mig_device_handle_by_index(self._device._handle, index)
156+
mig_device = Device.__new__(Device)
157+
mig_device._handle = mig_device_handle
158+
return mig_device
159+
160+
def get_all_devices(self) -> Iterable[Device]:
161+
"""
162+
Get all MIG devices under its parent device.
163+
164+
If the compute instance is destroyed either explicitly or by destroying,
165+
resetting or unbinding the parent GPU instance or the GPU device itself
166+
the MIG device handle would remain invalid and must be requested again
167+
using this API. Handles may be reused and their properties can change in
168+
the process.
169+
170+
For Ampere™ or newer fully supported devices.
171+
172+
Returns
173+
-------
174+
list[Device]
175+
A list of all MIG devices corresponding to this GPU.
176+
"""
177+
for i in range(self.get_device_count()):
178+
yield self.get_device_by_index(i)

0 commit comments

Comments
 (0)