Skip to content

Commit c015241

Browse files
mergennachinclaude
andauthored
Enable native GQA and hoist mask computation for Metal SDPA (#17720)
Replace F.scaled_dot_product_attention with a direct call to torch.ops.aten._scaled_dot_product_attention_math_for_mps in StandardSDPA. This is necessary because F.scaled_dot_product_attention is CompositeImplicitAutograd — torch.export() decomposes it back into repeat_interleave + matmul for GQA, defeating native kernel support. The _for_mps op stays as a single node in the exported graph and resolves at runtime to the custom Metal SDPA shader in op_sdpa.mm, which handles GQA natively via gqa_factor = n_heads / n_kv_heads. For voxtral (32 Q heads, 8 KV heads), this eliminates 4x redundant K/V memory traffic per layer (repeat_interleave materialized 128MB of expanded KV per layer vs 32MB with native GQA). Also hoist the attention mask computation from StandardSDPA (called 26x per token, once per layer) to MistralDecoder.forward (called 1x). Build the mask using integer arithmetic (clamp) instead of bool comparisons since Metal AOTI doesn't support bool tensor allocation. The XNNPACK path is unchanged. Improved from 19 tokens/s to 40 tokens/s Test Plan: ``` I 00:00:26.888453 executorch:voxtral_realtime_runner.cpp:247] Audio: 314240 samples -> 375 frames One, two, three. This is February 18th, Wednesday, and I'm testing an AI model. Tell me an interesting story. One, two, three. Thank you..</s> PyTorchObserver {"prompt_tokens":0,"generated_tokens":378,"model_load_start_ms":1772062281912,"model_load_end_ms":1772062303509,"inference_start_ms":1772062303509,"inference_end_ms":1772062319075,"prompt_eval_end_ms":1772062309797,"first_token_ms":1772062309797,"aggregate_sampling_time_ms":0,"SCALING_FACTOR_UNITS_PER_SECOND":1000} I 00:00:37.164707 executorch:stats.h:143] Prompt Tokens: 0 Generated Tokens: 378 I 00:00:37.164710 executorch:stats.h:149] Model Load Time: 21.597000 (seconds) I 00:00:37.164713 executorch:stats.h:159] Total inference time: 15.566000 (seconds) Rate: 24.283695 (tokens/second) I 00:00:37.164715 executorch:stats.h:167] Prompt evaluation: 6.288000 (seconds) Rate: 0.000000 (tokens/second) I 00:00:37.164736 executorch:stats.h:178] Generated 378 tokens: 9.278000 (seconds) Rate: 40.741539 (tokens/second) I 00:00:37.164747 executorch:stats.h:186] Time to first generated token: 6.288000 (seconds) I 00:00:37.164749 executorch:stats.h:193] Sampling time over 378 tokens: 0.000000 (seconds) I 00:00:37.166004 executorch:metal_backend.cpp:716] Removed temporary shared library file: /var/folders/_1/z_wzgpv50gn73c02j5xmcnf40000gn/T/text_decoder_so_blob46972.so I 00:00:37.291288 executorch:memory.cpp:642] Cleared all tensors and Metal resources I 00:00:37.291937 executorch:metal_backend.cpp:716] Removed temporary shared library file: /var/folders/_1/z_wzgpv50gn73c02j5xmcnf40000gn/T/token_embedding_so_blob46972.so I 00:00:37.291944 executorch:memory.cpp:642] Cleared all tensors and Metal resources I 00:00:37.292499 executorch:metal_backend.cpp:716] Removed temporary shared library file: /var/folders/_1/z_wzgpv50gn73c02j5xmcnf40000gn/T/audio_encoder_so_blob46972.so I 00:00:37.292507 executorch:memory.cpp:642] Cleared all tensors and Metal resources ``` Co-authored-by: Claude <noreply@anthropic.com>
1 parent f30d5ed commit c015241

1 file changed

Lines changed: 54 additions & 48 deletions

File tree

  • examples/models/voxtral_realtime

examples/models/voxtral_realtime/model.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -427,11 +427,33 @@ def forward(
427427
return y.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
428428

429429

430-
class StandardSDPA(nn.Module):
431-
"""Standard scaled dot-product attention using F.scaled_dot_product_attention.
430+
def _build_attn_mask(
431+
input_pos: torch.Tensor, max_seq_len: int, device: torch.device
432+
) -> torch.Tensor:
433+
"""Build float additive attention mask without bool intermediates.
432434
433-
Compatible with AOTI/Metal backend. Handles GQA and causal masking using
434-
standard PyTorch operations. Works with StaticKVCache that returns [B, H, S, D] layout.
435+
Metal AOTI doesn't support bool tensor allocation on MPS, so we use
436+
integer arithmetic: clamp(curr_pos - k_pos + 1, 0, 1) gives 1 for
437+
valid positions (k <= curr_pos) and 0 for invalid, then convert to
438+
additive mask (0.0 = attend, -1e9 = don't attend).
439+
"""
440+
seqlen = input_pos.shape[0]
441+
k_pos = torch.arange(max_seq_len, device=device)
442+
if seqlen > 1:
443+
# Prefill: [seqlen, max_seq_len]
444+
diff = input_pos.unsqueeze(1) - k_pos.unsqueeze(0) + 1
445+
else:
446+
# Decode: [1, max_seq_len]
447+
diff = (input_pos[0] - k_pos + 1).unsqueeze(0)
448+
valid = torch.clamp(diff, min=0, max=1)
449+
return (valid.float() - 1.0) * 1e9
450+
451+
452+
class MetalSDPA(nn.Module):
453+
"""Standard SDPA calling the MPS op directly for native GQA support.
454+
455+
The Metal SDPA kernel handles GQA natively via gqa_factor = n_heads / n_kv_heads,
456+
avoiding the 4x memory bandwidth overhead of repeat_interleave.
435457
"""
436458

437459
def __init__(self, n_heads: int, n_kv_heads: int, head_dim: int):
@@ -449,59 +471,30 @@ def forward(
449471
v: torch.Tensor,
450472
bsz: int,
451473
seqlen: int,
474+
attn_mask: torch.Tensor | None = None,
452475
) -> torch.Tensor:
453476
"""
454477
Args:
455478
input_pos: (seq_len,) position indices.
456479
q: (B, seq_len, n_heads, head_dim) in [B, S, H, D] layout.
457480
k, v: (B, n_kv_heads, max_seq_len, head_dim) in [B, H, S, D] layout from StaticKVCache.
458481
bsz, seqlen: batch size and query sequence length.
482+
attn_mask: precomputed float additive mask, or None to compute here.
459483
Returns:
460484
output: (B, seq_len, n_heads * head_dim).
461485
"""
462-
# Convert q from [B, S, H, D] to [B, H, S, D] for F.scaled_dot_product_attention
463486
q = q.transpose(1, 2) # [B, n_heads, seq_len, head_dim]
464-
# k, v are already in [B, H, S, D] from StaticKVCache
465-
466-
# Handle GQA: repeat k/v heads if needed
467-
if self.n_heads != self.n_kv_heads:
468-
n_rep = self.n_heads // self.n_kv_heads
469-
k = k.repeat_interleave(n_rep, dim=1) # [B, n_heads, max_seq_len, head_dim]
470-
v = v.repeat_interleave(n_rep, dim=1)
471-
472-
# Create causal attention mask
473-
# input_pos contains the positions being attended to, e.g., [0,1,2,3] or [pos]
474-
# We need a mask of shape [seqlen, max_seq_len]
475-
max_seq_len = k.shape[2]
476-
477-
if seqlen > 1:
478-
# Prefill: create causal mask where position i can attend to positions 0..input_pos[i]
479-
# Create position matrix for queries and keys
480-
q_pos = input_pos.unsqueeze(1) # [seqlen, 1]
481-
k_pos = torch.arange(max_seq_len, device=q.device).unsqueeze(
482-
0
483-
) # [1, max_seq_len]
484-
# Causal mask: can attend where k_pos <= q_pos
485-
# PyTorch convention: True = attend, False = don't attend
486-
attn_mask = k_pos <= q_pos # [seqlen, max_seq_len], True = can attend
487-
else:
488-
# Decode: single token can attend to all positions up to current position
489-
# Current position is input_pos[0]
490-
curr_pos = input_pos[0] # scalar tensor
491-
k_pos = torch.arange(max_seq_len, device=q.device)
492-
attn_mask = k_pos <= curr_pos # [max_seq_len], True = can attend
493-
attn_mask = attn_mask.unsqueeze(0) # [1, max_seq_len] for broadcasting
494-
495-
# Standard SDPA
496-
y = F.scaled_dot_product_attention(
497-
q,
498-
k,
499-
v,
500-
attn_mask=attn_mask,
501-
is_causal=False, # We handle causal masking explicitly via attn_mask
487+
488+
if attn_mask is None:
489+
attn_mask = _build_attn_mask(input_pos, k.shape[2], q.device)
490+
491+
# Call the MPS SDPA op directly — bypasses CompositeImplicitAutograd
492+
# decomposition which would insert repeat_interleave for GQA.
493+
# The Metal kernel handles GQA natively via gqa_factor = n_heads / n_kv_heads.
494+
y, _ = torch.ops.aten._scaled_dot_product_attention_math_for_mps(
495+
q, k, v, attn_mask, 0.0, False, None
502496
) # [B, n_heads, seq_len, head_dim]
503497

504-
# Convert back to [B, S, H, D] and flatten
505498
y = y.transpose(1, 2).contiguous() # [B, seq_len, n_heads, head_dim]
506499
return y.view(bsz, seqlen, self.dim)
507500

@@ -580,7 +573,7 @@ def __init__(self, config: VoxtralRealtimeConfig, max_seq_len: int):
580573
# Choose KV cache and SDPA based on backend
581574
if self.use_standard_attention:
582575
self.kv_cache = StaticKVCache(max_seq_len, self.n_kv_heads, self.head_dim)
583-
self.sdpa = StandardSDPA(self.n_heads, self.n_kv_heads, self.head_dim)
576+
self.sdpa = MetalSDPA(self.n_heads, self.n_kv_heads, self.head_dim)
584577
else:
585578
self.kv_cache = KVCache(max_seq_len, self.n_kv_heads, self.head_dim)
586579
self.sdpa = SDPA(self.n_heads, self.head_dim)
@@ -591,6 +584,7 @@ def forward(
591584
freqs_cos: torch.Tensor,
592585
freqs_sin: torch.Tensor,
593586
input_pos: torch.Tensor,
587+
attn_mask: torch.Tensor | None = None,
594588
) -> torch.Tensor:
595589
B, T, _ = x.shape
596590
q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
@@ -601,7 +595,10 @@ def forward(
601595

602596
k, v = self.kv_cache.update(input_pos, k, v)
603597

604-
y = self.sdpa(input_pos, q, k, v, B, T)
598+
if self.use_standard_attention:
599+
y = self.sdpa(input_pos, q, k, v, B, T, attn_mask)
600+
else:
601+
y = self.sdpa(input_pos, q, k, v, B, T)
605602

606603
return self.wo(y)
607604

@@ -647,8 +644,11 @@ def forward(
647644
freqs_sin: torch.Tensor,
648645
input_pos: torch.Tensor,
649646
t_cond: torch.Tensor,
647+
attn_mask: torch.Tensor | None = None,
650648
) -> torch.Tensor:
651-
x = x + self.attention(self.attention_norm(x), freqs_cos, freqs_sin, input_pos)
649+
x = x + self.attention(
650+
self.attention_norm(x), freqs_cos, freqs_sin, input_pos, attn_mask
651+
)
652652
normed = self.ffn_norm(x)
653653
scale = 1.0 + self.ada_rms_norm_t_cond(t_cond)
654654
x = x + self.feed_forward(normed * scale)
@@ -683,9 +683,15 @@ def forward(
683683
freqs_cos = self.freqs_cos[input_pos]
684684
freqs_sin = self.freqs_sin[input_pos]
685685

686+
# Compute attention mask once for all 26 layers (P3 optimization).
687+
attn_mask: torch.Tensor | None = None
688+
if self.config.use_standard_attention:
689+
max_seq_len = self.freqs_cos.shape[0]
690+
attn_mask = _build_attn_mask(input_pos, max_seq_len, input_embeds.device)
691+
686692
x = input_embeds
687693
for layer in self.layers:
688-
x = layer(x, freqs_cos, freqs_sin, input_pos, t_cond)
694+
x = layer(x, freqs_cos, freqs_sin, input_pos, t_cond, attn_mask)
689695

690696
return self.output(self.norm(x))
691697

0 commit comments

Comments
 (0)