Skip to content

Commit 8f2ff30

Browse files
author
gasoonjia
committed
Dual-method PTE with GPU-resident state for Qwen3.5 MoE
- Split model into prefill (chunked FLA triton_op) and decode (native PyTorch recurrent delta rule) methods with explicit state passing - Add runtime_specs processing in CudaBackend::init() so LoadBackendOptionsMap options (skip_copy_output_to_cpu, use_shared_cuda_stream) take effect - Keep state tensors GPU-resident across method calls; only copy logits to CPU for sampling via cudaMemcpy - Achieves 77.4 tok/s decode (3.75x over naive dual-method baseline) Modified files: - cuda_backend.cpp: read runtime_specs in init() for skip_copy + shared stream - main.cpp: dual-method runner with GPU-resident state, logits CPU copy helper - CMakeLists.txt: link CUDA::cudart for cudaMemcpy - model.py: dual-method model definition (prefill + decode) - export.py: export script for dual-method PTE
1 parent 7dd4280 commit 8f2ff30

File tree

5 files changed

+501
-103
lines changed

5 files changed

+501
-103
lines changed

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,27 @@ class ET_EXPERIMENTAL CudaBackend final
262262
FreeableBuffer* processed, // This will be a empty buffer
263263
ArrayRef<CompileSpec> compile_specs // This will be my empty list
264264
) const override {
265+
// Apply runtime_specs from LoadBackendOptionsMap (if provided)
266+
auto runtime_specs = context.runtime_specs();
267+
if (runtime_specs.size() > 0) {
268+
for (size_t i = 0; i < runtime_specs.size(); ++i) {
269+
const auto& opt = runtime_specs[i];
270+
if (std::strcmp(opt.key, kSkipCopyOutputToCpuForMethod) == 0) {
271+
if (auto* val =
272+
std::get_if<std::array<char, kMaxOptionValueLength>>(
273+
&opt.value)) {
274+
const_cast<CudaBackend*>(this)->set_skip_copy_method(*val);
275+
}
276+
} else if (std::strcmp(opt.key, kUseSharedCudaStream) == 0) {
277+
if (auto* val = std::get_if<bool>(&opt.value)) {
278+
if (*val) {
279+
const_cast<CudaBackend*>(this)->create_shared_cuda_stream();
280+
}
281+
}
282+
}
283+
}
284+
}
285+
265286
std::string method_name;
266287
for (const CompileSpec& spec : compile_specs) {
267288
if (std::strcmp(spec.key, "method_name") == 0) {

examples/models/qwen3_5_moe/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ list(
4444

4545
# CUDA backend (required)
4646
find_package(CUDAToolkit REQUIRED)
47-
list(APPEND link_libraries aoti_cuda_backend)
47+
list(APPEND link_libraries aoti_cuda_backend CUDA::cudart)
4848
executorch_target_link_options_shared_lib(aoti_cuda_backend)
4949

5050
# Tokenizer

examples/models/qwen3_5_moe/export.py

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,6 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096):
107107
# Any missing weight key indicates a version mismatch between the
108108
# checkpoint and the model (e.g., unfused vs fused projections).
109109
runtime_prefixes = (
110-
".kv_cache.",
111-
".conv_state",
112-
".recurrent_state",
113110
".mask",
114111
".inv_freq",
115112
)
@@ -312,10 +309,10 @@ def _materialize_buffers(model, config):
312309
"""Materialize meta-device buffers before torch.export.
313310
314311
Replaces meta buffers with real tensors on CPU, recomputes RoPE
315-
inv_freq and causal masks.
312+
inv_freq and causal masks. State buffers (KV cache, conv/recurrent
313+
state) are no longer registered buffers — they are explicit function args.
316314
"""
317-
# State buffers (KV cache, conv/recurrent state) are bf16 to match
318-
# compute dtype. Masks stay bool, inv_freq stays float32.
315+
# Masks stay bool, inv_freq stays float32.
319316
for fqn, buf in list(model.named_buffers()):
320317
if buf.device.type == "meta":
321318
dtype = torch.bfloat16 if buf.dtype != torch.bool else torch.bool
@@ -354,7 +351,17 @@ def _materialize_buffers(model, config):
354351

355352

356353
def export_and_lower(model, config, args):
357-
"""Export model to .pte via torch.export + CUDA backend."""
354+
"""Export model to .pte via torch.export + CUDA backend.
355+
356+
Exports two methods:
357+
- "forward": decode path (T=1), uses native PyTorch recurrent FLA
358+
so AOTI can fuse with surrounding ops for maximum decode throughput.
359+
- "prefill": prefill path (T>=2), uses chunked FLA triton_op with
360+
dynamic sequence length.
361+
362+
Both methods take explicit state tensors (conv_states, recurrent_states,
363+
k_caches, v_caches) as inputs and return updated state as outputs.
364+
"""
358365
import torch._inductor.config as inductor_config
359366

360367
from executorch.backends.cuda.cuda_backend import CudaBackend
@@ -374,36 +381,86 @@ def export_and_lower(model, config, args):
374381
# -O0 compiles ~8x faster than -O1 with no measurable runtime impact.
375382
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"
376383

377-
# Dynamic shapes
378-
example_tokens = torch.tensor([[0, 1]], dtype=torch.long)
379-
example_input_pos = torch.tensor([0, 1], dtype=torch.long)
380-
seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1)
381-
dynamic_shapes = ({1: seq_dim}, {0: seq_dim})
384+
# Create initial state tensors
385+
conv_states, recurrent_states, k_caches, v_caches = \
386+
Qwen35MoE.make_initial_state(config)
382387

383-
print("Exporting with torch.export...")
388+
# --- Decode method (T=1, static shape) ---
389+
print("Exporting decode method (forward)...")
390+
decode_tokens = torch.tensor([[0]], dtype=torch.long)
391+
decode_pos = torch.tensor([0], dtype=torch.long)
384392
with torch.no_grad():
385-
exported = export(
393+
decode_ep = export(
386394
model,
387-
(example_tokens, example_input_pos),
388-
dynamic_shapes=dynamic_shapes,
395+
(decode_tokens, decode_pos,
396+
conv_states, recurrent_states, k_caches, v_caches),
389397
strict=True,
390398
)
391-
print("Export successful!")
399+
print("Decode export successful!")
400+
401+
# --- Prefill method (T>=2, dynamic shape) ---
402+
print("Exporting prefill method...")
403+
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
404+
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
405+
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
406+
# Dynamic shapes: only tokens dim 1 and pos dim 0 are dynamic;
407+
# state tensors have static shapes.
408+
prefill_dynamic_shapes = (
409+
{1: seq_dim}, # tokens
410+
{0: seq_dim}, # input_pos
411+
None, # conv_states
412+
None, # recurrent_states
413+
None, # k_caches
414+
None, # v_caches
415+
)
416+
with torch.no_grad():
417+
prefill_ep = export(
418+
model,
419+
(prefill_tokens, prefill_pos,
420+
conv_states, recurrent_states, k_caches, v_caches),
421+
dynamic_shapes=prefill_dynamic_shapes,
422+
strict=True,
423+
)
424+
print("Prefill export successful!")
392425

393-
# Lower with CUDA backend
426+
# Lower with CUDA backend (per-method partitioners to avoid so_blob collision)
394427
print("Lowering to ExecuTorch with CUDA...")
395-
compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")]
428+
429+
num_fla = sum(1 for t in config.layer_types if t == "linear_attention")
430+
num_attn = sum(1 for t in config.layer_types if t == "full_attention")
431+
conv_dim = (
432+
config.linear_num_key_heads * config.linear_key_head_dim * 2
433+
+ config.linear_num_value_heads * config.linear_value_head_dim
434+
)
435+
396436
metadata = {
397437
"get_max_seq_len": config.max_seq_len,
398438
"get_vocab_size": config.vocab_size,
399439
"get_n_layers": config.num_hidden_layers,
400440
"use_kv_cache": True,
401441
"use_sdpa_with_kv_cache": False,
402442
"enable_dynamic_shape": True,
443+
# State shape metadata for C++ runner
444+
"get_num_fla_layers": num_fla,
445+
"get_num_attn_layers": num_attn,
446+
"get_conv_dim": conv_dim,
447+
"get_conv_kernel_size": config.linear_conv_kernel_dim,
448+
"get_num_v_heads": config.linear_num_value_heads,
449+
"get_head_k_dim": config.linear_key_head_dim,
450+
"get_head_v_dim": config.linear_value_head_dim,
451+
"get_n_kv_heads": config.num_kv_heads,
452+
"get_head_dim": config.head_dim,
403453
}
404454
et_prog = to_edge_transform_and_lower(
405-
exported,
406-
partitioner=[CudaPartitioner(compile_specs)],
455+
{"forward": decode_ep, "prefill": prefill_ep},
456+
partitioner={
457+
"forward": [CudaPartitioner(
458+
[CudaBackend.generate_method_name_compile_spec("forward")]
459+
)],
460+
"prefill": [CudaPartitioner(
461+
[CudaBackend.generate_method_name_compile_spec("prefill")]
462+
)],
463+
},
407464
compile_config=EdgeCompileConfig(
408465
_check_ir_validity=False,
409466
_skip_dim_order=True,

0 commit comments

Comments
 (0)