@@ -220,6 +220,18 @@ def unmask(self, new_unmasked_len):
220220 self .unmasked_len += new_unmasked_len
221221
222222
223+ def _is_kv_shared_layer (
224+ layer_idx : int , n_layers : int , num_kv_shared_layers : int
225+ ) -> bool :
226+ """Check if this layer uses shared K/V from a donor layer (YOCO)."""
227+ if num_kv_shared_layers <= 0 :
228+ return False
229+ first_shared = n_layers - num_kv_shared_layers
230+ if first_shared <= 0 :
231+ return False
232+ return layer_idx >= first_shared
233+
234+
223235class StaticAttentionIOManager :
224236 class NGramCache :
225237 def __init__ (self , max_size ):
@@ -286,11 +298,7 @@ def _from_config(
286298 self .freqs_sin = freqs [1 ].to (dtype )
287299
288300 split_mha = config .attention_type in ("static" , "static_shas" )
289- # YOCO: skip cache creation for KV-shared layers (2nd half)
290301 num_kv_shared = getattr (config , "num_kv_shared_layers" , 0 )
291- first_kv_shared = (
292- config .n_layers - num_kv_shared if num_kv_shared > 0 else config .n_layers
293- )
294302 if split_mha :
295303 self .k_caches = {
296304 StaticKVCache .calculate_cache_key (layer_id , head_id ): torch .zeros (
@@ -301,7 +309,8 @@ def _from_config(
301309 )
302310 for layer_id in range (config .n_layers )
303311 for head_id in range (none_throws (config .n_kv_heads ))
304- if cache_lens [layer_id ] > 0 and layer_id < first_kv_shared
312+ if cache_lens [layer_id ] > 0
313+ and not _is_kv_shared_layer (layer_id , config .n_layers , num_kv_shared )
305314 }
306315 self .v_caches = {
307316 StaticKVCache .calculate_cache_key (layer_id , head_id ): torch .zeros (
@@ -312,7 +321,8 @@ def _from_config(
312321 )
313322 for layer_id in range (config .n_layers )
314323 for head_id in range (none_throws (config .n_kv_heads ))
315- if cache_lens [layer_id ] > 0 and layer_id < first_kv_shared
324+ if cache_lens [layer_id ] > 0
325+ and not _is_kv_shared_layer (layer_id , config .n_layers , num_kv_shared )
316326 }
317327 else :
318328 self .k_caches = {
@@ -324,7 +334,8 @@ def _from_config(
324334 dtype = dtype ,
325335 )
326336 for layer_id in range (config .n_layers )
327- if cache_lens [layer_id ] > 0 and layer_id < first_kv_shared
337+ if cache_lens [layer_id ] > 0
338+ and not _is_kv_shared_layer (layer_id , config .n_layers , num_kv_shared )
328339 }
329340 self .v_caches = {
330341 StaticKVCache .calculate_cache_key (layer_id , 0 ): torch .zeros (
@@ -335,7 +346,8 @@ def _from_config(
335346 dtype = dtype ,
336347 )
337348 for layer_id in range (config .n_layers )
338- if cache_lens [layer_id ] > 0 and layer_id < first_kv_shared
349+ if cache_lens [layer_id ] > 0
350+ and not _is_kv_shared_layer (layer_id , config .n_layers , num_kv_shared )
339351 }
340352
341353 self .generate_full_logits = config .generate_full_logits
@@ -998,8 +1010,10 @@ def forward(
9981010 new_qs = [wq (x ) for wq in self .wqs ]
9991011
10001012 shared_kv = kwargs .get ("shared_kv" )
1001- if self .is_kv_shared_layer :
1002- assert shared_kv is not None
1013+ if shared_kv is not None :
1014+ assert (
1015+ self .is_kv_shared_layer
1016+ ), "shared_kv provided but this is not a KV shared layer"
10031017 new_ks = []
10041018 new_vs = []
10051019 else :
@@ -1222,7 +1236,7 @@ def _process_normal_kv(
12221236
12231237 kv_to_share = None
12241238 if self .num_kv_shared_layers > 0 :
1225- kv_to_share = (torch .cat (all_ks , dim = 1 ), torch .cat (all_vs , dim = 1 ))
1239+ kv_to_share = (torch .stack (all_ks , dim = 1 ), torch .stack (all_vs , dim = 1 ))
12261240
12271241 return torch .cat (heads , dim = - 1 ), out_cache_state , kv_to_share
12281242
0 commit comments