diff --git a/docs/source/models/visual-generation.md b/docs/source/models/visual-generation.md index 58c2874b3045..41f1226e8e88 100644 --- a/docs/source/models/visual-generation.md +++ b/docs/source/models/visual-generation.md @@ -28,6 +28,7 @@ TensorRT-LLM **VisualGen** provides a unified inference stack for diffusion mode | `black-forest-labs/FLUX.2-dev` | Text-to-Image | | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` | Text-to-Video | | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | Text-to-Video | +| `FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers` | Text-to-Video (VSA) | | `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` | Image-to-Video | | `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` | Image-to-Video | | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | Text-to-Video | @@ -43,19 +44,22 @@ Models are auto-detected from the checkpoint directory. Diffusers-format models ### Feature Matrix -| Model | FP8 blockwise | NVFP4 | TeaCache | Cache-DiT | CFG Parallelism | Ulysses Parallelism | Parallel VAE | CUDA Graph | torch.compile | trtllm-serve | Attention2D | Ring Attention | Tensor Parallelism | -|---|---|---|---|---|---|---|---|---|---|---|--|--|--| -| **FLUX.1** | Yes | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | -| **FLUX.2** | Yes | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | -| **Wan 2.1** | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | -| **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | -| **LTX-2** | Yes | Yes | No | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No | -| **Qwen-Image** [^2] | Yes | Yes | No | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | -| **Cosmos3** | Yes | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | No | No | Yes | +| Model | FP8 blockwise | NVFP4 | TeaCache | Cache-DiT | CFG Parallelism | Ulysses Parallelism | Parallel VAE | CUDA Graph | torch.compile | trtllm-serve | Attention2D | Ring Attention | Tensor Parallelism | VSA | +|---|---|---|---|---|---|---|---|---|---|---|--|--|--|--| +| **FLUX.1** | Yes | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | No | +| **FLUX.2** | Yes | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | No | +| **Wan 2.1** | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | +| **Wan 2.1 VSA** [^3] | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | Yes | Yes | +| **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | +| **LTX-2** | Yes | Yes | No | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No | No | +| **Qwen-Image** [^2] | Yes | Yes | No | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | No | +| **Cosmos3** | Yes | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | No | No | Yes | No | [^1]: FLUX models use embedded guidance and do not have a separate negative prompt path, so CFG parallelism is not applicable. -[^2]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required. +[^2]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required. + +[^3]: `FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers` — VSA-fine-tuned checkpoint with learned sparse-attention gates. Requires `CUTEDSL` on Blackwell sm_100+ (falls back to dense SDPA on older hardware). Ring and Attention2D not supported (no LSE output); Ulysses supported. ## Quick Start @@ -222,6 +226,44 @@ args = VisualGenArgs( **Wan 2.2 dual-transformer note:** Wan 2.2 uses two expert transformers (high-noise and low-noise stacks). All `CacheDiTConfig` parameters apply to both stacks, except `max_warmup_steps` and `max_cached_steps`: the low-noise stack always uses fixed internal caps (`max_warmup_steps=2`, `max_cached_steps=20`) regardless of user config. +### Video Sparse Attention (VSA) + +VSA reduces the compute cost of self-attention in video diffusion models by selectively attending to only the most relevant spatial-temporal blocks. It uses a two-branch design: a lightweight coarse mean-pool branch computes block-level attention scores to identify the top-K most relevant token blocks, then a fine branch runs a block-sparse CuTe kernel over only those blocks. The two outputs are blended with learned gates. + +**Requirements:** +- VSA-fine-tuned checkpoint: [`FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers`](https://huggingface.co/FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers). Standard Wan checkpoints do not have the learned VSA gates. +- Blackwell GPU (sm_100+) for the CuTe JIT kernel. Falls back to dense SDPA on older hardware with no accuracy loss. +- `CUTEDSL` attention backend. +- Not compatible with Ring attention or Attention2D (VSA does not produce per-split LSE). Ulysses is supported. + +**`vsa_sparsity`** controls the fraction of K/V blocks skipped in the fine branch (0.0 = dense, 0.9 = 90% blocks skipped). Higher sparsity gives more speedup at the cost of some quality. + +Python API: + +```python +from tensorrt_llm import VisualGenArgs +from tensorrt_llm.visual_gen.args import AttentionConfig, VideoSparseAttentionConfig + +args = VisualGenArgs( + model="FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers", + attention_config=AttentionConfig( + backend="CUTEDSL", + sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=0.9), + ), +) +``` + +YAML (for use with `--visual_gen_args` or `trtllm-serve`): + +```yaml +attention_config: + backend: CUTEDSL + sparse_attention_config: + algorithm: vsa + vsa_sparsity: 0.90 +``` + + ### Multi-GPU Parallelism Configured under `VisualGenArgs.parallel_config`. Modes can be combined: diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py b/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py index 5c5f6f18a007..6b5d9b538d8b 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py @@ -20,7 +20,15 @@ simplified metadata that doesn't require KV caching. """ -from .cute_dsl import CuTeDSLAttention +from .cute_dsl import ( + VSA_TILE_SIZE, + CuTeDSLAttention, + VSAAttention, + VSAMetadata, + VSAMetadataBuilder, + get_vsa_forward_context, + set_vsa_forward_context, +) from .flash_attn4 import FlashAttn4Attention from .interface import AttentionBackend, AttentionTensorLayout from .parallel import Attention2DAttention, RingAttention, UlyssesAttention, wrap_parallel_attention @@ -35,6 +43,7 @@ "get_visual_gen_attention_backend", "create_attention", "CuTeDSLAttention", + "VSAAttention", "FlashAttn4Attention", "TrtllmAttention", "TrtllmAttentionMetadata", @@ -42,4 +51,9 @@ "VanillaAttention", "RingAttention", "wrap_parallel_attention", + "VSAMetadata", + "VSAMetadataBuilder", + "VSA_TILE_SIZE", + "get_vsa_forward_context", + "set_vsa_forward_context", ] diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/__init__.py b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/__init__.py new file mode 100644 index 000000000000..292be3036d5b --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/__init__.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +CuTe DSL attention backend family for visual generation models. + + fmha.py — CuTeDSLAttention (dense cubin path, head_dim=128) + vsa.py — VSAAttention (Video Sparse Attention, CuTe JIT + SDPA fallback) +""" + +from .fmha import CuTeDSLAttention, _cute_dsl_import_error +from .vsa import ( + VSA_KERNEL_MAX_CUBES, + VSA_TILE_SIZE, + VSAAttention, + VSAMetadata, + VSAMetadataBuilder, + VSAPreprocessor, + get_vsa_forward_context, + set_vsa_forward_context, +) + +__all__ = [ + "CuTeDSLAttention", + "VSAAttention", + "VSAMetadata", + "VSAMetadataBuilder", + "VSAPreprocessor", + "VSA_TILE_SIZE", + "VSA_KERNEL_MAX_CUBES", + "set_vsa_forward_context", + "get_vsa_forward_context", + "_cute_dsl_import_error", +] diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/fmha.py similarity index 96% rename from tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py rename to tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/fmha.py index 013065600207..c15f8b2ac47d 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/fmha.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -CuTe DSL (NVIDIA kernels) Backend for Visual Generation Models +CuTe DSL (NVIDIA kernels) Dense FMHA Backend for Visual Generation Models Uses pre-compiled cubins derived from CUTLASS CuTe DSL FMHA. Expects NHD layout ([B, S, H, D]) and supports float16/bfloat16. +For the VSA sparse path use VSAAttention in vsa.py. """ import math @@ -26,8 +27,8 @@ from tensorrt_llm.visual_gen.args import QuantAttentionConfig -from ...attention_backend.interface import PredefinedAttentionMask -from .interface import AttentionBackend, AttentionTensorLayout +from ....attention_backend.interface import PredefinedAttentionMask +from ..interface import AttentionBackend, AttentionTensorLayout _cute_dsl_import_error = None try: diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/vsa.py b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/vsa.py new file mode 100644 index 000000000000..741f82097ca2 --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/vsa.py @@ -0,0 +1,412 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Video Sparse Attention (VSA) backend for visual generation models. + +VSAAttention implements hierarchical sparse attention: + - Coarse branch: mean-pooled cube attention (always dense) + - Fine branch: block-sparse top-K attention via CuTe JIT kernel (sm100+) + or dense SDPA fallback when CuTe is unavailable / head_dim != 128. +""" + +import contextvars +from contextlib import contextmanager +from dataclasses import dataclass +from math import ceil +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F + +from ..interface import AttentionBackend, AttentionTensorLayout + +_vsa_import_error = None +try: + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + block_sparse_attn_from_indices_cute, + is_cute_supported, + ) +except (ImportError, OSError) as e: + block_sparse_attn_from_indices_cute = None + is_cute_supported = None + _vsa_import_error = e + + +# Must match the Blackwell kernel's block_size expectation. +VSA_TILE_SIZE: Tuple[int, int, int] = (4, 4, 4) + +# Kernel's SMEM buffer for variable_block_sizes is fixed-size and unchecked, +# so num_cubes must stay <= this. +VSA_KERNEL_MAX_CUBES: int = 4 * 1024 + + +def _get_tile_partition_indices( + dit_seq_shape: Tuple[int, int, int], + tile_size: Tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + T, H, W = dit_seq_shape + tT, tH, tW = tile_size + nT, nH, nW = ceil(T / tT), ceil(H / tH), ceil(W / tW) + + bt = torch.arange(nT, device=device).view(nT, 1, 1, 1, 1, 1) + bh = torch.arange(nH, device=device).view(1, nH, 1, 1, 1, 1) + bw = torch.arange(nW, device=device).view(1, 1, nW, 1, 1, 1) + lt = torch.arange(tT, device=device).view(1, 1, 1, tT, 1, 1) + lh = torch.arange(tH, device=device).view(1, 1, 1, 1, tH, 1) + lw = torch.arange(tW, device=device).view(1, 1, 1, 1, 1, tW) + + gt = bt * tT + lt + gh = bh * tH + lh + gw = bw * tW + lw + valid = (gt < T) & (gh < H) & (gw < W) + flat = gt * (H * W) + gh * W + gw + out = torch.where(valid, flat, torch.full_like(flat, -1)) + return out.reshape(-1).to(torch.long) + + +def _construct_variable_block_sizes( + dit_seq_shape: Tuple[int, int, int], + num_tiles: Tuple[int, int, int], + tile_size: Tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + T, H, W = dit_seq_shape + tT, tH, tW = tile_size + nT, nH, nW = num_tiles + + bt = torch.arange(nT, device=device) + bh = torch.arange(nH, device=device) + bw = torch.arange(nW, device=device) + valid_t = (T - bt * tT).clamp(max=tT) + valid_h = (H - bh * tH).clamp(max=tH) + valid_w = (W - bw * tW).clamp(max=tW) + sizes = valid_t.view(nT, 1, 1) * valid_h.view(1, nH, 1) * valid_w.view(1, 1, nW) + return sizes.reshape(-1).to(torch.long) + + +@dataclass +class VSAMetadata: + """Per-timestep metadata required by the VSA sparse path.""" + + current_timestep: int + dit_seq_shape: Tuple[int, int, int] + vsa_sparsity: float + num_tiles: Tuple[int, int, int] + total_seq_length: int + padded_seq_length: int + tile_partition_indices: torch.LongTensor + reverse_tile_partition_indices: torch.LongTensor + variable_block_sizes: torch.LongTensor + non_pad_index: torch.LongTensor + gather_idx: torch.LongTensor + + +class VSAMetadataBuilder: + """Builds VSAMetadata; caches per-shape index tensors so torch.compile + guards stay stable across denoising steps.""" + + def __init__(self) -> None: + self._cache: Dict[Tuple[Tuple[int, int, int], str], Dict[str, object]] = {} + + def _build_shape_payload( + self, + dit_seq_shape: Tuple[int, int, int], + device: torch.device, + ) -> Dict[str, object]: + T, H, W = dit_seq_shape + tT, tH, tW = VSA_TILE_SIZE + num_tiles = (ceil(T / tT), ceil(H / tH), ceil(W / tW)) + total_seq_length = T * H * W + padded_seq_length = num_tiles[0] * num_tiles[1] * num_tiles[2] * tT * tH * tW + + tile_partition_indices = _get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device) + non_pad_index = (tile_partition_indices >= 0).nonzero(as_tuple=True)[0] + gather_idx = tile_partition_indices[non_pad_index] + + reverse = torch.zeros(total_seq_length, dtype=torch.long, device=device) + reverse[gather_idx] = torch.arange(len(non_pad_index), dtype=torch.long, device=device) + + variable_block_sizes = _construct_variable_block_sizes( + dit_seq_shape, num_tiles, VSA_TILE_SIZE, device + ) + + return { + "dit_seq_shape": dit_seq_shape, + "num_tiles": num_tiles, + "total_seq_length": total_seq_length, + "padded_seq_length": padded_seq_length, + "tile_partition_indices": tile_partition_indices, + "reverse_tile_partition_indices": reverse, + "variable_block_sizes": variable_block_sizes, + "non_pad_index": non_pad_index, + "gather_idx": gather_idx, + } + + def build( + self, + current_timestep: int, + raw_latent_shape: Tuple[int, int, int], + patch_size: Tuple[int, int, int], + vsa_sparsity: float, + device: torch.device, + ) -> VSAMetadata: + dit_seq_shape = ( + raw_latent_shape[0] // patch_size[0], + raw_latent_shape[1] // patch_size[1], + raw_latent_shape[2] // patch_size[2], + ) + cache_key = (dit_seq_shape, str(device)) + payload = self._cache.get(cache_key) + if payload is None: + payload = self._build_shape_payload(dit_seq_shape, device) + self._cache[cache_key] = payload + + return VSAMetadata( + current_timestep=current_timestep, + vsa_sparsity=vsa_sparsity, + **payload, # type: ignore[arg-type] + ) + + +_vsa_forward_context_var: contextvars.ContextVar[Optional[VSAMetadata]] = contextvars.ContextVar( + "_vsa_forward_context", default=None +) + + +@contextmanager +def set_vsa_forward_context(metadata: VSAMetadata): + token = _vsa_forward_context_var.set(metadata) + try: + yield + finally: + _vsa_forward_context_var.reset(token) + + +def get_vsa_forward_context() -> Optional[VSAMetadata]: + return _vsa_forward_context_var.get(None) + + +def _mean_pool_cubes( + x_tiled: torch.Tensor, + variable_block_sizes: torch.LongTensor, + prod_tile: int, + num_cubes: int, +) -> torch.Tensor: + B, _padded, H, D = x_tiled.shape + x_cubes = x_tiled.view(B, num_cubes, prod_tile, H, D) + # fp32 accumulation: bf16 sum over 64 tokens perturbs the coarse softmax. + x_sum = x_cubes.float().sum(dim=2) + valid_counts = variable_block_sizes.float().clamp(min=1).view(1, num_cubes, 1, 1) + return (x_sum / valid_counts).to(x_tiled.dtype) + + +class VSAPreprocessor: + """Reorders NHD tokens into tile-major layout and zero-pads to tile boundaries.""" + + @staticmethod + def tile( + x: torch.Tensor, + non_pad_index: torch.LongTensor, + gather_idx: torch.LongTensor, + padded_seq_len: int, + ) -> torch.Tensor: + # index_select + index_copy_ instead of chained advanced indexing so + # torch.compile can trace this without a graph break. + B, _S, H, D = x.shape + x_valid = x.index_select(1, gather_idx) + x_padded = x.new_zeros(B, padded_seq_len, H, D) + x_padded.index_copy_(1, non_pad_index, x_valid) + return x_padded + + @staticmethod + def untile( + x: torch.Tensor, + reverse_tile_partition_indices: torch.LongTensor, + non_pad_index: torch.LongTensor, + ) -> torch.Tensor: + return x.index_select(1, non_pad_index).index_select(1, reverse_tile_partition_indices) + + +class VSAAttention(AttentionBackend): + """ + Video Sparse Attention (VSA) backend for diffusion models. + + Implements coarse mean-pool + fine block-sparse top-K attention. + The fine branch uses a JIT-compiled CuTe kernel on sm100+ for + head_dim=128 / fp16-bf16; otherwise falls back to dense SDPA. + + Requires an active VSA forward context (set_vsa_forward_context) during + each forward call. Does not support LSE output. + """ + + def __init__( + self, + layer_idx: int = 0, + num_heads: int = 8, + head_dim: int = 128, + num_kv_heads: Optional[int] = None, + dtype: Optional[torch.dtype] = None, + sparse_attention_config=None, + **kwargs, + ): + self.layer_idx = layer_idx + self.num_heads = num_heads + self.head_dim = head_dim + self.num_kv_heads = num_kv_heads or num_heads + assert self.num_kv_heads == self.num_heads, ( + f"VSA coarse mean-pool assumes MHA (num_kv_heads == num_heads), " + f"got num_kv_heads={self.num_kv_heads}, num_heads={self.num_heads}. " + f"GQA/MQA is not supported." + ) + self.dtype = dtype + self.sparse_attention_config = sparse_attention_config + + # Dynamo can't guard on the module-level mutable global, so this read + # runs in eager. + @torch.compiler.disable + def _get_vsa_inputs(self): + ctx: Optional[VSAMetadata] = get_vsa_forward_context() + if ctx is None: + raise RuntimeError( + "VSAAttention.forward called without an active VSA forward context. " + "Wrap each transformer call with set_vsa_forward_context()." + ) + return ( + ctx.non_pad_index, + ctx.gather_idx, + ctx.reverse_tile_partition_indices, + ctx.variable_block_sizes, + ctx.padded_seq_length, + ctx.num_tiles[0] * ctx.num_tiles[1] * ctx.num_tiles[2], + ctx.vsa_sparsity, + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + gate_compress: Optional[torch.Tensor] = None, + gate_fine: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + VSA forward: coarse mean-pool + fine block-sparse top-K. + + Args: + q, k, v: [B, S, H, D] in original (un-tiled) token order. + gate_compress: [B, S, H, D] G_c gate weighting the coarse branch O_c. + gate_fine: Optional [B, S, H, D] G_f gate weighting the fine branch + O_f. None means constant 1 (dense behavior preserved). + + Returns: + [B, S, H, D] in the same original token order. + """ + if gate_compress is None: + raise ValueError( + "VSAAttention requires gate_compress. " + "Ensure to_gate_compress is wired in the transformer block." + ) + + ( + non_pad_index, + gather_idx, + reverse_tile_partition_indices, + variable_block_sizes, + padded_len, + num_cubes, + vsa_sparsity, + ) = self._get_vsa_inputs() + + B, S, H, D = q.shape + prod_tile = VSA_TILE_SIZE[0] * VSA_TILE_SIZE[1] * VSA_TILE_SIZE[2] + cur_topk = max(1, ceil((1.0 - vsa_sparsity) * num_cubes)) + + q_t = VSAPreprocessor.tile(q, non_pad_index, gather_idx, padded_len) + k_t = VSAPreprocessor.tile(k, non_pad_index, gather_idx, padded_len) + v_t = VSAPreprocessor.tile(v, non_pad_index, gather_idx, padded_len) + + q_c = _mean_pool_cubes(q_t, variable_block_sizes, prod_tile, num_cubes) + k_c = _mean_pool_cubes(k_t, variable_block_sizes, prod_tile, num_cubes) + v_c = _mean_pool_cubes(v_t, variable_block_sizes, prod_tile, num_cubes) + + scale = D**-0.5 + scores_c = torch.einsum("bnhd,bmhd->bhnm", q_c, k_c) * scale + attn_probs_c = scores_c.softmax(dim=-1) + o_c = torch.einsum("bhnm,bmhd->bnhd", attn_probs_c, v_c) + + use_cute = ( + _vsa_import_error is None + and is_cute_supported(q) + and (q.dtype == k.dtype == v.dtype) + and num_cubes <= VSA_KERNEL_MAX_CUBES + ) + topk_indices = attn_probs_c.topk(cur_topk, dim=-1).indices.to(torch.int32) + + o_c_tiled = ( + o_c.unsqueeze(2).expand(B, num_cubes, prod_tile, H, D).reshape(B, padded_len, H, D) + ) + + if use_cute: + q_hnd = q_t.transpose(1, 2).contiguous() + k_hnd = k_t.transpose(1, 2).contiguous() + v_hnd = v_t.transpose(1, 2).contiguous() + q2k_num = torch.full((B, H, num_cubes), cur_topk, dtype=torch.int32, device=q.device) + o_hnd, _lse = block_sparse_attn_from_indices_cute( + q_hnd, + k_hnd, + v_hnd, + q2k_idx=topk_indices.contiguous(), + q2k_num=q2k_num, + variable_block_sizes=variable_block_sizes.to(torch.int32), + ) + o_f_tiled = o_hnd.transpose(1, 2) + + # Padded rows hold kernel garbage; zero-padded gates mask the coarse + # term and untile discards padded positions from both branches. + gate_c_t = VSAPreprocessor.tile(gate_compress, non_pad_index, gather_idx, padded_len) + if gate_fine is not None: + gate_f_t = VSAPreprocessor.tile(gate_fine, non_pad_index, gather_idx, padded_len) + combined_tiled = gate_c_t * o_c_tiled + gate_f_t * o_f_tiled + else: + combined_tiled = gate_c_t * o_c_tiled + o_f_tiled + return VSAPreprocessor.untile( + combined_tiled, reverse_tile_partition_indices, non_pad_index + ) + + # SDPA must run on the un-tiled Q/K/V — padded zero K/V slots would + # otherwise absorb softmax mass and pollute the output. Untile o_c so + # both branches combine in original-flat order. + o_c_full = VSAPreprocessor.untile(o_c_tiled, reverse_tile_partition_indices, non_pad_index) + o_f = F.scaled_dot_product_attention( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + ).transpose(1, 2) + if gate_fine is not None: + return gate_compress * o_c_full + gate_fine * o_f + return gate_compress * o_c_full + o_f + + @classmethod + def support_lse(cls) -> bool: + return False + + @property + def preferred_layout(self) -> AttentionTensorLayout: + return AttentionTensorLayout.NHD + + @classmethod + def support_fused_qkv(cls) -> bool: + return False diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py b/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py index 74fc7d968e81..c92c870534c6 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py @@ -163,17 +163,33 @@ def _forward_fused( v: torch.Tensor, **kwargs, ) -> torch.Tensor: + gate_compress = kwargs.pop("gate_compress", None) + gate_fine = kwargs.pop("gate_fine", None) + batch_size = q.shape[0] qkv = torch.stack([q, k, v], dim=2) qkv = all_to_all_5d(qkv, scatter_dim=3, gather_dim=1, process_group=self.process_group) B, seq_len, _, Hp, D = qkv.shape + if gate_compress is not None: + gate_compress = all_to_all_4d( + gate_compress, scatter_dim=2, gather_dim=1, process_group=self.process_group + ) + if gate_fine is not None: + gate_fine = all_to_all_4d( + gate_fine, scatter_dim=2, gather_dim=1, process_group=self.process_group + ) + # Caller passed pre-A2A (sharded) seq_len; the inner backend # reshapes by it, so hand it the post-A2A length instead. kwargs["batch_size"] = batch_size kwargs["seq_len"] = seq_len kwargs["seq_len_kv"] = seq_len + if gate_compress is not None: + kwargs["gate_compress"] = gate_compress + if gate_fine is not None: + kwargs["gate_fine"] = gate_fine output = self.inner_backend.forward(q=qkv, k=None, v=None, **kwargs) @@ -186,10 +202,23 @@ def _forward_unfused( v: torch.Tensor, **kwargs, ) -> torch.Tensor: + # gate_compress / gate_fine (VSA) must follow the same all-to-all as + # Q/K/V so they arrive at the inner backend in the same (full-S, sharded-H) layout. + gate_compress = kwargs.pop("gate_compress", None) + gate_fine = kwargs.pop("gate_fine", None) + batch_size = q.shape[0] q = all_to_all_4d(q, scatter_dim=2, gather_dim=1, process_group=self.process_group) k = all_to_all_4d(k, scatter_dim=2, gather_dim=1, process_group=self.process_group) v = all_to_all_4d(v, scatter_dim=2, gather_dim=1, process_group=self.process_group) + if gate_compress is not None: + gate_compress = all_to_all_4d( + gate_compress, scatter_dim=2, gather_dim=1, process_group=self.process_group + ) + if gate_fine is not None: + gate_fine = all_to_all_4d( + gate_fine, scatter_dim=2, gather_dim=1, process_group=self.process_group + ) seq_len_full = q.shape[1] kv_seq_len_full = k.shape[1] @@ -198,12 +227,20 @@ def _forward_unfused( q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) + if gate_compress is not None: + gate_compress = gate_compress.transpose(1, 2) + if gate_fine is not None: + gate_fine = gate_fine.transpose(1, 2) # Caller passed pre-A2A (sharded) seq_lens; hand the inner # backend the post-A2A lengths instead. kwargs["batch_size"] = batch_size kwargs["seq_len"] = seq_len_full kwargs["seq_len_kv"] = kv_seq_len_full + if gate_compress is not None: + kwargs["gate_compress"] = gate_compress + if gate_fine is not None: + kwargs["gate_fine"] = gate_fine output = self.inner_backend.forward(q=q, k=k, v=v, **kwargs) @@ -344,7 +381,6 @@ def _output_a2a( output = all_to_all_4d( output, scatter_dim=1, gather_dim=2, process_group=self.process_group ) - return output @property diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py b/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py index a3ce3df89cd9..e52f52c87247 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py @@ -128,6 +128,16 @@ def create_attention( "DiffusionModelConfig; creation path must not allocate metadata implicitly." ) kwargs["attention_metadata_state"] = attention_metadata_state + if backend.upper() == "CUTEDSL" and attention_config is not None: + if ( + attention_config.sparse_attention_config is not None + and getattr(attention_config.sparse_attention_config, "algorithm", None) == "vsa" + ): + # VSA sparse path: use VSAAttention + from .cute_dsl.vsa import VSAAttention + + attn_cls = VSAAttention + kwargs["sparse_attention_config"] = attention_config.sparse_attention_config return attn_cls( layer_idx=layer_idx, diff --git a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/__init__.py b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/__init__.py new file mode 100644 index 000000000000..2ced031ae28c --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/__init__.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .block_sparse_attn_dsl_fwd import ( + VideoSparseAttentionForwardGroup2QInterleaveKV as VideoSparseAttentionForward, +) +from .interface import CUTE_AVAILABLE, block_sparse_attn_from_indices_cute, is_cute_supported + +__all__ = [ + "CUTE_AVAILABLE", + "VideoSparseAttentionForward", + "block_sparse_attn_from_indices_cute", + "is_cute_supported", +] diff --git a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py new file mode 100644 index 000000000000..a2d814b2dbc8 --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py @@ -0,0 +1,2390 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Video Sparse Attention Forward (Blackwell) + +This file implements the forward pass of the video sparse attention. +It will produce the output and the log-sum-exp of the attention scores. + +This implementation requires zero-padding on the input tensor, to align with the block-size (64x64). +""" + +import math +from functools import partial +from typing import Callable, Optional, Tuple + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.nvgpu import cpasync, tcgen05 + +from . import ptx, scheduler + + +def make_thread_cooperative_group(size: int): + """ + Create a thread cooperative group. + """ + return pipeline.CooperativeGroup(pipeline.Agent.Thread, size) + + +def next_power_of_2(x: int) -> int: + if x <= 0: + return 1 + return 1 << (x - 1).bit_length() + + +SM100_TMEM_CAPACITY_COLUMNS: int = 512 + + +class VideoSparseAttentionForwardGroup2QInterleaveKV: + """ + This class implements the forward pass of the video sparse attention. + It doesn't require swap QK. + V0 + K0 K1 V1 + Q0 S0 0 O0 + Q1 0 S1 O1 + """ + + MAX_INDICES = 4 * 1024 + + def __init__( + self, + block_m: int, + block_n: int, + headdim: int, + ): + self.block_m = block_m + self.block_n = block_n + assert block_m == 64 and block_n == 64, "Block size must be 64x64" + self.headdim = headdim + assert self.headdim == 128, "Head dimension must be 128" + + self.acc_dtype: cutlass.Numeric = cutlass.Float32 + self.cta_group = tcgen05.CtaGroup.ONE + self.cluster_shape_mn = (1, 1) + self.mma_tiler_qk = (self.block_m * 2, self.block_n * 2, 1) + self.mma_tiler_pv = (self.block_m * 2, self.headdim, 1) + + self.occupancy = 1 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + self.threads_per_warp: int = 32 + self.threads_per_wg: int = 128 + + self.load_warp_id = 0 + self.mma_warp_id = 1 + self.epilogue_warp_id = 2 + self.empty_warp_ids = (3,) + self.softmax_warp_ids = (4, 5, 6, 7) + self.correction_warp_ids = (8, 9, 10, 11) + + self.threads_per_cta: int = self.threads_per_warp * len( + ( + *self.softmax_warp_ids, + *self.correction_warp_ids, + self.mma_warp_id, + self.load_warp_id, + self.epilogue_warp_id, + *self.empty_warp_ids, + ) + ) + + self.buffer_align_bytes: int = 1024 + + self.scheduler_cls = scheduler.StaticPersistentScheduler + + self.rescale_threshold: float = 8.0 + + self.mask_O: bool = False + # either P is in SMEM or in TMEM + self.p_in_smem: bool = False + + self.num_regs_load: int = 104 + self.num_regs_mma: int = 104 + self.num_regs_epi: int = 104 + self.num_regs_softmax: int = 224 # bigger as much as possible + self.num_regs_correction: int = 176 + self.num_regs_empty: int = 24 + + def _compute_grid( + self, + num_q_blocks: int, + num_kv_blocks: int, + num_heads: int, + batchsize: int, + headdim: int, + headdim_v: int, + ) -> Tuple[Tuple[int, int, int], scheduler.ParamsBase]: + cluster_shape = (*self.cluster_shape_mn, 1) + + scheduler_params = scheduler.TileSchedulerParams( + num_block=cute.ceil_div(num_q_blocks, 2), + num_head=num_heads, + num_batch=batchsize, + headdim=headdim, + headdim_v=headdim_v, + ) + params = self.scheduler_cls.to_underlying_arguments(scheduler_params) + + grid = self.scheduler_cls.get_grid_shape(params) + grid = cute.round_up(grid, cluster_shape) + return grid, params + + def _compute_stages( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + input_dtype: cutlass.Numeric, + ): + self.q_stage = 1 + assert self.q_stage == 1, "Q will be preserving in SMEM" + self.kv_stage = 2 if self.p_in_smem else 3 + self.s_stage = 2 + assert self.s_stage == 2, "S stage must be 2 for Interleaving KV blocks" + self.o_stage = 1 + self.epi_stage = 1 if self.p_in_smem else 2 + + self.scale_buffers = 2 + + self.tmem_cols_S = self.mma_tiler_qk[1] * self.s_stage + self.tmem_cols_O = self.mma_tiler_pv[1] * self.o_stage + + # P reuses S TMEM, when there is no space for P + self.p_in_s: bool = (not self.p_in_smem) and (self.o_stage == 2) + + self.tmem_cols_P = (self.mma_tiler_pv[1] // 2) * self.s_stage * (not self.p_in_s) + + self.tmem_offset_S = 0 + self.tmem_offset_O = self.tmem_cols_S + self.tmem_offset_P = self.tmem_offset_O + self.tmem_cols_O + + self.tmem_alloc_cols = self.tmem_cols_S + self.tmem_cols_O + self.tmem_cols_P + self.tmem_alloc_cols = next_power_of_2(self.tmem_alloc_cols) + assert self.tmem_alloc_cols <= SM100_TMEM_CAPACITY_COLUMNS + self.do_tmem_alloc: bool = self.tmem_alloc_cols != SM100_TMEM_CAPACITY_COLUMNS + + def _setup_attributes( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + input_dtype: cutlass.Numeric, + ): + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,) + ) + + qk_mma_inst_shape_k = cute.size(tiled_mma_qk.shape_mnk, mode=[2]) + qk_mma_inst_tile_k: int = self.headdim // qk_mma_inst_shape_k + self.mma_tiler_qk = ( + self.mma_tiler_qk[0], + self.mma_tiler_qk[1], + qk_mma_inst_shape_k * qk_mma_inst_tile_k, + ) + + pv_mma_inst_shape_k = cute.size(tiled_mma_pv.shape_mnk, mode=[2]) + pv_mma_inst_tile_k: int = self.mma_tiler_qk[1] // pv_mma_inst_shape_k + self.mma_tiler_pv = ( + self.mma_tiler_pv[0], + self.mma_tiler_pv[1], + pv_mma_inst_shape_k * pv_mma_inst_tile_k, + ) + + self._compute_stages(tiled_mma_qk, tiled_mma_pv, input_dtype) + + @cute.jit + def __call__( + self, + Q: cute.Tensor, + K: cute.Tensor, + V: cute.Tensor, + sm_scale: cutlass.Float32, + O: cute.Tensor, # noqa: E741 + LSE: cute.Tensor, + q2k_block_sparse_index: cute.Tensor, + q2k_block_sparse_num: cute.Tensor, + variable_block_sizes: cute.Tensor, + stream: cuda.CUstream, + ) -> None: + input_dtype = Q.element_type + if cutlass.const_expr(input_dtype not in [cutlass.Float16, cutlass.BFloat16]): + raise RuntimeError("Input dtype must be Float16 or BFloat16") + if cutlass.const_expr( + not (input_dtype == K.element_type == V.element_type == O.element_type) + ): + raise RuntimeError("All input tensors must have the same data type") + + batch, heads, seqlen, dim = Q.layout.shape + if cutlass.const_expr(dim != self.headdim): + raise RuntimeError(f"Dimension mismatch: {dim} != {self.headdim}") + + _, _, _, dim_v = V.layout.shape + if cutlass.const_expr(dim_v != self.headdim): + raise RuntimeError(f"Dimension mismatch: {dim_v} != {self.headdim}") + + _, _, num_q_blocks, num_kv_blocks = q2k_block_sparse_index.layout.shape + + grid, params = self._compute_grid( + num_q_blocks=num_q_blocks, + num_kv_blocks=num_kv_blocks, + num_heads=heads, + batchsize=batch, + headdim=dim, + headdim_v=dim_v, + ) + + # [batch, heads, seqlen, dim] -> [seqlen, dim, heads, batch] + Q_layout_transpose = [2, 3, 1, 0] + KV_layout_transpose = [2, 3, 1, 0] + # [batch, heads, seqlen, dim] -> [seqlen, dim, heads, batch] + O_layout_transpose = [2, 3, 1, 0] + # [seqlen, dim, heads, batch] -> [dim, seqlen, heads, batch] + V_layout_transpose = [1, 0, 2, 3] + # [batch, heads, seqlen] -> [seqlen, heads, batch] + LSE_layout_transpose = [2, 1, 0] + Q = cute.make_tensor(Q.iterator, cute.select(Q.layout, mode=Q_layout_transpose)) + K = cute.make_tensor(K.iterator, cute.select(K.layout, mode=KV_layout_transpose)) + V = cute.make_tensor(V.iterator, cute.select(V.layout, mode=KV_layout_transpose)) + # NOTE: need transpose here, make V be N-major + V = cute.make_tensor(V.iterator, cute.select(V.layout, mode=V_layout_transpose)) + O = cute.make_tensor(O.iterator, cute.select(O.layout, mode=O_layout_transpose)) # noqa: E741 + LSE = cute.make_tensor(LSE.iterator, cute.select(LSE.layout, mode=LSE_layout_transpose)) + + q_major_mode = utils.LayoutEnum.from_tensor(Q).mma_major_mode() + k_major_mode = utils.LayoutEnum.from_tensor(K).mma_major_mode() + v_major_mode = tcgen05.OperandMajorMode.MN + + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + input_dtype, + q_major_mode, + k_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler_qk[:2], + ) + p_major_mode = tcgen05.OperandMajorMode.K + p_source = tcgen05.OperandSource.SMEM if self.p_in_smem else tcgen05.OperandSource.TMEM + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + input_dtype, + p_major_mode, + v_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler_pv[:2], + a_source=p_source, + ) + self._setup_attributes(tiled_mma_qk, tiled_mma_pv, input_dtype) + + o_layout = utils.LayoutEnum.from_tensor(O) + self.epi_tile = (self.mma_tiler_pv[0], self.mma_tiler_pv[1]) + + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, input_dtype, self.q_stage + ) + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, input_dtype, self.kv_stage + ) + + sP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, input_dtype, self.s_stage + ) + + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, input_dtype, self.kv_stage + ) + + sO_layout = sm100_utils.make_smem_layout_epi( + input_dtype, o_layout, self.epi_tile, self.epi_stage + ) + fake_sO_layout = sm100_utils.make_smem_layout_epi( + input_dtype, o_layout, (self.block_m, self.headdim), self.epi_stage + ) + + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) + for name, mX, layout in [ + ("Q", Q, sQ_layout), + ("K", K, sK_layout), + ("V", V, sV_layout), + ] + } + + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(self.cta_group) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + + # use fake tiled mma to get TMA atoms + fake_tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + input_dtype, + q_major_mode, + k_major_mode, + self.acc_dtype, + self.cta_group, + (self.block_m, self.block_n), + ) + fake_sQ_layout = sm100_utils.make_smem_layout_a( + fake_tiled_mma_qk, + (self.block_m, self.block_n, self.mma_tiler_qk[2]), + input_dtype, + self.q_stage, + ) + fake_sK_layout = sm100_utils.make_smem_layout_b( + fake_tiled_mma_qk, + (self.block_m, self.block_n, self.mma_tiler_qk[2]), + input_dtype, + self.kv_stage, + ) + + tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + Q, + cute.select(fake_sQ_layout, mode=[0, 1, 2]), + (self.block_m, self.block_n, self.mma_tiler_qk[2]), + fake_tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + K, + cute.select(fake_sK_layout, mode=[0, 1, 2]), + (self.block_m, self.block_n, self.mma_tiler_qk[2]), + fake_tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + + fake_tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + input_dtype, + p_major_mode, + v_major_mode, + self.acc_dtype, + self.cta_group, + (self.block_m, self.block_n), + ) + fake_sV_layout = sm100_utils.make_smem_layout_b( + fake_tiled_mma_pv, + (self.block_m, self.block_n, self.block_n), + input_dtype, + self.kv_stage, + ) + + tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + V, + cute.select(fake_sV_layout, mode=[0, 1, 2]), + (self.block_m, self.block_n, self.block_n), + fake_tiled_mma_pv, + self.cluster_layout_vmnk.shape, + ) + + o_cta_v_layout = cute.composition( + cute.make_identity_layout(O.shape), (self.block_m, self.headdim) + ) + tma_atom_O, mO = cpasync.make_tiled_tma_atom( + tma_store_op, O, cute.select(fake_sO_layout, mode=[0, 1]), o_cta_v_layout + ) + + sScale_layout = None + if cutlass.const_expr(self.o_stage == 1): + # [acc_scale] + sScale_layout = cute.make_ordered_layout((self.mma_tiler_qk[0], self.s_stage), (0, 1)) + else: + # [acc_scale, running_max] + sScale_layout = cute.make_ordered_layout( + (2, self.mma_tiler_qk[0], self.s_stage), (0, 1, 2) + ) + + # [running_sum, running_max] * scale_buffers + # TODO: rearrange this to make to LDS.64 instead of LDS.32 x 2 + sFinal_layout = cute.make_ordered_layout( + (2, self.mma_tiler_qk[0], self.scale_buffers), (0, 1, 2) + ) + + self.max_indices = self.MAX_INDICES + + @cute.struct + class SharedStorage: + load_KV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.kv_stage * 2] + qk_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.s_stage * 2] + p_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.s_stage * 2] + o_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.o_stage * 2] + pv_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.o_stage * 2] + store_O_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_stage * 2] + corr_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.s_stage * 2] + + tmem_holding_buf: cutlass.Int32 + + sScale: cute.struct.MemRange[cutlass.Float32, cute.cosize(sScale_layout)] + + sFinal: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, cute.cosize(sFinal_layout)], + 8, # for LDS.64 it shall be 8 bytes aligned. + ] + sFinal_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.scale_buffers * 2] + + O_final_guard: cutlass.Int64 + + sQ: cute.struct.Align[ + cute.struct.MemRange[input_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[input_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes, + ] + sP: cute.struct.Align[ + cute.struct.MemRange[input_dtype, cute.cosize(sP_layout) * (self.p_in_smem)], + self.buffer_align_bytes, + ] + sO: cute.struct.Align[ + cute.struct.MemRange[input_dtype, cute.cosize(sO_layout)], + self.buffer_align_bytes, + ] + sVariable_block_sizes: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int32, self.max_indices], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + self.LOG2_E: float = math.log2(math.e) + self.LN2: float = math.log(2.0) + sm_scale_log2 = sm_scale * self.LOG2_E + + self.kernel( + mQ, + mK, + mV, + sm_scale_log2, + mO, + LSE, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_O, + tiled_mma_qk, + tiled_mma_pv, + fake_tiled_mma_qk, + fake_tiled_mma_pv, + sQ_layout, + sK_layout, + sV_layout, + sP_layout, + sScale_layout, + sFinal_layout, + sO_layout, + self.cluster_layout_vmnk, + params, + q2k_block_sparse_index, + q2k_block_sparse_num, + variable_block_sizes, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, # [seq, dim, head, batch] + mK: cute.Tensor, # [seq, dim, head, batch] + mV: cute.Tensor, # [dim, seq, head, batch] + sm_scale_log2: cutlass.Float32, # softmax scale in log2 + mO: cute.Tensor, # [dim, seq, head, batch] + mLSE: cute.Tensor, # [seq, head, batch] + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_O: cute.CopyAtom, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + fake_tiled_mma_qk: cute.TiledMma, + fake_tiled_mma_pv: cute.TiledMma, + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sP_layout: cute.ComposedLayout, + sScale_layout: cute.Layout, + sFinal_layout: cute.Layout, + sO_layout: cute.ComposedLayout, + cluster_layout_vmnk: cute.Layout, + scheduler_params: scheduler.ParamsBase, + q2k_block_sparse_index: cute.Tensor, + q2k_block_sparse_num: cute.Tensor, + variable_block_sizes: cute.Tensor, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_O) + + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + KV_mbar_ptr = storage.load_KV_mbar_ptr.data_ptr() + O_final_guard = storage.O_final_guard + if warp_idx == self.load_warp_id: + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + O_final_guard, self.threads_per_warp * len(self.correction_warp_ids) + ) + for i in cutlass.range_constexpr(self.kv_stage, unroll_full=True): + # producer + cute.arch.mbarrier_init(KV_mbar_ptr + i * 2, len([self.load_warp_id])) + # consumer + cute.arch.mbarrier_init(KV_mbar_ptr + i * 2 + 1, len([self.mma_warp_id])) + kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kv_stage + ) + kv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kv_stage + ) + + qk_mma_pipeline = pipeline.PipelineUmmaAsync.create( + num_stages=self.s_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax_warp_ids) + ), + barrier_storage=storage.qk_mma_mbar_ptr.data_ptr(), + ) + qk_mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.s_stage + ) + qk_mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.s_stage + ) + + p_pipeline = pipeline.PipelineAsyncUmma.create( + num_stages=self.s_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax_warp_ids) + ), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.p_mbar_ptr.data_ptr(), + ) + p_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.s_stage + ) + p_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.s_stage + ) + o_pipeline = pipeline.PipelineAsyncUmma.create( + num_stages=self.o_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.o_mbar_ptr.data_ptr(), + ) + o_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.o_stage + ) + o_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.o_stage + ) + + pv_mma_pipeline = pipeline.PipelineUmmaAsync.create( + num_stages=self.o_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.pv_mma_mbar_ptr.data_ptr(), + ) + pv_mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.o_stage + ) + pv_mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.o_stage + ) + + st_O_pipeline = pipeline.PipelineAsync.create( + num_stages=self.epi_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len([self.epilogue_warp_id]) + ), + barrier_storage=storage.store_O_mbar_ptr.data_ptr(), + ) + st_O_producer_state = pipeline.PipelineState( + self.epi_stage, + cutlass.Int32(0), + cutlass.Int32(0), + cutlass.Int32(0), + ) + st_O_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.epi_stage + ) + + corr_pipeline = pipeline.PipelineAsync.create( + num_stages=self.s_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.corr_mbar_ptr.data_ptr(), + ) + + correction_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.s_stage + ) + + sFinal_pipeline = pipeline.PipelineAsync.create( + num_stages=self.scale_buffers, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.sFinal_mbar_ptr.data_ptr(), + ) + sFinal_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.scale_buffers + ) + sFinal_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.scale_buffers + ) + + cute.arch.mbarrier_init_fence() + + tmem_holding_buf = storage.tmem_holding_buf + if cutlass.const_expr(self.do_tmem_alloc): + if warp_idx == 0: + cute.arch.alloc_tmem( + self.tmem_alloc_cols, + tmem_holding_buf, + is_two_cta=False, + ) + cute.arch.sync_threads() + + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + # NOTE: V will reuse the same smem as K + # stripe swizzle info to reuse smem + sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout.inner), sV_layout.outer) + + sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) + + sVariable_block_sizes_layout = cute.make_layout((self.max_indices,)) + sVariable_block_sizes = storage.sVariable_block_sizes.get_tensor( + sVariable_block_sizes_layout + ) + + for i in cutlass.range(tidx, variable_block_sizes.shape[0], cute.arch.block_dim()[0]): + sVariable_block_sizes[i] = variable_block_sizes[i] + # cta sync + cute.arch.sync_threads() + + sScale = storage.sScale.get_tensor(sScale_layout) + sFinal = storage.sFinal.get_tensor(sFinal_layout) + + thr_mma_qk = tiled_mma_qk.get_slice(0) + thr_mma_pv = tiled_mma_pv.get_slice(0) + + # ----- GMEM partition ----------- # + # [block_m, dimk, block_cnt, loop_k, head, batch] + gQ = cute.flat_divide(mQ, (self.block_m, self.headdim)) + gK = cute.flat_divide(mK, (self.block_n, self.headdim)) + gO = cute.flat_divide(mO, (self.block_m, self.headdim)) + # [dimK, block_n, loop_k, block_cnt, head, batch] + gV = cute.flat_divide(mV, (self.headdim, self.block_n)) + + # [block_m, block_cnt, head, batch] + gLSE = cute.flat_divide(mLSE, (self.block_m,)) + + # ----- TMEM partition ----------- # + # NOTE: we already know each SM's occupancy is 1, + # so tmem_ptr starting at zero + tmem_ptr = cute.make_ptr( + self.acc_dtype, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16 + ) + if cutlass.const_expr(self.do_tmem_alloc): + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + + tS_shape = (self.mma_tiler_qk[0], self.tmem_cols_S) + tCtS_shape = thr_mma_qk.partition_shape_C(tS_shape) + tCtS_fake = thr_mma_qk.make_fragment_C(tCtS_shape) + tCtS = cute.make_tensor( + cute.recast_ptr(tmem_ptr + self.tmem_offset_S, dtype=self.acc_dtype), tCtS_fake.layout + ) + sP = None + if cutlass.const_expr(self.p_in_smem): + sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) + else: + if cutlass.const_expr(self.p_in_s): + p_ptr = cute.recast_ptr(tmem_ptr + self.tmem_offset_S, dtype=sV.element_type) + tP = cute.make_tensor(p_ptr, sP_layout.outer) + tCtP = thr_mma_pv.make_fragment_A(tP) + stride = tCtP.stride + # since it will reuse the S TMEM, we need to make sure the stride is correct + stride = (*(stride[i] for i in range(len(stride) - 1)), stride[len(stride) - 1] * 2) + layout = cute.make_layout(tCtP.shape, stride=stride) + tCtP = cute.make_tensor(tCtP.iterator, layout) + sP = cute.make_tensor(p_ptr, tCtP.layout) + else: + # P has its own TMEM + p_ptr = cute.recast_ptr(tmem_ptr + self.tmem_offset_P, dtype=sV.element_type) + tP = cute.make_tensor(p_ptr, sP_layout.outer) + tCtP = thr_mma_pv.make_fragment_A(tP) + sP = cute.make_tensor(p_ptr, tCtP.layout) + + tO_shape = (self.mma_tiler_pv[0], self.tmem_cols_O) + tCtO_shape = thr_mma_pv.partition_shape_C(tO_shape) + tCtO_fake = thr_mma_pv.make_fragment_C(tCtO_shape) + tCtO = cute.make_tensor( + cute.recast_ptr(tmem_ptr + self.tmem_offset_O, dtype=self.acc_dtype), tCtO_fake.layout + ) + + TileSchedulerCls = partial(self.scheduler_cls.create, scheduler_params) + + if warp_idx in self.empty_warp_ids: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_load) + + self.load( + TileSchedulerCls, + gQ, + gK, + gV, + KV_mbar_ptr, + kv_producer_state, + q2k_block_sparse_num, + q2k_block_sparse_index, + sQ, + sK, + sV, + fake_tiled_mma_qk, + fake_tiled_mma_pv, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + ) + + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_mma) + + self.mma( + TileSchedulerCls, + KV_mbar_ptr, + kv_consumer_state, + q2k_block_sparse_num, + qk_mma_pipeline, + qk_mma_producer_state, + tiled_mma_qk, + p_pipeline, + p_consumer_state, + o_pipeline, + o_consumer_state, + pv_mma_pipeline, + pv_mma_producer_state, + tiled_mma_pv, + sQ, + sK, + sV, + sP, + tCtS, + tCtO, + ) + + if warp_idx == self.epilogue_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_epi) + + self.epilogue( + TileSchedulerCls, + st_O_pipeline, + st_O_consumer_state, + gO, + sO, + tma_atom_O, + ) + + if warp_idx in self.softmax_warp_ids: + cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + + num_q_blocks = cute.size(gQ, mode=[2]) + self.softmax( + TileSchedulerCls, + q2k_block_sparse_num, + q2k_block_sparse_index, + variable_block_sizes, + sVariable_block_sizes, + qk_mma_pipeline, + qk_mma_consumer_state, + p_pipeline, + p_producer_state, + num_q_blocks, + sP, + tCtS, + thr_mma_qk, + sm_scale_log2, + sScale, + corr_pipeline, + sFinal, + sFinal_pipeline, + sFinal_producer_state, + O_final_guard, + ) + + if warp_idx in self.correction_warp_ids: + cute.arch.warpgroup_reg_alloc(self.num_regs_correction) + + num_q_blocks = cute.size(gQ, mode=[2]) + + self.correction( + TileSchedulerCls, + q2k_block_sparse_num, + p_producer_state, + o_pipeline, + o_producer_state, + pv_mma_pipeline, + pv_mma_consumer_state, + st_O_pipeline, + st_O_producer_state, + tCtO, + thr_mma_pv, + sO, + corr_pipeline, + correction_consumer_state, + sScale, + sm_scale_log2, + num_q_blocks, + variable_block_sizes, + sFinal, + sFinal_pipeline, + sFinal_consumer_state, + O_final_guard, + gLSE, + ) + + if cutlass.const_expr(self.do_tmem_alloc): + cute.arch.sync_threads() + if warp_idx == 0: + cute.arch.relinquish_tmem_alloc_permit() + cute.arch.dealloc_tmem( + tmem_ptr, + self.tmem_alloc_cols, + ) + + @cute.jit + def load( + self, + TileSchedulerCls: Callable, + gQ: cute.Tensor, + gK: cute.Tensor, + gV: cute.Tensor, + KV_mbar_ptr: cutlass.Pointer, + kv_producer_state: pipeline.PipelineState, + q2k_block_sparse_num: cute.Tensor, + q2k_block_sparse_index: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + fake_tiled_mma_qk: cute.TiledMma, + fake_tiled_mma_pv: cute.TiledMma, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + ): + fake_thr_mma_qk = fake_tiled_mma_qk.get_slice(0) + fake_thr_mma_pv = fake_tiled_mma_pv.get_slice(0) + + def _get_sQ_cpy( + sQ: cute.Tensor, + ) -> cute.Tensor: + _tmp = cute.flatten(sQ) + mode0 = cute.get(_tmp.layout, mode=[0]) + mode0_split = cute.flat_divide(mode0, (self.block_m,)) + shape = ( + mode0_split.shape, + cute.get(_tmp.layout, mode=[1]).shape, + cute.get(_tmp.layout, mode=[2]).shape, + cute.get(_tmp.layout, mode=[3]).shape, + cute.get(_tmp.layout, mode=[4]).shape, + cute.get(_tmp.layout, mode=[5]).shape, + ) + stride = ( + mode0_split.stride, + cute.get(_tmp.layout, mode=[1]).stride, + cute.get(_tmp.layout, mode=[2]).stride, + cute.get(_tmp.layout, mode=[3]).stride, + cute.get(_tmp.layout, mode=[4]).stride, + cute.get(_tmp.layout, mode=[5]).stride, + ) + layout = cute.make_layout(shape, stride=stride) + layout = cute.flatten(layout) + layout = cute.select(layout, mode=[0, 2, 3, 4, 5, 1, 6]) + layout = cute.group_modes(layout, 0, 2) + cpy_tensor = cute.make_tensor(sQ.iterator, layout) + return cpy_tensor + + sQ_cpy = _get_sQ_cpy(sQ) + + def _get_gQ_cpy( + tSgQ: cute.Tensor, + ): + layout = cute.select(tSgQ.layout, mode=[2, 0, 1, 3, 4, 5, 6]) + layout = cute.flat_divide( + layout, + (4,), # it relates to 128B swizzle for 16-bit data + ) + layout = cute.select(layout, mode=[2, 3, 0, 1, 4, 5, 6, 7]) + cpy_tensor = cute.make_tensor(tSgQ.iterator, layout) + return cpy_tensor + + tSgQ = _get_gQ_cpy(fake_thr_mma_qk.partition_A(gQ)) + tTMAsQ, tTMAgQ = cpasync.tma_partition( + tma_atom_Q, + 0, + cute.make_layout(1), + cute.group_modes(sQ_cpy, 0, 3), + cute.group_modes(tSgQ, 0, 3), + ) + + def _load_Q( + tTMAgQ: cute.Tensor, + tTMAsQ: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + producer_mbar: cutlass.Pointer, + m_block_1st: cutlass.Int32, + m_block_2nd: cutlass.Int32, + head_idx: cutlass.Int32, + batch_idx: cutlass.Int32, + ): + tTMAgQ_1st = tTMAgQ[None, None, m_block_1st, None, head_idx, batch_idx] + tTMAgQ_2nd = tTMAgQ[None, None, m_block_2nd, None, head_idx, batch_idx] + + cute.copy( + tma_atom_Q, + tTMAgQ_1st[None, 0, 0], + tTMAsQ[None, 0, 0, 0], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_Q, + tTMAgQ_1st[None, 1, 0], + tTMAsQ[None, 1, 0, 0], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_Q, + tTMAgQ_2nd[None, 0, 0], + tTMAsQ[None, 0, 1, 0], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_Q, + tTMAgQ_2nd[None, 1, 0], + tTMAsQ[None, 1, 1, 0], + tma_bar_ptr=producer_mbar, + ) + + load_Q = partial( + _load_Q, + tTMAgQ=tTMAgQ, + tTMAsQ=tTMAsQ, + tma_atom_Q=tma_atom_Q, + ) + + def _get_sK_cpy( + sK: cute.Tensor, + ) -> cute.Tensor: + return _get_sQ_cpy(sK) + + sK_cpy = _get_sK_cpy(sK) + + def _get_gK_cpy( + tSgK: cute.Tensor, + ): + layout = cute.select(tSgK.layout, mode=[2, 0, 1, 3, 4, 5, 6]) + layout = cute.flat_divide( + layout, + (4,), # it relates to 128B swizzle for 16-bit data + ) + layout = cute.select(layout, mode=[2, 3, 0, 1, 4, 5, 6, 7]) + cpy_tensor = cute.make_tensor(tSgK.iterator, layout) + return cpy_tensor + + tSgK = _get_gK_cpy(fake_thr_mma_qk.partition_B(gK)) + tTMAsK, tTMAgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK_cpy, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + + def _load_K( + tTMAgK: cute.Tensor, + tTMAsK: cute.Tensor, + tma_atom_K: cute.CopyAtom, + producer_mbar: cutlass.Pointer, + buffer_idx: cutlass.Int32, + n: cutlass.Int32, + m_block_1st: cutlass.Int32, + m_block_2nd: cutlass.Int32, + valid_m_block_2nd: cutlass.Boolean, + head_idx: cutlass.Int32, + batch_idx: cutlass.Int32, + q2k_block_sparse_index: cute.Tensor, + num_k_blocks: cutlass.Int32, + ): + n_block_1st = q2k_block_sparse_index[batch_idx, head_idx, m_block_1st, n] + n_block_2nd = ( + q2k_block_sparse_index[batch_idx, head_idx, m_block_2nd, n] + if valid_m_block_2nd + else num_k_blocks + ) + + tTMAgK_1st = tTMAgK[None, None, n_block_1st, None, head_idx, batch_idx] + tTMAgK_2nd = tTMAgK[None, None, n_block_2nd, None, head_idx, batch_idx] + + cute.copy( + tma_atom_K, + tTMAgK_1st[None, 0, 0], + tTMAsK[None, 0, 0, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_K, + tTMAgK_1st[None, 1, 0], + tTMAsK[None, 1, 0, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_K, + tTMAgK_2nd[None, 0, 0], + tTMAsK[None, 0, 1, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_K, + tTMAgK_2nd[None, 1, 0], + tTMAsK[None, 1, 1, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + + load_K = partial( + _load_K, + tTMAgK=tTMAgK, + tTMAsK=tTMAsK, + tma_atom_K=tma_atom_K, + q2k_block_sparse_index=q2k_block_sparse_index, + ) + + def _get_sV_cpy( + sV: cute.Tensor, + ) -> cute.Tensor: + layout = cute.flatten(sV.layout) + layout = cute.select(layout, mode=[4, 0, 1, 2, 3, 5]) + layout = cute.flat_divide( + layout, + (4,), # it relates to 128B swizzle for 16-bit data + ) + layout = cute.select(layout, mode=[2, 3, 4, 5, 0, 1, 6]) + layout = cute.select(layout, mode=[0, 2, 3, 4, 5, 1, 6]) + layout = cute.group_modes(layout, 0, 2) + cpy_tensor = cute.make_tensor(sV.iterator, layout) + return cpy_tensor + + sV_cpy = _get_sV_cpy(sV) + + def _get_gV_cpy( + tOgV: cute.Tensor, + ): + layout = cute.append_ones(tOgV.layout) + layout = cute.select(layout, mode=[0, 7, 2, 1, 3, 4, 5, 6]) + cpy_tensor = cute.make_tensor(tOgV.iterator, layout) + return cpy_tensor + + tOgV = _get_gV_cpy(fake_thr_mma_pv.partition_B(gV)) + + tTMAsV, tTMAgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV_cpy, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + + def _load_V( + tTMAgV: cute.Tensor, + tTMAsV: cute.Tensor, + tma_atom_V: cute.CopyAtom, + producer_mbar: cutlass.Pointer, + buffer_idx: cutlass.Int32, + n: cutlass.Int32, + m_block_1st: cutlass.Int32, + m_block_2nd: cutlass.Int32, + valid_m_block_2nd: cutlass.Boolean, + num_v_blocks: cutlass.Int32, + head_idx: cutlass.Int32, + batch_idx: cutlass.Int32, + q2k_block_sparse_index: cute.Tensor, + ): + n_block_1st = q2k_block_sparse_index[batch_idx, head_idx, m_block_1st, n] + n_block_2nd = ( + q2k_block_sparse_index[batch_idx, head_idx, m_block_2nd, n] + if valid_m_block_2nd + else num_v_blocks + ) + + tTMAgV_1st = tTMAgV[None, None, None, n_block_1st, head_idx, batch_idx] + tTMAgV_2nd = tTMAgV[None, None, None, n_block_2nd, head_idx, batch_idx] + + cute.copy( + tma_atom_V, + tTMAgV_1st[None, 0, 0], + tTMAsV[None, 0, 0, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_V, + tTMAgV_1st[None, 1, 0], + tTMAsV[None, 0, 1, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_V, + tTMAgV_2nd[None, 0, 0], + tTMAsV[None, 1, 0, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_V, + tTMAgV_2nd[None, 1, 0], + tTMAsV[None, 1, 1, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + + load_V = partial( + _load_V, + tTMAgV=tTMAgV, + tTMAsV=tTMAsV, + tma_atom_V=tma_atom_V, + q2k_block_sparse_index=q2k_block_sparse_index, + ) + + num_q_blocks = cute.size(gQ, mode=[2]) + num_k_blocks = cute.size(gK, mode=[2]) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m2_idx, head_idx, batch_idx = work_tile.tile_idx + + m_block_1st = m2_idx * 2 + m_block_2nd = m_block_1st + 1 + valid_m_block_2nd: cutlass.Boolean = m_block_2nd < num_q_blocks + m_block_2nd = m_block_2nd if valid_m_block_2nd else num_q_blocks + + # NOTE: Assumption: different Q-tile has the same number of KV-blocks + _num_n_blocks = q2k_block_sparse_num[batch_idx, head_idx, m_block_1st] + + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_producer_state.index * 2 + 1, kv_producer_state.phase + ) + # load Q + load_Q( + producer_mbar=KV_mbar_ptr + kv_producer_state.index * 2, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + head_idx=head_idx, + batch_idx=batch_idx, + ) + load_K( + producer_mbar=KV_mbar_ptr + kv_producer_state.index * 2, + buffer_idx=kv_producer_state.index, + n=0, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + valid_m_block_2nd=valid_m_block_2nd, + num_k_blocks=num_k_blocks, + head_idx=head_idx, + batch_idx=batch_idx, + ) + + # load K0 + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_producer_state.index * 2, + self.tma_copy_bytes["Q"] + self.tma_copy_bytes["K"], + ) + kv_producer_state.advance() + + for n in cutlass.range(_num_n_blocks - 1): + # load Kn + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_producer_state.index * 2 + 1, kv_producer_state.phase + ) + + load_K( + producer_mbar=KV_mbar_ptr + kv_producer_state.index * 2, + buffer_idx=kv_producer_state.index, + n=n + 1, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + valid_m_block_2nd=valid_m_block_2nd, + num_k_blocks=num_k_blocks, + head_idx=head_idx, + batch_idx=batch_idx, + ) + + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_producer_state.index * 2, self.tma_copy_bytes["K"] + ) + kv_producer_state.advance() + # load Vn-1 + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_producer_state.index * 2 + 1, kv_producer_state.phase + ) + + load_V( + producer_mbar=KV_mbar_ptr + kv_producer_state.index * 2, + buffer_idx=kv_producer_state.index, + n=n, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + valid_m_block_2nd=valid_m_block_2nd, + num_v_blocks=num_k_blocks, + head_idx=head_idx, + batch_idx=batch_idx, + ) + + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_producer_state.index * 2, self.tma_copy_bytes["V"] + ) + kv_producer_state.advance() + + # load Vn + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_producer_state.index * 2 + 1, kv_producer_state.phase + ) + + load_V( + producer_mbar=KV_mbar_ptr + kv_producer_state.index * 2, + buffer_idx=kv_producer_state.index, + n=_num_n_blocks - 1, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + valid_m_block_2nd=valid_m_block_2nd, + num_v_blocks=num_k_blocks, + head_idx=head_idx, + batch_idx=batch_idx, + ) + + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_producer_state.index * 2, self.tma_copy_bytes["V"] + ) + kv_producer_state.advance() + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def mma( + self, + TileSchedulerCls: Callable, + KV_mbar_ptr: cutlass.Pointer, + kv_consumer_state: pipeline.PipelineState, + q2k_block_sparse_num: cute.Tensor, + qk_mma_pipeline: pipeline.PipelineUmmaAsync, + qk_mma_producer_state: pipeline.PipelineState, + tiled_mma_qk: cute.TiledMma, + p_pipeline: pipeline.PipelineAsyncUmma, + p_consumer_state: pipeline.PipelineState, + o_pipeline: pipeline.PipelineAsyncUmma, + o_consumer_state: pipeline.PipelineState, + pv_mma_pipeline: pipeline.PipelineUmmaAsync, + pv_mma_producer_state: pipeline.PipelineState, + tiled_mma_pv: cute.TiledMma, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sP: cute.Tensor, + tCtS: cute.Tensor, + tCtO: cute.Tensor, + ): + thr_mma_qk = tiled_mma_qk.get_slice(0) + thr_mma_pv = tiled_mma_pv.get_slice(0) + + tCsQ = thr_mma_qk.make_fragment_A(sQ) + tCsK = thr_mma_qk.make_fragment_B(sK) + + tCsP = None + if cutlass.const_expr(self.p_in_smem): + tCsP = thr_mma_pv.make_fragment_A(sP) + else: + # it resides in TMEM + tCsP = sP + tCsV = thr_mma_pv.make_fragment_B(sV) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m2_idx, head_idx, batch_idx = work_tile.tile_idx + + m_block_1st = m2_idx * 2 + _num_n_blocks = q2k_block_sparse_num[batch_idx, head_idx, m_block_1st] + + # Q @ K0 + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False) + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_consumer_state.index * 2, kv_consumer_state.phase + ) + qk_mma_pipeline.producer_acquire(qk_mma_producer_state) + for kblock_idx in cutlass.range_constexpr(cute.size(sQ, mode=[2])): + cute.gemm( + tiled_mma_qk, + cute.append_ones(tCtS[None, None, qk_mma_producer_state.index]), + tCsQ[None, None, kblock_idx, 0], + tCsK[None, None, kblock_idx, kv_consumer_state.index], + cute.append_ones(tCtS[None, None, qk_mma_producer_state.index]), + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + qk_mma_pipeline.producer_commit(qk_mma_producer_state) + qk_mma_producer_state.advance() + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_consumer_state.index * 2 + 1, 0 + ) + kv_consumer_state.advance() + + for n in cutlass.range(_num_n_blocks - 1): + # Q @ Kn + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False) + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_consumer_state.index * 2, kv_consumer_state.phase + ) + qk_mma_pipeline.producer_acquire(qk_mma_producer_state) + for kblock_idx in cutlass.range_constexpr(cute.size(sQ, mode=[2])): + cute.gemm( + tiled_mma_qk, + cute.append_ones(tCtS[None, None, qk_mma_producer_state.index]), + tCsQ[None, None, kblock_idx, 0], + tCsK[None, None, kblock_idx, kv_consumer_state.index], + cute.append_ones(tCtS[None, None, qk_mma_producer_state.index]), + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + + qk_mma_pipeline.producer_commit(qk_mma_producer_state) + qk_mma_producer_state.advance() + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_consumer_state.index * 2 + 1, 0 + ) + kv_consumer_state.advance() + + # P @ Vn-1 + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, n >= self.o_stage) + p_pipeline.consumer_wait(p_consumer_state) + o_pipeline.consumer_wait(o_consumer_state) + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_consumer_state.index * 2, kv_consumer_state.phase + ) + pv_mma_pipeline.producer_acquire(pv_mma_producer_state) + for kblock_idx in cutlass.range_constexpr(cute.size(sV, mode=[2])): + cute.gemm( + tiled_mma_pv, + cute.append_ones(tCtO[None, None, pv_mma_producer_state.index]), + tCsP[None, None, kblock_idx, p_consumer_state.index], + tCsV[None, None, kblock_idx, kv_consumer_state.index], + cute.append_ones(tCtO[None, None, pv_mma_producer_state.index]), + ) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) + pv_mma_pipeline.producer_commit(pv_mma_producer_state) + pv_mma_producer_state.advance() + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_consumer_state.index * 2 + 1, 0 + ) + kv_consumer_state.advance() + p_pipeline.consumer_release(p_consumer_state) + p_consumer_state.advance() + o_pipeline.consumer_release(o_consumer_state) + o_consumer_state.advance() + + # P @ Vn + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, (_num_n_blocks - 1) >= self.o_stage) + p_pipeline.consumer_wait(p_consumer_state) + o_pipeline.consumer_wait(o_consumer_state) + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_consumer_state.index * 2, kv_consumer_state.phase + ) + pv_mma_pipeline.producer_acquire(pv_mma_producer_state) + for kblock_idx in cutlass.range_constexpr(cute.size(sV, mode=[2])): + cute.gemm( + tiled_mma_pv, + cute.append_ones(tCtO[None, None, pv_mma_producer_state.index]), + tCsP[None, None, kblock_idx, p_consumer_state.index], + tCsV[None, None, kblock_idx, kv_consumer_state.index], + cute.append_ones(tCtO[None, None, pv_mma_producer_state.index]), + ) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) + pv_mma_pipeline.producer_commit(pv_mma_producer_state) + pv_mma_producer_state.advance() + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_consumer_state.index * 2 + 1, 0 + ) + kv_consumer_state.advance() + p_pipeline.consumer_release(p_consumer_state) + p_consumer_state.advance() + o_pipeline.consumer_release(o_consumer_state) + o_consumer_state.advance() + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def exp2f( + self, + value: cutlass.Float32, + ) -> cutlass.Float32: + return cute.arch.exp2(value) + + @cute.jit + def update_row_max( + self, + max_new: cutlass.Float32, + max_old: cutlass.Float32, + is_first: cutlass.Boolean, + sm_scale_log2: cutlass.Float32, + ): + _max_safe = max_new if max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + if cutlass.const_expr(not is_first): + acc_scale_ = (max_old - _max_safe) * sm_scale_log2 + acc_scale = self.exp2f(acc_scale_) + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + _max_safe = max_old + acc_scale = 1.0 + return _max_safe, acc_scale + + @cute.jit + def mask_then_cal_local_max( + self, + tensor: cute.Tensor, + right_bound: cutlass.Int32, + ) -> cutlass.Float32: + def clamp( + tensor: cute.Tensor, + index: cutlass.Int32, + right_bound: cutlass.Int32, + ) -> cutlass.Float32: + tensor[index] = tensor[index] if (index < right_bound) else -cutlass.Float32.inf + return tensor[index] + + _clamp = partial(clamp, tensor=tensor, right_bound=right_bound) + + if cutlass.const_expr(cute.size(tensor, mode=[0]) < 8): + _max = -cutlass.Float32.inf + for i in cutlass.range_constexpr(0, cute.size(tensor, mode=[0]), 2): + _max = ptx.max3f(_max, _clamp(index=i), _clamp(index=i + 1)) + return _max + else: + local_max = [ + ptx.max3f(_clamp(index=0), _clamp(index=1), -cutlass.Float32.inf), + ptx.max3f(_clamp(index=2), _clamp(index=3), -cutlass.Float32.inf), + ptx.max3f(_clamp(index=4), _clamp(index=5), -cutlass.Float32.inf), + ptx.max3f(_clamp(index=6), _clamp(index=7), -cutlass.Float32.inf), + ] + for i in cutlass.range_constexpr(8, cute.size(tensor, mode=[0]), 8): + local_max[0] = ptx.max3f(local_max[0], _clamp(index=i), _clamp(index=i + 1)) + local_max[1] = ptx.max3f(local_max[1], _clamp(index=i + 2), _clamp(index=i + 3)) + local_max[2] = ptx.max3f(local_max[2], _clamp(index=i + 4), _clamp(index=i + 5)) + local_max[3] = ptx.max3f(local_max[3], _clamp(index=i + 6), _clamp(index=i + 7)) + local_max[0] = cute.arch.fmax(local_max[0], local_max[1]) + return ptx.max3f(local_max[0], local_max[2], local_max[3]) + + @cute.jit + def cal_local_max( + self, + tensor: cute.Tensor, + ) -> cutlass.Float32: + # Full-block fast path: no column is masked, so read directly without + # the per-element bound test / -inf write that mask_then_cal_local_max + # performs. Mirrors the masked variant's reduction structure exactly. + if cutlass.const_expr(cute.size(tensor, mode=[0]) < 8): + _max = -cutlass.Float32.inf + for i in cutlass.range_constexpr(0, cute.size(tensor, mode=[0]), 2): + _max = ptx.max3f(_max, tensor[i], tensor[i + 1]) + return _max + else: + local_max = [ + ptx.max3f(tensor[0], tensor[1], -cutlass.Float32.inf), + ptx.max3f(tensor[2], tensor[3], -cutlass.Float32.inf), + ptx.max3f(tensor[4], tensor[5], -cutlass.Float32.inf), + ptx.max3f(tensor[6], tensor[7], -cutlass.Float32.inf), + ] + for i in cutlass.range_constexpr(8, cute.size(tensor, mode=[0]), 8): + local_max[0] = ptx.max3f(local_max[0], tensor[i], tensor[i + 1]) + local_max[1] = ptx.max3f(local_max[1], tensor[i + 2], tensor[i + 3]) + local_max[2] = ptx.max3f(local_max[2], tensor[i + 4], tensor[i + 5]) + local_max[3] = ptx.max3f(local_max[3], tensor[i + 6], tensor[i + 7]) + local_max[0] = cute.arch.fmax(local_max[0], local_max[1]) + return ptx.max3f(local_max[0], local_max[2], local_max[3]) + + @cute.jit + def cal_local_sum( + self, + tensor: cute.Tensor, + init: cutlass.Float32, + ) -> cutlass.Float32: + if cutlass.const_expr(cute.size(tensor, mode=[0]) < 8): + _sum = init + for i in cutlass.range_constexpr(cute.size(tensor, mode=[0])): + _sum += tensor[i] + return _sum + else: + local_sum = [ + cute.arch.add_packed_f32x2((init, 0.0), (tensor[0], tensor[1])), + (tensor[2], tensor[3]), + (tensor[4], tensor[5]), + (tensor[6], tensor[7]), + ] + for i in cutlass.range_constexpr(8, cute.size(tensor, mode=[0]), 8): + local_sum[0] = cute.arch.add_packed_f32x2( + local_sum[0], (tensor[i + 0], tensor[i + 1]) + ) + local_sum[1] = cute.arch.add_packed_f32x2( + local_sum[1], (tensor[i + 2], tensor[i + 3]) + ) + local_sum[2] = cute.arch.add_packed_f32x2( + local_sum[2], (tensor[i + 4], tensor[i + 5]) + ) + local_sum[3] = cute.arch.add_packed_f32x2( + local_sum[3], (tensor[i + 6], tensor[i + 7]) + ) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + + @cute.jit + def softmax_step( + self, + which: cutlass.Constexpr[cutlass.Int32], + stage: cutlass.Constexpr[cutlass.Int32], + n: cutlass.Int32, + q2k_block_sparse_index: cute.Tensor, + sVariable_block_sizes: cute.Tensor, + p_producer_state: pipeline.PipelineState, + sP: cute.Tensor, + tCcS_ld: cute.Tensor, + qk_mma_pipeline: pipeline.PipelineUmmaAsync, + p_pipeline: pipeline.PipelineAsyncUmma, + qk_mma_consumer_state: pipeline.PipelineState, + tiled_copy_t2r: cute.TiledCopy, + tiled_copy_r2t: Optional[cute.TiledCopy], + tCrS_ld_half_zeros: Optional[cute.Tensor], + tCtS_ld: cute.Tensor, + tCrS_ld: cute.Tensor, + tCrS_ld_half: cute.Tensor, + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block_1st: cutlass.Int32, + m_block_2nd: cutlass.Int32, + tidx: cutlass.Int32, + running_max: cute.Tensor, + running_sum: cute.Tensor, + sm_scale_log2: cutlass.Float32, + sScale: cute.Tensor, + correction_pipeline: pipeline.PipelineAsync, + is_first: cutlass.Boolean, + ): + n_size = 0 + if cutlass.const_expr(which == 0): + n_block_1st = q2k_block_sparse_index[ + batch_idx, head_idx, m_block_1st, n + ] # TODO: move to smem + n_size = sVariable_block_sizes[n_block_1st] + elif cutlass.const_expr(which == 1): + n_block_2nd = q2k_block_sparse_index[batch_idx, head_idx, m_block_2nd, n] + n_size = sVariable_block_sizes[n_block_2nd] + + qk_mma_pipeline.consumer_wait(qk_mma_consumer_state) + correction_pipeline.producer_acquire(p_producer_state) + cute.copy( + tiled_copy_t2r, + tCtS_ld[None, None, None, 0, which, qk_mma_consumer_state.index], + tCrS_ld, + ) + # calculate P + # Full-block fast path: when the KV block is full (n_size == tile width, + # i.e. BLOCK==64 at the focus shapes) the variable-block mask masks + # nothing every iteration. Skip the per-element bound test / -inf write + # and take the unmasked local-max path. The masked path is preserved for + # the partial-block case (n_size < tile width) so correctness holds. + # NOTE: predeclare _max before the dynamic branch — CuTe DSL does not + # propagate variables first bound inside dynamic control flow. + _max = -cutlass.Float32.inf + if n_size >= cute.size(tCrS_ld, mode=[0]): + _max = self.cal_local_max(tCrS_ld) + else: + _max = self.mask_then_cal_local_max(tCrS_ld, n_size) + _max_safe, _acc_scale = self.update_row_max( + max_new=_max, + is_first=is_first, + max_old=running_max[0], + sm_scale_log2=sm_scale_log2, + ) + if cutlass.const_expr(self.o_stage == 1): + sScale[tidx, p_producer_state.index] = _acc_scale + else: + acc_scale = cute.arch.exp2( + (sScale[1, tidx, p_producer_state.index] - _max_safe) * sm_scale_log2 + ) + sScale[0, tidx, p_producer_state.index] = acc_scale + sScale[1, tidx, p_producer_state.index] = _max_safe + correction_pipeline.producer_commit(p_producer_state) + + running_max[0] = _max_safe + minus_coeff = -_max_safe * sm_scale_log2 + for i in cutlass.range(0, cute.size(tCrS_ld.shape), 2, unroll_full=True): + tCrS_ld[i], tCrS_ld[i + 1] = cute.arch.fma_packed_f32x2( + (tCrS_ld[i], tCrS_ld[i + 1]), + (sm_scale_log2, sm_scale_log2), + (minus_coeff, minus_coeff), + ) + tCrS_ld[i] = self.exp2f(tCrS_ld[i]) + tCrS_ld[i + 1] = self.exp2f(tCrS_ld[i + 1]) + + # type conversion + tCrS_ld_half.store(tCrS_ld.load().to(tCrS_ld_half.element_type)) + + # update running sum + running_sum[0] *= _acc_scale + running_sum[0] = self.cal_local_sum(tCrS_ld, running_sum[0]) + + p_pipeline.producer_acquire(p_producer_state) + if cutlass.const_expr(self.p_in_smem): + # copy P to SMEM + cute.autovec_copy(tCrS_ld_half, sP[None, which, p_producer_state.index]) + else: + # copy P to TMEM + cute.copy( + tiled_copy_r2t, tCrS_ld_half, sP[None, None, None, 0, which, p_producer_state.index] + ) + if cutlass.const_expr(self.p_in_s): + cute.copy( + tiled_copy_r2t, + tCrS_ld_half_zeros, + sP[None, None, None, 0, which ^ 1, p_producer_state.index], + ) + p_pipeline.producer_commit(p_producer_state) + + qk_mma_pipeline.consumer_release(qk_mma_consumer_state) + + @cute.jit + def softmax_loop( + self, + TileSchedulerCls: Callable, + q2k_block_sparse_num: cute.Tensor, + p_producer_state: pipeline.PipelineState, + softmax_step: Callable, + qk_mma_consumer_state: pipeline.PipelineState, + sFinal: cute.Tensor, + sFinal_pipeline: pipeline.PipelineAsync, + sFinal_producer_state: pipeline.PipelineState, + num_q_blocks: cutlass.Int32, + tidx: cutlass.Int32, + which: cutlass.Constexpr[cutlass.Int32], + corr_pipeline: pipeline.PipelineAsync, + O_final_guard: cutlass.Pointer, + sScale: cute.Tensor, + ): + running_states = cute.make_rmem_tensor(cute.make_layout((2, 1)), self.acc_dtype) + running_sum = running_states[0, None] + running_max = running_states[1, None] + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + O_final_phase = 1 + while work_tile.is_valid_tile: + m2_idx, head_idx, batch_idx = work_tile.tile_idx + + m_block_1st = m2_idx * 2 + _num_n_blocks = q2k_block_sparse_num[batch_idx, head_idx, m_block_1st] + m_block_2nd = m_block_1st + 1 if m_block_1st + 1 < num_q_blocks else m_block_1st + + cute.arch.mbarrier_wait(O_final_guard, O_final_phase) + O_final_phase ^= 1 + + running_max.fill(-cutlass.Float32.inf) + running_sum.fill(0.0) + if cutlass.const_expr(self.o_stage != 1): + sScale[1, tidx, None].fill(-cutlass.Float32.inf) + + softmax_step( + n=0, + stage=0, + batch_idx=batch_idx, + head_idx=head_idx, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + p_producer_state=p_producer_state, + qk_mma_consumer_state=qk_mma_consumer_state, + running_max=running_max, + running_sum=running_sum, + which=which, + correction_pipeline=corr_pipeline, + is_first=True, + ) + p_producer_state.advance() + qk_mma_consumer_state.advance() + + for n in cutlass.range(1, _num_n_blocks): + softmax_step( + n=n, + stage=0, + batch_idx=batch_idx, + head_idx=head_idx, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + p_producer_state=p_producer_state, + qk_mma_consumer_state=qk_mma_consumer_state, + running_max=running_max, + running_sum=running_sum, + which=which, + correction_pipeline=corr_pipeline, + is_first=False, + ) + p_producer_state.advance() + qk_mma_consumer_state.advance() + + # put running stages to SMEM + sFinal_pipeline.producer_acquire(sFinal_producer_state) + cute.autovec_copy( + running_states, cute.append_ones(sFinal[None, tidx, sFinal_producer_state.index]) + ) + sFinal_pipeline.producer_commit(sFinal_producer_state) + sFinal_producer_state.advance() + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def softmax( + self, + TileSchedulerCls: Callable, + q2k_block_sparse_num: cute.Tensor, + q2k_block_sparse_index: cute.Tensor, + variable_block_sizes: cute.Tensor, + sVariable_block_sizes: cute.Tensor, + qk_mma_pipeline: pipeline.PipelineUmmaAsync, + qk_mma_consumer_state: pipeline.PipelineState, + p_pipeline: pipeline.PipelineAsyncUmma, + p_producer_state: pipeline.PipelineState, + num_q_blocks: cutlass.Int32, + sP: cute.Tensor, + tCtS: cute.Tensor, + thr_mma_qk: cute.core.ThrMma, + sm_scale_log2: cutlass.Float32, + sScale: cute.Tensor, + corr_pipeline: pipeline.PipelineAsync, + sFinal: cute.Tensor, + sFinal_pipeline: pipeline.PipelineAsync, + sFinal_producer_state: pipeline.PipelineState, + O_final_guard: cutlass.Pointer, + ): + tidx = cute.arch.thread_idx()[0] % (self.threads_per_warp * len(self.softmax_warp_ids)) + in_which = cute.arch.make_warp_uniform(tidx // self.block_m) + + copy_atom_t2r = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.block_n)), + self.acc_dtype, + ) + tS_load = cute.flat_divide( + tCtS[(None, None), 0, None], (self.mma_tiler_qk[0], self.block_n) + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tS_load[(None, None, 0, 0, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + + tCtS_ld = thr_copy_t2r.partition_S(tS_load) + + cS = cute.make_identity_tensor(self.mma_tiler_qk[:2]) + tCcS = thr_mma_qk.partition_C(cS) + cS_load = cute.flat_divide( + tCcS[(None, None), 0, None], (self.mma_tiler_qk[0], self.block_n) + ) + tCcS_ld = thr_copy_t2r.partition_D(cS_load) + tCrS_ld = cute.make_fragment(cute.select(tCcS_ld.shape, mode=[0, 1, 2]), self.acc_dtype) + tCrS_ld_half = cute.make_fragment(tCrS_ld.layout, sP.element_type) + + sP_cpy_slice = None + tiled_copy_r2t = None + tCrS_ld_half_zeros = None + if cutlass.const_expr(self.p_in_smem): + tv_layout = cute.make_ordered_layout( + (self.threads_per_wg, self.mma_tiler_qk[1], self.s_stage), (0, 1, 2) + ) + sP_cpy = cute.composition(sP, tv_layout) + sP_cpy_slice = cute.flatten(sP_cpy[tidx, None, None]) + else: + copy_atom_r2t = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(self.block_n // 2)), sP.element_type + ) + + def _get_tP_store(tP: cute.Tensor): + layout = cute.flatten(tP.layout) + mode1 = cute.make_layout( + cute.get(layout, mode=[1]).shape * cute.get(layout, mode=[3]).shape, + stride=cute.get(layout, mode=[1]).stride, + ) + shape = ( + cute.get(layout, mode=[0]).shape, + mode1.shape, + cute.get(layout, mode=[2]).shape, + cute.get(layout, mode=[4]).shape, + cute.get(layout, mode=[5]).shape, + ) + stride = ( + cute.get(layout, mode=[0]).stride, + mode1.stride, + cute.get(layout, mode=[2]).stride, + cute.get(layout, mode=[4]).stride, + cute.get(layout, mode=[5]).stride, + ) + layout = cute.make_layout(shape, stride=stride) + return cute.make_tensor(tP.iterator, layout) + + tP_store = _get_tP_store(sP) + + tiled_copy_r2t = tcgen05.make_tmem_copy(copy_atom_r2t, tP_store[(None, None, 0, 0, 0)]) + thr_copy_r2t = tiled_copy_r2t.get_slice(tidx) + + tCtP_st = thr_copy_r2t.partition_D(tP_store) + sP_cpy_slice = tCtP_st + if cutlass.const_expr(self.p_in_s): + tCrS_ld_half_zeros = cute.make_rmem_tensor( + tCrS_ld_half.layout, tCrS_ld_half.element_type + ) + tCrS_ld_half_zeros.fill(0.0) + else: + # P has its own TMEM + _zeros = cute.make_rmem_tensor(tCrS_ld_half.layout, tCrS_ld_half.element_type) + _zeros.fill(0.0) + for stage in cutlass.range_constexpr(self.s_stage): + cute.copy( + tiled_copy_r2t, _zeros, tCtP_st[None, None, None, 0, in_which ^ 1, stage] + ) + + _softmax_step = partial( + self.softmax_step, + q2k_block_sparse_index=q2k_block_sparse_index, + sVariable_block_sizes=sVariable_block_sizes, + sP=sP_cpy_slice, + tCcS_ld=tCcS_ld, + qk_mma_pipeline=qk_mma_pipeline, + p_pipeline=p_pipeline, + tiled_copy_t2r=tiled_copy_t2r, + tiled_copy_r2t=tiled_copy_r2t, + tCrS_ld_half_zeros=tCrS_ld_half_zeros, + tCtS_ld=tCtS_ld, + tCrS_ld=tCrS_ld, + tCrS_ld_half=tCrS_ld_half, + tidx=tidx, + sm_scale_log2=sm_scale_log2, + sScale=sScale, + ) + + _softmax_loop = partial( + self.softmax_loop, + TileSchedulerCls=TileSchedulerCls, + q2k_block_sparse_num=q2k_block_sparse_num, + p_producer_state=p_producer_state, + softmax_step=_softmax_step, + qk_mma_consumer_state=qk_mma_consumer_state, + sFinal=sFinal, + sFinal_pipeline=sFinal_pipeline, + sFinal_producer_state=sFinal_producer_state, + num_q_blocks=num_q_blocks, + tidx=tidx, + O_final_guard=O_final_guard, + sScale=sScale, + ) + + if cutlass.const_expr(self.p_in_smem): + sP_cpy = cute.composition( + sP, + cute.make_ordered_layout( + (self.threads_per_wg, self.mma_tiler_qk[1] * self.s_stage), (0, 1) + ), + ) + sP_cpy_slice = cute.flatten(sP_cpy[tidx, None]) + _zeros = cute.make_rmem_tensor((self.block_n, 1), sP_cpy_slice.element_type) + _zeros.fill(0.0) + + for stage in cutlass.range_constexpr(self.s_stage): + cute.autovec_copy(_zeros, sP_cpy_slice[None, stage * 2 + (in_which ^ 1)]) + if in_which == 0: + _softmax_loop(which=0, corr_pipeline=corr_pipeline) + else: + _softmax_loop(which=1, corr_pipeline=corr_pipeline) + + @cute.jit + def correction_loop( + self, + TileSchedulerCls: Callable, + variable_block_sizes: cute.Tensor, + q2k_block_sparse_num: cute.Tensor, + correction_pipeline: pipeline.PipelineAsync, + p_producer_state: pipeline.PipelineState, + o_pipeline: pipeline.PipelineAsyncUmma, + o_producer_state: pipeline.PipelineState, + correction_consumer_state: pipeline.PipelineState, + corr_ld_repeat: cutlass.Constexpr[cutlass.Int32], + corr_tiled_copy_t2r: cute.TiledCopy, + corr_tCtO_ld: cute.Tensor, + corr_tCrO_ld: cute.Tensor, + corr_tiled_copy_r2t: cute.TiledCopy, + pv_mma_pipeline: pipeline.PipelineUmmaAsync, + pv_mma_consumer_state: pipeline.PipelineState, + st_O_pipeline: pipeline.PipelineAsync, + st_O_producer_state: pipeline.PipelineState, + tv_layout: cute.Layout, + sO: cute.Tensor, + sScale: cute.Tensor, + wb_tCcO_ld: cute.Tensor, + wb_tCrO_ld: cute.Tensor, + sFinal: cute.Tensor, + sFinal_pipeline: pipeline.PipelineAsync, + sFinal_consumer_state: pipeline.PipelineState, + wb_ld_repeat: cutlass.Constexpr[cutlass.Int32], + wb_tCrO_reduction: Optional[cute.Tensor], + wb_tiled_copy_t2r: cute.TiledCopy, + wb_tCtO_ld: cute.Tensor, + wb_tCrO_ld_half: cute.Tensor, + tidx: cutlass.Int32, + sm_scale_log2: cutlass.Float32, + num_q_blocks: cutlass.Int32, + O_final_guard: cutlass.Pointer, + gLSE: cute.Tensor, + in_which: cutlass.Int32, + ): + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m2_idx, head_idx, batch_idx = work_tile.tile_idx + + _num_n_blocks = q2k_block_sparse_num[batch_idx, head_idx, m2_idx * 2] + + # non-rescaling + for n in cutlass.range_constexpr(self.o_stage): + correction_pipeline.consumer_wait(correction_consumer_state) + correction_pipeline.consumer_release(correction_consumer_state) + p_producer_state.advance() + correction_consumer_state.advance() + + o_pipeline.producer_acquire(o_producer_state) + o_pipeline.producer_commit(o_producer_state) + o_producer_state.advance() + + pv_mma_pipeline.consumer_wait(pv_mma_consumer_state) + pv_mma_pipeline.consumer_release(pv_mma_consumer_state) + pv_mma_consumer_state.advance() + + for n in cutlass.range(self.o_stage, _num_n_blocks): + correction_pipeline.consumer_wait(correction_consumer_state) + o_pipeline.producer_acquire(o_producer_state) + + acc_scale = 0.0 + if cutlass.const_expr(self.o_stage == 1): + acc_scale = sScale[tidx, p_producer_state.index] + else: + acc_scale = sScale[0, tidx, p_producer_state.index] + should_rescale = cute.arch.vote_ballot_sync(acc_scale < 1.0) != 0 + if should_rescale: + for repeat in cutlass.range(corr_ld_repeat): + cute.copy( + corr_tiled_copy_t2r, + corr_tCtO_ld[None, None, None, 0, repeat, pv_mma_consumer_state.index], + corr_tCrO_ld, + ) + + # apply correction + for i in cutlass.range_constexpr(0, cute.size(corr_tCrO_ld, mode=[0]), 2): + corr_tCrO_ld[i], corr_tCrO_ld[i + 1] = cute.arch.mul_packed_f32x2( + (corr_tCrO_ld[i], corr_tCrO_ld[i + 1]), + (acc_scale, acc_scale), + ) + + cute.copy( + corr_tiled_copy_r2t, + corr_tCrO_ld, + corr_tCtO_ld[None, None, None, 0, repeat, pv_mma_consumer_state.index], + ) + o_pipeline.producer_commit(o_producer_state) + o_producer_state.advance() + + correction_pipeline.consumer_release(correction_consumer_state) + p_producer_state.advance() + correction_consumer_state.advance() + + # currently, we don't need pv_mma_result + pv_mma_pipeline.consumer_wait(pv_mma_consumer_state) + pv_mma_pipeline.consumer_release(pv_mma_consumer_state) + pv_mma_consumer_state.advance() + + # ----- Correction Epilogue: Store O to SMEM ----- # + st_O_pipeline.producer_acquire(st_O_producer_state) + running_states = cute.make_rmem_tensor(cute.make_layout((2, 1)), self.acc_dtype) + running_sum = running_states[0, None] + running_max = running_states[1, None] + + sFinal_pipeline.consumer_wait(sFinal_consumer_state) + cute.autovec_copy( + cute.append_ones(sFinal[None, tidx, sFinal_consumer_state.index]), running_states + ) + sFinal_pipeline.consumer_release(sFinal_consumer_state) + sFinal_consumer_state.advance() + + # calculate LSE + lse = ( + running_max[0] * sm_scale_log2 + cute.math.log2(running_sum[0], fastmath=True) + ) * self.LN2 + + gLSE_1st = gLSE[None, m2_idx * 2, head_idx, batch_idx] + gLSE_2nd = gLSE[None, m2_idx * 2 + 1, head_idx, batch_idx] + if in_which == 0: + gLSE_1st[tidx] = lse + elif m2_idx * 2 + 1 < num_q_blocks: + gLSE_2nd[tidx - self.block_m] = lse + + # calculate scale + scale = cute.arch.rcp_approx(running_sum[0]) + + if cutlass.const_expr(self.o_stage == 1): + for repeat in cutlass.range(wb_ld_repeat): + cute.copy( + wb_tiled_copy_t2r, + wb_tCtO_ld[None, None, None, 0, repeat, pv_mma_consumer_state.index], + wb_tCrO_ld, + ) + # scale + for i in cutlass.range_constexpr(0, cute.size(wb_tCrO_ld, mode=[0]), 2): + wb_tCrO_ld[i], wb_tCrO_ld[i + 1] = cute.arch.mul_packed_f32x2( + (wb_tCrO_ld[i], wb_tCrO_ld[i + 1]), + (scale, scale), + ) + # type conversion + wb_tCrO_ld_half.store(wb_tCrO_ld.load().to(wb_tCrO_ld_half.element_type)) + + # store to SMEM + cute.autovec_copy(wb_tCrO_ld_half, sO[None, repeat, st_O_producer_state.index]) + else: + stage_in_use = pv_mma_consumer_state.index - 1 + stage_in_use = stage_in_use if stage_in_use != -1 else self.o_stage - 1 + for repeat in cutlass.range(wb_ld_repeat): + wb_tCrO_reduction.fill(0.0) + for stage in cutlass.range(cutlass.min(self.o_stage, _num_n_blocks)): + _stage_id = (stage_in_use + stage) % self.o_stage + cute.copy( + wb_tiled_copy_t2r, + wb_tCtO_ld[None, None, None, 0, repeat, _stage_id], + wb_tCrO_ld, + ) + + _scale = scale + if cutlass.const_expr(self.o_stage != 1): + acc_scale = cute.arch.exp2( + (sScale[1, tidx, _stage_id] - running_max[0]) * sm_scale_log2 + ) + _scale *= acc_scale + + # scale and reduction + for i in cutlass.range_constexpr(0, cute.size(wb_tCrO_ld, mode=[0]), 2): + wb_tCrO_reduction[i], wb_tCrO_reduction[i + 1] = ( + cute.arch.fma_packed_f32x2( + (wb_tCrO_ld[i], wb_tCrO_ld[i + 1]), + (_scale, _scale), + (wb_tCrO_reduction[i], wb_tCrO_reduction[i + 1]), + ) + ) + # type conversion + wb_tCrO_ld_half.store(wb_tCrO_reduction.load().to(wb_tCrO_ld_half.element_type)) + + # store to SMEM + cute.autovec_copy(wb_tCrO_ld_half, sO[None, repeat, st_O_producer_state.index]) + + # NOTE: It is safe now to do next wave + cute.arch.mbarrier_arrive_and_expect_tx(O_final_guard, 0) + + cute.arch.fence_proxy( + "async.shared", + space="cta", + ) + st_O_pipeline.producer_commit(st_O_producer_state) + st_O_producer_state.advance() + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def correction( + self, + TileSchedulerCls: Callable, + q2k_block_sparse_num: cute.Tensor, + p_producer_state: pipeline.PipelineState, + o_pipeline: pipeline.PipelineAsyncUmma, + o_producer_state: pipeline.PipelineState, + pv_mma_pipeline: pipeline.PipelineUmmaAsync, + pv_mma_consumer_state: pipeline.PipelineState, + st_O_pipeline: pipeline.PipelineAsync, + st_O_producer_state: pipeline.PipelineState, + tCtO: cute.Tensor, + thr_mma_pv: cute.core.ThrMma, + sO: cute.Tensor, + corr_pipeline: pipeline.PipelineAsync, + correction_consumer_state: pipeline.PipelineState, + sScale: cute.Tensor, + sm_scale_log2: cutlass.Float32, + num_q_blocks: cutlass.Int32, + variable_block_sizes: cute.Tensor, + sFinal: cute.Tensor, + sFinal_pipeline: pipeline.PipelineAsync, + sFinal_consumer_state: pipeline.PipelineState, + O_final_guard: cutlass.Pointer, + gLSE: cute.Tensor, + ): + tidx = cute.arch.thread_idx()[0] % (self.threads_per_warp * len(self.correction_warp_ids)) + in_which = cute.arch.make_warp_uniform(tidx // self.block_m) + # A-tmem: widen correction TMEM<->reg copy 32->64 cols (halve LDTM/STTM count) + corr_ld_inst: int = 64 + corr_ld_repeat: int = cute.ceil_div(self.mma_tiler_pv[1], corr_ld_inst) + + corr_copy_atom_t2r = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp( + tcgen05.copy.Repetition(corr_ld_inst), + ), + self.acc_dtype, + ) + corr_tO_load = cute.flat_divide( + tCtO[(None, None), 0, None], (self.mma_tiler_pv[0], corr_ld_inst) + ) + corr_tiled_copy_t2r = tcgen05.make_tmem_copy( + corr_copy_atom_t2r, corr_tO_load[None, None, 0, 0, 0] + ) + corr_thr_copy_t2r = corr_tiled_copy_t2r.get_slice(tidx) + corr_tCtO_ld = corr_thr_copy_t2r.partition_S(corr_tO_load) + + corr_copy_atom_r2t = cute.make_copy_atom( + tcgen05.copy.St32x32bOp( + tcgen05.copy.Repetition(corr_ld_inst), + ), + self.acc_dtype, + ) + corr_tiled_copy_r2t = tcgen05.make_tmem_copy( + corr_copy_atom_r2t, corr_tO_load[None, None, 0, 0, 0] + ) + + cO = cute.make_identity_tensor(self.mma_tiler_pv[:2]) + tCcO = thr_mma_pv.partition_C(cO) + corr_cO_load = cute.flat_divide( + tCcO[(None, None), 0, None], (self.mma_tiler_pv[0], corr_ld_inst) + ) + corr_tCcO_ld = corr_thr_copy_t2r.partition_D(corr_cO_load) + corr_tCrO_ld = cute.make_fragment( + cute.select(corr_tCcO_ld.shape, mode=[0, 1, 2]), self.acc_dtype + ) + + # for writeback + wb_ld_inst: int = 64 + wb_ld_repeat: int = cute.ceil_div(self.mma_tiler_pv[1], wb_ld_inst) + + wb_copy_atom_t2r = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp( + tcgen05.copy.Repetition(wb_ld_inst), + ), + self.acc_dtype, + ) + wb_tO_load = cute.flat_divide( + tCtO[(None, None), 0, None], (self.mma_tiler_pv[0], wb_ld_inst) + ) + wb_tiled_copy_t2r = tcgen05.make_tmem_copy( + wb_copy_atom_t2r, wb_tO_load[(None, None, 0, 0, 0)] + ) + wb_thr_copy_t2r = wb_tiled_copy_t2r.get_slice(tidx) + wb_tCtO_ld = wb_thr_copy_t2r.partition_S(wb_tO_load) + + wb_cO_load = cute.flat_divide( + tCcO[(None, None), 0, None], (self.mma_tiler_pv[0], wb_ld_inst) + ) + wb_tCcO_ld = wb_thr_copy_t2r.partition_D(wb_cO_load) + wb_tCrO_ld = cute.make_fragment( + cute.select(wb_tCcO_ld.shape, mode=[0, 1, 2]), self.acc_dtype + ) + wb_tCrO_ld_half = cute.make_fragment(wb_tCrO_ld.layout, sO.element_type) + tv_layout = cute.make_ordered_layout( + (self.threads_per_wg, self.mma_tiler_pv[1], self.s_stage), (0, 1, 2) + ) + sO_cpy = cute.composition(sO, tv_layout) + sO_cpy_slice = cute.flatten(sO_cpy[tidx, None, None]) + + wb_tCrO_reduction = None + if cutlass.const_expr(self.o_stage != 1): + wb_tCrO_reduction = cute.make_fragment(wb_tCrO_ld.layout, self.acc_dtype) + + _correction_loop = partial( + self.correction_loop, + TileSchedulerCls=TileSchedulerCls, + variable_block_sizes=variable_block_sizes, + q2k_block_sparse_num=q2k_block_sparse_num, + p_producer_state=p_producer_state, + o_pipeline=o_pipeline, + o_producer_state=o_producer_state, + correction_consumer_state=correction_consumer_state, + corr_ld_repeat=corr_ld_repeat, + corr_tiled_copy_t2r=corr_tiled_copy_t2r, + corr_tCtO_ld=corr_tCtO_ld, + corr_tCrO_ld=corr_tCrO_ld, + corr_tiled_copy_r2t=corr_tiled_copy_r2t, + pv_mma_pipeline=pv_mma_pipeline, + pv_mma_consumer_state=pv_mma_consumer_state, + st_O_pipeline=st_O_pipeline, + st_O_producer_state=st_O_producer_state, + tv_layout=tv_layout, + sO=sO_cpy_slice, + sScale=sScale, + wb_tCcO_ld=wb_tCcO_ld, + wb_tCrO_ld=wb_tCrO_ld, + sFinal=sFinal, + sFinal_pipeline=sFinal_pipeline, + sFinal_consumer_state=sFinal_consumer_state, + wb_ld_repeat=wb_ld_repeat, + wb_tCrO_reduction=wb_tCrO_reduction, + wb_tiled_copy_t2r=wb_tiled_copy_t2r, + wb_tCtO_ld=wb_tCtO_ld, + wb_tCrO_ld_half=wb_tCrO_ld_half, + tidx=tidx, + sm_scale_log2=sm_scale_log2, + num_q_blocks=num_q_blocks, + O_final_guard=O_final_guard, + gLSE=gLSE, + in_which=in_which, + ) + + _correction_loop(correction_pipeline=corr_pipeline) + + @cute.jit + def epilogue( + self, + TileSchedulerCls: Callable, + st_O_pipeline: pipeline.PipelineAsync, + st_O_consumer_state: pipeline.PipelineState, + gO: cute.Tensor, + sO: cute.Tensor, + tma_atom_O: cute.CopyAtom, + ): + def _get_sO_cpy( + sO: cute.Tensor, + ) -> cute.Tensor: + layout = cute.flatten(sO.layout) + layout = cute.select(layout, mode=[1, 0, 2, 3, 4, 5]) + layout = cute.flat_divide( + layout, + (8,), # it relates to 128B swizzle for 16-bit data + ) + layout = cute.select(layout, mode=[2, 0, 3, 4, 1, 5, 6]) + layout = cute.group_modes(layout, 0, 2) + layout = cute.group_modes(layout, 4, cute.rank(layout)) + cpy_tensor = cute.make_tensor(sO.iterator, layout) + return cpy_tensor + + sO_cpy = _get_sO_cpy(sO) + + def _get_gO_cpy( + gO: cute.Tensor, + ) -> cute.Tensor: + layout = cute.flat_divide(gO.layout, (64, 64)) + layout = cute.group_modes(layout, 0, 2) + cpy_tensor = cute.make_tensor(gO.iterator, layout) + return cpy_tensor + + _gO = _get_gO_cpy(gO) + tTMAsO, tTMAgO = cute.nvgpu.cpasync.tma_partition( + tma_atom_O, + 0, + cute.make_layout(1), + cute.group_modes(sO_cpy, 0, 2), + cute.group_modes(_gO, 0, 2), + ) + + num_o_blocks = cute.size(gO, mode=[2]) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m2_idx, head_idx, batch_idx = work_tile.tile_idx + + m_block_1st = m2_idx * 2 + m_block_2nd = m_block_1st + 1 + m_block_2nd = m_block_2nd if m_block_2nd < num_o_blocks else num_o_blocks + + tTMAgO_1st = tTMAgO[None, None, m_block_1st, None, head_idx, batch_idx] + tTMAgO_2nd = tTMAgO[None, None, m_block_2nd, None, head_idx, batch_idx] + + cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True) + st_O_pipeline.consumer_release(st_O_consumer_state) + st_O_pipeline.consumer_wait(st_O_consumer_state) + + cute.copy( + tma_atom_O, tTMAsO[None, 0, 0, st_O_consumer_state.index], tTMAgO_1st[None, 0, 0] + ) + cute.copy( + tma_atom_O, tTMAsO[None, 1, 0, st_O_consumer_state.index], tTMAgO_1st[None, 1, 0] + ) + cute.copy( + tma_atom_O, tTMAsO[None, 0, 1, st_O_consumer_state.index], tTMAgO_2nd[None, 0, 0] + ) + cute.copy( + tma_atom_O, tTMAsO[None, 1, 1, st_O_consumer_state.index], tTMAgO_2nd[None, 1, 0] + ) + + cute.arch.cp_async_bulk_commit_group() + st_O_consumer_state.advance() + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + cute.arch.cp_async_bulk_wait_group(0, read=True) diff --git a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/interface.py b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/interface.py new file mode 100644 index 000000000000..13a9e08d1d5e --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/interface.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Torch entry point for the CuTe DSL block-sparse attention forward. + +Blackwell (sm_100) fast path for VSA's fine stage. The kernel JIT-compiles +on first call and is cached per process; the caller +(CuTeDSLAttention._forward_vsa) falls back to dense SDPA when the +device/dtype/head_dim envelope is not met. +""" + +from __future__ import annotations + +import math +import os +from typing import Tuple + +import torch + +try: + import cuda.bindings.driver as _cuda + import cutlass.cute as cute + from cutlass.cute.runtime import from_dlpack + + from .block_sparse_attn_dsl_fwd import ( + VideoSparseAttentionForwardGroup2QInterleaveKV as VideoSparseAttentionForward, + ) + + CUTE_AVAILABLE = True +except ImportError: # cuda-bindings / cutlass-dsl not installed + _cuda = None + cute = None + from_dlpack = None + VideoSparseAttentionForward = None + CUTE_AVAILABLE = False + + +__all__ = [ + "CUTE_AVAILABLE", + "is_cute_supported", + "block_sparse_attn_from_indices_cute", +] + + +# JIT compile is multi-second; reuse aggressively. +_COMPILE_CACHE: dict = {} + + +def is_cute_supported(q: torch.Tensor) -> bool: + """Capability check for the CuTe path. Set TLLM_VSA_DISABLE_CUTE=1 to force off.""" + # Kernel asserts head_dim==128, block_m==block_n==64, fp16/bf16, sm_100+. + if os.environ.get("TLLM_VSA_DISABLE_CUTE", "").strip() in ("1", "true", "True"): + return False + if not CUTE_AVAILABLE: + return False + if not q.is_cuda: + return False + if q.dtype not in (torch.float16, torch.bfloat16): + return False + if q.shape[-1] != 128: + return False + cap = torch.cuda.get_device_capability(q.device) + return cap[0] >= 10 + + +def _to_cute_tensor(t: torch.Tensor): + """Convert a 4-D BHSD tensor into a CuTe tensor with dynamic B/H/S strides.""" + # Head-dim mode left static (=128) since the kernel specializes on it. + return ( + from_dlpack(t.detach(), assumed_align=128) + .mark_compact_shape_dynamic(mode=0, stride_order=t.dim_order()) + .mark_compact_shape_dynamic(mode=1, stride_order=t.dim_order()) + .mark_compact_shape_dynamic(mode=2, stride_order=t.dim_order()) + ) + + +@torch.compiler.disable +def block_sparse_attn_from_indices_cute( + q: torch.Tensor, # [B, H, S, D] fp16/bf16, sm_100+ + k: torch.Tensor, # [B, H, S, D] + v: torch.Tensor, # [B, H, S, D] + q2k_idx: torch.Tensor, # [B, H, num_q_blk, K] int32 + q2k_num: torch.Tensor, # [B, H, num_q_blk] int32 + variable_block_sizes: torch.Tensor, # [num_q_blk] int32 + sm_scale: float | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Run VSA fine-stage attention on Blackwell using the CuTe DSL kernel. + + Returns: + out: [B, H, S, D] same dtype as Q. + lse: [B, H, S] fp32. + """ + # Disabled for torch.compile: cuda.bindings.driver.CUstream + cute.compile + # are not Dynamo-traceable, and torch.cuda.current_stream() returns a + # proxy without .cuda_stream inside compiled regions. + if not CUTE_AVAILABLE: + raise RuntimeError( + "block_sparse_attn_from_indices_cute called but cuda.bindings or " + "cutlass-dsl is not importable." + ) + + num_q_blk = variable_block_sizes.shape[0] + if num_q_blk > VideoSparseAttentionForward.MAX_INDICES: + raise ValueError( + f"variable_block_sizes has {num_q_blk} entries but the CuTe kernel " + f"supports at most {VideoSparseAttentionForward.MAX_INDICES} " + "(SMEM-allocated sVariable_block_sizes buffer). Lower video " + "resolution/length or fall back to dense SDPA." + ) + + B, H, T, D = q.shape + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + out = torch.empty_like(q) + lse = torch.empty((B, H, T), device=q.device, dtype=torch.float32) + + cuda_stream = _cuda.CUstream(torch.cuda.current_stream(q.device).cuda_stream) + + q_packed = _to_cute_tensor(q) + k_packed = _to_cute_tensor(k) + v_packed = _to_cute_tensor(v) + o_packed = _to_cute_tensor(out) + + lse_packed = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) + idx_packed = from_dlpack(q2k_idx.detach()).mark_layout_dynamic(leading_dim=3) + num_packed = from_dlpack(q2k_num.detach()).mark_layout_dynamic(leading_dim=2) + var_packed = from_dlpack(variable_block_sizes.detach()).mark_layout_dynamic(leading_dim=0) + + # Full q2k_idx shape is part of the key: mark_layout_dynamic(leading_dim=3) + # only makes the innermost stride dynamic; B-loop bound and inner strides + # are baked in at compile time, so each shape needs its own compiled kernel. + compile_key = (D, q.dtype, float(sm_scale)) + tuple(q2k_idx.shape) + compiled = _COMPILE_CACHE.get(compile_key) + if compiled is None: + fwd_kernel = VideoSparseAttentionForward(block_m=64, block_n=64, headdim=D) + compiled = cute.compile( + fwd_kernel, + q_packed, + k_packed, + v_packed, + sm_scale, + o_packed, + lse_packed, + idx_packed, + num_packed, + var_packed, + cuda_stream, + ) + _COMPILE_CACHE[compile_key] = compiled + + compiled( + q_packed, + k_packed, + v_packed, + sm_scale, + o_packed, + lse_packed, + idx_packed, + num_packed, + var_packed, + cuda_stream, + ) + return out, lse diff --git a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/ptx.py b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/ptx.py new file mode 100644 index 000000000000..12fa15bc5930 --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/ptx.py @@ -0,0 +1,377 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass._mlir.dialects import llvm, nvvm +from cutlass.cute import typing as cutlass_typing +from cutlass.cutlass_dsl import dsl_user_op + + +@dsl_user_op +def warp_reduction_fmax( + val: cutlass.Float32, + mask: cutlass.Int32 = 0xFFFFFFFF, + *, + loc=None, + ip=None, +) -> cutlass.Float32: + return cutlass.Float32( + llvm.inline_asm( + cutlass_typing.Float32.mlir_type, + [ + cutlass_typing.Float32(val).ir_value(loc=loc, ip=ip), + cutlass_typing.Int32(mask).ir_value(loc=loc, ip=ip), + ], + """{\n\t + redux.sync.max.f32 $0, $1, $2;\n\t + \n\t}""", + "=f,f,r", + ) + ) + + +@dsl_user_op +def __cvta_generic_to_shared( + ptr: cutlass.Pointer, + *, + loc=None, + ip=None, +) -> cutlass.Uint32: + # NOTE: assume the SMEM pointer fits in a 32-bit register + return cutlass.Uint32( + llvm.inline_asm( + cutlass_typing.Uint32.mlir_type, + [ + cutlass_typing.Int32(ptr.toint()).ir_value(loc=loc, ip=ip), + ], + """{\n\t + mov.u32 $0, $1; + \n\t}""", + "=r, r", + ) + ) + + +@dsl_user_op +def atomicAdd_f32( + val: cutlass.Float32, + ptr: cutlass.Pointer, + *, + loc=None, + ip=None, +): + if cutlass.const_expr(ptr.memspace == cutlass_typing.AddressSpace.smem): + ptr = __cvta_generic_to_shared(ptr, loc=loc, ip=ip) + llvm.inline_asm( + None, + [ + cutlass_typing.Uint32(ptr).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(val).ir_value(loc=loc, ip=ip), + ], + """{\n\t + atom.relaxed.shared::cta.cta.add.f32 _, [$0], $1;\n\t + \n\t}""", + "r, f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + else: + llvm.inline_asm( + None, + [ + cutlass_typing.Int64(ptr.toint()).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(val).ir_value(loc=loc, ip=ip), + ], + """{\n\t + atom.relaxed.shared::cta.cta.add.f32 _, [$0], $1;\n\t + \n\t}""", + "l, f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def atomicMax_f32( + val: cutlass.Float32, + ptr: cutlass.Pointer, + *, + loc=None, + ip=None, +): + val_i32 = llvm.bitcast( + cutlass_typing.Int32.mlir_type, val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + if cutlass.const_expr(ptr.memspace == cutlass_typing.AddressSpace.smem): + ptr = __cvta_generic_to_shared(ptr, loc=loc, ip=ip) + llvm.inline_asm( + None, + [ + cutlass_typing.Uint32(ptr).ir_value(loc=loc, ip=ip), + cutlass_typing.Int32(val_i32).ir_value(loc=loc, ip=ip), + ], + """{\n\t + atom.relaxed.shared::cta.cta.max.s32 _, [$0], $1;\n\t + \n\t}""", + "r, r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + else: + llvm.inline_asm( + None, + [ + cutlass_typing.Int64(ptr.toint()).ir_value(loc=loc, ip=ip), + cutlass_typing.Int32(val_i32).ir_value(loc=loc, ip=ip), + ], + """{\n\t + atom.relaxed.shared::cta.cta.max.s32 _, [$0], $1;\n\t + \n\t}""", + "l, r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def exp2f( + val: cutlass.Float32, + *, + loc=None, + ip=None, +): + return cutlass.Float32( + llvm.inline_asm( + cutlass_typing.Float32.mlir_type, + [ + cutlass_typing.Float32(val).ir_value(loc=loc, ip=ip), + ], + """{\n\t + .reg .f32 f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11;\n\t + .reg .s32 r1, r2, r3;\n\t + max.ftz.f32 f1, $1, 0fC2FE0000;\n\t + mov.f32 f3, 0f4B400000;\n\t + add.rm.ftz.f32 f4, f1, f3;\n\t + sub.rn.ftz.f32 f5, f4, f3;\n\t + sub.rn.ftz.f32 f6, f1, f5;\n\t + mov.f32 f7, 0f3D9DF09D;\n\t + mov.f32 f8, 0f3E6906A4;\n\t + mov.f32 f9, 0f3F31F519;\n\t + mov.f32 f10, 0f3F800000;\n\t + fma.rn.ftz.f32 f11, f6, f7, f8;\n\t + fma.rn.ftz.f32 f11, f11, f6, f9;\n\t + fma.rn.ftz.f32 f11, f11, f6, f10;\n\t + mov.b32 r3, f11;\n\t + shl.b32 r1, f4, 23;\n\t + add.s32 r2, r1, r3;\n\t + mov.b32 $0, r2;\n\t + \n\t}""", + "=f, f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def fma( + a: cutlass.Float32, + b: cutlass.Float32, + c: cutlass.Float32, + *, + loc=None, + ip=None, +) -> cutlass.Float32: + return cutlass.Float32( + llvm.inline_asm( + cutlass_typing.Float32.mlir_type, + [ + cutlass_typing.Float32(a).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(b).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(c).ir_value(loc=loc, ip=ip), + ], + """{\n\t + fma.rn.ftz.f32 $0, $1, $2, $3;\n\t + \n\t}""", + "=f, f, f, f", + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def max3f( + a: cutlass.Float32, + b: cutlass.Float32, + c: cutlass.Float32, + *, + loc=None, + ip=None, +) -> cutlass.Float32: + return cutlass.Float32( + llvm.inline_asm( + cutlass_typing.Float32.mlir_type, + [ + cutlass_typing.Float32(a).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(b).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(c).ir_value(loc=loc, ip=ip), + ], + """{\n\t + max.f32 $0, $1, $2, $3;\n\t + \n\t}""", + "=f, f, f, f", + loc=loc, + ip=ip, + ) + ) + + +sub_packed_f32x2 = partial( + cute.arch.calc_packed_f32x2_op, + src_c=None, + calc_func=nvvm.sub_packed_f32x2, + rnd=nvvm.RoundingModeKind.RN, +) + + +@dsl_user_op +@cute.jit +def evaluate_polynomial_2( + x: cutlass.Float32, + y: cutlass.Float32, + poly: Tuple[cutlass.Float32, ...], + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Float32, cutlass.Float32]: + deg = len(poly) - 1 + out = (poly[deg], poly[deg]) + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i])) + return out + + +@dsl_user_op +def combine_int_frac_ex2( + x_rounded: cutlass.Float32, + frac_ex2: cutlass.Float32, + *, + loc=None, + ip=None, +) -> cutlass.Float32: + return cutlass.Float32( + llvm.inline_asm( + cutlass_typing.Float32.mlir_type, + [ + cutlass_typing.Float32(x_rounded).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(frac_ex2).ir_value(loc=loc, ip=ip), + ], + """{\n\t + .reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i; \n\t + mov.b32 x_rounded_i, $1; \n\t + mov.b32 frac_ex_i, $2; \n\t + shl.b32 x_rounded_e, x_rounded_i, 23; \n\t + add.s32 out_i, x_rounded_e, frac_ex_i; \n\t + mov.b32 $0, out_i; \n\t + \n\t}""", + "=f, f, f", + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def exp2_emulation_2( + x: cutlass.Float32, + y: cutlass.Float32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Float32, cutlass.Float32]: + # assume x <= 127.0 and y <= 127.0 + poly_ex2_deg3 = ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ) + fp32_round_int = float(2**23 + 2**22) + xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) + xy_rounded = cute.arch.add_packed_f32x2( + xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM + ) + xy_rounded_back = sub_packed_f32x2(xy_rounded, (fp32_round_int, fp32_round_int)) + xy_frac = sub_packed_f32x2(xy_clamped, xy_rounded_back) + xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, poly_ex2_deg3, loc=loc, ip=ip) + x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) + y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) + return x_out, y_out + + +@dsl_user_op +def exp2f_packed_f32x2( + x: cutlass.Float32, + y: cutlass.Float32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Float32, cutlass.Float32]: + result = llvm.inline_asm( + llvm.StructType.get_literal( + [ + cutlass_typing.Float32.mlir_type, + cutlass_typing.Float32.mlir_type, + ] + ), + [ + cutlass_typing.Float32(x).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(y).ir_value(loc=loc, ip=ip), + ], + """{\n\t + ex2.approx.f32 $0, $2;\n\t + ex2.approx.f32 $1, $3;\n\t + \n\t}""", + # Keep constraints compact (no spaces) and in the correct order + "=f,=f,f,f", + loc=loc, + ip=ip, + ) + + # Extract struct fields + out0_val = llvm.extractvalue(cutlass_typing.Float32.mlir_type, result, [0], loc=loc, ip=ip) + out1_val = llvm.extractvalue(cutlass_typing.Float32.mlir_type, result, [1], loc=loc, ip=ip) + + # Wrap back into cutlass.Float32 + return cutlass.Float32(out0_val), cutlass.Float32(out1_val) diff --git a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/scheduler.py b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/scheduler.py new file mode 100644 index 000000000000..921251c613ad --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/scheduler.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, fields +from functools import partial +from typing import Optional, Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 +from cutlass._mlir import ir +from cutlass.cute import FastDivmodDivisor + +try: + from typing import override +except ImportError: + from typing_extensions import override + + +class WorkTileInfo(cutlass.utils.WorkTileInfo): + """ + It includes block, head, batch, and is_valid_tile + """ + + @override + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": + assert len(values) == 4 + new_tile_idx = cutlass.new_from_mlir_values(self.tile_idx, values[:-1]) + new_is_valid_tile = cutlass.new_from_mlir_values(self.is_valid_tile, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, cutlass.Constexpr)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, cutlass.Constexpr) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class TileSchedulerParams(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + headdim: Int32 + headdim_v: Int32 + + +class StaticPersistentScheduler: + @dataclass + class Params(ParamsBase): + num_block_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor + total_blocks: Int32 + + @staticmethod + def create( + args: TileSchedulerParams, + *, + loc=None, + ip=None, + ) -> "StaticPersistentScheduler.Params": + total_blocks = args.num_block * args.num_head * args.num_batch + return StaticPersistentScheduler.Params( + FastDivmodDivisor(args.num_block), FastDivmodDivisor(args.num_head), total_blocks + ) + + def __init__( + self, + params: Params, + tile_idx: Int32, + *, + loc=None, + ip=None, + ): + self.params = params + self.tile_idx = tile_idx + self.loc = loc + self.ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerParams, *, loc=None, ip=None) -> Params: + return StaticPersistentScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentScheduler": + tile_idx = cute.arch.block_idx()[0] + return StaticPersistentScheduler(params, tile_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + sm_count: Optional[Int32] = None, + occupancy: Int32 = 1, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + if sm_count is None: + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + vacancies = sm_count * occupancy + return (cutlass.min(vacancies, params.total_blocks), Int32(1), Int32(1)) + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + hn_idx, block_idx = divmod(self.tile_idx, self.params.num_block_divmod) + batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) + is_valid = self.tile_idx < self.params.total_blocks + return WorkTileInfo((Int32(block_idx), Int32(head_idx), Int32(batch_idx)), is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self.tile_idx += cute.arch.grid_dim()[0] + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self.tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.params, self.tile_idx], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return StaticPersistentScheduler(*(tuple(obj_list)), loc=self.loc) + + +if __name__ == "__main__": + + class TestScheduler: + def __init__(self, scheduler_cls): + self.scheduler_cls = scheduler_cls + + @cute.jit + def __call__(self): + scheduler_params = TileSchedulerParams( + num_block=Int32(5), + num_head=Int32(4), + num_batch=Int32(3), + headdim=Int32(128), + headdim_v=Int32(64), + ) + params = self.scheduler_cls.to_underlying_arguments(scheduler_params) + + grid_dim = self.scheduler_cls.get_grid_shape(params) + print(f"grid_dim: {grid_dim}") + + self.kernel(params).launch(grid=grid_dim, block=[32, 2, 1], min_blocks_per_mp=1) + + @cute.kernel + def kernel(self, params: ParamsBase): + TileSchedulerCls = partial(self.scheduler_cls.create, params) + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + bidx, _, _ = cute.arch.block_idx() + + if warp_idx == 0: + lane_idx = cute.arch.lane_idx() + + scheduler = TileSchedulerCls() + + work_tile = scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + + if lane_idx == 0: + cute.printf( + "block_idx: {}, warp_idx: {}, m_block: {}, head_idx: {}, batch_idx: {}", + bidx, + warp_idx, + m_block, + head_idx, + batch_idx, + ) + + scheduler.prefetch_next_work() + scheduler.advance_to_next_work() + work_tile = scheduler.get_current_work() + + elif warp_idx == 1: + scheduler = TileSchedulerCls() + + test_scheduler = TestScheduler(StaticPersistentScheduler) + test_scheduler() diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py index 4528a657709f..a6cb6eeffaa8 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py @@ -59,6 +59,14 @@ def __init__(self, pipeline_config): "FluxPipeline does not support CFG parallelism. Please set cfg_size to 1." ) + _sa_cfg = pipeline_config.attention.sparse_attention_config + if _sa_cfg is not None and getattr(_sa_cfg, "algorithm", None) == "vsa": + raise ValueError( + "Video Sparse Attention (sparse_attention_config.algorithm='vsa') is " + "only supported by the Wan 2.1 T2V 14B (720P) pipeline. Remove " + "sparse_attention_config for FLUX." + ) + super().__init__(pipeline_config) @staticmethod diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py index 7850d9413e0f..b30fc0a5ef79 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py @@ -128,6 +128,14 @@ def __init__(self, pipeline_config): "Flux2Pipeline does not support CFG parallelism. Please set cfg_size to 1." ) + _sa_cfg = pipeline_config.attention.sparse_attention_config + if _sa_cfg is not None and getattr(_sa_cfg, "algorithm", None) == "vsa": + raise ValueError( + "Video Sparse Attention (sparse_attention_config.algorithm='vsa') is " + "only supported by the Wan 2.1 T2V 14B (720P) pipeline. Remove " + "sparse_attention_config for FLUX.2." + ) + super().__init__(pipeline_config) @staticmethod diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py index ff919c739f17..2069e3762c1e 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py @@ -619,6 +619,16 @@ class LTX2Pipeline(BasePipeline): ``transformers`` library. """ + def __init__(self, model_config): + _sa_cfg = model_config.attention.sparse_attention_config + if _sa_cfg is not None and getattr(_sa_cfg, "algorithm", None) == "vsa": + raise ValueError( + "Video Sparse Attention (sparse_attention_config.algorithm='vsa') is " + "only supported by the Wan 2.1 T2V 14B (720P) pipeline. Remove " + "sparse_attention_config for LTX-2." + ) + super().__init__(model_config) + @classmethod def resolve_variant(cls, config): if getattr(config, "cache_backend", None) == "cache_dit": diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index ea4e1b880608..98a48b9745c9 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -10,6 +10,10 @@ from diffusers.video_processor import VideoProcessor from transformers import AutoTokenizer, UMT5EncoderModel +from tensorrt_llm._torch.visual_gen.attention_backend import ( + VSAMetadataBuilder, + set_vsa_forward_context, +) from tensorrt_llm._torch.visual_gen.cache.teacache import ( ExtractorConfig, register_extractor_from_config, @@ -516,9 +520,22 @@ def forward( f"guidance_scale={guidance_scale}, guidance_scale_2={guidance_scale_2}" ) + # VSA: build metadata builder once per forward() call; reused across timesteps. + _attn_cfg = self.pipeline_config.primary_model_config.attention + _sparse_cfg = getattr(_attn_cfg, "sparse_attention_config", None) + _vsa_active = ( + getattr(_attn_cfg, "backend", "VANILLA") == "CUTEDSL" + and _sparse_cfg is not None + and getattr(_sparse_cfg, "algorithm", None) == "vsa" + ) + _vsa_builder = VSAMetadataBuilder() if _vsa_active else None + _vsa_patch_size = tuple(getattr(self.config, "patch_size", [1, 2, 2])) # (pT, pH, pW) + _vsa_sparsity = _sparse_cfg.vsa_sparsity if _vsa_active else 0.0 + # Denoising with two-stage support # Track which model was used in last step (for logging model transitions) last_model_used = [None] + _vsa_step_counter = [0] def forward_fn( latents, @@ -567,6 +584,24 @@ def forward_fn( # T2V: current_t for all frames timestep = current_t.reshape(1, 1).expand(latents.shape[0], nf * nh * nw) + if _vsa_active and _vsa_builder is not None: + # latents: [B, C, T_latent, H_latent, W_latent] + raw_latent_shape = (latents.shape[2], latents.shape[3], latents.shape[4]) + vsa_metadata = _vsa_builder.build( + current_timestep=_vsa_step_counter[0], + raw_latent_shape=raw_latent_shape, + patch_size=_vsa_patch_size, + vsa_sparsity=_vsa_sparsity, + device=latents.device, + ) + _vsa_step_counter[0] += 1 + with set_vsa_forward_context(vsa_metadata): + return current_model( + hidden_states=latents, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + ) + return current_model( hidden_states=latents, timestep=timestep / self.scheduler.config.num_train_timesteps, diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index 5d8ea711abea..0bd7adba1573 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -103,6 +103,14 @@ def __init__(self, pipeline_config): "Use cache_backend='none' or 'cache_dit' (not 'teacache')." ) + _sa_cfg = pipeline_config.attention.sparse_attention_config + if _sa_cfg is not None and getattr(_sa_cfg, "algorithm", None) == "vsa": + raise ValueError( + "Video Sparse Attention (sparse_attention_config.algorithm='vsa') is " + "only supported by the Wan 2.1 T2V 14B (720P) pipeline. Remove " + "sparse_attention_config for Wan I2V." + ) + super().__init__(pipeline_config) def _compute_wan_timestep_embedding(self, module, timestep=None, **kwargs): diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py index 7e86f23d7c70..17c0847cda22 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py @@ -246,6 +246,19 @@ def forward( return temb, temb_proj, encoder_hidden_states, encoder_hidden_states_image +def _default_vsa_gate(linear: Linear, bias_value: float) -> None: + """Default a VSA gate Linear absent from the checkpoint: weight=0, bias=bias_value + (G_c=0 / G_f=1, preserving dense behavior at sparsity=0).""" + if not linear._weights_created: + linear.create_weights() + if linear.weight.is_meta: + return + with torch.no_grad(): + linear.weight.zero_() + if linear.bias is not None: + linear.bias.fill_(bias_value) + + class WanBlock(nn.Module): def __init__( self, @@ -356,6 +369,46 @@ def __init__( reduce_output=(tp_size != 1), ) + # VSA gates (CUTEDSL backend, sparse_attention_config.algorithm == "vsa"). + # G_c weights the coarse branch; G_f weights the fine branch. + self.to_gate_compress = None + self.to_gate_fine = None + _attn_cfg = getattr(model_config, "attention", None) + _sa_cfg = getattr(_attn_cfg, "sparse_attention_config", None) if _attn_cfg else None + _is_vsa = ( + _attn_cfg is not None + and getattr(_attn_cfg, "backend", "VANILLA") == "CUTEDSL" + and _sa_cfg is not None + and getattr(_sa_cfg, "algorithm", None) == "vsa" + ) + if _is_vsa: + q_dim = num_heads * head_dim + gate_tp_mode = TensorParallelMode.COLUMN if tp_size > 1 else None + self.to_gate_compress = Linear( + hidden_size, + q_dim, + bias=True, + dtype=dtype, + mapping=model_config.mapping, + quant_config=quant_config, + skip_create_weights_in_init=skip_create_weights, + force_dynamic_quantization=force_dynamic_quant, + tensor_parallel_mode=gate_tp_mode, + reduce_output=False, + ) + self.to_gate_fine = Linear( + hidden_size, + q_dim, + bias=True, + dtype=dtype, + mapping=model_config.mapping, + quant_config=quant_config, + skip_create_weights_in_init=skip_create_weights, + force_dynamic_quantization=force_dynamic_quant, + tensor_parallel_mode=gate_tp_mode, + reduce_output=False, + ) + # I2V: Additional K/V projections for image embeddings. self.add_k_proj = self.add_v_proj = None self.norm_added_k = None @@ -430,13 +483,19 @@ def forward( # Prepare frequencies for Attention freqs = (freqs_cos, freqs_sin) if freqs_cos is not None and freqs_sin is not None else None + attn1_kwargs = {} + if self.to_gate_compress is not None: + attn1_kwargs["gate_compress"] = self.to_gate_compress(normed) + if self.to_gate_fine is not None: + attn1_kwargs["gate_fine"] = self.to_gate_fine(normed) + # Self-attention with RoPE. Async-ulysses dispatches to forward_async # so each V/Q/K GEMM + norm + RoPE overlaps with the peer push on the # side stream; both paths return 3D [B, S, H*D]. if self._use_async_ulysses: attn1_out = self.attn1.forward_async(normed, freqs=freqs, timestep=timestep) else: - attn1_out = self.attn1(normed, freqs=freqs, timestep=timestep) + attn1_out = self.attn1(normed, freqs=freqs, timestep=timestep, **attn1_kwargs) x = (x.float() + attn1_out.float() * gate_msa).to(x.dtype) @@ -809,6 +868,12 @@ def load_weights(self, weights: dict) -> None: if weight_dicts: loader.load_linear_weights(module, name, weight_dicts) + # VSA gates absent from the checkpoint default to G_c=0 / G_f=1 + # (dense behavior at sparsity=0). + elif name.endswith(".to_gate_compress"): + _default_vsa_gate(module, 0.0) + elif name.endswith(".to_gate_fine"): + _default_vsa_gate(module, 1.0) elif "add_k_proj" in name or "add_v_proj" in name: logger.info(f"[Weight Loading] No weights found for I2V module: {name}") elif isinstance(module, RMSNormTPAware): diff --git a/tensorrt_llm/_torch/visual_gen/modules/attention.py b/tensorrt_llm/_torch/visual_gen/modules/attention.py index eafac1e328c4..c35a9dfb0834 100644 --- a/tensorrt_llm/_torch/visual_gen/modules/attention.py +++ b/tensorrt_llm/_torch/visual_gen/modules/attention.py @@ -95,13 +95,34 @@ def __init__( # Select compute backend (orthogonal to parallelism) vgm = config.visual_gen_mapping ulysses_size = vgm.ulysses_size if vgm else 1 + ring_size = vgm.ring_size if vgm else 1 + attn2d_size = (vgm.attn2d_row_size * vgm.attn2d_col_size) if vgm else 1 base_backend = config.attention.backend + _sa_cfg = config.attention.sparse_attention_config + _is_vsa = ( + base_backend == "CUTEDSL" + and _sa_cfg is not None + and getattr(_sa_cfg, "algorithm", None) == "vsa" + ) - # TRTLLM doesn't support cross-attention (different Q/KV seq lengths); fall back to VANILLA - if self.qkv_mode == QKVMode.SEPARATE_QKV and base_backend == "TRTLLM": + # Cross-attention fallback: TRTLLM and CUTEDSL VSA are self-attn only. + if self.qkv_mode == QKVMode.SEPARATE_QKV and (base_backend == "TRTLLM" or _is_vsa): backend_name = "VANILLA" else: backend_name = base_backend + + if _is_vsa and attn2d_size > 1: + raise ValueError( + f"VSA needs the full token sequence per rank, so it is incompatible " + f"with Attention2D (attn2d_size={attn2d_size}). Use ulysses or cfg " + f"parallelism instead." + ) + if _is_vsa and ring_size > 1: + raise ValueError( + f"VSA needs the full token sequence per rank, so it is incompatible " + f"with Ring attention (ring_size={ring_size}). Use ulysses or cfg " + f"parallelism instead." + ) self.attn_backend = backend_name self.qk_norm = qk_norm self.qk_norm_mode = qk_norm_mode @@ -458,8 +479,10 @@ def _attn_impl( """ Call attention backend with appropriate tensor layout. - Dimensions are derived from tensor shapes. Extra ``**kwargs`` - (e.g. ``attention_mask``) are forwarded to the backend. + Dimensions are derived from tensor shapes. Extra **kwargs are + forwarded to the backend. Backend-specific tensors that share + Q/K/V's [B, S, H*D] layout (e.g. VSA's gate_compress / + gate_fine) are reshaped here to the backend's 4-D layout. Two layout paths: 1. HND backends (VANILLA): [B, S, H*D] -> [B, H, S, D] @@ -471,6 +494,12 @@ def _attn_impl( seq_len = q.shape[1] seq_len_kv = k.shape[1] if k is not None else seq_len + def _reshape_gate(gate: torch.Tensor) -> torch.Tensor: + gate = gate.view(batch_size, -1, self.local_num_attention_heads, self.head_dim) + if backend_layout == AttentionTensorLayout.HND: + gate = gate.transpose(1, 2) + return gate + # Reshape inputs: [B, S, H*D] -> backend's preferred 4D layout if backend_layout == AttentionTensorLayout.HND: q = q.view(batch_size, -1, self.local_num_attention_heads, self.head_dim).transpose( @@ -494,13 +523,11 @@ def _attn_impl( "seq_len_kv": seq_len_kv, } ) + for gate_key in ("gate_compress", "gate_fine"): + if kwargs.get(gate_key) is not None: + kwargs[gate_key] = _reshape_gate(kwargs[gate_key]) - out = self.attn.forward( - q=q, - k=k, - v=v, - **kwargs, - ) + out = self.attn.forward(q=q, k=k, v=v, **kwargs) # Flatten back to [B, S, H*D] if backend_layout == AttentionTensorLayout.HND: @@ -514,6 +541,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, timestep: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: assert hidden_states.ndim == 3, "hidden_states must be a 3D tensor" batch_size, seq_len = hidden_states.shape[:2] @@ -532,7 +560,7 @@ def forward( freqs_cos, freqs_sin = freqs self.apply_packed_qk_norm_rope(qkv, freqs_cos, freqs_sin) q, k, v = qkv.split([self.local_q_dim, self.local_kv_dim, self.local_kv_dim], dim=-1) - out = self._attn_impl(q, k, v, timestep=timestep) + out = self._attn_impl(q, k, v, timestep=timestep, **kwargs) return self.to_out[0](out) # Unfused path: separate QK norm → separate RoPE → attention @@ -551,7 +579,7 @@ def forward( q = q.flatten(2) k = k.flatten(2) - out = self._attn_impl(q, k, v, timestep=timestep) + out = self._attn_impl(q, k, v, timestep=timestep, **kwargs) out = self.to_out[0](out) return out diff --git a/tensorrt_llm/_torch/visual_gen/pipeline_loader.py b/tensorrt_llm/_torch/visual_gen/pipeline_loader.py index 2ef4eb84c885..3551371fa173 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline_loader.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline_loader.py @@ -23,6 +23,9 @@ from tensorrt_llm._torch.autotuner import autotune from tensorrt_llm._torch.models.modeling_utils import MetaInitMode +from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, +) from tensorrt_llm.llmapi.utils import download_hf_model from tensorrt_llm.logger import logger from tensorrt_llm.visual_gen.args import VisualGenArgs @@ -221,6 +224,21 @@ def load( logger.info(f"Quantization: {config.quant_config.quant_algo.name}") logger.info(f"Dynamic weight quant: {config.dynamic_weight_quant}") + _attn_backend = config.attention.backend + _sa_cfg = config.attention.sparse_attention_config + if ( + _attn_backend == "CUTEDSL" + and _sa_cfg is not None + and getattr(_sa_cfg, "algorithm", None) == "vsa" + ): + kernel_path = "CuTe DSL block-sparse" if CUTE_AVAILABLE else "dense SDPA fallback" + logger.info( + f"Attention backend: CUTEDSL (algorithm=vsa, " + f"sparsity={_sa_cfg.vsa_sparsity}, fine-stage={kernel_path})" + ) + else: + logger.info(f"Attention backend: {_attn_backend}") + # ===================================================================== # STEP 1b: Build VisualGenMapping (must precede model creation) # ===================================================================== diff --git a/tensorrt_llm/visual_gen/__init__.py b/tensorrt_llm/visual_gen/__init__.py index 5b5ecb70048f..f15294afec4c 100644 --- a/tensorrt_llm/visual_gen/__init__.py +++ b/tensorrt_llm/visual_gen/__init__.py @@ -36,6 +36,7 @@ SparseAttentionConfig, TeaCacheConfig, TorchCompileConfig, + VideoSparseAttentionConfig, VisualGenArgs, ) from .output import VisualGenMetrics, VisualGenOutput @@ -60,6 +61,7 @@ "QuantAttentionConfig", "SparseAttentionConfig", "SkipSoftmaxAttentionConfig", + "VideoSparseAttentionConfig", "CacheConfig", "TeaCacheConfig", "CacheDiTConfig", diff --git a/tensorrt_llm/visual_gen/args.py b/tensorrt_llm/visual_gen/args.py index 570744f6c6b9..49033a3c4f3a 100644 --- a/tensorrt_llm/visual_gen/args.py +++ b/tensorrt_llm/visual_gen/args.py @@ -30,7 +30,7 @@ from tensorrt_llm.llmapi.utils import StrictBaseModel, set_api_status from tensorrt_llm.models.modeling_utils import QuantConfig -from .sparse_attention import SkipSoftmaxAttentionConfig +from .sparse_attention import SkipSoftmaxAttentionConfig, VideoSparseAttentionConfig # ============================================================================= # Type aliases @@ -86,7 +86,7 @@ class QuantAttentionConfig(StrictBaseModel): # Discriminated union of sparse attention configs. SparseAttentionConfig = Annotated[ - Union[SkipSoftmaxAttentionConfig], + Union[SkipSoftmaxAttentionConfig, VideoSparseAttentionConfig], Field(discriminator="algorithm"), ] @@ -111,7 +111,10 @@ class AttentionConfig(StrictBaseModel): sparse_attention_config: Optional[SparseAttentionConfig] = Field( None, status="prototype", - description="Sparse attention configuration. Currently supports: skip_softmax.", + description=( + "Sparse attention recipe. Discriminated by algorithm: " + "skip_softmax (TRTLLM backend) or VSA (CUTEDSL backend)." + ), ) @model_validator(mode="after") @@ -161,6 +164,41 @@ def _validate_quant_attention_config(self) -> "AttentionConfig": ) return self + @model_validator(mode="after") + def _validate_sparse_attention_config(self) -> "AttentionConfig": + if self.sparse_attention_config is None: + return self + + algo = self.sparse_attention_config.algorithm + required_backend = {"skip_softmax": "TRTLLM", "vsa": "CUTEDSL"}.get(algo) + if required_backend is None: + return self + + if self.backend != required_backend: + raise ValueError( + f"sparse_attention_config with algorithm='{algo}' requires " + f"backend='{required_backend}', got backend='{self.backend}'. " + f"Either set backend='{required_backend}' or remove " + f"sparse_attention_config." + ) + return self + + @model_validator(mode="after") + def _validate_cutedsl_quant_sparse_mutex(self) -> "AttentionConfig": + # quant_attention_config and sparse_attention_config are mutually exclusive. + if ( + self.backend == "CUTEDSL" + and self.quant_attention_config is not None + and self.sparse_attention_config is not None + ): + raise ValueError( + "CUTEDSL backend: quant_attention_config and " + "sparse_attention_config are mutually exclusive (the " + "CuTeDSLAttention dispatcher selects either the dense path " + "or the sparse VSA path, not both)." + ) + return self + class ParallelConfig(StrictBaseModel): """Configuration for distributed parallelism across DiT-shaped models. @@ -605,6 +643,7 @@ def from_yaml(cls, yaml_path: Union[str, Path], **overrides: Any) -> "VisualGenA "QuantAttentionConfig", "SparseAttentionConfig", "SkipSoftmaxAttentionConfig", + "VideoSparseAttentionConfig", "AttentionConfig", "ParallelConfig", "BaseCacheConfig", diff --git a/tensorrt_llm/visual_gen/sparse_attention.py b/tensorrt_llm/visual_gen/sparse_attention.py index 2316dffd41ca..09d7e7e10d21 100644 --- a/tensorrt_llm/visual_gen/sparse_attention.py +++ b/tensorrt_llm/visual_gen/sparse_attention.py @@ -197,3 +197,26 @@ def _ckpt_sparse_attention_config_from_kwargs( if isinstance(sparse_config, dict): return sparse_config return None + + +class VideoSparseAttentionConfig(StrictBaseModel): + """Video Sparse Attention (VSA) sparse-attention recipe (CUTEDSL backend only). + + Two-stage hybrid attention: a coarse mean-pooled stage over (4,4,4) cubes + and a block-sparse fine stage over the top-K cubes selected per head. + vsa_sparsity controls the fraction of cubes dropped on the fine stage. + """ + + algorithm: Literal["vsa"] = PydanticField( + "vsa", + description="Sparse attention algorithm discriminator.", + ) + vsa_sparsity: float = PydanticField( + 0.9, + ge=0.0, + le=1.0, + description=( + "Fraction of cubes dropped on the fine stage. 0.0 keeps all cubes " + "(dense fine stage); values closer to 1.0 keep fewer cubes." + ), + ) diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 219388348a69..114851b3821e 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -229,6 +229,7 @@ l0_b200: - unittest/_torch/visual_gen/test_cache_dit.py - unittest/_torch/visual_gen/test_quant_ops.py - unittest/_torch/visual_gen/test_attention_cute_dsl.py + - unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py - unittest/_torch/visual_gen/test_attention_trtllm_sage.py - unittest/_torch/visual_gen/test_attention_integration.py - unittest/_torch/visual_gen/test_attention_perf.py @@ -246,6 +247,7 @@ l0_b200: - unittest/_torch/visual_gen/test_wan22_i2v_pipeline.py - unittest/_torch/visual_gen/test_wan22_t2v_pipeline.py - unittest/_torch/visual_gen/test_wan22_ti2v_5b_pipeline.py + - unittest/_torch/visual_gen/test_wan_vsa_pipeline.py - unittest/_torch/visual_gen/test_wan21_i2v_teacache.py - unittest/_torch/visual_gen/test_wan21_t2v_teacache.py - unittest/_torch/visual_gen/test_wan_transformer.py diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_pipeline_parallel.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_pipeline_parallel.py index bb88b0e474ac..a23924dcd504 100644 --- a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_pipeline_parallel.py +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_pipeline_parallel.py @@ -76,20 +76,20 @@ def _cleanup_mpi_env(): # ============================================================================= -def _llm_models_root() -> str: +def _llm_models_root() -> str | None: root = Path("/home/scratch.trt_llm_data_ci/llm-models/") if "LLM_MODELS_ROOT" in os.environ: root = Path(os.environ["LLM_MODELS_ROOT"]) if not root.exists(): root = Path("/scratch.trt_llm_data/llm-models/") - assert root.exists(), ( - "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." - ) - return str(root) + return str(root) if root.exists() else None -def _checkpoint(env_var: str, default_name: str) -> str: - return os.environ.get(env_var) or os.path.join(_llm_models_root(), default_name) +def _checkpoint(env_var: str, default_name: str) -> str | None: + if env_var in os.environ: + return os.environ[env_var] + models_root = _llm_models_root() + return os.path.join(models_root, default_name) if models_root is not None else None WAN21_1_3B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN21_1_3B", "Wan2.1-T2V-1.3B-Diffusers") @@ -416,7 +416,7 @@ def test_cfg2_ulysses2_pvae2(self): """world=4, cfg=2, ulysses=2, parallel_vae=2 vs HF reference.""" if not MODULES_AVAILABLE: pytest.skip("Required modules not available") - if not os.path.exists(WAN21_1_3B_PATH): + if not WAN21_1_3B_PATH or not os.path.exists(WAN21_1_3B_PATH): pytest.skip( f"Checkpoint not found: {WAN21_1_3B_PATH}. Set DIFFUSION_MODEL_PATH_WAN21_1_3B." ) diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py new file mode 100644 index 000000000000..9094ba4abf7e --- /dev/null +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-GPU end-to-end test for the Wan T2V VSA pipeline. + +Run with: + pytest tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py -v -s + +Override checkpoint path: + DIFFUSION_MODEL_PATH_WAN21_VSA=/path/to/vsa \\ + pytest tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py -v -s +""" + +import gc +import os +from pathlib import Path +from typing import Callable + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn.functional as F + +try: + from tensorrt_llm._torch.visual_gen.attention_backend.cute_dsl import _cute_dsl_import_error + from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader + from tensorrt_llm._utils import get_free_port + from tensorrt_llm.visual_gen.args import ( + AttentionConfig, + ParallelConfig, + TorchCompileConfig, + VideoSparseAttentionConfig, + VisualGenArgs, + ) + + MODULES_AVAILABLE = True + _cute_dsl_available = _cute_dsl_import_error is None +except ImportError: + MODULES_AVAILABLE = False + _cute_dsl_available = False + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +# ============================================================================= +# Path helpers (mirrors test_wan_vsa_pipeline.py) +# ============================================================================= + + +def _llm_models_root() -> str | None: + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + return str(root) if root.exists() else None + + +def _checkpoint(env_var: str, default_name: str) -> str | None: + if env_var in os.environ: + return os.environ[env_var] + models_root = _llm_models_root() + return os.path.join(models_root, default_name) if models_root is not None else None + + +WAN21_VSA_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN21_VSA", "Wan2.1-VSA-T2V-14B-720P-Diffusers") + + +# ============================================================================= +# Inference constants +# ============================================================================= + +PROMPT = "A cat sitting on a sunny windowsill watching birds outside." +NEGATIVE_PROMPT = "" + +HEIGHT = 720 +WIDTH = 1280 +NUM_FRAMES = 9 +NUM_STEPS = 4 +GUIDANCE_SCALE = 5.0 +SEED = 42 + +COS_SIM_THRESHOLD = 0.97 + + +# ============================================================================= +# Distributed harness (mirrors test_wan_pipeline_parallel.py) +# ============================================================================= + + +def init_distributed_worker(rank: int, world_size: int, backend: str = "nccl", port: int = 29500): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + torch.cuda.set_device(rank % torch.cuda.device_count()) + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) + + +def cleanup_distributed(): + if dist.is_initialized(): + dist.destroy_process_group() + + +def _distributed_worker(rank, world_size, backend, test_fn, port, kwargs): + try: + init_distributed_worker(rank, world_size, backend, port) + test_fn(rank, world_size, **kwargs) + except Exception as e: + print(f"Rank {rank} failed with error: {e}") + raise + finally: + cleanup_distributed() + + +def run_test_in_distributed(world_size: int, test_fn: Callable, **kwargs): + if not MODULES_AVAILABLE: + pytest.skip("Required modules not available") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Test requires {world_size} GPUs, only {torch.cuda.device_count()} available") + port = get_free_port() + mp.spawn( + _distributed_worker, + args=(world_size, "nccl", test_fn, port, kwargs), + nprocs=world_size, + join=True, + ) + + +# ============================================================================= +# Inference helpers +# ============================================================================= + + +VSA_SPARSITY = 0.9 + + +def _build_vsa_parallel_args(checkpoint_path: str) -> "VisualGenArgs": + """cfg=2, ulysses=4, CUTEDSL backend with vsa_sparsity=0.9.""" + return VisualGenArgs( + model=checkpoint_path, + torch_compile_config=TorchCompileConfig(enable=False), + attention_config=AttentionConfig( + backend="CUTEDSL", + sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=VSA_SPARSITY), + ), + parallel_config=ParallelConfig(cfg_size=2, ulysses_size=4), + ) + + +def _build_vsa_single_args(checkpoint_path: str) -> "VisualGenArgs": + """Single-GPU CUTEDSL reference at the same vsa_sparsity=0.9.""" + return VisualGenArgs( + model=checkpoint_path, + torch_compile_config=TorchCompileConfig(enable=False), + attention_config=AttentionConfig( + backend="CUTEDSL", + sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=VSA_SPARSITY), + ), + ) + + +def _capture_trtllm_video(pipeline) -> "torch.Tensor | None": + """Run full TRTLLM pipeline; return (T, H, W, C) float in [0, 1] or None.""" + with torch.no_grad(): + result = pipeline.forward( + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=HEIGHT, + width=WIDTH, + num_frames=NUM_FRAMES, + num_inference_steps=NUM_STEPS, + guidance_scale=GUIDANCE_SCALE, + seed=SEED, + ) + if result is None or result.video is None: + return None + return result.video.float() / 255.0 + + +def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float: + a_flat = a.float().cpu().reshape(-1) + b_flat = b.float().cpu().reshape(-1) + return F.cosine_similarity(a_flat.unsqueeze(0), b_flat.unsqueeze(0)).clamp(-1.0, 1.0).item() + + +def _free(*objs) -> None: + for o in objs: + del o + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================================= +# Worker logic (module-level for mp.spawn pickling) +# ============================================================================= + + +def _logic_vsa_cfg2_ulysses4(rank: int, world_size: int, *, checkpoint_path: str) -> None: + """End-to-end pipeline: 8-GPU (cfg=2, ulysses=4) VSA vs 1-GPU VSA reference. + + Both runs use vsa_sparsity=0.9. All 8 ranks run the parallel denoising loop. + Rank 0 then frees the distributed model and loads a single-GPU VSA reference + at the same sparsity to compare. + """ + assert world_size == 8, f"This test is hardcoded to world_size=8, got {world_size}" + + vsa_pipe = PipelineLoader(_build_vsa_parallel_args(checkpoint_path)).load(skip_warmup=True) + vsa_video = _capture_trtllm_video(vsa_pipe) + + if rank == 0: + assert vsa_video is not None, "Rank 0 produced no video from the VSA pipeline." + + _free(vsa_pipe) + if rank != 0: + vsa_video = None + dist.barrier() + + if rank != 0: + return + + # Destroy the world_size=8 group before loading the single-GPU reference. + dist.destroy_process_group() + + ref_pipe = PipelineLoader(_build_vsa_single_args(checkpoint_path)).load(skip_warmup=True) + ref_video = _capture_trtllm_video(ref_pipe) + _free(ref_pipe) + + assert vsa_video.numel() == ref_video.numel(), ( + f"Element count mismatch — 8-GPU {tuple(vsa_video.shape)} " + f"({vsa_video.numel()}) vs 1-GPU {tuple(ref_video.shape)} ({ref_video.numel()})" + ) + + cos_sim = _cosine_similarity(vsa_video, ref_video) + print( + f"\n Wan2.1-VSA-T2V-14B (cfg=2, ulysses=4, sparsity={VSA_SPARSITY}) " + f"cosine similarity vs 1-GPU VSA: {cos_sim:.6f}" + ) + assert cos_sim >= COS_SIM_THRESHOLD, ( + f"8-GPU VSA pipeline diverges from 1-GPU VSA reference: " + f"cosine similarity {cos_sim:.6f} < {COS_SIM_THRESHOLD}. " + f"Shapes — 8-GPU: {tuple(vsa_video.shape)}, 1-GPU: {tuple(ref_video.shape)}." + ) + + +# ============================================================================= +# Test class +# ============================================================================= + + +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWanVsaUlysses: + """Multi-GPU correctness for the Wan T2V VSA pipeline (cfg=2, ulysses=4, 8 GPUs).""" + + def test_cfg2_ulysses4(self): + """world=8, cfg=2, ulysses=4, VSA sparsity=0.9 vs 1-GPU VSA reference.""" + if not MODULES_AVAILABLE: + pytest.skip("Required modules not available") + if not _cute_dsl_available: + pytest.skip(f"CUTEDSL not available (requires Blackwell GPU): {_cute_dsl_import_error}") + if WAN21_VSA_PATH is None or not os.path.exists(WAN21_VSA_PATH): + pytest.skip( + f"Checkpoint not found: {WAN21_VSA_PATH}. Set DIFFUSION_MODEL_PATH_WAN21_VSA." + ) + run_test_in_distributed( + world_size=8, + test_fn=_logic_vsa_cfg2_ulysses4, + checkpoint_path=WAN21_VSA_PATH, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py b/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py new file mode 100644 index 000000000000..63e7fd2498ff --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py @@ -0,0 +1,441 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VSA correctness tests: CuTe kernel, tile/untile roundtrip, top-k math, backend guards. + +Module-level dense-equivalence and finite-output checks live in +test_attention_integration.py. +""" + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn.functional as F + +from tensorrt_llm._torch.visual_gen.attention_backend import VSAMetadataBuilder +from tensorrt_llm._torch.visual_gen.config import ( + DiffusionModelConfig, + create_attention_metadata_state, +) +from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode +from tensorrt_llm.visual_gen.args import AttentionConfig, VideoSparseAttentionConfig + + +def _make_config( + hidden_size: int, + num_heads: int, + head_dim: int, + backend: str, + vsa_sparsity: "float | None" = None, +) -> DiffusionModelConfig: + """Minimal DiffusionModelConfig for one Attention module.""" + pretrained_config = SimpleNamespace( + hidden_size=hidden_size, + num_attention_heads=num_heads, + attention_head_dim=head_dim, + eps=1e-6, + ) + sparse_attention_config = ( + VideoSparseAttentionConfig(vsa_sparsity=vsa_sparsity) if vsa_sparsity is not None else None + ) + config = DiffusionModelConfig( + pretrained_config=pretrained_config, + attention=AttentionConfig(backend=backend, sparse_attention_config=sparse_attention_config), + skip_create_weights_in_init=False, + ) + config.attention_metadata_state = ( + create_attention_metadata_state() if backend == "TRTLLM" else None + ) + return config + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="VSA needs CUDA") +def test_vsa_falls_back_to_vanilla_for_cross_attention(): + """Cross-attention (SEPARATE_QKV) falls back to VANILLA — it has no cube structure.""" + device = torch.device("cuda") + dtype = torch.bfloat16 + cfg = _make_config( + hidden_size=64, num_heads=4, head_dim=16, backend="CUTEDSL", vsa_sparsity=0.5 + ) + cross_attn = ( + Attention(64, 4, qkv_mode=QKVMode.SEPARATE_QKV, config=cfg) + .to(device=device, dtype=dtype) + .eval() + ) + assert cross_attn.attn_backend == "VANILLA", ( + f"VSA on cross-attention should fall back to VANILLA, got {cross_attn.attn_backend!r}" + ) + + +def test_vsa_with_attn2d_raises(): + """VSA + Attention2D must error at construction (VSA needs the full sequence per rank).""" + pretrained_config = SimpleNamespace( + hidden_size=64, + num_attention_heads=4, + attention_head_dim=16, + eps=1e-6, + ) + cfg = DiffusionModelConfig( + pretrained_config=pretrained_config, + attention=AttentionConfig( + backend="CUTEDSL", + sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=0.0), + ), + skip_create_weights_in_init=False, + ) + cfg.visual_gen_mapping = SimpleNamespace( + ring_size=1, + ring_group=None, + ulysses_size=1, + ulysses_group=None, + attn2d_row_size=2, + attn2d_col_size=2, + attn2d_row_group=None, + attn2d_col_group=None, + ) + with pytest.raises(ValueError, match="incompatible with Attention2D"): + Attention(64, 4, qkv_mode=QKVMode.FUSE_QKV, config=cfg) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="VSA needs CUDA") +def test_vsa_topk_collapses_to_dense_at_sparsity_zero(): + """At sparsity=0, top_k equals num_cubes (dense connectivity).""" + from math import ceil + + device = torch.device("cuda") + builder = VSAMetadataBuilder() + metadata = builder.build( + current_timestep=0, + raw_latent_shape=(8, 8, 8), + patch_size=(1, 1, 1), + vsa_sparsity=0.0, + device=device, + ) + num_cubes = metadata.num_tiles[0] * metadata.num_tiles[1] * metadata.num_tiles[2] + cur_topk = max(1, ceil((1.0 - metadata.vsa_sparsity) * num_cubes)) + assert cur_topk == num_cubes, ( + f"sparsity=0 should select all {num_cubes} cubes, got top_k={cur_topk}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="VSA needs CUDA") +@pytest.mark.parametrize( + "latent_shape", + [ + (8, 8, 8), + (9, 9, 9), + (21, 45, 80), + ], + ids=["clean_8x8x8", "ragged_9x9x9", "wan720p_21x45x80"], +) +def test_vsa_tile_untile_roundtrip(latent_shape): + """VSAPreprocessor.tile then .untile must losslessly reproduce the input.""" + from tensorrt_llm._torch.visual_gen.attention_backend.cute_dsl.vsa import VSAPreprocessor + + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(0) + + B, H, D = 2, 4, 32 + seq_len = latent_shape[0] * latent_shape[1] * latent_shape[2] + + builder = VSAMetadataBuilder() + meta = builder.build( + current_timestep=0, + raw_latent_shape=latent_shape, + patch_size=(1, 1, 1), + vsa_sparsity=0.0, + device=device, + ) + + x = torch.randn(B, seq_len, H, D, device=device, dtype=dtype) + + x_tiled = VSAPreprocessor.tile( + x, + meta.non_pad_index, + meta.gather_idx, + meta.padded_seq_length, + ) + + pad_mask = torch.ones(meta.padded_seq_length, dtype=torch.bool, device=device) + pad_mask[meta.non_pad_index] = False + if pad_mask.any(): + assert x_tiled[:, pad_mask, :, :].abs().max().item() == 0.0, ( + "tile() must zero-fill padded positions" + ) + + x_roundtrip = VSAPreprocessor.untile( + x_tiled, + meta.reverse_tile_partition_indices, + meta.non_pad_index, + ) + + assert x_roundtrip.shape == x.shape, ( + f"shape mismatch after tile/untile: {x_roundtrip.shape} vs {x.shape}" + ) + assert torch.equal(x_roundtrip, x), ( + f"tile/untile round-trip is not lossless for latent_shape={latent_shape}: " + f"max_diff={(x_roundtrip - x).abs().max().item():.3e}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="kernel test needs CUDA") +def test_cute_kernel_matches_dense_at_full_topk(): + """CuTe block-sparse kernel matches dense SDPA when every cube is selected.""" + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, + block_sparse_attn_from_indices_cute, + is_cute_supported, + ) + + if not CUTE_AVAILABLE: + pytest.skip("cuda-bindings or cutlass-dsl not importable") + + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(0) + + B, H, num_cubes, D = 1, 4, 4, 128 + block_size = 64 + seq_len = num_cubes * block_size + + q = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + k = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + v = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + + if not is_cute_supported(q): + pytest.skip("CuTe path needs sm_100+ Blackwell (current device unsupported)") + + topk = num_cubes + q2k_idx = ( + torch.arange(num_cubes, device=device, dtype=torch.int32) + .view(1, 1, 1, num_cubes) + .expand(B, H, num_cubes, topk) + .contiguous() + ) + q2k_num = torch.full((B, H, num_cubes), topk, dtype=torch.int32, device=device) + variable_block_sizes = torch.full((num_cubes,), block_size, dtype=torch.int32, device=device) + + out_kernel, _lse = block_sparse_attn_from_indices_cute( + q, k, v, q2k_idx, q2k_num, variable_block_sizes + ) + out_ref = F.scaled_dot_product_attention(q, k, v) + + max_diff = (out_kernel - out_ref).abs().max().item() + mean_diff = (out_kernel - out_ref).abs().mean().item() + + rtol, atol = 1e-2, 1e-2 + assert torch.allclose(out_kernel, out_ref, rtol=rtol, atol=atol), ( + f"CuTe block-sparse kernel deviates from dense SDPA at full top-K: " + f"max_diff={max_diff:.3e}, mean_diff={mean_diff:.3e} (rtol={rtol}, atol={atol})" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="kernel test needs CUDA") +def test_cute_kernel_matches_ref_with_independent_indices(): + """CuTe kernel: paired Q-blocks (2i, 2i+1) attend to independent KV index lists.""" + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, + block_sparse_attn_from_indices_cute, + is_cute_supported, + ) + + if not CUTE_AVAILABLE: + pytest.skip("cuda-bindings or cutlass-dsl not importable") + + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(42) + + B, H, num_cubes, D = 2, 4, 16, 128 + block_size = 64 + topk = num_cubes // 2 + seq_len = num_cubes * block_size + + q = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + k = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + v = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + + if not is_cute_supported(q): + pytest.skip("CuTe path needs sm_100+ Blackwell (current device unsupported)") + + q2k_idx = ( + torch.stack( + [ + torch.randperm(num_cubes, device=device, dtype=torch.int32)[:topk] + for _ in range(B * H * num_cubes) + ] + ) + .view(B, H, num_cubes, topk) + .contiguous() + ) + + paired = q2k_idx.view(B, H, num_cubes // 2, 2, topk).sort(dim=-1).values + pair_mismatch = (paired[..., 0, :] != paired[..., 1, :]).sum().item() + assert pair_mismatch > 0, ( + "Pre-condition failed: random permutations matched across every pair; " + "re-seed or raise num_cubes." + ) + + q2k_num = torch.full((B, H, num_cubes), topk, dtype=torch.int32, device=device) + variable_block_sizes = torch.full((num_cubes,), block_size, dtype=torch.int32, device=device) + + attn_mask = torch.full( + (B, H, seq_len, seq_len), float("-inf"), device=device, dtype=torch.float32 + ) + for b in range(B): + for h in range(H): + for q_blk in range(num_cubes): + for ki in range(topk): + k_blk = q2k_idx[b, h, q_blk, ki].item() + qs = q_blk * block_size + ks = k_blk * block_size + attn_mask[b, h, qs : qs + block_size, ks : ks + block_size] = 0.0 + + out_kernel, _lse = block_sparse_attn_from_indices_cute( + q, k, v, q2k_idx, q2k_num, variable_block_sizes + ) + + scale = 1.0 / (D**0.5) + scores = (q.float() @ k.float().transpose(-2, -1)) * scale + scores = scores + attn_mask + probs = torch.softmax(scores, dim=-1) + out_ref = (probs @ v.float()).to(dtype) + + abs_diff = (out_kernel.float() - out_ref.float()).abs() + max_diff = abs_diff.max().item() + mean_diff = abs_diff.mean().item() + + rtol, atol = 1e-2, 1e-2 + assert torch.allclose(out_kernel, out_ref, rtol=rtol, atol=atol), ( + f"CuTe kernel with independent per-Q-block indices deviated from masked fp32 " + f"reference: max_diff={max_diff:.3e}, mean_diff={mean_diff:.3e} " + f"(rtol={rtol}, atol={atol}, pair_mismatch={pair_mismatch})" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="kernel test needs CUDA") +def test_cute_kernel_50pct_sparsity_quality_vs_dense(): + """50% sparse CuTe kernel with score-based topk should stay close to dense SDPA.""" + + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, + block_sparse_attn_from_indices_cute, + is_cute_supported, + ) + + if not CUTE_AVAILABLE: + pytest.skip("cuda-bindings or cutlass-dsl not importable") + + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(0) + + B, H, num_cubes, D = 1, 4, 16, 128 + block_size = 64 + topk = num_cubes // 2 + seq_len = num_cubes * block_size + + q = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + k = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + v = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + + if not is_cute_supported(q): + pytest.skip("CuTe path needs sm_100+ Blackwell (current device unsupported)") + + q_blocks = q.reshape(B, H, num_cubes, block_size, D).mean(dim=3) + k_blocks = k.reshape(B, H, num_cubes, block_size, D).mean(dim=3) + scale = D**-0.5 + block_scores = torch.einsum("bhqd,bhkd->bhqk", q_blocks.float(), k_blocks.float()) * scale + q2k_idx = block_scores.topk(topk, dim=-1).indices.to(torch.int32).contiguous() + + q2k_num = torch.full((B, H, num_cubes), topk, dtype=torch.int32, device=device) + variable_block_sizes = torch.full((num_cubes,), block_size, dtype=torch.int32, device=device) + + out_sparse, _lse = block_sparse_attn_from_indices_cute( + q, k, v, q2k_idx, q2k_num, variable_block_sizes + ) + out_dense = F.scaled_dot_product_attention(q, k, v) + + cos_sim = F.cosine_similarity( + out_sparse.float().reshape(-1), out_dense.float().reshape(-1), dim=0 + ).item() + print(f"\n 50% sparse (score-based topk) vs dense SDPA cos_sim: {cos_sim:.4f}") + + assert cos_sim >= 0.65, ( + f"50% sparse CuTe kernel deviated too far from dense SDPA: cos_sim={cos_sim:.4f} < 0.65" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="kernel test needs CUDA") +@pytest.mark.parametrize( + "num_cubes", + [1, 3, 9], + ids=["1cube_odd", "3cubes_odd", "9cubes_odd"], +) +def test_cute_kernel_odd_num_cubes_correctness(num_cubes): + """CuTe kernel with odd num_cubes must match dense SDPA (last Q-block has no pair).""" + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, + block_sparse_attn_from_indices_cute, + is_cute_supported, + ) + + if not CUTE_AVAILABLE: + pytest.skip("cuda-bindings or cutlass-dsl not importable") + + assert num_cubes % 2 == 1, f"pre-condition: num_cubes={num_cubes} must be odd" + + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(0) + + B, H, D = 1, 4, 128 + block_size = 64 + seq_len = num_cubes * block_size + + q = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + k = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + v = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + + if not is_cute_supported(q): + pytest.skip("CuTe path needs sm_100+ Blackwell (current device unsupported)") + + topk = num_cubes + q2k_idx = ( + torch.arange(num_cubes, device=device, dtype=torch.int32) + .view(1, 1, 1, num_cubes) + .expand(B, H, num_cubes, topk) + .contiguous() + ) + q2k_num = torch.full((B, H, num_cubes), topk, dtype=torch.int32, device=device) + variable_block_sizes = torch.full((num_cubes,), block_size, dtype=torch.int32, device=device) + + out_kernel, _lse = block_sparse_attn_from_indices_cute( + q, k, v, q2k_idx, q2k_num, variable_block_sizes + ) + out_ref = F.scaled_dot_product_attention(q, k, v) + + assert torch.isfinite(out_kernel).all(), ( + f"CuTe kernel produced non-finite output for odd num_cubes={num_cubes}" + ) + + max_diff = (out_kernel - out_ref).abs().max().item() + mean_diff = (out_kernel - out_ref).abs().mean().item() + rtol, atol = 1e-2, 1e-2 + assert torch.allclose(out_kernel, out_ref, rtol=rtol, atol=atol), ( + f"CuTe kernel deviated from dense SDPA for odd num_cubes={num_cubes}: " + f"max_diff={max_diff:.3e}, mean_diff={mean_diff:.3e} (rtol={rtol}, atol={atol})" + ) diff --git a/tests/unittest/_torch/visual_gen/test_attention_integration.py b/tests/unittest/_torch/visual_gen/test_attention_integration.py index f88f1a9db738..979a94406095 100644 --- a/tests/unittest/_torch/visual_gen/test_attention_integration.py +++ b/tests/unittest/_torch/visual_gen/test_attention_integration.py @@ -35,7 +35,11 @@ # Import new integrated versions from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode, apply_rotary_emb -from tensorrt_llm.visual_gen.args import AttentionConfig, QuantAttentionConfig +from tensorrt_llm.visual_gen.args import ( + AttentionConfig, + QuantAttentionConfig, + VideoSparseAttentionConfig, +) _flash_attn4_available = _fa4_fwd is not None _cute_dsl_available = _cute_dsl_import_error is None @@ -128,6 +132,8 @@ def create_model_config( eps: float = 1e-6, attn_backend: str = "VANILLA", quant_attention_config: "QuantAttentionConfig | None" = None, + sparse_attention_config=None, + vsa_sparsity: "float | None" = None, *, visual_gen_mapping: VisualGenMapping | None = None, skip_create_weights_in_init: bool = False, @@ -140,12 +146,15 @@ def create_model_config( eps=eps, ) - # Create a minimal config without quantization + if vsa_sparsity is not None and sparse_attention_config is None: + sparse_attention_config = VideoSparseAttentionConfig(vsa_sparsity=vsa_sparsity) + config = DiffusionModelConfig( pretrained_config=pretrained_config, attention=AttentionConfig( backend=attn_backend, quant_attention_config=quant_attention_config, + sparse_attention_config=sparse_attention_config, ), skip_create_weights_in_init=skip_create_weights_in_init, ) @@ -652,6 +661,107 @@ def test_fast_cross_attention_wan_shapes( assert is_close, f"{attn_backend} cross-attn mismatch at Wan shapes: max_diff={max_diff:.2e}" +# ============================================================================ +# VSA self-attention (CUTEDSL backend, sparse_attention_config.algorithm='vsa') +# ============================================================================ + + +def _build_vsa_setup(sparsity: float, batch_size: int, seed: int): + """Build naive + integrated models, VSA metadata, and inputs for a VSA test. + + latent (8,8,8) -> 512 tokens (divisible by block_size=64), head_dim=128. + """ + from tensorrt_llm._torch.visual_gen.attention_backend import VSAMetadataBuilder + + latent_shape = (8, 8, 8) + seq_len = latent_shape[0] * latent_shape[1] * latent_shape[2] + num_heads = 4 + head_dim = 128 + hidden_size = num_heads * head_dim + device = torch.device("cuda") + dtype = torch.bfloat16 + + naive = NaiveWanSelfAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device) + cfg_vsa = create_model_config( + hidden_size, num_heads, head_dim, attn_backend="CUTEDSL", vsa_sparsity=sparsity + ) + integrated = Attention(hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=cfg_vsa).to( + device + ) + # Fail loudly if the VSA path silently fell back to dense (which would set + # attn_backend to "VANILLA") instead of selecting the CUTEDSL/VSA backend. + assert integrated.attn_backend == "CUTEDSL", ( + f"Expected CUTEDSL (VSA) backend, got {integrated.attn_backend!r}" + ) + copy_weights_self_attention(naive, integrated) + naive.eval() + integrated.eval() + + metadata = VSAMetadataBuilder().build( + current_timestep=0, + raw_latent_shape=latent_shape, + patch_size=(1, 1, 1), + vsa_sparsity=sparsity, + device=device, + ) + torch.manual_seed(seed) + hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype) + return SimpleNamespace( + naive=naive, + integrated=integrated, + metadata=metadata, + hidden_states=hidden_states, + gate_compress_zero=torch.zeros_like(hidden_states), + freqs_HSD=generate_rope_embeddings(seq_len, head_dim, device, is_HSD=True), + freqs_SHD=generate_rope_embeddings(seq_len, head_dim, device, is_HSD=False), + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="VSA needs CUDA") +def test_vsa_self_attention_equivalence_at_sparsity_zero(): + """VSA at sparsity=0 with G_c=0 reduces to dense attention (top_k=num_cubes, + output=O_f); must match the naive SDPA reference modulo bf16 rounding.""" + from tensorrt_llm._torch.visual_gen.attention_backend import set_vsa_forward_context + + s = _build_vsa_setup(sparsity=0.0, batch_size=2, seed=42) + + with torch.no_grad(): + out_naive = s.naive(s.hidden_states, *s.freqs_HSD) + with torch.no_grad(), set_vsa_forward_context(s.metadata): + out_vsa = s.integrated( + s.hidden_states, freqs=s.freqs_SHD, gate_compress=s.gate_compress_zero + ) + + assert out_naive.shape == out_vsa.shape, ( + f"shape mismatch: naive={out_naive.shape}, vsa={out_vsa.shape}" + ) + max_diff = (out_naive - out_vsa).abs().max().item() + mean_diff = (out_naive - out_vsa).abs().mean().item() + assert torch.allclose(out_naive, out_vsa, rtol=1e-2, atol=1e-2), ( + f"VSA(sparsity=0, G_c=0) deviates from naive dense SDPA: " + f"max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="VSA needs CUDA") +@pytest.mark.parametrize("sparsity", [0.0, 0.5], ids=["s0", "s0p5"]) +def test_vsa_self_attention_finite(sparsity: float): + """VSA forward must produce finite output (no NaN/Inf) at any supported sparsity.""" + from tensorrt_llm._torch.visual_gen.attention_backend import set_vsa_forward_context + + s = _build_vsa_setup(sparsity=sparsity, batch_size=1, seed=0) + + with torch.no_grad(), set_vsa_forward_context(s.metadata): + out = s.integrated(s.hidden_states, freqs=s.freqs_SHD, gate_compress=s.gate_compress_zero) + + assert out.shape == s.hidden_states.shape + nan_count = torch.isnan(out).sum().item() + inf_count = torch.isinf(out).sum().item() + assert nan_count == 0 and inf_count == 0, ( + f"VSA produced non-finite output at sparsity={sparsity}: NaN={nan_count}, Inf={inf_count}" + ) + + def test_trtllm_cached_prepare(): """Test that TRTLLM attention cached prepare works correctly. diff --git a/tests/unittest/_torch/visual_gen/test_attention_perf.py b/tests/unittest/_torch/visual_gen/test_attention_perf.py index 90a8ac3608c0..763309f3f96d 100644 --- a/tests/unittest/_torch/visual_gen/test_attention_perf.py +++ b/tests/unittest/_torch/visual_gen/test_attention_perf.py @@ -29,13 +29,18 @@ """ import time -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from types import SimpleNamespace from typing import Dict, Optional, Tuple import pytest import torch +from tensorrt_llm._torch.visual_gen.attention_backend import ( + VSAMetadataBuilder, + set_vsa_forward_context, +) + # ============================================================================ # Flash Attention 4 availability # ============================================================================ @@ -49,7 +54,11 @@ create_attention_metadata_state, ) from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode -from tensorrt_llm.visual_gen.args import AttentionConfig, QuantAttentionConfig +from tensorrt_llm.visual_gen.args import ( + AttentionConfig, + QuantAttentionConfig, + VideoSparseAttentionConfig, +) _flash_attn4_available = _fa4_fwd is not None _cute_dsl_available = _cute_dsl_import_error is None @@ -130,6 +139,17 @@ def get_elapsed_time(): yield get_elapsed_time +def bench_fn(device: torch.device, fn, n_iters: int) -> torch.Tensor: + """Time `fn` over `n_iters` GPU-timed iterations; returns per-iteration ms.""" + times = [] + with torch.no_grad(): + for _ in range(n_iters): + with cuda_timer(device) as get_t: + fn() + times.append(get_t()) + return torch.tensor(times) + + @contextmanager def nvtx_range(name: str): """Context manager for NVTX range profiling.""" @@ -171,6 +191,7 @@ def create_model_config( eps: float = 1e-6, attn_backend: str = "VANILLA", quant_attention_config: "QuantAttentionConfig | None" = None, + vsa_sparsity: "float | None" = None, ) -> DiffusionModelConfig: """Create a mock DiffusionModelConfig for testing.""" pretrained_config = SimpleNamespace( @@ -180,11 +201,16 @@ def create_model_config( eps=eps, ) + sparse_attention_config = ( + VideoSparseAttentionConfig(vsa_sparsity=vsa_sparsity) if vsa_sparsity is not None else None + ) + config = DiffusionModelConfig( pretrained_config=pretrained_config, attention=AttentionConfig( backend=attn_backend, quant_attention_config=quant_attention_config, + sparse_attention_config=sparse_attention_config, ), skip_create_weights_in_init=False, ) @@ -230,6 +256,21 @@ def _is_sage_attention_enabled(model: Attention) -> bool: return getattr(inner_backend, "quant_attention_config", None) is not None +def _init_attention_weights(attn: Attention, std: float = 0.02) -> None: + """Initialize TRT-LLM Linear weights to N(0, std) to avoid RMSNorm NaN.""" + with torch.no_grad(): + if getattr(attn, "qkv_proj", None) is not None: + attn.qkv_proj.weight.normal_(0.0, std) + if attn.qkv_proj.bias is not None: + attn.qkv_proj.bias.zero_() + if getattr(attn, "qk_norm", False): + attn.norm_q.weight.fill_(1.0) + attn.norm_k.weight.fill_(1.0) + attn.to_out[0].weight.normal_(0.0, std) + if attn.to_out[0].bias is not None: + attn.to_out[0].bias.zero_() + + # ============================================================================ # Performance benchmark class # ============================================================================ @@ -286,6 +327,8 @@ def create_attention_model( head_dim: int, backend: str, quant_attention_config: "QuantAttentionConfig | None" = None, + vsa_sparsity: "float | None" = None, + init_std: "float | None" = None, ) -> Attention: """Create a WAN self-attention model with specified backend.""" config = create_model_config( @@ -294,11 +337,14 @@ def create_attention_model( head_dim, attn_backend=backend, quant_attention_config=quant_attention_config, + vsa_sparsity=vsa_sparsity, ) model = Attention(hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=config).to( self.device ) model.eval() + if init_std is not None: + _init_attention_weights(model, init_std) return model def create_cross_attention_model( @@ -434,9 +480,16 @@ def benchmark_single( backend: str, verbose: bool = True, quant_attention_config: "QuantAttentionConfig | None" = None, + vsa_sparsity: "float | None" = None, + latent_shape: "Tuple[int, int, int] | None" = None, + init_std: "float | None" = None, ) -> Optional[Dict]: """Benchmark a single configuration. + For the VSA path, pass ``vsa_sparsity`` and the 3-D ``latent_shape`` + (T, H, W) with ``seq_len == T * H * W``; the forward is run inside a + VSA forward context with a zero ``gate_compress``. + Returns: Dict with timing statistics or None if test failed/skipped """ @@ -452,16 +505,56 @@ def benchmark_single( try: # Create model and data model = self.create_attention_model( - hidden_size, num_heads, head_dim, backend, quant_attention_config + hidden_size, + num_heads, + head_dim, + backend, + quant_attention_config, + vsa_sparsity=vsa_sparsity, + init_std=init_std, ) hidden_states, freqs = self.create_test_data(batch_size, seq_len, hidden_size, head_dim) + # Fail fast if the requested backend silently resolved to something + # else (e.g. CUTEDSL/VSA downgraded to VANILLA dense for an + # unsupported config). + resolved_backend = getattr(model, "attn_backend", backend) + if vsa_sparsity is not None: + assert resolved_backend == "CUTEDSL", ( + f"VSA run requested CUTEDSL but the attention module resolved to " + f"{resolved_backend!r}; refusing to report VSA timings for a dense fallback." + ) + + # VSA needs an active forward context + a zero gate_compress; other + # backends use neither. + vsa_metadata = None + gate_compress = None + if vsa_sparsity is not None: + vsa_metadata = VSAMetadataBuilder().build( + current_timestep=0, + raw_latent_shape=latent_shape, + patch_size=(1, 1, 1), + vsa_sparsity=vsa_sparsity, + device=self.device, + ) + gate_compress = torch.zeros_like(hidden_states) + + def _forward(): + ctx = ( + set_vsa_forward_context(vsa_metadata) + if vsa_metadata is not None + else nullcontext() + ) + kwargs = {"gate_compress": gate_compress} if gate_compress is not None else {} + with ctx: + return model(hidden_states, freqs=freqs, **kwargs) + # Warmup with nvtx_range(f"warmup_{backend}"): with torch_profiler_range(f"warmup_{backend}"): with torch.no_grad(): for _ in range(self.warmup_iterations): - _ = model(hidden_states, freqs=freqs) + _ = _forward() if self.device.type == "cuda": torch.cuda.synchronize() @@ -474,7 +567,7 @@ def benchmark_single( for i in range(self.benchmark_iterations): with nvtx_range(f"iter_{backend}_{i}"): with cuda_timer(self.device) as get_time: - _ = model(hidden_states, freqs=freqs) + _ = _forward() times.append(get_time()) # Statistics @@ -489,6 +582,7 @@ def benchmark_single( "p99_ms": torch.quantile(times_tensor, 0.99).item(), "times_ms": times, "uses_sage": _is_sage_attention_enabled(model), + "resolved_backend": resolved_backend, } # Calculate throughput (approximate TOPS) @@ -504,6 +598,9 @@ def benchmark_single( return stats + except AssertionError: + # Backend-resolution / fail-fast checks are genuine test failures. + raise except Exception as e: if verbose: print(f" {backend}: ERROR - {e}") @@ -935,6 +1032,226 @@ def test_fa4_cross_attn_quick( assert result["avg_ms"] > 0 +# ============================================================================ +# VSA fine-stage kernel vs FA4 (kernel-level) +# ============================================================================ + + +class TestVsaVsFa4KernelPerformance: + """Kernel-level VSA fine-stage vs FA4 benchmark. + + Times the two kernels directly (no QKV proj / norm / gate / tile-untile) so + the comparison reflects only the attention math. + """ + + @pytest.fixture(autouse=True) + def setup(self): + if not _flash_attn4_available: + pytest.skip( + "FlashAttention 4 not available" + + (f": {_fa4_import_error}" if _fa4_import_error else "") + ) + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, + block_sparse_attn_from_indices_cute, + is_cute_supported, + ) + + if not CUTE_AVAILABLE: + pytest.skip("cuda-bindings or cutlass-dsl not importable (VSA CuTe path)") + self.device = torch.device("cuda") + self.dtype = torch.bfloat16 + self.warmup = 10 + self.iters = 50 + self._cute_fn = block_sparse_attn_from_indices_cute + self._is_cute_supported = is_cute_supported + + @pytest.mark.parametrize( + "batch,num_heads,seq_len,head_dim,block_size,sparsity", + [ + (1, 40, 131_072, 128, 64, 0.5), + ], + ids=["wan_BS1_H40_S131k_D128_blk64_s0.5"], + ) + def test_vsa_kernel_vs_fa4( + self, + batch: int, + num_heads: int, + seq_len: int, + head_dim: int, + block_size: int, + sparsity: float, + ): + from tensorrt_llm._torch.visual_gen.attention_backend.cute_dsl import VSA_KERNEL_MAX_CUBES + + assert seq_len % block_size == 0, "seq_len must be a multiple of block_size" + num_cubes = seq_len // block_size + if num_cubes > VSA_KERNEL_MAX_CUBES: + pytest.skip( + f"num_cubes={num_cubes} exceeds VSA_KERNEL_MAX_CUBES={VSA_KERNEL_MAX_CUBES}" + ) + topk = max(1, int(round((1.0 - sparsity) * num_cubes))) + + # FA4 expects [B, S, H, D]; VSA CuTe expects [B, H, S, D]. + q_nhd = torch.randn( + batch, seq_len, num_heads, head_dim, device=self.device, dtype=self.dtype + ) + k_nhd = torch.randn_like(q_nhd) + v_nhd = torch.randn_like(q_nhd) + q_bhsd = q_nhd.transpose(1, 2).contiguous() + k_bhsd = k_nhd.transpose(1, 2).contiguous() + v_bhsd = v_nhd.transpose(1, 2).contiguous() + + if not self._is_cute_supported(q_bhsd): + pytest.skip("VSA CuTe path requires sm_100+ Blackwell + bf16/fp16 + head_dim=128") + + # Group2QInterleaveKV requires paired Q-blocks (2i, 2i+1) to share KV + # selections. A constant index set across all Q-blocks trivially satisfies that. + q2k_idx = ( + torch.arange(topk, device=self.device, dtype=torch.int32) + .view(1, 1, 1, topk) + .expand(batch, num_heads, num_cubes, topk) + .contiguous() + ) + q2k_num = torch.full( + (batch, num_heads, num_cubes), topk, dtype=torch.int32, device=self.device + ) + variable_block_sizes = torch.full( + (num_cubes,), block_size, dtype=torch.int32, device=self.device + ) + + softmax_scale = head_dim**-0.5 + + def fa4_call(): + _fa4_fwd( + q_nhd, + k_nhd, + v_nhd, + softmax_scale=softmax_scale, + causal=False, + window_size_left=None, + window_size_right=None, + learnable_sink=None, + softcap=0.0, + pack_gqa=None, + mask_mod=None, + block_sparse_tensors=None, + return_lse=True, + ) + + def vsa_call(): + self._cute_fn( + q_bhsd, + k_bhsd, + v_bhsd, + q2k_idx, + q2k_num, + variable_block_sizes, + ) + + with torch.no_grad(): + for _ in range(self.warmup): + fa4_call() + vsa_call() + torch.cuda.synchronize() + + with nvtx_range("bench_fa4"): + fa4_times = bench_fn(self.device, fa4_call, self.iters) + with nvtx_range("bench_vsa_fine"): + vsa_times = bench_fn(self.device, vsa_call, self.iters) + + fa4_avg = fa4_times.mean().item() + vsa_avg = vsa_times.mean().item() + speedup = fa4_avg / vsa_avg + + print( + f"\n VSA vs FA4 (BS={batch}, H={num_heads}, S={seq_len}, " + f"D={head_dim}, blk={block_size}, sparsity={sparsity:.2f}, " + f"topk={topk}/{num_cubes}):" + ) + print(f" FA4 (dense): avg={fa4_avg:.3f}ms") + print(f" VSA (fine-stage): avg={vsa_avg:.3f}ms") + print(f" VSA speedup vs FA4: {speedup:.2f}x ({'faster' if speedup > 1 else 'slower'})") + assert fa4_avg > 0 and vsa_avg > 0 + + +# ============================================================================ +# VSA vs VANILLA at Wan 2.2 T2V 14B production shape — module-level +# ============================================================================ + + +class TestVsaVsVanillaWan22T2v14bModulePerformance: + """Module-level VSA vs VANILLA at Wan 2.2 T2V 14B 720p / 81-frame shape. + + Drives the shared WanAttentionPerformanceBenchmark engine so VSA is timed + and reported exactly like the VANILLA/TRTLLM/FA4 backends. + """ + + _BATCH = 1 + _NUM_HEADS = 40 + _HEAD_DIM = 128 + _LATENT_SHAPE = (21, 45, 80) # DiT seq shape after VAE + patchify + _SEQ_LEN = 21 * 45 * 80 # 75_600 + + @pytest.fixture(autouse=True) + def setup(self): + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, + is_cute_supported, + ) + + if not CUTE_AVAILABLE: + pytest.skip("cuda-bindings or cutlass-dsl not importable (VSA CuTe path)") + self._is_cute_supported = is_cute_supported + self.benchmark = WanAttentionPerformanceBenchmark( + warmup_iterations=5, benchmark_iterations=20 + ) + + @pytest.mark.parametrize( + "sparsity", + [0.5, 0.75, 0.85, 0.875], + ids=["s0.5", "s0.75", "s0.85", "s0.875"], + ) + def test_vsa_module_vs_vanilla_wan22_t2v_14b(self, sparsity: float): + bench = self.benchmark + probe_q = torch.empty( + self._BATCH, + self._NUM_HEADS, + 64, + self._HEAD_DIM, + device=bench.device, + dtype=bench.dtype, + ) + if not self._is_cute_supported(probe_q): + pytest.skip("VSA CuTe path requires sm_100+ Blackwell + bf16/fp16 + head_dim=128") + + common = dict( + batch_size=self._BATCH, + num_heads=self._NUM_HEADS, + seq_len=self._SEQ_LEN, + head_dim=self._HEAD_DIM, + init_std=0.02, + verbose=False, + ) + vanilla = bench.benchmark_single(backend="VANILLA", **common) + vsa = bench.benchmark_single( + backend="CUTEDSL", vsa_sparsity=sparsity, latent_shape=self._LATENT_SHAPE, **common + ) + assert vanilla is not None and vsa is not None, "benchmark returned None (skipped/failed)" + assert vsa["resolved_backend"] == "CUTEDSL", ( + f"VSA timings came from {vsa['resolved_backend']!r}, not the CUTEDSL/VSA path" + ) + assert vanilla["avg_ms"] > 0 and vsa["avg_ms"] > 0 + + speedup = vanilla["avg_ms"] / vsa["avg_ms"] + print( + f"\n VSA vs VANILLA (Wan 2.2 T2V 14B 720p/81f, latent={self._LATENT_SHAPE}, " + f"S={self._SEQ_LEN}, sparsity={sparsity:.3f}): " + f"VANILLA={vanilla['avg_ms']:.3f}ms, VSA={vsa['avg_ms']:.3f}ms, " + f"speedup={speedup:.2f}x ({'faster' if speedup > 1 else 'slower'})" + ) + + # ============================================================================ # Main entry point # ============================================================================ diff --git a/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py b/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py new file mode 100644 index 000000000000..979619e8e433 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Correctness tests for VSA T2V pipeline with Video Sparse Attention (VSA). + +Verifies >= 0.95 cosine similarity between the CuTe-DSL VSA kernel and the +SDPA-fallback VSA path (same gated coarse+fine formulation, different fine kernel). + +Models: + - FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers (720x1280, 9 frames) + +Run: + pytest tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py -v -s + +Override checkpoint path: + DIFFUSION_MODEL_PATH_WAN21_VSA=/path/to/vsa \\ + pytest tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py -v -s +""" + +import gc +import os +from pathlib import Path + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import pytest +import torch +import torch.nn.functional as F + +from tensorrt_llm._torch.visual_gen.attention_backend.cute_dsl import _cute_dsl_import_error +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader +from tensorrt_llm.visual_gen.args import ( + AttentionConfig, + TorchCompileConfig, + VideoSparseAttentionConfig, + VisualGenArgs, +) + +_cute_dsl_available = _cute_dsl_import_error is None + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +# ============================================================================ +# Path helpers +# ============================================================================ + + +def _llm_models_root() -> str: + """Return LLM_MODELS_ROOT path if set in env, assert when it's set but not a valid path.""" + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return str(root) + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or os.path.join(_llm_models_root(), default_name) + + +WAN21_VSA_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN21_VSA", "Wan2.1-VSA-T2V-14B-720P-Diffusers") + +# ============================================================================ +# Test constants +# ============================================================================ + +PROMPT = "A cat sitting on a sunny windowsill watching birds outside." +NEGATIVE_PROMPT = "" +NUM_STEPS = 4 +SEED = 42 +COS_SIM_THRESHOLD = 0.95 + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _load_vsa_pipeline(checkpoint_path: str, vsa_sparsity: float = 0.0): + """Load TRTLLM WanPipeline with CUTEDSL + VSA backend.""" + if not os.path.exists(checkpoint_path): + pytest.skip(f"Checkpoint not found: {checkpoint_path}") + if not _cute_dsl_available: + pytest.skip(f"CUTEDSL not available (requires Blackwell GPU): {_cute_dsl_import_error}") + args = VisualGenArgs( + model=checkpoint_path, + attention_config=AttentionConfig( + backend="CUTEDSL", + sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=vsa_sparsity), + ), + torch_compile_config=TorchCompileConfig(enable=False), + ) + return PipelineLoader(args).load(skip_warmup=True) + + +def _capture_trtllm_video( + pipeline, + prompt: str, + negative_prompt: str, + height: int, + width: int, + num_frames: int, + num_inference_steps: int, + guidance_scale: float, + seed: int, +) -> torch.Tensor: + """Run full TRTLLM pipeline including VAE decode; return (T, H, W, C) float in [0, 1].""" + with torch.no_grad(): + result = pipeline.forward( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + seed=seed, + ) + video = result.video # (B, T, H, W, C) uint8 + return video.float() / 255.0 + + +def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float: + """Cosine similarity between two tensors (flattened to 1D, cast to float32 on CPU).""" + a_flat = a.float().cpu().reshape(-1) + b_flat = b.float().cpu().reshape(-1) + return F.cosine_similarity(a_flat.unsqueeze(0), b_flat.unsqueeze(0)).clamp(-1.0, 1.0).item() + + +def _assert_vsa_matches_dense( + checkpoint_path: str, + height: int, + width: int, + num_frames: int, + guidance_scale: float, + model_label: str, + vsa_sparsity: float = 0.0, +) -> None: + """Compare CuTe-DSL VSA against SDPA-fallback VSA (same gated formulation, different fine kernel).""" + from unittest.mock import patch + + from tensorrt_llm._torch.visual_gen.attention_backend.cute_dsl import vsa as _vsa_module + + common_kwargs = dict( + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=NUM_STEPS, + guidance_scale=guidance_scale, + seed=SEED, + ) + + # --- CuTe-DSL path --- + vsa_pipe = _load_vsa_pipeline(checkpoint_path, vsa_sparsity=vsa_sparsity) + vsa_video = _capture_trtllm_video(vsa_pipe, **common_kwargs) + del vsa_pipe + gc.collect() + torch.cuda.empty_cache() + + # --- SDPA fallback reference (same VSA formulation, fine attn via SDPA) --- + sdpa_pipe = _load_vsa_pipeline(checkpoint_path, vsa_sparsity=vsa_sparsity) + with patch.object(_vsa_module, "is_cute_supported", return_value=False): + sdpa_video = _capture_trtllm_video(sdpa_pipe, **common_kwargs) + del sdpa_pipe + gc.collect() + torch.cuda.empty_cache() + + # --- Compare --- + assert vsa_video.numel() == sdpa_video.numel(), ( + f"{model_label}: element count mismatch — " + f"CuTe {vsa_video.shape} ({vsa_video.numel()}) vs " + f"SDPA {sdpa_video.shape} ({sdpa_video.numel()})" + ) + + cos_sim = _cosine_similarity(vsa_video, sdpa_video) + print(f"\n {model_label} cosine similarity: {cos_sim:.6f}") + assert cos_sim >= COS_SIM_THRESHOLD, ( + f"{model_label}: cosine similarity {cos_sim:.6f} < {COS_SIM_THRESHOLD}. " + f"CuTe-DSL VSA diverges from SDPA-fallback VSA. " + f"Video shapes — CuTe: {vsa_video.shape}, SDPA: {sdpa_video.shape}." + ) + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWanVsa14B_PipelineCorrectness: + """Wan2.1-VSA-T2V-14B: CuTe-DSL vs SDPA-fallback correctness (720x1280, 9 frames). + + Verifies CuTe-DSL kernel at sparsity=0.0 matches SDPA fallback with >= 0.95 cosine sim. + """ + + def test_cosine_similarity(self): + _assert_vsa_matches_dense( + checkpoint_path=WAN21_VSA_PATH, + height=720, + width=1280, + num_frames=9, + guidance_scale=5.0, + model_label="Wan2.1-VSA-T2V-14B+VSA", + vsa_sparsity=0.0, + ) + + +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWanVsaSparse: + """VSA at sparsity=0.9: config propagates, output is correctly shaped and finite.""" + + def test_sparse_vsa(self): + pipeline = _load_vsa_pipeline(WAN21_VSA_PATH, vsa_sparsity=0.9) + try: + attn_cfg = pipeline.pipeline_config.primary_model_config.attention + assert attn_cfg.backend == "CUTEDSL" + assert attn_cfg.sparse_attention_config.vsa_sparsity == 0.9 + + with torch.no_grad(): + result = pipeline.forward( + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=720, + width=1280, + num_frames=9, + num_inference_steps=NUM_STEPS, + guidance_scale=5.0, + seed=SEED, + ) + assert result.video.dim() == 5 + B, _T, H, W, C = result.video.shape + assert B == 1 and H == 720 and W == 1280 and C == 3 + assert torch.isfinite(result.video.float()).all() + assert pipeline.transformer.blocks[0].attn1.attn_backend == "CUTEDSL" + finally: + del pipeline + gc.collect() + torch.cuda.empty_cache()