Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions extension/llm/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ Tensor& sdpa_with_kv_cache_out_no_context(
}

at::Tensor sdpa_with_kv_cache_aten(
const at::Tensor& q_projected,
const at::Tensor& k_projected,
const at::Tensor& v_projected,
at::Tensor& key_cache,
at::Tensor& value_cache,
const at::Tensor& q_proj,
const at::Tensor& k_proj,
const at::Tensor& v_proj,
at::Tensor& k_cache,
at::Tensor& v_cache,
const int64_t start_pos,
const int64_t seq_len,
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
Expand All @@ -197,6 +197,11 @@ at::Tensor sdpa_with_kv_cache_aten(
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const std::optional<double> scale) {
auto q_projected = q_proj.contiguous();
auto k_projected = k_proj.contiguous();
auto v_projected = v_proj.contiguous();
auto key_cache = k_cache.contiguous();
auto value_cache = v_cache.contiguous();
auto output = at::empty_like(q_projected);
WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11)
(q_projected,
Expand Down Expand Up @@ -256,11 +261,14 @@ at::Tensor custom_sdpa_aten(
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const std::optional<double> scale,
const bool is_seq_dim_2) {
auto output = at::empty(q.sizes());
auto q_projected = q.contiguous();
auto k_projected = k.contiguous();
auto v_projected = v.contiguous();
auto output = at::empty_like(q_projected);
WRAP_TO_ATEN(custom_sdpa_out_no_context, 9)
(q,
k,
v,
(q_projected,
k_projected,
v_projected,
start_pos,
attn_mask,
dropout_p,
Expand Down Expand Up @@ -331,7 +339,10 @@ at::Tensor custom_quantized_sdpa_aten(
const std::optional<at::Tensor>& v_zero_points,
const std::optional<at::Tensor>& v_scales,
const bool is_seq_at_dim_2) {
auto output = at::empty(q.sizes());
auto q_projected = q.contiguous();
auto k_projected = k.contiguous();
auto v_projected = v.contiguous();
auto output = at::empty(q_projected.sizes());
WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 15)
(q,
k,
Expand Down
Loading