Skip to content

Use fp32 accumulation in SkipLayerNorm/EmbedLayerNorm CUDA kernels#28682

Open
tianleiwu wants to merge 4 commits into
mainfrom
tlwu/sln_fp32_compute_type
Open

Use fp32 accumulation in SkipLayerNorm/EmbedLayerNorm CUDA kernels#28682
tianleiwu wants to merge 4 commits into
mainfrom
tlwu/sln_fp32_compute_type

Conversation

@tianleiwu
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu commented May 26, 2026

Description

Use fp32 accumulation in SkipLayerNormalization, SkipSimplifiedLayerNormalization, and EmbedLayerNormalization CUDA kernels to avoid overflow and improve numerical accuracy when processing fp16/bf16 data.

The original implementation accumulated mean and variance statistics in the input data type (fp16/bf16), which can overflow for large hidden sizes or when input values have large magnitude. This change promotes all intermediate accumulation (mean, variance, normalization math) to fp32, matching the approach used by TensorRT-LLM's LayerNorm kernels.

Motivation

  • fp16 has limited range (max ~65504) and precision (10-bit mantissa). Accumulating x²/ld across thousands of elements in fp16 easily overflows or loses precision.
  • bf16 has even less precision (7-bit mantissa), making accumulation errors more severe.
  • The fix is straightforward: cast to float before accumulating, compute normalization in float, cast back to the output type.

Key Changes

File Change
layer_norm.cuh Changed LayerNorm, SimplifiedLayerNorm, LayerNormSmall, SimplifiedLayerNormSmall to accept and operate on float for thread_data, epsilon, mu, rsigma. Removed unused KeyValuePairSum overloads for half/bfloat16.
skip_layer_norm_impl.cu Changed SkipLayerNormKernel and SkipLayerNormKernelSmall to accumulate in fp32 (cub::KeyValuePair<float, float>). Removed maybe2half helper (no longer needed).
embed_layer_norm_impl.cu Changed epsilon from T to float, accumulation to use float thread_data.
profile_skip_layer_norm.py New profiling script for nsys-based kernel timing analysis.
profile_skip_layer_norm.sh Shell wrapper for running nsys profiling.
parse_nsys.py Utility to parse nsys SQLite output and extract CUDA kernel timings.

Performance Results

Profiled on NVIDIA GPU with nsys (B=1, seq_len=2048, fp16 data, 200 iterations, skip first 5 warmup):

Hidden Size fp16 accum (μs) fp32 accum (μs) Regression
768 3.81 3.81 0.0%
1024 4.22 4.22 0.0%
4096 13.01 13.03 +0.15% (noise)
8192 28.94 28.94 0.0%

No measurable performance regression. The kernel is memory-bandwidth-bound, so fp32 arithmetic is completely hidden behind memory latency.

Testing

  • Existing unit tests pass (SkipLayerNorm, EmbedLayerNorm ops).
  • Profiling scripts added for reproducible performance measurement:
    cd onnxruntime/test/python/transformers
    nsys profile -o sln_fp16 --export=sqlite python profile_skip_layer_norm.py --mode fp16 --warmup 5 --repeat 100
    python parse_nsys.py sln_fp16.sqlite --skip-first 5

Related PRs

#28442
#15660

Comment thread onnxruntime/test/python/transformers/profile_skip_layer_norm.py Fixed
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves numerical stability of the CUDA fused normalization kernels used in transformer models by switching mean/variance (and RMS) accumulation to FP32 when inputs are FP16/BF16, reducing overflow/NaN risk while keeping outputs in the original type. It also adds local profiling utilities to measure kernel performance with Nsight Systems.

Changes:

  • Promote SkipLayerNorm and EmbedLayerNorm CUDA kernel intermediate accumulation (mean/variance/RMS) from FP16/BF16 to FP32.
  • Refactor layer_norm.cuh helpers to operate on float thread/block reduction values (and remove half/bfloat16 reduction overloads).
  • Add nsys-based profiling + parsing scripts for reproducible kernel timing analysis.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh Switch block-reduce inputs/outputs and normalization math to FP32 intermediates for improved stability.
onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu Accumulate skip layer norm statistics in FP32; remove now-unneeded maybe2half epsilon casting.
onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu Accumulate embedding sum statistics in FP32 and pass epsilon as float.
onnxruntime/test/python/transformers/profile_skip_layer_norm.py New Python profiler that builds a minimal SLN model and benchmarks it (optionally with NVTX ranges).
onnxruntime/test/python/transformers/profile_skip_layer_norm.sh New wrapper script to run nsys profile and parse results.
onnxruntime/test/python/transformers/parse_nsys.py New SQLite parser for nsys --export=sqlite kernel timing summaries.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/test/python/transformers/parse_nsys.py Outdated
Comment thread onnxruntime/test/python/transformers/parse_nsys.py Outdated
Comment thread onnxruntime/test/python/transformers/profile_skip_layer_norm.sh Outdated
tianleiwu added 2 commits May 26, 2026 17:29
- Fix SQL injection in parse_nsys.py: use parameterized queries instead of
  string interpolation for kernel pattern matching
- Add --nvtx-range option to parse_nsys.py to filter kernels by NVTX range
  (e.g., 'benchmark'), eliminating the need for --skip-first to exclude warmup
- Update parse_nsys.py description/epilog to reflect current purpose
- Remove pip install nvtx from shell script; just check availability and warn
- Fix CodeQL import warning: use 'from onnx import save_model' instead of
  'import onnx' + 'onnx.save'
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.

Comment on lines 81 to 85
// Reduce sum of x and x^2, and the results are divided by ld.
// Uses fp32 accumulation to avoid overflow in fp16/bf16.
KeyValuePairSum pair_sum;
cub::KeyValuePair<T, T> thread_data(0, 0);
cub::KeyValuePair<float, float> thread_data(0.f, 0.f);

Comment on lines 122 to 127
template <typename T, unsigned TPB>
__global__ void EmbedLayerNormKernel(
int hidden_size, const int* input_ids, const int* segment_ids, const T* beta, const T* gamma,
const T* word_embedding, const T* position_embedding, const T* segment_embedding,
const T epsilon, T* output, T* embedding_sum, const int* position_ids, const bool broadcast_position_ids) {
float epsilon, T* output, T* embedding_sum, const int* position_ids, const bool broadcast_position_ids) {
KeyValuePairSum pair_sum;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants