Skip to content

Commit 10de998

Browse files
committed
review(cuda.core): address PR #1775 feedback
- Drop defensive cuInit retry in _query_memory_attrs (Andy): we don't auto-init CUDA elsewhere; let HANDLE_RETURN propagate the error. - Use checked Cython cast `<Buffer?>t` in _coerce_buffer_targets (Leo) in place of the manual isinstance loop. - Introduce *Options dataclasses (AdviseOptions, PrefetchOptions, DiscardOptions, DiscardPrefetchOptions) per cuda.core convention (Leo). Functions accept None or the matching dataclass; tests updated to match the new error message.
1 parent e0c782a commit 10de998

5 files changed

Lines changed: 89 additions & 31 deletions

File tree

cuda_core/cuda/core/_memory/_buffer.pyx

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ else:
3636

3737
from cuda.core._dlpack import classify_dl_device, make_py_capsule
3838
from cuda.core._utils.cuda_utils import driver
39-
from cuda.core._device import Device
4039

4140

4241
# =============================================================================
@@ -449,11 +448,6 @@ cdef inline int _query_memory_attrs(
449448

450449
cdef cydriver.CUresult ret
451450
ret = cydriver.cuPointerGetAttributes(3, attrs, <void**>vals, ptr)
452-
if ret == cydriver.CUresult.CUDA_ERROR_NOT_INITIALIZED:
453-
with cython.gil:
454-
# Device class handles the cuInit call internally
455-
Device()
456-
ret = cydriver.cuPointerGetAttributes(3, attrs, <void**>vals, ptr)
457451
HANDLE_RETURN(ret)
458452

459453
# TODO: HMM/ATS-enabled sysmem should also report is_managed=True; the

cuda_core/cuda/core/_memory/_managed_memory_ops.pyx

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
1515

1616
from cuda.core._utils.cuda_utils import driver
1717
from cuda.core._memory._managed_location import _coerce_location
18+
from cuda.core._memory._managed_memory_options import (
19+
AdviseOptions,
20+
DiscardOptions,
21+
DiscardPrefetchOptions,
22+
PrefetchOptions,
23+
)
1824

1925

2026
cdef dict _MANAGED_ADVICE_ALIASES = {
@@ -89,6 +95,7 @@ cdef void _require_managed_buffer(Buffer self, str what):
8995

9096

9197
cdef tuple _coerce_buffer_targets(object targets, str what):
98+
cdef Buffer buf
9299
cdef list out
93100
if isinstance(targets, Buffer):
94101
return (<Buffer>targets,)
@@ -97,11 +104,8 @@ cdef tuple _coerce_buffer_targets(object targets, str what):
97104
raise ValueError(f"{what}: empty targets sequence")
98105
out = []
99106
for t in targets:
100-
if not isinstance(t, Buffer):
101-
raise TypeError(
102-
f"{what}: each target must be a Buffer, got {type(t).__name__}"
103-
)
104-
out.append(t)
107+
buf = <Buffer?>t
108+
out.append(buf)
105109
return tuple(out)
106110
raise TypeError(
107111
f"{what}: targets must be a Buffer or sequence of Buffer, "
@@ -167,8 +171,9 @@ def discard(
167171
One or more managed allocations to discard. Their resident pages
168172
are released without prefetching new contents; subsequent access
169173
is satisfied by lazy migration.
170-
options : None
171-
Reserved for future per-call flags. Must be ``None``.
174+
options : :class:`DiscardOptions`, optional
175+
Reserved for future per-call flags. ``None`` (default) and
176+
``DiscardOptions()`` are equivalent.
172177
stream : :class:`~_stream.Stream` | :class:`~graph.GraphBuilder`
173178
Stream for the asynchronous discard (keyword-only).
174179
@@ -177,9 +182,10 @@ def discard(
177182
NotImplementedError
178183
On a CUDA 12 build of ``cuda.core``. Discard requires CUDA 13+.
179184
"""
180-
if options is not None:
185+
if options is not None and not isinstance(options, DiscardOptions):
181186
raise TypeError(
182-
f"discard options must be None (reserved); got {type(options).__name__}"
187+
"discard options must be a DiscardOptions instance or None, "
188+
f"got {type(options).__name__}"
183189
)
184190
cdef tuple bufs = _coerce_buffer_targets(targets, "discard")
185191
cdef Stream s = Stream_accept(stream)
@@ -246,12 +252,14 @@ def advise(
246252
location; ignored (may be ``None``) for ``set_read_mostly``,
247253
``unset_read_mostly``, and ``unset_preferred_location``. A sequence
248254
must match ``len(targets)``.
249-
options : None
250-
Reserved for future per-call flags. Must be ``None``.
255+
options : :class:`AdviseOptions`, optional
256+
Reserved for future per-call flags. ``None`` (default) and
257+
``AdviseOptions()`` are equivalent.
251258
"""
252-
if options is not None:
259+
if options is not None and not isinstance(options, AdviseOptions):
253260
raise TypeError(
254-
f"advise options must be None (reserved); got {type(options).__name__}"
261+
"advise options must be an AdviseOptions instance or None, "
262+
f"got {type(options).__name__}"
255263
)
256264
cdef str advice_name
257265
cdef object advice_value
@@ -317,8 +325,9 @@ def prefetch(
317325
Target location(s). A single location applies to all targets; a
318326
sequence must match ``len(targets)``. ``Device`` and ``int`` values
319327
are coerced to :class:`Location` (``-1`` maps to host).
320-
options : None
321-
Reserved for future per-call flags. Must be ``None``.
328+
options : :class:`PrefetchOptions`, optional
329+
Reserved for future per-call flags. ``None`` (default) and
330+
``PrefetchOptions()`` are equivalent.
322331
stream : :class:`~_stream.Stream` | :class:`~graph.GraphBuilder`
323332
Stream for the asynchronous prefetch (keyword-only).
324333
@@ -327,9 +336,10 @@ def prefetch(
327336
NotImplementedError
328337
If ``len(targets) > 1`` on a CUDA 12 build of ``cuda.core``.
329338
"""
330-
if options is not None:
339+
if options is not None and not isinstance(options, PrefetchOptions):
331340
raise TypeError(
332-
f"prefetch options must be None (reserved); got {type(options).__name__}"
341+
"prefetch options must be a PrefetchOptions instance or None, "
342+
f"got {type(options).__name__}"
333343
)
334344
cdef tuple bufs = _coerce_buffer_targets(targets, "prefetch")
335345
cdef Py_ssize_t n = len(bufs)
@@ -420,8 +430,9 @@ def discard_prefetch(
420430
location : :class:`Location` | :obj:`~_device.Device` | int | Sequence[...]
421431
Target location(s). A single location applies to all targets;
422432
a sequence must match ``len(targets)``.
423-
options : None
424-
Reserved for future per-call flags. Must be ``None``.
433+
options : :class:`DiscardPrefetchOptions`, optional
434+
Reserved for future per-call flags. ``None`` (default) and
435+
``DiscardPrefetchOptions()`` are equivalent.
425436
stream : :class:`~_stream.Stream` | :class:`~graph.GraphBuilder`
426437
Stream for the asynchronous operation (keyword-only).
427438
@@ -431,10 +442,10 @@ def discard_prefetch(
431442
On a CUDA 12 build of ``cuda.core``. Discard-and-prefetch
432443
requires CUDA 13+.
433444
"""
434-
if options is not None:
445+
if options is not None and not isinstance(options, DiscardPrefetchOptions):
435446
raise TypeError(
436-
f"discard_prefetch options must be None (reserved); "
437-
f"got {type(options).__name__}"
447+
"discard_prefetch options must be a DiscardPrefetchOptions "
448+
f"instance or None, got {type(options).__name__}"
438449
)
439450
cdef tuple bufs = _coerce_buffer_targets(targets, "discard_prefetch")
440451
cdef Py_ssize_t n = len(bufs)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from dataclasses import dataclass
7+
8+
9+
@dataclass(frozen=True)
10+
class AdviseOptions:
11+
"""Per-call options for :func:`cuda.core.utils.advise`.
12+
13+
Reserved for future advise flags. Currently has no fields; pass
14+
``AdviseOptions()`` or ``None`` to use driver defaults.
15+
"""
16+
17+
18+
@dataclass(frozen=True)
19+
class PrefetchOptions:
20+
"""Per-call options for :func:`cuda.core.utils.prefetch`.
21+
22+
Reserved for future prefetch flags. Currently has no fields; pass
23+
``PrefetchOptions()`` or ``None`` to use driver defaults.
24+
"""
25+
26+
27+
@dataclass(frozen=True)
28+
class DiscardOptions:
29+
"""Per-call options for :func:`cuda.core.utils.discard`.
30+
31+
Reserved for future discard flags. Currently has no fields; pass
32+
``DiscardOptions()`` or ``None`` to use driver defaults.
33+
"""
34+
35+
36+
@dataclass(frozen=True)
37+
class DiscardPrefetchOptions:
38+
"""Per-call options for :func:`cuda.core.utils.discard_prefetch`.
39+
40+
Reserved for future discard-and-prefetch flags. Currently has no
41+
fields; pass ``DiscardPrefetchOptions()`` or ``None`` to use driver
42+
defaults.
43+
"""

cuda_core/cuda/core/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,20 @@
44

55
from cuda.core._memory._managed_location import Location
66
from cuda.core._memory._managed_memory_ops import advise, discard, discard_prefetch, prefetch
7+
from cuda.core._memory._managed_memory_options import (
8+
AdviseOptions,
9+
DiscardOptions,
10+
DiscardPrefetchOptions,
11+
PrefetchOptions,
12+
)
713
from cuda.core._memoryview import StridedMemoryView, args_viewable_as_strided_memory
814

915
__all__ = [
16+
"AdviseOptions",
17+
"DiscardOptions",
18+
"DiscardPrefetchOptions",
1019
"Location",
20+
"PrefetchOptions",
1121
"StridedMemoryView",
1222
"advise",
1323
"args_viewable_as_strided_memory",

cuda_core/tests/test_memory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2095,7 +2095,7 @@ def test_options_must_be_none(self, init_cuda):
20952095
mr = create_managed_memory_resource_or_skip()
20962096
buf = mr.allocate(_MANAGED_TEST_ALLOCATION_SIZE)
20972097
stream = device.create_stream()
2098-
with pytest.raises(TypeError, match="must be None"):
2098+
with pytest.raises(TypeError, match="must be a .*Options instance or None"):
20992099
prefetch(buf, Location.host(), options={}, stream=stream)
21002100
buf.close()
21012101

@@ -2156,7 +2156,7 @@ def test_options_must_be_none(self, init_cuda):
21562156
mr = create_managed_memory_resource_or_skip()
21572157
buf = mr.allocate(_MANAGED_TEST_ALLOCATION_SIZE)
21582158
stream = device.create_stream()
2159-
with pytest.raises(TypeError, match="must be None"):
2159+
with pytest.raises(TypeError, match="must be a .*Options instance or None"):
21602160
discard(buf, options={}, stream=stream)
21612161
buf.close()
21622162

@@ -2284,6 +2284,6 @@ def test_options_must_be_none(self, init_cuda):
22842284
_skip_if_managed_allocation_unsupported(device)
22852285
device.set_current()
22862286
buf = DummyUnifiedMemoryResource(device).allocate(_MANAGED_TEST_ALLOCATION_SIZE)
2287-
with pytest.raises(TypeError, match="must be None"):
2287+
with pytest.raises(TypeError, match="must be a .*Options instance or None"):
22882288
advise(buf, "set_read_mostly", options={})
22892289
buf.close()

0 commit comments

Comments
 (0)