2424)
2525from codegen .utils import update_file
2626
27-
28- DTYPE_BITS = {"fp32" : 32 , "fp16" : 16 , "bf16" : 16 , "fp8" : 8 , "bf8" : 8 }
27+ DTYPE_BITS = {
28+ "fp32" : 32 ,
29+ "fp16" : 16 ,
30+ "bf16" : 16 ,
31+ "fp8" : 8 ,
32+ "fp8bf16" : 8 ,
33+ "fp8fp32" : 8 ,
34+ "bf8" : 8 ,
35+ }
2936
3037K0_MAX_SUBMAX_MAP = {32 : 32 , 64 : 64 , 96 : 128 , 128 : 128 , 256 : 256 }
3138
108115{{
109116 using k_ = fmha_kernel_{F_idx};
110117 if(s.log_level_ > 0)
111- std::cout << ", " << k_::GetName() << std::flush;
118+ std::cout << ", {F_kname}" << std::flush;
112119 auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
113120 const dim3 blocks = k_::BlockSize();
114121 constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
@@ -494,6 +501,7 @@ class FmhaFwdKernel:
494501 @property
495502 def template (self ) -> str :
496503 return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY .format (
504+ F_kname = self .name ,
497505 F_idx = self .F_idx ,
498506 F_hdim = self .F_hdim ,
499507 F_dtype = FWD_DTYPE_MAP [self .F_dtype ],
@@ -576,10 +584,14 @@ def api_trait(self) -> FmhaFwdApiTrait:
576584class KernelComponentFactory :
577585 @staticmethod
578586 def get_hdim_tile_size_dict (dtype : str ) -> Optional [dict ]:
579- if dtype == "fp16" or dtype == "bf16" :
587+ if dtype in [ "fp16" , "bf16" ] :
580588 return {
581589 128 : [FmhaFwdTileSize (128 , 128 , 32 , 128 , 32 , 128 , 4 , 1 , 1 , 4 , 1 , 1 , 32 , 32 , 16 , 32 , 32 , 16 , - 1 )],
582590 } # fmt: skip
591+ elif dtype in ["fp8bf16" ]:
592+ return {
593+ 128 : [FmhaFwdTileSize (128 , 128 , 32 , 128 , 32 , 128 , 4 , 1 , 1 , 4 , 1 , 1 , 32 , 32 , 32 , 32 , 32 , 32 , - 1 )],
594+ } # fmt: skip
583595 else :
584596 return None
585597
@@ -589,20 +601,26 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
589601 # TODO: the order of List matters! the later in this list will be also be checked later
590602 # TODO: currently for qr pipeline, let 't' padding to appear later!!
591603 # TODO: how to design this more generic?
592- qscale = "no"
593604 pipelines = []
594605 if dtype in ["fp16" , "bf16" ]:
606+ qscale = "no"
595607 for logits , mask , bias , lse , dropout in itertools .product (
596608 ["t" , "f" ],
597609 get_mask_map (mask_impl ).keys (),
598610 BIAS_MAP .keys (),
599611 ["t" , "f" ],
600612 ["t" , "f" ],
601613 ):
602- pipelines .append (FmhaFwdPipeline ("qr_async" , "row" , "t" , "f" , "t" , "t" , logits , bias , lse , dropout , qscale , mask )) # fmt: skip
603614 pipelines .append (FmhaFwdPipeline ("qr_async" , "row" , "t" , "t" , "t" , "t" , logits , bias , lse , dropout , qscale , mask )) # fmt: skip
604- # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
605- # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
615+ elif dtype in ["fp8bf16" ]:
616+ # no need lse/dropout kernels
617+ for logits , qscale , mask , bias in itertools .product (
618+ ["t" , "f" ],
619+ ["pertensor" ],
620+ get_mask_map (mask_impl ).keys (),
621+ ["no" ],
622+ ):
623+ pipelines .append (FmhaFwdPipeline ("qr_async" , "row" , "t" , "t" , "t" , "t" , logits , bias , "f" , "f" , qscale , mask )) # fmt: skip
606624 else :
607625 assert False
608626 return pipelines
@@ -612,7 +630,7 @@ class CustomFactory(KernelComponentFactory):
612630 @staticmethod
613631 def get_hdim_tile_size_dict (dtype : str ) -> Optional [dict ]:
614632 result = KernelComponentFactory .get_hdim_tile_size_dict (dtype )
615- if dtype == "fp16" or dtype == "bf16" :
633+ if dtype in [ "fp16" , "bf16" ] :
616634 if 128 in result .keys ():
617635 result [128 ].insert (0 , FmhaFwdTileSize ( 64 , 128 , 64 , 128 , 64 , 128 , 4 , 1 , 1 , 4 , 1 , 1 , 16 , 16 , 16 , 16 , 16 , 16 , - 1 , CppConstraint ("get_num_blocks(128) < num_cus * min_cu_util_rate" ))) # fmt: skip
618636 return result
@@ -695,15 +713,14 @@ def get_fwd_blobs(
695713 continue
696714 # Aiter(mha_batch_prefill) integration
697715 elif receipt == 200 :
698- cond = dtype in ["fp16" , "bf16" ]
716+ cond = dtype in ["fp16" , "bf16" , "fp8bf16" ]
699717 cond &= mode == "group"
700718 cond &= pipeline .F_vlayout == "row"
701- cond &= pipeline .F_qscale == "no"
702719 if not cond :
703720 continue
704721 # aiter::mha_batch_prefill C++ api integration
705722 elif receipt == 600 :
706- cond = dtype in ["fp16" , "bf16" ]
723+ cond = dtype in ["fp16" , "bf16" , "fp8bf16" ]
707724 cond &= mode == "group"
708725 cond &= pipeline .F_vlayout == "row"
709726 cond &= pipeline .F_qscale == "no"
0 commit comments