Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def forward(
0, # dropout probability. Ignored by the code
True, # is_causal
)
if self.is_seq_at_dim_2:
output = output.transpose(1, 2).contiguous()
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)


Expand Down Expand Up @@ -198,6 +200,8 @@ def forward(
v_scale_fp32,
)

if self.is_seq_at_dim_2:
output = output.transpose(1, 2).contiguous()
return output.view(bsz, seqlen, self.dim)


Expand Down
3 changes: 2 additions & 1 deletion extension/llm/custom_ops/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ def custom_sdpa(
drpout_p=0.0,
is_causal=False,
scale=None,
is_seq_dim_2=False,
):
seq_len = query.size(1)
seq_len = query.size(2) if is_seq_dim_2 else query.size(1)
_validate_params(
query,
key_cache,
Expand Down
25 changes: 20 additions & 5 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -360,16 +360,13 @@
output,
"Invalid arguments");

int64_t seq_len = q.size(1);
SeqDim seq_dim{SeqDim::TWO};
if (!is_seq_at_dim_2) {
seq_dim = SeqDim::ONE;
}
int64_t seq_len = q.size(static_cast<int64_t>(seq_dim));

if (q.scalar_type() == ScalarType::Char) {
if (seq_dim == SeqDim::TWO) {
seq_len = q.size(2);
}
ET_KERNEL_CHECK_MSG(
ctx,
q_scales.has_value() && q_zero_points.has_value() &&
Expand Down Expand Up @@ -564,9 +561,26 @@
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
const bool is_seq_dim_2,
Tensor& output) {
return custom_sdpa_out_impl(
ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
ctx,
q,
k,
v,
start_pos,
attn_mask,
dropout_p,
is_causal,
scale,
output,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt,
nullopt,
is_seq_dim_2);
}
/*
Input params
Expand Down Expand Up @@ -621,6 +635,7 @@
dropout_p,
is_causal,
scale,
false, // is_seq_dim_2 - default to false for backward compatibility
output);

return output;
Expand Down
1 change: 1 addition & 0 deletions extension/llm/custom_ops/op_sdpa.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Tensor& custom_sdpa_out(
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
const bool is_seq_dim_2,
Tensor& output);

Tensor& flash_attention_kernel_out(
Expand Down
28 changes: 21 additions & 7 deletions extension/llm/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Tensor& custom_sdpa_out_no_context(
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
const bool is_seq_dim_2,
Tensor& output);

at::Tensor custom_sdpa_aten(
Expand All @@ -75,7 +76,8 @@ at::Tensor custom_sdpa_aten(
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const std::optional<double> scale);
const std::optional<double> scale,
const bool is_seq_dim_2);

Tensor& custom_quantized_sdpa_out_no_context(
const Tensor& q,
Expand Down Expand Up @@ -224,6 +226,7 @@ Tensor& custom_sdpa_out_no_context(
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const optional<double> scale,
const bool is_seq_dim_2,
Tensor& output) {
executorch::aten::RuntimeContext context{};
return torch::executor::native::custom_sdpa_out(
Expand All @@ -236,6 +239,7 @@ Tensor& custom_sdpa_out_no_context(
dropout_p,
is_causal,
scale,
is_seq_dim_2,
output);
}

Expand All @@ -250,10 +254,20 @@ at::Tensor custom_sdpa_aten(
const double dropout_p,
const bool is_causal,
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
const std::optional<double> scale) {
const std::optional<double> scale,
const bool is_seq_dim_2) {
auto output = at::empty(q.sizes());
WRAP_TO_ATEN(custom_sdpa_out_no_context, 8)
(q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
WRAP_TO_ATEN(custom_sdpa_out_no_context, 9)
(q,
k,
v,
start_pos,
attn_mask,
dropout_p,
is_causal,
scale,
is_seq_dim_2,
output);
return output;
}

Expand Down Expand Up @@ -401,11 +415,11 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
m.def(
"custom_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
"float? scale=None) -> Tensor");
"float? scale=None, bool is_seq_dim_2=False) -> Tensor");
m.def(
"custom_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
"float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
"float? scale=None, bool is_seq_dim_2=False, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"update_cache(Tensor value, Tensor(a!) cache, "
"SymInt start_pos, bool is_seq_dim_2=False) -> Tensor");
Expand Down Expand Up @@ -443,7 +457,7 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten);
m.impl(
"custom_sdpa.out",
WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8));
WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 9));
m.impl("update_cache", torch::executor::native::update_cache_aten);
m.impl(
"update_cache.out",
Expand Down
Loading
Loading