@@ -109,6 +109,9 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096):
109109 runtime_prefixes = (
110110 ".mask" ,
111111 ".inv_freq" ,
112+ ".kv_cache." ,
113+ ".conv_state" ,
114+ ".recurrent_state" ,
112115 )
113116 expected_missing = {k for k in missing if any (p in k for p in runtime_prefixes )}
114117 weight_missing = set (missing ) - expected_missing
@@ -310,7 +313,8 @@ def _materialize_buffers(model, config):
310313
311314 Replaces meta buffers with real tensors on CPU, recomputes RoPE
312315 inv_freq and causal masks. State buffers (KV cache, conv/recurrent
313- state) are no longer registered buffers — they are explicit function args.
316+ state) are zero-initialized registered buffers that will be shared
317+ across methods via share_mutable_buffers.
314318 """
315319 # Masks stay bool, inv_freq stays float32.
316320 for fqn , buf in list (model .named_buffers ()):
@@ -359,8 +363,9 @@ def export_and_lower(model, config, args):
359363 - "prefill": prefill path (T>=2), uses chunked FLA triton_op with
360364 dynamic sequence length.
361365
362- Both methods take explicit state tensors (conv_states, recurrent_states,
363- k_caches, v_caches) as inputs and return updated state as outputs.
366+ Both methods share mutable state buffers (KV cache, conv_state,
367+ recurrent_state) via share_mutable_buffers=True. The model uses
368+ registered buffers with in-place updates — no state in/out args.
364369 """
365370 import torch ._inductor .config as inductor_config
366371
@@ -381,19 +386,14 @@ def export_and_lower(model, config, args):
381386 # -O0 compiles ~8x faster than -O1 with no measurable runtime impact.
382387 inductor_config .aot_inductor .compile_wrapper_opt_level = "O0"
383388
384- # Create initial state tensors
385- conv_states , recurrent_states , k_caches , v_caches = \
386- Qwen35MoE .make_initial_state (config )
387-
388389 # --- Decode method (T=1, static shape) ---
389390 print ("Exporting decode method (forward)..." )
390391 decode_tokens = torch .tensor ([[0 ]], dtype = torch .long )
391392 decode_pos = torch .tensor ([0 ], dtype = torch .long )
392393 with torch .no_grad ():
393394 decode_ep = export (
394395 model ,
395- (decode_tokens , decode_pos ,
396- conv_states , recurrent_states , k_caches , v_caches ),
396+ (decode_tokens , decode_pos ),
397397 strict = True ,
398398 )
399399 print ("Decode export successful!" )
@@ -403,21 +403,14 @@ def export_and_lower(model, config, args):
403403 prefill_tokens = torch .tensor ([[0 , 1 ]], dtype = torch .long )
404404 prefill_pos = torch .tensor ([0 , 1 ], dtype = torch .long )
405405 seq_dim = Dim ("seq_len" , min = 2 , max = config .max_seq_len - 1 )
406- # Dynamic shapes: only tokens dim 1 and pos dim 0 are dynamic;
407- # state tensors have static shapes.
408406 prefill_dynamic_shapes = (
409407 {1 : seq_dim }, # tokens
410408 {0 : seq_dim }, # input_pos
411- None , # conv_states
412- None , # recurrent_states
413- None , # k_caches
414- None , # v_caches
415409 )
416410 with torch .no_grad ():
417411 prefill_ep = export (
418412 model ,
419- (prefill_tokens , prefill_pos ,
420- conv_states , recurrent_states , k_caches , v_caches ),
413+ (prefill_tokens , prefill_pos ),
421414 dynamic_shapes = prefill_dynamic_shapes ,
422415 strict = True ,
423416 )
@@ -426,30 +419,13 @@ def export_and_lower(model, config, args):
426419 # Lower with CUDA backend (per-method partitioners to avoid so_blob collision)
427420 print ("Lowering to ExecuTorch with CUDA..." )
428421
429- num_fla = sum (1 for t in config .layer_types if t == "linear_attention" )
430- num_attn = sum (1 for t in config .layer_types if t == "full_attention" )
431- conv_dim = (
432- config .linear_num_key_heads * config .linear_key_head_dim * 2
433- + config .linear_num_value_heads * config .linear_value_head_dim
434- )
435-
436422 metadata = {
437423 "get_max_seq_len" : config .max_seq_len ,
438424 "get_vocab_size" : config .vocab_size ,
439425 "get_n_layers" : config .num_hidden_layers ,
440426 "use_kv_cache" : True ,
441427 "use_sdpa_with_kv_cache" : False ,
442428 "enable_dynamic_shape" : True ,
443- # State shape metadata for C++ runner
444- "get_num_fla_layers" : num_fla ,
445- "get_num_attn_layers" : num_attn ,
446- "get_conv_dim" : conv_dim ,
447- "get_conv_kernel_size" : config .linear_conv_kernel_dim ,
448- "get_num_v_heads" : config .linear_num_value_heads ,
449- "get_head_k_dim" : config .linear_key_head_dim ,
450- "get_head_v_dim" : config .linear_value_head_dim ,
451- "get_n_kv_heads" : config .num_kv_heads ,
452- "get_head_dim" : config .head_dim ,
453429 }
454430 et_prog = to_edge_transform_and_lower (
455431 {"forward" : decode_ep , "prefill" : prefill_ep },
@@ -471,7 +447,11 @@ def export_and_lower(model, config, args):
471447 config = ExecutorchBackendConfig (
472448 extract_delegate_segments = True ,
473449 do_quant_fusion_and_const_prop = True ,
474- memory_planning_pass = MemoryPlanningPass (alloc_graph_input = False ),
450+ memory_planning_pass = MemoryPlanningPass (
451+ alloc_graph_input = False ,
452+ share_mutable_buffers = True ,
453+ ),
454+ emit_mutable_buffer_names = True ,
475455 ),
476456 )
477457
0 commit comments