You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update on "Add ONNX Runtime GQA-style SDPA benchmark"
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:
- Scale baked into GEMM alpha (no separate scaling pass)
- Scores buffer padded to max_seq_len columns
- Causal mask: zero out future positions, softmax on valid window only
- Output always in [B, S, Hq, D] format
Extends validation to verify ONNX GQA output matches custom_sdpa_out
reference. Adds OnnxGQABenchFixture for benchmarking both layouts.
Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/)
[ghstack-poisoned]
"The stable softmax decomposition is now supported by all arm targets and will be made default in a future release. Overwrite the default config using `compile_spec.set_pass_pipeline_config(ArmPassPipelineConfig())` to use the stable algorithm and avoid this error."
32
-
)
33
-
defdisable_masked_softmax(self) ->None:
34
-
"""
35
-
.. warning::
36
-
37
-
The stable softmax decomposition is now supported by all arm targets and will be made default in a future release. Overwrite the default config using `compile_spec.set_pass_pipeline_config(ArmPassPipelineConfig())` to use the stable algorithm and avoid this error."
0 commit comments