@@ -1102,12 +1102,12 @@ std::vector<at::Tensor> mha_fwd(
11021102
11031103 auto dispatch_page_size = [&](auto _QG_SZ, auto _HEAD_DIM) {
11041104 switch (params.page_size ) {
1105- // case 32:
1106- // launch_kernel(_QG_SZ, _HEAD_DIM, _32{}, _2{});
1107- // break;
1108- // case 64:
1109- // launch_kernel(_QG_SZ, _HEAD_DIM, _64{}, _4{});
1110- // break;
1105+ case 32 :
1106+ launch_kernel (_QG_SZ, _HEAD_DIM, _32{}, _2{});
1107+ break ;
1108+ case 64 :
1109+ launch_kernel (_QG_SZ, _HEAD_DIM, _64{}, _4{});
1110+ break ;
11111111 case 128 :
11121112 launch_kernel (_QG_SZ, _HEAD_DIM, _128{}, _8{});
11131113 break ;
@@ -1142,21 +1142,21 @@ std::vector<at::Tensor> mha_fwd(
11421142 };
11431143
11441144 switch (params.d ) {
1145- // case 64:
1146- // dispatch_q_group(_64{});
1147- // break;
1148- // case 96:
1149- // dispatch_q_group(_96{});
1150- // break;
1145+ case 64 :
1146+ dispatch_q_group (_64{});
1147+ break ;
1148+ case 96 :
1149+ dispatch_q_group (_96{});
1150+ break ;
11511151 case 128 :
11521152 dispatch_q_group (_128{});
11531153 break ;
1154- // case 192:
1155- // dispatch_q_group(_192{});
1156- // break;
1157- // case 256:
1158- // dispatch_q_group(_256{});
1159- // break;
1154+ case 192 :
1155+ dispatch_q_group (_192{});
1156+ break ;
1157+ case 256 :
1158+ dispatch_q_group (_256{});
1159+ break ;
11601160 default :
11611161 TORCH_CHECK (false , " Unsupported head size for decode attention: " , params.d );
11621162 }
0 commit comments