Skip to content

Commit 1c31519

Browse files
authored
[CK_TILE][FMHA] Add FP8 support for batch_prefill kernel (#3425)
* Add fp8bf16 support for batch_prefill * Fix wrong scale_s re-compute logic in batch_prefill * Fix wrong scale_s re-compute logic in fmha fwd * Fix batch_prefill codegen error * Remove no-longer used GetName() function * Add fp8 logits=True instances * Update CHANGELOG.md
1 parent c0797c1 commit 1c31519

6 files changed

Lines changed: 175 additions & 90 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
1010
* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
1111
* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline.
1212
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
13+
* Added FP8 KV cache support for FMHA batch prefill.
1314

1415
### Changed
1516

example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,15 @@
2424
)
2525
from codegen.utils import update_file
2626

27-
28-
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
27+
DTYPE_BITS = {
28+
"fp32": 32,
29+
"fp16": 16,
30+
"bf16": 16,
31+
"fp8": 8,
32+
"fp8bf16": 8,
33+
"fp8fp32": 8,
34+
"bf8": 8,
35+
}
2936

3037
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
3138

@@ -108,7 +115,7 @@
108115
{{
109116
using k_ = fmha_kernel_{F_idx};
110117
if(s.log_level_ > 0)
111-
std::cout << ", " << k_::GetName() << std::flush;
118+
std::cout << ", {F_kname}" << std::flush;
112119
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
113120
const dim3 blocks = k_::BlockSize();
114121
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
@@ -494,6 +501,7 @@ class FmhaFwdKernel:
494501
@property
495502
def template(self) -> str:
496503
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
504+
F_kname=self.name,
497505
F_idx=self.F_idx,
498506
F_hdim=self.F_hdim,
499507
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
@@ -576,10 +584,14 @@ def api_trait(self) -> FmhaFwdApiTrait:
576584
class KernelComponentFactory:
577585
@staticmethod
578586
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
579-
if dtype == "fp16" or dtype == "bf16":
587+
if dtype in ["fp16", "bf16"]:
580588
return {
581589
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
582590
} # fmt: skip
591+
elif dtype in ["fp8bf16"]:
592+
return {
593+
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
594+
} # fmt: skip
583595
else:
584596
return None
585597

@@ -589,20 +601,26 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
589601
# TODO: the order of List matters! the later in this list will be also be checked later
590602
# TODO: currently for qr pipeline, let 't' padding to appear later!!
591603
# TODO: how to design this more generic?
592-
qscale = "no"
593604
pipelines = []
594605
if dtype in ["fp16", "bf16"]:
606+
qscale = "no"
595607
for logits, mask, bias, lse, dropout in itertools.product(
596608
["t", "f"],
597609
get_mask_map(mask_impl).keys(),
598610
BIAS_MAP.keys(),
599611
["t", "f"],
600612
["t", "f"],
601613
):
602-
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
603614
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
604-
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
605-
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
615+
elif dtype in ["fp8bf16"]:
616+
# no need lse/dropout kernels
617+
for logits, qscale, mask, bias in itertools.product(
618+
["t", "f"],
619+
["pertensor"],
620+
get_mask_map(mask_impl).keys(),
621+
["no"],
622+
):
623+
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip
606624
else:
607625
assert False
608626
return pipelines
@@ -612,7 +630,7 @@ class CustomFactory(KernelComponentFactory):
612630
@staticmethod
613631
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
614632
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
615-
if dtype == "fp16" or dtype == "bf16":
633+
if dtype in ["fp16", "bf16"]:
616634
if 128 in result.keys():
617635
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
618636
return result
@@ -695,15 +713,14 @@ def get_fwd_blobs(
695713
continue
696714
# Aiter(mha_batch_prefill) integration
697715
elif receipt == 200:
698-
cond = dtype in ["fp16", "bf16"]
716+
cond = dtype in ["fp16", "bf16", "fp8bf16"]
699717
cond &= mode == "group"
700718
cond &= pipeline.F_vlayout == "row"
701-
cond &= pipeline.F_qscale == "no"
702719
if not cond:
703720
continue
704721
# aiter::mha_batch_prefill C++ api integration
705722
elif receipt == 600:
706-
cond = dtype in ["fp16", "bf16"]
723+
cond = dtype in ["fp16", "bf16", "fp8bf16"]
707724
cond &= mode == "group"
708725
cond &= pipeline.F_vlayout == "row"
709726
cond &= pipeline.F_qscale == "no"

example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,7 @@ def get_pipelines(
10171017
elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32:
10181018
# no need lse/dropout kernels
10191019
for logits, qscale, mask, bias, sink in itertools.product(
1020-
["f"],
1020+
["t", "f"],
10211021
["no", "pertensor"],
10221022
get_mask_map(mask_impl).keys(),
10231023
["no"],

example/ck_tile/01_fmha/fmha_fwd.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,9 @@ struct fmha_batch_prefill_args
500500
const void* k_ptr;
501501
const void* v_ptr;
502502
const void* bias_ptr; // bias or alibi_slope pointer
503+
const void* q_descale_ptr;
504+
const void* k_descale_ptr;
505+
const void* v_descale_ptr;
503506
void* rand_val_ptr;
504507
void* lse_ptr;
505508
void* o_ptr;
@@ -1118,6 +1121,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
11181121
args.k_ptr,
11191122
args.v_ptr,
11201123
args.bias_ptr,
1124+
args.q_descale_ptr,
1125+
args.k_descale_ptr,
1126+
args.v_descale_ptr,
11211127
args.rand_val_ptr,
11221128
args.lse_ptr,
11231129
args.o_ptr,
@@ -1166,6 +1172,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
11661172
args.k_ptr,
11671173
args.v_ptr,
11681174
args.bias_ptr,
1175+
args.q_descale_ptr,
1176+
args.k_descale_ptr,
1177+
args.v_descale_ptr,
11691178
args.rand_val_ptr,
11701179
args.lse_ptr,
11711180
args.o_ptr,

0 commit comments

Comments
 (0)