Skip to content

Commit 26ae8da

Browse files
authored
[2/3] Implicit Gemm NVFP4 (#1227)
### What does this PR do? Type of change: new feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> - Add Conv3D implicit GEMM kernel with BF16 WMMA tensor cores and fused NVFP4 activation quantization for video diffusion VAE layers - Integrate into _QuantConv3d via QuantModuleRegistry — automatically dispatched when NVFP4 quantization is applied to nn.Conv3d - Move kernel from `experimental/conv/ to modelopt/torch/kernels/conv/`; move tests to `tests/gpu/torch/quantization/kernels/` ### Testing <!-- Mention how have you tested your change if applicable. --> - Added test cases to measure the difference between cuDNN and our CUDA implicit GEMM kernel - Added an NVFP4 fake quantization test using CUDA code ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ <!--- If ❌, explain why. --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ <!--- Mandatory --> - Did you write any new necessary tests?: ✅ <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Per-backbone quantization/export in a single run with per-backbone checkpoints and backbone-aware quant filters * Configurable NVFP4 block-size via CLI/config; improved NVFP4 Conv3D inference path and Wan 2.2 quantization support * **Bug Fixes** * Video-model calibration now respects extra params and forces video decoding during calibration * **Documentation** * Added comprehensive Conv3D implicit‑GEMM kernel documentation; removed experimental Conv3D prototype docs/benchmark * **Tests** * New Wan 2.2 quantization/export tests and expanded Conv3D/FP4 kernel test coverage <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent c20f9c4 commit 26ae8da

23 files changed

Lines changed: 1458 additions & 522 deletions

File tree

.github/codecov.yml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,3 @@ coverage:
1111
target: auto
1212
threshold: 1% # Allow atmost 1% coverage drop from main branch.
1313
patch: false
14-
15-
# Exclude GPU-only Triton kernel files from ALL codecov calculations (project
16-
# and patch checks, all flags). Rationale: these files are dominated by
17-
# @triton.jit kernel bodies that CPU unit tests cannot exercise. GPU tests
18-
# cover them end-to-end (see tests/gpu/torch/sparsity/attention_sparsity/) but
19-
# the `gpu`-flag upload may race with the PR status check, so relying on flag
20-
# combination alone leaves the project check flaky. Dropping these files here
21-
# makes the check deterministic — local `pytest --cov` and GPU runs still
22-
# measure them; only the codecov PR status ignores them.
23-
ignore:
24-
- "modelopt/torch/kernels/triton_fa.py"
25-
- "modelopt/torch/kernels/hf_triton_attention.py"

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Changelog
1616
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
1717
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
1818
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml>`_ for usage.
19+
- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.quantization.src.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning.
1920

2021
**Backward Breaking Changes**
2122

examples/diffusers/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,20 @@ python quantize.py \
117117
--hf-ckpt-dir ./hf_ckpt
118118
```
119119

120+
#### Wan 2.2 VAE NVFP4 (Conv3D Implicit GEMM)
121+
122+
The Wan 2.2 VAE (`AutoencoderKLWan`, shared between the 5B and 14B pipelines) is built from 3D convolutions. When quantizing the VAE with NVFP4, the `Conv3d` layers are automatically dispatched through a custom BF16 WMMA implicit-GEMM kernel with fused FP4 activation quantization. Requires SM80+ (Ampere or newer). See [`modelopt/torch/quantization/src/conv/README.md`](../../modelopt/torch/quantization/src/conv/README.md) for kernel details.
123+
124+
```sh
125+
python quantize.py \
126+
--model {wan2.2-t2v-14b|wan2.2-t2v-5b} \
127+
--backbone vae \
128+
--format fp4 --quant-algo max --collect-method default \
129+
--model-dtype BFloat16 --trt-high-precision-dtype BFloat16 \
130+
--batch-size 1 --calib-size 32 --n-steps 30 \
131+
--quantized-torch-ckpt-save-path ./wan22_vae_fp4.pt
132+
```
133+
120134
#### [LTX-2](https://github.com/Lightricks/LTX-2) FP4
121135

122136
> [!WARNING]

examples/diffusers/quantization/calibration.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,12 @@ def run_calibration(self, batched_prompts: list[list[str]]) -> None:
108108
def _run_wan_video_calibration(
109109
self, prompt_batch: list[str], extra_args: dict[str, Any]
110110
) -> None:
111+
extra_params = self.pipeline_manager.config.extra_params
111112
kwargs = {}
112113
kwargs["negative_prompt"] = extra_args["negative_prompt"]
113-
kwargs["height"] = extra_args["height"]
114-
kwargs["width"] = extra_args["width"]
115-
kwargs["num_frames"] = extra_args["num_frames"]
114+
kwargs["height"] = extra_params.get("height", extra_args["height"])
115+
kwargs["width"] = extra_params.get("width", extra_args["width"])
116+
kwargs["num_frames"] = extra_params.get("num_frames", extra_args["num_frames"])
116117
kwargs["guidance_scale"] = extra_args["guidance_scale"]
117118
if "guidance_scale_2" in extra_args:
118119
kwargs["guidance_scale_2"] = extra_args["guidance_scale_2"]
@@ -154,7 +155,11 @@ def _run_ltx2_calibration(self, prompt_batch: list[str], extra_args: dict[str, A
154155
"images": extra_params.get("images", []),
155156
"tiling_config": extra_params.get("tiling_config", TilingConfig.default()),
156157
}
157-
self.pipe(prompt=prompt, **kwargs)
158+
decoded_video, decoded_audio = self.pipe(prompt=prompt, **kwargs)
159+
# vae_decode_video returns a lazy generator — consume it so the
160+
# video decoder's forward() actually runs during calibration.
161+
for _ in decoded_video:
162+
pass
158163

159164
def _run_ltx_video_calibration(
160165
self, prompt_batch: list[str], extra_args: dict[str, Any]

examples/diffusers/quantization/models_utils.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
from utils import (
3434
filter_func_default,
3535
filter_func_flux_dev,
36+
filter_func_ltx2_vae,
3637
filter_func_ltx_video,
38+
filter_func_wan_vae,
3739
filter_func_wan_video,
3840
)
3941

@@ -54,31 +56,30 @@ class ModelType(str, Enum):
5456
WAN22_T2V_5b = "wan2.2-t2v-5b"
5557

5658

57-
def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
58-
"""
59-
Get the appropriate filter function for a given model type.
59+
_FILTER_FUNC_MAP: dict[ModelType, Callable[[str], bool]] = {
60+
ModelType.FLUX_DEV: filter_func_flux_dev,
61+
ModelType.FLUX2_DEV: filter_func_flux_dev,
62+
ModelType.LTX_VIDEO_DEV: filter_func_ltx_video,
63+
ModelType.LTX2: filter_func_ltx_video,
64+
ModelType.WAN22_T2V_14b: filter_func_wan_video,
65+
ModelType.WAN22_T2V_5b: filter_func_wan_video,
66+
}
6067

61-
Args:
62-
model_type: The model type enum
68+
_VAE_FILTER_FUNC_MAP: dict[tuple[ModelType, str], Callable[[str], bool]] = {
69+
(ModelType.LTX2, "video_decoder"): filter_func_ltx2_vae,
70+
(ModelType.WAN22_T2V_14b, "vae"): filter_func_wan_vae,
71+
(ModelType.WAN22_T2V_5b, "vae"): filter_func_wan_vae,
72+
}
6373

64-
Returns:
65-
A filter function appropriate for the model type
66-
"""
67-
filter_func_map = {
68-
ModelType.FLUX_DEV: filter_func_flux_dev,
69-
ModelType.FLUX_SCHNELL: filter_func_default,
70-
ModelType.FLUX2_DEV: filter_func_flux_dev,
71-
ModelType.SDXL_BASE: filter_func_default,
72-
ModelType.SDXL_TURBO: filter_func_default,
73-
ModelType.SD3_MEDIUM: filter_func_default,
74-
ModelType.SD35_MEDIUM: filter_func_default,
75-
ModelType.LTX_VIDEO_DEV: filter_func_ltx_video,
76-
ModelType.LTX2: filter_func_ltx_video,
77-
ModelType.WAN22_T2V_14b: filter_func_wan_video,
78-
ModelType.WAN22_T2V_5b: filter_func_wan_video,
79-
}
8074

81-
return filter_func_map.get(model_type, filter_func_default)
75+
def get_model_filter_func(
76+
model_type: ModelType, backbone_name: str = "transformer"
77+
) -> Callable[[str], bool]:
78+
"""Get the appropriate filter function for a given model type and backbone."""
79+
vae_func = _VAE_FILTER_FUNC_MAP.get((model_type, backbone_name))
80+
if vae_func is not None:
81+
return vae_func
82+
return _FILTER_FUNC_MAP.get(model_type, filter_func_default)
8283

8384

8485
# Model registry with HuggingFace model IDs

examples/diffusers/quantization/pipeline_manager.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, config: ModelConfig, logger: logging.Logger):
4242
self.pipe: Any | None = None
4343
self.pipe_upsample: LTXLatentUpsamplePipeline | None = None # For LTX-Video upsampling
4444
self._transformer: torch.nn.Module | None = None
45+
self._video_decoder: torch.nn.Module | None = None
4546

4647
@staticmethod
4748
def create_pipeline_from(
@@ -58,23 +59,20 @@ def create_pipeline_from(
5859
Raises:
5960
ValueError: If model type is unsupported
6061
"""
61-
try:
62-
pipeline_cls = MODEL_PIPELINE[model_type]
63-
if pipeline_cls is None:
64-
raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.")
65-
model_id = (
66-
MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path
67-
)
68-
pipe = pipeline_cls.from_pretrained(
69-
model_id,
70-
torch_dtype=torch_dtype,
71-
use_safetensors=True,
72-
**MODEL_DEFAULTS[model_type].get("from_pretrained_extra_args", {}),
73-
)
74-
pipe.set_progress_bar_config(disable=True)
75-
return pipe
76-
except Exception as e:
77-
raise e
62+
pipeline_cls = MODEL_PIPELINE[model_type]
63+
if pipeline_cls is None:
64+
raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.")
65+
model_id = (
66+
MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path
67+
)
68+
pipe = pipeline_cls.from_pretrained(
69+
model_id,
70+
torch_dtype=torch_dtype,
71+
use_safetensors=True,
72+
**MODEL_DEFAULTS[model_type].get("from_pretrained_extra_args", {}),
73+
)
74+
pipe.set_progress_bar_config(disable=True)
75+
return pipe
7876

7977
def create_pipeline(self) -> Any:
8078
"""
@@ -157,42 +155,32 @@ def setup_device(self) -> None:
157155
self.logger.info("Enabling VAE tiling for LTX-Video")
158156
self.pipe.vae.enable_tiling()
159157

160-
def get_backbone(self) -> torch.nn.Module:
161-
"""
162-
Get the backbone model (transformer or UNet).
163-
164-
Returns:
165-
Backbone model module
166-
"""
167-
if not self.pipe:
168-
raise RuntimeError("Pipeline not created. Call create_pipeline() first.")
169-
170-
backbone_pairs = list(self.iter_backbones())
171-
if len(backbone_pairs) == 1:
172-
return backbone_pairs[0][1]
173-
return torch.nn.ModuleList([module for _, module in backbone_pairs])
174-
175158
def iter_backbones(self) -> Iterator[tuple[str, torch.nn.Module]]:
176159
"""
177-
Yield backbone modules by name, based on a backbone spec.
178-
179-
Yields:
180-
(backbone_name, module) pairs
160+
Yield (backbone_name, module) pairs.
181161
"""
182162
if not self.pipe:
183163
raise RuntimeError("Pipeline not created. Call create_pipeline() first.")
184164

185165
names = list(self.config.backbone)
166+
if not names:
167+
raise RuntimeError("No backbone names provided.")
186168

187169
if self.config.model_type == ModelType.LTX2:
188-
self._ensure_ltx2_transformer_cached()
189-
name = names[0] if names else "transformer"
190-
yield name, self._transformer
170+
for name in names:
171+
if name == "video_decoder":
172+
self._ensure_ltx2_video_decoder_cached()
173+
yield name, self._video_decoder
174+
elif name == "transformer":
175+
self._ensure_ltx2_transformer_cached()
176+
yield name, self._transformer
177+
else:
178+
raise ValueError(
179+
f"Unsupported LTX-2 backbone name '{name}'. "
180+
"Expected 'transformer' or 'video_decoder'."
181+
)
191182
return
192183

193-
if not names:
194-
raise RuntimeError("No backbone names provided.")
195-
196184
for name in names:
197185
module = getattr(self.pipe, name, None)
198186
if module is None:
@@ -207,6 +195,16 @@ def _ensure_ltx2_transformer_cached(self) -> None:
207195
self.pipe.stage_1_model_ledger.transformer = lambda: transformer
208196
self._transformer = transformer
209197

198+
def _ensure_ltx2_video_decoder_cached(self) -> None:
199+
if not self.pipe:
200+
raise RuntimeError("Pipeline not created. Call create_pipeline() first.")
201+
if self._video_decoder is None:
202+
video_decoder = self.pipe.stage_1_model_ledger.video_decoder()
203+
# Cache it so subsequent calls return the same (quantized) instance
204+
self.pipe.stage_1_model_ledger.video_decoder = lambda: video_decoder
205+
self.pipe.stage_2_model_ledger.video_decoder = lambda: video_decoder
206+
self._video_decoder = video_decoder
207+
210208
def _create_ltx2_pipeline(self) -> Any:
211209
params = dict(self.config.extra_params)
212210
checkpoint_path = params.pop("checkpoint_path", None)
@@ -261,7 +259,6 @@ def _create_ltx2_pipeline(self) -> Any:
261259
return TI2VidTwoStagesPipeline(**pipeline_kwargs)
262260

263261
def print_quant_summary(self):
264-
backbone_pairs = list(self.iter_backbones())
265-
for name, backbone in backbone_pairs:
262+
for name, backbone in self.iter_backbones():
266263
self.logger.info(f"{name} quantization info:")
267264
mtq.print_quant_summary(backbone)

0 commit comments

Comments
 (0)