[ExecuTorch][WebGPU] rms_norm: add a vec4 kernel for 4-aligned row widths#20458
[ExecuTorch][WebGPU] rms_norm: add a vec4 kernel for 4-aligned row widths#20458JulianCloudNTH wants to merge 4 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20458
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ No Failures, 102 PendingAs of commit fcdca19 with merge base 68bb668 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review if there are any bugs |
|
Claude finished @JulianCloudNTH's task in 8m 14s —— View job Bug review — vec4 rms_norm kernel
I reviewed the WGSL kernel, the host routing in Kernel correctness ✅
Buffer/alignment safety ✅The Numerical note (minor)The new comment in CI failures — unrelated to this PRThe "New Failures" flagged by Dr. CI are not caused by this change:
A rebase onto VerdictLGTM on correctness. The only actionable item is the optional comment-wording tweak above. (Note: I couldn't run the |
…dths Pull Request resolved: #20458 Add optimized vec4 kernel for bandwidth-bound rms_norm on Llama decode. **Problem**: Scalar kernel loads one element per lane per iteration — bandwidth-limited on Llama decode. **Solution**: Add vec4 kernel that loads/stores four contiguous elements as `vec4<f32>` and squares them with `dot(v, v)`, cutting loop iterations 4× and widening memory transactions. **Routing Logic**: - Use vec4 when: row_width % 4 == 0 - Otherwise: Fall back to scalar kernel **Constraints**: - row_width % 4 == 0: vec4 kernel has no partial-texel tail handling - Llama models (all hidden sizes 4-aligned) satisfy constraint **Implementation**: - New kernel: rms_norm_vec4.wgsl (same 64-lane workgroup) - Shared infrastructure: Same bind layout, Params, dispatch - Numerical: Float reassociation differs, not bit-identical to scalar **Performance**: ~33% faster on Apple M4 Pro / Metal across benchmark shapes (largest on decode, smallest on long prefill where already bandwidth-bound). This change was authored with assistance from Claude. ghstack-source-id: 396619676 @exported-using-ghexport Differential Revision: [D109333390](https://our.internmc.facebook.com/intern/diff/D109333390/)
…dths Pull Request resolved: #20458 Add optimized vec4 kernel for bandwidth-bound rms_norm on Llama decode. **Problem**: Scalar kernel loads one element per lane per iteration — bandwidth-limited on Llama decode. **Solution**: Add vec4 kernel that loads/stores four contiguous elements as `vec4<f32>` and squares them with `dot(v, v)`, cutting loop iterations 4× and widening memory transactions. **Routing Logic**: - Use vec4 when: row_width % 4 == 0 - Otherwise: Fall back to scalar kernel **Constraints**: - row_width % 4 == 0: vec4 kernel has no partial-texel tail handling - Llama models (all hidden sizes 4-aligned) satisfy constraint **Implementation**: - New kernel: rms_norm_vec4.wgsl (same 64-lane workgroup) - Shared infrastructure: Same bind layout, Params, dispatch - Numerical: Float reassociation differs, not bit-identical to scalar **Performance**: ~33% faster on Apple M4 Pro / Metal across benchmark shapes (largest on decode, smallest on long prefill where already bandwidth-bound). This change was authored with assistance from Claude. ghstack-source-id: 396619676 @exported-using-ghexport Differential Revision: [D109333390](https://our.internmc.facebook.com/intern/diff/D109333390/)
…dths Pull Request resolved: #20458 Add optimized vec4 kernel for bandwidth-bound rms_norm on Llama decode. **Problem**: Scalar kernel loads one element per lane per iteration — bandwidth-limited on Llama decode. **Solution**: Add vec4 kernel that loads/stores four contiguous elements as `vec4<f32>` and squares them with `dot(v, v)`, cutting loop iterations 4× and widening memory transactions. **Routing Logic**: - Use vec4 when: row_width % 4 == 0 - Otherwise: Fall back to scalar kernel **Constraints**: - row_width % 4 == 0: vec4 kernel has no partial-texel tail handling - Llama models (all hidden sizes 4-aligned) satisfy constraint **Implementation**: - New kernel: rms_norm_vec4.wgsl (same 64-lane workgroup) - Shared infrastructure: Same bind layout, Params, dispatch - Numerical: Float reassociation differs, not bit-identical to scalar **Performance**: ~33% faster on Apple M4 Pro / Metal across benchmark shapes (largest on decode, smallest on long prefill where already bandwidth-bound). This change was authored with assistance from Claude. ghstack-source-id: 396677654 @exported-using-ghexport Differential Revision: [D109333390](https://our.internmc.facebook.com/intern/diff/D109333390/)
Stack from ghstack (oldest at bottom):
Add optimized vec4 kernel for bandwidth-bound rms_norm on Llama decode.
Problem: Scalar kernel loads one element per lane per iteration — bandwidth-limited on Llama decode.
Solution: Add vec4 kernel that loads/stores four contiguous elements as
vec4<f32>and squares them withdot(v, v), cutting loop iterations 4× and widening memory transactions.Routing Logic:
Constraints:
Implementation:
Performance: ~33% faster on Apple M4 Pro / Metal across benchmark shapes (largest on decode, smallest on long prefill where already bandwidth-bound). This change was authored with assistance from Claude.
@exported-using-ghexport
Differential Revision: D109333390
Differential Revision: D109333390