Skip to content

Commit f84270e

Browse files
authored
ggml : use 64 bytes aligned tile buffers (#21058)
| Model | Test | t/s OLD | t/s NEW | Speedup | |:---------------------------------|:-------|----------:|----------:|----------:| | qwen35 0.8B BF16 | pp512 | 584.59 | 595.41 | 1.02 | | qwen35 0.8B BF16 | tg128 | 52.23 | 52.82 | 1.01 | | qwen35 0.8B IQ2_M - 2.7 bpw | pp512 | 260.64 | 261.70 | 1.00 | | qwen35 0.8B IQ2_M - 2.7 bpw | tg128 | 81.17 | 80.89 | 1.00 | | qwen35 0.8B IQ2_XXS - 2.0625 bpw | pp512 | 302.36 | 302.56 | 1.00 | | qwen35 0.8B IQ2_XXS - 2.0625 bpw | tg128 | 84.93 | 85.12 | 1.00 | | qwen35 0.8B IQ3_XXS - 3.0625 bpw | pp512 | 263.22 | 260.01 | 0.99 | | qwen35 0.8B IQ3_XXS - 3.0625 bpw | tg128 | 80.29 | 78.94 | 0.98 | | qwen35 0.8B IQ4_NL - 4.5 bpw | pp512 | 728.65 | 742.09 | 1.02 | | qwen35 0.8B IQ4_NL - 4.5 bpw | tg128 | 82.39 | 84.46 | 1.03 | | qwen35 0.8B IQ4_XS - 4.25 bpw | pp512 | 681.33 | 677.06 | 0.99 | | qwen35 0.8B IQ4_XS - 4.25 bpw | tg128 | 80.18 | 79.28 | 0.99 | | qwen35 0.8B Q2_K_M | pp512 | 413.28 | 415.94 | 1.01 | | qwen35 0.8B Q2_K_M | tg128 | 81.90 | 82.78 | 1.01 | | qwen35 0.8B Q3_K_M | pp512 | 493.17 | 495.08 | 1.00 | | qwen35 0.8B Q3_K_M | tg128 | 82.75 | 83.23 | 1.01 | | qwen35 0.8B Q3_K_S | pp512 | 429.35 | 427.64 | 1.00 | | qwen35 0.8B Q3_K_S | tg128 | 86.69 | 87.02 | 1.00 | | qwen35 0.8B Q4_0 | pp512 | 783.46 | 782.32 | 1.00 | | qwen35 0.8B Q4_0 | tg128 | 88.23 | 87.90 | 1.00 | | qwen35 0.8B Q4_1 | pp512 | 741.71 | 729.76 | 0.98 | | qwen35 0.8B Q4_1 | tg128 | 85.44 | 86.01 | 1.01 | | qwen35 0.8B Q4_K_M | pp512 | 676.24 | 681.31 | 1.01 | | qwen35 0.8B Q4_K_M | tg128 | 76.59 | 77.06 | 1.01 | | qwen35 0.8B Q4_K_S | pp512 | 683.12 | 688.81 | 1.01 | | qwen35 0.8B Q4_K_S | tg128 | 80.50 | 81.19 | 1.01 | | qwen35 0.8B Q5_K_M | pp512 | 635.33 | 642.11 | 1.01 | | qwen35 0.8B Q5_K_M | tg128 | 72.07 | 72.49 | 1.01 | | qwen35 0.8B Q5_K_S | pp512 | 660.95 | 658.18 | 1.00 | | qwen35 0.8B Q5_K_S | tg128 | 72.19 | 72.95 | 1.01 | | qwen35 0.8B Q6_K | pp512 | 647.97 | 638.84 | 0.99 | | qwen35 0.8B Q6_K | tg128 | 72.83 | 72.49 | 1.00 | | qwen35 0.8B Q8_0 | pp512 | 805.01 | 785.49 | 0.98 | | qwen35 0.8B Q8_0 | tg128 | 70.10 | 70.13 | 1.00 | Signed-off-by: Adrien Gallouët <angt@huggingface.co>
1 parent 5594d13 commit f84270e

1 file changed

Lines changed: 16 additions & 16 deletions

File tree

ggml/src/ggml-cpu/amx/mmq.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,12 +2005,12 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
20052005
const int lda = KB * sizeof(TA);
20062006
//const int ldb = KB * sizeof(TB);
20072007

2008-
static thread_local packed_B_t Tile0[TILE_N * TILE_K];
2009-
static thread_local packed_B_t Tile1[TILE_N * TILE_K];
2010-
static thread_local int8_t Tile23[TILE_M * TILE_K];
2008+
alignas(64) static thread_local packed_B_t Tile0[TILE_N * TILE_K];
2009+
alignas(64) static thread_local packed_B_t Tile1[TILE_N * TILE_K];
2010+
alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K];
20112011

2012-
static thread_local int32_t TileC0[TILE_M * TILE_N * 4];
2013-
static thread_local int32_t TileC1[TILE_M * TILE_N * 4];
2012+
alignas(64) static thread_local int32_t TileC0[TILE_M * TILE_N * 4];
2013+
alignas(64) static thread_local int32_t TileC1[TILE_M * TILE_N * 4];
20142014

20152015
// double buffering C to interleave avx512 and amx
20162016
int32_t * C_cur = TileC0;
@@ -2187,21 +2187,21 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
21872187
const int m1 = std::max(M - TILE_M, 0);
21882188
//const int lda = KB * sizeof(TA);
21892189

2190-
static thread_local int8_t Tile0[TILE_N * TILE_K];
2191-
static thread_local int8_t Tile1[TILE_N * TILE_K];
2192-
static thread_local int8_t Tile23[TILE_M * TILE_K];
2190+
alignas(64) static thread_local int8_t Tile0[TILE_N * TILE_K];
2191+
alignas(64) static thread_local int8_t Tile1[TILE_N * TILE_K];
2192+
alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K];
21932193

21942194
// mat mul result for each group
2195-
static thread_local int32_t Tile4[TILE_M * TILE_N];
2196-
static thread_local int32_t Tile5[TILE_M * TILE_N];
2197-
static thread_local int32_t Tile6[TILE_M * TILE_N];
2198-
static thread_local int32_t Tile7[TILE_M * TILE_N];
2195+
alignas(64) static thread_local int32_t Tile4[TILE_M * TILE_N];
2196+
alignas(64) static thread_local int32_t Tile5[TILE_M * TILE_N];
2197+
alignas(64) static thread_local int32_t Tile6[TILE_M * TILE_N];
2198+
alignas(64) static thread_local int32_t Tile7[TILE_M * TILE_N];
21992199

22002200
// sum of each QK_K block, contains 8 groups, int32
2201-
static thread_local int32_t Sumi4[TILE_M * TILE_N];
2202-
static thread_local int32_t Sumi5[TILE_M * TILE_N];
2203-
static thread_local int32_t Sumi6[TILE_M * TILE_N];
2204-
static thread_local int32_t Sumi7[TILE_M * TILE_N];
2201+
alignas(64) static thread_local int32_t Sumi4[TILE_M * TILE_N];
2202+
alignas(64) static thread_local int32_t Sumi5[TILE_M * TILE_N];
2203+
alignas(64) static thread_local int32_t Sumi6[TILE_M * TILE_N];
2204+
alignas(64) static thread_local int32_t Sumi7[TILE_M * TILE_N];
22052205

22062206
const int k_group_size = std::is_same<TB, block_q6_K>::value ? 16 : 32;
22072207
for (int i = 0; i < KB; ++i) {

0 commit comments

Comments
 (0)