Skip to content

Commit 4473ab1

Browse files
author
gasoonjia
committed
Use share_mutable_buffers to eliminate select_scatter overhead
Revert from explicit state passing back to registered buffers with in-place updates (KVCache, conv_state, recurrent_state). Export with share_mutable_buffers=True so both prefill and forward methods share mutable state via mem_id=2. C++ runner uses share_memory_arenas=true and only passes (tokens, input_pos) — no CUDA runtime dependency. Results: 84.5 tok/s (up from 77.4), 0 select_scatter ops in profile, 65 D2H memcpy (logits only).
1 parent 3a1ee31 commit 4473ab1

5 files changed

Lines changed: 119 additions & 297 deletions

File tree

examples/models/qwen3_5_moe/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ list(
4444

4545
# CUDA backend (required)
4646
find_package(CUDAToolkit REQUIRED)
47-
list(APPEND link_libraries aoti_cuda_backend CUDA::cudart)
47+
list(APPEND link_libraries aoti_cuda_backend)
4848
executorch_target_link_options_shared_lib(aoti_cuda_backend)
4949

5050
# Tokenizer

examples/models/qwen3_5_moe/export.py

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096):
109109
runtime_prefixes = (
110110
".mask",
111111
".inv_freq",
112+
".kv_cache.",
113+
".conv_state",
114+
".recurrent_state",
112115
)
113116
expected_missing = {k for k in missing if any(p in k for p in runtime_prefixes)}
114117
weight_missing = set(missing) - expected_missing
@@ -310,7 +313,8 @@ def _materialize_buffers(model, config):
310313
311314
Replaces meta buffers with real tensors on CPU, recomputes RoPE
312315
inv_freq and causal masks. State buffers (KV cache, conv/recurrent
313-
state) are no longer registered buffers — they are explicit function args.
316+
state) are zero-initialized registered buffers that will be shared
317+
across methods via share_mutable_buffers.
314318
"""
315319
# Masks stay bool, inv_freq stays float32.
316320
for fqn, buf in list(model.named_buffers()):
@@ -359,8 +363,9 @@ def export_and_lower(model, config, args):
359363
- "prefill": prefill path (T>=2), uses chunked FLA triton_op with
360364
dynamic sequence length.
361365
362-
Both methods take explicit state tensors (conv_states, recurrent_states,
363-
k_caches, v_caches) as inputs and return updated state as outputs.
366+
Both methods share mutable state buffers (KV cache, conv_state,
367+
recurrent_state) via share_mutable_buffers=True. The model uses
368+
registered buffers with in-place updates — no state in/out args.
364369
"""
365370
import torch._inductor.config as inductor_config
366371

@@ -381,19 +386,14 @@ def export_and_lower(model, config, args):
381386
# -O0 compiles ~8x faster than -O1 with no measurable runtime impact.
382387
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"
383388

384-
# Create initial state tensors
385-
conv_states, recurrent_states, k_caches, v_caches = \
386-
Qwen35MoE.make_initial_state(config)
387-
388389
# --- Decode method (T=1, static shape) ---
389390
print("Exporting decode method (forward)...")
390391
decode_tokens = torch.tensor([[0]], dtype=torch.long)
391392
decode_pos = torch.tensor([0], dtype=torch.long)
392393
with torch.no_grad():
393394
decode_ep = export(
394395
model,
395-
(decode_tokens, decode_pos,
396-
conv_states, recurrent_states, k_caches, v_caches),
396+
(decode_tokens, decode_pos),
397397
strict=True,
398398
)
399399
print("Decode export successful!")
@@ -403,21 +403,14 @@ def export_and_lower(model, config, args):
403403
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
404404
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
405405
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
406-
# Dynamic shapes: only tokens dim 1 and pos dim 0 are dynamic;
407-
# state tensors have static shapes.
408406
prefill_dynamic_shapes = (
409407
{1: seq_dim}, # tokens
410408
{0: seq_dim}, # input_pos
411-
None, # conv_states
412-
None, # recurrent_states
413-
None, # k_caches
414-
None, # v_caches
415409
)
416410
with torch.no_grad():
417411
prefill_ep = export(
418412
model,
419-
(prefill_tokens, prefill_pos,
420-
conv_states, recurrent_states, k_caches, v_caches),
413+
(prefill_tokens, prefill_pos),
421414
dynamic_shapes=prefill_dynamic_shapes,
422415
strict=True,
423416
)
@@ -426,30 +419,13 @@ def export_and_lower(model, config, args):
426419
# Lower with CUDA backend (per-method partitioners to avoid so_blob collision)
427420
print("Lowering to ExecuTorch with CUDA...")
428421

429-
num_fla = sum(1 for t in config.layer_types if t == "linear_attention")
430-
num_attn = sum(1 for t in config.layer_types if t == "full_attention")
431-
conv_dim = (
432-
config.linear_num_key_heads * config.linear_key_head_dim * 2
433-
+ config.linear_num_value_heads * config.linear_value_head_dim
434-
)
435-
436422
metadata = {
437423
"get_max_seq_len": config.max_seq_len,
438424
"get_vocab_size": config.vocab_size,
439425
"get_n_layers": config.num_hidden_layers,
440426
"use_kv_cache": True,
441427
"use_sdpa_with_kv_cache": False,
442428
"enable_dynamic_shape": True,
443-
# State shape metadata for C++ runner
444-
"get_num_fla_layers": num_fla,
445-
"get_num_attn_layers": num_attn,
446-
"get_conv_dim": conv_dim,
447-
"get_conv_kernel_size": config.linear_conv_kernel_dim,
448-
"get_num_v_heads": config.linear_num_value_heads,
449-
"get_head_k_dim": config.linear_key_head_dim,
450-
"get_head_v_dim": config.linear_value_head_dim,
451-
"get_n_kv_heads": config.num_kv_heads,
452-
"get_head_dim": config.head_dim,
453429
}
454430
et_prog = to_edge_transform_and_lower(
455431
{"forward": decode_ep, "prefill": prefill_ep},
@@ -471,7 +447,11 @@ def export_and_lower(model, config, args):
471447
config=ExecutorchBackendConfig(
472448
extract_delegate_segments=True,
473449
do_quant_fusion_and_const_prop=True,
474-
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
450+
memory_planning_pass=MemoryPlanningPass(
451+
alloc_graph_input=False,
452+
share_mutable_buffers=True,
453+
),
454+
emit_mutable_buffer_names=True,
475455
),
476456
)
477457

examples/models/qwen3_5_moe/inference.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -71,34 +71,24 @@ def _move_to_cuda(model, config):
7171

7272

7373
def generate(
74-
model, config, tokenizer, prompt, max_new_tokens=128, temperature=0.0, eos_token_ids=None
74+
model, tokenizer, prompt, max_new_tokens=128, temperature=0.0, eos_token_ids=None
7575
):
76-
"""Generate text autoregressively with explicit state passing.
76+
"""Generate text autoregressively.
7777
78-
Prefills one token at a time (the chunk_gated_delta_rule kernel's chunked
79-
path has numerical issues with T>1 in eager mode; token-by-token uses the
80-
stable recurrent path).
78+
State (KV cache, conv_state, recurrent_state) is managed internally
79+
via registered buffers — the model signature is just (tokens, input_pos).
8180
"""
82-
from executorch.examples.models.qwen3_5_moe.model import Qwen35MoE
83-
8481
if eos_token_ids is None:
8582
eos_token_ids = set()
8683

8784
input_ids = tokenizer.encode(prompt).ids
8885

89-
# Initialize state tensors
90-
conv_states, recurrent_states, k_caches, v_caches = (
91-
Qwen35MoE.make_initial_state(config, dtype=torch.bfloat16, device="cuda")
92-
)
93-
94-
# Prefill: one token at a time
86+
# Prefill: one token at a time (recurrent path is stable for T=1)
9587
with torch.no_grad():
9688
for i, tok_id in enumerate(input_ids):
9789
tok = torch.tensor([[tok_id]], dtype=torch.long, device="cuda")
9890
pos = torch.tensor([i], dtype=torch.long, device="cuda")
99-
logits, conv_states, recurrent_states, k_caches, v_caches = model(
100-
tok, pos, conv_states, recurrent_states, k_caches, v_caches
101-
)
91+
logits = model(tok, pos)
10292

10393
# Sample first generated token
10494
next_token = _sample(logits[:, -1, :], temperature)
@@ -109,10 +99,7 @@ def generate(
10999
with torch.no_grad():
110100
for i in range(max_new_tokens - 1):
111101
pos = torch.tensor([seq_len + i], device="cuda")
112-
logits, conv_states, recurrent_states, k_caches, v_caches = model(
113-
next_token.unsqueeze(0), pos,
114-
conv_states, recurrent_states, k_caches, v_caches
115-
)
102+
logits = model(next_token.unsqueeze(0), pos)
116103
next_token = _sample(logits[:, -1, :], temperature)
117104
tok_id = next_token.item()
118105
generated.append(tok_id)
@@ -193,7 +180,6 @@ def main():
193180
t0 = time.perf_counter()
194181
output = generate(
195182
model,
196-
config,
197183
tokenizer,
198184
args.prompt,
199185
max_new_tokens=args.max_new_tokens,

0 commit comments

Comments
 (0)