FP8 support for FusedQKNormRope#133
Merged
Merged
Conversation
Signed-off-by: Adarsh Dubey <adarsh.dubey@intel.com>
jayachandranb-ai
approved these changes
Mar 13, 2026
Collaborator
kareemshaik80
left a comment
There was a problem hiding this comment.
It might be useful to add FP8 benchmarks as well so we can evaluate performance
bb12bad to
05599e7
Compare
Contributor
There was a problem hiding this comment.
Pull request overview
Adds FP8 (E4M3) support for the XPU SYCL fused_qk_norm_rope kernel and validates it with a new unit test, while also adjusting SYCL toolchain flags intended to fix icx linking behavior.
Changes:
- Extend
fused_qk_norm_ropeSYCL dispatch to supporttorch.float8_e4m3fn(via CUTLASSfloat_e4m3_t) and adjust accumulation behavior. - Add a dedicated FP8 E4M3 unit test comparing dequantized FP8 results against a float32 reference.
- Attempt to make
-Xshandling conditional inFindSYCL.cmaketo avoid passing a bare-Xs.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
tests/test_fused_qk_norm_rope.py |
Adds an FP8 E4M3 test case and compares outputs in float32. |
src/sycl/FusedQKNormRope.cpp |
Adds FP8 dtype dispatch and adjusts kernel accumulation/weight access. |
cmake/Modules/FindSYCL.cmake |
Introduces conditional construction of -Xs flags for SYCL device link. |
Comments suppressed due to low confidence (1)
cmake/Modules/FindSYCL.cmake:421
- The new
_sycl_xs_flagsvariable is computed but never used; the custom command still passes-Xs ${sycl_offline_compiler_flags}unconditionally. This likely means the originalicxfailure (bare-Xsconsuming-o) is still present. Use the computed_sycl_xs_flagsin the COMMAND line (and base the condition on the macro argumentsycl_offline_compiler_flags, not the globalSYCL_OFFLINE_COMPILER_FLAGS).
# Only pass -Xs when there are offline compiler flags (AOT targets).
# For JIT (spir64) targets SYCL_OFFLINE_COMPILER_FLAGS is empty and
# passing a bare -Xs causes icx to consume the following -o flag as
# its argument, producing "no such file or directory" errors.
if(SYCL_OFFLINE_COMPILER_FLAGS)
set(_sycl_xs_flags -Xs ${SYCL_OFFLINE_COMPILER_FLAGS})
else()
set(_sycl_xs_flags)
endif()
add_custom_command(
OUTPUT ${output_file}
DEPENDS ${object_files}
COMMAND ${SYCL_EXECUTABLE}
${SYCL_device_link_flags}
-fsycl-link ${object_files}
-Xs ${sycl_offline_compiler_flags}
-o ${output_file}
Signed-off-by: adubey <adarsh.dubey@intel.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Adarsh Dubey <adarsh.dubey@intel.com>
airMeng
reviewed
Apr 10, 2026
Collaborator
airMeng
left a comment
There was a problem hiding this comment.
This is functional support, no performance specific optimization, right?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The baseline fp8 changes were taken from PR #109 Thanks.
Fix head_dim=16 cause: change test to head_dim=64 and add C++ guard to avoid rotary_dim==0 divide/modulo.
Replace undefined CHECK_INPUT with CHECK_DEVICE/CHECK_CONTIGUOUS in FusedQKNormRope.cpp.
Update test reference to use qkv.to(torch.float32) so comparisons match FP8 dequantized kernel outputs.
Make -Xs flag conditional in FindSYCL.cmake to avoid passing a bare -Xs to icx.
================================================================================