Skip to content

Commit 264524b

Browse files
authored
Feature/public default stream methods (#1483)
* feat: Make legacy_default and per_thread_default public - Fixes #1445 Signed-off-by: Monishver Chandrasekaran <monishver@Monishvers-MacBook-Air.local> * feat: Make legacy_default and per_thread_default public - Fixes #1445 Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com> * added init_cuda for test_stream methods --------- Signed-off-by: Monishver Chandrasekaran <monishver@Monishvers-MacBook-Air.local> Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
1 parent 7ef7873 commit 264524b

File tree

2 files changed

+60
-8
lines changed

2 files changed

+60
-8
lines changed

cuda_core/cuda/core/_stream.pyx

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,49 @@ cdef class Stream:
107107
return s
108108

109109
@classmethod
110-
def _legacy_default(cls):
111-
"""Return the legacy default stream (supports subclassing)."""
110+
def legacy_default(cls):
111+
"""Return the legacy default stream.
112+
113+
The legacy default stream is an implicit stream which synchronizes
114+
with all other streams in the same CUDA context except for non-blocking
115+
streams. When any operation is launched on the legacy default stream,
116+
it waits for all previously launched operations in blocking streams to
117+
complete, and all subsequent operations in blocking streams wait for
118+
the legacy default stream operation to complete.
119+
120+
Returns
121+
-------
122+
Stream
123+
The legacy default stream instance for the current context.
124+
125+
See Also
126+
--------
127+
per_thread_default : Per-thread default stream alternative.
128+
129+
"""
112130
return Stream._from_handle(cls, get_legacy_stream())
113131

114132
@classmethod
115-
def _per_thread_default(cls):
116-
"""Return the per-thread default stream (supports subclassing)."""
133+
def per_thread_default(cls):
134+
"""Return the per-thread default stream.
135+
136+
The per-thread default stream is local to both the calling thread and
137+
the CUDA context. Unlike the legacy default stream, it does not
138+
synchronize with other streams and behaves like an explicitly created
139+
non-blocking stream. This allows for better concurrency in multi-threaded
140+
applications.
141+
142+
Returns
143+
-------
144+
Stream
145+
The per-thread default stream instance for the current thread
146+
and context.
147+
148+
See Also
149+
--------
150+
legacy_default : Legacy default stream alternative.
151+
152+
"""
117153
return Stream._from_handle(cls, get_per_thread_stream())
118154

119155
@classmethod
@@ -378,8 +414,8 @@ cdef class Stream:
378414

379415

380416
# c-only python objects, not public
381-
cdef Stream C_LEGACY_DEFAULT_STREAM = Stream._legacy_default()
382-
cdef Stream C_PER_THREAD_DEFAULT_STREAM = Stream._per_thread_default()
417+
cdef Stream C_LEGACY_DEFAULT_STREAM = Stream.legacy_default()
418+
cdef Stream C_PER_THREAD_DEFAULT_STREAM = Stream.per_thread_default()
383419

384420
# standard python objects, public
385421
LEGACY_DEFAULT_STREAM = C_LEGACY_DEFAULT_STREAM

cuda_core/tests/test_stream.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,34 @@ def test_stream_legacy_default_subclassing():
117117
class MyStream(Stream):
118118
pass
119119

120-
stream = MyStream._legacy_default()
120+
stream = MyStream.legacy_default()
121121
assert isinstance(stream, MyStream)
122122

123123

124124
def test_stream_per_thread_default_subclassing():
125125
class MyStream(Stream):
126126
pass
127127

128-
stream = MyStream._per_thread_default()
128+
stream = MyStream.per_thread_default()
129129
assert isinstance(stream, MyStream)
130130

131131

132+
def test_stream_legacy_default_public_api(init_cuda):
133+
"""Test public legacy_default() method."""
134+
stream = Stream.legacy_default()
135+
assert isinstance(stream, Stream)
136+
# Verify it's the same as LEGACY_DEFAULT_STREAM
137+
assert stream == LEGACY_DEFAULT_STREAM
138+
139+
140+
def test_stream_per_thread_default_public_api(init_cuda):
141+
"""Test public per_thread_default() method."""
142+
stream = Stream.per_thread_default()
143+
assert isinstance(stream, Stream)
144+
# Verify it's the same as PER_THREAD_DEFAULT_STREAM
145+
assert stream == PER_THREAD_DEFAULT_STREAM
146+
147+
132148
# ============================================================================
133149
# Stream Equality Tests
134150
# ============================================================================

0 commit comments

Comments
 (0)