@@ -11,6 +11,8 @@ from cuda.bindings cimport cydriver
1111from cuda.core.experimental._event cimport Event as cyEvent
1212from cuda.core.experimental._utils.cuda_utils cimport (
1313 check_or_create_options,
14+ CU_CONTEXT_INVALID,
15+ get_device_from_ctx,
1416 HANDLE_RETURN,
1517)
1618
@@ -29,8 +31,6 @@ from cuda.core.experimental._graph import GraphBuilder
2931from cuda.core.experimental._utils.clear_error_support import assert_type
3032from cuda.core.experimental._utils.cuda_utils import (
3133 driver,
32- get_device_from_ctx,
33- handle_return,
3434)
3535
3636
@@ -117,11 +117,13 @@ cdef class Stream:
117117 object _builtin
118118 object _nonblocking
119119 object _priority
120- object _device_id
121- object _ctx_handle
120+ cydriver.CUdevice _device_id
121+ cydriver.CUcontext _ctx_handle
122122
123123 def __cinit__ (self ):
124124 self ._handle = < cydriver.CUstream> (NULL )
125+ self ._device_id = cydriver.CU_DEVICE_INVALID # delayed
126+ self ._ctx_handle = CU_CONTEXT_INVALID # delayed
125127
126128 def __init__ (self , *args , **kwargs ):
127129 raise RuntimeError (
@@ -137,8 +139,6 @@ cdef class Stream:
137139 self ._builtin = True
138140 self ._nonblocking = None # delayed
139141 self ._priority = None # delayed
140- self ._device_id = None # delayed
141- self ._ctx_handle = None # delayed
142142 return self
143143
144144 @classmethod
@@ -149,8 +149,6 @@ cdef class Stream:
149149 self ._builtin = True
150150 self ._nonblocking = None # delayed
151151 self ._priority = None # delayed
152- self ._device_id = None # delayed
153- self ._ctx_handle = None # delayed
154152 return self
155153
156154 @classmethod
@@ -167,8 +165,6 @@ cdef class Stream:
167165 self ._owner = obj
168166 self ._nonblocking = None # delayed
169167 self ._priority = None # delayed
170- self ._device_id = None # delayed
171- self ._ctx_handle = None # delayed
172168 return self
173169
174170 cdef StreamOptions opts = check_or_create_options(StreamOptions, options, " Stream options" )
@@ -195,8 +191,7 @@ cdef class Stream:
195191 self ._owner = None
196192 self ._nonblocking = nonblocking
197193 self ._priority = priority
198- self ._device_id = device_id
199- self ._ctx_handle = None # delayed
194+ self ._device_id = device_id if device_id is not None else self ._device_id
200195 return self
201196
202197 def __del__ (self ):
@@ -284,7 +279,7 @@ cdef class Stream:
284279 # and CU_EVENT_RECORD_EXTERNAL , can be set in EventOptions.
285280 if event is None:
286281 self._get_device_and_context()
287- event = Event._init(self ._device_id, self ._ctx_handle, options)
282+ event = Event._init(< int > ( self ._device_id), < uintptr_t > ( self ._ctx_handle) , options)
288283 cdef cydriver.CUevent e = (< cyEvent?> (event))._handle
289284 with nogil:
290285 HANDLE_RETURN(cydriver.cuEventRecord(e , self._handle ))
@@ -340,30 +335,31 @@ cdef class Stream:
340335 """
341336 from cuda.core.experimental._device import Device # avoid circular import
342337 self._get_device_and_context()
343- return Device(self._device_id )
338+ return Device(<int>( self._device_id ) )
344339
345- cdef int _get_context(Stream self ) except?-1:
346- # TODO: consider making self._ctx_handle typed?
347- cdef cydriver.CUcontext ctx
348- if self._ctx_handle is None:
349- with nogil:
350- HANDLE_RETURN(cydriver.cuStreamGetCtx(self._handle , &ctx ))
351- self._ctx_handle = driver.CUcontext(< uintptr_t> ctx)
340+ cdef int _get_context(self ) except?-1 nogil:
341+ if self._ctx_handle == CU_CONTEXT_INVALID:
342+ HANDLE_RETURN(cydriver.cuStreamGetCtx(self._handle , &(self._ctx_handle )))
352343 return 0
353344
354- cdef int _get_device_and_context(Stream self ) except?-1:
355- if self._device_id is None:
356- # Get the stream context first
357- self._get_context()
358- self._device_id = get_device_from_ctx(self ._ctx_handle)
345+ cdef int _get_device_and_context(self ) except?-1:
346+ cdef cydriver.CUcontext curr_ctx
347+ if self._device_id == cydriver.CU_DEVICE_INVALID:
348+ # TODO: It is likely faster/safer to call cuCtxGetCurrent?
349+ from cuda.core.experimental._device import Device # avoid circular import
350+ curr_ctx = < cydriver.CUcontext>< uintptr_t> (Device().context._handle)
351+ with nogil:
352+ # Get the stream context first
353+ self._get_context()
354+ self._device_id = get_device_from_ctx(self ._ctx_handle, curr_ctx)
359355 return 0
360356
361357 @property
362358 def context(self ) -> Context:
363359 """Return the :obj:`~_context.Context` associated with this stream."""
364360 self._get_context()
365361 self._get_device_and_context()
366- return Context._from_ctx(self._ctx_handle , self._device_id )
362+ return Context._from_ctx(<uintptr_t>( self._ctx_handle ), <int>( self._device_id ) )
367363
368364 @staticmethod
369365 def from_handle(handle: int ) -> Stream:
0 commit comments