Skip to content

Commit 502d2de

Browse files
Improve YOCO static attention: reusable helper, correct tensor op, runtime guard (pytorch#18545)
Differential Revision: D97637849 Pull Request resolved: pytorch#18545
1 parent 24751f1 commit 502d2de

1 file changed

Lines changed: 25 additions & 11 deletions

File tree

examples/models/llama/static_attention.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,18 @@ def unmask(self, new_unmasked_len):
220220
self.unmasked_len += new_unmasked_len
221221

222222

223+
def _is_kv_shared_layer(
224+
layer_idx: int, n_layers: int, num_kv_shared_layers: int
225+
) -> bool:
226+
"""Check if this layer uses shared K/V from a donor layer (YOCO)."""
227+
if num_kv_shared_layers <= 0:
228+
return False
229+
first_shared = n_layers - num_kv_shared_layers
230+
if first_shared <= 0:
231+
return False
232+
return layer_idx >= first_shared
233+
234+
223235
class StaticAttentionIOManager:
224236
class NGramCache:
225237
def __init__(self, max_size):
@@ -286,11 +298,7 @@ def _from_config(
286298
self.freqs_sin = freqs[1].to(dtype)
287299

288300
split_mha = config.attention_type in ("static", "static_shas")
289-
# YOCO: skip cache creation for KV-shared layers (2nd half)
290301
num_kv_shared = getattr(config, "num_kv_shared_layers", 0)
291-
first_kv_shared = (
292-
config.n_layers - num_kv_shared if num_kv_shared > 0 else config.n_layers
293-
)
294302
if split_mha:
295303
self.k_caches = {
296304
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
@@ -301,7 +309,8 @@ def _from_config(
301309
)
302310
for layer_id in range(config.n_layers)
303311
for head_id in range(none_throws(config.n_kv_heads))
304-
if cache_lens[layer_id] > 0 and layer_id < first_kv_shared
312+
if cache_lens[layer_id] > 0
313+
and not _is_kv_shared_layer(layer_id, config.n_layers, num_kv_shared)
305314
}
306315
self.v_caches = {
307316
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
@@ -312,7 +321,8 @@ def _from_config(
312321
)
313322
for layer_id in range(config.n_layers)
314323
for head_id in range(none_throws(config.n_kv_heads))
315-
if cache_lens[layer_id] > 0 and layer_id < first_kv_shared
324+
if cache_lens[layer_id] > 0
325+
and not _is_kv_shared_layer(layer_id, config.n_layers, num_kv_shared)
316326
}
317327
else:
318328
self.k_caches = {
@@ -324,7 +334,8 @@ def _from_config(
324334
dtype=dtype,
325335
)
326336
for layer_id in range(config.n_layers)
327-
if cache_lens[layer_id] > 0 and layer_id < first_kv_shared
337+
if cache_lens[layer_id] > 0
338+
and not _is_kv_shared_layer(layer_id, config.n_layers, num_kv_shared)
328339
}
329340
self.v_caches = {
330341
StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros(
@@ -335,7 +346,8 @@ def _from_config(
335346
dtype=dtype,
336347
)
337348
for layer_id in range(config.n_layers)
338-
if cache_lens[layer_id] > 0 and layer_id < first_kv_shared
349+
if cache_lens[layer_id] > 0
350+
and not _is_kv_shared_layer(layer_id, config.n_layers, num_kv_shared)
339351
}
340352

341353
self.generate_full_logits = config.generate_full_logits
@@ -998,8 +1010,10 @@ def forward(
9981010
new_qs = [wq(x) for wq in self.wqs]
9991011

10001012
shared_kv = kwargs.get("shared_kv")
1001-
if self.is_kv_shared_layer:
1002-
assert shared_kv is not None
1013+
if shared_kv is not None:
1014+
assert (
1015+
self.is_kv_shared_layer
1016+
), "shared_kv provided but this is not a KV shared layer"
10031017
new_ks = []
10041018
new_vs = []
10051019
else:
@@ -1222,7 +1236,7 @@ def _process_normal_kv(
12221236

12231237
kv_to_share = None
12241238
if self.num_kv_shared_layers > 0:
1225-
kv_to_share = (torch.cat(all_ks, dim=1), torch.cat(all_vs, dim=1))
1239+
kv_to_share = (torch.stack(all_ks, dim=1), torch.stack(all_vs, dim=1))
12261240

12271241
return torch.cat(heads, dim=-1), out_cache_state, kv_to_share
12281242

0 commit comments

Comments
 (0)