[vLLM IR] rework gemma_rms_norm#39014
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements mixed-dtype support for RMSNorm by ensuring that provider kernels (vllm_c, aiter, xpu) fall back to a native implementation when input and weight dtypes do not match. It refactors GemmaRMSNorm to use the unified IR operation and adds comprehensive tests for mixed-dtype scenarios. Review feedback identifies a potential precision loss in the native multiplication step, a bug where the residual tensor is returned in float32 instead of the original dtype, and a performance regression in GemmaRMSNorm caused by the removal of torch.compile and redundant weight conversions.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 979bea2a64
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if not _rms_weight_dtype_matches_input(input, weight): | ||
| result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon) | ||
| return self.quant_matcher(result_rms, scale)[0] |
There was a problem hiding this comment.
Move mixed-dtype gate out of traced replacement function
This mixed-dtype fallback is placed inside a replacement closure that Inductor traces using example tensors, so the Python if is resolved at trace time (with homogeneous sample dtypes) and the fallback branch is not preserved in the replacement graph. As a result, mixed-dtype RMSNorm+quant graphs introduced by this commit can still be rewritten to fused kernels that require weight.dtype == input.dtype, which leads to runtime failures (or undefined behavior for kernels that reinterpret weight as input scalar type). Please enforce this constraint via match filtering (e.g., extra_check/separate pattern registration) rather than a runtime Python branch inside replacement.
Useful? React with 👍 / 👎.
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
|
This is a nice fix! IIUC, you convert |
Actually, this will make Need to fix |
|
Thanks for your work! The current solution seems to be causing accuracy issues. I think we could try the solution provided by Luka, limiting fusion doesn't happen with fp32 weights. |
It's not about fusion. It's beause in this pr gemma gguf test uses cuda kernel, but previous is native implementation |
If |
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
|
This pull request has merge conflicts that must be resolved before it can be |
# Conflicts: # tests/kernels/core/test_layernorm.py Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
da2e44c to
672fb48
Compare
|
@ProExpertProg I disable allreduce_rms fusion when dtype is mismatching. It wll cause accuracy issue for quantized model |
ProExpertProg
left a comment
There was a problem hiding this comment.
Disabling allreduce+rms fusion with mismatched types sounds good. CI failing though?
This failure is unrelated I think |
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Jacob Lou <jacoblou0924@gmail.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Purpose
rework for #38780
cc @ProExpertProg @wxsIcey
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.