Skip to content

Commit 5b1a358

Browse files
committed
feat: NVFP4 Conv3D implicit GEMM kernel with end-to-end integration
- Move CUDA implicit GEMM kernel from experimental/ to modelopt/torch/quantization/src/conv/ - Extend QuantConv to dispatch into the implicit GEMM kernel for NVFP4 - Add diffusers plugin hook for Wan Conv3D - Add unit, GPU kernel, and example integration tests - Update examples/diffusers/quantization/ for the E2E flow Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent feec81a commit 5b1a358

21 files changed

Lines changed: 1469 additions & 510 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Changelog
1616
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
1717
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
1818
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml>`_ for usage.
19+
- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.quantization.src.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning.
1920

2021
**Backward Breaking Changes**
2122

examples/diffusers/quantization/calibration.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,12 @@ def run_calibration(self, batched_prompts: list[list[str]]) -> None:
108108
def _run_wan_video_calibration(
109109
self, prompt_batch: list[str], extra_args: dict[str, Any]
110110
) -> None:
111+
extra_params = self.pipeline_manager.config.extra_params
111112
kwargs = {}
112113
kwargs["negative_prompt"] = extra_args["negative_prompt"]
113-
kwargs["height"] = extra_args["height"]
114-
kwargs["width"] = extra_args["width"]
115-
kwargs["num_frames"] = extra_args["num_frames"]
114+
kwargs["height"] = extra_params.get("height", extra_args["height"])
115+
kwargs["width"] = extra_params.get("width", extra_args["width"])
116+
kwargs["num_frames"] = extra_params.get("num_frames", extra_args["num_frames"])
116117
kwargs["guidance_scale"] = extra_args["guidance_scale"]
117118
if "guidance_scale_2" in extra_args:
118119
kwargs["guidance_scale_2"] = extra_args["guidance_scale_2"]
@@ -154,7 +155,11 @@ def _run_ltx2_calibration(self, prompt_batch: list[str], extra_args: dict[str, A
154155
"images": extra_params.get("images", []),
155156
"tiling_config": extra_params.get("tiling_config", TilingConfig.default()),
156157
}
157-
self.pipe(prompt=prompt, **kwargs)
158+
decoded_video, decoded_audio = self.pipe(prompt=prompt, **kwargs)
159+
# vae_decode_video returns a lazy generator — consume it so the
160+
# video decoder's forward() actually runs during calibration.
161+
for _ in decoded_video:
162+
pass
158163

159164
def _run_ltx_video_calibration(
160165
self, prompt_batch: list[str], extra_args: dict[str, Any]

examples/diffusers/quantization/models_utils.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
from utils import (
3434
filter_func_default,
3535
filter_func_flux_dev,
36+
filter_func_ltx2_vae,
3637
filter_func_ltx_video,
38+
filter_func_wan_vae,
3739
filter_func_wan_video,
3840
)
3941

@@ -54,31 +56,30 @@ class ModelType(str, Enum):
5456
WAN22_T2V_5b = "wan2.2-t2v-5b"
5557

5658

57-
def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
58-
"""
59-
Get the appropriate filter function for a given model type.
59+
_FILTER_FUNC_MAP: dict[ModelType, Callable[[str], bool]] = {
60+
ModelType.FLUX_DEV: filter_func_flux_dev,
61+
ModelType.FLUX2_DEV: filter_func_flux_dev,
62+
ModelType.LTX_VIDEO_DEV: filter_func_ltx_video,
63+
ModelType.LTX2: filter_func_ltx_video,
64+
ModelType.WAN22_T2V_14b: filter_func_wan_video,
65+
ModelType.WAN22_T2V_5b: filter_func_wan_video,
66+
}
6067

61-
Args:
62-
model_type: The model type enum
68+
_VAE_FILTER_FUNC_MAP: dict[tuple[ModelType, str], Callable[[str], bool]] = {
69+
(ModelType.LTX2, "video_decoder"): filter_func_ltx2_vae,
70+
(ModelType.WAN22_T2V_14b, "vae"): filter_func_wan_vae,
71+
(ModelType.WAN22_T2V_5b, "vae"): filter_func_wan_vae,
72+
}
6373

64-
Returns:
65-
A filter function appropriate for the model type
66-
"""
67-
filter_func_map = {
68-
ModelType.FLUX_DEV: filter_func_flux_dev,
69-
ModelType.FLUX_SCHNELL: filter_func_default,
70-
ModelType.FLUX2_DEV: filter_func_flux_dev,
71-
ModelType.SDXL_BASE: filter_func_default,
72-
ModelType.SDXL_TURBO: filter_func_default,
73-
ModelType.SD3_MEDIUM: filter_func_default,
74-
ModelType.SD35_MEDIUM: filter_func_default,
75-
ModelType.LTX_VIDEO_DEV: filter_func_ltx_video,
76-
ModelType.LTX2: filter_func_ltx_video,
77-
ModelType.WAN22_T2V_14b: filter_func_wan_video,
78-
ModelType.WAN22_T2V_5b: filter_func_wan_video,
79-
}
8074

81-
return filter_func_map.get(model_type, filter_func_default)
75+
def get_model_filter_func(
76+
model_type: ModelType, backbone_name: str = "transformer"
77+
) -> Callable[[str], bool]:
78+
"""Get the appropriate filter function for a given model type and backbone."""
79+
vae_func = _VAE_FILTER_FUNC_MAP.get((model_type, backbone_name))
80+
if vae_func is not None:
81+
return vae_func
82+
return _FILTER_FUNC_MAP.get(model_type, filter_func_default)
8283

8384

8485
# Model registry with HuggingFace model IDs

examples/diffusers/quantization/pipeline_manager.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, config: ModelConfig, logger: logging.Logger):
4242
self.pipe: Any | None = None
4343
self.pipe_upsample: LTXLatentUpsamplePipeline | None = None # For LTX-Video upsampling
4444
self._transformer: torch.nn.Module | None = None
45+
self._video_decoder: torch.nn.Module | None = None
4546

4647
@staticmethod
4748
def create_pipeline_from(
@@ -58,23 +59,20 @@ def create_pipeline_from(
5859
Raises:
5960
ValueError: If model type is unsupported
6061
"""
61-
try:
62-
pipeline_cls = MODEL_PIPELINE[model_type]
63-
if pipeline_cls is None:
64-
raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.")
65-
model_id = (
66-
MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path
67-
)
68-
pipe = pipeline_cls.from_pretrained(
69-
model_id,
70-
torch_dtype=torch_dtype,
71-
use_safetensors=True,
72-
**MODEL_DEFAULTS[model_type].get("from_pretrained_extra_args", {}),
73-
)
74-
pipe.set_progress_bar_config(disable=True)
75-
return pipe
76-
except Exception as e:
77-
raise e
62+
pipeline_cls = MODEL_PIPELINE[model_type]
63+
if pipeline_cls is None:
64+
raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.")
65+
model_id = (
66+
MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path
67+
)
68+
pipe = pipeline_cls.from_pretrained(
69+
model_id,
70+
torch_dtype=torch_dtype,
71+
use_safetensors=True,
72+
**MODEL_DEFAULTS[model_type].get("from_pretrained_extra_args", {}),
73+
)
74+
pipe.set_progress_bar_config(disable=True)
75+
return pipe
7876

7977
def create_pipeline(self) -> Any:
8078
"""
@@ -157,42 +155,32 @@ def setup_device(self) -> None:
157155
self.logger.info("Enabling VAE tiling for LTX-Video")
158156
self.pipe.vae.enable_tiling()
159157

160-
def get_backbone(self) -> torch.nn.Module:
161-
"""
162-
Get the backbone model (transformer or UNet).
163-
164-
Returns:
165-
Backbone model module
166-
"""
167-
if not self.pipe:
168-
raise RuntimeError("Pipeline not created. Call create_pipeline() first.")
169-
170-
backbone_pairs = list(self.iter_backbones())
171-
if len(backbone_pairs) == 1:
172-
return backbone_pairs[0][1]
173-
return torch.nn.ModuleList([module for _, module in backbone_pairs])
174-
175158
def iter_backbones(self) -> Iterator[tuple[str, torch.nn.Module]]:
176159
"""
177-
Yield backbone modules by name, based on a backbone spec.
178-
179-
Yields:
180-
(backbone_name, module) pairs
160+
Yield (backbone_name, module) pairs.
181161
"""
182162
if not self.pipe:
183163
raise RuntimeError("Pipeline not created. Call create_pipeline() first.")
184164

185165
names = list(self.config.backbone)
166+
if not names:
167+
raise RuntimeError("No backbone names provided.")
186168

187169
if self.config.model_type == ModelType.LTX2:
188-
self._ensure_ltx2_transformer_cached()
189-
name = names[0] if names else "transformer"
190-
yield name, self._transformer
170+
for name in names:
171+
if name == "video_decoder":
172+
self._ensure_ltx2_video_decoder_cached()
173+
yield name, self._video_decoder
174+
elif name == "transformer":
175+
self._ensure_ltx2_transformer_cached()
176+
yield name, self._transformer
177+
else:
178+
raise ValueError(
179+
f"Unsupported LTX-2 backbone name '{name}'. "
180+
"Expected 'transformer' or 'video_decoder'."
181+
)
191182
return
192183

193-
if not names:
194-
raise RuntimeError("No backbone names provided.")
195-
196184
for name in names:
197185
module = getattr(self.pipe, name, None)
198186
if module is None:
@@ -207,6 +195,16 @@ def _ensure_ltx2_transformer_cached(self) -> None:
207195
self.pipe.stage_1_model_ledger.transformer = lambda: transformer
208196
self._transformer = transformer
209197

198+
def _ensure_ltx2_video_decoder_cached(self) -> None:
199+
if not self.pipe:
200+
raise RuntimeError("Pipeline not created. Call create_pipeline() first.")
201+
if self._video_decoder is None:
202+
video_decoder = self.pipe.stage_1_model_ledger.video_decoder()
203+
# Cache it so subsequent calls return the same (quantized) instance
204+
self.pipe.stage_1_model_ledger.video_decoder = lambda: video_decoder
205+
self.pipe.stage_2_model_ledger.video_decoder = lambda: video_decoder
206+
self._video_decoder = video_decoder
207+
210208
def _create_ltx2_pipeline(self) -> Any:
211209
params = dict(self.config.extra_params)
212210
checkpoint_path = params.pop("checkpoint_path", None)
@@ -261,7 +259,6 @@ def _create_ltx2_pipeline(self) -> Any:
261259
return TI2VidTwoStagesPipeline(**pipeline_kwargs)
262260

263261
def print_quant_summary(self):
264-
backbone_pairs = list(self.iter_backbones())
265-
for name, backbone in backbone_pairs:
262+
for name, backbone in self.iter_backbones():
266263
self.logger.info(f"{name} quantization info:")
267264
mtq.print_quant_summary(backbone)

0 commit comments

Comments
 (0)