Skip to content

[vLLM IR] rework gemma_rms_norm#39014

Merged
vllm-bot merged 16 commits into
vllm-project:mainfrom
ZJY0516:rework-gemma-norm
Apr 7, 2026
Merged

[vLLM IR] rework gemma_rms_norm#39014
vllm-bot merged 16 commits into
vllm-project:mainfrom
ZJY0516:rework-gemma-norm

Conversation

@ZJY0516
Copy link
Copy Markdown
Member

@ZJY0516 ZJY0516 commented Apr 5, 2026

Purpose

rework for #38780

cc @ProExpertProg @wxsIcey

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements mixed-dtype support for RMSNorm by ensuring that provider kernels (vllm_c, aiter, xpu) fall back to a native implementation when input and weight dtypes do not match. It refactors GemmaRMSNorm to use the unified IR operation and adds comprehensive tests for mixed-dtype scenarios. Review feedback identifies a potential precision loss in the native multiplication step, a bug where the residual tensor is returned in float32 instead of the original dtype, and a performance regression in GemmaRMSNorm caused by the removal of torch.compile and redundant weight conversions.

Comment thread vllm/ir/ops/layernorm.py
Comment thread vllm/model_executor/layers/layernorm.py Outdated
Comment thread vllm/model_executor/layers/layernorm.py
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 979bea2a64

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +172 to +174
if not _rms_weight_dtype_matches_input(input, weight):
result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
return self.quant_matcher(result_rms, scale)[0]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Move mixed-dtype gate out of traced replacement function

This mixed-dtype fallback is placed inside a replacement closure that Inductor traces using example tensors, so the Python if is resolved at trace time (with homogeneous sample dtypes) and the fallback branch is not preserved in the replacement graph. As a result, mixed-dtype RMSNorm+quant graphs introduced by this commit can still be rewritten to fused kernels that require weight.dtype == input.dtype, which leads to runtime failures (or undefined behavior for kernels that reinterpret weight as input scalar type). Please enforce this constraint via match filtering (e.g., extra_check/separate pattern registration) rather than a runtime Python branch inside replacement.

Useful? React with 👍 / 👎.

ZJY0516 added 2 commits April 5, 2026 13:35
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
ZJY0516 added 5 commits April 5, 2026 16:27
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516 ZJY0516 added the ready-run-all-tests Trigger CI with all tests for wide-ranging PRs label Apr 5, 2026
@ProExpertProg
Copy link
Copy Markdown
Collaborator

This is a nice fix! IIUC, you convert x to fp32 before passing it to rms_norm so it's the same dtype as weight. I think it would be good to check the final compiled graph and get performance numbers - if x is 16bit and produced by a custom kernel, I'd want to avoid an extra triton kernel to convert to 32bit. Also wonder if it's slower to read & write x at 32 bit instead of 16.

@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Apr 5, 2026

This is a nice fix! IIUC, you convert x to fp32 before passing it to rms_norm so it's the same dtype as weight. I think it would be good to check the final compiled graph and get performance numbers - if x is 16bit and produced by a custom kernel, I'd want to avoid an extra triton kernel to convert to 32bit. Also wonder if it's slower to read & write x at 32 bit instead of 16.

Actually, this will make models/quantization/test_gguf.py::test_models[1-5-32-bfloat16-model6] fail

Need to fix

https://buildkite.com/vllm/ci/builds/59823/steps/canvas?jid=019d5d54-ee19-441f-9417-854701716e26&tab=output

@wxsIcey
Copy link
Copy Markdown
Contributor

wxsIcey commented Apr 5, 2026

Thanks for your work! The current solution seems to be causing accuracy issues. I think we could try the solution provided by Luka, limiting fusion doesn't happen with fp32 weights.

@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Apr 5, 2026

Thanks for your work! The current solution seems to be causing accuracy issues. I think we could try the solution provided by Luka, limiting fusion doesn't happen with fp32 weights.

It's not about fusion. It's beause in this pr gemma gguf test uses cuda kernel, but previous is native implementation

@wxsIcey
Copy link
Copy Markdown
Contributor

wxsIcey commented Apr 5, 2026

Thanks for your work! The current solution seems to be causing accuracy issues. I think we could try the solution provided by Luka, limiting fusion doesn't happen with fp32 weights.

It's not about fusion. It's beause in this pr gemma gguf test uses cuda kernel, but previous is native implementation

If x isn't converted to float(), the dtype of x and the weight are inconsistent, preventing the CUDA kernel from being used. Then, by implementing restriction in fusion pass, perhaps CI can pass completely.

@ZJY0516 ZJY0516 removed the ready-run-all-tests Trigger CI with all tests for wide-ranging PRs label Apr 5, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516 ZJY0516 added ready ONLY add when PR is ready to merge/full CI is needed and removed ready ONLY add when PR is ready to merge/full CI is needed labels Apr 5, 2026
Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight simplification?

Comment thread vllm/ir/ops/layernorm.py Outdated
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 6, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZJY0516.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 6, 2026
# Conflicts:
#	tests/kernels/core/test_layernorm.py

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@mergify mergify Bot removed the needs-rebase label Apr 6, 2026
ZJY0516 added 3 commits April 6, 2026 10:22
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516 ZJY0516 force-pushed the rework-gemma-norm branch from da2e44c to 672fb48 Compare April 6, 2026 10:26
@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Apr 6, 2026

@ProExpertProg I disable allreduce_rms fusion when dtype is mismatching. It wll cause accuracy issue for quantized model

Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disabling allreduce+rms fusion with mismatched types sounds good. CI failing though?

@yma11
Copy link
Copy Markdown
Contributor

yma11 commented Apr 7, 2026

cc @chaojun-zhang

@ZJY0516
Copy link
Copy Markdown
Member Author

ZJY0516 commented Apr 7, 2026

Disabling allreduce+rms fusion with mismatched types sounds good. CI failing though?

This failure is unrelated I think

@vllm-bot vllm-bot merged commit 8060bb0 into vllm-project:main Apr 7, 2026
179 of 181 checks passed
jacob-lou pushed a commit to jacob-lou/vllm that referenced this pull request Apr 7, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Jacob Lou <jacoblou0924@gmail.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready-run-all-tests Trigger CI with all tests for wide-ranging PRs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants