diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index e50b3707d51..8ec0ab40a65 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -183,11 +183,11 @@ Tensor& sdpa_with_kv_cache_out_no_context( } at::Tensor sdpa_with_kv_cache_aten( - const at::Tensor& q_projected, - const at::Tensor& k_projected, - const at::Tensor& v_projected, - at::Tensor& key_cache, - at::Tensor& value_cache, + const at::Tensor& q_proj, + const at::Tensor& k_proj, + const at::Tensor& v_proj, + at::Tensor& k_cache, + at::Tensor& v_cache, const int64_t start_pos, const int64_t seq_len, // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue @@ -197,6 +197,11 @@ at::Tensor sdpa_with_kv_cache_aten( const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale) { + auto q_projected = q_proj.contiguous(); + auto k_projected = k_proj.contiguous(); + auto v_projected = v_proj.contiguous(); + auto key_cache = k_cache.contiguous(); + auto value_cache = v_cache.contiguous(); auto output = at::empty_like(q_projected); WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11) (q_projected, @@ -256,11 +261,14 @@ at::Tensor custom_sdpa_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale, const bool is_seq_dim_2) { - auto output = at::empty(q.sizes()); + auto q_projected = q.contiguous(); + auto k_projected = k.contiguous(); + auto v_projected = v.contiguous(); + auto output = at::empty_like(q_projected); WRAP_TO_ATEN(custom_sdpa_out_no_context, 9) - (q, - k, - v, + (q_projected, + k_projected, + v_projected, start_pos, attn_mask, dropout_p, @@ -331,7 +339,10 @@ at::Tensor custom_quantized_sdpa_aten( const std::optional& v_zero_points, const std::optional& v_scales, const bool is_seq_at_dim_2) { - auto output = at::empty(q.sizes()); + auto q_projected = q.contiguous(); + auto k_projected = k.contiguous(); + auto v_projected = v.contiguous(); + auto output = at::empty(q_projected.sizes()); WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 15) (q, k,