Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5465d8b
Replace chunked FLA with recurrent gated delta rule for T=1 decode
Gasoonjia Apr 2, 2026
a6ebe8a
Runtime dispatch: recurrent (T=1) vs chunked (T>1) inside triton_op
Gasoonjia Apr 3, 2026
fc5018e
Revert model.py, export.py, main.cpp to main branch
Gasoonjia Apr 3, 2026
c90a8e8
Add tests for recurrent (T=1) and multi-T dispatch
Gasoonjia Apr 3, 2026
ce3e9ca
lint fix - 2
Gasoonjia Apr 3, 2026
8d35c65
lint fix - 2
Gasoonjia Apr 3, 2026
709deb0
Merge branch 'main' into recurrent-fla
Gasoonjia Apr 3, 2026
eff976d
lint fix - 3
Gasoonjia Apr 3, 2026
7dd4280
Optimize recurrent kernel: parallelize over V tiles
Gasoonjia Apr 3, 2026
3a1ee31
Dual-method PTE with GPU-resident state for Qwen3.5 MoE
Apr 5, 2026
63c162e
Use share_mutable_buffers to eliminate select_scatter overhead
Apr 6, 2026
47d6b98
Merge branch 'main' into recurrent-fla
Gasoonjia Apr 6, 2026
375e5c0
lint
Gasoonjia Apr 6, 2026
2b36797
remove reduntdant updates
Gasoonjia Apr 6, 2026
c06d58b
Cross-method AOTI constant sharing for KV cache
Apr 7, 2026
6945b2a
Fix cross-method AOTI constant sharing and add dual-method runner
Gasoonjia Apr 7, 2026
ea51d0d
Remove debug printf and decode_only flag
Gasoonjia Apr 7, 2026
a0a62f1
Lint formatting fixes
Gasoonjia Apr 7, 2026
ca69871
Improve CUDA backend error handling and add dual-method runner fallback
Apr 9, 2026
7c148f7
Add CUDA graph capture/replay for decode method
Apr 10, 2026
ee75c2e
Merge branch 'main' into cuda-graph
Gasoonjia Apr 10, 2026
10e7aad
lint and reformat
Gasoonjia Apr 13, 2026
9042f36
Merge branch 'main' into cuda-graph
Gasoonjia Apr 13, 2026
84d1587
Merge branch 'main' into cuda-graph
Gasoonjia Apr 15, 2026
e00a499
solve claude
Gasoonjia Apr 15, 2026
aa7bb82
Merge branch 'main' into cuda-graph
Gasoonjia Apr 15, 2026
cef386b
Merge branch 'main' into cuda-graph
Gasoonjia Apr 15, 2026
2d32422
Merge branch 'main' into cuda-graph
Gasoonjia Apr 16, 2026
1270870
Merge branch 'main' into cuda-graph
Gasoonjia Apr 16, 2026
8fc7355
solve stride out of scope
Gasoonjia Apr 17, 2026
2c46ed2
Merge branch 'main' into cuda-graph
Gasoonjia Apr 21, 2026
855eb93
Merge branch 'main' into cuda-graph
Gasoonjia Apr 22, 2026
4237d17
remove unused env var
Gasoonjia Apr 22, 2026
9b4705e
Merge branch 'main' into cuda-graph
Gasoonjia Apr 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/test_model_e2e.sh
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ EOF
fi
;;
qwen3_5_moe)
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0"
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0 --cuda_graph"
;;
voxtral_realtime)
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0"
Expand Down
246 changes: 232 additions & 14 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ using executorch::runtime::Span;
using executorch::runtime::etensor::Tensor;

// SlimTensor type aliases
using cuda::CudaGraphPhase;
using slim::CPU_DEVICE;
using slim::DEFAULT_CUDA_DEVICE;
using slim::DeviceTraits;
Expand All @@ -80,6 +81,8 @@ namespace {
constexpr char kSkipCopyOutputToCpuForMethod[] =
"skip_copy_output_to_cpu_for_method";
constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream";
constexpr char kEnableCudaGraphForMethod[] = "enable_cuda_graph_for_method";
constexpr int kCudaGraphWarmupSteps = 3;
constexpr char kShareKvCacheAcrossMethods[] = "share_kv_cache_across_methods";
} // anonymous namespace

Expand Down Expand Up @@ -147,6 +150,20 @@ class ET_EXPERIMENTAL CudaBackend final
return method_in_csv(method_name, skip_copy_method_);
}

void set_cuda_graph_method(
const std::array<char, kMaxOptionValueLength>& raw) {
std::lock_guard<std::mutex> guard(cuda_graph_method_mutex_);
cuda_graph_method_ = std::string(raw.data());
}

bool should_use_cuda_graph_for_method(const std::string& method_name) const {
if (method_name.empty()) {
return false;
}
std::lock_guard<std::mutex> guard(cuda_graph_method_mutex_);
return method_in_csv(method_name, cuda_graph_method_);
}

// Create the shared CUDA stream. Called when use_shared_cuda_stream option
// is set to true. The presence of shared_cuda_stream_ indicates shared mode.
void create_shared_cuda_stream() {
Expand Down Expand Up @@ -265,6 +282,17 @@ class ET_EXPERIMENTAL CudaBackend final
ET_LOG(Error, "Option %s must be a boolean.", kUseSharedCudaStream);
return Error::InvalidArgument;
}
} else if (std::strcmp(option.key, kEnableCudaGraphForMethod) == 0) {
if (auto* val = std::get_if<std::array<char, kMaxOptionValueLength>>(
&option.value)) {
set_cuda_graph_method(*val);
} else {
ET_LOG(
Error,
"Option %s must be a method name string.",
kEnableCudaGraphForMethod);
return Error::InvalidArgument;
}
}
}
return Error::Ok;
Expand Down Expand Up @@ -532,6 +560,17 @@ class ET_EXPERIMENTAL CudaBackend final
method_name.c_str());
}

// Initialize CUDA graph state if enabled for this method.
if (should_use_cuda_graph_for_method(method_name)) {
handle->cuda_graph_state.phase = CudaGraphPhase::Warmup;
handle->cuda_graph_state.warmup_remaining = kCudaGraphWarmupSteps;
ET_LOG(
Info,
"CUDA graph enabled for method '%s' (warmup=%d)",
method_name.c_str(),
kCudaGraphWarmupSteps);
}

return (DelegateHandle*)handle; // Return the handle post-processing
}

Expand All @@ -558,6 +597,68 @@ class ET_EXPERIMENTAL CudaBackend final
n_outputs,
args.size())

// ---------------------------------------------------------------
// CUDA graph REPLAY path — skip all tensor setup and just replay
// ---------------------------------------------------------------
if (handle->cuda_graph_state.phase == CudaGraphPhase::Replay) {
Result<cudaStream_t> csr = getCurrentCUDAStream(0);
ET_CHECK_OK_OR_RETURN_ERROR(csr.error());
cudaStream_t cs = csr.get();

// Copy new input data into static input buffers
for (size_t i = 0; i < n_inputs; i++) {
auto* cpu_tensor = &(args[i]->toTensor());
ET_CHECK_OR_RETURN_ERROR(
cpu_tensor->nbytes() ==
handle->cuda_graph_state.static_input_nbytes[i],
InvalidArgument,
"CUDA graph replay: input %zu size mismatch (expected %zu, got %zu)",
i,
handle->cuda_graph_state.static_input_nbytes[i],
cpu_tensor->nbytes());
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpyAsync(
handle->cuda_graph_state.static_input_ptrs[i],
cpu_tensor->const_data_ptr(),
handle->cuda_graph_state.static_input_nbytes[i],
cudaMemcpyHostToDevice,
cs));
}

// Replay the captured graph
cudaError_t gerr =
cudaGraphLaunch(handle->cuda_graph_state.graph_exec, cs);
ET_CHECK_OR_RETURN_ERROR(
gerr == cudaSuccess,
Internal,
"cudaGraphLaunch failed: %s",
cudaGetErrorString(gerr));

// Copy outputs back to CPU
const bool copy_outputs =
!should_skip_copy_for_method(handle->method_name);
if (copy_outputs) {
for (size_t i = 0; i < n_outputs; i++) {
auto* cpu_out = &(args[i + n_inputs]->toTensor());
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpyAsync(
cpu_out->mutable_data_ptr(),
handle->cuda_graph_state.static_output_ptrs[i],
handle->cuda_graph_state.static_output_nbytes[i],
cudaMemcpyDeviceToHost,
cs));
}
cudaStreamSynchronize(cs);
}

return Error::Ok;
}

// ---------------------------------------------------------------
// Normal path (also used for WARMUP and CAPTURE phases)
// ---------------------------------------------------------------
bool is_capture_step =
(handle->cuda_graph_state.phase == CudaGraphPhase::Warmup &&
handle->cuda_graph_state.warmup_remaining == 0);

// NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy
// optimization. We need to create GPU copies for CUDA kernel execution
// using SlimTensor.
Expand All @@ -568,6 +669,33 @@ class ET_EXPERIMENTAL CudaBackend final
for (size_t i = 0; i < n_inputs; i++) {
auto* cpu_tensor = &(args[i]->toTensor());

// CAPTURE step: allocate persistent static GPU buffers
if (is_capture_step) {
size_t nbytes = cpu_tensor->nbytes();

void* static_ptr = nullptr;
cudaError_t merr = cudaMalloc(&static_ptr, nbytes);
ET_CHECK_OR_RETURN_ERROR(
merr == cudaSuccess,
Internal,
"cudaMalloc for static input %zu failed: %s",
i,
cudaGetErrorString(merr));

cudaMemcpy(
static_ptr,
cpu_tensor->const_data_ptr(),
nbytes,
cudaMemcpyHostToDevice);

handle->cuda_graph_state.static_input_ptrs.push_back(static_ptr);
handle->cuda_graph_state.static_input_nbytes.push_back(nbytes);

gpu_inputs[i] = make_slimtensor_from_blob_with_etensor_metadata(
static_ptr, cpu_tensor);
continue;
}

// Check if input data is already on GPU (skip-copy optimization for
// inputs) This can happen when the caller has pre-staged data on GPU
cudaPointerAttributes attributes{};
Expand All @@ -576,19 +704,8 @@ class ET_EXPERIMENTAL CudaBackend final
cudaError_t err = cudaPointerGetAttributes(&attributes, data_ptr);
if (err == cudaSuccess && attributes.type == cudaMemoryTypeDevice) {
// Data is already on GPU - wrap it directly without copy
auto sizes = cpu_tensor->sizes();
auto strides = cpu_tensor->strides();
std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());
std::vector<int64_t> strides_vec(strides.begin(), strides.end());

gpu_inputs[i] = new SlimTensor(slim::from_blob(
const_cast<void*>(data_ptr),
slim::makeArrayRef(sizes_vec),
slim::makeArrayRef(strides_vec),
static_cast<slim::c10::ScalarType>(cpu_tensor->scalar_type()),
DEFAULT_CUDA_DEVICE,
0 // storage_offset
));
gpu_inputs[i] = make_slimtensor_from_blob_with_etensor_metadata(
const_cast<void*>(data_ptr), cpu_tensor);

continue;
}
Expand Down Expand Up @@ -640,8 +757,25 @@ class ET_EXPERIMENTAL CudaBackend final
// NOTE: run() steals input handles (RAII wraps them at the start of
// run_impl) and may replace output handles with its own.
Result<cudaStream_t> cuda_stream_ret = getCurrentCUDAStream(0);
cudaStream_t cuda_stream = cuda_stream_ret.get();
ET_CHECK_OK_OR_RETURN_ERROR(cuda_stream_ret.error());
cudaStream_t cuda_stream = cuda_stream_ret.get();

if (is_capture_step) {
// ----- CUDA graph CAPTURE -----
ET_LOG(
Info,
"CUDA graph: beginning stream capture for '%s'",
handle->method_name.c_str());

cudaError_t cerr =
cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeRelaxed);
ET_CHECK_OR_RETURN_ERROR(
cerr == cudaSuccess,
Internal,
"cudaStreamBeginCapture failed: %s",
cudaGetErrorString(cerr));
}

AOTIRuntimeError error = handle->run(
handle->container_handle,
reinterpret_cast<Tensor**>(gpu_inputs.data()),
Expand All @@ -667,6 +801,87 @@ class ET_EXPERIMENTAL CudaBackend final
"AOTInductorModelContainerRun failed with error code %d",
error);

if (is_capture_step) {
// End capture → instantiate graph
cudaError_t gerr =
cudaStreamEndCapture(cuda_stream, &handle->cuda_graph_state.graph);
ET_CHECK_OR_RETURN_ERROR(
gerr == cudaSuccess,
Internal,
"cudaStreamEndCapture failed: %s",
cudaGetErrorString(gerr));

gerr = cudaGraphInstantiate(
&handle->cuda_graph_state.graph_exec,
handle->cuda_graph_state.graph,
cudaGraphInstantiateFlagAutoFreeOnLaunch);
ET_CHECK_OR_RETURN_ERROR(
gerr == cudaSuccess,
Internal,
"cudaGraphInstantiate failed: %s",
cudaGetErrorString(gerr));

// Record static output pointers (stable under graph replay)
for (size_t i = 0; i < n_outputs; i++) {
SlimTensor* out = gpu_outputs[i];
handle->cuda_graph_state.static_output_ptrs.push_back(out->data_ptr());
handle->cuda_graph_state.static_output_nbytes.push_back(out->nbytes());
}

handle->cuda_graph_state.phase = CudaGraphPhase::Replay;
ET_LOG(
Info,
"CUDA graph: captured and instantiated for '%s'",
handle->method_name.c_str());

// Replay once to actually produce output (capture doesn't execute)
gerr = cudaGraphLaunch(handle->cuda_graph_state.graph_exec, cuda_stream);
ET_CHECK_OR_RETURN_ERROR(
gerr == cudaSuccess,
Internal,
"cudaGraphLaunch (first replay) failed: %s",
cudaGetErrorString(gerr));

// Copy capture-step outputs to CPU
const bool copy_outputs =
!should_skip_copy_for_method(handle->method_name);
if (copy_outputs) {
for (size_t i = 0; i < n_outputs; i++) {
auto* cpu_out = &(args[i + n_inputs]->toTensor());
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpyAsync(
cpu_out->mutable_data_ptr(),
handle->cuda_graph_state.static_output_ptrs[i],
handle->cuda_graph_state.static_output_nbytes[i],
cudaMemcpyDeviceToHost,
cuda_stream));
// Don't delete — static buffers are owned by the handle
gpu_outputs[i] = nullptr;
}
cudaStreamSynchronize(cuda_stream);
} else {
// Even when skipping copy, null out gpu_outputs to prevent
// the ScopeGuard from deleting static output buffers.
for (size_t i = 0; i < n_outputs; i++) {
gpu_outputs[i] = nullptr;
}
}

return Error::Ok;
}

// ----- Normal / WARMUP execution continues here -----

// Decrement warmup counter if in warmup phase
if (handle->cuda_graph_state.phase == CudaGraphPhase::Warmup &&
handle->cuda_graph_state.warmup_remaining > 0) {
handle->cuda_graph_state.warmup_remaining--;
ET_LOG(
Info,
"CUDA graph warmup: %d steps remaining for '%s'",
handle->cuda_graph_state.warmup_remaining,
handle->method_name.c_str());
}

const bool copy_outputs = !should_skip_copy_for_method(handle->method_name);

if (copy_outputs) {
Expand Down Expand Up @@ -761,6 +976,9 @@ class ET_EXPERIMENTAL CudaBackend final
mutable std::mutex skip_copy_method_mutex_;
std::string skip_copy_method_;

mutable std::mutex cuda_graph_method_mutex_;
std::string cuda_graph_method_;

// Shared CUDA stream for all methods. When set (non-null), all methods use
// the same stream to ensure proper ordering (critical for skip-copy
// optimization). Created when use_shared_cuda_stream option is set to true.
Expand Down
Loading
Loading