@@ -632,10 +632,22 @@ mha_fwd_kvcache(
632632 TORCH_CHECK (leftpad_k.dtype () == torch::kInt32 , " leftpad_k must have dtype int32" );
633633 }
634634
635- // Write new K/V to cache in-place
636- // Non-paged without padding: fused in kernel (knew/vnew passed to dispatch)
637- // Paged or needs-padding: API-layer scatter (kernel fusion not applicable)
638- bool fuse_knew = k_.has_value () && seqlen_new > 0 && !paged_KV && !needs_padding;
635+ // Write new K/V to cache.
636+ //
637+ // Strategy:
638+ // - Always prefer kernel-fused scatter (passes knew/vnew to the kernel,
639+ // which writes them in-place during the prologue). This avoids any
640+ // host sync and works for both contiguous and paged caches.
641+ // - Fall back to API-layer scatter only when fusion is impossible:
642+ // * needs_padding: the cache pad is a separate buffer, so the
643+ // in-kernel writer would write to the padded copy, not the user
644+ // tensor; do the scatter on the user tensor and re-pad.
645+ // * rotary_cos: the rotary application happened on the padded
646+ // buffer; we need to slice off the padding before scattering to
647+ // the user cache. (Kernel-fused scatter copies the padded buffer
648+ // instead, which is wrong.)
649+ bool fuse_knew = k_.has_value () && seqlen_new > 0
650+ && !needs_padding && !rotary_cos_.has_value ();
639651 if (k_.has_value () && seqlen_new > 0 && !fuse_knew) {
640652 auto seqlens_cpu = seqlens_k.to (torch::kCPU );
641653 auto seqlens_accessor = seqlens_cpu.accessor <int32_t , 1 >();
@@ -683,28 +695,8 @@ mha_fwd_kvcache(
683695 seqlens_k = seqlens_k + seqlen_new;
684696 }
685697
686- // For paged KV, gather to contiguous format
687- if (paged_KV) {
688- int num_pages_needed = (seqlen_k + page_block_size - 1 ) / page_block_size;
689- auto block_indices = block_table.index ({
690- torch::indexing::Slice (),
691- torch::indexing::Slice (0 , num_pages_needed)
692- }).flatten ();
693- auto k_gathered = kcache_padded.index_select (0 , block_indices.to (torch::kLong ));
694- auto v_gathered = vcache_padded.index_select (0 , block_indices.to (torch::kLong ));
695- k_gathered = k_gathered.view ({batch_size, num_pages_needed, page_block_size, num_heads_k, head_size_padded});
696- v_gathered = v_gathered.view ({batch_size, num_pages_needed, page_block_size, num_heads_k, head_size_padded});
697- k_gathered = k_gathered.view ({batch_size, num_pages_needed * page_block_size, num_heads_k, head_size_padded});
698- v_gathered = v_gathered.view ({batch_size, num_pages_needed * page_block_size, num_heads_k, head_size_padded});
699- kcache_padded = k_gathered.index ({
700- torch::indexing::Slice (), torch::indexing::Slice (0 , seqlen_k)
701- }).contiguous ();
702- vcache_padded = v_gathered.index ({
703- torch::indexing::Slice (), torch::indexing::Slice (0 , seqlen_k)
704- }).contiguous ();
705- }
706-
707- // Dispatch to kernel
698+ // Dispatch to kernel. Paged caches are now passed natively (block_table
699+ // routed straight through to the kernel, no host gather).
708700 auto queue = c10::xpu::getCurrentXPUStream (device_idx).queue ();
709701 const bool is_local = (window_size_left >= 0 );
710702
@@ -718,19 +710,25 @@ mha_fwd_kvcache(
718710 leftpad_k_opt = leftpad_k;
719711 }
720712
721- // For non-paged path with new KV, pass knew/vnew for fused scatter in kernel
713+ // For paths where new KV is appended in-kernel , pass knew/vnew through.
722714 std::optional<at::Tensor> knew_opt, vnew_opt;
723715 if (fuse_knew) {
724716 knew_opt = k_padded;
725717 vnew_opt = v_padded;
726718 }
727719
720+ std::optional<at::Tensor> block_table_opt;
721+ if (paged_KV) {
722+ block_table_opt = block_table;
723+ }
724+
728725 cutlass_fmha_fwd_kvcache_impl (
729726 queue,
730727 q_padded, kcache_padded, vcache_padded,
731728 out, softmax_lse,
732729 seqlens_k, cache_batch_idx_opt, leftpad_k_opt,
733730 knew_opt, vnew_opt,
731+ block_table_opt, seqlen_k,
734732 softmax_scale, window_size_left, window_size_right,
735733 is_causal, is_local);
736734
0 commit comments