Skip to content

Commit 9bdad0b

Browse files
committed
add page 64
1 parent 6f76a31 commit 9bdad0b

1 file changed

Lines changed: 18 additions & 18 deletions

File tree

src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,12 +1102,12 @@ std::vector<at::Tensor> mha_fwd(
11021102

11031103
auto dispatch_page_size = [&](auto _QG_SZ, auto _HEAD_DIM) {
11041104
switch (params.page_size) {
1105-
// case 32:
1106-
// launch_kernel(_QG_SZ, _HEAD_DIM, _32{}, _2{});
1107-
// break;
1108-
// case 64:
1109-
// launch_kernel(_QG_SZ, _HEAD_DIM, _64{}, _4{});
1110-
// break;
1105+
case 32:
1106+
launch_kernel(_QG_SZ, _HEAD_DIM, _32{}, _2{});
1107+
break;
1108+
case 64:
1109+
launch_kernel(_QG_SZ, _HEAD_DIM, _64{}, _4{});
1110+
break;
11111111
case 128:
11121112
launch_kernel(_QG_SZ, _HEAD_DIM, _128{}, _8{});
11131113
break;
@@ -1142,21 +1142,21 @@ std::vector<at::Tensor> mha_fwd(
11421142
};
11431143

11441144
switch (params.d) {
1145-
// case 64:
1146-
// dispatch_q_group(_64{});
1147-
// break;
1148-
// case 96:
1149-
// dispatch_q_group(_96{});
1150-
// break;
1145+
case 64:
1146+
dispatch_q_group(_64{});
1147+
break;
1148+
case 96:
1149+
dispatch_q_group(_96{});
1150+
break;
11511151
case 128:
11521152
dispatch_q_group(_128{});
11531153
break;
1154-
// case 192:
1155-
// dispatch_q_group(_192{});
1156-
// break;
1157-
// case 256:
1158-
// dispatch_q_group(_256{});
1159-
// break;
1154+
case 192:
1155+
dispatch_q_group(_192{});
1156+
break;
1157+
case 256:
1158+
dispatch_q_group(_256{});
1159+
break;
11601160
default:
11611161
TORCH_CHECK(false, "Unsupported head size for decode attention: ", params.d);
11621162
}

0 commit comments

Comments
 (0)