@@ -169,6 +169,18 @@ llama_kv_cache::llama_kv_cache(
169169 continue ;
170170 }
171171
172+ if (n_embd_head_k_all == 0 ) {
173+ n_embd_head_k_all = (int32_t ) hparams.n_embd_head_k (il);
174+ } else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t ) hparams.n_embd_head_k (il)) {
175+ n_embd_head_k_all = -1 ;
176+ }
177+
178+ if (n_embd_head_v_all == 0 ) {
179+ n_embd_head_v_all = (int32_t ) hparams.n_embd_head_v (il);
180+ } else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t ) hparams.n_embd_head_v (il)) {
181+ n_embd_head_v_all = -1 ;
182+ }
183+
172184 // [TAG_V_CACHE_VARIABLE]
173185 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
174186 const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa (il) : hparams.n_embd_v_gqa_max ();
@@ -213,6 +225,9 @@ llama_kv_cache::llama_kv_cache(
213225 layers.push_back ({ il, k, v, k_stream, v_stream, });
214226 }
215227
228+ GGML_ASSERT (n_embd_head_k_all == -1 || n_embd_head_k_all > 0 );
229+ GGML_ASSERT (n_embd_head_v_all == -1 || n_embd_head_v_all > 0 );
230+
216231 if (reuse) {
217232 LLAMA_LOG_DEBUG (" %s: reusing layers:\n " , __func__);
218233
@@ -276,23 +291,23 @@ llama_kv_cache::llama_kv_cache(
276291
277292 attn_rot_k =
278293 !attn_rot_disable &&
294+ n_embd_head_k_all > 0 &&
279295 ggml_is_quantized (type_k) &&
280- !hparams.is_n_embd_k_gqa_variable () &&
281296 hparams.n_embd_head_k () % 64 == 0 ;
282297
283298 attn_rot_v =
284299 !attn_rot_disable &&
300+ n_embd_head_v_all > 0 &&
285301 ggml_is_quantized (type_v) &&
286- !hparams.is_n_embd_v_gqa_variable () &&
287302 hparams.n_embd_head_v () % 64 == 0 ;
288303
289- LLAMA_LOG_INFO (" %s: attn_rot_k = %d\n " , __func__, attn_rot_k);
290- LLAMA_LOG_INFO (" %s: attn_rot_v = %d\n " , __func__, attn_rot_v);
304+ LLAMA_LOG_INFO (" %s: attn_rot_k = %d, n_embd_head_k_all = %d \n " , __func__, attn_rot_k, n_embd_head_k_all );
305+ LLAMA_LOG_INFO (" %s: attn_rot_v = %d, n_embd_head_k_all = %d \n " , __func__, attn_rot_v, n_embd_head_v_all );
291306
292307 // pre-compute the haramard matrices and keep them in host memory
293308 // TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers
294309 if (attn_rot_k || attn_rot_v) {
295- for (int64_t n = 64 ; n <= std::max (hparams. n_embd_head_k (), hparams. n_embd_head_v () ); n *= 2 ) {
310+ for (int64_t n = 64 ; n <= std::max (n_embd_head_k_all, n_embd_head_v_all ); n *= 2 ) {
296311 attn_rot_hadamard[n] = std::vector<float >(n*n);
297312
298313 ggml_init_params params = {
@@ -1308,7 +1323,7 @@ ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const {
13081323 // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
13091324 do {
13101325 nrot *= 2 ;
1311- } while (hparams. n_embd_head_k () % nrot == 0 );
1326+ } while (n_embd_head_k_all % nrot == 0 );
13121327 nrot /= 2 ;
13131328
13141329 res = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, nrot, nrot);
0 commit comments