Copilot/integrate spargeattn optimization again#491
Open
Copilot/integrate spargeattn optimization again#491
Conversation
…own3D, DupUp3D, and Wan2_2_VAE wrapper class
- Add VAEArchitectureConfig for encoder/decoder configuration - Add VAEEncodingConfig for encoding parameters - Add VAEModelConfig for complete model configuration - Implement VAEConfigManager with full CRUD operations - Support JSON serialization/deserialization - Include predefined configs for Wan2.1 and Wan2.2 - Add config cloning, updating, saving, and loading - Support batch import/export operations
Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
…it__ Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
Add NVFP4 (4-bit floating point) quantization support for Blackwell GPUs
…ackwell GPUs Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
…fication, remove private APIs Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
Add NVFP4 async offloading and pinned memory for Blackwell GPU optimization
Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
….h when unavailable Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
Integrate NVIDIA GPU Optimizations: Async Offloading, Pinned Memory, torch.compile for Windows/Blackwell
…Memory, torch.compile for Windows/Blackwell"
…-gpu-memory Revert "Integrate NVIDIA GPU Optimizations: Async Offloading, Pinned Memory, torch.compile for Windows/Blackwell"
Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
…ell support Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
Co-authored-by: naxci1 <206254294+naxci1@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Changelog
All notable changes to the ComfyUI-SeedVR2.5 project will be documented in this file.
The format is based on Keep a Changelog,
and this project adheres to Semantic Versioning.
[Unreleased]
Added
SpargeAttn/Sage2 Block-Sparse Attention Integration
sparge_sage2- Block-sparse attention optimized for NVIDIA Blackwell (RTX 50xx) GPUsspas_sage2_attn_meansim_topk_cudablock_sparse_sage2_attn_cudawith strict mask geometry (128x64 blocks)sparge_sage2→sageattn_3→sageattn_2→sdpaLocal SpargeAttn Module (
src/optimization/spas_sage_attn/)core.py- Main API functions (spas_sage2_attn_meansim_topk_cuda,block_sparse_sage2_attn_cuda)utils.py- Utility functions for block map computationquant_per_block.py- INT8 quantization kernelsautotune.py- Triton autotuning utilitiesBlackwell (RTX 50xx) Specific Optimizations
Sage2BlackwellConfigclass with Blackwell-tuned parameters:get_blackwell_config()function for architecture-specific kernel tuningVerification & Benchmarking Scripts
scripts/sage2_verification.py- Numerical parity verification against SDPA baselinescripts/sage2_benchmark.py- Comprehensive performance benchmarkingCompatibility Layer Enhancements
src/optimization/compatibility.py:call_sparge_sage2_attn()- Direct Sage2 attention callcall_block_sparse_sage2_attn()- Block-sparse with custom maskscall_sparge_sage2_varlen()- Variable-length sequence supportSage2BlackwellConfig.validate_mask_geometry()Changed
Dependencies
torch>=2.3.0- Minimum PyTorch version for CUDA 12.x compatibilityninja>=1.11- Required for SpargeAttn Triton kernel compilationAttention Backends
FlashAttentionVarlenclass (both dit_3b and dit_7b) to supportsparge_sage2modeTechnical Details
Sage2 API Usage
The Sage2 architecture provides two primary APIs:
Plug-and-Play API (recommended for most use cases):
Block-Sparse API (for custom sparsity patterns):
Blackwell-Specific Tuning
num_warps=8,num_stages=4,block_m=128,block_n=64TOPK_FAST = 0.3- Maximum speed, some accuracy tradeoffTOPK_BALANCED = 0.5- Default, balanced speed/accuracyTOPK_QUALITY = 0.7- Higher quality, less speedupBlock-Sparse Mask Geometry
The block-sparse API requires masks with specific geometry:
(batch_size, num_heads, ceil(seq_len/128), ceil(seq_len/64))Installation
Prerequisites
Local Integration (Recommended - No Build Required)
The SpargeAttn implementation is now vendored locally in
src/optimization/spas_sage_attn/.No separate installation is needed - Triton kernels compile JIT on first use.
Global Installation (Optional - For Full CUDA Kernel Support)
For maximum performance with precompiled CUDA kernels (if local JIT has issues):
Verification
Performance Notes
Expected Performance (Blackwell GPUs)
Based on Sage2 architecture characteristics:
Fallback Behavior
Known Limitations
Migration Guide
Enabling SpargeAttn/Sage2
To use the new attention mode, set
attention_mode='sparge_sage2'in your pipeline configuration:Adjusting Sparsity
For custom sparsity levels, pass the
topkparameter through kwargs: