Skip to content

[ExecuTorch][WebGPU] rms_norm: add a vec4 kernel for 4-aligned row widths#20458

Open
JulianCloudNTH wants to merge 4 commits into
gh/JulianCloudNTH/53/basefrom
gh/JulianCloudNTH/53/head
Open

[ExecuTorch][WebGPU] rms_norm: add a vec4 kernel for 4-aligned row widths#20458
JulianCloudNTH wants to merge 4 commits into
gh/JulianCloudNTH/53/basefrom
gh/JulianCloudNTH/53/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Add optimized vec4 kernel for bandwidth-bound rms_norm on Llama decode.

Problem: Scalar kernel loads one element per lane per iteration — bandwidth-limited on Llama decode.

Solution: Add vec4 kernel that loads/stores four contiguous elements as vec4<f32> and squares them with dot(v, v), cutting loop iterations 4× and widening memory transactions.

Routing Logic:

  • Use vec4 when: row_width % 4 == 0
  • Otherwise: Fall back to scalar kernel

Constraints:

  • row_width % 4 == 0: vec4 kernel has no partial-texel tail handling
  • Llama models (all hidden sizes 4-aligned) satisfy constraint

Implementation:

  • New kernel: rms_norm_vec4.wgsl (same 64-lane workgroup)
  • Shared infrastructure: Same bind layout, Params, dispatch
  • Numerical: Float reassociation differs, not bit-identical to scalar

Performance: ~33% faster on Apple M4 Pro / Metal across benchmark shapes (largest on decode, smallest on long prefill where already bandwidth-bound). This change was authored with assistance from Claude.
@exported-using-ghexport

Differential Revision: D109333390

Differential Revision: D109333390

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 23, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20458

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⏳ No Failures, 102 Pending

As of commit fcdca19 with merge base 68bb668 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 23, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@JulianCloudNTH

Copy link
Copy Markdown
Contributor Author

@claude review if there are any bugs

@claude

claude Bot commented Jun 24, 2026

Copy link
Copy Markdown

Claude finished @JulianCloudNTH's task in 8m 14s —— View job


Bug review — vec4 rms_norm kernel

  • Read the diff and both kernels (scalar + vec4)
  • Check kernel correctness (indexing, reduction, normalization)
  • Check host routing / bind group / dispatch
  • Verify codegen header + test-count parity
  • Triage the reported CI failures

I reviewed the WGSL kernel, the host routing in RmsNorm.cpp, and the test changes. No correctness bugs found — the implementation looks right.

Kernel correctness ✅

  • Weight indexing is correct. This is the one spot that's easy to get wrong. In the normalize loop t_out[base4 + x4] = t_in[base4 + x4] * rstd * t_weight[x4] (rms_norm_vec4.wgsl:69), the input is indexed with the row offset base4 + x4 while the weight is indexed with the bare column-texel x4. That's correct: t_weight is array<vec4<f32>> of length row_width/4, weights are per-column (row-independent), and texel x4 maps to columns 4*x4 .. 4*x4+3. Matches the scalar kernel's t_weight[x].
  • Mean divides by element count, not texel count. mean_sq = shared_sum[0] / f32(params.row_width) — correct, since dot(v,v) already sums 4 elements per texel. The comment calls this out explicitly. 👍
  • Reduction is byte-identical to the scalar kernel.
  • vec4 math t_in[...] * rstd * t_weight[x4] is componentwise (vec4 × scalar × vec4). Correct.

Buffer/alignment safety ✅

The array<vec4<f32>> bindings are safe because the row_width % 4 == 0 guard implies every buffer length is a multiple of 16 bytes: input/output = num_rows * row_width elements, weight = row_width elements, all bound at offset 0. No partial-texel tail, as documented.

Numerical note (minor)

The new comment in RmsNorm.cpp:96-99 says the two kernels are "numerically equivalent", but the PR description correctly notes the float reassociation means they are not bit-identical. Consider softening to "algebraically equivalent (results match within fp32 tolerance, not bit-identical)" so the comment doesn't overstate. Not a bug — the native test uses a tolerance.

CI failures — unrelated to this PR

The "New Failures" flagged by Dr. CI are not caused by this change:

  • unittest / linux & unittest-editable / linux: the only error is ModuleNotFoundError: No module named 'tosa_reference_model' in backends/transforms/test/test_quantize_fused_convbn_bias_pass.py (Arm/TOSA). The run reports 1642 passed, 21 skipped, 1 error — the WebGPU rms_norm tests (incl. the new llama_hidden_2048 case and the == 15 parity assertion in test_schema.py) are in the passing set.
  • test-llama-runner-qnn-linux (exit 137) and test-qnn-models-linux (exit 92): docker OOM/killed on QNN jobs, infra-related.

A rebase onto viable/strict should clear the TOSA collection error.

Verdict

LGTM on correctness. The only actionable item is the optional comment-wording tweak above. (Note: I couldn't run the gen_wgsl_headers.py --check drift gate locally — Bash is sandboxed here — but the embedded wgsl-sha256 header is present and the linux unittest run that exercises test_committed_headers_match_generator passed.)
· branch gh/JulianCloudNTH/53/head

[ghstack-poisoned]
JulianCloudNTH added a commit that referenced this pull request Jun 24, 2026
…dths

Pull Request resolved: #20458

Add optimized vec4 kernel for bandwidth-bound rms_norm on Llama decode.

**Problem**: Scalar kernel loads one element per lane per iteration — bandwidth-limited on Llama decode.

**Solution**: Add vec4 kernel that loads/stores four contiguous elements as `vec4<f32>` and squares them with `dot(v, v)`, cutting loop iterations 4× and widening memory transactions.

**Routing Logic**:
- Use vec4 when: row_width % 4 == 0
- Otherwise: Fall back to scalar kernel

**Constraints**:
- row_width % 4 == 0: vec4 kernel has no partial-texel tail handling
- Llama models (all hidden sizes 4-aligned) satisfy constraint

**Implementation**:
- New kernel: rms_norm_vec4.wgsl (same 64-lane workgroup)
- Shared infrastructure: Same bind layout, Params, dispatch
- Numerical: Float reassociation differs, not bit-identical to scalar

**Performance**: ~33% faster on Apple M4 Pro / Metal across benchmark shapes (largest on decode, smallest on long prefill where already bandwidth-bound). This change was authored with assistance from Claude.
ghstack-source-id: 396619676
@exported-using-ghexport

Differential Revision: [D109333390](https://our.internmc.facebook.com/intern/diff/D109333390/)
[ghstack-poisoned]
JulianCloudNTH added a commit that referenced this pull request Jun 24, 2026
…dths

Pull Request resolved: #20458

Add optimized vec4 kernel for bandwidth-bound rms_norm on Llama decode.

**Problem**: Scalar kernel loads one element per lane per iteration — bandwidth-limited on Llama decode.

**Solution**: Add vec4 kernel that loads/stores four contiguous elements as `vec4<f32>` and squares them with `dot(v, v)`, cutting loop iterations 4× and widening memory transactions.

**Routing Logic**:
- Use vec4 when: row_width % 4 == 0
- Otherwise: Fall back to scalar kernel

**Constraints**:
- row_width % 4 == 0: vec4 kernel has no partial-texel tail handling
- Llama models (all hidden sizes 4-aligned) satisfy constraint

**Implementation**:
- New kernel: rms_norm_vec4.wgsl (same 64-lane workgroup)
- Shared infrastructure: Same bind layout, Params, dispatch
- Numerical: Float reassociation differs, not bit-identical to scalar

**Performance**: ~33% faster on Apple M4 Pro / Metal across benchmark shapes (largest on decode, smallest on long prefill where already bandwidth-bound). This change was authored with assistance from Claude.
ghstack-source-id: 396619676
@exported-using-ghexport

Differential Revision: [D109333390](https://our.internmc.facebook.com/intern/diff/D109333390/)
[ghstack-poisoned]
JulianCloudNTH added a commit that referenced this pull request Jun 24, 2026
…dths

Pull Request resolved: #20458

Add optimized vec4 kernel for bandwidth-bound rms_norm on Llama decode.

**Problem**: Scalar kernel loads one element per lane per iteration — bandwidth-limited on Llama decode.

**Solution**: Add vec4 kernel that loads/stores four contiguous elements as `vec4<f32>` and squares them with `dot(v, v)`, cutting loop iterations 4× and widening memory transactions.

**Routing Logic**:
- Use vec4 when: row_width % 4 == 0
- Otherwise: Fall back to scalar kernel

**Constraints**:
- row_width % 4 == 0: vec4 kernel has no partial-texel tail handling
- Llama models (all hidden sizes 4-aligned) satisfy constraint

**Implementation**:
- New kernel: rms_norm_vec4.wgsl (same 64-lane workgroup)
- Shared infrastructure: Same bind layout, Params, dispatch
- Numerical: Float reassociation differs, not bit-identical to scalar

**Performance**: ~33% faster on Apple M4 Pro / Metal across benchmark shapes (largest on decode, smallest on long prefill where already bandwidth-bound). This change was authored with assistance from Claude.
ghstack-source-id: 396677654
@exported-using-ghexport

Differential Revision: [D109333390](https://our.internmc.facebook.com/intern/diff/D109333390/)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants