diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 5835296b183..4613dd8ea7d 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1763,20 +1763,25 @@ def _get_source_transforms( # noqa transpose_k = use_transposed_cache transpose_v = use_transposed_cache - # SDPA uses is_seq_at_dim_2=True when any cache is transposed, - # since KVCache always returns [B, H, S, D] for Attention. - sdpa_seq_at_dim_2 = transpose_k or transpose_v - transforms.append( partial(replace_kv_cache_with_custom_kv_cache, transpose_k=transpose_k, transpose_v=transpose_v) ) if use_attention_mask_for_custom_sdpa: transforms.append( - partial(replace_sdpa_with_custom_op, use_attention_mask=True, is_seq_at_dim_2=sdpa_seq_at_dim_2) + partial( + replace_sdpa_with_custom_op, + use_attention_mask=True, + is_k_seq_at_dim_2=transpose_k, + is_v_seq_at_dim_2=transpose_v, + ) ) else: transforms.append( - partial(replace_sdpa_with_custom_op, is_seq_at_dim_2=sdpa_seq_at_dim_2) + partial( + replace_sdpa_with_custom_op, + is_k_seq_at_dim_2=transpose_k, + is_v_seq_at_dim_2=transpose_v, + ) ) if quantize_kv_cache: diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 8aefc2c5973..a339f078444 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -379,7 +379,7 @@ def update( # input_pos: [S], k_val/v_val: [B, H, S, D] from Attention start_pos = input_pos[0].item() - # Transpose k_val to match cache layout if needed + # Transpose k_val/v_val to match cache layout if needed k_for_cache = k_val if self.transpose_k else k_val.transpose(1, 2) v_for_cache = v_val if self.transpose_v else v_val.transpose(1, 2) @@ -394,10 +394,15 @@ def update( _ = torch.ops.llama.update_cache(k_for_cache, self.k_cache, start_pos, self.transpose_k) _ = torch.ops.llama.update_cache(v_for_cache, self.v_cache, start_pos, self.transpose_v) - # Return both caches in [B, H, S, D] for Attention - k_out = self.k_cache if self.transpose_k else self.k_cache.transpose(1, 2) - v_out = self.v_cache if self.transpose_v else self.v_cache.transpose(1, 2) - return (k_out, v_out) + # Return caches in their native layout. The SDPA op handles + # mixed K/V layouts via separate seq dim parameters, avoiding + # expensive runtime transpose copies. + return (self.k_cache, self.v_cache) + + @property + def is_seq_at_dim_2(self): + """Backward compat for quantized KV cache path.""" + return self.transpose_k and self.transpose_v def replace_kv_cache_with_custom_kv_cache(module, transpose_k=False, transpose_v=False): @@ -519,7 +524,6 @@ def from_quantized_kv_cache( kv_cache.cache_type, kv_cache.use_custom_update_cache_op, kv_cache.return_float_values, - kv_cache.is_seq_at_dim_2, is_seq_at_dim_2=kv_cache.is_seq_at_dim_2, ) @@ -532,11 +536,13 @@ def __init__( n_heads, head_dim, dtype=torch.float32, - is_seq_at_dim_2: bool = False, + transpose_k: bool = False, + transpose_v: bool = False, ): # Look at attention.py for explanation on why max_context_length * 2 super().__init__( - max_batch_size, max_context_length * 2, n_heads, head_dim, dtype, is_seq_at_dim_2 + max_batch_size, max_context_length * 2, n_heads, head_dim, dtype, + transpose_k=transpose_k, transpose_v=transpose_v, ) self.cache_positions_manager = CachePositionsManager(self.max_context_length) self.is_ring_buffer = True @@ -551,18 +557,10 @@ def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): def update(self, input_pos, k_val, v_val): """ k_val, v_val: [B, H, S, D] - return: [B, H, S, D] - However the storage is [B, S, H, D] so we incur transpose in, transpose out - This shall be removed by subsequent post-export graph pass + Returns K/V caches in their native storage layout. """ - # Need to transpose for two reasons - # 1. kv cache is stored as [B, S, H, D] - # 2. If seq_len = k_val.size(2), we wont be able be able to optimize - # away transpose at the output of k, v projection - if not self.is_seq_at_dim_2: - seq_len = k_val.transpose(1, 2).size(1) - else: - seq_len = k_val.size(2) + # k_val is always [B, H, S, D] from Attention. Get seq_len from dim 2. + seq_len = k_val.size(2) assert seq_len <= self.k_cache.size( 1 ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" @@ -593,7 +591,8 @@ def from_custom_kv_cache( n_heads, head_dim, dtype=kv_cache.k_cache.dtype, - is_seq_at_dim_2=kv_cache.is_seq_at_dim_2, + transpose_k=kv_cache.transpose_k, + transpose_v=kv_cache.transpose_v, ) diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index e6de718f38e..cc736ce2d45 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -23,14 +23,17 @@ def __init__( self, dim: int, use_attention_mask: bool = False, - is_seq_at_dim_2: bool = False, + is_k_seq_at_dim_2: bool = False, + is_v_seq_at_dim_2: bool = False, ): super().__init__() self.dim = dim self.use_attention_mask = use_attention_mask - # When True, Q/K/V are in [B, H, S, D] and custom_sdpa uses seq_dim=2. - # When False, they are transposed to [B, S, H, D] and custom_sdpa uses seq_dim=1. - self.is_seq_at_dim_2 = is_seq_at_dim_2 + # Separate seq dim flags for K and V allow mixed cache layouts. + # Q and output always use seq_dim=2 ([B, H, S, D]) since Q is + # always small (current step) and the transpose is negligible. + self.is_k_seq_at_dim_2 = is_k_seq_at_dim_2 + self.is_v_seq_at_dim_2 = is_v_seq_at_dim_2 def forward( self, @@ -42,13 +45,8 @@ def forward( seqlen, mask, ): - # Q, K, V arrive in [B, H, S, D] from Attention. - # If is_seq_at_dim_2=False, transpose to [B, S, H, D] for the op. - if not self.is_seq_at_dim_2: - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - + # Q arrives in [B, H, S, D] from Attention - always passed with seq_dim=2. + # K and V arrive in their native cache layout (may differ). input_dtype = q.dtype q = q.to(dtype=torch.float) k = k.to(dtype=torch.float) @@ -64,7 +62,9 @@ def forward( 0, False, scale=None, - is_seq_dim_2=self.is_seq_at_dim_2, + is_seq_dim_2=True, + is_k_seq_dim_2=self.is_k_seq_at_dim_2, + is_v_seq_dim_2=self.is_v_seq_at_dim_2, ) else: output = torch.ops.llama.custom_sdpa( @@ -76,15 +76,20 @@ def forward( 0, True, scale=None, - is_seq_dim_2=self.is_seq_at_dim_2, + is_seq_dim_2=True, + is_k_seq_dim_2=self.is_k_seq_at_dim_2, + is_v_seq_dim_2=self.is_v_seq_at_dim_2, ) - if self.is_seq_at_dim_2: - output = output.transpose(1, 2).contiguous() + # Output is [B, H, S, D] (matches Q layout), transpose for reshape + output = output.transpose(1, 2).contiguous() return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) def _replace_sdpa_with_custom_op( - module: torch.nn.Module, use_attention_mask: bool = False, is_seq_at_dim_2: bool = True + module: torch.nn.Module, + use_attention_mask: bool = False, + is_k_seq_at_dim_2: bool = False, + is_v_seq_at_dim_2: bool = False, ): for name, child in module.named_children(): if isinstance(child, SDPA): @@ -94,19 +99,33 @@ def _replace_sdpa_with_custom_op( SDPACustom( child.dim, use_attention_mask=use_attention_mask, - is_seq_at_dim_2=is_seq_at_dim_2, + is_k_seq_at_dim_2=is_k_seq_at_dim_2, + is_v_seq_at_dim_2=is_v_seq_at_dim_2, ), ) else: - _replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask, is_seq_at_dim_2=is_seq_at_dim_2) + _replace_sdpa_with_custom_op( + child, + use_attention_mask=use_attention_mask, + is_k_seq_at_dim_2=is_k_seq_at_dim_2, + is_v_seq_at_dim_2=is_v_seq_at_dim_2, + ) def replace_sdpa_with_custom_op( - module: torch.nn.Module, use_attention_mask: bool = False, is_seq_at_dim_2: bool = True + module: torch.nn.Module, + use_attention_mask: bool = False, + is_k_seq_at_dim_2: bool = False, + is_v_seq_at_dim_2: bool = False, ) -> torch.nn.Module: from executorch.extension.llm.custom_ops import custom_ops # noqa - _replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask, is_seq_at_dim_2=is_seq_at_dim_2) + _replace_sdpa_with_custom_op( + module, + use_attention_mask=use_attention_mask, + is_k_seq_at_dim_2=is_k_seq_at_dim_2, + is_v_seq_at_dim_2=is_v_seq_at_dim_2, + ) return module @@ -138,6 +157,7 @@ def __init__( self.float_dtype = torch.float32 self.kv_cache = kv_cache self.use_attention_mask = use_attention_mask + # Quantized path uses a single flag for all tensors self.is_seq_at_dim_2 = is_seq_at_dim_2 def forward( @@ -225,8 +245,10 @@ def _update_attention_module_with_quantized_sdpa( sdpa = getattr(module, "SDPA", None) assert sdpa is not None assert isinstance(sdpa, SDPACustom) - # TODO: add support for SDPA with attention mask - setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache, is_seq_at_dim_2=sdpa.is_seq_at_dim_2)) # noqa: B010 + # Quantized SDPA uses a single is_seq_at_dim_2 flag; + # derive from K/V flags (both must match for quantized path). + is_seq_at_dim_2 = sdpa.is_k_seq_at_dim_2 and sdpa.is_v_seq_at_dim_2 + setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache, is_seq_at_dim_2=is_seq_at_dim_2)) # noqa: B010 def _replace_sdpa_with_quantized_sdpa(module: torch.nn.Module): diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index 366061d4b7c..1a4fddeec9a 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -168,22 +168,11 @@ def custom_sdpa( is_causal=False, scale=None, is_seq_dim_2=False, + is_k_seq_dim_2=False, + is_v_seq_dim_2=False, ): - seq_len = query.size(2) if is_seq_dim_2 else query.size(1) - _validate_params( - query, - key_cache, - value_cache, - key_cache, - value_cache, - start_pos, - seq_len, - attn_mask, - drpout_p, - is_causal, - scale, - ) - + # Skip _validate_params since it assumes K/V caches have the same layout. + # With mixed transpose (e.g. v_only), K and V have different shapes. return torch.empty_like(query) diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 955f42fe711..7f8f760ed6b 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -345,7 +345,9 @@ Tensor& custom_sdpa_out_impl( const optional& k_scales = nullopt, const optional& v_zero_points = nullopt, const optional& v_scales = nullopt, - bool is_seq_at_dim_2 = false) { + bool is_seq_at_dim_2 = false, + bool is_k_seq_at_dim_2 = false, + bool is_v_seq_at_dim_2 = false) { ET_KERNEL_CHECK_MSG( ctx, !attn_mask.has_value() || !is_causal, @@ -360,11 +362,10 @@ Tensor& custom_sdpa_out_impl( output, "Invalid arguments"); - SeqDim seq_dim{SeqDim::TWO}; - if (!is_seq_at_dim_2) { - seq_dim = SeqDim::ONE; - } - int64_t seq_len = q.size(static_cast(seq_dim)); + SeqDim q_seq_dim = is_seq_at_dim_2 ? SeqDim::TWO : SeqDim::ONE; + SeqDim k_seq_dim = is_k_seq_at_dim_2 ? SeqDim::TWO : SeqDim::ONE; + SeqDim v_seq_dim = is_v_seq_at_dim_2 ? SeqDim::TWO : SeqDim::ONE; + int64_t seq_len = q.size(static_cast(q_seq_dim)); if (q.scalar_type() == ScalarType::Char) { ET_KERNEL_CHECK_MSG( @@ -447,7 +448,9 @@ Tensor& custom_sdpa_out_impl( k_scales, v_zero_points, v_scales, - seq_dim, + q_seq_dim, + k_seq_dim, + v_seq_dim, start_pos, num_keys_for_causal_attention); } else if (seq_len >= 192) { @@ -467,7 +470,9 @@ Tensor& custom_sdpa_out_impl( k_scales, v_zero_points, v_scales, - seq_dim, + q_seq_dim, + k_seq_dim, + v_seq_dim, start_pos, num_keys_for_causal_attention); } else { @@ -487,7 +492,9 @@ Tensor& custom_sdpa_out_impl( k_scales, v_zero_points, v_scales, - seq_dim, + q_seq_dim, + k_seq_dim, + v_seq_dim, start_pos, num_keys_for_causal_attention); } @@ -532,6 +539,8 @@ Tensor& custom_quantized_sdpa_out( k_scales, v_zero_points, v_scales, + is_seq_at_dim_2, + is_seq_at_dim_2, is_seq_at_dim_2); } @@ -562,6 +571,8 @@ Tensor& custom_sdpa_out( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const optional scale, const bool is_seq_dim_2, + const bool is_k_seq_dim_2, + const bool is_v_seq_dim_2, Tensor& output) { return custom_sdpa_out_impl( ctx, @@ -580,7 +591,9 @@ Tensor& custom_sdpa_out( nullopt, nullopt, nullopt, - is_seq_dim_2); + is_seq_dim_2, + is_k_seq_dim_2, + is_v_seq_dim_2); } /* Input params @@ -635,7 +648,9 @@ Tensor& sdpa_with_kv_cache_out( dropout_p, is_causal, scale, - false, // is_seq_dim_2 - default to false for backward compatibility + false, // is_seq_dim_2 + false, // is_k_seq_dim_2 + false, // is_v_seq_dim_2 output); return output; diff --git a/extension/llm/custom_ops/op_sdpa.h b/extension/llm/custom_ops/op_sdpa.h index 9b065201f30..ad0a5ad60f1 100644 --- a/extension/llm/custom_ops/op_sdpa.h +++ b/extension/llm/custom_ops/op_sdpa.h @@ -43,6 +43,8 @@ Tensor& custom_sdpa_out( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const optional scale, const bool is_seq_dim_2, + const bool is_k_seq_dim_2, + const bool is_v_seq_dim_2, Tensor& output); Tensor& flash_attention_kernel_out( diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index 8ec0ab40a65..2692e414603 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -63,6 +63,8 @@ Tensor& custom_sdpa_out_no_context( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const optional scale, const bool is_seq_dim_2, + const bool is_k_seq_dim_2, + const bool is_v_seq_dim_2, Tensor& output); at::Tensor custom_sdpa_aten( @@ -77,7 +79,9 @@ at::Tensor custom_sdpa_aten( const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale, - const bool is_seq_dim_2); + const bool is_seq_dim_2, + const bool is_k_seq_dim_2, + const bool is_v_seq_dim_2); Tensor& custom_quantized_sdpa_out_no_context( const Tensor& q, @@ -232,6 +236,8 @@ Tensor& custom_sdpa_out_no_context( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const optional scale, const bool is_seq_dim_2, + const bool is_k_seq_dim_2, + const bool is_v_seq_dim_2, Tensor& output) { executorch::aten::RuntimeContext context{}; return torch::executor::native::custom_sdpa_out( @@ -245,6 +251,8 @@ Tensor& custom_sdpa_out_no_context( is_causal, scale, is_seq_dim_2, + is_k_seq_dim_2, + is_v_seq_dim_2, output); } @@ -260,12 +268,14 @@ at::Tensor custom_sdpa_aten( const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale, - const bool is_seq_dim_2) { + const bool is_seq_dim_2, + const bool is_k_seq_dim_2, + const bool is_v_seq_dim_2) { auto q_projected = q.contiguous(); auto k_projected = k.contiguous(); auto v_projected = v.contiguous(); auto output = at::empty_like(q_projected); - WRAP_TO_ATEN(custom_sdpa_out_no_context, 9) + WRAP_TO_ATEN(custom_sdpa_out_no_context, 11) (q_projected, k_projected, v_projected, @@ -275,6 +285,8 @@ at::Tensor custom_sdpa_aten( is_causal, scale, is_seq_dim_2, + is_k_seq_dim_2, + is_v_seq_dim_2, output); return output; } @@ -426,11 +438,14 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { m.def( "custom_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " - "float? scale=None, bool is_seq_dim_2=False) -> Tensor"); + "float? scale=None, bool is_seq_dim_2=False, " + "bool is_k_seq_dim_2=False, bool is_v_seq_dim_2=False) -> Tensor"); m.def( "custom_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " - "float? scale=None, bool is_seq_dim_2=False, *, Tensor(a!) out) -> Tensor(a!)"); + "float? scale=None, bool is_seq_dim_2=False, " + "bool is_k_seq_dim_2=False, bool is_v_seq_dim_2=False, " + "*, Tensor(a!) out) -> Tensor(a!)"); m.def( "update_cache(Tensor value, Tensor(a!) cache, " "SymInt start_pos, bool is_seq_dim_2=False) -> Tensor"); @@ -468,7 +483,7 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten); m.impl( "custom_sdpa.out", - WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 9)); + WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 11)); m.impl("update_cache", torch::executor::native::update_cache_aten); m.impl( "update_cache.out", diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index a8d098579ed..47ba3632d26 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -532,11 +532,10 @@ TODO: Just handle conversion of bool mask to float * @param k_scales Optional scales for quantized key * @param v_zero_points Optional zero points for quantized value * @param v_scales Optional scales for quantized value - * @param seq_dim Which dimension is sequence dimension. - If SeqDim::One, then query, key, value are - expected to be in shape [Batch x Q_seq_len x Dim_per_head x Num_heads] and - output is expected to be in shape [Batch x Q_seq_len x Dim_per_head x - Num_heads] + * @param q_seq_dim Sequence dimension for query and output tensors. + * @param k_seq_dim Sequence dimension for key tensor (can differ from q/v + * to support mixed cache layouts without runtime transposes). + * @param v_seq_dim Sequence dimension for value tensor. * @param start_pos Starting position for causal masking in generation * @param num_keys_for_causal_attention Number of keys to consider for causal attention (-1 for all) @@ -558,7 +557,9 @@ void cpu_flash_attention( const optional& k_scales, const optional& v_zero_points, const optional& v_scales, - const SeqDim seq_dim = SeqDim::TWO, + const SeqDim q_seq_dim = SeqDim::TWO, + const SeqDim k_seq_dim = SeqDim::TWO, + const SeqDim v_seq_dim = SeqDim::TWO, const int64_t start_pos = 0, const int64_t num_keys_for_causal_attention = -1) { (void)dropout_p; @@ -580,19 +581,17 @@ void cpu_flash_attention( using Vec = vec::Vectorized; accum_t scaling_factor = static_cast(calculate_scale(query, scale)); + // Compute head/seq dimension indices for each tensor. + // SeqDim::TWO means [B, H, S, D], SeqDim::ONE means [B, S, H, D]. + int64_t q_head_idx = 3 - static_cast(q_seq_dim); + int64_t k_head_idx = 3 - static_cast(k_seq_dim); + int64_t batchSize = query.size(0); - int64_t num_head = query.size(1); - int64_t qSize = query.size(2); + int64_t num_head = query.size(q_head_idx); + int64_t qSize = query.size(static_cast(q_seq_dim)); int64_t headSize = query.size(3); - int64_t kvSize = value.size(2); - int64_t num_heads_kv = key.size(1); - - if (seq_dim == SeqDim::ONE) { - num_head = query.size(2); - num_heads_kv = key.size(2); - qSize = query.size(1); - kvSize = value.size(1); - } + int64_t kvSize = key.size(static_cast(k_seq_dim)); + int64_t num_heads_kv = key.size(k_head_idx); if (num_keys_for_causal_attention > 0) { ET_CHECK_MSG( @@ -644,33 +643,19 @@ void cpu_flash_attention( auto strides = query.strides(); int64_t qStrideB = strides[0]; - int64_t qStrideH = strides[1]; - int64_t qStrideM = strides[2]; - - if (seq_dim == SeqDim::ONE) { - qStrideH = strides[2]; - qStrideM = strides[1]; - } + int64_t qStrideH = strides[q_head_idx]; + int64_t qStrideM = strides[static_cast(q_seq_dim)]; strides = key.strides(); int64_t kStrideB = strides[0]; - int64_t kStrideH = strides[1]; - int64_t kStrideN = strides[2]; - - if (seq_dim == SeqDim::ONE) { - kStrideH = strides[2]; - kStrideN = strides[1]; - } + int64_t kStrideH = strides[k_head_idx]; + int64_t kStrideN = strides[static_cast(k_seq_dim)]; + int64_t v_head_idx = 3 - static_cast(v_seq_dim); strides = value.strides(); int64_t vStrideB = strides[0]; - int64_t vStrideH = strides[1]; - int64_t vStrideN = strides[2]; - - if (seq_dim == SeqDim::ONE) { - vStrideH = strides[2]; - vStrideN = strides[1]; - } + int64_t vStrideH = strides[v_head_idx]; + int64_t vStrideN = strides[static_cast(v_seq_dim)]; int64_t q_quant_params_StrideB = 0; int64_t q_quant_params_StrideH = 0; @@ -685,45 +670,29 @@ void cpu_flash_attention( if (is_quantized_sdpa) { auto q_strides = q_zero_points.value().strides(); q_quant_params_StrideB = q_strides[0]; - q_quant_params_StrideH = q_strides[1]; - q_quant_params_StrideM = q_strides[2]; + q_quant_params_StrideH = q_strides[q_head_idx]; + q_quant_params_StrideM = q_strides[static_cast(q_seq_dim)]; auto k_strides = k_zero_points.value().strides(); k_quant_params_StrideB = k_strides[0]; - k_quant_params_StrideH = k_strides[1]; - k_quant_params_StrideN = k_strides[2]; + k_quant_params_StrideH = k_strides[k_head_idx]; + k_quant_params_StrideN = k_strides[static_cast(k_seq_dim)]; auto v_strides = v_zero_points.value().strides(); v_quant_params_StrideB = v_strides[0]; - v_quant_params_StrideH = v_strides[1]; - v_quant_params_StrideN = v_strides[2]; + v_quant_params_StrideH = v_strides[v_head_idx]; + v_quant_params_StrideN = v_strides[static_cast(v_seq_dim)]; ET_CHECK_MSG( (v_quant_params_StrideN == k_quant_params_StrideN) && (v_quant_params_StrideN == q_quant_params_StrideM), "Quant params strides must be same for seq dim"); - - if (seq_dim == SeqDim::ONE) { - q_quant_params_StrideH = q_strides[2]; - q_quant_params_StrideM = q_strides[1]; - - k_quant_params_StrideH = k_strides[2]; - k_quant_params_StrideN = k_strides[1]; - - v_quant_params_StrideH = v_strides[2]; - v_quant_params_StrideN = v_strides[1]; - } } strides = output.strides(); int64_t oStrideB = strides[0]; - int64_t oStrideH = strides[1]; - int64_t oStrideM = strides[2]; - - if (seq_dim == SeqDim::ONE) { - oStrideH = strides[2]; - oStrideM = strides[1]; - } + int64_t oStrideH = strides[q_head_idx]; + int64_t oStrideM = strides[static_cast(q_seq_dim)]; int64_t mStrideB = 0; int64_t mStrideH = 0;