Skip to content

Commit 9766e54

Browse files
committed
Merge branch 'graph-node-identity' into graph-node-repr
2 parents 6b36e47 + 42131b6 commit 9766e54

File tree

5 files changed

+28
-34
lines changed

5 files changed

+28
-34
lines changed

cuda_core/cuda/core/_cpp/resource_handles.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,8 @@ class HandleRegistry {
174174
}
175175

176176
void unregister_handle(const Key& key) noexcept {
177-
try {
178-
std::lock_guard<std::mutex> lock(mutex_);
179-
auto it = map_.find(key);
180-
if (it != map_.end() && it->second.expired()) {
181-
map_.erase(it);
182-
}
183-
} catch (...) {}
177+
std::lock_guard<std::mutex> lock(mutex_);
178+
map_.erase(key);
184179
}
185180

186181
Handle lookup(const Key& key) {
@@ -989,7 +984,7 @@ GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept {
989984
return h ? get_box(h)->h_graph : GraphHandle{};
990985
}
991986

992-
void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept {
987+
void invalidate_graph_node(const GraphNodeHandle& h) noexcept {
993988
if (h) {
994989
CUgraphNode node = get_box(h)->resource;
995990
if (node) {

cuda_core/cuda/core/_cpp/resource_handles.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_
416416
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept;
417417

418418
// Zero the CUgraphNode resource inside the handle, marking it invalid.
419-
void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept;
419+
void invalidate_graph_node(const GraphNodeHandle& h) noexcept;
420420

421421
// ============================================================================
422422
// Graphics resource handle functions

cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ from cuda.core._resource_handles cimport (
4848
create_graph_handle_ref,
4949
create_graph_node_handle,
5050
graph_node_get_graph,
51-
invalidate_graph_node_handle,
51+
invalidate_graph_node,
5252
)
5353
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value
5454

@@ -63,11 +63,11 @@ from cuda.core import Device
6363
from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy
6464
from cuda.core._utils.cuda_utils import driver, handle_return
6565

66-
_node_cache = weakref.WeakValueDictionary()
66+
_node_registry = weakref.WeakValueDictionary()
6767

6868

69-
cdef inline GraphNode _cached(GraphNode n):
70-
_node_cache[<uintptr_t>n._h_node.get()] = n
69+
cdef inline GraphNode _registered(GraphNode n):
70+
_node_registry[<uintptr_t>n._h_node.get()] = n
7171
return n
7272

7373

@@ -153,8 +153,8 @@ cdef class GraphNode:
153153
return
154154
with nogil:
155155
HANDLE_RETURN(cydriver.cuGraphDestroyNode(node))
156-
_node_cache.pop(<uintptr_t>self._h_node.get(), None)
157-
invalidate_graph_node_handle(self._h_node)
156+
_node_registry.pop(<uintptr_t>self._h_node.get(), None)
157+
invalidate_graph_node(self._h_node)
158158

159159
@property
160160
def pred(self):
@@ -532,8 +532,7 @@ cdef inline ConditionalNode _make_conditional_node(
532532
n._cond_type = cond_type
533533
n._branches = branches
534534

535-
return _cached(n)
536-
535+
return _registered(n)
537536

538537
cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node):
539538
cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph)
@@ -544,12 +543,12 @@ cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node):
544543
(<GraphNode>n)._h_node = h_node
545544
return n
546545

547-
# Return a cached object or create and cache a new one.
548-
cached = _node_cache.get(<uintptr_t>h_node.get())
549-
if cached is not None:
550-
return <GraphNode>cached
546+
# Return a registered object or create and register a new one.
547+
registered = _node_registry.get(<uintptr_t>h_node.get())
548+
if registered is not None:
549+
return <GraphNode>registered
551550
else:
552-
return _cached(GN_create_impl(h_node))
551+
return _registered(GN_create_impl(h_node))
553552

554553

555554
cdef inline GraphNode GN_create_impl(GraphNodeHandle h_node):
@@ -617,7 +616,7 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker,
617616
_attach_user_object(as_cu(h_graph), <void*>new KernelHandle(ker._h_kernel),
618617
<cydriver.CUhostFn>_destroy_kernel_handle_copy)
619618

620-
return _cached(KernelNode._create_with_params(
619+
return _registered(KernelNode._create_with_params(
621620
create_graph_node_handle(new_node, h_graph),
622621
conf.grid, conf.block, conf.shmem_size,
623622
ker._h_kernel))
@@ -646,7 +645,7 @@ cdef inline EmptyNode GN_join(GraphNode self, tuple nodes):
646645
HANDLE_RETURN(cydriver.cuGraphAddEmptyNode(
647646
&new_node, as_cu(h_graph), deps_ptr, num_deps))
648647

649-
return _cached(EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph)))
648+
return _registered(EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph)))
650649

651650

652651
cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options):
@@ -722,7 +721,7 @@ cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options):
722721
HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode(
723722
&new_node, as_cu(h_graph), deps, num_deps, &alloc_params))
724723

725-
return _cached(AllocNode._create_with_params(
724+
return _registered(AllocNode._create_with_params(
726725
create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size,
727726
device_id, memory_type, tuple(peer_ids)))
728727

@@ -742,7 +741,7 @@ cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr):
742741
HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode(
743742
&new_node, as_cu(h_graph), deps, num_deps, c_dptr))
744743

745-
return _cached(FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr))
744+
return _registered(FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr))
746745

747746

748747
cdef inline MemsetNode GN_memset(
@@ -777,7 +776,7 @@ cdef inline MemsetNode GN_memset(
777776
&new_node, as_cu(h_graph), deps, num_deps,
778777
&memset_params, ctx))
779778

780-
return _cached(MemsetNode._create_with_params(
779+
return _registered(MemsetNode._create_with_params(
781780
create_graph_node_handle(new_node, h_graph), c_dst,
782781
val, elem_size, width, height, pitch))
783782

@@ -838,7 +837,7 @@ cdef inline MemcpyNode GN_memcpy(
838837
HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode(
839838
&new_node, as_cu(h_graph), deps, num_deps, &params, ctx))
840839

841-
return _cached(MemcpyNode._create_with_params(
840+
return _registered(MemcpyNode._create_with_params(
842841
create_graph_node_handle(new_node, h_graph), c_dst, c_src, size,
843842
c_dst_type, c_src_type))
844843

@@ -865,7 +864,7 @@ cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def):
865864

866865
cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph)
867866

868-
return _cached(ChildGraphNode._create_with_params(
867+
return _registered(ChildGraphNode._create_with_params(
869868
create_graph_node_handle(new_node, h_graph), h_embedded))
870869

871870

@@ -887,7 +886,7 @@ cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev):
887886
_attach_user_object(as_cu(h_graph), <void*>new EventHandle(ev._h_event),
888887
<cydriver.CUhostFn>_destroy_event_handle_copy)
889888

890-
return _cached(EventRecordNode._create_with_params(
889+
return _registered(EventRecordNode._create_with_params(
891890
create_graph_node_handle(new_node, h_graph), ev._h_event))
892891

893892

@@ -909,7 +908,7 @@ cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev):
909908
_attach_user_object(as_cu(h_graph), <void*>new EventHandle(ev._h_event),
910909
<cydriver.CUhostFn>_destroy_event_handle_copy)
911910

912-
return _cached(EventWaitNode._create_with_params(
911+
return _registered(EventWaitNode._create_with_params(
913912
create_graph_node_handle(new_node, h_graph), ev._h_event))
914913

915914

@@ -936,6 +935,6 @@ cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_
936935
&new_node, as_cu(h_graph), deps, num_deps, &node_params))
937936

938937
cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None
939-
return _cached(HostCallbackNode._create_with_params(
938+
return _registered(HostCallbackNode._create_with_params(
940939
create_graph_node_handle(new_node, h_graph), callable_obj,
941940
node_params.fn, node_params.userData))

cuda_core/cuda/core/_resource_handles.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHand
186186
# Graph node handles
187187
cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil
188188
cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil
189-
cdef void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept nogil
189+
cdef void invalidate_graph_node(const GraphNodeHandle& h) noexcept nogil
190190

191191
# Graphics resource handles
192192
cdef GraphicsResourceHandle create_graphics_resource_handle(

cuda_core/cuda/core/_resource_handles.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
159159
cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil
160160
GraphHandle graph_node_get_graph "cuda_core::graph_node_get_graph" (
161161
const GraphNodeHandle& h) noexcept nogil
162-
void invalidate_graph_node_handle "cuda_core::invalidate_graph_node_handle" (
162+
void invalidate_graph_node "cuda_core::invalidate_graph_node" (
163163
const GraphNodeHandle& h) noexcept nogil
164164

165165
# Graphics resource handles

0 commit comments

Comments
 (0)