@@ -2959,12 +2959,20 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
29592959 if (group_size_ == 32 ) { LAUNCH_TILED (hip_bfloat16, hip_bfloat16, 4 , 32 ); }
29602960 else if (group_size_ == 64 ) { LAUNCH_TILED (hip_bfloat16, hip_bfloat16, 4 , 64 ); }
29612961 else if (group_size_ == 128 ) { LAUNCH_TILED (hip_bfloat16, hip_bfloat16, 4 , 128 ); }
2962+ } else if (bits_ == 8 ) {
2963+ if (group_size_ == 32 ) { LAUNCH_TILED (hip_bfloat16, hip_bfloat16, 8 , 32 ); }
2964+ else if (group_size_ == 64 ) { LAUNCH_TILED (hip_bfloat16, hip_bfloat16, 8 , 64 ); }
2965+ else if (group_size_ == 128 ) { LAUNCH_TILED (hip_bfloat16, hip_bfloat16, 8 , 128 ); }
29622966 }
29632967 } else if (x.dtype () == float16) {
29642968 if (bits_ == 4 ) {
29652969 if (group_size_ == 32 ) { LAUNCH_TILED (__half, __half, 4 , 32 ); }
29662970 else if (group_size_ == 64 ) { LAUNCH_TILED (__half, __half, 4 , 64 ); }
29672971 else if (group_size_ == 128 ) { LAUNCH_TILED (__half, __half, 4 , 128 ); }
2972+ } else if (bits_ == 8 ) {
2973+ if (group_size_ == 32 ) { LAUNCH_TILED (__half, __half, 8 , 32 ); }
2974+ else if (group_size_ == 64 ) { LAUNCH_TILED (__half, __half, 8 , 64 ); }
2975+ else if (group_size_ == 128 ) { LAUNCH_TILED (__half, __half, 8 , 128 ); }
29682976 }
29692977 }
29702978 #undef LAUNCH_TILED
0 commit comments