Skip to content

Commit fec95b8

Browse files
authored
Make Device.set_current() faster (#781)
* cache primary context * avoid increasing stack size * unconditionally set primary context to current
1 parent 382f49b commit fec95b8

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

cuda_core/cuda/core/experimental/_device.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,18 @@ def _check_context_initialized(self):
10221022
f"Device {self._id} is not yet initialized, perhaps you forgot to call .set_current() first?"
10231023
)
10241024

1025+
def _get_primary_context(self) -> driver.CUcontext:
1026+
try:
1027+
primary_ctxs = _tls.primary_ctxs
1028+
except AttributeError:
1029+
total = len(_tls.devices)
1030+
primary_ctxs = _tls.primary_ctxs = [None] * total
1031+
ctx = primary_ctxs[self._id]
1032+
if ctx is None:
1033+
ctx = handle_return(driver.cuDevicePrimaryCtxRetain(self._id))
1034+
primary_ctxs[self._id] = ctx
1035+
return ctx
1036+
10251037
def _get_current_context(self, check_consistency=False) -> driver.CUcontext:
10261038
err, ctx = driver.cuCtxGetCurrent()
10271039

@@ -1186,20 +1198,9 @@ def set_current(self, ctx: Context = None) -> Union[Context, None]:
11861198
if int(prev_ctx) != 0:
11871199
return Context._from_ctx(prev_ctx, self._id)
11881200
else:
1189-
ctx = handle_return(driver.cuCtxGetCurrent())
1190-
if int(ctx) == 0:
1191-
# use primary ctx
1192-
ctx = handle_return(driver.cuDevicePrimaryCtxRetain(self._id))
1193-
handle_return(driver.cuCtxPushCurrent(ctx))
1194-
else:
1195-
ctx_id = handle_return(driver.cuCtxGetDevice())
1196-
if ctx_id != self._id:
1197-
# use primary ctx
1198-
ctx = handle_return(driver.cuDevicePrimaryCtxRetain(self._id))
1199-
handle_return(driver.cuCtxPushCurrent(ctx))
1200-
else:
1201-
# no-op, a valid context already exists and is set current
1202-
pass
1201+
# use primary ctx
1202+
ctx = self._get_primary_context()
1203+
handle_return(driver.cuCtxSetCurrent(ctx))
12031204
self._has_inited = True
12041205

12051206
def create_context(self, options: ContextOptions = None) -> Context:

0 commit comments

Comments
 (0)