Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions extension/llm/custom_ops/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,15 @@ def _validate_update_cache_params(
value,
cache,
start_pos,
is_seq_dim_2=False,
indices=None,
):
seq_len = value.size(1)
# Determine sequence dimension based on is_seq_dim_2
# If is_seq_dim_2 is False: [batch, seq, heads, head_dim]
# If is_seq_dim_2 is True: [batch, heads, seq, head_dim]
seq_dim = 2 if is_seq_dim_2 else 1
seq_len = value.size(seq_dim)

assert (
value.dim() == 4
), f"Expected value to be 4 dimensional but got {value.dim()} dimensions."
Expand All @@ -201,22 +207,31 @@ def _validate_update_cache_params(
value.dtype == cache.dtype
), f"Expected value and cache to be of the same type but got value type {value.dtype} and cache type {cache.dtype}"

for i in [0, 2, 3]:
assert value.size(i) == cache.size(
i
), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"
# Validate batch and head_dim dimensions match
assert value.size(0) == cache.size(
0
), f"Expected value and cache to have same size in dimension 0 (batch) but got {value.size(0)} and {cache.size(0)}"
assert value.size(3) == cache.size(
3
), f"Expected value and cache to have same size in dimension 3 (head_dim) but got {value.size(3)} and {cache.size(3)}"

# Validate heads dimension matches based on layout
heads_dim = 1 if is_seq_dim_2 else 2
assert value.size(heads_dim) == cache.size(
heads_dim
), f"Expected value and cache to have same size in dimension {heads_dim} (heads) but got {value.size(heads_dim)} and {cache.size(heads_dim)}"

torch._check_is_size(start_pos)
if indices is None:
torch._check(start_pos < cache.size(1))
torch._check(start_pos < cache.size(seq_dim))
assert start_pos < cache.size(
1
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
seq_dim
), f"Start position {start_pos} must be less than sequence length {cache.size(seq_dim)}"

torch._check((start_pos + seq_len) <= cache.size(1))
torch._check((start_pos + seq_len) <= cache.size(seq_dim))
assert (start_pos + seq_len) <= cache.size(
1
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
seq_dim
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(seq_dim)}"

if indices is not None:
assert (
Expand All @@ -229,20 +244,22 @@ def _validate_update_cache_params(
0
), f"Expected indices batch dimension to match value batch dimension but got {indices.size(0)} and {value.size(0)}"
assert indices.size(1) == value.size(
1
), f"Expected indices sequence length dimension to match value sequence length dimension but got {indices.size(1)} and {value.size(1)}"
seq_dim
), f"Expected indices sequence length dimension to match value sequence length dimension but got {indices.size(1)} and {value.size(seq_dim)}"


@impl(custom_ops_lib, "update_cache", "Meta")
def update_cache_meta(
value,
cache,
start_pos,
is_seq_dim_2=False,
):
_validate_update_cache_params(
value,
cache,
start_pos,
is_seq_dim_2,
)

# Update cache doesnt really return anything but I dont know a better
Expand All @@ -257,11 +274,13 @@ def update_cache_with_indices_meta(
cache,
start_pos,
indices,
is_seq_dim_2=False,
):
_validate_update_cache_params(
value,
cache,
start_pos,
is_seq_dim_2,
indices,
)

Expand Down
40 changes: 24 additions & 16 deletions extension/llm/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,26 +122,30 @@ Tensor& update_cache_out_no_context(
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
const bool is_seq_dim_2,
Tensor& output);

at::Tensor update_cache_aten(
const at::Tensor& value,
at::Tensor& cache,
const int64_t start_pos);
const int64_t start_pos,
const bool is_seq_dim_2);

// New functions for update_cache_with_indices
Tensor& update_cache_with_indices_out_no_context(
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
const Tensor& indices,
const bool is_seq_dim_2,
Tensor& output);

at::Tensor update_cache_with_indices_aten(
const at::Tensor& value,
at::Tensor& cache,
const int64_t start_pos,
const at::Tensor& indices);
const at::Tensor& indices,
const bool is_seq_dim_2);

Tensor& sdpa_with_kv_cache_out_no_context(
const Tensor& q_projected,
Expand Down Expand Up @@ -338,19 +342,21 @@ Tensor& update_cache_out_no_context(
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
const bool is_seq_dim_2,
Tensor& output) {
executorch::aten::RuntimeContext context{};
return torch::executor::native::update_cache_out(
context, value, cache, start_pos, output);
context, value, cache, start_pos, is_seq_dim_2, output);
}

at::Tensor update_cache_aten(
const at::Tensor& value,
at::Tensor& cache,
const int64_t start_pos) {
const int64_t start_pos,
const bool is_seq_dim_2) {
auto output = at::empty({1});
WRAP_TO_ATEN(update_cache_out_no_context, 3)
(value, cache, start_pos, output);
WRAP_TO_ATEN(update_cache_out_no_context, 4)
(value, cache, start_pos, is_seq_dim_2, output);
return output;
}

Expand All @@ -360,20 +366,22 @@ Tensor& update_cache_with_indices_out_no_context(
Tensor& cache,
const int64_t start_pos,
const Tensor& indices,
const bool is_seq_dim_2,
Tensor& output) {
executorch::aten::RuntimeContext context{};
return torch::executor::native::update_cache_with_indices_out(
context, value, cache, start_pos, indices, output);
context, value, cache, start_pos, indices, is_seq_dim_2, output);
}

at::Tensor update_cache_with_indices_aten(
const at::Tensor& value,
at::Tensor& cache,
const int64_t start_pos,
const at::Tensor& indices) {
const at::Tensor& indices,
const bool is_seq_dim_2) {
auto output = at::empty({1});
WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 4)
(value, cache, start_pos, indices, output);
WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 5)
(value, cache, start_pos, indices, is_seq_dim_2, output);
return output;
}

Expand All @@ -400,16 +408,16 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
"float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"update_cache(Tensor value, Tensor(a!) cache, "
"SymInt start_pos) -> Tensor");
"SymInt start_pos, bool is_seq_dim_2=False) -> Tensor");
m.def(
"update_cache.out(Tensor value, Tensor(a!) cache, "
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
"SymInt start_pos, bool is_seq_dim_2=False, *, Tensor(b!) out) -> Tensor(b!)");
m.def(
"update_cache_with_indices(Tensor value, Tensor(a!) cache, "
"SymInt start_pos, Tensor indices) -> Tensor");
"SymInt start_pos, Tensor indices, bool is_seq_dim_2=False) -> Tensor");
m.def(
"update_cache_with_indices.out(Tensor value, Tensor(a!) cache, "
"SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)");
"SymInt start_pos, Tensor indices, bool is_seq_dim_2=False, *, Tensor(b!) out) -> Tensor(b!)");
m.def(
"custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
Expand Down Expand Up @@ -439,15 +447,15 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl("update_cache", torch::executor::native::update_cache_aten);
m.impl(
"update_cache.out",
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3));
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 4));
m.impl(
"update_cache_with_indices",
torch::executor::native::update_cache_with_indices_aten);
m.impl(
"update_cache_with_indices.out",
WRAP_TO_ATEN(
torch::executor::native::update_cache_with_indices_out_no_context,
4));
5));
m.impl(
"custom_quantized_sdpa",
torch::executor::native::custom_quantized_sdpa_aten);
Expand Down
Loading
Loading