Skip to content

Commit 2f3fd75

Browse files
authored
issue/481 fix: support more data type for rmsnorm in moore gpu
1 parent e698ef6 commit 2f3fd75

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

src/infiniop/ops/rms_norm/moore/rms_norm_moore.mu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,14 @@ infiniStatus_t launchKernel(
7777

7878
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
7979
LAUNCH_KERNEL(half, half, float);
80+
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16){
81+
LAUNCH_KERNEL(half, __mt_bfloat16, float);
8082
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
8183
LAUNCH_KERNEL(half, float, float);
8284
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
8385
LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float);
86+
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
87+
LAUNCH_KERNEL(__mt_bfloat16, half, float);
8488
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
8589
LAUNCH_KERNEL(__mt_bfloat16, float, float);
8690
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {

0 commit comments

Comments
 (0)