Commit c015241
Enable native GQA and hoist mask computation for Metal SDPA (#17720)
Replace F.scaled_dot_product_attention with a direct call to
torch.ops.aten._scaled_dot_product_attention_math_for_mps in
StandardSDPA. This is necessary because F.scaled_dot_product_attention
is CompositeImplicitAutograd — torch.export() decomposes it back into
repeat_interleave + matmul for GQA, defeating native kernel support.
The _for_mps op stays as a single node in the exported graph and
resolves at runtime to the custom Metal SDPA shader in op_sdpa.mm,
which handles GQA natively via gqa_factor = n_heads / n_kv_heads.
For voxtral (32 Q heads, 8 KV heads), this eliminates 4x redundant
K/V memory traffic per layer (repeat_interleave materialized 128MB
of expanded KV per layer vs 32MB with native GQA).
Also hoist the attention mask computation from StandardSDPA (called
26x per token, once per layer) to MistralDecoder.forward (called 1x).
Build the mask using integer arithmetic (clamp) instead of bool
comparisons since Metal AOTI doesn't support bool tensor allocation.
The XNNPACK path is unchanged.
Improved from 19 tokens/s to 40 tokens/s
Test Plan:
```
I 00:00:26.888453 executorch:voxtral_realtime_runner.cpp:247] Audio: 314240 samples -> 375 frames
One, two, three. This is February 18th, Wednesday, and I'm testing an AI model. Tell me an interesting story. One, two, three. Thank you..</s>
PyTorchObserver {"prompt_tokens":0,"generated_tokens":378,"model_load_start_ms":1772062281912,"model_load_end_ms":1772062303509,"inference_start_ms":1772062303509,"inference_end_ms":1772062319075,"prompt_eval_end_ms":1772062309797,"first_token_ms":1772062309797,"aggregate_sampling_time_ms":0,"SCALING_FACTOR_UNITS_PER_SECOND":1000}
I 00:00:37.164707 executorch:stats.h:143] Prompt Tokens: 0 Generated Tokens: 378
I 00:00:37.164710 executorch:stats.h:149] Model Load Time: 21.597000 (seconds)
I 00:00:37.164713 executorch:stats.h:159] Total inference time: 15.566000 (seconds) Rate: 24.283695 (tokens/second)
I 00:00:37.164715 executorch:stats.h:167] Prompt evaluation: 6.288000 (seconds) Rate: 0.000000 (tokens/second)
I 00:00:37.164736 executorch:stats.h:178] Generated 378 tokens: 9.278000 (seconds) Rate: 40.741539 (tokens/second)
I 00:00:37.164747 executorch:stats.h:186] Time to first generated token: 6.288000 (seconds)
I 00:00:37.164749 executorch:stats.h:193] Sampling time over 378 tokens: 0.000000 (seconds)
I 00:00:37.166004 executorch:metal_backend.cpp:716] Removed temporary shared library file: /var/folders/_1/z_wzgpv50gn73c02j5xmcnf40000gn/T/text_decoder_so_blob46972.so
I 00:00:37.291288 executorch:memory.cpp:642] Cleared all tensors and Metal resources
I 00:00:37.291937 executorch:metal_backend.cpp:716] Removed temporary shared library file: /var/folders/_1/z_wzgpv50gn73c02j5xmcnf40000gn/T/token_embedding_so_blob46972.so
I 00:00:37.291944 executorch:memory.cpp:642] Cleared all tensors and Metal resources
I 00:00:37.292499 executorch:metal_backend.cpp:716] Removed temporary shared library file: /var/folders/_1/z_wzgpv50gn73c02j5xmcnf40000gn/T/audio_encoder_so_blob46972.so
I 00:00:37.292507 executorch:memory.cpp:642] Cleared all tensors and Metal resources
```
Co-authored-by: Claude <noreply@anthropic.com>1 parent f30d5ed commit c015241
1 file changed
Lines changed: 54 additions & 48 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
427 | 427 | | |
428 | 428 | | |
429 | 429 | | |
430 | | - | |
431 | | - | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
432 | 434 | | |
433 | | - | |
434 | | - | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
435 | 457 | | |
436 | 458 | | |
437 | 459 | | |
| |||
449 | 471 | | |
450 | 472 | | |
451 | 473 | | |
| 474 | + | |
452 | 475 | | |
453 | 476 | | |
454 | 477 | | |
455 | 478 | | |
456 | 479 | | |
457 | 480 | | |
458 | 481 | | |
| 482 | + | |
459 | 483 | | |
460 | 484 | | |
461 | 485 | | |
462 | | - | |
463 | 486 | | |
464 | | - | |
465 | | - | |
466 | | - | |
467 | | - | |
468 | | - | |
469 | | - | |
470 | | - | |
471 | | - | |
472 | | - | |
473 | | - | |
474 | | - | |
475 | | - | |
476 | | - | |
477 | | - | |
478 | | - | |
479 | | - | |
480 | | - | |
481 | | - | |
482 | | - | |
483 | | - | |
484 | | - | |
485 | | - | |
486 | | - | |
487 | | - | |
488 | | - | |
489 | | - | |
490 | | - | |
491 | | - | |
492 | | - | |
493 | | - | |
494 | | - | |
495 | | - | |
496 | | - | |
497 | | - | |
498 | | - | |
499 | | - | |
500 | | - | |
501 | | - | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
502 | 496 | | |
503 | 497 | | |
504 | | - | |
505 | 498 | | |
506 | 499 | | |
507 | 500 | | |
| |||
580 | 573 | | |
581 | 574 | | |
582 | 575 | | |
583 | | - | |
| 576 | + | |
584 | 577 | | |
585 | 578 | | |
586 | 579 | | |
| |||
591 | 584 | | |
592 | 585 | | |
593 | 586 | | |
| 587 | + | |
594 | 588 | | |
595 | 589 | | |
596 | 590 | | |
| |||
601 | 595 | | |
602 | 596 | | |
603 | 597 | | |
604 | | - | |
| 598 | + | |
| 599 | + | |
| 600 | + | |
| 601 | + | |
605 | 602 | | |
606 | 603 | | |
607 | 604 | | |
| |||
647 | 644 | | |
648 | 645 | | |
649 | 646 | | |
| 647 | + | |
650 | 648 | | |
651 | | - | |
| 649 | + | |
| 650 | + | |
| 651 | + | |
652 | 652 | | |
653 | 653 | | |
654 | 654 | | |
| |||
683 | 683 | | |
684 | 684 | | |
685 | 685 | | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
| 689 | + | |
| 690 | + | |
| 691 | + | |
686 | 692 | | |
687 | 693 | | |
688 | | - | |
| 694 | + | |
689 | 695 | | |
690 | 696 | | |
691 | 697 | | |
| |||
0 commit comments