@@ -827,18 +827,51 @@ void wp_free_device_async(void* context, void* ptr)
827827 // check if the capture is still active
828828 auto capture_iter = g_captures.find (capture_id);
829829 if (capture_iter != g_captures.end ()) {
830- // Add a mem free node. Use all current leaf nodes as dependencies to ensure that all prior
831- // work completes before deallocating. This works with both Warp-initiated and external captures
832- // and avoids the need to explicitly track all streams used during the capture.
830+ // Add a mem free node. Prefer per-caller-stream capture dependencies so frees
831+ // are ordered with respect to work recorded on the stream where the free
832+ // occurs (handles forked substreams correctly). Fall back to using the
833+ // graph's leaf nodes if per-stream info isn't available.
833834 CaptureInfo* capture = capture_iter->second ;
834- cudaGraph_t graph = get_capture_graph (capture->stream );
835- std::vector<cudaGraphNode_t> leaf_nodes;
836- if (graph && get_graph_leaf_nodes (graph, leaf_nodes)) {
835+ cudaGraph_t begin_graph = get_capture_graph (capture->stream );
836+
837+ // get the caller stream (the stream on which wp_free_device_async was invoked)
838+ CUstream caller_cuda_stream = get_current_stream ();
839+
840+ cudaGraph_t caller_graph = nullptr ;
841+ const cudaGraphNode_t* capture_deps = nullptr ;
842+ size_t dep_count = 0 ;
843+ CUstreamCaptureStatus capture_status = CU_STREAM_CAPTURE_STATUS_NONE;
844+
845+ bool added = false ;
846+
847+ // Query per-stream capture info into separate variables so we don't
848+ // clobber the begin_graph value used for the fallback.
849+ if (begin_graph
850+ && check_cu (cuStreamGetCaptureInfo_f (
851+ caller_cuda_stream, &capture_status, nullptr , &caller_graph, &capture_deps, &dep_count
852+ ))
853+ && caller_graph && capture_status == CU_STREAM_CAPTURE_STATUS_ACTIVE) {
837854 cudaGraphNode_t free_node;
838- if (check_cuda (cudaGraphAddMemFreeNode (&free_node, graph, leaf_nodes. data (), leaf_nodes. size () , ptr))) {
855+ if (check_cuda (cudaGraphAddMemFreeNode (&free_node, caller_graph, capture_deps, dep_count , ptr))) {
839856 check_cu (cuStreamUpdateCaptureDependencies_f (
840- capture-> stream , &free_node, 1 , cudaStreamSetCaptureDependencies
857+ caller_cuda_stream , &free_node, 1 , cudaStreamSetCaptureDependencies
841858 ));
859+ added = true ;
860+ }
861+ }
862+
863+ // Fallback: use the graph leaf nodes from the original capture if needed
864+ if (!added && begin_graph) {
865+ std::vector<cudaGraphNode_t> leaf_nodes;
866+ if (get_graph_leaf_nodes (begin_graph, leaf_nodes)) {
867+ cudaGraphNode_t free_node;
868+ if (check_cuda (
869+ cudaGraphAddMemFreeNode (&free_node, begin_graph, leaf_nodes.data (), leaf_nodes.size (), ptr)
870+ )) {
871+ check_cu (cuStreamUpdateCaptureDependencies_f (
872+ capture->stream , &free_node, 1 , cudaStreamSetCaptureDependencies
873+ ));
874+ }
842875 }
843876 }
844877
0 commit comments