Skip to content

Commit eebd4cf

Browse files
leofangclaude
andcommitted
Polish green context API: docs, error handling, simplification
- dev.create_context raises ValueError (not NotImplementedError) when options or resources are missing. - Cache version checks (_check_green_ctx_support, _check_workqueue_support) at module level; raise ValueError instead of NotImplementedError. - Simplify _device_resources.pyx: merge _as_uint and _count_to_sm_count into _to_sm_count; inline unsigned int casts for coscheduled params. - Add green context classes to api.rst (Context, ContextOptions, DeviceResources, SMResource, SMResourceOptions, WorkqueueResource, WorkqueueResourceOptions). - Update all docstrings to NumPy style with Attributes/Parameters/Returns sections matching the existing codebase convention. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 340506e commit eebd4cf

5 files changed

Lines changed: 149 additions & 64 deletions

File tree

cuda_core/cuda/core/_device.pyx

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,11 +1302,15 @@ class Device:
13021302
cdef GreenCtxHandle h_green
13031303

13041304
if options is None:
1305-
raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/189")
1305+
raise ValueError(
1306+
"options with device resources must be provided to create a green context"
1307+
)
13061308

13071309
options = check_or_create_options(ContextOptions, options, "Context options")
13081310
if options.resources is None:
1309-
raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/189")
1311+
raise ValueError(
1312+
"ContextOptions.resources must be provided to create a green context"
1313+
)
13101314

13111315
resources = tuple(options.resources)
13121316
if len(resources) == 0:
@@ -1334,9 +1338,9 @@ class Device:
13341338

13351339
h_green = create_green_ctx_handle(
13361340
c_resources,
1337-
<unsigned int>n_resources,
1338-
<cydriver.CUdevice>self._device_id,
1339-
<unsigned int>cydriver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM,
1341+
<unsigned int>(n_resources),
1342+
<cydriver.CUdevice>(self._device_id),
1343+
<unsigned int>(cydriver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM),
13401344
)
13411345
if h_green.get() == NULL:
13421346
HANDLE_RETURN(get_last_error())

cuda_core/cuda/core/_device_resources.pyx

Lines changed: 127 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -27,45 +27,85 @@ __all__ = [
2727
]
2828

2929

30+
# Module-level cached version checks (trinary: 0=unchecked, 1=supported, -1=unsupported)
31+
cdef int _green_ctx_checked = 0
32+
cdef int _workqueue_checked = 0
33+
cdef str _green_ctx_err_msg = ""
34+
cdef str _workqueue_err_msg = ""
35+
36+
3037
cdef inline int _check_green_ctx_support() except?-1:
38+
global _green_ctx_checked, _green_ctx_err_msg
39+
if _green_ctx_checked == 1:
40+
return 0
41+
if _green_ctx_checked == -1:
42+
raise ValueError(_green_ctx_err_msg)
3143
cdef tuple drv = cy_driver_version()
3244
cdef tuple bind = cy_binding_version()
3345
if drv < (12, 4, 0):
34-
raise NotImplementedError(
35-
"Green context support requires CUDA driver 12.4 or newer. "
36-
f"Using driver version {'.'.join(map(str, drv))}"
46+
_green_ctx_err_msg = (
47+
"Green context support requires CUDA driver 12.4 or newer "
48+
f"(current driver: {'.'.join(map(str, drv))})"
3749
)
50+
_green_ctx_checked = -1
51+
raise ValueError(_green_ctx_err_msg)
3852
if bind < (12, 4, 0):
39-
raise NotImplementedError(
40-
"Green context support requires cuda.bindings 12.4 or newer. "
41-
f"Using cuda.bindings version {'.'.join(map(str, bind))}"
53+
_green_ctx_err_msg = (
54+
"Green context support requires cuda.bindings 12.4 or newer "
55+
f"(current bindings: {'.'.join(map(str, bind))})"
4256
)
57+
_green_ctx_checked = -1
58+
raise ValueError(_green_ctx_err_msg)
59+
_green_ctx_checked = 1
4360
return 0
4461

4562

4663
cdef inline int _check_workqueue_support() except?-1:
64+
global _workqueue_checked, _workqueue_err_msg
65+
if _workqueue_checked == 1:
66+
return 0
67+
if _workqueue_checked == -1:
68+
raise ValueError(_workqueue_err_msg)
4769
cdef tuple drv = cy_driver_version()
4870
cdef tuple bind = cy_binding_version()
4971
if drv < (13, 1, 0):
50-
raise NotImplementedError(
51-
"WorkqueueResource requires CUDA driver 13.1 or newer. "
52-
f"Using driver version {'.'.join(map(str, drv))}"
72+
_workqueue_err_msg = (
73+
"WorkqueueResource requires CUDA driver 13.1 or newer "
74+
f"(current driver: {'.'.join(map(str, drv))})"
5375
)
76+
_workqueue_checked = -1
77+
raise ValueError(_workqueue_err_msg)
5478
if bind < (13, 1, 0):
55-
raise NotImplementedError(
56-
"WorkqueueResource requires cuda.bindings 13.1 or newer. "
57-
f"Using cuda.bindings version {'.'.join(map(str, bind))}"
79+
_workqueue_err_msg = (
80+
"WorkqueueResource requires cuda.bindings 13.1 or newer "
81+
f"(current bindings: {'.'.join(map(str, bind))})"
5882
)
83+
_workqueue_checked = -1
84+
raise ValueError(_workqueue_err_msg)
85+
_workqueue_checked = 1
5986
return 0
6087

6188

6289
@dataclass
6390
cdef class SMResourceOptions:
64-
"""Options for :meth:`SMResource.split`.
65-
66-
``count`` determines the number of requested groups. Scalar ``count`` or
67-
``None`` creates one group; a sequence creates ``len(count)`` groups. Other
68-
sequence fields must match the length of ``count``.
91+
"""Customizable :obj:`SMResource.split` options.
92+
93+
Each field accepts a scalar (for a single group) or a ``Sequence``
94+
(for multiple groups). ``count`` drives the number of groups; other
95+
``Sequence`` fields must match its length.
96+
97+
Attributes
98+
----------
99+
count : int or Sequence[int], optional
100+
Requested SM count per group. ``None`` means discovery mode
101+
(auto-detect). (Default to ``None``)
102+
coscheduled_sm_count : int or Sequence[int], optional
103+
Minimum number of SMs guaranteed to be co-scheduled in each
104+
group. (Default to ``None``)
105+
preferred_coscheduled_sm_count : int or Sequence[int], optional
106+
Preferred co-scheduled SM count; the driver tries to satisfy
107+
this but may fall back to ``coscheduled_sm_count``.
108+
(Default to ``None``)
69109
"""
70110

71111
count: int | SequenceABC | None = None
@@ -75,7 +115,14 @@ cdef class SMResourceOptions:
75115

76116
@dataclass
77117
cdef class WorkqueueResourceOptions:
78-
"""Options for :meth:`WorkqueueResource.configure`."""
118+
"""Customizable :obj:`WorkqueueResource.configure` options.
119+
120+
Attributes
121+
----------
122+
sharing_scope : str, optional
123+
Workqueue sharing scope. Accepted values: ``"device_ctx"``
124+
or ``"green_ctx_balanced"``. (Default to ``None``)
125+
"""
79126

80127
sharing_scope: str | None = None
81128

@@ -134,18 +181,13 @@ cdef object _broadcast_field(object value, int n_groups):
134181
return [value] * n_groups
135182

136183

137-
cdef inline unsigned int _as_uint(object value, str field_name) except? 0:
138-
if not isinstance(value, int):
139-
raise TypeError(f"{field_name} must be an int or None, got {type(value)}")
140-
if value < 0:
141-
raise ValueError(f"{field_name} must be non-negative")
142-
return <unsigned int>value
143-
144-
145-
cdef inline unsigned int _count_to_sm_count(object value) except? 0:
184+
cdef inline unsigned int _to_sm_count(object value) except? 0:
185+
"""Convert a count value to unsigned int. None maps to 0 (discovery)."""
146186
if value is None:
147187
return 0
148-
return _as_uint(value, "count")
188+
if value < 0:
189+
raise ValueError(f"count must be non-negative, got {value}")
190+
return <unsigned int>(value)
149191

150192

151193
cdef inline bint _can_use_structured_sm_split():
@@ -196,7 +238,7 @@ cdef object _resolve_split_by_count_request(SMResourceOptions options):
196238
"use CUDA 13.1 or newer for per-group counts"
197239
)
198240

199-
min_count = _count_to_sm_count(first)
241+
min_count = _to_sm_count(first)
200242
return n_groups, min_count
201243

202244

@@ -213,13 +255,11 @@ IF CUDA_CORE_BUILD_MAJOR >= 13:
213255

214256
for i in range(n_groups):
215257
memset(&params[i], 0, sizeof(cydriver.CU_DEV_SM_RESOURCE_GROUP_PARAMS))
216-
params[i].smCount = _count_to_sm_count(counts[i])
258+
params[i].smCount = _to_sm_count(counts[i])
217259
if coscheduled[i] is not None:
218-
params[i].coscheduledSmCount = _as_uint(coscheduled[i], "coscheduled_sm_count")
260+
params[i].coscheduledSmCount = <unsigned int>(coscheduled[i])
219261
if preferred[i] is not None:
220-
params[i].preferredCoscheduledSmCount = _as_uint(
221-
preferred[i], "preferred_coscheduled_sm_count"
222-
)
262+
params[i].preferredCoscheduledSmCount = <unsigned int>(preferred[i])
223263
params[i].flags = 0
224264
return 0
225265

@@ -253,7 +293,7 @@ IF CUDA_CORE_BUILD_MAJOR >= 13:
253293
with nogil:
254294
HANDLE_RETURN(cydriver.cuDevSmResourceSplit(
255295
result,
256-
<unsigned int>n_groups,
296+
<unsigned int>(n_groups),
257297
&sm._resource,
258298
&remaining,
259299
0,
@@ -285,8 +325,8 @@ ELSE:
285325

286326
cdef object _split_with_count_api(SMResource sm, SMResourceOptions options, bint dry_run):
287327
cdef object request = _resolve_split_by_count_request(options)
288-
cdef unsigned int nb_groups = <unsigned int>request[0]
289-
cdef unsigned int min_count = <unsigned int>request[1]
328+
cdef unsigned int nb_groups = <unsigned int>(request[0])
329+
cdef unsigned int min_count = <unsigned int>(request[1])
290330
cdef unsigned int actual_groups = nb_groups
291331
cdef cydriver.CUdevResource* result = NULL
292332
cdef cydriver.CUdevResource remaining
@@ -328,7 +368,7 @@ cdef inline unsigned int _sm_resource_granularity(int device_id) except? 0:
328368
HANDLE_RETURN(cydriver.cuDeviceGetAttribute(
329369
&major,
330370
cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
331-
<cydriver.CUdevice>device_id,
371+
<cydriver.CUdevice>(device_id),
332372
))
333373
if major >= 9:
334374
return 8
@@ -342,7 +382,11 @@ cdef inline unsigned int _fallback_if_zero(unsigned int value, unsigned int fall
342382

343383

344384
cdef class SMResource:
345-
"""SM resource queried from a device. Not user-constructible."""
385+
"""Represent an SM (streaming multiprocessor) resource partition.
386+
387+
Instances are returned by :obj:`DeviceResources.sm` or
388+
:meth:`SMResource.split` and cannot be instantiated directly.
389+
"""
346390

347391
def __init__(self, *args, **kwargs):
348392
raise RuntimeError(
@@ -391,7 +435,7 @@ cdef class SMResource:
391435
@property
392436
def handle(self) -> int:
393437
"""Return the address of the underlying ``CUdevResource`` struct."""
394-
return <intptr_t>&self._resource
438+
return <intptr_t>(&self._resource)
395439

396440
@property
397441
def sm_count(self) -> int:
@@ -414,7 +458,22 @@ cdef class SMResource:
414458
return self._flags
415459

416460
def split(self, options not None, *, bint dry_run=False):
417-
"""Split this SM resource into groups plus a remainder."""
461+
"""Split this SM resource into groups and a remainder.
462+
463+
Parameters
464+
----------
465+
options : :obj:`SMResourceOptions`
466+
Split configuration (count, co-scheduling constraints).
467+
dry_run : bool, optional
468+
If ``True``, return filled-in metadata without creating
469+
usable resource objects. (Default to ``False``)
470+
471+
Returns
472+
-------
473+
tuple[list[:obj:`SMResource`], :obj:`SMResource`]
474+
``(groups, remainder)`` where each group holds a disjoint
475+
SM partition and *remainder* holds any unassigned SMs.
476+
"""
418477
cdef SMResourceOptions opts = check_or_create_options(
419478
SMResourceOptions, options, "SM resource options"
420479
)
@@ -427,7 +486,13 @@ cdef class SMResource:
427486

428487

429488
cdef class WorkqueueResource:
430-
"""Workqueue resource. Not user-constructible."""
489+
"""Represent a workqueue resource for a device or green context.
490+
491+
Merges ``CU_DEV_RESOURCE_TYPE_WORKQUEUE_CONFIG`` and
492+
``CU_DEV_RESOURCE_TYPE_WORKQUEUE`` under one user-facing type.
493+
Instances are returned by :obj:`DeviceResources.workqueue` and
494+
cannot be instantiated directly.
495+
"""
431496

432497
def __init__(self, *args, **kwargs):
433498
raise RuntimeError(
@@ -448,10 +513,16 @@ cdef class WorkqueueResource:
448513
@property
449514
def handle(self) -> int:
450515
"""Return the address of the underlying config ``CUdevResource`` struct."""
451-
return <intptr_t>&self._wq_config_resource
516+
return <intptr_t>(&self._wq_config_resource)
452517

453518
def configure(self, options not None):
454-
"""Configure the workqueue resource in place."""
519+
"""Configure the workqueue resource in place.
520+
521+
Parameters
522+
----------
523+
options : :obj:`WorkqueueResourceOptions`
524+
Configuration options (sharing scope, etc.).
525+
"""
455526
cdef WorkqueueResourceOptions opts = check_or_create_options(
456527
WorkqueueResourceOptions, options, "Workqueue resource options"
457528
)
@@ -481,11 +552,14 @@ cdef class WorkqueueResource:
481552

482553

483554
cdef class DeviceResources:
484-
"""Namespace for hardware resource query. Not user-constructible.
555+
"""Namespace for hardware resource queries.
556+
557+
When obtained via :obj:`Device.resources`, queries return full device
558+
resources. When obtained via :obj:`Context.resources` or
559+
:obj:`Stream.resources`, queries return the resources provisioned for
560+
that context.
485561
486-
When obtained via ``dev.resources``, queries return full device resources.
487-
When obtained via ``ctx.resources``, queries return the resources
488-
provisioned for that context.
562+
This class cannot be instantiated directly.
489563
"""
490564

491565
def __init__(self, *args, **kwargs):
@@ -525,14 +599,14 @@ cdef class DeviceResources:
525599
))
526600
else:
527601
HANDLE_RETURN(cydriver.cuDeviceGetDevResource(
528-
<cydriver.CUdevice>self._device_id, res,
602+
<cydriver.CUdevice>(self._device_id), res,
529603
cydriver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM,
530604
))
531605
return 0
532606

533607
@property
534608
def sm(self) -> SMResource:
535-
"""Query SM resources."""
609+
"""Return the :obj:`SMResource` for this device or context."""
536610
_check_green_ctx_support()
537611
cdef cydriver.CUdevResource res
538612
with nogil:
@@ -541,7 +615,7 @@ cdef class DeviceResources:
541615

542616
@property
543617
def workqueue(self) -> WorkqueueResource:
544-
"""Query workqueue resources."""
618+
"""Return the :obj:`WorkqueueResource` for this device or context."""
545619
_check_green_ctx_support()
546620
_check_workqueue_support()
547621
cdef cydriver.CUdevResource _wq_config
@@ -581,12 +655,12 @@ cdef class DeviceResources:
581655
# Device-level query
582656
with nogil:
583657
HANDLE_RETURN(cydriver.cuDeviceGetDevResource(
584-
<cydriver.CUdevice>self._device_id,
658+
<cydriver.CUdevice>(self._device_id),
585659
&_wq_config,
586660
cydriver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_WORKQUEUE_CONFIG,
587661
))
588662
HANDLE_RETURN(cydriver.cuDeviceGetDevResource(
589-
<cydriver.CUdevice>self._device_id,
663+
<cydriver.CUdevice>(self._device_id),
590664
&_wq,
591665
cydriver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_WORKQUEUE,
592666
))

cuda_core/docs/source/api.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,18 @@ Devices and execution
2626

2727
Stream
2828
Event
29+
Context
30+
SMResource
31+
WorkqueueResource
2932

3033
:template: dataclass.rst
3134

3235
StreamOptions
3336
EventOptions
3437
LaunchConfig
38+
ContextOptions
39+
SMResourceOptions
40+
WorkqueueResourceOptions
3541

3642
.. data:: LEGACY_DEFAULT_STREAM
3743

cuda_core/docs/source/api_private.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ CUDA runtime
3131
:template: autosummary/cyclass.rst
3232

3333
_device.DeviceProperties
34+
_device_resources.DeviceResources
3435
_memory._ipc.IPCAllocationHandle
3536
_memory._ipc.IPCBufferDescriptor
3637

0 commit comments

Comments
 (0)