-
Notifications
You must be signed in to change notification settings - Fork 420
Implicit Gemm NVFP4 on Conv3D #886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 12 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
31c201b
Update the implicit gemm kernel
jingyu-ml 9b278d8
Update the readme
jingyu-ml abd598f
Update
jingyu-ml fcb4571
Merge branch 'main' into jingyux/implicit-gemm-nvfp4
jingyu-ml c669198
Merge branch 'main' into jingyux/implicit-gemm-nvfp4
jingyu-ml 7ca8bd6
Add test case, and move the cuda code out of python script
jingyu-ml 66278df
Update the README
jingyu-ml e2c375d
E2E implicit gemm nvfp4 results
jingyu-ml 8dfe250
Update the LTX2 recipe
jingyu-ml 6ba7802
Merge branch 'main' into jingyux/implicit-gemm-nvfp4
jingyu-ml 7ead24d
revert the change
jingyu-ml 52bb60d
Update some of the code and checks
jingyu-ml 9067290
Update the README
jingyu-ml 40757f8
Update
jingyu-ml 863dd1a
Added the SM version constraint
jingyu-ml d15f926
Undo
jingyu-ml b0f6014
Merge branch 'main' into jingyux/implicit-gemm-nvfp4
jingyu-ml 9d425cc
Merge branch 'main' into jingyux/implicit-gemm-nvfp4
jingyu-ml File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| # Conv3D Implicit GEMM (Experimental) | ||
|
|
||
| Experimental Conv3D kernel prototype using implicit GEMM, with optional fused FP4 fake quantization for activations. | ||
|
|
||
| This code is kept under `experimental/` by design and is **not** part of the stable `modelopt.torch.quantization` API. | ||
|
|
||
| ## Model Support | ||
|
|
||
| | Model/Framework | Supported | Notes | | ||
| |-----------------|-----------|-------| | ||
| | Video diffusion backbones using Conv3D | Partial | Intended for experimentation and microbenchmarking | | ||
| | Generic LLM backbones | No | Conv3D path is not relevant | | ||
| | End-to-end ModelOpt PTQ/QAT pipeline | No | Not wired into formal quantization/export/compress flows | | ||
|
|
||
| ## Deployment | ||
|
|
||
| | Framework | Supported | Notes | | ||
| |-----------|-----------|-------| | ||
| | TensorRT-LLM | No | No formal export integration for this kernel path | | ||
| | vLLM | No | No integration | | ||
| | SGLang | No | No integration | | ||
| | PyTorch runtime (CUDA) | Yes (experimental) | JIT-compiles CUDA extension on first use | | ||
|
|
||
| ## Usage | ||
|
|
||
| ```python | ||
| import torch | ||
|
|
||
| from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda | ||
| from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op | ||
|
|
||
| x = torch.randn(1, 128, 21, 60, 106, device="cuda") | ||
| w = torch.randn(512, 128, 3, 3, 3, device="cuda") | ||
| block_size = 128 | ||
|
|
||
| # Without FP4 activation quantization (drop-in-style Conv3D call) | ||
| out = conv3d_implicit_gemm_cuda(x, w, stride=(1, 1, 1), padding=(1, 1, 1)) | ||
|
|
||
| # Optional FP4 block quantization of weights along the GEMM K dimension. | ||
| # The kernel's A-tile (activations) is quantized along K = Cin*kD*kH*kW, | ||
| # so weights must be flattened to [Cout, K] before quantizing to match. | ||
| Cout, Cin = w.shape[:2] | ||
| K = Cin * w.shape[2] * w.shape[3] * w.shape[4] | ||
| w_flat = w.reshape(Cout, K) | ||
| w_q_flat = dynamic_block_quantize_op( | ||
| w_flat, | ||
| block_size, | ||
| w_flat.abs().max().unsqueeze(0), | ||
| 4, # num_bits | ||
| 2, # exponent_bits | ||
| 8, # scale_num_bits | ||
| 4, # scale_exponent_bits | ||
| ) | ||
| w_q = w_q_flat.reshape_as(w) | ||
|
|
||
| # With FP4 activation fake quantization | ||
| out_q = conv3d_implicit_gemm_cuda( | ||
| x, | ||
| w_q, | ||
| stride=(1, 1, 1), | ||
| padding=(1, 1, 1), | ||
| act_amax=x.abs().max().unsqueeze(0), | ||
| quant_act=True, | ||
| fp4_block_size=block_size, # 16, 32, 64, 128, or 256 | ||
| ) | ||
| ``` | ||
|
|
||
| ## API | ||
|
|
||
| Function: `conv3d_implicit_gemm_cuda(...)` from `experimental/conv/implicit_gemm_cuda.py` | ||
|
|
||
| | Parameter | Description | | ||
| |-----------|-------------| | ||
| | `x` | Input tensor `[N, Cin, D, H, W]` | | ||
| | `w` | Weight tensor `[Cout, Cin, kD, kH, kW]` | | ||
| | `bias` | Optional bias `[Cout]` | | ||
| | `stride` | Convolution stride `(D, H, W)` | | ||
| | `padding` | Convolution padding `(D, H, W)` | | ||
| | `dilation` | Convolution dilation `(D, H, W)` | | ||
| | `act_amax` | Activation abs-max scalar tensor (required when `quant_act=True`) | | ||
| | `quant_act` | Enable FP4 fake quantization on activations | | ||
| | `fp4_block_size` | FP4 quantization block size (`16`, `32`, `64`, `128`, or `256`) | | ||
|
|
||
| ## Status | ||
|
|
||
| Current state: **Prototype** | ||
|
|
||
| Known limitations: | ||
|
|
||
| - API is unstable and may change without notice. | ||
| - Not registered in core quantization module registries. | ||
| - Not covered by formal export/compress integration. | ||
| - CUDA extension compile latency on first invocation. | ||
| - Validation and performance coverage are limited to local experiments. | ||
|
|
||
| ## Notes | ||
|
|
||
| - The CUDA kernel is JIT-compiled on first call (can take several seconds). | ||
| - Output shape matches `torch.nn.functional.conv3d`. | ||
| - FP4 path applies quantize-dequantize in-kernel for activation tiles. | ||
|
|
||
| ## References | ||
|
|
||
| - Implicit GEMM-based convolution design patterns in GPU kernels. | ||
| - ModelOpt FP4-related quantization utilities in `modelopt.torch.quantization.tensor_quant`. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,208 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Latency benchmark: implicit GEMM (quant / non-quant) vs cuDNN conv3d. | ||
|
|
||
| Usage: | ||
| python -m experimental.conv.bench_implicit_gemm | ||
| python -m experimental.conv.bench_implicit_gemm --shapes wan22 | ||
| python -m experimental.conv.bench_implicit_gemm --shapes all --warmup 20 --iters 100 | ||
| """ | ||
|
|
||
| import argparse | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Benchmark shapes | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
| # (name, N, Cin, D, H, W, Cout, kD, kH, kW, stride, padding, dilation) | ||
| SHAPES = { | ||
| "small": [ | ||
| ("small_16x32_3x3x3", 1, 16, 8, 8, 8, 32, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), | ||
| ], | ||
| "medium": [ | ||
| ("med_64x128_3x3x3", 1, 64, 16, 32, 32, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), | ||
| ("med_128x256_3x3x3", 1, 128, 8, 16, 16, 256, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), | ||
| ("med_128x128_1x3x3", 1, 128, 16, 32, 32, 128, 1, 3, 3, (1, 1, 1), (0, 1, 1), (1, 1, 1)), | ||
| ], | ||
| "wan22": [ | ||
| ("wan22_128x512", 1, 128, 21, 60, 106, 512, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), | ||
| ("wan22_512x512", 1, 512, 21, 60, 106, 512, 1, 1, 1, (1, 1, 1), (0, 0, 0), (1, 1, 1)), | ||
| ("wan22_512x128", 1, 512, 21, 60, 106, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), | ||
| ], | ||
| "stride": [ | ||
| ("stride2_64x128", 1, 64, 16, 32, 32, 128, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)), | ||
| ("stride2_128x256", 1, 128, 16, 32, 32, 256, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)), | ||
| ], | ||
| } | ||
|
|
||
|
|
||
| def get_shapes(name: str): | ||
| """Return list of benchmark shapes by name or all shapes.""" | ||
| if name == "all": | ||
| result = [] | ||
| for v in SHAPES.values(): | ||
| result.extend(v) | ||
| return result | ||
| return SHAPES[name] | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Timing utility | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def bench_fn(fn, warmup: int, iters: int) -> float: | ||
| """Benchmark a callable, return median time in ms.""" | ||
| for _ in range(warmup): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
|
|
||
| times = [] | ||
| for _ in range(iters): | ||
| start = torch.cuda.Event(enable_timing=True) | ||
| end = torch.cuda.Event(enable_timing=True) | ||
| start.record() | ||
| fn() | ||
| end.record() | ||
| torch.cuda.synchronize() | ||
| times.append(start.elapsed_time(end)) | ||
|
|
||
| times.sort() | ||
| return times[len(times) // 2] # median | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Main | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def run_benchmark(shapes_name: str, warmup: int, iters: int, fp4_block_size: int): | ||
| """Run latency benchmark for the given shapes.""" | ||
| from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda | ||
|
|
||
| shapes = get_shapes(shapes_name) | ||
|
|
||
| # Header | ||
| print(f"\n{'=' * 100}") | ||
| print( | ||
| f"Conv3D Latency Benchmark | warmup={warmup} iters={iters} fp4_block_size={fp4_block_size}" | ||
| ) | ||
| print(f"GPU: {torch.cuda.get_device_name()}") | ||
| print(f"{'=' * 100}") | ||
| print( | ||
| f"{'Shape':<25} {'M':>10} {'K':>8} {'N':>6} " | ||
| f"{'cuDNN':>9} {'GEMM':>9} {'GEMM+FP4':>9} " | ||
| f"{'GEMM/cuDNN':>11} {'FP4/cuDNN':>10}" | ||
| ) | ||
| print("-" * 100) | ||
|
|
||
| for name, n, cin, d, h, w, cout, kd, kh, kw, stride, padding, dilation in shapes: | ||
| torch.manual_seed(42) | ||
| x = torch.randn(n, cin, d, h, w, device="cuda", dtype=torch.float32) | ||
| weight = torch.randn(cout, cin, kd, kh, kw, device="cuda", dtype=torch.float32) | ||
| act_amax = x.abs().max().unsqueeze(0) | ||
|
|
||
| # Compute GEMM dimensions for display | ||
| sd, sh, sw = stride | ||
| dd, dh, dw = dilation | ||
| pd, ph, pw = padding | ||
| od = (d + 2 * pd - dd * (kd - 1) - 1) // sd + 1 | ||
| oh = (h + 2 * ph - dh * (kh - 1) - 1) // sh + 1 | ||
| ow = (w + 2 * pw - dw * (kw - 1) - 1) // sw + 1 | ||
| gemm_m = n * od * oh * ow | ||
| gemm_k = cin * kd * kh * kw | ||
| gemm_n = cout | ||
|
|
||
| # cuDNN (torch.nn.functional.conv3d) | ||
| t_cudnn = bench_fn( | ||
| lambda: F.conv3d(x, weight, stride=stride, padding=padding, dilation=dilation), | ||
| warmup, | ||
| iters, | ||
| ) | ||
|
|
||
| # Implicit GEMM (non-quantized) | ||
| t_gemm = bench_fn( | ||
| lambda: conv3d_implicit_gemm_cuda( | ||
| x, | ||
| weight, | ||
| stride=stride, | ||
| padding=padding, | ||
| dilation=dilation, | ||
| quant_act=False, | ||
| fp4_block_size=fp4_block_size, | ||
| ), | ||
| warmup, | ||
| iters, | ||
| ) | ||
|
|
||
| # Implicit GEMM (FP4 quantized) | ||
| t_fp4 = bench_fn( | ||
| lambda: conv3d_implicit_gemm_cuda( | ||
| x, | ||
| weight, | ||
| stride=stride, | ||
| padding=padding, | ||
| dilation=dilation, | ||
| act_amax=act_amax, | ||
| quant_act=True, | ||
| fp4_block_size=fp4_block_size, | ||
| ), | ||
| warmup, | ||
| iters, | ||
| ) | ||
|
|
||
| ratio_gemm = t_gemm / t_cudnn | ||
| ratio_fp4 = t_fp4 / t_cudnn | ||
|
|
||
| print( | ||
| f"{name:<25} {gemm_m:>10,} {gemm_k:>8,} {gemm_n:>6,} " | ||
| f"{t_cudnn:>8.3f}ms {t_gemm:>8.3f}ms {t_fp4:>8.3f}ms " | ||
| f"{ratio_gemm:>10.2f}x {ratio_fp4:>9.2f}x" | ||
| ) | ||
|
|
||
| print(f"{'=' * 100}") | ||
| print("Ratios > 1.0x mean slower than cuDNN; < 1.0x mean faster.") | ||
| print() | ||
|
|
||
|
|
||
| def main(): | ||
| """Entry point for the benchmark CLI.""" | ||
| parser = argparse.ArgumentParser(description="Conv3D latency benchmark") | ||
| parser.add_argument( | ||
| "--shapes", | ||
| default="all", | ||
| choices=[*list(SHAPES.keys()), "all"], | ||
| help="Which shape set to benchmark (default: all)", | ||
| ) | ||
| parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations") | ||
| parser.add_argument("--iters", type=int, default=100, help="Benchmark iterations") | ||
| parser.add_argument( | ||
| "--fp4-block-size", | ||
| type=int, | ||
| default=128, | ||
| choices=[128, 256], | ||
| help="FP4 block size (default: 128)", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| run_benchmark(args.shapes, args.warmup, args.iters, args.fp4_block_size) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| // SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| #include <torch/extension.h> | ||
|
|
||
| torch::Tensor conv3d_implicit_gemm_cuda(torch::Tensor x_pad, torch::Tensor w_flat, | ||
| torch::Tensor bias, torch::Tensor act_amax, int N_batch, | ||
| int Cin, int Dp, int Hp, int Wp, int Cout, int OD, int OH, | ||
| int OW, int kD, int kH, int kW, int sd, int sh, int sw, | ||
| int dd, int dh, int dw, int M, int K, bool quant_act, | ||
| bool has_bias, int fp4_block_size); | ||
|
|
||
| torch::Tensor fp4_fake_quant_cuda(torch::Tensor x, torch::Tensor global_amax, int block_size); | ||
|
|
||
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
| m.def("conv3d_implicit_gemm_cuda", &conv3d_implicit_gemm_cuda, | ||
| "Conv3D implicit GEMM with BF16 WMMA and optional FP4 quantization"); | ||
| m.def("fp4_fake_quant_cuda", &fp4_fake_quant_cuda, | ||
| "Standalone FP4 fake quantization (blockwise, with FP8 scale quantization)"); | ||
| } |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.