|
| 1 | +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. |
| 2 | +# SPDX-License-Identifier: MIT |
| 3 | + |
| 4 | +set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) |
| 5 | +# Currently only gfx9 arch is supported |
| 6 | +list(FILTER INST_TARGETS INCLUDE REGEX "gfx9") |
| 7 | +if(NOT INST_TARGETS) |
| 8 | + message(WARNING "Skipping SageAttention compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") |
| 9 | + return() |
| 10 | +endif() |
| 11 | + |
| 12 | +# ==================================================================== |
| 13 | +# SageAttention codegen - only FWD API, minimal instances |
| 14 | +# ==================================================================== |
| 15 | +file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS |
| 16 | + ${CMAKE_CURRENT_LIST_DIR}/generate.py |
| 17 | + ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py |
| 18 | +) |
| 19 | +set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}") |
| 20 | + |
| 21 | +list(JOIN INST_TARGETS , SAGEATTN_TARGETS_ARG) |
| 22 | + |
| 23 | +# Only generate FWD API, only supported head dimension (128) |
| 24 | +# Note: Only d=128, d_v=128 has kernel tile definitions in sageattn_fwd.py |
| 25 | +set(SAGEATTN_FWD_CODE_GEN_COMMON_ARGS |
| 26 | + ${CMAKE_CURRENT_LIST_DIR}/generate.py |
| 27 | + --targets ${SAGEATTN_TARGETS_ARG} |
| 28 | + --api fwd |
| 29 | + --optdim 128 |
| 30 | +) |
| 31 | + |
| 32 | +# Generate list of kernels to build |
| 33 | +execute_process( |
| 34 | + COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS} |
| 35 | + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt |
| 36 | + RESULT_VARIABLE ret |
| 37 | +) |
| 38 | +if(ret AND NOT ret EQUAL 0) |
| 39 | + message(FATAL_ERROR "SageAttention FAILED to generate kernel list via Python.") |
| 40 | +endif() |
| 41 | + |
| 42 | +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt SAGEATTN_FWD_GEN_BLOBS) |
| 43 | + |
| 44 | +# Generate the kernel instance files |
| 45 | +add_custom_command( |
| 46 | + OUTPUT ${SAGEATTN_FWD_GEN_BLOBS} |
| 47 | + COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS} |
| 48 | + --output_dir ${CMAKE_CURRENT_BINARY_DIR} |
| 49 | + DEPENDS ${CODE_GEN_SCRIPTS} |
| 50 | + COMMENT "Generate SageAttention FWD kernels" |
| 51 | + VERBATIM |
| 52 | +) |
| 53 | + |
| 54 | +# Build the kernel instances library |
| 55 | +add_library(tile_sageattn_fwd_instances OBJECT EXCLUDE_FROM_ALL ${SAGEATTN_FWD_GEN_BLOBS}) |
| 56 | +target_include_directories(tile_sageattn_fwd_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR}) |
| 57 | + |
| 58 | +# Compile options for kernel instances |
| 59 | +set(SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS) |
| 60 | +list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-undefined-func-template) |
| 61 | +list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-float-equal) |
| 62 | +list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero) |
| 63 | + |
| 64 | +if(CK_USE_OCP_FP8) |
| 65 | + list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) |
| 66 | +endif() |
| 67 | + |
| 68 | +target_compile_options(tile_sageattn_fwd_instances PRIVATE ${SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS}) |
| 69 | +set_property(TARGET tile_sageattn_fwd_instances PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) |
| 70 | +set_property(TARGET tile_sageattn_fwd_instances PROPERTY POSITION_INDEPENDENT_CODE ON) |
| 71 | + |
| 72 | +# ==================================================================== |
| 73 | +# SageAttention FWD Example |
| 74 | +# ==================================================================== |
| 75 | +set(EXAMPLE_SAGEATTN_FWD "tile_example_sageattn_fwd") |
| 76 | + |
| 77 | +message(DEBUG "adding example ${EXAMPLE_SAGEATTN_FWD}") |
| 78 | + |
| 79 | +add_executable(${EXAMPLE_SAGEATTN_FWD} EXCLUDE_FROM_ALL example_sageattn_fwd.cpp) |
| 80 | +target_include_directories(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) |
| 81 | + |
| 82 | +# Link with our own minimal instances library (INDEPENDENT from FMHA!) |
| 83 | +target_link_libraries(${EXAMPLE_SAGEATTN_FWD} tile_sageattn_fwd_instances) |
| 84 | + |
| 85 | +set(SAGEATTN_FWD_COMPILE_OPTIONS) |
| 86 | +list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-undefined-func-template) |
| 87 | +list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-float-equal) |
| 88 | +list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero) |
| 89 | + |
| 90 | +if(CK_USE_OCP_FP8) |
| 91 | + list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) |
| 92 | +endif() |
| 93 | + |
| 94 | +target_compile_options(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${SAGEATTN_FWD_COMPILE_OPTIONS}) |
| 95 | +set_property(TARGET ${EXAMPLE_SAGEATTN_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) |
0 commit comments