88from cuda .core ._device import Device
99from cuda .core ._host import Host
1010from cuda .core ._memory ._buffer import Buffer
11+ from cuda .core ._memory ._managed_memory_ops import advise , discard , discard_prefetch , prefetch
1112from cuda .core ._utils .cuda_utils import driver , handle_return
1213
1314if TYPE_CHECKING :
1718
1819_INT_SIZE = 4
1920
21+ # Enum aliases — referenced once per property write, so cache the lookup.
22+ _ADV = driver .CUmem_advise
23+ _SET_READ_MOSTLY = _ADV .CU_MEM_ADVISE_SET_READ_MOSTLY
24+ _UNSET_READ_MOSTLY = _ADV .CU_MEM_ADVISE_UNSET_READ_MOSTLY
25+ _SET_PREFERRED = _ADV .CU_MEM_ADVISE_SET_PREFERRED_LOCATION
26+ _UNSET_PREFERRED = _ADV .CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION
27+ _SET_ACCESSED_BY = _ADV .CU_MEM_ADVISE_SET_ACCESSED_BY
28+ _UNSET_ACCESSED_BY = _ADV .CU_MEM_ADVISE_UNSET_ACCESSED_BY
29+
30+ _RANGE = driver .CUmem_range_attribute
31+ _ATTR_READ_MOSTLY = _RANGE .CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY
32+ _ATTR_PREFERRED = _RANGE .CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION
33+ _ATTR_ACCESSED_BY = _RANGE .CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY
34+
2035
2136def _get_int_attr (buf : Buffer , attribute ) -> int :
2237 return handle_return (driver .cuMemRangeGetAttribute (_INT_SIZE , attribute , buf .handle , buf .size ))
2338
2439
40+ def _query_accessed_by (buf : Buffer ) -> list [Device | Host ]:
41+ """Read the live ``CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY`` list.
42+
43+ Driver fills an int32 array: device id, ``-1`` = host, ``-2`` = empty.
44+ Sized to ``cuDeviceGetCount() + 1`` (every visible device plus host).
45+ """
46+ num_devices = handle_return (driver .cuDeviceGetCount ())
47+ n = num_devices + 1
48+ raw = handle_return (driver .cuMemRangeGetAttribute (n * _INT_SIZE , _ATTR_ACCESSED_BY , buf .handle , buf .size ))
49+ return [Host () if v == - 1 else Device (v ) for v in raw if v != - 2 ]
50+
51+
2552class AccessedBySet :
2653 """Live driver-backed view of ``set_accessed_by`` advice for a managed buffer.
2754
@@ -32,75 +59,51 @@ class AccessedBySet:
3259
3360 Note
3461 ----
35- The driver's read-back path returns integer device ordinals (``-1`` for
36- host); host NUMA distinctions applied via ``Host(numa_id=...)`` are not
37- distinguishable from a generic ``Host()`` when iterating this set.
62+ The driver returns integer device ordinals (``-1`` for host); host
63+ NUMA distinctions applied via ``Host(numa_id=...)`` collapse to a
64+ generic ``Host()`` when iterating this set.
3865 """
3966
4067 __slots__ = ("_buf" ,)
4168
4269 def __init__ (self , buf : ManagedBuffer ):
4370 self ._buf = buf
4471
45- def _query (self ) -> list [Device | Host ]:
46- # Driver fills the array with device ordinals: device id, -1 = host,
47- # -2 = empty slot. Size must accommodate every CUDA-visible device
48- # plus a slot for the host. We use cuDeviceGetCount (driver-side) to
49- # stay independent of NVML availability.
50- num_devices = handle_return (driver .cuDeviceGetCount ())
51- n = num_devices + 1
52- raw = handle_return (
53- driver .cuMemRangeGetAttribute (
54- n * _INT_SIZE ,
55- driver .CUmem_range_attribute .CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY ,
56- self ._buf .handle ,
57- self ._buf .size ,
58- )
59- )
60- result : list [Device | Host ] = []
61- for v in raw :
62- if v == - 2 : # CU_DEVICE_INVALID — empty slot
63- continue
64- result .append (Host () if v == - 1 else Device (v ))
65- return result
66-
6772 def __contains__ (self , location ) -> bool :
68- return location in self ._query ( )
73+ return location in _query_accessed_by ( self ._buf )
6974
7075 def __iter__ (self ):
71- return iter (self ._query ( ))
76+ return iter (_query_accessed_by ( self ._buf ))
7277
7378 def __len__ (self ) -> int :
74- return len (self ._query ( ))
79+ return len (_query_accessed_by ( self ._buf ))
7580
7681 def __eq__ (self , other ) -> bool :
7782 if isinstance (other , AccessedBySet ):
78- return set (self ._query ( )) == set (other ._query ( ))
83+ return set (_query_accessed_by ( self ._buf )) == set (_query_accessed_by ( other ._buf ))
7984 if isinstance (other , (set , frozenset )):
80- return set (self ._query ( )) == other
85+ return set (_query_accessed_by ( self ._buf )) == other
8186 return NotImplemented
8287
8388 def __repr__ (self ) -> str :
84- return f"AccessedBySet({ set (self ._query ( ))!r} )"
89+ return f"AccessedBySet({ set (_query_accessed_by ( self ._buf ))!r} )"
8590
8691 def add (self , location : Device | Host ) -> None :
8792 """Apply ``set_accessed_by`` advice for ``location``."""
88- from cuda .core .utils import advise
89-
90- advise (self ._buf , "set_accessed_by" , location )
93+ advise (self ._buf , _SET_ACCESSED_BY , location )
9194
9295 def discard (self , location : Device | Host ) -> None :
9396 """Apply ``unset_accessed_by`` advice for ``location``."""
94- from cuda .core .utils import advise
95-
96- advise (self ._buf , "unset_accessed_by" , location )
97+ advise (self ._buf , _UNSET_ACCESSED_BY , location )
9798
9899
99100class ManagedBuffer (Buffer ):
100101 """Managed (unified) memory buffer with a property-style advice API.
101102
102- Returned by :meth:`ManagedMemoryResource.allocate`. Wrap an external
103- managed-memory pointer with :meth:`ManagedBuffer.from_handle`.
103+ Returned by :meth:`ManagedMemoryResource.allocate`, or wrap an
104+ existing managed-memory pointer with :meth:`Buffer.from_handle`
105+ (which dispatches by class — ``ManagedBuffer.from_handle(...)``
106+ returns a ``ManagedBuffer``).
104107
105108 Examples
106109 --------
@@ -112,42 +115,25 @@ class ManagedBuffer(Buffer):
112115
113116 Note
114117 ----
115- The driver's read-back path for ``preferred_location `` and
116- ``accessed_by `` returns integer device ordinals; host NUMA distinctions
117- applied via ``Host(numa_id=...)`` collapse to a generic ``Host()`` when
118- queried. Setters preserve full NUMA information when issuing advice.
118+ The legacy ``cuMemRangeGetAttribute `` query path returns integer
119+ device ordinals, so ``Host(numa_id=...) `` collapses to ``Host()``
120+ on read-back. Setters preserve full NUMA information when issuing
121+ advice.
119122 """
120123
121- @classmethod
122- def from_handle (
123- cls ,
124- ptr ,
125- size : int ,
126- mr = None ,
127- owner = None ,
128- ) -> ManagedBuffer :
129- """Wrap an existing managed-memory pointer in a :class:`ManagedBuffer`."""
130- return cls ._init (ptr , size , mr = mr , owner = owner )
131-
132124 @property
133125 def read_mostly (self ) -> bool :
134- """Whether ``set_read_mostly`` advice is currently applied to this range ."""
135- return _get_int_attr (self , driver . CUmem_range_attribute . CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY ) != 0
126+ """Whether ``set_read_mostly`` advice is currently applied."""
127+ return _get_int_attr (self , _ATTR_READ_MOSTLY ) != 0
136128
137129 @read_mostly .setter
138130 def read_mostly (self , value : bool ) -> None :
139- from cuda .core .utils import advise
140-
141- advise (self , "set_read_mostly" if value else "unset_read_mostly" )
131+ advise (self , _SET_READ_MOSTLY if value else _UNSET_READ_MOSTLY )
142132
143133 @property
144134 def preferred_location (self ) -> Device | Host | None :
145- """Currently applied ``set_preferred_location`` target, or ``None`` if unset."""
146- # The legacy PREFERRED_LOCATION attribute returns a single int:
147- # -2 = invalid (no preferred location), -1 = host, >=0 = device ordinal.
148- # NUMA-specific preferences round-trip as a generic Host (CUDA driver
149- # limitation of the legacy query path).
150- loc_id = _get_int_attr (self , driver .CUmem_range_attribute .CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION )
135+ """Currently applied ``set_preferred_location`` target, or ``None``."""
136+ loc_id = _get_int_attr (self , _ATTR_PREFERRED )
151137 if loc_id == - 2 :
152138 return None
153139 if loc_id == - 1 :
@@ -156,12 +142,10 @@ def preferred_location(self) -> Device | Host | None:
156142
157143 @preferred_location .setter
158144 def preferred_location (self , value : Device | Host | None ) -> None :
159- from cuda .core .utils import advise
160-
161145 if value is None :
162- advise (self , "unset_preferred_location" )
146+ advise (self , _UNSET_PREFERRED )
163147 else :
164- advise (self , "set_preferred_location" , value )
148+ advise (self , _SET_PREFERRED , value )
165149
166150 @property
167151 def accessed_by (self ) -> AccessedBySet :
@@ -171,29 +155,21 @@ def accessed_by(self) -> AccessedBySet:
171155 @accessed_by .setter
172156 def accessed_by (self , locations ) -> None :
173157 # Diff against the current driver state and advise only the deltas.
174- from cuda .core .utils import advise
175-
176- current = set (AccessedBySet (self ))
158+ current = set (_query_accessed_by (self ))
177159 target = set (locations )
178160 for loc in current - target :
179- advise (self , "unset_accessed_by" , loc )
161+ advise (self , _UNSET_ACCESSED_BY , loc )
180162 for loc in target - current :
181- advise (self , "set_accessed_by" , loc )
163+ advise (self , _SET_ACCESSED_BY , loc )
182164
183165 def prefetch (self , location : Device | Host | int , * , stream : Stream | GraphBuilder ) -> None :
184166 """Prefetch this range to ``location`` on ``stream``."""
185- from cuda .core .utils import prefetch as _prefetch
186-
187- _prefetch (self , location , stream = stream )
167+ prefetch (self , location , stream = stream )
188168
189169 def discard (self , * , stream : Stream | GraphBuilder ) -> None :
190170 """Discard this range's resident pages on ``stream`` (CUDA 13+)."""
191- from cuda .core .utils import discard as _discard
192-
193- _discard (self , stream = stream )
171+ discard (self , stream = stream )
194172
195173 def discard_prefetch (self , location : Device | Host | int , * , stream : Stream | GraphBuilder ) -> None :
196174 """Discard this range and prefetch to ``location`` on ``stream`` (CUDA 13+)."""
197- from cuda .core .utils import discard_prefetch as _discard_prefetch
198-
199- _discard_prefetch (self , location , stream = stream )
175+ discard_prefetch (self , location , stream = stream )
0 commit comments