Skip to content

Commit 47d2358

Browse files
committed
fix(cuda.core): make AccessedBySetProxy.discard a no-op on NUMA-host inputs
`MutableSet.discard(x)` is contractually a no-op when `x` is not in the set. `set_accessed_by` only accepts `device` and generic `host` kinds, so NUMA-aware host variants (`Host(numa_id=...)`, `Host.numa_current()`) can never enter the set — they are necessarily non-members. Previously, `discard` forwarded such inputs to `_advise_one`, which raised `ValueError` from the kind-allowed check. That broke the `MutableSet` contract that `AccessedBySetProxy` claims by inheriting from `MutableSet`, and caused `test_accessed_by_mutable_set_interface` to fail on the helper's non-member sentinel. Add an explicit short-circuit for the NUMA-host kinds. Fixes CI failure for tests/memory/test_managed_ops.py::TestManagedBuffer::test_accessed_by_mutable_set_interface.
1 parent be46eed commit 47d2358

1 file changed

Lines changed: 10 additions & 1 deletion

File tree

cuda_core/cuda/core/_memory/_managed_buffer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,18 @@ def add(self, location: Device | Host) -> None:
103103
_advise_one(self._buf, _SET_ACCESSED_BY, location)
104104

105105
def discard(self, location: Device | Host) -> None:
106-
"""Apply ``unset_accessed_by`` advice for ``location``."""
106+
"""Apply ``unset_accessed_by`` advice for ``location``.
107+
108+
Per the ``MutableSet`` contract, ``discard`` is a no-op for elements
109+
not in the set. ``set_accessed_by`` only accepts ``Device`` and the
110+
generic ``Host()`` — NUMA-aware host variants (``Host(numa_id=...)``,
111+
``Host.numa_current()``) can never enter the set, so discarding them
112+
is silently ignored rather than forwarded to the driver.
113+
"""
107114
if not isinstance(location, (Device, Host)):
108115
return
116+
if isinstance(location, Host) and (location.numa_id is not None or location.is_numa_current):
117+
return
109118
_advise_one(self._buf, _UNSET_ACCESSED_BY, location)
110119

111120
def __repr__(self) -> str:

0 commit comments

Comments
 (0)