Skip to content

FP8 support for FusedQKNormRope#133

Merged
pralay-das merged 12 commits into
sgl-project:mainfrom
adubey-ai:main
Apr 10, 2026
Merged

FP8 support for FusedQKNormRope#133
pralay-das merged 12 commits into
sgl-project:mainfrom
adubey-ai:main

Conversation

@adubey-ai
Copy link
Copy Markdown
Contributor

@adubey-ai adubey-ai commented Mar 13, 2026

The baseline fp8 changes were taken from PR #109 Thanks.

Fix head_dim=16 cause: change test to head_dim=64 and add C++ guard to avoid rotary_dim==0 divide/modulo.

Replace undefined CHECK_INPUT with CHECK_DEVICE/CHECK_CONTIGUOUS in FusedQKNormRope.cpp.

Update test reference to use qkv.to(torch.float32) so comparisons match FP8 dequantized kernel outputs.

Make -Xs flag conditional in FindSYCL.cmake to avoid passing a bare -Xs to icx.

  • Note : The unit quant/dequant scale factors are considered for fp8. If actual scale factors need to be used, the API needs to be updated and those scales should be handled in this kernel. This will be taken up in another PR

================================================================================

('bandwidth_gbs', 'mean') ('bandwidth_gbs', 'min') ('bandwidth_gbs', 'max') ('time_us', 'mean') ('time_us', 'min') ('time_us', 'max') ('gflops', 'mean') ('gflops', 'min') ('gflops', 'max')
('bf16', 'sglang') 156.321 93.1 205.42 875.286 14.14 11662.2 280.237 192.78 363.2
('bf16', 'torch') 18.9632 2.28 27.1 5974.76 582.61 69409.4 33.5531 4.74 45.17
('fp16', 'sglang') 154.439 92.25 202.33 885.874 14.25 11785.6 276.858 190.78 357.79
('fp16', 'torch') 19.0132 2.22 27.18 5990.07 558.71 69453.3 33.6375 4.63 45.3
('fp8_e4m3fn', 'sglang') 63.3408 39.5 77.96 1029.48 17.21 13222.7 226.378 158.53 267.58
('fp8_e4m3fn', 'torch') 8.02333 1.11 11.05 7291.13 498.71 90833.8 28.4285 4.63 36.83

Signed-off-by: Adarsh Dubey <adarsh.dubey@intel.com>
Copy link
Copy Markdown
Collaborator

@kareemshaik80 kareemshaik80 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to add FP8 benchmarks as well so we can evaluate performance

@adubey-ai adubey-ai force-pushed the main branch 2 times, most recently from bb12bad to 05599e7 Compare March 16, 2026 08:36
Copilot AI review requested due to automatic review settings March 20, 2026 03:55
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds FP8 (E4M3) support for the XPU SYCL fused_qk_norm_rope kernel and validates it with a new unit test, while also adjusting SYCL toolchain flags intended to fix icx linking behavior.

Changes:

  • Extend fused_qk_norm_rope SYCL dispatch to support torch.float8_e4m3fn (via CUTLASS float_e4m3_t) and adjust accumulation behavior.
  • Add a dedicated FP8 E4M3 unit test comparing dequantized FP8 results against a float32 reference.
  • Attempt to make -Xs handling conditional in FindSYCL.cmake to avoid passing a bare -Xs.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
tests/test_fused_qk_norm_rope.py Adds an FP8 E4M3 test case and compares outputs in float32.
src/sycl/FusedQKNormRope.cpp Adds FP8 dtype dispatch and adjusts kernel accumulation/weight access.
cmake/Modules/FindSYCL.cmake Introduces conditional construction of -Xs flags for SYCL device link.
Comments suppressed due to low confidence (1)

cmake/Modules/FindSYCL.cmake:421

  • The new _sycl_xs_flags variable is computed but never used; the custom command still passes -Xs ${sycl_offline_compiler_flags} unconditionally. This likely means the original icx failure (bare -Xs consuming -o) is still present. Use the computed _sycl_xs_flags in the COMMAND line (and base the condition on the macro argument sycl_offline_compiler_flags, not the global SYCL_OFFLINE_COMPILER_FLAGS).
    # Only pass -Xs when there are offline compiler flags (AOT targets).
    # For JIT (spir64) targets SYCL_OFFLINE_COMPILER_FLAGS is empty and
    # passing a bare -Xs causes icx to consume the following -o flag as
    # its argument, producing "no such file or directory" errors.
    if(SYCL_OFFLINE_COMPILER_FLAGS)
      set(_sycl_xs_flags -Xs ${SYCL_OFFLINE_COMPILER_FLAGS})
    else()
      set(_sycl_xs_flags)
    endif()

    add_custom_command(
      OUTPUT ${output_file}
      DEPENDS ${object_files}
      COMMAND ${SYCL_EXECUTABLE}
      ${SYCL_device_link_flags}
      -fsycl-link ${object_files}
      -Xs ${sycl_offline_compiler_flags}
      -o ${output_file}

Comment thread src/sycl/FusedQKNormRope.cpp
Comment thread src/sycl/FusedQKNormRope.cpp
Comment thread tests/test_fused_qk_norm_rope.py Outdated
Comment thread tests/test_fused_qk_norm_rope.py
Copy link
Copy Markdown
Collaborator

@kareemshaik80 kareemshaik80 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM!

Comment thread src/sycl/FusedQKNormRope.cpp Outdated
adubey-ai and others added 2 commits April 6, 2026 08:49
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Collaborator

@kareemshaik80 kareemshaik80 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@pralay-das pralay-das merged commit ffccd5c into sgl-project:main Apr 10, 2026
2 checks passed
Copy link
Copy Markdown
Collaborator

@airMeng airMeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is functional support, no performance specific optimization, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants