Skip to content

Commit 6f32d24

Browse files
authored
Implicit Gemm NVFP4 on Conv3D (#886)
## What does this PR do? **Type of change:** new feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Experimental Conv3D implicit-GEMM CUDA kernel with optional NVFP4-style (E2M1 + FP8 E4M3 scale) fake quantization for activations. It is intended for research/prototyping and quantization-accuracy experiments only, not production deployment. The implementation runs as a JIT-compiled PyTorch extension, mirrors conv3d output shape, and provides a quantized and non-quantized path to compare numerical behavior. There is currently no real quantized production kernel integration in the formal ModelOpt export/compress/runtime stack; this path is kept in experimental/ for fake-quant accuracy validation and benchmarking. ## Usage <!-- You can potentially add a usage example below. --> ```python import torch from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op x = torch.randn(1, 128, 21, 60, 106, device="cuda") w = torch.randn(512, 128, 3, 3, 3, device="cuda") block_size = 128 # Without FP4 activation quantization (drop-in-style Conv3D call) out = conv3d_implicit_gemm_cuda(x, w, stride=(1, 1, 1), padding=(1, 1, 1)) # Optional FP4 block quantization of weights along the GEMM K dimension. # The kernel's A-tile (activations) is quantized along K = Cin*kD*kH*kW, # so weights must be flattened to [Cout, K] before quantizing to match. Cout, Cin = w.shape[:2] K = Cin * w.shape[2] * w.shape[3] * w.shape[4] w_flat = w.reshape(Cout, K) w_q_flat = dynamic_block_quantize_op( w_flat, block_size, w_flat.abs().max().unsqueeze(0), 4, # num_bits 2, # exponent_bits 8, # scale_num_bits 4, # scale_exponent_bits ) w_q = w_q_flat.reshape_as(w) # With FP4 activation fake quantization out_q = conv3d_implicit_gemm_cuda( x, w_q, stride=(1, 1, 1), padding=(1, 1, 1), act_amax=x.abs().max().unsqueeze(0), quant_act=True, fp4_block_size=block_size, # 128 or 256 ) ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added experimental Conv3D implementation with implicit GEMM acceleration and optional FP4 quantization support * Added benchmarking tool to compare 3D convolution performance across implementations * Enhanced quantization framework integration for Conv3D operations * **Documentation** * Added comprehensive guide for experimental Conv3D prototype, including supported scenarios, API reference, and current limitations <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 812e8c6 commit 6f32d24

6 files changed

Lines changed: 2274 additions & 0 deletions

File tree

experimental/conv/README.md

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Conv3D Implicit GEMM (Experimental)
2+
3+
Experimental Conv3D kernel prototype using implicit GEMM, with optional fused FP4 fake quantization for activations.
4+
5+
This code is kept under `experimental/` by design and is **not** part of the stable `modelopt.torch.quantization` API.
6+
7+
## Model Support
8+
9+
| Model/Framework | Supported | Notes |
10+
|-----------------|-----------|-------|
11+
| Video diffusion VAE Conv3D layers | Tested | Validated on VAE encoder/decoder Conv3D layers in video diffusion models |
12+
| Generic LLM backbones | No | Conv3D path is not relevant |
13+
| End-to-end ModelOpt PTQ/QAT pipeline | No | Not wired into formal quantization/export/compress flows |
14+
15+
## Deployment
16+
17+
| Framework | Supported | Notes |
18+
|-----------|-----------|-------|
19+
| TensorRT-LLM | No | No formal export integration for this kernel path |
20+
| vLLM | No | No integration |
21+
| SGLang | No | No integration |
22+
| PyTorch runtime (CUDA) | Yes (experimental) | JIT-compiles CUDA extension on first use |
23+
24+
## Usage
25+
26+
```python
27+
import torch
28+
29+
from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda
30+
from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op
31+
32+
x = torch.randn(1, 128, 21, 60, 106, device="cuda")
33+
w = torch.randn(512, 128, 3, 3, 3, device="cuda")
34+
block_size = 128
35+
36+
# Without FP4 activation quantization (drop-in-style Conv3D call)
37+
out = conv3d_implicit_gemm_cuda(x, w, stride=(1, 1, 1), padding=(1, 1, 1))
38+
39+
# Optional FP4 block quantization of weights along the GEMM K dimension.
40+
# The kernel's A-tile (activations) is quantized along K = Cin*kD*kH*kW,
41+
# so weights must be flattened to [Cout, K] before quantizing to match.
42+
Cout, Cin = w.shape[:2]
43+
K = Cin * w.shape[2] * w.shape[3] * w.shape[4]
44+
w_flat = w.reshape(Cout, K)
45+
w_q_flat = dynamic_block_quantize_op(
46+
w_flat,
47+
block_size,
48+
w_flat.abs().max().unsqueeze(0),
49+
4, # num_bits
50+
2, # exponent_bits
51+
8, # scale_num_bits
52+
4, # scale_exponent_bits
53+
)
54+
w_q = w_q_flat.reshape_as(w)
55+
56+
# With FP4 activation fake quantization
57+
out_q = conv3d_implicit_gemm_cuda(
58+
x,
59+
w_q,
60+
stride=(1, 1, 1),
61+
padding=(1, 1, 1),
62+
act_amax=x.abs().max().unsqueeze(0),
63+
quant_act=True,
64+
fp4_block_size=block_size, # 16, 32, 64, 128, or 256
65+
)
66+
```
67+
68+
## API
69+
70+
Function: `conv3d_implicit_gemm_cuda(...)` from `experimental/conv/implicit_gemm_cuda.py`
71+
72+
| Parameter | Description |
73+
|-----------|-------------|
74+
| `x` | Input tensor `[N, Cin, D, H, W]` |
75+
| `w` | Weight tensor `[Cout, Cin, kD, kH, kW]` |
76+
| `bias` | Optional bias `[Cout]` |
77+
| `stride` | Convolution stride `(D, H, W)` |
78+
| `padding` | Convolution padding `(D, H, W)` |
79+
| `dilation` | Convolution dilation `(D, H, W)` |
80+
| `act_amax` | Activation abs-max scalar tensor (required when `quant_act=True`) |
81+
| `quant_act` | Enable FP4 fake quantization on activations |
82+
| `fp4_block_size` | FP4 quantization block size (`16`, `32`, `64`, `128`, or `256`) |
83+
84+
## Status
85+
86+
Current state: **Prototype**
87+
88+
Known limitations:
89+
90+
- API is unstable and may change without notice.
91+
- Not registered in core quantization module registries.
92+
- Not covered by formal export/compress integration.
93+
- CUDA extension compile latency on first invocation.
94+
- Validation and performance coverage are limited to local experiments.
95+
96+
## Notes
97+
98+
- The CUDA kernel is JIT-compiled on first call (can take several seconds).
99+
- Output shape matches `torch.nn.functional.conv3d`.
100+
- FP4 path applies quantize-dequantize in-kernel for activation tiles.
101+
102+
## References
103+
104+
- Implicit GEMM-based convolution design patterns in GPU kernels.
105+
- ModelOpt FP4-related quantization utilities in `modelopt.torch.quantization.tensor_quant`.
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Latency benchmark: implicit GEMM (quant / non-quant) vs cuDNN conv3d.
17+
18+
Usage:
19+
python -m experimental.conv.bench_implicit_gemm
20+
python -m experimental.conv.bench_implicit_gemm --shapes wan22
21+
python -m experimental.conv.bench_implicit_gemm --shapes all --warmup 20 --iters 100
22+
"""
23+
24+
import argparse
25+
26+
import torch
27+
import torch.nn.functional as F
28+
29+
# ---------------------------------------------------------------------------
30+
# Benchmark shapes
31+
# ---------------------------------------------------------------------------
32+
33+
# (name, N, Cin, D, H, W, Cout, kD, kH, kW, stride, padding, dilation)
34+
SHAPES = {
35+
"small": [
36+
("small_16x32_3x3x3", 1, 16, 8, 8, 8, 32, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
37+
],
38+
"medium": [
39+
("med_64x128_3x3x3", 1, 64, 16, 32, 32, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
40+
("med_128x256_3x3x3", 1, 128, 8, 16, 16, 256, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
41+
("med_128x128_1x3x3", 1, 128, 16, 32, 32, 128, 1, 3, 3, (1, 1, 1), (0, 1, 1), (1, 1, 1)),
42+
],
43+
"wan22": [
44+
("wan22_128x512", 1, 128, 21, 60, 106, 512, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
45+
("wan22_512x512", 1, 512, 21, 60, 106, 512, 1, 1, 1, (1, 1, 1), (0, 0, 0), (1, 1, 1)),
46+
("wan22_512x128", 1, 512, 21, 60, 106, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
47+
],
48+
"stride": [
49+
("stride2_64x128", 1, 64, 16, 32, 32, 128, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)),
50+
("stride2_128x256", 1, 128, 16, 32, 32, 256, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)),
51+
],
52+
}
53+
54+
55+
def get_shapes(name: str):
56+
"""Return list of benchmark shapes by name or all shapes."""
57+
if name == "all":
58+
result = []
59+
for v in SHAPES.values():
60+
result.extend(v)
61+
return result
62+
return SHAPES[name]
63+
64+
65+
# ---------------------------------------------------------------------------
66+
# Timing utility
67+
# ---------------------------------------------------------------------------
68+
69+
70+
def bench_fn(fn, warmup: int, iters: int) -> float:
71+
"""Benchmark a callable, return median time in ms."""
72+
for _ in range(warmup):
73+
fn()
74+
torch.cuda.synchronize()
75+
76+
times = []
77+
for _ in range(iters):
78+
start = torch.cuda.Event(enable_timing=True)
79+
end = torch.cuda.Event(enable_timing=True)
80+
start.record()
81+
fn()
82+
end.record()
83+
torch.cuda.synchronize()
84+
times.append(start.elapsed_time(end))
85+
86+
times.sort()
87+
return times[len(times) // 2] # median
88+
89+
90+
# ---------------------------------------------------------------------------
91+
# Main
92+
# ---------------------------------------------------------------------------
93+
94+
95+
def run_benchmark(shapes_name: str, warmup: int, iters: int, fp4_block_size: int):
96+
"""Run latency benchmark for the given shapes."""
97+
from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda
98+
99+
shapes = get_shapes(shapes_name)
100+
101+
# Header
102+
print(f"\n{'=' * 100}")
103+
print(
104+
f"Conv3D Latency Benchmark | warmup={warmup} iters={iters} fp4_block_size={fp4_block_size}"
105+
)
106+
print(f"GPU: {torch.cuda.get_device_name()}")
107+
print(f"{'=' * 100}")
108+
print(
109+
f"{'Shape':<25} {'M':>10} {'K':>8} {'N':>6} "
110+
f"{'cuDNN':>9} {'GEMM':>9} {'GEMM+FP4':>9} "
111+
f"{'GEMM/cuDNN':>11} {'FP4/cuDNN':>10}"
112+
)
113+
print("-" * 100)
114+
115+
for name, n, cin, d, h, w, cout, kd, kh, kw, stride, padding, dilation in shapes:
116+
torch.manual_seed(42)
117+
x = torch.randn(n, cin, d, h, w, device="cuda", dtype=torch.float32)
118+
weight = torch.randn(cout, cin, kd, kh, kw, device="cuda", dtype=torch.float32)
119+
act_amax = x.abs().max().unsqueeze(0)
120+
121+
# Compute GEMM dimensions for display
122+
sd, sh, sw = stride
123+
dd, dh, dw = dilation
124+
pd, ph, pw = padding
125+
od = (d + 2 * pd - dd * (kd - 1) - 1) // sd + 1
126+
oh = (h + 2 * ph - dh * (kh - 1) - 1) // sh + 1
127+
ow = (w + 2 * pw - dw * (kw - 1) - 1) // sw + 1
128+
gemm_m = n * od * oh * ow
129+
gemm_k = cin * kd * kh * kw
130+
gemm_n = cout
131+
132+
# cuDNN (torch.nn.functional.conv3d)
133+
t_cudnn = bench_fn(
134+
lambda: F.conv3d(x, weight, stride=stride, padding=padding, dilation=dilation),
135+
warmup,
136+
iters,
137+
)
138+
139+
# Implicit GEMM (non-quantized)
140+
t_gemm = bench_fn(
141+
lambda: conv3d_implicit_gemm_cuda(
142+
x,
143+
weight,
144+
stride=stride,
145+
padding=padding,
146+
dilation=dilation,
147+
quant_act=False,
148+
fp4_block_size=fp4_block_size,
149+
),
150+
warmup,
151+
iters,
152+
)
153+
154+
# Implicit GEMM (FP4 quantized)
155+
t_fp4 = bench_fn(
156+
lambda: conv3d_implicit_gemm_cuda(
157+
x,
158+
weight,
159+
stride=stride,
160+
padding=padding,
161+
dilation=dilation,
162+
act_amax=act_amax,
163+
quant_act=True,
164+
fp4_block_size=fp4_block_size,
165+
),
166+
warmup,
167+
iters,
168+
)
169+
170+
ratio_gemm = t_gemm / t_cudnn
171+
ratio_fp4 = t_fp4 / t_cudnn
172+
173+
print(
174+
f"{name:<25} {gemm_m:>10,} {gemm_k:>8,} {gemm_n:>6,} "
175+
f"{t_cudnn:>8.3f}ms {t_gemm:>8.3f}ms {t_fp4:>8.3f}ms "
176+
f"{ratio_gemm:>10.2f}x {ratio_fp4:>9.2f}x"
177+
)
178+
179+
print(f"{'=' * 100}")
180+
print("Ratios > 1.0x mean slower than cuDNN; < 1.0x mean faster.")
181+
print()
182+
183+
184+
def main():
185+
"""Entry point for the benchmark CLI."""
186+
parser = argparse.ArgumentParser(description="Conv3D latency benchmark")
187+
parser.add_argument(
188+
"--shapes",
189+
default="all",
190+
choices=[*list(SHAPES.keys()), "all"],
191+
help="Which shape set to benchmark (default: all)",
192+
)
193+
parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations")
194+
parser.add_argument("--iters", type=int, default=100, help="Benchmark iterations")
195+
parser.add_argument(
196+
"--fp4-block-size",
197+
type=int,
198+
default=128,
199+
choices=[128, 256],
200+
help="FP4 block size (default: 128)",
201+
)
202+
args = parser.parse_args()
203+
204+
run_benchmark(args.shapes, args.warmup, args.iters, args.fp4_block_size)
205+
206+
207+
if __name__ == "__main__":
208+
main()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
19+
// SPDX-License-Identifier: Apache-2.0
20+
21+
#include <torch/extension.h>
22+
23+
torch::Tensor conv3d_implicit_gemm_cuda(torch::Tensor x_pad, torch::Tensor w_flat,
24+
torch::Tensor bias, torch::Tensor act_amax, int N_batch,
25+
int Cin, int Dp, int Hp, int Wp, int Cout, int OD, int OH,
26+
int OW, int kD, int kH, int kW, int sd, int sh, int sw,
27+
int dd, int dh, int dw, int M, int K, bool quant_act,
28+
bool has_bias, int fp4_block_size);
29+
30+
torch::Tensor fp4_fake_quant_cuda(torch::Tensor x, torch::Tensor global_amax, int block_size);
31+
32+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
33+
m.def("conv3d_implicit_gemm_cuda", &conv3d_implicit_gemm_cuda,
34+
"Conv3D implicit GEMM with BF16 WMMA and optional FP4 quantization");
35+
m.def("fp4_fake_quant_cuda", &fp4_fake_quant_cuda,
36+
"Standalone FP4 fake quantization (blockwise, with FP8 scale quantization)");
37+
}

0 commit comments

Comments
 (0)