Skip to content

Commit f59af4e

Browse files
committed
chore(cuda.core): simplify ManagedBuffer per /simplify review
- Buffer.from_handle is now a classmethod that dispatches via cls._init, so subclasses inherit it: ManagedBuffer.from_handle(...) returns a ManagedBuffer with no override needed. Drop ManagedBuffer.from_handle. - Hoist `advise / prefetch / discard / discard_prefetch` imports from per-method lazy imports to module-level (no circular import: they live in cuda.core._memory._managed_memory_ops, not cuda.core.utils). - Cache the CUmem_advise and CUmem_range_attribute enum lookups at module level and pass enum constants directly to advise() instead of re-resolving from string aliases on every property write. - Extract _query_accessed_by as a module-level helper; AccessedBySet delegates and the accessed_by setter calls it directly instead of constructing a throwaway view.
1 parent bede674 commit f59af4e

2 files changed

Lines changed: 65 additions & 85 deletions

File tree

cuda_core/cuda/core/_memory/_buffer.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ cdef class Buffer:
130130
# Must not serialize the parent's stream!
131131
return Buffer._reduce_helper, (self.memory_resource, self.get_ipc_descriptor())
132132

133-
@staticmethod
133+
@classmethod
134134
def from_handle(
135+
cls,
135136
ptr: DevicePointerT, size_t size, mr: MemoryResource | None = None,
136137
owner: object | None = None,
137138
) -> Buffer:
@@ -157,8 +158,11 @@ cdef class Buffer:
157158
When neither ``mr`` nor ``owner`` is specified, this creates a
158159
non-owning reference. The pointer will NOT be freed when the
159160
:class:`Buffer` is closed or garbage collected.
161+
162+
Subclasses inherit this method via :meth:`Buffer._init`, so e.g.
163+
``ManagedBuffer.from_handle(ptr, size)`` returns a ``ManagedBuffer``.
160164
"""
161-
return Buffer._init(ptr, size, mr=mr, owner=owner)
165+
return cls._init(ptr, size, mr=mr, owner=owner)
162166

163167
@classmethod
164168
def from_ipc_descriptor(

cuda_core/cuda/core/_memory/_managed_buffer.py

Lines changed: 59 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from cuda.core._device import Device
99
from cuda.core._host import Host
1010
from cuda.core._memory._buffer import Buffer
11+
from cuda.core._memory._managed_memory_ops import advise, discard, discard_prefetch, prefetch
1112
from cuda.core._utils.cuda_utils import driver, handle_return
1213

1314
if TYPE_CHECKING:
@@ -17,11 +18,37 @@
1718

1819
_INT_SIZE = 4
1920

21+
# Enum aliases — referenced once per property write, so cache the lookup.
22+
_ADV = driver.CUmem_advise
23+
_SET_READ_MOSTLY = _ADV.CU_MEM_ADVISE_SET_READ_MOSTLY
24+
_UNSET_READ_MOSTLY = _ADV.CU_MEM_ADVISE_UNSET_READ_MOSTLY
25+
_SET_PREFERRED = _ADV.CU_MEM_ADVISE_SET_PREFERRED_LOCATION
26+
_UNSET_PREFERRED = _ADV.CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION
27+
_SET_ACCESSED_BY = _ADV.CU_MEM_ADVISE_SET_ACCESSED_BY
28+
_UNSET_ACCESSED_BY = _ADV.CU_MEM_ADVISE_UNSET_ACCESSED_BY
29+
30+
_RANGE = driver.CUmem_range_attribute
31+
_ATTR_READ_MOSTLY = _RANGE.CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY
32+
_ATTR_PREFERRED = _RANGE.CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION
33+
_ATTR_ACCESSED_BY = _RANGE.CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY
34+
2035

2136
def _get_int_attr(buf: Buffer, attribute) -> int:
2237
return handle_return(driver.cuMemRangeGetAttribute(_INT_SIZE, attribute, buf.handle, buf.size))
2338

2439

40+
def _query_accessed_by(buf: Buffer) -> list[Device | Host]:
41+
"""Read the live ``CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY`` list.
42+
43+
Driver fills an int32 array: device id, ``-1`` = host, ``-2`` = empty.
44+
Sized to ``cuDeviceGetCount() + 1`` (every visible device plus host).
45+
"""
46+
num_devices = handle_return(driver.cuDeviceGetCount())
47+
n = num_devices + 1
48+
raw = handle_return(driver.cuMemRangeGetAttribute(n * _INT_SIZE, _ATTR_ACCESSED_BY, buf.handle, buf.size))
49+
return [Host() if v == -1 else Device(v) for v in raw if v != -2]
50+
51+
2552
class AccessedBySet:
2653
"""Live driver-backed view of ``set_accessed_by`` advice for a managed buffer.
2754
@@ -32,75 +59,51 @@ class AccessedBySet:
3259
3360
Note
3461
----
35-
The driver's read-back path returns integer device ordinals (``-1`` for
36-
host); host NUMA distinctions applied via ``Host(numa_id=...)`` are not
37-
distinguishable from a generic ``Host()`` when iterating this set.
62+
The driver returns integer device ordinals (``-1`` for host); host
63+
NUMA distinctions applied via ``Host(numa_id=...)`` collapse to a
64+
generic ``Host()`` when iterating this set.
3865
"""
3966

4067
__slots__ = ("_buf",)
4168

4269
def __init__(self, buf: ManagedBuffer):
4370
self._buf = buf
4471

45-
def _query(self) -> list[Device | Host]:
46-
# Driver fills the array with device ordinals: device id, -1 = host,
47-
# -2 = empty slot. Size must accommodate every CUDA-visible device
48-
# plus a slot for the host. We use cuDeviceGetCount (driver-side) to
49-
# stay independent of NVML availability.
50-
num_devices = handle_return(driver.cuDeviceGetCount())
51-
n = num_devices + 1
52-
raw = handle_return(
53-
driver.cuMemRangeGetAttribute(
54-
n * _INT_SIZE,
55-
driver.CUmem_range_attribute.CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY,
56-
self._buf.handle,
57-
self._buf.size,
58-
)
59-
)
60-
result: list[Device | Host] = []
61-
for v in raw:
62-
if v == -2: # CU_DEVICE_INVALID — empty slot
63-
continue
64-
result.append(Host() if v == -1 else Device(v))
65-
return result
66-
6772
def __contains__(self, location) -> bool:
68-
return location in self._query()
73+
return location in _query_accessed_by(self._buf)
6974

7075
def __iter__(self):
71-
return iter(self._query())
76+
return iter(_query_accessed_by(self._buf))
7277

7378
def __len__(self) -> int:
74-
return len(self._query())
79+
return len(_query_accessed_by(self._buf))
7580

7681
def __eq__(self, other) -> bool:
7782
if isinstance(other, AccessedBySet):
78-
return set(self._query()) == set(other._query())
83+
return set(_query_accessed_by(self._buf)) == set(_query_accessed_by(other._buf))
7984
if isinstance(other, (set, frozenset)):
80-
return set(self._query()) == other
85+
return set(_query_accessed_by(self._buf)) == other
8186
return NotImplemented
8287

8388
def __repr__(self) -> str:
84-
return f"AccessedBySet({set(self._query())!r})"
89+
return f"AccessedBySet({set(_query_accessed_by(self._buf))!r})"
8590

8691
def add(self, location: Device | Host) -> None:
8792
"""Apply ``set_accessed_by`` advice for ``location``."""
88-
from cuda.core.utils import advise
89-
90-
advise(self._buf, "set_accessed_by", location)
93+
advise(self._buf, _SET_ACCESSED_BY, location)
9194

9295
def discard(self, location: Device | Host) -> None:
9396
"""Apply ``unset_accessed_by`` advice for ``location``."""
94-
from cuda.core.utils import advise
95-
96-
advise(self._buf, "unset_accessed_by", location)
97+
advise(self._buf, _UNSET_ACCESSED_BY, location)
9798

9899

99100
class ManagedBuffer(Buffer):
100101
"""Managed (unified) memory buffer with a property-style advice API.
101102
102-
Returned by :meth:`ManagedMemoryResource.allocate`. Wrap an external
103-
managed-memory pointer with :meth:`ManagedBuffer.from_handle`.
103+
Returned by :meth:`ManagedMemoryResource.allocate`, or wrap an
104+
existing managed-memory pointer with :meth:`Buffer.from_handle`
105+
(which dispatches by class — ``ManagedBuffer.from_handle(...)``
106+
returns a ``ManagedBuffer``).
104107
105108
Examples
106109
--------
@@ -112,42 +115,25 @@ class ManagedBuffer(Buffer):
112115
113116
Note
114117
----
115-
The driver's read-back path for ``preferred_location`` and
116-
``accessed_by`` returns integer device ordinals; host NUMA distinctions
117-
applied via ``Host(numa_id=...)`` collapse to a generic ``Host()`` when
118-
queried. Setters preserve full NUMA information when issuing advice.
118+
The legacy ``cuMemRangeGetAttribute`` query path returns integer
119+
device ordinals, so ``Host(numa_id=...)`` collapses to ``Host()``
120+
on read-back. Setters preserve full NUMA information when issuing
121+
advice.
119122
"""
120123

121-
@classmethod
122-
def from_handle(
123-
cls,
124-
ptr,
125-
size: int,
126-
mr=None,
127-
owner=None,
128-
) -> ManagedBuffer:
129-
"""Wrap an existing managed-memory pointer in a :class:`ManagedBuffer`."""
130-
return cls._init(ptr, size, mr=mr, owner=owner)
131-
132124
@property
133125
def read_mostly(self) -> bool:
134-
"""Whether ``set_read_mostly`` advice is currently applied to this range."""
135-
return _get_int_attr(self, driver.CUmem_range_attribute.CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY) != 0
126+
"""Whether ``set_read_mostly`` advice is currently applied."""
127+
return _get_int_attr(self, _ATTR_READ_MOSTLY) != 0
136128

137129
@read_mostly.setter
138130
def read_mostly(self, value: bool) -> None:
139-
from cuda.core.utils import advise
140-
141-
advise(self, "set_read_mostly" if value else "unset_read_mostly")
131+
advise(self, _SET_READ_MOSTLY if value else _UNSET_READ_MOSTLY)
142132

143133
@property
144134
def preferred_location(self) -> Device | Host | None:
145-
"""Currently applied ``set_preferred_location`` target, or ``None`` if unset."""
146-
# The legacy PREFERRED_LOCATION attribute returns a single int:
147-
# -2 = invalid (no preferred location), -1 = host, >=0 = device ordinal.
148-
# NUMA-specific preferences round-trip as a generic Host (CUDA driver
149-
# limitation of the legacy query path).
150-
loc_id = _get_int_attr(self, driver.CUmem_range_attribute.CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION)
135+
"""Currently applied ``set_preferred_location`` target, or ``None``."""
136+
loc_id = _get_int_attr(self, _ATTR_PREFERRED)
151137
if loc_id == -2:
152138
return None
153139
if loc_id == -1:
@@ -156,12 +142,10 @@ def preferred_location(self) -> Device | Host | None:
156142

157143
@preferred_location.setter
158144
def preferred_location(self, value: Device | Host | None) -> None:
159-
from cuda.core.utils import advise
160-
161145
if value is None:
162-
advise(self, "unset_preferred_location")
146+
advise(self, _UNSET_PREFERRED)
163147
else:
164-
advise(self, "set_preferred_location", value)
148+
advise(self, _SET_PREFERRED, value)
165149

166150
@property
167151
def accessed_by(self) -> AccessedBySet:
@@ -171,29 +155,21 @@ def accessed_by(self) -> AccessedBySet:
171155
@accessed_by.setter
172156
def accessed_by(self, locations) -> None:
173157
# Diff against the current driver state and advise only the deltas.
174-
from cuda.core.utils import advise
175-
176-
current = set(AccessedBySet(self))
158+
current = set(_query_accessed_by(self))
177159
target = set(locations)
178160
for loc in current - target:
179-
advise(self, "unset_accessed_by", loc)
161+
advise(self, _UNSET_ACCESSED_BY, loc)
180162
for loc in target - current:
181-
advise(self, "set_accessed_by", loc)
163+
advise(self, _SET_ACCESSED_BY, loc)
182164

183165
def prefetch(self, location: Device | Host | int, *, stream: Stream | GraphBuilder) -> None:
184166
"""Prefetch this range to ``location`` on ``stream``."""
185-
from cuda.core.utils import prefetch as _prefetch
186-
187-
_prefetch(self, location, stream=stream)
167+
prefetch(self, location, stream=stream)
188168

189169
def discard(self, *, stream: Stream | GraphBuilder) -> None:
190170
"""Discard this range's resident pages on ``stream`` (CUDA 13+)."""
191-
from cuda.core.utils import discard as _discard
192-
193-
_discard(self, stream=stream)
171+
discard(self, stream=stream)
194172

195173
def discard_prefetch(self, location: Device | Host | int, *, stream: Stream | GraphBuilder) -> None:
196174
"""Discard this range and prefetch to ``location`` on ``stream`` (CUDA 13+)."""
197-
from cuda.core.utils import discard_prefetch as _discard_prefetch
198-
199-
_discard_prefetch(self, location, stream=stream)
175+
discard_prefetch(self, location, stream=stream)

0 commit comments

Comments
 (0)