Skip to content

Commit bede674

Browse files
committed
feat(cuda.core): add ManagedBuffer subclass + Host location
Land Andy's ManagedBuffer + Device/Host design (review #3976251223, #3164213789). The free-function shape introduced earlier in this PR is preserved; ManagedBuffer methods delegate into it, so existing call sites keep working. ManagedBuffer - Subclass of Buffer returned by ManagedMemoryResource.allocate, also constructable from an external pointer via ManagedBuffer.from_handle. - Property-style advice API: - read_mostly (bool, driver-backed get/set) - preferred_location (Device | Host | None, get/set; None unsets) - accessed_by (live AccessedBySet view: __contains__/__iter__/len query the driver, add()/discard() issue advice; setter diffs and advises only the deltas) - Instance methods prefetch / discard / discard_prefetch delegate to the matching cuda.core.utils functions. Host - New top-level class symmetric to Device. Host(), Host(numa_id=N), Host.numa_current(). Replaces Location.host()/host_numa()/etc. Location -> Device|Host|int - Drop the public Location dataclass and its classmethod constructors. - _coerce_location now accepts Device | Host | int | None and produces an internal _LocSpec record; advise/prefetch/discard/discard_prefetch signatures and docstrings updated accordingly. - int still accepted for ergonomic compatibility (-1 = host, >=0 = device ordinal). Plumbing - Buffer_from_deviceptr_handle takes an optional `cls` parameter so the pool allocator can materialize Buffer subclasses; _MP_allocate threads the same parameter through; ManagedMemoryResource.allocate passes ManagedBuffer. Tests - TestHost replaces TestLocation; TestLocationCoerce adapted to the new coerce signature. New TestManagedBuffer covers from_handle, isinstance(allocate(), ManagedBuffer), read_mostly/preferred_location/ accessed_by roundtrips, and instance methods. Property tests use external (cuMemAllocManaged) backing wrapped via from_handle, since some driver/device combinations decline cuMemAdvise on pool-allocated managed memory. - Use cuDeviceGetCount in AccessedBySet._query so the read path doesn't pull in NVML. Docs - 1.0.0 notes describe Host, ManagedBuffer, the property API, and the Device/Host location inputs. api.rst lists Host, ManagedBuffer, and the *Options dataclasses; Location is removed.
1 parent c2a9662 commit bede674

16 files changed

Lines changed: 610 additions & 176 deletions

cuda_core/cuda/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def _import_versioned_module():
3232
from cuda.core._device import Device
3333
from cuda.core._event import Event, EventOptions
3434
from cuda.core._graphics import GraphicsResource
35+
from cuda.core._host import Host
3536
from cuda.core._launch_config import LaunchConfig
3637
from cuda.core._launcher import launch
3738
from cuda.core._linker import Linker, LinkerOptions
@@ -41,6 +42,7 @@ def _import_versioned_module():
4142
DeviceMemoryResourceOptions,
4243
GraphMemoryResource,
4344
LegacyPinnedMemoryResource,
45+
ManagedBuffer,
4446
ManagedMemoryResource,
4547
ManagedMemoryResourceOptions,
4648
MemoryResource,

cuda_core/cuda/core/_host.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from dataclasses import dataclass
7+
8+
9+
@dataclass(frozen=True)
10+
class Host:
11+
"""Host (CPU) location for managed-memory operations.
12+
13+
Use one of the three forms:
14+
15+
* ``Host()`` — generic host (any NUMA node).
16+
* ``Host(numa_id=N)`` — specific NUMA node ``N``.
17+
* ``Host.numa_current()`` — NUMA node of the calling thread.
18+
19+
``Host`` is the symmetric counterpart of :class:`~cuda.core.Device`
20+
for managed-memory `prefetch`, `advise`, and `discard_prefetch`
21+
targets. Pass either a ``Device`` or a ``Host`` to those operations
22+
and to ``ManagedBuffer.preferred_location`` / ``accessed_by``.
23+
"""
24+
25+
numa_id: int | None = None
26+
is_numa_current: bool = False
27+
28+
def __post_init__(self) -> None:
29+
if self.is_numa_current and self.numa_id is not None:
30+
raise ValueError("Host.numa_current() cannot have an explicit numa_id")
31+
if self.numa_id is not None and (not isinstance(self.numa_id, int) or self.numa_id < 0):
32+
raise ValueError(f"numa_id must be a non-negative int, got {self.numa_id!r}")
33+
34+
@classmethod
35+
def numa_current(cls) -> Host:
36+
"""Construct a ``Host`` referring to the calling thread's NUMA node."""
37+
return cls(is_numa_current=True)
38+
39+
def __repr__(self) -> str:
40+
if self.is_numa_current:
41+
return "Host.numa_current()"
42+
if self.numa_id is None:
43+
return "Host()"
44+
return f"Host(numa_id={self.numa_id})"

cuda_core/cuda/core/_memory/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ._graph_memory_resource import *
88
from ._ipc import *
99
from ._legacy import *
10+
from ._managed_buffer import ManagedBuffer
1011
from ._managed_memory_resource import *
1112
from ._pinned_memory_resource import *
1213
from ._virtual_memory_resource import *

cuda_core/cuda/core/_memory/_buffer.pxd

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,15 @@ cdef class MemoryResource:
3232
pass
3333

3434

35-
# Helper function to create a Buffer from a DevicePtrHandle
35+
# Helper function to create a Buffer from a DevicePtrHandle.
36+
# `cls` lets callers materialize Buffer subclasses (e.g. ManagedBuffer for
37+
# managed-memory allocations); defaults to Buffer.
3638
cdef Buffer Buffer_from_deviceptr_handle(
3739
DevicePtrHandle h_ptr,
3840
size_t size,
3941
MemoryResource mr,
40-
object ipc_descriptor = *
42+
object ipc_descriptor = *,
43+
type cls = *,
4144
)
4245

4346
# Memory attribute query helpers (used by _managed_memory_ops)

cuda_core/cuda/core/_memory/_buffer.pyx

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -548,14 +548,15 @@ cdef class MemoryResource:
548548

549549
# Buffer Implementation Helpers
550550
# -----------------------------
551-
cdef inline Buffer Buffer_from_deviceptr_handle(
551+
cdef Buffer Buffer_from_deviceptr_handle(
552552
DevicePtrHandle h_ptr,
553553
size_t size,
554554
MemoryResource mr,
555-
object ipc_descriptor = None
555+
object ipc_descriptor = None,
556+
type cls = Buffer,
556557
):
557-
"""Create a Buffer from an existing DevicePtrHandle."""
558-
cdef Buffer buf = Buffer.__new__(Buffer)
558+
"""Create a Buffer (or subclass instance) from an existing DevicePtrHandle."""
559+
cdef Buffer buf = cls.__new__(cls)
559560
buf._h_ptr = h_ptr
560561
buf._size = size
561562
buf._memory_resource = mr
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from typing import TYPE_CHECKING
7+
8+
from cuda.core._device import Device
9+
from cuda.core._host import Host
10+
from cuda.core._memory._buffer import Buffer
11+
from cuda.core._utils.cuda_utils import driver, handle_return
12+
13+
if TYPE_CHECKING:
14+
from cuda.core._stream import Stream
15+
from cuda.core.graph import GraphBuilder
16+
17+
18+
_INT_SIZE = 4
19+
20+
21+
def _get_int_attr(buf: Buffer, attribute) -> int:
22+
return handle_return(driver.cuMemRangeGetAttribute(_INT_SIZE, attribute, buf.handle, buf.size))
23+
24+
25+
class AccessedBySet:
26+
"""Live driver-backed view of ``set_accessed_by`` advice for a managed buffer.
27+
28+
Reads (``__contains__``, ``__iter__``, ``len(...)``) call
29+
``cuMemRangeGetAttribute``; writes (``add``, ``discard``) call
30+
``cuMemAdvise``. There is no in-memory mirror, so the view always
31+
reflects the current driver state.
32+
33+
Note
34+
----
35+
The driver's read-back path returns integer device ordinals (``-1`` for
36+
host); host NUMA distinctions applied via ``Host(numa_id=...)`` are not
37+
distinguishable from a generic ``Host()`` when iterating this set.
38+
"""
39+
40+
__slots__ = ("_buf",)
41+
42+
def __init__(self, buf: ManagedBuffer):
43+
self._buf = buf
44+
45+
def _query(self) -> list[Device | Host]:
46+
# Driver fills the array with device ordinals: device id, -1 = host,
47+
# -2 = empty slot. Size must accommodate every CUDA-visible device
48+
# plus a slot for the host. We use cuDeviceGetCount (driver-side) to
49+
# stay independent of NVML availability.
50+
num_devices = handle_return(driver.cuDeviceGetCount())
51+
n = num_devices + 1
52+
raw = handle_return(
53+
driver.cuMemRangeGetAttribute(
54+
n * _INT_SIZE,
55+
driver.CUmem_range_attribute.CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY,
56+
self._buf.handle,
57+
self._buf.size,
58+
)
59+
)
60+
result: list[Device | Host] = []
61+
for v in raw:
62+
if v == -2: # CU_DEVICE_INVALID — empty slot
63+
continue
64+
result.append(Host() if v == -1 else Device(v))
65+
return result
66+
67+
def __contains__(self, location) -> bool:
68+
return location in self._query()
69+
70+
def __iter__(self):
71+
return iter(self._query())
72+
73+
def __len__(self) -> int:
74+
return len(self._query())
75+
76+
def __eq__(self, other) -> bool:
77+
if isinstance(other, AccessedBySet):
78+
return set(self._query()) == set(other._query())
79+
if isinstance(other, (set, frozenset)):
80+
return set(self._query()) == other
81+
return NotImplemented
82+
83+
def __repr__(self) -> str:
84+
return f"AccessedBySet({set(self._query())!r})"
85+
86+
def add(self, location: Device | Host) -> None:
87+
"""Apply ``set_accessed_by`` advice for ``location``."""
88+
from cuda.core.utils import advise
89+
90+
advise(self._buf, "set_accessed_by", location)
91+
92+
def discard(self, location: Device | Host) -> None:
93+
"""Apply ``unset_accessed_by`` advice for ``location``."""
94+
from cuda.core.utils import advise
95+
96+
advise(self._buf, "unset_accessed_by", location)
97+
98+
99+
class ManagedBuffer(Buffer):
100+
"""Managed (unified) memory buffer with a property-style advice API.
101+
102+
Returned by :meth:`ManagedMemoryResource.allocate`. Wrap an external
103+
managed-memory pointer with :meth:`ManagedBuffer.from_handle`.
104+
105+
Examples
106+
--------
107+
>>> buf = mr.allocate(size)
108+
>>> buf.read_mostly = True
109+
>>> buf.preferred_location = Device(0)
110+
>>> buf.accessed_by.add(Device(1))
111+
>>> buf.prefetch(Device(0), stream=stream)
112+
113+
Note
114+
----
115+
The driver's read-back path for ``preferred_location`` and
116+
``accessed_by`` returns integer device ordinals; host NUMA distinctions
117+
applied via ``Host(numa_id=...)`` collapse to a generic ``Host()`` when
118+
queried. Setters preserve full NUMA information when issuing advice.
119+
"""
120+
121+
@classmethod
122+
def from_handle(
123+
cls,
124+
ptr,
125+
size: int,
126+
mr=None,
127+
owner=None,
128+
) -> ManagedBuffer:
129+
"""Wrap an existing managed-memory pointer in a :class:`ManagedBuffer`."""
130+
return cls._init(ptr, size, mr=mr, owner=owner)
131+
132+
@property
133+
def read_mostly(self) -> bool:
134+
"""Whether ``set_read_mostly`` advice is currently applied to this range."""
135+
return _get_int_attr(self, driver.CUmem_range_attribute.CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY) != 0
136+
137+
@read_mostly.setter
138+
def read_mostly(self, value: bool) -> None:
139+
from cuda.core.utils import advise
140+
141+
advise(self, "set_read_mostly" if value else "unset_read_mostly")
142+
143+
@property
144+
def preferred_location(self) -> Device | Host | None:
145+
"""Currently applied ``set_preferred_location`` target, or ``None`` if unset."""
146+
# The legacy PREFERRED_LOCATION attribute returns a single int:
147+
# -2 = invalid (no preferred location), -1 = host, >=0 = device ordinal.
148+
# NUMA-specific preferences round-trip as a generic Host (CUDA driver
149+
# limitation of the legacy query path).
150+
loc_id = _get_int_attr(self, driver.CUmem_range_attribute.CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION)
151+
if loc_id == -2:
152+
return None
153+
if loc_id == -1:
154+
return Host()
155+
return Device(loc_id)
156+
157+
@preferred_location.setter
158+
def preferred_location(self, value: Device | Host | None) -> None:
159+
from cuda.core.utils import advise
160+
161+
if value is None:
162+
advise(self, "unset_preferred_location")
163+
else:
164+
advise(self, "set_preferred_location", value)
165+
166+
@property
167+
def accessed_by(self) -> AccessedBySet:
168+
"""Live set-like view of ``set_accessed_by`` locations."""
169+
return AccessedBySet(self)
170+
171+
@accessed_by.setter
172+
def accessed_by(self, locations) -> None:
173+
# Diff against the current driver state and advise only the deltas.
174+
from cuda.core.utils import advise
175+
176+
current = set(AccessedBySet(self))
177+
target = set(locations)
178+
for loc in current - target:
179+
advise(self, "unset_accessed_by", loc)
180+
for loc in target - current:
181+
advise(self, "set_accessed_by", loc)
182+
183+
def prefetch(self, location: Device | Host | int, *, stream: Stream | GraphBuilder) -> None:
184+
"""Prefetch this range to ``location`` on ``stream``."""
185+
from cuda.core.utils import prefetch as _prefetch
186+
187+
_prefetch(self, location, stream=stream)
188+
189+
def discard(self, *, stream: Stream | GraphBuilder) -> None:
190+
"""Discard this range's resident pages on ``stream`` (CUDA 13+)."""
191+
from cuda.core.utils import discard as _discard
192+
193+
_discard(self, stream=stream)
194+
195+
def discard_prefetch(self, location: Device | Host | int, *, stream: Stream | GraphBuilder) -> None:
196+
"""Discard this range and prefetch to ``location`` on ``stream`` (CUDA 13+)."""
197+
from cuda.core.utils import discard_prefetch as _discard_prefetch
198+
199+
_discard_prefetch(self, location, stream=stream)

0 commit comments

Comments
 (0)