Skip to content

Commit c5d1417

Browse files
authored
Add __eq__ and __hash__ to Buffer, LaunchConfig, Kernel, ObjectCode (#1534)
* Add __eq__ and __hash__ to Buffer, LaunchConfig, Kernel, ObjectCode Make these classes hashable and comparable: - Buffer: identity based on (type, ptr, size) - LaunchConfig: uses _LAUNCH_CONFIG_ATTRS tuple for forward-compatible identity; also updates __repr__ to use the same attribute list - Kernel: identity based on (type, handle) - ObjectCode: identity based on (type, handle), triggers lazy load Stream, Event, Context, Device already had __eq__/__hash__. * Fix test_hashable kernel fixture PTX version incompatibility Change object_code fixture to compile to cubin instead of ptx to avoid CUDA_ERROR_UNSUPPORTED_PTX_VERSION when the toolkit version is newer than the driver version on test machines. Also remove outdated "type salt" reference from assertion message.
1 parent bcb803f commit c5d1417

6 files changed

Lines changed: 143 additions & 20 deletions

File tree

cuda_core/cuda/core/_context.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ cdef class Context:
4646
return as_intptr(self._h_context) == as_intptr(_other._h_context)
4747

4848
def __hash__(self) -> int:
49-
return hash((type(self), as_intptr(self._h_context)))
49+
return hash(as_intptr(self._h_context))
5050

5151

5252
@dataclass

cuda_core/cuda/core/_event.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ cdef class Event:
169169
raise RuntimeError(explanation)
170170

171171
def __hash__(self) -> int:
172-
return hash((type(self), as_intptr(self._h_event)))
172+
return hash(as_intptr(self._h_event))
173173

174174
def __eq__(self, other) -> bool:
175175
# Note: using isinstance because `Event` can be subclassed.

cuda_core/cuda/core/_launch_config.pyx

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ cdef bint _inited = False
2121
cdef bint _use_ex = False
2222
cdef object _lock = threading.Lock()
2323

24+
# Attribute names for identity comparison and representation
25+
_LAUNCH_CONFIG_ATTRS = ('grid', 'cluster', 'block', 'shmem_size', 'cooperative_launch')
26+
2427

2528
cdef int _lazy_init() except?-1:
2629
global _inited, _use_ex
@@ -131,11 +134,21 @@ cdef class LaunchConfig:
131134
if self.cooperative_launch and not Device().properties.cooperative_launch:
132135
raise CUDAError("cooperative kernels are not supported on this device")
133136

137+
def _identity(self):
138+
return tuple(getattr(self, attr) for attr in _LAUNCH_CONFIG_ATTRS)
139+
134140
def __repr__(self):
135141
"""Return string representation of LaunchConfig."""
136-
return (f"LaunchConfig(grid={self.grid}, cluster={self.cluster}, "
137-
f"block={self.block}, shmem_size={self.shmem_size}, "
138-
f"cooperative_launch={self.cooperative_launch})")
142+
parts = ', '.join(f'{attr}={getattr(self, attr)!r}' for attr in _LAUNCH_CONFIG_ATTRS)
143+
return f"LaunchConfig({parts})"
144+
145+
def __eq__(self, other) -> bool:
146+
if not isinstance(other, LaunchConfig):
147+
return NotImplemented
148+
return self._identity() == (<LaunchConfig>other)._identity()
149+
150+
def __hash__(self) -> int:
151+
return hash(self._identity())
139152

140153
cdef cydriver.CUlaunchConfig _to_native_launch_config(self):
141154
_lazy_init()

cuda_core/cuda/core/_memory/_buffer.pyx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,16 @@ cdef class Buffer:
324324
# that expect a raw pointer value
325325
return as_intptr(self._h_ptr)
326326

327+
def __eq__(self, other) -> bool:
328+
if not isinstance(other, Buffer):
329+
return NotImplemented
330+
cdef Buffer other_buf = <Buffer>other
331+
return (as_intptr(self._h_ptr) == as_intptr(other_buf._h_ptr) and
332+
self._size == other_buf._size)
333+
334+
def __hash__(self) -> int:
335+
return hash((as_intptr(self._h_ptr), self._size))
336+
327337
@property
328338
def is_device_accessible(self) -> bool:
329339
"""Return True if this buffer can be accessed by the GPU, otherwise False."""

cuda_core/cuda/core/_module.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,14 @@ def from_handle(handle: int, mod: ObjectCode = None) -> Kernel:
528528

529529
return Kernel._from_obj(kernel_obj, mod)
530530

531+
def __eq__(self, other) -> bool:
532+
if not isinstance(other, Kernel):
533+
return NotImplemented
534+
return int(self._handle) == int(other._handle)
535+
536+
def __hash__(self) -> int:
537+
return hash(int(self._handle))
538+
531539

532540
CodeTypeT = bytes | bytearray | str
533541

@@ -757,3 +765,13 @@ def handle(self):
757765
handle, call ``int(ObjectCode.handle)``.
758766
"""
759767
return self._handle
768+
769+
def __eq__(self, other) -> bool:
770+
if not isinstance(other, ObjectCode):
771+
return NotImplemented
772+
# Trigger lazy load for both objects to compare handles
773+
return int(self.handle) == int(other.handle)
774+
775+
def __hash__(self) -> int:
776+
# Trigger lazy load to get the handle
777+
return hash(int(self.handle))

cuda_core/tests/test_hashable.py

Lines changed: 97 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

44
"""
@@ -12,24 +12,109 @@
1212
5. Hash/equality contract compliance (if a == b, then hash(a) must equal hash(b))
1313
"""
1414

15-
from cuda.core import Device
15+
import pytest
16+
from cuda.core import Device, LaunchConfig, Program
1617
from cuda.core._stream import Stream, StreamOptions
1718

19+
# ============================================================================
20+
# Fixtures for parameterized tests
21+
# ============================================================================
22+
23+
24+
@pytest.fixture
25+
def sample_device(init_cuda):
26+
return Device()
27+
28+
29+
@pytest.fixture
30+
def sample_stream(sample_device):
31+
return sample_device.create_stream()
32+
33+
34+
@pytest.fixture
35+
def sample_event(sample_device):
36+
return sample_device.create_event()
37+
38+
39+
@pytest.fixture
40+
def sample_context(sample_device):
41+
return sample_device.context
42+
43+
44+
@pytest.fixture
45+
def sample_buffer(sample_device):
46+
return sample_device.allocate(1024)
47+
48+
49+
@pytest.fixture
50+
def sample_launch_config():
51+
return LaunchConfig(grid=(1,), block=(1,))
52+
53+
54+
@pytest.fixture
55+
def sample_object_code(init_cuda):
56+
prog = Program('extern "C" __global__ void test_kernel() {}', "c++")
57+
return prog.compile("cubin")
58+
59+
60+
@pytest.fixture
61+
def sample_kernel(sample_object_code):
62+
return sample_object_code.get_kernel("test_kernel")
63+
64+
65+
# All hashable classes
66+
HASHABLE = [
67+
"sample_device",
68+
"sample_stream",
69+
"sample_event",
70+
"sample_context",
71+
"sample_buffer",
72+
"sample_launch_config",
73+
"sample_object_code",
74+
"sample_kernel",
75+
]
76+
77+
78+
# ============================================================================
79+
# Parameterized Hash Tests
80+
# ============================================================================
81+
82+
83+
@pytest.mark.parametrize("fixture_name", HASHABLE)
84+
def test_hash_consistency(fixture_name, request):
85+
"""Hash of same object is consistent across calls."""
86+
obj = request.getfixturevalue(fixture_name)
87+
assert hash(obj) == hash(obj)
88+
89+
90+
@pytest.mark.parametrize("fixture_name", HASHABLE)
91+
def test_set_membership(fixture_name, request):
92+
"""Objects work correctly in sets."""
93+
obj = request.getfixturevalue(fixture_name)
94+
s = {obj}
95+
assert obj in s
96+
assert len(s) == 1
97+
98+
99+
@pytest.mark.parametrize("fixture_name", HASHABLE)
100+
def test_dict_key(fixture_name, request):
101+
"""Objects work correctly as dict keys."""
102+
obj = request.getfixturevalue(fixture_name)
103+
d = {obj: "value"}
104+
assert d[obj] == "value"
105+
106+
18107
# ============================================================================
19108
# Integration Tests
20109
# ============================================================================
21110

22111

23-
def test_hash_type_disambiguation_and_mixed_dict(init_cuda):
24-
"""Test that hash salt (type(self)) prevents collisions between different types
25-
and that different object types can coexist in dictionaries.
112+
def test_mixed_type_dict(init_cuda):
113+
"""Test that different object types can coexist in dictionaries.
26114
27-
This test validates that:
28-
1. Including type(self) in the hash calculation ensures different types with
29-
potentially similar underlying values (like monotonically increasing handles
30-
or IDs) produce different hashes and don't collide.
31-
2. Different object types can be used together in the same dictionary without
32-
conflicts.
115+
Since each CUDA handle type has unique values within its type (handles are
116+
memory addresses or unique identifiers), hash collisions between different
117+
types are unlikely in practice.
33118
"""
34119
device = Device(0)
35120
device.set_current()
@@ -42,10 +127,7 @@ def test_hash_type_disambiguation_and_mixed_dict(init_cuda):
42127
# Test 1: Verify all hashes are unique (no collisions between different types)
43128
hashes = {hash(device), hash(stream), hash(event), hash(context)}
44129

45-
assert len(hashes) == 4, (
46-
f"Hash collision detected! Expected 4 unique hashes, got {len(hashes)}. "
47-
f"This indicates the type salt is not working correctly."
48-
)
130+
assert len(hashes) == 4, f"Hash collision detected! Expected 4 unique hashes, got {len(hashes)}. "
49131

50132
# Test 2: Verify all types can coexist in same dict without conflicts
51133
mixed_cache = {stream: "stream_data", event: "event_data", context: "context_data", device: "device_data"}

0 commit comments

Comments
 (0)