@@ -68,6 +68,7 @@ using executorch::runtime::Span;
6868using executorch::runtime::etensor::Tensor;
6969
7070// SlimTensor type aliases
71+ using cuda::CudaGraphPhase;
7172using slim::CPU_DEVICE;
7273using slim::DEFAULT_CUDA_DEVICE;
7374using slim::DeviceTraits;
@@ -541,8 +542,8 @@ class ET_EXPERIMENTAL CudaBackend final
541542
542543 // Initialize CUDA graph state if enabled for this method.
543544 if (should_use_cuda_graph_for_method (method_name)) {
544- handle->cuda_graph_phase = 1 ; // warmup
545- handle->cuda_graph_warmup_remaining = kCudaGraphWarmupSteps ;
545+ handle->cuda_graph_state . phase = CudaGraphPhase::Warmup;
546+ handle->cuda_graph_state . warmup_remaining = kCudaGraphWarmupSteps ;
546547 ET_LOG (
547548 Info,
548549 " CUDA graph enabled for method '%s' (warmup=%d)" ,
@@ -579,7 +580,7 @@ class ET_EXPERIMENTAL CudaBackend final
579580 // ---------------------------------------------------------------
580581 // CUDA graph REPLAY path — skip all tensor setup and just replay
581582 // ---------------------------------------------------------------
582- if (handle->cuda_graph_phase == 2 ) {
583+ if (handle->cuda_graph_state . phase == CudaGraphPhase::Replay ) {
583584 Result<cudaStream_t> csr = getCurrentCUDAStream (0 );
584585 cudaStream_t cs = csr.get ();
585586 ET_CHECK_OK_OR_RETURN_ERROR (csr.error ());
@@ -588,15 +589,16 @@ class ET_EXPERIMENTAL CudaBackend final
588589 for (size_t i = 0 ; i < n_inputs; i++) {
589590 auto * cpu_tensor = &(args[i]->toTensor ());
590591 cudaMemcpyAsync (
591- handle->static_input_ptrs [i],
592+ handle->cuda_graph_state . static_input_ptrs [i],
592593 cpu_tensor->const_data_ptr (),
593- handle->static_input_nbytes [i],
594+ handle->cuda_graph_state . static_input_nbytes [i],
594595 cudaMemcpyHostToDevice,
595596 cs);
596597 }
597598
598599 // Replay the captured graph
599- cudaError_t gerr = cudaGraphLaunch (handle->cuda_graph_exec , cs);
600+ cudaError_t gerr =
601+ cudaGraphLaunch (handle->cuda_graph_state .graph_exec , cs);
600602 ET_CHECK_OR_RETURN_ERROR (
601603 gerr == cudaSuccess,
602604 Internal,
@@ -611,8 +613,8 @@ class ET_EXPERIMENTAL CudaBackend final
611613 auto * cpu_out = &(args[i + n_inputs]->toTensor ());
612614 cudaMemcpyAsync (
613615 cpu_out->mutable_data_ptr (),
614- handle->static_output_ptrs [i],
615- handle->static_output_nbytes [i],
616+ handle->cuda_graph_state . static_output_ptrs [i],
617+ handle->cuda_graph_state . static_output_nbytes [i],
616618 cudaMemcpyDeviceToHost,
617619 cs);
618620 }
@@ -626,8 +628,8 @@ class ET_EXPERIMENTAL CudaBackend final
626628 // Normal path (also used for WARMUP and CAPTURE phases)
627629 // ---------------------------------------------------------------
628630 bool is_capture_step =
629- (handle->cuda_graph_phase == 1 &&
630- handle->cuda_graph_warmup_remaining == 0 );
631+ (handle->cuda_graph_state . phase == CudaGraphPhase::Warmup &&
632+ handle->cuda_graph_state . warmup_remaining == 0 );
631633
632634 // NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy
633635 // optimization. We need to create GPU copies for CUDA kernel execution
@@ -662,12 +664,12 @@ class ET_EXPERIMENTAL CudaBackend final
662664 nbytes,
663665 cudaMemcpyHostToDevice);
664666
665- handle->static_input_ptrs .push_back (static_ptr);
666- handle->static_input_sizes .push_back (sizes_vec);
667- handle->static_input_strides .push_back (strides_vec);
668- handle->static_input_scalar_types .push_back (
667+ handle->cuda_graph_state . static_input_ptrs .push_back (static_ptr);
668+ handle->cuda_graph_state . static_input_sizes .push_back (sizes_vec);
669+ handle->cuda_graph_state . static_input_strides .push_back (strides_vec);
670+ handle->cuda_graph_state . static_input_scalar_types .push_back (
669671 static_cast <int >(cpu_tensor->scalar_type ()));
670- handle->static_input_nbytes .push_back (nbytes);
672+ handle->cuda_graph_state . static_input_nbytes .push_back (nbytes);
671673
672674 gpu_inputs[i] = new SlimTensor (slim::from_blob (
673675 static_ptr,
@@ -797,16 +799,17 @@ class ET_EXPERIMENTAL CudaBackend final
797799
798800 if (is_capture_step) {
799801 // End capture → instantiate graph
800- cudaError_t gerr = cudaStreamEndCapture (cuda_stream, &handle->cuda_graph );
802+ cudaError_t gerr =
803+ cudaStreamEndCapture (cuda_stream, &handle->cuda_graph_state .graph );
801804 ET_CHECK_OR_RETURN_ERROR (
802805 gerr == cudaSuccess,
803806 Internal,
804807 " cudaStreamEndCapture failed: %s" ,
805808 cudaGetErrorString (gerr));
806809
807810 gerr = cudaGraphInstantiate (
808- &handle->cuda_graph_exec ,
809- handle->cuda_graph ,
811+ &handle->cuda_graph_state . graph_exec ,
812+ handle->cuda_graph_state . graph ,
810813 cudaGraphInstantiateFlagAutoFreeOnLaunch);
811814 ET_CHECK_OR_RETURN_ERROR (
812815 gerr == cudaSuccess,
@@ -817,27 +820,27 @@ class ET_EXPERIMENTAL CudaBackend final
817820 // Record static output pointers (stable under graph replay)
818821 for (size_t i = 0 ; i < n_outputs; i++) {
819822 SlimTensor* out = gpu_outputs[i];
820- handle->static_output_ptrs .push_back (out->data_ptr ());
823+ handle->cuda_graph_state . static_output_ptrs .push_back (out->data_ptr ());
821824
822825 auto out_sizes = out->sizes ();
823826 auto out_strides = out->strides ();
824- handle->static_output_sizes .push_back (
827+ handle->cuda_graph_state . static_output_sizes .push_back (
825828 std::vector<int64_t >(out_sizes.begin (), out_sizes.end ()));
826- handle->static_output_strides .push_back (
829+ handle->cuda_graph_state . static_output_strides .push_back (
827830 std::vector<int64_t >(out_strides.begin (), out_strides.end ()));
828- handle->static_output_scalar_types .push_back (
831+ handle->cuda_graph_state . static_output_scalar_types .push_back (
829832 static_cast <int >(out->dtype ()));
830- handle->static_output_nbytes .push_back (out->nbytes ());
833+ handle->cuda_graph_state . static_output_nbytes .push_back (out->nbytes ());
831834 }
832835
833- handle->cuda_graph_phase = 2 ; // switch to replay mode
836+ handle->cuda_graph_state . phase = CudaGraphPhase::Replay;
834837 ET_LOG (
835838 Info,
836839 " CUDA graph: captured and instantiated for '%s'" ,
837840 handle->method_name .c_str ());
838841
839842 // Replay once to actually produce output (capture doesn't execute)
840- gerr = cudaGraphLaunch (handle->cuda_graph_exec , cuda_stream);
843+ gerr = cudaGraphLaunch (handle->cuda_graph_state . graph_exec , cuda_stream);
841844 ET_CHECK_OR_RETURN_ERROR (
842845 gerr == cudaSuccess,
843846 Internal,
@@ -852,8 +855,8 @@ class ET_EXPERIMENTAL CudaBackend final
852855 auto * cpu_out = &(args[i + n_inputs]->toTensor ());
853856 cudaMemcpyAsync (
854857 cpu_out->mutable_data_ptr (),
855- handle->static_output_ptrs [i],
856- handle->static_output_nbytes [i],
858+ handle->cuda_graph_state . static_output_ptrs [i],
859+ handle->cuda_graph_state . static_output_nbytes [i],
857860 cudaMemcpyDeviceToHost,
858861 cuda_stream);
859862 // Don't delete — static buffers are owned by the handle
@@ -867,13 +870,13 @@ class ET_EXPERIMENTAL CudaBackend final
867870 // ----- Normal / WARMUP execution continues here -----
868871
869872 // Decrement warmup counter if in warmup phase
870- if (handle->cuda_graph_phase == 1 &&
871- handle->cuda_graph_warmup_remaining > 0 ) {
872- handle->cuda_graph_warmup_remaining --;
873+ if (handle->cuda_graph_state . phase == CudaGraphPhase::Warmup &&
874+ handle->cuda_graph_state . warmup_remaining > 0 ) {
875+ handle->cuda_graph_state . warmup_remaining --;
873876 ET_LOG (
874877 Info,
875878 " CUDA graph warmup: %d steps remaining for '%s'" ,
876- handle->cuda_graph_warmup_remaining ,
879+ handle->cuda_graph_state . warmup_remaining ,
877880 handle->method_name .c_str ());
878881 }
879882
0 commit comments