File tree Expand file tree Collapse file tree
include/ck_tile/ops/fmha/kernel Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -72,12 +72,14 @@ struct FmhaFwdKernel
7272 static constexpr std::string_view kPipelineName = FmhaPipeline::name;
7373
7474 // clang-format off
75- template <typename T > struct t2s ;
75+ template <typename T1 , typename T2 = T1 > struct t2s ;
7676 template <> struct t2s <float > { static constexpr const char * name = " fp32" ; };
7777 template <> struct t2s <ck_tile::fp16_t > { static constexpr const char * name = " fp16" ; };
7878 template <> struct t2s <ck_tile::bf16_t > { static constexpr const char * name = " bf16" ; };
7979 template <> struct t2s <ck_tile::fp8_t > { static constexpr const char * name = " fp8" ; };
8080 template <> struct t2s <ck_tile::bf8_t > { static constexpr const char * name = " bf8" ; };
81+ template <> struct t2s <ck_tile::fp8_t , ck_tile::bf16_t > { static constexpr const char * name = " fp8bf16" ; };
82+ template <> struct t2s <ck_tile::fp8_t , ck_tile::fp32_t > { static constexpr const char * name = " fp8fp32" ; };
8183 // clang-format on
8284
8385 CK_TILE_HOST static std::string GetName ()
@@ -99,7 +101,7 @@ struct FmhaFwdKernel
99101 if (kPadHeadDimV ) n += " dv" ;
100102 return n.empty () ? n : std::string (" p" ) + n; }();
101103 return
102- _SS_ (" fmha_fwd_d" ) + _TS_ (bfs::kQKHeaddim ) + " _" + _SS_ (t2s<QDataType>::name) +
104+ _SS_ (" fmha_fwd_d" ) + _TS_ (bfs::kQKHeaddim ) + " _" + _SS_ (t2s<QDataType, ODataType >::name) +
103105 " _" + (kIsGroupMode ? " group" : " batch" ) + " _"
104106 " b" + _TS_ (bfs::kM0 ) + " x" + _TS_ (bfs::kN0 ) + " x" + _TS_ (bfs::kK0 ) + " x" +
105107 _TS_ (bfs::kN1 ) + " x" + _TS_ (bfs::kK1 ) + " x" + _TS_ (bfs::kQKHeaddim ) + " _" +
You can’t perform that action at this time.
0 commit comments