You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
### What does this PR do?
Type of change: new feature <!-- Use one of the following: Bug fix, new
feature, new example, new tests, documentation. -->
- Add Conv3D implicit GEMM kernel with BF16 WMMA tensor cores and fused
NVFP4 activation quantization for video diffusion VAE layers
- Integrate into _QuantConv3d via QuantModuleRegistry — automatically
dispatched when NVFP4 quantization is applied to nn.Conv3d
- Move kernel from `experimental/conv/ to modelopt/torch/kernels/conv/`;
move tests to `tests/gpu/torch/quantization/kernels/`
### Testing
<!-- Mention how have you tested your change if applicable. -->
- Added test cases to measure the difference between cuDNN and our CUDA
implicit GEMM kernel
- Added an NVFP4 fake quantization test using CUDA code
### Before your PR is "*Ready for review*"
Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)
and your commits are signed (`git commit -s -S`).
Make sure you read and follow the [Security Best
Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors)
(e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(...,
weights_only=False)`, `pickle`, etc.).
- Is this change backward compatible?: ✅ <!--- If ❌, explain why. -->
- If you copied code from any other sources or added a new PIP
dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ <!---
Mandatory -->
- Did you write any new necessary tests?: ✅ <!--- Mandatory for new
features or examples. -->
- Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?:
✅ <!--- Only for new features, API changes, critical bug fixes or
backward incompatible changes. -->
### Additional Information
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Per-backbone quantization/export in a single run with per-backbone
checkpoints and backbone-aware quant filters
* Configurable NVFP4 block-size via CLI/config; improved NVFP4 Conv3D
inference path and Wan 2.2 quantization support
* **Bug Fixes**
* Video-model calibration now respects extra params and forces video
decoding during calibration
* **Documentation**
* Added comprehensive Conv3D implicit‑GEMM kernel documentation; removed
experimental Conv3D prototype docs/benchmark
* **Tests**
* New Wan 2.2 quantization/export tests and expanded Conv3D/FP4 kernel
test coverage
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Copy file name to clipboardExpand all lines: CHANGELOG.rst
+1Lines changed: 1 addition & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -16,6 +16,7 @@ Changelog
16
16
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
17
17
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
18
18
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml>`_ for usage.
19
+
- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.quantization.src.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning.
Copy file name to clipboardExpand all lines: examples/diffusers/README.md
+14Lines changed: 14 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -117,6 +117,20 @@ python quantize.py \
117
117
--hf-ckpt-dir ./hf_ckpt
118
118
```
119
119
120
+
#### Wan 2.2 VAE NVFP4 (Conv3D Implicit GEMM)
121
+
122
+
The Wan 2.2 VAE (`AutoencoderKLWan`, shared between the 5B and 14B pipelines) is built from 3D convolutions. When quantizing the VAE with NVFP4, the `Conv3d` layers are automatically dispatched through a custom BF16 WMMA implicit-GEMM kernel with fused FP4 activation quantization. Requires SM80+ (Ampere or newer). See [`modelopt/torch/quantization/src/conv/README.md`](../../modelopt/torch/quantization/src/conv/README.md) for kernel details.
123
+
124
+
```sh
125
+
python quantize.py \
126
+
--model {wan2.2-t2v-14b|wan2.2-t2v-5b} \
127
+
--backbone vae \
128
+
--format fp4 --quant-algo max --collect-method default \
0 commit comments