diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index d66b9e2bb94..64822fd0c42 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -325,6 +325,7 @@ def _replace_kv_cache_with_quantized_kv_cache(module): child, QuantizedCacheType.AffineAsymmetric, use_custom_update_cache_op=True, + is_seq_at_dim_2=child.is_seq_at_dim_2, ), ) else: @@ -421,6 +422,7 @@ def _replace_kv_cache_with_custom_kv_cache(module): n_heads, head_dim, dtype=cache_dtype, + is_seq_at_dim_2=True, # hacking temporarily ), ) else: diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 5217a103ce4..a3544d5704e 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -92,6 +92,7 @@ def _replace_sdpa_with_custom_op( SDPACustom( child.dim, use_attention_mask=use_attention_mask, + is_seq_at_dim_2=True, # hacking temporarily ), ) else: