Conversation
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR adds skip-softmax sparse attention support for Diffusers and LTX-2 models, including new eager and Triton kernel backends, calibration refinements, example scripts, and comprehensive tests for framework-specific attention implementations. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant Pipeline as WAN Pipeline
participant Sparsify as mtsa.sparsify()
participant Transformer as Transformer<br/>Modules
participant CalibLoop as Calibration<br/>Forward Loop
participant Config as Sparse<br/>Config
User->>Pipeline: build_pipeline(model_path)
activate Pipeline
Pipeline-->>User: pipeline ready
deactivate Pipeline
User->>Config: build_sparse_config(args)
activate Config
Config-->>User: sparse config dict
deactivate Config
alt Calibration Mode
User->>CalibLoop: build_calibration_forward_loop()
activate CalibLoop
CalibLoop-->>User: forward_loop callable
deactivate CalibLoop
User->>Sparsify: sparsify(transformer, config,<br/>forward_loop=...)
activate Sparsify
Sparsify->>CalibLoop: invoke forward_loop<br/>(multiple prompts)
CalibLoop->>Transformer: collect attention stats
Transformer-->>CalibLoop: attention outputs
Sparsify->>Transformer: apply sparse config
Transformer-->>Sparsify: sparse attention modules
Sparsify-->>User: sparsified transformer
deactivate Sparsify
else No Calibration
User->>Sparsify: sparsify(transformer, config)
activate Sparsify
Sparsify->>Transformer: apply sparse config
Transformer-->>Sparsify: sparse attention modules
Sparsify-->>User: sparsified transformer
deactivate Sparsify
end
User->>Pipeline: generate(prompt, ...)
activate Pipeline
Pipeline->>Transformer: forward with sparse<br/>attention
Transformer-->>Pipeline: output frames
deactivate Pipeline
User->>User: print_sparsity_summary(model)
activate User
User->>Transformer: enumerate SparseAttentionModule
Transformer-->>User: module configs
deactivate User
sequenceDiagram
participant Model as Model<br/>(Diffusers/LTX)
participant ConversionFn as convert_to_sparse_<br/>attention_model()
participant RegFn as _register_diffusers_<br/>backends_if_needed()
participant DiffusersBackends as Diffusers<br/>Backends
participant LTXBackends as LTX<br/>Backends
participant Method as Sparse<br/>Method
Model->>ConversionFn: convert_to_sparse_attention_model(model, ...)
activate ConversionFn
ConversionFn->>RegFn: _register_diffusers_backends_if_needed(model)
activate RegFn
alt Is Diffusers ModelMixin
RegFn->>DiffusersBackends: register_diffusers_eager_attention()
RegFn->>DiffusersBackends: register_diffusers_triton_attention()
DiffusersBackends-->>RegFn: backends registered
end
alt Has LTX modules
RegFn->>LTXBackends: patch ltx attention<br/>modules
LTXBackends-->>RegFn: wrappers installed
end
RegFn-->>ConversionFn: registration complete
deactivate RegFn
ConversionFn->>ConversionFn: _set_attn_implementation()
ConversionFn->>Method: apply sparse config
Method-->>ConversionFn: sparse model
ConversionFn-->>Model: sparse model ready
deactivate ConversionFn
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1166 +/- ##
==========================================
- Coverage 75.68% 74.97% -0.71%
==========================================
Files 353 355 +2
Lines 40491 40982 +491
==========================================
+ Hits 30644 30728 +84
- Misses 9847 10254 +407
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
8151232 to
5873652
Compare
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (2)
306-316:⚠️ Potential issue | 🟠 MajorDecode calibration fails when
forward_loopis provided.When a user supplies
forward_loop, lines 264-265 skip buildingtokenizerandcalibration_data. However, decode calibration (lines 312-316) unconditionally requires both, raisingRuntimeErroreven though the user intended to use their own loop.This creates an inconsistency: prefill calibration supports user-provided
forward_loop, but decode calibration does not. The docstring (line 227) also statesforward_loopis "Only used for prefill", but this limitation should either be enforced earlier or decode should also accept a custom loop.💡 Suggested approach
Either:
- Skip decode calibration when
forward_loopis provided andcalibration_dataisNone, with a warning- Accept a separate
decode_forward_loopparameter- Document and enforce that decode calibration requires RULER dataset
# Run decode calibration if enabled if calibrate_decode: + if calibration_data is None or tokenizer is None: + warnings.warn( + "Decode calibration requires RULER dataset. Skipping decode calibration " + "because a custom forward_loop was provided without calibration_data." + ) + else: print("\n" + "=" * 60) # ... rest of decode calibration🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around lines 306 - 316, The decode calibration block (calibrate_decode) currently raises RuntimeError if calibration_data or tokenizer are missing even when the user supplied forward_loop; change the logic in the calibrate_decode section to detect when forward_loop is provided and calibration_data is None and skip decode calibration with a warning instead of raising, i.e., only call create_decode_calibration_forward_loop when calibration_data and tokenizer exist (use create_decode_calibration_forward_loop(calibration_data, tokenizer, ...)), otherwise log/warn that decode calibration is skipped due to missing calibration_data/tokenizer while a custom forward_loop was supplied; update any related docstring or comment near the calibrate_decode and forward_loop mentions to reflect this behavior.
24-24:⚠️ Potential issue | 🔴 CriticalUnconditional
transformersimport causes pipeline failure.The module-level import of
transformers.AutoTokenizerfails whentransformersis not installed. This should be deferred to usage sites (inside_load_tokenizeror guarded) to allow the module to be imported when only diffusers-based workflows are used.🐛 Proposed fix: defer import to usage site
-from transformers import AutoTokenizerThen update
_load_tokenizer:def _load_tokenizer(tokenizer_name_or_path: str) -> "AutoTokenizer": """Load tokenizer and ensure pad_token is set.""" + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` at line 24, The file currently imports transformers.AutoTokenizer at module scope which raises ImportError when transformers isn't installed; move the import into the tokenizer-loading code path so the module can be imported without transformers. Specifically, remove the top-level "from transformers import AutoTokenizer" and instead import AutoTokenizer inside the _load_tokenizer function (or guard the import with a try/except that raises a clear error when the function is called), ensuring _load_tokenizer handles the absence of transformers and only then attempts to create the tokenizer.
♻️ Duplicate comments (1)
modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py (1)
30-36:⚠️ Potential issue | 🟠 MajorSame top-level import issue as the eager backend.
Both
diffusersandmodelopt.torch.kernelsare imported unconditionally at the top level. This will cause import failures for users who don't have diffusers installed or don't have CUDA/Triton available.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py` around lines 30 - 36, Top-level unconditional imports of diffusers symbols (AttentionBackendName, _AttentionBackendRegistry, attention_backend) and modelopt.torch.kernels.attention can fail for users without diffusers or CUDA/Triton; move these imports into the function or class that actually uses them (e.g., inside the registration function or the backend implementation in diffusers_triton_attention.py) and guard with try/except ImportError to raise a clear error only when the backend is instantiated; ensure you reference the same symbols (AttentionBackendName, _AttentionBackendRegistry, attention_backend, and attention) after relocating the imports so registration only occurs when dependencies are present.
🧹 Nitpick comments (4)
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)
124-125: Overly broad exception handling may hide bugs.
except (ImportError, Exception)catches all exceptions including programming errors (e.g.,TypeError,AttributeError). Consider narrowing to specific expected exceptions.♻️ Suggested fix
- except (ImportError, Exception): + except (ImportError, RuntimeError): passOr log unexpected exceptions for debugging:
- except (ImportError, Exception): - pass + except ImportError: + pass + except Exception as e: + import logging + logging.debug(f"Diffusers backend registration failed: {e}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 124 - 125, The current broad except clause "except (ImportError, Exception)" in conversion.py swallows all errors and can hide bugs; change it to only catch ImportError (e.g., "except ImportError as e") for the import-failure path, and if you must catch other runtime issues around the same block, catch specific exceptions or log unexpected exceptions (use a logger.exception or re-raise after logging) so that programming errors like TypeError/AttributeError are not silently ignored; update the except block that follows the import attempt in conversion.py accordingly.examples/diffusers/sparsity/ltx2_skip_softmax.py (2)
66-81: Hardcoded user-specific paths should be placeholders.The default paths contain user-specific paths (
/home/scratch.omniml_data_2/jingyux/...) that won't exist on other systems. Consider using empty strings or raising a clear error when environment variables are not set.Proposed fix: Require explicit configuration
-CHECKPOINT_PATH = os.environ.get( - "LTX2_CHECKPOINT", - "/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-dev.safetensors", -) +CHECKPOINT_PATH = os.environ.get("LTX2_CHECKPOINT", "") +DISTILLED_LORA_PATH = os.environ.get("LTX2_DISTILLED_LORA", "") +SPATIAL_UPSAMPLER_PATH = os.environ.get("LTX2_SPATIAL_UPSAMPLER", "") +GEMMA_ROOT = os.environ.get("LTX2_GEMMA_ROOT", "")Then in
build_pipeline():def build_pipeline() -> TI2VidTwoStagesPipeline: if not all([CHECKPOINT_PATH, DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, GEMMA_ROOT]): raise ValueError( "Missing required environment variables. Set: " "LTX2_CHECKPOINT, LTX2_DISTILLED_LORA, LTX2_SPATIAL_UPSAMPLER, LTX2_GEMMA_ROOT" ) ...🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/sparsity/ltx2_skip_softmax.py` around lines 66 - 81, The file defines CHECKPOINT_PATH, DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, and GEMMA_ROOT with hardcoded user-specific default paths; remove those user paths and default to empty string or None when reading the env vars (os.environ.get(..., "") or None), and add a validation at the start of build_pipeline() that checks these constants (CHECKPOINT_PATH, DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, GEMMA_ROOT) and raises a clear ValueError listing the required env names (LTX2_CHECKPOINT, LTX2_DISTILLED_LORA, LTX2_SPATIAL_UPSAMPLER, LTX2_GEMMA_ROOT) if any are missing so callers must explicitly configure them.
260-267:load_datasetcall may download data unexpectedly.The
load_dataset("nkp37/OpenVid-1M")call will download the dataset on first run, which could be surprising for users. Consider adding a note in the docstring or CLI help, or making this behavior opt-in.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/sparsity/ltx2_skip_softmax.py` around lines 260 - 267, The load_calib_prompts function calls load_dataset("nkp37/OpenVid-1M") which may trigger a large download unexpectedly; update load_calib_prompts to make dataset download explicit or opt-in (e.g., add a parameter like download: bool = False or a dataset_path argument) and update the docstring to warn that calling this function will download the OpenVid-1M dataset unless an existing local dataset path is provided; ensure the function checks the opt-in flag or uses the provided path before invoking load_dataset to avoid surprising downloads.tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py (1)
130-158: Test mock doesn't match actual kernel call signature.The mock
mk.attention = lambda q, k, v, **kw: qreturnsqdirectly, but per the context snippet, the actual kernel receives varlen metadata and returns output with shape[B*S, H, D]. The mock should return a tensor with the correct output shape to avoid masking reshape bugs in the code under test.Proposed fix: Return correctly shaped tensor
- mk.attention = lambda q, k, v, **kw: q + def mock_attention(q, k, v, **kw): + # Return tensor with same shape as q (correct for varlen format) + return torch.zeros_like(q) + mk.attention = mock_attention🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py` around lines 130 - 158, The mock attention implementation in TestDiffusersTritonAttention._setup (mk.attention = lambda q, k, v, **kw: q) does not match the real kernel signature/behavior and returns the wrong shape, which hides reshape/masking bugs; update the mock in _setup (mk.attention) to accept the same args including varlen metadata (keep **kw) and return a tensor with shape [B*S, H, D] derived from the input q/k/v shapes (e.g., compute B, S, H, D from q and construct/return a tensor of that shape instead of returning q directly) so the code paths in _diffusers_triton_attention, set_triton_skip_softmax_config, clear_triton_skip_softmax_config, register_diffusers_triton_attention and get_triton_attention_backend are exercised with realistic output shapes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py`:
- Around line 31-35: The top-level import of diffusers internals
(AttentionBackendName, _AttentionBackendRegistry, attention_backend) should be
guarded to avoid ImportError for users without diffusers; wrap the import in a
try/except and set a module-level flag (e.g., _DIFFUSERS_AVAILABLE =
False/True). Update any registration or accessor functions that reference
AttentionBackendName, _AttentionBackendRegistry or attention_backend to check
_DIFFUSERS_AVAILABLE before using them and return/do nothing or raise a clear
runtime error when diffusers is unavailable. Ensure all places that previously
assumed the imports (registration functions) consult _DIFFUSERS_AVAILABLE so the
module can be imported without diffusers installed.
- Around line 120-132: The code unconditionally manipulates private diffusers
internals (AttentionBackendName, _AttentionBackendRegistry, etc.) that exist
only in diffusers >= 0.36.0; add a runtime/version guard before creating
new_member and registering _diffusers_eager_attention: check
diffusers.__version__ (or use the same utility used in
modelopt/torch/quantization/plugins/diffusion/diffusers.py) or probe for the
presence of attributes like AttentionBackendName._member_map_ and
_AttentionBackendRegistry._backends, and only perform the enum extension and
registry assignments when those APIs exist; also update pyproject.toml to
require diffusers>=0.36.0 or make the registration conditional so the code
no-ops on older diffusers versions.
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 61-70: The _diffusers_triton_attention function currently accepts
attn_mask and enable_gqa but ignores them; update the function so it either
implements GQA and mask handling consistent with the eager backend or explicitly
fails fast: if enable_gqa is True or attn_mask is not None, raise a clear
NotImplementedError mentioning "_diffusers_triton_attention does not support
enable_gqa/attn_mask yet" (or implement the same GQA reshaping/aggregation logic
used in diffusers_eager_attention for query/key/value before calling the Triton
kernel) so callers won't silently get incorrect results.
In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`:
- Around line 60-67: The _ltx_triton_attention function currently accepts a mask
parameter but never uses it; update the implementation to handle masks: either
pass the mask into the Triton kernel via the attn_mask argument when invoking
the kernel (ensure shapes/dtypes match and add logic to convert/expand the mask
to the kernel's expected form), or if kernel masking isn't supported yet,
explicitly reject masks by raising a clear error (e.g., raise
NotImplementedError("mask not supported by _ltx_triton_attention") when mask is
not None) so callers won't silently get incorrect results. Ensure the change is
applied inside _ltx_triton_attention and that any conversion/validation of mask
is performed before the kernel call.
- Line 29: The module unconditionally imports Attention from ltx_core which will
raise ImportError for users without LTX-2; change the top-level import to a
guarded import (try/except ImportError) or defer importing until registration,
set Attention = None on failure, and update register_ltx_triton_attention to
check if Attention is None and raise a clear ImportError like "ltx_core is
required for LTX-2 Triton attention" before proceeding; reference the symbols
Attention and register_ltx_triton_attention in ltx_triton_attention.py to locate
where to apply the guard.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 389-401: The code sets the thread-wide flag via
set_skip_softmax_context(True) immediately which can leak if an exception occurs
before the returned ExitStack is entered; instead create a small context manager
(e.g., using contextlib.contextmanager or a tiny class) that calls
set_skip_softmax_context(True) on __enter__/enter and
set_skip_softmax_context(False) on __exit__/exit, and then register that context
with stack.enter_context rather than calling set_skip_softmax_context and
stack.callback directly; update the function that builds the stack (the block
using ExitStack, get_skip_softmax_attention_backend,
replace_function(torch.nn.functional, "softmax", sparse_softmax)) to enter the
new flag-context via stack.enter_context so the flag is only set when the stack
is actually entered and always cleaned up on exit.
In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py`:
- Around line 56-72: The tests import
modelopt.torch.sparsity.attention_sparsity.kernels which transitively imports
transformers and breaks CI; update TestSkipSoftmaxContext to skip when optional
dependency missing by using pytest.importorskip('transformers') or catching
ImportError before importing get_skip_softmax_context/set_skip_softmax_context
(or call pytest.skip) so the test cleanly skips in environments without
transformers; ensure the changes are applied around the imports used in
TestSkipSoftmaxContext (references: get_skip_softmax_context,
set_skip_softmax_context, TestSkipSoftmaxContext).
- Around line 178-207: The test fails because patching targets under
"modelopt.torch.sparsity.attention_sparsity.kernels" occurs before that
submodule is loaded, causing a Module attribute error; to fix, ensure the module
is imported before patching or patch the symbols at the location they are looked
up by _register_diffusers_backends_if_needed: import
modelopt.torch.sparsity.attention_sparsity.conversion (or the parent package)
first, then patch the call targets register_diffusers_eager_attention and
register_diffusers_triton_attention as used by that conversion module (i.e.,
patch where _register_diffusers_backends_if_needed resolves them) so the
MagicMock replacement applies correctly.
---
Outside diff comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 306-316: The decode calibration block (calibrate_decode) currently
raises RuntimeError if calibration_data or tokenizer are missing even when the
user supplied forward_loop; change the logic in the calibrate_decode section to
detect when forward_loop is provided and calibration_data is None and skip
decode calibration with a warning instead of raising, i.e., only call
create_decode_calibration_forward_loop when calibration_data and tokenizer exist
(use create_decode_calibration_forward_loop(calibration_data, tokenizer, ...)),
otherwise log/warn that decode calibration is skipped due to missing
calibration_data/tokenizer while a custom forward_loop was supplied; update any
related docstring or comment near the calibrate_decode and forward_loop mentions
to reflect this behavior.
- Line 24: The file currently imports transformers.AutoTokenizer at module scope
which raises ImportError when transformers isn't installed; move the import into
the tokenizer-loading code path so the module can be imported without
transformers. Specifically, remove the top-level "from transformers import
AutoTokenizer" and instead import AutoTokenizer inside the _load_tokenizer
function (or guard the import with a try/except that raises a clear error when
the function is called), ensuring _load_tokenizer handles the absence of
transformers and only then attempts to create the tokenizer.
---
Duplicate comments:
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 30-36: Top-level unconditional imports of diffusers symbols
(AttentionBackendName, _AttentionBackendRegistry, attention_backend) and
modelopt.torch.kernels.attention can fail for users without diffusers or
CUDA/Triton; move these imports into the function or class that actually uses
them (e.g., inside the registration function or the backend implementation in
diffusers_triton_attention.py) and guard with try/except ImportError to raise a
clear error only when the backend is instantiated; ensure you reference the same
symbols (AttentionBackendName, _AttentionBackendRegistry, attention_backend, and
attention) after relocating the imports so registration only occurs when
dependencies are present.
---
Nitpick comments:
In `@examples/diffusers/sparsity/ltx2_skip_softmax.py`:
- Around line 66-81: The file defines CHECKPOINT_PATH, DISTILLED_LORA_PATH,
SPATIAL_UPSAMPLER_PATH, and GEMMA_ROOT with hardcoded user-specific default
paths; remove those user paths and default to empty string or None when reading
the env vars (os.environ.get(..., "") or None), and add a validation at the
start of build_pipeline() that checks these constants (CHECKPOINT_PATH,
DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, GEMMA_ROOT) and raises a clear
ValueError listing the required env names (LTX2_CHECKPOINT, LTX2_DISTILLED_LORA,
LTX2_SPATIAL_UPSAMPLER, LTX2_GEMMA_ROOT) if any are missing so callers must
explicitly configure them.
- Around line 260-267: The load_calib_prompts function calls
load_dataset("nkp37/OpenVid-1M") which may trigger a large download
unexpectedly; update load_calib_prompts to make dataset download explicit or
opt-in (e.g., add a parameter like download: bool = False or a dataset_path
argument) and update the docstring to warn that calling this function will
download the OpenVid-1M dataset unless an existing local dataset path is
provided; ensure the function checks the opt-in flag or uses the provided path
before invoking load_dataset to avoid surprising downloads.
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 124-125: The current broad except clause "except (ImportError,
Exception)" in conversion.py swallows all errors and can hide bugs; change it to
only catch ImportError (e.g., "except ImportError as e") for the import-failure
path, and if you must catch other runtime issues around the same block, catch
specific exceptions or log unexpected exceptions (use a logger.exception or
re-raise after logging) so that programming errors like TypeError/AttributeError
are not silently ignored; update the except block that follows the import
attempt in conversion.py accordingly.
In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py`:
- Around line 130-158: The mock attention implementation in
TestDiffusersTritonAttention._setup (mk.attention = lambda q, k, v, **kw: q)
does not match the real kernel signature/behavior and returns the wrong shape,
which hides reshape/masking bugs; update the mock in _setup (mk.attention) to
accept the same args including varlen metadata (keep **kw) and return a tensor
with shape [B*S, H, D] derived from the input q/k/v shapes (e.g., compute B, S,
H, D from q and construct/return a tensor of that shape instead of returning q
directly) so the code paths in _diffusers_triton_attention,
set_triton_skip_softmax_config, clear_triton_skip_softmax_config,
register_diffusers_triton_attention and get_triton_attention_backend are
exercised with realistic output shapes.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 578e7893-06c2-4586-93e1-4726205a2f84
📒 Files selected for processing (14)
examples/diffusers/sparsity/ltx2_skip_softmax.pyexamples/diffusers/sparsity/wan22_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/calibration/calibrate.pymodelopt/torch/sparsity/attention_sparsity/conversion.pymodelopt/torch/sparsity/attention_sparsity/kernels/__init__.pymodelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.pymodelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.pymodelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.pymodelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.pymodelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/plugins/huggingface.pymodelopt/torch/sparsity/attention_sparsity/stats_manager.pytests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py
modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py
Outdated
Show resolved
Hide resolved
modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py
Outdated
Show resolved
Hide resolved
modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py
Show resolved
Hide resolved
modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py
Show resolved
Hide resolved
modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py
Show resolved
Hide resolved
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
Show resolved
Hide resolved
tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)
308-317:⚠️ Potential issue | 🟡 MinorImprove error message and consider early validation when custom
forward_loopconflicts with decode calibration.When a user provides a custom
forward_loop(e.g., for diffusion models) but also configurescalibrate_decode=True, the current RuntimeError message "calibration_data and tokenizer must be built before decode" is confusing—it implies a bug rather than an unsupported configuration.Consider either:
- Early validation (preferred): Check at the start of calibration if
forward_loop is not None and calibrate_decodeand raise with a clear message, or- Improve the error message to explicitly state the limitation.
Option 1: Add early validation near line 246
# Skip if both phases are disabled if not calibrate_prefill and not calibrate_decode: print("Both prefill and decode target sparsity are 0.0, skipping calibration") return {} + # Decode calibration requires RULER dataset, which is incompatible with custom forward_loop + if forward_loop is not None and calibrate_decode: + raise ValueError( + "Decode calibration is not supported when a custom forward_loop is provided. " + "Either set decode target_sparse_ratio to 0.0 or remove the forward_loop argument " + "to use auto-generated RULER dataset calibration." + ) + # Get sparse attention modulesOption 2: Improve error message at lines 313-314
if calibration_data is None or tokenizer is None: - raise RuntimeError("calibration_data and tokenizer must be built before decode") + raise RuntimeError( + "Decode calibration requires tokenizer and RULER dataset, which are not available " + "when a custom forward_loop is provided. Set decode target_sparse_ratio to 0.0 " + "or remove the forward_loop argument to enable decode calibration." + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around lines 308 - 317, Add an early validation that explicitly rejects using a custom forward_loop with decode calibration: if forward_loop is not None and calibrate_decode is True, raise a clear RuntimeError stating that decode calibration is incompatible with a custom forward_loop and instruct the user to disable calibrate_decode or remove the custom forward_loop; place this check near the start of the calibration flow (before create_decode_calibration_forward_loop is called). Alternatively, if you prefer the minimal change, improve the existing RuntimeError in the block that checks calibration_data and tokenizer to a message that either explains missing calibration_data/tokenizer or that decode calibration is unsupported when a custom forward_loop is provided (reference forward_loop, calibrate_decode, calibration_data, tokenizer, and create_decode_calibration_forward_loop).
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)
34-41: Consider restoring the return type annotation.The lazy import is a good practice per project guidelines. However, removing the return type annotation reduces type safety. Consider adding it back using a string literal or
TYPE_CHECKINGimport to maintain mypy compatibility:Suggested fix
+from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + -def _load_tokenizer(tokenizer_name_or_path: str): +def _load_tokenizer(tokenizer_name_or_path: str) -> "PreTrainedTokenizerBase": """Load tokenizer and ensure pad_token is set.""" from transformers import AutoTokenizer🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around lines 34 - 41, Restore the return type annotation for _load_tokenizer to preserve type safety; update the signature to annotate the return as "PreTrainedTokenizerBase" (a string literal) or import PreTrainedTokenizerBase under TYPE_CHECKING from transformers and use it as the return type, keeping the lazy import of AutoTokenizer inside the function and leaving runtime behavior unchanged; ensure the chosen symbol (PreTrainedTokenizerBase) is referenced in the annotation for mypy compatibility while avoiding importing transformers at module import time.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 405-443: Guard against empty per-phase payloads by selecting the
first non-empty calibration dict instead of blindly using
next(iter(calibration_params.values())). Replace sample_params =
next(iter(calibration_params.values())) with logic that finds the first
non-empty value: sample_params = next((v for v in calibration_params.values() if
v), None); set is_percentile only when sample_params is not None, and if
sample_params is None skip building threshold_config / threshold_scale_factor
and the per-phase loops (so export_config stays without threshold entries),
referencing the existing names calibration_params, sample_params, is_percentile,
export_config, threshold_config and threshold_scale_factor.
- Around line 125-139: The code currently suppresses all exceptions around
diffusers backend registration (the block checking isinstance(model, ModelMixin)
and the uses of contextlib.suppress(Exception)), which can hide real
registration errors; change the try/except to only catch ImportError when
importing ModelMixin and only suppress ImportError when calling
register_diffusers_eager_attention and register_diffusers_triton_attention so
that any other Exception raised by those registration functions bubbles up (or
is re-raised) instead of being swallowed; update the other places using
contextlib.suppress(Exception) to suppress ImportError only and ensure failures
in register_diffusers_* or in _set_attn_implementation are not silenced so the
error surfaces during conversion.
---
Outside diff comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 308-317: Add an early validation that explicitly rejects using a
custom forward_loop with decode calibration: if forward_loop is not None and
calibrate_decode is True, raise a clear RuntimeError stating that decode
calibration is incompatible with a custom forward_loop and instruct the user to
disable calibrate_decode or remove the custom forward_loop; place this check
near the start of the calibration flow (before
create_decode_calibration_forward_loop is called). Alternatively, if you prefer
the minimal change, improve the existing RuntimeError in the block that checks
calibration_data and tokenizer to a message that either explains missing
calibration_data/tokenizer or that decode calibration is unsupported when a
custom forward_loop is provided (reference forward_loop, calibrate_decode,
calibration_data, tokenizer, and create_decode_calibration_forward_loop).
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 34-41: Restore the return type annotation for _load_tokenizer to
preserve type safety; update the signature to annotate the return as
"PreTrainedTokenizerBase" (a string literal) or import PreTrainedTokenizerBase
under TYPE_CHECKING from transformers and use it as the return type, keeping the
lazy import of AutoTokenizer inside the function and leaving runtime behavior
unchanged; ensure the chosen symbol (PreTrainedTokenizerBase) is referenced in
the annotation for mypy compatibility while avoiding importing transformers at
module import time.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 728d5b26-043d-44d2-a39e-cc945d944a89
📒 Files selected for processing (2)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.pymodelopt/torch/sparsity/attention_sparsity/conversion.py
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
There was a problem hiding this comment.
♻️ Duplicate comments (1)
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)
125-154:⚠️ Potential issue | 🟠 MajorDon’t suppress all backend registration errors.
Line 138 and Lines 150-154 currently swallow arbitrary exceptions. That can hide real backend registration bugs and leave conversion proceeding with a partially configured runtime.
Proposed fix
- except (ImportError, Exception): + except ImportError: pass @@ - except (ImportError, RuntimeError): + except ImportError: return @@ - if register_ltx_eager_attention is not None: - with contextlib.suppress(Exception): - register_ltx_eager_attention(model) - if register_ltx_triton_attention is not None: - with contextlib.suppress(Exception): - register_ltx_triton_attention(model) + if register_ltx_eager_attention is not None: + with contextlib.suppress(ImportError): + register_ltx_eager_attention(model) + if register_ltx_triton_attention is not None: + with contextlib.suppress(ImportError): + register_ltx_triton_attention(model)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 125 - 154, The code currently suppresses all exceptions when importing/calling backend registration functions (the try/except around ModelMixin and the contextlib.suppress usage for register_ltx_eager_attention/register_ltx_triton_attention), which hides real errors; change these to only catch expected import/runtime errors and surface or log unexpected exceptions: when importing from .kernels and checking isinstance(model, ModelMixin), catch only ImportError and RuntimeError, and when invoking register_diffusers_eager_attention, register_diffusers_triton_attention, register_ltx_eager_attention, and register_ltx_triton_attention, wrap each call in a try/except that logs the full exception (including stack trace) via the module logger and re-raises or returns a clear error instead of silently swallowing it so backend-registration failures are visible and conversion does not proceed silently with a misconfigured runtime.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 125-154: The code currently suppresses all exceptions when
importing/calling backend registration functions (the try/except around
ModelMixin and the contextlib.suppress usage for
register_ltx_eager_attention/register_ltx_triton_attention), which hides real
errors; change these to only catch expected import/runtime errors and surface or
log unexpected exceptions: when importing from .kernels and checking
isinstance(model, ModelMixin), catch only ImportError and RuntimeError, and when
invoking register_diffusers_eager_attention,
register_diffusers_triton_attention, register_ltx_eager_attention, and
register_ltx_triton_attention, wrap each call in a try/except that logs the full
exception (including stack trace) via the module logger and re-raises or returns
a clear error instead of silently swallowing it so backend-registration failures
are visible and conversion does not proceed silently with a misconfigured
runtime.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 09ecbcc7-b977-4605-a5e2-b4e013bcaee8
📒 Files selected for processing (1)
modelopt/torch/sparsity/attention_sparsity/conversion.py
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 198-206: The calibration loop in build_calibration_forward_loop
currently hardcodes guidance_scale=5.0 so calibration activations ignore the CLI
--guidance-scale; add a guidance_scale parameter to
build_calibration_forward_loop (and the other calibration-related functions
mentioned) and pass that guidance_scale through to any calls that currently use
guidance_scale=5.0 (look for explicit guidance_scale=5.0 in the function body
and replace with the new guidance_scale parameter), ensuring the same parameter
is threaded into the calibration runs so activation collection honors the
user-specified guidance scale.
- Around line 177-183: The calibration block is being added at the top-level
sparse_cfg (making "calibration" a selector) instead of being nested under the
self-attention selector; update the code so that when args.calibrate is true you
insert the calibration dict (using args.target_sparsity,
DEFAULT_THRESHOLD_TRIALS and samples:1) into the "*.attn1*" entry of sparse_cfg
(i.e., sparse_cfg["*.attn1*"]["calibration"] = {...}) so the attn1 selector
receives the calibration config rather than creating a new selector at the top
level.
- Around line 77-135: The CLI allows invalid values that later break
calibration; update parse_args() to validate arguments after
parser.parse_args(): ensure args.num_frames and args.calib_frames satisfy (value
- 1) % 4 == 0 and are > 1 (i.e., 4k+1), ensure args.target_sparsity is between
0.0 and 1.0 inclusive, and if a check fails call parser.error(...) with a clear
message referencing the offending flag (e.g., --num-frames, --calib-frames,
--target-sparsity) so users get immediate, actionable feedback; implement these
checks in parse_args() using the parsed args variables.
- Around line 192-193: The code currently loads the entire "caption" column then
slices it; change the dataset load to request only the needed rows by using the
HuggingFace split range: replace load_dataset("nkp37/OpenVid-1M", split="train")
with load_dataset("nkp37/OpenVid-1M", split=f"train[:{calib_size}]") and then
build prompts from the returned small dataset (e.g., prompts =
list(dataset["caption"])). This avoids materializing the full caption column for
the tiny calibration sample.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a396d14c-ebef-49a3-aeb9-c3d398add26a
📒 Files selected for processing (2)
examples/diffusers/README.mdexamples/diffusers/sparsity/wan22_skip_softmax.py
✅ Files skipped from review due to trivial changes (1)
- examples/diffusers/README.md
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Move NVFP4 P-matrix quantization (quantize_p) out of the sparsity module
and into a new modelopt/torch/quantization/sage_attention/ module.
Key changes:
- Add modelopt/torch/quantization/sage_attention/__init__.py with
apply_sage_attention(transformer) API exposed via mtq namespace.
Wraps the transformer forward to activate the modelopt_triton diffusers
backend and set quantize_p=True in thread-local for every call.
- Remove quantize_p from SparseAttentionAttributeConfig (config.py),
TritonSkipSoftmaxMethod, and TritonSparseSoftmaxMethod — sparsity
methods no longer control quantization.
- Split thread-local management in diffusers_triton_attention.py:
* set_triton_skip_softmax_config() no longer accepts quantize_p
* clear_triton_skip_softmax_config() does NOT reset quantize_p
* New set_sage_attention_config() / clear_sage_attention_config()
manage quantize_p independently
This enables transparent composition: apply_sage_attention() sets
quantize_p=True at the outer forward level; per-layer sparsity
contexts clear only their own params without clobbering quantize_p.
- Delete plugins/diffusers.py (WanSparseAttentionModule) — superseded
by PR #1166's diffusers_triton_attention.py backend approach.
- Update wan2_sage_attention.py example: apply_triton_sparse_kernel()
no longer accepts quantize_p; --quantize-p now calls
apply_sage_attention() from modelopt.torch.quantization.
- Update tests to reflect the new API boundaries.
Usage:
from modelopt.torch.quantization import apply_sage_attention
apply_sage_attention(pipe.transformer) # standalone
# combined with N:M sparse softmax:
mtsa.sparsify(transformer, mtsa.SPARSE_SOFTMAX_DEFAULT)
apply_sage_attention(transformer)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
| import threading | ||
|
|
||
| import torch | ||
| from ltx_core.model.transformer.attention import Attention |
There was a problem hiding this comment.
We need to be careful about any usages for LTX as its under a different license. See the email thread for Model Optimizer Legal Approval. We'll need to add clear notice in readme about LTX
|
/ok to test 3845b47 |
What does this PR do?
Type of change: new feature, new example
Summary
triton_skip_softmaxmethod with exponential model calibration (scale_factor = a * exp(b * sparsity)) and log-space fitting for diffusion modelsforward_loop(required for non-LLM models)Changes
Triton kernels (
modelopt/torch/kernels/triton_fa.py)_attn_fwd: Forward kernel with optional tile skipping — tiles whose max attention score is far below the running softmax max are skipped entirely (no V load, no softmax, no accumulation). Runtime sparsity measurement via atomic counters._attn_fwd_calibrate: Calibration kernel that computes full attention while measuring how many tiles would be skipped at each of N thresholds simultaneously. Uses per-program output buffers (zero atomic contention) and vectorized multi-threshold comparison.attention()/attention_calibrate(): Python wrappers for inference and calibration kernels.Kernel backends (
modelopt/torch/sparsity/attention_sparsity/kernels/)diffusers_triton_attention.py: Registersmodelopt_tritonbackend in diffusers' attention dispatch. Handles [B, S, H, D] → varlen layout conversion, calibration/inference mode switching, thread-local configuration, and counter accumulation.ltx_triton_attention.py: Patchesltx_core.Attentionmodules for Triton dispatch with the same calibration/inference modes.Method (
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py)TritonSkipSoftmaxMethod: Context managers for calibration (→ calibration kernel) and inference (→ forward kernel with tile skipping). Three threshold priority levels: raw threshold > calibrated scale_factor > static threshold.Calibration (
modelopt/torch/sparsity/attention_sparsity/calibration/)calibrator.py:DynamicThresholdCalibratorwithfit_logspaceoption — fits exponential model in log space (minimizes relative error) for diffusion models where scale_factors span many orders of magnitude. Records observed sparsity range for extrapolation warnings.calibrate.py: Skips RULER dataset whenforward_loopis provided; passesfit_logspacethrough from config.Config & conversion
config.py:CalibrationConfig.fit_logspacefield (default False, recommended True for diffusion models).skip_softmax_raw_thresholdfield for direct threshold mode.conversion.py: Auto-registers diffusers/LTX Triton backends onsparsify(). Updated summary display.Example
wan22_skip_softmax.py: End-to-end example for WAN 2.2 5B/14B with baseline, raw-threshold, and calibrated modes. Supports runtime sparsity reporting.Threshold modes
--raw-threshold -0.7)skip_threshold_log2--calibrate --target-sparsity 0.5)scale_factor = a * exp(b * target), thenthreshold = scale_factor / seq_kat runtimeskip_softmax_threshold=0.1)log2(lambda) * sm_scaleUsage
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅Additional Information
Summary by CodeRabbit
Release Notes
New Features
Documentation
Tests