@@ -107,9 +107,6 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096):
107107 # Any missing weight key indicates a version mismatch between the
108108 # checkpoint and the model (e.g., unfused vs fused projections).
109109 runtime_prefixes = (
110- ".kv_cache." ,
111- ".conv_state" ,
112- ".recurrent_state" ,
113110 ".mask" ,
114111 ".inv_freq" ,
115112 )
@@ -312,10 +309,10 @@ def _materialize_buffers(model, config):
312309 """Materialize meta-device buffers before torch.export.
313310
314311 Replaces meta buffers with real tensors on CPU, recomputes RoPE
315- inv_freq and causal masks.
312+ inv_freq and causal masks. State buffers (KV cache, conv/recurrent
313+ state) are no longer registered buffers — they are explicit function args.
316314 """
317- # State buffers (KV cache, conv/recurrent state) are bf16 to match
318- # compute dtype. Masks stay bool, inv_freq stays float32.
315+ # Masks stay bool, inv_freq stays float32.
319316 for fqn , buf in list (model .named_buffers ()):
320317 if buf .device .type == "meta" :
321318 dtype = torch .bfloat16 if buf .dtype != torch .bool else torch .bool
@@ -354,7 +351,17 @@ def _materialize_buffers(model, config):
354351
355352
356353def export_and_lower (model , config , args ):
357- """Export model to .pte via torch.export + CUDA backend."""
354+ """Export model to .pte via torch.export + CUDA backend.
355+
356+ Exports two methods:
357+ - "forward": decode path (T=1), uses native PyTorch recurrent FLA
358+ so AOTI can fuse with surrounding ops for maximum decode throughput.
359+ - "prefill": prefill path (T>=2), uses chunked FLA triton_op with
360+ dynamic sequence length.
361+
362+ Both methods take explicit state tensors (conv_states, recurrent_states,
363+ k_caches, v_caches) as inputs and return updated state as outputs.
364+ """
358365 import torch ._inductor .config as inductor_config
359366
360367 from executorch .backends .cuda .cuda_backend import CudaBackend
@@ -374,36 +381,86 @@ def export_and_lower(model, config, args):
374381 # -O0 compiles ~8x faster than -O1 with no measurable runtime impact.
375382 inductor_config .aot_inductor .compile_wrapper_opt_level = "O0"
376383
377- # Dynamic shapes
378- example_tokens = torch .tensor ([[0 , 1 ]], dtype = torch .long )
379- example_input_pos = torch .tensor ([0 , 1 ], dtype = torch .long )
380- seq_dim = Dim ("seq_len" , min = 1 , max = config .max_seq_len - 1 )
381- dynamic_shapes = ({1 : seq_dim }, {0 : seq_dim })
384+ # Create initial state tensors
385+ conv_states , recurrent_states , k_caches , v_caches = \
386+ Qwen35MoE .make_initial_state (config )
382387
383- print ("Exporting with torch.export..." )
388+ # --- Decode method (T=1, static shape) ---
389+ print ("Exporting decode method (forward)..." )
390+ decode_tokens = torch .tensor ([[0 ]], dtype = torch .long )
391+ decode_pos = torch .tensor ([0 ], dtype = torch .long )
384392 with torch .no_grad ():
385- exported = export (
393+ decode_ep = export (
386394 model ,
387- (example_tokens , example_input_pos ) ,
388- dynamic_shapes = dynamic_shapes ,
395+ (decode_tokens , decode_pos ,
396+ conv_states , recurrent_states , k_caches , v_caches ) ,
389397 strict = True ,
390398 )
391- print ("Export successful!" )
399+ print ("Decode export successful!" )
400+
401+ # --- Prefill method (T>=2, dynamic shape) ---
402+ print ("Exporting prefill method..." )
403+ prefill_tokens = torch .tensor ([[0 , 1 ]], dtype = torch .long )
404+ prefill_pos = torch .tensor ([0 , 1 ], dtype = torch .long )
405+ seq_dim = Dim ("seq_len" , min = 2 , max = config .max_seq_len - 1 )
406+ # Dynamic shapes: only tokens dim 1 and pos dim 0 are dynamic;
407+ # state tensors have static shapes.
408+ prefill_dynamic_shapes = (
409+ {1 : seq_dim }, # tokens
410+ {0 : seq_dim }, # input_pos
411+ None , # conv_states
412+ None , # recurrent_states
413+ None , # k_caches
414+ None , # v_caches
415+ )
416+ with torch .no_grad ():
417+ prefill_ep = export (
418+ model ,
419+ (prefill_tokens , prefill_pos ,
420+ conv_states , recurrent_states , k_caches , v_caches ),
421+ dynamic_shapes = prefill_dynamic_shapes ,
422+ strict = True ,
423+ )
424+ print ("Prefill export successful!" )
392425
393- # Lower with CUDA backend
426+ # Lower with CUDA backend (per-method partitioners to avoid so_blob collision)
394427 print ("Lowering to ExecuTorch with CUDA..." )
395- compile_specs = [CudaBackend .generate_method_name_compile_spec ("forward" )]
428+
429+ num_fla = sum (1 for t in config .layer_types if t == "linear_attention" )
430+ num_attn = sum (1 for t in config .layer_types if t == "full_attention" )
431+ conv_dim = (
432+ config .linear_num_key_heads * config .linear_key_head_dim * 2
433+ + config .linear_num_value_heads * config .linear_value_head_dim
434+ )
435+
396436 metadata = {
397437 "get_max_seq_len" : config .max_seq_len ,
398438 "get_vocab_size" : config .vocab_size ,
399439 "get_n_layers" : config .num_hidden_layers ,
400440 "use_kv_cache" : True ,
401441 "use_sdpa_with_kv_cache" : False ,
402442 "enable_dynamic_shape" : True ,
443+ # State shape metadata for C++ runner
444+ "get_num_fla_layers" : num_fla ,
445+ "get_num_attn_layers" : num_attn ,
446+ "get_conv_dim" : conv_dim ,
447+ "get_conv_kernel_size" : config .linear_conv_kernel_dim ,
448+ "get_num_v_heads" : config .linear_num_value_heads ,
449+ "get_head_k_dim" : config .linear_key_head_dim ,
450+ "get_head_v_dim" : config .linear_value_head_dim ,
451+ "get_n_kv_heads" : config .num_kv_heads ,
452+ "get_head_dim" : config .head_dim ,
403453 }
404454 et_prog = to_edge_transform_and_lower (
405- exported ,
406- partitioner = [CudaPartitioner (compile_specs )],
455+ {"forward" : decode_ep , "prefill" : prefill_ep },
456+ partitioner = {
457+ "forward" : [CudaPartitioner (
458+ [CudaBackend .generate_method_name_compile_spec ("forward" )]
459+ )],
460+ "prefill" : [CudaPartitioner (
461+ [CudaBackend .generate_method_name_compile_spec ("prefill" )]
462+ )],
463+ },
407464 compile_config = EdgeCompileConfig (
408465 _check_ir_validity = False ,
409466 _skip_dim_order = True ,
0 commit comments