@@ -69,6 +69,7 @@ using executorch::runtime::Span;
6969using executorch::runtime::etensor::Tensor;
7070
7171// SlimTensor type aliases
72+ using cuda::CudaGraphPhase;
7273using slim::CPU_DEVICE;
7374using slim::DEFAULT_CUDA_DEVICE;
7475using slim::DeviceTraits;
@@ -81,6 +82,8 @@ namespace {
8182constexpr char kSkipCopyOutputToCpuForMethod [] =
8283 " skip_copy_output_to_cpu_for_method" ;
8384constexpr char kUseSharedCudaStream [] = " use_shared_cuda_stream" ;
85+ constexpr char kEnableCudaGraphForMethod [] = " enable_cuda_graph_for_method" ;
86+ constexpr int kCudaGraphWarmupSteps = 3 ;
8487constexpr char kShareKvCacheAcrossMethods [] = " share_kv_cache_across_methods" ;
8588} // anonymous namespace
8689
@@ -148,6 +151,20 @@ class ET_EXPERIMENTAL CudaBackend final
148151 return method_in_csv (method_name, skip_copy_method_);
149152 }
150153
154+ void set_cuda_graph_method (
155+ const std::array<char , kMaxOptionValueLength >& raw) {
156+ std::lock_guard<std::mutex> guard (cuda_graph_method_mutex_);
157+ cuda_graph_method_ = std::string (raw.data ());
158+ }
159+
160+ bool should_use_cuda_graph_for_method (const std::string& method_name) const {
161+ if (method_name.empty ()) {
162+ return false ;
163+ }
164+ std::lock_guard<std::mutex> guard (cuda_graph_method_mutex_);
165+ return method_in_csv (method_name, cuda_graph_method_);
166+ }
167+
151168 // Create the shared CUDA stream. Called when use_shared_cuda_stream option
152169 // is set to true. The presence of shared_cuda_stream_ indicates shared mode.
153170 void create_shared_cuda_stream () {
@@ -266,6 +283,17 @@ class ET_EXPERIMENTAL CudaBackend final
266283 ET_LOG (Error, " Option %s must be a boolean." , kUseSharedCudaStream );
267284 return Error::InvalidArgument;
268285 }
286+ } else if (std::strcmp (option.key , kEnableCudaGraphForMethod ) == 0 ) {
287+ if (auto * val = std::get_if<std::array<char , kMaxOptionValueLength >>(
288+ &option.value )) {
289+ set_cuda_graph_method (*val);
290+ } else {
291+ ET_LOG (
292+ Error,
293+ " Option %s must be a method name string." ,
294+ kEnableCudaGraphForMethod );
295+ return Error::InvalidArgument;
296+ }
269297 }
270298 }
271299 return Error::Ok;
@@ -533,6 +561,17 @@ class ET_EXPERIMENTAL CudaBackend final
533561 method_name.c_str ());
534562 }
535563
564+ // Initialize CUDA graph state if enabled for this method.
565+ if (should_use_cuda_graph_for_method (method_name)) {
566+ handle->cuda_graph_state .phase = CudaGraphPhase::Warmup;
567+ handle->cuda_graph_state .warmup_remaining = kCudaGraphWarmupSteps ;
568+ ET_LOG (
569+ Info,
570+ " CUDA graph enabled for method '%s' (warmup=%d)" ,
571+ method_name.c_str (),
572+ kCudaGraphWarmupSteps );
573+ }
574+
536575 return (DelegateHandle*)handle; // Return the handle post-processing
537576 }
538577
@@ -561,6 +600,68 @@ class ET_EXPERIMENTAL CudaBackend final
561600 n_outputs,
562601 args.size ())
563602
603+ // ---------------------------------------------------------------
604+ // CUDA graph REPLAY path — skip all tensor setup and just replay
605+ // ---------------------------------------------------------------
606+ if (handle->cuda_graph_state .phase == CudaGraphPhase::Replay) {
607+ Result<cudaStream_t> csr = getCurrentCUDAStream (0 );
608+ ET_CHECK_OK_OR_RETURN_ERROR (csr.error ());
609+ cudaStream_t cs = csr.get ();
610+
611+ // Copy new input data into static input buffers
612+ for (size_t i = 0 ; i < n_inputs; i++) {
613+ auto * cpu_tensor = &(args[i]->toTensor ());
614+ ET_CHECK_OR_RETURN_ERROR (
615+ cpu_tensor->nbytes () ==
616+ handle->cuda_graph_state .static_input_nbytes [i],
617+ InvalidArgument,
618+ " CUDA graph replay: input %zu size mismatch (expected %zu, got %zu)" ,
619+ i,
620+ handle->cuda_graph_state .static_input_nbytes [i],
621+ cpu_tensor->nbytes ());
622+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpyAsync (
623+ handle->cuda_graph_state .static_input_ptrs [i],
624+ cpu_tensor->const_data_ptr (),
625+ handle->cuda_graph_state .static_input_nbytes [i],
626+ cudaMemcpyHostToDevice,
627+ cs));
628+ }
629+
630+ // Replay the captured graph
631+ cudaError_t gerr =
632+ cudaGraphLaunch (handle->cuda_graph_state .graph_exec , cs);
633+ ET_CHECK_OR_RETURN_ERROR (
634+ gerr == cudaSuccess,
635+ Internal,
636+ " cudaGraphLaunch failed: %s" ,
637+ cudaGetErrorString (gerr));
638+
639+ // Copy outputs back to CPU
640+ const bool copy_outputs =
641+ !should_skip_copy_for_method (handle->method_name );
642+ if (copy_outputs) {
643+ for (size_t i = 0 ; i < n_outputs; i++) {
644+ auto * cpu_out = &(args[i + n_inputs]->toTensor ());
645+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpyAsync (
646+ cpu_out->mutable_data_ptr (),
647+ handle->cuda_graph_state .static_output_ptrs [i],
648+ handle->cuda_graph_state .static_output_nbytes [i],
649+ cudaMemcpyDeviceToHost,
650+ cs));
651+ }
652+ cudaStreamSynchronize (cs);
653+ }
654+
655+ return Error::Ok;
656+ }
657+
658+ // ---------------------------------------------------------------
659+ // Normal path (also used for WARMUP and CAPTURE phases)
660+ // ---------------------------------------------------------------
661+ bool is_capture_step =
662+ (handle->cuda_graph_state .phase == CudaGraphPhase::Warmup &&
663+ handle->cuda_graph_state .warmup_remaining == 0 );
664+
564665 // NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy
565666 // optimization. We need to create GPU copies for CUDA kernel execution
566667 // using SlimTensor.
@@ -571,6 +672,33 @@ class ET_EXPERIMENTAL CudaBackend final
571672 for (size_t i = 0 ; i < n_inputs; i++) {
572673 auto * cpu_tensor = &(args[i]->toTensor ());
573674
675+ // CAPTURE step: allocate persistent static GPU buffers
676+ if (is_capture_step) {
677+ size_t nbytes = cpu_tensor->nbytes ();
678+
679+ void * static_ptr = nullptr ;
680+ cudaError_t merr = cudaMalloc (&static_ptr, nbytes);
681+ ET_CHECK_OR_RETURN_ERROR (
682+ merr == cudaSuccess,
683+ Internal,
684+ " cudaMalloc for static input %zu failed: %s" ,
685+ i,
686+ cudaGetErrorString (merr));
687+
688+ cudaMemcpy (
689+ static_ptr,
690+ cpu_tensor->const_data_ptr (),
691+ nbytes,
692+ cudaMemcpyHostToDevice);
693+
694+ handle->cuda_graph_state .static_input_ptrs .push_back (static_ptr);
695+ handle->cuda_graph_state .static_input_nbytes .push_back (nbytes);
696+
697+ gpu_inputs[i] = make_slimtensor_from_blob_with_etensor_metadata (
698+ static_ptr, cpu_tensor);
699+ continue ;
700+ }
701+
574702 // Check if input data is already on GPU (skip-copy optimization for
575703 // inputs) This can happen when the caller has pre-staged data on GPU
576704 cudaPointerAttributes attributes{};
@@ -579,19 +707,8 @@ class ET_EXPERIMENTAL CudaBackend final
579707 cudaError_t err = cudaPointerGetAttributes (&attributes, data_ptr);
580708 if (err == cudaSuccess && attributes.type == cudaMemoryTypeDevice) {
581709 // Data is already on GPU - wrap it directly without copy
582- auto sizes = cpu_tensor->sizes ();
583- auto strides = cpu_tensor->strides ();
584- std::vector<int64_t > sizes_vec (sizes.begin (), sizes.end ());
585- std::vector<int64_t > strides_vec (strides.begin (), strides.end ());
586-
587- gpu_inputs[i] = new SlimTensor (slim::from_blob (
588- const_cast <void *>(data_ptr),
589- slim::makeArrayRef (sizes_vec),
590- slim::makeArrayRef (strides_vec),
591- static_cast <slim::c10::ScalarType>(cpu_tensor->scalar_type ()),
592- DEFAULT_CUDA_DEVICE,
593- 0 // storage_offset
594- ));
710+ gpu_inputs[i] = make_slimtensor_from_blob_with_etensor_metadata (
711+ const_cast <void *>(data_ptr), cpu_tensor);
595712
596713 continue ;
597714 }
@@ -643,8 +760,25 @@ class ET_EXPERIMENTAL CudaBackend final
643760 // NOTE: run() steals input handles (RAII wraps them at the start of
644761 // run_impl) and may replace output handles with its own.
645762 Result<cudaStream_t> cuda_stream_ret = getCurrentCUDAStream (0 );
646- cudaStream_t cuda_stream = cuda_stream_ret.get ();
647763 ET_CHECK_OK_OR_RETURN_ERROR (cuda_stream_ret.error ());
764+ cudaStream_t cuda_stream = cuda_stream_ret.get ();
765+
766+ if (is_capture_step) {
767+ // ----- CUDA graph CAPTURE -----
768+ ET_LOG (
769+ Info,
770+ " CUDA graph: beginning stream capture for '%s'" ,
771+ handle->method_name .c_str ());
772+
773+ cudaError_t cerr =
774+ cudaStreamBeginCapture (cuda_stream, cudaStreamCaptureModeRelaxed);
775+ ET_CHECK_OR_RETURN_ERROR (
776+ cerr == cudaSuccess,
777+ Internal,
778+ " cudaStreamBeginCapture failed: %s" ,
779+ cudaGetErrorString (cerr));
780+ }
781+
648782 AOTIRuntimeError error = handle->run (
649783 handle->container_handle ,
650784 reinterpret_cast <Tensor**>(gpu_inputs.data ()),
@@ -670,6 +804,87 @@ class ET_EXPERIMENTAL CudaBackend final
670804 " AOTInductorModelContainerRun failed with error code %d" ,
671805 error);
672806
807+ if (is_capture_step) {
808+ // End capture → instantiate graph
809+ cudaError_t gerr =
810+ cudaStreamEndCapture (cuda_stream, &handle->cuda_graph_state .graph );
811+ ET_CHECK_OR_RETURN_ERROR (
812+ gerr == cudaSuccess,
813+ Internal,
814+ " cudaStreamEndCapture failed: %s" ,
815+ cudaGetErrorString (gerr));
816+
817+ gerr = cudaGraphInstantiate (
818+ &handle->cuda_graph_state .graph_exec ,
819+ handle->cuda_graph_state .graph ,
820+ cudaGraphInstantiateFlagAutoFreeOnLaunch);
821+ ET_CHECK_OR_RETURN_ERROR (
822+ gerr == cudaSuccess,
823+ Internal,
824+ " cudaGraphInstantiate failed: %s" ,
825+ cudaGetErrorString (gerr));
826+
827+ // Record static output pointers (stable under graph replay)
828+ for (size_t i = 0 ; i < n_outputs; i++) {
829+ SlimTensor* out = gpu_outputs[i];
830+ handle->cuda_graph_state .static_output_ptrs .push_back (out->data_ptr ());
831+ handle->cuda_graph_state .static_output_nbytes .push_back (out->nbytes ());
832+ }
833+
834+ handle->cuda_graph_state .phase = CudaGraphPhase::Replay;
835+ ET_LOG (
836+ Info,
837+ " CUDA graph: captured and instantiated for '%s'" ,
838+ handle->method_name .c_str ());
839+
840+ // Replay once to actually produce output (capture doesn't execute)
841+ gerr = cudaGraphLaunch (handle->cuda_graph_state .graph_exec , cuda_stream);
842+ ET_CHECK_OR_RETURN_ERROR (
843+ gerr == cudaSuccess,
844+ Internal,
845+ " cudaGraphLaunch (first replay) failed: %s" ,
846+ cudaGetErrorString (gerr));
847+
848+ // Copy capture-step outputs to CPU
849+ const bool copy_outputs =
850+ !should_skip_copy_for_method (handle->method_name );
851+ if (copy_outputs) {
852+ for (size_t i = 0 ; i < n_outputs; i++) {
853+ auto * cpu_out = &(args[i + n_inputs]->toTensor ());
854+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpyAsync (
855+ cpu_out->mutable_data_ptr (),
856+ handle->cuda_graph_state .static_output_ptrs [i],
857+ handle->cuda_graph_state .static_output_nbytes [i],
858+ cudaMemcpyDeviceToHost,
859+ cuda_stream));
860+ // Don't delete — static buffers are owned by the handle
861+ gpu_outputs[i] = nullptr ;
862+ }
863+ cudaStreamSynchronize (cuda_stream);
864+ } else {
865+ // Even when skipping copy, null out gpu_outputs to prevent
866+ // the ScopeGuard from deleting static output buffers.
867+ for (size_t i = 0 ; i < n_outputs; i++) {
868+ gpu_outputs[i] = nullptr ;
869+ }
870+ }
871+
872+ return Error::Ok;
873+ }
874+
875+ // ----- Normal / WARMUP execution continues here -----
876+
877+ // Decrement warmup counter if in warmup phase
878+ if (handle->cuda_graph_state .phase == CudaGraphPhase::Warmup &&
879+ handle->cuda_graph_state .warmup_remaining > 0 ) {
880+ handle->cuda_graph_state .warmup_remaining --;
881+ ET_LOG (
882+ Info,
883+ " CUDA graph warmup: %d steps remaining for '%s'" ,
884+ handle->cuda_graph_state .warmup_remaining ,
885+ handle->method_name .c_str ());
886+ }
887+
673888 const bool copy_outputs = !should_skip_copy_for_method (handle->method_name );
674889
675890 if (copy_outputs) {
@@ -764,6 +979,9 @@ class ET_EXPERIMENTAL CudaBackend final
764979 mutable std::mutex skip_copy_method_mutex_;
765980 std::string skip_copy_method_;
766981
982+ mutable std::mutex cuda_graph_method_mutex_;
983+ std::string cuda_graph_method_;
984+
767985 // Shared CUDA stream for all methods. When set (non-null), all methods use
768986 // the same stream to ensure proper ordering (critical for skip-copy
769987 // optimization). Created when use_shared_cuda_stream option is set to true.
0 commit comments