Skip to content

Commit fbea241

Browse files
authored
support group_size=24 (#7636)
1 parent 2a606e3 commit fbea241

2 files changed

Lines changed: 6 additions & 3 deletions

File tree

custom_ops/gpu_ops/append_attn/template_config.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"IsDynamicC8"
1818
],
1919
"dispatch_params": {
20-
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16],
20+
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16, 24],
2121
"HEAD_DIM": [128],
2222
"BLOCK_SIZE": [64],
2323
"CAUSAL": [0, 1],
@@ -54,7 +54,7 @@
5454
"ENABLE_PREFILL"
5555
],
5656
"dispatch_params": {
57-
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16],
57+
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16, 24],
5858
"HEAD_DIM": [128],
5959
"BLOCK_SIZE": [64],
6060
"CAUSAL": [0, 1],
@@ -89,7 +89,7 @@
8989
"ENABLE_PREFILL"
9090
],
9191
"dispatch_params": {
92-
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16],
92+
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16, 24],
9393
"HEAD_DIM": [64,128],
9494
"BLOCK_SIZE": [64],
9595
"CAUSAL": [0, 1],

custom_ops/gpu_ops/append_attn/utils.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,9 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
445445
} else if (group_size == 16) { \
446446
constexpr size_t GROUP_SIZE = 16; \
447447
__VA_ARGS__ \
448+
} else if (group_size == 24) { \
449+
constexpr size_t GROUP_SIZE = 24; \
450+
__VA_ARGS__ \
448451
} else { \
449452
PD_THROW("not support the group_size", group_size); \
450453
}

0 commit comments

Comments
 (0)