Limit softmax to causally-valid elements in cpu_sdpa#18650
Limit softmax to causally-valid elements in cpu_sdpa#18650kimishpatel wants to merge 2 commits intogh/kimishpatel/223/basefrom
Conversation
Instead of setting masked positions to -inf and computing max/exp/normalize over all kvSize elements, limit the softmax to only the causally-valid range per row. This avoids unnecessary computation on masked positions and zero-fills them for GEMM 2. Differential Revision: [D96044307](https://our.internmc.facebook.com/intern/diff/D96044307/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18650
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 2 Cancelled JobsAs of commit 2e6dc3a with merge base fb1618e ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
digantdesai
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
Instead of setting masked positions to -inf and computing max/exp/normalize over all kvSize elements, limit the softmax to only the causally-valid range per row. This avoids unnecessary computation on masked positions and zero-fills them for GEMM 2. Differential Revision: [D96044307](https://our.internmc.facebook.com/intern/diff/D96044307/) [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Instead of setting masked positions to -inf and computing
max/exp/normalize over all kvSize elements, limit the softmax
to only the causally-valid range per row. This avoids unnecessary
computation on masked positions and zero-fills them for GEMM 2.
Differential Revision: D96044307