Skip to content

Commit 7c35ef3

Browse files
committed
Simplify green context view handles
1 parent faf0d17 commit 7c35ef3

8 files changed

Lines changed: 22 additions & 88 deletions

File tree

cuda_core/cuda/core/_context.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,4 @@ cdef class Context:
2424
@staticmethod
2525
cdef Context _from_green_ctx(type cls, GreenCtxHandle h_green_ctx, int device_id)
2626

27-
cdef int _ensure_context_handle(self) except -1
28-
2927
cpdef close(self)

cuda_core/cuda/core/_context.pyx

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ from cuda.core._resource_handles cimport (
1010
GreenCtxHandle,
1111
as_cu,
1212
create_context_handle_from_green_ctx,
13-
ensure_context_handle,
1413
get_context_green_ctx,
1514
get_last_error,
1615
as_intptr,
@@ -48,29 +47,18 @@ cdef class Context:
4847
cdef Context ctx = cls.__new__(cls)
4948
ctx._h_green_ctx = h_green_ctx
5049
ctx._h_context = create_context_handle_from_green_ctx(h_green_ctx)
50+
if not ctx._h_context:
51+
HANDLE_RETURN(get_last_error())
52+
raise RuntimeError("Failed to create CUDA context view from green context")
5153
ctx._device_id = device_id
5254
ctx._is_green = True
5355
return ctx
5456

55-
cdef int _ensure_context_handle(self) except -1:
56-
cdef cydriver.CUcontext raw_ctx
57-
if not self._h_context:
58-
return 0
59-
if as_cu(self._h_context) != NULL:
60-
return 0
61-
with nogil:
62-
raw_ctx = ensure_context_handle(self._h_context)
63-
if raw_ctx == NULL:
64-
HANDLE_RETURN(get_last_error())
65-
raise RuntimeError("Failed to materialize CUDA context from green context")
66-
return 0
67-
6857
@property
6958
def handle(self):
7059
"""Return the underlying CUcontext handle."""
7160
if not self._h_context:
7261
return None
73-
self._ensure_context_handle()
7462
if as_cu(self._h_context) == NULL:
7563
return None
7664
return as_py(self._h_context)
@@ -102,16 +90,12 @@ cdef class Context:
10290
if not isinstance(other, Context):
10391
return NotImplemented
10492
cdef Context _other = <Context>other
105-
self._ensure_context_handle()
106-
_other._ensure_context_handle()
10793
return as_intptr(self._h_context) == as_intptr(_other._h_context)
10894

10995
def __hash__(self) -> int:
110-
self._ensure_context_handle()
11196
return hash(as_intptr(self._h_context))
11297

11398
def __repr__(self) -> str:
114-
self._ensure_context_handle()
11599
return f"<Context handle={as_intptr(self._h_context):#x} device={self._device_id}>"
116100

117101

cuda_core/cuda/core/_cpp/resource_handles.cpp

Lines changed: 16 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include <mutex>
1212
#include <stdexcept>
1313
#include <unordered_map>
14-
#include <utility>
1514
#include <vector>
1615

1716
#ifndef _WIN32
@@ -227,12 +226,8 @@ void clear_last_error() noexcept {
227226

228227
namespace {
229228
struct ContextBox {
230-
mutable CUcontext resource;
229+
CUcontext resource;
231230
GreenCtxHandle h_green_ctx;
232-
mutable std::mutex mutex;
233-
234-
explicit ContextBox(CUcontext resource, GreenCtxHandle h_green_ctx = {})
235-
: resource(resource), h_green_ctx(std::move(h_green_ctx)) {}
236231
};
237232

238233
struct GreenCtxBox {
@@ -262,7 +257,7 @@ ContextHandle create_context_handle_ref(CUcontext ctx) {
262257
return h;
263258
}
264259
auto box = std::shared_ptr<const ContextBox>(
265-
new ContextBox(ctx),
260+
new ContextBox{ctx, {}},
266261
[](const ContextBox* b) {
267262
context_registry.unregister_handle(b->resource);
268263
delete b;
@@ -273,57 +268,31 @@ ContextHandle create_context_handle_ref(CUcontext ctx) {
273268
return h;
274269
}
275270

276-
static const GreenCtxBox* get_box(const GreenCtxHandle& h) noexcept {
277-
const CUgreenCtx* p = h.get();
278-
return reinterpret_cast<const GreenCtxBox*>(
279-
reinterpret_cast<const char*>(p) - offsetof(GreenCtxBox, resource)
280-
);
281-
}
282-
283271
ContextHandle create_context_handle_from_green_ctx(const GreenCtxHandle& h_green_ctx) {
284272
if (!h_green_ctx) {
285273
return {};
286274
}
287-
auto box = std::shared_ptr<const ContextBox>(
288-
new ContextBox(nullptr, h_green_ctx),
289-
[](const ContextBox* b) {
290-
if (b->resource) {
291-
context_registry.unregister_handle(b->resource);
292-
}
293-
delete b;
294-
}
295-
);
296-
return ContextHandle(box, &box->resource);
297-
}
298-
299-
CUcontext ensure_context_handle(const ContextHandle& h) noexcept {
300-
if (!h) {
301-
err = CUDA_ERROR_INVALID_CONTEXT;
302-
return nullptr;
303-
}
304-
305-
const ContextBox* box = get_box(h);
306-
std::lock_guard<std::mutex> lock(box->mutex);
307-
if (box->resource) {
308-
return box->resource;
309-
}
310-
if (!box->h_green_ctx) {
311-
err = CUDA_ERROR_INVALID_CONTEXT;
312-
return nullptr;
313-
}
314275
if (!p_cuCtxFromGreenCtx) {
315276
err = CUDA_ERROR_NOT_SUPPORTED;
316-
return nullptr;
277+
return {};
317278
}
318279

319280
GILReleaseGuard gil;
320281
CUcontext ctx = nullptr;
321-
if (CUDA_SUCCESS != (err = p_cuCtxFromGreenCtx(&ctx, as_cu(box->h_green_ctx)))) {
322-
return nullptr;
282+
if (CUDA_SUCCESS != (err = p_cuCtxFromGreenCtx(&ctx, as_cu(h_green_ctx)))) {
283+
return {};
323284
}
324-
box->resource = ctx;
285+
286+
auto box = std::shared_ptr<const ContextBox>(
287+
new ContextBox{ctx, h_green_ctx},
288+
[](const ContextBox* b) {
289+
context_registry.unregister_handle(b->resource);
290+
delete b;
291+
}
292+
);
293+
ContextHandle h(box, &box->resource);
325294
context_registry.register_handle(ctx, h);
326-
return ctx;
295+
return h;
327296
}
328297

329298
GreenCtxHandle get_context_green_ctx(const ContextHandle& h) noexcept {
@@ -399,7 +368,7 @@ ContextHandle get_primary_context(int device_id) {
399368
}
400369

401370
auto box = std::shared_ptr<const ContextBox>(
402-
new ContextBox(ctx),
371+
new ContextBox{ctx, {}},
403372
[device_id](const ContextBox* b) {
404373
context_registry.unregister_handle(b->resource);
405374
GILReleaseGuard gil;

cuda_core/cuda/core/_cpp/resource_handles.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,11 @@ using FileDescriptorHandle = std::shared_ptr<const int>;
170170
// Function to create a non-owning context handle (references existing context).
171171
ContextHandle create_context_handle_ref(CUcontext ctx);
172172

173-
// Create a context handle whose CUcontext view is lazily materialized from
174-
// the provided green context. The returned ContextHandle keeps the green
175-
// context alive.
173+
// Create a context handle for the CUcontext view of the provided green context.
174+
// The returned ContextHandle keeps the green context alive, but the CUcontext
175+
// view is non-owning and is not destroyed independently.
176176
ContextHandle create_context_handle_from_green_ctx(const GreenCtxHandle& h_green_ctx);
177177

178-
// Ensure a ContextHandle has a materialized CUcontext value. For green-context
179-
// views this calls cuCtxFromGreenCtx once and caches the non-owning CUcontext.
180-
CUcontext ensure_context_handle(const ContextHandle& h) noexcept;
181-
182178
// Return the green context dependency associated with a ContextHandle, if any.
183179
GreenCtxHandle get_context_green_ctx(const ContextHandle& h) noexcept;
184180

cuda_core/cuda/core/_device.pyx

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,7 +1262,6 @@ class Device:
12621262
if self._has_inited and self._context is not None:
12631263
prev_owned = self._context
12641264
# prev_ctx is the previous context
1265-
ctx._ensure_context_handle()
12661265
curr_ctx = as_cu(ctx._h_context)
12671266
prev_ctx = NULL
12681267
with nogil:
@@ -1271,8 +1270,6 @@ class Device:
12711270
self._has_inited = True
12721271
self._context = ctx # Store owning context reference
12731272
if prev_ctx != NULL:
1274-
if prev_owned is not None:
1275-
prev_owned._ensure_context_handle()
12761273
if prev_owned is not None and as_cu(prev_owned._h_context) == prev_ctx:
12771274
return prev_owned
12781275
return Context._from_handle(Context, create_context_handle_ref(prev_ctx), self._device_id)
@@ -1418,7 +1415,6 @@ class Device:
14181415
"""
14191416
self._check_context_initialized()
14201417
cdef Context ctx = self._context
1421-
ctx._ensure_context_handle()
14221418
return cyEvent._init(cyEvent, self._device_id, ctx._h_context, options, True)
14231419

14241420
def allocate(self, size, stream: Stream | GraphBuilder | None = None) -> Buffer:

cuda_core/cuda/core/_resource_handles.pxd

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ cdef void clear_last_error() noexcept nogil
116116
# Context handles
117117
cdef ContextHandle create_context_handle_ref(cydriver.CUcontext ctx) except+ nogil
118118
cdef ContextHandle create_context_handle_from_green_ctx(const GreenCtxHandle& h_green_ctx) except+ nogil
119-
cdef cydriver.CUcontext ensure_context_handle(const ContextHandle& h) noexcept nogil
120119
cdef GreenCtxHandle get_context_green_ctx(const ContextHandle& h) noexcept nogil
121120
cdef GreenCtxHandle create_green_ctx_handle(
122121
cydriver.CUdevResourceDesc desc, cydriver.CUdevice dev, unsigned int flags) except+ nogil

cuda_core/cuda/core/_resource_handles.pyx

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
5959
cydriver.CUcontext ctx) except+ nogil
6060
ContextHandle create_context_handle_from_green_ctx "cuda_core::create_context_handle_from_green_ctx" (
6161
const GreenCtxHandle& h_green_ctx) except+ nogil
62-
cydriver.CUcontext ensure_context_handle "cuda_core::ensure_context_handle" (
63-
const ContextHandle& h) noexcept nogil
6462
GreenCtxHandle get_context_green_ctx "cuda_core::get_context_green_ctx" (
6563
const ContextHandle& h) noexcept nogil
6664
GreenCtxHandle create_green_ctx_handle "cuda_core::create_green_ctx_handle" (

cuda_core/cuda/core/_stream.pyx

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ from cuda.core._resource_handles cimport (
3030
create_event_handle_noctx,
3131
create_stream_handle,
3232
create_stream_handle_with_owner,
33-
ensure_context_handle,
3433
get_current_context,
35-
get_last_error,
3634
get_legacy_stream,
3735
get_per_thread_stream,
3836
get_stream_context,
@@ -411,10 +409,6 @@ cdef inline int Stream_ensure_ctx(Stream self) except?-1 nogil:
411409
if not self._h_context:
412410
self._h_context = get_stream_context(self._h_stream)
413411
if self._h_context:
414-
if as_cu(self._h_context) == NULL:
415-
ctx = ensure_context_handle(self._h_context)
416-
if ctx == NULL:
417-
HANDLE_RETURN(get_last_error())
418412
return 0
419413
HANDLE_RETURN(cydriver.cuStreamGetCtx(as_cu(self._h_stream), &ctx))
420414
if ctx != NULL:

0 commit comments

Comments
 (0)