Skip to content

Commit c06dc9d

Browse files
committed
reduce further the number of Python objects held by Stream
1 parent 1083d90 commit c06dc9d

File tree

4 files changed

+49
-40
lines changed

4 files changed

+49
-40
lines changed

cuda_core/cuda/core/experimental/_device.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ from cuda.core.experimental._utils.cuda_utils import (
2727
)
2828

2929

30+
# TODO: I prefer to type these as "cdef object" and avoid accessing them from within Python,
31+
# but it seems it is very convenient to expose them for testing purposes...
3032
_tls = threading.local()
3133
_lock = threading.Lock()
3234
cdef bint _is_cuInit = False

cuda_core/cuda/core/experimental/_stream.pyx

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ from cuda.bindings cimport cydriver
1111
from cuda.core.experimental._event cimport Event as cyEvent
1212
from cuda.core.experimental._utils.cuda_utils cimport (
1313
check_or_create_options,
14+
CU_CONTEXT_INVALID,
15+
get_device_from_ctx,
1416
HANDLE_RETURN,
1517
)
1618

@@ -29,8 +31,6 @@ from cuda.core.experimental._graph import GraphBuilder
2931
from cuda.core.experimental._utils.clear_error_support import assert_type
3032
from cuda.core.experimental._utils.cuda_utils import (
3133
driver,
32-
get_device_from_ctx,
33-
handle_return,
3434
)
3535

3636

@@ -117,11 +117,13 @@ cdef class Stream:
117117
object _builtin
118118
object _nonblocking
119119
object _priority
120-
object _device_id
121-
object _ctx_handle
120+
cydriver.CUdevice _device_id
121+
cydriver.CUcontext _ctx_handle
122122

123123
def __cinit__(self):
124124
self._handle = <cydriver.CUstream>(NULL)
125+
self._device_id = cydriver.CU_DEVICE_INVALID # delayed
126+
self._ctx_handle = CU_CONTEXT_INVALID # delayed
125127

126128
def __init__(self, *args, **kwargs):
127129
raise RuntimeError(
@@ -137,8 +139,6 @@ cdef class Stream:
137139
self._builtin = True
138140
self._nonblocking = None # delayed
139141
self._priority = None # delayed
140-
self._device_id = None # delayed
141-
self._ctx_handle = None # delayed
142142
return self
143143

144144
@classmethod
@@ -149,8 +149,6 @@ cdef class Stream:
149149
self._builtin = True
150150
self._nonblocking = None # delayed
151151
self._priority = None # delayed
152-
self._device_id = None # delayed
153-
self._ctx_handle = None # delayed
154152
return self
155153

156154
@classmethod
@@ -167,8 +165,6 @@ cdef class Stream:
167165
self._owner = obj
168166
self._nonblocking = None # delayed
169167
self._priority = None # delayed
170-
self._device_id = None # delayed
171-
self._ctx_handle = None # delayed
172168
return self
173169

174170
cdef StreamOptions opts = check_or_create_options(StreamOptions, options, "Stream options")
@@ -195,8 +191,7 @@ cdef class Stream:
195191
self._owner = None
196192
self._nonblocking = nonblocking
197193
self._priority = priority
198-
self._device_id = device_id
199-
self._ctx_handle = None # delayed
194+
self._device_id = device_id if device_id is not None else self._device_id
200195
return self
201196

202197
def __del__(self):
@@ -284,7 +279,7 @@ cdef class Stream:
284279
# and CU_EVENT_RECORD_EXTERNAL, can be set in EventOptions.
285280
if event is None:
286281
self._get_device_and_context()
287-
event = Event._init(self._device_id, self._ctx_handle, options)
282+
event = Event._init(<int>(self._device_id), <uintptr_t>(self._ctx_handle), options)
288283
cdef cydriver.CUevent e = (<cyEvent?>(event))._handle
289284
with nogil:
290285
HANDLE_RETURN(cydriver.cuEventRecord(e, self._handle))
@@ -340,30 +335,31 @@ cdef class Stream:
340335
"""
341336
from cuda.core.experimental._device import Device # avoid circular import
342337
self._get_device_and_context()
343-
return Device(self._device_id)
338+
return Device(<int>(self._device_id))
344339

345-
cdef int _get_context(Stream self) except?-1:
346-
# TODO: consider making self._ctx_handle typed?
347-
cdef cydriver.CUcontext ctx
348-
if self._ctx_handle is None:
349-
with nogil:
350-
HANDLE_RETURN(cydriver.cuStreamGetCtx(self._handle, &ctx))
351-
self._ctx_handle = driver.CUcontext(<uintptr_t>ctx)
340+
cdef int _get_context(self) except?-1 nogil:
341+
if self._ctx_handle == CU_CONTEXT_INVALID:
342+
HANDLE_RETURN(cydriver.cuStreamGetCtx(self._handle, &(self._ctx_handle)))
352343
return 0
353344

354-
cdef int _get_device_and_context(Stream self) except?-1:
355-
if self._device_id is None:
356-
# Get the stream context first
357-
self._get_context()
358-
self._device_id = get_device_from_ctx(self._ctx_handle)
345+
cdef int _get_device_and_context(self) except?-1:
346+
cdef cydriver.CUcontext curr_ctx
347+
if self._device_id == cydriver.CU_DEVICE_INVALID:
348+
# TODO: It is likely faster/safer to call cuCtxGetCurrent?
349+
from cuda.core.experimental._device import Device # avoid circular import
350+
curr_ctx = <cydriver.CUcontext><uintptr_t>(Device().context._handle)
351+
with nogil:
352+
# Get the stream context first
353+
self._get_context()
354+
self._device_id = get_device_from_ctx(self._ctx_handle, curr_ctx)
359355
return 0
360356

361357
@property
362358
def context(self) -> Context:
363359
"""Return the :obj:`~_context.Context` associated with this stream."""
364360
self._get_context()
365361
self._get_device_and_context()
366-
return Context._from_ctx(self._ctx_handle, self._device_id)
362+
return Context._from_ctx(<uintptr_t>(self._ctx_handle), <int>(self._device_id))
367363

368364
@staticmethod
369365
def from_handle(handle: int) -> Stream:

cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ ctypedef fused supported_error_type:
1212
cydriver.CUresult
1313

1414

15+
# mimic CU_DEVICE_INVALID
16+
cdef cydriver.CUcontext CU_CONTEXT_INVALID = <cydriver.CUcontext>(-2)
17+
18+
19+
cdef cydriver.CUdevice get_device_from_ctx(
20+
cydriver.CUcontext target_ctx, cydriver.CUcontext curr_ctx) except?cydriver.CU_DEVICE_INVALID nogil
21+
22+
1523
cdef int HANDLE_RETURN(supported_error_type err) except?-1 nogil
1624

1725

cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -192,20 +192,23 @@ def precondition(checker: Callable[..., None], str what="") -> Callable:
192192
return outer
193193

194194

195-
def get_device_from_ctx(ctx_handle) -> int:
195+
cdef cydriver.CUdevice get_device_from_ctx(
196+
cydriver.CUcontext target_ctx, cydriver.CUcontext curr_ctx) except?cydriver.CU_DEVICE_INVALID nogil:
196197
"""Get device ID from the given ctx."""
197-
from cuda.core.experimental._device import Device # avoid circular import
198-
199-
prev_ctx = Device().context._handle
200-
switch_context = int(ctx_handle) != int(prev_ctx)
201-
if switch_context:
202-
assert prev_ctx == handle_return(driver.cuCtxPopCurrent())
203-
handle_return(driver.cuCtxPushCurrent(ctx_handle))
204-
device_id = int(handle_return(driver.cuCtxGetDevice()))
205-
if switch_context:
206-
assert ctx_handle == handle_return(driver.cuCtxPopCurrent())
207-
handle_return(driver.cuCtxPushCurrent(prev_ctx))
208-
return device_id
198+
cdef bint switch_context = (curr_ctx != target_ctx)
199+
cdef cydriver.CUcontext ctx
200+
cdef cydriver.CUdevice target_dev
201+
with nogil:
202+
if switch_context:
203+
HANDLE_RETURN(cydriver.cuCtxPopCurrent(&ctx))
204+
assert curr_ctx == ctx
205+
HANDLE_RETURN(cydriver.cuCtxPushCurrent(target_ctx))
206+
HANDLE_RETURN(cydriver.cuCtxGetDevice(&target_dev))
207+
if switch_context:
208+
HANDLE_RETURN(cydriver.cuCtxPopCurrent(&ctx))
209+
assert target_ctx == ctx
210+
HANDLE_RETURN(cydriver.cuCtxPushCurrent(curr_ctx))
211+
return target_dev
209212

210213

211214
def is_sequence(obj):

0 commit comments

Comments
 (0)