Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions experimental/conv/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Conv3D Implicit GEMM (Experimental)

Experimental Conv3D kernel prototype using implicit GEMM, with optional fused FP4 fake quantization for activations.

This code is kept under `experimental/` by design and is **not** part of the stable `modelopt.torch.quantization` API.

## Model Support

| Model/Framework | Supported | Notes |
|-----------------|-----------|-------|
| Video diffusion VAE Conv3D layers | Tested | Validated on VAE encoder/decoder Conv3D layers in video diffusion models |
| Generic LLM backbones | No | Conv3D path is not relevant |
| End-to-end ModelOpt PTQ/QAT pipeline | No | Not wired into formal quantization/export/compress flows |

## Deployment

| Framework | Supported | Notes |
|-----------|-----------|-------|
| TensorRT-LLM | No | No formal export integration for this kernel path |
| vLLM | No | No integration |
| SGLang | No | No integration |
| PyTorch runtime (CUDA) | Yes (experimental) | JIT-compiles CUDA extension on first use |

## Usage

```python
import torch

from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda
from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op

x = torch.randn(1, 128, 21, 60, 106, device="cuda")
w = torch.randn(512, 128, 3, 3, 3, device="cuda")
block_size = 128

# Without FP4 activation quantization (drop-in-style Conv3D call)
out = conv3d_implicit_gemm_cuda(x, w, stride=(1, 1, 1), padding=(1, 1, 1))

# Optional FP4 block quantization of weights along the GEMM K dimension.
# The kernel's A-tile (activations) is quantized along K = Cin*kD*kH*kW,
# so weights must be flattened to [Cout, K] before quantizing to match.
Cout, Cin = w.shape[:2]
K = Cin * w.shape[2] * w.shape[3] * w.shape[4]
w_flat = w.reshape(Cout, K)
w_q_flat = dynamic_block_quantize_op(
w_flat,
block_size,
w_flat.abs().max().unsqueeze(0),
4, # num_bits
2, # exponent_bits
8, # scale_num_bits
4, # scale_exponent_bits
)
w_q = w_q_flat.reshape_as(w)

# With FP4 activation fake quantization
out_q = conv3d_implicit_gemm_cuda(
x,
w_q,
stride=(1, 1, 1),
padding=(1, 1, 1),
act_amax=x.abs().max().unsqueeze(0),
quant_act=True,
fp4_block_size=block_size, # 16, 32, 64, 128, or 256
)
```

## API

Function: `conv3d_implicit_gemm_cuda(...)` from `experimental/conv/implicit_gemm_cuda.py`

| Parameter | Description |
|-----------|-------------|
| `x` | Input tensor `[N, Cin, D, H, W]` |
| `w` | Weight tensor `[Cout, Cin, kD, kH, kW]` |
| `bias` | Optional bias `[Cout]` |
| `stride` | Convolution stride `(D, H, W)` |
| `padding` | Convolution padding `(D, H, W)` |
| `dilation` | Convolution dilation `(D, H, W)` |
| `act_amax` | Activation abs-max scalar tensor (required when `quant_act=True`) |
| `quant_act` | Enable FP4 fake quantization on activations |
| `fp4_block_size` | FP4 quantization block size (`16`, `32`, `64`, `128`, or `256`) |

## Status

Current state: **Prototype**

Known limitations:

- API is unstable and may change without notice.
- Not registered in core quantization module registries.
- Not covered by formal export/compress integration.
- CUDA extension compile latency on first invocation.
- Validation and performance coverage are limited to local experiments.

## Notes

- The CUDA kernel is JIT-compiled on first call (can take several seconds).
- Output shape matches `torch.nn.functional.conv3d`.
- FP4 path applies quantize-dequantize in-kernel for activation tiles.

## References

- Implicit GEMM-based convolution design patterns in GPU kernels.
- ModelOpt FP4-related quantization utilities in `modelopt.torch.quantization.tensor_quant`.
208 changes: 208 additions & 0 deletions experimental/conv/bench_implicit_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Latency benchmark: implicit GEMM (quant / non-quant) vs cuDNN conv3d.

Usage:
python -m experimental.conv.bench_implicit_gemm
python -m experimental.conv.bench_implicit_gemm --shapes wan22
python -m experimental.conv.bench_implicit_gemm --shapes all --warmup 20 --iters 100
"""

import argparse

import torch
import torch.nn.functional as F

# ---------------------------------------------------------------------------
# Benchmark shapes
# ---------------------------------------------------------------------------

# (name, N, Cin, D, H, W, Cout, kD, kH, kW, stride, padding, dilation)
SHAPES = {
"small": [
("small_16x32_3x3x3", 1, 16, 8, 8, 8, 32, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
],
"medium": [
("med_64x128_3x3x3", 1, 64, 16, 32, 32, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
("med_128x256_3x3x3", 1, 128, 8, 16, 16, 256, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
("med_128x128_1x3x3", 1, 128, 16, 32, 32, 128, 1, 3, 3, (1, 1, 1), (0, 1, 1), (1, 1, 1)),
],
"wan22": [
("wan22_128x512", 1, 128, 21, 60, 106, 512, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
("wan22_512x512", 1, 512, 21, 60, 106, 512, 1, 1, 1, (1, 1, 1), (0, 0, 0), (1, 1, 1)),
("wan22_512x128", 1, 512, 21, 60, 106, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
],
"stride": [
("stride2_64x128", 1, 64, 16, 32, 32, 128, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)),
("stride2_128x256", 1, 128, 16, 32, 32, 256, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)),
],
}


def get_shapes(name: str):
"""Return list of benchmark shapes by name or all shapes."""
if name == "all":
result = []
for v in SHAPES.values():
result.extend(v)
return result
return SHAPES[name]


# ---------------------------------------------------------------------------
# Timing utility
# ---------------------------------------------------------------------------


def bench_fn(fn, warmup: int, iters: int) -> float:
"""Benchmark a callable, return median time in ms."""
for _ in range(warmup):
fn()
torch.cuda.synchronize()

times = []
for _ in range(iters):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
fn()
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))

times.sort()
return times[len(times) // 2] # median


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------


def run_benchmark(shapes_name: str, warmup: int, iters: int, fp4_block_size: int):
"""Run latency benchmark for the given shapes."""
from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda

shapes = get_shapes(shapes_name)

# Header
print(f"\n{'=' * 100}")
print(
f"Conv3D Latency Benchmark | warmup={warmup} iters={iters} fp4_block_size={fp4_block_size}"
)
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"{'=' * 100}")
print(
f"{'Shape':<25} {'M':>10} {'K':>8} {'N':>6} "
f"{'cuDNN':>9} {'GEMM':>9} {'GEMM+FP4':>9} "
f"{'GEMM/cuDNN':>11} {'FP4/cuDNN':>10}"
)
print("-" * 100)

for name, n, cin, d, h, w, cout, kd, kh, kw, stride, padding, dilation in shapes:
torch.manual_seed(42)
x = torch.randn(n, cin, d, h, w, device="cuda", dtype=torch.float32)
weight = torch.randn(cout, cin, kd, kh, kw, device="cuda", dtype=torch.float32)
act_amax = x.abs().max().unsqueeze(0)

# Compute GEMM dimensions for display
sd, sh, sw = stride
dd, dh, dw = dilation
pd, ph, pw = padding
od = (d + 2 * pd - dd * (kd - 1) - 1) // sd + 1
oh = (h + 2 * ph - dh * (kh - 1) - 1) // sh + 1
ow = (w + 2 * pw - dw * (kw - 1) - 1) // sw + 1
gemm_m = n * od * oh * ow
gemm_k = cin * kd * kh * kw
gemm_n = cout

# cuDNN (torch.nn.functional.conv3d)
t_cudnn = bench_fn(
lambda: F.conv3d(x, weight, stride=stride, padding=padding, dilation=dilation),
warmup,
iters,
)

# Implicit GEMM (non-quantized)
t_gemm = bench_fn(
lambda: conv3d_implicit_gemm_cuda(
x,
weight,
stride=stride,
padding=padding,
dilation=dilation,
quant_act=False,
fp4_block_size=fp4_block_size,
),
warmup,
iters,
)

# Implicit GEMM (FP4 quantized)
t_fp4 = bench_fn(
lambda: conv3d_implicit_gemm_cuda(
x,
weight,
stride=stride,
padding=padding,
dilation=dilation,
act_amax=act_amax,
quant_act=True,
fp4_block_size=fp4_block_size,
),
warmup,
iters,
)

ratio_gemm = t_gemm / t_cudnn
ratio_fp4 = t_fp4 / t_cudnn

print(
f"{name:<25} {gemm_m:>10,} {gemm_k:>8,} {gemm_n:>6,} "
f"{t_cudnn:>8.3f}ms {t_gemm:>8.3f}ms {t_fp4:>8.3f}ms "
f"{ratio_gemm:>10.2f}x {ratio_fp4:>9.2f}x"
)

print(f"{'=' * 100}")
print("Ratios > 1.0x mean slower than cuDNN; < 1.0x mean faster.")
print()


def main():
"""Entry point for the benchmark CLI."""
parser = argparse.ArgumentParser(description="Conv3D latency benchmark")
parser.add_argument(
"--shapes",
default="all",
choices=[*list(SHAPES.keys()), "all"],
help="Which shape set to benchmark (default: all)",
)
parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=100, help="Benchmark iterations")
parser.add_argument(
"--fp4-block-size",
type=int,
default=128,
choices=[128, 256],
help="FP4 block size (default: 128)",
)
args = parser.parse_args()

run_benchmark(args.shapes, args.warmup, args.iters, args.fp4_block_size)


if __name__ == "__main__":
main()
37 changes: 37 additions & 0 deletions experimental/conv/implicit_gemm_binding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

#include <torch/extension.h>

torch::Tensor conv3d_implicit_gemm_cuda(torch::Tensor x_pad, torch::Tensor w_flat,
torch::Tensor bias, torch::Tensor act_amax, int N_batch,
int Cin, int Dp, int Hp, int Wp, int Cout, int OD, int OH,
int OW, int kD, int kH, int kW, int sd, int sh, int sw,
int dd, int dh, int dw, int M, int K, bool quant_act,
bool has_bias, int fp4_block_size);

torch::Tensor fp4_fake_quant_cuda(torch::Tensor x, torch::Tensor global_amax, int block_size);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("conv3d_implicit_gemm_cuda", &conv3d_implicit_gemm_cuda,
"Conv3D implicit GEMM with BF16 WMMA and optional FP4 quantization");
m.def("fp4_fake_quant_cuda", &fp4_fake_quant_cuda,
"Standalone FP4 fake quantization (blockwise, with FP8 scale quantization)");
}
Loading
Loading