@@ -827,18 +827,49 @@ void wp_free_device_async(void* context, void* ptr)
827827 // check if the capture is still active
828828 auto capture_iter = g_captures.find (capture_id);
829829 if (capture_iter != g_captures.end ()) {
830- // Add a mem free node. Use all current leaf nodes as dependencies to ensure that all prior
831- // work completes before deallocating. This works with both Warp-initiated and external captures
832- // and avoids the need to explicitly track all streams used during the capture.
830+ // Add a mem free node. Use the caller stream's capture dependencies so frees
831+ // are ordered with respect to work recorded on the stream where the free
832+ // occurs (handles forked substreams correctly). Fall back to using the
833+ // global graph leaf nodes if capture info isn't available for the caller.
833834 CaptureInfo* capture = capture_iter->second ;
834835 cudaGraph_t graph = get_capture_graph (capture->stream );
835- std::vector<cudaGraphNode_t> leaf_nodes;
836- if (graph && get_graph_leaf_nodes (graph, leaf_nodes)) {
837- cudaGraphNode_t free_node;
838- if (check_cuda (cudaGraphAddMemFreeNode (&free_node, graph, leaf_nodes.data (), leaf_nodes.size (), ptr))) {
839- check_cu (cuStreamUpdateCaptureDependencies_f (
840- capture->stream , &free_node, 1 , cudaStreamSetCaptureDependencies
841- ));
836+
837+ // get the caller stream (the stream on which wp_free_device_async was invoked)
838+ CUstream caller_cuda_stream = get_current_stream ();
839+
840+ const cudaGraphNode_t* capture_deps = nullptr ;
841+ size_t dep_count = 0 ;
842+ CUstreamCaptureStatus capture_status = CU_STREAM_CAPTURE_STATUS_NONE;
843+
844+ bool added = false ;
845+
846+ // Try to get per-stream capture dependencies for the caller stream and use
847+ // them as predecessors for the memfree node. This ensures the memfree will
848+ // be ordered after work recorded on the caller stream (including forked
849+ // substreams brought into the capture via wait_stream/wait_event).
850+ if (graph && check_cu (cuStreamGetCaptureInfo_f (caller_cuda_stream, &capture_status, nullptr , &graph, &capture_deps, &dep_count))) {
851+ if (graph && (capture_deps != nullptr || dep_count > 0 )) {
852+ cudaGraphNode_t free_node;
853+ if (check_cuda (cudaGraphAddMemFreeNode (&free_node, graph, capture_deps, dep_count, ptr))) {
854+ check_cu (cuStreamUpdateCaptureDependencies_f (
855+ caller_cuda_stream, &free_node, 1 , cudaStreamSetCaptureDependencies
856+ ));
857+ added = true ;
858+ }
859+ }
860+ }
861+
862+ // Fallback: if we couldn't obtain per-stream capture dependencies, use
863+ // the graph leaf nodes as before.
864+ if (!added && graph) {
865+ std::vector<cudaGraphNode_t> leaf_nodes;
866+ if (get_graph_leaf_nodes (graph, leaf_nodes)) {
867+ cudaGraphNode_t free_node;
868+ if (check_cuda (cudaGraphAddMemFreeNode (&free_node, graph, leaf_nodes.data (), leaf_nodes.size (), ptr))) {
869+ check_cu (cuStreamUpdateCaptureDependencies_f (
870+ capture->stream , &free_node, 1 , cudaStreamSetCaptureDependencies
871+ ));
872+ }
842873 }
843874 }
844875
0 commit comments