Skip to content

Commit 4248327

Browse files
committed
Address PR review comments
- CLIP scoring: use CPU to avoid OOM with WAN pipeline on GPU; catch RuntimeError in addition to OSError - --skip-threshold help: fix description to match actual exp(tile_max - running_max) < lambda criterion - vLLM worker: reject unsupported sparse presets (non-triton backend or unknown method) with a clear ValueError instead of silently degrading to dense attention - PYTHONPATH construction: use os.pathsep and skip empty entries to avoid CWD injection when PYTHONPATH is unset - diffusers_triton backend: raise ValueError when mixed with other backends instead of silently skipping _attn_implementation setup - _wan_forward_triton: fall back to SDPA when attention_mask is not None to preserve masking semantics Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 8996ef1 commit 4248327

5 files changed

Lines changed: 37 additions & 7 deletions

File tree

examples/diffusers/quantization/wan2_sage_attention.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,9 @@ def parse_args() -> argparse.Namespace:
794794
help=(
795795
"Override skip_softmax_threshold for triton-skip / triton-skip-nvfp4 kernels. "
796796
f"Default: {_TRITON_SKIP_DEFAULT_THRESHOLD}. "
797-
"A tile is skipped when its max attention score is less than LAMBDA times the "
798-
"running maximum (BLASST criterion). Lower = better quality, less speedup. "
797+
"A tile is skipped when exp(tile_max - running_max) < LAMBDA "
798+
"(equivalently: tile_max < running_max + log(LAMBDA)). "
799+
"Lower = better quality, less speedup. "
799800
"Typical sweep: 0.1 (aggressive), 0.01 (moderate), 0.001 (conservative)."
800801
),
801802
)
@@ -830,14 +831,14 @@ def main() -> None:
830831
disable_attention_kernel()
831832

832833
# --- CLIP scores (per-video semantic alignment with prompt) ---
833-
device = "cuda" if torch.cuda.is_available() else "cpu"
834+
# Use CPU to avoid OOM: the WAN pipeline already occupies GPU memory.
834835
print("\nComputing CLIP scores (prompt-video semantic alignment)...")
835836
try:
836837
clip_base = compute_clip_score(
837-
frames_base, args.prompt, clip_model_id=args.clip_model, device=device
838+
frames_base, args.prompt, clip_model_id=args.clip_model, device="cpu"
838839
)
839840
clip_quant = compute_clip_score(
840-
frames_quant, args.prompt, clip_model_id=args.clip_model, device=device
841+
frames_quant, args.prompt, clip_model_id=args.clip_model, device="cpu"
841842
)
842843
print(f" baseline CLIP: {clip_base:.4f}")
843844
print(f" {args.kernel} CLIP: {clip_quant:.4f} (delta {clip_quant - clip_base:+.4f})")
@@ -847,7 +848,7 @@ def main() -> None:
847848
print(
848849
" Tip: set HF_TOKEN env var or use --clip-model <local-path> to avoid rate limits"
849850
)
850-
except OSError as e:
851+
except (OSError, RuntimeError) as e:
851852
print(f" WARNING: CLIP scoring failed ({e})")
852853
print(" To fix: set HF_TOKEN env var or pass --clip-model <local-path-to-clip>")
853854

examples/vllm_serve/sparse_attn_worker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,16 @@ def _replace_attention_impl(worker, config: dict):
143143
if layer_cfg is None or not layer_cfg.get("enable", True):
144144
continue
145145

146+
method = layer_cfg.get("method", "triton_sparse_softmax")
147+
backend = layer_cfg.get("backend", "triton")
148+
if backend != "triton" or method not in {"triton_sparse_softmax", "triton_skip_softmax"}:
149+
raise ValueError(
150+
f"{name}: unsupported sparse config for vLLM worker "
151+
f"(backend={backend!r}, method={method!r}). "
152+
"Only backend='triton' with method='triton_sparse_softmax' or "
153+
"'triton_skip_softmax' is supported."
154+
)
155+
146156
# Build per-layer sparse kwargs
147157
sparse_kw = {}
148158
sparsity_n = layer_cfg.get("sparsity_n", 0)

examples/vllm_serve/vllm_serve_sparse_attn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def main():
7272
repo_root = str(Path(__file__).resolve().parent)
7373
if repo_root not in sys.path:
7474
sys.path.insert(0, repo_root)
75-
os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + f"{repo_root}"
75+
existing = os.environ.get("PYTHONPATH")
76+
parts = [p for p in [existing, repo_root] if p]
77+
os.environ["PYTHONPATH"] = os.pathsep.join(parts)
7678

7779
# Select worker based on env vars
7880
has_quant = os.environ.get("QUANT_CFG") or os.environ.get("KV_QUANT_CFG")

modelopt/torch/sparsity/attention_sparsity/conversion.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ def _set_attn_implementation(model: nn.Module, config: SparseAttentionConfig) ->
6363
# diffusers_triton: the ModelOptWanAttnProcessor calls triton_fa directly.
6464
# No HF attention-function registration or _attn_implementation patching needed.
6565
if "diffusers_triton" in backends:
66+
if len(backends) > 1:
67+
raise ValueError(
68+
"Mixed backends including 'diffusers_triton' are not supported. "
69+
"All sparse attention layers must use the same backend."
70+
)
6671
return
6772

6873
model_config = getattr(model, "config", None)

modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,18 @@ def _wan_forward_triton(
208208
rotary_emb: tuple[torch.Tensor, torch.Tensor] | None,
209209
) -> torch.Tensor:
210210
"""Triton-backed WAN attention (self-attention and I2V cross-attention)."""
211+
if attention_mask is not None:
212+
from diffusers.models.attention_dispatch import dispatch_attention_fn
213+
214+
return self._wan_forward_sdpa(
215+
attn,
216+
hidden_states,
217+
encoder_hidden_states,
218+
attention_mask,
219+
rotary_emb,
220+
dispatch_fn=dispatch_attention_fn,
221+
)
222+
211223
encoder_hidden_states_img = None
212224
if attn.add_k_proj is not None:
213225
# 512 is the text-encoder context length (WAN hardcoded constant)

0 commit comments

Comments
 (0)