Skip to content

Commit 05292b3

Browse files
poyencillsilin
andauthored
[CK_TILE][FMHA] Integrate FAv2 & FAv3 (WIP) in the single fmha_fwd() API (#3153)
* Let fmha_fwd_v3() compatible with fmha_fwd() * Decouple get_fwd_blobs() and FmhaFwdKernel * Decouple compatibility checks from get_fwd_blobs() * Extract product feature checks out from get_fwd_blobs() * Remove duplicated code in factories and redundant checks * Remove FmhaFwdKernel<>::GetName() * Let FmhaFwdApiPool support pipelines with different mask_impl * Add tile setting for fmha fwd v3 pipeline * Add fwd v3 instances to tile_example_fmha_fwd manually * Remove unused function import * Undo irrelevant changes * Remove fwd v3 instances from tile_example_fmha_fwd * Finish fmha fwd v3 kernel instance codegen * Fix formatting * Remove unused F_idx attribute * Add is_generic_attention_mask<> traits * Add constraints to the fmha fwd v3 pipeline * Unify traits & problem used for fmha fwd v3 * Unify kernel launch code for fmha fwd v2 & v3 * Unify kernel template selection logic * Use same kernel codegen template for both v2 & v3 * Rename api() property as render() method * Allow specifying filter for fmha fwd api pool * Allow specifying function name when rendering api pool items * Separate fmha fwd v3 kernel dispatching logic from v2 * Remove lambda assignment * Add simple v2/v3 dispatch logic * Stop generating empty if-clauses Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them. * Use "".join() to concatenate fmha fwd api string content * Add more feature checks for fmha fwd v3 pipeline * Check features before dispatch to fmha_fwd_v3() * Add more feature checks for fmha_fwd_v3() * Add missing filter call * Use Tuple to reserve the dtype orders * Fix wrong pipeline matching logic * Add fmha fwd v3 group mode instances * Add functor_transform<> * Add type constraints to make_tile_window() * Remove fmha fwd v3 example * Fix wrong product(aiter mha_fwd()) config * Fix wrong fmha fwd v2/v3 selection logic * Fix formatting * Add comment to warning v3 kernel users * Fix wrong codegen logics * Remove unnecessary param * Fix format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
1 parent d1193e8 commit 05292b3

22 files changed

Lines changed: 897 additions & 1456 deletions

example/ck_tile/01_fmha/CMakeLists.txt

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -208,40 +208,6 @@ add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
208208
target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
209209
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
210210

211-
# add fmha_fwd_v3 example
212-
set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3")
213-
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")
214-
215-
add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp)
216-
target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
217-
file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS
218-
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
219-
)
220-
target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE
221-
fmha_fwd_v3.cpp
222-
${FMHA_FWD_V3_INSTANCES}
223-
)
224-
225-
set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS)
226-
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
227-
-fgpu-flush-denormals-to-zero
228-
-Wno-undefined-func-template
229-
--save-temps
230-
)
231-
set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS)
232-
233-
check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32)
234-
if(HAS_DISABLE_PACKED_FP32)
235-
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
236-
-mllvm --amdgpu-disable-packed-fp32=1
237-
)
238-
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS
239-
-DCK_TILE_DISABLE_PACKED_FP32=1
240-
)
241-
endif()
242-
243-
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
244-
target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS})
245211
# TODO: we have to turn off this global prop, otherwise the progress bar generated
246212
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
247213
# however, this property may affect global

example/ck_tile/01_fmha/codegen/cpp_symbol_map.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,24 @@
3030
}
3131

3232

33-
def get_mask_map(mask: str):
34-
if mask == "generic":
33+
def get_mask_map(mask_impl: str):
34+
if mask_impl == "generic":
3535
return _MASK_MAP
36-
elif mask == "simplified":
36+
elif mask_impl == "simplified":
3737
return _MASK_SIMPLIFIED_MAP
3838
else:
3939
assert False
4040
return None
4141

4242

43+
def get_mask_impl(mask: str) -> str:
44+
return "simplified" if mask.startswith("s_") else "generic"
45+
46+
47+
def get_mask_cpp_type(mask: str) -> str:
48+
return get_mask_map(get_mask_impl(mask))[mask]
49+
50+
4351
_MASK_CHECK_MAP = {
4452
"no": "t.mask_type == mask_enum::no_mask",
4553
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
@@ -62,6 +70,10 @@ def get_mask_check_map(mask: str):
6270
return None
6371

6472

73+
def get_mask_cpp_check_expr(mask: str) -> str:
74+
return get_mask_check_map(get_mask_impl(mask))[mask]
75+
76+
6577
QSCALE_MAP = {
6678
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
6779
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
@@ -122,6 +134,7 @@ def get_mask_check_map(mask: str):
122134
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
123135
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
124136
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
137+
"qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline",
125138
}
126139

127140
PIPELINE_ENUM_MAP = {
@@ -131,6 +144,7 @@ def get_mask_check_map(mask: str):
131144
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
132145
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
133146
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
147+
"qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3",
134148
}
135149

136150
BOOL_MAP = {

0 commit comments

Comments
 (0)