diff --git a/experimental/conv/README.md b/experimental/conv/README.md new file mode 100644 index 0000000000..65b7cc5563 --- /dev/null +++ b/experimental/conv/README.md @@ -0,0 +1,105 @@ +# Conv3D Implicit GEMM (Experimental) + +Experimental Conv3D kernel prototype using implicit GEMM, with optional fused FP4 fake quantization for activations. + +This code is kept under `experimental/` by design and is **not** part of the stable `modelopt.torch.quantization` API. + +## Model Support + +| Model/Framework | Supported | Notes | +|-----------------|-----------|-------| +| Video diffusion VAE Conv3D layers | Tested | Validated on VAE encoder/decoder Conv3D layers in video diffusion models | +| Generic LLM backbones | No | Conv3D path is not relevant | +| End-to-end ModelOpt PTQ/QAT pipeline | No | Not wired into formal quantization/export/compress flows | + +## Deployment + +| Framework | Supported | Notes | +|-----------|-----------|-------| +| TensorRT-LLM | No | No formal export integration for this kernel path | +| vLLM | No | No integration | +| SGLang | No | No integration | +| PyTorch runtime (CUDA) | Yes (experimental) | JIT-compiles CUDA extension on first use | + +## Usage + +```python +import torch + +from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda +from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op + +x = torch.randn(1, 128, 21, 60, 106, device="cuda") +w = torch.randn(512, 128, 3, 3, 3, device="cuda") +block_size = 128 + +# Without FP4 activation quantization (drop-in-style Conv3D call) +out = conv3d_implicit_gemm_cuda(x, w, stride=(1, 1, 1), padding=(1, 1, 1)) + +# Optional FP4 block quantization of weights along the GEMM K dimension. +# The kernel's A-tile (activations) is quantized along K = Cin*kD*kH*kW, +# so weights must be flattened to [Cout, K] before quantizing to match. +Cout, Cin = w.shape[:2] +K = Cin * w.shape[2] * w.shape[3] * w.shape[4] +w_flat = w.reshape(Cout, K) +w_q_flat = dynamic_block_quantize_op( + w_flat, + block_size, + w_flat.abs().max().unsqueeze(0), + 4, # num_bits + 2, # exponent_bits + 8, # scale_num_bits + 4, # scale_exponent_bits +) +w_q = w_q_flat.reshape_as(w) + +# With FP4 activation fake quantization +out_q = conv3d_implicit_gemm_cuda( + x, + w_q, + stride=(1, 1, 1), + padding=(1, 1, 1), + act_amax=x.abs().max().unsqueeze(0), + quant_act=True, + fp4_block_size=block_size, # 16, 32, 64, 128, or 256 +) +``` + +## API + +Function: `conv3d_implicit_gemm_cuda(...)` from `experimental/conv/implicit_gemm_cuda.py` + +| Parameter | Description | +|-----------|-------------| +| `x` | Input tensor `[N, Cin, D, H, W]` | +| `w` | Weight tensor `[Cout, Cin, kD, kH, kW]` | +| `bias` | Optional bias `[Cout]` | +| `stride` | Convolution stride `(D, H, W)` | +| `padding` | Convolution padding `(D, H, W)` | +| `dilation` | Convolution dilation `(D, H, W)` | +| `act_amax` | Activation abs-max scalar tensor (required when `quant_act=True`) | +| `quant_act` | Enable FP4 fake quantization on activations | +| `fp4_block_size` | FP4 quantization block size (`16`, `32`, `64`, `128`, or `256`) | + +## Status + +Current state: **Prototype** + +Known limitations: + +- API is unstable and may change without notice. +- Not registered in core quantization module registries. +- Not covered by formal export/compress integration. +- CUDA extension compile latency on first invocation. +- Validation and performance coverage are limited to local experiments. + +## Notes + +- The CUDA kernel is JIT-compiled on first call (can take several seconds). +- Output shape matches `torch.nn.functional.conv3d`. +- FP4 path applies quantize-dequantize in-kernel for activation tiles. + +## References + +- Implicit GEMM-based convolution design patterns in GPU kernels. +- ModelOpt FP4-related quantization utilities in `modelopt.torch.quantization.tensor_quant`. diff --git a/experimental/conv/bench_implicit_gemm.py b/experimental/conv/bench_implicit_gemm.py new file mode 100644 index 0000000000..164c074467 --- /dev/null +++ b/experimental/conv/bench_implicit_gemm.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Latency benchmark: implicit GEMM (quant / non-quant) vs cuDNN conv3d. + +Usage: + python -m experimental.conv.bench_implicit_gemm + python -m experimental.conv.bench_implicit_gemm --shapes wan22 + python -m experimental.conv.bench_implicit_gemm --shapes all --warmup 20 --iters 100 +""" + +import argparse + +import torch +import torch.nn.functional as F + +# --------------------------------------------------------------------------- +# Benchmark shapes +# --------------------------------------------------------------------------- + +# (name, N, Cin, D, H, W, Cout, kD, kH, kW, stride, padding, dilation) +SHAPES = { + "small": [ + ("small_16x32_3x3x3", 1, 16, 8, 8, 8, 32, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), + ], + "medium": [ + ("med_64x128_3x3x3", 1, 64, 16, 32, 32, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), + ("med_128x256_3x3x3", 1, 128, 8, 16, 16, 256, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), + ("med_128x128_1x3x3", 1, 128, 16, 32, 32, 128, 1, 3, 3, (1, 1, 1), (0, 1, 1), (1, 1, 1)), + ], + "wan22": [ + ("wan22_128x512", 1, 128, 21, 60, 106, 512, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), + ("wan22_512x512", 1, 512, 21, 60, 106, 512, 1, 1, 1, (1, 1, 1), (0, 0, 0), (1, 1, 1)), + ("wan22_512x128", 1, 512, 21, 60, 106, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), + ], + "stride": [ + ("stride2_64x128", 1, 64, 16, 32, 32, 128, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)), + ("stride2_128x256", 1, 128, 16, 32, 32, 256, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)), + ], +} + + +def get_shapes(name: str): + """Return list of benchmark shapes by name or all shapes.""" + if name == "all": + result = [] + for v in SHAPES.values(): + result.extend(v) + return result + return SHAPES[name] + + +# --------------------------------------------------------------------------- +# Timing utility +# --------------------------------------------------------------------------- + + +def bench_fn(fn, warmup: int, iters: int) -> float: + """Benchmark a callable, return median time in ms.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + times = [] + for _ in range(iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + times.sort() + return times[len(times) // 2] # median + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def run_benchmark(shapes_name: str, warmup: int, iters: int, fp4_block_size: int): + """Run latency benchmark for the given shapes.""" + from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda + + shapes = get_shapes(shapes_name) + + # Header + print(f"\n{'=' * 100}") + print( + f"Conv3D Latency Benchmark | warmup={warmup} iters={iters} fp4_block_size={fp4_block_size}" + ) + print(f"GPU: {torch.cuda.get_device_name()}") + print(f"{'=' * 100}") + print( + f"{'Shape':<25} {'M':>10} {'K':>8} {'N':>6} " + f"{'cuDNN':>9} {'GEMM':>9} {'GEMM+FP4':>9} " + f"{'GEMM/cuDNN':>11} {'FP4/cuDNN':>10}" + ) + print("-" * 100) + + for name, n, cin, d, h, w, cout, kd, kh, kw, stride, padding, dilation in shapes: + torch.manual_seed(42) + x = torch.randn(n, cin, d, h, w, device="cuda", dtype=torch.float32) + weight = torch.randn(cout, cin, kd, kh, kw, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + # Compute GEMM dimensions for display + sd, sh, sw = stride + dd, dh, dw = dilation + pd, ph, pw = padding + od = (d + 2 * pd - dd * (kd - 1) - 1) // sd + 1 + oh = (h + 2 * ph - dh * (kh - 1) - 1) // sh + 1 + ow = (w + 2 * pw - dw * (kw - 1) - 1) // sw + 1 + gemm_m = n * od * oh * ow + gemm_k = cin * kd * kh * kw + gemm_n = cout + + # cuDNN (torch.nn.functional.conv3d) + t_cudnn = bench_fn( + lambda: F.conv3d(x, weight, stride=stride, padding=padding, dilation=dilation), + warmup, + iters, + ) + + # Implicit GEMM (non-quantized) + t_gemm = bench_fn( + lambda: conv3d_implicit_gemm_cuda( + x, + weight, + stride=stride, + padding=padding, + dilation=dilation, + quant_act=False, + fp4_block_size=fp4_block_size, + ), + warmup, + iters, + ) + + # Implicit GEMM (FP4 quantized) + t_fp4 = bench_fn( + lambda: conv3d_implicit_gemm_cuda( + x, + weight, + stride=stride, + padding=padding, + dilation=dilation, + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ), + warmup, + iters, + ) + + ratio_gemm = t_gemm / t_cudnn + ratio_fp4 = t_fp4 / t_cudnn + + print( + f"{name:<25} {gemm_m:>10,} {gemm_k:>8,} {gemm_n:>6,} " + f"{t_cudnn:>8.3f}ms {t_gemm:>8.3f}ms {t_fp4:>8.3f}ms " + f"{ratio_gemm:>10.2f}x {ratio_fp4:>9.2f}x" + ) + + print(f"{'=' * 100}") + print("Ratios > 1.0x mean slower than cuDNN; < 1.0x mean faster.") + print() + + +def main(): + """Entry point for the benchmark CLI.""" + parser = argparse.ArgumentParser(description="Conv3D latency benchmark") + parser.add_argument( + "--shapes", + default="all", + choices=[*list(SHAPES.keys()), "all"], + help="Which shape set to benchmark (default: all)", + ) + parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations") + parser.add_argument("--iters", type=int, default=100, help="Benchmark iterations") + parser.add_argument( + "--fp4-block-size", + type=int, + default=128, + choices=[128, 256], + help="FP4 block size (default: 128)", + ) + args = parser.parse_args() + + run_benchmark(args.shapes, args.warmup, args.iters, args.fp4_block_size) + + +if __name__ == "__main__": + main() diff --git a/experimental/conv/implicit_gemm_binding.cpp b/experimental/conv/implicit_gemm_binding.cpp new file mode 100644 index 0000000000..b91650cd4e --- /dev/null +++ b/experimental/conv/implicit_gemm_binding.cpp @@ -0,0 +1,37 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include + +torch::Tensor conv3d_implicit_gemm_cuda(torch::Tensor x_pad, torch::Tensor w_flat, + torch::Tensor bias, torch::Tensor act_amax, int N_batch, + int Cin, int Dp, int Hp, int Wp, int Cout, int OD, int OH, + int OW, int kD, int kH, int kW, int sd, int sh, int sw, + int dd, int dh, int dw, int M, int K, bool quant_act, + bool has_bias, int fp4_block_size); + +torch::Tensor fp4_fake_quant_cuda(torch::Tensor x, torch::Tensor global_amax, int block_size); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("conv3d_implicit_gemm_cuda", &conv3d_implicit_gemm_cuda, + "Conv3D implicit GEMM with BF16 WMMA and optional FP4 quantization"); + m.def("fp4_fake_quant_cuda", &fp4_fake_quant_cuda, + "Standalone FP4 fake quantization (blockwise, with FP8 scale quantization)"); +} diff --git a/experimental/conv/implicit_gemm_cuda.py b/experimental/conv/implicit_gemm_cuda.py new file mode 100644 index 0000000000..713b5f82f9 --- /dev/null +++ b/experimental/conv/implicit_gemm_cuda.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conv3D Implicit GEMM with BF16 WMMA Tensor Cores and optional fused FP4 quantization. + +CUDA kernel source: implicit_gemm_kernel.cu +C++ binding: implicit_gemm_binding.cpp +""" + +import os + +import torch +import torch.nn.functional as F + +_KERNEL_DIR = os.path.dirname(os.path.abspath(__file__)) + +_cuda_module = None + + +_MIN_SM_MAJOR = 8 # BF16 WMMA tensor cores require SM80+ (Ampere and newer) + + +def _get_cuda_module(): + """Get or compile the CUDA module.""" + global _cuda_module + if _cuda_module is None: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. This kernel requires a CUDA GPU.") + major, minor = torch.cuda.get_device_capability() + if major < _MIN_SM_MAJOR: + raise RuntimeError( + f"This kernel requires SM{_MIN_SM_MAJOR}0+ (Ampere or newer) for BF16 WMMA " + f"tensor cores, but the current GPU has SM{major}{minor}." + ) + + from torch.utils.cpp_extension import load + + _cuda_module = load( + name="conv3d_implicit_gemm_cuda_v20_wmma", + sources=[ + os.path.join(_KERNEL_DIR, "implicit_gemm_binding.cpp"), + os.path.join(_KERNEL_DIR, "implicit_gemm_kernel.cu"), + ], + verbose=True, + extra_cuda_cflags=[ + "-O3", + "--use_fast_math", + "-lineinfo", + "--ptxas-options=-v", + "-std=c++17", + ], + ) + return _cuda_module + + +def _triple(v) -> tuple[int, int, int]: + if isinstance(v, int): + return (v, v, v) + assert len(v) == 3 + return (int(v[0]), int(v[1]), int(v[2])) + + +def _pad6(padding) -> tuple[int, int, int, int, int, int]: + if isinstance(padding, int): + p = int(padding) + return (p, p, p, p, p, p) + if len(padding) == 3: + pd, ph, pw = map(int, padding) + return (pw, pw, ph, ph, pd, pd) + assert len(padding) == 6 + return tuple(map(int, padding)) # type: ignore[return-value] + + +@torch.no_grad() +def conv3d_implicit_gemm_cuda( + x: torch.Tensor, + w: torch.Tensor, + bias: torch.Tensor | None = None, + stride: tuple[int, int, int] = (1, 1, 1), + padding: tuple[int, int, int] = (0, 0, 0), + dilation: tuple[int, int, int] = (1, 1, 1), + act_amax: torch.Tensor | None = None, + quant_act: bool = False, + fp4_block_size: int = 256, +) -> torch.Tensor: + """Conv3D via implicit GEMM with BF16 WMMA tensor cores. + + Args: + x: Input tensor [N, Cin, D, H, W] + w: Weight tensor [Cout, Cin, kD, kH, kW] + bias: Optional bias tensor [Cout] + stride: Convolution stride (D, H, W) + padding: Convolution padding (D, H, W) + dilation: Convolution dilation (D, H, W) + act_amax: Activation max value for FP4 quantization + quant_act: Whether to apply FP4 quantization to activations + fp4_block_size: FP4 quantization block size (16, 32, 64, 128, or 256) + + Returns: + Output tensor [N, Cout, OD, OH, OW] + + Raises: + ValueError: If fp4_block_size is not one of {16, 32, 64, 128, 256}. + """ + valid_block_sizes = {16, 32, 64, 128, 256} + if fp4_block_size not in valid_block_sizes: + raise ValueError( + f"fp4_block_size must be one of {sorted(valid_block_sizes)}, got {fp4_block_size}" + ) + + cuda_mod = _get_cuda_module() + + if x.ndim != 5 or w.ndim != 5: + raise ValueError(f"Expected 5D tensors, got x.ndim={x.ndim}, w.ndim={w.ndim}") + n_batch, cin, d, h, w_in = x.shape + cout, cin_w, kd, kh, kw = w.shape + if cin_w != cin: + raise ValueError( + f"Grouped convolution is not supported (x has {cin} input channels, " + f"w has {cin_w}). This kernel requires groups=1." + ) + + sd, sh, sw = _triple(stride) + dd, dh, dw = _triple(dilation) + pad_wl, pad_wr, pad_hl, pad_hr, pad_dl, pad_dr = _pad6(padding) + + x_pad = F.pad(x, (pad_wl, pad_wr, pad_hl, pad_hr, pad_dl, pad_dr)) + dp = d + pad_dl + pad_dr + hp = h + pad_hl + pad_hr + wp = w_in + pad_wl + pad_wr + + od = (dp - (dd * (kd - 1) + 1)) // sd + 1 + oh = (hp - (dh * (kh - 1) + 1)) // sh + 1 + ow = (wp - (dw * (kw - 1) + 1)) // sw + 1 + + m = n_batch * od * oh * ow + k = cin * kd * kh * kw + + w_flat = w.reshape(cout, k).transpose(0, 1).contiguous() + + x_pad = x_pad.float().contiguous() + w_flat = w_flat.float().contiguous() + + has_bias = bias is not None + bias_t = bias.float().contiguous() if has_bias else torch.empty(0, device=x.device) # type: ignore[union-attr] + + if quant_act and act_amax is None: + raise ValueError("act_amax is required when quant_act=True") + + do_quant = quant_act + amax_t = act_amax.float().contiguous() if do_quant else torch.empty(0, device=x.device) # type: ignore[union-attr] + + y_flat = cuda_mod.conv3d_implicit_gemm_cuda( + x_pad, + w_flat, + bias_t, + amax_t, + n_batch, + cin, + dp, + hp, + wp, + cout, + od, + oh, + ow, + kd, + kh, + kw, + sd, + sh, + sw, + dd, + dh, + dw, + m, + k, + do_quant, + has_bias, + fp4_block_size, + ) + + y = y_flat.view(n_batch, od, oh, ow, cout).permute(0, 4, 1, 2, 3).contiguous() + return y.to(x.dtype) + + +@torch.no_grad() +def fp4_fake_quant( + x: torch.Tensor, + global_amax: torch.Tensor, + block_size: int = 16, +) -> torch.Tensor: + """Standalone FP4 fake quantization using the same CUDA device functions as the GEMM kernel. + + Applies blockwise FP4 (E2M1) quantize-dequantize with FP8 E4M3 scale quantization. + + Args: + x: Input tensor (any shape, numel must be divisible by block_size). + global_amax: Scalar tensor — global abs max for scale computation. + block_size: Number of elements per FP4 quantization block. + + Returns: + Fake-quantized tensor with same shape and dtype as input. + """ + cuda_mod = _get_cuda_module() + + orig_shape = x.shape + orig_dtype = x.dtype + x_f32 = x.float().contiguous() + amax_f32 = global_amax.float().contiguous() + + assert x_f32.numel() % block_size == 0, ( + f"numel ({x_f32.numel()}) must be divisible by block_size ({block_size})" + ) + + y = cuda_mod.fp4_fake_quant_cuda(x_f32, amax_f32, block_size) + return y.view(orig_shape).to(orig_dtype) diff --git a/experimental/conv/implicit_gemm_kernel.cu b/experimental/conv/implicit_gemm_kernel.cu new file mode 100644 index 0000000000..a3b40f4848 --- /dev/null +++ b/experimental/conv/implicit_gemm_kernel.cu @@ -0,0 +1,628 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Conv3D Implicit GEMM with BF16 WMMA Tensor Cores and optional fused FP4 quantization. +// +// Key optimizations: +// 1. BF16 WMMA tensor core operations (m16n16k16) with FP32 accumulators +// 2. On-the-fly spatial index computation (no global memory lookup tables) +// 3. Dual FP4_BLOCK_SIZE support (128 and 256) with optimized tile configs +// 4. Register-fused FP4 quantization (quantize during A-tile load, eliminates sync) +// 5. Branchless FP4 quantization using predicated selects +// 6. BF16 shared memory (halves memory vs FP32) +// 7. L2-friendly block scheduling (swizzled grid) +// 8. FP8 E4M3 round-trip for scale quantization + +#include +#include +#include +#include +#include + +using namespace nvcuda; + +// ============================================================================= +// FP4 Quantization Helpers +// ============================================================================= + +__device__ __forceinline__ float fp4_quantize_value(float scaled) { + float q; + q = (scaled <= 5.0f) ? 4.0f : 6.0f; + q = (scaled < 3.5f) ? 3.0f : q; + q = (scaled <= 2.5f) ? 2.0f : q; + q = (scaled < 1.75f) ? 1.5f : q; + q = (scaled <= 1.25f) ? 1.0f : q; + q = (scaled < 0.75f) ? 0.5f : q; + q = (scaled <= 0.25f) ? 0.0f : q; + return q; +} + +__device__ __forceinline__ float warp_reduce_max(float val) { +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, offset)); + } + return val; +} + +__device__ __forceinline__ float fp8_e4m3_round_trip(float x) { + if (x == 0.0f) + return 0.0f; + + unsigned int bits = __float_as_uint(x); + unsigned int sign = bits >> 31; + int exp = ((bits >> 23) & 0xff) - 127; + unsigned int mantissa = bits & 0x7fffff; + + if (exp > 8) + return sign ? -448.0f : 448.0f; + if (exp < -9) + return 0.0f; + + unsigned int mantissa_3bit = (mantissa + (1 << 19)) >> 20; + if (mantissa_3bit > 7) { + mantissa_3bit = 0; + exp += 1; + if (exp > 8) + return sign ? -448.0f : 448.0f; + } + + if (exp < -6) { + int shift = -6 - exp; + mantissa_3bit = (mantissa_3bit | 8) >> shift; + exp = -6; + } + + int fp32_exp = exp + 127; + unsigned int fp32_mantissa = mantissa_3bit << 20; + unsigned int fp32_bits = (sign << 31) | (fp32_exp << 23) | fp32_mantissa; + + return __uint_as_float(fp32_bits); +} + +__device__ __forceinline__ float quantize_scale_fp8(float block_max, float global_scale) { + float scaled = block_max / (6.0f * global_scale); + scaled = fminf(scaled, 448.0f); + float quantized = fp8_e4m3_round_trip(scaled); + return quantized * global_scale; +} + +// ============================================================================= +// BF16 WMMA Conv3D Implicit GEMM Kernel +// ============================================================================= +// Template parameters: +// QUANT_ACT - whether to apply FP4 quantization +// HAS_BIAS - whether bias is present +// BLOCK_M - M tile size (64) +// BLOCK_N - N tile size (64) +// BLOCK_K - K tile size (matches FP4_BLOCK_SIZE: 128 or 256) +// WARPS_M - warp tiling in M dimension (2) +// WARPS_N - warp tiling in N dimension (4) +// L2_SWIZZLE_GROUP - group size for L2-friendly block scheduling +// +// Each warp computes a (WARP_M x WARP_N) output tile using 16x16x16 WMMA. +// WARP_M = BLOCK_M / WARPS_M, WARP_N = BLOCK_N / WARPS_N +// WARP_TILES_M = WARP_M / 16, WARP_TILES_N = WARP_N / 16 +// +// Shared memory layout (BF16): +// As[BLOCK_M][BK_STRIDE] - M-major (row_major for WMMA A-fragments) +// Bs[BLOCK_K][BN_STRIDE] - K-major (row_major for WMMA B-fragments) + +template +__global__ void __launch_bounds__(WARPS_M * WARPS_N * 32, 2) + conv3d_implicit_gemm_wmma(const float *__restrict__ x_pad, const float *__restrict__ w_flat, + const float *__restrict__ bias, float *__restrict__ y, + const float *__restrict__ act_amax, int Cin, int Dp, int Hp, int Wp, + int Cout, int OD, int OH, int OW, int kD, int kH, int kW, int sd, + int sh, int sw, int dd, int dh, int dw, int M, int K) { + // Derived constants + constexpr int NUM_WARPS = WARPS_M * WARPS_N; + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int WARP_M = BLOCK_M / WARPS_M; // 32 + constexpr int WARP_N = BLOCK_N / WARPS_N; // 16 + constexpr int WARP_TILES_M = WARP_M / 16; // 2 + constexpr int WARP_TILES_N = WARP_N / 16; // 1 + + // BF16 shared memory strides with padding to avoid bank conflicts + // Pad by 8 BF16 elements (16 bytes) — keeps 16-byte alignment while breaking conflicts + constexpr int BK_STRIDE = BLOCK_K + 8; + constexpr int BN_STRIDE = BLOCK_N + 8; + + // Thread/warp indices + const int tid = threadIdx.x; + const int warp_id = tid / 32; + const int lane_id = tid % 32; + const int warp_m = warp_id / WARPS_N; // which M-warp (0..WARPS_M-1) + const int warp_n = warp_id % WARPS_N; // which N-warp (0..WARPS_N-1) + + // L2-friendly block scheduling (swizzle) + int bm, bn; + { + const int pid = blockIdx.x; + constexpr int GS = L2_SWIZZLE_GROUP; + const int grid_n = (Cout + BLOCK_N - 1) / BLOCK_N; + const int grid_m = (M + BLOCK_M - 1) / BLOCK_M; + const int tiles_per_group = GS * grid_n; + + const int group_row = pid / tiles_per_group; + const int group_rem = pid % tiles_per_group; + bn = group_rem / GS; + const int swizzle_lane = group_rem % GS; + bm = group_row * GS + swizzle_lane; + + if (bm >= grid_m || bn >= grid_n) + return; + } + + // Dynamic shared memory — BF16 tiles + extern __shared__ char smem_raw[]; + __nv_bfloat16 *As = reinterpret_cast<__nv_bfloat16 *>(smem_raw); + // As: [BLOCK_M][BK_STRIDE] — M-major + constexpr int A_SMEM_ELEMS = BLOCK_M * BK_STRIDE; + __nv_bfloat16 *Bs = As + A_SMEM_ELEMS; + // Bs: [BLOCK_K][BN_STRIDE] — K-major + + // WMMA accumulators — FP32 + wmma::fragment acc[WARP_TILES_M][WARP_TILES_N]; +#pragma unroll + for (int wm = 0; wm < WARP_TILES_M; wm++) { +#pragma unroll + for (int wn = 0; wn < WARP_TILES_N; wn++) { + wmma::fill_fragment(acc[wm][wn], 0.0f); + } + } + + // Global scale for FP4 quantization + float global_scale = 1.0f; + if constexpr (QUANT_ACT) { + global_scale = act_amax[0] / (6.0f * 448.0f); + } + + // Precompute spatial constants + const int HpWp = Hp * Wp; + const int DpHpWp = Dp * HpWp; + const int kHW = kH * kW; + const int kDHW = kD * kHW; + const int OHW = OH * OW; + const int ODHW = OD * OHW; + + const int m_start = bm * BLOCK_M; + const int n_start = bn * BLOCK_N; + const int num_k_tiles = (K + BLOCK_K - 1) / BLOCK_K; + + // Total elements to load cooperatively + constexpr int A_ELEMS = BLOCK_M * BLOCK_K; + constexpr int B_ELEMS = BLOCK_K * BLOCK_N; + + // Main loop over K tiles + for (int k_tile = 0; k_tile < num_k_tiles; k_tile++) { + const int k_start_tile = k_tile * BLOCK_K; + + // ================================================================= + // Load A tile into BF16 shared memory (M-major layout) + // As[m][k] stored at As[m * BK_STRIDE + k] + // ================================================================= + if constexpr (QUANT_ACT) { + // Fused FP4 quantization: each warp handles M-rows + // FP4_BLOCK_SIZE can be smaller than BLOCK_K — we quantize in sub-blocks + static_assert(BLOCK_K % FP4_BLOCK_SIZE == 0, "BLOCK_K must be divisible by FP4_BLOCK_SIZE"); + static_assert(FP4_BLOCK_SIZE >= 16, "FP4_BLOCK_SIZE must be >= 16"); + constexpr int ELEMS_PER_LANE = (BLOCK_K + 31) / 32; + constexpr int NUM_FP4_BLOCKS = BLOCK_K / FP4_BLOCK_SIZE; + + for (int m = warp_id; m < BLOCK_M; m += NUM_WARPS) { + int m_idx = m_start + m; + + int n_batch, od_val, oh_val, ow_val; + if (m_idx < M) { + n_batch = m_idx / ODHW; + int rem = m_idx % ODHW; + od_val = rem / OHW; + rem = rem % OHW; + oh_val = rem / OW; + ow_val = rem % OW; + } else { + n_batch = 0; + od_val = 0; + oh_val = 0; + ow_val = 0; + } + + // Pass 1: Load all values from global memory + float vals[ELEMS_PER_LANE]; + +#pragma unroll + for (int i = 0; i < ELEMS_PER_LANE; i++) { + int k = lane_id + i * 32; + float val = 0.0f; + if (k < BLOCK_K && m_idx < M) { + int k_idx = k_start_tile + k; + if (k_idx < K) { + int c = k_idx / kDHW; + int remk = k_idx % kDHW; + int kd_v = remk / kHW; + remk = remk % kHW; + int kh_v = remk / kW; + int kw_v = remk % kW; + + int id = od_val * sd + kd_v * dd; + int ih = oh_val * sh + kh_v * dh; + int iw = ow_val * sw + kw_v * dw; + + val = x_pad[n_batch * Cin * DpHpWp + c * DpHpWp + id * HpWp + ih * Wp + iw]; + } + } + vals[i] = val; + } + + // Pass 2: For each FP4 sub-block, find max → compute scale → quantize + // Pre-compute per-sub-block scales + float scales[NUM_FP4_BLOCKS]; + float inv_scales[NUM_FP4_BLOCKS]; + +#pragma unroll + for (int sb = 0; sb < NUM_FP4_BLOCKS; sb++) { + const int k_sb_start = sb * FP4_BLOCK_SIZE; + const int k_sb_end = k_sb_start + FP4_BLOCK_SIZE; + + // Each lane accumulates max over its elements in this sub-block + float local_max = 0.0f; +#pragma unroll + for (int i = 0; i < ELEMS_PER_LANE; i++) { + int k = lane_id + i * 32; + // Compiler resolves this at compile time for unrolled loops + if (k >= k_sb_start && k < k_sb_end) { + local_max = fmaxf(local_max, fabsf(vals[i])); + } + } + + // Warp reduce — lanes outside this sub-block contribute 0, which is correct + float block_max = warp_reduce_max(local_max); + float scale = quantize_scale_fp8(block_max, global_scale); + if (scale < 1e-5f) + scale = 1.0f; + scales[sb] = scale; + inv_scales[sb] = 1.0f / scale; + } + + // Pass 3: Quantize and store to shared memory +#pragma unroll + for (int i = 0; i < ELEMS_PER_LANE; i++) { + int k = lane_id + i * 32; + if (k < BLOCK_K) { + int sb = k / FP4_BLOCK_SIZE; // compile-time shift for power-of-2 + float val = vals[i]; + float sign = (val >= 0.0f) ? 1.0f : -1.0f; + float q = fp4_quantize_value(fabsf(val) * inv_scales[sb]); + float result = sign * q * scales[sb]; + // M-major: As[m * BK_STRIDE + k] + As[m * BK_STRIDE + k] = __float2bfloat16(result); + } + } + } + } else { +// Non-quantized: cooperative load, store as BF16 in M-major +#pragma unroll 4 + for (int i = tid; i < A_ELEMS; i += NUM_THREADS) { + int local_m = i / BLOCK_K; + int local_k = i % BLOCK_K; + int m_idx = m_start + local_m; + int k_idx = k_start_tile + local_k; + + float val = 0.0f; + if (m_idx < M && k_idx < K) { + int n_batch = m_idx / ODHW; + int rem = m_idx % ODHW; + int od_val = rem / OHW; + rem = rem % OHW; + int oh_val = rem / OW; + int ow_val = rem % OW; + + int c = k_idx / kDHW; + int remk = k_idx % kDHW; + int kd_v = remk / kHW; + remk = remk % kHW; + int kh_v = remk / kW; + int kw_v = remk % kW; + + int id = od_val * sd + kd_v * dd; + int ih = oh_val * sh + kh_v * dh; + int iw = ow_val * sw + kw_v * dw; + + val = x_pad[n_batch * Cin * DpHpWp + c * DpHpWp + id * HpWp + ih * Wp + iw]; + } + // M-major: As[m * BK_STRIDE + k] + As[local_m * BK_STRIDE + local_k] = __float2bfloat16(val); + } + } + +// ================================================================= +// Load B tile into BF16 shared memory (K-major layout) +// Bs[k][n] stored at Bs[k * BN_STRIDE + n] +// ================================================================= +#pragma unroll 4 + for (int i = tid; i < B_ELEMS; i += NUM_THREADS) { + int local_k = i / BLOCK_N; + int local_n = i % BLOCK_N; + int k_idx = k_start_tile + local_k; + int n_idx = n_start + local_n; + + float val = 0.0f; + if (k_idx < K && n_idx < Cout) { + val = w_flat[k_idx * Cout + n_idx]; + } + Bs[local_k * BN_STRIDE + local_n] = __float2bfloat16(val); + } + + __syncthreads(); + + // ================================================================= + // WMMA Compute: iterate over K in steps of 16 (WMMA K-dim) + // ================================================================= + constexpr int K_STEPS = BLOCK_K / 16; + +#pragma unroll + for (int kk = 0; kk < K_STEPS; kk++) { + // Load A and B fragments from shared memory + wmma::fragment + a_frag[WARP_TILES_M]; + wmma::fragment + b_frag[WARP_TILES_N]; + +// Load A fragments +#pragma unroll + for (int wm = 0; wm < WARP_TILES_M; wm++) { + int a_row = warp_m * WARP_M + wm * 16; + int a_col = kk * 16; + wmma::load_matrix_sync(a_frag[wm], &As[a_row * BK_STRIDE + a_col], BK_STRIDE); + } + +// Load B fragments +#pragma unroll + for (int wn = 0; wn < WARP_TILES_N; wn++) { + int b_row = kk * 16; + int b_col = warp_n * WARP_N + wn * 16; + wmma::load_matrix_sync(b_frag[wn], &Bs[b_row * BN_STRIDE + b_col], BN_STRIDE); + } + +// MMA: acc[wm][wn] += a_frag[wm] * b_frag[wn] +#pragma unroll + for (int wm = 0; wm < WARP_TILES_M; wm++) { +#pragma unroll + for (int wn = 0; wn < WARP_TILES_N; wn++) { + wmma::mma_sync(acc[wm][wn], a_frag[wm], b_frag[wn], acc[wm][wn]); + } + } + } + + __syncthreads(); + } + + // ===================================================================== + // Store results: use shared memory as FP32 staging buffer + // Each warp stores its accumulator fragments, then all threads + // cooperatively copy to global memory with bounds checking and bias. + // ===================================================================== + + // Reinterpret shared memory as FP32 for output staging + float *out_smem = reinterpret_cast(smem_raw); +// out_smem layout: [BLOCK_M][BLOCK_N], row-major + +// Each warp stores its accumulator fragments to shared memory +#pragma unroll + for (int wm = 0; wm < WARP_TILES_M; wm++) { +#pragma unroll + for (int wn = 0; wn < WARP_TILES_N; wn++) { + int out_row = warp_m * WARP_M + wm * 16; + int out_col = warp_n * WARP_N + wn * 16; + wmma::store_matrix_sync(&out_smem[out_row * BLOCK_N + out_col], acc[wm][wn], BLOCK_N, + wmma::mem_row_major); + } + } + + __syncthreads(); + + // Cooperatively copy from shared memory to global memory + constexpr int OUT_ELEMS = BLOCK_M * BLOCK_N; +#pragma unroll 4 + for (int i = tid; i < OUT_ELEMS; i += NUM_THREADS) { + int local_m = i / BLOCK_N; + int local_n = i % BLOCK_N; + int m_idx = m_start + local_m; + int n_idx = n_start + local_n; + + if (m_idx < M && n_idx < Cout) { + float result = out_smem[i]; + if constexpr (HAS_BIAS) { + result += bias[n_idx]; + } + y[m_idx * Cout + n_idx] = result; + } + } +} + +// ============================================================================= +// Standalone FP4 Fake Quantization Kernel (for testing) +// ============================================================================= +// Applies the same blockwise FP4 fake quant used in the GEMM A-tile loader, +// but on a flat 2D tensor [num_blocks, block_size]. +// Each warp processes one row (= one FP4 block). + +__global__ void fp4_fake_quant_kernel(const float *__restrict__ x, float *__restrict__ y, + const float *__restrict__ global_amax_ptr, int num_blocks, + int block_size) { + const int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + const int lane_id = threadIdx.x % 32; + + if (warp_id >= num_blocks) + return; + + float global_scale = global_amax_ptr[0] / (6.0f * 448.0f); + + const float *row = x + warp_id * block_size; + float *out_row = y + warp_id * block_size; + + // Pass 1: compute block max via warp reduction + float local_max = 0.0f; + for (int i = lane_id; i < block_size; i += 32) { + local_max = fmaxf(local_max, fabsf(row[i])); + } + float block_max = warp_reduce_max(local_max); + + // Quantize the scale via FP8 E4M3 round-trip + float scale = quantize_scale_fp8(block_max, global_scale); + if (scale < 1e-5f) + scale = 1.0f; + float inv_scale = 1.0f / scale; + + // Pass 2: quantize + dequantize each element + for (int i = lane_id; i < block_size; i += 32) { + float val = row[i]; + float sign = (val >= 0.0f) ? 1.0f : -1.0f; + float q = fp4_quantize_value(fabsf(val) * inv_scale); + out_row[i] = sign * q * scale; + } +} + +torch::Tensor fp4_fake_quant_cuda(torch::Tensor x, torch::Tensor global_amax, int block_size) { + // x: [num_blocks, block_size] or flat [N] where N % block_size == 0 + auto x_flat = x.contiguous().view({-1}); + int N = x_flat.numel(); + int num_blocks = N / block_size; + + auto y = torch::empty_like(x_flat); + + // Launch: one warp (32 threads) per block + int threads_per_block = 256; // 8 warps per CUDA block + int warps_per_block = threads_per_block / 32; + int num_cuda_blocks = (num_blocks + warps_per_block - 1) / warps_per_block; + + fp4_fake_quant_kernel<<>>( + x_flat.data_ptr(), y.data_ptr(), global_amax.data_ptr(), num_blocks, + block_size); + + return y.view_as(x); +} + +// ============================================================================= +// Python Binding +// ============================================================================= + +torch::Tensor conv3d_implicit_gemm_cuda(torch::Tensor x_pad, torch::Tensor w_flat, + torch::Tensor bias, torch::Tensor act_amax, int N_batch, + int Cin, int Dp, int Hp, int Wp, int Cout, int OD, int OH, + int OW, int kD, int kH, int kW, int sd, int sh, int sw, + int dd, int dh, int dw, int M, int K, bool quant_act, + bool has_bias, int fp4_block_size) { + auto y = torch::zeros({M, Cout}, x_pad.options()); + + // Helper to compute padded 1D grid size for L2 swizzle + constexpr int GS = 8; // L2_SWIZZLE_GROUP + auto compute_grid = [&](int BM, int BN) -> dim3 { + int grid_m = (M + BM - 1) / BM; + int grid_n = (Cout + BN - 1) / BN; + int num_m_groups = (grid_m + GS - 1) / GS; + int total_blocks = num_m_groups * GS * grid_n; + return dim3(total_blocks, 1); + }; + +// Macro to dispatch kernel with all 4 template specializations +// FP4_BS is the FP4 quantization block size (independent of BK) +#define LAUNCH_WMMA_KERNEL(BM, BN, BK, WM, WN, FP4_BS) \ + { \ + constexpr int BK_S = BK + 8; \ + constexpr int BN_S = BN + 8; \ + constexpr size_t smem_a = BM * BK_S * sizeof(__nv_bfloat16); \ + constexpr size_t smem_b = BK * BN_S * sizeof(__nv_bfloat16); \ + constexpr size_t smem = smem_a + smem_b; \ + \ + dim3 block(WM * WN * 32); \ + dim3 grid = compute_grid(BM, BN); \ + \ + auto set_smem = [](auto kernel) { \ + constexpr size_t s_a = BM * (BK + 8) * sizeof(__nv_bfloat16); \ + constexpr size_t s_b = BK * (BN + 8) * sizeof(__nv_bfloat16); \ + constexpr size_t s = s_a + s_b; \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, s); \ + }; \ + \ + if (quant_act && has_bias) { \ + auto kern = conv3d_implicit_gemm_wmma; \ + set_smem(kern); \ + kern<<>>(x_pad.data_ptr(), w_flat.data_ptr(), \ + bias.data_ptr(), y.data_ptr(), \ + act_amax.data_ptr(), Cin, Dp, Hp, Wp, Cout, OD, OH, OW, \ + kD, kH, kW, sd, sh, sw, dd, dh, dw, M, K); \ + } else if (quant_act) { \ + auto kern = conv3d_implicit_gemm_wmma; \ + set_smem(kern); \ + kern<<>>(x_pad.data_ptr(), w_flat.data_ptr(), \ + bias.data_ptr(), y.data_ptr(), \ + act_amax.data_ptr(), Cin, Dp, Hp, Wp, Cout, OD, OH, OW, \ + kD, kH, kW, sd, sh, sw, dd, dh, dw, M, K); \ + } else if (has_bias) { \ + auto kern = conv3d_implicit_gemm_wmma; \ + set_smem(kern); \ + kern<<>>(x_pad.data_ptr(), w_flat.data_ptr(), \ + bias.data_ptr(), y.data_ptr(), \ + act_amax.data_ptr(), Cin, Dp, Hp, Wp, Cout, OD, OH, OW, \ + kD, kH, kW, sd, sh, sw, dd, dh, dw, M, K); \ + } else { \ + auto kern = conv3d_implicit_gemm_wmma; \ + set_smem(kern); \ + kern<<>>(x_pad.data_ptr(), w_flat.data_ptr(), \ + bias.data_ptr(), y.data_ptr(), \ + act_amax.data_ptr(), Cin, Dp, Hp, Wp, Cout, OD, OH, OW, \ + kD, kH, kW, sd, sh, sw, dd, dh, dw, M, K); \ + } \ + } + + // BLOCK_K=256 always, FP4_BLOCK_SIZE varies + // BLOCK_M=64, BLOCK_N=64, WARPS_M=2, WARPS_N=4 (8 warps = 256 threads) + // Shared: 64*(256+8)*2 + 256*(64+8)*2 = 33,792 + 36,864 = 70,656 bytes (~69KB) + if (fp4_block_size == 16) { + LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4, 16) + } else if (fp4_block_size == 32) { + LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4, 32) + } else if (fp4_block_size == 64) { + LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4, 64) + } else if (fp4_block_size == 128) { + LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4, 128) + } else { + LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4, 256) + } + +#undef LAUNCH_WMMA_KERNEL + + return y; +} diff --git a/experimental/conv/test_implicit_gemm.py b/experimental/conv/test_implicit_gemm.py new file mode 100644 index 0000000000..af52660e42 --- /dev/null +++ b/experimental/conv/test_implicit_gemm.py @@ -0,0 +1,1067 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Conv3D implicit GEMM CUDA kernel. + +Tests both non-quantized path (vs cuDNN) and FP4-quantized path (vs Triton reference). +""" + +import pytest +import torch +import torch.nn.functional as F + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +@pytest.fixture(scope="module") +def cuda_conv3d(): + """Import and return the CUDA implicit GEMM conv3d function.""" + from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda + + return conv3d_implicit_gemm_cuda + + +def _triton_fp4_available(): + """Check if the Triton FP4 fake quant kernel is available (requires compute >= 8.9).""" + try: + import modelopt.torch.quantization.triton as triton_kernel + + return hasattr(triton_kernel, "fp4_fake_quant_block") + except ImportError: + return False + + +requires_triton_fp4 = pytest.mark.skipif( + not _triton_fp4_available(), + reason="Triton fp4_fake_quant_block not available (requires compute >= 8.9)", +) + + +# BF16 WMMA accumulates in FP32 but inputs are rounded to BF16, so expect diffs. +# For large K (e.g. 3456 = 128*27), max abs diff can reach ~0.8 due to BF16 rounding +# and different accumulation order vs cuDNN's FP32 path. +ATOL = 1.0 +RTOL = 1e-3 + + +def _run_conv3d_test(cuda_conv3d, x, w, bias, stride, padding, dilation): + """Helper: run both cuDNN and implicit GEMM, compare results.""" + ref = F.conv3d(x, w, bias=bias, stride=stride, padding=padding, dilation=dilation) + out = cuda_conv3d( + x, w, bias=bias, stride=stride, padding=padding, dilation=dilation, quant_act=False + ) + assert out.shape == ref.shape, f"Shape mismatch: {out.shape} vs {ref.shape}" + abs_diff = (out - ref).abs() + max_diff = abs_diff.max().item() + # Scale tolerance with K (reduction dimension) — BF16 rounding accumulates + cin = w.shape[1] + k_size = cin * w.shape[2] * w.shape[3] * w.shape[4] + scaled_atol = ATOL * (k_size / 1000.0) ** 0.5 + assert max_diff < scaled_atol, ( + f"Max abs diff {max_diff:.6e} exceeds tolerance {scaled_atol:.4f} (K={k_size})" + ) + # Check mean diff is small (more robust than quantile for large tensors) + mean_diff = abs_diff.mean().item() + assert mean_diff < scaled_atol * 0.1, f"Mean diff {mean_diff:.6e} too high" + return max_diff + + +class TestConv3dBasic: + """Basic correctness tests with simple shapes.""" + + def test_minimal(self, cuda_conv3d): + """Smallest possible conv3d: 1x1x1 kernel, single channel.""" + x = torch.randn(1, 1, 1, 1, 1, device="cuda", dtype=torch.float32) + w = torch.randn(1, 1, 1, 1, 1, device="cuda", dtype=torch.float32) + diff = _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + # K=1, so BF16 rounding is the only source of error + assert diff < 1e-2 + + def test_single_channel_3x3x3(self, cuda_conv3d): + """Single input/output channel with 3x3x3 kernel.""" + x = torch.randn(1, 1, 5, 5, 5, device="cuda", dtype=torch.float32) + w = torch.randn(1, 1, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_multi_channel(self, cuda_conv3d): + """Multiple input and output channels.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_with_bias(self, cuda_conv3d): + """Conv3d with bias.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + b = torch.randn(32, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, b, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_batch_size(self, cuda_conv3d): + """Batch size > 1.""" + x = torch.randn(4, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + +class TestConv3dStride: + """Tests with various stride configurations.""" + + def test_stride_2(self, cuda_conv3d): + """Uniform stride of 2.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (2, 2, 2), (1, 1, 1), (1, 1, 1)) + + def test_asymmetric_stride(self, cuda_conv3d): + """Different stride per dimension.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 2, 2), (1, 1, 1), (1, 1, 1)) + + +class TestConv3dPadding: + """Tests with various padding configurations.""" + + def test_no_padding(self, cuda_conv3d): + """Zero padding.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + def test_large_padding(self, cuda_conv3d): + """Padding larger than kernel radius.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (2, 2, 2), (1, 1, 1)) + + def test_asymmetric_padding(self, cuda_conv3d): + """Different padding per dimension.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 1, 2), (1, 1, 1)) + + +class TestConv3dDilation: + """Tests with dilation.""" + + def test_dilation_2(self, cuda_conv3d): + """Uniform dilation of 2.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (2, 2, 2), (2, 2, 2)) + + def test_asymmetric_dilation(self, cuda_conv3d): + """Different dilation per dimension.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 2, 2), (1, 2, 2)) + + +class TestConv3dKernelSizes: + """Tests with non-3x3x3 kernels.""" + + def test_1x1x1_kernel(self, cuda_conv3d): + """Pointwise 1x1x1 kernel.""" + x = torch.randn(1, 64, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(128, 64, 1, 1, 1, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + def test_asymmetric_kernel(self, cuda_conv3d): + """Kernel with different sizes per dimension (e.g. 1x3x3).""" + x = torch.randn(1, 16, 8, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 1, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 1, 1), (1, 1, 1)) + + def test_5x5x5_kernel(self, cuda_conv3d): + """Larger 5x5x5 kernel.""" + x = torch.randn(1, 8, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(16, 8, 5, 5, 5, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (2, 2, 2), (1, 1, 1)) + + +class TestConv3dRealisticShapes: + """Tests with shapes resembling real video diffusion models.""" + + def test_wan22_shape(self, cuda_conv3d): + """Shape from Wan2.2 video diffusion backbone.""" + x = torch.randn(1, 128, 21, 60, 106, device="cuda", dtype=torch.float32) + w = torch.randn(512, 128, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_large_cout(self, cuda_conv3d): + """Large output channel count.""" + x = torch.randn(1, 64, 8, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(512, 64, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_large_cin(self, cuda_conv3d): + """Large input channel count.""" + x = torch.randn(1, 512, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(64, 512, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + +class TestConv3dEdgeCases: + """Edge cases for tile boundary handling.""" + + def test_m_not_aligned_to_block(self, cuda_conv3d): + """M (N*OD*OH*OW) not a multiple of BLOCK_M=64.""" + # 1*3*5*7 = 105, not divisible by 64 + x = torch.randn(1, 8, 5, 7, 9, device="cuda", dtype=torch.float32) + w = torch.randn(16, 8, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_cout_not_aligned_to_block(self, cuda_conv3d): + """Cout not a multiple of BLOCK_N=64.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(17, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_k_not_aligned_to_block(self, cuda_conv3d): + """K (Cin*kD*kH*kW) not a multiple of BLOCK_K.""" + # Cin=7, kDHW=27, K=189 -- not a multiple of 128 or 256 + x = torch.randn(1, 7, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(16, 7, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_output_size_1x1x1(self, cuda_conv3d): + """Output spatial dims are all 1.""" + x = torch.randn(1, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + def test_single_output_element(self, cuda_conv3d): + """M=1: batch=1, output 1x1x1. + + With only one output element, mean diff == max diff, so the generic + helper's mean_diff < scaled_atol * 0.1 check is too tight. Use max diff only. + """ + x = torch.randn(1, 4, 3, 3, 3, device="cuda", dtype=torch.float32) + w = torch.randn(1, 4, 3, 3, 3, device="cuda", dtype=torch.float32) + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1)) + out = cuda_conv3d( + x, w, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), quant_act=False + ) + assert out.shape == ref.shape + max_diff = (out - ref).abs().max().item() + assert max_diff < ATOL, f"Max abs diff {max_diff:.6e} exceeds tolerance {ATOL}" + + +class TestConv3dFP4BlockSize: + """Test all FP4 block size configs (BLOCK_K=256 always, FP4_BLOCK_SIZE varies). + + Non-quantized path ignores FP4_BLOCK_SIZE, so all should match cuDNN. + """ + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_non_quant_all_block_sizes(self, cuda_conv3d, fp4_block_size): + """Non-quant conv should match cuDNN regardless of fp4_block_size.""" + x = torch.randn(1, 32, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(64, 32, 3, 3, 3, device="cuda", dtype=torch.float32) + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)) + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + quant_act=False, + fp4_block_size=fp4_block_size, + ) + assert out.shape == ref.shape + assert (out - ref).abs().max().item() < ATOL + + +class TestConv3dDeterminism: + """Verify deterministic output across repeated calls.""" + + def test_deterministic(self, cuda_conv3d): + """Repeated calls produce identical output.""" + torch.manual_seed(123) + x = torch.randn(1, 32, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(64, 32, 3, 3, 3, device="cuda", dtype=torch.float32) + out1 = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + out2 = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + assert torch.equal(out1, out2), "Kernel is not deterministic" + + +# ============================================================================= +# FP4 Quantized Conv3D Tests (fused activation quantization) +# ============================================================================= + + +@pytest.fixture(scope="module") +def cuda_fp4_quant(): + """Import FP4 fake quant for reference comparisons.""" + from experimental.conv.implicit_gemm_cuda import fp4_fake_quant + + return fp4_fake_quant + + +class TestConv3dFP4QuantBlockSizes: + """Test fused FP4 activation quantization with all supported block sizes. + + The kernel applies blockwise FP4 quantization to the im2col'd activation tiles + along the K dimension. We verify correctness by comparing the fused kernel output + against an unfused reference: fp4_fake_quant(im2col) @ fp4_fake_quant(weight). + """ + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_runs_all_block_sizes(self, cuda_conv3d, fp4_block_size): + """All FP4 block sizes should run without errors and produce valid output.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert out.shape == (1, 32, 8, 8, 8) + assert not torch.isnan(out).any(), "Output contains NaN" + assert not torch.isinf(out).any(), "Output contains Inf" + assert out.abs().max() > 0, "Output is all zeros" + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_deterministic(self, cuda_conv3d, fp4_block_size): + """Quantized conv should be deterministic for all block sizes.""" + torch.manual_seed(42) + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + kwargs = { + "stride": (1, 1, 1), + "padding": (1, 1, 1), + "dilation": (1, 1, 1), + "act_amax": act_amax, + "quant_act": True, + "fp4_block_size": fp4_block_size, + } + out1 = cuda_conv3d(x, w, **kwargs) + out2 = cuda_conv3d(x, w, **kwargs) + assert torch.equal(out1, out2), f"Non-deterministic for fp4_block_size={fp4_block_size}" + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_vs_unfused_reference(self, cuda_conv3d, cuda_fp4_quant, fp4_block_size): + """Compare fused kernel vs unfused: fp4(im2col) @ fp4(weight). + + Uses a shape where K is a multiple of 256 so all K-tiles are full + and block boundaries align perfectly between fused and unfused paths. + """ + torch.manual_seed(123) + # K = Cin * kD * kH * kW. Choose Cin so K is a multiple of 256. + # Cin=256, k=1x1x1 -> K=256 (exactly 1 full K-tile) + cin, cout = 256, 64 + x = torch.randn(1, cin, 4, 4, 4, device="cuda", dtype=torch.float32) + w = torch.randn(cout, cin, 1, 1, 1, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + w_amax = w.abs().max().unsqueeze(0) + + # Unfused reference: + # 1. Build im2col matrix (for 1x1x1 kernel, it's just reshape) + n, c, d, h, w_dim = x.shape + im2col = x.permute(0, 2, 3, 4, 1).reshape(-1, cin) # [M, K] + + # 2. FP4 fake-quant both matrices along K with the same block_size + im2col_q = cuda_fp4_quant(im2col, act_amax, fp4_block_size) + w_flat = w.reshape(cout, cin).transpose(0, 1).contiguous() # [K, Cout] + w_flat_q = cuda_fp4_quant(w_flat, w_amax, fp4_block_size) + + # 3. Matmul (in BF16 to match kernel's WMMA path) + ref_out = (im2col_q.bfloat16() @ w_flat_q.bfloat16()).float() + ref_out = ref_out.view(n, d, h, w_dim, cout).permute(0, 4, 1, 2, 3) + + # Note: the fused kernel does NOT quantize weights — weights are passed as-is. + # So for a proper comparison we need the fused kernel with pre-quantized weights. + fused_out_preq = cuda_conv3d( + x, + w_flat_q.transpose(0, 1).reshape(cout, cin, 1, 1, 1), + stride=(1, 1, 1), + padding=(0, 0, 0), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + + # The fused kernel and unfused reference should match closely. + # Differences come from BF16 accumulation order (WMMA 16x16x16 tiles vs flat matmul). + max_diff = (fused_out_preq - ref_out).abs().max().item() + mean_diff = (fused_out_preq - ref_out).abs().mean().item() + # Scale tolerance with K + scaled_atol = ATOL * (cin / 1000.0) ** 0.5 + assert max_diff < scaled_atol, ( + f"fp4_block_size={fp4_block_size}: fused vs unfused max diff {max_diff:.4f} " + f"exceeds tolerance {scaled_atol:.4f}" + ) + assert mean_diff < scaled_atol * 0.1, ( + f"fp4_block_size={fp4_block_size}: mean diff {mean_diff:.6e} too high" + ) + + def test_smaller_block_less_error(self, cuda_conv3d): + """Smaller FP4 block sizes should generally produce lower quantization error. + + Finer-grained blocks capture local ranges better, reducing quant error vs cuDNN. + Test monotonicity: error(16) <= error(32) <= ... <= error(256) (with some tolerance). + Reports detailed accuracy metrics for each block size vs cuDNN baseline. + """ + torch.manual_seed(42) + + # Test multiple shapes to get a comprehensive picture + configs = [ + ("Small K=432", 1, 16, 8, 8, 8, 32, 3, 3, 3), + ("Medium K=1728", 1, 64, 8, 8, 8, 64, 3, 3, 3), + ("Large K=3456", 1, 128, 5, 8, 8, 256, 3, 3, 3), + ] + + block_sizes = [16, 32, 64, 128, 256] + all_errors = {} + + for desc, n, cin, d, h, w_s, cout, kd, kh, kw in configs: + x = torch.randn(n, cin, d, h, w_s, device="cuda", dtype=torch.float32) + w = torch.randn(cout, cin, kd, kh, kw, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + k_size = cin * kd * kh * kw + + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)) + ref_abs_mean = ref.abs().mean().item() + + # Also compute no-quant baseline (BF16 rounding only) + out_nq = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + quant_act=False, + ) + nq_diff = (out_nq - ref).abs() + + print( + f"\n {desc} (K={k_size}), output range [{ref.min().item():.1f}, {ref.max().item():.1f}]" + ) + print( + f" {'Block Size':>10} | {'Max Diff':>10} | {'Mean Diff':>10} | {'RMSE':>10} | {'Rel Err%':>8}" + ) + print(f" {'-' * 10}-+-{'-' * 10}-+-{'-' * 10}-+-{'-' * 10}-+-{'-' * 8}") + print( + f" {'no-quant':>10} | {nq_diff.max().item():>10.4f} | " + f"{nq_diff.mean().item():>10.6f} | " + f"{((out_nq - ref) ** 2).mean().sqrt().item():>10.4f} | " + f"{nq_diff.mean().item() / ref_abs_mean * 100:>7.3f}%" + ) + + errors = {} + for bs in block_sizes: + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=bs, + ) + diff = (out - ref).abs() + max_d = diff.max().item() + mean_d = diff.mean().item() + rmse = ((out - ref) ** 2).mean().sqrt().item() + rel_err = mean_d / ref_abs_mean * 100 + errors[bs] = mean_d + print( + f" {bs:>10} | {max_d:>10.4f} | {mean_d:>10.6f} | " + f"{rmse:>10.4f} | {rel_err:>7.3f}%" + ) + all_errors[desc] = errors + + # Monotonicity check on the medium config + errors = all_errors["Medium K=1728"] + for smaller, larger in [(16, 64), (16, 256), (32, 256), (64, 256)]: + assert errors[smaller] <= errors[larger] * 1.2, ( + f"Expected error({smaller})={errors[smaller]:.6f} <= " + f"error({larger})={errors[larger]:.6f} * 1.2" + ) + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_with_bias(self, cuda_conv3d, fp4_block_size): + """FP4 quantized conv with bias for all block sizes.""" + torch.manual_seed(42) + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + b = torch.randn(32, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + out = cuda_conv3d( + x, + w, + bias=b, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert out.shape == (1, 32, 8, 8, 8) + assert not torch.isnan(out).any() + # Bias should shift output values + out_no_bias = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert not torch.equal(out, out_no_bias), "Bias had no effect" + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_k_not_aligned(self, cuda_conv3d, fp4_block_size): + """FP4 quant with K not aligned to BLOCK_K or fp4_block_size. + + K = Cin * kDHW = 7 * 27 = 189. The last K-tile has partial data (zeros padded). + """ + torch.manual_seed(42) + x = torch.randn(1, 7, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(16, 7, 3, 3, 3, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert out.shape == (1, 16, 8, 8, 8) + assert not torch.isnan(out).any() + assert out.abs().max() > 0 + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_realistic_shape(self, cuda_conv3d, fp4_block_size): + """Realistic video diffusion shape with all FP4 block sizes.""" + torch.manual_seed(42) + x = torch.randn(1, 128, 5, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(256, 128, 3, 3, 3, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert out.shape == (1, 256, 5, 8, 8) + assert not torch.isnan(out).any() + assert out.abs().max() > 0 + + +# ============================================================================= +# FP4 Fake Quantization Tests +# ============================================================================= + + +@pytest.fixture(scope="module") +def cuda_fp4(): + """Import and return the CUDA FP4 fake quant function.""" + from experimental.conv.implicit_gemm_cuda import fp4_fake_quant + + return fp4_fake_quant + + +def _py_fp4_fake_quant_ref(x_flat, global_amax, block_size): + """Pure Python reference for FP4 fake quant (no BF16 rounding). + + This implements the exact same algorithm as the CUDA kernel: + 1. Compute global_scale = global_amax / (6 * 448) + 2. Per block: block_max = max(|x|), scale = fp8_e4m3_roundtrip(block_max / (6 * global_scale)) * global_scale + 3. Quantize each element to nearest E2M1 level, then dequantize. + """ + import math + + # E2M1 quantization levels: {0, 0.5, 1, 1.5, 2, 3, 4, 6} + # Boundaries (midpoints): <=0.25->0, <0.75->0.5, <=1.25->1, <1.75->1.5, <=2.5->2, <3.5->3, <=5->4, >5->6 + def quantize_e2m1(scaled_abs): + if scaled_abs <= 0.25: + return 0.0 + elif scaled_abs < 0.75: + return 0.5 + elif scaled_abs <= 1.25: + return 1.0 + elif scaled_abs < 1.75: + return 1.5 + elif scaled_abs <= 2.5: + return 2.0 + elif scaled_abs < 3.5: + return 3.0 + elif scaled_abs <= 5.0: + return 4.0 + else: + return 6.0 + + def fp8_e4m3_roundtrip(val): + """Simulate FP8 E4M3 round-trip in Python.""" + if val == 0.0: + return 0.0 + sign = 1.0 if val >= 0 else -1.0 + val = abs(val) + # FP8 E4M3: bias=7, 3 mantissa bits, max=448, no inf/nan + if val > 448.0: + return sign * 448.0 + # Compute exponent + exp = math.floor(math.log2(val)) + exp = max(exp, -6) # min normal exponent for E4M3 + # Compute mantissa (3 bits) + mantissa = val / (2.0**exp) # 1.xxx + mantissa_bits = round((mantissa - 1.0) * 8.0) # 3 bits + if mantissa_bits > 7: + mantissa_bits = 0 + exp += 1 + if exp > 8: + return sign * 448.0 + # Reconstruct + result = (1.0 + mantissa_bits / 8.0) * (2.0**exp) + return sign * result + + global_scale = float(global_amax) / (6.0 * 448.0) + x_np = x_flat.cpu().float().numpy().copy() + num_blocks = len(x_np) // block_size + + for b in range(num_blocks): + block = x_np[b * block_size : (b + 1) * block_size] + block_max = float(max(abs(v) for v in block)) + + # Scale quantization + scaled = block_max / (6.0 * global_scale) + scaled = min(scaled, 448.0) + quantized_scale = fp8_e4m3_roundtrip(scaled) * global_scale + if quantized_scale < 1e-5: + quantized_scale = 1.0 + inv_scale = 1.0 / quantized_scale + + for i in range(block_size): + val = block[i] + sign = 1.0 if val >= 0 else -1.0 + q = quantize_e2m1(abs(val) * inv_scale) + x_np[b * block_size + i] = sign * q * quantized_scale + + return torch.tensor(x_np, device=x_flat.device) + + +class TestFP4FakeQuantValues: + """Test FP4 fake quant with known E2M1 table values.""" + + def test_exact_e2m1_values(self, cuda_fp4): + """E2M1 representable values should round-trip exactly (when scale=1 via amax=6*448).""" + # With global_amax = 6*448 = 2688, global_scale = 1.0 + # A single-block input with max=6 -> block_max=6, scaled=6/(6*1)=1.0 + # fp8_e4m3(1.0)=1.0, scale = 1.0*1.0 = 1.0 + block_size = 8 + vals = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], device="cuda", dtype=torch.float32) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(vals, amax, block_size) + assert torch.allclose(out, vals, atol=1e-5), f"Got {out} vs expected {vals}" + + def test_exact_e2m1_negative(self, cuda_fp4): + """Negative E2M1 values should also round-trip.""" + block_size = 8 + vals = torch.tensor([0, -0.5, -1, -1.5, -2, -3, -4, -6], device="cuda", dtype=torch.float32) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(vals, amax, block_size) + assert torch.allclose(out, vals, atol=1e-5), f"Got {out} vs expected {vals}" + + def test_below_boundary(self, cuda_fp4): + """Values slightly below E2M1 boundaries should quantize down.""" + block_size = 8 + # Boundaries: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 + # Slightly below -> quantize to lower level + inp = torch.tensor( + [0.15, 0.65, 1.15, 1.65, 2.4, 3.4, 4.9, 6.0], device="cuda", dtype=torch.float32 + ) + expected = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device="cuda", dtype=torch.float32 + ) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + assert torch.allclose(out, expected, atol=1e-5), f"Got {out} vs expected {expected}" + + def test_above_boundary(self, cuda_fp4): + """Values slightly above E2M1 boundaries should quantize up.""" + block_size = 8 + inp = torch.tensor( + [0.35, 0.85, 1.35, 1.85, 2.6, 3.6, 5.1, 6.0], device="cuda", dtype=torch.float32 + ) + expected = torch.tensor( + [0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 6.0], device="cuda", dtype=torch.float32 + ) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + assert torch.allclose(out, expected, atol=1e-5), f"Got {out} vs expected {expected}" + + def test_mixed_signs(self, cuda_fp4): + """Mixed positive/negative values.""" + block_size = 8 + inp = torch.tensor([-6, -3, -1, 0, 0.5, 2, 4, 6], device="cuda", dtype=torch.float32) + expected = torch.tensor([-6, -3, -1, 0, 0.5, 2, 4, 6], device="cuda", dtype=torch.float32) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + assert torch.allclose(out, expected, atol=1e-5), f"Got {out} vs expected {expected}" + + +class TestFP4FakeQuantScale: + """Test FP4 scale computation and FP8 round-trip.""" + + def test_scale_factor(self, cuda_fp4): + """When amax != 6*448, scale should adjust values proportionally.""" + block_size = 8 + # global_amax = 12*448 = 5376, global_scale = 2.0 + # Input block max = 12 -> scaled = 12/(6*2) = 1.0 -> fp8(1.0) = 1.0 -> scale = 2.0 + # So input 12 -> |12|/2 = 6.0 -> q=6 -> 6*2 = 12 + inp = torch.tensor([0, 1, 2, 3, 4, 6, 8, 12], device="cuda", dtype=torch.float32) + amax = torch.tensor([12.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + # Expected: each val/2.0 -> quantize to E2M1 -> * 2.0 + expected = torch.tensor([0, 1, 2, 3, 4, 6, 8, 12], device="cuda", dtype=torch.float32) + assert torch.allclose(out, expected, atol=1e-4), f"Got {out} vs expected {expected}" + + def test_zero_block(self, cuda_fp4): + """All-zero block should produce all zeros.""" + block_size = 16 + inp = torch.zeros(block_size, device="cuda", dtype=torch.float32) + amax = torch.tensor([1.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + assert torch.equal(out, inp) + + def test_multiple_blocks(self, cuda_fp4): + """Multiple blocks with different ranges.""" + block_size = 8 + # Block 0: small values, Block 1: large values + block0 = torch.tensor([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], device="cuda") + block1 = torch.tensor([0, 1, 2, 3, 4, 5, 6, 6], device="cuda") + inp = torch.cat([block0, block1]) + amax = inp.abs().max().unsqueeze(0) + out = cuda_fp4(inp, amax, block_size) + # Each block should be independently quantized + assert out.shape == inp.shape + # Block 1 exact values should be close to E2M1 levels + assert out[8:].abs().max() <= 6.0 + 1e-5 + + +class TestFP4FakeQuantBlockSizes: + """Test different block sizes.""" + + @pytest.mark.parametrize("block_size", [8, 16, 32, 64, 128, 256]) + def test_block_sizes(self, cuda_fp4, block_size): + """FP4 quant should work for various block sizes.""" + torch.manual_seed(42) + num_blocks = 4 + inp = torch.randn(num_blocks * block_size, device="cuda", dtype=torch.float32) * 5 + amax = inp.abs().max().unsqueeze(0) + out = cuda_fp4(inp, amax, block_size) + assert out.shape == inp.shape + # Output should not be all zeros for non-zero input + assert out.abs().max() > 0 + # Output should be <= max possible after quant + assert out.abs().max() <= inp.abs().max() * 1.5 # generous bound + + +class TestFP4FakeQuantVsReference: + """Compare CUDA FP4 fake quant against Python reference implementation.""" + + @pytest.mark.parametrize("block_size", [8, 16, 32]) + def test_vs_python_ref(self, cuda_fp4, block_size): + """CUDA kernel should match the Python reference exactly.""" + torch.manual_seed(123) + num_blocks = 8 + inp = torch.randn(num_blocks * block_size, device="cuda") * 10 + amax = inp.abs().max().unsqueeze(0) + + cuda_out = cuda_fp4(inp, amax, block_size) + ref_out = _py_fp4_fake_quant_ref(inp, amax, block_size) + + assert torch.allclose(cuda_out, ref_out, atol=1e-5), ( + f"CUDA vs Python ref max diff: {(cuda_out - ref_out).abs().max().item():.6e}" + ) + + @pytest.mark.parametrize("block_size", [16, 32]) + def test_vs_python_ref_large(self, cuda_fp4, block_size): + """Larger tensor test against Python reference.""" + torch.manual_seed(456) + num_blocks = 64 + inp = torch.randn(num_blocks * block_size, device="cuda") * 20 + amax = inp.abs().max().unsqueeze(0) + + cuda_out = cuda_fp4(inp, amax, block_size) + ref_out = _py_fp4_fake_quant_ref(inp, amax, block_size) + + assert torch.allclose(cuda_out, ref_out, atol=1e-4), ( + f"CUDA vs Python ref max diff: {(cuda_out - ref_out).abs().max().item():.6e}" + ) + + +class TestFP4FakeQuantVsTriton: + """Compare CUDA FP4 fake quant against Triton fp4_fake_quant_block reference.""" + + @requires_triton_fp4 + @pytest.mark.parametrize("block_size", [16, 32, 64]) + @pytest.mark.parametrize("num_blocks", [4, 16, 64]) + def test_vs_triton(self, cuda_fp4, block_size, num_blocks): + """CUDA kernel should match the Triton fp4_fake_quant_block.""" + from modelopt.torch.quantization.triton import fp4_fake_quant_block + + torch.manual_seed(42) + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 + global_amax = x.abs().max() + + cuda_out = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + triton_out = fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=num_blocks, + tile_cols=block_size, + ) + + assert torch.allclose(cuda_out, triton_out, atol=1e-5), ( + f"CUDA vs Triton max diff: {(cuda_out - triton_out).abs().max().item():.6e}\n" + f"Mean diff: {(cuda_out - triton_out).abs().mean().item():.6e}" + ) + + +class TestFP4FakeQuantDeterminism: + """Verify FP4 quant is deterministic.""" + + def test_deterministic(self, cuda_fp4): + """Repeated calls produce identical output.""" + torch.manual_seed(99) + inp = torch.randn(256, device="cuda") * 5 + amax = inp.abs().max().unsqueeze(0) + out1 = cuda_fp4(inp, amax, 16) + out2 = cuda_fp4(inp, amax, 16) + assert torch.equal(out1, out2), "FP4 fake quant is not deterministic" + + +# ============================================================================= +# Cross-validation: experimental FP4 vs modelopt FP4 implementations +# ============================================================================= + + +def _modelopt_cuda_ext_mx_available(): + """Check if the modelopt CUDA MX extension is available.""" + try: + from modelopt.torch.quantization.extensions import get_cuda_ext_mx + + return get_cuda_ext_mx() is not None + except Exception: + return False + + +def _modelopt_dynamic_block_quantize_available(): + """Check if dynamic_block_quantize_op is available.""" + try: + from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op + + return dynamic_block_quantize_op is not None + except Exception: + return False + + +requires_cuda_ext_mx = pytest.mark.skipif( + not _modelopt_cuda_ext_mx_available(), + reason="modelopt cuda_ext_mx not available", +) + +requires_dynamic_block_quantize = pytest.mark.skipif( + not _modelopt_dynamic_block_quantize_available(), + reason="modelopt dynamic_block_quantize_op not available", +) + + +class TestFP4FakeQuantVsModelopt: + """Compare experimental CUDA FP4 fake quant against all modelopt FP4 implementations. + + This ensures the standalone FP4 kernel in experimental/conv produces the same + results as the official modelopt quantization paths: + 1. Triton fp4_fake_quant_block (Hopper+ dynamic blockwise) + 2. cuda_ext_mx.fused_amax_convert (CUDA extension fallback) + 3. dynamic_block_quantize_op (high-level API that dispatches to either) + """ + + @requires_triton_fp4 + @pytest.mark.parametrize("block_size", [16, 32, 64]) + @pytest.mark.parametrize("seed", [42, 123, 999]) + def test_vs_triton_fp4_fake_quant_block(self, cuda_fp4, block_size, seed): + """Compare against modelopt Triton fp4_fake_quant_block.""" + from modelopt.torch.quantization.triton import fp4_fake_quant_block + + torch.manual_seed(seed) + num_blocks = 16 + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 + global_amax = x.abs().max() + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=num_blocks, + tile_cols=block_size, + ) + + assert torch.allclose(ours, theirs, atol=1e-5), ( + f"experimental vs modelopt Triton max diff: {(ours - theirs).abs().max().item():.6e}" + ) + + @requires_cuda_ext_mx + @pytest.mark.parametrize("block_size", [16, 32]) + @pytest.mark.parametrize("seed", [42, 123]) + def test_vs_cuda_ext_mx(self, cuda_fp4, block_size, seed): + """Compare against modelopt cuda_ext_mx.fused_amax_convert.""" + from modelopt.torch.quantization.extensions import get_cuda_ext_mx + from modelopt.torch.quantization.tensor_quant import mx_format_map + + cuda_ext_mx = get_cuda_ext_mx() + torch.manual_seed(seed) + num_blocks = 16 + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 + global_amax = x.abs().max() + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = cuda_ext_mx.fused_amax_convert( + x, + block_size, + getattr(cuda_ext_mx.Types, mx_format_map[(2, 1)]), + getattr(cuda_ext_mx.Types, mx_format_map[(4, 3)]), + global_amax, + ) + + assert torch.allclose(ours, theirs, atol=1e-5), ( + f"experimental vs modelopt cuda_ext_mx max diff: " + f"{(ours - theirs).abs().max().item():.6e}" + ) + + @requires_dynamic_block_quantize + @pytest.mark.parametrize("seed", [42, 123, 999]) + def test_vs_dynamic_block_quantize_op(self, cuda_fp4, seed): + """Compare against modelopt dynamic_block_quantize_op (high-level API). + + This is the function used by the actual quantization pipeline with + num_bits=4 (E2M1) and scale_bits=8 (E4M3). + Note: dynamic_block_quantize_op dispatches to Triton with default block_size=16. + """ + from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op + + block_size = 16 # dynamic_block_quantize_op uses block_size=16 for Triton path + torch.manual_seed(seed) + num_blocks = 16 + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 + global_amax = x.abs().max() + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = dynamic_block_quantize_op( + x, + block_size, + global_amax, + num_bits=4, # total bits = 1 sign + 2 exp + 1 mantissa + exponent_bits=2, + scale_num_bits=8, # FP8 E4M3 for scales + scale_exponent_bits=4, + ) + + assert torch.allclose(ours, theirs, atol=1e-5), ( + f"experimental vs modelopt dynamic_block_quantize_op max diff: " + f"{(ours - theirs).abs().max().item():.6e}" + ) + + @requires_triton_fp4 + def test_vs_triton_realistic_shape(self, cuda_fp4): + """Realistic activation shape from a Conv3D layer (flattened).""" + torch.manual_seed(42) + block_size = 16 + # Simulate a large tensor: 256 blocks of 16 elements + # (tile_rows must be power-of-2 for Triton block_ptr) + num_blocks = 256 + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 5 + global_amax = x.abs().max() + + from modelopt.torch.quantization.triton import fp4_fake_quant_block + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=16, + tile_cols=block_size, + ) + + max_diff = (ours - theirs).abs().max().item() + mean_diff = (ours - theirs).abs().mean().item() + assert torch.allclose(ours, theirs, atol=1e-5), ( + f"Realistic shape: experimental vs Triton max diff: {max_diff:.6e}, " + f"mean diff: {mean_diff:.6e}" + ) + + @requires_triton_fp4 + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_vs_triton_input_dtypes(self, cuda_fp4, dtype): + """Test that our kernel handles different input dtypes correctly. + + Our kernel casts to float32 internally, so the result should match + Triton's output when both receive the same dtype input. + """ + from modelopt.torch.quantization.triton import fp4_fake_quant_block + + torch.manual_seed(42) + block_size = 16 + num_blocks = 8 + x = (torch.randn(num_blocks, block_size, device="cuda") * 5).to(dtype) + global_amax = x.float().abs().max() + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=num_blocks, + tile_cols=block_size, + ) + + # Both should return the input dtype + assert ours.dtype == dtype + assert theirs.dtype == dtype + + # Compare in float32 + max_diff = (ours.float() - theirs.float()).abs().max().item() + # BF16/FP16 input rounding may cause small diffs + tol = 1e-2 if dtype != torch.float32 else 1e-5 + assert max_diff < tol, f"dtype={dtype}: experimental vs Triton max diff: {max_diff:.6e}"