Skip to content

Commit de0a61e

Browse files
ltqinassistant-librarian[bot]
authored andcommitted
[rocm-libraries] ROCm/rocm-libraries#6574 (commit b3db057)
[CK_TILE] Add SageAttention v2 forward kernel with multi-granularity quantization (#6574) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Add a CK_TILE forward kernel implementing [SageAttention v2](https://arxiv.org/abs/2411.10958) — an attention algorithm that applies multi-granularity quantization to Q/K/V before computing attention, trading minimal accuracy loss for higher throughput on low-precision hardware. ### Quantization design | Tensor | Supported data types | Scale granularity options | |--------|---------------------|--------------------------| | Q | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (32 tokens), per-thread (4 tokens) | | K | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (64 tokens), per-thread (16 tokens) | | V | fp8 | per-channel (always) | | O | bf16 | — | Three precision combinations are supported: `fp8/bf16` (QKV fp8, O bf16), `i8/fp8/bf16` (QK int8, V fp8, O bf16), and `i4/fp8/bf16` (QK int4, V fp8, O bf16). ### Architecture support - **gfx9** (CDNA2/3, e.g. gfx90a, gfx942) — full tile set - **gfx950** (CDNA4) — restricted tile set (N-per-block capped at 64 for fp8-family dtypes) ### Implementation - Two pipeline variants: `QRKSVS` (synchronous) and `QRKSVS_ASYNC` (async copy) - Masking support: no mask, causal (top-left / bottom-right), and generic windowed - Batch and group (variable-length) modes - Head dimension: d=128, d_v=128 - Python codegen under `example/ck_tile/49_sageattention/codegen/` generates kernel instances per target/dtype/tile combination - Smoke tests included via `tile_example_sageattn_fwd` ### Test commands \`\`\`bash # fp8 QKV ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=fp8bf16 -qscale=3 -init=3 # int8 QK, fp8 V ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=i8fp8bf16 -qscale=3 -init=3 \`\`\` \`-qscale\` values: 1=per-tensor, 2=per-block, 3=per-warp, 4=per-thread
1 parent e8d64ad commit de0a61e

30 files changed

Lines changed: 7809 additions & 0 deletions
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
# SPDX-License-Identifier: MIT
3+
4+
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
5+
# Currently only gfx9 arch is supported
6+
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
7+
if(NOT INST_TARGETS)
8+
message(WARNING "Skipping SageAttention compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
9+
return()
10+
endif()
11+
12+
# ====================================================================
13+
# SageAttention codegen - only FWD API, minimal instances
14+
# ====================================================================
15+
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
16+
${CMAKE_CURRENT_LIST_DIR}/generate.py
17+
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
18+
)
19+
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
20+
21+
list(JOIN INST_TARGETS , SAGEATTN_TARGETS_ARG)
22+
23+
# Only generate FWD API, only supported head dimension (128)
24+
# Note: Only d=128, d_v=128 has kernel tile definitions in sageattn_fwd.py
25+
set(SAGEATTN_FWD_CODE_GEN_COMMON_ARGS
26+
${CMAKE_CURRENT_LIST_DIR}/generate.py
27+
--targets ${SAGEATTN_TARGETS_ARG}
28+
--api fwd
29+
--optdim 128
30+
)
31+
32+
# Generate list of kernels to build
33+
execute_process(
34+
COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS}
35+
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt
36+
RESULT_VARIABLE ret
37+
)
38+
if(ret AND NOT ret EQUAL 0)
39+
message(FATAL_ERROR "SageAttention FAILED to generate kernel list via Python.")
40+
endif()
41+
42+
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sageattn_fwd_blob_list.txt SAGEATTN_FWD_GEN_BLOBS)
43+
44+
# Generate the kernel instance files
45+
add_custom_command(
46+
OUTPUT ${SAGEATTN_FWD_GEN_BLOBS}
47+
COMMAND ${Python3_EXECUTABLE} ${SAGEATTN_FWD_CODE_GEN_COMMON_ARGS}
48+
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
49+
DEPENDS ${CODE_GEN_SCRIPTS}
50+
COMMENT "Generate SageAttention FWD kernels"
51+
VERBATIM
52+
)
53+
54+
# Build the kernel instances library
55+
add_library(tile_sageattn_fwd_instances OBJECT EXCLUDE_FROM_ALL ${SAGEATTN_FWD_GEN_BLOBS})
56+
target_include_directories(tile_sageattn_fwd_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR})
57+
58+
# Compile options for kernel instances
59+
set(SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS)
60+
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-undefined-func-template)
61+
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -Wno-float-equal)
62+
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
63+
64+
if(CK_USE_OCP_FP8)
65+
list(APPEND SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
66+
endif()
67+
68+
target_compile_options(tile_sageattn_fwd_instances PRIVATE ${SAGEATTN_FWD_INSTANCE_COMPILE_OPTIONS})
69+
set_property(TARGET tile_sageattn_fwd_instances PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
70+
set_property(TARGET tile_sageattn_fwd_instances PROPERTY POSITION_INDEPENDENT_CODE ON)
71+
72+
# ====================================================================
73+
# SageAttention FWD Example
74+
# ====================================================================
75+
set(EXAMPLE_SAGEATTN_FWD "tile_example_sageattn_fwd")
76+
77+
message(DEBUG "adding example ${EXAMPLE_SAGEATTN_FWD}")
78+
79+
add_executable(${EXAMPLE_SAGEATTN_FWD} EXCLUDE_FROM_ALL example_sageattn_fwd.cpp)
80+
target_include_directories(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
81+
82+
# Link with our own minimal instances library (INDEPENDENT from FMHA!)
83+
target_link_libraries(${EXAMPLE_SAGEATTN_FWD} tile_sageattn_fwd_instances)
84+
85+
set(SAGEATTN_FWD_COMPILE_OPTIONS)
86+
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-undefined-func-template)
87+
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -Wno-float-equal)
88+
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
89+
90+
if(CK_USE_OCP_FP8)
91+
list(APPEND SAGEATTN_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
92+
endif()
93+
94+
target_compile_options(${EXAMPLE_SAGEATTN_FWD} PRIVATE ${SAGEATTN_FWD_COMPILE_OPTIONS})
95+
set_property(TARGET ${EXAMPLE_SAGEATTN_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
# SPDX-License-Identifier: MIT
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
# SPDX-License-Identifier: MIT
3+
4+
from dataclasses import dataclass, field
5+
from typing import Any, List, Callable
6+
7+
8+
@dataclass(frozen=True)
9+
class ArchTrait:
10+
name: str
11+
preprocessor_check: str = field(default=None)
12+
device_name_check: str = field(default=None)
13+
tag: str = field(default=None)
14+
filename_suffix: str = field(default=None)
15+
16+
def __post_init__(self):
17+
if self.preprocessor_check is None:
18+
object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)")
19+
if self.device_name_check is None:
20+
object.__setattr__(
21+
self,
22+
"device_name_check",
23+
f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0',
24+
)
25+
if self.tag is None:
26+
object.__setattr__(self, "tag", f"ck_tile::{self.name}_t")
27+
if self.filename_suffix is None:
28+
object.__setattr__(self, "filename_suffix", f"_{self.name}")
29+
30+
31+
def get_factories_for_targets(
32+
targets: List[str], get_factory: Callable[[str], Any]
33+
) -> List[Any]:
34+
factories = dict()
35+
for target in targets:
36+
factory = get_factory(target)
37+
factories[factory.arch.name] = factory
38+
# Place more specific architectures first
39+
factories = sorted(
40+
list(factories.values()), key=lambda f: len(f.arch.name), reverse=True
41+
)
42+
return factories
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
# SPDX-License-Identifier: MIT
3+
# generate kernel instances to speed up compilation
4+
GEN_DIR = "" # in Cmake, have to generate files in same folder
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
# SPDX-License-Identifier: MIT
3+
# generate kernel instances to speed up compilation
4+
FWD_DTYPE_MAP = {
5+
"fp16": "SageAttentionFwdFp16",
6+
"bf16": "SageAttentionFwdBf16",
7+
"fp8bf16": "SageAttentionFwdFp8Bf16",
8+
"i8fp8bf16": "SageAttentionFwdI8Fp8Bf16",
9+
"i4fp8bf16": "SageAttentionFwdI4Fp8Bf16",
10+
}
11+
12+
_MASK_SIMPLIFIED_MAP = {
13+
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
14+
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
15+
}
16+
17+
_MASK_MAP = {
18+
"no": "SageAttnMasks::NoMask",
19+
"causal": "SageAttnMasks::CausalMask",
20+
"generic": "SageAttnMasks::GenericMask",
21+
}
22+
23+
24+
def get_mask_map(mask_impl: str):
25+
if mask_impl == "generic":
26+
return _MASK_MAP
27+
elif mask_impl == "simplified":
28+
return _MASK_SIMPLIFIED_MAP
29+
else:
30+
assert False
31+
return None
32+
33+
34+
def get_mask_impl(mask: str) -> str:
35+
return "simplified" if mask.startswith("s_") else "generic"
36+
37+
38+
def get_mask_cpp_type(mask: str) -> str:
39+
return get_mask_map(get_mask_impl(mask))[mask]
40+
41+
42+
_MASK_CHECK_MAP = {
43+
"no": "t.mask_type == mask_enum::no_mask",
44+
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
45+
"generic": "t.mask_type == mask_enum::window_generic",
46+
}
47+
48+
_MASK_SIMPLIFIED_CHECK_MAP = {
49+
"s_no": "t.mask_type == mask_enum::no_mask",
50+
"s_mask": "t.mask_type != mask_enum::no_mask",
51+
}
52+
53+
54+
def get_mask_check_map(mask: str):
55+
if mask == "generic":
56+
return _MASK_CHECK_MAP
57+
elif mask == "simplified":
58+
return _MASK_SIMPLIFIED_CHECK_MAP
59+
else:
60+
assert False
61+
return None
62+
63+
64+
def get_mask_cpp_check_expr(mask: str) -> str:
65+
return get_mask_check_map(get_mask_impl(mask))[mask]
66+
67+
68+
QSCALE_MAP = {
69+
"no": "ck_tile::BlockSageAttentionQuantScaleEnum::NO_SCALE",
70+
"pertensor": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTENSOR",
71+
"blockscale": "ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE",
72+
"perwarp": "ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP",
73+
"perthread": "ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD",
74+
}
75+
76+
QSCALE_CHECK_MAP = {
77+
"no": "quant_scale_enum::no_scale",
78+
"pertensor": "quant_scale_enum::pertensor",
79+
"blockscale": "quant_scale_enum::blockscale",
80+
"perwarp": "quant_scale_enum::perwarp",
81+
"perthread": "quant_scale_enum::perthread",
82+
}
83+
84+
MODE_MAP = {"batch": "false", "group": "true"}
85+
86+
LAYOUT_MAP = {"row": "true", "col": "false"}
87+
88+
PIPELINE_MAP = {
89+
"qr": "ck_tile::BlockSageAttentionPipelineQRKSVS",
90+
"qr_async": "ck_tile::BlockSageAttentionPipelineQRKSVSAsync",
91+
}
92+
93+
PIPELINE_ENUM_MAP = {
94+
"qr": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS",
95+
"qr_async": "ck_tile::BlockSageAttnPipelineEnum::QRKSVS_ASYNC",
96+
}
97+
98+
BOOL_MAP = {
99+
"t": "true",
100+
"f": "false",
101+
True: "true",
102+
False: "false",
103+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
# SPDX-License-Identifier: MIT

0 commit comments

Comments
 (0)