Skip to content

Add the Skip softmax for diffusion#1166

Open
jingyu-ml wants to merge 25 commits intomainfrom
jingyux/diffusion-skip-softmax
Open

Add the Skip softmax for diffusion#1166
jingyu-ml wants to merge 25 commits intomainfrom
jingyux/diffusion-skip-softmax

Conversation

@jingyu-ml
Copy link
Copy Markdown
Contributor

@jingyu-ml jingyu-ml commented Apr 2, 2026

What does this PR do?

Type of change: new feature, new example

Summary

  • Add skip-softmax sparse attention (BLASST) for diffusion models via dedicated Triton kernels — an inference kernel with tile skipping and a calibration kernel with vectorized multi-threshold sparsity measurement
  • Add triton_skip_softmax method with exponential model calibration (scale_factor = a * exp(b * sparsity)) and log-space fitting for diffusion models
  • Add Triton kernel backends for diffusers and LTX attention dispatch
  • Fix calibration to skip RULER dataset generation when user provides their own forward_loop (required for non-LLM models)

Changes

Triton kernels (modelopt/torch/kernels/triton_fa.py)

  • _attn_fwd: Forward kernel with optional tile skipping — tiles whose max attention score is far below the running softmax max are skipped entirely (no V load, no softmax, no accumulation). Runtime sparsity measurement via atomic counters.
  • _attn_fwd_calibrate: Calibration kernel that computes full attention while measuring how many tiles would be skipped at each of N thresholds simultaneously. Uses per-program output buffers (zero atomic contention) and vectorized multi-threshold comparison.
  • attention() / attention_calibrate(): Python wrappers for inference and calibration kernels.

Kernel backends (modelopt/torch/sparsity/attention_sparsity/kernels/)

  • diffusers_triton_attention.py: Registers modelopt_triton backend in diffusers' attention dispatch. Handles [B, S, H, D] → varlen layout conversion, calibration/inference mode switching, thread-local configuration, and counter accumulation.
  • ltx_triton_attention.py: Patches ltx_core.Attention modules for Triton dispatch with the same calibration/inference modes.

Method (modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py)

  • TritonSkipSoftmaxMethod: Context managers for calibration (→ calibration kernel) and inference (→ forward kernel with tile skipping). Three threshold priority levels: raw threshold > calibrated scale_factor > static threshold.

Calibration (modelopt/torch/sparsity/attention_sparsity/calibration/)

  • calibrator.py: DynamicThresholdCalibrator with fit_logspace option — fits exponential model in log space (minimizes relative error) for diffusion models where scale_factors span many orders of magnitude. Records observed sparsity range for extrapolation warnings.
  • calibrate.py: Skips RULER dataset when forward_loop is provided; passes fit_logspace through from config.

Config & conversion

  • config.py: CalibrationConfig.fit_logspace field (default False, recommended True for diffusion models). skip_softmax_raw_threshold field for direct threshold mode.
  • conversion.py: Auto-registers diffusers/LTX Triton backends on sparsify(). Updated summary display.

Example

  • wan22_skip_softmax.py: End-to-end example for WAN 2.2 5B/14B with baseline, raw-threshold, and calibrated modes. Supports runtime sparsity reporting.

Threshold modes

Mode How it works Use case
Raw threshold (--raw-threshold -0.7) Passed directly to kernel as skip_threshold_log2 Quick testing, sweeps
Calibrated (--calibrate --target-sparsity 0.5) scale_factor = a * exp(b * target), then threshold = scale_factor / seq_k at runtime Production use with seqlen adaptation
Static (default skip_softmax_threshold=0.1) log2(lambda) * sm_scale Fallback

Usage

# Fixed raw threshold (no calibration)
python examples/diffusers/sparsity/wan22_skip_softmax.py \
    --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
    --raw-threshold -0.7 \
    --prompt "A cat playing piano" --output out.mp4

# With calibration (log-space fit for diffusion models)
python examples/diffusers/sparsity/wan22_skip_softmax.py \
    --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
    --calibrate --target-sparsity 0.5 \
    --prompt "A cat playing piano" --output out.mp4

# Dense baseline for comparison
python examples/diffusers/sparsity/wan22_skip_softmax.py \
    --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
    --baseline \
    --prompt "A cat playing piano" --output baseline.mp4

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ❌

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Added skip-softmax sparse attention support for Diffusers models, enabling efficient video generation
    • Added support for both eager and Triton attention backends for sparse attention
    • Added new example script for Wan 2.2 text-to-video generation with sparse attention optimization
  • Documentation

    • Updated documentation with sparse attention configuration guide and usage examples
  • Tests

    • Added comprehensive unit tests for kernel backend registration and skip-softmax functionality

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested review from a team as code owners April 2, 2026 06:02
@jingyu-ml jingyu-ml requested a review from kaix-nv April 2, 2026 06:02
@jingyu-ml jingyu-ml marked this pull request as draft April 2, 2026 06:02
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 2, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds skip-softmax sparse attention support for Diffusers and LTX-2 models, including new eager and Triton kernel backends, calibration refinements, example scripts, and comprehensive tests for framework-specific attention implementations.

Changes

Cohort / File(s) Summary
Example Script & Documentation
examples/diffusers/sparsity/wan22_skip_softmax.py, examples/diffusers/README.md
New executable example for WAN 2.2 video generation using skip-softmax sparse attention with CLI argument parsing, calibration support, and sparsity summary reporting. README updated with sparse attention section and example script instructions.
Calibration & Conversion Core
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py, modelopt/torch/sparsity/attention_sparsity/conversion.py
Modified calibrate_sparse_attention() to defer tokenizer/dataset generation when forward_loop is provided. Added _register_diffusers_backends_if_needed() to conditionally register Diffusers backends and patch LTX modules. Updated print_sparse_attention_summary() to skip disabled modules when computing sparsity counts.
Diffusers Kernel Backends
modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py, modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py
New eager attention backend implementing scaled dot-product with softmax interception point. New Triton backend reshaping Diffusers layout to varlen format with optional skip-softmax threshold support. Both include idempotent backend registration and context managers.
LTX Kernel Backends
modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py, modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py
New eager wrapper for LTX-2 attention with skip-softmax context detection. New Triton backend reshaping LTX fused-head layout to varlen format with skip-softmax threshold support. Both include thread-local configuration and idempotent module wrapping.
Kernel Infrastructure
modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py
Added thread-local skip-softmax context helpers (set_skip_softmax_context, get_skip_softmax_context) and optional backend registration symbols with conditional imports for Diffusers and LTX backends.
Sparse Attention Methods
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py, modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
Updated flash_skip_softmax get_sparse_context() to toggle skip-softmax context flag and conditionally enter Diffusers eager backend context. Added calculate_sparsity() and apply_sparsity() to TritonSkipSoftmaxMethod with explicit NotImplementedError for Python-path sparsity.
Plugin System & Infrastructure
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py, modelopt/torch/sparsity/attention_sparsity/stats_manager.py
Deferred transformers import in _is_supported_model() and added Diffusers ModelMixin detection. Modified stats collection to conditionally extend sample stats with "normalized_gaps" when present in incoming statistics.
Unit Tests
tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py
New comprehensive test module covering thread-local context behavior, eager/Triton backend registration idempotence, shape validation across attention scenarios (basic, cross-attention, causal, GQA), and diffusers backend registration via mocked dependencies.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant Pipeline as WAN Pipeline
    participant Sparsify as mtsa.sparsify()
    participant Transformer as Transformer<br/>Modules
    participant CalibLoop as Calibration<br/>Forward Loop
    participant Config as Sparse<br/>Config

    User->>Pipeline: build_pipeline(model_path)
    activate Pipeline
    Pipeline-->>User: pipeline ready
    deactivate Pipeline

    User->>Config: build_sparse_config(args)
    activate Config
    Config-->>User: sparse config dict
    deactivate Config

    alt Calibration Mode
        User->>CalibLoop: build_calibration_forward_loop()
        activate CalibLoop
        CalibLoop-->>User: forward_loop callable
        deactivate CalibLoop
        
        User->>Sparsify: sparsify(transformer, config,<br/>forward_loop=...)
        activate Sparsify
        Sparsify->>CalibLoop: invoke forward_loop<br/>(multiple prompts)
        CalibLoop->>Transformer: collect attention stats
        Transformer-->>CalibLoop: attention outputs
        Sparsify->>Transformer: apply sparse config
        Transformer-->>Sparsify: sparse attention modules
        Sparsify-->>User: sparsified transformer
        deactivate Sparsify
    else No Calibration
        User->>Sparsify: sparsify(transformer, config)
        activate Sparsify
        Sparsify->>Transformer: apply sparse config
        Transformer-->>Sparsify: sparse attention modules
        Sparsify-->>User: sparsified transformer
        deactivate Sparsify
    end

    User->>Pipeline: generate(prompt, ...)
    activate Pipeline
    Pipeline->>Transformer: forward with sparse<br/>attention
    Transformer-->>Pipeline: output frames
    deactivate Pipeline
    
    User->>User: print_sparsity_summary(model)
    activate User
    User->>Transformer: enumerate SparseAttentionModule
    Transformer-->>User: module configs
    deactivate User
Loading
sequenceDiagram
    participant Model as Model<br/>(Diffusers/LTX)
    participant ConversionFn as convert_to_sparse_<br/>attention_model()
    participant RegFn as _register_diffusers_<br/>backends_if_needed()
    participant DiffusersBackends as Diffusers<br/>Backends
    participant LTXBackends as LTX<br/>Backends
    participant Method as Sparse<br/>Method

    Model->>ConversionFn: convert_to_sparse_attention_model(model, ...)
    activate ConversionFn
    ConversionFn->>RegFn: _register_diffusers_backends_if_needed(model)
    activate RegFn
    
    alt Is Diffusers ModelMixin
        RegFn->>DiffusersBackends: register_diffusers_eager_attention()
        RegFn->>DiffusersBackends: register_diffusers_triton_attention()
        DiffusersBackends-->>RegFn: backends registered
    end
    
    alt Has LTX modules
        RegFn->>LTXBackends: patch ltx attention<br/>modules
        LTXBackends-->>RegFn: wrappers installed
    end
    
    RegFn-->>ConversionFn: registration complete
    deactivate RegFn
    
    ConversionFn->>ConversionFn: _set_attn_implementation()
    ConversionFn->>Method: apply sparse config
    Method-->>ConversionFn: sparse model
    ConversionFn-->>Model: sparse model ready
    deactivate ConversionFn
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

  • PR #1078: Adds and wires up Triton-based N:M sparse softmax support affecting the Triton flash-attention path and sparse-attention kernel integration.

Suggested reviewers

  • Edwardf0t1
  • cjluo-nv
🚥 Pre-merge checks | ✅ 2 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 61.11% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'Add the Skip softmax for diffusion' is vague and partially related to the main changeset. It mentions skip-softmax for diffusion but lacks specificity about what is being added (kernel backends, calibration support, example scripts). Consider revising to be more specific, such as 'Add skip-softmax sparse attention support for diffusion models' or 'Implement skip-softmax sparse attention with diffusers/LTX backends and calibration'.
✅ Passed checks (2 passed)
Check name Status Explanation
Security Anti-Patterns ✅ Passed Comprehensive security analysis confirms no critical security anti-patterns present in pull request changes.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch jingyux/diffusion-skip-softmax

Comment @coderabbitai help to get the list of available commands and usage tips.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 2, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@jingyu-ml jingyu-ml self-assigned this Apr 2, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 2, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1166/

Built to branch gh-pages at 2026-04-09 06:36 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 2, 2026

Codecov Report

❌ Patch coverage is 23.30827% with 408 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.97%. Comparing base (04cd596) to head (3845b47).

Files with missing lines Patch % Lines
modelopt/torch/kernels/triton_fa.py 0.00% 108 Missing ⚠️
.../attention_sparsity/methods/triton_skip_softmax.py 16.37% 97 Missing ⚠️
...attention_sparsity/kernels/ltx_triton_attention.py 4.70% 81 Missing ⚠️
...ion_sparsity/kernels/diffusers_triton_attention.py 48.51% 52 Missing ⚠️
...rsity/attention_sparsity/calibration/calibrator.py 10.81% 33 Missing ⚠️
...arsity/attention_sparsity/calibration/calibrate.py 22.22% 14 Missing ⚠️
...y/attention_sparsity/methods/flash_skip_softmax.py 8.33% 11 Missing ⚠️
...pt/torch/sparsity/attention_sparsity/conversion.py 65.00% 7 Missing ⚠️
modelopt/torch/kernels/__init__.py 33.33% 2 Missing ⚠️
...ch/sparsity/attention_sparsity/methods/registry.py 50.00% 1 Missing ⚠️
... and 2 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1166      +/-   ##
==========================================
- Coverage   75.68%   74.97%   -0.71%     
==========================================
  Files         353      355       +2     
  Lines       40491    40982     +491     
==========================================
+ Hits        30644    30728      +84     
- Misses       9847    10254     +407     
Flag Coverage Δ
unit 54.89% <23.30%> (-0.35%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml force-pushed the jingyux/diffusion-skip-softmax branch from 8151232 to 5873652 Compare April 2, 2026 08:38
jingyu-ml and others added 2 commits April 2, 2026 21:29
@jingyu-ml jingyu-ml marked this pull request as ready for review April 3, 2026 06:15
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (2)

306-316: ⚠️ Potential issue | 🟠 Major

Decode calibration fails when forward_loop is provided.

When a user supplies forward_loop, lines 264-265 skip building tokenizer and calibration_data. However, decode calibration (lines 312-316) unconditionally requires both, raising RuntimeError even though the user intended to use their own loop.

This creates an inconsistency: prefill calibration supports user-provided forward_loop, but decode calibration does not. The docstring (line 227) also states forward_loop is "Only used for prefill", but this limitation should either be enforced earlier or decode should also accept a custom loop.

💡 Suggested approach

Either:

  1. Skip decode calibration when forward_loop is provided and calibration_data is None, with a warning
  2. Accept a separate decode_forward_loop parameter
  3. Document and enforce that decode calibration requires RULER dataset
     # Run decode calibration if enabled
     if calibrate_decode:
+        if calibration_data is None or tokenizer is None:
+            warnings.warn(
+                "Decode calibration requires RULER dataset. Skipping decode calibration "
+                "because a custom forward_loop was provided without calibration_data."
+            )
+        else:
             print("\n" + "=" * 60)
             # ... rest of decode calibration
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 306 - 316, The decode calibration block (calibrate_decode) currently
raises RuntimeError if calibration_data or tokenizer are missing even when the
user supplied forward_loop; change the logic in the calibrate_decode section to
detect when forward_loop is provided and calibration_data is None and skip
decode calibration with a warning instead of raising, i.e., only call
create_decode_calibration_forward_loop when calibration_data and tokenizer exist
(use create_decode_calibration_forward_loop(calibration_data, tokenizer, ...)),
otherwise log/warn that decode calibration is skipped due to missing
calibration_data/tokenizer while a custom forward_loop was supplied; update any
related docstring or comment near the calibrate_decode and forward_loop mentions
to reflect this behavior.

24-24: ⚠️ Potential issue | 🔴 Critical

Unconditional transformers import causes pipeline failure.

The module-level import of transformers.AutoTokenizer fails when transformers is not installed. This should be deferred to usage sites (inside _load_tokenizer or guarded) to allow the module to be imported when only diffusers-based workflows are used.

🐛 Proposed fix: defer import to usage site
-from transformers import AutoTokenizer

Then update _load_tokenizer:

 def _load_tokenizer(tokenizer_name_or_path: str) -> "AutoTokenizer":
     """Load tokenizer and ensure pad_token is set."""
+    from transformers import AutoTokenizer
+
     tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` at line
24, The file currently imports transformers.AutoTokenizer at module scope which
raises ImportError when transformers isn't installed; move the import into the
tokenizer-loading code path so the module can be imported without transformers.
Specifically, remove the top-level "from transformers import AutoTokenizer" and
instead import AutoTokenizer inside the _load_tokenizer function (or guard the
import with a try/except that raises a clear error when the function is called),
ensuring _load_tokenizer handles the absence of transformers and only then
attempts to create the tokenizer.
♻️ Duplicate comments (1)
modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py (1)

30-36: ⚠️ Potential issue | 🟠 Major

Same top-level import issue as the eager backend.

Both diffusers and modelopt.torch.kernels are imported unconditionally at the top level. This will cause import failures for users who don't have diffusers installed or don't have CUDA/Triton available.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`
around lines 30 - 36, Top-level unconditional imports of diffusers symbols
(AttentionBackendName, _AttentionBackendRegistry, attention_backend) and
modelopt.torch.kernels.attention can fail for users without diffusers or
CUDA/Triton; move these imports into the function or class that actually uses
them (e.g., inside the registration function or the backend implementation in
diffusers_triton_attention.py) and guard with try/except ImportError to raise a
clear error only when the backend is instantiated; ensure you reference the same
symbols (AttentionBackendName, _AttentionBackendRegistry, attention_backend, and
attention) after relocating the imports so registration only occurs when
dependencies are present.
🧹 Nitpick comments (4)
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)

124-125: Overly broad exception handling may hide bugs.

except (ImportError, Exception) catches all exceptions including programming errors (e.g., TypeError, AttributeError). Consider narrowing to specific expected exceptions.

♻️ Suggested fix
-    except (ImportError, Exception):
+    except (ImportError, RuntimeError):
         pass

Or log unexpected exceptions for debugging:

-    except (ImportError, Exception):
-        pass
+    except ImportError:
+        pass
+    except Exception as e:
+        import logging
+        logging.debug(f"Diffusers backend registration failed: {e}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 124 -
125, The current broad except clause "except (ImportError, Exception)" in
conversion.py swallows all errors and can hide bugs; change it to only catch
ImportError (e.g., "except ImportError as e") for the import-failure path, and
if you must catch other runtime issues around the same block, catch specific
exceptions or log unexpected exceptions (use a logger.exception or re-raise
after logging) so that programming errors like TypeError/AttributeError are not
silently ignored; update the except block that follows the import attempt in
conversion.py accordingly.
examples/diffusers/sparsity/ltx2_skip_softmax.py (2)

66-81: Hardcoded user-specific paths should be placeholders.

The default paths contain user-specific paths (/home/scratch.omniml_data_2/jingyux/...) that won't exist on other systems. Consider using empty strings or raising a clear error when environment variables are not set.

Proposed fix: Require explicit configuration
-CHECKPOINT_PATH = os.environ.get(
-    "LTX2_CHECKPOINT",
-    "/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-dev.safetensors",
-)
+CHECKPOINT_PATH = os.environ.get("LTX2_CHECKPOINT", "")
+DISTILLED_LORA_PATH = os.environ.get("LTX2_DISTILLED_LORA", "")
+SPATIAL_UPSAMPLER_PATH = os.environ.get("LTX2_SPATIAL_UPSAMPLER", "")
+GEMMA_ROOT = os.environ.get("LTX2_GEMMA_ROOT", "")

Then in build_pipeline():

def build_pipeline() -> TI2VidTwoStagesPipeline:
    if not all([CHECKPOINT_PATH, DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, GEMMA_ROOT]):
        raise ValueError(
            "Missing required environment variables. Set: "
            "LTX2_CHECKPOINT, LTX2_DISTILLED_LORA, LTX2_SPATIAL_UPSAMPLER, LTX2_GEMMA_ROOT"
        )
    ...
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/ltx2_skip_softmax.py` around lines 66 - 81, The
file defines CHECKPOINT_PATH, DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, and
GEMMA_ROOT with hardcoded user-specific default paths; remove those user paths
and default to empty string or None when reading the env vars
(os.environ.get(..., "") or None), and add a validation at the start of
build_pipeline() that checks these constants (CHECKPOINT_PATH,
DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, GEMMA_ROOT) and raises a clear
ValueError listing the required env names (LTX2_CHECKPOINT, LTX2_DISTILLED_LORA,
LTX2_SPATIAL_UPSAMPLER, LTX2_GEMMA_ROOT) if any are missing so callers must
explicitly configure them.

260-267: load_dataset call may download data unexpectedly.

The load_dataset("nkp37/OpenVid-1M") call will download the dataset on first run, which could be surprising for users. Consider adding a note in the docstring or CLI help, or making this behavior opt-in.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/ltx2_skip_softmax.py` around lines 260 - 267, The
load_calib_prompts function calls load_dataset("nkp37/OpenVid-1M") which may
trigger a large download unexpectedly; update load_calib_prompts to make dataset
download explicit or opt-in (e.g., add a parameter like download: bool = False
or a dataset_path argument) and update the docstring to warn that calling this
function will download the OpenVid-1M dataset unless an existing local dataset
path is provided; ensure the function checks the opt-in flag or uses the
provided path before invoking load_dataset to avoid surprising downloads.
tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py (1)

130-158: Test mock doesn't match actual kernel call signature.

The mock mk.attention = lambda q, k, v, **kw: q returns q directly, but per the context snippet, the actual kernel receives varlen metadata and returns output with shape [B*S, H, D]. The mock should return a tensor with the correct output shape to avoid masking reshape bugs in the code under test.

Proposed fix: Return correctly shaped tensor
-        mk.attention = lambda q, k, v, **kw: q
+        def mock_attention(q, k, v, **kw):
+            # Return tensor with same shape as q (correct for varlen format)
+            return torch.zeros_like(q)
+        mk.attention = mock_attention
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py` around
lines 130 - 158, The mock attention implementation in
TestDiffusersTritonAttention._setup (mk.attention = lambda q, k, v, **kw: q)
does not match the real kernel signature/behavior and returns the wrong shape,
which hides reshape/masking bugs; update the mock in _setup (mk.attention) to
accept the same args including varlen metadata (keep **kw) and return a tensor
with shape [B*S, H, D] derived from the input q/k/v shapes (e.g., compute B, S,
H, D from q and construct/return a tensor of that shape instead of returning q
directly) so the code paths in _diffusers_triton_attention,
set_triton_skip_softmax_config, clear_triton_skip_softmax_config,
register_diffusers_triton_attention and get_triton_attention_backend are
exercised with realistic output shapes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py`:
- Around line 31-35: The top-level import of diffusers internals
(AttentionBackendName, _AttentionBackendRegistry, attention_backend) should be
guarded to avoid ImportError for users without diffusers; wrap the import in a
try/except and set a module-level flag (e.g., _DIFFUSERS_AVAILABLE =
False/True). Update any registration or accessor functions that reference
AttentionBackendName, _AttentionBackendRegistry or attention_backend to check
_DIFFUSERS_AVAILABLE before using them and return/do nothing or raise a clear
runtime error when diffusers is unavailable. Ensure all places that previously
assumed the imports (registration functions) consult _DIFFUSERS_AVAILABLE so the
module can be imported without diffusers installed.
- Around line 120-132: The code unconditionally manipulates private diffusers
internals (AttentionBackendName, _AttentionBackendRegistry, etc.) that exist
only in diffusers >= 0.36.0; add a runtime/version guard before creating
new_member and registering _diffusers_eager_attention: check
diffusers.__version__ (or use the same utility used in
modelopt/torch/quantization/plugins/diffusion/diffusers.py) or probe for the
presence of attributes like AttentionBackendName._member_map_ and
_AttentionBackendRegistry._backends, and only perform the enum extension and
registry assignments when those APIs exist; also update pyproject.toml to
require diffusers>=0.36.0 or make the registration conditional so the code
no-ops on older diffusers versions.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 61-70: The _diffusers_triton_attention function currently accepts
attn_mask and enable_gqa but ignores them; update the function so it either
implements GQA and mask handling consistent with the eager backend or explicitly
fails fast: if enable_gqa is True or attn_mask is not None, raise a clear
NotImplementedError mentioning "_diffusers_triton_attention does not support
enable_gqa/attn_mask yet" (or implement the same GQA reshaping/aggregation logic
used in diffusers_eager_attention for query/key/value before calling the Triton
kernel) so callers won't silently get incorrect results.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`:
- Around line 60-67: The _ltx_triton_attention function currently accepts a mask
parameter but never uses it; update the implementation to handle masks: either
pass the mask into the Triton kernel via the attn_mask argument when invoking
the kernel (ensure shapes/dtypes match and add logic to convert/expand the mask
to the kernel's expected form), or if kernel masking isn't supported yet,
explicitly reject masks by raising a clear error (e.g., raise
NotImplementedError("mask not supported by _ltx_triton_attention") when mask is
not None) so callers won't silently get incorrect results. Ensure the change is
applied inside _ltx_triton_attention and that any conversion/validation of mask
is performed before the kernel call.
- Line 29: The module unconditionally imports Attention from ltx_core which will
raise ImportError for users without LTX-2; change the top-level import to a
guarded import (try/except ImportError) or defer importing until registration,
set Attention = None on failure, and update register_ltx_triton_attention to
check if Attention is None and raise a clear ImportError like "ltx_core is
required for LTX-2 Triton attention" before proceeding; reference the symbols
Attention and register_ltx_triton_attention in ltx_triton_attention.py to locate
where to apply the guard.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 389-401: The code sets the thread-wide flag via
set_skip_softmax_context(True) immediately which can leak if an exception occurs
before the returned ExitStack is entered; instead create a small context manager
(e.g., using contextlib.contextmanager or a tiny class) that calls
set_skip_softmax_context(True) on __enter__/enter and
set_skip_softmax_context(False) on __exit__/exit, and then register that context
with stack.enter_context rather than calling set_skip_softmax_context and
stack.callback directly; update the function that builds the stack (the block
using ExitStack, get_skip_softmax_attention_backend,
replace_function(torch.nn.functional, "softmax", sparse_softmax)) to enter the
new flag-context via stack.enter_context so the flag is only set when the stack
is actually entered and always cleaned up on exit.

In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py`:
- Around line 56-72: The tests import
modelopt.torch.sparsity.attention_sparsity.kernels which transitively imports
transformers and breaks CI; update TestSkipSoftmaxContext to skip when optional
dependency missing by using pytest.importorskip('transformers') or catching
ImportError before importing get_skip_softmax_context/set_skip_softmax_context
(or call pytest.skip) so the test cleanly skips in environments without
transformers; ensure the changes are applied around the imports used in
TestSkipSoftmaxContext (references: get_skip_softmax_context,
set_skip_softmax_context, TestSkipSoftmaxContext).
- Around line 178-207: The test fails because patching targets under
"modelopt.torch.sparsity.attention_sparsity.kernels" occurs before that
submodule is loaded, causing a Module attribute error; to fix, ensure the module
is imported before patching or patch the symbols at the location they are looked
up by _register_diffusers_backends_if_needed: import
modelopt.torch.sparsity.attention_sparsity.conversion (or the parent package)
first, then patch the call targets register_diffusers_eager_attention and
register_diffusers_triton_attention as used by that conversion module (i.e.,
patch where _register_diffusers_backends_if_needed resolves them) so the
MagicMock replacement applies correctly.

---

Outside diff comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 306-316: The decode calibration block (calibrate_decode) currently
raises RuntimeError if calibration_data or tokenizer are missing even when the
user supplied forward_loop; change the logic in the calibrate_decode section to
detect when forward_loop is provided and calibration_data is None and skip
decode calibration with a warning instead of raising, i.e., only call
create_decode_calibration_forward_loop when calibration_data and tokenizer exist
(use create_decode_calibration_forward_loop(calibration_data, tokenizer, ...)),
otherwise log/warn that decode calibration is skipped due to missing
calibration_data/tokenizer while a custom forward_loop was supplied; update any
related docstring or comment near the calibrate_decode and forward_loop mentions
to reflect this behavior.
- Line 24: The file currently imports transformers.AutoTokenizer at module scope
which raises ImportError when transformers isn't installed; move the import into
the tokenizer-loading code path so the module can be imported without
transformers. Specifically, remove the top-level "from transformers import
AutoTokenizer" and instead import AutoTokenizer inside the _load_tokenizer
function (or guard the import with a try/except that raises a clear error when
the function is called), ensuring _load_tokenizer handles the absence of
transformers and only then attempts to create the tokenizer.

---

Duplicate comments:
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 30-36: Top-level unconditional imports of diffusers symbols
(AttentionBackendName, _AttentionBackendRegistry, attention_backend) and
modelopt.torch.kernels.attention can fail for users without diffusers or
CUDA/Triton; move these imports into the function or class that actually uses
them (e.g., inside the registration function or the backend implementation in
diffusers_triton_attention.py) and guard with try/except ImportError to raise a
clear error only when the backend is instantiated; ensure you reference the same
symbols (AttentionBackendName, _AttentionBackendRegistry, attention_backend, and
attention) after relocating the imports so registration only occurs when
dependencies are present.

---

Nitpick comments:
In `@examples/diffusers/sparsity/ltx2_skip_softmax.py`:
- Around line 66-81: The file defines CHECKPOINT_PATH, DISTILLED_LORA_PATH,
SPATIAL_UPSAMPLER_PATH, and GEMMA_ROOT with hardcoded user-specific default
paths; remove those user paths and default to empty string or None when reading
the env vars (os.environ.get(..., "") or None), and add a validation at the
start of build_pipeline() that checks these constants (CHECKPOINT_PATH,
DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, GEMMA_ROOT) and raises a clear
ValueError listing the required env names (LTX2_CHECKPOINT, LTX2_DISTILLED_LORA,
LTX2_SPATIAL_UPSAMPLER, LTX2_GEMMA_ROOT) if any are missing so callers must
explicitly configure them.
- Around line 260-267: The load_calib_prompts function calls
load_dataset("nkp37/OpenVid-1M") which may trigger a large download
unexpectedly; update load_calib_prompts to make dataset download explicit or
opt-in (e.g., add a parameter like download: bool = False or a dataset_path
argument) and update the docstring to warn that calling this function will
download the OpenVid-1M dataset unless an existing local dataset path is
provided; ensure the function checks the opt-in flag or uses the provided path
before invoking load_dataset to avoid surprising downloads.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 124-125: The current broad except clause "except (ImportError,
Exception)" in conversion.py swallows all errors and can hide bugs; change it to
only catch ImportError (e.g., "except ImportError as e") for the import-failure
path, and if you must catch other runtime issues around the same block, catch
specific exceptions or log unexpected exceptions (use a logger.exception or
re-raise after logging) so that programming errors like TypeError/AttributeError
are not silently ignored; update the except block that follows the import
attempt in conversion.py accordingly.

In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py`:
- Around line 130-158: The mock attention implementation in
TestDiffusersTritonAttention._setup (mk.attention = lambda q, k, v, **kw: q)
does not match the real kernel signature/behavior and returns the wrong shape,
which hides reshape/masking bugs; update the mock in _setup (mk.attention) to
accept the same args including varlen metadata (keep **kw) and return a tensor
with shape [B*S, H, D] derived from the input q/k/v shapes (e.g., compute B, S,
H, D from q and construct/return a tensor of that shape instead of returning q
directly) so the code paths in _diffusers_triton_attention,
set_triton_skip_softmax_config, clear_triton_skip_softmax_config,
register_diffusers_triton_attention and get_triton_attention_backend are
exercised with realistic output shapes.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 578e7893-06c2-4586-93e1-4726205a2f84

📥 Commits

Reviewing files that changed from the base of the PR and between 87ea8ba and 2c323df.

📒 Files selected for processing (14)
  • examples/diffusers/sparsity/ltx2_skip_softmax.py
  • examples/diffusers/sparsity/wan22_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py
  • modelopt/torch/sparsity/attention_sparsity/stats_manager.py
  • tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py

jingyu-ml and others added 2 commits April 6, 2026 17:20
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)

308-317: ⚠️ Potential issue | 🟡 Minor

Improve error message and consider early validation when custom forward_loop conflicts with decode calibration.

When a user provides a custom forward_loop (e.g., for diffusion models) but also configures calibrate_decode=True, the current RuntimeError message "calibration_data and tokenizer must be built before decode" is confusing—it implies a bug rather than an unsupported configuration.

Consider either:

  1. Early validation (preferred): Check at the start of calibration if forward_loop is not None and calibrate_decode and raise with a clear message, or
  2. Improve the error message to explicitly state the limitation.
Option 1: Add early validation near line 246
     # Skip if both phases are disabled
     if not calibrate_prefill and not calibrate_decode:
         print("Both prefill and decode target sparsity are 0.0, skipping calibration")
         return {}

+    # Decode calibration requires RULER dataset, which is incompatible with custom forward_loop
+    if forward_loop is not None and calibrate_decode:
+        raise ValueError(
+            "Decode calibration is not supported when a custom forward_loop is provided. "
+            "Either set decode target_sparse_ratio to 0.0 or remove the forward_loop argument "
+            "to use auto-generated RULER dataset calibration."
+        )
+
     # Get sparse attention modules
Option 2: Improve error message at lines 313-314
         if calibration_data is None or tokenizer is None:
-            raise RuntimeError("calibration_data and tokenizer must be built before decode")
+            raise RuntimeError(
+                "Decode calibration requires tokenizer and RULER dataset, which are not available "
+                "when a custom forward_loop is provided. Set decode target_sparse_ratio to 0.0 "
+                "or remove the forward_loop argument to enable decode calibration."
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 308 - 317, Add an early validation that explicitly rejects using a custom
forward_loop with decode calibration: if forward_loop is not None and
calibrate_decode is True, raise a clear RuntimeError stating that decode
calibration is incompatible with a custom forward_loop and instruct the user to
disable calibrate_decode or remove the custom forward_loop; place this check
near the start of the calibration flow (before
create_decode_calibration_forward_loop is called). Alternatively, if you prefer
the minimal change, improve the existing RuntimeError in the block that checks
calibration_data and tokenizer to a message that either explains missing
calibration_data/tokenizer or that decode calibration is unsupported when a
custom forward_loop is provided (reference forward_loop, calibrate_decode,
calibration_data, tokenizer, and create_decode_calibration_forward_loop).
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)

34-41: Consider restoring the return type annotation.

The lazy import is a good practice per project guidelines. However, removing the return type annotation reduces type safety. Consider adding it back using a string literal or TYPE_CHECKING import to maintain mypy compatibility:

Suggested fix
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from transformers import PreTrainedTokenizerBase
+
-def _load_tokenizer(tokenizer_name_or_path: str):
+def _load_tokenizer(tokenizer_name_or_path: str) -> "PreTrainedTokenizerBase":
     """Load tokenizer and ensure pad_token is set."""
     from transformers import AutoTokenizer
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 34 - 41, Restore the return type annotation for _load_tokenizer to
preserve type safety; update the signature to annotate the return as
"PreTrainedTokenizerBase" (a string literal) or import PreTrainedTokenizerBase
under TYPE_CHECKING from transformers and use it as the return type, keeping the
lazy import of AutoTokenizer inside the function and leaving runtime behavior
unchanged; ensure the chosen symbol (PreTrainedTokenizerBase) is referenced in
the annotation for mypy compatibility while avoiding importing transformers at
module import time.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 405-443: Guard against empty per-phase payloads by selecting the
first non-empty calibration dict instead of blindly using
next(iter(calibration_params.values())). Replace sample_params =
next(iter(calibration_params.values())) with logic that finds the first
non-empty value: sample_params = next((v for v in calibration_params.values() if
v), None); set is_percentile only when sample_params is not None, and if
sample_params is None skip building threshold_config / threshold_scale_factor
and the per-phase loops (so export_config stays without threshold entries),
referencing the existing names calibration_params, sample_params, is_percentile,
export_config, threshold_config and threshold_scale_factor.
- Around line 125-139: The code currently suppresses all exceptions around
diffusers backend registration (the block checking isinstance(model, ModelMixin)
and the uses of contextlib.suppress(Exception)), which can hide real
registration errors; change the try/except to only catch ImportError when
importing ModelMixin and only suppress ImportError when calling
register_diffusers_eager_attention and register_diffusers_triton_attention so
that any other Exception raised by those registration functions bubbles up (or
is re-raised) instead of being swallowed; update the other places using
contextlib.suppress(Exception) to suppress ImportError only and ensure failures
in register_diffusers_* or in _set_attn_implementation are not silenced so the
error surfaces during conversion.

---

Outside diff comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 308-317: Add an early validation that explicitly rejects using a
custom forward_loop with decode calibration: if forward_loop is not None and
calibrate_decode is True, raise a clear RuntimeError stating that decode
calibration is incompatible with a custom forward_loop and instruct the user to
disable calibrate_decode or remove the custom forward_loop; place this check
near the start of the calibration flow (before
create_decode_calibration_forward_loop is called). Alternatively, if you prefer
the minimal change, improve the existing RuntimeError in the block that checks
calibration_data and tokenizer to a message that either explains missing
calibration_data/tokenizer or that decode calibration is unsupported when a
custom forward_loop is provided (reference forward_loop, calibrate_decode,
calibration_data, tokenizer, and create_decode_calibration_forward_loop).

---

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 34-41: Restore the return type annotation for _load_tokenizer to
preserve type safety; update the signature to annotate the return as
"PreTrainedTokenizerBase" (a string literal) or import PreTrainedTokenizerBase
under TYPE_CHECKING from transformers and use it as the return type, keeping the
lazy import of AutoTokenizer inside the function and leaving runtime behavior
unchanged; ensure the chosen symbol (PreTrainedTokenizerBase) is referenced in
the annotation for mypy compatibility while avoiding importing transformers at
module import time.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 728d5b26-043d-44d2-a39e-cc945d944a89

📥 Commits

Reviewing files that changed from the base of the PR and between 2c323df and bbe2123.

📒 Files selected for processing (2)
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)

125-154: ⚠️ Potential issue | 🟠 Major

Don’t suppress all backend registration errors.

Line 138 and Lines 150-154 currently swallow arbitrary exceptions. That can hide real backend registration bugs and leave conversion proceeding with a partially configured runtime.

Proposed fix
-    except (ImportError, Exception):
+    except ImportError:
         pass
@@
-    except (ImportError, RuntimeError):
+    except ImportError:
         return
@@
-    if register_ltx_eager_attention is not None:
-        with contextlib.suppress(Exception):
-            register_ltx_eager_attention(model)
-    if register_ltx_triton_attention is not None:
-        with contextlib.suppress(Exception):
-            register_ltx_triton_attention(model)
+    if register_ltx_eager_attention is not None:
+        with contextlib.suppress(ImportError):
+            register_ltx_eager_attention(model)
+    if register_ltx_triton_attention is not None:
+        with contextlib.suppress(ImportError):
+            register_ltx_triton_attention(model)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 125 -
154, The code currently suppresses all exceptions when importing/calling backend
registration functions (the try/except around ModelMixin and the
contextlib.suppress usage for
register_ltx_eager_attention/register_ltx_triton_attention), which hides real
errors; change these to only catch expected import/runtime errors and surface or
log unexpected exceptions: when importing from .kernels and checking
isinstance(model, ModelMixin), catch only ImportError and RuntimeError, and when
invoking register_diffusers_eager_attention,
register_diffusers_triton_attention, register_ltx_eager_attention, and
register_ltx_triton_attention, wrap each call in a try/except that logs the full
exception (including stack trace) via the module logger and re-raises or returns
a clear error instead of silently swallowing it so backend-registration failures
are visible and conversion does not proceed silently with a misconfigured
runtime.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 125-154: The code currently suppresses all exceptions when
importing/calling backend registration functions (the try/except around
ModelMixin and the contextlib.suppress usage for
register_ltx_eager_attention/register_ltx_triton_attention), which hides real
errors; change these to only catch expected import/runtime errors and surface or
log unexpected exceptions: when importing from .kernels and checking
isinstance(model, ModelMixin), catch only ImportError and RuntimeError, and when
invoking register_diffusers_eager_attention,
register_diffusers_triton_attention, register_ltx_eager_attention, and
register_ltx_triton_attention, wrap each call in a try/except that logs the full
exception (including stack trace) via the module logger and re-raises or returns
a clear error instead of silently swallowing it so backend-registration failures
are visible and conversion does not proceed silently with a misconfigured
runtime.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 09ecbcc7-b977-4605-a5e2-b4e013bcaee8

📥 Commits

Reviewing files that changed from the base of the PR and between bbe2123 and 70099a5.

📒 Files selected for processing (1)
  • modelopt/torch/sparsity/attention_sparsity/conversion.py

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested a review from Edwardf0t1 April 7, 2026 02:16
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 198-206: The calibration loop in build_calibration_forward_loop
currently hardcodes guidance_scale=5.0 so calibration activations ignore the CLI
--guidance-scale; add a guidance_scale parameter to
build_calibration_forward_loop (and the other calibration-related functions
mentioned) and pass that guidance_scale through to any calls that currently use
guidance_scale=5.0 (look for explicit guidance_scale=5.0 in the function body
and replace with the new guidance_scale parameter), ensuring the same parameter
is threaded into the calibration runs so activation collection honors the
user-specified guidance scale.
- Around line 177-183: The calibration block is being added at the top-level
sparse_cfg (making "calibration" a selector) instead of being nested under the
self-attention selector; update the code so that when args.calibrate is true you
insert the calibration dict (using args.target_sparsity,
DEFAULT_THRESHOLD_TRIALS and samples:1) into the "*.attn1*" entry of sparse_cfg
(i.e., sparse_cfg["*.attn1*"]["calibration"] = {...}) so the attn1 selector
receives the calibration config rather than creating a new selector at the top
level.
- Around line 77-135: The CLI allows invalid values that later break
calibration; update parse_args() to validate arguments after
parser.parse_args(): ensure args.num_frames and args.calib_frames satisfy (value
- 1) % 4 == 0 and are > 1 (i.e., 4k+1), ensure args.target_sparsity is between
0.0 and 1.0 inclusive, and if a check fails call parser.error(...) with a clear
message referencing the offending flag (e.g., --num-frames, --calib-frames,
--target-sparsity) so users get immediate, actionable feedback; implement these
checks in parse_args() using the parsed args variables.
- Around line 192-193: The code currently loads the entire "caption" column then
slices it; change the dataset load to request only the needed rows by using the
HuggingFace split range: replace load_dataset("nkp37/OpenVid-1M", split="train")
with load_dataset("nkp37/OpenVid-1M", split=f"train[:{calib_size}]") and then
build prompts from the returned small dataset (e.g., prompts =
list(dataset["caption"])). This avoids materializing the full caption column for
the tiny calibration sample.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a396d14c-ebef-49a3-aeb9-c3d398add26a

📥 Commits

Reviewing files that changed from the base of the PR and between 6cc96a4 and 4de0d3b.

📒 Files selected for processing (2)
  • examples/diffusers/README.md
  • examples/diffusers/sparsity/wan22_skip_softmax.py
✅ Files skipped from review due to trivial changes (1)
  • examples/diffusers/README.md

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested a review from a team as a code owner April 7, 2026 19:55
@jingyu-ml jingyu-ml requested a review from kevalmorabia97 April 7, 2026 19:55
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
yeyu-nvidia added a commit that referenced this pull request Apr 8, 2026
Move NVFP4 P-matrix quantization (quantize_p) out of the sparsity module
and into a new modelopt/torch/quantization/sage_attention/ module.

Key changes:
- Add modelopt/torch/quantization/sage_attention/__init__.py with
  apply_sage_attention(transformer) API exposed via mtq namespace.
  Wraps the transformer forward to activate the modelopt_triton diffusers
  backend and set quantize_p=True in thread-local for every call.

- Remove quantize_p from SparseAttentionAttributeConfig (config.py),
  TritonSkipSoftmaxMethod, and TritonSparseSoftmaxMethod — sparsity
  methods no longer control quantization.

- Split thread-local management in diffusers_triton_attention.py:
  * set_triton_skip_softmax_config() no longer accepts quantize_p
  * clear_triton_skip_softmax_config() does NOT reset quantize_p
  * New set_sage_attention_config() / clear_sage_attention_config()
    manage quantize_p independently
  This enables transparent composition: apply_sage_attention() sets
  quantize_p=True at the outer forward level; per-layer sparsity
  contexts clear only their own params without clobbering quantize_p.

- Delete plugins/diffusers.py (WanSparseAttentionModule) — superseded
  by PR #1166's diffusers_triton_attention.py backend approach.

- Update wan2_sage_attention.py example: apply_triton_sparse_kernel()
  no longer accepts quantize_p; --quantize-p now calls
  apply_sage_attention() from modelopt.torch.quantization.

- Update tests to reflect the new API boundaries.

Usage:
    from modelopt.torch.quantization import apply_sage_attention
    apply_sage_attention(pipe.transformer)  # standalone

    # combined with N:M sparse softmax:
    mtsa.sparsify(transformer, mtsa.SPARSE_SOFTMAX_DEFAULT)
    apply_sage_attention(transformer)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 8, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

jingyu-ml and others added 7 commits April 8, 2026 15:36
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested a review from kevalmorabia97 April 9, 2026 04:36
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
import threading

import torch
from ltx_core.model.transformer.attention import Attention
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to be careful about any usages for LTX as its under a different license. See the email thread for Model Optimizer Legal Approval. We'll need to add clear notice in readme about LTX

@kevalmorabia97
Copy link
Copy Markdown
Collaborator

/ok to test 3845b47

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants