Skip to content

Commit f9153ab

Browse files
committed
lintrunner
1 parent 504709c commit f9153ab

4 files changed

Lines changed: 47 additions & 6 deletions

File tree

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,7 @@ class ET_EXPERIMENTAL CudaBackend final
693693

694694
gpu_inputs[i] = make_slimtensor_from_blob_with_etensor_metadata(
695695
static_ptr, cpu_tensor);
696+
696697
continue;
697698
}
698699

@@ -805,6 +806,7 @@ class ET_EXPERIMENTAL CudaBackend final
805806
// End capture → instantiate graph
806807
cudaError_t gerr =
807808
cudaStreamEndCapture(cuda_stream, &handle->cuda_graph_state.graph);
809+
808810
ET_CHECK_OR_RETURN_ERROR(
809811
gerr == cudaSuccess,
810812
Internal,
@@ -814,6 +816,7 @@ class ET_EXPERIMENTAL CudaBackend final
814816
gerr = cudaGraphInstantiate(
815817
&handle->cuda_graph_state.graph_exec,
816818
handle->cuda_graph_state.graph,
819+
817820
cudaGraphInstantiateFlagAutoFreeOnLaunch);
818821
ET_CHECK_OR_RETURN_ERROR(
819822
gerr == cudaSuccess,

backends/cuda/runtime/cuda_delegate_handle.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,45 @@ struct CudaDelegateHandle : public aoti::AOTIDelegateHandle {
149149

150150
// CUDA graph state (warmup, capture, replay, static buffers)
151151
CudaGraphState cuda_graph_state;
152+
// --- CUDA graph state ---
153+
// Phase: 0=disabled, 1=warmup, 2=captured (replay mode)
154+
int cuda_graph_phase = 0;
155+
int cuda_graph_warmup_remaining = 0;
156+
157+
// Captured graph and executable instance
158+
cudaGraph_t cuda_graph = nullptr;
159+
cudaGraphExec_t cuda_graph_exec = nullptr;
160+
161+
// Static input/output GPU buffers pinned during capture.
162+
// These hold the tensor metadata; the underlying data pointers are fixed
163+
// addresses that CUDA graph replay will write to / read from.
164+
// SlimTensor pointers — owned by this handle.
165+
std::vector<void*> static_input_ptrs; // raw GPU data pointers for inputs
166+
std::vector<void*> static_output_ptrs; // raw GPU data pointers for outputs
167+
std::vector<std::vector<int64_t>> static_input_sizes;
168+
std::vector<std::vector<int64_t>> static_input_strides;
169+
std::vector<std::vector<int64_t>> static_output_sizes;
170+
std::vector<std::vector<int64_t>> static_output_strides;
171+
std::vector<int> static_input_scalar_types;
172+
std::vector<int> static_output_scalar_types;
173+
std::vector<size_t> static_input_nbytes;
174+
std::vector<size_t> static_output_nbytes;
175+
176+
~CudaDelegateHandle() {
177+
if (cuda_graph_exec) {
178+
cudaGraphExecDestroy(cuda_graph_exec);
179+
}
180+
if (cuda_graph) {
181+
cudaGraphDestroy(cuda_graph);
182+
}
183+
// Only free input buffers — output buffers are owned by the AOTI runtime
184+
// (allocated during graph capture via the caching allocator).
185+
for (auto* ptr : static_input_ptrs) {
186+
if (ptr)
187+
cudaFree(ptr);
188+
}
189+
}
190+
>>>>>>> 028894ef8e (lintrunner)
152191
};
153192

154193
} // namespace cuda

examples/models/qwen3_5_moe/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ def _export_cuda(model, config, args):
794794
prefill_dynamic_shapes = (
795795
{1: seq_dim}, # tokens
796796
{0: seq_dim}, # input_pos
797-
None, # temperature (static scalar)
797+
None, # temperature (static scalar)
798798
)
799799
with torch.no_grad():
800800
prefill_ep = export(

examples/models/qwen3_5_moe/main.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,10 @@ int main(int argc, char** argv) {
147147

148148
// Use a very small temperature for greedy to avoid division by zero
149149
// while keeping the Gumbel noise negligible relative to logit differences.
150-
float temp_val = FLAGS_temperature <= 0.0
151-
? 1e-6f
152-
: static_cast<float>(FLAGS_temperature);
153-
auto temp_tensor = from_blob(
154-
&temp_val, {1}, executorch::aten::ScalarType::Float);
150+
float temp_val =
151+
FLAGS_temperature <= 0.0 ? 1e-6f : static_cast<float>(FLAGS_temperature);
152+
auto temp_tensor =
153+
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float);
155154

156155
// ---------------------------------------------------------------
157156
// Prefill

0 commit comments

Comments
 (0)