Skip to content

Commit 31c201b

Browse files
committed
Update the implicit gemm kernel
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 95511a0 commit 31c201b

2 files changed

Lines changed: 760 additions & 0 deletions

File tree

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Conv3D Implicit GEMM Kernels
2+
3+
CUDA and Triton kernels for Conv3D via implicit GEMM with optional FP4 fake quantization.
4+
5+
## Usage
6+
7+
```python
8+
import torch
9+
from modelopt.torch.quantization.conv_gemm.implicit_gemm_cuda import conv3d_implicit_gemm_cuda
10+
11+
x = torch.randn(1, 128, 21, 60, 106, device="cuda")
12+
w = torch.randn(512, 128, 3, 3, 3, device="cuda")
13+
14+
# Without quantization (drop-in replacement for F.conv3d)
15+
out = conv3d_implicit_gemm_cuda(x, w, stride=(1,1,1), padding=(1,1,1))
16+
17+
# With FP4 quantization
18+
out = conv3d_implicit_gemm_cuda(
19+
x, w,
20+
stride=(1,1,1),
21+
padding=(1,1,1),
22+
act_amax=x.abs().max().unsqueeze(0),
23+
quant_act=True,
24+
FP4_BLOCK_SIZE=128, # 128 or 256
25+
)
26+
```
27+
28+
The Triton kernel has the same API:
29+
30+
```python
31+
from modelopt.torch.quantization.conv_gemm.implicit_gemm import conv3d_implicit_gemm_triton
32+
33+
out = conv3d_implicit_gemm_triton(x, w, stride=(1,1,1), padding=(1,1,1))
34+
```
35+
36+
## Parameters
37+
38+
| Parameter | Description |
39+
|-----------|-------------|
40+
| `x` | Input tensor `[N, Cin, D, H, W]` |
41+
| `w` | Weight tensor `[Cout, Cin, kD, kH, kW]` |
42+
| `bias` | Optional bias `[Cout]` |
43+
| `stride` | Convolution stride `(D, H, W)` |
44+
| `padding` | Convolution padding `(D, H, W)` |
45+
| `dilation` | Convolution dilation `(D, H, W)` |
46+
| `act_amax` | Activation abs-max scalar tensor (required when `quant_act=True`) |
47+
| `quant_act` | Enable FP4 fake quantization on activations |
48+
| `FP4_BLOCK_SIZE` | Quantization block size: `128` or `256` |
49+
50+
## Notes
51+
52+
- The CUDA kernel is JIT-compiled on first call (takes a few seconds).
53+
- Both kernels return the same shape as `torch.nn.functional.conv3d`.
54+
- FP4 quantization fuses the quantize-dequantize into the GEMM tile load, so there is minimal overhead vs the non-quantized path.

0 commit comments

Comments
 (0)