Use fp32 accumulation in SkipLayerNorm/EmbedLayerNorm CUDA kernels#28682
Open
tianleiwu wants to merge 4 commits into
Open
Use fp32 accumulation in SkipLayerNorm/EmbedLayerNorm CUDA kernels#28682tianleiwu wants to merge 4 commits into
tianleiwu wants to merge 4 commits into
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR improves numerical stability of the CUDA fused normalization kernels used in transformer models by switching mean/variance (and RMS) accumulation to FP32 when inputs are FP16/BF16, reducing overflow/NaN risk while keeping outputs in the original type. It also adds local profiling utilities to measure kernel performance with Nsight Systems.
Changes:
- Promote SkipLayerNorm and EmbedLayerNorm CUDA kernel intermediate accumulation (mean/variance/RMS) from FP16/BF16 to FP32.
- Refactor
layer_norm.cuhhelpers to operate onfloatthread/block reduction values (and remove half/bfloat16 reduction overloads). - Add
nsys-based profiling + parsing scripts for reproducible kernel timing analysis.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh | Switch block-reduce inputs/outputs and normalization math to FP32 intermediates for improved stability. |
| onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu | Accumulate skip layer norm statistics in FP32; remove now-unneeded maybe2half epsilon casting. |
| onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu | Accumulate embedding sum statistics in FP32 and pass epsilon as float. |
| onnxruntime/test/python/transformers/profile_skip_layer_norm.py | New Python profiler that builds a minimal SLN model and benchmarks it (optionally with NVTX ranges). |
| onnxruntime/test/python/transformers/profile_skip_layer_norm.sh | New wrapper script to run nsys profile and parse results. |
| onnxruntime/test/python/transformers/parse_nsys.py | New SQLite parser for nsys --export=sqlite kernel timing summaries. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Fix SQL injection in parse_nsys.py: use parameterized queries instead of string interpolation for kernel pattern matching - Add --nvtx-range option to parse_nsys.py to filter kernels by NVTX range (e.g., 'benchmark'), eliminating the need for --skip-first to exclude warmup - Update parse_nsys.py description/epilog to reflect current purpose - Remove pip install nvtx from shell script; just check availability and warn - Fix CodeQL import warning: use 'from onnx import save_model' instead of 'import onnx' + 'onnx.save'
Comment on lines
81
to
85
| // Reduce sum of x and x^2, and the results are divided by ld. | ||
| // Uses fp32 accumulation to avoid overflow in fp16/bf16. | ||
| KeyValuePairSum pair_sum; | ||
| cub::KeyValuePair<T, T> thread_data(0, 0); | ||
| cub::KeyValuePair<float, float> thread_data(0.f, 0.f); | ||
|
|
Comment on lines
122
to
127
| template <typename T, unsigned TPB> | ||
| __global__ void EmbedLayerNormKernel( | ||
| int hidden_size, const int* input_ids, const int* segment_ids, const T* beta, const T* gamma, | ||
| const T* word_embedding, const T* position_embedding, const T* segment_embedding, | ||
| const T epsilon, T* output, T* embedding_sum, const int* position_ids, const bool broadcast_position_ids) { | ||
| float epsilon, T* output, T* embedding_sum, const int* position_ids, const bool broadcast_position_ids) { | ||
| KeyValuePairSum pair_sum; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Use fp32 accumulation in SkipLayerNormalization, SkipSimplifiedLayerNormalization, and EmbedLayerNormalization CUDA kernels to avoid overflow and improve numerical accuracy when processing fp16/bf16 data.
The original implementation accumulated mean and variance statistics in the input data type (fp16/bf16), which can overflow for large hidden sizes or when input values have large magnitude. This change promotes all intermediate accumulation (mean, variance, normalization math) to fp32, matching the approach used by TensorRT-LLM's LayerNorm kernels.
Motivation
x²/ldacross thousands of elements in fp16 easily overflows or loses precision.Key Changes
layer_norm.cuhLayerNorm,SimplifiedLayerNorm,LayerNormSmall,SimplifiedLayerNormSmallto accept and operate onfloatfor thread_data, epsilon, mu, rsigma. Removed unusedKeyValuePairSumoverloads for half/bfloat16.skip_layer_norm_impl.cuSkipLayerNormKernelandSkipLayerNormKernelSmallto accumulate in fp32 (cub::KeyValuePair<float, float>). Removedmaybe2halfhelper (no longer needed).embed_layer_norm_impl.cuTtofloat, accumulation to usefloatthread_data.profile_skip_layer_norm.pyprofile_skip_layer_norm.shparse_nsys.pyPerformance Results
Profiled on NVIDIA GPU with nsys (B=1, seq_len=2048, fp16 data, 200 iterations, skip first 5 warmup):
No measurable performance regression. The kernel is memory-bandwidth-bound, so fp32 arithmetic is completely hidden behind memory latency.
Testing
cd onnxruntime/test/python/transformers nsys profile -o sln_fp16 --export=sqlite python profile_skip_layer_norm.py --mode fp16 --warmup 5 --repeat 100 python parse_nsys.py sln_fp16.sqlite --skip-first 5Related PRs
#28442
#15660