diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index 9aacded4b4c..a56e3de5782 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -190,9 +190,15 @@ def _validate_update_cache_params( value, cache, start_pos, + is_seq_dim_2=False, indices=None, ): - seq_len = value.size(1) + # Determine sequence dimension based on is_seq_dim_2 + # If is_seq_dim_2 is False: [batch, seq, heads, head_dim] + # If is_seq_dim_2 is True: [batch, heads, seq, head_dim] + seq_dim = 2 if is_seq_dim_2 else 1 + seq_len = value.size(seq_dim) + assert ( value.dim() == 4 ), f"Expected value to be 4 dimensional but got {value.dim()} dimensions." @@ -201,22 +207,31 @@ def _validate_update_cache_params( value.dtype == cache.dtype ), f"Expected value and cache to be of the same type but got value type {value.dtype} and cache type {cache.dtype}" - for i in [0, 2, 3]: - assert value.size(i) == cache.size( - i - ), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}" + # Validate batch and head_dim dimensions match + assert value.size(0) == cache.size( + 0 + ), f"Expected value and cache to have same size in dimension 0 (batch) but got {value.size(0)} and {cache.size(0)}" + assert value.size(3) == cache.size( + 3 + ), f"Expected value and cache to have same size in dimension 3 (head_dim) but got {value.size(3)} and {cache.size(3)}" + + # Validate heads dimension matches based on layout + heads_dim = 1 if is_seq_dim_2 else 2 + assert value.size(heads_dim) == cache.size( + heads_dim + ), f"Expected value and cache to have same size in dimension {heads_dim} (heads) but got {value.size(heads_dim)} and {cache.size(heads_dim)}" torch._check_is_size(start_pos) if indices is None: - torch._check(start_pos < cache.size(1)) + torch._check(start_pos < cache.size(seq_dim)) assert start_pos < cache.size( - 1 - ), f"Start position {start_pos} must be less than sequence length {cache.size(1)}" + seq_dim + ), f"Start position {start_pos} must be less than sequence length {cache.size(seq_dim)}" - torch._check((start_pos + seq_len) <= cache.size(1)) + torch._check((start_pos + seq_len) <= cache.size(seq_dim)) assert (start_pos + seq_len) <= cache.size( - 1 - ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}" + seq_dim + ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(seq_dim)}" if indices is not None: assert ( @@ -229,8 +244,8 @@ def _validate_update_cache_params( 0 ), f"Expected indices batch dimension to match value batch dimension but got {indices.size(0)} and {value.size(0)}" assert indices.size(1) == value.size( - 1 - ), f"Expected indices sequence length dimension to match value sequence length dimension but got {indices.size(1)} and {value.size(1)}" + seq_dim + ), f"Expected indices sequence length dimension to match value sequence length dimension but got {indices.size(1)} and {value.size(seq_dim)}" @impl(custom_ops_lib, "update_cache", "Meta") @@ -238,11 +253,13 @@ def update_cache_meta( value, cache, start_pos, + is_seq_dim_2=False, ): _validate_update_cache_params( value, cache, start_pos, + is_seq_dim_2, ) # Update cache doesnt really return anything but I dont know a better @@ -257,11 +274,13 @@ def update_cache_with_indices_meta( cache, start_pos, indices, + is_seq_dim_2=False, ): _validate_update_cache_params( value, cache, start_pos, + is_seq_dim_2, indices, ) diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index 5bbf22d336e..7bed1e61b6b 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -122,12 +122,14 @@ Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, const int64_t start_pos, + const bool is_seq_dim_2, Tensor& output); at::Tensor update_cache_aten( const at::Tensor& value, at::Tensor& cache, - const int64_t start_pos); + const int64_t start_pos, + const bool is_seq_dim_2); // New functions for update_cache_with_indices Tensor& update_cache_with_indices_out_no_context( @@ -135,13 +137,15 @@ Tensor& update_cache_with_indices_out_no_context( Tensor& cache, const int64_t start_pos, const Tensor& indices, + const bool is_seq_dim_2, Tensor& output); at::Tensor update_cache_with_indices_aten( const at::Tensor& value, at::Tensor& cache, const int64_t start_pos, - const at::Tensor& indices); + const at::Tensor& indices, + const bool is_seq_dim_2); Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, @@ -338,19 +342,21 @@ Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, const int64_t start_pos, + const bool is_seq_dim_2, Tensor& output) { executorch::aten::RuntimeContext context{}; return torch::executor::native::update_cache_out( - context, value, cache, start_pos, output); + context, value, cache, start_pos, is_seq_dim_2, output); } at::Tensor update_cache_aten( const at::Tensor& value, at::Tensor& cache, - const int64_t start_pos) { + const int64_t start_pos, + const bool is_seq_dim_2) { auto output = at::empty({1}); - WRAP_TO_ATEN(update_cache_out_no_context, 3) - (value, cache, start_pos, output); + WRAP_TO_ATEN(update_cache_out_no_context, 4) + (value, cache, start_pos, is_seq_dim_2, output); return output; } @@ -360,20 +366,22 @@ Tensor& update_cache_with_indices_out_no_context( Tensor& cache, const int64_t start_pos, const Tensor& indices, + const bool is_seq_dim_2, Tensor& output) { executorch::aten::RuntimeContext context{}; return torch::executor::native::update_cache_with_indices_out( - context, value, cache, start_pos, indices, output); + context, value, cache, start_pos, indices, is_seq_dim_2, output); } at::Tensor update_cache_with_indices_aten( const at::Tensor& value, at::Tensor& cache, const int64_t start_pos, - const at::Tensor& indices) { + const at::Tensor& indices, + const bool is_seq_dim_2) { auto output = at::empty({1}); - WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 4) - (value, cache, start_pos, indices, output); + WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 5) + (value, cache, start_pos, indices, is_seq_dim_2, output); return output; } @@ -400,16 +408,16 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { "float? scale=None, *, Tensor(a!) out) -> Tensor(a!)"); m.def( "update_cache(Tensor value, Tensor(a!) cache, " - "SymInt start_pos) -> Tensor"); + "SymInt start_pos, bool is_seq_dim_2=False) -> Tensor"); m.def( "update_cache.out(Tensor value, Tensor(a!) cache, " - "SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)"); + "SymInt start_pos, bool is_seq_dim_2=False, *, Tensor(b!) out) -> Tensor(b!)"); m.def( "update_cache_with_indices(Tensor value, Tensor(a!) cache, " - "SymInt start_pos, Tensor indices) -> Tensor"); + "SymInt start_pos, Tensor indices, bool is_seq_dim_2=False) -> Tensor"); m.def( "update_cache_with_indices.out(Tensor value, Tensor(a!) cache, " - "SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)"); + "SymInt start_pos, Tensor indices, bool is_seq_dim_2=False, *, Tensor(b!) out) -> Tensor(b!)"); m.def( "custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " @@ -439,7 +447,7 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { m.impl("update_cache", torch::executor::native::update_cache_aten); m.impl( "update_cache.out", - WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3)); + WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 4)); m.impl( "update_cache_with_indices", torch::executor::native::update_cache_with_indices_aten); @@ -447,7 +455,7 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { "update_cache_with_indices.out", WRAP_TO_ATEN( torch::executor::native::update_cache_with_indices_out_no_context, - 4)); + 5)); m.impl( "custom_quantized_sdpa", torch::executor::native::custom_quantized_sdpa_aten); diff --git a/extension/llm/custom_ops/op_update_cache.cpp b/extension/llm/custom_ops/op_update_cache.cpp index 7ab994deb5f..5f918bd90bb 100644 --- a/extension/llm/custom_ops/op_update_cache.cpp +++ b/extension/llm/custom_ops/op_update_cache.cpp @@ -26,6 +26,7 @@ bool validate_cache_params( const Tensor& quantized_cache, int64_t start_pos, int64_t seq_length, + bool is_seq_dim_2, const optional& indices = nullopt) { ET_CHECK_OR_RETURN_FALSE( quantized_cache.dim() == 4, "quantized cache must be a 4D tensor"); @@ -33,6 +34,9 @@ bool validate_cache_params( ET_CHECK_OR_RETURN_FALSE( quantized_value.dim() == 4, "quantized_value must be a 4D tensor"); + // Determine the sequence dimension based on is_seq_dim_2 + int64_t seq_dim = is_seq_dim_2 ? 2 : 1; + if (indices.has_value()) { const auto& indices_tensor = indices.value(); ET_CHECK_OR_RETURN_FALSE( @@ -44,7 +48,7 @@ bool validate_cache_params( "indices batch dimension must match value batch dimension"); ET_CHECK_OR_RETURN_FALSE( - indices_tensor.size(1) == quantized_value.size(1), + indices_tensor.size(1) == quantized_value.size(seq_dim), "indices sequence length dimension must match value sequence length dimension"); ET_CHECK_OR_RETURN_FALSE( @@ -57,20 +61,22 @@ bool validate_cache_params( "indices must be in contiguous dim order"); } else { ET_CHECK_OR_RETURN_FALSE( - start_pos < quantized_cache.size(1), - "start_pos: %" PRId64 " must be less than cache size at dim 1: %zd", + start_pos < quantized_cache.size(seq_dim), + "start_pos: %" PRId64 " must be less than cache size at dim %" PRId64 + ": %zd", start_pos, - quantized_cache.size(1)); + seq_dim, + quantized_cache.size(seq_dim)); ET_CHECK_OR_RETURN_FALSE( - (start_pos + seq_length) <= quantized_cache.size(1), + (start_pos + seq_length) <= quantized_cache.size(seq_dim), "start_post + seq_length must be less than max seq length supported by cache." "start pos: %" PRId64 ", seq_length: %" PRId64 "." "cache size: %zd", start_pos, seq_length, - quantized_cache.size(1)); + quantized_cache.size(seq_dim)); } // Make sure they are in contiguous dim order @@ -93,25 +99,39 @@ Tensor& update_cache_impl( const Tensor& value, Tensor& cache, const int64_t start_pos, + bool is_seq_dim_2, Tensor& output, const optional& indices = nullopt) { (void)ctx; + // Determine dimensions based on is_seq_dim_2 + // If is_seq_dim_2 is false: [batch, seq, heads, head_dim] + // If is_seq_dim_2 is true: [batch, heads, seq, head_dim] + int64_t value_batch_size = value.size(0); + int64_t value_seq_len = is_seq_dim_2 ? value.size(2) : value.size(1); + int64_t value_num_heads = is_seq_dim_2 ? value.size(1) : value.size(2); + int64_t value_head_dim = value.size(3); + + int64_t cache_batch_size = cache.size(0); + int64_t cache_seq_len = is_seq_dim_2 ? cache.size(2) : cache.size(1); + int64_t cache_num_heads = is_seq_dim_2 ? cache.size(1) : cache.size(2); + int64_t cache_head_dim = cache.size(3); + ET_CHECK_MSG( - value.size(0) == cache.size(0), + value_batch_size == cache_batch_size, "projected_value batch size (%zd) should be equal to the cache batch size (%zd).", - value.size(0), - cache.size(0)); + value_batch_size, + cache_batch_size); ET_CHECK_MSG( - value.size(2) == cache.size(2), + value_num_heads == cache_num_heads, "projected_value number of heads (%zd) should be equal to the cache number of heads (%zd).", - value.size(2), - cache.size(2)); + value_num_heads, + cache_num_heads); ET_CHECK_MSG( - value.size(3) == cache.size(3), + value_head_dim == cache_head_dim, "projected_value embedding dimension (%zd) should be equal to the cache embedding dimension (%zd).", - value.size(3), - cache.size(3)); + value_head_dim, + cache_head_dim); ET_CHECK_MSG( value.element_size() == cache.element_size(), "projected_value data type size (%zd) should be equal to the cache data type size (%zd).", @@ -133,13 +153,17 @@ Tensor& update_cache_impl( auto cache_strides = cache.strides(); executorch::aten::StridesType cache_batch_dim_stride = cache_strides[0]; - executorch::aten::StridesType cache_seq_dim_stride = cache_strides[1]; + executorch::aten::StridesType cache_seq_dim_stride = + is_seq_dim_2 ? cache_strides[2] : cache_strides[1]; + executorch::aten::StridesType cache_head_dim_stride = + is_seq_dim_2 ? cache_strides[1] : cache_strides[2]; auto value_strides = value.strides(); executorch::aten::StridesType value_batch_dim_stride = value_strides[0]; - - executorch::aten::SizesType num_bytes_to_copy = - (value.numel() / value.size(0)) * value.element_size(); + executorch::aten::StridesType value_seq_dim_stride = + is_seq_dim_2 ? value_strides[2] : value_strides[1]; + executorch::aten::StridesType value_head_dim_stride = + is_seq_dim_2 ? value_strides[1] : value_strides[2]; if (indices.has_value()) { // Use the provided indices tensor for each batch and sequence position @@ -152,30 +176,53 @@ Tensor& update_cache_impl( // Calculate bytes to copy for a single token executorch::aten::SizesType bytes_per_token = - (value.numel() / (value.size(0) * value.size(1))) * + (value.numel() / (value_batch_size * value_seq_len)) * value.element_size(); + int64_t num_values_to_copy = value_batch_size; + executorch::aten::StridesType value_stride = value_batch_dim_stride; + executorch::aten::StridesType cache_stride = cache_batch_dim_stride; + if (is_seq_dim_2) { + /* + If is_seq_dim_2 is true, the value tensor is in the format + [batch, heads, seq, head_dim]. We assume we collapse this in + [batch * heads, seq, head_dim]. Thus the stride on the first dim + is actually stride of the heads dim + Then for each value in [batch * heads], we copy value tensor at the index + corresponding to the indices tensor. + Number of bytes to copy is seqlen (value.size(seq_dim)) * head_dim * element_size + */ + num_values_to_copy = value_batch_size * value_num_heads; + value_stride = value_head_dim_stride; + cache_stride = cache_head_dim_stride; + bytes_per_token = value.size(3) * value.element_size(); + } - for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) { - for (int64_t seq_idx = 0; seq_idx < value.size(1); ++seq_idx) { + for (int64_t value_idx = 0; value_idx < num_values_to_copy; ++value_idx) { + for (int64_t seq_idx = 0; seq_idx < value_seq_len; ++seq_idx) { + int64_t batch_index = value_idx; + if (is_seq_dim_2) { + batch_index = value_idx / value_num_heads; + } // Get the target position from the indices tensor int64_t target_pos = indices_data - [batch_line * indices_batch_stride + seq_idx * indices_seq_stride]; + [batch_index * indices_batch_stride + seq_idx * indices_seq_stride]; // Ensure the target position is valid ET_CHECK_MSG( - target_pos >= 0 && target_pos < cache.size(1), + target_pos >= 0 && target_pos < cache_seq_len, "Index out of bounds: %" PRId64 " not in [0, %zd)", target_pos, - cache.size(1)); + cache_seq_len); // Calculate offsets for cache and value executorch::aten::SizesType cache_pos_offset = - (batch_line * cache_batch_dim_stride + + (value_idx * cache_stride + target_pos * cache_seq_dim_stride) * cache.element_size(); executorch::aten::SizesType value_pos_offset = - (batch_line * value_batch_dim_stride + seq_idx * value_strides[1]) * + (value_idx * value_stride + + seq_idx * value_seq_dim_stride) * value.element_size(); // Copy a single token @@ -187,13 +234,34 @@ Tensor& update_cache_impl( } } else { // Use the original implementation with start_pos - for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) { + int64_t num_values_to_copy = value_batch_size; + executorch::aten::SizesType num_bytes_to_copy = + (value.numel() / value_batch_size) * value.element_size(); + executorch::aten::StridesType value_stride = value_batch_dim_stride; + executorch::aten::StridesType cache_stride = cache_batch_dim_stride; + if (is_seq_dim_2) { + /* + If is_seq_dim_2 is true, the value tensor is in the format + [batch, heads, seq, head_dim]. We assume we collapse this in + [batch * heads, seq, head_dim]. Thus the stride on the first dim + is actually stride of the heads dim + Then for each value in [batch * heads], we copy value tensor at that index + in the cache, starting at the start_pos. + Number of bytes to copy is seqlen (value.size(seq_dim)) * head_dim * element_size + */ + num_values_to_copy = value_batch_size * value_num_heads; + num_bytes_to_copy = (value.numel() / (value_batch_size * value_num_heads)) * value.element_size(); + value_stride = value_head_dim_stride; + cache_stride = cache_head_dim_stride; + } + + for (int64_t value_idx = 0; value_idx < num_values_to_copy; ++value_idx) { executorch::aten::SizesType cache_pos_offset = - (batch_line * cache_batch_dim_stride + + (value_idx * cache_stride + start_pos * cache_seq_dim_stride) * cache.element_size(); executorch::aten::SizesType value_pos_offset = - (batch_line * value_batch_dim_stride) * cache.element_size(); + (value_idx * value_stride) * cache.element_size(); std::memcpy( (uint8_t*)cache_data + cache_pos_offset, @@ -213,15 +281,16 @@ Tensor& update_cache_out( const Tensor& value, Tensor& cache, const int64_t start_pos, + bool is_seq_dim_2, Tensor& output) { - int64_t seq_len = value.size(1); + int64_t seq_len = is_seq_dim_2 ? value.size(2) : value.size(1); ET_KERNEL_CHECK( ctx, - validate_cache_params(value, cache, start_pos, seq_len), + validate_cache_params(value, cache, start_pos, seq_len, is_seq_dim_2), InvalidArgument, output); - return update_cache_impl(ctx, value, cache, start_pos, output); + return update_cache_impl(ctx, value, cache, start_pos, is_seq_dim_2, output); } // New function that explicitly takes indices @@ -231,15 +300,18 @@ Tensor& update_cache_with_indices_out( Tensor& cache, const int64_t start_pos, const Tensor& indices, + bool is_seq_dim_2, Tensor& output) { - int64_t seq_len = value.size(1); + int64_t seq_len = is_seq_dim_2 ? value.size(2) : value.size(1); ET_KERNEL_CHECK( ctx, - validate_cache_params(value, cache, start_pos, seq_len, indices), + validate_cache_params( + value, cache, start_pos, seq_len, is_seq_dim_2, indices), InvalidArgument, output); - return update_cache_impl(ctx, value, cache, start_pos, output, indices); + return update_cache_impl( + ctx, value, cache, start_pos, is_seq_dim_2, output, indices); } } // namespace native diff --git a/extension/llm/custom_ops/op_update_cache.h b/extension/llm/custom_ops/op_update_cache.h index 84c73039469..76eb48e05f0 100644 --- a/extension/llm/custom_ops/op_update_cache.h +++ b/extension/llm/custom_ops/op_update_cache.h @@ -16,20 +16,26 @@ namespace executor { namespace native { // Original update_cache_out function without indices parameter +// is_seq_dim_2: when false, expects [batch, seq, heads, head_dim] layout +// when true, expects [batch, heads, seq, head_dim] layout Tensor& update_cache_out( RuntimeContext& ctx, const Tensor& value, Tensor& cache, const int64_t start_pos, + bool is_seq_dim_2, Tensor& output); // New function that explicitly takes indices +// is_seq_dim_2: when false, expects [batch, seq, heads, head_dim] layout +// when true, expects [batch, heads, seq, head_dim] layout Tensor& update_cache_with_indices_out( RuntimeContext& ctx, const Tensor& value, Tensor& cache, const int64_t start_pos, const Tensor& indices, + bool is_seq_dim_2, Tensor& output); } // namespace native } // namespace executor diff --git a/extension/llm/custom_ops/test_update_cache.py b/extension/llm/custom_ops/test_update_cache.py index 84a349c97f0..cad19b1bd75 100644 --- a/extension/llm/custom_ops/test_update_cache.py +++ b/extension/llm/custom_ops/test_update_cache.py @@ -86,13 +86,13 @@ def _update_and_validate( self._update_k(start_pos, k, k_scales, k_zero_points) self._update_v(start_pos, v, v_scales, v_zero_points) - torch.ops.llama.update_cache(k, k_cache, start_pos) - torch.ops.llama.update_cache(k_scales, k_scales_cache, start_pos) - torch.ops.llama.update_cache(k_zero_points, k_zero_points_cache, start_pos) + torch.ops.llama.update_cache(k, k_cache, start_pos, False) + torch.ops.llama.update_cache(k_scales, k_scales_cache, start_pos, False) + torch.ops.llama.update_cache(k_zero_points, k_zero_points_cache, start_pos, False) - torch.ops.llama.update_cache(v, v_cache, start_pos) - torch.ops.llama.update_cache(v_scales, v_scales_cache, start_pos) - torch.ops.llama.update_cache(v_zero_points, v_zero_points_cache, start_pos) + torch.ops.llama.update_cache(v, v_cache, start_pos, False) + torch.ops.llama.update_cache(v_scales, v_scales_cache, start_pos, False) + torch.ops.llama.update_cache(v_zero_points, v_zero_points_cache, start_pos, False) self.assertTrue(torch.allclose(k_cache, self.quantized_k_cache)) self.assertTrue(torch.allclose(v_cache, self.quantized_v_cache)) @@ -120,12 +120,12 @@ def _update_with_indices_and_validate( ] # Update using custom op - torch.ops.llama.update_cache_with_indices(k, k_cache, start_pos, indices) + torch.ops.llama.update_cache_with_indices(k, k_cache, start_pos, indices, False) torch.ops.llama.update_cache_with_indices( - k_scales, k_scales_cache, start_pos, indices + k_scales, k_scales_cache, start_pos, indices, False ) torch.ops.llama.update_cache_with_indices( - k_zero_points, k_zero_points_cache, start_pos, indices + k_zero_points, k_zero_points_cache, start_pos, indices, False ) # Validate results @@ -218,7 +218,7 @@ def test_indices_exceeding_cache_size(self): @run_in_subprocess def run_and_catch(k, k_cache, start_pos, indices): - torch.ops.llama.update_cache(k, k_cache, start_pos, indices) + torch.ops.llama.update_cache_with_indices(k, k_cache, start_pos, indices, False) exception_raised = False try: @@ -238,7 +238,7 @@ def test_negative_indices(self): @run_in_subprocess def run_and_catch(k, k_cache, start_pos, indices): - torch.ops.llama.update_cache(k, k_cache, start_pos, indices) + torch.ops.llama.update_cache_with_indices(k, k_cache, start_pos, indices, False) exception_raised = False try: @@ -270,19 +270,19 @@ def test_duplicate_indices(self): v_zero_points_cache = self.v_zero_points_cache.clone() # Update using custom op - torch.ops.llama.update_cache_with_indices(k, k_cache, start_pos, indices) + torch.ops.llama.update_cache_with_indices(k, k_cache, start_pos, indices, False) torch.ops.llama.update_cache_with_indices( - k_scales, k_scales_cache, start_pos, indices + k_scales, k_scales_cache, start_pos, indices, False ) torch.ops.llama.update_cache_with_indices( - k_zero_points, k_zero_points_cache, start_pos, indices + k_zero_points, k_zero_points_cache, start_pos, indices, False ) - torch.ops.llama.update_cache_with_indices(v, v_cache, start_pos, indices) + torch.ops.llama.update_cache_with_indices(v, v_cache, start_pos, indices, False) torch.ops.llama.update_cache_with_indices( - v_scales, v_scales_cache, start_pos, indices + v_scales, v_scales_cache, start_pos, indices, False ) torch.ops.llama.update_cache_with_indices( - v_zero_points, v_zero_points_cache, start_pos, indices + v_zero_points, v_zero_points_cache, start_pos, indices, False ) # Position 3 should have the value from the last update (index 2 in the sequence) @@ -338,7 +338,7 @@ def test_different_seq_lengths_per_batch(self): @run_in_subprocess def run_and_catch(k, k_cache, start_pos, indices): - torch.ops.llama.update_cache(k, k_cache, start_pos, indices) + torch.ops.llama.update_cache_with_indices(k, k_cache, start_pos, indices, False) exception_raised = False try: @@ -431,3 +431,75 @@ def test_batched_update_kv_cache_more_updates(self): self._update_and_validate( k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ) + + def test_update_cache_with_seq_dim_2(self): + """Test update_cache with is_seq_dim_2=True (layout: [batch, heads, seq, head_dim]).""" + # Reset and prepare caches in the new layout + batch_size = 1 + seq_len = 10 + num_heads = 8 + head_dim = 4 + + # Cache with layout [batch, heads, seq, head_dim] + k_cache = torch.zeros( + (batch_size, num_heads, seq_len, head_dim), + dtype=torch.int8, + ) + k_scales_cache = torch.zeros( + (batch_size, num_heads, seq_len, 1), dtype=torch.float64 + ) + k_zero_points_cache = torch.zeros( + (batch_size, num_heads, seq_len, 1), dtype=torch.int64 + ) + + # Value with layout [batch, heads, seq=1, head_dim] + k = torch.randint(0, 50, (batch_size, num_heads, 1, head_dim), dtype=torch.int8) + k_scales = torch.rand((batch_size, num_heads, 1, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (batch_size, num_heads, 1, 1), dtype=torch.int64) + + start_pos = 3 + + # Update using custom op with is_seq_dim_2=True + torch.ops.llama.update_cache(k, k_cache, start_pos, True) + torch.ops.llama.update_cache(k_scales, k_scales_cache, start_pos, True) + torch.ops.llama.update_cache(k_zero_points, k_zero_points_cache, start_pos, True) + + # Verify the update happened at the correct position + # The sequence dimension is at index 2 when is_seq_dim_2=True + self.assertTrue(torch.allclose(k_cache[:, :, start_pos:start_pos+1, :], k)) + self.assertTrue(torch.allclose(k_scales_cache[:, :, start_pos:start_pos+1, :], k_scales)) + self.assertTrue(torch.allclose(k_zero_points_cache[:, :, start_pos:start_pos+1, :], k_zero_points)) + + def test_update_cache_with_indices_seq_dim_2(self): + """Test update_cache_with_indices with is_seq_dim_2=True.""" + batch_size = 1 + seq_len = 10 + num_heads = 8 + head_dim = 4 + + # Cache with layout [batch, heads, seq, head_dim] + k_cache = torch.zeros( + (batch_size, num_heads, seq_len, head_dim), + dtype=torch.int8, + ) + k_scales_cache = torch.zeros( + (batch_size, num_heads, seq_len, 1), dtype=torch.float64 + ) + + # Value with layout [batch, heads, seq=3, head_dim] + k = torch.randint(0, 50, (batch_size, num_heads, 3, head_dim), dtype=torch.int8) + k_scales = torch.rand((batch_size, num_heads, 3, 1), dtype=torch.float64) + + # Update positions 2, 5, 7 + indices = torch.tensor([[2, 5, 7]], dtype=torch.int64) + start_pos = 0 + + # Update using custom op with is_seq_dim_2=True + torch.ops.llama.update_cache_with_indices(k, k_cache, start_pos, indices, True) + torch.ops.llama.update_cache_with_indices(k_scales, k_scales_cache, start_pos, indices, True) + + # Verify the updates happened at the correct positions + for seq_idx in range(3): + target_pos = indices[0, seq_idx].item() + self.assertTrue(torch.allclose(k_cache[:, :, target_pos:target_pos+1, :], k[:, :, seq_idx:seq_idx+1, :])) + self.assertTrue(torch.allclose(k_scales_cache[:, :, target_pos:target_pos+1, :], k_scales[:, :, seq_idx:seq_idx+1, :]))