From ffb29f710edda400b8fa63ce202eb347067a37c2 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 6 Apr 2026 08:47:18 -0700 Subject: [PATCH] Add support for seq at dim=2 This makes custom sdpa and quantized sdpa on feature parity and allows for transposed kv cache Differential Revision: [D93870394](https://our.internmc.facebook.com/intern/diff/D93870394/) [ghstack-poisoned] --- .../llama/source_transformation/sdpa.py | 4 + extension/llm/custom_ops/custom_ops.py | 3 +- extension/llm/custom_ops/op_sdpa.cpp | 25 +- extension/llm/custom_ops/op_sdpa.h | 1 + extension/llm/custom_ops/op_sdpa_aot.cpp | 28 +- .../llm/custom_ops/test_sdpa_with_kv_cache.py | 346 ++++++++++++++++++ 6 files changed, 394 insertions(+), 13 deletions(-) diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index c54e689ba8d..2e108b2ec19 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -69,6 +69,8 @@ def forward( 0, # dropout probability. Ignored by the code True, # is_causal ) + if self.is_seq_at_dim_2: + output = output.transpose(1, 2).contiguous() return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) @@ -198,6 +200,8 @@ def forward( v_scale_fp32, ) + if self.is_seq_at_dim_2: + output = output.transpose(1, 2).contiguous() return output.view(bsz, seqlen, self.dim) diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index a56e3de5782..366061d4b7c 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -167,8 +167,9 @@ def custom_sdpa( drpout_p=0.0, is_causal=False, scale=None, + is_seq_dim_2=False, ): - seq_len = query.size(1) + seq_len = query.size(2) if is_seq_dim_2 else query.size(1) _validate_params( query, key_cache, diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 0906f039df4..955f42fe711 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -360,16 +360,13 @@ Tensor& custom_sdpa_out_impl( output, "Invalid arguments"); - int64_t seq_len = q.size(1); SeqDim seq_dim{SeqDim::TWO}; if (!is_seq_at_dim_2) { seq_dim = SeqDim::ONE; } + int64_t seq_len = q.size(static_cast(seq_dim)); if (q.scalar_type() == ScalarType::Char) { - if (seq_dim == SeqDim::TWO) { - seq_len = q.size(2); - } ET_KERNEL_CHECK_MSG( ctx, q_scales.has_value() && q_zero_points.has_value() && @@ -564,9 +561,26 @@ Tensor& custom_sdpa_out( const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const optional scale, + const bool is_seq_dim_2, Tensor& output) { return custom_sdpa_out_impl( - ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); + ctx, + q, + k, + v, + start_pos, + attn_mask, + dropout_p, + is_causal, + scale, + output, + nullopt, + nullopt, + nullopt, + nullopt, + nullopt, + nullopt, + is_seq_dim_2); } /* Input params @@ -621,6 +635,7 @@ Tensor& sdpa_with_kv_cache_out( dropout_p, is_causal, scale, + false, // is_seq_dim_2 - default to false for backward compatibility output); return output; diff --git a/extension/llm/custom_ops/op_sdpa.h b/extension/llm/custom_ops/op_sdpa.h index 9d357eb6ea1..9b065201f30 100644 --- a/extension/llm/custom_ops/op_sdpa.h +++ b/extension/llm/custom_ops/op_sdpa.h @@ -42,6 +42,7 @@ Tensor& custom_sdpa_out( const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const optional scale, + const bool is_seq_dim_2, Tensor& output); Tensor& flash_attention_kernel_out( diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index 7bed1e61b6b..e50b3707d51 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -62,6 +62,7 @@ Tensor& custom_sdpa_out_no_context( const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const optional scale, + const bool is_seq_dim_2, Tensor& output); at::Tensor custom_sdpa_aten( @@ -75,7 +76,8 @@ at::Tensor custom_sdpa_aten( const double dropout_p, const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const std::optional scale); + const std::optional scale, + const bool is_seq_dim_2); Tensor& custom_quantized_sdpa_out_no_context( const Tensor& q, @@ -224,6 +226,7 @@ Tensor& custom_sdpa_out_no_context( const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const optional scale, + const bool is_seq_dim_2, Tensor& output) { executorch::aten::RuntimeContext context{}; return torch::executor::native::custom_sdpa_out( @@ -236,6 +239,7 @@ Tensor& custom_sdpa_out_no_context( dropout_p, is_causal, scale, + is_seq_dim_2, output); } @@ -250,10 +254,20 @@ at::Tensor custom_sdpa_aten( const double dropout_p, const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const std::optional scale) { + const std::optional scale, + const bool is_seq_dim_2) { auto output = at::empty(q.sizes()); - WRAP_TO_ATEN(custom_sdpa_out_no_context, 8) - (q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); + WRAP_TO_ATEN(custom_sdpa_out_no_context, 9) + (q, + k, + v, + start_pos, + attn_mask, + dropout_p, + is_causal, + scale, + is_seq_dim_2, + output); return output; } @@ -401,11 +415,11 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { m.def( "custom_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " - "float? scale=None) -> Tensor"); + "float? scale=None, bool is_seq_dim_2=False) -> Tensor"); m.def( "custom_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " - "float? scale=None, *, Tensor(a!) out) -> Tensor(a!)"); + "float? scale=None, bool is_seq_dim_2=False, *, Tensor(a!) out) -> Tensor(a!)"); m.def( "update_cache(Tensor value, Tensor(a!) cache, " "SymInt start_pos, bool is_seq_dim_2=False) -> Tensor"); @@ -443,7 +457,7 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten); m.impl( "custom_sdpa.out", - WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8)); + WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 9)); m.impl("update_cache", torch::executor::native::update_cache_aten); m.impl( "update_cache.out", diff --git a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py index d044a4789ff..1ad2046e8a9 100644 --- a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py @@ -45,6 +45,76 @@ def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq return out +def _custom_sdpa_ref( + q, k_cache, v_cache, attn_mask, start_pos, is_seq_dim_2=False +): + """ + Reference implementation for custom_sdpa operation. + + Args: + q: Query tensor [batch, seq, heads, head_dim] if is_seq_dim_2=False + or [batch, heads, seq, head_dim] if is_seq_dim_2=True + k_cache: Key cache [batch, max_seq, heads, head_dim] if is_seq_dim_2=False + or [batch, heads, max_seq, head_dim] if is_seq_dim_2=True + v_cache: Value cache [batch, max_seq, heads, head_dim] if is_seq_dim_2=False + or [batch, heads, max_seq, head_dim] if is_seq_dim_2=True + attn_mask: Optional attention mask [seq_len, max_seq_len] + start_pos: Starting position in cache + is_seq_dim_2: If True, sequence dimension is at position 2, else at position 1 + """ + if is_seq_dim_2: + # Input: [batch, heads, seq, head_dim] + # Transpose to [batch, seq, heads, head_dim] for processing + q_transposed = q.transpose(1, 2) + k_cache_transposed = k_cache.transpose(1, 2) + v_cache_transposed = v_cache.transpose(1, 2) + else: + # Input: [batch, seq, heads, head_dim] + q_transposed = q + k_cache_transposed = k_cache + v_cache_transposed = v_cache + + # Now all tensors are in [batch, seq, heads, head_dim] format + seq_len = q_transposed.size(1) + + # Transpose for SDPA: [batch, heads, seq, head_dim] + q_sdpa = q_transposed.transpose(1, 2) + + # Slice cache up to start_pos + seq_len + sliced_k_cache = k_cache_transposed[:, : start_pos + seq_len, :, :] + sliced_v_cache = v_cache_transposed[:, : start_pos + seq_len, :, :] + + # Transpose cache: [batch, heads, cached_seq, head_dim] + sliced_k_cache = sliced_k_cache.transpose(1, 2) + sliced_v_cache = sliced_v_cache.transpose(1, 2) + + # Handle MQA/GQA - repeat key/value heads if necessary + num_heads_q = q_sdpa.size(1) + num_heads_kv = sliced_k_cache.size(1) + if num_heads_q != num_heads_kv: + assert ( + num_heads_q % num_heads_kv == 0 + ), f"{num_heads_q} not divisible by {num_heads_kv}" + n_reps = num_heads_q // num_heads_kv + if n_reps > 1: + sliced_k_cache = sliced_k_cache.repeat_interleave(n_reps, dim=1) + sliced_v_cache = sliced_v_cache.repeat_interleave(n_reps, dim=1) + + # Run scaled dot product attention + out = F.scaled_dot_product_attention( + q_sdpa, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask + ) + + # Transpose back: [batch, seq, heads, head_dim] + out = out.transpose(1, 2) + + if is_seq_dim_2: + # Transpose back to [batch, heads, seq, head_dim] + out = out.transpose(1, 2) + + return out + + class SDPATest(unittest.TestCase): def setUp(self): torch.manual_seed(42) @@ -645,3 +715,279 @@ def test_sdpa_to_repro_long_seq_failure(self): self._test_sdpa_common( n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len ) + + +class CustomSDPASeqDim2Test(unittest.TestCase): + """Tests for custom_sdpa with is_seq_dim_2=True parameter.""" + + def setUp(self): + torch.manual_seed(42) + self.batch_size = 1 + self.num_heads = 8 + self.head_dim = 4 + self.max_seq_len = 10 + self.mask = torch.full( + (self.max_seq_len, self.max_seq_len), + float("-inf"), + ) + self.mask = torch.triu(self.mask, diagonal=1) + + def test_custom_sdpa_seq_dim_2_decode_phase(self): + """Test custom_sdpa with is_seq_dim_2=True during decode phase (seq_len=1).""" + # Setup: Create tensors with layout [batch, heads, seq, head_dim] + seq_len = 1 + q = torch.rand((self.batch_size, self.num_heads, seq_len, self.head_dim)) + k_cache = torch.rand( + (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) + ) + v_cache = torch.rand( + (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) + ) + start_pos = 5 + + # Execute: Run custom_sdpa with is_seq_dim_2=True and reference implementation + op_output = torch.ops.llama.custom_sdpa( + q, + k_cache, + v_cache, + start_pos, + None, # attn_mask + 0.0, # dropout_p + True, # is_causal + None, # scale + True, # is_seq_dim_2 + ) + + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + ref_output = _custom_sdpa_ref( + q, k_cache, v_cache, attn_mask, start_pos, is_seq_dim_2=True + ) + + # Assert: Verify outputs match + self.assertTrue(torch.allclose(op_output, ref_output, atol=1e-5)) + + def test_custom_sdpa_seq_dim_2_prefill_phase(self): + """Test custom_sdpa with is_seq_dim_2=True during prefill phase (seq_len>1).""" + # Setup: Create tensors with layout [batch, heads, seq, head_dim] for prefill + seq_len = 4 + q = torch.rand((self.batch_size, self.num_heads, seq_len, self.head_dim)) + k_cache = torch.rand( + (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) + ) + v_cache = torch.rand( + (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) + ) + start_pos = 0 + + # Execute: Run custom_sdpa with is_seq_dim_2=True and reference implementation + op_output = torch.ops.llama.custom_sdpa( + q, + k_cache, + v_cache, + start_pos, + None, # attn_mask + 0.0, # dropout_p + True, # is_causal + None, # scale + True, # is_seq_dim_2 + ) + + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + ref_output = _custom_sdpa_ref( + q, k_cache, v_cache, attn_mask, start_pos, is_seq_dim_2=True + ) + + # Assert: Verify outputs match + self.assertTrue(torch.allclose(op_output, ref_output, atol=1e-5)) + + def test_custom_sdpa_seq_dim_2_vs_seq_dim_1(self): + """Test that results are consistent between seq_dim_1 and seq_dim_2 layouts.""" + # Setup: Create tensors in both layouts with same data + seq_len = 2 + + # Layout 1: [batch, seq, heads, head_dim] + q_dim1 = torch.rand((self.batch_size, seq_len, self.num_heads, self.head_dim)) + k_cache_dim1 = torch.rand( + (self.batch_size, self.max_seq_len, self.num_heads, self.head_dim) + ) + v_cache_dim1 = torch.rand( + (self.batch_size, self.max_seq_len, self.num_heads, self.head_dim) + ) + + # Layout 2: [batch, heads, seq, head_dim] - transpose from layout 1 + q_dim2 = q_dim1.transpose(1, 2).contiguous() + k_cache_dim2 = k_cache_dim1.transpose(1, 2).contiguous() + v_cache_dim2 = v_cache_dim1.transpose(1, 2).contiguous() + + start_pos = 3 + + # Execute: Run custom_sdpa with both layouts + output_dim1 = torch.ops.llama.custom_sdpa( + q_dim1, + k_cache_dim1, + v_cache_dim1, + start_pos, + None, # attn_mask + 0.0, # dropout_p + True, # is_causal + None, # scale + False, # is_seq_dim_2 + ) + + output_dim2 = torch.ops.llama.custom_sdpa( + q_dim2, + k_cache_dim2, + v_cache_dim2, + start_pos, + None, # attn_mask + 0.0, # dropout_p + True, # is_causal + None, # scale + True, # is_seq_dim_2 + ) + + # Assert: Results should be equivalent when transposed back + output_dim1_transposed = output_dim1.transpose(1, 2) + self.assertTrue(torch.allclose(output_dim2, output_dim1_transposed, atol=1e-5)) + + def test_custom_sdpa_seq_dim_2_with_attention_mask(self): + """Test custom_sdpa with is_seq_dim_2=True and attention mask.""" + # Setup: Create tensors with layout [batch, heads, seq, head_dim] + seq_len = 3 + q = torch.rand((self.batch_size, self.num_heads, seq_len, self.head_dim)) + k_cache = torch.rand( + (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) + ) + v_cache = torch.rand( + (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) + ) + + start_pos = 2 + # Create attention mask + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask.contiguous() + + # Execute: Run custom_sdpa with attention mask and reference implementation + op_output = torch.ops.llama.custom_sdpa( + q, + k_cache, + v_cache, + start_pos, + attn_mask, + 0.0, # dropout_p + False, # is_causal (using mask instead) + None, # scale + True, # is_seq_dim_2 + ) + + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + ref_output = _custom_sdpa_ref( + q, k_cache, v_cache, attn_mask, start_pos, is_seq_dim_2=True + ) + + # Assert: Verify outputs match + self.assertTrue(torch.allclose(op_output, ref_output, atol=1e-5)) + + def test_custom_sdpa_seq_dim_2_with_scale(self): + """Test custom_sdpa with is_seq_dim_2=True and custom scale.""" + # Setup: Create tensors with layout [batch, heads, seq, head_dim] + q = torch.rand((self.batch_size, self.num_heads, 1, self.head_dim)) + k_cache = torch.rand( + (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) + ) + v_cache = torch.rand( + (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) + ) + start_pos = 4 + custom_scale = 0.5 + + # Execute: Run custom_sdpa with custom scale + # Note: Reference implementation doesn't support scale parameter easily, + # so we just verify the operation runs without error + op_output = torch.ops.llama.custom_sdpa( + q, + k_cache, + v_cache, + start_pos, + None, # attn_mask + 0.0, # dropout_p + True, # is_causal + custom_scale, # scale + True, # is_seq_dim_2 + ) + + # Assert: Verify output shape is correct + self.assertEqual(op_output.shape, q.shape) + self.assertEqual(op_output.dtype, torch.float32) + + def test_custom_sdpa_seq_dim_2_backward_compatible(self): + """Test that is_seq_dim_2 defaults to False for backward compatibility.""" + # Setup: Create tensors with default layout [batch, seq, heads, head_dim] + q = torch.rand((self.batch_size, 1, self.num_heads, self.head_dim)) + k_cache = torch.rand( + (self.batch_size, self.max_seq_len, self.num_heads, self.head_dim) + ) + v_cache = torch.rand( + (self.batch_size, self.max_seq_len, self.num_heads, self.head_dim) + ) + start_pos = 3 + + # Execute: Call without is_seq_dim_2 parameter (should default to False) + op_output = torch.ops.llama.custom_sdpa( + q, + k_cache, + v_cache, + start_pos, + None, # attn_mask + 0.0, # dropout_p + True, # is_causal + None, # scale + ) + + ref_output = _custom_sdpa_ref( + q, k_cache, v_cache, None, start_pos, is_seq_dim_2=False + ) + + # Assert: Should work with default layout and match reference + self.assertTrue(torch.allclose(op_output, ref_output, atol=1e-5)) + + def test_custom_sdpa_seq_dim_2_with_gqa(self): + """Test custom_sdpa with is_seq_dim_2=True and GQA (grouped query attention).""" + # Setup: Create tensors with GQA (different number of heads for Q vs K/V) + num_heads_q = 8 + num_heads_kv = 4 + seq_len = 2 + + q = torch.rand((self.batch_size, num_heads_q, seq_len, self.head_dim)) + k_cache = torch.rand( + (self.batch_size, num_heads_kv, self.max_seq_len, self.head_dim) + ) + v_cache = torch.rand( + (self.batch_size, num_heads_kv, self.max_seq_len, self.head_dim) + ) + start_pos = 3 + + # Execute: Run custom_sdpa with GQA setup + op_output = torch.ops.llama.custom_sdpa( + q, + k_cache, + v_cache, + start_pos, + None, # attn_mask + 0.0, # dropout_p + True, # is_causal + None, # scale + True, # is_seq_dim_2 + ) + + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + ref_output = _custom_sdpa_ref( + q, k_cache, v_cache, attn_mask, start_pos, is_seq_dim_2=True + ) + + # Assert: Verify outputs match for GQA case + self.assertTrue(torch.allclose(op_output, ref_output, atol=1e-5))