@@ -68,6 +68,7 @@ using executorch::runtime::Span;
6868using executorch::runtime::etensor::Tensor;
6969
7070// SlimTensor type aliases
71+ using cuda::CudaGraphPhase;
7172using slim::CPU_DEVICE;
7273using slim::DEFAULT_CUDA_DEVICE;
7374using slim::DeviceTraits;
@@ -80,8 +81,7 @@ namespace {
8081constexpr char kSkipCopyOutputToCpuForMethod [] =
8182 " skip_copy_output_to_cpu_for_method" ;
8283constexpr char kUseSharedCudaStream [] = " use_shared_cuda_stream" ;
83- constexpr char kEnableCudaGraphForMethod [] =
84- " enable_cuda_graph_for_method" ;
84+ constexpr char kEnableCudaGraphForMethod [] = " enable_cuda_graph_for_method" ;
8585constexpr int kCudaGraphWarmupSteps = 3 ;
8686} // anonymous namespace
8787
@@ -410,7 +410,9 @@ class ET_EXPERIMENTAL CudaBackend final
410410 cudaDeviceSynchronize ();
411411 buffer_res->Free ();
412412 } else {
413- ET_LOG (Info, " weights_blob '%s' not found or update fn is null" ,
413+ ET_LOG (
414+ Info,
415+ " weights_blob '%s' not found or update fn is null" ,
414416 weights_blob_key.c_str ());
415417 }
416418
@@ -540,8 +542,8 @@ class ET_EXPERIMENTAL CudaBackend final
540542
541543 // Initialize CUDA graph state if enabled for this method.
542544 if (should_use_cuda_graph_for_method (method_name)) {
543- handle->cuda_graph_phase = 1 ; // warmup
544- handle->cuda_graph_warmup_remaining = kCudaGraphWarmupSteps ;
545+ handle->cuda_graph_state . phase = CudaGraphPhase::Warmup;
546+ handle->cuda_graph_state . warmup_remaining = kCudaGraphWarmupSteps ;
545547 ET_LOG (
546548 Info,
547549 " CUDA graph enabled for method '%s' (warmup=%d)" ,
@@ -578,7 +580,7 @@ class ET_EXPERIMENTAL CudaBackend final
578580 // ---------------------------------------------------------------
579581 // CUDA graph REPLAY path — skip all tensor setup and just replay
580582 // ---------------------------------------------------------------
581- if (handle->cuda_graph_phase == 2 ) {
583+ if (handle->cuda_graph_state . phase == CudaGraphPhase::Replay ) {
582584 Result<cudaStream_t> csr = getCurrentCUDAStream (0 );
583585 cudaStream_t cs = csr.get ();
584586 ET_CHECK_OK_OR_RETURN_ERROR (csr.error ());
@@ -587,15 +589,16 @@ class ET_EXPERIMENTAL CudaBackend final
587589 for (size_t i = 0 ; i < n_inputs; i++) {
588590 auto * cpu_tensor = &(args[i]->toTensor ());
589591 cudaMemcpyAsync (
590- handle->static_input_ptrs [i],
592+ handle->cuda_graph_state . static_input_ptrs [i],
591593 cpu_tensor->const_data_ptr (),
592- handle->static_input_nbytes [i],
594+ handle->cuda_graph_state . static_input_nbytes [i],
593595 cudaMemcpyHostToDevice,
594596 cs);
595597 }
596598
597599 // Replay the captured graph
598- cudaError_t gerr = cudaGraphLaunch (handle->cuda_graph_exec , cs);
600+ cudaError_t gerr =
601+ cudaGraphLaunch (handle->cuda_graph_state .graph_exec , cs);
599602 ET_CHECK_OR_RETURN_ERROR (
600603 gerr == cudaSuccess,
601604 Internal,
@@ -610,8 +613,8 @@ class ET_EXPERIMENTAL CudaBackend final
610613 auto * cpu_out = &(args[i + n_inputs]->toTensor ());
611614 cudaMemcpyAsync (
612615 cpu_out->mutable_data_ptr (),
613- handle->static_output_ptrs [i],
614- handle->static_output_nbytes [i],
616+ handle->cuda_graph_state . static_output_ptrs [i],
617+ handle->cuda_graph_state . static_output_nbytes [i],
615618 cudaMemcpyDeviceToHost,
616619 cs);
617620 }
@@ -625,8 +628,8 @@ class ET_EXPERIMENTAL CudaBackend final
625628 // Normal path (also used for WARMUP and CAPTURE phases)
626629 // ---------------------------------------------------------------
627630 bool is_capture_step =
628- (handle->cuda_graph_phase == 1 &&
629- handle->cuda_graph_warmup_remaining == 0 );
631+ (handle->cuda_graph_state . phase == CudaGraphPhase::Warmup &&
632+ handle->cuda_graph_state . warmup_remaining == 0 );
630633
631634 // NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy
632635 // optimization. We need to create GPU copies for CUDA kernel execution
@@ -649,27 +652,32 @@ class ET_EXPERIMENTAL CudaBackend final
649652 void * static_ptr = nullptr ;
650653 cudaError_t merr = cudaMalloc (&static_ptr, nbytes);
651654 ET_CHECK_OR_RETURN_ERROR (
652- merr == cudaSuccess, Internal,
655+ merr == cudaSuccess,
656+ Internal,
653657 " cudaMalloc for static input %zu failed: %s" ,
654- i, cudaGetErrorString (merr));
658+ i,
659+ cudaGetErrorString (merr));
655660
656661 cudaMemcpy (
657- static_ptr, cpu_tensor->const_data_ptr (),
658- nbytes, cudaMemcpyHostToDevice);
662+ static_ptr,
663+ cpu_tensor->const_data_ptr (),
664+ nbytes,
665+ cudaMemcpyHostToDevice);
659666
660- handle->static_input_ptrs .push_back (static_ptr);
661- handle->static_input_sizes .push_back (sizes_vec);
662- handle->static_input_strides .push_back (strides_vec);
663- handle->static_input_scalar_types .push_back (
667+ handle->cuda_graph_state . static_input_ptrs .push_back (static_ptr);
668+ handle->cuda_graph_state . static_input_sizes .push_back (sizes_vec);
669+ handle->cuda_graph_state . static_input_strides .push_back (strides_vec);
670+ handle->cuda_graph_state . static_input_scalar_types .push_back (
664671 static_cast <int >(cpu_tensor->scalar_type ()));
665- handle->static_input_nbytes .push_back (nbytes);
672+ handle->cuda_graph_state . static_input_nbytes .push_back (nbytes);
666673
667674 gpu_inputs[i] = new SlimTensor (slim::from_blob (
668675 static_ptr,
669676 slim::makeArrayRef (sizes_vec),
670677 slim::makeArrayRef (strides_vec),
671678 static_cast <slim::c10::ScalarType>(cpu_tensor->scalar_type ()),
672- DEFAULT_CUDA_DEVICE, 0 ));
679+ DEFAULT_CUDA_DEVICE,
680+ 0 ));
673681 continue ;
674682 }
675683
@@ -755,8 +763,8 @@ class ET_EXPERIMENTAL CudaBackend final
755763 " CUDA graph: beginning stream capture for '%s'" ,
756764 handle->method_name .c_str ());
757765
758- cudaError_t cerr = cudaStreamBeginCapture (
759- cuda_stream, cudaStreamCaptureModeRelaxed);
766+ cudaError_t cerr =
767+ cudaStreamBeginCapture ( cuda_stream, cudaStreamCaptureModeRelaxed);
760768 ET_CHECK_OR_RETURN_ERROR (
761769 cerr == cudaSuccess,
762770 Internal,
@@ -792,15 +800,16 @@ class ET_EXPERIMENTAL CudaBackend final
792800 if (is_capture_step) {
793801 // End capture → instantiate graph
794802 cudaError_t gerr =
795- cudaStreamEndCapture (cuda_stream, &handle->cuda_graph );
803+ cudaStreamEndCapture (cuda_stream, &handle->cuda_graph_state . graph );
796804 ET_CHECK_OR_RETURN_ERROR (
797805 gerr == cudaSuccess,
798806 Internal,
799807 " cudaStreamEndCapture failed: %s" ,
800808 cudaGetErrorString (gerr));
801809
802810 gerr = cudaGraphInstantiate (
803- &handle->cuda_graph_exec , handle->cuda_graph ,
811+ &handle->cuda_graph_state .graph_exec ,
812+ handle->cuda_graph_state .graph ,
804813 cudaGraphInstantiateFlagAutoFreeOnLaunch);
805814 ET_CHECK_OR_RETURN_ERROR (
806815 gerr == cudaSuccess,
@@ -811,27 +820,27 @@ class ET_EXPERIMENTAL CudaBackend final
811820 // Record static output pointers (stable under graph replay)
812821 for (size_t i = 0 ; i < n_outputs; i++) {
813822 SlimTensor* out = gpu_outputs[i];
814- handle->static_output_ptrs .push_back (out->data_ptr ());
823+ handle->cuda_graph_state . static_output_ptrs .push_back (out->data_ptr ());
815824
816825 auto out_sizes = out->sizes ();
817826 auto out_strides = out->strides ();
818- handle->static_output_sizes .push_back (
827+ handle->cuda_graph_state . static_output_sizes .push_back (
819828 std::vector<int64_t >(out_sizes.begin (), out_sizes.end ()));
820- handle->static_output_strides .push_back (
829+ handle->cuda_graph_state . static_output_strides .push_back (
821830 std::vector<int64_t >(out_strides.begin (), out_strides.end ()));
822- handle->static_output_scalar_types .push_back (
831+ handle->cuda_graph_state . static_output_scalar_types .push_back (
823832 static_cast <int >(out->dtype ()));
824- handle->static_output_nbytes .push_back (out->nbytes ());
833+ handle->cuda_graph_state . static_output_nbytes .push_back (out->nbytes ());
825834 }
826835
827- handle->cuda_graph_phase = 2 ; // switch to replay mode
836+ handle->cuda_graph_state . phase = CudaGraphPhase::Replay;
828837 ET_LOG (
829838 Info,
830839 " CUDA graph: captured and instantiated for '%s'" ,
831840 handle->method_name .c_str ());
832841
833842 // Replay once to actually produce output (capture doesn't execute)
834- gerr = cudaGraphLaunch (handle->cuda_graph_exec , cuda_stream);
843+ gerr = cudaGraphLaunch (handle->cuda_graph_state . graph_exec , cuda_stream);
835844 ET_CHECK_OR_RETURN_ERROR (
836845 gerr == cudaSuccess,
837846 Internal,
@@ -846,8 +855,8 @@ class ET_EXPERIMENTAL CudaBackend final
846855 auto * cpu_out = &(args[i + n_inputs]->toTensor ());
847856 cudaMemcpyAsync (
848857 cpu_out->mutable_data_ptr (),
849- handle->static_output_ptrs [i],
850- handle->static_output_nbytes [i],
858+ handle->cuda_graph_state . static_output_ptrs [i],
859+ handle->cuda_graph_state . static_output_nbytes [i],
851860 cudaMemcpyDeviceToHost,
852861 cuda_stream);
853862 // Don't delete — static buffers are owned by the handle
@@ -861,13 +870,13 @@ class ET_EXPERIMENTAL CudaBackend final
861870 // ----- Normal / WARMUP execution continues here -----
862871
863872 // Decrement warmup counter if in warmup phase
864- if (handle->cuda_graph_phase == 1 &&
865- handle->cuda_graph_warmup_remaining > 0 ) {
866- handle->cuda_graph_warmup_remaining --;
873+ if (handle->cuda_graph_state . phase == CudaGraphPhase::Warmup &&
874+ handle->cuda_graph_state . warmup_remaining > 0 ) {
875+ handle->cuda_graph_state . warmup_remaining --;
867876 ET_LOG (
868877 Info,
869878 " CUDA graph warmup: %d steps remaining for '%s'" ,
870- handle->cuda_graph_warmup_remaining ,
879+ handle->cuda_graph_state . warmup_remaining ,
871880 handle->method_name .c_str ());
872881 }
873882
0 commit comments