Skip to content

Commit 4d2f8c1

Browse files
jiangyon-amdpoyencasleepzzz
authored
[CK_TILE][FMHA] Add sparse attention VSA (#3341)
* add sparse attention VSA * fix the pre-commit * Add jenga test and pre-commit * add bf16 for vsa * add jenga support bf16 * remove lse arg * split kernel code to block & kernel * fix the pre-commit * fix the pre-commit * fix the copyrights * fix the copyright * fix the copyright & rename block to pipeline * fix the copyright and pipeline * remove lse & dropout & add fmt * fix the jenga&VSA code review * remove the useless code & resolved the comments * remove useless code * remove useless code * Clean up code * Remove more unused code * Re-format .hpp * Refactor codegen scripts --------- Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> Co-authored-by: asleepzzz <hanwen.chang@amd.com>
1 parent 2377a62 commit 4d2f8c1

22 files changed

Lines changed: 6058 additions & 0 deletions
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
# SPDX-License-Identifier: MIT
3+
# CMakeLists.txt for sparse attention (Jenga and VSA)
4+
5+
# Use SUPPORTED_GPU_TARGETS directly
6+
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
7+
set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS})
8+
9+
message(STATUS "Sparse Attention: SUPPORTED_GPU_TARGETS=${SUPPORTED_GPU_TARGETS}, INST_TARGETS=${INST_TARGETS}")
10+
11+
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12")
12+
if(NOT INST_TARGETS)
13+
message(WARNING "Skipping Tile Engine Sparse Attention: No supported GPU targets found")
14+
return()
15+
endif()
16+
17+
message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}")
18+
19+
# Code generation scripts
20+
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
21+
${CMAKE_CURRENT_LIST_DIR}/generate.py
22+
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
23+
)
24+
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
25+
26+
# ============================================================================
27+
# Jenga Sparse Attention
28+
# ============================================================================
29+
set(SPARSE_ATTN_JENGA_CODE_GEN_ARGS
30+
${CMAKE_CURRENT_LIST_DIR}/generate.py
31+
--api fwd_jenga
32+
--receipt 600
33+
)
34+
35+
# Generate list of Jenga kernels (at configure time, only list)
36+
execute_process(
37+
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS}
38+
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt
39+
RESULT_VARIABLE ret
40+
)
41+
if(ret AND NOT ret EQUAL 0)
42+
message(FATAL_ERROR "Failed to generate Jenga kernel list")
43+
endif()
44+
45+
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt SPARSE_ATTN_JENGA_GEN_BLOBS)
46+
47+
# Generate Jenga kernel source files at build time
48+
add_custom_command(
49+
OUTPUT ${SPARSE_ATTN_JENGA_GEN_BLOBS}
50+
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS}
51+
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
52+
DEPENDS ${CODE_GEN_SCRIPTS}
53+
COMMENT "Generate CK Tile Jenga Sparse Attention kernels"
54+
)
55+
56+
message(STATUS "Jenga kernel files to be generated: ${SPARSE_ATTN_JENGA_GEN_BLOBS}")
57+
58+
# Jenga Instances
59+
set(SPARSE_ATTN_JENGA_INSTANCES "tile_sparse_attn_jenga_instances")
60+
61+
add_library(${SPARSE_ATTN_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
62+
${SPARSE_ATTN_JENGA_GEN_BLOBS}
63+
${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp
64+
)
65+
target_include_directories(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE
66+
${CMAKE_CURRENT_LIST_DIR}
67+
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
68+
)
69+
set_source_files_properties(${SPARSE_ATTN_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
70+
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp PROPERTIES LANGUAGE HIP)
71+
set_property(TARGET ${SPARSE_ATTN_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
72+
73+
target_compile_options(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE
74+
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
75+
-DCK_TILE_FMHA_FWD_FAST_EXP2
76+
-Wno-undefined-func-template
77+
-Wno-float-equal
78+
)
79+
80+
# Jenga Example executable
81+
set(EXAMPLE_JENGA_SPARSE_ATTN "tile_example_jenga_sparse_attn")
82+
message(DEBUG "adding example ${EXAMPLE_JENGA_SPARSE_ATTN}")
83+
add_executable(${EXAMPLE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_jenga_sparse_attn.cpp)
84+
target_link_libraries(${EXAMPLE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES})
85+
target_include_directories(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
86+
target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE
87+
-Wno-undefined-func-template
88+
-Wno-float-equal
89+
)
90+
91+
# ============================================================================
92+
# VSA Sparse Attention
93+
# ============================================================================
94+
set(SPARSE_ATTN_VSA_CODE_GEN_ARGS
95+
${CMAKE_CURRENT_LIST_DIR}/generate.py
96+
--api fwd_vsa
97+
--receipt 600
98+
)
99+
100+
# Generate list of VSA kernels (at configure time, only list)
101+
execute_process(
102+
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS}
103+
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt
104+
RESULT_VARIABLE ret
105+
)
106+
if(ret AND NOT ret EQUAL 0)
107+
message(FATAL_ERROR "Failed to generate VSA kernel list")
108+
endif()
109+
110+
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt SPARSE_ATTN_VSA_GEN_BLOBS)
111+
112+
# Generate VSA kernel source files at build time
113+
add_custom_command(
114+
OUTPUT ${SPARSE_ATTN_VSA_GEN_BLOBS}
115+
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS}
116+
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
117+
DEPENDS ${CODE_GEN_SCRIPTS}
118+
COMMENT "Generate CK Tile VSA Sparse Attention kernels"
119+
)
120+
121+
message(STATUS "VSA kernel files to be generated: ${SPARSE_ATTN_VSA_GEN_BLOBS}")
122+
123+
# VSA Instances
124+
set(SPARSE_ATTN_VSA_INSTANCES "tile_sparse_attn_vsa_instances")
125+
126+
add_library(${SPARSE_ATTN_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
127+
${SPARSE_ATTN_VSA_GEN_BLOBS}
128+
${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp
129+
)
130+
target_include_directories(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE
131+
${CMAKE_CURRENT_LIST_DIR}
132+
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
133+
)
134+
set_source_files_properties(${SPARSE_ATTN_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
135+
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp PROPERTIES LANGUAGE HIP)
136+
set_property(TARGET ${SPARSE_ATTN_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
137+
138+
target_compile_options(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE
139+
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
140+
-DCK_TILE_FMHA_FWD_FAST_EXP2
141+
-Wno-undefined-func-template
142+
-Wno-float-equal
143+
)
144+
145+
# VSA Example executable
146+
set(EXAMPLE_VSA_SPARSE_ATTN "tile_example_vsa_sparse_attn")
147+
message(DEBUG "adding example ${EXAMPLE_VSA_SPARSE_ATTN}")
148+
add_executable(${EXAMPLE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp)
149+
target_link_libraries(${EXAMPLE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES})
150+
target_include_directories(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
151+
target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
152+
-Wno-undefined-func-template
153+
-Wno-float-equal
154+
)
155+
156+
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
# SPDX-License-Identifier: MIT
3+
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
# SPDX-License-Identifier: MIT
3+
# generate kernel instances to speed up compilation
4+
5+
FWD_DTYPE_MAP = {
6+
"fp16": "FmhaSparseFwdFp16",
7+
"bf16": "FmhaSparseFwdBf16",
8+
}
9+
10+
_MASK_SIMPLIFIED_MAP = {
11+
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
12+
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
13+
}
14+
15+
_MASK_MAP = {
16+
"no": "FmhaMasks::NoMask",
17+
"causal": "FmhaMasks::CausalMask",
18+
"generic": "FmhaMasks::GenericMask",
19+
}
20+
21+
22+
def get_mask_map(mask: str):
23+
if mask == "generic":
24+
return _MASK_MAP
25+
elif mask == "simplified":
26+
return _MASK_SIMPLIFIED_MAP
27+
else:
28+
assert False
29+
return None
30+
31+
32+
_MASK_CHECK_MAP = {
33+
"no": "t.mask_type == mask_enum::no_mask",
34+
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
35+
"generic": "t.mask_type == mask_enum::window_generic",
36+
}
37+
38+
_MASK_SIMPLIFIED_CHECK_MAP = {
39+
"s_no": "t.mask_type == mask_enum::no_mask",
40+
"s_mask": "t.mask_type != mask_enum::no_mask",
41+
}
42+
43+
44+
def get_mask_check_map(mask: str):
45+
if mask == "generic":
46+
return _MASK_CHECK_MAP
47+
elif mask == "simplified":
48+
return _MASK_SIMPLIFIED_CHECK_MAP
49+
else:
50+
assert False
51+
return None
52+
53+
54+
MODE_MAP = {"batch": "false"}
55+
56+
LAYOUT_MAP = {"row": "true", "col": "false"}
57+
58+
PIPELINE_MAP = {
59+
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga",
60+
"qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA",
61+
}
62+
63+
PIPELINE_ENUM_MAP = {
64+
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
65+
"qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
66+
}
67+
68+
BOOL_MAP = {
69+
"t": "true",
70+
"f": "false",
71+
True: "true",
72+
False: "false",
73+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
# SPDX-License-Identifier: MIT
3+

0 commit comments

Comments
 (0)