Skip to content

Commit 10e7aad

Browse files
committed
lint and reformat
1 parent ee75c2e commit 10e7aad

3 files changed

Lines changed: 101 additions & 99 deletions

File tree

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ using executorch::runtime::Span;
6868
using executorch::runtime::etensor::Tensor;
6969

7070
// SlimTensor type aliases
71+
using cuda::CudaGraphPhase;
7172
using slim::CPU_DEVICE;
7273
using slim::DEFAULT_CUDA_DEVICE;
7374
using slim::DeviceTraits;
@@ -80,8 +81,7 @@ namespace {
8081
constexpr char kSkipCopyOutputToCpuForMethod[] =
8182
"skip_copy_output_to_cpu_for_method";
8283
constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream";
83-
constexpr char kEnableCudaGraphForMethod[] =
84-
"enable_cuda_graph_for_method";
84+
constexpr char kEnableCudaGraphForMethod[] = "enable_cuda_graph_for_method";
8585
constexpr int kCudaGraphWarmupSteps = 3;
8686
} // anonymous namespace
8787

@@ -410,7 +410,9 @@ class ET_EXPERIMENTAL CudaBackend final
410410
cudaDeviceSynchronize();
411411
buffer_res->Free();
412412
} else {
413-
ET_LOG(Info, "weights_blob '%s' not found or update fn is null",
413+
ET_LOG(
414+
Info,
415+
"weights_blob '%s' not found or update fn is null",
414416
weights_blob_key.c_str());
415417
}
416418

@@ -540,8 +542,8 @@ class ET_EXPERIMENTAL CudaBackend final
540542

541543
// Initialize CUDA graph state if enabled for this method.
542544
if (should_use_cuda_graph_for_method(method_name)) {
543-
handle->cuda_graph_phase = 1; // warmup
544-
handle->cuda_graph_warmup_remaining = kCudaGraphWarmupSteps;
545+
handle->cuda_graph_state.phase = CudaGraphPhase::Warmup;
546+
handle->cuda_graph_state.warmup_remaining = kCudaGraphWarmupSteps;
545547
ET_LOG(
546548
Info,
547549
"CUDA graph enabled for method '%s' (warmup=%d)",
@@ -578,7 +580,7 @@ class ET_EXPERIMENTAL CudaBackend final
578580
// ---------------------------------------------------------------
579581
// CUDA graph REPLAY path — skip all tensor setup and just replay
580582
// ---------------------------------------------------------------
581-
if (handle->cuda_graph_phase == 2) {
583+
if (handle->cuda_graph_state.phase == CudaGraphPhase::Replay) {
582584
Result<cudaStream_t> csr = getCurrentCUDAStream(0);
583585
cudaStream_t cs = csr.get();
584586
ET_CHECK_OK_OR_RETURN_ERROR(csr.error());
@@ -587,15 +589,16 @@ class ET_EXPERIMENTAL CudaBackend final
587589
for (size_t i = 0; i < n_inputs; i++) {
588590
auto* cpu_tensor = &(args[i]->toTensor());
589591
cudaMemcpyAsync(
590-
handle->static_input_ptrs[i],
592+
handle->cuda_graph_state.static_input_ptrs[i],
591593
cpu_tensor->const_data_ptr(),
592-
handle->static_input_nbytes[i],
594+
handle->cuda_graph_state.static_input_nbytes[i],
593595
cudaMemcpyHostToDevice,
594596
cs);
595597
}
596598

597599
// Replay the captured graph
598-
cudaError_t gerr = cudaGraphLaunch(handle->cuda_graph_exec, cs);
600+
cudaError_t gerr =
601+
cudaGraphLaunch(handle->cuda_graph_state.graph_exec, cs);
599602
ET_CHECK_OR_RETURN_ERROR(
600603
gerr == cudaSuccess,
601604
Internal,
@@ -610,8 +613,8 @@ class ET_EXPERIMENTAL CudaBackend final
610613
auto* cpu_out = &(args[i + n_inputs]->toTensor());
611614
cudaMemcpyAsync(
612615
cpu_out->mutable_data_ptr(),
613-
handle->static_output_ptrs[i],
614-
handle->static_output_nbytes[i],
616+
handle->cuda_graph_state.static_output_ptrs[i],
617+
handle->cuda_graph_state.static_output_nbytes[i],
615618
cudaMemcpyDeviceToHost,
616619
cs);
617620
}
@@ -625,8 +628,8 @@ class ET_EXPERIMENTAL CudaBackend final
625628
// Normal path (also used for WARMUP and CAPTURE phases)
626629
// ---------------------------------------------------------------
627630
bool is_capture_step =
628-
(handle->cuda_graph_phase == 1 &&
629-
handle->cuda_graph_warmup_remaining == 0);
631+
(handle->cuda_graph_state.phase == CudaGraphPhase::Warmup &&
632+
handle->cuda_graph_state.warmup_remaining == 0);
630633

631634
// NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy
632635
// optimization. We need to create GPU copies for CUDA kernel execution
@@ -649,27 +652,32 @@ class ET_EXPERIMENTAL CudaBackend final
649652
void* static_ptr = nullptr;
650653
cudaError_t merr = cudaMalloc(&static_ptr, nbytes);
651654
ET_CHECK_OR_RETURN_ERROR(
652-
merr == cudaSuccess, Internal,
655+
merr == cudaSuccess,
656+
Internal,
653657
"cudaMalloc for static input %zu failed: %s",
654-
i, cudaGetErrorString(merr));
658+
i,
659+
cudaGetErrorString(merr));
655660

656661
cudaMemcpy(
657-
static_ptr, cpu_tensor->const_data_ptr(),
658-
nbytes, cudaMemcpyHostToDevice);
662+
static_ptr,
663+
cpu_tensor->const_data_ptr(),
664+
nbytes,
665+
cudaMemcpyHostToDevice);
659666

660-
handle->static_input_ptrs.push_back(static_ptr);
661-
handle->static_input_sizes.push_back(sizes_vec);
662-
handle->static_input_strides.push_back(strides_vec);
663-
handle->static_input_scalar_types.push_back(
667+
handle->cuda_graph_state.static_input_ptrs.push_back(static_ptr);
668+
handle->cuda_graph_state.static_input_sizes.push_back(sizes_vec);
669+
handle->cuda_graph_state.static_input_strides.push_back(strides_vec);
670+
handle->cuda_graph_state.static_input_scalar_types.push_back(
664671
static_cast<int>(cpu_tensor->scalar_type()));
665-
handle->static_input_nbytes.push_back(nbytes);
672+
handle->cuda_graph_state.static_input_nbytes.push_back(nbytes);
666673

667674
gpu_inputs[i] = new SlimTensor(slim::from_blob(
668675
static_ptr,
669676
slim::makeArrayRef(sizes_vec),
670677
slim::makeArrayRef(strides_vec),
671678
static_cast<slim::c10::ScalarType>(cpu_tensor->scalar_type()),
672-
DEFAULT_CUDA_DEVICE, 0));
679+
DEFAULT_CUDA_DEVICE,
680+
0));
673681
continue;
674682
}
675683

@@ -755,8 +763,8 @@ class ET_EXPERIMENTAL CudaBackend final
755763
"CUDA graph: beginning stream capture for '%s'",
756764
handle->method_name.c_str());
757765

758-
cudaError_t cerr = cudaStreamBeginCapture(
759-
cuda_stream, cudaStreamCaptureModeRelaxed);
766+
cudaError_t cerr =
767+
cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeRelaxed);
760768
ET_CHECK_OR_RETURN_ERROR(
761769
cerr == cudaSuccess,
762770
Internal,
@@ -792,15 +800,16 @@ class ET_EXPERIMENTAL CudaBackend final
792800
if (is_capture_step) {
793801
// End capture → instantiate graph
794802
cudaError_t gerr =
795-
cudaStreamEndCapture(cuda_stream, &handle->cuda_graph);
803+
cudaStreamEndCapture(cuda_stream, &handle->cuda_graph_state.graph);
796804
ET_CHECK_OR_RETURN_ERROR(
797805
gerr == cudaSuccess,
798806
Internal,
799807
"cudaStreamEndCapture failed: %s",
800808
cudaGetErrorString(gerr));
801809

802810
gerr = cudaGraphInstantiate(
803-
&handle->cuda_graph_exec, handle->cuda_graph,
811+
&handle->cuda_graph_state.graph_exec,
812+
handle->cuda_graph_state.graph,
804813
cudaGraphInstantiateFlagAutoFreeOnLaunch);
805814
ET_CHECK_OR_RETURN_ERROR(
806815
gerr == cudaSuccess,
@@ -811,27 +820,27 @@ class ET_EXPERIMENTAL CudaBackend final
811820
// Record static output pointers (stable under graph replay)
812821
for (size_t i = 0; i < n_outputs; i++) {
813822
SlimTensor* out = gpu_outputs[i];
814-
handle->static_output_ptrs.push_back(out->data_ptr());
823+
handle->cuda_graph_state.static_output_ptrs.push_back(out->data_ptr());
815824

816825
auto out_sizes = out->sizes();
817826
auto out_strides = out->strides();
818-
handle->static_output_sizes.push_back(
827+
handle->cuda_graph_state.static_output_sizes.push_back(
819828
std::vector<int64_t>(out_sizes.begin(), out_sizes.end()));
820-
handle->static_output_strides.push_back(
829+
handle->cuda_graph_state.static_output_strides.push_back(
821830
std::vector<int64_t>(out_strides.begin(), out_strides.end()));
822-
handle->static_output_scalar_types.push_back(
831+
handle->cuda_graph_state.static_output_scalar_types.push_back(
823832
static_cast<int>(out->dtype()));
824-
handle->static_output_nbytes.push_back(out->nbytes());
833+
handle->cuda_graph_state.static_output_nbytes.push_back(out->nbytes());
825834
}
826835

827-
handle->cuda_graph_phase = 2; // switch to replay mode
836+
handle->cuda_graph_state.phase = CudaGraphPhase::Replay;
828837
ET_LOG(
829838
Info,
830839
"CUDA graph: captured and instantiated for '%s'",
831840
handle->method_name.c_str());
832841

833842
// Replay once to actually produce output (capture doesn't execute)
834-
gerr = cudaGraphLaunch(handle->cuda_graph_exec, cuda_stream);
843+
gerr = cudaGraphLaunch(handle->cuda_graph_state.graph_exec, cuda_stream);
835844
ET_CHECK_OR_RETURN_ERROR(
836845
gerr == cudaSuccess,
837846
Internal,
@@ -846,8 +855,8 @@ class ET_EXPERIMENTAL CudaBackend final
846855
auto* cpu_out = &(args[i + n_inputs]->toTensor());
847856
cudaMemcpyAsync(
848857
cpu_out->mutable_data_ptr(),
849-
handle->static_output_ptrs[i],
850-
handle->static_output_nbytes[i],
858+
handle->cuda_graph_state.static_output_ptrs[i],
859+
handle->cuda_graph_state.static_output_nbytes[i],
851860
cudaMemcpyDeviceToHost,
852861
cuda_stream);
853862
// Don't delete — static buffers are owned by the handle
@@ -861,13 +870,13 @@ class ET_EXPERIMENTAL CudaBackend final
861870
// ----- Normal / WARMUP execution continues here -----
862871

863872
// Decrement warmup counter if in warmup phase
864-
if (handle->cuda_graph_phase == 1 &&
865-
handle->cuda_graph_warmup_remaining > 0) {
866-
handle->cuda_graph_warmup_remaining--;
873+
if (handle->cuda_graph_state.phase == CudaGraphPhase::Warmup &&
874+
handle->cuda_graph_state.warmup_remaining > 0) {
875+
handle->cuda_graph_state.warmup_remaining--;
867876
ET_LOG(
868877
Info,
869878
"CUDA graph warmup: %d steps remaining for '%s'",
870-
handle->cuda_graph_warmup_remaining,
879+
handle->cuda_graph_state.warmup_remaining,
871880
handle->method_name.c_str());
872881
}
873882

backends/cuda/runtime/cuda_delegate_handle.h

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,42 +39,27 @@ inline std::shared_ptr<cudaStream_t> create_cuda_stream() {
3939
return std::shared_ptr<cudaStream_t>(
4040
new cudaStream_t(stream), CudaStreamDeleter());
4141
}
42-
// CUDA-specific delegate handle that extends AOTIDelegateHandle.
43-
// This consolidates CUDA stream management into a single location.
44-
struct CudaDelegateHandle : public aoti::AOTIDelegateHandle {
45-
// CUDA stream for this handle, support both shared mode and single mode.
46-
// In shared mode, all cuda delegate handles share the same stream (e.g., for
47-
// skip-copy optimization), they will all hold a reference to the same
48-
// shared_ptr. The stream is automatically destroyed when the last handle is
49-
// destroyed. In single mode, every cuda delegate handle has its own stream.
50-
std::shared_ptr<cudaStream_t> cuda_stream;
51-
52-
// Get the raw CUDA stream pointer for use in CUDA API calls.
53-
// Returns nullptr if no stream is set.
54-
cudaStream_t get_cuda_stream() const {
55-
return cuda_stream ? *cuda_stream : nullptr;
56-
}
5742

58-
// Check if this handle has a valid CUDA stream.
59-
bool has_cuda_stream() const {
60-
return cuda_stream != nullptr && *cuda_stream != nullptr;
61-
}
43+
enum class CudaGraphPhase {
44+
Disabled = 0,
45+
Warmup = 1,
46+
Replay = 2,
47+
};
6248

63-
// --- CUDA graph state ---
64-
// Phase: 0=disabled, 1=warmup, 2=captured (replay mode)
65-
int cuda_graph_phase = 0;
66-
int cuda_graph_warmup_remaining = 0;
49+
// All CUDA graph related state grouped into a single struct.
50+
struct CudaGraphState {
51+
CudaGraphPhase phase = CudaGraphPhase::Disabled;
52+
int warmup_remaining = 0;
6753

6854
// Captured graph and executable instance
69-
cudaGraph_t cuda_graph = nullptr;
70-
cudaGraphExec_t cuda_graph_exec = nullptr;
55+
cudaGraph_t graph = nullptr;
56+
cudaGraphExec_t graph_exec = nullptr;
7157

7258
// Static input/output GPU buffers pinned during capture.
7359
// These hold the tensor metadata; the underlying data pointers are fixed
7460
// addresses that CUDA graph replay will write to / read from.
75-
// SlimTensor pointers — owned by this handle.
76-
std::vector<void*> static_input_ptrs; // raw GPU data pointers for inputs
77-
std::vector<void*> static_output_ptrs; // raw GPU data pointers for outputs
61+
std::vector<void*> static_input_ptrs;
62+
std::vector<void*> static_output_ptrs;
7863
std::vector<std::vector<int64_t>> static_input_sizes;
7964
std::vector<std::vector<int64_t>> static_input_strides;
8065
std::vector<std::vector<int64_t>> static_output_sizes;
@@ -84,12 +69,12 @@ struct CudaDelegateHandle : public aoti::AOTIDelegateHandle {
8469
std::vector<size_t> static_input_nbytes;
8570
std::vector<size_t> static_output_nbytes;
8671

87-
~CudaDelegateHandle() {
88-
if (cuda_graph_exec) {
89-
cudaGraphExecDestroy(cuda_graph_exec);
72+
~CudaGraphState() {
73+
if (graph_exec) {
74+
cudaGraphExecDestroy(graph_exec);
9075
}
91-
if (cuda_graph) {
92-
cudaGraphDestroy(cuda_graph);
76+
if (graph) {
77+
cudaGraphDestroy(graph);
9378
}
9479
// Only free input buffers — output buffers are owned by the AOTI runtime
9580
// (allocated during graph capture via the caching allocator).
@@ -100,6 +85,31 @@ struct CudaDelegateHandle : public aoti::AOTIDelegateHandle {
10085
}
10186
};
10287

88+
// CUDA-specific delegate handle that extends AOTIDelegateHandle.
89+
// This consolidates CUDA stream management into a single location.
90+
struct CudaDelegateHandle : public aoti::AOTIDelegateHandle {
91+
// CUDA stream for this handle, support both shared mode and single mode.
92+
// In shared mode, all cuda delegate handles share the same stream (e.g., for
93+
// skip-copy optimization), they will all hold a reference to the same
94+
// shared_ptr. The stream is automatically destroyed when the last handle is
95+
// destroyed. In single mode, every cuda delegate handle has its own stream.
96+
std::shared_ptr<cudaStream_t> cuda_stream;
97+
98+
// Get the raw CUDA stream pointer for use in CUDA API calls.
99+
// Returns nullptr if no stream is set.
100+
cudaStream_t get_cuda_stream() const {
101+
return cuda_stream ? *cuda_stream : nullptr;
102+
}
103+
104+
// Check if this handle has a valid CUDA stream.
105+
bool has_cuda_stream() const {
106+
return cuda_stream != nullptr && *cuda_stream != nullptr;
107+
}
108+
109+
// CUDA graph state (warmup, capture, replay, static buffers)
110+
CudaGraphState cuda_graph_state;
111+
};
112+
103113
} // namespace cuda
104114
} // namespace backends
105115
} // namespace executorch

examples/models/qwen3_5_moe/main.cpp

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@ int main(int argc, char** argv) {
8787
}
8888
auto metadata = metadata_result.get();
8989

90-
printf("Loading methods...\n");
91-
9290
// Set CUDA graph option if requested (must be before load_method)
9391
if (FLAGS_cuda_graph) {
9492
executorch::runtime::BackendOptions<2> cuda_opts;
@@ -97,27 +95,17 @@ int main(int argc, char** argv) {
9795
printf("CUDA graph enabled for decode method\n");
9896
}
9997

100-
// Try loading both methods; fall back to single "forward" method
101-
bool dual_method = true;
102-
std::string prefill_method = "prefill";
98+
printf("Loading methods...\n");
99+
103100
auto err = module->load_method("prefill");
104101
if (err != Error::Ok) {
105-
// Try "forward" for single-method export
106-
err = module->load_method("forward");
107-
if (err != Error::Ok) {
108-
ET_LOG(Error, "Failed to load prefill/forward method");
109-
return 1;
110-
}
111-
prefill_method = "forward";
112-
dual_method = false;
113-
printf("Using single-method mode (forward)\n");
102+
ET_LOG(Error, "Failed to load prefill method");
103+
return 1;
114104
}
115-
if (dual_method) {
116-
err = module->load_method("decode");
117-
if (err != Error::Ok) {
118-
ET_LOG(Error, "Failed to load decode method");
119-
return 1;
120-
}
105+
err = module->load_method("decode");
106+
if (err != Error::Ok) {
107+
ET_LOG(Error, "Failed to load decode method");
108+
return 1;
121109
}
122110

123111
// Get EOS ids
@@ -160,7 +148,7 @@ int main(int argc, char** argv) {
160148
prefill_inputs.push_back(tokens_tensor);
161149
prefill_inputs.push_back(pos_tensor);
162150

163-
auto prefill_result = module->execute(prefill_method, prefill_inputs);
151+
auto prefill_result = module->execute("prefill", prefill_inputs);
164152
if (prefill_result.error() != Error::Ok) {
165153
ET_LOG(Error, "Prefill failed");
166154
return 1;
@@ -187,11 +175,6 @@ int main(int argc, char** argv) {
187175
// decode method, which may run on a different CUDA stream.
188176
cudaDeviceSynchronize();
189177

190-
if (!dual_method) {
191-
printf("Single-method mode: skipping decode\n");
192-
return 0;
193-
}
194-
195178
// ---------------------------------------------------------------
196179
// Decode — generate tokens one at a time
197180
// ---------------------------------------------------------------

0 commit comments

Comments
 (0)