@@ -1022,6 +1022,18 @@ def _check_context_initialized(self):
10221022 f"Device { self ._id } is not yet initialized, perhaps you forgot to call .set_current() first?"
10231023 )
10241024
1025+ def _get_primary_context (self ) -> driver .CUcontext :
1026+ try :
1027+ primary_ctxs = _tls .primary_ctxs
1028+ except AttributeError :
1029+ total = len (_tls .devices )
1030+ primary_ctxs = _tls .primary_ctxs = [None ] * total
1031+ ctx = primary_ctxs [self ._id ]
1032+ if ctx is None :
1033+ ctx = handle_return (driver .cuDevicePrimaryCtxRetain (self ._id ))
1034+ primary_ctxs [self ._id ] = ctx
1035+ return ctx
1036+
10251037 def _get_current_context (self , check_consistency = False ) -> driver .CUcontext :
10261038 err , ctx = driver .cuCtxGetCurrent ()
10271039
@@ -1186,20 +1198,9 @@ def set_current(self, ctx: Context = None) -> Union[Context, None]:
11861198 if int (prev_ctx ) != 0 :
11871199 return Context ._from_ctx (prev_ctx , self ._id )
11881200 else :
1189- ctx = handle_return (driver .cuCtxGetCurrent ())
1190- if int (ctx ) == 0 :
1191- # use primary ctx
1192- ctx = handle_return (driver .cuDevicePrimaryCtxRetain (self ._id ))
1193- handle_return (driver .cuCtxPushCurrent (ctx ))
1194- else :
1195- ctx_id = handle_return (driver .cuCtxGetDevice ())
1196- if ctx_id != self ._id :
1197- # use primary ctx
1198- ctx = handle_return (driver .cuDevicePrimaryCtxRetain (self ._id ))
1199- handle_return (driver .cuCtxPushCurrent (ctx ))
1200- else :
1201- # no-op, a valid context already exists and is set current
1202- pass
1201+ # use primary ctx
1202+ ctx = self ._get_primary_context ()
1203+ handle_return (driver .cuCtxSetCurrent (ctx ))
12031204 self ._has_inited = True
12041205
12051206 def create_context (self , options : ContextOptions = None ) -> Context :
0 commit comments