Skip to content

Commit 7ca634b

Browse files
Implement __hash__ and __eq__ for cuda.core classes (NVIDIA#1198)
* Add hash and equality support to Stream, Event, Context, and Device classes, enabling their use as dictionary keys and in sets. * pre-commit fixes * pre-commit fixes * Update cuda_core/cuda/core/experimental/_event.pyx Co-authored-by: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> * return False on type error * reverting int conversion fix due to cython compilation error * Addressing performance nit feedback * Reverting back to returning NotImplemented on type error in __eq__ based on feedback. * Adding to the __hash__ function for all core types. * Updating __eq__: Adding comment to explain perf optimization and ensuring we are doing a direct c-cast instead of isinstance * Test organization: moving type specific tests to their corresponding files * Device __hash__ function not uses uuid instead of device ordinal * Adding context to the stream hash implementation * forgot to include event changes * cython compilation fixes * Checking for context initalization * Handle builtin streams for fetching context * Fixing seg fault * memoize uuid property * memoize the uuid * Removing unnecessary casts * Making docstrings consistent with other dunder methods * Address _event.pyx potential seg fault feedback when casting Event objects * reverting change that prevented lazy init * Adding an isinstance check for Stream.__eq__ to avoid seg faults * Fixes equality and hash to make these consistent. Both include context handle. * Moving the uuid string generation code from a global function to an inline * Pushing missing updates to Device.__eq__ * Addressing feedback and adding more unit tests that validate equality and hash behaviour is consistent. * formatting * pre-commit fixes --------- Co-authored-by: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com>
1 parent 69aac67 commit 7ca634b

10 files changed

Lines changed: 931 additions & 14 deletions

File tree

cuda_core/cuda/core/experimental/_context.pyx

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,10 @@ cdef class Context:
2929
return ctx
3030

3131
def __eq__(self, other):
32-
return int(self._handle) == int(other._handle)
32+
if not isinstance(other, Context):
33+
return NotImplemented
34+
cdef Context _other = <Context>other
35+
return int(self._handle) == int(_other._handle)
36+
37+
def __hash__(self) -> int:
38+
return hash(int(self._handle))

cuda_core/cuda/core/experimental/_device.pyx

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ class Device:
949949
Default value of `None` return the currently used device.
950950

951951
"""
952-
__slots__ = ("_id", "_mr", "_has_inited", "_properties")
952+
__slots__ = ("_id", "_mr", "_has_inited", "_properties", "_uuid")
953953

954954
def __new__(cls, device_id: int | None = None):
955955
global _is_cuInit
@@ -1002,6 +1002,7 @@ class Device:
10021002

10031003
device._has_inited = False
10041004
device._properties = None
1005+
device._uuid = None
10051006
devices.append(device)
10061007

10071008
try:
@@ -1053,18 +1054,26 @@ class Device:
10531054
MIG UUID is only returned when device is in MIG mode and the
10541055
driver is older than CUDA 11.4.
10551056

1057+
The UUID is cached after first access to avoid repeated CUDA API calls.
1058+
10561059
"""
10571060
cdef cydriver.CUuuid uuid
1058-
cdef cydriver.CUdevice this_dev = self._id
1059-
with nogil:
1060-
IF CUDA_CORE_BUILD_MAJOR == "12":
1061-
HANDLE_RETURN(cydriver.cuDeviceGetUuid_v2(&uuid, this_dev))
1062-
ELSE: # 13.0+
1063-
HANDLE_RETURN(cydriver.cuDeviceGetUuid(&uuid, this_dev))
1064-
cdef bytes uuid_b = cpython.PyBytes_FromStringAndSize(uuid.bytes, sizeof(uuid.bytes))
1065-
cdef str uuid_hex = uuid_b.hex()
1066-
# 8-4-4-4-12
1067-
return f"{uuid_hex[:8]}-{uuid_hex[8:12]}-{uuid_hex[12:16]}-{uuid_hex[16:20]}-{uuid_hex[20:]}"
1061+
cdef cydriver.CUdevice dev
1062+
cdef bytes uuid_b
1063+
cdef str uuid_hex
1064+
1065+
if self._uuid is None:
1066+
dev = self._id
1067+
with nogil:
1068+
IF CUDA_CORE_BUILD_MAJOR == "12":
1069+
HANDLE_RETURN(cydriver.cuDeviceGetUuid_v2(&uuid, dev))
1070+
ELSE: # 13.0+
1071+
HANDLE_RETURN(cydriver.cuDeviceGetUuid(&uuid, dev))
1072+
uuid_b = cpython.PyBytes_FromStringAndSize(uuid.bytes, sizeof(uuid.bytes))
1073+
uuid_hex = uuid_b.hex()
1074+
# 8-4-4-4-12
1075+
self._uuid = f"{uuid_hex[:8]}-{uuid_hex[8:12]}-{uuid_hex[12:16]}-{uuid_hex[16:20]}-{uuid_hex[20:]}"
1076+
return self._uuid
10681077

10691078
@property
10701079
def name(self) -> str:
@@ -1145,6 +1154,14 @@ class Device:
11451154
def __repr__(self):
11461155
return f"<Device {self._id} ({self.name})>"
11471156

1157+
def __hash__(self) -> int:
1158+
return hash(self.uuid)
1159+
1160+
def __eq__(self, other) -> bool:
1161+
if not isinstance(other, Device):
1162+
return NotImplemented
1163+
return self._id == other._id
1164+
11481165
def __reduce__(self):
11491166
return Device, (self.device_id,)
11501167

cuda_core/cuda/core/experimental/_event.pyx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,16 @@ cdef class Event:
165165
raise CUDAError(err)
166166
raise RuntimeError(explanation)
167167

168+
def __hash__(self) -> int:
169+
return hash((self._ctx_handle, <uintptr_t>(self._handle)))
170+
171+
def __eq__(self, other) -> bool:
172+
# Note: using isinstance because `Event` can be subclassed.
173+
if not isinstance(other, Event):
174+
return NotImplemented
175+
cdef Event _other = <Event>other
176+
return <uintptr_t>(self._handle) == <uintptr_t>(_other._handle)
177+
168178
def get_ipc_descriptor(self) -> IPCEventDescriptor:
169179
"""Export an event allocated for sharing between processes."""
170180
if self._ipc_descriptor is not None:

cuda_core/cuda/core/experimental/_stream.pyx

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,27 @@ cdef class Stream:
197197
"""Return an instance of a __cuda_stream__ protocol."""
198198
return (0, <uintptr_t>(self._handle))
199199

200+
def __hash__(self) -> int:
201+
# Ensure context is initialized for hash consistency
202+
if self._ctx_handle == CU_CONTEXT_INVALID:
203+
self._get_context()
204+
return hash((<uintptr_t>(self._ctx_handle), <uintptr_t>(self._handle)))
205+
206+
def __eq__(self, other) -> bool:
207+
if not isinstance(other, Stream):
208+
return NotImplemented
209+
cdef Stream _other = <Stream>other
210+
# Fast path: compare handles first
211+
if <uintptr_t>(self._handle) != <uintptr_t>((_other)._handle):
212+
return False
213+
# Ensure contexts are initialized for both streams
214+
if self._ctx_handle == CU_CONTEXT_INVALID:
215+
self._get_context()
216+
if _other._ctx_handle == CU_CONTEXT_INVALID:
217+
_other._get_context()
218+
# Compare contexts as well
219+
return <uintptr_t>(self._ctx_handle) == <uintptr_t>((_other)._ctx_handle)
220+
200221
@property
201222
def handle(self) -> cuda.bindings.driver.CUstream:
202223
"""Return the underlying ``CUstream`` object.

cuda_core/tests/test_comparable.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Tests for __eq__ and __ne__ implementations in cuda.core classes.
6+
7+
These tests verify multi-type equality behavior and subclassing equality behavior
8+
across Device, Stream, Event, and Context objects.
9+
"""
10+
11+
from cuda.core.experimental import Device, Stream
12+
from cuda.core.experimental._context import Context
13+
from cuda.core.experimental._event import Event, EventOptions
14+
from cuda.core.experimental._stream import StreamOptions
15+
16+
# ============================================================================
17+
# Equality Contract Tests
18+
# ============================================================================
19+
20+
21+
def test_equality_is_not_identity():
22+
"""Test that equality (==) is different from identity (is)."""
23+
device = Device(0)
24+
device.set_current()
25+
26+
# Streams: Different objects can be equal
27+
s1 = device.create_stream()
28+
s2 = Stream.from_handle(int(s1.handle))
29+
30+
assert s1 == s2, "Streams with same handle are equal"
31+
assert s1 is not s2, "But they are not the same object"
32+
33+
# Device: Same object due to singleton (special case)
34+
d1 = Device(0)
35+
d2 = Device(0)
36+
37+
assert d1 == d2, "Devices with same ID are equal"
38+
assert d1 is d2, "And they ARE the same object (singleton)"
39+
40+
41+
# ============================================================================
42+
# Subclassing Equality Tests
43+
# ============================================================================
44+
45+
46+
def test_device_subclass_equality(init_cuda):
47+
"""Test Device subclass equality behavior.
48+
49+
Device uses a singleton pattern where Device(0) always returns the same
50+
cached instance. This means subclassing Device doesn't create new instances;
51+
MyDevice(0) returns the original Device(0) instance from the cache.
52+
"""
53+
54+
class MyDevice(Device):
55+
pass
56+
57+
device = Device(0)
58+
device.set_current()
59+
my_device = MyDevice(0)
60+
61+
# Due to singleton pattern, both return the exact same instance
62+
assert device is my_device, "Device singleton returns same instance for same device_id"
63+
assert type(device) is Device, "Singleton returns original Device type, not subclass"
64+
assert type(my_device) is Device, "Even MyDevice(0) returns Device instance due to singleton"
65+
66+
# Since they're the same object, they're equal
67+
assert device == my_device
68+
69+
70+
def test_stream_subclass_equality(init_cuda):
71+
"""Test Stream subclass equality behavior.
72+
73+
Stream uses isinstance() for equality checking, which means a Stream instance
74+
and a MyStream subclass instance wrapping the same handle will compare equal.
75+
"""
76+
77+
class MyStream(Stream):
78+
pass
79+
80+
device = Device(0)
81+
device.set_current()
82+
83+
# Create base Stream instance
84+
stream = Stream._init(options=StreamOptions(), device_id=device.device_id)
85+
86+
# Create another Stream wrapping same handle
87+
stream2 = Stream.from_handle(int(stream.handle))
88+
assert stream == stream2, "Streams wrapping same handle are equal"
89+
90+
# Create subclass instance with different handle
91+
my_stream = MyStream._init(options=StreamOptions(), device_id=device.device_id)
92+
93+
# Different handles -> not equal
94+
assert stream != my_stream, "Streams with different handles are not equal"
95+
assert stream.handle != my_stream.handle
96+
97+
# sanity check: base and subclass compare equal (and hash equal)
98+
stream_from_handle = MyStream.from_handle(int(my_stream.handle))
99+
assert my_stream == stream_from_handle, "MyStream and Stream wrapping same handle compare equal"
100+
assert hash(my_stream) == hash(stream_from_handle)
101+
102+
103+
def test_event_subclass_equality(init_cuda):
104+
"""Test Event subclass equality behavior.
105+
106+
Event uses isinstance() for equality checking, similar to Stream.
107+
"""
108+
109+
class MyEvent(Event):
110+
pass
111+
112+
device = Device(0)
113+
device.set_current()
114+
115+
# Create two different events
116+
event = Event._init(device.device_id, device.context, options=EventOptions())
117+
my_event = MyEvent._init(device.device_id, device.context, options=EventOptions())
118+
119+
# Different events should not be equal (different handles)
120+
assert event != my_event, "Different Event instances are not equal"
121+
122+
# Same subclass type with different handles
123+
my_event2 = MyEvent._init(device.device_id, device.context, options=EventOptions())
124+
assert my_event != my_event2, "Different MyEvent instances are not equal"
125+
126+
127+
def test_context_subclass_equality(init_cuda):
128+
"""Test Context subclass equality behavior."""
129+
130+
class MyContext(Context):
131+
pass
132+
133+
device = Device(0)
134+
device.set_current()
135+
stream = device.create_stream()
136+
context = stream.context
137+
138+
# MyContext._from_ctx() returns a Context instance, not MyContext
139+
my_context = MyContext._from_ctx(context._handle, device.device_id)
140+
assert type(my_context) is Context, "_from_ctx returns Context, not subclass"
141+
assert type(my_context) is not MyContext
142+
143+
# Since both are Context instances with same handle, they're equal
144+
assert context == my_context, "Context instances with same handle are equal"
145+
146+
# Create another context from different stream
147+
stream2 = device.create_stream()
148+
context2 = stream2.context
149+
150+
# Same device, same primary context, should be equal
151+
assert context == context2, "Contexts from same device are equal"
152+
153+
154+
def test_subclass_type_safety(init_cuda):
155+
"""Test that equality checks with incompatible types return False or NotImplemented."""
156+
device = Device(0)
157+
device.set_current()
158+
159+
stream = device.create_stream()
160+
event = stream.record()
161+
context = stream.context
162+
163+
# None of these should be equal to each other
164+
assert device != stream
165+
assert device != event
166+
assert device != context
167+
assert stream != event
168+
assert stream != context
169+
assert event != context
170+
171+
# None should be equal to arbitrary types
172+
assert device != "device"
173+
assert stream != 123
174+
assert event != []
175+
assert context != {"key": "value"}

cuda_core/tests/test_context.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,88 @@
33

44
import cuda.core.experimental
55
import pytest
6+
from cuda.core.experimental import Device
67

78

89
def test_context_init_disabled():
910
with pytest.raises(RuntimeError, match=r"^Context objects cannot be instantiated directly\."):
1011
cuda.core.experimental._context.Context() # Ensure back door is locked.
12+
13+
14+
# ============================================================================
15+
# Context Equality Tests
16+
# ============================================================================
17+
18+
19+
def test_context_equality_same_context(init_cuda):
20+
"""Contexts from same device should be equal."""
21+
device = Device()
22+
23+
s1 = device.create_stream()
24+
s2 = device.create_stream()
25+
26+
ctx1 = s1.context
27+
ctx2 = s2.context
28+
29+
# Same device, should have same context
30+
assert ctx1 == ctx2, "Streams on same device should share context"
31+
32+
33+
def test_context_equality_reflexive(init_cuda):
34+
"""Context should equal itself (reflexive property)."""
35+
device = Device()
36+
stream = device.create_stream()
37+
context = stream.context
38+
39+
assert context == context, "Context should equal itself"
40+
41+
42+
def test_context_type_safety(init_cuda):
43+
"""Comparing Context with wrong type should return False."""
44+
device = Device()
45+
context = device.create_stream().context
46+
47+
assert (context == "not a context") is False
48+
assert (context == 123) is False
49+
assert (context is None) is False
50+
51+
52+
# ============================================================================
53+
# Context Hash Tests
54+
# ============================================================================
55+
56+
57+
def test_context_hash_consistency(init_cuda):
58+
"""Hash of same Context object should be consistent."""
59+
device = Device()
60+
stream = device.create_stream()
61+
context = stream.context
62+
63+
hash1 = hash(context)
64+
hash2 = hash(context)
65+
assert hash1 == hash2, "Hash should be consistent for same object"
66+
67+
68+
def test_context_hash_equality(init_cuda):
69+
"""Contexts from same device should hash equal."""
70+
device = Device()
71+
72+
s1 = device.create_stream()
73+
s2 = device.create_stream()
74+
75+
ctx1 = s1.context
76+
ctx2 = s2.context
77+
78+
# Same device, should have same context
79+
assert ctx1 == ctx2, "Streams on same device should share context"
80+
assert hash(ctx1) == hash(ctx2), "Same context should hash equal"
81+
82+
83+
def test_context_dict_key(init_cuda):
84+
"""Contexts should be usable as dictionary keys."""
85+
device = Device()
86+
stream = device.create_stream()
87+
context = stream.context
88+
89+
ctx_cache = {context: "context_data"}
90+
assert ctx_cache[context] == "context_data"

0 commit comments

Comments
 (0)