Skip to content

Commit 01142cc

Browse files
committed
temp hack
Differential Revision: [D93870396](https://our.internmc.facebook.com/intern/diff/D93870396/) [ghstack-poisoned]
1 parent bab3440 commit 01142cc

2 files changed

Lines changed: 3 additions & 0 deletions

File tree

examples/models/llama/source_transformation/custom_kv_cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ def _replace_kv_cache_with_quantized_kv_cache(module):
325325
child,
326326
QuantizedCacheType.AffineAsymmetric,
327327
use_custom_update_cache_op=True,
328+
is_seq_at_dim_2=child.is_seq_at_dim_2,
328329
),
329330
)
330331
else:
@@ -421,6 +422,7 @@ def _replace_kv_cache_with_custom_kv_cache(module):
421422
n_heads,
422423
head_dim,
423424
dtype=cache_dtype,
425+
is_seq_at_dim_2=True, # hacking temporarily
424426
),
425427
)
426428
else:

examples/models/llama/source_transformation/sdpa.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def _replace_sdpa_with_custom_op(
9292
SDPACustom(
9393
child.dim,
9494
use_attention_mask=use_attention_mask,
95+
is_seq_at_dim_2=True, # hacking temporarily
9596
),
9697
)
9798
else:

0 commit comments

Comments
 (0)