diff --git a/.github/workflows/ut.yaml b/.github/workflows/ut.yaml index a2e94ad44..599ace779 100644 --- a/.github/workflows/ut.yaml +++ b/.github/workflows/ut.yaml @@ -43,9 +43,9 @@ jobs: - name: build docker image & push to local id: build-image run: | - docker build -t xpu-kernel-ci-image:latest -f Dockerfile.xpu . - docker tag xpu-kernel-ci-image:latest ${{ env.REGISTRY }}/xpu-kernel-ci-image:latest - docker push ${{ env.REGISTRY }}/xpu-kernel-ci-image:latest + docker build -t xpu-kernel-ci-image:test-213 -f Dockerfile.xpu . + docker tag xpu-kernel-ci-image:test-213 ${{ env.REGISTRY }}/xpu-kernel-ci-image:test-213 + docker push ${{ env.REGISTRY }}/xpu-kernel-ci-image:test-213 build-docker-image-latest-bmg: runs-on: self-hosted-bmg @@ -65,9 +65,9 @@ jobs: - name: build docker image & push to local id: build-image run: | - docker build -t xpu-kernel-ci-image:latest -f Dockerfile.xpu . - docker tag xpu-kernel-ci-image:latest ${{ env.REGISTRY }}/xpu-kernel-ci-image:latest - docker push ${{ env.REGISTRY }}/xpu-kernel-ci-image:latest + docker build -t xpu-kernel-ci-image:test-213 -f Dockerfile.xpu . + docker tag xpu-kernel-ci-image:test-213 ${{ env.REGISTRY }}/xpu-kernel-ci-image:test-213 + docker push ${{ env.REGISTRY }}/xpu-kernel-ci-image:test-213 # Build wheel only once on PVC, then share via GitHub Actions artifact. # BMG runner cannot reach PVC directly (different network segment), so the @@ -76,7 +76,7 @@ jobs: runs-on: self-hosted-pvc needs: build-docker-image-latest-pvc container: - image: localhost:5000/xpu-kernel-ci-image:latest + image: localhost:5000/xpu-kernel-ci-image:test-213 options: --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --privileged -v ccache:/root/.ccache -e CCACHE_DIR=/root/.ccache steps: - name: Checkout @@ -125,8 +125,11 @@ jobs: runs-on: self-hosted-pvc needs: [build-docker-image-latest-pvc, build-wheel] timeout-minutes: 50 + defaults: + run: + shell: bash container: - image: localhost:5000/xpu-kernel-ci-image:latest + image: localhost:5000/xpu-kernel-ci-image:test-213 options: --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --privileged -v ccache:/root/.ccache -e CCACHE_DIR=/root/.ccache steps: - name: Checkout @@ -143,11 +146,15 @@ jobs: - name: install wheel run: | git config --global --add safe.directory "${GITHUB_WORKSPACE}" + source /opt/intel/oneapi/setvars.sh --force || true + source /opt/venv/bin/activate uv pip install -r requirements.txt VLLM_USE_PRECOMPILED=1 VLLM_PRECOMPILED_WHEEL_LOCATION=$(find dist -name '*.whl' -print -quit) uv pip install --no-build-isolation -e . -v - name: test run: | + source /opt/intel/oneapi/setvars.sh --force || true + source /opt/venv/bin/activate echo "Running tests with XPU_KERNEL_TEST_SCOPE=${{ env.XPU_KERNEL_TEST_SCOPE }}" XPU_KERNEL_TEST_SCOPE=${{ env.XPU_KERNEL_TEST_SCOPE }} ZE_AFFINITY_MASK=0,1 SKIP_ACC_ERROR_KERNEL=1 pytest -v -s tests/ --ignore=tests/test_fp8_gemm_onednn.py VLLM_XPU_FORCE_XE_DEFAULT_KERNEL=1 XPU_KERNEL_TEST_SCOPE=${{ env.XPU_KERNEL_TEST_SCOPE }} ZE_AFFINITY_MASK=0,1 pytest -v -s tests/fused_moe/test_grouped_gemm.py::test_grouped_gemm @@ -167,8 +174,11 @@ jobs: runs-on: self-hosted-bmg needs: [build-docker-image-latest-bmg, build-wheel] timeout-minutes: 50 + defaults: + run: + shell: bash container: - image: localhost:5000/xpu-kernel-ci-image:latest + image: localhost:5000/xpu-kernel-ci-image:test-213 options: --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --privileged -v ccache:/root/.ccache -e CCACHE_DIR=/root/.ccache steps: - name: Checkout @@ -185,11 +195,15 @@ jobs: - name: install wheel run: | git config --global --add safe.directory "${GITHUB_WORKSPACE}" + source /opt/intel/oneapi/setvars.sh --force || true + source /opt/venv/bin/activate uv pip install -r requirements.txt VLLM_USE_PRECOMPILED=1 VLLM_PRECOMPILED_WHEEL_LOCATION=$(find dist -name '*.whl' -print -quit) uv pip install --no-build-isolation -e . -v - name: test run: | + source /opt/intel/oneapi/setvars.sh --force || true + source /opt/venv/bin/activate echo "Running tests with XPU_KERNEL_TEST_SCOPE=${{ env.XPU_KERNEL_TEST_SCOPE }}" # tests/test_moe_align_block_size.py, tests/test_moe_lora_align_sum.py takes much time than expected. ignore it for now. XPU_KERNEL_TEST_SCOPE=${{ env.XPU_KERNEL_TEST_SCOPE }} ZE_AFFINITY_MASK=0,1 pytest -v -s tests/ --ignore=tests/test_lora_ops.py --ignore=tests/test_fp8_quant.py --ignore=tests/test_moe_align_block_size.py --ignore=tests/test_moe_lora_align_sum.py --ignore=tests/test_cache.py::test_swap_blocks --ignore=tests/test_topk_per_row.py --ignore=tests/test_lora_ops.py --ignore=tests/test_fp8_gemm_onednn.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 492aec7ec..5e9177169 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,7 @@ set(BUILD_SYCL_TLA_KERNELS CACHE BOOL "Build SYCL-TLA based kernels for XPU") # ARCHITECTURE OPTIONS option(VLLM_XPU_ENABLE_XE2 "Enable XE2 architecture kernels" ON) +option(VLLM_XPU_ENABLE_XE3 "Enable XE3 architecture kernels" OFF) option(VLLM_XPU_ENABLE_XE_DEFAULT "Enable XE Default architecture kernels" ON) # KERNEL OPTIONS — each controls whether the corresponding Python extension is @@ -71,6 +72,7 @@ message(STATUS "") message(STATUS "Kernel build configuration:") message(STATUS " BUILD_SYCL_TLA_KERNELS = ${BUILD_SYCL_TLA_KERNELS}") message(STATUS " VLLM_XPU_ENABLE_XE2 = ${VLLM_XPU_ENABLE_XE2}") +message(STATUS " VLLM_XPU_ENABLE_XE3 = ${VLLM_XPU_ENABLE_XE3}") message(STATUS " VLLM_XPU_ENABLE_XE_DEFAULT = ${VLLM_XPU_ENABLE_XE_DEFAULT}") message(STATUS " BASIC_KERNELS_ENABLED = ${BASIC_KERNELS_ENABLED}") message(STATUS " FA2_KERNELS_ENABLED = ${FA2_KERNELS_ENABLED}") @@ -178,8 +180,9 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") # VLLM_XPU_AOT_DEVICES and VLLM_XPU_XE2_AOT_DEVICES Example: export # VLLM_XPU_AOT_DEVICES="pvc,bmg-g21-a0" export # VLLM_XPU_XE2_AOT_DEVICES="pvc,bmg-g31-a0" - set(AOT_DEVICES "pvc,bmg,bmg-g21-a0,bmg-g31-a0") + set(AOT_DEVICES "pvc,bmg,bmg-g21-a0,bmg-g31-a0,xe3p,nvl-s") set(XE2_AOT_DEVICES "pvc,bmg,bmg-g21-a0,bmg-g31-a0") + set(XE3_AOT_DEVICES "xe3p,nvl-s") # Allow overriding via env, including explicitly disabling AOT by setting an # empty env var (e.g. export VLLM_XPU_AOT_DEVICES=""). @@ -312,7 +315,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") # header only library list(APPEND VLLM_CUTLASS_FLAGS "-DCUTLASS_ENABLE_HEADERS_ONLY") list(APPEND VLLM_CUTLASS_FLAGS "-DCUTLASS_ENABLE_SYCL") - list(APPEND VLLM_CUTLASS_FLAGS "-DSYCL_INTEL_TARGET") + # list(APPEND VLLM_CUTLASS_FLAGS "-DSYCL_INTEL_TARGET") list(APPEND VLLM_CUTLASS_FLAGS "-DCUTLASS_VERSIONS_GENERATED") list(APPEND VLLM_CUTLASS_FLAGS "-ftemplate-backtrace-limit=0") list(APPEND VLLM_CUTLASS_FLAGS "-fdiagnostics-color=always") @@ -363,6 +366,14 @@ if(BUILD_SYCL_TLA_KERNELS) endif() list(APPEND SYCL_TLA_COMPILE_OPTIONS -DVLLM_XPU_ENABLE_XE2) endif() + if(VLLM_XPU_ENABLE_XE3) + message("BUILDING XE3 ATTN!!!!!!") + # add_subdirectory(csrc/xpu/grouped_gemm/xe_3) + add_subdirectory(csrc/xpu/attn/xe_3) + # list(APPEND GROUPED_GEMM_LIB_NAME "grouped_gemm_xe_3") + list(APPEND ATTN_KERNEL_LIB_NAME "attn_kernels_xe_3") + list(APPEND SYCL_TLA_COMPILE_OPTIONS -DVLLM_XPU_ENABLE_XE3) + endif() list(APPEND VLLM_GPU_COMPILE_FLAGS ${SYCL_TLA_COMPILE_OPTIONS}) endif() diff --git a/Dockerfile.xpu b/Dockerfile.xpu index dfa2f2bb7..527b214ee 100644 --- a/Dockerfile.xpu +++ b/Dockerfile.xpu @@ -1,9 +1,9 @@ -FROM intel/deep-learning-essentials:2025.3.2-0-devel-ubuntu24.04 AS vllm-base +FROM intel/deep-learning-essentials:2026.0.0-devel-ubuntu24.04 AS vllm-base WORKDIR /workspace/ ARG PYTHON_VERSION=3.12 -ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/xpu" +ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/test/xpu" RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \ @@ -26,15 +26,16 @@ RUN apt clean && apt-get update -y && \ python3.12 \ python3.12-dev -RUN apt install -y libze1 libze-dev +RUN apt remove libze1 -y RUN mkdir neo && cd neo && \ - wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.22.2/intel-igc-core-2_2.22.2+20121_amd64.deb && \ - wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.22.2/intel-igc-opencl-2_2.22.2+20121_amd64.deb && \ - wget https://github.com/intel/compute-runtime/releases/download/25.44.36015.8/intel-ocloc_25.44.36015.8-0_amd64.deb && \ - wget https://github.com/intel/compute-runtime/releases/download/25.44.36015.8/intel-opencl-icd_25.44.36015.8-0_amd64.deb && \ - wget https://github.com/intel/compute-runtime/releases/download/25.44.36015.8/libigdgmm12_22.8.2_amd64.deb && \ - wget https://github.com/intel/compute-runtime/releases/download/25.44.36015.8/libze-intel-gpu1_25.44.36015.8-0_amd64.deb && \ + wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.32.7/intel-igc-core-2_2.32.7+21184_amd64.deb && \ + wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.32.7/intel-igc-opencl-2_2.32.7+21184_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/26.14.37833.4/intel-ocloc_26.14.37833.4-0_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/26.14.37833.4/intel-opencl-icd_26.14.37833.4-0_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/26.14.37833.4/libigdgmm12_22.9.0_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/26.14.37833.4/libze-intel-gpu1_26.14.37833.4-0_amd64.deb && \ + wget https://github.com/oneapi-src/level-zero/releases/download/v1.28.2/level-zero_1.28.2+u24.04_amd64.deb && \ dpkg -i *.deb && cd .. && rm -rf neo ENV PATH="/root/.local/bin:$PATH" diff --git a/cmake/utils.cmake b/cmake/utils.cmake index ac9399e05..65634baed 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -583,6 +583,7 @@ function(add_xe2_kernel_library LIBRARY_NAME) target_compile_options(${LIBRARY_NAME} PRIVATE ${SYCL_TLA_KERNELS_COMPILE_FLAGS} -fPIC) target_compile_definitions(${LIBRARY_NAME} PRIVATE -DVLLM_XPU_ENABLE_XE2) + target_compile_definitions(${LIBRARY_NAME} PRIVATE -DSYCL_INTEL_TARGET=20) target_include_directories(${LIBRARY_NAME} PRIVATE ${SYCL_TLA_INCLUDE_DIRS}) # Link torch libraries @@ -610,6 +611,76 @@ function(add_xe2_kernel_library LIBRARY_NAME) target_link_options(${LIBRARY_NAME} PRIVATE ${XE2_GPU_LINK_FLAGS}) endfunction() +# +# Create a shared library for XE3 kernels with common configuration. +# +# Arguments: LIBRARY_NAME: Name of the library to create (e.g., +# attn_kernels_xe_3) DESTINATION: Installation destination directory (optional, +# defaults to vllm_xpu_kernels) INCLUDE_CMAKE_SOURCE_DIR: Optional flag to +# include ${CMAKE_SOURCE_DIR} in include directories +# +function(add_xe3_kernel_library LIBRARY_NAME) + cmake_parse_arguments( + PARSE_ARGV 1 ARG "INCLUDE_CMAKE_SOURCE_DIR" # Boolean options + "DESTINATION" # Single value keywords + "" # Multi-value keywords + ) + + # Set default destination if not provided + if(NOT ARG_DESTINATION) + set(ARG_DESTINATION "vllm_xpu_kernels") + endif() + + # Set C++ standard + set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD_REQUIRED ON) + + # Find all source files + file(GLOB_RECURSE KERNEL_SOURCES "*.cpp" ${ATTN_KERNEL_SRCS_GEN}) + + # Create shared library + add_library(${LIBRARY_NAME} SHARED ${KERNEL_SOURCES}) + + # Set include directories + target_include_directories( + ${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/..) + + # Optionally add CMAKE_SOURCE_DIR + if(ARG_INCLUDE_CMAKE_SOURCE_DIR) + target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_SOURCE_DIR}) + endif() + + # Set compile options and definitions + target_compile_options( + ${LIBRARY_NAME} + PRIVATE ${SYCL_TLA_KERNELS_COMPILE_FLAGS} -fPIC -Wno-c++20-extensions + -Wno-intel-compat -Wno-pragma-once-outside-header) + target_compile_definitions(${LIBRARY_NAME} PRIVATE -DSYCL_INTEL_TARGET=35) + target_compile_definitions(${LIBRARY_NAME} PRIVATE -DVLLM_GRF_SIZE=512) + target_include_directories(${LIBRARY_NAME} PRIVATE ${SYCL_TLA_INCLUDE_DIRS}) + + # Link torch libraries + target_link_libraries(${LIBRARY_NAME} PRIVATE torch) + target_link_libraries(${LIBRARY_NAME} PRIVATE ${TORCH_LIBRARIES}) + + message( + STATUS + "Setting library output directory for target '${LIBRARY_NAME}' to '${CMAKE_BINARY_DIR}/'.'" + ) + set_target_properties(${LIBRARY_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY + "${CMAKE_BINARY_DIR}/") + install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${ARG_DESTINATION} + COMPONENT ${LIBRARY_NAME}) + + # Set link options for XE3 devices + set(XE3_GPU_LINK_FLAGS ${SYCL_DEVICE_LINK_FLAGS}) + list( + APPEND XE3_GPU_LINK_FLAGS -Xsycl-target-backend=spir64_gen + "-device ${XE3_AOT_DEVICES} -internal_options -cl-intel-512-GRF-per-thread") + target_link_options(${LIBRARY_NAME} PRIVATE ${XE3_GPU_LINK_FLAGS}) +endfunction() + # # Create a static library for XE default kernels with common configuration. # @@ -655,6 +726,7 @@ function(add_xe_default_kernel_library LIBRARY_NAME) PRIVATE ${SYCL_TLA_KERNELS_COMPILE_FLAGS} -fPIC) target_compile_definitions(${LIBRARY_NAME} PRIVATE -DVLLM_XPU_ENABLE_XE_DEFAULT) + target_compile_definitions(${LIBRARY_NAME} PRIVATE -DSYCL_INTEL_TARGET=20) target_include_directories(${LIBRARY_NAME} PRIVATE ${SYCL_TLA_INCLUDE_DIRS}) # Link torch libraries diff --git a/csrc/utils.h b/csrc/utils.h index 0b24dfba4..7d7976ee4 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -78,6 +78,13 @@ static inline bool is_xe3_arch(at::DeviceIndex device_index = -1) { arch == syclex::architecture::intel_gpu_wcl; } +#ifdef VLLM_XPU_ENABLE_XE3 +static inline bool is_xe3p_arch(at::DeviceIndex device_index = -1) { + auto arch = get_device_architecture(device_index); + return arch == syclex::architecture::intel_gpu_nvl_s; +} +#endif + static inline std::optional getEnv(const char* name) { if (const char* val = std::getenv(name)) return val; return std::nullopt; diff --git a/csrc/xpu/attn/xe_3/CMakeLists.txt b/csrc/xpu/attn/xe_3/CMakeLists.txt new file mode 100644 index 000000000..5d3d8861a --- /dev/null +++ b/csrc/xpu/attn/xe_3/CMakeLists.txt @@ -0,0 +1,10 @@ +cmake_minimum_required(VERSION 3.18) + +set(ATTN_KERNEL_SRCS_GEN) # output +include("chunk_prefill_configure.cmake") +fmha_forward_configure(chunk_prefill_kernel_template) + +include("paged_decode_configure.cmake") +paged_decode_configure(paged_decode_kernel_template) + +add_xe3_kernel_library(attn_kernels_xe_3 INCLUDE_CMAKE_SOURCE_DIR) diff --git a/csrc/xpu/attn/xe_3/chunk_prefill_configure.cmake b/csrc/xpu/attn/xe_3/chunk_prefill_configure.cmake new file mode 100644 index 000000000..99f4f8254 --- /dev/null +++ b/csrc/xpu/attn/xe_3/chunk_prefill_configure.cmake @@ -0,0 +1,54 @@ +function(fmha_forward_configure FILENAME_SUFFIX) + set(GEN_KERNEL_SRCS) # output + set(L_TYPES "fp16" "bf16") + set(L_BOOLS "false" "true") + set(BOOL_FLAG_false "f") + set(BOOL_FLAG_true "t") + set(policy_list + "chunk_policy_head64" "chunk_policy_head96" "chunk_policy_head128" + "chunk_policy_head192" "chunk_policy_head256") + + set(IMPL_KV_T "fp16") + + foreach(IMPL_POLICY ${policy_list}) + # foreach(IMPL_T ${L_TYPES}) + foreach(IMPL_KISPAGED ${L_BOOLS}) + foreach(IMPL_KISCAUSAL ${L_BOOLS}) + foreach(IMPL_KISLOCAL ${L_BOOLS}) + foreach(IMPL_KISSINK ${L_BOOLS}) + set(FILE_SUFFIX "${IMPL_POLICY}_") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}") + configure_file(${FILENAME_SUFFIX}.cpp.in + "${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp") + list( + APPEND + GEN_KERNEL_SRCS + "${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp" + ) + endforeach() + endforeach() + endforeach() + endforeach() + endforeach() + + list(REMOVE_DUPLICATES GEN_KERNEL_SRCS) + list(LENGTH GEN_KERNEL_SRCS GEN_KERNEL_SRCS_LENGTH) + message( + STATUS + "Generated ${FILENAME_SUFFIX} kernel sources: ${GEN_KERNEL_SRCS_LENGTH}") + set(GEN_KERNEL_SRCS + ${GEN_KERNEL_SRCS} + PARENT_SCOPE) + set(GEN_KERNEL_SRCS_LENGTH + ${GEN_KERNEL_SRCS_LENGTH} + PARENT_SCOPE) + + list(APPEND ATTN_KERNEL_SRCS_GEN ${GEN_KERNEL_SRCS}) + set(ATTN_KERNEL_SRCS_GEN + ${ATTN_KERNEL_SRCS_GEN} + PARENT_SCOPE) + +endfunction() diff --git a/csrc/xpu/attn/xe_3/chunk_prefill_kernel_template.cpp.in b/csrc/xpu/attn/xe_3/chunk_prefill_kernel_template.cpp.in new file mode 100644 index 000000000..25b292af3 --- /dev/null +++ b/csrc/xpu/attn/xe_3/chunk_prefill_kernel_template.cpp.in @@ -0,0 +1,27 @@ +#include "csrc/xpu/attn/xe_2/chunk_prefill.hpp" + +using namespace cute; + +// clang-format off +// macros to be filled in CMake +#define IMPL_T ${IMPL_T} +#define IMPL_KV_T ${IMPL_KV_T} +#define IMPL_POLICY ${IMPL_POLICY} +#cmakedefine01 IMPL_KISPAGED +#cmakedefine01 IMPL_KISCAUSAL +#cmakedefine01 IMPL_KISSINK +#cmakedefine01 IMPL_KISLOCAL +// clang-format on + +#define INSTANTIATE_KERNEL() \ + template void policy_dispatch_impl< \ + IMPL_POLICY, \ + static_cast(IMPL_KISPAGED), \ + static_cast(IMPL_KISCAUSAL), \ + static_cast(IMPL_KISLOCAL), \ + static_cast(IMPL_KISSINK)>( \ + sycl::queue & queue, \ + CutlassQKType& cuQKType, \ + const chunk_prefill_args_t& args); + +INSTANTIATE_KERNEL() diff --git a/csrc/xpu/attn/xe_3/fmha.h b/csrc/xpu/attn/xe_3/fmha.h new file mode 100644 index 000000000..4ab391767 --- /dev/null +++ b/csrc/xpu/attn/xe_3/fmha.h @@ -0,0 +1,24 @@ +#include + +void cutlass_chunk_prefill_xe3( + sycl::queue& queue, + const at::Tensor& query, // [seq_q, heads, head_size] + const at::Tensor& key_cache, // [num_block, block_size, heads, head_size] + const at::Tensor& value_cache, + at::Tensor& out, + const at::Tensor& block_table, + const at::Tensor& cu_seqlens_q, + const at::Tensor& cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + std::optional& k_scale, + std::optional& v_scale, + double sm_scale, + std::optional& sm_sink_, + int window_size_left, + int window_size_right, + bool is_varlen, + bool is_paged, + bool is_causal, + bool is_local, + bool is_sink); \ No newline at end of file diff --git a/csrc/xpu/attn/xe_3/fmha_xe3.cpp b/csrc/xpu/attn/xe_3/fmha_xe3.cpp new file mode 100644 index 000000000..2181e64d5 --- /dev/null +++ b/csrc/xpu/attn/xe_3/fmha_xe3.cpp @@ -0,0 +1,174 @@ +#include "fmha_xe3.h" +// FIXME: reuse chunk_prefill from xe2 now +#include "csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp" +#include "csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp" + +void cutlass_chunk_prefill_xe3( + sycl::queue& queue, + const at::Tensor& query, // [seq_q, heads, head_size] + const at::Tensor& key_cache, // [num_block, block_size, heads, head_size] + const at::Tensor& value_cache, + at::Tensor& out, + const at::Tensor& block_table, + const at::Tensor& cu_seqlens_q, + const at::Tensor& cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + std::optional& k_scale, + std::optional& v_scale, + double sm_scale, + std::optional& sm_sink_, + int window_size_left, + int window_size_right, + bool is_varlen, + bool is_paged, + bool is_causal, + bool is_local, + bool is_sink) { + cutlass_chunk_prefill_impl( + queue, + query, + key_cache, + value_cache, + out, + block_table, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + k_scale, + v_scale, + sm_scale, + sm_sink_, + window_size_left, + window_size_right, + is_varlen, + is_paged, + is_causal, + is_local, + is_sink); +} + +void cutlass_chunk_prefill_impl( + sycl::queue& queue, + const at::Tensor& query, // [seq_q, heads, head_size] + const at::Tensor& key_cache, // [num_block, block_size, heads, head_size] + const at::Tensor& value_cache, + at::Tensor& out, + const at::Tensor& block_table, + const at::Tensor& cu_seqlens_q, + const at::Tensor& cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + std::optional& k_scale, + std::optional& v_scale, + double sm_scale, + std::optional& sm_sink_, + int window_size_left, + int window_size_right, + bool is_varlen, + bool is_paged, + bool is_causal, + bool is_local, + bool is_sink) { + // general params + int batch_size, num_heads_q, num_heads_kv, head_size; + // additional params + int total_seqlen_q, total_seqlen_k; + int num_blocks, block_size, max_blocks_per_seq; + if (is_varlen) { + // query: [total_seq, num_heads, head_size] + batch_size = cu_seqlens_q.numel() - 1; + num_heads_q = query.size(1); + num_heads_kv = key_cache.size(1); + head_size = query.size(2); + total_seqlen_q = query.size(0); + total_seqlen_k = key_cache.size(0); + } else { + // query: [batch, num_heads, seq, head_size] + batch_size = query.size(0); + num_heads_q = query.size(1); + num_heads_kv = key_cache.size(1); + head_size = query.size(3); + max_seqlen_q = query.size(2); + max_seqlen_k = key_cache.size(2); + } + if (is_paged) { + num_blocks = key_cache.size(0); + block_size = key_cache.size(1); + num_heads_kv = key_cache.size(2); + max_blocks_per_seq = block_table.size(1); + total_seqlen_k = num_blocks * block_size; + } + + if (is_local) { + window_size_left = window_size_left == -1 ? max_seqlen_k : window_size_left; + window_size_right = + window_size_right == -1 ? max_seqlen_k : window_size_right; + if (is_causal) { + window_size_right = 0; + is_causal = false; + } + } + + bool is_fp8_kv = + (key_cache.scalar_type() == at::ScalarType::Float8_e5m2 || + key_cache.scalar_type() == at::ScalarType::Float8_e4m3fn); + + chunk_prefill_args_t args = { + query.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + out.data_ptr(), + is_paged ? block_table.data_ptr() : nullptr, + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + max_seqlen_q, + max_seqlen_k, + total_seqlen_q, + total_seqlen_k, + is_fp8_kv ? k_scale.value().data_ptr() : nullptr, + is_fp8_kv ? v_scale.value().data_ptr() : nullptr, + static_cast(sm_scale), + is_sink ? sm_sink_.value().data_ptr() : nullptr, + batch_size, + num_heads_q, + num_heads_kv, + head_size, + max_blocks_per_seq, + block_size, + window_size_left, + window_size_right, + is_varlen, // varlen + is_paged, // paged + is_causal, + is_local, + is_sink}; + + CutlassQKType cuQKType = aten_to_Cutlass_qk_dtype(query, key_cache); + + static constexpr int max_head_size = 256; + TORCH_CHECK( + head_size <= max_head_size, + "FMHA forward only supports head dimension at most " + + std::to_string(max_head_size)); + + if (args.head_size <= HEAD_SIZE_LIMIT_0) { + policy_dispatch_func( + queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); + } else if (args.head_size <= HEAD_SIZE_LIMIT_1) { + policy_dispatch_func( + queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); + } else if (args.head_size <= HEAD_SIZE_LIMIT_2) { + policy_dispatch_func( + queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); + } else if (args.head_size <= HEAD_SIZE_LIMIT_3) { + policy_dispatch_func( + queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); + } else if (args.head_size <= HEAD_SIZE_LIMIT_4) { + policy_dispatch_func( + queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); + } else { + TORCH_CHECK(false, "Unsupported head size for fmha"); + } +} diff --git a/csrc/xpu/attn/xe_3/paged_decode_configure.cmake b/csrc/xpu/attn/xe_3/paged_decode_configure.cmake new file mode 100644 index 000000000..2131afd19 --- /dev/null +++ b/csrc/xpu/attn/xe_3/paged_decode_configure.cmake @@ -0,0 +1,134 @@ +# ============================================================================= +# Paged Decode Kernel Configuration +# ============================================================================= +# This function generates kernel source files for all combinations of: - Policy +# types (q_group_size × head_size) - Boolean flags (Causal, Local, Sink) +# +# Each generated file instantiates one specific kernel configuration to enable +# parallel compilation and reduce individual object file sizes. +# +# Usage: paged_decode_configure(paged_decode_kernel_template) +# +# Parameters: FILENAME_SUFFIX - Base name for generated .cpp files (without +# extension) +# +# Output: GEN_KERNEL_SRCS - List of generated source file paths +# GEN_KERNEL_SRCS_LENGTH - Number of generated files ATTN_KERNEL_SRCS_GEN - +# Updated global list with appended sources +# ============================================================================= + +function(paged_decode_configure FILENAME_SUFFIX) + set(GEN_KERNEL_SRCS) # Initialize output list + + # Boolean flag values and their single-character representations + set(L_BOOLS "false" "true") + set(BOOL_FLAG_false "f") + set(BOOL_FLAG_true "t") + + # ============================================================================= + # Policy Configuration Mapping + # ============================================================================= + # Maps (q_group_size, head_size) pairs to policy type names These must match + # the policies defined in paged_decode_policy.hpp + + # Q-group size 8 policies + set(policy_8_64_64 "decode_policy_q8_h64_p64") + set(policy_8_96_64 "decode_policy_q8_h96_p64") + set(policy_8_128_64 "decode_policy_q8_h128_p64") + set(policy_8_192_64 "decode_policy_q8_h192_p64") + set(policy_8_256_64 "decode_policy_q8_h256_p64") + + set(policy_8_64_128 "decode_policy_q8_h64_p128") + set(policy_8_96_128 "decode_policy_q8_h96_p128") + set(policy_8_128_128 "decode_policy_q8_h128_p128") + set(policy_8_192_128 "decode_policy_q8_h192_p128") + set(policy_8_256_128 "decode_policy_q8_h256_p128") + + # Q-group size 16 policies + set(policy_16_64_64 "decode_policy_q16_h64_p64") + set(policy_16_96_64 "decode_policy_q16_h96_p64") + set(policy_16_128_64 "decode_policy_q16_h128_p64") + set(policy_16_192_64 "decode_policy_q16_h192_p64") + set(policy_16_256_64 "decode_policy_q16_h256_p64") + + set(policy_16_64_128 "decode_policy_q16_h64_p128") + set(policy_16_96_128 "decode_policy_q16_h96_p128") + set(policy_16_128_128 "decode_policy_q16_h128_p128") + set(policy_16_192_128 "decode_policy_q16_h192_p128") + set(policy_16_256_128 "decode_policy_q16_h256_p128") + + # Configuration space dimensions + set(qgroup_list "8" "16") + set(headsize_list "64" "96" "128" "192" "256") + set(pagesize_list "64" "128") + + # ============================================================================= + # Generate Kernel Sources + # ============================================================================= + # Iterate over all combinations: policy × causal × local × sink + + foreach(IMPL_QGROUP ${qgroup_list}) + foreach(IMPL_HEADSIZE ${headsize_list}) + foreach(IMPL_PAGESIZE ${pagesize_list}) + # Lookup policy name from mapping + set(IMPL_POLICY + ${policy_${IMPL_QGROUP}_${IMPL_HEADSIZE}_${IMPL_PAGESIZE}}) + + foreach(IMPL_KISCAUSAL ${L_BOOLS}) + foreach(IMPL_KISLOCAL ${L_BOOLS}) + foreach(IMPL_KISSINK ${L_BOOLS}) + # Construct unique filename suffix: e.g., _q8_h64_fff + set(FILE_SUFFIX + "_q${IMPL_QGROUP}_h${IMPL_HEADSIZE}_p${IMPL_PAGESIZE}_") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}") + + # Generate .cpp file from template + configure_file(${FILENAME_SUFFIX}.cpp.in + "${FILENAME_SUFFIX}${FILE_SUFFIX}.cpp") + + # Add to output list + list( + APPEND + GEN_KERNEL_SRCS + "${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}${FILE_SUFFIX}.cpp" + ) + endforeach() + endforeach() + endforeach() + endforeach() + endforeach() + endforeach() + + # ============================================================================= + # Output Results + # ============================================================================= + + list(REMOVE_DUPLICATES GEN_KERNEL_SRCS) + list(LENGTH GEN_KERNEL_SRCS GEN_KERNEL_SRCS_LENGTH) + + message( + STATUS + "Generated ${FILENAME_SUFFIX} sources: ${GEN_KERNEL_SRCS_LENGTH} files") + + # Export to parent scope + set(GEN_KERNEL_SRCS + ${GEN_KERNEL_SRCS} + PARENT_SCOPE) + set(GEN_KERNEL_SRCS_LENGTH + ${GEN_KERNEL_SRCS_LENGTH} + PARENT_SCOPE) + + # Update global kernel source list + list(APPEND ATTN_KERNEL_SRCS_GEN ${GEN_KERNEL_SRCS}) + set(ATTN_KERNEL_SRCS_GEN + ${ATTN_KERNEL_SRCS_GEN} + PARENT_SCOPE) + + message( + STATUS + "Total ATTN kernel sources after ${FILENAME_SUFFIX}: ${ATTN_KERNEL_SRCS_GEN}" + ) + +endfunction() diff --git a/csrc/xpu/attn/xe_3/paged_decode_kernel_template.cpp.in b/csrc/xpu/attn/xe_3/paged_decode_kernel_template.cpp.in new file mode 100644 index 000000000..49193485a --- /dev/null +++ b/csrc/xpu/attn/xe_3/paged_decode_kernel_template.cpp.in @@ -0,0 +1,35 @@ +#include "csrc/xpu/attn/xe_2/paged_decode.hpp" + +using namespace cute; + +// ============================================================================= +// CMake Template Variables +// ============================================================================= +// These macros are populated by CMake during the configuration process +// to generate specific kernel instantiations for each policy combination. + +// clang-format off +#define IMPL_POLICY ${IMPL_POLICY} +#cmakedefine01 IMPL_KISCAUSAL +#cmakedefine01 IMPL_KISLOCAL +#cmakedefine01 IMPL_KISSINK +// clang-format on + +// ============================================================================= +// Explicit Template Instantiation +// ============================================================================= +// Instantiate the decode_policy_dispatch_impl function template with the +// specific policy and boolean flag combinations provided by CMake. This +// produces one compiled kernel per source file. + +#define INSTANTIATE_KERNEL() \ + template void decode_policy_dispatch_impl< \ + IMPL_POLICY, \ + static_cast(IMPL_KISCAUSAL), \ + static_cast(IMPL_KISLOCAL), \ + static_cast(IMPL_KISSINK)>( \ + sycl::queue & queue, \ + CutlassQKType & cuQKType, \ + const paged_decode_args_t& args); + +INSTANTIATE_KERNEL() diff --git a/csrc/xpu/attn/xe_3/paged_decode_xe3.cpp b/csrc/xpu/attn/xe_3/paged_decode_xe3.cpp new file mode 100644 index 000000000..395d9f3fd --- /dev/null +++ b/csrc/xpu/attn/xe_3/paged_decode_xe3.cpp @@ -0,0 +1,207 @@ +#include "paged_decode_xe3.h" +#include "csrc/xpu/attn/xe_2/paged_decode_utils.hpp" +#include "csrc/xpu/attn/xe_2/paged_decode_extern.hpp" + +using namespace cute; + +void cutlass_paged_decode_xe3( + sycl::queue& queue, + const at::Tensor& query, // [seq_q, heads, head_size] + const at::Tensor& key_cache, // [num_block, block_size, heads, head_size] + const at::Tensor& value_cache, + at::Tensor& out, + at::Tensor& + temp_out, // [batch, num_head_q, seq_q, head_size, num_kv_splits] + at::Tensor& exp_sums, // [batch, num_head_q, seq_q, num_kv_splits] + at::Tensor& max_logits, // [batch, num_head_q, seq_q, num_kv_splits] + const at::Tensor& block_table, + const at::Tensor& cu_seqlens_q, + const at::Tensor& cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + std::optional& k_scale, + std::optional& v_scale, + double sm_scale, + std::optional& sm_sink_, + int window_size_left, + int window_size_right, + bool is_varlen, + bool is_paged, + bool is_causal, + bool is_local, + bool is_sink, + int num_kv_splits) { + cutlass_paged_decode_impl( + queue, + query, + key_cache, + value_cache, + out, + temp_out, + exp_sums, + max_logits, + block_table, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + k_scale, + v_scale, + sm_scale, + sm_sink_, + window_size_left, + window_size_right, + is_varlen, + is_paged, + is_causal, + is_local, + is_sink, + num_kv_splits); +} + +void cutlass_paged_decode_impl( + sycl::queue& queue, + const at::Tensor& query, // [seq_q, heads, head_size] + const at::Tensor& key_cache, // [num_block, block_size, heads, head_size] + const at::Tensor& value_cache, + at::Tensor& out, + at::Tensor& + temp_out, // [batch, num_head_q, seq_q, head_size, num_kv_splits] + at::Tensor& exp_sums, // [batch, num_head_q, seq_q, num_kv_splits] + at::Tensor& max_logits, // [batch, num_head_q, seq_q, num_kv_splits] + const at::Tensor& block_table, + const at::Tensor& cu_seqlens_q, + const at::Tensor& cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + std::optional& k_scale, + std::optional& v_scale, + double sm_scale, + std::optional& sm_sink_, + int window_size_left, + int window_size_right, + bool is_varlen, + bool is_paged, + bool is_causal, + bool is_local, + bool is_sink, + int num_kv_splits) { + bool is_fp8_kv = key_cache.scalar_type() == at::ScalarType::Float8_e5m2 || + key_cache.scalar_type() == at::ScalarType::Float8_e4m3fn; + if (is_fp8_kv) { + TORCH_CHECK( + k_scale.has_value() && v_scale.has_value(), + "FP8 KV cache requires both k_scale and v_scale tensors to be " + "provided."); + } + // general params + int batch_size, num_heads_q, num_heads_kv, head_size, v_head_size; + // additional params + int total_seqlen_q, total_seqlen_k; + int num_blocks, block_size, max_blocks_per_seq; + if (is_varlen) { + // query: [total_seq, num_heads, head_size] + batch_size = cu_seqlens_q.numel() - 1; + num_heads_q = query.size(1); + num_heads_kv = key_cache.size(1); + head_size = query.size(2); + v_head_size = value_cache.size(-1); + total_seqlen_q = query.size(0); + total_seqlen_k = key_cache.size(0); + } else { + // query: [batch, num_heads, seq, head_size] + batch_size = query.size(0); + num_heads_q = query.size(1); + num_heads_kv = key_cache.size(1); + head_size = query.size(3); + v_head_size = value_cache.size(-1); + max_seqlen_q = query.size(2); + max_seqlen_k = key_cache.size(2); + } + if (is_paged) { + // num_blocks is used to build total_seqlen_k for shape_K in kernels + // it is not just the meaning of used blocks for kv. + num_blocks = key_cache.size(0); + block_size = key_cache.size(1); + num_heads_kv = key_cache.size(2); + max_blocks_per_seq = block_table.size(1); + total_seqlen_k = num_blocks * block_size; + } + + if (is_local) { + window_size_left = window_size_left == -1 ? max_seqlen_k : window_size_left; + window_size_right = + window_size_right == -1 ? max_seqlen_k : window_size_right; + } + + paged_decode_args_t args = { + query.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + out.data_ptr(), + temp_out.data_ptr(), + exp_sums.data_ptr(), + max_logits.data_ptr(), + block_table.data_ptr(), + cu_seqlens_q.data_ptr(), + cu_seqlens_k.data_ptr(), + max_seqlen_q, + max_seqlen_k, + total_seqlen_q, + total_seqlen_k, + is_fp8_kv ? k_scale.value().data_ptr() : nullptr, + is_fp8_kv ? v_scale.value().data_ptr() : nullptr, + static_cast(sm_scale), + is_sink ? sm_sink_.value().data_ptr() : nullptr, + batch_size, + num_heads_q, + num_heads_kv, + head_size, + v_head_size, + max_blocks_per_seq, + block_size, + window_size_left, + window_size_right, + is_varlen, // varlen + is_paged, // paged + is_causal, + is_local, + is_sink, + false, // is_interleaved_kv_cache + num_kv_splits, + // KV cache strides + key_cache.stride(0), + key_cache.stride(1), + key_cache.stride(2), + value_cache.stride(0), + value_cache.stride(1), + value_cache.stride(2)}; + + CutlassQKType cuQKType = aten_to_Cutlass_qk_dtype(query, key_cache); + + static constexpr int max_head_size = 256; + TORCH_CHECK( + head_size <= max_head_size, + "FMHA forward only supports head dimension at most " + + std::to_string(max_head_size)); + + auto get_head_size_case = [](int head_size) -> int { + if (head_size <= HEAD_SIZE_LIMIT_0) return 0; + if (head_size <= HEAD_SIZE_LIMIT_1) return 1; + if (head_size <= HEAD_SIZE_LIMIT_2) return 2; + if (head_size <= HEAD_SIZE_LIMIT_3) return 3; + if (head_size <= HEAD_SIZE_LIMIT_4) return 4; + return -1; + }; + + int head_case = get_head_size_case(args.head_size); + int num_q_group_size = num_heads_q / num_heads_kv; + + if (num_q_group_size <= 8) { + dispatch_by_page_size<_8>(block_size, head_case, queue, cuQKType, args); + } else if (num_q_group_size <= 16) { + dispatch_by_page_size<_16>(block_size, head_case, queue, cuQKType, args); + } else { + TORCH_CHECK(false, "Unsupported num_heads_q / num_heads_kv for fmha"); + } +} diff --git a/csrc/xpu/attn/xe_3/paged_decode_xe3.h b/csrc/xpu/attn/xe_3/paged_decode_xe3.h new file mode 100644 index 000000000..01ef3aab8 --- /dev/null +++ b/csrc/xpu/attn/xe_3/paged_decode_xe3.h @@ -0,0 +1,29 @@ +#include + +void cutlass_paged_decode_xe3( + sycl::queue& queue, + const at::Tensor& query, // [seq_q, heads, head_size] + const at::Tensor& key_cache, // [num_block, block_size, heads, head_size] + const at::Tensor& value_cache, + at::Tensor& out, + at::Tensor& + temp_out, // [batch, num_head_q, seq_q, head_size, num_kv_splits] + at::Tensor& exp_sums, // [batch, num_head_q, seq_q, num_kv_splits] + at::Tensor& max_logits, // [batch, num_head_q, seq_q, num_kv_splits] + const at::Tensor& block_table, + const at::Tensor& cu_seqlens_q, + const at::Tensor& cu_seqlens_k, + int max_seqlen_q, + int max_seqlen_k, + std::optional& k_scale, + std::optional& v_scale, + double sm_scale, + std::optional& sm_sink_, + int window_size_left, + int window_size_right, + bool is_varlen, + bool is_paged, + bool is_causal, + bool is_local, + bool is_sink, + int num_kv_splits); diff --git a/pyproject.toml b/pyproject.toml index b70790b99..23b478017 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ "packaging>=24.2", "setuptools>=77.0.3,<80.0.0", "setuptools-scm>=8.0", - "torch == 2.11.0+xpu", + "torch @ https://download-r2.pytorch.org/whl/nightly/xpu/torch-2.13.0.dev20260522%2Bxpu-cp312-cp312-manylinux_2_28_x86_64.whl", "wheel", "regex", "jinja2", diff --git a/requirements.txt b/requirements.txt index f96f2b3bb..8994b853a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,9 +14,9 @@ ninja psutil # torch dependency ---extra-index-url=https://download.pytorch.org/whl/xpu -torch==2.11.0+xpu -triton-xpu +--extra-index-url https://download.pytorch.org/whl/nightly/xpu +torch @ https://download-r2.pytorch.org/whl/nightly/xpu/torch-2.13.0.dev20260522%2Bxpu-cp312-cp312-manylinux_2_28_x86_64.whl +triton-xpu @ https://download-r2.pytorch.org/whl/nightly/triton_xpu-3.7.1%2Bgit21033c4e-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl # may need oneapi packages # tests