From afdeb5ac0a9073d68fca9d67c9055d160993dff2 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Mon, 1 Jun 2026 14:14:17 -0700 Subject: [PATCH 01/14] VSA refactoring Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../visual_gen/attention_backend/__init__.py | 14 +- .../visual_gen/attention_backend/cute_dsl.py | 372 ++- .../visual_gen/attention_backend/parallel.py | 21 + .../visual_gen/attention_backend/utils.py | 3 + .../video_sparse_attention/__init__.py | 26 + .../block_sparse_attn_dsl_fwd.py | 2348 +++++++++++++++++ .../video_sparse_attention/interface.py | 168 ++ .../blackwell/video_sparse_attention/ptx.py | 377 +++ .../video_sparse_attention/scheduler.py | 223 ++ .../visual_gen/models/flux/pipeline_flux.py | 8 + .../visual_gen/models/flux/pipeline_flux2.py | 8 + .../visual_gen/models/ltx2/pipeline_ltx2.py | 10 + .../visual_gen/models/wan/pipeline_wan.py | 57 + .../visual_gen/models/wan/pipeline_wan_i2v.py | 28 + .../visual_gen/models/wan/transformer_wan.py | 94 +- .../_torch/visual_gen/modules/attention.py | 45 +- .../_torch/visual_gen/pipeline_loader.py | 17 + tensorrt_llm/visual_gen/__init__.py | 2 + tensorrt_llm/visual_gen/args.py | 45 +- tensorrt_llm/visual_gen/params.py | 7 + tensorrt_llm/visual_gen/sparse_attention.py | 27 +- .../test_lists/test-db/l0_b200.yml | 1 + .../multi_gpu/test_wan_vsa_ulysses.py | 248 ++ .../visual_gen/test_attention_cute_dsl_vsa.py | 329 +++ .../visual_gen/test_attention_integration.py | 109 +- .../_torch/visual_gen/test_attention_perf.py | 310 ++- 26 files changed, 4855 insertions(+), 42 deletions(-) create mode 100644 tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/__init__.py create mode 100644 tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py create mode 100644 tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/interface.py create mode 100644 tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/ptx.py create mode 100644 tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/scheduler.py create mode 100644 tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py create mode 100644 tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py b/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py index 5c5f6f18a007..fbcbc931f0c4 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py @@ -20,7 +20,14 @@ simplified metadata that doesn't require KV caching. """ -from .cute_dsl import CuTeDSLAttention +from .cute_dsl import ( + VSA_TILE_SIZE, + CuTeDSLAttention, + VSAMetadata, + VSAMetadataBuilder, + get_vsa_forward_context, + set_vsa_forward_context, +) from .flash_attn4 import FlashAttn4Attention from .interface import AttentionBackend, AttentionTensorLayout from .parallel import Attention2DAttention, RingAttention, UlyssesAttention, wrap_parallel_attention @@ -42,4 +49,9 @@ "VanillaAttention", "RingAttention", "wrap_parallel_attention", + "VSAMetadata", + "VSAMetadataBuilder", + "VSA_TILE_SIZE", + "get_vsa_forward_context", + "set_vsa_forward_context", ] diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py index 013065600207..e96781b345e8 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py @@ -13,16 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -CuTe DSL (NVIDIA kernels) Backend for Visual Generation Models +CuTe DSL Backend for Visual Generation Models -Uses pre-compiled cubins derived from CUTLASS CuTe DSL FMHA. +CuTeDSLAttention runs the VSA sparse path when sparse_attention_config is set, +otherwise the dense cubin path (with optional QK16PV8 quantization). Expects NHD layout ([B, S, H, D]) and supports float16/bfloat16. """ import math -from typing import Optional, Tuple +from contextlib import contextmanager +from dataclasses import dataclass +from math import ceil +from typing import Dict, Optional, Tuple import torch +import torch.nn.functional as F from tensorrt_llm.visual_gen.args import QuantAttentionConfig @@ -43,11 +48,211 @@ _cute_dsl_import_error = e +# VSA (Video Sparse Attention) sparse-path helpers + +# Must match the Blackwell kernel's block_size expectation. +VSA_TILE_SIZE: Tuple[int, int, int] = (4, 4, 4) + +# Kernel's SMEM buffer for variable_block_sizes is fixed-size and unchecked, +# so num_cubes must stay <= this. +VSA_KERNEL_MAX_CUBES: int = 4 * 1024 + + +def _get_tile_partition_indices( + dit_seq_shape: Tuple[int, int, int], + tile_size: Tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + T, H, W = dit_seq_shape + tT, tH, tW = tile_size + nT, nH, nW = ceil(T / tT), ceil(H / tH), ceil(W / tW) + + bt = torch.arange(nT, device=device).view(nT, 1, 1, 1, 1, 1) + bh = torch.arange(nH, device=device).view(1, nH, 1, 1, 1, 1) + bw = torch.arange(nW, device=device).view(1, 1, nW, 1, 1, 1) + lt = torch.arange(tT, device=device).view(1, 1, 1, tT, 1, 1) + lh = torch.arange(tH, device=device).view(1, 1, 1, 1, tH, 1) + lw = torch.arange(tW, device=device).view(1, 1, 1, 1, 1, tW) + + gt = bt * tT + lt + gh = bh * tH + lh + gw = bw * tW + lw + valid = (gt < T) & (gh < H) & (gw < W) + flat = gt * (H * W) + gh * W + gw + out = torch.where(valid, flat, torch.full_like(flat, -1)) + return out.reshape(-1).to(torch.long) + + +def _construct_variable_block_sizes( + dit_seq_shape: Tuple[int, int, int], + num_tiles: Tuple[int, int, int], + tile_size: Tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + T, H, W = dit_seq_shape + tT, tH, tW = tile_size + nT, nH, nW = num_tiles + + bt = torch.arange(nT, device=device) + bh = torch.arange(nH, device=device) + bw = torch.arange(nW, device=device) + valid_t = (T - bt * tT).clamp(max=tT) + valid_h = (H - bh * tH).clamp(max=tH) + valid_w = (W - bw * tW).clamp(max=tW) + sizes = valid_t.view(nT, 1, 1) * valid_h.view(1, nH, 1) * valid_w.view(1, 1, nW) + return sizes.reshape(-1).to(torch.long) + + +@dataclass +class VSAMetadata: + """Per-timestep metadata required by the VSA sparse path.""" + + current_timestep: int + dit_seq_shape: Tuple[int, int, int] + vsa_sparsity: float + num_tiles: Tuple[int, int, int] + total_seq_length: int + padded_seq_length: int + tile_partition_indices: torch.LongTensor + reverse_tile_partition_indices: torch.LongTensor + variable_block_sizes: torch.LongTensor + non_pad_index: torch.LongTensor + gather_idx: torch.LongTensor + + +class VSAMetadataBuilder: + """Builds VSAMetadata; caches per-shape index tensors so torch.compile + guards stay stable across denoising steps.""" + + def __init__(self) -> None: + self._cache: Dict[Tuple[Tuple[int, int, int], str], Dict[str, object]] = {} + + def _build_shape_payload( + self, + dit_seq_shape: Tuple[int, int, int], + device: torch.device, + ) -> Dict[str, object]: + T, H, W = dit_seq_shape + tT, tH, tW = VSA_TILE_SIZE + num_tiles = (ceil(T / tT), ceil(H / tH), ceil(W / tW)) + total_seq_length = T * H * W + padded_seq_length = num_tiles[0] * num_tiles[1] * num_tiles[2] * tT * tH * tW + + tile_partition_indices = _get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device) + non_pad_index = (tile_partition_indices >= 0).nonzero(as_tuple=True)[0] + gather_idx = tile_partition_indices[non_pad_index] + + reverse = torch.zeros(total_seq_length, dtype=torch.long, device=device) + reverse[gather_idx] = torch.arange(len(non_pad_index), dtype=torch.long, device=device) + + variable_block_sizes = _construct_variable_block_sizes( + dit_seq_shape, num_tiles, VSA_TILE_SIZE, device + ) + + return { + "dit_seq_shape": dit_seq_shape, + "num_tiles": num_tiles, + "total_seq_length": total_seq_length, + "padded_seq_length": padded_seq_length, + "tile_partition_indices": tile_partition_indices, + "reverse_tile_partition_indices": reverse, + "variable_block_sizes": variable_block_sizes, + "non_pad_index": non_pad_index, + "gather_idx": gather_idx, + } + + def build( + self, + current_timestep: int, + raw_latent_shape: Tuple[int, int, int], + patch_size: Tuple[int, int, int], + vsa_sparsity: float, + device: torch.device, + ) -> VSAMetadata: + dit_seq_shape = ( + raw_latent_shape[0] // patch_size[0], + raw_latent_shape[1] // patch_size[1], + raw_latent_shape[2] // patch_size[2], + ) + cache_key = (dit_seq_shape, str(device)) + payload = self._cache.get(cache_key) + if payload is None: + payload = self._build_shape_payload(dit_seq_shape, device) + self._cache[cache_key] = payload + + return VSAMetadata( + current_timestep=current_timestep, + vsa_sparsity=vsa_sparsity, + **payload, # type: ignore[arg-type] + ) + + +_vsa_forward_context: Optional[VSAMetadata] = None + + +@contextmanager +def set_vsa_forward_context(metadata: VSAMetadata): + global _vsa_forward_context + prev = _vsa_forward_context + _vsa_forward_context = metadata + try: + yield + finally: + _vsa_forward_context = prev + + +def get_vsa_forward_context() -> Optional[VSAMetadata]: + return _vsa_forward_context + + +def _mean_pool_cubes( + x_tiled: torch.Tensor, + variable_block_sizes: torch.LongTensor, + prod_tile: int, + num_cubes: int, +) -> torch.Tensor: + B, _padded, H, D = x_tiled.shape + x_cubes = x_tiled.view(B, num_cubes, prod_tile, H, D) + # fp32 accumulation: bf16 sum over 64 tokens perturbs the coarse softmax. + x_sum = x_cubes.float().sum(dim=2) + valid_counts = variable_block_sizes.float().clamp(min=1).view(1, num_cubes, 1, 1) + return (x_sum / valid_counts).to(x_tiled.dtype) + + +class VSAPreprocessor: + """Reorders NHD tokens into tile-major layout and zero-pads to tile boundaries.""" + + @staticmethod + def tile( + x: torch.Tensor, + non_pad_index: torch.LongTensor, + gather_idx: torch.LongTensor, + padded_seq_len: int, + ) -> torch.Tensor: + # index_select + index_copy_ instead of chained advanced indexing so + # torch.compile can trace this without a graph break. + B, _S, H, D = x.shape + x_valid = x.index_select(1, gather_idx) + x_padded = x.new_zeros(B, padded_seq_len, H, D) + x_padded.index_copy_(1, non_pad_index, x_valid) + return x_padded + + @staticmethod + def untile( + x: torch.Tensor, + reverse_tile_partition_indices: torch.LongTensor, + non_pad_index: torch.LongTensor, + ) -> torch.Tensor: + return x.index_select(1, non_pad_index).index_select(1, reverse_tile_partition_indices) + + class CuTeDSLAttention(AttentionBackend): """ CuTe DSL (NVIDIA kernels) backend for diffusion models. - Uses pre-compiled cubin kernels (head_dim=128 only). + Dense path uses pre-compiled cubins and requires head_dim=128. The VSA + sparse path (sparse_attention_config set) uses a JIT-compiled CuTe kernel + when head_dim=128 / fp16-bf16 / sm100+, and otherwise falls back to dense SDPA. """ def __init__( @@ -58,11 +263,13 @@ def __init__( num_kv_heads: Optional[int] = None, dtype: Optional[torch.dtype] = None, quant_attention_config: Optional[QuantAttentionConfig] = None, + sparse_attention_config=None, skip_softmax_threshold_scale: Optional[float] = None, **kwargs, ): - # Only head_dim=128 cubins are packaged. - if head_dim != 128: + # Dense path requires head_dim=128 (packaged cubins); the VSA sparse + # path JIT-compiles per shape, so it has no such restriction. + if sparse_attention_config is None and head_dim != 128: raise ValueError(f"CUTEDSL cubins require head_dim=128, got head_dim={head_dim}.") self.layer_idx = layer_idx self.num_heads = num_heads @@ -70,6 +277,7 @@ def __init__( self.num_kv_heads = num_kv_heads or num_heads self.dtype = dtype self.quant_attention_config = quant_attention_config + self.sparse_attention_config = sparse_attention_config self.skip_softmax_threshold_scale = skip_softmax_threshold_scale self.scale = 1.0 / math.sqrt(head_dim) @@ -173,22 +381,31 @@ def forward( v: torch.Tensor, *, attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL, + gate_compress: Optional[torch.Tensor] = None, + gate_fine: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ Forward pass using CuTe DSL (NVIDIA kernels). - Dimensions are derived from tensor shapes (NHD layout: ``[B, S, H, D]``). + Dimensions are derived from tensor shapes (NHD layout: [B, S, H, D]). + Dispatches to _forward_vsa when sparse_attention_config is set + (VSA sparse path); otherwise runs the dense cubins via forward_with_lse. Args: q: Query tensor [batch_size, seq_len, num_heads, head_dim] k: Key tensor [batch_size, seq_len_kv, num_kv_heads, head_dim] v: Value tensor [batch_size, seq_len_kv, num_kv_heads, head_dim] - attention_mask: Attention mask type (CAUSAL or FULL) + attention_mask: Attention mask type (CAUSAL or FULL) — dense path only. + gate_compress: VSA path only — G_c gate for the coarse branch. + gate_fine: VSA path only — G_f gate for the fine branch. None means + constant 1. Returns: Output tensor [batch_size, seq_len, num_heads, head_dim] """ + if self.sparse_attention_config is not None: + return self._forward_vsa(q, k, v, gate_compress=gate_compress, gate_fine=gate_fine) output, _ = self.forward_with_lse(q, k, v, attention_mask=attention_mask, **kwargs) return output @@ -201,7 +418,8 @@ def forward_with_lse( **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Forward pass returning both output and log-sum-exp (LSE). + Forward pass returning both output and log-sum-exp (LSE). Dense path + only — the VSA sparse path does not produce an LSE. Returns: output: [batch_size, seq_len, num_heads, head_dim] @@ -209,12 +427,148 @@ def forward_with_lse( always in float32. Used for numerically stable combination of partial attention results in Attention2D parallelism. """ + if self.sparse_attention_config is not None: + raise RuntimeError( + "CuTeDSLAttention.forward_with_lse() does not support the VSA " + "sparse path. Use forward() instead, or construct without " + "sparse_attention_config to use the dense path." + ) q, k, v, is_causal, origin_dtype = self._prepare_inputs(q, k, v, attention_mask) output, lse = self._fwd(q, k, v, is_causal, **kwargs) if output.dtype != origin_dtype: output = output.to(origin_dtype) return output, lse.transpose(1, 2) + # Dynamo can't guard on the module-level mutable global, so this read + # runs in eager. + @torch.compiler.disable + def _get_vsa_inputs(self): + ctx: Optional[VSAMetadata] = get_vsa_forward_context() + if ctx is None: + raise RuntimeError( + "CuTeDSLAttention._forward_vsa called without an active VSA forward context. " + "Wrap each transformer call with set_vsa_forward_context()." + ) + return ( + ctx.non_pad_index, + ctx.gather_idx, + ctx.reverse_tile_partition_indices, + ctx.variable_block_sizes, + ctx.padded_seq_length, + ctx.num_tiles[0] * ctx.num_tiles[1] * ctx.num_tiles[2], + ctx.vsa_sparsity, + ) + + def _forward_vsa( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + gate_compress: Optional[torch.Tensor], + gate_fine: Optional[torch.Tensor], + ) -> torch.Tensor: + """VSA forward: coarse mean-pool + fine block-sparse top-K. + + Args: + q, k, v: [B, S, H, D] in original (un-tiled) token order. + gate_compress: [B, S, H, D] G_c gate weighting the coarse branch O_c. + gate_fine: Optional [B, S, H, D] G_f gate weighting the fine branch + O_f. None means constant 1 (dense behavior preserved). + + Returns: + [B, S, H, D] in the same original token order. + """ + # Lazy import: the VSA kernels package is optional and may not be + # importable in environments without the cute-dsl runtime. + from ..cute_dsl_kernels.blackwell.video_sparse_attention import ( + block_sparse_attn_from_indices_cute, + is_cute_supported, + ) + + if gate_compress is None: + raise ValueError( + "CuTeDSLAttention VSA path requires gate_compress. " + "Ensure to_gate_compress is wired in the transformer block." + ) + + ( + non_pad_index, + gather_idx, + reverse_tile_partition_indices, + variable_block_sizes, + padded_len, + num_cubes, + vsa_sparsity, + ) = self._get_vsa_inputs() + + B, S, H, D = q.shape + prod_tile = VSA_TILE_SIZE[0] * VSA_TILE_SIZE[1] * VSA_TILE_SIZE[2] + cur_topk = max(1, ceil((1.0 - vsa_sparsity) * num_cubes)) + + q_t = VSAPreprocessor.tile(q, non_pad_index, gather_idx, padded_len) + k_t = VSAPreprocessor.tile(k, non_pad_index, gather_idx, padded_len) + v_t = VSAPreprocessor.tile(v, non_pad_index, gather_idx, padded_len) + + q_c = _mean_pool_cubes(q_t, variable_block_sizes, prod_tile, num_cubes) + k_c = _mean_pool_cubes(k_t, variable_block_sizes, prod_tile, num_cubes) + v_c = _mean_pool_cubes(v_t, variable_block_sizes, prod_tile, num_cubes) + + scale = D**-0.5 + scores_c = torch.einsum("bnhd,bmhd->bhnm", q_c, k_c) * scale + attn_probs_c = scores_c.softmax(dim=-1) + o_c = torch.einsum("bhnm,bmhd->bnhd", attn_probs_c, v_c) + + use_cute = is_cute_supported(q) and (q.dtype == k.dtype == v.dtype) + topk_indices = attn_probs_c.topk(cur_topk, dim=-1).indices.to(torch.int32) + + o_c_tiled = ( + o_c.unsqueeze(2).expand(B, num_cubes, prod_tile, H, D).reshape(B, padded_len, H, D) + ) + + if use_cute: + assert num_cubes <= VSA_KERNEL_MAX_CUBES, ( + f"VSA CuTe kernel supports at most {VSA_KERNEL_MAX_CUBES} cubes " + f"(SMEM-allocated variable_block_sizes buffer); got num_cubes={num_cubes}. " + "Lower video resolution/length or fall back to dense SDPA." + ) + q_hnd = q_t.transpose(1, 2).contiguous() + k_hnd = k_t.transpose(1, 2).contiguous() + v_hnd = v_t.transpose(1, 2).contiguous() + q2k_num = torch.full((B, H, num_cubes), cur_topk, dtype=torch.int32, device=q.device) + o_hnd, _lse = block_sparse_attn_from_indices_cute( + q_hnd, + k_hnd, + v_hnd, + q2k_idx=topk_indices.contiguous(), + q2k_num=q2k_num, + variable_block_sizes=variable_block_sizes.to(torch.int32), + ) + o_f_tiled = o_hnd.transpose(1, 2) + + # Padded rows hold kernel garbage; zero-padded gates mask the coarse + # term and untile discards padded positions from both branches. + gate_c_t = VSAPreprocessor.tile(gate_compress, non_pad_index, gather_idx, padded_len) + if gate_fine is not None: + gate_f_t = VSAPreprocessor.tile(gate_fine, non_pad_index, gather_idx, padded_len) + combined_tiled = gate_c_t * o_c_tiled + gate_f_t * o_f_tiled + else: + combined_tiled = gate_c_t * o_c_tiled + o_f_tiled + return VSAPreprocessor.untile( + combined_tiled, reverse_tile_partition_indices, non_pad_index + ) + + # SDPA must run on the un-tiled Q/K/V — padded zero K/V slots would + # otherwise absorb softmax mass and pollute the output. Untile o_c so + # both branches combine in original-flat order. + o_c_full = VSAPreprocessor.untile(o_c_tiled, reverse_tile_partition_indices, non_pad_index) + o_f = F.scaled_dot_product_attention( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + ).transpose(1, 2) + if gate_fine is not None: + return gate_compress * o_c_full + gate_fine * o_f + return gate_compress * o_c_full + o_f + @classmethod def support_lse(cls) -> bool: return True diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py b/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py index 241aa8b2d166..2e3d5a81b453 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py @@ -143,10 +143,23 @@ def _forward_unfused( v: torch.Tensor, **kwargs, ) -> torch.Tensor: + # gate_compress / gate_fine (VSA) must follow the same all-to-all as + # Q/K/V so they arrive at the inner backend in the same (full-S, sharded-H) layout. + gate_compress = kwargs.pop("gate_compress", None) + gate_fine = kwargs.pop("gate_fine", None) + batch_size = q.shape[0] q = all_to_all_4d(q, scatter_dim=2, gather_dim=1, process_group=self.process_group) k = all_to_all_4d(k, scatter_dim=2, gather_dim=1, process_group=self.process_group) v = all_to_all_4d(v, scatter_dim=2, gather_dim=1, process_group=self.process_group) + if gate_compress is not None: + gate_compress = all_to_all_4d( + gate_compress, scatter_dim=2, gather_dim=1, process_group=self.process_group + ) + if gate_fine is not None: + gate_fine = all_to_all_4d( + gate_fine, scatter_dim=2, gather_dim=1, process_group=self.process_group + ) seq_len_full = q.shape[1] kv_seq_len_full = k.shape[1] @@ -155,12 +168,20 @@ def _forward_unfused( q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) + if gate_compress is not None: + gate_compress = gate_compress.transpose(1, 2) + if gate_fine is not None: + gate_fine = gate_fine.transpose(1, 2) # Caller passed pre-A2A (sharded) seq_lens; hand the inner # backend the post-A2A lengths instead. kwargs["batch_size"] = batch_size kwargs["seq_len"] = seq_len_full kwargs["seq_len_kv"] = kv_seq_len_full + if gate_compress is not None: + kwargs["gate_compress"] = gate_compress + if gate_fine is not None: + kwargs["gate_fine"] = gate_fine output = self.inner_backend.forward(q=q, k=k, v=v, **kwargs) diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py b/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py index a3ce3df89cd9..de0c143e976d 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py @@ -128,6 +128,9 @@ def create_attention( "DiffusionModelConfig; creation path must not allocate metadata implicitly." ) kwargs["attention_metadata_state"] = attention_metadata_state + if backend.upper() == "CUTEDSL" and attention_config is not None: + # CuTeDSLAttention dispatches dense / VSA based on this sub-config. + kwargs["sparse_attention_config"] = attention_config.sparse_attention_config return attn_cls( layer_idx=layer_idx, diff --git a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/__init__.py b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/__init__.py new file mode 100644 index 000000000000..2ced031ae28c --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/__init__.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .block_sparse_attn_dsl_fwd import ( + VideoSparseAttentionForwardGroup2QInterleaveKV as VideoSparseAttentionForward, +) +from .interface import CUTE_AVAILABLE, block_sparse_attn_from_indices_cute, is_cute_supported + +__all__ = [ + "CUTE_AVAILABLE", + "VideoSparseAttentionForward", + "block_sparse_attn_from_indices_cute", + "is_cute_supported", +] diff --git a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py new file mode 100644 index 000000000000..3bae813fc7f1 --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py @@ -0,0 +1,2348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Video Sparse Attention Forward (Blackwell) + +This file implements the forward pass of the video sparse attention. +It will produce the output and the log-sum-exp of the attention scores. + +This implementation requires zero-padding on the input tensor, to align with the block-size (64x64). +""" + +import math +from functools import partial +from typing import Callable, Optional, Tuple + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.nvgpu import cpasync, tcgen05 + +from . import ptx, scheduler + + +def make_thread_cooperative_group(size: int): + """ + Create a thread cooperative group. + """ + return pipeline.CooperativeGroup(pipeline.Agent.Thread, size) + + +def next_power_of_2(x: int) -> int: + if x <= 0: + return 1 + return 1 << (x - 1).bit_length() + + +SM100_TMEM_CAPACITY_COLUMNS: int = 512 + + +class VideoSparseAttentionForwardGroup2QInterleaveKV: + """ + This class implements the forward pass of the video sparse attention. + It doesn't require swap QK. + V0 + K0 K1 V1 + Q0 S0 0 O0 + Q1 0 S1 O1 + """ + + def __init__( + self, + block_m: int, + block_n: int, + headdim: int, + ): + self.block_m = block_m + self.block_n = block_n + assert block_m == 64 and block_n == 64, "Block size must be 64x64" + self.headdim = headdim + assert self.headdim == 128, "Head dimension must be 128" + + self.acc_dtype: cutlass.Numeric = cutlass.Float32 + self.cta_group = tcgen05.CtaGroup.ONE + self.cluster_shape_mn = (1, 1) + self.mma_tiler_qk = (self.block_m * 2, self.block_n * 2, 1) + self.mma_tiler_pv = (self.block_m * 2, self.headdim, 1) + + self.occupancy = 1 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + self.threads_per_warp: int = 32 + self.threads_per_wg: int = 128 + + self.load_warp_id = 0 + self.mma_warp_id = 1 + self.epilogue_warp_id = 2 + self.empty_warp_ids = (3,) + self.softmax_warp_ids = (4, 5, 6, 7) + self.correction_warp_ids = (8, 9, 10, 11) + + self.threads_per_cta: int = self.threads_per_warp * len( + ( + *self.softmax_warp_ids, + *self.correction_warp_ids, + self.mma_warp_id, + self.load_warp_id, + self.epilogue_warp_id, + *self.empty_warp_ids, + ) + ) + + self.buffer_align_bytes: int = 1024 + + self.scheduler_cls = scheduler.StaticPersistentScheduler + + self.rescale_threshold: float = 8.0 + + self.mask_O: bool = False + # either P is in SMEM or in TMEM + self.p_in_smem: bool = False + + self.num_regs_load: int = 104 + self.num_regs_mma: int = 104 + self.num_regs_epi: int = 104 + self.num_regs_softmax: int = 224 # bigger as much as possible + self.num_regs_correction: int = 176 + self.num_regs_empty: int = 24 + + def _compute_grid( + self, + num_q_blocks: int, + num_kv_blocks: int, + num_heads: int, + batchsize: int, + headdim: int, + headdim_v: int, + ) -> Tuple[Tuple[int, int, int], scheduler.ParamsBase]: + cluster_shape = (*self.cluster_shape_mn, 1) + + scheduler_params = scheduler.TileSchedulerParams( + num_block=cute.ceil_div(num_q_blocks, 2), + num_head=num_heads, + num_batch=batchsize, + headdim=headdim, + headdim_v=headdim_v, + ) + params = self.scheduler_cls.to_underlying_arguments(scheduler_params) + + grid = self.scheduler_cls.get_grid_shape(params) + grid = cute.round_up(grid, cluster_shape) + return grid, params + + def _compute_stages( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + input_dtype: cutlass.Numeric, + ): + self.q_stage = 1 + assert self.q_stage == 1, "Q will be preserving in SMEM" + self.kv_stage = 2 if self.p_in_smem else 3 + self.s_stage = 2 + assert self.s_stage == 2, "S stage must be 2 for Interleaving KV blocks" + self.o_stage = 1 + self.epi_stage = 1 if self.p_in_smem else 2 + + self.scale_buffers = 2 + + self.tmem_cols_S = self.mma_tiler_qk[1] * self.s_stage + self.tmem_cols_O = self.mma_tiler_pv[1] * self.o_stage + + # P reuses S TMEM, when there is no space for P + self.p_in_s: bool = (not self.p_in_smem) and (self.o_stage == 2) + + self.tmem_cols_P = (self.mma_tiler_pv[1] // 2) * self.s_stage * (not self.p_in_s) + + self.tmem_offset_S = 0 + self.tmem_offset_O = self.tmem_cols_S + self.tmem_offset_P = self.tmem_offset_O + self.tmem_cols_O + + self.tmem_alloc_cols = self.tmem_cols_S + self.tmem_cols_O + self.tmem_cols_P + self.tmem_alloc_cols = next_power_of_2(self.tmem_alloc_cols) + assert self.tmem_alloc_cols <= SM100_TMEM_CAPACITY_COLUMNS + self.do_tmem_alloc: bool = self.tmem_alloc_cols != SM100_TMEM_CAPACITY_COLUMNS + + def _setup_attributes( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + input_dtype: cutlass.Numeric, + ): + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,) + ) + + qk_mma_inst_shape_k = cute.size(tiled_mma_qk.shape_mnk, mode=[2]) + qk_mma_inst_tile_k: int = self.headdim // qk_mma_inst_shape_k + self.mma_tiler_qk = ( + self.mma_tiler_qk[0], + self.mma_tiler_qk[1], + qk_mma_inst_shape_k * qk_mma_inst_tile_k, + ) + + pv_mma_inst_shape_k = cute.size(tiled_mma_pv.shape_mnk, mode=[2]) + pv_mma_inst_tile_k: int = self.mma_tiler_qk[1] // pv_mma_inst_shape_k + self.mma_tiler_pv = ( + self.mma_tiler_pv[0], + self.mma_tiler_pv[1], + pv_mma_inst_shape_k * pv_mma_inst_tile_k, + ) + + self._compute_stages(tiled_mma_qk, tiled_mma_pv, input_dtype) + + @cute.jit + def __call__( + self, + Q: cute.Tensor, + K: cute.Tensor, + V: cute.Tensor, + sm_scale: cutlass.Float32, + O: cute.Tensor, # noqa: E741 + LSE: cute.Tensor, + q2k_block_sparse_index: cute.Tensor, + q2k_block_sparse_num: cute.Tensor, + variable_block_sizes: cute.Tensor, + stream: cuda.CUstream, + ) -> None: + input_dtype = Q.element_type + if cutlass.const_expr(input_dtype not in [cutlass.Float16, cutlass.BFloat16]): + raise RuntimeError("Input dtype must be Float16 or BFloat16") + if cutlass.const_expr( + not (input_dtype == K.element_type == V.element_type == O.element_type) + ): + raise RuntimeError("All input tensors must have the same data type") + + batch, heads, seqlen, dim = Q.layout.shape + if cutlass.const_expr(dim != self.headdim): + raise RuntimeError(f"Dimension mismatch: {dim} != {self.headdim}") + + _, _, _, dim_v = V.layout.shape + if cutlass.const_expr(dim_v != self.headdim): + raise RuntimeError(f"Dimension mismatch: {dim_v} != {self.headdim}") + + _, _, num_q_blocks, num_kv_blocks = q2k_block_sparse_index.layout.shape + + grid, params = self._compute_grid( + num_q_blocks=num_q_blocks, + num_kv_blocks=num_kv_blocks, + num_heads=heads, + batchsize=batch, + headdim=dim, + headdim_v=dim_v, + ) + + # [batch, heads, seqlen, dim] -> [seqlen, dim, heads, batch] + Q_layout_transpose = [2, 3, 1, 0] + KV_layout_transpose = [2, 3, 1, 0] + # [batch, heads, seqlen, dim] -> [seqlen, dim, heads, batch] + O_layout_transpose = [2, 3, 1, 0] + # [seqlen, dim, heads, batch] -> [dim, seqlen, heads, batch] + V_layout_transpose = [1, 0, 2, 3] + # [batch, heads, seqlen] -> [seqlen, heads, batch] + LSE_layout_transpose = [2, 1, 0] + Q = cute.make_tensor(Q.iterator, cute.select(Q.layout, mode=Q_layout_transpose)) + K = cute.make_tensor(K.iterator, cute.select(K.layout, mode=KV_layout_transpose)) + V = cute.make_tensor(V.iterator, cute.select(V.layout, mode=KV_layout_transpose)) + # NOTE: need transpose here, make V be N-major + V = cute.make_tensor(V.iterator, cute.select(V.layout, mode=V_layout_transpose)) + O = cute.make_tensor(O.iterator, cute.select(O.layout, mode=O_layout_transpose)) # noqa: E741 + LSE = cute.make_tensor(LSE.iterator, cute.select(LSE.layout, mode=LSE_layout_transpose)) + + q_major_mode = utils.LayoutEnum.from_tensor(Q).mma_major_mode() + k_major_mode = utils.LayoutEnum.from_tensor(K).mma_major_mode() + v_major_mode = tcgen05.OperandMajorMode.MN + + tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + input_dtype, + q_major_mode, + k_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler_qk[:2], + ) + p_major_mode = tcgen05.OperandMajorMode.K + p_source = tcgen05.OperandSource.SMEM if self.p_in_smem else tcgen05.OperandSource.TMEM + tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + input_dtype, + p_major_mode, + v_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler_pv[:2], + a_source=p_source, + ) + self._setup_attributes(tiled_mma_qk, tiled_mma_pv, input_dtype) + + o_layout = utils.LayoutEnum.from_tensor(O) + self.epi_tile = (self.mma_tiler_pv[0], self.mma_tiler_pv[1]) + + sQ_layout = sm100_utils.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, input_dtype, self.q_stage + ) + sK_layout = sm100_utils.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, input_dtype, self.kv_stage + ) + + sP_layout = sm100_utils.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, input_dtype, self.s_stage + ) + + sV_layout = sm100_utils.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, input_dtype, self.kv_stage + ) + + sO_layout = sm100_utils.make_smem_layout_epi( + input_dtype, o_layout, self.epi_tile, self.epi_stage + ) + fake_sO_layout = sm100_utils.make_smem_layout_epi( + input_dtype, o_layout, (self.block_m, self.headdim), self.epi_stage + ) + + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) + for name, mX, layout in [ + ("Q", Q, sQ_layout), + ("K", K, sK_layout), + ("V", V, sV_layout), + ] + } + + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(self.cta_group) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + + # use fake tiled mma to get TMA atoms + fake_tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + input_dtype, + q_major_mode, + k_major_mode, + self.acc_dtype, + self.cta_group, + (self.block_m, self.block_n), + ) + fake_sQ_layout = sm100_utils.make_smem_layout_a( + fake_tiled_mma_qk, + (self.block_m, self.block_n, self.mma_tiler_qk[2]), + input_dtype, + self.q_stage, + ) + fake_sK_layout = sm100_utils.make_smem_layout_b( + fake_tiled_mma_qk, + (self.block_m, self.block_n, self.mma_tiler_qk[2]), + input_dtype, + self.kv_stage, + ) + + tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + Q, + cute.select(fake_sQ_layout, mode=[0, 1, 2]), + (self.block_m, self.block_n, self.mma_tiler_qk[2]), + fake_tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + K, + cute.select(fake_sK_layout, mode=[0, 1, 2]), + (self.block_m, self.block_n, self.mma_tiler_qk[2]), + fake_tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + + fake_tiled_mma_pv = sm100_utils.make_trivial_tiled_mma( + input_dtype, + p_major_mode, + v_major_mode, + self.acc_dtype, + self.cta_group, + (self.block_m, self.block_n), + ) + fake_sV_layout = sm100_utils.make_smem_layout_b( + fake_tiled_mma_pv, + (self.block_m, self.block_n, self.block_n), + input_dtype, + self.kv_stage, + ) + + tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + V, + cute.select(fake_sV_layout, mode=[0, 1, 2]), + (self.block_m, self.block_n, self.block_n), + fake_tiled_mma_pv, + self.cluster_layout_vmnk.shape, + ) + + o_cta_v_layout = cute.composition( + cute.make_identity_layout(O.shape), (self.block_m, self.headdim) + ) + tma_atom_O, mO = cpasync.make_tiled_tma_atom( + tma_store_op, O, cute.select(fake_sO_layout, mode=[0, 1]), o_cta_v_layout + ) + + sScale_layout = None + if cutlass.const_expr(self.o_stage == 1): + # [acc_scale] + sScale_layout = cute.make_ordered_layout((self.mma_tiler_qk[0], self.s_stage), (0, 1)) + else: + # [acc_scale, running_max] + sScale_layout = cute.make_ordered_layout( + (2, self.mma_tiler_qk[0], self.s_stage), (0, 1, 2) + ) + + # [running_sum, running_max] * scale_buffers + # TODO: rearrange this to make to LDS.64 instead of LDS.32 x 2 + sFinal_layout = cute.make_ordered_layout( + (2, self.mma_tiler_qk[0], self.scale_buffers), (0, 1, 2) + ) + + self.max_indices = 4 * 1024 + + @cute.struct + class SharedStorage: + load_KV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.kv_stage * 2] + qk_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.s_stage * 2] + p_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.s_stage * 2] + o_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.o_stage * 2] + pv_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.o_stage * 2] + store_O_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_stage * 2] + corr_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.s_stage * 2] + + tmem_holding_buf: cutlass.Int32 + + sScale: cute.struct.MemRange[cutlass.Float32, cute.cosize(sScale_layout)] + + sFinal: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, cute.cosize(sFinal_layout)], + 8, # for LDS.64 it shall be 8 bytes aligned. + ] + sFinal_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.scale_buffers * 2] + + O_final_guard: cutlass.Int64 + + sQ: cute.struct.Align[ + cute.struct.MemRange[input_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[input_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes, + ] + sP: cute.struct.Align[ + cute.struct.MemRange[input_dtype, cute.cosize(sP_layout) * (self.p_in_smem)], + self.buffer_align_bytes, + ] + sO: cute.struct.Align[ + cute.struct.MemRange[input_dtype, cute.cosize(sO_layout)], + self.buffer_align_bytes, + ] + sVariable_block_sizes: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int32, self.max_indices], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + self.LOG2_E: float = math.log2(math.e) + self.LN2: float = math.log(2.0) + sm_scale_log2 = sm_scale * self.LOG2_E + + self.kernel( + mQ, + mK, + mV, + sm_scale_log2, + mO, + LSE, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_O, + tiled_mma_qk, + tiled_mma_pv, + fake_tiled_mma_qk, + fake_tiled_mma_pv, + sQ_layout, + sK_layout, + sV_layout, + sP_layout, + sScale_layout, + sFinal_layout, + sO_layout, + self.cluster_layout_vmnk, + params, + q2k_block_sparse_index, + q2k_block_sparse_num, + variable_block_sizes, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, # [seq, dim, head, batch] + mK: cute.Tensor, # [seq, dim, head, batch] + mV: cute.Tensor, # [dim, seq, head, batch] + sm_scale_log2: cutlass.Float32, # softmax scale in log2 + mO: cute.Tensor, # [dim, seq, head, batch] + mLSE: cute.Tensor, # [seq, head, batch] + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_O: cute.CopyAtom, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + fake_tiled_mma_qk: cute.TiledMma, + fake_tiled_mma_pv: cute.TiledMma, + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sP_layout: cute.ComposedLayout, + sScale_layout: cute.Layout, + sFinal_layout: cute.Layout, + sO_layout: cute.ComposedLayout, + cluster_layout_vmnk: cute.Layout, + scheduler_params: scheduler.ParamsBase, + q2k_block_sparse_index: cute.Tensor, + q2k_block_sparse_num: cute.Tensor, + variable_block_sizes: cute.Tensor, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx = cute.arch.thread_idx()[0] + + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_O) + + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + KV_mbar_ptr = storage.load_KV_mbar_ptr.data_ptr() + O_final_guard = storage.O_final_guard + if warp_idx == self.load_warp_id: + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + O_final_guard, self.threads_per_warp * len(self.correction_warp_ids) + ) + for i in cutlass.range_constexpr(self.kv_stage, unroll_full=True): + # producer + cute.arch.mbarrier_init(KV_mbar_ptr + i * 2, len([self.load_warp_id])) + # consumer + cute.arch.mbarrier_init(KV_mbar_ptr + i * 2 + 1, len([self.mma_warp_id])) + kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kv_stage + ) + kv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kv_stage + ) + + qk_mma_pipeline = pipeline.PipelineUmmaAsync.create( + num_stages=self.s_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax_warp_ids) + ), + barrier_storage=storage.qk_mma_mbar_ptr.data_ptr(), + ) + qk_mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.s_stage + ) + qk_mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.s_stage + ) + + p_pipeline = pipeline.PipelineAsyncUmma.create( + num_stages=self.s_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax_warp_ids) + ), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.p_mbar_ptr.data_ptr(), + ) + p_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.s_stage + ) + p_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.s_stage + ) + o_pipeline = pipeline.PipelineAsyncUmma.create( + num_stages=self.o_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.o_mbar_ptr.data_ptr(), + ) + o_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.o_stage + ) + o_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.o_stage + ) + + pv_mma_pipeline = pipeline.PipelineUmmaAsync.create( + num_stages=self.o_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.pv_mma_mbar_ptr.data_ptr(), + ) + pv_mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.o_stage + ) + pv_mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.o_stage + ) + + st_O_pipeline = pipeline.PipelineAsync.create( + num_stages=self.epi_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len([self.epilogue_warp_id]) + ), + barrier_storage=storage.store_O_mbar_ptr.data_ptr(), + ) + st_O_producer_state = pipeline.PipelineState( + self.epi_stage, + cutlass.Int32(0), + cutlass.Int32(0), + cutlass.Int32(0), + ) + st_O_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.epi_stage + ) + + corr_pipeline = pipeline.PipelineAsync.create( + num_stages=self.s_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.corr_mbar_ptr.data_ptr(), + ) + + correction_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.s_stage + ) + + sFinal_pipeline = pipeline.PipelineAsync.create( + num_stages=self.scale_buffers, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.sFinal_mbar_ptr.data_ptr(), + ) + sFinal_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.scale_buffers + ) + sFinal_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.scale_buffers + ) + + cute.arch.mbarrier_init_fence() + + tmem_holding_buf = storage.tmem_holding_buf + if cutlass.const_expr(self.do_tmem_alloc): + if warp_idx == 0: + cute.arch.alloc_tmem( + self.tmem_alloc_cols, + tmem_holding_buf, + is_two_cta=False, + ) + cute.arch.sync_threads() + + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + # NOTE: V will reuse the same smem as K + # stripe swizzle info to reuse smem + sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout.inner), sV_layout.outer) + + sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) + + sVariable_block_sizes_layout = cute.make_layout((self.max_indices,)) + sVariable_block_sizes = storage.sVariable_block_sizes.get_tensor( + sVariable_block_sizes_layout + ) + + for i in cutlass.range(tidx, variable_block_sizes.shape[0], cute.arch.block_dim()[0]): + sVariable_block_sizes[i] = variable_block_sizes[i] + # cta sync + cute.arch.sync_threads() + + sScale = storage.sScale.get_tensor(sScale_layout) + sFinal = storage.sFinal.get_tensor(sFinal_layout) + + thr_mma_qk = tiled_mma_qk.get_slice(0) + thr_mma_pv = tiled_mma_pv.get_slice(0) + + # ----- GMEM partition ----------- # + # [block_m, dimk, block_cnt, loop_k, head, batch] + gQ = cute.flat_divide(mQ, (self.block_m, self.headdim)) + gK = cute.flat_divide(mK, (self.block_n, self.headdim)) + gO = cute.flat_divide(mO, (self.block_m, self.headdim)) + # [dimK, block_n, loop_k, block_cnt, head, batch] + gV = cute.flat_divide(mV, (self.headdim, self.block_n)) + + # [block_m, block_cnt, head, batch] + gLSE = cute.flat_divide(mLSE, (self.block_m,)) + + # ----- TMEM partition ----------- # + # NOTE: we already know each SM's occupancy is 1, + # so tmem_ptr starting at zero + tmem_ptr = cute.make_ptr( + self.acc_dtype, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16 + ) + if cutlass.const_expr(self.do_tmem_alloc): + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + + tS_shape = (self.mma_tiler_qk[0], self.tmem_cols_S) + tCtS_shape = thr_mma_qk.partition_shape_C(tS_shape) + tCtS_fake = thr_mma_qk.make_fragment_C(tCtS_shape) + tCtS = cute.make_tensor( + cute.recast_ptr(tmem_ptr + self.tmem_offset_S, dtype=self.acc_dtype), tCtS_fake.layout + ) + sP = None + if cutlass.const_expr(self.p_in_smem): + sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) + else: + if cutlass.const_expr(self.p_in_s): + p_ptr = cute.recast_ptr(tmem_ptr + self.tmem_offset_S, dtype=sV.element_type) + tP = cute.make_tensor(p_ptr, sP_layout.outer) + tCtP = thr_mma_pv.make_fragment_A(tP) + stride = tCtP.stride + # since it will reuse the S TMEM, we need to make sure the stride is correct + stride = (*(stride[i] for i in range(len(stride) - 1)), stride[len(stride) - 1] * 2) + layout = cute.make_layout(tCtP.shape, stride=stride) + tCtP = cute.make_tensor(tCtP.iterator, layout) + sP = cute.make_tensor(p_ptr, tCtP.layout) + else: + # P has its own TMEM + p_ptr = cute.recast_ptr(tmem_ptr + self.tmem_offset_P, dtype=sV.element_type) + tP = cute.make_tensor(p_ptr, sP_layout.outer) + tCtP = thr_mma_pv.make_fragment_A(tP) + sP = cute.make_tensor(p_ptr, tCtP.layout) + + tO_shape = (self.mma_tiler_pv[0], self.tmem_cols_O) + tCtO_shape = thr_mma_pv.partition_shape_C(tO_shape) + tCtO_fake = thr_mma_pv.make_fragment_C(tCtO_shape) + tCtO = cute.make_tensor( + cute.recast_ptr(tmem_ptr + self.tmem_offset_O, dtype=self.acc_dtype), tCtO_fake.layout + ) + + TileSchedulerCls = partial(self.scheduler_cls.create, scheduler_params) + + if warp_idx in self.empty_warp_ids: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_load) + + self.load( + TileSchedulerCls, + gQ, + gK, + gV, + KV_mbar_ptr, + kv_producer_state, + q2k_block_sparse_num, + q2k_block_sparse_index, + sQ, + sK, + sV, + fake_tiled_mma_qk, + fake_tiled_mma_pv, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + ) + + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_mma) + + self.mma( + TileSchedulerCls, + KV_mbar_ptr, + kv_consumer_state, + q2k_block_sparse_num, + qk_mma_pipeline, + qk_mma_producer_state, + tiled_mma_qk, + p_pipeline, + p_consumer_state, + o_pipeline, + o_consumer_state, + pv_mma_pipeline, + pv_mma_producer_state, + tiled_mma_pv, + sQ, + sK, + sV, + sP, + tCtS, + tCtO, + ) + + if warp_idx == self.epilogue_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_epi) + + self.epilogue( + TileSchedulerCls, + st_O_pipeline, + st_O_consumer_state, + gO, + sO, + tma_atom_O, + ) + + if warp_idx in self.softmax_warp_ids: + cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + + num_q_blocks = cute.size(gQ, mode=[2]) + self.softmax( + TileSchedulerCls, + q2k_block_sparse_num, + q2k_block_sparse_index, + variable_block_sizes, + sVariable_block_sizes, + qk_mma_pipeline, + qk_mma_consumer_state, + p_pipeline, + p_producer_state, + num_q_blocks, + sP, + tCtS, + thr_mma_qk, + sm_scale_log2, + sScale, + corr_pipeline, + sFinal, + sFinal_pipeline, + sFinal_producer_state, + O_final_guard, + ) + + if warp_idx in self.correction_warp_ids: + cute.arch.warpgroup_reg_alloc(self.num_regs_correction) + + num_q_blocks = cute.size(gQ, mode=[2]) + + self.correction( + TileSchedulerCls, + q2k_block_sparse_num, + p_producer_state, + o_pipeline, + o_producer_state, + pv_mma_pipeline, + pv_mma_consumer_state, + st_O_pipeline, + st_O_producer_state, + tCtO, + thr_mma_pv, + sO, + corr_pipeline, + correction_consumer_state, + sScale, + sm_scale_log2, + num_q_blocks, + variable_block_sizes, + sFinal, + sFinal_pipeline, + sFinal_consumer_state, + O_final_guard, + gLSE, + ) + + if cutlass.const_expr(self.do_tmem_alloc): + cute.arch.sync_threads() + if warp_idx == 0: + cute.arch.relinquish_tmem_alloc_permit() + cute.arch.dealloc_tmem( + tmem_ptr, + self.tmem_alloc_cols, + ) + + @cute.jit + def load( + self, + TileSchedulerCls: Callable, + gQ: cute.Tensor, + gK: cute.Tensor, + gV: cute.Tensor, + KV_mbar_ptr: cutlass.Pointer, + kv_producer_state: pipeline.PipelineState, + q2k_block_sparse_num: cute.Tensor, + q2k_block_sparse_index: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + fake_tiled_mma_qk: cute.TiledMma, + fake_tiled_mma_pv: cute.TiledMma, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + ): + fake_thr_mma_qk = fake_tiled_mma_qk.get_slice(0) + fake_thr_mma_pv = fake_tiled_mma_pv.get_slice(0) + + def _get_sQ_cpy( + sQ: cute.Tensor, + ) -> cute.Tensor: + _tmp = cute.flatten(sQ) + mode0 = cute.get(_tmp.layout, mode=[0]) + mode0_split = cute.flat_divide(mode0, (self.block_m,)) + shape = ( + mode0_split.shape, + cute.get(_tmp.layout, mode=[1]).shape, + cute.get(_tmp.layout, mode=[2]).shape, + cute.get(_tmp.layout, mode=[3]).shape, + cute.get(_tmp.layout, mode=[4]).shape, + cute.get(_tmp.layout, mode=[5]).shape, + ) + stride = ( + mode0_split.stride, + cute.get(_tmp.layout, mode=[1]).stride, + cute.get(_tmp.layout, mode=[2]).stride, + cute.get(_tmp.layout, mode=[3]).stride, + cute.get(_tmp.layout, mode=[4]).stride, + cute.get(_tmp.layout, mode=[5]).stride, + ) + layout = cute.make_layout(shape, stride=stride) + layout = cute.flatten(layout) + layout = cute.select(layout, mode=[0, 2, 3, 4, 5, 1, 6]) + layout = cute.group_modes(layout, 0, 2) + cpy_tensor = cute.make_tensor(sQ.iterator, layout) + return cpy_tensor + + sQ_cpy = _get_sQ_cpy(sQ) + + def _get_gQ_cpy( + tSgQ: cute.Tensor, + ): + layout = cute.select(tSgQ.layout, mode=[2, 0, 1, 3, 4, 5, 6]) + layout = cute.flat_divide( + layout, + (4,), # it relates to 128B swizzle for 16-bit data + ) + layout = cute.select(layout, mode=[2, 3, 0, 1, 4, 5, 6, 7]) + cpy_tensor = cute.make_tensor(tSgQ.iterator, layout) + return cpy_tensor + + tSgQ = _get_gQ_cpy(fake_thr_mma_qk.partition_A(gQ)) + tTMAsQ, tTMAgQ = cpasync.tma_partition( + tma_atom_Q, + 0, + cute.make_layout(1), + cute.group_modes(sQ_cpy, 0, 3), + cute.group_modes(tSgQ, 0, 3), + ) + + def _load_Q( + tTMAgQ: cute.Tensor, + tTMAsQ: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + producer_mbar: cutlass.Pointer, + m_block_1st: cutlass.Int32, + m_block_2nd: cutlass.Int32, + head_idx: cutlass.Int32, + batch_idx: cutlass.Int32, + ): + tTMAgQ_1st = tTMAgQ[None, None, m_block_1st, None, head_idx, batch_idx] + tTMAgQ_2nd = tTMAgQ[None, None, m_block_2nd, None, head_idx, batch_idx] + + cute.copy( + tma_atom_Q, + tTMAgQ_1st[None, 0, 0], + tTMAsQ[None, 0, 0, 0], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_Q, + tTMAgQ_1st[None, 1, 0], + tTMAsQ[None, 1, 0, 0], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_Q, + tTMAgQ_2nd[None, 0, 0], + tTMAsQ[None, 0, 1, 0], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_Q, + tTMAgQ_2nd[None, 1, 0], + tTMAsQ[None, 1, 1, 0], + tma_bar_ptr=producer_mbar, + ) + + load_Q = partial( + _load_Q, + tTMAgQ=tTMAgQ, + tTMAsQ=tTMAsQ, + tma_atom_Q=tma_atom_Q, + ) + + def _get_sK_cpy( + sK: cute.Tensor, + ) -> cute.Tensor: + return _get_sQ_cpy(sK) + + sK_cpy = _get_sK_cpy(sK) + + def _get_gK_cpy( + tSgK: cute.Tensor, + ): + layout = cute.select(tSgK.layout, mode=[2, 0, 1, 3, 4, 5, 6]) + layout = cute.flat_divide( + layout, + (4,), # it relates to 128B swizzle for 16-bit data + ) + layout = cute.select(layout, mode=[2, 3, 0, 1, 4, 5, 6, 7]) + cpy_tensor = cute.make_tensor(tSgK.iterator, layout) + return cpy_tensor + + tSgK = _get_gK_cpy(fake_thr_mma_qk.partition_B(gK)) + tTMAsK, tTMAgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK_cpy, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + + def _load_K( + tTMAgK: cute.Tensor, + tTMAsK: cute.Tensor, + tma_atom_K: cute.CopyAtom, + producer_mbar: cutlass.Pointer, + buffer_idx: cutlass.Int32, + n: cutlass.Int32, + m_block_1st: cutlass.Int32, + m_block_2nd: cutlass.Int32, + valid_m_block_2nd: cutlass.Boolean, + head_idx: cutlass.Int32, + batch_idx: cutlass.Int32, + q2k_block_sparse_index: cute.Tensor, + num_k_blocks: cutlass.Int32, + ): + n_block_1st = q2k_block_sparse_index[batch_idx, head_idx, m_block_1st, n] + n_block_2nd = ( + q2k_block_sparse_index[batch_idx, head_idx, m_block_2nd, n] + if valid_m_block_2nd + else num_k_blocks + ) + + tTMAgK_1st = tTMAgK[None, None, n_block_1st, None, head_idx, batch_idx] + tTMAgK_2nd = tTMAgK[None, None, n_block_2nd, None, head_idx, batch_idx] + + cute.copy( + tma_atom_K, + tTMAgK_1st[None, 0, 0], + tTMAsK[None, 0, 0, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_K, + tTMAgK_1st[None, 1, 0], + tTMAsK[None, 1, 0, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_K, + tTMAgK_2nd[None, 0, 0], + tTMAsK[None, 0, 1, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_K, + tTMAgK_2nd[None, 1, 0], + tTMAsK[None, 1, 1, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + + load_K = partial( + _load_K, + tTMAgK=tTMAgK, + tTMAsK=tTMAsK, + tma_atom_K=tma_atom_K, + q2k_block_sparse_index=q2k_block_sparse_index, + ) + + def _get_sV_cpy( + sV: cute.Tensor, + ) -> cute.Tensor: + layout = cute.flatten(sV.layout) + layout = cute.select(layout, mode=[4, 0, 1, 2, 3, 5]) + layout = cute.flat_divide( + layout, + (4,), # it relates to 128B swizzle for 16-bit data + ) + layout = cute.select(layout, mode=[2, 3, 4, 5, 0, 1, 6]) + layout = cute.select(layout, mode=[0, 2, 3, 4, 5, 1, 6]) + layout = cute.group_modes(layout, 0, 2) + cpy_tensor = cute.make_tensor(sV.iterator, layout) + return cpy_tensor + + sV_cpy = _get_sV_cpy(sV) + + def _get_gV_cpy( + tOgV: cute.Tensor, + ): + layout = cute.append_ones(tOgV.layout) + layout = cute.select(layout, mode=[0, 7, 2, 1, 3, 4, 5, 6]) + cpy_tensor = cute.make_tensor(tOgV.iterator, layout) + return cpy_tensor + + tOgV = _get_gV_cpy(fake_thr_mma_pv.partition_B(gV)) + + tTMAsV, tTMAgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV_cpy, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + + def _load_V( + tTMAgV: cute.Tensor, + tTMAsV: cute.Tensor, + tma_atom_V: cute.CopyAtom, + producer_mbar: cutlass.Pointer, + buffer_idx: cutlass.Int32, + n: cutlass.Int32, + m_block_1st: cutlass.Int32, + m_block_2nd: cutlass.Int32, + valid_m_block_2nd: cutlass.Boolean, + num_v_blocks: cutlass.Int32, + head_idx: cutlass.Int32, + batch_idx: cutlass.Int32, + q2k_block_sparse_index: cute.Tensor, + ): + n_block_1st = q2k_block_sparse_index[batch_idx, head_idx, m_block_1st, n] + n_block_2nd = ( + q2k_block_sparse_index[batch_idx, head_idx, m_block_2nd, n] + if valid_m_block_2nd + else num_v_blocks + ) + + tTMAgV_1st = tTMAgV[None, None, None, n_block_1st, head_idx, batch_idx] + tTMAgV_2nd = tTMAgV[None, None, None, n_block_2nd, head_idx, batch_idx] + + cute.copy( + tma_atom_V, + tTMAgV_1st[None, 0, 0], + tTMAsV[None, 0, 0, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_V, + tTMAgV_1st[None, 1, 0], + tTMAsV[None, 0, 1, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_V, + tTMAgV_2nd[None, 0, 0], + tTMAsV[None, 1, 0, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + cute.copy( + tma_atom_V, + tTMAgV_2nd[None, 1, 0], + tTMAsV[None, 1, 1, buffer_idx], + tma_bar_ptr=producer_mbar, + ) + + load_V = partial( + _load_V, + tTMAgV=tTMAgV, + tTMAsV=tTMAsV, + tma_atom_V=tma_atom_V, + q2k_block_sparse_index=q2k_block_sparse_index, + ) + + num_q_blocks = cute.size(gQ, mode=[2]) + num_k_blocks = cute.size(gK, mode=[2]) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m2_idx, head_idx, batch_idx = work_tile.tile_idx + + m_block_1st = m2_idx * 2 + m_block_2nd = m_block_1st + 1 + valid_m_block_2nd: cutlass.Boolean = m_block_2nd < num_q_blocks + m_block_2nd = m_block_2nd if valid_m_block_2nd else num_q_blocks + + # NOTE: Assumption: different Q-tile has the same number of KV-blocks + _num_n_blocks = q2k_block_sparse_num[batch_idx, head_idx, m_block_1st] + + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_producer_state.index * 2 + 1, kv_producer_state.phase + ) + # load Q + load_Q( + producer_mbar=KV_mbar_ptr + kv_producer_state.index * 2, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + head_idx=head_idx, + batch_idx=batch_idx, + ) + load_K( + producer_mbar=KV_mbar_ptr + kv_producer_state.index * 2, + buffer_idx=kv_producer_state.index, + n=0, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + valid_m_block_2nd=valid_m_block_2nd, + num_k_blocks=num_k_blocks, + head_idx=head_idx, + batch_idx=batch_idx, + ) + + # load K0 + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_producer_state.index * 2, + self.tma_copy_bytes["Q"] + self.tma_copy_bytes["K"], + ) + kv_producer_state.advance() + + for n in cutlass.range(_num_n_blocks - 1): + # load Kn + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_producer_state.index * 2 + 1, kv_producer_state.phase + ) + + load_K( + producer_mbar=KV_mbar_ptr + kv_producer_state.index * 2, + buffer_idx=kv_producer_state.index, + n=n + 1, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + valid_m_block_2nd=valid_m_block_2nd, + num_k_blocks=num_k_blocks, + head_idx=head_idx, + batch_idx=batch_idx, + ) + + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_producer_state.index * 2, self.tma_copy_bytes["K"] + ) + kv_producer_state.advance() + # load Vn-1 + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_producer_state.index * 2 + 1, kv_producer_state.phase + ) + + load_V( + producer_mbar=KV_mbar_ptr + kv_producer_state.index * 2, + buffer_idx=kv_producer_state.index, + n=n, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + valid_m_block_2nd=valid_m_block_2nd, + num_v_blocks=num_k_blocks, + head_idx=head_idx, + batch_idx=batch_idx, + ) + + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_producer_state.index * 2, self.tma_copy_bytes["V"] + ) + kv_producer_state.advance() + + # load Vn + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_producer_state.index * 2 + 1, kv_producer_state.phase + ) + + load_V( + producer_mbar=KV_mbar_ptr + kv_producer_state.index * 2, + buffer_idx=kv_producer_state.index, + n=_num_n_blocks - 1, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + valid_m_block_2nd=valid_m_block_2nd, + num_v_blocks=num_k_blocks, + head_idx=head_idx, + batch_idx=batch_idx, + ) + + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_producer_state.index * 2, self.tma_copy_bytes["V"] + ) + kv_producer_state.advance() + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def mma( + self, + TileSchedulerCls: Callable, + KV_mbar_ptr: cutlass.Pointer, + kv_consumer_state: pipeline.PipelineState, + q2k_block_sparse_num: cute.Tensor, + qk_mma_pipeline: pipeline.PipelineUmmaAsync, + qk_mma_producer_state: pipeline.PipelineState, + tiled_mma_qk: cute.TiledMma, + p_pipeline: pipeline.PipelineAsyncUmma, + p_consumer_state: pipeline.PipelineState, + o_pipeline: pipeline.PipelineAsyncUmma, + o_consumer_state: pipeline.PipelineState, + pv_mma_pipeline: pipeline.PipelineUmmaAsync, + pv_mma_producer_state: pipeline.PipelineState, + tiled_mma_pv: cute.TiledMma, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sP: cute.Tensor, + tCtS: cute.Tensor, + tCtO: cute.Tensor, + ): + thr_mma_qk = tiled_mma_qk.get_slice(0) + thr_mma_pv = tiled_mma_pv.get_slice(0) + + tCsQ = thr_mma_qk.make_fragment_A(sQ) + tCsK = thr_mma_qk.make_fragment_B(sK) + + tCsP = None + if cutlass.const_expr(self.p_in_smem): + tCsP = thr_mma_pv.make_fragment_A(sP) + else: + # it resides in TMEM + tCsP = sP + tCsV = thr_mma_pv.make_fragment_B(sV) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m2_idx, head_idx, batch_idx = work_tile.tile_idx + + m_block_1st = m2_idx * 2 + _num_n_blocks = q2k_block_sparse_num[batch_idx, head_idx, m_block_1st] + + # Q @ K0 + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False) + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_consumer_state.index * 2, kv_consumer_state.phase + ) + qk_mma_pipeline.producer_acquire(qk_mma_producer_state) + for kblock_idx in cutlass.range_constexpr(cute.size(sQ, mode=[2])): + cute.gemm( + tiled_mma_qk, + cute.append_ones(tCtS[None, None, qk_mma_producer_state.index]), + tCsQ[None, None, kblock_idx, 0], + tCsK[None, None, kblock_idx, kv_consumer_state.index], + cute.append_ones(tCtS[None, None, qk_mma_producer_state.index]), + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + qk_mma_pipeline.producer_commit(qk_mma_producer_state) + qk_mma_producer_state.advance() + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_consumer_state.index * 2 + 1, 0 + ) + kv_consumer_state.advance() + + for n in cutlass.range(_num_n_blocks - 1): + # Q @ Kn + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False) + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_consumer_state.index * 2, kv_consumer_state.phase + ) + qk_mma_pipeline.producer_acquire(qk_mma_producer_state) + for kblock_idx in cutlass.range_constexpr(cute.size(sQ, mode=[2])): + cute.gemm( + tiled_mma_qk, + cute.append_ones(tCtS[None, None, qk_mma_producer_state.index]), + tCsQ[None, None, kblock_idx, 0], + tCsK[None, None, kblock_idx, kv_consumer_state.index], + cute.append_ones(tCtS[None, None, qk_mma_producer_state.index]), + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + + qk_mma_pipeline.producer_commit(qk_mma_producer_state) + qk_mma_producer_state.advance() + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_consumer_state.index * 2 + 1, 0 + ) + kv_consumer_state.advance() + + # P @ Vn-1 + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, n >= self.o_stage) + p_pipeline.consumer_wait(p_consumer_state) + o_pipeline.consumer_wait(o_consumer_state) + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_consumer_state.index * 2, kv_consumer_state.phase + ) + pv_mma_pipeline.producer_acquire(pv_mma_producer_state) + for kblock_idx in cutlass.range_constexpr(cute.size(sV, mode=[2])): + cute.gemm( + tiled_mma_pv, + cute.append_ones(tCtO[None, None, pv_mma_producer_state.index]), + tCsP[None, None, kblock_idx, p_consumer_state.index], + tCsV[None, None, kblock_idx, kv_consumer_state.index], + cute.append_ones(tCtO[None, None, pv_mma_producer_state.index]), + ) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) + pv_mma_pipeline.producer_commit(pv_mma_producer_state) + pv_mma_producer_state.advance() + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_consumer_state.index * 2 + 1, 0 + ) + kv_consumer_state.advance() + p_pipeline.consumer_release(p_consumer_state) + p_consumer_state.advance() + o_pipeline.consumer_release(o_consumer_state) + o_consumer_state.advance() + + # P @ Vn + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, (_num_n_blocks - 1) >= self.o_stage) + p_pipeline.consumer_wait(p_consumer_state) + o_pipeline.consumer_wait(o_consumer_state) + cute.arch.mbarrier_wait( + KV_mbar_ptr + kv_consumer_state.index * 2, kv_consumer_state.phase + ) + pv_mma_pipeline.producer_acquire(pv_mma_producer_state) + for kblock_idx in cutlass.range_constexpr(cute.size(sV, mode=[2])): + cute.gemm( + tiled_mma_pv, + cute.append_ones(tCtO[None, None, pv_mma_producer_state.index]), + tCsP[None, None, kblock_idx, p_consumer_state.index], + tCsV[None, None, kblock_idx, kv_consumer_state.index], + cute.append_ones(tCtO[None, None, pv_mma_producer_state.index]), + ) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) + pv_mma_pipeline.producer_commit(pv_mma_producer_state) + pv_mma_producer_state.advance() + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + KV_mbar_ptr + kv_consumer_state.index * 2 + 1, 0 + ) + kv_consumer_state.advance() + p_pipeline.consumer_release(p_consumer_state) + p_consumer_state.advance() + o_pipeline.consumer_release(o_consumer_state) + o_consumer_state.advance() + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def exp2f( + self, + value: cutlass.Float32, + ) -> cutlass.Float32: + return cute.arch.exp2(value) + + @cute.jit + def update_row_max( + self, + max_new: cutlass.Float32, + max_old: cutlass.Float32, + is_first: cutlass.Boolean, + sm_scale_log2: cutlass.Float32, + ): + _max_safe = max_new if max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + if cutlass.const_expr(not is_first): + acc_scale_ = (max_old - _max_safe) * sm_scale_log2 + acc_scale = self.exp2f(acc_scale_) + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + _max_safe = max_old + acc_scale = 1.0 + return _max_safe, acc_scale + + @cute.jit + def mask_then_cal_local_max( + self, + tensor: cute.Tensor, + right_bound: cutlass.Int32, + ) -> cutlass.Float32: + def clamp( + tensor: cute.Tensor, + index: cutlass.Int32, + right_bound: cutlass.Int32, + ) -> cutlass.Float32: + tensor[index] = tensor[index] if (index < right_bound) else -cutlass.Float32.inf + return tensor[index] + + _clamp = partial(clamp, tensor=tensor, right_bound=right_bound) + + if cutlass.const_expr(cute.size(tensor, mode=[0]) < 8): + _max = -cutlass.Float32.inf + for i in cutlass.range_constexpr(0, cute.size(tensor, mode=[0]), 2): + _max = ptx.max3f(_max, _clamp(index=i), _clamp(index=i + 1)) + return _max + else: + local_max = [ + ptx.max3f(_clamp(index=0), _clamp(index=1), -cutlass.Float32.inf), + ptx.max3f(_clamp(index=2), _clamp(index=3), -cutlass.Float32.inf), + ptx.max3f(_clamp(index=4), _clamp(index=5), -cutlass.Float32.inf), + ptx.max3f(_clamp(index=6), _clamp(index=7), -cutlass.Float32.inf), + ] + for i in cutlass.range_constexpr(8, cute.size(tensor, mode=[0]), 8): + local_max[0] = ptx.max3f(local_max[0], _clamp(index=i), _clamp(index=i + 1)) + local_max[1] = ptx.max3f(local_max[1], _clamp(index=i + 2), _clamp(index=i + 3)) + local_max[2] = ptx.max3f(local_max[2], _clamp(index=i + 4), _clamp(index=i + 5)) + local_max[3] = ptx.max3f(local_max[3], _clamp(index=i + 6), _clamp(index=i + 7)) + local_max[0] = cute.arch.fmax(local_max[0], local_max[1]) + return ptx.max3f(local_max[0], local_max[2], local_max[3]) + + @cute.jit + def cal_local_sum( + self, + tensor: cute.Tensor, + init: cutlass.Float32, + ) -> cutlass.Float32: + if cutlass.const_expr(cute.size(tensor, mode=[0]) < 8): + _sum = init + for i in cutlass.range_constexpr(cute.size(tensor, mode=[0])): + _sum += tensor[i] + return _sum + else: + local_sum = [ + cute.arch.add_packed_f32x2((init, 0.0), (tensor[0], tensor[1])), + (tensor[2], tensor[3]), + (tensor[4], tensor[5]), + (tensor[6], tensor[7]), + ] + for i in cutlass.range_constexpr(8, cute.size(tensor, mode=[0]), 8): + local_sum[0] = cute.arch.add_packed_f32x2( + local_sum[0], (tensor[i + 0], tensor[i + 1]) + ) + local_sum[1] = cute.arch.add_packed_f32x2( + local_sum[1], (tensor[i + 2], tensor[i + 3]) + ) + local_sum[2] = cute.arch.add_packed_f32x2( + local_sum[2], (tensor[i + 4], tensor[i + 5]) + ) + local_sum[3] = cute.arch.add_packed_f32x2( + local_sum[3], (tensor[i + 6], tensor[i + 7]) + ) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + + @cute.jit + def softmax_step( + self, + which: cutlass.Constexpr[cutlass.Int32], + stage: cutlass.Constexpr[cutlass.Int32], + n: cutlass.Int32, + q2k_block_sparse_index: cute.Tensor, + sVariable_block_sizes: cute.Tensor, + p_producer_state: pipeline.PipelineState, + sP: cute.Tensor, + tCcS_ld: cute.Tensor, + qk_mma_pipeline: pipeline.PipelineUmmaAsync, + p_pipeline: pipeline.PipelineAsyncUmma, + qk_mma_consumer_state: pipeline.PipelineState, + tiled_copy_t2r: cute.TiledCopy, + tiled_copy_r2t: Optional[cute.TiledCopy], + tCrS_ld_half_zeros: Optional[cute.Tensor], + tCtS_ld: cute.Tensor, + tCrS_ld: cute.Tensor, + tCrS_ld_half: cute.Tensor, + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block_1st: cutlass.Int32, + m_block_2nd: cutlass.Int32, + tidx: cutlass.Int32, + running_max: cute.Tensor, + running_sum: cute.Tensor, + sm_scale_log2: cutlass.Float32, + sScale: cute.Tensor, + correction_pipeline: pipeline.PipelineAsync, + is_first: cutlass.Boolean, + ): + n_size = 0 + if cutlass.const_expr(which == 0): + n_block_1st = q2k_block_sparse_index[ + batch_idx, head_idx, m_block_1st, n + ] # TODO: move to smem + n_size = sVariable_block_sizes[n_block_1st] + elif cutlass.const_expr(which == 1): + n_block_2nd = q2k_block_sparse_index[batch_idx, head_idx, m_block_2nd, n] + n_size = sVariable_block_sizes[n_block_2nd] + + qk_mma_pipeline.consumer_wait(qk_mma_consumer_state) + correction_pipeline.producer_acquire(p_producer_state) + cute.copy( + tiled_copy_t2r, + tCtS_ld[None, None, None, 0, which, qk_mma_consumer_state.index], + tCrS_ld, + ) + # calculate P + _max = self.mask_then_cal_local_max(tCrS_ld, n_size) + _max_safe, _acc_scale = self.update_row_max( + max_new=_max, + is_first=is_first, + max_old=running_max[0], + sm_scale_log2=sm_scale_log2, + ) + if cutlass.const_expr(self.o_stage == 1): + sScale[tidx, p_producer_state.index] = _acc_scale + else: + acc_scale = cute.arch.exp2( + (sScale[1, tidx, p_producer_state.index] - _max_safe) * sm_scale_log2 + ) + sScale[0, tidx, p_producer_state.index] = acc_scale + sScale[1, tidx, p_producer_state.index] = _max_safe + correction_pipeline.producer_commit(p_producer_state) + + running_max[0] = _max_safe + minus_coeff = -_max_safe * sm_scale_log2 + for i in cutlass.range(0, cute.size(tCrS_ld.shape), 2, unroll_full=True): + tCrS_ld[i], tCrS_ld[i + 1] = cute.arch.fma_packed_f32x2( + (tCrS_ld[i], tCrS_ld[i + 1]), + (sm_scale_log2, sm_scale_log2), + (minus_coeff, minus_coeff), + ) + tCrS_ld[i] = self.exp2f(tCrS_ld[i]) + tCrS_ld[i + 1] = self.exp2f(tCrS_ld[i + 1]) + + # type conversion + tCrS_ld_half.store(tCrS_ld.load().to(tCrS_ld_half.element_type)) + + # update running sum + running_sum[0] *= _acc_scale + running_sum[0] = self.cal_local_sum(tCrS_ld, running_sum[0]) + + p_pipeline.producer_acquire(p_producer_state) + if cutlass.const_expr(self.p_in_smem): + # copy P to SMEM + cute.autovec_copy(tCrS_ld_half, sP[None, which, p_producer_state.index]) + else: + # copy P to TMEM + cute.copy( + tiled_copy_r2t, tCrS_ld_half, sP[None, None, None, 0, which, p_producer_state.index] + ) + if cutlass.const_expr(self.p_in_s): + cute.copy( + tiled_copy_r2t, + tCrS_ld_half_zeros, + sP[None, None, None, 0, which ^ 1, p_producer_state.index], + ) + p_pipeline.producer_commit(p_producer_state) + + qk_mma_pipeline.consumer_release(qk_mma_consumer_state) + + @cute.jit + def softmax_loop( + self, + TileSchedulerCls: Callable, + q2k_block_sparse_num: cute.Tensor, + p_producer_state: pipeline.PipelineState, + softmax_step: Callable, + qk_mma_consumer_state: pipeline.PipelineState, + sFinal: cute.Tensor, + sFinal_pipeline: pipeline.PipelineAsync, + sFinal_producer_state: pipeline.PipelineState, + num_q_blocks: cutlass.Int32, + tidx: cutlass.Int32, + which: cutlass.Constexpr[cutlass.Int32], + corr_pipeline: pipeline.PipelineAsync, + O_final_guard: cutlass.Pointer, + sScale: cute.Tensor, + ): + running_states = cute.make_rmem_tensor(cute.make_layout((2, 1)), self.acc_dtype) + running_sum = running_states[0, None] + running_max = running_states[1, None] + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + O_final_phase = 1 + while work_tile.is_valid_tile: + m2_idx, head_idx, batch_idx = work_tile.tile_idx + + m_block_1st = m2_idx * 2 + _num_n_blocks = q2k_block_sparse_num[batch_idx, head_idx, m_block_1st] + m_block_2nd = m_block_1st + 1 if m_block_1st + 1 < num_q_blocks else m_block_1st + + cute.arch.mbarrier_wait(O_final_guard, O_final_phase) + O_final_phase ^= 1 + + running_max.fill(-cutlass.Float32.inf) + running_sum.fill(0.0) + if cutlass.const_expr(self.o_stage != 1): + sScale[1, tidx, None].fill(-cutlass.Float32.inf) + + softmax_step( + n=0, + stage=0, + batch_idx=batch_idx, + head_idx=head_idx, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + p_producer_state=p_producer_state, + qk_mma_consumer_state=qk_mma_consumer_state, + running_max=running_max, + running_sum=running_sum, + which=which, + correction_pipeline=corr_pipeline, + is_first=True, + ) + p_producer_state.advance() + qk_mma_consumer_state.advance() + + for n in cutlass.range(1, _num_n_blocks): + softmax_step( + n=n, + stage=0, + batch_idx=batch_idx, + head_idx=head_idx, + m_block_1st=m_block_1st, + m_block_2nd=m_block_2nd, + p_producer_state=p_producer_state, + qk_mma_consumer_state=qk_mma_consumer_state, + running_max=running_max, + running_sum=running_sum, + which=which, + correction_pipeline=corr_pipeline, + is_first=False, + ) + p_producer_state.advance() + qk_mma_consumer_state.advance() + + # put running stages to SMEM + sFinal_pipeline.producer_acquire(sFinal_producer_state) + cute.autovec_copy( + running_states, cute.append_ones(sFinal[None, tidx, sFinal_producer_state.index]) + ) + sFinal_pipeline.producer_commit(sFinal_producer_state) + sFinal_producer_state.advance() + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def softmax( + self, + TileSchedulerCls: Callable, + q2k_block_sparse_num: cute.Tensor, + q2k_block_sparse_index: cute.Tensor, + variable_block_sizes: cute.Tensor, + sVariable_block_sizes: cute.Tensor, + qk_mma_pipeline: pipeline.PipelineUmmaAsync, + qk_mma_consumer_state: pipeline.PipelineState, + p_pipeline: pipeline.PipelineAsyncUmma, + p_producer_state: pipeline.PipelineState, + num_q_blocks: cutlass.Int32, + sP: cute.Tensor, + tCtS: cute.Tensor, + thr_mma_qk: cute.core.ThrMma, + sm_scale_log2: cutlass.Float32, + sScale: cute.Tensor, + corr_pipeline: pipeline.PipelineAsync, + sFinal: cute.Tensor, + sFinal_pipeline: pipeline.PipelineAsync, + sFinal_producer_state: pipeline.PipelineState, + O_final_guard: cutlass.Pointer, + ): + tidx = cute.arch.thread_idx()[0] % (self.threads_per_warp * len(self.softmax_warp_ids)) + in_which = cute.arch.make_warp_uniform(tidx // self.block_m) + + copy_atom_t2r = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.block_n)), + self.acc_dtype, + ) + tS_load = cute.flat_divide( + tCtS[(None, None), 0, None], (self.mma_tiler_qk[0], self.block_n) + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tS_load[(None, None, 0, 0, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + + tCtS_ld = thr_copy_t2r.partition_S(tS_load) + + cS = cute.make_identity_tensor(self.mma_tiler_qk[:2]) + tCcS = thr_mma_qk.partition_C(cS) + cS_load = cute.flat_divide( + tCcS[(None, None), 0, None], (self.mma_tiler_qk[0], self.block_n) + ) + tCcS_ld = thr_copy_t2r.partition_D(cS_load) + tCrS_ld = cute.make_fragment(cute.select(tCcS_ld.shape, mode=[0, 1, 2]), self.acc_dtype) + tCrS_ld_half = cute.make_fragment(tCrS_ld.layout, sP.element_type) + + sP_cpy_slice = None + tiled_copy_r2t = None + tCrS_ld_half_zeros = None + if cutlass.const_expr(self.p_in_smem): + tv_layout = cute.make_ordered_layout( + (self.threads_per_wg, self.mma_tiler_qk[1], self.s_stage), (0, 1, 2) + ) + sP_cpy = cute.composition(sP, tv_layout) + sP_cpy_slice = cute.flatten(sP_cpy[tidx, None, None]) + else: + copy_atom_r2t = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(self.block_n // 2)), sP.element_type + ) + + def _get_tP_store(tP: cute.Tensor): + layout = cute.flatten(tP.layout) + mode1 = cute.make_layout( + cute.get(layout, mode=[1]).shape * cute.get(layout, mode=[3]).shape, + stride=cute.get(layout, mode=[1]).stride, + ) + shape = ( + cute.get(layout, mode=[0]).shape, + mode1.shape, + cute.get(layout, mode=[2]).shape, + cute.get(layout, mode=[4]).shape, + cute.get(layout, mode=[5]).shape, + ) + stride = ( + cute.get(layout, mode=[0]).stride, + mode1.stride, + cute.get(layout, mode=[2]).stride, + cute.get(layout, mode=[4]).stride, + cute.get(layout, mode=[5]).stride, + ) + layout = cute.make_layout(shape, stride=stride) + return cute.make_tensor(tP.iterator, layout) + + tP_store = _get_tP_store(sP) + + tiled_copy_r2t = tcgen05.make_tmem_copy(copy_atom_r2t, tP_store[(None, None, 0, 0, 0)]) + thr_copy_r2t = tiled_copy_r2t.get_slice(tidx) + + tCtP_st = thr_copy_r2t.partition_D(tP_store) + sP_cpy_slice = tCtP_st + if cutlass.const_expr(self.p_in_s): + tCrS_ld_half_zeros = cute.make_rmem_tensor( + tCrS_ld_half.layout, tCrS_ld_half.element_type + ) + tCrS_ld_half_zeros.fill(0.0) + else: + # P has its own TMEM + _zeros = cute.make_rmem_tensor(tCrS_ld_half.layout, tCrS_ld_half.element_type) + _zeros.fill(0.0) + for stage in cutlass.range_constexpr(self.s_stage): + cute.copy( + tiled_copy_r2t, _zeros, tCtP_st[None, None, None, 0, in_which ^ 1, stage] + ) + + _softmax_step = partial( + self.softmax_step, + q2k_block_sparse_index=q2k_block_sparse_index, + sVariable_block_sizes=sVariable_block_sizes, + sP=sP_cpy_slice, + tCcS_ld=tCcS_ld, + qk_mma_pipeline=qk_mma_pipeline, + p_pipeline=p_pipeline, + tiled_copy_t2r=tiled_copy_t2r, + tiled_copy_r2t=tiled_copy_r2t, + tCrS_ld_half_zeros=tCrS_ld_half_zeros, + tCtS_ld=tCtS_ld, + tCrS_ld=tCrS_ld, + tCrS_ld_half=tCrS_ld_half, + tidx=tidx, + sm_scale_log2=sm_scale_log2, + sScale=sScale, + ) + + _softmax_loop = partial( + self.softmax_loop, + TileSchedulerCls=TileSchedulerCls, + q2k_block_sparse_num=q2k_block_sparse_num, + p_producer_state=p_producer_state, + softmax_step=_softmax_step, + qk_mma_consumer_state=qk_mma_consumer_state, + sFinal=sFinal, + sFinal_pipeline=sFinal_pipeline, + sFinal_producer_state=sFinal_producer_state, + num_q_blocks=num_q_blocks, + tidx=tidx, + O_final_guard=O_final_guard, + sScale=sScale, + ) + + if cutlass.const_expr(self.p_in_smem): + sP_cpy = cute.composition( + sP, + cute.make_ordered_layout( + (self.threads_per_wg, self.mma_tiler_qk[1] * self.s_stage), (0, 1) + ), + ) + sP_cpy_slice = cute.flatten(sP_cpy[tidx, None]) + _zeros = cute.make_rmem_tensor((self.block_n, 1), sP_cpy_slice.element_type) + _zeros.fill(0.0) + + for stage in cutlass.range_constexpr(self.s_stage): + cute.autovec_copy(_zeros, sP_cpy_slice[None, stage * 2 + (in_which ^ 1)]) + if in_which == 0: + _softmax_loop(which=0, corr_pipeline=corr_pipeline) + else: + _softmax_loop(which=1, corr_pipeline=corr_pipeline) + + @cute.jit + def correction_loop( + self, + TileSchedulerCls: Callable, + variable_block_sizes: cute.Tensor, + q2k_block_sparse_num: cute.Tensor, + correction_pipeline: pipeline.PipelineAsync, + p_producer_state: pipeline.PipelineState, + o_pipeline: pipeline.PipelineAsyncUmma, + o_producer_state: pipeline.PipelineState, + correction_consumer_state: pipeline.PipelineState, + corr_ld_repeat: cutlass.Constexpr[cutlass.Int32], + corr_tiled_copy_t2r: cute.TiledCopy, + corr_tCtO_ld: cute.Tensor, + corr_tCrO_ld: cute.Tensor, + corr_tiled_copy_r2t: cute.TiledCopy, + pv_mma_pipeline: pipeline.PipelineUmmaAsync, + pv_mma_consumer_state: pipeline.PipelineState, + st_O_pipeline: pipeline.PipelineAsync, + st_O_producer_state: pipeline.PipelineState, + tv_layout: cute.Layout, + sO: cute.Tensor, + sScale: cute.Tensor, + wb_tCcO_ld: cute.Tensor, + wb_tCrO_ld: cute.Tensor, + sFinal: cute.Tensor, + sFinal_pipeline: pipeline.PipelineAsync, + sFinal_consumer_state: pipeline.PipelineState, + wb_ld_repeat: cutlass.Constexpr[cutlass.Int32], + wb_tCrO_reduction: Optional[cute.Tensor], + wb_tiled_copy_t2r: cute.TiledCopy, + wb_tCtO_ld: cute.Tensor, + wb_tCrO_ld_half: cute.Tensor, + tidx: cutlass.Int32, + sm_scale_log2: cutlass.Float32, + num_q_blocks: cutlass.Int32, + O_final_guard: cutlass.Pointer, + gLSE: cute.Tensor, + in_which: cutlass.Int32, + ): + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m2_idx, head_idx, batch_idx = work_tile.tile_idx + + _num_n_blocks = q2k_block_sparse_num[batch_idx, head_idx, m2_idx * 2] + + # non-rescaling + for n in cutlass.range_constexpr(self.o_stage): + correction_pipeline.consumer_wait(correction_consumer_state) + correction_pipeline.consumer_release(correction_consumer_state) + p_producer_state.advance() + correction_consumer_state.advance() + + o_pipeline.producer_acquire(o_producer_state) + o_pipeline.producer_commit(o_producer_state) + o_producer_state.advance() + + pv_mma_pipeline.consumer_wait(pv_mma_consumer_state) + pv_mma_pipeline.consumer_release(pv_mma_consumer_state) + pv_mma_consumer_state.advance() + + for n in cutlass.range(self.o_stage, _num_n_blocks): + correction_pipeline.consumer_wait(correction_consumer_state) + o_pipeline.producer_acquire(o_producer_state) + + acc_scale = 0.0 + if cutlass.const_expr(self.o_stage == 1): + acc_scale = sScale[tidx, p_producer_state.index] + else: + acc_scale = sScale[0, tidx, p_producer_state.index] + should_rescale = cute.arch.vote_ballot_sync(acc_scale < 1.0) != 0 + if should_rescale: + for repeat in cutlass.range(corr_ld_repeat): + cute.copy( + corr_tiled_copy_t2r, + corr_tCtO_ld[None, None, None, 0, repeat, pv_mma_consumer_state.index], + corr_tCrO_ld, + ) + + # apply correction + for i in cutlass.range_constexpr(0, cute.size(corr_tCrO_ld, mode=[0]), 2): + corr_tCrO_ld[i], corr_tCrO_ld[i + 1] = cute.arch.mul_packed_f32x2( + (corr_tCrO_ld[i], corr_tCrO_ld[i + 1]), + (acc_scale, acc_scale), + ) + + cute.copy( + corr_tiled_copy_r2t, + corr_tCrO_ld, + corr_tCtO_ld[None, None, None, 0, repeat, pv_mma_consumer_state.index], + ) + o_pipeline.producer_commit(o_producer_state) + o_producer_state.advance() + + correction_pipeline.consumer_release(correction_consumer_state) + p_producer_state.advance() + correction_consumer_state.advance() + + # currently, we don't need pv_mma_result + pv_mma_pipeline.consumer_wait(pv_mma_consumer_state) + pv_mma_pipeline.consumer_release(pv_mma_consumer_state) + pv_mma_consumer_state.advance() + + # ----- Correction Epilogue: Store O to SMEM ----- # + st_O_pipeline.producer_acquire(st_O_producer_state) + running_states = cute.make_rmem_tensor(cute.make_layout((2, 1)), self.acc_dtype) + running_sum = running_states[0, None] + running_max = running_states[1, None] + + sFinal_pipeline.consumer_wait(sFinal_consumer_state) + cute.autovec_copy( + cute.append_ones(sFinal[None, tidx, sFinal_consumer_state.index]), running_states + ) + sFinal_pipeline.consumer_release(sFinal_consumer_state) + sFinal_consumer_state.advance() + + # calculate LSE + lse = ( + running_max[0] * sm_scale_log2 + cute.math.log2(running_sum[0], fastmath=True) + ) * self.LN2 + + gLSE_1st = gLSE[None, m2_idx * 2, head_idx, batch_idx] + gLSE_2nd = gLSE[None, m2_idx * 2 + 1, head_idx, batch_idx] + if in_which == 0: + gLSE_1st[tidx] = lse + elif m2_idx * 2 + 1 < num_q_blocks: + gLSE_2nd[tidx - self.block_m] = lse + + # calculate scale + scale = cute.arch.rcp_approx(running_sum[0]) + + if cutlass.const_expr(self.o_stage == 1): + for repeat in cutlass.range(wb_ld_repeat): + cute.copy( + wb_tiled_copy_t2r, + wb_tCtO_ld[None, None, None, 0, repeat, pv_mma_consumer_state.index], + wb_tCrO_ld, + ) + # scale + for i in cutlass.range_constexpr(0, cute.size(wb_tCrO_ld, mode=[0]), 2): + wb_tCrO_ld[i], wb_tCrO_ld[i + 1] = cute.arch.mul_packed_f32x2( + (wb_tCrO_ld[i], wb_tCrO_ld[i + 1]), + (scale, scale), + ) + # type conversion + wb_tCrO_ld_half.store(wb_tCrO_ld.load().to(wb_tCrO_ld_half.element_type)) + + # store to SMEM + cute.autovec_copy(wb_tCrO_ld_half, sO[None, repeat, st_O_producer_state.index]) + else: + stage_in_use = pv_mma_consumer_state.index - 1 + stage_in_use = stage_in_use if stage_in_use != -1 else self.o_stage - 1 + for repeat in cutlass.range(wb_ld_repeat): + wb_tCrO_reduction.fill(0.0) + for stage in cutlass.range(cutlass.min(self.o_stage, _num_n_blocks)): + _stage_id = (stage_in_use + stage) % self.o_stage + cute.copy( + wb_tiled_copy_t2r, + wb_tCtO_ld[None, None, None, 0, repeat, _stage_id], + wb_tCrO_ld, + ) + + _scale = scale + if cutlass.const_expr(self.o_stage != 1): + acc_scale = cute.arch.exp2( + (sScale[1, tidx, _stage_id] - running_max[0]) * sm_scale_log2 + ) + _scale *= acc_scale + + # scale and reduction + for i in cutlass.range_constexpr(0, cute.size(wb_tCrO_ld, mode=[0]), 2): + wb_tCrO_reduction[i], wb_tCrO_reduction[i + 1] = ( + cute.arch.fma_packed_f32x2( + (wb_tCrO_ld[i], wb_tCrO_ld[i + 1]), + (_scale, _scale), + (wb_tCrO_reduction[i], wb_tCrO_reduction[i + 1]), + ) + ) + # type conversion + wb_tCrO_ld_half.store(wb_tCrO_reduction.load().to(wb_tCrO_ld_half.element_type)) + + # store to SMEM + cute.autovec_copy(wb_tCrO_ld_half, sO[None, repeat, st_O_producer_state.index]) + + # NOTE: It is safe now to do next wave + cute.arch.mbarrier_arrive_and_expect_tx(O_final_guard, 0) + + cute.arch.fence_proxy( + "async.shared", + space="cta", + ) + st_O_pipeline.producer_commit(st_O_producer_state) + st_O_producer_state.advance() + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def correction( + self, + TileSchedulerCls: Callable, + q2k_block_sparse_num: cute.Tensor, + p_producer_state: pipeline.PipelineState, + o_pipeline: pipeline.PipelineAsyncUmma, + o_producer_state: pipeline.PipelineState, + pv_mma_pipeline: pipeline.PipelineUmmaAsync, + pv_mma_consumer_state: pipeline.PipelineState, + st_O_pipeline: pipeline.PipelineAsync, + st_O_producer_state: pipeline.PipelineState, + tCtO: cute.Tensor, + thr_mma_pv: cute.core.ThrMma, + sO: cute.Tensor, + corr_pipeline: pipeline.PipelineAsync, + correction_consumer_state: pipeline.PipelineState, + sScale: cute.Tensor, + sm_scale_log2: cutlass.Float32, + num_q_blocks: cutlass.Int32, + variable_block_sizes: cute.Tensor, + sFinal: cute.Tensor, + sFinal_pipeline: pipeline.PipelineAsync, + sFinal_consumer_state: pipeline.PipelineState, + O_final_guard: cutlass.Pointer, + gLSE: cute.Tensor, + ): + tidx = cute.arch.thread_idx()[0] % (self.threads_per_warp * len(self.correction_warp_ids)) + in_which = cute.arch.make_warp_uniform(tidx // self.block_m) + corr_ld_inst: int = 32 + corr_ld_repeat: int = cute.ceil_div(self.mma_tiler_pv[1], corr_ld_inst) + + corr_copy_atom_t2r = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp( + tcgen05.copy.Repetition(corr_ld_inst), + ), + self.acc_dtype, + ) + corr_tO_load = cute.flat_divide( + tCtO[(None, None), 0, None], (self.mma_tiler_pv[0], corr_ld_inst) + ) + corr_tiled_copy_t2r = tcgen05.make_tmem_copy( + corr_copy_atom_t2r, corr_tO_load[None, None, 0, 0, 0] + ) + corr_thr_copy_t2r = corr_tiled_copy_t2r.get_slice(tidx) + corr_tCtO_ld = corr_thr_copy_t2r.partition_S(corr_tO_load) + + corr_copy_atom_r2t = cute.make_copy_atom( + tcgen05.copy.St32x32bOp( + tcgen05.copy.Repetition(corr_ld_inst), + ), + self.acc_dtype, + ) + corr_tiled_copy_r2t = tcgen05.make_tmem_copy( + corr_copy_atom_r2t, corr_tO_load[None, None, 0, 0, 0] + ) + + cO = cute.make_identity_tensor(self.mma_tiler_pv[:2]) + tCcO = thr_mma_pv.partition_C(cO) + corr_cO_load = cute.flat_divide( + tCcO[(None, None), 0, None], (self.mma_tiler_pv[0], corr_ld_inst) + ) + corr_tCcO_ld = corr_thr_copy_t2r.partition_D(corr_cO_load) + corr_tCrO_ld = cute.make_fragment( + cute.select(corr_tCcO_ld.shape, mode=[0, 1, 2]), self.acc_dtype + ) + + # for writeback + wb_ld_inst: int = 64 + wb_ld_repeat: int = cute.ceil_div(self.mma_tiler_pv[1], wb_ld_inst) + + wb_copy_atom_t2r = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp( + tcgen05.copy.Repetition(wb_ld_inst), + ), + self.acc_dtype, + ) + wb_tO_load = cute.flat_divide( + tCtO[(None, None), 0, None], (self.mma_tiler_pv[0], wb_ld_inst) + ) + wb_tiled_copy_t2r = tcgen05.make_tmem_copy( + wb_copy_atom_t2r, wb_tO_load[(None, None, 0, 0, 0)] + ) + wb_thr_copy_t2r = wb_tiled_copy_t2r.get_slice(tidx) + wb_tCtO_ld = wb_thr_copy_t2r.partition_S(wb_tO_load) + + wb_cO_load = cute.flat_divide( + tCcO[(None, None), 0, None], (self.mma_tiler_pv[0], wb_ld_inst) + ) + wb_tCcO_ld = wb_thr_copy_t2r.partition_D(wb_cO_load) + wb_tCrO_ld = cute.make_fragment( + cute.select(wb_tCcO_ld.shape, mode=[0, 1, 2]), self.acc_dtype + ) + wb_tCrO_ld_half = cute.make_fragment(wb_tCrO_ld.layout, sO.element_type) + tv_layout = cute.make_ordered_layout( + (self.threads_per_wg, self.mma_tiler_pv[1], self.s_stage), (0, 1, 2) + ) + sO_cpy = cute.composition(sO, tv_layout) + sO_cpy_slice = cute.flatten(sO_cpy[tidx, None, None]) + + wb_tCrO_reduction = None + if cutlass.const_expr(self.o_stage != 1): + wb_tCrO_reduction = cute.make_fragment(wb_tCrO_ld.layout, self.acc_dtype) + + _correction_loop = partial( + self.correction_loop, + TileSchedulerCls=TileSchedulerCls, + variable_block_sizes=variable_block_sizes, + q2k_block_sparse_num=q2k_block_sparse_num, + p_producer_state=p_producer_state, + o_pipeline=o_pipeline, + o_producer_state=o_producer_state, + correction_consumer_state=correction_consumer_state, + corr_ld_repeat=corr_ld_repeat, + corr_tiled_copy_t2r=corr_tiled_copy_t2r, + corr_tCtO_ld=corr_tCtO_ld, + corr_tCrO_ld=corr_tCrO_ld, + corr_tiled_copy_r2t=corr_tiled_copy_r2t, + pv_mma_pipeline=pv_mma_pipeline, + pv_mma_consumer_state=pv_mma_consumer_state, + st_O_pipeline=st_O_pipeline, + st_O_producer_state=st_O_producer_state, + tv_layout=tv_layout, + sO=sO_cpy_slice, + sScale=sScale, + wb_tCcO_ld=wb_tCcO_ld, + wb_tCrO_ld=wb_tCrO_ld, + sFinal=sFinal, + sFinal_pipeline=sFinal_pipeline, + sFinal_consumer_state=sFinal_consumer_state, + wb_ld_repeat=wb_ld_repeat, + wb_tCrO_reduction=wb_tCrO_reduction, + wb_tiled_copy_t2r=wb_tiled_copy_t2r, + wb_tCtO_ld=wb_tCtO_ld, + wb_tCrO_ld_half=wb_tCrO_ld_half, + tidx=tidx, + sm_scale_log2=sm_scale_log2, + num_q_blocks=num_q_blocks, + O_final_guard=O_final_guard, + gLSE=gLSE, + in_which=in_which, + ) + + _correction_loop(correction_pipeline=corr_pipeline) + + @cute.jit + def epilogue( + self, + TileSchedulerCls: Callable, + st_O_pipeline: pipeline.PipelineAsync, + st_O_consumer_state: pipeline.PipelineState, + gO: cute.Tensor, + sO: cute.Tensor, + tma_atom_O: cute.CopyAtom, + ): + def _get_sO_cpy( + sO: cute.Tensor, + ) -> cute.Tensor: + layout = cute.flatten(sO.layout) + layout = cute.select(layout, mode=[1, 0, 2, 3, 4, 5]) + layout = cute.flat_divide( + layout, + (8,), # it relates to 128B swizzle for 16-bit data + ) + layout = cute.select(layout, mode=[2, 0, 3, 4, 1, 5, 6]) + layout = cute.group_modes(layout, 0, 2) + layout = cute.group_modes(layout, 4, cute.rank(layout)) + cpy_tensor = cute.make_tensor(sO.iterator, layout) + return cpy_tensor + + sO_cpy = _get_sO_cpy(sO) + + def _get_gO_cpy( + gO: cute.Tensor, + ) -> cute.Tensor: + layout = cute.flat_divide(gO.layout, (64, 64)) + layout = cute.group_modes(layout, 0, 2) + cpy_tensor = cute.make_tensor(gO.iterator, layout) + return cpy_tensor + + _gO = _get_gO_cpy(gO) + tTMAsO, tTMAgO = cute.nvgpu.cpasync.tma_partition( + tma_atom_O, + 0, + cute.make_layout(1), + cute.group_modes(sO_cpy, 0, 2), + cute.group_modes(_gO, 0, 2), + ) + + num_o_blocks = cute.size(gO, mode=[2]) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m2_idx, head_idx, batch_idx = work_tile.tile_idx + + m_block_1st = m2_idx * 2 + m_block_2nd = m_block_1st + 1 + m_block_2nd = m_block_2nd if m_block_2nd < num_o_blocks else num_o_blocks + + tTMAgO_1st = tTMAgO[None, None, m_block_1st, None, head_idx, batch_idx] + tTMAgO_2nd = tTMAgO[None, None, m_block_2nd, None, head_idx, batch_idx] + + cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True) + st_O_pipeline.consumer_release(st_O_consumer_state) + st_O_pipeline.consumer_wait(st_O_consumer_state) + + cute.copy( + tma_atom_O, tTMAsO[None, 0, 0, st_O_consumer_state.index], tTMAgO_1st[None, 0, 0] + ) + cute.copy( + tma_atom_O, tTMAsO[None, 1, 0, st_O_consumer_state.index], tTMAgO_1st[None, 1, 0] + ) + cute.copy( + tma_atom_O, tTMAsO[None, 0, 1, st_O_consumer_state.index], tTMAgO_2nd[None, 0, 0] + ) + cute.copy( + tma_atom_O, tTMAsO[None, 1, 1, st_O_consumer_state.index], tTMAgO_2nd[None, 1, 0] + ) + + cute.arch.cp_async_bulk_commit_group() + st_O_consumer_state.advance() + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + cute.arch.cp_async_bulk_wait_group(0, read=True) diff --git a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/interface.py b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/interface.py new file mode 100644 index 000000000000..e590234f45e7 --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/interface.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Torch entry point for the CuTe DSL block-sparse attention forward. + +Blackwell (sm_100) fast path for VSA's fine stage. The kernel JIT-compiles +on first call and is cached per process; the caller +(CuTeDSLAttention._forward_vsa) falls back to dense SDPA when the +device/dtype/head_dim envelope is not met. +""" + +from __future__ import annotations + +import math +import os +from typing import Tuple + +import torch + +try: + import cuda.bindings.driver as _cuda + import cutlass.cute as cute + from cutlass.cute.runtime import from_dlpack + + from .block_sparse_attn_dsl_fwd import ( + VideoSparseAttentionForwardGroup2QInterleaveKV as VideoSparseAttentionForward, + ) + + CUTE_AVAILABLE = True +except ImportError: # cuda-bindings / cutlass-dsl not installed + _cuda = None + cute = None + from_dlpack = None + VideoSparseAttentionForward = None + CUTE_AVAILABLE = False + + +__all__ = [ + "CUTE_AVAILABLE", + "is_cute_supported", + "block_sparse_attn_from_indices_cute", +] + + +# JIT compile is multi-second; reuse aggressively. +_COMPILE_CACHE: dict = {} + + +def is_cute_supported(q: torch.Tensor) -> bool: + """Capability check for the CuTe path. Set TLLM_VSA_DISABLE_CUTE=1 to force off.""" + # Kernel asserts head_dim==128, block_m==block_n==64, fp16/bf16, sm_100+. + if os.environ.get("TLLM_VSA_DISABLE_CUTE", "").strip() in ("1", "true", "True"): + return False + if not CUTE_AVAILABLE: + return False + if not q.is_cuda: + return False + if q.dtype not in (torch.float16, torch.bfloat16): + return False + if q.shape[-1] != 128: + return False + cap = torch.cuda.get_device_capability(q.device) + return cap[0] >= 10 + + +def _to_cute_tensor(t: torch.Tensor): + """Convert a 4-D BHSD tensor into a CuTe tensor with dynamic B/H/S strides.""" + # Head-dim mode left static (=128) since the kernel specializes on it. + return ( + from_dlpack(t.detach(), assumed_align=128) + .mark_compact_shape_dynamic(mode=0, stride_order=t.dim_order()) + .mark_compact_shape_dynamic(mode=1, stride_order=t.dim_order()) + .mark_compact_shape_dynamic(mode=2, stride_order=t.dim_order()) + ) + + +@torch.compiler.disable +def block_sparse_attn_from_indices_cute( + q: torch.Tensor, # [B, H, S, D] fp16/bf16, sm_100+ + k: torch.Tensor, # [B, H, S, D] + v: torch.Tensor, # [B, H, S, D] + q2k_idx: torch.Tensor, # [B, H, num_q_blk, K] int32 + q2k_num: torch.Tensor, # [B, H, num_q_blk] int32 + variable_block_sizes: torch.Tensor, # [num_q_blk] int32 + sm_scale: float | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Run VSA fine-stage attention on Blackwell using the CuTe DSL kernel. + + Returns: + out: [B, H, S, D] same dtype as Q. + lse: [B, H, S] fp32. + """ + # Disabled for torch.compile: cuda.bindings.driver.CUstream + cute.compile + # are not Dynamo-traceable, and torch.cuda.current_stream() returns a + # proxy without .cuda_stream inside compiled regions. + if not CUTE_AVAILABLE: + raise RuntimeError( + "block_sparse_attn_from_indices_cute called but cuda.bindings or " + "cutlass-dsl is not importable." + ) + + B, H, T, D = q.shape + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + out = torch.empty_like(q) + lse = torch.empty((B, H, T), device=q.device, dtype=torch.float32) + + cuda_stream = _cuda.CUstream(torch.cuda.current_stream(q.device).cuda_stream) + + q_packed = _to_cute_tensor(q) + k_packed = _to_cute_tensor(k) + v_packed = _to_cute_tensor(v) + o_packed = _to_cute_tensor(out) + + lse_packed = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) + idx_packed = from_dlpack(q2k_idx.detach()).mark_layout_dynamic(leading_dim=3) + num_packed = from_dlpack(q2k_num.detach()).mark_layout_dynamic(leading_dim=2) + var_packed = from_dlpack(variable_block_sizes.detach()).mark_layout_dynamic(leading_dim=0) + + # Full q2k_idx shape is part of the key: mark_layout_dynamic(leading_dim=3) + # only makes the innermost stride dynamic; B-loop bound and inner strides + # are baked in at compile time, so each shape needs its own compiled kernel. + compile_key = (D, q.dtype, float(sm_scale)) + tuple(q2k_idx.shape) + compiled = _COMPILE_CACHE.get(compile_key) + if compiled is None: + fwd_kernel = VideoSparseAttentionForward(block_m=64, block_n=64, headdim=D) + compiled = cute.compile( + fwd_kernel, + q_packed, + k_packed, + v_packed, + sm_scale, + o_packed, + lse_packed, + idx_packed, + num_packed, + var_packed, + cuda_stream, + ) + _COMPILE_CACHE[compile_key] = compiled + + compiled( + q_packed, + k_packed, + v_packed, + sm_scale, + o_packed, + lse_packed, + idx_packed, + num_packed, + var_packed, + cuda_stream, + ) + return out, lse diff --git a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/ptx.py b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/ptx.py new file mode 100644 index 000000000000..12fa15bc5930 --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/ptx.py @@ -0,0 +1,377 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass._mlir.dialects import llvm, nvvm +from cutlass.cute import typing as cutlass_typing +from cutlass.cutlass_dsl import dsl_user_op + + +@dsl_user_op +def warp_reduction_fmax( + val: cutlass.Float32, + mask: cutlass.Int32 = 0xFFFFFFFF, + *, + loc=None, + ip=None, +) -> cutlass.Float32: + return cutlass.Float32( + llvm.inline_asm( + cutlass_typing.Float32.mlir_type, + [ + cutlass_typing.Float32(val).ir_value(loc=loc, ip=ip), + cutlass_typing.Int32(mask).ir_value(loc=loc, ip=ip), + ], + """{\n\t + redux.sync.max.f32 $0, $1, $2;\n\t + \n\t}""", + "=f,f,r", + ) + ) + + +@dsl_user_op +def __cvta_generic_to_shared( + ptr: cutlass.Pointer, + *, + loc=None, + ip=None, +) -> cutlass.Uint32: + # NOTE: assume the SMEM pointer fits in a 32-bit register + return cutlass.Uint32( + llvm.inline_asm( + cutlass_typing.Uint32.mlir_type, + [ + cutlass_typing.Int32(ptr.toint()).ir_value(loc=loc, ip=ip), + ], + """{\n\t + mov.u32 $0, $1; + \n\t}""", + "=r, r", + ) + ) + + +@dsl_user_op +def atomicAdd_f32( + val: cutlass.Float32, + ptr: cutlass.Pointer, + *, + loc=None, + ip=None, +): + if cutlass.const_expr(ptr.memspace == cutlass_typing.AddressSpace.smem): + ptr = __cvta_generic_to_shared(ptr, loc=loc, ip=ip) + llvm.inline_asm( + None, + [ + cutlass_typing.Uint32(ptr).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(val).ir_value(loc=loc, ip=ip), + ], + """{\n\t + atom.relaxed.shared::cta.cta.add.f32 _, [$0], $1;\n\t + \n\t}""", + "r, f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + else: + llvm.inline_asm( + None, + [ + cutlass_typing.Int64(ptr.toint()).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(val).ir_value(loc=loc, ip=ip), + ], + """{\n\t + atom.relaxed.shared::cta.cta.add.f32 _, [$0], $1;\n\t + \n\t}""", + "l, f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def atomicMax_f32( + val: cutlass.Float32, + ptr: cutlass.Pointer, + *, + loc=None, + ip=None, +): + val_i32 = llvm.bitcast( + cutlass_typing.Int32.mlir_type, val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + if cutlass.const_expr(ptr.memspace == cutlass_typing.AddressSpace.smem): + ptr = __cvta_generic_to_shared(ptr, loc=loc, ip=ip) + llvm.inline_asm( + None, + [ + cutlass_typing.Uint32(ptr).ir_value(loc=loc, ip=ip), + cutlass_typing.Int32(val_i32).ir_value(loc=loc, ip=ip), + ], + """{\n\t + atom.relaxed.shared::cta.cta.max.s32 _, [$0], $1;\n\t + \n\t}""", + "r, r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + else: + llvm.inline_asm( + None, + [ + cutlass_typing.Int64(ptr.toint()).ir_value(loc=loc, ip=ip), + cutlass_typing.Int32(val_i32).ir_value(loc=loc, ip=ip), + ], + """{\n\t + atom.relaxed.shared::cta.cta.max.s32 _, [$0], $1;\n\t + \n\t}""", + "l, r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def exp2f( + val: cutlass.Float32, + *, + loc=None, + ip=None, +): + return cutlass.Float32( + llvm.inline_asm( + cutlass_typing.Float32.mlir_type, + [ + cutlass_typing.Float32(val).ir_value(loc=loc, ip=ip), + ], + """{\n\t + .reg .f32 f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11;\n\t + .reg .s32 r1, r2, r3;\n\t + max.ftz.f32 f1, $1, 0fC2FE0000;\n\t + mov.f32 f3, 0f4B400000;\n\t + add.rm.ftz.f32 f4, f1, f3;\n\t + sub.rn.ftz.f32 f5, f4, f3;\n\t + sub.rn.ftz.f32 f6, f1, f5;\n\t + mov.f32 f7, 0f3D9DF09D;\n\t + mov.f32 f8, 0f3E6906A4;\n\t + mov.f32 f9, 0f3F31F519;\n\t + mov.f32 f10, 0f3F800000;\n\t + fma.rn.ftz.f32 f11, f6, f7, f8;\n\t + fma.rn.ftz.f32 f11, f11, f6, f9;\n\t + fma.rn.ftz.f32 f11, f11, f6, f10;\n\t + mov.b32 r3, f11;\n\t + shl.b32 r1, f4, 23;\n\t + add.s32 r2, r1, r3;\n\t + mov.b32 $0, r2;\n\t + \n\t}""", + "=f, f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def fma( + a: cutlass.Float32, + b: cutlass.Float32, + c: cutlass.Float32, + *, + loc=None, + ip=None, +) -> cutlass.Float32: + return cutlass.Float32( + llvm.inline_asm( + cutlass_typing.Float32.mlir_type, + [ + cutlass_typing.Float32(a).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(b).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(c).ir_value(loc=loc, ip=ip), + ], + """{\n\t + fma.rn.ftz.f32 $0, $1, $2, $3;\n\t + \n\t}""", + "=f, f, f, f", + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def max3f( + a: cutlass.Float32, + b: cutlass.Float32, + c: cutlass.Float32, + *, + loc=None, + ip=None, +) -> cutlass.Float32: + return cutlass.Float32( + llvm.inline_asm( + cutlass_typing.Float32.mlir_type, + [ + cutlass_typing.Float32(a).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(b).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(c).ir_value(loc=loc, ip=ip), + ], + """{\n\t + max.f32 $0, $1, $2, $3;\n\t + \n\t}""", + "=f, f, f, f", + loc=loc, + ip=ip, + ) + ) + + +sub_packed_f32x2 = partial( + cute.arch.calc_packed_f32x2_op, + src_c=None, + calc_func=nvvm.sub_packed_f32x2, + rnd=nvvm.RoundingModeKind.RN, +) + + +@dsl_user_op +@cute.jit +def evaluate_polynomial_2( + x: cutlass.Float32, + y: cutlass.Float32, + poly: Tuple[cutlass.Float32, ...], + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Float32, cutlass.Float32]: + deg = len(poly) - 1 + out = (poly[deg], poly[deg]) + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i])) + return out + + +@dsl_user_op +def combine_int_frac_ex2( + x_rounded: cutlass.Float32, + frac_ex2: cutlass.Float32, + *, + loc=None, + ip=None, +) -> cutlass.Float32: + return cutlass.Float32( + llvm.inline_asm( + cutlass_typing.Float32.mlir_type, + [ + cutlass_typing.Float32(x_rounded).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(frac_ex2).ir_value(loc=loc, ip=ip), + ], + """{\n\t + .reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i; \n\t + mov.b32 x_rounded_i, $1; \n\t + mov.b32 frac_ex_i, $2; \n\t + shl.b32 x_rounded_e, x_rounded_i, 23; \n\t + add.s32 out_i, x_rounded_e, frac_ex_i; \n\t + mov.b32 $0, out_i; \n\t + \n\t}""", + "=f, f, f", + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def exp2_emulation_2( + x: cutlass.Float32, + y: cutlass.Float32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Float32, cutlass.Float32]: + # assume x <= 127.0 and y <= 127.0 + poly_ex2_deg3 = ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ) + fp32_round_int = float(2**23 + 2**22) + xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) + xy_rounded = cute.arch.add_packed_f32x2( + xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM + ) + xy_rounded_back = sub_packed_f32x2(xy_rounded, (fp32_round_int, fp32_round_int)) + xy_frac = sub_packed_f32x2(xy_clamped, xy_rounded_back) + xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, poly_ex2_deg3, loc=loc, ip=ip) + x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) + y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) + return x_out, y_out + + +@dsl_user_op +def exp2f_packed_f32x2( + x: cutlass.Float32, + y: cutlass.Float32, + *, + loc=None, + ip=None, +) -> Tuple[cutlass.Float32, cutlass.Float32]: + result = llvm.inline_asm( + llvm.StructType.get_literal( + [ + cutlass_typing.Float32.mlir_type, + cutlass_typing.Float32.mlir_type, + ] + ), + [ + cutlass_typing.Float32(x).ir_value(loc=loc, ip=ip), + cutlass_typing.Float32(y).ir_value(loc=loc, ip=ip), + ], + """{\n\t + ex2.approx.f32 $0, $2;\n\t + ex2.approx.f32 $1, $3;\n\t + \n\t}""", + # Keep constraints compact (no spaces) and in the correct order + "=f,=f,f,f", + loc=loc, + ip=ip, + ) + + # Extract struct fields + out0_val = llvm.extractvalue(cutlass_typing.Float32.mlir_type, result, [0], loc=loc, ip=ip) + out1_val = llvm.extractvalue(cutlass_typing.Float32.mlir_type, result, [1], loc=loc, ip=ip) + + # Wrap back into cutlass.Float32 + return cutlass.Float32(out0_val), cutlass.Float32(out1_val) diff --git a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/scheduler.py b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/scheduler.py new file mode 100644 index 000000000000..921251c613ad --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/scheduler.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, fields +from functools import partial +from typing import Optional, Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 +from cutlass._mlir import ir +from cutlass.cute import FastDivmodDivisor + +try: + from typing import override +except ImportError: + from typing_extensions import override + + +class WorkTileInfo(cutlass.utils.WorkTileInfo): + """ + It includes block, head, batch, and is_valid_tile + """ + + @override + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": + assert len(values) == 4 + new_tile_idx = cutlass.new_from_mlir_values(self.tile_idx, values[:-1]) + new_is_valid_tile = cutlass.new_from_mlir_values(self.is_valid_tile, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, cutlass.Constexpr)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, cutlass.Constexpr) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class TileSchedulerParams(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + headdim: Int32 + headdim_v: Int32 + + +class StaticPersistentScheduler: + @dataclass + class Params(ParamsBase): + num_block_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor + total_blocks: Int32 + + @staticmethod + def create( + args: TileSchedulerParams, + *, + loc=None, + ip=None, + ) -> "StaticPersistentScheduler.Params": + total_blocks = args.num_block * args.num_head * args.num_batch + return StaticPersistentScheduler.Params( + FastDivmodDivisor(args.num_block), FastDivmodDivisor(args.num_head), total_blocks + ) + + def __init__( + self, + params: Params, + tile_idx: Int32, + *, + loc=None, + ip=None, + ): + self.params = params + self.tile_idx = tile_idx + self.loc = loc + self.ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerParams, *, loc=None, ip=None) -> Params: + return StaticPersistentScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentScheduler": + tile_idx = cute.arch.block_idx()[0] + return StaticPersistentScheduler(params, tile_idx, loc=loc, ip=ip) + + @staticmethod + def get_grid_shape( + params: Params, + *, + sm_count: Optional[Int32] = None, + occupancy: Int32 = 1, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + if sm_count is None: + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + vacancies = sm_count * occupancy + return (cutlass.min(vacancies, params.total_blocks), Int32(1), Int32(1)) + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + hn_idx, block_idx = divmod(self.tile_idx, self.params.num_block_divmod) + batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) + is_valid = self.tile_idx < self.params.total_blocks + return WorkTileInfo((Int32(block_idx), Int32(head_idx), Int32(batch_idx)), is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self.tile_idx += cute.arch.grid_dim()[0] + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self.tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.params, self.tile_idx], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return StaticPersistentScheduler(*(tuple(obj_list)), loc=self.loc) + + +if __name__ == "__main__": + + class TestScheduler: + def __init__(self, scheduler_cls): + self.scheduler_cls = scheduler_cls + + @cute.jit + def __call__(self): + scheduler_params = TileSchedulerParams( + num_block=Int32(5), + num_head=Int32(4), + num_batch=Int32(3), + headdim=Int32(128), + headdim_v=Int32(64), + ) + params = self.scheduler_cls.to_underlying_arguments(scheduler_params) + + grid_dim = self.scheduler_cls.get_grid_shape(params) + print(f"grid_dim: {grid_dim}") + + self.kernel(params).launch(grid=grid_dim, block=[32, 2, 1], min_blocks_per_mp=1) + + @cute.kernel + def kernel(self, params: ParamsBase): + TileSchedulerCls = partial(self.scheduler_cls.create, params) + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + bidx, _, _ = cute.arch.block_idx() + + if warp_idx == 0: + lane_idx = cute.arch.lane_idx() + + scheduler = TileSchedulerCls() + + work_tile = scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + + if lane_idx == 0: + cute.printf( + "block_idx: {}, warp_idx: {}, m_block: {}, head_idx: {}, batch_idx: {}", + bidx, + warp_idx, + m_block, + head_idx, + batch_idx, + ) + + scheduler.prefetch_next_work() + scheduler.advance_to_next_work() + work_tile = scheduler.get_current_work() + + elif warp_idx == 1: + scheduler = TileSchedulerCls() + + test_scheduler = TestScheduler(StaticPersistentScheduler) + test_scheduler() diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py index c43345f9c8c4..930414d840f4 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py @@ -59,6 +59,14 @@ def __init__(self, model_config): "FluxPipeline does not support CFG parallelism. Please set cfg_size to 1." ) + _sa_cfg = model_config.attention.sparse_attention_config + if _sa_cfg is not None and getattr(_sa_cfg, "algorithm", None) == "vsa": + raise ValueError( + "Video Sparse Attention (sparse_attention_config.algorithm='vsa') is " + "only supported by Wan T2V pipelines. Remove sparse_attention_config " + "for FLUX." + ) + super().__init__(model_config) @staticmethod diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py index 8675a0387e79..874a3edeed45 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py @@ -128,6 +128,14 @@ def __init__(self, model_config): "Flux2Pipeline does not support CFG parallelism. Please set cfg_size to 1." ) + _sa_cfg = model_config.attention.sparse_attention_config + if _sa_cfg is not None and getattr(_sa_cfg, "algorithm", None) == "vsa": + raise ValueError( + "Video Sparse Attention (sparse_attention_config.algorithm='vsa') is " + "only supported by Wan T2V pipelines. Remove sparse_attention_config " + "for FLUX.2." + ) + super().__init__(model_config) @staticmethod diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py index fc5a21758478..6cd874997bab 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py @@ -599,6 +599,16 @@ class LTX2Pipeline(BasePipeline): ``transformers`` library. """ + def __init__(self, model_config): + _sa_cfg = model_config.attention.sparse_attention_config + if _sa_cfg is not None and getattr(_sa_cfg, "algorithm", None) == "vsa": + raise ValueError( + "Video Sparse Attention (sparse_attention_config.algorithm='vsa') is " + "only supported by Wan T2V pipelines. Remove sparse_attention_config " + "for LTX-2." + ) + super().__init__(model_config) + @classmethod def resolve_variant(cls, config): if getattr(config, "cache_backend", None) == "cache_dit": diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index f314f33fad3f..6f5bb0a0ffb0 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -9,6 +9,10 @@ from diffusers.video_processor import VideoProcessor from transformers import AutoTokenizer, UMT5EncoderModel +from tensorrt_llm._torch.visual_gen.attention_backend import ( + VSAMetadataBuilder, + set_vsa_forward_context, +) from tensorrt_llm._torch.visual_gen.cache.teacache import ( ExtractorConfig, register_extractor_from_config, @@ -374,8 +378,18 @@ def infer(self, req): seed=req.params.seed, max_sequence_length=req.params.max_sequence_length, image=image, + flow_shift=req.params.flow_shift, ) + def _default_flow_shift(self, height: int, width: int) -> float: + """Recommended flow_shift for the active Wan variant + resolution.""" + + if self.is_wan22_14b: + return 12.0 # Wan2.2 T2V A14B + if self.is_wan22_5b: + return 5.0 # Wan2.2 TI2V 5B + return 5.0 if max(height, width) >= 1280 else 3.0 # Wan2.1 T2V (720P vs 480P) + @nvtx_range("WanPipeline.forward") @torch.no_grad() def forward( @@ -392,6 +406,7 @@ def forward( seed: int = 42, max_sequence_length: int = 512, image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None, + flow_shift: Optional[float] = None, ): pipeline_start = time.time() timer = CudaPhaseTimer() @@ -480,6 +495,17 @@ def forward( latents = self._prepare_latents(batch_size, height, width, num_frames, generator) logger.debug(f"Latents shape: {latents.shape}") + # Resolve flow_shift: user override wins, else the per-variant recommended default. + resolved_flow_shift = ( + flow_shift if flow_shift is not None else self._default_flow_shift(height, width) + ) + if self.scheduler.config.shift != resolved_flow_shift: + logger.info( + f"flow_shift: {self.scheduler.config.shift} -> {resolved_flow_shift} " + f"({'user' if flow_shift is not None else 'variant default'})" + ) + self.scheduler.config.shift = resolved_flow_shift + self.scheduler.set_timesteps(num_inference_steps, device=self.device) # Wan2.2 A14B: Calculate boundary timestep for two-stage denoising @@ -491,9 +517,22 @@ def forward( f"guidance_scale={guidance_scale}, guidance_scale_2={guidance_scale_2}" ) + # VSA: build metadata builder once per forward() call; reused across timesteps. + _attn_cfg = self.model_config.attention + _sparse_cfg = getattr(_attn_cfg, "sparse_attention_config", None) + _vsa_active = ( + getattr(_attn_cfg, "backend", "VANILLA") == "CUTEDSL" + and _sparse_cfg is not None + and getattr(_sparse_cfg, "algorithm", None) == "vsa" + ) + _vsa_builder = VSAMetadataBuilder() if _vsa_active else None + _vsa_patch_size = tuple(getattr(self.config, "patch_size", [1, 2, 2])) # (pT, pH, pW) + _vsa_sparsity = _sparse_cfg.vsa_sparsity if _vsa_active else 0.0 + # Denoising with two-stage support # Track which model was used in last step (for logging model transitions) last_model_used = [None] + _vsa_step_counter = [0] def forward_fn( latents, extra_stream_latents, timestep, encoder_hidden_states, extra_tensors @@ -537,6 +576,24 @@ def forward_fn( # T2V: current_t for all frames timestep = current_t.reshape(1, 1).expand(latents.shape[0], nf * nh * nw) + if _vsa_active and _vsa_builder is not None: + # latents: [B, C, T_latent, H_latent, W_latent] + raw_latent_shape = (latents.shape[2], latents.shape[3], latents.shape[4]) + vsa_metadata = _vsa_builder.build( + current_timestep=_vsa_step_counter[0], + raw_latent_shape=raw_latent_shape, + patch_size=_vsa_patch_size, + vsa_sparsity=_vsa_sparsity, + device=latents.device, + ) + _vsa_step_counter[0] += 1 + with set_vsa_forward_context(vsa_metadata): + return current_model( + hidden_states=latents, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + ) + return current_model( hidden_states=latents, timestep=timestep, diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index ca9e53ecf9a8..b7402d0b3630 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -101,6 +101,14 @@ def __init__(self, model_config): "Use cache_backend='none' or 'cache_dit' (not 'teacache')." ) + _sa_cfg = model_config.attention.sparse_attention_config + if _sa_cfg is not None and getattr(_sa_cfg, "algorithm", None) == "vsa": + raise ValueError( + "Video Sparse Attention (sparse_attention_config.algorithm='vsa') is " + "only supported by Wan T2V pipelines (Wan 2.1 and Wan 2.2). Remove " + "sparse_attention_config for Wan I2V." + ) + super().__init__(model_config) def _compute_wan_timestep_embedding(self, module, timestep=None, **kwargs): @@ -404,8 +412,16 @@ def infer(self, req): seed=req.params.seed, max_sequence_length=req.params.max_sequence_length, last_image=last_image, + flow_shift=req.params.flow_shift, ) + def _default_flow_shift(self, height: int, width: int) -> float: + """Recommended flow_shift for the active Wan I2V variant + resolution.""" + + if self.is_wan22_14b: + return 5.0 # Wan2.2 I2V A14B + return 5.0 if max(height, width) >= 1280 else 3.0 # Wan2.1 I2V (720P vs 480P) + @torch.no_grad() def forward( self, @@ -422,6 +438,7 @@ def forward( seed: int = 42, max_sequence_length: int = 512, last_image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None, + flow_shift: Optional[float] = None, ): pipeline_start = time.time() timer = CudaPhaseTimer() @@ -529,6 +546,17 @@ def forward( batch_size, image, height, width, num_frames, generator, last_image ) + # Resolve flow_shift: user override wins, else the per-variant default. + resolved_flow_shift = ( + flow_shift if flow_shift is not None else self._default_flow_shift(height, width) + ) + if self.scheduler.config.shift != resolved_flow_shift: + logger.info( + f"flow_shift: {self.scheduler.config.shift} -> {resolved_flow_shift} " + f"({'user' if flow_shift is not None else 'variant default'})" + ) + self.scheduler.config.shift = resolved_flow_shift + self.scheduler.set_timesteps(num_inference_steps, device=self.device) # Wan2.2: Calculate boundary timestep for two-stage denoising diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py index 0bac474df0ff..3eb08b415ad3 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py @@ -341,6 +341,41 @@ def __init__( reduce_output=(tp_size != 1), ) + # VSA gates (CUTEDSL backend, sparse_attention_config.algorithm == "vsa"). + # G_c weights the coarse branch; G_f weights the fine branch. + self.to_gate_compress = None + self.to_gate_fine = None + _attn_cfg = getattr(model_config, "attention", None) + _sa_cfg = getattr(_attn_cfg, "sparse_attention_config", None) if _attn_cfg else None + _is_vsa = ( + _attn_cfg is not None + and getattr(_attn_cfg, "backend", "VANILLA") == "CUTEDSL" + and _sa_cfg is not None + and getattr(_sa_cfg, "algorithm", None) == "vsa" + ) + if _is_vsa: + q_dim = num_heads * head_dim + self.to_gate_compress = Linear( + hidden_size, + q_dim, + bias=True, + dtype=dtype, + mapping=model_config.mapping, + quant_config=quant_config, + skip_create_weights_in_init=skip_create_weights, + force_dynamic_quantization=force_dynamic_quant, + ) + self.to_gate_fine = Linear( + hidden_size, + q_dim, + bias=True, + dtype=dtype, + mapping=model_config.mapping, + quant_config=quant_config, + skip_create_weights_in_init=skip_create_weights, + force_dynamic_quantization=force_dynamic_quant, + ) + # I2V: Additional K/V projections for image embeddings. self.add_k_proj = self.add_v_proj = None self.norm_added_k = None @@ -382,6 +417,32 @@ def __init__( torch.empty(1, 6, hidden_size).normal_(std=hidden_size**-0.5) ) + def init_gate_compress_zero(self) -> None: + """Zero-initialize to_gate_compress.""" + if self.to_gate_compress is None: + return + if not self.to_gate_compress._weights_created: + self.to_gate_compress.create_weights() + if self.to_gate_compress.weight.is_meta: + return + with torch.no_grad(): + self.to_gate_compress.weight.zero_() + if self.to_gate_compress.bias is not None: + self.to_gate_compress.bias.zero_() + + def init_gate_fine_default(self) -> None: + """Initialize to_gate_fine to emit constant 1 (weight=0, bias=1).""" + if self.to_gate_fine is None: + return + if not self.to_gate_fine._weights_created: + self.to_gate_fine.create_weights() + if self.to_gate_fine.weight.is_meta: + return + with torch.no_grad(): + self.to_gate_fine.weight.zero_() + if self.to_gate_fine.bias is not None: + self.to_gate_fine.bias.fill_(1.0) + def forward( self, x, @@ -414,15 +475,16 @@ def forward( # Prepare frequencies for Attention freqs = (freqs_cos, freqs_sin) if freqs_cos is not None and freqs_sin is not None else None + attn1_kwargs = {} + if self.to_gate_compress is not None: + attn1_kwargs["gate_compress"] = self.to_gate_compress(normed) + if self.to_gate_fine is not None: + attn1_kwargs["gate_fine"] = self.to_gate_fine(normed) + # Self-attention with RoPE - x = ( - x.float() - + self.attn1( - normed, - freqs=freqs, - ).float() - * gate_msa - ).to(x.dtype) + x = (x.float() + self.attn1(normed, freqs=freqs, **attn1_kwargs).float() * gate_msa).to( + x.dtype + ) norm_x = self.norm2(x.float()).to(x.dtype) @@ -777,6 +839,7 @@ def load_weights(self, weights: dict) -> None: } loader = DynamicLinearWeightLoader(self.model_config, params_map=params_map) + loaded_linears: set = set() for name, module in tqdm(self.named_modules(), desc="Loading weights"): if len(module._parameters) == 0: continue @@ -786,6 +849,7 @@ def load_weights(self, weights: dict) -> None: if weight_dicts: loader.load_linear_weights(module, name, weight_dicts) + loaded_linears.add(name) elif "add_k_proj" in name or "add_v_proj" in name: logger.info(f"[Weight Loading] No weights found for I2V module: {name}") elif isinstance(module, RMSNormTPAware): @@ -799,6 +863,20 @@ def load_weights(self, weights: dict) -> None: module_weights[param_name].to(self.model_config.torch_dtype) ) + # Default any VSA gates not loaded from the checkpoint: G_c=0, G_f=1 + # (preserves dense behavior at sparsity=0). + for name, module in self.named_modules(): + if not isinstance(module, WanBlock): + continue + if module.to_gate_compress is not None: + gate_path = f"{name}.to_gate_compress" if name else "to_gate_compress" + if gate_path not in loaded_linears: + module.init_gate_compress_zero() + if module.to_gate_fine is not None: + gate_path = f"{name}.to_gate_fine" if name else "to_gate_fine" + if gate_path not in loaded_linears: + module.init_gate_fine_default() + def post_load_weights(self) -> None: """Call post_load_weights on all Linear modules and convert embedders to target dtype.""" # Convert condition_embedder components to target dtype diff --git a/tensorrt_llm/_torch/visual_gen/modules/attention.py b/tensorrt_llm/_torch/visual_gen/modules/attention.py index 3465fcd63acd..48cdb4004766 100644 --- a/tensorrt_llm/_torch/visual_gen/modules/attention.py +++ b/tensorrt_llm/_torch/visual_gen/modules/attention.py @@ -92,13 +92,27 @@ def __init__( # Select compute backend (orthogonal to parallelism) vgm = config.visual_gen_mapping ulysses_size = vgm.ulysses_size if vgm else 1 + attn2d_size = (vgm.attn2d_row_size * vgm.attn2d_col_size) if vgm else 1 base_backend = config.attention.backend + _sa_cfg = config.attention.sparse_attention_config + _is_vsa = ( + base_backend == "CUTEDSL" + and _sa_cfg is not None + and getattr(_sa_cfg, "algorithm", None) == "vsa" + ) - # TRTLLM doesn't support cross-attention (different Q/KV seq lengths); fall back to VANILLA - if self.qkv_mode == QKVMode.SEPARATE_QKV and base_backend == "TRTLLM": + # Cross-attention fallback: TRTLLM and CUTEDSL VSA are self-attn only. + if self.qkv_mode == QKVMode.SEPARATE_QKV and (base_backend == "TRTLLM" or _is_vsa): backend_name = "VANILLA" else: backend_name = base_backend + + if _is_vsa and attn2d_size > 1: + raise ValueError( + f"VSA needs the full token sequence per rank, so it is incompatible " + f"with Attention2D (attn2d_size={attn2d_size}). Use ulysses or cfg " + f"parallelism instead." + ) self.attn_backend = backend_name self.qk_norm = qk_norm self.qk_norm_mode = qk_norm_mode @@ -425,8 +439,10 @@ def _attn_impl( """ Call attention backend with appropriate tensor layout. - Dimensions are derived from tensor shapes. Extra ``**kwargs`` - (e.g. ``attention_mask``) are forwarded to the backend. + Dimensions are derived from tensor shapes. Extra **kwargs are + forwarded to the backend. Backend-specific tensors that share + Q/K/V's [B, S, H*D] layout (e.g. VSA's gate_compress / + gate_fine) are reshaped here to the backend's 4-D layout. Two layout paths: 1. HND backends (VANILLA): [B, S, H*D] -> [B, H, S, D] @@ -438,6 +454,12 @@ def _attn_impl( seq_len = q.shape[1] seq_len_kv = k.shape[1] if k is not None else seq_len + def _reshape_gate(gate: torch.Tensor) -> torch.Tensor: + gate = gate.view(batch_size, -1, self.num_attention_heads, self.head_dim) + if backend_layout == AttentionTensorLayout.HND: + gate = gate.transpose(1, 2) + return gate + # Reshape inputs: [B, S, H*D] -> backend's preferred 4D layout if backend_layout == AttentionTensorLayout.HND: q = q.view(batch_size, -1, self.local_num_attention_heads, self.head_dim).transpose( @@ -461,13 +483,11 @@ def _attn_impl( "seq_len_kv": seq_len_kv, } ) + for gate_key in ("gate_compress", "gate_fine"): + if kwargs.get(gate_key) is not None: + kwargs[gate_key] = _reshape_gate(kwargs[gate_key]) - out = self.attn.forward( - q=q, - k=k, - v=v, - **kwargs, - ) + out = self.attn.forward(q=q, k=k, v=v, **kwargs) # Flatten back to [B, S, H*D] if backend_layout == AttentionTensorLayout.HND: @@ -480,6 +500,7 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, ) -> torch.Tensor: assert hidden_states.ndim == 3, "hidden_states must be a 3D tensor" batch_size, seq_len = hidden_states.shape[:2] @@ -498,7 +519,7 @@ def forward( freqs_cos, freqs_sin = freqs self.apply_packed_qk_norm_rope(qkv, freqs_cos, freqs_sin) q, k, v = qkv.split([self.local_q_dim, self.local_kv_dim, self.local_kv_dim], dim=-1) - out = self._attn_impl(q, k, v) + out = self._attn_impl(q, k, v, **kwargs) return self.to_out[0](out) # Unfused path: separate QK norm → separate RoPE → attention @@ -517,6 +538,6 @@ def forward( q = q.flatten(2) k = k.flatten(2) - out = self._attn_impl(q, k, v) + out = self._attn_impl(q, k, v, **kwargs) out = self.to_out[0](out) return out diff --git a/tensorrt_llm/_torch/visual_gen/pipeline_loader.py b/tensorrt_llm/_torch/visual_gen/pipeline_loader.py index ac61ce729fbc..0098ff57f7f4 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline_loader.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline_loader.py @@ -218,6 +218,23 @@ def load( logger.info(f"Quantization: {config.quant_config.quant_algo.name}") logger.info(f"Dynamic weight quant: {config.dynamic_weight_quant}") + _attn_backend = config.attention.backend + _sa_cfg = config.attention.sparse_attention_config + if ( + _attn_backend == "CUTEDSL" + and _sa_cfg is not None + and getattr(_sa_cfg, "algorithm", None) == "vsa" + ): + from .cute_dsl_kernels.blackwell.video_sparse_attention import CUTE_AVAILABLE + + kernel_path = "CuTe DSL block-sparse" if CUTE_AVAILABLE else "dense SDPA fallback" + logger.info( + f"Attention backend: CUTEDSL (algorithm=vsa, " + f"sparsity={_sa_cfg.vsa_sparsity}, fine-stage={kernel_path})" + ) + else: + logger.info(f"Attention backend: {_attn_backend}") + # ===================================================================== # STEP 1b: Build VisualGenMapping (must precede model creation) # ===================================================================== diff --git a/tensorrt_llm/visual_gen/__init__.py b/tensorrt_llm/visual_gen/__init__.py index 0975fa2ee2ca..c44abd65a31d 100644 --- a/tensorrt_llm/visual_gen/__init__.py +++ b/tensorrt_llm/visual_gen/__init__.py @@ -36,6 +36,7 @@ SparseAttentionConfig, TeaCacheConfig, TorchCompileConfig, + VideoSparseAttentionConfig, VisualGenArgs, ) from .output import VisualGenMetrics, VisualGenOutput @@ -60,6 +61,7 @@ "QuantAttentionConfig", "SparseAttentionConfig", "SkipSoftmaxConfig", + "VideoSparseAttentionConfig", "CacheConfig", "TeaCacheConfig", "CacheDiTConfig", diff --git a/tensorrt_llm/visual_gen/args.py b/tensorrt_llm/visual_gen/args.py index 8daefd5d7bb6..f87c82fda230 100644 --- a/tensorrt_llm/visual_gen/args.py +++ b/tensorrt_llm/visual_gen/args.py @@ -30,7 +30,7 @@ from tensorrt_llm.llmapi.utils import StrictBaseModel, set_api_status from tensorrt_llm.models.modeling_utils import QuantConfig -from .sparse_attention import SkipSoftmaxConfig +from .sparse_attention import SkipSoftmaxConfig, VideoSparseAttentionConfig # ============================================================================= # Type aliases @@ -86,7 +86,7 @@ class QuantAttentionConfig(StrictBaseModel): # Discriminated union of sparse attention configs. SparseAttentionConfig = Annotated[ - Union[SkipSoftmaxConfig], + Union[SkipSoftmaxConfig, VideoSparseAttentionConfig], Field(discriminator="algorithm"), ] @@ -111,7 +111,10 @@ class AttentionConfig(StrictBaseModel): sparse_attention_config: Optional[SparseAttentionConfig] = Field( None, status="prototype", - description="Sparse attention configuration. Currently supports: skip_softmax.", + description=( + "Sparse attention recipe. Discriminated by algorithm: " + "skip_softmax (TRTLLM backend) or VSA (CUTEDSL backend)." + ), ) sparse_config_path: Optional[str] = Field( None, @@ -169,6 +172,41 @@ def _validate_quant_attention_config(self) -> "AttentionConfig": ) return self + @model_validator(mode="after") + def _validate_sparse_attention_config(self) -> "AttentionConfig": + if self.sparse_attention_config is None: + return self + + algo = self.sparse_attention_config.algorithm + required_backend = {"skip_softmax": "TRTLLM", "vsa": "CUTEDSL"}.get(algo) + if required_backend is None: + return self + + if self.backend != required_backend: + raise ValueError( + f"sparse_attention_config with algorithm='{algo}' requires " + f"backend='{required_backend}', got backend='{self.backend}'. " + f"Either set backend='{required_backend}' or remove " + f"sparse_attention_config." + ) + return self + + @model_validator(mode="after") + def _validate_cutedsl_quant_sparse_mutex(self) -> "AttentionConfig": + # quant_attention_config and sparse_attention_config are mutually exclusive. + if ( + self.backend == "CUTEDSL" + and self.quant_attention_config is not None + and self.sparse_attention_config is not None + ): + raise ValueError( + "CUTEDSL backend: quant_attention_config and " + "sparse_attention_config are mutually exclusive (the " + "CuTeDSLAttention dispatcher selects either the dense path " + "or the sparse VSA path, not both)." + ) + return self + class ParallelConfig(StrictBaseModel): """Configuration for distributed parallelism across DiT-shaped models. @@ -587,6 +625,7 @@ def from_yaml(cls, yaml_path: Union[str, Path], **overrides: Any) -> "VisualGenA "QuantAttentionConfig", "SparseAttentionConfig", "SkipSoftmaxConfig", + "VideoSparseAttentionConfig", "AttentionConfig", "ParallelConfig", "BaseCacheConfig", diff --git a/tensorrt_llm/visual_gen/params.py b/tensorrt_llm/visual_gen/params.py index 87754a56c956..41fa444ec9a6 100644 --- a/tensorrt_llm/visual_gen/params.py +++ b/tensorrt_llm/visual_gen/params.py @@ -46,6 +46,13 @@ class VisualGenParams(StrictBaseModel): max_sequence_length: Optional[int] = Field( default=None, description="Max tokens for text encoding." ) + flow_shift: Optional[float] = Field( + default=None, + description=( + "Override the scheduler's flow-matching shift. None = pipeline's " + "per-variant recommended default. Currently honored only by the Wan pipelines." + ), + ) seed: int = Field(default=42, description="Random seed for reproducibility.") # Video diff --git a/tensorrt_llm/visual_gen/sparse_attention.py b/tensorrt_llm/visual_gen/sparse_attention.py index 41409c6ac6d6..a66cda701b6b 100644 --- a/tensorrt_llm/visual_gen/sparse_attention.py +++ b/tensorrt_llm/visual_gen/sparse_attention.py @@ -12,12 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Skip-softmax sparse attention helpers for visual generation.""" +"""Sparse attention recipe classes for visual generation (skip-softmax, VSA).""" import fnmatch import math from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union import yaml from pydantic import Field as PydanticField @@ -249,6 +249,29 @@ def _shared_formula_coefficients( return None +class VideoSparseAttentionConfig(StrictBaseModel): + """Video Sparse Attention (VSA) sparse-attention recipe (CUTEDSL backend only). + + Two-stage hybrid attention: a coarse mean-pooled stage over (4,4,4) cubes + and a block-sparse fine stage over the top-K cubes selected per head. + vsa_sparsity controls the fraction of cubes dropped on the fine stage. + """ + + algorithm: Literal["vsa"] = PydanticField( + "vsa", + description="Sparse attention algorithm discriminator.", + ) + vsa_sparsity: float = PydanticField( + 0.9, + ge=0.0, + le=1.0, + description=( + "Fraction of cubes dropped on the fine stage. 0.0 keeps all cubes " + "(dense fine stage); values closer to 1.0 keep fewer cubes." + ), + ) + + def _load_sparse_config_group_container(data: Dict[str, Any]) -> Optional[SkipSoftmaxConfig]: """Load one component's skip-softmax config from a ``config_groups`` container.""" config_groups = data.get("config_groups", {}) diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 0d3d5f4932e9..3c8db98fb552 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -188,6 +188,7 @@ l0_b200: - unittest/_torch/visual_gen/test_cache_dit.py - unittest/_torch/visual_gen/test_quant_ops.py - unittest/_torch/visual_gen/test_attention_cute_dsl.py + - unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py - unittest/_torch/visual_gen/test_attention_trtllm_sage.py - unittest/_torch/visual_gen/test_attention_integration.py - unittest/_torch/visual_gen/test_attention_perf.py diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py new file mode 100644 index 000000000000..4133c3dc1a55 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-GPU tests for VSA + Ulysses sequence parallelism.""" + +import functools +import math +import os + +os.environ["TLLM_DISABLE_MPI"] = "1" + +from typing import Callable + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn.functional as F + +try: + from tensorrt_llm._torch.visual_gen.attention_backend import ( + CuTeDSLAttention, + UlyssesAttention, + VSAMetadataBuilder, + set_vsa_forward_context, + ) + from tensorrt_llm._utils import get_free_port + from tensorrt_llm.visual_gen.args import VideoSparseAttentionConfig + + MODULES_AVAILABLE = True +except ImportError: + MODULES_AVAILABLE = False + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +# ============================================================================= +# Distributed helpers (same pattern as test_ulysses_sage_attention.py) +# ============================================================================= + + +def init_distributed_worker(rank: int, world_size: int, backend: str = "nccl", port: int = 29500): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + torch.cuda.set_device(rank % torch.cuda.device_count()) + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) + + +def cleanup_distributed(): + if dist.is_initialized(): + dist.destroy_process_group() + + +def _distributed_worker(rank, world_size, backend, test_fn, port): + try: + init_distributed_worker(rank, world_size, backend, port) + test_fn(rank, world_size) + except Exception as e: + print(f"Rank {rank} failed with error: {e}") + raise + finally: + cleanup_distributed() + + +def run_test_in_distributed(world_size: int, test_fn: Callable): + if not MODULES_AVAILABLE: + pytest.skip("Required modules not available") + if not torch.cuda.is_available(): + pytest.skip("CUDA required for VSA") + if torch.cuda.device_count() < world_size: + pytest.skip(f"Test requires {world_size} GPUs, only {torch.cuda.device_count()} available") + + port = get_free_port() + mp.spawn( + _distributed_worker, + args=(world_size, "nccl", test_fn, port), + nprocs=world_size, + join=True, + ) + + +# (8,8,8) latent -> 512 tokens (256/rank at P=2); the (4,4,4) tile gives 8 cubes +# (even, as the paired-block kernel needs) of 64 tokens (the kernel block_size). +_DIT_SEQ_SHAPE = (8, 8, 8) +_VSA_PATCH_SIZE = (1, 1, 1) +_HEAD_DIM = 128 # CuTe VSA fine-stage kernel requires head_dim == 128 +_HEADS_PER_RANK = 4 + + +def _make_vsa_backend(num_heads: int, vsa_sparsity: float) -> "CuTeDSLAttention": + """CUTEDSL backend on the VSA path; effective sparsity comes from the forward context.""" + return CuTeDSLAttention( + layer_idx=0, + num_heads=num_heads, + head_dim=_HEAD_DIM, + sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=vsa_sparsity), + ) + + +def _build_full_seq_vsa_metadata(vsa_sparsity: float, device: torch.device): + """VSAMetadata for the full sequence — identical on every rank after Ulysses all-to-all.""" + builder = VSAMetadataBuilder() + return builder.build( + current_timestep=0, + raw_latent_shape=_DIT_SEQ_SHAPE, + patch_size=_VSA_PATCH_SIZE, + vsa_sparsity=vsa_sparsity, + device=device, + ) + + +# ============================================================================= +# Test logic functions (module-level so mp.spawn can pickle them) +# ============================================================================= + + +def _logic_vsa_ulysses_forward(rank, world_size, *, vsa_sparsity: float): + """Forward pass: output shape correct and finite.""" + batch = 1 + seq_full = math.prod(_DIT_SEQ_SHAPE) + assert seq_full % world_size == 0 + seq_per_rank = seq_full // world_size + num_heads = world_size * _HEADS_PER_RANK + + device = torch.device(f"cuda:{rank}") + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + + inner = _make_vsa_backend(num_heads // world_size, vsa_sparsity) + attention = UlyssesAttention(inner_backend=inner, process_group=None) + + # Ulysses input: sequence-sharded, head-full [B, S/P, H, D]. + shape = (batch, seq_per_rank, num_heads, _HEAD_DIM) + q = torch.randn(shape, device=device, dtype=torch.bfloat16) + k = torch.randn(shape, device=device, dtype=torch.bfloat16) + v = torch.randn(shape, device=device, dtype=torch.bfloat16) + gate_compress = torch.randn(shape, device=device, dtype=torch.bfloat16) + gate_fine = torch.randn(shape, device=device, dtype=torch.bfloat16) + + metadata = _build_full_seq_vsa_metadata(vsa_sparsity, device) + with set_vsa_forward_context(metadata): + output = attention(q, k, v, gate_compress=gate_compress, gate_fine=gate_fine) + + assert output.shape == (batch, seq_per_rank, num_heads, _HEAD_DIM), ( + f"Rank {rank}: expected {(batch, seq_per_rank, num_heads, _HEAD_DIM)}, got {output.shape}" + ) + assert torch.isfinite(output).all(), f"Rank {rank}: Inf/NaN in output" + + +def _logic_vsa_ulysses_vs_reference(rank, world_size, *, vsa_sparsity: float): + """Each rank's Ulysses+VSA output matches the single-GPU VSA reference's sequence slice.""" + batch = 1 + seq_full = math.prod(_DIT_SEQ_SHAPE) + assert seq_full % world_size == 0 + seq_per_rank = seq_full // world_size + num_heads = world_size * _HEADS_PER_RANK + + device = torch.device(f"cuda:{rank}") + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + + full_shape = (batch, seq_full, num_heads, _HEAD_DIM) + q_full = torch.randn(full_shape, device=device, dtype=torch.bfloat16) + k_full = torch.randn(full_shape, device=device, dtype=torch.bfloat16) + v_full = torch.randn(full_shape, device=device, dtype=torch.bfloat16) + gate_c_full = torch.randn(full_shape, device=device, dtype=torch.bfloat16) + gate_f_full = torch.randn(full_shape, device=device, dtype=torch.bfloat16) + + sl = slice(rank * seq_per_rank, (rank + 1) * seq_per_rank) + q_shard = q_full[:, sl].contiguous() + k_shard = k_full[:, sl].contiguous() + v_shard = v_full[:, sl].contiguous() + gate_c_shard = gate_c_full[:, sl].contiguous() + gate_f_shard = gate_f_full[:, sl].contiguous() + + metadata = _build_full_seq_vsa_metadata(vsa_sparsity, device) + + # Ulysses path: sharded input, head-sharded inner backend. + inner = _make_vsa_backend(num_heads // world_size, vsa_sparsity) + attention = UlyssesAttention(inner_backend=inner, process_group=None) + with set_vsa_forward_context(metadata): + ulysses_out = attention( + q_shard, k_shard, v_shard, gate_compress=gate_c_shard, gate_fine=gate_f_shard + ) + + # Single-GPU VSA reference over the full sequence. + ref_attn = _make_vsa_backend(num_heads, vsa_sparsity) + with set_vsa_forward_context(metadata): + ref_out = ref_attn.forward( + q_full, k_full, v_full, gate_compress=gate_c_full, gate_fine=gate_f_full + ) + ref_shard = ref_out[:, sl] + + ulysses_out = ulysses_out.view(batch, seq_per_rank, num_heads, _HEAD_DIM).to(torch.bfloat16) + ref_shard = ref_shard.to(torch.bfloat16) + + cos_sim = F.cosine_similarity( + ulysses_out.reshape(-1).float(), + ref_shard.reshape(-1).float(), + dim=0, + ).item() + assert cos_sim > 0.990, f"Rank {rank}: cosine similarity {cos_sim:.6f} is below threshold 0.990" + torch.testing.assert_close(ulysses_out, ref_shard, atol=2e-2, rtol=2e-2) + + +# ============================================================================= +# Test class +# ============================================================================= + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +class TestWanVsaUlysses: + """VSA (CUTEDSL) attention backend combined with Ulysses sequence parallelism.""" + + @pytest.mark.parametrize("vsa_sparsity", [0.0, 0.5]) + def test_vsa_ulysses_forward(self, vsa_sparsity: float): + run_test_in_distributed( + world_size=2, + test_fn=functools.partial(_logic_vsa_ulysses_forward, vsa_sparsity=vsa_sparsity), + ) + + @pytest.mark.parametrize("vsa_sparsity", [0.0, 0.5, 0.75]) + def test_vsa_ulysses_vs_reference(self, vsa_sparsity: float): + run_test_in_distributed( + world_size=2, + test_fn=functools.partial(_logic_vsa_ulysses_vs_reference, vsa_sparsity=vsa_sparsity), + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py b/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py new file mode 100644 index 000000000000..d079b79ae1e5 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py @@ -0,0 +1,329 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VSA correctness tests: CuTe kernel, tile/untile roundtrip, top-k math, backend guards. + +Module-level dense-equivalence and finite-output checks live in +test_attention_integration.py. +""" + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn.functional as F + +from tensorrt_llm._torch.visual_gen.attention_backend import VSAMetadataBuilder +from tensorrt_llm._torch.visual_gen.config import ( + DiffusionModelConfig, + create_attention_metadata_state, +) +from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode +from tensorrt_llm.visual_gen.args import AttentionConfig, VideoSparseAttentionConfig + + +def _make_config( + hidden_size: int, + num_heads: int, + head_dim: int, + backend: str, + vsa_sparsity: "float | None" = None, +) -> DiffusionModelConfig: + """Minimal DiffusionModelConfig for one Attention module.""" + pretrained_config = SimpleNamespace( + hidden_size=hidden_size, + num_attention_heads=num_heads, + attention_head_dim=head_dim, + eps=1e-6, + ) + sparse_attention_config = ( + VideoSparseAttentionConfig(vsa_sparsity=vsa_sparsity) if vsa_sparsity is not None else None + ) + config = DiffusionModelConfig( + pretrained_config=pretrained_config, + attention=AttentionConfig(backend=backend, sparse_attention_config=sparse_attention_config), + skip_create_weights_in_init=False, + ) + config.attention_metadata_state = ( + create_attention_metadata_state() if backend == "TRTLLM" else None + ) + return config + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="VSA needs CUDA") +def test_vsa_falls_back_to_vanilla_for_cross_attention(): + """Cross-attention (SEPARATE_QKV) falls back to VANILLA — it has no cube structure.""" + device = torch.device("cuda") + dtype = torch.bfloat16 + cfg = _make_config( + hidden_size=64, num_heads=4, head_dim=16, backend="CUTEDSL", vsa_sparsity=0.5 + ) + cross_attn = ( + Attention(64, 4, qkv_mode=QKVMode.SEPARATE_QKV, config=cfg) + .to(device=device, dtype=dtype) + .eval() + ) + assert cross_attn.attn_backend == "VANILLA", ( + f"VSA on cross-attention should fall back to VANILLA, got {cross_attn.attn_backend!r}" + ) + + +def test_vsa_with_attn2d_raises(): + """VSA + Attention2D must error at construction (VSA needs the full sequence per rank).""" + pretrained_config = SimpleNamespace( + hidden_size=64, + num_attention_heads=4, + attention_head_dim=16, + eps=1e-6, + ) + cfg = DiffusionModelConfig( + pretrained_config=pretrained_config, + attention=AttentionConfig( + backend="CUTEDSL", + sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=0.0), + ), + skip_create_weights_in_init=False, + ) + cfg.visual_gen_mapping = SimpleNamespace( + ring_size=1, + ring_group=None, + ulysses_size=1, + ulysses_group=None, + attn2d_row_size=2, + attn2d_col_size=2, + attn2d_row_group=None, + attn2d_col_group=None, + ) + with pytest.raises(ValueError, match="incompatible with Attention2D"): + Attention(64, 4, qkv_mode=QKVMode.FUSE_QKV, config=cfg) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="VSA needs CUDA") +def test_vsa_topk_collapses_to_dense_at_sparsity_zero(): + """At sparsity=0, top_k equals num_cubes (dense connectivity).""" + from math import ceil + + device = torch.device("cuda") + builder = VSAMetadataBuilder() + metadata = builder.build( + current_timestep=0, + raw_latent_shape=(8, 8, 8), + patch_size=(1, 1, 1), + vsa_sparsity=0.0, + device=device, + ) + num_cubes = metadata.num_tiles[0] * metadata.num_tiles[1] * metadata.num_tiles[2] + cur_topk = max(1, ceil((1.0 - metadata.vsa_sparsity) * num_cubes)) + assert cur_topk == num_cubes, ( + f"sparsity=0 should select all {num_cubes} cubes, got top_k={cur_topk}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="VSA needs CUDA") +@pytest.mark.parametrize( + "latent_shape", + [ + (8, 8, 8), + (9, 9, 9), + (21, 45, 80), + ], + ids=["clean_8x8x8", "ragged_9x9x9", "wan720p_21x45x80"], +) +def test_vsa_tile_untile_roundtrip(latent_shape): + """VSAPreprocessor.tile then .untile must losslessly reproduce the input.""" + from tensorrt_llm._torch.visual_gen.attention_backend.cute_dsl import VSAPreprocessor + + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(0) + + B, H, D = 2, 4, 32 + seq_len = latent_shape[0] * latent_shape[1] * latent_shape[2] + + builder = VSAMetadataBuilder() + meta = builder.build( + current_timestep=0, + raw_latent_shape=latent_shape, + patch_size=(1, 1, 1), + vsa_sparsity=0.0, + device=device, + ) + + x = torch.randn(B, seq_len, H, D, device=device, dtype=dtype) + + x_tiled = VSAPreprocessor.tile( + x, + meta.non_pad_index, + meta.gather_idx, + meta.padded_seq_length, + ) + + pad_mask = torch.ones(meta.padded_seq_length, dtype=torch.bool, device=device) + pad_mask[meta.non_pad_index] = False + if pad_mask.any(): + assert x_tiled[:, pad_mask, :, :].abs().max().item() == 0.0, ( + "tile() must zero-fill padded positions" + ) + + x_roundtrip = VSAPreprocessor.untile( + x_tiled, + meta.reverse_tile_partition_indices, + meta.non_pad_index, + ) + + assert x_roundtrip.shape == x.shape, ( + f"shape mismatch after tile/untile: {x_roundtrip.shape} vs {x.shape}" + ) + assert torch.equal(x_roundtrip, x), ( + f"tile/untile round-trip is not lossless for latent_shape={latent_shape}: " + f"max_diff={(x_roundtrip - x).abs().max().item():.3e}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="kernel test needs CUDA") +def test_cute_kernel_matches_dense_at_full_topk(): + """CuTe block-sparse kernel matches dense SDPA when every cube is selected.""" + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, + block_sparse_attn_from_indices_cute, + is_cute_supported, + ) + + if not CUTE_AVAILABLE: + pytest.skip("cuda-bindings or cutlass-dsl not importable") + + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(0) + + B, H, num_cubes, D = 1, 4, 4, 128 + block_size = 64 + seq_len = num_cubes * block_size + + q = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + k = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + v = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + + if not is_cute_supported(q): + pytest.skip("CuTe path needs sm_100+ Blackwell (current device unsupported)") + + topk = num_cubes + q2k_idx = ( + torch.arange(num_cubes, device=device, dtype=torch.int32) + .view(1, 1, 1, num_cubes) + .expand(B, H, num_cubes, topk) + .contiguous() + ) + q2k_num = torch.full((B, H, num_cubes), topk, dtype=torch.int32, device=device) + variable_block_sizes = torch.full((num_cubes,), block_size, dtype=torch.int32, device=device) + + out_kernel, _lse = block_sparse_attn_from_indices_cute( + q, k, v, q2k_idx, q2k_num, variable_block_sizes + ) + out_ref = F.scaled_dot_product_attention(q, k, v) + + max_diff = (out_kernel - out_ref).abs().max().item() + mean_diff = (out_kernel - out_ref).abs().mean().item() + + rtol, atol = 1e-2, 1e-2 + assert torch.allclose(out_kernel, out_ref, rtol=rtol, atol=atol), ( + f"CuTe block-sparse kernel deviates from dense SDPA at full top-K: " + f"max_diff={max_diff:.3e}, mean_diff={mean_diff:.3e} (rtol={rtol}, atol={atol})" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="kernel test needs CUDA") +def test_cute_kernel_matches_ref_with_independent_indices(): + """CuTe kernel: paired Q-blocks (2i, 2i+1) attend to independent KV index lists.""" + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, + block_sparse_attn_from_indices_cute, + is_cute_supported, + ) + + if not CUTE_AVAILABLE: + pytest.skip("cuda-bindings or cutlass-dsl not importable") + + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(42) + + B, H, num_cubes, D = 2, 4, 16, 128 + block_size = 64 + topk = num_cubes // 2 + seq_len = num_cubes * block_size + assert num_cubes % 2 == 0, "num_cubes must be even for the paired-block kernel" + + q = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + k = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + v = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + + if not is_cute_supported(q): + pytest.skip("CuTe path needs sm_100+ Blackwell (current device unsupported)") + + q2k_idx = ( + torch.stack( + [ + torch.randperm(num_cubes, device=device, dtype=torch.int32)[:topk] + for _ in range(B * H * num_cubes) + ] + ) + .view(B, H, num_cubes, topk) + .contiguous() + ) + + paired = q2k_idx.view(B, H, num_cubes // 2, 2, topk).sort(dim=-1).values + pair_mismatch = (paired[..., 0, :] != paired[..., 1, :]).sum().item() + assert pair_mismatch > 0, ( + "Pre-condition failed: random permutations matched across every pair; " + "re-seed or raise num_cubes." + ) + + q2k_num = torch.full((B, H, num_cubes), topk, dtype=torch.int32, device=device) + variable_block_sizes = torch.full((num_cubes,), block_size, dtype=torch.int32, device=device) + + attn_mask = torch.full( + (B, H, seq_len, seq_len), float("-inf"), device=device, dtype=torch.float32 + ) + for b in range(B): + for h in range(H): + for q_blk in range(num_cubes): + for ki in range(topk): + k_blk = q2k_idx[b, h, q_blk, ki].item() + qs = q_blk * block_size + ks = k_blk * block_size + attn_mask[b, h, qs : qs + block_size, ks : ks + block_size] = 0.0 + + out_kernel, _lse = block_sparse_attn_from_indices_cute( + q, k, v, q2k_idx, q2k_num, variable_block_sizes + ) + + # Manual fp32 masked-softmax: bf16 inputs + an fp32 -inf mask through SDPA + # mishandle the masked region at this shape, so compute it explicitly. + scale = 1.0 / (D**0.5) + scores = (q.float() @ k.float().transpose(-2, -1)) * scale + scores = scores + attn_mask + probs = torch.softmax(scores, dim=-1) + out_ref = (probs @ v.float()).to(dtype) + + abs_diff = (out_kernel.float() - out_ref.float()).abs() + max_diff = abs_diff.max().item() + mean_diff = abs_diff.mean().item() + + rtol, atol = 1e-2, 1e-2 + assert torch.allclose(out_kernel, out_ref, rtol=rtol, atol=atol), ( + f"CuTe kernel with independent per-Q-block indices deviated from masked fp32 " + f"reference: max_diff={max_diff:.3e}, mean_diff={mean_diff:.3e} " + f"(rtol={rtol}, atol={atol}, pair_mismatch={pair_mismatch})" + ) diff --git a/tests/unittest/_torch/visual_gen/test_attention_integration.py b/tests/unittest/_torch/visual_gen/test_attention_integration.py index 22f094ffd05e..bb798fd62fa1 100644 --- a/tests/unittest/_torch/visual_gen/test_attention_integration.py +++ b/tests/unittest/_torch/visual_gen/test_attention_integration.py @@ -35,7 +35,11 @@ # Import new integrated versions from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode, apply_rotary_emb -from tensorrt_llm.visual_gen.args import AttentionConfig, QuantAttentionConfig +from tensorrt_llm.visual_gen.args import ( + AttentionConfig, + QuantAttentionConfig, + VideoSparseAttentionConfig, +) _flash_attn4_available = _fa4_fwd is not None _cute_dsl_available = _cute_dsl_import_error is None @@ -128,6 +132,8 @@ def create_model_config( eps: float = 1e-6, attn_backend: str = "VANILLA", quant_attention_config: "QuantAttentionConfig | None" = None, + sparse_attention_config=None, + vsa_sparsity: "float | None" = None, *, visual_gen_mapping: VisualGenMapping | None = None, skip_create_weights_in_init: bool = False, @@ -140,12 +146,15 @@ def create_model_config( eps=eps, ) - # Create a minimal config without quantization + if vsa_sparsity is not None and sparse_attention_config is None: + sparse_attention_config = VideoSparseAttentionConfig(vsa_sparsity=vsa_sparsity) + config = DiffusionModelConfig( pretrained_config=pretrained_config, attention=AttentionConfig( backend=attn_backend, quant_attention_config=quant_attention_config, + sparse_attention_config=sparse_attention_config, ), skip_create_weights_in_init=skip_create_weights_in_init, ) @@ -662,6 +671,102 @@ def test_fast_cross_attention_wan_shapes( assert is_close, f"{attn_backend} cross-attn mismatch at Wan shapes: max_diff={max_diff:.2e}" +# ============================================================================ +# VSA self-attention (CUTEDSL backend, sparse_attention_config.algorithm='vsa') +# ============================================================================ + + +def _build_vsa_setup(sparsity: float, batch_size: int, seed: int): + """Build naive + integrated models, VSA metadata, and inputs for a VSA test. + + latent (8,8,8) -> 512 tokens (divisible by block_size=64), head_dim=128. + """ + from tensorrt_llm._torch.visual_gen.attention_backend import VSAMetadataBuilder + + latent_shape = (8, 8, 8) + seq_len = latent_shape[0] * latent_shape[1] * latent_shape[2] + num_heads = 4 + head_dim = 128 + hidden_size = num_heads * head_dim + device = torch.device("cuda") + dtype = torch.bfloat16 + + naive = NaiveWanSelfAttention(hidden_size, num_heads, head_dim, dtype=dtype).to(device) + cfg_vsa = create_model_config( + hidden_size, num_heads, head_dim, attn_backend="CUTEDSL", vsa_sparsity=sparsity + ) + integrated = Attention(hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=cfg_vsa).to( + device + ) + copy_weights_self_attention(naive, integrated) + naive.eval() + integrated.eval() + + metadata = VSAMetadataBuilder().build( + current_timestep=0, + raw_latent_shape=latent_shape, + patch_size=(1, 1, 1), + vsa_sparsity=sparsity, + device=device, + ) + torch.manual_seed(seed) + hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype) + return SimpleNamespace( + naive=naive, + integrated=integrated, + metadata=metadata, + hidden_states=hidden_states, + gate_compress_zero=torch.zeros_like(hidden_states), + freqs_HSD=generate_rope_embeddings(seq_len, head_dim, device, is_HSD=True), + freqs_SHD=generate_rope_embeddings(seq_len, head_dim, device, is_HSD=False), + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="VSA needs CUDA") +def test_vsa_self_attention_equivalence_at_sparsity_zero(): + """VSA at sparsity=0 with G_c=0 reduces to dense attention (top_k=num_cubes, + output=O_f); must match the naive SDPA reference modulo bf16 rounding.""" + from tensorrt_llm._torch.visual_gen.attention_backend import set_vsa_forward_context + + s = _build_vsa_setup(sparsity=0.0, batch_size=2, seed=42) + + with torch.no_grad(): + out_naive = s.naive(s.hidden_states, *s.freqs_HSD) + with torch.no_grad(), set_vsa_forward_context(s.metadata): + out_vsa = s.integrated( + s.hidden_states, freqs=s.freqs_SHD, gate_compress=s.gate_compress_zero + ) + + assert out_naive.shape == out_vsa.shape, ( + f"shape mismatch: naive={out_naive.shape}, vsa={out_vsa.shape}" + ) + max_diff = (out_naive - out_vsa).abs().max().item() + mean_diff = (out_naive - out_vsa).abs().mean().item() + assert torch.allclose(out_naive, out_vsa, rtol=1e-2, atol=1e-2), ( + f"VSA(sparsity=0, G_c=0) deviates from naive dense SDPA: " + f"max_diff={max_diff:.2e}, mean_diff={mean_diff:.2e}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="VSA needs CUDA") +@pytest.mark.parametrize("sparsity", [0.0, 0.5], ids=["s0", "s0p5"]) +def test_vsa_self_attention_finite(sparsity: float): + """VSA forward must produce finite output (no NaN/Inf) at any supported sparsity.""" + from tensorrt_llm._torch.visual_gen.attention_backend import set_vsa_forward_context + + s = _build_vsa_setup(sparsity=sparsity, batch_size=1, seed=0) + + with torch.no_grad(), set_vsa_forward_context(s.metadata): + out = s.integrated(s.hidden_states, freqs=s.freqs_SHD, gate_compress=s.gate_compress_zero) + + assert out.shape == s.hidden_states.shape + nan_count = torch.isnan(out).sum().item() + inf_count = torch.isinf(out).sum().item() + assert nan_count == 0 and inf_count == 0, ( + f"VSA produced non-finite output at sparsity={sparsity}: NaN={nan_count}, Inf={inf_count}" + ) + + def test_trtllm_cached_prepare(): """Test that TRTLLM attention cached prepare works correctly. diff --git a/tests/unittest/_torch/visual_gen/test_attention_perf.py b/tests/unittest/_torch/visual_gen/test_attention_perf.py index 90a8ac3608c0..693efcc663bd 100644 --- a/tests/unittest/_torch/visual_gen/test_attention_perf.py +++ b/tests/unittest/_torch/visual_gen/test_attention_perf.py @@ -29,13 +29,18 @@ """ import time -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from types import SimpleNamespace from typing import Dict, Optional, Tuple import pytest import torch +from tensorrt_llm._torch.visual_gen.attention_backend import ( + VSAMetadataBuilder, + set_vsa_forward_context, +) + # ============================================================================ # Flash Attention 4 availability # ============================================================================ @@ -49,7 +54,11 @@ create_attention_metadata_state, ) from tensorrt_llm._torch.visual_gen.modules.attention import Attention, QKVMode -from tensorrt_llm.visual_gen.args import AttentionConfig, QuantAttentionConfig +from tensorrt_llm.visual_gen.args import ( + AttentionConfig, + QuantAttentionConfig, + VideoSparseAttentionConfig, +) _flash_attn4_available = _fa4_fwd is not None _cute_dsl_available = _cute_dsl_import_error is None @@ -130,6 +139,17 @@ def get_elapsed_time(): yield get_elapsed_time +def bench_fn(device: torch.device, fn, n_iters: int) -> torch.Tensor: + """Time `fn` over `n_iters` GPU-timed iterations; returns per-iteration ms.""" + times = [] + with torch.no_grad(): + for _ in range(n_iters): + with cuda_timer(device) as get_t: + fn() + times.append(get_t()) + return torch.tensor(times) + + @contextmanager def nvtx_range(name: str): """Context manager for NVTX range profiling.""" @@ -171,6 +191,7 @@ def create_model_config( eps: float = 1e-6, attn_backend: str = "VANILLA", quant_attention_config: "QuantAttentionConfig | None" = None, + vsa_sparsity: "float | None" = None, ) -> DiffusionModelConfig: """Create a mock DiffusionModelConfig for testing.""" pretrained_config = SimpleNamespace( @@ -180,11 +201,16 @@ def create_model_config( eps=eps, ) + sparse_attention_config = ( + VideoSparseAttentionConfig(vsa_sparsity=vsa_sparsity) if vsa_sparsity is not None else None + ) + config = DiffusionModelConfig( pretrained_config=pretrained_config, attention=AttentionConfig( backend=attn_backend, quant_attention_config=quant_attention_config, + sparse_attention_config=sparse_attention_config, ), skip_create_weights_in_init=False, ) @@ -230,6 +256,21 @@ def _is_sage_attention_enabled(model: Attention) -> bool: return getattr(inner_backend, "quant_attention_config", None) is not None +def _init_attention_weights(attn: Attention, std: float = 0.02) -> None: + """Initialize TRT-LLM Linear weights to N(0, std) to avoid RMSNorm NaN.""" + with torch.no_grad(): + if getattr(attn, "qkv_proj", None) is not None: + attn.qkv_proj.weight.normal_(0.0, std) + if attn.qkv_proj.bias is not None: + attn.qkv_proj.bias.zero_() + if getattr(attn, "qk_norm", False): + attn.norm_q.weight.fill_(1.0) + attn.norm_k.weight.fill_(1.0) + attn.to_out[0].weight.normal_(0.0, std) + if attn.to_out[0].bias is not None: + attn.to_out[0].bias.zero_() + + # ============================================================================ # Performance benchmark class # ============================================================================ @@ -286,6 +327,8 @@ def create_attention_model( head_dim: int, backend: str, quant_attention_config: "QuantAttentionConfig | None" = None, + vsa_sparsity: "float | None" = None, + init_std: "float | None" = None, ) -> Attention: """Create a WAN self-attention model with specified backend.""" config = create_model_config( @@ -294,11 +337,14 @@ def create_attention_model( head_dim, attn_backend=backend, quant_attention_config=quant_attention_config, + vsa_sparsity=vsa_sparsity, ) model = Attention(hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=config).to( self.device ) model.eval() + if init_std is not None: + _init_attention_weights(model, init_std) return model def create_cross_attention_model( @@ -434,9 +480,16 @@ def benchmark_single( backend: str, verbose: bool = True, quant_attention_config: "QuantAttentionConfig | None" = None, + vsa_sparsity: "float | None" = None, + latent_shape: "Tuple[int, int, int] | None" = None, + init_std: "float | None" = None, ) -> Optional[Dict]: """Benchmark a single configuration. + For the VSA path, pass ``vsa_sparsity`` and the 3-D ``latent_shape`` + (T, H, W) with ``seq_len == T * H * W``; the forward is run inside a + VSA forward context with a zero ``gate_compress``. + Returns: Dict with timing statistics or None if test failed/skipped """ @@ -452,16 +505,46 @@ def benchmark_single( try: # Create model and data model = self.create_attention_model( - hidden_size, num_heads, head_dim, backend, quant_attention_config + hidden_size, + num_heads, + head_dim, + backend, + quant_attention_config, + vsa_sparsity=vsa_sparsity, + init_std=init_std, ) hidden_states, freqs = self.create_test_data(batch_size, seq_len, hidden_size, head_dim) + # VSA needs an active forward context + a zero gate_compress; other + # backends use neither. + vsa_metadata = None + gate_compress = None + if vsa_sparsity is not None: + vsa_metadata = VSAMetadataBuilder().build( + current_timestep=0, + raw_latent_shape=latent_shape, + patch_size=(1, 1, 1), + vsa_sparsity=vsa_sparsity, + device=self.device, + ) + gate_compress = torch.zeros_like(hidden_states) + + def _forward(): + ctx = ( + set_vsa_forward_context(vsa_metadata) + if vsa_metadata is not None + else nullcontext() + ) + kwargs = {"gate_compress": gate_compress} if gate_compress is not None else {} + with ctx: + return model(hidden_states, freqs=freqs, **kwargs) + # Warmup with nvtx_range(f"warmup_{backend}"): with torch_profiler_range(f"warmup_{backend}"): with torch.no_grad(): for _ in range(self.warmup_iterations): - _ = model(hidden_states, freqs=freqs) + _ = _forward() if self.device.type == "cuda": torch.cuda.synchronize() @@ -474,7 +557,7 @@ def benchmark_single( for i in range(self.benchmark_iterations): with nvtx_range(f"iter_{backend}_{i}"): with cuda_timer(self.device) as get_time: - _ = model(hidden_states, freqs=freqs) + _ = _forward() times.append(get_time()) # Statistics @@ -935,6 +1018,223 @@ def test_fa4_cross_attn_quick( assert result["avg_ms"] > 0 +# ============================================================================ +# VSA fine-stage kernel vs FA4 (kernel-level) +# ============================================================================ + + +class TestVsaVsFa4KernelPerformance: + """Kernel-level VSA fine-stage vs FA4 benchmark. + + Times the two kernels directly (no QKV proj / norm / gate / tile-untile) so + the comparison reflects only the attention math. + """ + + @pytest.fixture(autouse=True) + def setup(self): + if not _flash_attn4_available: + pytest.skip( + "FlashAttention 4 not available" + + (f": {_fa4_import_error}" if _fa4_import_error else "") + ) + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, + block_sparse_attn_from_indices_cute, + is_cute_supported, + ) + + if not CUTE_AVAILABLE: + pytest.skip("cuda-bindings or cutlass-dsl not importable (VSA CuTe path)") + self.device = torch.device("cuda") + self.dtype = torch.bfloat16 + self.warmup = 10 + self.iters = 50 + self._cute_fn = block_sparse_attn_from_indices_cute + self._is_cute_supported = is_cute_supported + + @pytest.mark.parametrize( + "batch,num_heads,seq_len,head_dim,block_size,sparsity", + [ + (1, 40, 131_072, 128, 64, 0.5), + ], + ids=["wan_BS1_H40_S131k_D128_blk64_s0.5"], + ) + def test_vsa_kernel_vs_fa4( + self, + batch: int, + num_heads: int, + seq_len: int, + head_dim: int, + block_size: int, + sparsity: float, + ): + from tensorrt_llm._torch.visual_gen.attention_backend.cute_dsl import VSA_KERNEL_MAX_CUBES + + assert seq_len % block_size == 0, "seq_len must be a multiple of block_size" + num_cubes = seq_len // block_size + if num_cubes > VSA_KERNEL_MAX_CUBES: + pytest.skip( + f"num_cubes={num_cubes} exceeds VSA_KERNEL_MAX_CUBES={VSA_KERNEL_MAX_CUBES}" + ) + topk = max(1, int(round((1.0 - sparsity) * num_cubes))) + + # FA4 expects [B, S, H, D]; VSA CuTe expects [B, H, S, D]. + q_nhd = torch.randn( + batch, seq_len, num_heads, head_dim, device=self.device, dtype=self.dtype + ) + k_nhd = torch.randn_like(q_nhd) + v_nhd = torch.randn_like(q_nhd) + q_bhsd = q_nhd.transpose(1, 2).contiguous() + k_bhsd = k_nhd.transpose(1, 2).contiguous() + v_bhsd = v_nhd.transpose(1, 2).contiguous() + + if not self._is_cute_supported(q_bhsd): + pytest.skip("VSA CuTe path requires sm_100+ Blackwell + bf16/fp16 + head_dim=128") + + # Group2QInterleaveKV requires paired Q-blocks (2i, 2i+1) to share KV + # selections. A constant index set across all Q-blocks trivially satisfies that. + q2k_idx = ( + torch.arange(topk, device=self.device, dtype=torch.int32) + .view(1, 1, 1, topk) + .expand(batch, num_heads, num_cubes, topk) + .contiguous() + ) + q2k_num = torch.full( + (batch, num_heads, num_cubes), topk, dtype=torch.int32, device=self.device + ) + variable_block_sizes = torch.full( + (num_cubes,), block_size, dtype=torch.int32, device=self.device + ) + + softmax_scale = head_dim**-0.5 + + def fa4_call(): + _fa4_fwd( + q_nhd, + k_nhd, + v_nhd, + softmax_scale=softmax_scale, + causal=False, + window_size_left=None, + window_size_right=None, + learnable_sink=None, + softcap=0.0, + pack_gqa=None, + mask_mod=None, + block_sparse_tensors=None, + return_lse=True, + ) + + def vsa_call(): + self._cute_fn( + q_bhsd, + k_bhsd, + v_bhsd, + q2k_idx, + q2k_num, + variable_block_sizes, + ) + + with torch.no_grad(): + for _ in range(self.warmup): + fa4_call() + vsa_call() + torch.cuda.synchronize() + + with nvtx_range("bench_fa4"): + fa4_times = bench_fn(self.device, fa4_call, self.iters) + with nvtx_range("bench_vsa_fine"): + vsa_times = bench_fn(self.device, vsa_call, self.iters) + + fa4_avg = fa4_times.mean().item() + vsa_avg = vsa_times.mean().item() + speedup = fa4_avg / vsa_avg + + print( + f"\n VSA vs FA4 (BS={batch}, H={num_heads}, S={seq_len}, " + f"D={head_dim}, blk={block_size}, sparsity={sparsity:.2f}, " + f"topk={topk}/{num_cubes}):" + ) + print(f" FA4 (dense): avg={fa4_avg:.3f}ms") + print(f" VSA (fine-stage): avg={vsa_avg:.3f}ms") + print(f" VSA speedup vs FA4: {speedup:.2f}x ({'faster' if speedup > 1 else 'slower'})") + assert fa4_avg > 0 and vsa_avg > 0 + + +# ============================================================================ +# VSA vs VANILLA at Wan 2.2 T2V 14B production shape — module-level +# ============================================================================ + + +class TestVsaVsVanillaWan22T2v14bModulePerformance: + """Module-level VSA vs VANILLA at Wan 2.2 T2V 14B 720p / 81-frame shape. + + Drives the shared WanAttentionPerformanceBenchmark engine so VSA is timed + and reported exactly like the VANILLA/TRTLLM/FA4 backends. + """ + + _BATCH = 1 + _NUM_HEADS = 40 + _HEAD_DIM = 128 + _LATENT_SHAPE = (21, 45, 80) # DiT seq shape after VAE + patchify + _SEQ_LEN = 21 * 45 * 80 # 75_600 + + @pytest.fixture(autouse=True) + def setup(self): + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, + is_cute_supported, + ) + + if not CUTE_AVAILABLE: + pytest.skip("cuda-bindings or cutlass-dsl not importable (VSA CuTe path)") + self._is_cute_supported = is_cute_supported + self.benchmark = WanAttentionPerformanceBenchmark( + warmup_iterations=5, benchmark_iterations=20 + ) + + @pytest.mark.parametrize( + "sparsity", + [0.5, 0.75, 0.85, 0.875], + ids=["s0.5", "s0.75", "s0.85", "s0.875"], + ) + def test_vsa_module_vs_vanilla_wan22_t2v_14b(self, sparsity: float): + bench = self.benchmark + probe_q = torch.empty( + self._BATCH, + self._NUM_HEADS, + 64, + self._HEAD_DIM, + device=bench.device, + dtype=bench.dtype, + ) + if not self._is_cute_supported(probe_q): + pytest.skip("VSA CuTe path requires sm_100+ Blackwell + bf16/fp16 + head_dim=128") + + common = dict( + batch_size=self._BATCH, + num_heads=self._NUM_HEADS, + seq_len=self._SEQ_LEN, + head_dim=self._HEAD_DIM, + init_std=0.02, + verbose=False, + ) + vanilla = bench.benchmark_single(backend="VANILLA", **common) + vsa = bench.benchmark_single( + backend="CUTEDSL", vsa_sparsity=sparsity, latent_shape=self._LATENT_SHAPE, **common + ) + assert vanilla is not None and vsa is not None, "benchmark returned None (skipped/failed)" + assert vanilla["avg_ms"] > 0 and vsa["avg_ms"] > 0 + + speedup = vanilla["avg_ms"] / vsa["avg_ms"] + print( + f"\n VSA vs VANILLA (Wan 2.2 T2V 14B 720p/81f, latent={self._LATENT_SHAPE}, " + f"S={self._SEQ_LEN}, sparsity={sparsity:.3f}): " + f"VANILLA={vanilla['avg_ms']:.3f}ms, VSA={vsa['avg_ms']:.3f}ms, " + f"speedup={speedup:.2f}x ({'faster' if speedup > 1 else 'slower'})" + ) + + # ============================================================================ # Main entry point # ============================================================================ From fe5c812737c4d23fd2e9ffe8429c25cd4562c79d Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Mon, 1 Jun 2026 14:41:02 -0700 Subject: [PATCH 02/14] update default flow shift Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../_torch/visual_gen/models/wan/pipeline_wan.py | 11 ++++++++--- .../_torch/visual_gen/models/wan/pipeline_wan_i2v.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index 6f5bb0a0ffb0..cf7164f5a14f 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -499,12 +499,17 @@ def forward( resolved_flow_shift = ( flow_shift if flow_shift is not None else self._default_flow_shift(height, width) ) - if self.scheduler.config.shift != resolved_flow_shift: + + sched_cfg = self.scheduler.config + shift_key = ( + "shift" if "shift" in sched_cfg else "flow_shift" if "flow_shift" in sched_cfg else None + ) + if shift_key is not None and sched_cfg[shift_key] != resolved_flow_shift: logger.info( - f"flow_shift: {self.scheduler.config.shift} -> {resolved_flow_shift} " + f"flow_shift: {sched_cfg[shift_key]} -> {resolved_flow_shift} " f"({'user' if flow_shift is not None else 'variant default'})" ) - self.scheduler.config.shift = resolved_flow_shift + self.scheduler.register_to_config(**{shift_key: resolved_flow_shift}) self.scheduler.set_timesteps(num_inference_steps, device=self.device) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index b7402d0b3630..94204e378331 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -550,12 +550,17 @@ def forward( resolved_flow_shift = ( flow_shift if flow_shift is not None else self._default_flow_shift(height, width) ) - if self.scheduler.config.shift != resolved_flow_shift: + + sched_cfg = self.scheduler.config + shift_key = ( + "shift" if "shift" in sched_cfg else "flow_shift" if "flow_shift" in sched_cfg else None + ) + if shift_key is not None and sched_cfg[shift_key] != resolved_flow_shift: logger.info( - f"flow_shift: {self.scheduler.config.shift} -> {resolved_flow_shift} " + f"flow_shift: {sched_cfg[shift_key]} -> {resolved_flow_shift} " f"({'user' if flow_shift is not None else 'variant default'})" ) - self.scheduler.config.shift = resolved_flow_shift + self.scheduler.register_to_config(**{shift_key: resolved_flow_shift}) self.scheduler.set_timesteps(num_inference_steps, device=self.device) From 5426f8cea7fa72cb52bba12f42fe2a59b706acd3 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Mon, 1 Jun 2026 16:44:21 -0700 Subject: [PATCH 03/14] small updates, mostly documentation Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../visual_gen/attention_backend/cute_dsl.py | 29 +++++---- .../visual_gen/models/flux/pipeline_flux.py | 4 +- .../visual_gen/models/flux/pipeline_flux2.py | 4 +- .../visual_gen/models/ltx2/pipeline_ltx2.py | 4 +- .../visual_gen/models/wan/pipeline_wan_i2v.py | 2 +- .../visual_gen/models/wan/transformer_wan.py | 61 ++++++------------- .../_torch/visual_gen/pipeline_loader.py | 5 +- 7 files changed, 47 insertions(+), 62 deletions(-) diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py index e96781b345e8..c173bffd37d5 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -CuTe DSL Backend for Visual Generation Models +CuTe DSL (NVIDIA kernels) Backend for Visual Generation Models CuTeDSLAttention runs the VSA sparse path when sparse_attention_config is set, otherwise the dense cubin path (with optional QK16PV8 quantization). @@ -47,6 +47,17 @@ cute_dsl = None _cute_dsl_import_error = e +_vsa_import_error = None +try: + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + block_sparse_attn_from_indices_cute, + is_cute_supported, + ) +except (ImportError, OSError) as e: + block_sparse_attn_from_indices_cute = None + is_cute_supported = None + _vsa_import_error = e + # VSA (Video Sparse Attention) sparse-path helpers @@ -267,8 +278,9 @@ def __init__( skip_softmax_threshold_scale: Optional[float] = None, **kwargs, ): - # Dense path requires head_dim=128 (packaged cubins); the VSA sparse - # path JIT-compiles per shape, so it has no such restriction. + # Dense cubin path is head_dim=128-only (packaged cubins), so enforce it + # here. The VSA path needs no check: it is gated at runtime by + # is_cute_supported and falls back to dense SDPA when head_dim != 128. if sparse_attention_config is None and head_dim != 128: raise ValueError(f"CUTEDSL cubins require head_dim=128, got head_dim={head_dim}.") self.layer_idx = layer_idx @@ -479,13 +491,6 @@ def _forward_vsa( Returns: [B, S, H, D] in the same original token order. """ - # Lazy import: the VSA kernels package is optional and may not be - # importable in environments without the cute-dsl runtime. - from ..cute_dsl_kernels.blackwell.video_sparse_attention import ( - block_sparse_attn_from_indices_cute, - is_cute_supported, - ) - if gate_compress is None: raise ValueError( "CuTeDSLAttention VSA path requires gate_compress. " @@ -519,7 +524,9 @@ def _forward_vsa( attn_probs_c = scores_c.softmax(dim=-1) o_c = torch.einsum("bhnm,bmhd->bnhd", attn_probs_c, v_c) - use_cute = is_cute_supported(q) and (q.dtype == k.dtype == v.dtype) + use_cute = ( + _vsa_import_error is None and is_cute_supported(q) and (q.dtype == k.dtype == v.dtype) + ) topk_indices = attn_probs_c.topk(cur_topk, dim=-1).indices.to(torch.int32) o_c_tiled = ( diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py index 930414d840f4..bd21aba0b6f2 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py @@ -63,8 +63,8 @@ def __init__(self, model_config): if _sa_cfg is not None and getattr(_sa_cfg, "algorithm", None) == "vsa": raise ValueError( "Video Sparse Attention (sparse_attention_config.algorithm='vsa') is " - "only supported by Wan T2V pipelines. Remove sparse_attention_config " - "for FLUX." + "only supported by the Wan 2.1 T2V 14B (720P) pipeline. Remove " + "sparse_attention_config for FLUX." ) super().__init__(model_config) diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py index 874a3edeed45..fabc8bd3ef19 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py @@ -132,8 +132,8 @@ def __init__(self, model_config): if _sa_cfg is not None and getattr(_sa_cfg, "algorithm", None) == "vsa": raise ValueError( "Video Sparse Attention (sparse_attention_config.algorithm='vsa') is " - "only supported by Wan T2V pipelines. Remove sparse_attention_config " - "for FLUX.2." + "only supported by the Wan 2.1 T2V 14B (720P) pipeline. Remove " + "sparse_attention_config for FLUX.2." ) super().__init__(model_config) diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py index 6cd874997bab..a0d722a1277d 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py @@ -604,8 +604,8 @@ def __init__(self, model_config): if _sa_cfg is not None and getattr(_sa_cfg, "algorithm", None) == "vsa": raise ValueError( "Video Sparse Attention (sparse_attention_config.algorithm='vsa') is " - "only supported by Wan T2V pipelines. Remove sparse_attention_config " - "for LTX-2." + "only supported by the Wan 2.1 T2V 14B (720P) pipeline. Remove " + "sparse_attention_config for LTX-2." ) super().__init__(model_config) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index 94204e378331..d6162af4d7ac 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -105,7 +105,7 @@ def __init__(self, model_config): if _sa_cfg is not None and getattr(_sa_cfg, "algorithm", None) == "vsa": raise ValueError( "Video Sparse Attention (sparse_attention_config.algorithm='vsa') is " - "only supported by Wan T2V pipelines (Wan 2.1 and Wan 2.2). Remove " + "only supported by the Wan 2.1 T2V 14B (720P) pipeline. Remove " "sparse_attention_config for Wan I2V." ) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py index 3eb08b415ad3..5162c49962c5 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py @@ -246,6 +246,19 @@ def forward( return temb, temb_proj, encoder_hidden_states, encoder_hidden_states_image +def _default_vsa_gate(linear: Linear, bias_value: float) -> None: + """Default a VSA gate Linear absent from the checkpoint: weight=0, bias=bias_value + (G_c=0 / G_f=1, preserving dense behavior at sparsity=0).""" + if not linear._weights_created: + linear.create_weights() + if linear.weight.is_meta: + return + with torch.no_grad(): + linear.weight.zero_() + if linear.bias is not None: + linear.bias.fill_(bias_value) + + class WanBlock(nn.Module): def __init__( self, @@ -417,32 +430,6 @@ def __init__( torch.empty(1, 6, hidden_size).normal_(std=hidden_size**-0.5) ) - def init_gate_compress_zero(self) -> None: - """Zero-initialize to_gate_compress.""" - if self.to_gate_compress is None: - return - if not self.to_gate_compress._weights_created: - self.to_gate_compress.create_weights() - if self.to_gate_compress.weight.is_meta: - return - with torch.no_grad(): - self.to_gate_compress.weight.zero_() - if self.to_gate_compress.bias is not None: - self.to_gate_compress.bias.zero_() - - def init_gate_fine_default(self) -> None: - """Initialize to_gate_fine to emit constant 1 (weight=0, bias=1).""" - if self.to_gate_fine is None: - return - if not self.to_gate_fine._weights_created: - self.to_gate_fine.create_weights() - if self.to_gate_fine.weight.is_meta: - return - with torch.no_grad(): - self.to_gate_fine.weight.zero_() - if self.to_gate_fine.bias is not None: - self.to_gate_fine.bias.fill_(1.0) - def forward( self, x, @@ -839,7 +826,6 @@ def load_weights(self, weights: dict) -> None: } loader = DynamicLinearWeightLoader(self.model_config, params_map=params_map) - loaded_linears: set = set() for name, module in tqdm(self.named_modules(), desc="Loading weights"): if len(module._parameters) == 0: continue @@ -849,7 +835,12 @@ def load_weights(self, weights: dict) -> None: if weight_dicts: loader.load_linear_weights(module, name, weight_dicts) - loaded_linears.add(name) + # VSA gates absent from the checkpoint default to G_c=0 / G_f=1 + # (dense behavior at sparsity=0). + elif name.endswith(".to_gate_compress"): + _default_vsa_gate(module, 0.0) + elif name.endswith(".to_gate_fine"): + _default_vsa_gate(module, 1.0) elif "add_k_proj" in name or "add_v_proj" in name: logger.info(f"[Weight Loading] No weights found for I2V module: {name}") elif isinstance(module, RMSNormTPAware): @@ -863,20 +854,6 @@ def load_weights(self, weights: dict) -> None: module_weights[param_name].to(self.model_config.torch_dtype) ) - # Default any VSA gates not loaded from the checkpoint: G_c=0, G_f=1 - # (preserves dense behavior at sparsity=0). - for name, module in self.named_modules(): - if not isinstance(module, WanBlock): - continue - if module.to_gate_compress is not None: - gate_path = f"{name}.to_gate_compress" if name else "to_gate_compress" - if gate_path not in loaded_linears: - module.init_gate_compress_zero() - if module.to_gate_fine is not None: - gate_path = f"{name}.to_gate_fine" if name else "to_gate_fine" - if gate_path not in loaded_linears: - module.init_gate_fine_default() - def post_load_weights(self) -> None: """Call post_load_weights on all Linear modules and convert embedders to target dtype.""" # Convert condition_embedder components to target dtype diff --git a/tensorrt_llm/_torch/visual_gen/pipeline_loader.py b/tensorrt_llm/_torch/visual_gen/pipeline_loader.py index 0098ff57f7f4..ec24e2a839da 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline_loader.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline_loader.py @@ -23,6 +23,9 @@ from tensorrt_llm._torch.autotuner import autotune from tensorrt_llm._torch.models.modeling_utils import MetaInitMode +from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, +) from tensorrt_llm.llmapi.utils import download_hf_model from tensorrt_llm.logger import logger from tensorrt_llm.visual_gen.args import VisualGenArgs @@ -225,8 +228,6 @@ def load( and _sa_cfg is not None and getattr(_sa_cfg, "algorithm", None) == "vsa" ): - from .cute_dsl_kernels.blackwell.video_sparse_attention import CUTE_AVAILABLE - kernel_path = "CuTe DSL block-sparse" if CUTE_AVAILABLE else "dense SDPA fallback" logger.info( f"Attention backend: CUTEDSL (algorithm=vsa, " From d4cf5b78e56f0b7f8b771d8f895a8e2ecd05f9ab Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:48:06 -0700 Subject: [PATCH 04/14] flow_shift update: no override unless value provided Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../visual_gen/models/wan/pipeline_wan.py | 36 +++++++------------ .../visual_gen/models/wan/pipeline_wan_i2v.py | 33 +++++++---------- 2 files changed, 25 insertions(+), 44 deletions(-) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index cf7164f5a14f..8e991eb014e4 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -381,15 +381,6 @@ def infer(self, req): flow_shift=req.params.flow_shift, ) - def _default_flow_shift(self, height: int, width: int) -> float: - """Recommended flow_shift for the active Wan variant + resolution.""" - - if self.is_wan22_14b: - return 12.0 # Wan2.2 T2V A14B - if self.is_wan22_5b: - return 5.0 # Wan2.2 TI2V 5B - return 5.0 if max(height, width) >= 1280 else 3.0 # Wan2.1 T2V (720P vs 480P) - @nvtx_range("WanPipeline.forward") @torch.no_grad() def forward( @@ -495,21 +486,20 @@ def forward( latents = self._prepare_latents(batch_size, height, width, num_frames, generator) logger.debug(f"Latents shape: {latents.shape}") - # Resolve flow_shift: user override wins, else the per-variant recommended default. - resolved_flow_shift = ( - flow_shift if flow_shift is not None else self._default_flow_shift(height, width) - ) - - sched_cfg = self.scheduler.config - shift_key = ( - "shift" if "shift" in sched_cfg else "flow_shift" if "flow_shift" in sched_cfg else None - ) - if shift_key is not None and sched_cfg[shift_key] != resolved_flow_shift: - logger.info( - f"flow_shift: {sched_cfg[shift_key]} -> {resolved_flow_shift} " - f"({'user' if flow_shift is not None else 'variant default'})" + # Apply an explicit user flow_shift override; otherwise keep the checkpoint + # scheduler default so output matches the reference HuggingFace pipeline. + if flow_shift is not None: + sched_cfg = self.scheduler.config + shift_key = ( + "shift" + if "shift" in sched_cfg + else "flow_shift" + if "flow_shift" in sched_cfg + else None ) - self.scheduler.register_to_config(**{shift_key: resolved_flow_shift}) + if shift_key is not None and sched_cfg[shift_key] != flow_shift: + logger.info(f"flow_shift: {sched_cfg[shift_key]} -> {flow_shift} (user)") + self.scheduler.register_to_config(**{shift_key: flow_shift}) self.scheduler.set_timesteps(num_inference_steps, device=self.device) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index d6162af4d7ac..0efaa98acbba 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -415,13 +415,6 @@ def infer(self, req): flow_shift=req.params.flow_shift, ) - def _default_flow_shift(self, height: int, width: int) -> float: - """Recommended flow_shift for the active Wan I2V variant + resolution.""" - - if self.is_wan22_14b: - return 5.0 # Wan2.2 I2V A14B - return 5.0 if max(height, width) >= 1280 else 3.0 # Wan2.1 I2V (720P vs 480P) - @torch.no_grad() def forward( self, @@ -546,21 +539,19 @@ def forward( batch_size, image, height, width, num_frames, generator, last_image ) - # Resolve flow_shift: user override wins, else the per-variant default. - resolved_flow_shift = ( - flow_shift if flow_shift is not None else self._default_flow_shift(height, width) - ) - - sched_cfg = self.scheduler.config - shift_key = ( - "shift" if "shift" in sched_cfg else "flow_shift" if "flow_shift" in sched_cfg else None - ) - if shift_key is not None and sched_cfg[shift_key] != resolved_flow_shift: - logger.info( - f"flow_shift: {sched_cfg[shift_key]} -> {resolved_flow_shift} " - f"({'user' if flow_shift is not None else 'variant default'})" + # Apply an explicit user flow_shift override; otherwise keep the checkpoint scheduler default. + if flow_shift is not None: + sched_cfg = self.scheduler.config + shift_key = ( + "shift" + if "shift" in sched_cfg + else "flow_shift" + if "flow_shift" in sched_cfg + else None ) - self.scheduler.register_to_config(**{shift_key: resolved_flow_shift}) + if shift_key is not None and sched_cfg[shift_key] != flow_shift: + logger.info(f"flow_shift: {sched_cfg[shift_key]} -> {flow_shift} (user)") + self.scheduler.register_to_config(**{shift_key: flow_shift}) self.scheduler.set_timesteps(num_inference_steps, device=self.device) From 7eb5039431f0b8549cf9770f4e3ee42b921ef415 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Wed, 3 Jun 2026 11:29:22 -0700 Subject: [PATCH 05/14] address CodeRabbit comments Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../visual_gen/attention_backend/cute_dsl.py | 23 +++++++++---------- .../block_sparse_attn_dsl_fwd.py | 4 +++- .../video_sparse_attention/interface.py | 9 ++++++++ .../visual_gen/models/wan/pipeline_wan.py | 13 ++++++++--- .../visual_gen/models/wan/pipeline_wan_i2v.py | 13 ++++++++--- .../visual_gen/models/wan/transformer_wan.py | 5 ++++ .../_torch/visual_gen/modules/attention.py | 2 +- .../visual_gen/test_attention_integration.py | 5 ++++ .../_torch/visual_gen/test_attention_perf.py | 17 ++++++++++++++ 9 files changed, 71 insertions(+), 20 deletions(-) diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py index c173bffd37d5..92f85f6a387d 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py @@ -20,6 +20,7 @@ Expects NHD layout ([B, S, H, D]) and supports float16/bfloat16. """ +import contextvars import math from contextlib import contextmanager from dataclasses import dataclass @@ -198,22 +199,22 @@ def build( ) -_vsa_forward_context: Optional[VSAMetadata] = None +_vsa_forward_context_var: contextvars.ContextVar[Optional[VSAMetadata]] = contextvars.ContextVar( + "_vsa_forward_context", default=None +) @contextmanager def set_vsa_forward_context(metadata: VSAMetadata): - global _vsa_forward_context - prev = _vsa_forward_context - _vsa_forward_context = metadata + token = _vsa_forward_context_var.set(metadata) try: yield finally: - _vsa_forward_context = prev + _vsa_forward_context_var.reset(token) def get_vsa_forward_context() -> Optional[VSAMetadata]: - return _vsa_forward_context + return _vsa_forward_context_var.get(None) def _mean_pool_cubes( @@ -525,7 +526,10 @@ def _forward_vsa( o_c = torch.einsum("bhnm,bmhd->bnhd", attn_probs_c, v_c) use_cute = ( - _vsa_import_error is None and is_cute_supported(q) and (q.dtype == k.dtype == v.dtype) + _vsa_import_error is None + and is_cute_supported(q) + and (q.dtype == k.dtype == v.dtype) + and num_cubes <= VSA_KERNEL_MAX_CUBES ) topk_indices = attn_probs_c.topk(cur_topk, dim=-1).indices.to(torch.int32) @@ -534,11 +538,6 @@ def _forward_vsa( ) if use_cute: - assert num_cubes <= VSA_KERNEL_MAX_CUBES, ( - f"VSA CuTe kernel supports at most {VSA_KERNEL_MAX_CUBES} cubes " - f"(SMEM-allocated variable_block_sizes buffer); got num_cubes={num_cubes}. " - "Lower video resolution/length or fall back to dense SDPA." - ) q_hnd = q_t.transpose(1, 2).contiguous() k_hnd = k_t.transpose(1, 2).contiguous() v_hnd = v_t.transpose(1, 2).contiguous() diff --git a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py index 3bae813fc7f1..f0d548992c2a 100644 --- a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py +++ b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py @@ -62,6 +62,8 @@ class VideoSparseAttentionForwardGroup2QInterleaveKV: Q1 0 S1 O1 """ + MAX_INDICES = 4 * 1024 + def __init__( self, block_m: int, @@ -413,7 +415,7 @@ def __call__( (2, self.mma_tiler_qk[0], self.scale_buffers), (0, 1, 2) ) - self.max_indices = 4 * 1024 + self.max_indices = self.MAX_INDICES @cute.struct class SharedStorage: diff --git a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/interface.py b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/interface.py index e590234f45e7..13a9e08d1d5e 100644 --- a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/interface.py +++ b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/interface.py @@ -112,6 +112,15 @@ def block_sparse_attn_from_indices_cute( "cutlass-dsl is not importable." ) + num_q_blk = variable_block_sizes.shape[0] + if num_q_blk > VideoSparseAttentionForward.MAX_INDICES: + raise ValueError( + f"variable_block_sizes has {num_q_blk} entries but the CuTe kernel " + f"supports at most {VideoSparseAttentionForward.MAX_INDICES} " + "(SMEM-allocated sVariable_block_sizes buffer). Lower video " + "resolution/length or fall back to dense SDPA." + ) + B, H, T, D = q.shape if sm_scale is None: sm_scale = 1.0 / math.sqrt(D) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index 8e991eb014e4..4928804a0df5 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -488,8 +488,10 @@ def forward( # Apply an explicit user flow_shift override; otherwise keep the checkpoint # scheduler default so output matches the reference HuggingFace pipeline. + sched_cfg = self.scheduler.config + shift_key = None + orig_shift = None if flow_shift is not None: - sched_cfg = self.scheduler.config shift_key = ( "shift" if "shift" in sched_cfg @@ -498,10 +500,15 @@ def forward( else None ) if shift_key is not None and sched_cfg[shift_key] != flow_shift: - logger.info(f"flow_shift: {sched_cfg[shift_key]} -> {flow_shift} (user)") + orig_shift = sched_cfg[shift_key] + logger.info(f"flow_shift: {orig_shift} -> {flow_shift} (user)") self.scheduler.register_to_config(**{shift_key: flow_shift}) - self.scheduler.set_timesteps(num_inference_steps, device=self.device) + try: + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + finally: + if orig_shift is not None: + self.scheduler.register_to_config(**{shift_key: orig_shift}) # Wan2.2 A14B: Calculate boundary timestep for two-stage denoising boundary_timestep = None diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index 0efaa98acbba..8d41ec8df1da 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -540,8 +540,10 @@ def forward( ) # Apply an explicit user flow_shift override; otherwise keep the checkpoint scheduler default. + sched_cfg = self.scheduler.config + shift_key = None + orig_shift = None if flow_shift is not None: - sched_cfg = self.scheduler.config shift_key = ( "shift" if "shift" in sched_cfg @@ -550,10 +552,15 @@ def forward( else None ) if shift_key is not None and sched_cfg[shift_key] != flow_shift: - logger.info(f"flow_shift: {sched_cfg[shift_key]} -> {flow_shift} (user)") + orig_shift = sched_cfg[shift_key] + logger.info(f"flow_shift: {orig_shift} -> {flow_shift} (user)") self.scheduler.register_to_config(**{shift_key: flow_shift}) - self.scheduler.set_timesteps(num_inference_steps, device=self.device) + try: + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + finally: + if orig_shift is not None: + self.scheduler.register_to_config(**{shift_key: orig_shift}) # Wan2.2: Calculate boundary timestep for two-stage denoising boundary_timestep = None diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py index 5162c49962c5..a87583ef3363 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py @@ -368,6 +368,7 @@ def __init__( ) if _is_vsa: q_dim = num_heads * head_dim + gate_tp_mode = TensorParallelMode.COLUMN if tp_size > 1 else None self.to_gate_compress = Linear( hidden_size, q_dim, @@ -377,6 +378,8 @@ def __init__( quant_config=quant_config, skip_create_weights_in_init=skip_create_weights, force_dynamic_quantization=force_dynamic_quant, + tensor_parallel_mode=gate_tp_mode, + reduce_output=False, ) self.to_gate_fine = Linear( hidden_size, @@ -387,6 +390,8 @@ def __init__( quant_config=quant_config, skip_create_weights_in_init=skip_create_weights, force_dynamic_quantization=force_dynamic_quant, + tensor_parallel_mode=gate_tp_mode, + reduce_output=False, ) # I2V: Additional K/V projections for image embeddings. diff --git a/tensorrt_llm/_torch/visual_gen/modules/attention.py b/tensorrt_llm/_torch/visual_gen/modules/attention.py index 48cdb4004766..2820095d9a47 100644 --- a/tensorrt_llm/_torch/visual_gen/modules/attention.py +++ b/tensorrt_llm/_torch/visual_gen/modules/attention.py @@ -455,7 +455,7 @@ def _attn_impl( seq_len_kv = k.shape[1] if k is not None else seq_len def _reshape_gate(gate: torch.Tensor) -> torch.Tensor: - gate = gate.view(batch_size, -1, self.num_attention_heads, self.head_dim) + gate = gate.view(batch_size, -1, self.local_num_attention_heads, self.head_dim) if backend_layout == AttentionTensorLayout.HND: gate = gate.transpose(1, 2) return gate diff --git a/tests/unittest/_torch/visual_gen/test_attention_integration.py b/tests/unittest/_torch/visual_gen/test_attention_integration.py index bb798fd62fa1..b08d3469e96e 100644 --- a/tests/unittest/_torch/visual_gen/test_attention_integration.py +++ b/tests/unittest/_torch/visual_gen/test_attention_integration.py @@ -698,6 +698,11 @@ def _build_vsa_setup(sparsity: float, batch_size: int, seed: int): integrated = Attention(hidden_size, num_heads, qkv_mode=QKVMode.FUSE_QKV, config=cfg_vsa).to( device ) + # Fail loudly if the VSA path silently fell back to dense (which would set + # attn_backend to "VANILLA") instead of selecting the CUTEDSL/VSA backend. + assert integrated.attn_backend == "CUTEDSL", ( + f"Expected CUTEDSL (VSA) backend, got {integrated.attn_backend!r}" + ) copy_weights_self_attention(naive, integrated) naive.eval() integrated.eval() diff --git a/tests/unittest/_torch/visual_gen/test_attention_perf.py b/tests/unittest/_torch/visual_gen/test_attention_perf.py index 693efcc663bd..763309f3f96d 100644 --- a/tests/unittest/_torch/visual_gen/test_attention_perf.py +++ b/tests/unittest/_torch/visual_gen/test_attention_perf.py @@ -515,6 +515,16 @@ def benchmark_single( ) hidden_states, freqs = self.create_test_data(batch_size, seq_len, hidden_size, head_dim) + # Fail fast if the requested backend silently resolved to something + # else (e.g. CUTEDSL/VSA downgraded to VANILLA dense for an + # unsupported config). + resolved_backend = getattr(model, "attn_backend", backend) + if vsa_sparsity is not None: + assert resolved_backend == "CUTEDSL", ( + f"VSA run requested CUTEDSL but the attention module resolved to " + f"{resolved_backend!r}; refusing to report VSA timings for a dense fallback." + ) + # VSA needs an active forward context + a zero gate_compress; other # backends use neither. vsa_metadata = None @@ -572,6 +582,7 @@ def _forward(): "p99_ms": torch.quantile(times_tensor, 0.99).item(), "times_ms": times, "uses_sage": _is_sage_attention_enabled(model), + "resolved_backend": resolved_backend, } # Calculate throughput (approximate TOPS) @@ -587,6 +598,9 @@ def _forward(): return stats + except AssertionError: + # Backend-resolution / fail-fast checks are genuine test failures. + raise except Exception as e: if verbose: print(f" {backend}: ERROR - {e}") @@ -1224,6 +1238,9 @@ def test_vsa_module_vs_vanilla_wan22_t2v_14b(self, sparsity: float): backend="CUTEDSL", vsa_sparsity=sparsity, latent_shape=self._LATENT_SHAPE, **common ) assert vanilla is not None and vsa is not None, "benchmark returned None (skipped/failed)" + assert vsa["resolved_backend"] == "CUTEDSL", ( + f"VSA timings came from {vsa['resolved_backend']!r}, not the CUTEDSL/VSA path" + ) assert vanilla["avg_ms"] > 0 and vsa["avg_ms"] > 0 speedup = vanilla["avg_ms"] / vsa["avg_ms"] From c30cae8c633ba0493415728bb268247039ac42a9 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Fri, 5 Jun 2026 11:30:25 -0700 Subject: [PATCH 06/14] decouple fmha/vsa interfaces for cutedsl Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../visual_gen/attention_backend/__init__.py | 2 + .../attention_backend/cute_dsl/__init__.py | 46 ++++ .../attention_backend/cute_dsl/fmha.py | 211 ++++++++++++++++ .../{cute_dsl.py => cute_dsl/vsa.py} | 230 ++---------------- .../visual_gen/attention_backend/utils.py | 8 +- .../visual_gen/test_attention_cute_dsl_vsa.py | 2 +- 6 files changed, 290 insertions(+), 209 deletions(-) create mode 100644 tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/__init__.py create mode 100644 tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/fmha.py rename tensorrt_llm/_torch/visual_gen/attention_backend/{cute_dsl.py => cute_dsl/vsa.py} (61%) diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py b/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py index fbcbc931f0c4..6b5d9b538d8b 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/__init__.py @@ -23,6 +23,7 @@ from .cute_dsl import ( VSA_TILE_SIZE, CuTeDSLAttention, + VSAAttention, VSAMetadata, VSAMetadataBuilder, get_vsa_forward_context, @@ -42,6 +43,7 @@ "get_visual_gen_attention_backend", "create_attention", "CuTeDSLAttention", + "VSAAttention", "FlashAttn4Attention", "TrtllmAttention", "TrtllmAttentionMetadata", diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/__init__.py b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/__init__.py new file mode 100644 index 000000000000..11b8858db56f --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/__init__.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +CuTe DSL attention backend family for visual generation models. + + fmha.py — CuTeDSLAttention (dense cubin path, head_dim=128) + vsa.py — VSAAttention (Video Sparse Attention, CuTe JIT + SDPA fallback) + _common.py — shared CuTe DSL plumbing +""" + +from .fmha import CuTeDSLAttention, _cute_dsl_import_error +from .vsa import ( + VSA_KERNEL_MAX_CUBES, + VSA_TILE_SIZE, + VSAAttention, + VSAMetadata, + VSAMetadataBuilder, + VSAPreprocessor, + get_vsa_forward_context, + set_vsa_forward_context, +) + +__all__ = [ + "CuTeDSLAttention", + "VSAAttention", + "VSAMetadata", + "VSAMetadataBuilder", + "VSAPreprocessor", + "VSA_TILE_SIZE", + "VSA_KERNEL_MAX_CUBES", + "set_vsa_forward_context", + "get_vsa_forward_context", + "_cute_dsl_import_error", +] diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/fmha.py b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/fmha.py new file mode 100644 index 000000000000..e71ede03f3ca --- /dev/null +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/fmha.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +CuTe DSL dense FMHA backend for visual generation models. + +CuTeDSLAttention uses pre-compiled cubin kernels and requires head_dim=128. +Supports float16/bfloat16 with optional QK16PV8 quantization and a +skip-softmax threshold optimisation. Expects NHD layout ([B, S, H, D]). + +For the VSA sparse path use VSAAttention in vsa.py. +""" + +import math +from typing import Optional, Tuple + +import torch + +from tensorrt_llm.visual_gen.args import QuantAttentionConfig + +from ....attention_backend.interface import PredefinedAttentionMask +from ..interface import AttentionBackend, AttentionTensorLayout + +_cute_dsl_import_error = None +try: + import tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.attention as cute_dsl + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.attention.fmha import ( + _cute_runtime_import_error, + ) + + if _cute_runtime_import_error is not None: + raise ImportError(_cute_runtime_import_error) +except (ImportError, OSError) as e: + cute_dsl = None + _cute_dsl_import_error = e + + +class CuTeDSLAttention(AttentionBackend): + """ + CuTe DSL dense FMHA backend for diffusion models. + + Uses pre-compiled cubins and requires head_dim=128. + Supports float16/bfloat16 with optional QK16PV8 quantization (quant_attention_config) + and a skip-softmax threshold optimisation (skip_softmax_threshold_scale). + """ + + def __init__( + self, + layer_idx: int = 0, + num_heads: int = 8, + head_dim: int = 64, + num_kv_heads: Optional[int] = None, + dtype: Optional[torch.dtype] = None, + quant_attention_config: Optional[QuantAttentionConfig] = None, + skip_softmax_threshold_scale: Optional[float] = None, + **kwargs, + ): + if head_dim != 128: + raise ValueError(f"CUTEDSL cubins require head_dim=128, got head_dim={head_dim}.") + self.layer_idx = layer_idx + self.num_heads = num_heads + self.head_dim = head_dim + self.num_kv_heads = num_kv_heads or num_heads + self.dtype = dtype + self.quant_attention_config = quant_attention_config + self.skip_softmax_threshold_scale = skip_softmax_threshold_scale + self.scale = 1.0 / math.sqrt(head_dim) + + def _prepare_inputs( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_mask: PredefinedAttentionMask, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool, torch.dtype]: + """Cast inputs to CuTeDSL-compatible dtype and resolve causal flag.""" + if _cute_dsl_import_error is not None: + raise ImportError( + f"CuTe DSL kernels are not available. Import error: {_cute_dsl_import_error}" + ) from _cute_dsl_import_error + + is_causal = attention_mask == PredefinedAttentionMask.CAUSAL + + # Packaged cubins support float16 and bfloat16 only. + origin_dtype = q.dtype + if q.dtype not in (torch.float16, torch.bfloat16): + q = q.to(torch.bfloat16) + k = k.to(torch.bfloat16) + v = v.to(torch.bfloat16) + return q, k, v, is_causal, origin_dtype + + # cute_dsl.cute_dsl_fmha_fwd is already decorated with @torch.compiler.disable + # Allow torch.compile to fuse preceding linear/norm with quantization of V / seq-preprocess + def _fwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + is_causal: bool, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_len_q, num_heads, _ = q.shape + _, seq_len_kv, _, value_head_dim = v.shape + out = torch.empty( + batch_size, + seq_len_q, + num_heads, + value_head_dim, + dtype=q.dtype, + device=q.device, + ) + lse = torch.empty( + batch_size, + seq_len_q, + num_heads, + dtype=torch.float32, + device=q.device, + ) + + # Options that instructs quantization of V + scale_v = kwargs.get("scale_v", 1.0) + if self.quant_attention_config is not None: + v_qscale = 448.0 / v.abs().amax().clamp(min=1e-3) + v = (v * v_qscale).to(torch.float8_e4m3fn) + scale_v = scale_v / v_qscale + + # Sequence preproc. + qo_indptr_host = [i * seq_len_q for i in range(batch_size + 1)] + qo_indptr = torch.tensor(qo_indptr_host).to(device=q.device, dtype=torch.int32) + kv_indptr_host = [i * seq_len_kv for i in range(batch_size + 1)] + kv_indptr = torch.tensor(kv_indptr_host).to(device=q.device, dtype=torch.int32) + + # Skip softmax. + skip_softmax_threshold_scale = self.skip_softmax_threshold_scale + if skip_softmax_threshold_scale is not None and skip_softmax_threshold_scale <= 0.0: + skip_softmax_threshold_scale = None + + cute_dsl.cute_dsl_fmha_fwd( + q.flatten(0, 1).contiguous(), + k.flatten(0, 1).contiguous(), + v.flatten(0, 1).contiguous(), + out.flatten(0, 1), + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + is_causal=is_causal, + sm_scale=self.scale, + lse=lse.flatten(0, 1).contiguous(), + scale_q=kwargs.get("scale_q", 1.0), + scale_k=kwargs.get("scale_k", 1.0), + scale_v=scale_v, + scale_o=kwargs.get("scale_o", 1.0), + max_qo_len=seq_len_q, + max_kv_len=seq_len_kv, + is_persistent=False, + skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale, + ) + return out, lse + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL, + **kwargs, + ) -> torch.Tensor: + output, _ = self.forward_with_lse(q, k, v, attention_mask=attention_mask, **kwargs) + return output + + def forward_with_lse( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns: + output: [batch_size, seq_len, num_heads, head_dim] + lse: [batch_size, num_heads, seq_len] — log-sum-exp in float32. + """ + q, k, v, is_causal, origin_dtype = self._prepare_inputs(q, k, v, attention_mask) + output, lse = self._fwd(q, k, v, is_causal, **kwargs) + if output.dtype != origin_dtype: + output = output.to(origin_dtype) + return output, lse.transpose(1, 2) + + @classmethod + def support_lse(cls) -> bool: + return True + + @property + def preferred_layout(self) -> AttentionTensorLayout: + return AttentionTensorLayout.NHD + + @classmethod + def support_fused_qkv(cls) -> bool: + return False diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/vsa.py similarity index 61% rename from tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py rename to tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/vsa.py index 92f85f6a387d..9cd2c8e1b414 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/vsa.py @@ -13,15 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -CuTe DSL (NVIDIA kernels) Backend for Visual Generation Models +Video Sparse Attention (VSA) backend for visual generation models. -CuTeDSLAttention runs the VSA sparse path when sparse_attention_config is set, -otherwise the dense cubin path (with optional QK16PV8 quantization). -Expects NHD layout ([B, S, H, D]) and supports float16/bfloat16. +VSAAttention implements hierarchical sparse attention: + - Coarse branch: mean-pooled cube attention (always dense) + - Fine branch: block-sparse top-K attention via CuTe JIT kernel (sm100+) + or dense SDPA fallback when CuTe is unavailable / head_dim != 128. """ import contextvars -import math from contextlib import contextmanager from dataclasses import dataclass from math import ceil @@ -30,23 +30,7 @@ import torch import torch.nn.functional as F -from tensorrt_llm.visual_gen.args import QuantAttentionConfig - -from ...attention_backend.interface import PredefinedAttentionMask -from .interface import AttentionBackend, AttentionTensorLayout - -_cute_dsl_import_error = None -try: - import tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.attention as cute_dsl - from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.attention.fmha import ( - _cute_runtime_import_error, - ) - - if _cute_runtime_import_error is not None: - raise ImportError(_cute_runtime_import_error) -except (ImportError, OSError) as e: - cute_dsl = None - _cute_dsl_import_error = e +from ..interface import AttentionBackend, AttentionTensorLayout _vsa_import_error = None try: @@ -60,8 +44,6 @@ _vsa_import_error = e -# VSA (Video Sparse Attention) sparse-path helpers - # Must match the Blackwell kernel's block_size expectation. VSA_TILE_SIZE: Tuple[int, int, int] = (4, 4, 4) @@ -258,13 +240,16 @@ def untile( return x.index_select(1, non_pad_index).index_select(1, reverse_tile_partition_indices) -class CuTeDSLAttention(AttentionBackend): +class VSAAttention(AttentionBackend): """ - CuTe DSL (NVIDIA kernels) backend for diffusion models. + Video Sparse Attention (VSA) backend for diffusion models. + + Implements coarse mean-pool + fine block-sparse top-K attention. + The fine branch uses a JIT-compiled CuTe kernel on sm100+ for + head_dim=128 / fp16-bf16; otherwise falls back to dense SDPA. - Dense path uses pre-compiled cubins and requires head_dim=128. The VSA - sparse path (sparse_attention_config set) uses a JIT-compiled CuTe kernel - when head_dim=128 / fp16-bf16 / sm100+, and otherwise falls back to dense SDPA. + Requires an active VSA forward context (set_vsa_forward_context) during + each forward call. Does not support LSE output. """ def __init__( @@ -274,183 +259,15 @@ def __init__( head_dim: int = 64, num_kv_heads: Optional[int] = None, dtype: Optional[torch.dtype] = None, - quant_attention_config: Optional[QuantAttentionConfig] = None, sparse_attention_config=None, - skip_softmax_threshold_scale: Optional[float] = None, **kwargs, ): - # Dense cubin path is head_dim=128-only (packaged cubins), so enforce it - # here. The VSA path needs no check: it is gated at runtime by - # is_cute_supported and falls back to dense SDPA when head_dim != 128. - if sparse_attention_config is None and head_dim != 128: - raise ValueError(f"CUTEDSL cubins require head_dim=128, got head_dim={head_dim}.") self.layer_idx = layer_idx self.num_heads = num_heads self.head_dim = head_dim self.num_kv_heads = num_kv_heads or num_heads self.dtype = dtype - self.quant_attention_config = quant_attention_config self.sparse_attention_config = sparse_attention_config - self.skip_softmax_threshold_scale = skip_softmax_threshold_scale - self.scale = 1.0 / math.sqrt(head_dim) - - # CuTe DSL expects [B, S, H, D] format - self._preferred_layout = AttentionTensorLayout.NHD - - def _prepare_inputs( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attention_mask: PredefinedAttentionMask, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool, torch.dtype]: - """Cast inputs to CuTeDSL-compatible dtype and resolve causal flag.""" - if _cute_dsl_import_error is not None: - raise ImportError( - f"CuTe DSL kernels are not available. Import error: {_cute_dsl_import_error}" - ) from _cute_dsl_import_error - - is_causal = attention_mask == PredefinedAttentionMask.CAUSAL - - # Packaged cubins support float16 and bfloat16 only. - origin_dtype = q.dtype - if q.dtype not in (torch.float16, torch.bfloat16): - q = q.to(torch.bfloat16) - k = k.to(torch.bfloat16) - v = v.to(torch.bfloat16) - return q, k, v, is_causal, origin_dtype - - # cute_dsl.cute_dsl_fmha_fwd is already decorated with @torch.compiler.disable - # Allow torch.compile to fuse preceding linear/norm with quantization of V / seq-preprocess - def _fwd( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - is_causal: bool, - **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size, seq_len_q, num_heads, _ = q.shape - _, seq_len_kv, _, value_head_dim = v.shape - out = torch.empty( - batch_size, - seq_len_q, - num_heads, - value_head_dim, - dtype=q.dtype, - device=q.device, - ) - lse = torch.empty( - batch_size, - seq_len_q, - num_heads, - dtype=torch.float32, - device=q.device, - ) - - # Options that instructs quantization of V - scale_v = kwargs.get("scale_v", 1.0) - if self.quant_attention_config is not None: - v_qscale = 448.0 / v.abs().amax().clamp(min=1e-3) - v = (v * v_qscale).to(torch.float8_e4m3fn) - scale_v = scale_v / v_qscale - - # Sequence preproc. - qo_indptr_host = [i * seq_len_q for i in range(batch_size + 1)] - qo_indptr = torch.tensor(qo_indptr_host).to(device=q.device, dtype=torch.int32) - kv_indptr_host = [i * seq_len_kv for i in range(batch_size + 1)] - kv_indptr = torch.tensor(kv_indptr_host).to(device=q.device, dtype=torch.int32) - - # Skip softmax. - skip_softmax_threshold_scale = self.skip_softmax_threshold_scale - if skip_softmax_threshold_scale is not None and skip_softmax_threshold_scale <= 0.0: - skip_softmax_threshold_scale = None - - cute_dsl.cute_dsl_fmha_fwd( - q.flatten(0, 1).contiguous(), - k.flatten(0, 1).contiguous(), - v.flatten(0, 1).contiguous(), - out.flatten(0, 1), - qo_indptr=qo_indptr, - kv_indptr=kv_indptr, - is_causal=is_causal, - sm_scale=self.scale, - lse=lse.flatten(0, 1).contiguous(), - scale_q=kwargs.get("scale_q", 1.0), - scale_k=kwargs.get("scale_k", 1.0), - scale_v=scale_v, - scale_o=kwargs.get("scale_o", 1.0), - max_qo_len=seq_len_q, - max_kv_len=seq_len_kv, - is_persistent=False, - skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale, - ) - return out, lse - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - *, - attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL, - gate_compress: Optional[torch.Tensor] = None, - gate_fine: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """ - Forward pass using CuTe DSL (NVIDIA kernels). - - Dimensions are derived from tensor shapes (NHD layout: [B, S, H, D]). - Dispatches to _forward_vsa when sparse_attention_config is set - (VSA sparse path); otherwise runs the dense cubins via forward_with_lse. - - Args: - q: Query tensor [batch_size, seq_len, num_heads, head_dim] - k: Key tensor [batch_size, seq_len_kv, num_kv_heads, head_dim] - v: Value tensor [batch_size, seq_len_kv, num_kv_heads, head_dim] - attention_mask: Attention mask type (CAUSAL or FULL) — dense path only. - gate_compress: VSA path only — G_c gate for the coarse branch. - gate_fine: VSA path only — G_f gate for the fine branch. None means - constant 1. - - Returns: - Output tensor [batch_size, seq_len, num_heads, head_dim] - """ - if self.sparse_attention_config is not None: - return self._forward_vsa(q, k, v, gate_compress=gate_compress, gate_fine=gate_fine) - output, _ = self.forward_with_lse(q, k, v, attention_mask=attention_mask, **kwargs) - return output - - def forward_with_lse( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL, - **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Forward pass returning both output and log-sum-exp (LSE). Dense path - only — the VSA sparse path does not produce an LSE. - - Returns: - output: [batch_size, seq_len, num_heads, head_dim] - lse: [batch_size, num_heads, seq_len] - log-sum-exp per query position, - always in float32. Used for numerically stable combination of - partial attention results in Attention2D parallelism. - """ - if self.sparse_attention_config is not None: - raise RuntimeError( - "CuTeDSLAttention.forward_with_lse() does not support the VSA " - "sparse path. Use forward() instead, or construct without " - "sparse_attention_config to use the dense path." - ) - q, k, v, is_causal, origin_dtype = self._prepare_inputs(q, k, v, attention_mask) - output, lse = self._fwd(q, k, v, is_causal, **kwargs) - if output.dtype != origin_dtype: - output = output.to(origin_dtype) - return output, lse.transpose(1, 2) # Dynamo can't guard on the module-level mutable global, so this read # runs in eager. @@ -459,7 +276,7 @@ def _get_vsa_inputs(self): ctx: Optional[VSAMetadata] = get_vsa_forward_context() if ctx is None: raise RuntimeError( - "CuTeDSLAttention._forward_vsa called without an active VSA forward context. " + "VSAAttention.forward called without an active VSA forward context. " "Wrap each transformer call with set_vsa_forward_context()." ) return ( @@ -472,16 +289,18 @@ def _get_vsa_inputs(self): ctx.vsa_sparsity, ) - def _forward_vsa( + def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *, - gate_compress: Optional[torch.Tensor], - gate_fine: Optional[torch.Tensor], + gate_compress: Optional[torch.Tensor] = None, + gate_fine: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: - """VSA forward: coarse mean-pool + fine block-sparse top-K. + """ + VSA forward: coarse mean-pool + fine block-sparse top-K. Args: q, k, v: [B, S, H, D] in original (un-tiled) token order. @@ -494,7 +313,7 @@ def _forward_vsa( """ if gate_compress is None: raise ValueError( - "CuTeDSLAttention VSA path requires gate_compress. " + "VSAAttention requires gate_compress. " "Ensure to_gate_compress is wired in the transformer block." ) @@ -577,12 +396,11 @@ def _forward_vsa( @classmethod def support_lse(cls) -> bool: - return True + return False @property def preferred_layout(self) -> AttentionTensorLayout: - """Return the preferred tensor layout for this backend.""" - return self._preferred_layout + return AttentionTensorLayout.NHD @classmethod def support_fused_qkv(cls) -> bool: diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py b/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py index de0c143e976d..01761cc1f513 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py @@ -129,8 +129,12 @@ def create_attention( ) kwargs["attention_metadata_state"] = attention_metadata_state if backend.upper() == "CUTEDSL" and attention_config is not None: - # CuTeDSLAttention dispatches dense / VSA based on this sub-config. - kwargs["sparse_attention_config"] = attention_config.sparse_attention_config + if attention_config.sparse_attention_config is not None: + # VSA sparse path: use VSAAttention + from .cute_dsl.vsa import VSAAttention + + attn_cls = VSAAttention + kwargs["sparse_attention_config"] = attention_config.sparse_attention_config return attn_cls( layer_idx=layer_idx, diff --git a/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py b/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py index d079b79ae1e5..500afbe99d15 100644 --- a/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py +++ b/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py @@ -142,7 +142,7 @@ def test_vsa_topk_collapses_to_dense_at_sparsity_zero(): ) def test_vsa_tile_untile_roundtrip(latent_shape): """VSAPreprocessor.tile then .untile must losslessly reproduce the input.""" - from tensorrt_llm._torch.visual_gen.attention_backend.cute_dsl import VSAPreprocessor + from tensorrt_llm._torch.visual_gen.attention_backend.cute_dsl.vsa import VSAPreprocessor device = torch.device("cuda") dtype = torch.bfloat16 From 1856befce60fd3012eef90ea1a7789b996bd7b4f Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Tue, 9 Jun 2026 16:39:41 -0700 Subject: [PATCH 07/14] address code comments Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- docs/source/models/visual-generation.md | 75 ++++- .../attention_backend/cute_dsl/__init__.py | 1 - .../attention_backend/cute_dsl/vsa.py | 7 +- .../block_sparse_attn_dsl_fwd.py | 44 ++- .../visual_gen/models/wan/pipeline_wan.py | 26 +- .../visual_gen/models/wan/pipeline_wan_i2v.py | 25 +- .../_torch/visual_gen/modules/attention.py | 7 + tensorrt_llm/visual_gen/params.py | 7 - .../test_lists/test-db/l0_b200.yml | 1 + .../multi_gpu/test_wan_vsa_ulysses.py | 102 ++++--- .../visual_gen/test_attention_cute_dsl_vsa.py | 117 +++++++- .../visual_gen/test_wan_vsa_pipeline.py | 266 ++++++++++++++++++ 12 files changed, 567 insertions(+), 111 deletions(-) create mode 100644 tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py diff --git a/docs/source/models/visual-generation.md b/docs/source/models/visual-generation.md index b6f125ed027c..0f26d051a781 100644 --- a/docs/source/models/visual-generation.md +++ b/docs/source/models/visual-generation.md @@ -27,6 +27,7 @@ TensorRT-LLM **VisualGen** provides a unified inference stack for diffusion mode | `black-forest-labs/FLUX.2-dev` | Text-to-Image | | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` | Text-to-Video | | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | Text-to-Video | +| `FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers` | Text-to-Video (VSA) | | `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` | Image-to-Video | | `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` | Image-to-Video | | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | Text-to-Video | @@ -42,19 +43,22 @@ Models are auto-detected from the checkpoint directory. Diffusers-format models ### Feature Matrix -| Model | FP8 blockwise | NVFP4 | TeaCache | CFG Parallelism | Ulysses Parallelism | Parallel VAE | CUDA Graph | torch.compile | trtllm-serve | Attention2D | Ring Attention | Tensor Parallelism | -|---|---|---|---|---|---|---|---|---|---|--|--|--| -| **FLUX.1** | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | -| **FLUX.2** | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | -| **Wan 2.1** | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | -| **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | -| **LTX-2** | Yes | Yes | No | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No | -| **Qwen-Image** [^2] | Yes | Yes | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | -| **Cosmos3** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | No | No | Yes | +| Model | FP8 blockwise | NVFP4 | TeaCache | CFG Parallelism | Ulysses Parallelism | Parallel VAE | CUDA Graph | torch.compile | trtllm-serve | Attention2D | Ring Attention | Tensor Parallelism | VSA | +|---|---|---|---|---|---|---|---|---|---|--|--|--|--| +| **FLUX.1** | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | No | +| **FLUX.2** | Yes | Yes | Yes | No [^1] | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | No | +| **Wan 2.1** | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | +| **Wan 2.1 VSA** [^3] | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | Yes | Yes | +| **Wan 2.2** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | +| **LTX-2** | Yes | Yes | No | Yes | Yes | No | No | Yes | Yes | Yes | Yes | No | No | +| **Qwen-Image** [^2] | Yes | Yes | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | No | +| **Cosmos3** | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | No | No | Yes | No | [^1]: FLUX models use embedded guidance and do not have a separate negative prompt path, so CFG parallelism is not applicable. -[^2]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required. +[^2]: Qwen-Image ships a native BF16 implementation with per-module numerical parity vs `diffusers.QwenImagePipeline` (cosine >= 0.999 on the full 20B transformer) and `trtllm-serve` / `/v1/images/generations` support. FP8 blockwise and NVFP4 use VisualGen dynamic quantization from BF16 checkpoints; no pre-quantized checkpoint is required. + +[^3]: `FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers` — VSA-fine-tuned checkpoint with learned sparse-attention gates. Requires `CUTEDSL` on Blackwell sm_100+ (falls back to dense SDPA on older hardware). Ring and Attention2D not supported (no LSE output); Ulysses supported. ## Quick Start @@ -166,6 +170,57 @@ cache_config: The `teacache_thresh` parameter controls the similarity threshold. Cache-DiT is also supported via `cache_backend: cache_dit` with its own set of knobs (see `CacheDiTConfig`). +### Video Sparse Attention (VSA) + +VSA reduces the compute cost of self-attention in video diffusion models by selectively attending to only the most relevant spatial-temporal blocks. It uses a two-branch design: a lightweight coarse mean-pool branch computes block-level attention scores to identify the top-K most relevant token blocks, then a fine branch runs a block-sparse CuTe kernel over only those blocks. The two outputs are blended with learned gates. + +**Requirements:** +- VSA-fine-tuned checkpoint: [`FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers`](https://huggingface.co/FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers). Standard Wan checkpoints do not have the learned VSA gates. +- Blackwell GPU (sm_100+) for the CuTe JIT kernel. Falls back to dense SDPA on older hardware with no accuracy loss. +- `CUTEDSL` attention backend. +- BF16 or FP16 dtype; `head_dim = 128`; multi-head attention (MHA) — GQA/MQA models are not supported. +- Not compatible with Ring attention or Attention2D (VSA does not produce per-split LSE). Ulysses is supported. + +**`vsa_sparsity`** controls the fraction of K/V blocks skipped in the fine branch (0.0 = dense, 0.9 = 90% blocks skipped). Higher sparsity gives more speedup at the cost of some quality. + +Python API: + +```python +from tensorrt_llm import VisualGenArgs +from tensorrt_llm.visual_gen.args import AttentionConfig, VideoSparseAttentionConfig + +args = VisualGenArgs( + model="FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers", + attention_config=AttentionConfig( + backend="CUTEDSL", + sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=0.9), + ), +) +``` + +YAML (for use with `--visual_gen_args` or `trtllm-serve`): + +```yaml +attention_config: + backend: CUTEDSL + sparse_attention_config: + algorithm: vsa + vsa_sparsity: 0.90 +``` + +Ready-to-use configs are provided in [`examples/visual_gen/configs/`](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/visual_gen/configs): +- `wan-vsa-1gpu.yaml` — single-GPU VSA with NVFP4 quantization +- `wan-vsa-4gpu.yaml` — 4-GPU VSA with Ulysses parallelism +- `wan-vsa-8gpu.yaml` — 8-GPU VSA with Ulysses parallelism + +Run with the existing Wan example script: + +```bash +python examples/visual_gen/models/wan_t2v.py \ + --model FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers \ + --visual_gen_args examples/visual_gen/configs/wan-vsa-1gpu.yaml +``` + ### Multi-GPU Parallelism Configured under `VisualGenArgs.parallel_config`. Modes can be combined: diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/__init__.py b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/__init__.py index 11b8858db56f..292be3036d5b 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/__init__.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/__init__.py @@ -17,7 +17,6 @@ fmha.py — CuTeDSLAttention (dense cubin path, head_dim=128) vsa.py — VSAAttention (Video Sparse Attention, CuTe JIT + SDPA fallback) - _common.py — shared CuTe DSL plumbing """ from .fmha import CuTeDSLAttention, _cute_dsl_import_error diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/vsa.py b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/vsa.py index 9cd2c8e1b414..741f82097ca2 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/vsa.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/vsa.py @@ -256,7 +256,7 @@ def __init__( self, layer_idx: int = 0, num_heads: int = 8, - head_dim: int = 64, + head_dim: int = 128, num_kv_heads: Optional[int] = None, dtype: Optional[torch.dtype] = None, sparse_attention_config=None, @@ -266,6 +266,11 @@ def __init__( self.num_heads = num_heads self.head_dim = head_dim self.num_kv_heads = num_kv_heads or num_heads + assert self.num_kv_heads == self.num_heads, ( + f"VSA coarse mean-pool assumes MHA (num_kv_heads == num_heads), " + f"got num_kv_heads={self.num_kv_heads}, num_heads={self.num_heads}. " + f"GQA/MQA is not supported." + ) self.dtype = dtype self.sparse_attention_config = sparse_attention_config diff --git a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py index f0d548992c2a..a2d814b2dbc8 100644 --- a/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py +++ b/tensorrt_llm/_torch/visual_gen/cute_dsl_kernels/blackwell/video_sparse_attention/block_sparse_attn_dsl_fwd.py @@ -1535,6 +1535,34 @@ def clamp( local_max[0] = cute.arch.fmax(local_max[0], local_max[1]) return ptx.max3f(local_max[0], local_max[2], local_max[3]) + @cute.jit + def cal_local_max( + self, + tensor: cute.Tensor, + ) -> cutlass.Float32: + # Full-block fast path: no column is masked, so read directly without + # the per-element bound test / -inf write that mask_then_cal_local_max + # performs. Mirrors the masked variant's reduction structure exactly. + if cutlass.const_expr(cute.size(tensor, mode=[0]) < 8): + _max = -cutlass.Float32.inf + for i in cutlass.range_constexpr(0, cute.size(tensor, mode=[0]), 2): + _max = ptx.max3f(_max, tensor[i], tensor[i + 1]) + return _max + else: + local_max = [ + ptx.max3f(tensor[0], tensor[1], -cutlass.Float32.inf), + ptx.max3f(tensor[2], tensor[3], -cutlass.Float32.inf), + ptx.max3f(tensor[4], tensor[5], -cutlass.Float32.inf), + ptx.max3f(tensor[6], tensor[7], -cutlass.Float32.inf), + ] + for i in cutlass.range_constexpr(8, cute.size(tensor, mode=[0]), 8): + local_max[0] = ptx.max3f(local_max[0], tensor[i], tensor[i + 1]) + local_max[1] = ptx.max3f(local_max[1], tensor[i + 2], tensor[i + 3]) + local_max[2] = ptx.max3f(local_max[2], tensor[i + 4], tensor[i + 5]) + local_max[3] = ptx.max3f(local_max[3], tensor[i + 6], tensor[i + 7]) + local_max[0] = cute.arch.fmax(local_max[0], local_max[1]) + return ptx.max3f(local_max[0], local_max[2], local_max[3]) + @cute.jit def cal_local_sum( self, @@ -1621,7 +1649,18 @@ def softmax_step( tCrS_ld, ) # calculate P - _max = self.mask_then_cal_local_max(tCrS_ld, n_size) + # Full-block fast path: when the KV block is full (n_size == tile width, + # i.e. BLOCK==64 at the focus shapes) the variable-block mask masks + # nothing every iteration. Skip the per-element bound test / -inf write + # and take the unmasked local-max path. The masked path is preserved for + # the partial-block case (n_size < tile width) so correctness holds. + # NOTE: predeclare _max before the dynamic branch — CuTe DSL does not + # propagate variables first bound inside dynamic control flow. + _max = -cutlass.Float32.inf + if n_size >= cute.size(tCrS_ld, mode=[0]): + _max = self.cal_local_max(tCrS_ld) + else: + _max = self.mask_then_cal_local_max(tCrS_ld, n_size) _max_safe, _acc_scale = self.update_row_max( max_new=_max, is_first=is_first, @@ -2151,7 +2190,8 @@ def correction( ): tidx = cute.arch.thread_idx()[0] % (self.threads_per_warp * len(self.correction_warp_ids)) in_which = cute.arch.make_warp_uniform(tidx // self.block_m) - corr_ld_inst: int = 32 + # A-tmem: widen correction TMEM<->reg copy 32->64 cols (halve LDTM/STTM count) + corr_ld_inst: int = 64 corr_ld_repeat: int = cute.ceil_div(self.mma_tiler_pv[1], corr_ld_inst) corr_copy_atom_t2r = cute.make_copy_atom( diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index 4928804a0df5..00aadf1fed3b 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -378,7 +378,6 @@ def infer(self, req): seed=req.params.seed, max_sequence_length=req.params.max_sequence_length, image=image, - flow_shift=req.params.flow_shift, ) @nvtx_range("WanPipeline.forward") @@ -397,7 +396,6 @@ def forward( seed: int = 42, max_sequence_length: int = 512, image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None, - flow_shift: Optional[float] = None, ): pipeline_start = time.time() timer = CudaPhaseTimer() @@ -486,29 +484,7 @@ def forward( latents = self._prepare_latents(batch_size, height, width, num_frames, generator) logger.debug(f"Latents shape: {latents.shape}") - # Apply an explicit user flow_shift override; otherwise keep the checkpoint - # scheduler default so output matches the reference HuggingFace pipeline. - sched_cfg = self.scheduler.config - shift_key = None - orig_shift = None - if flow_shift is not None: - shift_key = ( - "shift" - if "shift" in sched_cfg - else "flow_shift" - if "flow_shift" in sched_cfg - else None - ) - if shift_key is not None and sched_cfg[shift_key] != flow_shift: - orig_shift = sched_cfg[shift_key] - logger.info(f"flow_shift: {orig_shift} -> {flow_shift} (user)") - self.scheduler.register_to_config(**{shift_key: flow_shift}) - - try: - self.scheduler.set_timesteps(num_inference_steps, device=self.device) - finally: - if orig_shift is not None: - self.scheduler.register_to_config(**{shift_key: orig_shift}) + self.scheduler.set_timesteps(num_inference_steps, device=self.device) # Wan2.2 A14B: Calculate boundary timestep for two-stage denoising boundary_timestep = None diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index 8d41ec8df1da..117b3cf34282 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -412,7 +412,6 @@ def infer(self, req): seed=req.params.seed, max_sequence_length=req.params.max_sequence_length, last_image=last_image, - flow_shift=req.params.flow_shift, ) @torch.no_grad() @@ -431,7 +430,6 @@ def forward( seed: int = 42, max_sequence_length: int = 512, last_image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None, - flow_shift: Optional[float] = None, ): pipeline_start = time.time() timer = CudaPhaseTimer() @@ -539,28 +537,7 @@ def forward( batch_size, image, height, width, num_frames, generator, last_image ) - # Apply an explicit user flow_shift override; otherwise keep the checkpoint scheduler default. - sched_cfg = self.scheduler.config - shift_key = None - orig_shift = None - if flow_shift is not None: - shift_key = ( - "shift" - if "shift" in sched_cfg - else "flow_shift" - if "flow_shift" in sched_cfg - else None - ) - if shift_key is not None and sched_cfg[shift_key] != flow_shift: - orig_shift = sched_cfg[shift_key] - logger.info(f"flow_shift: {orig_shift} -> {flow_shift} (user)") - self.scheduler.register_to_config(**{shift_key: flow_shift}) - - try: - self.scheduler.set_timesteps(num_inference_steps, device=self.device) - finally: - if orig_shift is not None: - self.scheduler.register_to_config(**{shift_key: orig_shift}) + self.scheduler.set_timesteps(num_inference_steps, device=self.device) # Wan2.2: Calculate boundary timestep for two-stage denoising boundary_timestep = None diff --git a/tensorrt_llm/_torch/visual_gen/modules/attention.py b/tensorrt_llm/_torch/visual_gen/modules/attention.py index 2820095d9a47..8b06d04ef0e0 100644 --- a/tensorrt_llm/_torch/visual_gen/modules/attention.py +++ b/tensorrt_llm/_torch/visual_gen/modules/attention.py @@ -92,6 +92,7 @@ def __init__( # Select compute backend (orthogonal to parallelism) vgm = config.visual_gen_mapping ulysses_size = vgm.ulysses_size if vgm else 1 + ring_size = vgm.ring_size if vgm else 1 attn2d_size = (vgm.attn2d_row_size * vgm.attn2d_col_size) if vgm else 1 base_backend = config.attention.backend _sa_cfg = config.attention.sparse_attention_config @@ -113,6 +114,12 @@ def __init__( f"with Attention2D (attn2d_size={attn2d_size}). Use ulysses or cfg " f"parallelism instead." ) + if _is_vsa and ring_size > 1: + raise ValueError( + f"VSA needs the full token sequence per rank, so it is incompatible " + f"with Ring attention (ring_size={ring_size}). Use ulysses or cfg " + f"parallelism instead." + ) self.attn_backend = backend_name self.qk_norm = qk_norm self.qk_norm_mode = qk_norm_mode diff --git a/tensorrt_llm/visual_gen/params.py b/tensorrt_llm/visual_gen/params.py index 41fa444ec9a6..87754a56c956 100644 --- a/tensorrt_llm/visual_gen/params.py +++ b/tensorrt_llm/visual_gen/params.py @@ -46,13 +46,6 @@ class VisualGenParams(StrictBaseModel): max_sequence_length: Optional[int] = Field( default=None, description="Max tokens for text encoding." ) - flow_shift: Optional[float] = Field( - default=None, - description=( - "Override the scheduler's flow-matching shift. None = pipeline's " - "per-variant recommended default. Currently honored only by the Wan pipelines." - ), - ) seed: int = Field(default=42, description="Random seed for reproducibility.") # Video diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 3c8db98fb552..02bf1f1d5c04 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -206,6 +206,7 @@ l0_b200: - unittest/_torch/visual_gen/test_wan22_i2v_pipeline.py - unittest/_torch/visual_gen/test_wan22_t2v_pipeline.py - unittest/_torch/visual_gen/test_wan22_ti2v_5b_pipeline.py + - unittest/_torch/visual_gen/test_wan_vsa_pipeline.py - unittest/_torch/visual_gen/test_wan21_i2v_teacache.py - unittest/_torch/visual_gen/test_wan21_t2v_teacache.py - unittest/_torch/visual_gen/test_wan_transformer.py diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py index 4133c3dc1a55..9fb3487fe414 100644 --- a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py @@ -12,9 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Multi-GPU tests for VSA + Ulysses sequence parallelism.""" +"""Multi-GPU tests for VSA + Ulysses sequence parallelism (ulysses=2, cfg=2, 4 GPUs).""" -import functools import math import os @@ -35,6 +34,7 @@ VSAMetadataBuilder, set_vsa_forward_context, ) + from tensorrt_llm._torch.visual_gen.mapping import VisualGenMapping from tensorrt_llm._utils import get_free_port from tensorrt_llm.visual_gen.args import VideoSparseAttentionConfig @@ -96,55 +96,73 @@ def run_test_in_distributed(world_size: int, test_fn: Callable): ) -# (8,8,8) latent -> 512 tokens (256/rank at P=2); the (4,4,4) tile gives 8 cubes -# (even, as the paired-block kernel needs) of 64 tokens (the kernel block_size). +_ULYSSES_SIZE = 2 +_CFG_SIZE = 2 +_VSA_SPARSITY = 0.5 + +# (8,8,8) latent -> 512 tokens (256/ulysses-rank at P=2); the (4,4,4) tile gives +# 8 cubes (even, as the paired-block kernel needs) of 64 tokens (kernel block_size). _DIT_SEQ_SHAPE = (8, 8, 8) _VSA_PATCH_SIZE = (1, 1, 1) _HEAD_DIM = 128 # CuTe VSA fine-stage kernel requires head_dim == 128 _HEADS_PER_RANK = 4 -def _make_vsa_backend(num_heads: int, vsa_sparsity: float) -> "CuTeDSLAttention": +def _make_vsa_backend(num_heads: int) -> "CuTeDSLAttention": """CUTEDSL backend on the VSA path; effective sparsity comes from the forward context.""" return CuTeDSLAttention( layer_idx=0, num_heads=num_heads, head_dim=_HEAD_DIM, - sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=vsa_sparsity), + sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=_VSA_SPARSITY), ) -def _build_full_seq_vsa_metadata(vsa_sparsity: float, device: torch.device): +def _build_full_seq_vsa_metadata(device: torch.device): """VSAMetadata for the full sequence — identical on every rank after Ulysses all-to-all.""" builder = VSAMetadataBuilder() return builder.build( current_timestep=0, raw_latent_shape=_DIT_SEQ_SHAPE, patch_size=_VSA_PATCH_SIZE, - vsa_sparsity=vsa_sparsity, + vsa_sparsity=_VSA_SPARSITY, device=device, ) +def _make_vgm(rank: int, world_size: int) -> "VisualGenMapping": + from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl + + DeviceMeshTopologyImpl.device_mesh = None + return VisualGenMapping( + world_size=world_size, + rank=rank, + cfg_size=_CFG_SIZE, + ulysses_size=_ULYSSES_SIZE, + ) + + # ============================================================================= # Test logic functions (module-level so mp.spawn can pickle them) # ============================================================================= -def _logic_vsa_ulysses_forward(rank, world_size, *, vsa_sparsity: float): - """Forward pass: output shape correct and finite.""" +def _logic_vsa_ulysses_forward(rank, world_size): + """Forward pass: output shape correct and finite (ulysses=2, cfg=2).""" + vgm = _make_vgm(rank, world_size) + ulysses_rank = vgm.ulysses_rank + batch = 1 seq_full = math.prod(_DIT_SEQ_SHAPE) - assert seq_full % world_size == 0 - seq_per_rank = seq_full // world_size - num_heads = world_size * _HEADS_PER_RANK + seq_per_rank = seq_full // _ULYSSES_SIZE + num_heads = _ULYSSES_SIZE * _HEADS_PER_RANK device = torch.device(f"cuda:{rank}") torch.manual_seed(42) torch.cuda.manual_seed_all(42) - inner = _make_vsa_backend(num_heads // world_size, vsa_sparsity) - attention = UlyssesAttention(inner_backend=inner, process_group=None) + inner = _make_vsa_backend(num_heads // _ULYSSES_SIZE) + attention = UlyssesAttention(inner_backend=inner, process_group=vgm.ulysses_group) # Ulysses input: sequence-sharded, head-full [B, S/P, H, D]. shape = (batch, seq_per_rank, num_heads, _HEAD_DIM) @@ -154,23 +172,28 @@ def _logic_vsa_ulysses_forward(rank, world_size, *, vsa_sparsity: float): gate_compress = torch.randn(shape, device=device, dtype=torch.bfloat16) gate_fine = torch.randn(shape, device=device, dtype=torch.bfloat16) - metadata = _build_full_seq_vsa_metadata(vsa_sparsity, device) + metadata = _build_full_seq_vsa_metadata(device) with set_vsa_forward_context(metadata): output = attention(q, k, v, gate_compress=gate_compress, gate_fine=gate_fine) assert output.shape == (batch, seq_per_rank, num_heads, _HEAD_DIM), ( - f"Rank {rank}: expected {(batch, seq_per_rank, num_heads, _HEAD_DIM)}, got {output.shape}" + f"Rank {rank} (ulysses={ulysses_rank}, cfg={vgm.cfg_rank}): " + f"expected {(batch, seq_per_rank, num_heads, _HEAD_DIM)}, got {output.shape}" + ) + assert torch.isfinite(output).all(), ( + f"Rank {rank} (ulysses={ulysses_rank}, cfg={vgm.cfg_rank}): Inf/NaN in output" ) - assert torch.isfinite(output).all(), f"Rank {rank}: Inf/NaN in output" -def _logic_vsa_ulysses_vs_reference(rank, world_size, *, vsa_sparsity: float): +def _logic_vsa_ulysses_vs_reference(rank, world_size): """Each rank's Ulysses+VSA output matches the single-GPU VSA reference's sequence slice.""" + vgm = _make_vgm(rank, world_size) + ulysses_rank = vgm.ulysses_rank + batch = 1 seq_full = math.prod(_DIT_SEQ_SHAPE) - assert seq_full % world_size == 0 - seq_per_rank = seq_full // world_size - num_heads = world_size * _HEADS_PER_RANK + seq_per_rank = seq_full // _ULYSSES_SIZE + num_heads = _ULYSSES_SIZE * _HEADS_PER_RANK device = torch.device(f"cuda:{rank}") torch.manual_seed(42) @@ -183,25 +206,25 @@ def _logic_vsa_ulysses_vs_reference(rank, world_size, *, vsa_sparsity: float): gate_c_full = torch.randn(full_shape, device=device, dtype=torch.bfloat16) gate_f_full = torch.randn(full_shape, device=device, dtype=torch.bfloat16) - sl = slice(rank * seq_per_rank, (rank + 1) * seq_per_rank) + sl = slice(ulysses_rank * seq_per_rank, (ulysses_rank + 1) * seq_per_rank) q_shard = q_full[:, sl].contiguous() k_shard = k_full[:, sl].contiguous() v_shard = v_full[:, sl].contiguous() gate_c_shard = gate_c_full[:, sl].contiguous() gate_f_shard = gate_f_full[:, sl].contiguous() - metadata = _build_full_seq_vsa_metadata(vsa_sparsity, device) + metadata = _build_full_seq_vsa_metadata(device) - # Ulysses path: sharded input, head-sharded inner backend. - inner = _make_vsa_backend(num_heads // world_size, vsa_sparsity) - attention = UlyssesAttention(inner_backend=inner, process_group=None) + # Ulysses path: sequence-sharded input, head-sharded inner backend. + inner = _make_vsa_backend(num_heads // _ULYSSES_SIZE) + attention = UlyssesAttention(inner_backend=inner, process_group=vgm.ulysses_group) with set_vsa_forward_context(metadata): ulysses_out = attention( q_shard, k_shard, v_shard, gate_compress=gate_c_shard, gate_fine=gate_f_shard ) # Single-GPU VSA reference over the full sequence. - ref_attn = _make_vsa_backend(num_heads, vsa_sparsity) + ref_attn = _make_vsa_backend(num_heads) with set_vsa_forward_context(metadata): ref_out = ref_attn.forward( q_full, k_full, v_full, gate_compress=gate_c_full, gate_fine=gate_f_full @@ -216,7 +239,10 @@ def _logic_vsa_ulysses_vs_reference(rank, world_size, *, vsa_sparsity: float): ref_shard.reshape(-1).float(), dim=0, ).item() - assert cos_sim > 0.990, f"Rank {rank}: cosine similarity {cos_sim:.6f} is below threshold 0.990" + assert cos_sim > 0.990, ( + f"Rank {rank} (ulysses={ulysses_rank}, cfg={vgm.cfg_rank}): " + f"cosine similarity {cos_sim:.6f} below threshold 0.990" + ) torch.testing.assert_close(ulysses_out, ref_shard, atol=2e-2, rtol=2e-2) @@ -226,21 +252,19 @@ def _logic_vsa_ulysses_vs_reference(rank, world_size, *, vsa_sparsity: float): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -class TestWanVsaUlysses: - """VSA (CUTEDSL) attention backend combined with Ulysses sequence parallelism.""" +class TestWanVsaUlyssesCfg: + """VSA (CUTEDSL) with Ulysses sequence parallelism and CFG parallelism (ulysses=2, cfg=2).""" - @pytest.mark.parametrize("vsa_sparsity", [0.0, 0.5]) - def test_vsa_ulysses_forward(self, vsa_sparsity: float): + def test_vsa_ulysses_forward(self): run_test_in_distributed( - world_size=2, - test_fn=functools.partial(_logic_vsa_ulysses_forward, vsa_sparsity=vsa_sparsity), + world_size=_ULYSSES_SIZE * _CFG_SIZE, + test_fn=_logic_vsa_ulysses_forward, ) - @pytest.mark.parametrize("vsa_sparsity", [0.0, 0.5, 0.75]) - def test_vsa_ulysses_vs_reference(self, vsa_sparsity: float): + def test_vsa_ulysses_vs_reference(self): run_test_in_distributed( - world_size=2, - test_fn=functools.partial(_logic_vsa_ulysses_vs_reference, vsa_sparsity=vsa_sparsity), + world_size=_ULYSSES_SIZE * _CFG_SIZE, + test_fn=_logic_vsa_ulysses_vs_reference, ) diff --git a/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py b/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py index 500afbe99d15..86035a3bd8a0 100644 --- a/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py +++ b/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py @@ -309,8 +309,6 @@ def test_cute_kernel_matches_ref_with_independent_indices(): q, k, v, q2k_idx, q2k_num, variable_block_sizes ) - # Manual fp32 masked-softmax: bf16 inputs + an fp32 -inf mask through SDPA - # mishandle the masked region at this shape, so compute it explicitly. scale = 1.0 / (D**0.5) scores = (q.float() @ k.float().transpose(-2, -1)) * scale scores = scores + attn_mask @@ -327,3 +325,118 @@ def test_cute_kernel_matches_ref_with_independent_indices(): f"reference: max_diff={max_diff:.3e}, mean_diff={mean_diff:.3e} " f"(rtol={rtol}, atol={atol}, pair_mismatch={pair_mismatch})" ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="kernel test needs CUDA") +def test_cute_kernel_50pct_sparsity_quality_vs_dense(): + """50% sparse CuTe kernel with score-based topk should stay close to dense SDPA.""" + + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, + block_sparse_attn_from_indices_cute, + is_cute_supported, + ) + + if not CUTE_AVAILABLE: + pytest.skip("cuda-bindings or cutlass-dsl not importable") + + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(0) + + B, H, num_cubes, D = 1, 4, 16, 128 + block_size = 64 + topk = num_cubes // 2 + seq_len = num_cubes * block_size + + q = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + k = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + v = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + + if not is_cute_supported(q): + pytest.skip("CuTe path needs sm_100+ Blackwell (current device unsupported)") + + q_blocks = q.reshape(B, H, num_cubes, block_size, D).mean(dim=3) + k_blocks = k.reshape(B, H, num_cubes, block_size, D).mean(dim=3) + scale = D**-0.5 + block_scores = torch.einsum("bhqd,bhkd->bhqk", q_blocks.float(), k_blocks.float()) * scale + q2k_idx = block_scores.topk(topk, dim=-1).indices.to(torch.int32).contiguous() + + q2k_num = torch.full((B, H, num_cubes), topk, dtype=torch.int32, device=device) + variable_block_sizes = torch.full((num_cubes,), block_size, dtype=torch.int32, device=device) + + out_sparse, _lse = block_sparse_attn_from_indices_cute( + q, k, v, q2k_idx, q2k_num, variable_block_sizes + ) + out_dense = F.scaled_dot_product_attention(q, k, v) + + cos_sim = F.cosine_similarity( + out_sparse.float().reshape(-1), out_dense.float().reshape(-1), dim=0 + ).item() + print(f"\n 50% sparse (score-based topk) vs dense SDPA cos_sim: {cos_sim:.4f}") + + assert cos_sim >= 0.65, ( + f"50% sparse CuTe kernel deviated too far from dense SDPA: cos_sim={cos_sim:.4f} < 0.65" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="kernel test needs CUDA") +@pytest.mark.parametrize( + "num_cubes", + [1, 3, 9], + ids=["1cube_odd", "3cubes_odd", "9cubes_odd"], +) +def test_cute_kernel_odd_num_cubes_correctness(num_cubes): + """CuTe kernel with odd num_cubes must match dense SDPA (last Q-block has no pair).""" + from tensorrt_llm._torch.visual_gen.cute_dsl_kernels.blackwell.video_sparse_attention import ( + CUTE_AVAILABLE, + block_sparse_attn_from_indices_cute, + is_cute_supported, + ) + + if not CUTE_AVAILABLE: + pytest.skip("cuda-bindings or cutlass-dsl not importable") + + assert num_cubes % 2 == 1, f"pre-condition: num_cubes={num_cubes} must be odd" + + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(0) + + B, H, D = 1, 4, 128 + block_size = 64 + seq_len = num_cubes * block_size + + q = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + k = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + v = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) + + if not is_cute_supported(q): + pytest.skip("CuTe path needs sm_100+ Blackwell (current device unsupported)") + + topk = num_cubes + q2k_idx = ( + torch.arange(num_cubes, device=device, dtype=torch.int32) + .view(1, 1, 1, num_cubes) + .expand(B, H, num_cubes, topk) + .contiguous() + ) + q2k_num = torch.full((B, H, num_cubes), topk, dtype=torch.int32, device=device) + variable_block_sizes = torch.full((num_cubes,), block_size, dtype=torch.int32, device=device) + + out_kernel, _lse = block_sparse_attn_from_indices_cute( + q, k, v, q2k_idx, q2k_num, variable_block_sizes + ) + out_ref = F.scaled_dot_product_attention(q, k, v) + + assert torch.isfinite(out_kernel).all(), ( + f"CuTe kernel produced non-finite output for odd num_cubes={num_cubes}" + ) + + max_diff = (out_kernel - out_ref).abs().max().item() + mean_diff = (out_kernel - out_ref).abs().mean().item() + rtol, atol = 1e-2, 1e-2 + assert torch.allclose(out_kernel, out_ref, rtol=rtol, atol=atol), ( + f"CuTe kernel deviated from dense SDPA for odd num_cubes={num_cubes}: " + f"max_diff={max_diff:.3e}, mean_diff={mean_diff:.3e} (rtol={rtol}, atol={atol})" + ) diff --git a/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py b/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py new file mode 100644 index 000000000000..4706c91d7e69 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py @@ -0,0 +1,266 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Correctness tests for VSA T2V pipeline with Video Sparse Attention (VSA). + +Verifies >= 0.95 cosine similarity on decoded video frames against the +dense TRTLLM reference when VSA is enabled at sparsity=0.0 (dense). + +Models: + - FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers (720x1280, 9 frames) + +Run: + pytest tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py -v -s + +Override checkpoint path: + DIFFUSION_MODEL_PATH_WAN21_VSA=/path/to/vsa \\ + pytest tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py -v -s +""" + +import gc +import os +from pathlib import Path + +os.environ["TLLM_DISABLE_MPI"] = "1" + +import pytest +import torch +import torch.nn.functional as F + +from tensorrt_llm._torch.visual_gen.attention_backend.cute_dsl import _cute_dsl_import_error +from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader +from tensorrt_llm.visual_gen.args import ( + AttentionConfig, + TorchCompileConfig, + VideoSparseAttentionConfig, + VisualGenArgs, +) + +_cute_dsl_available = _cute_dsl_import_error is None + + +@pytest.fixture(autouse=True, scope="module") +def _cleanup_mpi_env(): + yield + os.environ.pop("TLLM_DISABLE_MPI", None) + + +# ============================================================================ +# Path helpers +# ============================================================================ + + +def _llm_models_root() -> str: + """Return LLM_MODELS_ROOT path if set in env, assert when it's set but not a valid path.""" + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return str(root) + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or os.path.join(_llm_models_root(), default_name) + + +WAN21_VSA_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN21_VSA", "Wan2.1-VSA-T2V-14B-720P-Diffusers") + +# ============================================================================ +# Test constants +# ============================================================================ + +PROMPT = "A cat sitting on a sunny windowsill watching birds outside." +NEGATIVE_PROMPT = "" +NUM_STEPS = 4 +SEED = 42 +COS_SIM_THRESHOLD = 0.95 + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _load_vsa_pipeline(checkpoint_path: str, vsa_sparsity: float = 0.0): + """Load TRTLLM WanPipeline with CUTEDSL + VSA backend.""" + if not os.path.exists(checkpoint_path): + pytest.skip(f"Checkpoint not found: {checkpoint_path}") + if not _cute_dsl_available: + pytest.skip(f"CUTEDSL not available (requires Blackwell GPU): {_cute_dsl_import_error}") + args = VisualGenArgs( + model=checkpoint_path, + attention_config=AttentionConfig( + backend="CUTEDSL", + sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=vsa_sparsity), + ), + torch_compile_config=TorchCompileConfig(enable=False), + ) + return PipelineLoader(args).load(skip_warmup=True) + + +def _load_dense_pipeline(checkpoint_path: str): + """Load TRTLLM WanPipeline with default dense attention (no VSA).""" + if not os.path.exists(checkpoint_path): + pytest.skip(f"Checkpoint not found: {checkpoint_path}") + args = VisualGenArgs( + model=checkpoint_path, + torch_compile_config=TorchCompileConfig(enable=False), + ) + return PipelineLoader(args).load(skip_warmup=True) + + +def _capture_trtllm_video( + pipeline, + prompt: str, + negative_prompt: str, + height: int, + width: int, + num_frames: int, + num_inference_steps: int, + guidance_scale: float, + seed: int, +) -> torch.Tensor: + """Run full TRTLLM pipeline including VAE decode; return (T, H, W, C) float in [0, 1].""" + with torch.no_grad(): + result = pipeline.forward( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + seed=seed, + ) + video = result.video # (B, T, H, W, C) uint8 + return video.float() / 255.0 + + +def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float: + """Cosine similarity between two tensors (flattened to 1D, cast to float32 on CPU).""" + a_flat = a.float().cpu().reshape(-1) + b_flat = b.float().cpu().reshape(-1) + return F.cosine_similarity(a_flat.unsqueeze(0), b_flat.unsqueeze(0)).clamp(-1.0, 1.0).item() + + +def _assert_vsa_matches_dense( + checkpoint_path: str, + height: int, + width: int, + num_frames: int, + guidance_scale: float, + model_label: str, + vsa_sparsity: float = 0.0, +) -> None: + """Run VSA and dense TRTLLM pipelines sequentially, compare decoded video output.""" + # --- VSA (sparsity=0.0 is fully dense via CUTEDSL path) --- + vsa_pipe = _load_vsa_pipeline(checkpoint_path, vsa_sparsity=vsa_sparsity) + vsa_video = _capture_trtllm_video( + vsa_pipe, + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=NUM_STEPS, + guidance_scale=guidance_scale, + seed=SEED, + ) + del vsa_pipe + gc.collect() + torch.cuda.empty_cache() + + # --- Dense reference (default TRTLLM attention, no VSA) --- + dense_pipe = _load_dense_pipeline(checkpoint_path) + dense_video = _capture_trtllm_video( + dense_pipe, + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=NUM_STEPS, + guidance_scale=guidance_scale, + seed=SEED, + ) + del dense_pipe + gc.collect() + torch.cuda.empty_cache() + + # --- Compare --- + assert vsa_video.numel() == dense_video.numel(), ( + f"{model_label}: element count mismatch — " + f"VSA {vsa_video.shape} ({vsa_video.numel()}) vs " + f"dense {dense_video.shape} ({dense_video.numel()})" + ) + + cos_sim = _cosine_similarity(vsa_video, dense_video) + print(f"\n {model_label} cosine similarity: {cos_sim:.6f}") + assert cos_sim >= COS_SIM_THRESHOLD, ( + f"{model_label}: cosine similarity {cos_sim:.6f} < {COS_SIM_THRESHOLD}. " + f"VSA pipeline output diverges from the dense reference. " + f"Video shapes — VSA: {vsa_video.shape}, dense: {dense_video.shape}." + ) + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWanVsa14B_PipelineCorrectness: + """Wan2.1-VSA-T2V-14B correctness vs dense TRTLLM reference (720x1280, 9 frames). + + VSA at sparsity=0.0 routes through CUTEDSL; threshold is 0.95. + """ + + def test_cosine_similarity(self): + _assert_vsa_matches_dense( + checkpoint_path=WAN21_VSA_PATH, + height=720, + width=1280, + num_frames=9, + guidance_scale=5.0, + model_label="Wan2.1-VSA-T2V-14B+VSA", + vsa_sparsity=0.0, + ) + + +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWanVsaSparse: + """VSA at sparsity=0.9: config propagates, output is correctly shaped and finite.""" + + def test_sparse_vsa(self): + pipeline = _load_vsa_pipeline(WAN21_VSA_PATH, vsa_sparsity=0.9) + try: + attn_cfg = pipeline.model_config.attention + assert attn_cfg.backend == "CUTEDSL" + assert attn_cfg.sparse_attention_config.vsa_sparsity == 0.9 + + with torch.no_grad(): + result = pipeline.forward( + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=720, + width=1280, + num_frames=9, + num_inference_steps=NUM_STEPS, + guidance_scale=5.0, + seed=SEED, + ) + assert result.video.dim() == 5 + B, _T, H, W, C = result.video.shape + assert B == 1 and H == 720 and W == 1280 and C == 3 + assert torch.isfinite(result.video.float()).all() + assert pipeline.transformer.blocks[0].attn1.attn_backend == "CUTEDSL" + finally: + del pipeline + gc.collect() + torch.cuda.empty_cache() From b0d370b5809572ecfff01ddc528c1d24582d9d9c Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Wed, 10 Jun 2026 10:00:14 -0700 Subject: [PATCH 08/14] update VSA documentation Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- docs/source/models/visual-generation.md | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/docs/source/models/visual-generation.md b/docs/source/models/visual-generation.md index 0f26d051a781..81de20786611 100644 --- a/docs/source/models/visual-generation.md +++ b/docs/source/models/visual-generation.md @@ -178,7 +178,6 @@ VSA reduces the compute cost of self-attention in video diffusion models by sele - VSA-fine-tuned checkpoint: [`FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers`](https://huggingface.co/FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers). Standard Wan checkpoints do not have the learned VSA gates. - Blackwell GPU (sm_100+) for the CuTe JIT kernel. Falls back to dense SDPA on older hardware with no accuracy loss. - `CUTEDSL` attention backend. -- BF16 or FP16 dtype; `head_dim = 128`; multi-head attention (MHA) — GQA/MQA models are not supported. - Not compatible with Ring attention or Attention2D (VSA does not produce per-split LSE). Ulysses is supported. **`vsa_sparsity`** controls the fraction of K/V blocks skipped in the fine branch (0.0 = dense, 0.9 = 90% blocks skipped). Higher sparsity gives more speedup at the cost of some quality. @@ -208,18 +207,6 @@ attention_config: vsa_sparsity: 0.90 ``` -Ready-to-use configs are provided in [`examples/visual_gen/configs/`](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/visual_gen/configs): -- `wan-vsa-1gpu.yaml` — single-GPU VSA with NVFP4 quantization -- `wan-vsa-4gpu.yaml` — 4-GPU VSA with Ulysses parallelism -- `wan-vsa-8gpu.yaml` — 8-GPU VSA with Ulysses parallelism - -Run with the existing Wan example script: - -```bash -python examples/visual_gen/models/wan_t2v.py \ - --model FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers \ - --visual_gen_args examples/visual_gen/configs/wan-vsa-1gpu.yaml -``` ### Multi-GPU Parallelism From 3a352fe5e5bed39d494ed5d12b15ada78f0ca8c1 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Wed, 10 Jun 2026 10:10:33 -0700 Subject: [PATCH 09/14] remove even cubes gating Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py b/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py index 86035a3bd8a0..63e7fd2498ff 100644 --- a/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py +++ b/tests/unittest/_torch/visual_gen/test_attention_cute_dsl_vsa.py @@ -263,7 +263,6 @@ def test_cute_kernel_matches_ref_with_independent_indices(): block_size = 64 topk = num_cubes // 2 seq_len = num_cubes * block_size - assert num_cubes % 2 == 0, "num_cubes must be even for the paired-block kernel" q = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) k = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) From e28cd9b94c101e865fe34eb024d01a13551d31f5 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Fri, 12 Jun 2026 02:39:49 +0000 Subject: [PATCH 10/14] address code review comments Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../attention_backend/cute_dsl/fmha.py | 41 ++- .../visual_gen/attention_backend/parallel.py | 17 +- .../visual_gen/attention_backend/utils.py | 5 +- .../visual_gen/models/wan/pipeline_wan.py | 2 +- .../multi_gpu/test_wan_vsa_ulysses.py | 331 +++++++++--------- 5 files changed, 225 insertions(+), 171 deletions(-) diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/fmha.py b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/fmha.py index e71ede03f3ca..c15f8b2ac47d 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/fmha.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/cute_dsl/fmha.py @@ -13,12 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -CuTe DSL dense FMHA backend for visual generation models. - -CuTeDSLAttention uses pre-compiled cubin kernels and requires head_dim=128. -Supports float16/bfloat16 with optional QK16PV8 quantization and a -skip-softmax threshold optimisation. Expects NHD layout ([B, S, H, D]). +CuTe DSL (NVIDIA kernels) Dense FMHA Backend for Visual Generation Models +Uses pre-compiled cubins derived from CUTLASS CuTe DSL FMHA. +Expects NHD layout ([B, S, H, D]) and supports float16/bfloat16. For the VSA sparse path use VSAAttention in vsa.py. """ @@ -48,11 +46,9 @@ class CuTeDSLAttention(AttentionBackend): """ - CuTe DSL dense FMHA backend for diffusion models. + CuTe DSL (NVIDIA kernels) backend for diffusion models. - Uses pre-compiled cubins and requires head_dim=128. - Supports float16/bfloat16 with optional QK16PV8 quantization (quant_attention_config) - and a skip-softmax threshold optimisation (skip_softmax_threshold_scale). + Uses pre-compiled cubin kernels (head_dim=128 only). """ def __init__( @@ -66,6 +62,7 @@ def __init__( skip_softmax_threshold_scale: Optional[float] = None, **kwargs, ): + # Only head_dim=128 cubins are packaged. if head_dim != 128: raise ValueError(f"CUTEDSL cubins require head_dim=128, got head_dim={head_dim}.") self.layer_idx = layer_idx @@ -77,6 +74,9 @@ def __init__( self.skip_softmax_threshold_scale = skip_softmax_threshold_scale self.scale = 1.0 / math.sqrt(head_dim) + # CuTe DSL expects [B, S, H, D] format + self._preferred_layout = AttentionTensorLayout.NHD + def _prepare_inputs( self, q: torch.Tensor, @@ -176,6 +176,20 @@ def forward( attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL, **kwargs, ) -> torch.Tensor: + """ + Forward pass using CuTe DSL (NVIDIA kernels). + + Dimensions are derived from tensor shapes (NHD layout: ``[B, S, H, D]``). + + Args: + q: Query tensor [batch_size, seq_len, num_heads, head_dim] + k: Key tensor [batch_size, seq_len_kv, num_kv_heads, head_dim] + v: Value tensor [batch_size, seq_len_kv, num_kv_heads, head_dim] + attention_mask: Attention mask type (CAUSAL or FULL) + + Returns: + Output tensor [batch_size, seq_len, num_heads, head_dim] + """ output, _ = self.forward_with_lse(q, k, v, attention_mask=attention_mask, **kwargs) return output @@ -188,9 +202,13 @@ def forward_with_lse( **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """ + Forward pass returning both output and log-sum-exp (LSE). + Returns: output: [batch_size, seq_len, num_heads, head_dim] - lse: [batch_size, num_heads, seq_len] — log-sum-exp in float32. + lse: [batch_size, num_heads, seq_len] - log-sum-exp per query position, + always in float32. Used for numerically stable combination of + partial attention results in Attention2D parallelism. """ q, k, v, is_causal, origin_dtype = self._prepare_inputs(q, k, v, attention_mask) output, lse = self._fwd(q, k, v, is_causal, **kwargs) @@ -204,7 +222,8 @@ def support_lse(cls) -> bool: @property def preferred_layout(self) -> AttentionTensorLayout: - return AttentionTensorLayout.NHD + """Return the preferred tensor layout for this backend.""" + return self._preferred_layout @classmethod def support_fused_qkv(cls) -> bool: diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py b/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py index 6a61a84dd498..4233f0238640 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py @@ -163,17 +163,33 @@ def _forward_fused( v: torch.Tensor, **kwargs, ) -> torch.Tensor: + gate_compress = kwargs.pop("gate_compress", None) + gate_fine = kwargs.pop("gate_fine", None) + batch_size = q.shape[0] qkv = torch.stack([q, k, v], dim=2) qkv = all_to_all_5d(qkv, scatter_dim=3, gather_dim=1, process_group=self.process_group) B, seq_len, _, Hp, D = qkv.shape + if gate_compress is not None: + gate_compress = all_to_all_4d( + gate_compress, scatter_dim=2, gather_dim=1, process_group=self.process_group + ) + if gate_fine is not None: + gate_fine = all_to_all_4d( + gate_fine, scatter_dim=2, gather_dim=1, process_group=self.process_group + ) + # Caller passed pre-A2A (sharded) seq_len; the inner backend # reshapes by it, so hand it the post-A2A length instead. kwargs["batch_size"] = batch_size kwargs["seq_len"] = seq_len kwargs["seq_len_kv"] = seq_len + if gate_compress is not None: + kwargs["gate_compress"] = gate_compress + if gate_fine is not None: + kwargs["gate_fine"] = gate_fine output = self.inner_backend.forward(q=qkv, k=None, v=None, **kwargs) @@ -365,7 +381,6 @@ def _output_a2a( output = all_to_all_4d( output, scatter_dim=1, gather_dim=2, process_group=self.process_group ) - return output @property diff --git a/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py b/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py index 01761cc1f513..e52f52c87247 100644 --- a/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py +++ b/tensorrt_llm/_torch/visual_gen/attention_backend/utils.py @@ -129,7 +129,10 @@ def create_attention( ) kwargs["attention_metadata_state"] = attention_metadata_state if backend.upper() == "CUTEDSL" and attention_config is not None: - if attention_config.sparse_attention_config is not None: + if ( + attention_config.sparse_attention_config is not None + and getattr(attention_config.sparse_attention_config, "algorithm", None) == "vsa" + ): # VSA sparse path: use VSAAttention from .cute_dsl.vsa import VSAAttention diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index 043cec1a07d6..5366c35c22b3 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -503,7 +503,7 @@ def forward( ) # VSA: build metadata builder once per forward() call; reused across timesteps. - _attn_cfg = self.model_config.attention + _attn_cfg = self.pipeline_config.primary_model_config.attention _sparse_cfg = getattr(_attn_cfg, "sparse_attention_config", None) _vsa_active = ( getattr(_attn_cfg, "backend", "VANILLA") == "CUTEDSL" diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py index 9fb3487fe414..d35da3aced81 100644 --- a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py @@ -12,15 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Multi-GPU tests for VSA + Ulysses sequence parallelism (ulysses=2, cfg=2, 4 GPUs).""" +"""Multi-GPU end-to-end test for the Wan T2V VSA pipeline. -import math -import os +Run with: + pytest tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py -v -s -os.environ["TLLM_DISABLE_MPI"] = "1" +Override checkpoint path: + DIFFUSION_MODEL_PATH_WAN21_VSA=/path/to/vsa \\ + pytest tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py -v -s +""" +import gc +import os +from pathlib import Path from typing import Callable +os.environ["TLLM_DISABLE_MPI"] = "1" + import pytest import torch import torch.distributed as dist @@ -28,19 +36,22 @@ import torch.nn.functional as F try: - from tensorrt_llm._torch.visual_gen.attention_backend import ( - CuTeDSLAttention, - UlyssesAttention, - VSAMetadataBuilder, - set_vsa_forward_context, - ) - from tensorrt_llm._torch.visual_gen.mapping import VisualGenMapping + from tensorrt_llm._torch.visual_gen.attention_backend.cute_dsl import _cute_dsl_import_error + from tensorrt_llm._torch.visual_gen.pipeline_loader import PipelineLoader from tensorrt_llm._utils import get_free_port - from tensorrt_llm.visual_gen.args import VideoSparseAttentionConfig + from tensorrt_llm.visual_gen.args import ( + AttentionConfig, + ParallelConfig, + TorchCompileConfig, + VideoSparseAttentionConfig, + VisualGenArgs, + ) MODULES_AVAILABLE = True + _cute_dsl_available = _cute_dsl_import_error is None except ImportError: MODULES_AVAILABLE = False + _cute_dsl_available = False @pytest.fixture(autouse=True, scope="module") @@ -50,7 +61,48 @@ def _cleanup_mpi_env(): # ============================================================================= -# Distributed helpers (same pattern as test_ulysses_sage_attention.py) +# Path helpers (mirrors test_wan_vsa_pipeline.py) +# ============================================================================= + + +def _llm_models_root() -> str: + root = Path("/home/scratch.trt_llm_data_ci/llm-models/") + if "LLM_MODELS_ROOT" in os.environ: + root = Path(os.environ["LLM_MODELS_ROOT"]) + if not root.exists(): + root = Path("/scratch.trt_llm_data/llm-models/") + assert root.exists(), ( + "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." + ) + return str(root) + + +def _checkpoint(env_var: str, default_name: str) -> str: + return os.environ.get(env_var) or os.path.join(_llm_models_root(), default_name) + + +WAN21_VSA_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN21_VSA", "Wan2.1-VSA-T2V-14B-720P-Diffusers") + + +# ============================================================================= +# Inference constants +# ============================================================================= + +PROMPT = "A cat sitting on a sunny windowsill watching birds outside." +NEGATIVE_PROMPT = "" + +HEIGHT = 720 +WIDTH = 1280 +NUM_FRAMES = 9 +NUM_STEPS = 4 +GUIDANCE_SCALE = 5.0 +SEED = 42 + +COS_SIM_THRESHOLD = 0.97 + + +# ============================================================================= +# Distributed harness (mirrors test_wan_pipeline_parallel.py) # ============================================================================= @@ -68,10 +120,10 @@ def cleanup_distributed(): dist.destroy_process_group() -def _distributed_worker(rank, world_size, backend, test_fn, port): +def _distributed_worker(rank, world_size, backend, test_fn, port, kwargs): try: init_distributed_worker(rank, world_size, backend, port) - test_fn(rank, world_size) + test_fn(rank, world_size, **kwargs) except Exception as e: print(f"Rank {rank} failed with error: {e}") raise @@ -79,171 +131,131 @@ def _distributed_worker(rank, world_size, backend, test_fn, port): cleanup_distributed() -def run_test_in_distributed(world_size: int, test_fn: Callable): +def run_test_in_distributed(world_size: int, test_fn: Callable, **kwargs): if not MODULES_AVAILABLE: pytest.skip("Required modules not available") - if not torch.cuda.is_available(): - pytest.skip("CUDA required for VSA") if torch.cuda.device_count() < world_size: pytest.skip(f"Test requires {world_size} GPUs, only {torch.cuda.device_count()} available") - port = get_free_port() mp.spawn( _distributed_worker, - args=(world_size, "nccl", test_fn, port), + args=(world_size, "nccl", test_fn, port, kwargs), nprocs=world_size, join=True, ) -_ULYSSES_SIZE = 2 -_CFG_SIZE = 2 -_VSA_SPARSITY = 0.5 +# ============================================================================= +# Inference helpers +# ============================================================================= + -# (8,8,8) latent -> 512 tokens (256/ulysses-rank at P=2); the (4,4,4) tile gives -# 8 cubes (even, as the paired-block kernel needs) of 64 tokens (kernel block_size). -_DIT_SEQ_SHAPE = (8, 8, 8) -_VSA_PATCH_SIZE = (1, 1, 1) -_HEAD_DIM = 128 # CuTe VSA fine-stage kernel requires head_dim == 128 -_HEADS_PER_RANK = 4 +VSA_SPARSITY = 0.9 -def _make_vsa_backend(num_heads: int) -> "CuTeDSLAttention": - """CUTEDSL backend on the VSA path; effective sparsity comes from the forward context.""" - return CuTeDSLAttention( - layer_idx=0, - num_heads=num_heads, - head_dim=_HEAD_DIM, - sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=_VSA_SPARSITY), +def _build_vsa_parallel_args(checkpoint_path: str) -> "VisualGenArgs": + """cfg=2, ulysses=4, CUTEDSL backend with vsa_sparsity=0.9.""" + return VisualGenArgs( + model=checkpoint_path, + torch_compile_config=TorchCompileConfig(enable=False), + attention_config=AttentionConfig( + backend="CUTEDSL", + sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=VSA_SPARSITY), + ), + parallel_config=ParallelConfig(cfg_size=2, ulysses_size=4), ) -def _build_full_seq_vsa_metadata(device: torch.device): - """VSAMetadata for the full sequence — identical on every rank after Ulysses all-to-all.""" - builder = VSAMetadataBuilder() - return builder.build( - current_timestep=0, - raw_latent_shape=_DIT_SEQ_SHAPE, - patch_size=_VSA_PATCH_SIZE, - vsa_sparsity=_VSA_SPARSITY, - device=device, +def _build_vsa_single_args(checkpoint_path: str) -> "VisualGenArgs": + """Single-GPU CUTEDSL reference at the same vsa_sparsity=0.9.""" + return VisualGenArgs( + model=checkpoint_path, + torch_compile_config=TorchCompileConfig(enable=False), + attention_config=AttentionConfig( + backend="CUTEDSL", + sparse_attention_config=VideoSparseAttentionConfig(vsa_sparsity=VSA_SPARSITY), + ), ) -def _make_vgm(rank: int, world_size: int) -> "VisualGenMapping": - from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl +def _capture_trtllm_video(pipeline) -> "torch.Tensor | None": + """Run full TRTLLM pipeline; return (T, H, W, C) float in [0, 1] or None.""" + with torch.no_grad(): + result = pipeline.forward( + prompt=PROMPT, + negative_prompt=NEGATIVE_PROMPT, + height=HEIGHT, + width=WIDTH, + num_frames=NUM_FRAMES, + num_inference_steps=NUM_STEPS, + guidance_scale=GUIDANCE_SCALE, + seed=SEED, + ) + if result is None or result.video is None: + return None + return result.video.float() / 255.0 - DeviceMeshTopologyImpl.device_mesh = None - return VisualGenMapping( - world_size=world_size, - rank=rank, - cfg_size=_CFG_SIZE, - ulysses_size=_ULYSSES_SIZE, - ) + +def _cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float: + a_flat = a.float().cpu().reshape(-1) + b_flat = b.float().cpu().reshape(-1) + return F.cosine_similarity(a_flat.unsqueeze(0), b_flat.unsqueeze(0)).clamp(-1.0, 1.0).item() + + +def _free(*objs) -> None: + for o in objs: + del o + gc.collect() + torch.cuda.empty_cache() # ============================================================================= -# Test logic functions (module-level so mp.spawn can pickle them) +# Worker logic (module-level for mp.spawn pickling) # ============================================================================= -def _logic_vsa_ulysses_forward(rank, world_size): - """Forward pass: output shape correct and finite (ulysses=2, cfg=2).""" - vgm = _make_vgm(rank, world_size) - ulysses_rank = vgm.ulysses_rank +def _logic_vsa_cfg2_ulysses4(rank: int, world_size: int, *, checkpoint_path: str) -> None: + """End-to-end pipeline: 8-GPU (cfg=2, ulysses=4) VSA vs 1-GPU VSA reference. - batch = 1 - seq_full = math.prod(_DIT_SEQ_SHAPE) - seq_per_rank = seq_full // _ULYSSES_SIZE - num_heads = _ULYSSES_SIZE * _HEADS_PER_RANK + Both runs use vsa_sparsity=0.9. All 8 ranks run the parallel denoising loop. + Rank 0 then frees the distributed model and loads a single-GPU VSA reference + at the same sparsity to compare. + """ + assert world_size == 8, f"This test is hardcoded to world_size=8, got {world_size}" - device = torch.device(f"cuda:{rank}") - torch.manual_seed(42) - torch.cuda.manual_seed_all(42) + vsa_pipe = PipelineLoader(_build_vsa_parallel_args(checkpoint_path)).load(skip_warmup=True) + vsa_video = _capture_trtllm_video(vsa_pipe) - inner = _make_vsa_backend(num_heads // _ULYSSES_SIZE) - attention = UlyssesAttention(inner_backend=inner, process_group=vgm.ulysses_group) + if rank == 0: + assert vsa_video is not None, "Rank 0 produced no video from the VSA pipeline." - # Ulysses input: sequence-sharded, head-full [B, S/P, H, D]. - shape = (batch, seq_per_rank, num_heads, _HEAD_DIM) - q = torch.randn(shape, device=device, dtype=torch.bfloat16) - k = torch.randn(shape, device=device, dtype=torch.bfloat16) - v = torch.randn(shape, device=device, dtype=torch.bfloat16) - gate_compress = torch.randn(shape, device=device, dtype=torch.bfloat16) - gate_fine = torch.randn(shape, device=device, dtype=torch.bfloat16) + _free(vsa_pipe) + if rank != 0: + vsa_video = None + dist.barrier() - metadata = _build_full_seq_vsa_metadata(device) - with set_vsa_forward_context(metadata): - output = attention(q, k, v, gate_compress=gate_compress, gate_fine=gate_fine) - - assert output.shape == (batch, seq_per_rank, num_heads, _HEAD_DIM), ( - f"Rank {rank} (ulysses={ulysses_rank}, cfg={vgm.cfg_rank}): " - f"expected {(batch, seq_per_rank, num_heads, _HEAD_DIM)}, got {output.shape}" - ) - assert torch.isfinite(output).all(), ( - f"Rank {rank} (ulysses={ulysses_rank}, cfg={vgm.cfg_rank}): Inf/NaN in output" - ) + if rank != 0: + return + ref_pipe = PipelineLoader(_build_vsa_single_args(checkpoint_path)).load(skip_warmup=True) + ref_video = _capture_trtllm_video(ref_pipe) + _free(ref_pipe) -def _logic_vsa_ulysses_vs_reference(rank, world_size): - """Each rank's Ulysses+VSA output matches the single-GPU VSA reference's sequence slice.""" - vgm = _make_vgm(rank, world_size) - ulysses_rank = vgm.ulysses_rank - - batch = 1 - seq_full = math.prod(_DIT_SEQ_SHAPE) - seq_per_rank = seq_full // _ULYSSES_SIZE - num_heads = _ULYSSES_SIZE * _HEADS_PER_RANK - - device = torch.device(f"cuda:{rank}") - torch.manual_seed(42) - torch.cuda.manual_seed_all(42) - - full_shape = (batch, seq_full, num_heads, _HEAD_DIM) - q_full = torch.randn(full_shape, device=device, dtype=torch.bfloat16) - k_full = torch.randn(full_shape, device=device, dtype=torch.bfloat16) - v_full = torch.randn(full_shape, device=device, dtype=torch.bfloat16) - gate_c_full = torch.randn(full_shape, device=device, dtype=torch.bfloat16) - gate_f_full = torch.randn(full_shape, device=device, dtype=torch.bfloat16) - - sl = slice(ulysses_rank * seq_per_rank, (ulysses_rank + 1) * seq_per_rank) - q_shard = q_full[:, sl].contiguous() - k_shard = k_full[:, sl].contiguous() - v_shard = v_full[:, sl].contiguous() - gate_c_shard = gate_c_full[:, sl].contiguous() - gate_f_shard = gate_f_full[:, sl].contiguous() - - metadata = _build_full_seq_vsa_metadata(device) - - # Ulysses path: sequence-sharded input, head-sharded inner backend. - inner = _make_vsa_backend(num_heads // _ULYSSES_SIZE) - attention = UlyssesAttention(inner_backend=inner, process_group=vgm.ulysses_group) - with set_vsa_forward_context(metadata): - ulysses_out = attention( - q_shard, k_shard, v_shard, gate_compress=gate_c_shard, gate_fine=gate_f_shard - ) + assert vsa_video.numel() == ref_video.numel(), ( + f"Element count mismatch — 8-GPU {tuple(vsa_video.shape)} " + f"({vsa_video.numel()}) vs 1-GPU {tuple(ref_video.shape)} ({ref_video.numel()})" + ) - # Single-GPU VSA reference over the full sequence. - ref_attn = _make_vsa_backend(num_heads) - with set_vsa_forward_context(metadata): - ref_out = ref_attn.forward( - q_full, k_full, v_full, gate_compress=gate_c_full, gate_fine=gate_f_full - ) - ref_shard = ref_out[:, sl] - - ulysses_out = ulysses_out.view(batch, seq_per_rank, num_heads, _HEAD_DIM).to(torch.bfloat16) - ref_shard = ref_shard.to(torch.bfloat16) - - cos_sim = F.cosine_similarity( - ulysses_out.reshape(-1).float(), - ref_shard.reshape(-1).float(), - dim=0, - ).item() - assert cos_sim > 0.990, ( - f"Rank {rank} (ulysses={ulysses_rank}, cfg={vgm.cfg_rank}): " - f"cosine similarity {cos_sim:.6f} below threshold 0.990" + cos_sim = _cosine_similarity(vsa_video, ref_video) + print( + f"\n Wan2.1-VSA-T2V-14B (cfg=2, ulysses=4, sparsity={VSA_SPARSITY}) " + f"cosine similarity vs 1-GPU VSA: {cos_sim:.6f}" + ) + assert cos_sim >= COS_SIM_THRESHOLD, ( + f"8-GPU VSA pipeline diverges from 1-GPU VSA reference: " + f"cosine similarity {cos_sim:.6f} < {COS_SIM_THRESHOLD}. " + f"Shapes — 8-GPU: {tuple(vsa_video.shape)}, 1-GPU: {tuple(ref_video.shape)}." ) - torch.testing.assert_close(ulysses_out, ref_shard, atol=2e-2, rtol=2e-2) # ============================================================================= @@ -251,22 +263,27 @@ def _logic_vsa_ulysses_vs_reference(rank, world_size): # ============================================================================= -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -class TestWanVsaUlyssesCfg: - """VSA (CUTEDSL) with Ulysses sequence parallelism and CFG parallelism (ulysses=2, cfg=2).""" - - def test_vsa_ulysses_forward(self): - run_test_in_distributed( - world_size=_ULYSSES_SIZE * _CFG_SIZE, - test_fn=_logic_vsa_ulysses_forward, - ) - - def test_vsa_ulysses_vs_reference(self): +@pytest.mark.integration +@pytest.mark.wan_t2v +class TestWanVsaUlysses: + """Multi-GPU correctness for the Wan T2V VSA pipeline (cfg=2, ulysses=4, 8 GPUs).""" + + def test_cfg2_ulysses4(self): + """world=8, cfg=2, ulysses=4, VSA sparsity=0.9 vs 1-GPU VSA reference.""" + if not MODULES_AVAILABLE: + pytest.skip("Required modules not available") + if not _cute_dsl_available: + pytest.skip(f"CUTEDSL not available (requires Blackwell GPU): {_cute_dsl_import_error}") + if not os.path.exists(WAN21_VSA_PATH): + pytest.skip( + f"Checkpoint not found: {WAN21_VSA_PATH}. Set DIFFUSION_MODEL_PATH_WAN21_VSA." + ) run_test_in_distributed( - world_size=_ULYSSES_SIZE * _CFG_SIZE, - test_fn=_logic_vsa_ulysses_vs_reference, + world_size=8, + test_fn=_logic_vsa_cfg2_ulysses4, + checkpoint_path=WAN21_VSA_PATH, ) if __name__ == "__main__": - pytest.main([__file__, "-v"]) + pytest.main([__file__, "-v", "-s"]) From 79d5f860f9da3f87f4e2071e1ccf7c1fbd8d5ee7 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Mon, 15 Jun 2026 10:29:46 -0700 Subject: [PATCH 11/14] resolve CI failure Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py b/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py index 4706c91d7e69..6bc9ae08674a 100644 --- a/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py +++ b/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py @@ -240,7 +240,7 @@ class TestWanVsaSparse: def test_sparse_vsa(self): pipeline = _load_vsa_pipeline(WAN21_VSA_PATH, vsa_sparsity=0.9) try: - attn_cfg = pipeline.model_config.attention + attn_cfg = pipeline.pipeline_config.primary_model_config.attention assert attn_cfg.backend == "CUTEDSL" assert attn_cfg.sparse_attention_config.vsa_sparsity == 0.9 From 3583a9c6dfe9e82e2ba96ac132134c5454c348b9 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Tue, 16 Jun 2026 13:17:10 -0700 Subject: [PATCH 12/14] skip wan multi-gpu test if checkpoint unavailable Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../multi_gpu/test_wan_pipeline_parallel.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_pipeline_parallel.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_pipeline_parallel.py index bb88b0e474ac..a23924dcd504 100644 --- a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_pipeline_parallel.py +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_pipeline_parallel.py @@ -76,20 +76,20 @@ def _cleanup_mpi_env(): # ============================================================================= -def _llm_models_root() -> str: +def _llm_models_root() -> str | None: root = Path("/home/scratch.trt_llm_data_ci/llm-models/") if "LLM_MODELS_ROOT" in os.environ: root = Path(os.environ["LLM_MODELS_ROOT"]) if not root.exists(): root = Path("/scratch.trt_llm_data/llm-models/") - assert root.exists(), ( - "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." - ) - return str(root) + return str(root) if root.exists() else None -def _checkpoint(env_var: str, default_name: str) -> str: - return os.environ.get(env_var) or os.path.join(_llm_models_root(), default_name) +def _checkpoint(env_var: str, default_name: str) -> str | None: + if env_var in os.environ: + return os.environ[env_var] + models_root = _llm_models_root() + return os.path.join(models_root, default_name) if models_root is not None else None WAN21_1_3B_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN21_1_3B", "Wan2.1-T2V-1.3B-Diffusers") @@ -416,7 +416,7 @@ def test_cfg2_ulysses2_pvae2(self): """world=4, cfg=2, ulysses=2, parallel_vae=2 vs HF reference.""" if not MODULES_AVAILABLE: pytest.skip("Required modules not available") - if not os.path.exists(WAN21_1_3B_PATH): + if not WAN21_1_3B_PATH or not os.path.exists(WAN21_1_3B_PATH): pytest.skip( f"Checkpoint not found: {WAN21_1_3B_PATH}. Set DIFFUSION_MODEL_PATH_WAN21_1_3B." ) From 56a94001f3a3382af13e30a37832e5ebe259e7ed Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Tue, 23 Jun 2026 10:54:54 -0700 Subject: [PATCH 13/14] fix VSA correctness test to compare CuTe-DSL vs SDPA fallback Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../visual_gen/test_wan_vsa_pipeline.py | 65 +++++++------------ 1 file changed, 25 insertions(+), 40 deletions(-) diff --git a/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py b/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py index 6bc9ae08674a..979619e8e433 100644 --- a/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py +++ b/tests/unittest/_torch/visual_gen/test_wan_vsa_pipeline.py @@ -3,8 +3,8 @@ """Correctness tests for VSA T2V pipeline with Video Sparse Attention (VSA). -Verifies >= 0.95 cosine similarity on decoded video frames against the -dense TRTLLM reference when VSA is enabled at sparsity=0.0 (dense). +Verifies >= 0.95 cosine similarity between the CuTe-DSL VSA kernel and the +SDPA-fallback VSA path (same gated coarse+fine formulation, different fine kernel). Models: - FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers (720x1280, 9 frames) @@ -102,17 +102,6 @@ def _load_vsa_pipeline(checkpoint_path: str, vsa_sparsity: float = 0.0): return PipelineLoader(args).load(skip_warmup=True) -def _load_dense_pipeline(checkpoint_path: str): - """Load TRTLLM WanPipeline with default dense attention (no VSA).""" - if not os.path.exists(checkpoint_path): - pytest.skip(f"Checkpoint not found: {checkpoint_path}") - args = VisualGenArgs( - model=checkpoint_path, - torch_compile_config=TorchCompileConfig(enable=False), - ) - return PipelineLoader(args).load(skip_warmup=True) - - def _capture_trtllm_video( pipeline, prompt: str, @@ -156,11 +145,12 @@ def _assert_vsa_matches_dense( model_label: str, vsa_sparsity: float = 0.0, ) -> None: - """Run VSA and dense TRTLLM pipelines sequentially, compare decoded video output.""" - # --- VSA (sparsity=0.0 is fully dense via CUTEDSL path) --- - vsa_pipe = _load_vsa_pipeline(checkpoint_path, vsa_sparsity=vsa_sparsity) - vsa_video = _capture_trtllm_video( - vsa_pipe, + """Compare CuTe-DSL VSA against SDPA-fallback VSA (same gated formulation, different fine kernel).""" + from unittest.mock import patch + + from tensorrt_llm._torch.visual_gen.attention_backend.cute_dsl import vsa as _vsa_module + + common_kwargs = dict( prompt=PROMPT, negative_prompt=NEGATIVE_PROMPT, height=height, @@ -170,40 +160,35 @@ def _assert_vsa_matches_dense( guidance_scale=guidance_scale, seed=SEED, ) + + # --- CuTe-DSL path --- + vsa_pipe = _load_vsa_pipeline(checkpoint_path, vsa_sparsity=vsa_sparsity) + vsa_video = _capture_trtllm_video(vsa_pipe, **common_kwargs) del vsa_pipe gc.collect() torch.cuda.empty_cache() - # --- Dense reference (default TRTLLM attention, no VSA) --- - dense_pipe = _load_dense_pipeline(checkpoint_path) - dense_video = _capture_trtllm_video( - dense_pipe, - prompt=PROMPT, - negative_prompt=NEGATIVE_PROMPT, - height=height, - width=width, - num_frames=num_frames, - num_inference_steps=NUM_STEPS, - guidance_scale=guidance_scale, - seed=SEED, - ) - del dense_pipe + # --- SDPA fallback reference (same VSA formulation, fine attn via SDPA) --- + sdpa_pipe = _load_vsa_pipeline(checkpoint_path, vsa_sparsity=vsa_sparsity) + with patch.object(_vsa_module, "is_cute_supported", return_value=False): + sdpa_video = _capture_trtllm_video(sdpa_pipe, **common_kwargs) + del sdpa_pipe gc.collect() torch.cuda.empty_cache() # --- Compare --- - assert vsa_video.numel() == dense_video.numel(), ( + assert vsa_video.numel() == sdpa_video.numel(), ( f"{model_label}: element count mismatch — " - f"VSA {vsa_video.shape} ({vsa_video.numel()}) vs " - f"dense {dense_video.shape} ({dense_video.numel()})" + f"CuTe {vsa_video.shape} ({vsa_video.numel()}) vs " + f"SDPA {sdpa_video.shape} ({sdpa_video.numel()})" ) - cos_sim = _cosine_similarity(vsa_video, dense_video) + cos_sim = _cosine_similarity(vsa_video, sdpa_video) print(f"\n {model_label} cosine similarity: {cos_sim:.6f}") assert cos_sim >= COS_SIM_THRESHOLD, ( f"{model_label}: cosine similarity {cos_sim:.6f} < {COS_SIM_THRESHOLD}. " - f"VSA pipeline output diverges from the dense reference. " - f"Video shapes — VSA: {vsa_video.shape}, dense: {dense_video.shape}." + f"CuTe-DSL VSA diverges from SDPA-fallback VSA. " + f"Video shapes — CuTe: {vsa_video.shape}, SDPA: {sdpa_video.shape}." ) @@ -215,9 +200,9 @@ def _assert_vsa_matches_dense( @pytest.mark.integration @pytest.mark.wan_t2v class TestWanVsa14B_PipelineCorrectness: - """Wan2.1-VSA-T2V-14B correctness vs dense TRTLLM reference (720x1280, 9 frames). + """Wan2.1-VSA-T2V-14B: CuTe-DSL vs SDPA-fallback correctness (720x1280, 9 frames). - VSA at sparsity=0.0 routes through CUTEDSL; threshold is 0.95. + Verifies CuTe-DSL kernel at sparsity=0.0 matches SDPA fallback with >= 0.95 cosine sim. """ def test_cosine_similarity(self): From fa1764eeb5419c9710caef388731b3fb27115169 Mon Sep 17 00:00:00 2001 From: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> Date: Wed, 24 Jun 2026 13:35:55 -0700 Subject: [PATCH 14/14] fix multi-GPU VSA + Ulysses test Signed-off-by: Olivia Stoner <245287810+o-stoner@users.noreply.github.com> --- .../multi_gpu/test_wan_vsa_ulysses.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py index d35da3aced81..9094ba4abf7e 100644 --- a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_vsa_ulysses.py @@ -65,20 +65,20 @@ def _cleanup_mpi_env(): # ============================================================================= -def _llm_models_root() -> str: +def _llm_models_root() -> str | None: root = Path("/home/scratch.trt_llm_data_ci/llm-models/") if "LLM_MODELS_ROOT" in os.environ: root = Path(os.environ["LLM_MODELS_ROOT"]) if not root.exists(): root = Path("/scratch.trt_llm_data/llm-models/") - assert root.exists(), ( - "Set LLM_MODELS_ROOT or ensure /home/scratch.trt_llm_data_ci/llm-models/ is accessible." - ) - return str(root) + return str(root) if root.exists() else None -def _checkpoint(env_var: str, default_name: str) -> str: - return os.environ.get(env_var) or os.path.join(_llm_models_root(), default_name) +def _checkpoint(env_var: str, default_name: str) -> str | None: + if env_var in os.environ: + return os.environ[env_var] + models_root = _llm_models_root() + return os.path.join(models_root, default_name) if models_root is not None else None WAN21_VSA_PATH = _checkpoint("DIFFUSION_MODEL_PATH_WAN21_VSA", "Wan2.1-VSA-T2V-14B-720P-Diffusers") @@ -237,6 +237,9 @@ def _logic_vsa_cfg2_ulysses4(rank: int, world_size: int, *, checkpoint_path: str if rank != 0: return + # Destroy the world_size=8 group before loading the single-GPU reference. + dist.destroy_process_group() + ref_pipe = PipelineLoader(_build_vsa_single_args(checkpoint_path)).load(skip_warmup=True) ref_video = _capture_trtllm_video(ref_pipe) _free(ref_pipe) @@ -274,7 +277,7 @@ def test_cfg2_ulysses4(self): pytest.skip("Required modules not available") if not _cute_dsl_available: pytest.skip(f"CUTEDSL not available (requires Blackwell GPU): {_cute_dsl_import_error}") - if not os.path.exists(WAN21_VSA_PATH): + if WAN21_VSA_PATH is None or not os.path.exists(WAN21_VSA_PATH): pytest.skip( f"Checkpoint not found: {WAN21_VSA_PATH}. Set DIFFUSION_MODEL_PATH_WAN21_VSA." )