From 31c201bf6b932b404d74cd88d2fa4dba0f08e518 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 13 Feb 2026 08:39:39 +0000 Subject: [PATCH 01/25] Update the implicit gemm kernel Signed-off-by: Jingyu Xin --- modelopt/torch/quantization/conv/README.md | 54 ++ .../quantization/conv/implicit_gemm_cuda.py | 706 ++++++++++++++++++ 2 files changed, 760 insertions(+) create mode 100644 modelopt/torch/quantization/conv/README.md create mode 100644 modelopt/torch/quantization/conv/implicit_gemm_cuda.py diff --git a/modelopt/torch/quantization/conv/README.md b/modelopt/torch/quantization/conv/README.md new file mode 100644 index 0000000000..a76cbf895c --- /dev/null +++ b/modelopt/torch/quantization/conv/README.md @@ -0,0 +1,54 @@ +# Conv3D Implicit GEMM Kernels + +CUDA and Triton kernels for Conv3D via implicit GEMM with optional FP4 fake quantization. + +## Usage + +```python +import torch +from modelopt.torch.quantization.conv_gemm.implicit_gemm_cuda import conv3d_implicit_gemm_cuda + +x = torch.randn(1, 128, 21, 60, 106, device="cuda") +w = torch.randn(512, 128, 3, 3, 3, device="cuda") + +# Without quantization (drop-in replacement for F.conv3d) +out = conv3d_implicit_gemm_cuda(x, w, stride=(1,1,1), padding=(1,1,1)) + +# With FP4 quantization +out = conv3d_implicit_gemm_cuda( + x, w, + stride=(1,1,1), + padding=(1,1,1), + act_amax=x.abs().max().unsqueeze(0), + quant_act=True, + FP4_BLOCK_SIZE=128, # 128 or 256 +) +``` + +The Triton kernel has the same API: + +```python +from modelopt.torch.quantization.conv_gemm.implicit_gemm import conv3d_implicit_gemm_triton + +out = conv3d_implicit_gemm_triton(x, w, stride=(1,1,1), padding=(1,1,1)) +``` + +## Parameters + +| Parameter | Description | +|-----------|-------------| +| `x` | Input tensor `[N, Cin, D, H, W]` | +| `w` | Weight tensor `[Cout, Cin, kD, kH, kW]` | +| `bias` | Optional bias `[Cout]` | +| `stride` | Convolution stride `(D, H, W)` | +| `padding` | Convolution padding `(D, H, W)` | +| `dilation` | Convolution dilation `(D, H, W)` | +| `act_amax` | Activation abs-max scalar tensor (required when `quant_act=True`) | +| `quant_act` | Enable FP4 fake quantization on activations | +| `FP4_BLOCK_SIZE` | Quantization block size: `128` or `256` | + +## Notes + +- The CUDA kernel is JIT-compiled on first call (takes a few seconds). +- Both kernels return the same shape as `torch.nn.functional.conv3d`. +- FP4 quantization fuses the quantize-dequantize into the GEMM tile load, so there is minimal overhead vs the non-quantized path. \ No newline at end of file diff --git a/modelopt/torch/quantization/conv/implicit_gemm_cuda.py b/modelopt/torch/quantization/conv/implicit_gemm_cuda.py new file mode 100644 index 0000000000..c2bd4597df --- /dev/null +++ b/modelopt/torch/quantization/conv/implicit_gemm_cuda.py @@ -0,0 +1,706 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Optimized CUDA-based Conv3D Implicit GEMM with FP4 quantization using BF16 WMMA Tensor Cores. + +This module provides an optimized CUDA kernel for Conv3D via implicit GEMM with +fused FP4 fake quantization. The kernel is compiled just-in-time using +PyTorch's cpp_extension. + +Key optimizations: +1. BF16 WMMA tensor core operations (m16n16k16) with FP32 accumulators +2. On-the-fly spatial index computation (no global memory lookup tables) +3. Dual FP4_BLOCK_SIZE support (128 and 256) with optimized tile configs: + - FP4_BLOCK_SIZE=128: BM=64, BN=64, BK=128, 8 warps (256 threads), ~35KB shared + - FP4_BLOCK_SIZE=256: BM=64, BN=64, BK=256, 8 warps (256 threads), ~69KB shared +4. Register-fused FP4 quantization (quantize during A-tile load, eliminates sync) +5. Branchless FP4 quantization using predicated selects +6. BF16 shared memory (halves memory vs FP32) +7. L2-friendly block scheduling (swizzled grid) +8. FP8 E4M3 round-trip for scale quantization (matches Triton exactly) +""" + +import torch +import torch.nn.functional as F + +# C++ header for function declarations +CPP_SOURCE = r""" +torch::Tensor conv3d_implicit_gemm_cuda( + torch::Tensor x_pad, + torch::Tensor w_flat, + torch::Tensor bias, + torch::Tensor act_amax, + int N_batch, int Cin, int Dp, int Hp, int Wp, + int Cout, int OD, int OH, int OW, + int kD, int kH, int kW, + int sd, int sh, int sw, + int dd, int dh, int dw, + int M, int K, + bool quant_act, bool has_bias, + int fp4_block_size +); +""" + +# Optimized CUDA kernel with BF16 WMMA tensor cores +CUDA_KERNEL_SOURCE = r""" +#include +#include +#include +#include +#include + +using namespace nvcuda; + +// ============================================================================= +// FP4 Quantization Helpers +// ============================================================================= + +__device__ __forceinline__ float fp4_quantize_value(float scaled) { + float q; + q = (scaled <= 5.0f) ? 4.0f : 6.0f; + q = (scaled < 3.5f) ? 3.0f : q; + q = (scaled <= 2.5f) ? 2.0f : q; + q = (scaled < 1.75f) ? 1.5f : q; + q = (scaled <= 1.25f) ? 1.0f : q; + q = (scaled < 0.75f) ? 0.5f : q; + q = (scaled <= 0.25f) ? 0.0f : q; + return q; +} + +__device__ __forceinline__ float warp_reduce_max(float val) { + #pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, offset)); + } + return val; +} + +__device__ __forceinline__ float fp8_e4m3_round_trip(float x) { + if (x == 0.0f) return 0.0f; + + unsigned int bits = __float_as_uint(x); + unsigned int sign = bits >> 31; + int exp = ((bits >> 23) & 0xff) - 127; + unsigned int mantissa = bits & 0x7fffff; + + if (exp > 8) return sign ? -448.0f : 448.0f; + if (exp < -9) return 0.0f; + + unsigned int mantissa_3bit = (mantissa + (1 << 19)) >> 20; + if (mantissa_3bit > 7) { + mantissa_3bit = 0; + exp += 1; + if (exp > 8) return sign ? -448.0f : 448.0f; + } + + if (exp < -6) { + int shift = -6 - exp; + mantissa_3bit = (mantissa_3bit | 8) >> shift; + exp = -6; + } + + int fp32_exp = exp + 127; + unsigned int fp32_mantissa = mantissa_3bit << 20; + unsigned int fp32_bits = (sign << 31) | (fp32_exp << 23) | fp32_mantissa; + + return __uint_as_float(fp32_bits); +} + +__device__ __forceinline__ float quantize_scale_fp8(float block_max, float global_scale) { + float scaled = block_max / (6.0f * global_scale); + scaled = fminf(scaled, 448.0f); + float quantized = fp8_e4m3_round_trip(scaled); + return quantized * global_scale; +} + +// ============================================================================= +// BF16 WMMA Conv3D Implicit GEMM Kernel +// ============================================================================= +// Template parameters: +// QUANT_ACT - whether to apply FP4 quantization +// HAS_BIAS - whether bias is present +// BLOCK_M - M tile size (64) +// BLOCK_N - N tile size (32) +// BLOCK_K - K tile size (matches FP4_BLOCK_SIZE: 128 or 256) +// WARPS_M - warp tiling in M dimension (2) +// WARPS_N - warp tiling in N dimension (2) +// L2_SWIZZLE_GROUP - group size for L2-friendly block scheduling +// +// Each warp computes a (WARP_M x WARP_N) output tile using 16x16x16 WMMA. +// WARP_M = BLOCK_M / WARPS_M, WARP_N = BLOCK_N / WARPS_N +// WARP_TILES_M = WARP_M / 16, WARP_TILES_N = WARP_N / 16 +// +// Shared memory layout (BF16): +// As[BLOCK_M][BK_STRIDE] - M-major (row_major for WMMA A-fragments) +// Bs[BLOCK_K][BN_STRIDE] - K-major (row_major for WMMA B-fragments) + +template< + bool QUANT_ACT, bool HAS_BIAS, + int BLOCK_M, int BLOCK_N, int BLOCK_K, + int WARPS_M, int WARPS_N, + int L2_SWIZZLE_GROUP = 8 +> +__global__ void __launch_bounds__(WARPS_M * WARPS_N * 32, 2) +conv3d_implicit_gemm_wmma( + const float* __restrict__ x_pad, + const float* __restrict__ w_flat, + const float* __restrict__ bias, + float* __restrict__ y, + const float* __restrict__ act_amax, + int Cin, int Dp, int Hp, int Wp, + int Cout, int OD, int OH, int OW, + int kD, int kH, int kW, + int sd, int sh, int sw, + int dd, int dh, int dw, + int M, int K +) { + // Derived constants + constexpr int NUM_WARPS = WARPS_M * WARPS_N; + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int WARP_M = BLOCK_M / WARPS_M; // 32 + constexpr int WARP_N = BLOCK_N / WARPS_N; // 16 + constexpr int WARP_TILES_M = WARP_M / 16; // 2 + constexpr int WARP_TILES_N = WARP_N / 16; // 1 + + // BF16 shared memory strides with padding to avoid bank conflicts + // Pad by 8 BF16 elements (16 bytes) — keeps 16-byte alignment while breaking conflicts + constexpr int BK_STRIDE = BLOCK_K + 8; + constexpr int BN_STRIDE = BLOCK_N + 8; + + // Thread/warp indices + const int tid = threadIdx.x; + const int warp_id = tid / 32; + const int lane_id = tid % 32; + const int warp_m = warp_id / WARPS_N; // which M-warp (0..WARPS_M-1) + const int warp_n = warp_id % WARPS_N; // which N-warp (0..WARPS_N-1) + + // L2-friendly block scheduling (swizzle) + int bm, bn; + { + const int pid = blockIdx.x; + constexpr int GS = L2_SWIZZLE_GROUP; + const int grid_n = (Cout + BLOCK_N - 1) / BLOCK_N; + const int grid_m = (M + BLOCK_M - 1) / BLOCK_M; + const int tiles_per_group = GS * grid_n; + + const int group_row = pid / tiles_per_group; + const int group_rem = pid % tiles_per_group; + bn = group_rem / GS; + const int swizzle_lane = group_rem % GS; + bm = group_row * GS + swizzle_lane; + + if (bm >= grid_m || bn >= grid_n) return; + } + + // Dynamic shared memory — BF16 tiles + extern __shared__ char smem_raw[]; + __nv_bfloat16* As = reinterpret_cast<__nv_bfloat16*>(smem_raw); + // As: [BLOCK_M][BK_STRIDE] — M-major + constexpr int A_SMEM_ELEMS = BLOCK_M * BK_STRIDE; + __nv_bfloat16* Bs = As + A_SMEM_ELEMS; + // Bs: [BLOCK_K][BN_STRIDE] — K-major + + // WMMA accumulators — FP32 + wmma::fragment acc[WARP_TILES_M][WARP_TILES_N]; + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; wm++) { + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; wn++) { + wmma::fill_fragment(acc[wm][wn], 0.0f); + } + } + + // Global scale for FP4 quantization + float global_scale = 1.0f; + if constexpr (QUANT_ACT) { + global_scale = act_amax[0] / (6.0f * 448.0f); + } + + // Precompute spatial constants + const int HpWp = Hp * Wp; + const int DpHpWp = Dp * HpWp; + const int kHW = kH * kW; + const int kDHW = kD * kHW; + const int OHW = OH * OW; + const int ODHW = OD * OHW; + + const int m_start = bm * BLOCK_M; + const int n_start = bn * BLOCK_N; + const int num_k_tiles = (K + BLOCK_K - 1) / BLOCK_K; + + // Total elements to load cooperatively + constexpr int A_ELEMS = BLOCK_M * BLOCK_K; + constexpr int B_ELEMS = BLOCK_K * BLOCK_N; + + // Main loop over K tiles + for (int k_tile = 0; k_tile < num_k_tiles; k_tile++) { + const int k_start_tile = k_tile * BLOCK_K; + + // ===================================================================== + // Load A tile into BF16 shared memory (M-major layout) + // As[m][k] stored at As[m * BK_STRIDE + k] + // ===================================================================== + if constexpr (QUANT_ACT) { + // Fused FP4 quantization: each warp handles M-rows + constexpr int ELEMS_PER_LANE = (BLOCK_K + 31) / 32; + + for (int m = warp_id; m < BLOCK_M; m += NUM_WARPS) { + int m_idx = m_start + m; + + int n_batch, od_val, oh_val, ow_val; + if (m_idx < M) { + n_batch = m_idx / ODHW; + int rem = m_idx % ODHW; + od_val = rem / OHW; + rem = rem % OHW; + oh_val = rem / OW; + ow_val = rem % OW; + } else { + n_batch = 0; od_val = 0; oh_val = 0; ow_val = 0; + } + + float local_max = 0.0f; + float vals[ELEMS_PER_LANE]; + + #pragma unroll + for (int i = 0; i < ELEMS_PER_LANE; i++) { + int k = lane_id + i * 32; + float val = 0.0f; + if (k < BLOCK_K && m_idx < M) { + int k_idx = k_start_tile + k; + if (k_idx < K) { + int c = k_idx / kDHW; + int remk = k_idx % kDHW; + int kd_v = remk / kHW; + remk = remk % kHW; + int kh_v = remk / kW; + int kw_v = remk % kW; + + int id = od_val * sd + kd_v * dd; + int ih = oh_val * sh + kh_v * dh; + int iw = ow_val * sw + kw_v * dw; + + val = x_pad[n_batch * Cin * DpHpWp + c * DpHpWp + id * HpWp + ih * Wp + iw]; + } + } + vals[i] = val; + local_max = fmaxf(local_max, fabsf(val)); + } + + float block_max = warp_reduce_max(local_max); + float scale = quantize_scale_fp8(block_max, global_scale); + if (scale < 1e-5f) scale = 1.0f; + float inv_scale = 1.0f / scale; + + #pragma unroll + for (int i = 0; i < ELEMS_PER_LANE; i++) { + int k = lane_id + i * 32; + if (k < BLOCK_K) { + float val = vals[i]; + float sign = (val >= 0.0f) ? 1.0f : -1.0f; + float q = fp4_quantize_value(fabsf(val) * inv_scale); + float result = sign * q * scale; + // M-major: As[m * BK_STRIDE + k] + As[m * BK_STRIDE + k] = __float2bfloat16(result); + } + } + } + } else { + // Non-quantized: cooperative load, store as BF16 in M-major + #pragma unroll 4 + for (int i = tid; i < A_ELEMS; i += NUM_THREADS) { + int local_m = i / BLOCK_K; + int local_k = i % BLOCK_K; + int m_idx = m_start + local_m; + int k_idx = k_start_tile + local_k; + + float val = 0.0f; + if (m_idx < M && k_idx < K) { + int n_batch = m_idx / ODHW; + int rem = m_idx % ODHW; + int od_val = rem / OHW; + rem = rem % OHW; + int oh_val = rem / OW; + int ow_val = rem % OW; + + int c = k_idx / kDHW; + int remk = k_idx % kDHW; + int kd_v = remk / kHW; + remk = remk % kHW; + int kh_v = remk / kW; + int kw_v = remk % kW; + + int id = od_val * sd + kd_v * dd; + int ih = oh_val * sh + kh_v * dh; + int iw = ow_val * sw + kw_v * dw; + + val = x_pad[n_batch * Cin * DpHpWp + c * DpHpWp + id * HpWp + ih * Wp + iw]; + } + // M-major: As[m * BK_STRIDE + k] + As[local_m * BK_STRIDE + local_k] = __float2bfloat16(val); + } + } + + // ===================================================================== + // Load B tile into BF16 shared memory (K-major layout) + // Bs[k][n] stored at Bs[k * BN_STRIDE + n] + // ===================================================================== + #pragma unroll 4 + for (int i = tid; i < B_ELEMS; i += NUM_THREADS) { + int local_k = i / BLOCK_N; + int local_n = i % BLOCK_N; + int k_idx = k_start_tile + local_k; + int n_idx = n_start + local_n; + + float val = 0.0f; + if (k_idx < K && n_idx < Cout) { + val = w_flat[k_idx * Cout + n_idx]; + } + Bs[local_k * BN_STRIDE + local_n] = __float2bfloat16(val); + } + + __syncthreads(); + + // ===================================================================== + // WMMA Compute: iterate over K in steps of 16 (WMMA K-dim) + // ===================================================================== + constexpr int K_STEPS = BLOCK_K / 16; + + #pragma unroll + for (int kk = 0; kk < K_STEPS; kk++) { + // Load A and B fragments from shared memory + wmma::fragment a_frag[WARP_TILES_M]; + wmma::fragment b_frag[WARP_TILES_N]; + + // Load A fragments: each from As[(warp_m * WARP_M + wm*16) * BK_STRIDE + kk*16] + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; wm++) { + int a_row = warp_m * WARP_M + wm * 16; + int a_col = kk * 16; + wmma::load_matrix_sync(a_frag[wm], &As[a_row * BK_STRIDE + a_col], BK_STRIDE); + } + + // Load B fragments: each from Bs[(kk*16) * BN_STRIDE + (warp_n * WARP_N + wn*16)] + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; wn++) { + int b_row = kk * 16; + int b_col = warp_n * WARP_N + wn * 16; + wmma::load_matrix_sync(b_frag[wn], &Bs[b_row * BN_STRIDE + b_col], BN_STRIDE); + } + + // MMA: acc[wm][wn] += a_frag[wm] * b_frag[wn] + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; wm++) { + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; wn++) { + wmma::mma_sync(acc[wm][wn], a_frag[wm], b_frag[wn], acc[wm][wn]); + } + } + } + + __syncthreads(); + } + + // ========================================================================= + // Store results: use shared memory as FP32 staging buffer + // Each warp stores its accumulator fragments, then all threads cooperatively + // copy to global memory with bounds checking and bias addition. + // ========================================================================= + + // Reinterpret shared memory as FP32 for output staging + // We need BLOCK_M * BLOCK_N floats = 64 * 32 * 4 = 8192 bytes + // This fits within our shared memory (>= 27KB) + float* out_smem = reinterpret_cast(smem_raw); + // out_smem layout: [BLOCK_M][BLOCK_N], row-major + + // Each warp stores its accumulator fragments to shared memory + #pragma unroll + for (int wm = 0; wm < WARP_TILES_M; wm++) { + #pragma unroll + for (int wn = 0; wn < WARP_TILES_N; wn++) { + int out_row = warp_m * WARP_M + wm * 16; + int out_col = warp_n * WARP_N + wn * 16; + // Store to out_smem[out_row][out_col] with stride BLOCK_N + wmma::store_matrix_sync(&out_smem[out_row * BLOCK_N + out_col], acc[wm][wn], BLOCK_N, wmma::mem_row_major); + } + } + + __syncthreads(); + + // Cooperatively copy from shared memory to global memory + constexpr int OUT_ELEMS = BLOCK_M * BLOCK_N; + #pragma unroll 4 + for (int i = tid; i < OUT_ELEMS; i += NUM_THREADS) { + int local_m = i / BLOCK_N; + int local_n = i % BLOCK_N; + int m_idx = m_start + local_m; + int n_idx = n_start + local_n; + + if (m_idx < M && n_idx < Cout) { + float result = out_smem[i]; + if constexpr (HAS_BIAS) { + result += bias[n_idx]; + } + y[m_idx * Cout + n_idx] = result; + } + } +} + +// ============================================================================= +// Python Binding +// ============================================================================= + +torch::Tensor conv3d_implicit_gemm_cuda( + torch::Tensor x_pad, + torch::Tensor w_flat, + torch::Tensor bias, + torch::Tensor act_amax, + int N_batch, int Cin, int Dp, int Hp, int Wp, + int Cout, int OD, int OH, int OW, + int kD, int kH, int kW, + int sd, int sh, int sw, + int dd, int dh, int dw, + int M, int K, + bool quant_act, bool has_bias, + int fp4_block_size +) { + auto y = torch::zeros({M, Cout}, x_pad.options()); + + // Helper to compute padded 1D grid size for L2 swizzle + constexpr int GS = 8; // L2_SWIZZLE_GROUP + auto compute_grid = [&](int BM, int BN) -> dim3 { + int grid_m = (M + BM - 1) / BM; + int grid_n = (Cout + BN - 1) / BN; + int num_m_groups = (grid_m + GS - 1) / GS; + int total_blocks = num_m_groups * GS * grid_n; + return dim3(total_blocks, 1); + }; + + // Macro to dispatch kernel with all 4 template specializations + #define LAUNCH_WMMA_KERNEL(BM, BN, BK, WM, WN) \ + { \ + constexpr int BK_S = BK + 8; \ + constexpr int BN_S = BN + 8; \ + constexpr size_t smem_a = BM * BK_S * sizeof(__nv_bfloat16); \ + constexpr size_t smem_b = BK * BN_S * sizeof(__nv_bfloat16); \ + constexpr size_t smem = smem_a + smem_b; \ + \ + dim3 block(WM * WN * 32); \ + dim3 grid = compute_grid(BM, BN); \ + \ + auto set_smem = [](auto kernel) { \ + constexpr size_t s_a = BM * (BK + 8) * sizeof(__nv_bfloat16); \ + constexpr size_t s_b = BK * (BN + 8) * sizeof(__nv_bfloat16); \ + constexpr size_t s = s_a + s_b; \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, s); \ + }; \ + \ + if (quant_act && has_bias) { \ + auto kern = conv3d_implicit_gemm_wmma; \ + set_smem(kern); \ + kern<<>>( \ + x_pad.data_ptr(), w_flat.data_ptr(), \ + bias.data_ptr(), y.data_ptr(), \ + act_amax.data_ptr(), \ + Cin, Dp, Hp, Wp, Cout, OD, OH, OW, kD, kH, kW, \ + sd, sh, sw, dd, dh, dw, M, K); \ + } else if (quant_act) { \ + auto kern = conv3d_implicit_gemm_wmma; \ + set_smem(kern); \ + kern<<>>( \ + x_pad.data_ptr(), w_flat.data_ptr(), \ + bias.data_ptr(), y.data_ptr(), \ + act_amax.data_ptr(), \ + Cin, Dp, Hp, Wp, Cout, OD, OH, OW, kD, kH, kW, \ + sd, sh, sw, dd, dh, dw, M, K); \ + } else if (has_bias) { \ + auto kern = conv3d_implicit_gemm_wmma; \ + set_smem(kern); \ + kern<<>>( \ + x_pad.data_ptr(), w_flat.data_ptr(), \ + bias.data_ptr(), y.data_ptr(), \ + act_amax.data_ptr(), \ + Cin, Dp, Hp, Wp, Cout, OD, OH, OW, kD, kH, kW, \ + sd, sh, sw, dd, dh, dw, M, K); \ + } else { \ + auto kern = conv3d_implicit_gemm_wmma; \ + set_smem(kern); \ + kern<<>>( \ + x_pad.data_ptr(), w_flat.data_ptr(), \ + bias.data_ptr(), y.data_ptr(), \ + act_amax.data_ptr(), \ + Cin, Dp, Hp, Wp, Cout, OD, OH, OW, kD, kH, kW, \ + sd, sh, sw, dd, dh, dw, M, K); \ + } \ + } + + if (fp4_block_size == 128) { + // BLOCK_M=64, BLOCK_N=64, BLOCK_K=128, WARPS_M=2, WARPS_N=4 + // 8 warps = 256 threads -> faster cooperative loading + // WARP_M=32, WARP_N=16, WARP_TILES_M=2, WARP_TILES_N=1 -> 2 mma per warp per K-step + // Shared: 64*(128+8)*2 + 128*(64+8)*2 = 17,408 + 18,432 = 35,840 bytes (~35KB) + LAUNCH_WMMA_KERNEL(64, 64, 128, 2, 4) + } else { + // BLOCK_M=64, BLOCK_N=64, BLOCK_K=256, WARPS_M=2, WARPS_N=4 + // 8 warps = 256 threads -> faster cooperative loading + // Shared: 64*(256+8)*2 + 256*(64+8)*2 = 33,792 + 36,864 = 70,656 bytes (~69KB) + LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4) + } + + #undef LAUNCH_WMMA_KERNEL + + return y; +} +""" + +# Compile the CUDA kernel +_cuda_module = None + + +def _get_cuda_module(): + """Get or compile the CUDA module.""" + global _cuda_module + if _cuda_module is None: + from torch.utils.cpp_extension import load_inline + + _cuda_module = load_inline( + name="conv3d_implicit_gemm_cuda_v19_wmma", + cpp_sources=CPP_SOURCE, + cuda_sources=CUDA_KERNEL_SOURCE, + functions=["conv3d_implicit_gemm_cuda"], + verbose=True, + extra_cuda_cflags=[ + "-O3", + "--use_fast_math", + "-lineinfo", + "--ptxas-options=-v", + "-std=c++17", + ], + ) + return _cuda_module + + +def _triple(v) -> tuple[int, int, int]: + if isinstance(v, int): + return (v, v, v) + assert len(v) == 3 + return (int(v[0]), int(v[1]), int(v[2])) + + +def _pad6(padding) -> tuple[int, int, int, int, int, int]: + if isinstance(padding, int): + p = int(padding) + return (p, p, p, p, p, p) + if len(padding) == 3: + pd, ph, pw = map(int, padding) + return (pw, pw, ph, ph, pd, pd) + assert len(padding) == 6 + return tuple(map(int, padding)) + + +@torch.no_grad() +def conv3d_implicit_gemm_cuda( + x: torch.Tensor, + w: torch.Tensor, + bias: torch.Tensor | None = None, + stride: tuple[int, int, int] = (1, 1, 1), + padding: tuple[int, int, int] = (0, 0, 0), + dilation: tuple[int, int, int] = (1, 1, 1), + act_amax: torch.Tensor | None = None, + quant_act: bool = False, + FP4_BLOCK_SIZE: int = 256, +) -> torch.Tensor: + """Optimized CUDA-based Conv3D via implicit GEMM with BF16 WMMA tensor cores. + + Args: + x: Input tensor [N, Cin, D, H, W] + w: Weight tensor [Cout, Cin, kD, kH, kW] + bias: Optional bias tensor [Cout] + stride: Convolution stride (D, H, W) + padding: Convolution padding (D, H, W) + dilation: Convolution dilation (D, H, W) + act_amax: Activation max value for FP4 quantization + quant_act: Whether to apply FP4 quantization to activations + FP4_BLOCK_SIZE: FP4 quantization block size (128 or 256) + + Returns: + Output tensor [N, Cout, OD, OH, OW] + """ + cuda_mod = _get_cuda_module() + + assert x.ndim == 5 and w.ndim == 5 + N_batch, Cin, D, H, W = x.shape + Cout, Cin_w, kD, kH, kW = w.shape + assert Cin_w == Cin + + sd, sh, sw = _triple(stride) + dd, dh, dw = _triple(dilation) + pad_wl, pad_wr, pad_hl, pad_hr, pad_dl, pad_dr = _pad6(padding) + + x_pad = F.pad(x, (pad_wl, pad_wr, pad_hl, pad_hr, pad_dl, pad_dr)) + Dp = D + pad_dl + pad_dr + Hp = H + pad_hl + pad_hr + Wp = W + pad_wl + pad_wr + + OD = (Dp - (dd * (kD - 1) + 1)) // sd + 1 + OH = (Hp - (dh * (kH - 1) + 1)) // sh + 1 + OW = (Wp - (dw * (kW - 1) + 1)) // sw + 1 + + M = N_batch * OD * OH * OW + K = Cin * kD * kH * kW + + w_flat = w.reshape(Cout, K).transpose(0, 1).contiguous() + + x_pad = x_pad.float().contiguous() + w_flat = w_flat.float().contiguous() + + has_bias = bias is not None + bias_t = bias.float().contiguous() if has_bias else torch.empty(0, device=x.device) + + do_quant = quant_act and act_amax is not None + amax_t = act_amax.float().contiguous() if do_quant else torch.empty(0, device=x.device) + + y_flat = cuda_mod.conv3d_implicit_gemm_cuda( + x_pad, + w_flat, + bias_t, + amax_t, + N_batch, + Cin, + Dp, + Hp, + Wp, + Cout, + OD, + OH, + OW, + kD, + kH, + kW, + sd, + sh, + sw, + dd, + dh, + dw, + M, + K, + do_quant, + has_bias, + FP4_BLOCK_SIZE, + ) + + y = y_flat.view(N_batch, OD, OH, OW, Cout).permute(0, 4, 1, 2, 3).contiguous() + return y.to(x.dtype) From 9b278d86dd724cc5ba0a86fbb5d8286763d5eea9 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 13 Feb 2026 23:46:48 +0000 Subject: [PATCH 02/25] Update the readme Signed-off-by: Jingyu Xin --- modelopt/torch/quantization/conv/README.md | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/conv/README.md b/modelopt/torch/quantization/conv/README.md index a76cbf895c..f0ab1b9c50 100644 --- a/modelopt/torch/quantization/conv/README.md +++ b/modelopt/torch/quantization/conv/README.md @@ -7,13 +7,25 @@ CUDA and Triton kernels for Conv3D via implicit GEMM with optional FP4 fake quan ```python import torch from modelopt.torch.quantization.conv_gemm.implicit_gemm_cuda import conv3d_implicit_gemm_cuda +from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op x = torch.randn(1, 128, 21, 60, 106, device="cuda") w = torch.randn(512, 128, 3, 3, 3, device="cuda") +block_size=128 # Without quantization (drop-in replacement for F.conv3d) out = conv3d_implicit_gemm_cuda(x, w, stride=(1,1,1), padding=(1,1,1)) +w = dynamic_block_quantize_op( + w, + block_size, + w.abs().max().unsqueeze(0), # AMAX + 5, # num_bits + 2, # exponent_bits + 8, # scale_num_bits, + 4, # scale_exponent_bits +) + # With FP4 quantization out = conv3d_implicit_gemm_cuda( x, w, @@ -21,7 +33,7 @@ out = conv3d_implicit_gemm_cuda( padding=(1,1,1), act_amax=x.abs().max().unsqueeze(0), quant_act=True, - FP4_BLOCK_SIZE=128, # 128 or 256 + FP4_BLOCK_SIZE=block_size, # 128 or 256 ) ``` @@ -51,4 +63,4 @@ out = conv3d_implicit_gemm_triton(x, w, stride=(1,1,1), padding=(1,1,1)) - The CUDA kernel is JIT-compiled on first call (takes a few seconds). - Both kernels return the same shape as `torch.nn.functional.conv3d`. -- FP4 quantization fuses the quantize-dequantize into the GEMM tile load, so there is minimal overhead vs the non-quantized path. \ No newline at end of file +- FP4 quantization fuses the quantize-dequantize into the GEMM tile load, so there is minimal overhead vs the non-quantized path. From abd598f2b77aff46ac5e180dec3b3b475e582f1b Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Sat, 14 Feb 2026 00:16:02 +0000 Subject: [PATCH 03/25] Update Signed-off-by: Jingyu Xin --- experimental/conv/README.md | 99 +++++++++++++++++++ .../conv/implicit_gemm_cuda.py | 66 ++++++------- modelopt/torch/quantization/conv/README.md | 66 ------------- 3 files changed, 132 insertions(+), 99 deletions(-) create mode 100644 experimental/conv/README.md rename {modelopt/torch/quantization => experimental}/conv/implicit_gemm_cuda.py (96%) delete mode 100644 modelopt/torch/quantization/conv/README.md diff --git a/experimental/conv/README.md b/experimental/conv/README.md new file mode 100644 index 0000000000..55a85f875e --- /dev/null +++ b/experimental/conv/README.md @@ -0,0 +1,99 @@ +# Conv3D Implicit GEMM (Experimental) + +Experimental Conv3D kernel prototype using implicit GEMM, with optional fused FP4 fake quantization for activations. + +This code is kept under `experimental/` by design and is **not** part of the stable `modelopt.torch.quantization` API. + +## Model Support + +| Model/Framework | Supported | Notes | +|-----------------|-----------|-------| +| Video diffusion backbones using Conv3D | Partial | Intended for experimentation and microbenchmarking | +| Generic LLM backbones | No | Conv3D path is not relevant | +| End-to-end ModelOpt PTQ/QAT pipeline | No | Not wired into formal quantization/export/compress flows | + +## Deployment + +| Framework | Supported | Notes | +|-----------|-----------|-------| +| TensorRT-LLM | No | No formal export integration for this kernel path | +| vLLM | No | No integration | +| SGLang | No | No integration | +| PyTorch runtime (CUDA) | Yes (experimental) | JIT-compiles CUDA extension on first use | + +## Usage + +```python +import torch + +from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda +from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op + +x = torch.randn(1, 128, 21, 60, 106, device="cuda") +w = torch.randn(512, 128, 3, 3, 3, device="cuda") +block_size = 128 + +# Without FP4 activation quantization (drop-in-style Conv3D call) +out = conv3d_implicit_gemm_cuda(x, w, stride=(1, 1, 1), padding=(1, 1, 1)) + +# Optional block quantization of weights for experiments +w_q = dynamic_block_quantize_op( + w, + block_size, + w.abs().max().unsqueeze(0), + 4, # num_bits + 2, # exponent_bits + 8, # scale_num_bits + 4, # scale_exponent_bits +) + +# With FP4 activation fake quantization +out_q = conv3d_implicit_gemm_cuda( + x, + w_q, + stride=(1, 1, 1), + padding=(1, 1, 1), + act_amax=x.abs().max().unsqueeze(0), + quant_act=True, + fp4_block_size=block_size, # 128 or 256 +) +``` + +## API + +Function: `conv3d_implicit_gemm_cuda(...)` from `experimental/conv/implicit_gemm_cuda.py` + +| Parameter | Description | +|-----------|-------------| +| `x` | Input tensor `[N, Cin, D, H, W]` | +| `w` | Weight tensor `[Cout, Cin, kD, kH, kW]` | +| `bias` | Optional bias `[Cout]` | +| `stride` | Convolution stride `(D, H, W)` | +| `padding` | Convolution padding `(D, H, W)` | +| `dilation` | Convolution dilation `(D, H, W)` | +| `act_amax` | Activation abs-max scalar tensor (required when `quant_act=True`) | +| `quant_act` | Enable FP4 fake quantization on activations | +| `FP4_BLOCK_SIZE` | FP4 quantization block size (`128` or `256`) | + +## Status + +Current state: **Prototype** + +Known limitations: + +- API is unstable and may change without notice. +- Not registered in core quantization module registries. +- Not covered by formal export/compress integration. +- CUDA extension compile latency on first invocation. +- Validation and performance coverage are limited to local experiments. + +## Notes + +- The CUDA kernel is JIT-compiled on first call (can take several seconds). +- Output shape matches `torch.nn.functional.conv3d`. +- FP4 path applies quantize-dequantize in-kernel for activation tiles. + +## References + +- Implicit GEMM-based convolution design patterns in GPU kernels. +- ModelOpt FP4-related quantization utilities in `modelopt.torch.quantization.tensor_quant`. diff --git a/modelopt/torch/quantization/conv/implicit_gemm_cuda.py b/experimental/conv/implicit_gemm_cuda.py similarity index 96% rename from modelopt/torch/quantization/conv/implicit_gemm_cuda.py rename to experimental/conv/implicit_gemm_cuda.py index c2bd4597df..dd8fe79541 100644 --- a/modelopt/torch/quantization/conv/implicit_gemm_cuda.py +++ b/experimental/conv/implicit_gemm_cuda.py @@ -607,7 +607,7 @@ def _pad6(padding) -> tuple[int, int, int, int, int, int]: pd, ph, pw = map(int, padding) return (pw, pw, ph, ph, pd, pd) assert len(padding) == 6 - return tuple(map(int, padding)) + return tuple(map(int, padding)) # type: ignore[return-value] @torch.no_grad() @@ -620,7 +620,7 @@ def conv3d_implicit_gemm_cuda( dilation: tuple[int, int, int] = (1, 1, 1), act_amax: torch.Tensor | None = None, quant_act: bool = False, - FP4_BLOCK_SIZE: int = 256, + fp4_block_size: int = 256, ) -> torch.Tensor: """Optimized CUDA-based Conv3D via implicit GEMM with BF16 WMMA tensor cores. @@ -633,7 +633,7 @@ def conv3d_implicit_gemm_cuda( dilation: Convolution dilation (D, H, W) act_amax: Activation max value for FP4 quantization quant_act: Whether to apply FP4 quantization to activations - FP4_BLOCK_SIZE: FP4 quantization block size (128 or 256) + fp4_block_size: FP4 quantization block size (128 or 256) Returns: Output tensor [N, Cout, OD, OH, OW] @@ -641,66 +641,66 @@ def conv3d_implicit_gemm_cuda( cuda_mod = _get_cuda_module() assert x.ndim == 5 and w.ndim == 5 - N_batch, Cin, D, H, W = x.shape - Cout, Cin_w, kD, kH, kW = w.shape - assert Cin_w == Cin + n_batch, cin, d, h, w_in = x.shape + cout, cin_w, kd, kh, kw = w.shape + assert cin_w == cin sd, sh, sw = _triple(stride) dd, dh, dw = _triple(dilation) pad_wl, pad_wr, pad_hl, pad_hr, pad_dl, pad_dr = _pad6(padding) x_pad = F.pad(x, (pad_wl, pad_wr, pad_hl, pad_hr, pad_dl, pad_dr)) - Dp = D + pad_dl + pad_dr - Hp = H + pad_hl + pad_hr - Wp = W + pad_wl + pad_wr + dp = d + pad_dl + pad_dr + hp = h + pad_hl + pad_hr + wp = w_in + pad_wl + pad_wr - OD = (Dp - (dd * (kD - 1) + 1)) // sd + 1 - OH = (Hp - (dh * (kH - 1) + 1)) // sh + 1 - OW = (Wp - (dw * (kW - 1) + 1)) // sw + 1 + od = (dp - (dd * (kd - 1) + 1)) // sd + 1 + oh = (hp - (dh * (kh - 1) + 1)) // sh + 1 + ow = (wp - (dw * (kw - 1) + 1)) // sw + 1 - M = N_batch * OD * OH * OW - K = Cin * kD * kH * kW + m = n_batch * od * oh * ow + k = cin * kd * kh * kw - w_flat = w.reshape(Cout, K).transpose(0, 1).contiguous() + w_flat = w.reshape(cout, k).transpose(0, 1).contiguous() x_pad = x_pad.float().contiguous() w_flat = w_flat.float().contiguous() has_bias = bias is not None - bias_t = bias.float().contiguous() if has_bias else torch.empty(0, device=x.device) + bias_t = bias.float().contiguous() if has_bias else torch.empty(0, device=x.device) # type: ignore[union-attr] do_quant = quant_act and act_amax is not None - amax_t = act_amax.float().contiguous() if do_quant else torch.empty(0, device=x.device) + amax_t = act_amax.float().contiguous() if do_quant else torch.empty(0, device=x.device) # type: ignore[union-attr] y_flat = cuda_mod.conv3d_implicit_gemm_cuda( x_pad, w_flat, bias_t, amax_t, - N_batch, - Cin, - Dp, - Hp, - Wp, - Cout, - OD, - OH, - OW, - kD, - kH, - kW, + n_batch, + cin, + dp, + hp, + wp, + cout, + od, + oh, + ow, + kd, + kh, + kw, sd, sh, sw, dd, dh, dw, - M, - K, + m, + k, do_quant, has_bias, - FP4_BLOCK_SIZE, + fp4_block_size, ) - y = y_flat.view(N_batch, OD, OH, OW, Cout).permute(0, 4, 1, 2, 3).contiguous() + y = y_flat.view(n_batch, od, oh, ow, cout).permute(0, 4, 1, 2, 3).contiguous() return y.to(x.dtype) diff --git a/modelopt/torch/quantization/conv/README.md b/modelopt/torch/quantization/conv/README.md deleted file mode 100644 index f0ab1b9c50..0000000000 --- a/modelopt/torch/quantization/conv/README.md +++ /dev/null @@ -1,66 +0,0 @@ -# Conv3D Implicit GEMM Kernels - -CUDA and Triton kernels for Conv3D via implicit GEMM with optional FP4 fake quantization. - -## Usage - -```python -import torch -from modelopt.torch.quantization.conv_gemm.implicit_gemm_cuda import conv3d_implicit_gemm_cuda -from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op - -x = torch.randn(1, 128, 21, 60, 106, device="cuda") -w = torch.randn(512, 128, 3, 3, 3, device="cuda") -block_size=128 - -# Without quantization (drop-in replacement for F.conv3d) -out = conv3d_implicit_gemm_cuda(x, w, stride=(1,1,1), padding=(1,1,1)) - -w = dynamic_block_quantize_op( - w, - block_size, - w.abs().max().unsqueeze(0), # AMAX - 5, # num_bits - 2, # exponent_bits - 8, # scale_num_bits, - 4, # scale_exponent_bits -) - -# With FP4 quantization -out = conv3d_implicit_gemm_cuda( - x, w, - stride=(1,1,1), - padding=(1,1,1), - act_amax=x.abs().max().unsqueeze(0), - quant_act=True, - FP4_BLOCK_SIZE=block_size, # 128 or 256 -) -``` - -The Triton kernel has the same API: - -```python -from modelopt.torch.quantization.conv_gemm.implicit_gemm import conv3d_implicit_gemm_triton - -out = conv3d_implicit_gemm_triton(x, w, stride=(1,1,1), padding=(1,1,1)) -``` - -## Parameters - -| Parameter | Description | -|-----------|-------------| -| `x` | Input tensor `[N, Cin, D, H, W]` | -| `w` | Weight tensor `[Cout, Cin, kD, kH, kW]` | -| `bias` | Optional bias `[Cout]` | -| `stride` | Convolution stride `(D, H, W)` | -| `padding` | Convolution padding `(D, H, W)` | -| `dilation` | Convolution dilation `(D, H, W)` | -| `act_amax` | Activation abs-max scalar tensor (required when `quant_act=True`) | -| `quant_act` | Enable FP4 fake quantization on activations | -| `FP4_BLOCK_SIZE` | Quantization block size: `128` or `256` | - -## Notes - -- The CUDA kernel is JIT-compiled on first call (takes a few seconds). -- Both kernels return the same shape as `torch.nn.functional.conv3d`. -- FP4 quantization fuses the quantize-dequantize into the GEMM tile load, so there is minimal overhead vs the non-quantized path. From 7ca8bd63685f093e21343bef8b5e9edb5d26c2ec Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 5 Mar 2026 02:37:07 +0000 Subject: [PATCH 04/25] Add test case, and move the cuda code out of python script Signed-off-by: Jingyu Xin --- experimental/conv/bench_implicit_gemm.py | 208 ++++++ experimental/conv/implicit_gemm_binding.cpp | 37 + experimental/conv/implicit_gemm_cuda.py | 600 ++------------- experimental/conv/implicit_gemm_kernel.cu | 595 +++++++++++++++ experimental/conv/test_implicit_gemm.py | 780 ++++++++++++++++++++ 5 files changed, 1667 insertions(+), 553 deletions(-) create mode 100644 experimental/conv/bench_implicit_gemm.py create mode 100644 experimental/conv/implicit_gemm_binding.cpp create mode 100644 experimental/conv/implicit_gemm_kernel.cu create mode 100644 experimental/conv/test_implicit_gemm.py diff --git a/experimental/conv/bench_implicit_gemm.py b/experimental/conv/bench_implicit_gemm.py new file mode 100644 index 0000000000..164c074467 --- /dev/null +++ b/experimental/conv/bench_implicit_gemm.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Latency benchmark: implicit GEMM (quant / non-quant) vs cuDNN conv3d. + +Usage: + python -m experimental.conv.bench_implicit_gemm + python -m experimental.conv.bench_implicit_gemm --shapes wan22 + python -m experimental.conv.bench_implicit_gemm --shapes all --warmup 20 --iters 100 +""" + +import argparse + +import torch +import torch.nn.functional as F + +# --------------------------------------------------------------------------- +# Benchmark shapes +# --------------------------------------------------------------------------- + +# (name, N, Cin, D, H, W, Cout, kD, kH, kW, stride, padding, dilation) +SHAPES = { + "small": [ + ("small_16x32_3x3x3", 1, 16, 8, 8, 8, 32, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), + ], + "medium": [ + ("med_64x128_3x3x3", 1, 64, 16, 32, 32, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), + ("med_128x256_3x3x3", 1, 128, 8, 16, 16, 256, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), + ("med_128x128_1x3x3", 1, 128, 16, 32, 32, 128, 1, 3, 3, (1, 1, 1), (0, 1, 1), (1, 1, 1)), + ], + "wan22": [ + ("wan22_128x512", 1, 128, 21, 60, 106, 512, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), + ("wan22_512x512", 1, 512, 21, 60, 106, 512, 1, 1, 1, (1, 1, 1), (0, 0, 0), (1, 1, 1)), + ("wan22_512x128", 1, 512, 21, 60, 106, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), + ], + "stride": [ + ("stride2_64x128", 1, 64, 16, 32, 32, 128, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)), + ("stride2_128x256", 1, 128, 16, 32, 32, 256, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)), + ], +} + + +def get_shapes(name: str): + """Return list of benchmark shapes by name or all shapes.""" + if name == "all": + result = [] + for v in SHAPES.values(): + result.extend(v) + return result + return SHAPES[name] + + +# --------------------------------------------------------------------------- +# Timing utility +# --------------------------------------------------------------------------- + + +def bench_fn(fn, warmup: int, iters: int) -> float: + """Benchmark a callable, return median time in ms.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + times = [] + for _ in range(iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + times.sort() + return times[len(times) // 2] # median + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def run_benchmark(shapes_name: str, warmup: int, iters: int, fp4_block_size: int): + """Run latency benchmark for the given shapes.""" + from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda + + shapes = get_shapes(shapes_name) + + # Header + print(f"\n{'=' * 100}") + print( + f"Conv3D Latency Benchmark | warmup={warmup} iters={iters} fp4_block_size={fp4_block_size}" + ) + print(f"GPU: {torch.cuda.get_device_name()}") + print(f"{'=' * 100}") + print( + f"{'Shape':<25} {'M':>10} {'K':>8} {'N':>6} " + f"{'cuDNN':>9} {'GEMM':>9} {'GEMM+FP4':>9} " + f"{'GEMM/cuDNN':>11} {'FP4/cuDNN':>10}" + ) + print("-" * 100) + + for name, n, cin, d, h, w, cout, kd, kh, kw, stride, padding, dilation in shapes: + torch.manual_seed(42) + x = torch.randn(n, cin, d, h, w, device="cuda", dtype=torch.float32) + weight = torch.randn(cout, cin, kd, kh, kw, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + # Compute GEMM dimensions for display + sd, sh, sw = stride + dd, dh, dw = dilation + pd, ph, pw = padding + od = (d + 2 * pd - dd * (kd - 1) - 1) // sd + 1 + oh = (h + 2 * ph - dh * (kh - 1) - 1) // sh + 1 + ow = (w + 2 * pw - dw * (kw - 1) - 1) // sw + 1 + gemm_m = n * od * oh * ow + gemm_k = cin * kd * kh * kw + gemm_n = cout + + # cuDNN (torch.nn.functional.conv3d) + t_cudnn = bench_fn( + lambda: F.conv3d(x, weight, stride=stride, padding=padding, dilation=dilation), + warmup, + iters, + ) + + # Implicit GEMM (non-quantized) + t_gemm = bench_fn( + lambda: conv3d_implicit_gemm_cuda( + x, + weight, + stride=stride, + padding=padding, + dilation=dilation, + quant_act=False, + fp4_block_size=fp4_block_size, + ), + warmup, + iters, + ) + + # Implicit GEMM (FP4 quantized) + t_fp4 = bench_fn( + lambda: conv3d_implicit_gemm_cuda( + x, + weight, + stride=stride, + padding=padding, + dilation=dilation, + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ), + warmup, + iters, + ) + + ratio_gemm = t_gemm / t_cudnn + ratio_fp4 = t_fp4 / t_cudnn + + print( + f"{name:<25} {gemm_m:>10,} {gemm_k:>8,} {gemm_n:>6,} " + f"{t_cudnn:>8.3f}ms {t_gemm:>8.3f}ms {t_fp4:>8.3f}ms " + f"{ratio_gemm:>10.2f}x {ratio_fp4:>9.2f}x" + ) + + print(f"{'=' * 100}") + print("Ratios > 1.0x mean slower than cuDNN; < 1.0x mean faster.") + print() + + +def main(): + """Entry point for the benchmark CLI.""" + parser = argparse.ArgumentParser(description="Conv3D latency benchmark") + parser.add_argument( + "--shapes", + default="all", + choices=[*list(SHAPES.keys()), "all"], + help="Which shape set to benchmark (default: all)", + ) + parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations") + parser.add_argument("--iters", type=int, default=100, help="Benchmark iterations") + parser.add_argument( + "--fp4-block-size", + type=int, + default=128, + choices=[128, 256], + help="FP4 block size (default: 128)", + ) + args = parser.parse_args() + + run_benchmark(args.shapes, args.warmup, args.iters, args.fp4_block_size) + + +if __name__ == "__main__": + main() diff --git a/experimental/conv/implicit_gemm_binding.cpp b/experimental/conv/implicit_gemm_binding.cpp new file mode 100644 index 0000000000..b91650cd4e --- /dev/null +++ b/experimental/conv/implicit_gemm_binding.cpp @@ -0,0 +1,37 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include + +torch::Tensor conv3d_implicit_gemm_cuda(torch::Tensor x_pad, torch::Tensor w_flat, + torch::Tensor bias, torch::Tensor act_amax, int N_batch, + int Cin, int Dp, int Hp, int Wp, int Cout, int OD, int OH, + int OW, int kD, int kH, int kW, int sd, int sh, int sw, + int dd, int dh, int dw, int M, int K, bool quant_act, + bool has_bias, int fp4_block_size); + +torch::Tensor fp4_fake_quant_cuda(torch::Tensor x, torch::Tensor global_amax, int block_size); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("conv3d_implicit_gemm_cuda", &conv3d_implicit_gemm_cuda, + "Conv3D implicit GEMM with BF16 WMMA and optional FP4 quantization"); + m.def("fp4_fake_quant_cuda", &fp4_fake_quant_cuda, + "Standalone FP4 fake quantization (blockwise, with FP8 scale quantization)"); +} diff --git a/experimental/conv/implicit_gemm_cuda.py b/experimental/conv/implicit_gemm_cuda.py index dd8fe79541..b5cc6c435b 100644 --- a/experimental/conv/implicit_gemm_cuda.py +++ b/experimental/conv/implicit_gemm_cuda.py @@ -13,559 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Optimized CUDA-based Conv3D Implicit GEMM with FP4 quantization using BF16 WMMA Tensor Cores. - -This module provides an optimized CUDA kernel for Conv3D via implicit GEMM with -fused FP4 fake quantization. The kernel is compiled just-in-time using -PyTorch's cpp_extension. - -Key optimizations: -1. BF16 WMMA tensor core operations (m16n16k16) with FP32 accumulators -2. On-the-fly spatial index computation (no global memory lookup tables) -3. Dual FP4_BLOCK_SIZE support (128 and 256) with optimized tile configs: - - FP4_BLOCK_SIZE=128: BM=64, BN=64, BK=128, 8 warps (256 threads), ~35KB shared - - FP4_BLOCK_SIZE=256: BM=64, BN=64, BK=256, 8 warps (256 threads), ~69KB shared -4. Register-fused FP4 quantization (quantize during A-tile load, eliminates sync) -5. Branchless FP4 quantization using predicated selects -6. BF16 shared memory (halves memory vs FP32) -7. L2-friendly block scheduling (swizzled grid) -8. FP8 E4M3 round-trip for scale quantization (matches Triton exactly) +"""Conv3D Implicit GEMM with BF16 WMMA Tensor Cores and optional fused FP4 quantization. + +CUDA kernel source: implicit_gemm_kernel.cu +C++ binding: implicit_gemm_binding.cpp """ +import os + import torch import torch.nn.functional as F -# C++ header for function declarations -CPP_SOURCE = r""" -torch::Tensor conv3d_implicit_gemm_cuda( - torch::Tensor x_pad, - torch::Tensor w_flat, - torch::Tensor bias, - torch::Tensor act_amax, - int N_batch, int Cin, int Dp, int Hp, int Wp, - int Cout, int OD, int OH, int OW, - int kD, int kH, int kW, - int sd, int sh, int sw, - int dd, int dh, int dw, - int M, int K, - bool quant_act, bool has_bias, - int fp4_block_size -); -""" - -# Optimized CUDA kernel with BF16 WMMA tensor cores -CUDA_KERNEL_SOURCE = r""" -#include -#include -#include -#include -#include - -using namespace nvcuda; - -// ============================================================================= -// FP4 Quantization Helpers -// ============================================================================= - -__device__ __forceinline__ float fp4_quantize_value(float scaled) { - float q; - q = (scaled <= 5.0f) ? 4.0f : 6.0f; - q = (scaled < 3.5f) ? 3.0f : q; - q = (scaled <= 2.5f) ? 2.0f : q; - q = (scaled < 1.75f) ? 1.5f : q; - q = (scaled <= 1.25f) ? 1.0f : q; - q = (scaled < 0.75f) ? 0.5f : q; - q = (scaled <= 0.25f) ? 0.0f : q; - return q; -} - -__device__ __forceinline__ float warp_reduce_max(float val) { - #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, offset)); - } - return val; -} - -__device__ __forceinline__ float fp8_e4m3_round_trip(float x) { - if (x == 0.0f) return 0.0f; - - unsigned int bits = __float_as_uint(x); - unsigned int sign = bits >> 31; - int exp = ((bits >> 23) & 0xff) - 127; - unsigned int mantissa = bits & 0x7fffff; - - if (exp > 8) return sign ? -448.0f : 448.0f; - if (exp < -9) return 0.0f; - - unsigned int mantissa_3bit = (mantissa + (1 << 19)) >> 20; - if (mantissa_3bit > 7) { - mantissa_3bit = 0; - exp += 1; - if (exp > 8) return sign ? -448.0f : 448.0f; - } - - if (exp < -6) { - int shift = -6 - exp; - mantissa_3bit = (mantissa_3bit | 8) >> shift; - exp = -6; - } - - int fp32_exp = exp + 127; - unsigned int fp32_mantissa = mantissa_3bit << 20; - unsigned int fp32_bits = (sign << 31) | (fp32_exp << 23) | fp32_mantissa; - - return __uint_as_float(fp32_bits); -} - -__device__ __forceinline__ float quantize_scale_fp8(float block_max, float global_scale) { - float scaled = block_max / (6.0f * global_scale); - scaled = fminf(scaled, 448.0f); - float quantized = fp8_e4m3_round_trip(scaled); - return quantized * global_scale; -} - -// ============================================================================= -// BF16 WMMA Conv3D Implicit GEMM Kernel -// ============================================================================= -// Template parameters: -// QUANT_ACT - whether to apply FP4 quantization -// HAS_BIAS - whether bias is present -// BLOCK_M - M tile size (64) -// BLOCK_N - N tile size (32) -// BLOCK_K - K tile size (matches FP4_BLOCK_SIZE: 128 or 256) -// WARPS_M - warp tiling in M dimension (2) -// WARPS_N - warp tiling in N dimension (2) -// L2_SWIZZLE_GROUP - group size for L2-friendly block scheduling -// -// Each warp computes a (WARP_M x WARP_N) output tile using 16x16x16 WMMA. -// WARP_M = BLOCK_M / WARPS_M, WARP_N = BLOCK_N / WARPS_N -// WARP_TILES_M = WARP_M / 16, WARP_TILES_N = WARP_N / 16 -// -// Shared memory layout (BF16): -// As[BLOCK_M][BK_STRIDE] - M-major (row_major for WMMA A-fragments) -// Bs[BLOCK_K][BN_STRIDE] - K-major (row_major for WMMA B-fragments) - -template< - bool QUANT_ACT, bool HAS_BIAS, - int BLOCK_M, int BLOCK_N, int BLOCK_K, - int WARPS_M, int WARPS_N, - int L2_SWIZZLE_GROUP = 8 -> -__global__ void __launch_bounds__(WARPS_M * WARPS_N * 32, 2) -conv3d_implicit_gemm_wmma( - const float* __restrict__ x_pad, - const float* __restrict__ w_flat, - const float* __restrict__ bias, - float* __restrict__ y, - const float* __restrict__ act_amax, - int Cin, int Dp, int Hp, int Wp, - int Cout, int OD, int OH, int OW, - int kD, int kH, int kW, - int sd, int sh, int sw, - int dd, int dh, int dw, - int M, int K -) { - // Derived constants - constexpr int NUM_WARPS = WARPS_M * WARPS_N; - constexpr int NUM_THREADS = NUM_WARPS * 32; - constexpr int WARP_M = BLOCK_M / WARPS_M; // 32 - constexpr int WARP_N = BLOCK_N / WARPS_N; // 16 - constexpr int WARP_TILES_M = WARP_M / 16; // 2 - constexpr int WARP_TILES_N = WARP_N / 16; // 1 - - // BF16 shared memory strides with padding to avoid bank conflicts - // Pad by 8 BF16 elements (16 bytes) — keeps 16-byte alignment while breaking conflicts - constexpr int BK_STRIDE = BLOCK_K + 8; - constexpr int BN_STRIDE = BLOCK_N + 8; - - // Thread/warp indices - const int tid = threadIdx.x; - const int warp_id = tid / 32; - const int lane_id = tid % 32; - const int warp_m = warp_id / WARPS_N; // which M-warp (0..WARPS_M-1) - const int warp_n = warp_id % WARPS_N; // which N-warp (0..WARPS_N-1) - - // L2-friendly block scheduling (swizzle) - int bm, bn; - { - const int pid = blockIdx.x; - constexpr int GS = L2_SWIZZLE_GROUP; - const int grid_n = (Cout + BLOCK_N - 1) / BLOCK_N; - const int grid_m = (M + BLOCK_M - 1) / BLOCK_M; - const int tiles_per_group = GS * grid_n; - - const int group_row = pid / tiles_per_group; - const int group_rem = pid % tiles_per_group; - bn = group_rem / GS; - const int swizzle_lane = group_rem % GS; - bm = group_row * GS + swizzle_lane; - - if (bm >= grid_m || bn >= grid_n) return; - } - - // Dynamic shared memory — BF16 tiles - extern __shared__ char smem_raw[]; - __nv_bfloat16* As = reinterpret_cast<__nv_bfloat16*>(smem_raw); - // As: [BLOCK_M][BK_STRIDE] — M-major - constexpr int A_SMEM_ELEMS = BLOCK_M * BK_STRIDE; - __nv_bfloat16* Bs = As + A_SMEM_ELEMS; - // Bs: [BLOCK_K][BN_STRIDE] — K-major - - // WMMA accumulators — FP32 - wmma::fragment acc[WARP_TILES_M][WARP_TILES_N]; - #pragma unroll - for (int wm = 0; wm < WARP_TILES_M; wm++) { - #pragma unroll - for (int wn = 0; wn < WARP_TILES_N; wn++) { - wmma::fill_fragment(acc[wm][wn], 0.0f); - } - } - - // Global scale for FP4 quantization - float global_scale = 1.0f; - if constexpr (QUANT_ACT) { - global_scale = act_amax[0] / (6.0f * 448.0f); - } - - // Precompute spatial constants - const int HpWp = Hp * Wp; - const int DpHpWp = Dp * HpWp; - const int kHW = kH * kW; - const int kDHW = kD * kHW; - const int OHW = OH * OW; - const int ODHW = OD * OHW; - - const int m_start = bm * BLOCK_M; - const int n_start = bn * BLOCK_N; - const int num_k_tiles = (K + BLOCK_K - 1) / BLOCK_K; - - // Total elements to load cooperatively - constexpr int A_ELEMS = BLOCK_M * BLOCK_K; - constexpr int B_ELEMS = BLOCK_K * BLOCK_N; - - // Main loop over K tiles - for (int k_tile = 0; k_tile < num_k_tiles; k_tile++) { - const int k_start_tile = k_tile * BLOCK_K; - - // ===================================================================== - // Load A tile into BF16 shared memory (M-major layout) - // As[m][k] stored at As[m * BK_STRIDE + k] - // ===================================================================== - if constexpr (QUANT_ACT) { - // Fused FP4 quantization: each warp handles M-rows - constexpr int ELEMS_PER_LANE = (BLOCK_K + 31) / 32; - - for (int m = warp_id; m < BLOCK_M; m += NUM_WARPS) { - int m_idx = m_start + m; - - int n_batch, od_val, oh_val, ow_val; - if (m_idx < M) { - n_batch = m_idx / ODHW; - int rem = m_idx % ODHW; - od_val = rem / OHW; - rem = rem % OHW; - oh_val = rem / OW; - ow_val = rem % OW; - } else { - n_batch = 0; od_val = 0; oh_val = 0; ow_val = 0; - } - - float local_max = 0.0f; - float vals[ELEMS_PER_LANE]; - - #pragma unroll - for (int i = 0; i < ELEMS_PER_LANE; i++) { - int k = lane_id + i * 32; - float val = 0.0f; - if (k < BLOCK_K && m_idx < M) { - int k_idx = k_start_tile + k; - if (k_idx < K) { - int c = k_idx / kDHW; - int remk = k_idx % kDHW; - int kd_v = remk / kHW; - remk = remk % kHW; - int kh_v = remk / kW; - int kw_v = remk % kW; - - int id = od_val * sd + kd_v * dd; - int ih = oh_val * sh + kh_v * dh; - int iw = ow_val * sw + kw_v * dw; - - val = x_pad[n_batch * Cin * DpHpWp + c * DpHpWp + id * HpWp + ih * Wp + iw]; - } - } - vals[i] = val; - local_max = fmaxf(local_max, fabsf(val)); - } - - float block_max = warp_reduce_max(local_max); - float scale = quantize_scale_fp8(block_max, global_scale); - if (scale < 1e-5f) scale = 1.0f; - float inv_scale = 1.0f / scale; - - #pragma unroll - for (int i = 0; i < ELEMS_PER_LANE; i++) { - int k = lane_id + i * 32; - if (k < BLOCK_K) { - float val = vals[i]; - float sign = (val >= 0.0f) ? 1.0f : -1.0f; - float q = fp4_quantize_value(fabsf(val) * inv_scale); - float result = sign * q * scale; - // M-major: As[m * BK_STRIDE + k] - As[m * BK_STRIDE + k] = __float2bfloat16(result); - } - } - } - } else { - // Non-quantized: cooperative load, store as BF16 in M-major - #pragma unroll 4 - for (int i = tid; i < A_ELEMS; i += NUM_THREADS) { - int local_m = i / BLOCK_K; - int local_k = i % BLOCK_K; - int m_idx = m_start + local_m; - int k_idx = k_start_tile + local_k; - - float val = 0.0f; - if (m_idx < M && k_idx < K) { - int n_batch = m_idx / ODHW; - int rem = m_idx % ODHW; - int od_val = rem / OHW; - rem = rem % OHW; - int oh_val = rem / OW; - int ow_val = rem % OW; - - int c = k_idx / kDHW; - int remk = k_idx % kDHW; - int kd_v = remk / kHW; - remk = remk % kHW; - int kh_v = remk / kW; - int kw_v = remk % kW; - - int id = od_val * sd + kd_v * dd; - int ih = oh_val * sh + kh_v * dh; - int iw = ow_val * sw + kw_v * dw; - - val = x_pad[n_batch * Cin * DpHpWp + c * DpHpWp + id * HpWp + ih * Wp + iw]; - } - // M-major: As[m * BK_STRIDE + k] - As[local_m * BK_STRIDE + local_k] = __float2bfloat16(val); - } - } - - // ===================================================================== - // Load B tile into BF16 shared memory (K-major layout) - // Bs[k][n] stored at Bs[k * BN_STRIDE + n] - // ===================================================================== - #pragma unroll 4 - for (int i = tid; i < B_ELEMS; i += NUM_THREADS) { - int local_k = i / BLOCK_N; - int local_n = i % BLOCK_N; - int k_idx = k_start_tile + local_k; - int n_idx = n_start + local_n; - - float val = 0.0f; - if (k_idx < K && n_idx < Cout) { - val = w_flat[k_idx * Cout + n_idx]; - } - Bs[local_k * BN_STRIDE + local_n] = __float2bfloat16(val); - } - - __syncthreads(); - - // ===================================================================== - // WMMA Compute: iterate over K in steps of 16 (WMMA K-dim) - // ===================================================================== - constexpr int K_STEPS = BLOCK_K / 16; - - #pragma unroll - for (int kk = 0; kk < K_STEPS; kk++) { - // Load A and B fragments from shared memory - wmma::fragment a_frag[WARP_TILES_M]; - wmma::fragment b_frag[WARP_TILES_N]; - - // Load A fragments: each from As[(warp_m * WARP_M + wm*16) * BK_STRIDE + kk*16] - #pragma unroll - for (int wm = 0; wm < WARP_TILES_M; wm++) { - int a_row = warp_m * WARP_M + wm * 16; - int a_col = kk * 16; - wmma::load_matrix_sync(a_frag[wm], &As[a_row * BK_STRIDE + a_col], BK_STRIDE); - } - - // Load B fragments: each from Bs[(kk*16) * BN_STRIDE + (warp_n * WARP_N + wn*16)] - #pragma unroll - for (int wn = 0; wn < WARP_TILES_N; wn++) { - int b_row = kk * 16; - int b_col = warp_n * WARP_N + wn * 16; - wmma::load_matrix_sync(b_frag[wn], &Bs[b_row * BN_STRIDE + b_col], BN_STRIDE); - } - - // MMA: acc[wm][wn] += a_frag[wm] * b_frag[wn] - #pragma unroll - for (int wm = 0; wm < WARP_TILES_M; wm++) { - #pragma unroll - for (int wn = 0; wn < WARP_TILES_N; wn++) { - wmma::mma_sync(acc[wm][wn], a_frag[wm], b_frag[wn], acc[wm][wn]); - } - } - } - - __syncthreads(); - } - - // ========================================================================= - // Store results: use shared memory as FP32 staging buffer - // Each warp stores its accumulator fragments, then all threads cooperatively - // copy to global memory with bounds checking and bias addition. - // ========================================================================= - - // Reinterpret shared memory as FP32 for output staging - // We need BLOCK_M * BLOCK_N floats = 64 * 32 * 4 = 8192 bytes - // This fits within our shared memory (>= 27KB) - float* out_smem = reinterpret_cast(smem_raw); - // out_smem layout: [BLOCK_M][BLOCK_N], row-major - - // Each warp stores its accumulator fragments to shared memory - #pragma unroll - for (int wm = 0; wm < WARP_TILES_M; wm++) { - #pragma unroll - for (int wn = 0; wn < WARP_TILES_N; wn++) { - int out_row = warp_m * WARP_M + wm * 16; - int out_col = warp_n * WARP_N + wn * 16; - // Store to out_smem[out_row][out_col] with stride BLOCK_N - wmma::store_matrix_sync(&out_smem[out_row * BLOCK_N + out_col], acc[wm][wn], BLOCK_N, wmma::mem_row_major); - } - } - - __syncthreads(); - - // Cooperatively copy from shared memory to global memory - constexpr int OUT_ELEMS = BLOCK_M * BLOCK_N; - #pragma unroll 4 - for (int i = tid; i < OUT_ELEMS; i += NUM_THREADS) { - int local_m = i / BLOCK_N; - int local_n = i % BLOCK_N; - int m_idx = m_start + local_m; - int n_idx = n_start + local_n; - - if (m_idx < M && n_idx < Cout) { - float result = out_smem[i]; - if constexpr (HAS_BIAS) { - result += bias[n_idx]; - } - y[m_idx * Cout + n_idx] = result; - } - } -} - -// ============================================================================= -// Python Binding -// ============================================================================= - -torch::Tensor conv3d_implicit_gemm_cuda( - torch::Tensor x_pad, - torch::Tensor w_flat, - torch::Tensor bias, - torch::Tensor act_amax, - int N_batch, int Cin, int Dp, int Hp, int Wp, - int Cout, int OD, int OH, int OW, - int kD, int kH, int kW, - int sd, int sh, int sw, - int dd, int dh, int dw, - int M, int K, - bool quant_act, bool has_bias, - int fp4_block_size -) { - auto y = torch::zeros({M, Cout}, x_pad.options()); - - // Helper to compute padded 1D grid size for L2 swizzle - constexpr int GS = 8; // L2_SWIZZLE_GROUP - auto compute_grid = [&](int BM, int BN) -> dim3 { - int grid_m = (M + BM - 1) / BM; - int grid_n = (Cout + BN - 1) / BN; - int num_m_groups = (grid_m + GS - 1) / GS; - int total_blocks = num_m_groups * GS * grid_n; - return dim3(total_blocks, 1); - }; - - // Macro to dispatch kernel with all 4 template specializations - #define LAUNCH_WMMA_KERNEL(BM, BN, BK, WM, WN) \ - { \ - constexpr int BK_S = BK + 8; \ - constexpr int BN_S = BN + 8; \ - constexpr size_t smem_a = BM * BK_S * sizeof(__nv_bfloat16); \ - constexpr size_t smem_b = BK * BN_S * sizeof(__nv_bfloat16); \ - constexpr size_t smem = smem_a + smem_b; \ - \ - dim3 block(WM * WN * 32); \ - dim3 grid = compute_grid(BM, BN); \ - \ - auto set_smem = [](auto kernel) { \ - constexpr size_t s_a = BM * (BK + 8) * sizeof(__nv_bfloat16); \ - constexpr size_t s_b = BK * (BN + 8) * sizeof(__nv_bfloat16); \ - constexpr size_t s = s_a + s_b; \ - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, s); \ - }; \ - \ - if (quant_act && has_bias) { \ - auto kern = conv3d_implicit_gemm_wmma; \ - set_smem(kern); \ - kern<<>>( \ - x_pad.data_ptr(), w_flat.data_ptr(), \ - bias.data_ptr(), y.data_ptr(), \ - act_amax.data_ptr(), \ - Cin, Dp, Hp, Wp, Cout, OD, OH, OW, kD, kH, kW, \ - sd, sh, sw, dd, dh, dw, M, K); \ - } else if (quant_act) { \ - auto kern = conv3d_implicit_gemm_wmma; \ - set_smem(kern); \ - kern<<>>( \ - x_pad.data_ptr(), w_flat.data_ptr(), \ - bias.data_ptr(), y.data_ptr(), \ - act_amax.data_ptr(), \ - Cin, Dp, Hp, Wp, Cout, OD, OH, OW, kD, kH, kW, \ - sd, sh, sw, dd, dh, dw, M, K); \ - } else if (has_bias) { \ - auto kern = conv3d_implicit_gemm_wmma; \ - set_smem(kern); \ - kern<<>>( \ - x_pad.data_ptr(), w_flat.data_ptr(), \ - bias.data_ptr(), y.data_ptr(), \ - act_amax.data_ptr(), \ - Cin, Dp, Hp, Wp, Cout, OD, OH, OW, kD, kH, kW, \ - sd, sh, sw, dd, dh, dw, M, K); \ - } else { \ - auto kern = conv3d_implicit_gemm_wmma; \ - set_smem(kern); \ - kern<<>>( \ - x_pad.data_ptr(), w_flat.data_ptr(), \ - bias.data_ptr(), y.data_ptr(), \ - act_amax.data_ptr(), \ - Cin, Dp, Hp, Wp, Cout, OD, OH, OW, kD, kH, kW, \ - sd, sh, sw, dd, dh, dw, M, K); \ - } \ - } - - if (fp4_block_size == 128) { - // BLOCK_M=64, BLOCK_N=64, BLOCK_K=128, WARPS_M=2, WARPS_N=4 - // 8 warps = 256 threads -> faster cooperative loading - // WARP_M=32, WARP_N=16, WARP_TILES_M=2, WARP_TILES_N=1 -> 2 mma per warp per K-step - // Shared: 64*(128+8)*2 + 128*(64+8)*2 = 17,408 + 18,432 = 35,840 bytes (~35KB) - LAUNCH_WMMA_KERNEL(64, 64, 128, 2, 4) - } else { - // BLOCK_M=64, BLOCK_N=64, BLOCK_K=256, WARPS_M=2, WARPS_N=4 - // 8 warps = 256 threads -> faster cooperative loading - // Shared: 64*(256+8)*2 + 256*(64+8)*2 = 33,792 + 36,864 = 70,656 bytes (~69KB) - LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4) - } - - #undef LAUNCH_WMMA_KERNEL - - return y; -} -""" +_KERNEL_DIR = os.path.dirname(os.path.abspath(__file__)) -# Compile the CUDA kernel _cuda_module = None @@ -573,13 +33,14 @@ def _get_cuda_module(): """Get or compile the CUDA module.""" global _cuda_module if _cuda_module is None: - from torch.utils.cpp_extension import load_inline + from torch.utils.cpp_extension import load - _cuda_module = load_inline( + _cuda_module = load( name="conv3d_implicit_gemm_cuda_v19_wmma", - cpp_sources=CPP_SOURCE, - cuda_sources=CUDA_KERNEL_SOURCE, - functions=["conv3d_implicit_gemm_cuda"], + sources=[ + os.path.join(_KERNEL_DIR, "implicit_gemm_binding.cpp"), + os.path.join(_KERNEL_DIR, "implicit_gemm_kernel.cu"), + ], verbose=True, extra_cuda_cflags=[ "-O3", @@ -622,7 +83,7 @@ def conv3d_implicit_gemm_cuda( quant_act: bool = False, fp4_block_size: int = 256, ) -> torch.Tensor: - """Optimized CUDA-based Conv3D via implicit GEMM with BF16 WMMA tensor cores. + """Conv3D via implicit GEMM with BF16 WMMA tensor cores. Args: x: Input tensor [N, Cin, D, H, W] @@ -704,3 +165,36 @@ def conv3d_implicit_gemm_cuda( y = y_flat.view(n_batch, od, oh, ow, cout).permute(0, 4, 1, 2, 3).contiguous() return y.to(x.dtype) + + +@torch.no_grad() +def fp4_fake_quant( + x: torch.Tensor, + global_amax: torch.Tensor, + block_size: int = 16, +) -> torch.Tensor: + """Standalone FP4 fake quantization using the same CUDA device functions as the GEMM kernel. + + Applies blockwise FP4 (E2M1) quantize-dequantize with FP8 E4M3 scale quantization. + + Args: + x: Input tensor (any shape, numel must be divisible by block_size). + global_amax: Scalar tensor — global abs max for scale computation. + block_size: Number of elements per FP4 quantization block. + + Returns: + Fake-quantized tensor with same shape and dtype as input. + """ + cuda_mod = _get_cuda_module() + + orig_shape = x.shape + orig_dtype = x.dtype + x_f32 = x.float().contiguous() + amax_f32 = global_amax.float().contiguous() + + assert x_f32.numel() % block_size == 0, ( + f"numel ({x_f32.numel()}) must be divisible by block_size ({block_size})" + ) + + y = cuda_mod.fp4_fake_quant_cuda(x_f32, amax_f32, block_size) + return y.view(orig_shape).to(orig_dtype) diff --git a/experimental/conv/implicit_gemm_kernel.cu b/experimental/conv/implicit_gemm_kernel.cu new file mode 100644 index 0000000000..87dadf3b4b --- /dev/null +++ b/experimental/conv/implicit_gemm_kernel.cu @@ -0,0 +1,595 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Conv3D Implicit GEMM with BF16 WMMA Tensor Cores and optional fused FP4 quantization. +// +// Key optimizations: +// 1. BF16 WMMA tensor core operations (m16n16k16) with FP32 accumulators +// 2. On-the-fly spatial index computation (no global memory lookup tables) +// 3. Dual FP4_BLOCK_SIZE support (128 and 256) with optimized tile configs +// 4. Register-fused FP4 quantization (quantize during A-tile load, eliminates sync) +// 5. Branchless FP4 quantization using predicated selects +// 6. BF16 shared memory (halves memory vs FP32) +// 7. L2-friendly block scheduling (swizzled grid) +// 8. FP8 E4M3 round-trip for scale quantization + +#include +#include +#include +#include +#include + +using namespace nvcuda; + +// ============================================================================= +// FP4 Quantization Helpers +// ============================================================================= + +__device__ __forceinline__ float fp4_quantize_value(float scaled) { + float q; + q = (scaled <= 5.0f) ? 4.0f : 6.0f; + q = (scaled < 3.5f) ? 3.0f : q; + q = (scaled <= 2.5f) ? 2.0f : q; + q = (scaled < 1.75f) ? 1.5f : q; + q = (scaled <= 1.25f) ? 1.0f : q; + q = (scaled < 0.75f) ? 0.5f : q; + q = (scaled <= 0.25f) ? 0.0f : q; + return q; +} + +__device__ __forceinline__ float warp_reduce_max(float val) { +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, offset)); + } + return val; +} + +__device__ __forceinline__ float fp8_e4m3_round_trip(float x) { + if (x == 0.0f) + return 0.0f; + + unsigned int bits = __float_as_uint(x); + unsigned int sign = bits >> 31; + int exp = ((bits >> 23) & 0xff) - 127; + unsigned int mantissa = bits & 0x7fffff; + + if (exp > 8) + return sign ? -448.0f : 448.0f; + if (exp < -9) + return 0.0f; + + unsigned int mantissa_3bit = (mantissa + (1 << 19)) >> 20; + if (mantissa_3bit > 7) { + mantissa_3bit = 0; + exp += 1; + if (exp > 8) + return sign ? -448.0f : 448.0f; + } + + if (exp < -6) { + int shift = -6 - exp; + mantissa_3bit = (mantissa_3bit | 8) >> shift; + exp = -6; + } + + int fp32_exp = exp + 127; + unsigned int fp32_mantissa = mantissa_3bit << 20; + unsigned int fp32_bits = (sign << 31) | (fp32_exp << 23) | fp32_mantissa; + + return __uint_as_float(fp32_bits); +} + +__device__ __forceinline__ float quantize_scale_fp8(float block_max, float global_scale) { + float scaled = block_max / (6.0f * global_scale); + scaled = fminf(scaled, 448.0f); + float quantized = fp8_e4m3_round_trip(scaled); + return quantized * global_scale; +} + +// ============================================================================= +// BF16 WMMA Conv3D Implicit GEMM Kernel +// ============================================================================= +// Template parameters: +// QUANT_ACT - whether to apply FP4 quantization +// HAS_BIAS - whether bias is present +// BLOCK_M - M tile size (64) +// BLOCK_N - N tile size (64) +// BLOCK_K - K tile size (matches FP4_BLOCK_SIZE: 128 or 256) +// WARPS_M - warp tiling in M dimension (2) +// WARPS_N - warp tiling in N dimension (4) +// L2_SWIZZLE_GROUP - group size for L2-friendly block scheduling +// +// Each warp computes a (WARP_M x WARP_N) output tile using 16x16x16 WMMA. +// WARP_M = BLOCK_M / WARPS_M, WARP_N = BLOCK_N / WARPS_N +// WARP_TILES_M = WARP_M / 16, WARP_TILES_N = WARP_N / 16 +// +// Shared memory layout (BF16): +// As[BLOCK_M][BK_STRIDE] - M-major (row_major for WMMA A-fragments) +// Bs[BLOCK_K][BN_STRIDE] - K-major (row_major for WMMA B-fragments) + +template +__global__ void __launch_bounds__(WARPS_M * WARPS_N * 32, 2) + conv3d_implicit_gemm_wmma(const float *__restrict__ x_pad, const float *__restrict__ w_flat, + const float *__restrict__ bias, float *__restrict__ y, + const float *__restrict__ act_amax, int Cin, int Dp, int Hp, int Wp, + int Cout, int OD, int OH, int OW, int kD, int kH, int kW, int sd, + int sh, int sw, int dd, int dh, int dw, int M, int K) { + // Derived constants + constexpr int NUM_WARPS = WARPS_M * WARPS_N; + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int WARP_M = BLOCK_M / WARPS_M; // 32 + constexpr int WARP_N = BLOCK_N / WARPS_N; // 16 + constexpr int WARP_TILES_M = WARP_M / 16; // 2 + constexpr int WARP_TILES_N = WARP_N / 16; // 1 + + // BF16 shared memory strides with padding to avoid bank conflicts + // Pad by 8 BF16 elements (16 bytes) — keeps 16-byte alignment while breaking conflicts + constexpr int BK_STRIDE = BLOCK_K + 8; + constexpr int BN_STRIDE = BLOCK_N + 8; + + // Thread/warp indices + const int tid = threadIdx.x; + const int warp_id = tid / 32; + const int lane_id = tid % 32; + const int warp_m = warp_id / WARPS_N; // which M-warp (0..WARPS_M-1) + const int warp_n = warp_id % WARPS_N; // which N-warp (0..WARPS_N-1) + + // L2-friendly block scheduling (swizzle) + int bm, bn; + { + const int pid = blockIdx.x; + constexpr int GS = L2_SWIZZLE_GROUP; + const int grid_n = (Cout + BLOCK_N - 1) / BLOCK_N; + const int grid_m = (M + BLOCK_M - 1) / BLOCK_M; + const int tiles_per_group = GS * grid_n; + + const int group_row = pid / tiles_per_group; + const int group_rem = pid % tiles_per_group; + bn = group_rem / GS; + const int swizzle_lane = group_rem % GS; + bm = group_row * GS + swizzle_lane; + + if (bm >= grid_m || bn >= grid_n) + return; + } + + // Dynamic shared memory — BF16 tiles + extern __shared__ char smem_raw[]; + __nv_bfloat16 *As = reinterpret_cast<__nv_bfloat16 *>(smem_raw); + // As: [BLOCK_M][BK_STRIDE] — M-major + constexpr int A_SMEM_ELEMS = BLOCK_M * BK_STRIDE; + __nv_bfloat16 *Bs = As + A_SMEM_ELEMS; + // Bs: [BLOCK_K][BN_STRIDE] — K-major + + // WMMA accumulators — FP32 + wmma::fragment acc[WARP_TILES_M][WARP_TILES_N]; +#pragma unroll + for (int wm = 0; wm < WARP_TILES_M; wm++) { +#pragma unroll + for (int wn = 0; wn < WARP_TILES_N; wn++) { + wmma::fill_fragment(acc[wm][wn], 0.0f); + } + } + + // Global scale for FP4 quantization + float global_scale = 1.0f; + if constexpr (QUANT_ACT) { + global_scale = act_amax[0] / (6.0f * 448.0f); + } + + // Precompute spatial constants + const int HpWp = Hp * Wp; + const int DpHpWp = Dp * HpWp; + const int kHW = kH * kW; + const int kDHW = kD * kHW; + const int OHW = OH * OW; + const int ODHW = OD * OHW; + + const int m_start = bm * BLOCK_M; + const int n_start = bn * BLOCK_N; + const int num_k_tiles = (K + BLOCK_K - 1) / BLOCK_K; + + // Total elements to load cooperatively + constexpr int A_ELEMS = BLOCK_M * BLOCK_K; + constexpr int B_ELEMS = BLOCK_K * BLOCK_N; + + // Main loop over K tiles + for (int k_tile = 0; k_tile < num_k_tiles; k_tile++) { + const int k_start_tile = k_tile * BLOCK_K; + + // ================================================================= + // Load A tile into BF16 shared memory (M-major layout) + // As[m][k] stored at As[m * BK_STRIDE + k] + // ================================================================= + if constexpr (QUANT_ACT) { + // Fused FP4 quantization: each warp handles M-rows + constexpr int ELEMS_PER_LANE = (BLOCK_K + 31) / 32; + + for (int m = warp_id; m < BLOCK_M; m += NUM_WARPS) { + int m_idx = m_start + m; + + int n_batch, od_val, oh_val, ow_val; + if (m_idx < M) { + n_batch = m_idx / ODHW; + int rem = m_idx % ODHW; + od_val = rem / OHW; + rem = rem % OHW; + oh_val = rem / OW; + ow_val = rem % OW; + } else { + n_batch = 0; + od_val = 0; + oh_val = 0; + ow_val = 0; + } + + float local_max = 0.0f; + float vals[ELEMS_PER_LANE]; + +#pragma unroll + for (int i = 0; i < ELEMS_PER_LANE; i++) { + int k = lane_id + i * 32; + float val = 0.0f; + if (k < BLOCK_K && m_idx < M) { + int k_idx = k_start_tile + k; + if (k_idx < K) { + int c = k_idx / kDHW; + int remk = k_idx % kDHW; + int kd_v = remk / kHW; + remk = remk % kHW; + int kh_v = remk / kW; + int kw_v = remk % kW; + + int id = od_val * sd + kd_v * dd; + int ih = oh_val * sh + kh_v * dh; + int iw = ow_val * sw + kw_v * dw; + + val = x_pad[n_batch * Cin * DpHpWp + c * DpHpWp + id * HpWp + ih * Wp + iw]; + } + } + vals[i] = val; + local_max = fmaxf(local_max, fabsf(val)); + } + + float block_max = warp_reduce_max(local_max); + float scale = quantize_scale_fp8(block_max, global_scale); + if (scale < 1e-5f) + scale = 1.0f; + float inv_scale = 1.0f / scale; + +#pragma unroll + for (int i = 0; i < ELEMS_PER_LANE; i++) { + int k = lane_id + i * 32; + if (k < BLOCK_K) { + float val = vals[i]; + float sign = (val >= 0.0f) ? 1.0f : -1.0f; + float q = fp4_quantize_value(fabsf(val) * inv_scale); + float result = sign * q * scale; + // M-major: As[m * BK_STRIDE + k] + As[m * BK_STRIDE + k] = __float2bfloat16(result); + } + } + } + } else { +// Non-quantized: cooperative load, store as BF16 in M-major +#pragma unroll 4 + for (int i = tid; i < A_ELEMS; i += NUM_THREADS) { + int local_m = i / BLOCK_K; + int local_k = i % BLOCK_K; + int m_idx = m_start + local_m; + int k_idx = k_start_tile + local_k; + + float val = 0.0f; + if (m_idx < M && k_idx < K) { + int n_batch = m_idx / ODHW; + int rem = m_idx % ODHW; + int od_val = rem / OHW; + rem = rem % OHW; + int oh_val = rem / OW; + int ow_val = rem % OW; + + int c = k_idx / kDHW; + int remk = k_idx % kDHW; + int kd_v = remk / kHW; + remk = remk % kHW; + int kh_v = remk / kW; + int kw_v = remk % kW; + + int id = od_val * sd + kd_v * dd; + int ih = oh_val * sh + kh_v * dh; + int iw = ow_val * sw + kw_v * dw; + + val = x_pad[n_batch * Cin * DpHpWp + c * DpHpWp + id * HpWp + ih * Wp + iw]; + } + // M-major: As[m * BK_STRIDE + k] + As[local_m * BK_STRIDE + local_k] = __float2bfloat16(val); + } + } + +// ================================================================= +// Load B tile into BF16 shared memory (K-major layout) +// Bs[k][n] stored at Bs[k * BN_STRIDE + n] +// ================================================================= +#pragma unroll 4 + for (int i = tid; i < B_ELEMS; i += NUM_THREADS) { + int local_k = i / BLOCK_N; + int local_n = i % BLOCK_N; + int k_idx = k_start_tile + local_k; + int n_idx = n_start + local_n; + + float val = 0.0f; + if (k_idx < K && n_idx < Cout) { + val = w_flat[k_idx * Cout + n_idx]; + } + Bs[local_k * BN_STRIDE + local_n] = __float2bfloat16(val); + } + + __syncthreads(); + + // ================================================================= + // WMMA Compute: iterate over K in steps of 16 (WMMA K-dim) + // ================================================================= + constexpr int K_STEPS = BLOCK_K / 16; + +#pragma unroll + for (int kk = 0; kk < K_STEPS; kk++) { + // Load A and B fragments from shared memory + wmma::fragment + a_frag[WARP_TILES_M]; + wmma::fragment + b_frag[WARP_TILES_N]; + +// Load A fragments +#pragma unroll + for (int wm = 0; wm < WARP_TILES_M; wm++) { + int a_row = warp_m * WARP_M + wm * 16; + int a_col = kk * 16; + wmma::load_matrix_sync(a_frag[wm], &As[a_row * BK_STRIDE + a_col], BK_STRIDE); + } + +// Load B fragments +#pragma unroll + for (int wn = 0; wn < WARP_TILES_N; wn++) { + int b_row = kk * 16; + int b_col = warp_n * WARP_N + wn * 16; + wmma::load_matrix_sync(b_frag[wn], &Bs[b_row * BN_STRIDE + b_col], BN_STRIDE); + } + +// MMA: acc[wm][wn] += a_frag[wm] * b_frag[wn] +#pragma unroll + for (int wm = 0; wm < WARP_TILES_M; wm++) { +#pragma unroll + for (int wn = 0; wn < WARP_TILES_N; wn++) { + wmma::mma_sync(acc[wm][wn], a_frag[wm], b_frag[wn], acc[wm][wn]); + } + } + } + + __syncthreads(); + } + + // ===================================================================== + // Store results: use shared memory as FP32 staging buffer + // Each warp stores its accumulator fragments, then all threads + // cooperatively copy to global memory with bounds checking and bias. + // ===================================================================== + + // Reinterpret shared memory as FP32 for output staging + float *out_smem = reinterpret_cast(smem_raw); +// out_smem layout: [BLOCK_M][BLOCK_N], row-major + +// Each warp stores its accumulator fragments to shared memory +#pragma unroll + for (int wm = 0; wm < WARP_TILES_M; wm++) { +#pragma unroll + for (int wn = 0; wn < WARP_TILES_N; wn++) { + int out_row = warp_m * WARP_M + wm * 16; + int out_col = warp_n * WARP_N + wn * 16; + wmma::store_matrix_sync(&out_smem[out_row * BLOCK_N + out_col], acc[wm][wn], BLOCK_N, + wmma::mem_row_major); + } + } + + __syncthreads(); + + // Cooperatively copy from shared memory to global memory + constexpr int OUT_ELEMS = BLOCK_M * BLOCK_N; +#pragma unroll 4 + for (int i = tid; i < OUT_ELEMS; i += NUM_THREADS) { + int local_m = i / BLOCK_N; + int local_n = i % BLOCK_N; + int m_idx = m_start + local_m; + int n_idx = n_start + local_n; + + if (m_idx < M && n_idx < Cout) { + float result = out_smem[i]; + if constexpr (HAS_BIAS) { + result += bias[n_idx]; + } + y[m_idx * Cout + n_idx] = result; + } + } +} + +// ============================================================================= +// Standalone FP4 Fake Quantization Kernel (for testing) +// ============================================================================= +// Applies the same blockwise FP4 fake quant used in the GEMM A-tile loader, +// but on a flat 2D tensor [num_blocks, block_size]. +// Each warp processes one row (= one FP4 block). + +__global__ void fp4_fake_quant_kernel(const float *__restrict__ x, float *__restrict__ y, + const float *__restrict__ global_amax_ptr, int num_blocks, + int block_size) { + const int warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + const int lane_id = threadIdx.x % 32; + + if (warp_id >= num_blocks) + return; + + float global_scale = global_amax_ptr[0] / (6.0f * 448.0f); + + const float *row = x + warp_id * block_size; + float *out_row = y + warp_id * block_size; + + // Pass 1: compute block max via warp reduction + float local_max = 0.0f; + for (int i = lane_id; i < block_size; i += 32) { + local_max = fmaxf(local_max, fabsf(row[i])); + } + float block_max = warp_reduce_max(local_max); + + // Quantize the scale via FP8 E4M3 round-trip + float scale = quantize_scale_fp8(block_max, global_scale); + if (scale < 1e-5f) + scale = 1.0f; + float inv_scale = 1.0f / scale; + + // Pass 2: quantize + dequantize each element + for (int i = lane_id; i < block_size; i += 32) { + float val = row[i]; + float sign = (val >= 0.0f) ? 1.0f : -1.0f; + float q = fp4_quantize_value(fabsf(val) * inv_scale); + out_row[i] = sign * q * scale; + } +} + +torch::Tensor fp4_fake_quant_cuda(torch::Tensor x, torch::Tensor global_amax, int block_size) { + // x: [num_blocks, block_size] or flat [N] where N % block_size == 0 + auto x_flat = x.contiguous().view({-1}); + int N = x_flat.numel(); + int num_blocks = N / block_size; + + auto y = torch::empty_like(x_flat); + + // Launch: one warp (32 threads) per block + int threads_per_block = 256; // 8 warps per CUDA block + int warps_per_block = threads_per_block / 32; + int num_cuda_blocks = (num_blocks + warps_per_block - 1) / warps_per_block; + + fp4_fake_quant_kernel<<>>( + x_flat.data_ptr(), y.data_ptr(), global_amax.data_ptr(), num_blocks, + block_size); + + return y.view_as(x); +} + +// ============================================================================= +// Python Binding +// ============================================================================= + +torch::Tensor conv3d_implicit_gemm_cuda(torch::Tensor x_pad, torch::Tensor w_flat, + torch::Tensor bias, torch::Tensor act_amax, int N_batch, + int Cin, int Dp, int Hp, int Wp, int Cout, int OD, int OH, + int OW, int kD, int kH, int kW, int sd, int sh, int sw, + int dd, int dh, int dw, int M, int K, bool quant_act, + bool has_bias, int fp4_block_size) { + auto y = torch::zeros({M, Cout}, x_pad.options()); + + // Helper to compute padded 1D grid size for L2 swizzle + constexpr int GS = 8; // L2_SWIZZLE_GROUP + auto compute_grid = [&](int BM, int BN) -> dim3 { + int grid_m = (M + BM - 1) / BM; + int grid_n = (Cout + BN - 1) / BN; + int num_m_groups = (grid_m + GS - 1) / GS; + int total_blocks = num_m_groups * GS * grid_n; + return dim3(total_blocks, 1); + }; + +// Macro to dispatch kernel with all 4 template specializations +#define LAUNCH_WMMA_KERNEL(BM, BN, BK, WM, WN) \ + { \ + constexpr int BK_S = BK + 8; \ + constexpr int BN_S = BN + 8; \ + constexpr size_t smem_a = BM * BK_S * sizeof(__nv_bfloat16); \ + constexpr size_t smem_b = BK * BN_S * sizeof(__nv_bfloat16); \ + constexpr size_t smem = smem_a + smem_b; \ + \ + dim3 block(WM * WN * 32); \ + dim3 grid = compute_grid(BM, BN); \ + \ + auto set_smem = [](auto kernel) { \ + constexpr size_t s_a = BM * (BK + 8) * sizeof(__nv_bfloat16); \ + constexpr size_t s_b = BK * (BN + 8) * sizeof(__nv_bfloat16); \ + constexpr size_t s = s_a + s_b; \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, s); \ + }; \ + \ + if (quant_act && has_bias) { \ + auto kern = conv3d_implicit_gemm_wmma; \ + set_smem(kern); \ + kern<<>>(x_pad.data_ptr(), w_flat.data_ptr(), \ + bias.data_ptr(), y.data_ptr(), \ + act_amax.data_ptr(), Cin, Dp, Hp, Wp, Cout, OD, OH, OW, \ + kD, kH, kW, sd, sh, sw, dd, dh, dw, M, K); \ + } else if (quant_act) { \ + auto kern = conv3d_implicit_gemm_wmma; \ + set_smem(kern); \ + kern<<>>(x_pad.data_ptr(), w_flat.data_ptr(), \ + bias.data_ptr(), y.data_ptr(), \ + act_amax.data_ptr(), Cin, Dp, Hp, Wp, Cout, OD, OH, OW, \ + kD, kH, kW, sd, sh, sw, dd, dh, dw, M, K); \ + } else if (has_bias) { \ + auto kern = conv3d_implicit_gemm_wmma; \ + set_smem(kern); \ + kern<<>>(x_pad.data_ptr(), w_flat.data_ptr(), \ + bias.data_ptr(), y.data_ptr(), \ + act_amax.data_ptr(), Cin, Dp, Hp, Wp, Cout, OD, OH, OW, \ + kD, kH, kW, sd, sh, sw, dd, dh, dw, M, K); \ + } else { \ + auto kern = conv3d_implicit_gemm_wmma; \ + set_smem(kern); \ + kern<<>>(x_pad.data_ptr(), w_flat.data_ptr(), \ + bias.data_ptr(), y.data_ptr(), \ + act_amax.data_ptr(), Cin, Dp, Hp, Wp, Cout, OD, OH, OW, \ + kD, kH, kW, sd, sh, sw, dd, dh, dw, M, K); \ + } \ + } + + if (fp4_block_size == 128) { + // BLOCK_M=64, BLOCK_N=64, BLOCK_K=128, WARPS_M=2, WARPS_N=4 + // 8 warps = 256 threads -> faster cooperative loading + // Shared: 64*(128+8)*2 + 128*(64+8)*2 = 17,408 + 18,432 = 35,840 bytes (~35KB) + LAUNCH_WMMA_KERNEL(64, 64, 128, 2, 4) + } else { + // BLOCK_M=64, BLOCK_N=64, BLOCK_K=256, WARPS_M=2, WARPS_N=4 + // 8 warps = 256 threads -> faster cooperative loading + // Shared: 64*(256+8)*2 + 256*(64+8)*2 = 33,792 + 36,864 = 70,656 bytes (~69KB) + LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4) + } + +#undef LAUNCH_WMMA_KERNEL + + return y; +} diff --git a/experimental/conv/test_implicit_gemm.py b/experimental/conv/test_implicit_gemm.py new file mode 100644 index 0000000000..b548ed3f0a --- /dev/null +++ b/experimental/conv/test_implicit_gemm.py @@ -0,0 +1,780 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Conv3D implicit GEMM CUDA kernel. + +Tests both non-quantized path (vs cuDNN) and FP4-quantized path (vs Triton reference). +""" + +import pytest +import torch +import torch.nn.functional as F + + +@pytest.fixture(scope="module") +def cuda_conv3d(): + """Import and return the CUDA implicit GEMM conv3d function.""" + from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda + + return conv3d_implicit_gemm_cuda + + +def _triton_fp4_available(): + """Check if the Triton FP4 fake quant kernel is available (requires compute >= 8.9).""" + try: + import modelopt.torch.quantization.triton as triton_kernel + + return hasattr(triton_kernel, "fp4_fake_quant_block") + except ImportError: + return False + + +requires_triton_fp4 = pytest.mark.skipif( + not _triton_fp4_available(), + reason="Triton fp4_fake_quant_block not available (requires compute >= 8.9)", +) + + +# BF16 WMMA accumulates in FP32 but inputs are rounded to BF16, so expect diffs. +# For large K (e.g. 3456 = 128*27), max abs diff can reach ~0.8 due to BF16 rounding +# and different accumulation order vs cuDNN's FP32 path. +ATOL = 1.0 +RTOL = 1e-3 + + +def _run_conv3d_test(cuda_conv3d, x, w, bias, stride, padding, dilation): + """Helper: run both cuDNN and implicit GEMM, compare results.""" + ref = F.conv3d(x, w, bias=bias, stride=stride, padding=padding, dilation=dilation) + out = cuda_conv3d( + x, w, bias=bias, stride=stride, padding=padding, dilation=dilation, quant_act=False + ) + assert out.shape == ref.shape, f"Shape mismatch: {out.shape} vs {ref.shape}" + abs_diff = (out - ref).abs() + max_diff = abs_diff.max().item() + # Scale tolerance with K (reduction dimension) — BF16 rounding accumulates + cin = w.shape[1] + k_size = cin * w.shape[2] * w.shape[3] * w.shape[4] + scaled_atol = ATOL * (k_size / 1000.0) ** 0.5 + assert max_diff < scaled_atol, ( + f"Max abs diff {max_diff:.6e} exceeds tolerance {scaled_atol:.4f} (K={k_size})" + ) + # Check mean diff is small (more robust than quantile for large tensors) + mean_diff = abs_diff.mean().item() + assert mean_diff < scaled_atol * 0.1, f"Mean diff {mean_diff:.6e} too high" + return max_diff + + +class TestConv3dBasic: + """Basic correctness tests with simple shapes.""" + + def test_minimal(self, cuda_conv3d): + """Smallest possible conv3d: 1x1x1 kernel, single channel.""" + x = torch.randn(1, 1, 1, 1, 1, device="cuda", dtype=torch.float32) + w = torch.randn(1, 1, 1, 1, 1, device="cuda", dtype=torch.float32) + diff = _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + # K=1, so BF16 rounding is the only source of error + assert diff < 1e-2 + + def test_single_channel_3x3x3(self, cuda_conv3d): + """Single input/output channel with 3x3x3 kernel.""" + x = torch.randn(1, 1, 5, 5, 5, device="cuda", dtype=torch.float32) + w = torch.randn(1, 1, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_multi_channel(self, cuda_conv3d): + """Multiple input and output channels.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_with_bias(self, cuda_conv3d): + """Conv3d with bias.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + b = torch.randn(32, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, b, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_batch_size(self, cuda_conv3d): + """Batch size > 1.""" + x = torch.randn(4, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + +class TestConv3dStride: + """Tests with various stride configurations.""" + + def test_stride_2(self, cuda_conv3d): + """Uniform stride of 2.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (2, 2, 2), (1, 1, 1), (1, 1, 1)) + + def test_asymmetric_stride(self, cuda_conv3d): + """Different stride per dimension.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 2, 2), (1, 1, 1), (1, 1, 1)) + + +class TestConv3dPadding: + """Tests with various padding configurations.""" + + def test_no_padding(self, cuda_conv3d): + """Zero padding.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + def test_large_padding(self, cuda_conv3d): + """Padding larger than kernel radius.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (2, 2, 2), (1, 1, 1)) + + def test_asymmetric_padding(self, cuda_conv3d): + """Different padding per dimension.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 1, 2), (1, 1, 1)) + + +class TestConv3dDilation: + """Tests with dilation.""" + + def test_dilation_2(self, cuda_conv3d): + """Uniform dilation of 2.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (2, 2, 2), (2, 2, 2)) + + def test_asymmetric_dilation(self, cuda_conv3d): + """Different dilation per dimension.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 2, 2), (1, 2, 2)) + + +class TestConv3dKernelSizes: + """Tests with non-3x3x3 kernels.""" + + def test_1x1x1_kernel(self, cuda_conv3d): + """Pointwise 1x1x1 kernel.""" + x = torch.randn(1, 64, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(128, 64, 1, 1, 1, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + def test_asymmetric_kernel(self, cuda_conv3d): + """Kernel with different sizes per dimension (e.g. 1x3x3).""" + x = torch.randn(1, 16, 8, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 1, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 1, 1), (1, 1, 1)) + + def test_5x5x5_kernel(self, cuda_conv3d): + """Larger 5x5x5 kernel.""" + x = torch.randn(1, 8, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(16, 8, 5, 5, 5, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (2, 2, 2), (1, 1, 1)) + + +class TestConv3dRealisticShapes: + """Tests with shapes resembling real video diffusion models.""" + + def test_wan22_shape(self, cuda_conv3d): + """Shape from Wan2.2 video diffusion backbone.""" + x = torch.randn(1, 128, 21, 60, 106, device="cuda", dtype=torch.float32) + w = torch.randn(512, 128, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_large_cout(self, cuda_conv3d): + """Large output channel count.""" + x = torch.randn(1, 64, 8, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(512, 64, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_large_cin(self, cuda_conv3d): + """Large input channel count.""" + x = torch.randn(1, 512, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(64, 512, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + +class TestConv3dEdgeCases: + """Edge cases for tile boundary handling.""" + + def test_m_not_aligned_to_block(self, cuda_conv3d): + """M (N*OD*OH*OW) not a multiple of BLOCK_M=64.""" + # 1*3*5*7 = 105, not divisible by 64 + x = torch.randn(1, 8, 5, 7, 9, device="cuda", dtype=torch.float32) + w = torch.randn(16, 8, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_cout_not_aligned_to_block(self, cuda_conv3d): + """Cout not a multiple of BLOCK_N=64.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(17, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_k_not_aligned_to_block(self, cuda_conv3d): + """K (Cin*kD*kH*kW) not a multiple of BLOCK_K.""" + # Cin=7, kDHW=27, K=189 -- not a multiple of 128 or 256 + x = torch.randn(1, 7, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(16, 7, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_output_size_1x1x1(self, cuda_conv3d): + """Output spatial dims are all 1.""" + x = torch.randn(1, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + def test_single_output_element(self, cuda_conv3d): + """M=1: batch=1, output 1x1x1.""" + x = torch.randn(1, 4, 3, 3, 3, device="cuda", dtype=torch.float32) + w = torch.randn(1, 4, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + +class TestConv3dFP4BlockSize: + """Test both FP4 block size configs (affects tile shapes even for non-quant).""" + + def test_block_size_128(self, cuda_conv3d): + """Non-quant conv with BK=128 tile config.""" + x = torch.randn(1, 32, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(64, 32, 3, 3, 3, device="cuda", dtype=torch.float32) + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)) + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + quant_act=False, + fp4_block_size=128, + ) + assert out.shape == ref.shape + assert (out - ref).abs().max().item() < ATOL + + def test_block_size_256(self, cuda_conv3d): + """Non-quant conv with BK=256 tile config.""" + x = torch.randn(1, 32, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(64, 32, 3, 3, 3, device="cuda", dtype=torch.float32) + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)) + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + quant_act=False, + fp4_block_size=256, + ) + assert out.shape == ref.shape + assert (out - ref).abs().max().item() < ATOL + + +class TestConv3dDeterminism: + """Verify deterministic output across repeated calls.""" + + def test_deterministic(self, cuda_conv3d): + """Repeated calls produce identical output.""" + torch.manual_seed(123) + x = torch.randn(1, 32, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(64, 32, 3, 3, 3, device="cuda", dtype=torch.float32) + out1 = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + out2 = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + assert torch.equal(out1, out2), "Kernel is not deterministic" + + +# ============================================================================= +# FP4 Fake Quantization Tests +# ============================================================================= + + +@pytest.fixture(scope="module") +def cuda_fp4(): + """Import and return the CUDA FP4 fake quant function.""" + from experimental.conv.implicit_gemm_cuda import fp4_fake_quant + + return fp4_fake_quant + + +def _py_fp4_fake_quant_ref(x_flat, global_amax, block_size): + """Pure Python reference for FP4 fake quant (no BF16 rounding). + + This implements the exact same algorithm as the CUDA kernel: + 1. Compute global_scale = global_amax / (6 * 448) + 2. Per block: block_max = max(|x|), scale = fp8_e4m3_roundtrip(block_max / (6 * global_scale)) * global_scale + 3. Quantize each element to nearest E2M1 level, then dequantize. + """ + import math + + # E2M1 quantization levels: {0, 0.5, 1, 1.5, 2, 3, 4, 6} + # Boundaries (midpoints): <=0.25->0, <0.75->0.5, <=1.25->1, <1.75->1.5, <=2.5->2, <3.5->3, <=5->4, >5->6 + def quantize_e2m1(scaled_abs): + if scaled_abs <= 0.25: + return 0.0 + elif scaled_abs < 0.75: + return 0.5 + elif scaled_abs <= 1.25: + return 1.0 + elif scaled_abs < 1.75: + return 1.5 + elif scaled_abs <= 2.5: + return 2.0 + elif scaled_abs < 3.5: + return 3.0 + elif scaled_abs <= 5.0: + return 4.0 + else: + return 6.0 + + def fp8_e4m3_roundtrip(val): + """Simulate FP8 E4M3 round-trip in Python.""" + if val == 0.0: + return 0.0 + sign = 1.0 if val >= 0 else -1.0 + val = abs(val) + # FP8 E4M3: bias=7, 3 mantissa bits, max=448, no inf/nan + if val > 448.0: + return sign * 448.0 + # Compute exponent + exp = math.floor(math.log2(val)) + exp = max(exp, -6) # min normal exponent for E4M3 + # Compute mantissa (3 bits) + mantissa = val / (2.0**exp) # 1.xxx + mantissa_bits = round((mantissa - 1.0) * 8.0) # 3 bits + if mantissa_bits > 7: + mantissa_bits = 0 + exp += 1 + if exp > 8: + return sign * 448.0 + # Reconstruct + result = (1.0 + mantissa_bits / 8.0) * (2.0**exp) + return sign * result + + global_scale = float(global_amax) / (6.0 * 448.0) + x_np = x_flat.cpu().float().numpy().copy() + num_blocks = len(x_np) // block_size + + for b in range(num_blocks): + block = x_np[b * block_size : (b + 1) * block_size] + block_max = float(max(abs(v) for v in block)) + + # Scale quantization + scaled = block_max / (6.0 * global_scale) + scaled = min(scaled, 448.0) + quantized_scale = fp8_e4m3_roundtrip(scaled) * global_scale + if quantized_scale < 1e-5: + quantized_scale = 1.0 + inv_scale = 1.0 / quantized_scale + + for i in range(block_size): + val = block[i] + sign = 1.0 if val >= 0 else -1.0 + q = quantize_e2m1(abs(val) * inv_scale) + x_np[b * block_size + i] = sign * q * quantized_scale + + return torch.tensor(x_np, device=x_flat.device) + + +class TestFP4FakeQuantValues: + """Test FP4 fake quant with known E2M1 table values.""" + + def test_exact_e2m1_values(self, cuda_fp4): + """E2M1 representable values should round-trip exactly (when scale=1 via amax=6*448).""" + # With global_amax = 6*448 = 2688, global_scale = 1.0 + # A single-block input with max=6 -> block_max=6, scaled=6/(6*1)=1.0 + # fp8_e4m3(1.0)=1.0, scale = 1.0*1.0 = 1.0 + block_size = 8 + vals = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], device="cuda", dtype=torch.float32) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(vals, amax, block_size) + assert torch.allclose(out, vals, atol=1e-5), f"Got {out} vs expected {vals}" + + def test_exact_e2m1_negative(self, cuda_fp4): + """Negative E2M1 values should also round-trip.""" + block_size = 8 + vals = torch.tensor([0, -0.5, -1, -1.5, -2, -3, -4, -6], device="cuda", dtype=torch.float32) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(vals, amax, block_size) + assert torch.allclose(out, vals, atol=1e-5), f"Got {out} vs expected {vals}" + + def test_below_boundary(self, cuda_fp4): + """Values slightly below E2M1 boundaries should quantize down.""" + block_size = 8 + # Boundaries: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 + # Slightly below -> quantize to lower level + inp = torch.tensor( + [0.15, 0.65, 1.15, 1.65, 2.4, 3.4, 4.9, 6.0], device="cuda", dtype=torch.float32 + ) + expected = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device="cuda", dtype=torch.float32 + ) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + assert torch.allclose(out, expected, atol=1e-5), f"Got {out} vs expected {expected}" + + def test_above_boundary(self, cuda_fp4): + """Values slightly above E2M1 boundaries should quantize up.""" + block_size = 8 + inp = torch.tensor( + [0.35, 0.85, 1.35, 1.85, 2.6, 3.6, 5.1, 6.0], device="cuda", dtype=torch.float32 + ) + expected = torch.tensor( + [0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 6.0], device="cuda", dtype=torch.float32 + ) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + assert torch.allclose(out, expected, atol=1e-5), f"Got {out} vs expected {expected}" + + def test_mixed_signs(self, cuda_fp4): + """Mixed positive/negative values.""" + block_size = 8 + inp = torch.tensor([-6, -3, -1, 0, 0.5, 2, 4, 6], device="cuda", dtype=torch.float32) + expected = torch.tensor([-6, -3, -1, 0, 0.5, 2, 4, 6], device="cuda", dtype=torch.float32) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + assert torch.allclose(out, expected, atol=1e-5), f"Got {out} vs expected {expected}" + + +class TestFP4FakeQuantScale: + """Test FP4 scale computation and FP8 round-trip.""" + + def test_scale_factor(self, cuda_fp4): + """When amax != 6*448, scale should adjust values proportionally.""" + block_size = 8 + # global_amax = 12*448 = 5376, global_scale = 2.0 + # Input block max = 12 -> scaled = 12/(6*2) = 1.0 -> fp8(1.0) = 1.0 -> scale = 2.0 + # So input 12 -> |12|/2 = 6.0 -> q=6 -> 6*2 = 12 + inp = torch.tensor([0, 1, 2, 3, 4, 6, 8, 12], device="cuda", dtype=torch.float32) + amax = torch.tensor([12.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + # Expected: each val/2.0 -> quantize to E2M1 -> * 2.0 + expected = torch.tensor([0, 1, 2, 3, 4, 6, 8, 12], device="cuda", dtype=torch.float32) + assert torch.allclose(out, expected, atol=1e-4), f"Got {out} vs expected {expected}" + + def test_zero_block(self, cuda_fp4): + """All-zero block should produce all zeros.""" + block_size = 16 + inp = torch.zeros(block_size, device="cuda", dtype=torch.float32) + amax = torch.tensor([1.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + assert torch.equal(out, inp) + + def test_multiple_blocks(self, cuda_fp4): + """Multiple blocks with different ranges.""" + block_size = 8 + # Block 0: small values, Block 1: large values + block0 = torch.tensor([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], device="cuda") + block1 = torch.tensor([0, 1, 2, 3, 4, 5, 6, 6], device="cuda") + inp = torch.cat([block0, block1]) + amax = inp.abs().max().unsqueeze(0) + out = cuda_fp4(inp, amax, block_size) + # Each block should be independently quantized + assert out.shape == inp.shape + # Block 1 exact values should be close to E2M1 levels + assert out[8:].abs().max() <= 6.0 + 1e-5 + + +class TestFP4FakeQuantBlockSizes: + """Test different block sizes.""" + + @pytest.mark.parametrize("block_size", [8, 16, 32, 64, 128, 256]) + def test_block_sizes(self, cuda_fp4, block_size): + """FP4 quant should work for various block sizes.""" + torch.manual_seed(42) + num_blocks = 4 + inp = torch.randn(num_blocks * block_size, device="cuda", dtype=torch.float32) * 5 + amax = inp.abs().max().unsqueeze(0) + out = cuda_fp4(inp, amax, block_size) + assert out.shape == inp.shape + # Output should not be all zeros for non-zero input + assert out.abs().max() > 0 + # Output should be <= max possible after quant + assert out.abs().max() <= inp.abs().max() * 1.5 # generous bound + + +class TestFP4FakeQuantVsReference: + """Compare CUDA FP4 fake quant against Python reference implementation.""" + + @pytest.mark.parametrize("block_size", [8, 16, 32]) + def test_vs_python_ref(self, cuda_fp4, block_size): + """CUDA kernel should match the Python reference exactly.""" + torch.manual_seed(123) + num_blocks = 8 + inp = torch.randn(num_blocks * block_size, device="cuda") * 10 + amax = inp.abs().max().unsqueeze(0) + + cuda_out = cuda_fp4(inp, amax, block_size) + ref_out = _py_fp4_fake_quant_ref(inp, amax, block_size) + + assert torch.allclose(cuda_out, ref_out, atol=1e-5), ( + f"CUDA vs Python ref max diff: {(cuda_out - ref_out).abs().max().item():.6e}" + ) + + @pytest.mark.parametrize("block_size", [16, 32]) + def test_vs_python_ref_large(self, cuda_fp4, block_size): + """Larger tensor test against Python reference.""" + torch.manual_seed(456) + num_blocks = 64 + inp = torch.randn(num_blocks * block_size, device="cuda") * 20 + amax = inp.abs().max().unsqueeze(0) + + cuda_out = cuda_fp4(inp, amax, block_size) + ref_out = _py_fp4_fake_quant_ref(inp, amax, block_size) + + assert torch.allclose(cuda_out, ref_out, atol=1e-4), ( + f"CUDA vs Python ref max diff: {(cuda_out - ref_out).abs().max().item():.6e}" + ) + + +class TestFP4FakeQuantVsTriton: + """Compare CUDA FP4 fake quant against Triton fp4_fake_quant_block reference.""" + + @requires_triton_fp4 + @pytest.mark.parametrize("block_size", [16, 32, 64]) + @pytest.mark.parametrize("num_blocks", [4, 16, 64]) + def test_vs_triton(self, cuda_fp4, block_size, num_blocks): + """CUDA kernel should match the Triton fp4_fake_quant_block.""" + from modelopt.torch.quantization.triton import fp4_fake_quant_block + + torch.manual_seed(42) + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 + global_amax = x.abs().max() + + cuda_out = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + triton_out = fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=num_blocks, + tile_cols=block_size, + ) + + assert torch.allclose(cuda_out, triton_out, atol=1e-5), ( + f"CUDA vs Triton max diff: {(cuda_out - triton_out).abs().max().item():.6e}\n" + f"Mean diff: {(cuda_out - triton_out).abs().mean().item():.6e}" + ) + + +class TestFP4FakeQuantDeterminism: + """Verify FP4 quant is deterministic.""" + + def test_deterministic(self, cuda_fp4): + """Repeated calls produce identical output.""" + torch.manual_seed(99) + inp = torch.randn(256, device="cuda") * 5 + amax = inp.abs().max().unsqueeze(0) + out1 = cuda_fp4(inp, amax, 16) + out2 = cuda_fp4(inp, amax, 16) + assert torch.equal(out1, out2), "FP4 fake quant is not deterministic" + + +# ============================================================================= +# Cross-validation: experimental FP4 vs modelopt FP4 implementations +# ============================================================================= + + +def _modelopt_cuda_ext_mx_available(): + """Check if the modelopt CUDA MX extension is available.""" + try: + from modelopt.torch.quantization.extensions import get_cuda_ext_mx + + return get_cuda_ext_mx() is not None + except Exception: + return False + + +def _modelopt_dynamic_block_quantize_available(): + """Check if dynamic_block_quantize_op is available.""" + try: + from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op + + return dynamic_block_quantize_op is not None + except Exception: + return False + + +requires_cuda_ext_mx = pytest.mark.skipif( + not _modelopt_cuda_ext_mx_available(), + reason="modelopt cuda_ext_mx not available", +) + +requires_dynamic_block_quantize = pytest.mark.skipif( + not _modelopt_dynamic_block_quantize_available(), + reason="modelopt dynamic_block_quantize_op not available", +) + + +class TestFP4FakeQuantVsModelopt: + """Compare experimental CUDA FP4 fake quant against all modelopt FP4 implementations. + + This ensures the standalone FP4 kernel in experimental/conv produces the same + results as the official modelopt quantization paths: + 1. Triton fp4_fake_quant_block (Hopper+ dynamic blockwise) + 2. cuda_ext_mx.fused_amax_convert (CUDA extension fallback) + 3. dynamic_block_quantize_op (high-level API that dispatches to either) + """ + + @requires_triton_fp4 + @pytest.mark.parametrize("block_size", [16, 32, 64]) + @pytest.mark.parametrize("seed", [42, 123, 999]) + def test_vs_triton_fp4_fake_quant_block(self, cuda_fp4, block_size, seed): + """Compare against modelopt Triton fp4_fake_quant_block.""" + from modelopt.torch.quantization.triton import fp4_fake_quant_block + + torch.manual_seed(seed) + num_blocks = 16 + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 + global_amax = x.abs().max() + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=num_blocks, + tile_cols=block_size, + ) + + assert torch.allclose(ours, theirs, atol=1e-5), ( + f"experimental vs modelopt Triton max diff: {(ours - theirs).abs().max().item():.6e}" + ) + + @requires_cuda_ext_mx + @pytest.mark.parametrize("block_size", [16, 32]) + @pytest.mark.parametrize("seed", [42, 123]) + def test_vs_cuda_ext_mx(self, cuda_fp4, block_size, seed): + """Compare against modelopt cuda_ext_mx.fused_amax_convert.""" + from modelopt.torch.quantization.extensions import get_cuda_ext_mx + from modelopt.torch.quantization.tensor_quant import mx_format_map + + cuda_ext_mx = get_cuda_ext_mx() + torch.manual_seed(seed) + num_blocks = 16 + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 + global_amax = x.abs().max() + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = cuda_ext_mx.fused_amax_convert( + x, + block_size, + getattr(cuda_ext_mx.Types, mx_format_map[(2, 1)]), + getattr(cuda_ext_mx.Types, mx_format_map[(4, 3)]), + global_amax, + ) + + assert torch.allclose(ours, theirs, atol=1e-5), ( + f"experimental vs modelopt cuda_ext_mx max diff: " + f"{(ours - theirs).abs().max().item():.6e}" + ) + + @requires_dynamic_block_quantize + @pytest.mark.parametrize("seed", [42, 123, 999]) + def test_vs_dynamic_block_quantize_op(self, cuda_fp4, seed): + """Compare against modelopt dynamic_block_quantize_op (high-level API). + + This is the function used by the actual quantization pipeline with + num_bits=4 (E2M1) and scale_bits=8 (E4M3). + Note: dynamic_block_quantize_op dispatches to Triton with default block_size=16. + """ + from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op + + block_size = 16 # dynamic_block_quantize_op uses block_size=16 for Triton path + torch.manual_seed(seed) + num_blocks = 16 + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 + global_amax = x.abs().max() + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = dynamic_block_quantize_op( + x, + block_size, + global_amax, + num_bits=4, # total bits = 1 sign + 2 exp + 1 mantissa + exponent_bits=2, + scale_num_bits=8, # FP8 E4M3 for scales + scale_exponent_bits=4, + ) + + assert torch.allclose(ours, theirs, atol=1e-5), ( + f"experimental vs modelopt dynamic_block_quantize_op max diff: " + f"{(ours - theirs).abs().max().item():.6e}" + ) + + @requires_triton_fp4 + def test_vs_triton_realistic_shape(self, cuda_fp4): + """Realistic activation shape from a Conv3D layer (flattened).""" + torch.manual_seed(42) + block_size = 16 + # Simulate a large tensor: 256 blocks of 16 elements + # (tile_rows must be power-of-2 for Triton block_ptr) + num_blocks = 256 + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 5 + global_amax = x.abs().max() + + from modelopt.torch.quantization.triton import fp4_fake_quant_block + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=16, + tile_cols=block_size, + ) + + max_diff = (ours - theirs).abs().max().item() + mean_diff = (ours - theirs).abs().mean().item() + assert torch.allclose(ours, theirs, atol=1e-5), ( + f"Realistic shape: experimental vs Triton max diff: {max_diff:.6e}, " + f"mean diff: {mean_diff:.6e}" + ) + + @requires_triton_fp4 + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_vs_triton_input_dtypes(self, cuda_fp4, dtype): + """Test that our kernel handles different input dtypes correctly. + + Our kernel casts to float32 internally, so the result should match + Triton's output when both receive the same dtype input. + """ + from modelopt.torch.quantization.triton import fp4_fake_quant_block + + torch.manual_seed(42) + block_size = 16 + num_blocks = 8 + x = (torch.randn(num_blocks, block_size, device="cuda") * 5).to(dtype) + global_amax = x.float().abs().max() + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=num_blocks, + tile_cols=block_size, + ) + + # Both should return the input dtype + assert ours.dtype == dtype + assert theirs.dtype == dtype + + # Compare in float32 + max_diff = (ours.float() - theirs.float()).abs().max().item() + # BF16/FP16 input rounding may cause small diffs + tol = 1e-2 if dtype != torch.float32 else 1e-5 + assert max_diff < tol, f"dtype={dtype}: experimental vs Triton max diff: {max_diff:.6e}" From 66278df79396f11e043fc61d2a614df6ddac2b97 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 5 Mar 2026 05:01:38 +0000 Subject: [PATCH 05/25] Update the README Signed-off-by: Jingyu Xin --- experimental/conv/README.md | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/experimental/conv/README.md b/experimental/conv/README.md index 55a85f875e..49582f03a3 100644 --- a/experimental/conv/README.md +++ b/experimental/conv/README.md @@ -36,16 +36,22 @@ block_size = 128 # Without FP4 activation quantization (drop-in-style Conv3D call) out = conv3d_implicit_gemm_cuda(x, w, stride=(1, 1, 1), padding=(1, 1, 1)) -# Optional block quantization of weights for experiments -w_q = dynamic_block_quantize_op( - w, +# Optional FP4 block quantization of weights along the GEMM K dimension. +# The kernel's A-tile (activations) is quantized along K = Cin*kD*kH*kW, +# so weights must be flattened to [Cout, K] before quantizing to match. +Cout, Cin = w.shape[:2] +K = Cin * w.shape[2] * w.shape[3] * w.shape[4] +w_flat = w.reshape(Cout, K) +w_q_flat = dynamic_block_quantize_op( + w_flat, block_size, - w.abs().max().unsqueeze(0), + w_flat.abs().max().unsqueeze(0), 4, # num_bits 2, # exponent_bits 8, # scale_num_bits 4, # scale_exponent_bits ) +w_q = w_q_flat.reshape_as(w) # With FP4 activation fake quantization out_q = conv3d_implicit_gemm_cuda( @@ -73,7 +79,7 @@ Function: `conv3d_implicit_gemm_cuda(...)` from `experimental/conv/implicit_gemm | `dilation` | Convolution dilation `(D, H, W)` | | `act_amax` | Activation abs-max scalar tensor (required when `quant_act=True`) | | `quant_act` | Enable FP4 fake quantization on activations | -| `FP4_BLOCK_SIZE` | FP4 quantization block size (`128` or `256`) | +| `fp4_block_size` | FP4 quantization block size (`128` or `256`) | ## Status From e2c375d7fbfa1841154d43d1243f48c04c3d4cf4 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 5 Mar 2026 05:33:34 +0000 Subject: [PATCH 06/25] E2E implicit gemm nvfp4 results Signed-off-by: Jingyu Xin --- .../quantization/nn/modules/quant_conv.py | 95 ++++++++++++++++++- 1 file changed, 94 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/nn/modules/quant_conv.py b/modelopt/torch/quantization/nn/modules/quant_conv.py index 44f0ae663c..3b9c435ae2 100644 --- a/modelopt/torch/quantization/nn/modules/quant_conv.py +++ b/modelopt/torch/quantization/nn/modules/quant_conv.py @@ -15,11 +15,15 @@ """Quantized convolution.""" +import logging + import torch.nn as nn from ... import tensor_quant from .quant_module import QuantLinearConvBase, QuantModuleRegistry, _LegacyQuantLinearConvBaseMixin +logger = logging.getLogger(__name__) + __all__ = [ "Conv1d", "Conv2d", @@ -62,12 +66,101 @@ class QuantConv2d(_LegacyQuantLinearConvBaseMixin, nn.Conv2d): default_quant_desc_weight = _QuantConv2d.default_quant_desc_weight +def _is_nvfp4_quantizer(quantizer) -> bool: + """Check if a TensorQuantizer is configured for NVFP4 dynamic block quantization.""" + return ( + hasattr(quantizer, "_num_bits") + and quantizer._num_bits == (2, 1) + and hasattr(quantizer, "_block_sizes") + and quantizer._block_sizes is not None + and quantizer._block_sizes.get("scale_bits") == (4, 3) + and quantizer._block_sizes.get("type") == "dynamic" + ) + + +def _nvfp4_quantize_weight_along_k(weight, weight_quantizer): + """Apply NVFP4 fake quantization to Conv3D weight along the GEMM K dimension. + + The implicit GEMM maps K = Cin * kD * kH * kW. The default quantizer would + quantize along the last dim (kW), which is wrong. We reshape to [Cout, K] + so blocks are along the contraction dimension. + """ + cout = weight.shape[0] + k = weight[0].numel() # Cin * kD * kH * kW + w_flat = weight.reshape(cout, k) + w_q_flat = weight_quantizer(w_flat) + return w_q_flat.reshape_as(weight) + + @QuantModuleRegistry.register({nn.Conv3d: "nn.Conv3d"}) class _QuantConv3d(QuantLinearConvBase): - """Quantized 3D convolution.""" + """Quantized 3D convolution. + + When both input and weight quantizers are configured for NVFP4, the forward + uses the fused implicit GEMM kernel from experimental/conv which performs + activation FP4 quantization inside the kernel. The weight is FP4-quantized + along the GEMM K dimension (Cin*kD*kH*kW) before being passed to the kernel. + + For all other quantization configs, the default cuDNN path is used. + """ default_quant_desc_weight = tensor_quant.QUANT_DESC_8BIT_CONV3D_WEIGHT_PER_CHANNEL + def _should_use_implicit_gemm(self): + """Check if both quantizers are NVFP4 and the implicit GEMM kernel is available.""" + if not ( + hasattr(self, "input_quantizer") + and hasattr(self, "weight_quantizer") + and _is_nvfp4_quantizer(self.input_quantizer) + and _is_nvfp4_quantizer(self.weight_quantizer) + ): + return False + try: + from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda # noqa: F401 + + return True + except ImportError: + return False + + def forward(self, input, *args, **kwargs): + """Forward with implicit GEMM for NVFP4, default path otherwise.""" + if not self._should_use_implicit_gemm(): + return super().forward(input, *args, **kwargs) + + from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda + + # --- Calibration: let input_quantizer collect amax without fake-quantizing --- + if self.input_quantizer._if_calib: + self.input_quantizer.collect(input) + + # --- Get activation amax for the kernel --- + act_amax = self.input_quantizer._get_amax(input) + + # --- Quantize weight along K dimension --- + weight = _nvfp4_quantize_weight_along_k(self.weight, self.weight_quantizer) + + # --- Get fp4_block_size from the input quantizer config --- + fp4_block_size = self.input_quantizer._block_sizes.get(-1, 16) + + # --- Call implicit GEMM kernel --- + output = conv3d_implicit_gemm_cuda( + input, + weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + act_amax=act_amax, + quant_act=self.input_quantizer._if_quant and not self.input_quantizer._disabled, + fp4_block_size=fp4_block_size, + ) + + # --- Output quantizer (usually disabled for NVFP4) --- + if hasattr(self, "output_quantizer"): + output = self.output_quantizer(output) + + return output + class QuantConv3d(_LegacyQuantLinearConvBaseMixin, nn.Conv3d): """Quantized 3D convolution.""" From a8174b8671848586548fe4d85d9f55cd40560b03 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 5 Mar 2026 05:44:23 +0000 Subject: [PATCH 07/25] Update the quant_conv Signed-off-by: Jingyu Xin --- modelopt/torch/quantization/nn/modules/quant_conv.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/quant_conv.py b/modelopt/torch/quantization/nn/modules/quant_conv.py index 3b9c435ae2..fa9caca669 100644 --- a/modelopt/torch/quantization/nn/modules/quant_conv.py +++ b/modelopt/torch/quantization/nn/modules/quant_conv.py @@ -127,11 +127,12 @@ def forward(self, input, *args, **kwargs): if not self._should_use_implicit_gemm(): return super().forward(input, *args, **kwargs) - from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda + # During calibration we only need to collect amax — use the faster + # default cuDNN path since the conv output itself doesn't matter. + if self.input_quantizer._if_calib and not self.input_quantizer._if_quant: + return super().forward(input, *args, **kwargs) - # --- Calibration: let input_quantizer collect amax without fake-quantizing --- - if self.input_quantizer._if_calib: - self.input_quantizer.collect(input) + from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda # --- Get activation amax for the kernel --- act_amax = self.input_quantizer._get_amax(input) @@ -151,7 +152,7 @@ def forward(self, input, *args, **kwargs): padding=self.padding, dilation=self.dilation, act_amax=act_amax, - quant_act=self.input_quantizer._if_quant and not self.input_quantizer._disabled, + quant_act=not self.input_quantizer._disabled, fp4_block_size=fp4_block_size, ) From 8dfe2508a7ff1f687670ed90a1477fc7dd4196f3 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Mon, 9 Mar 2026 22:49:18 +0000 Subject: [PATCH 08/25] Update the LTX2 recipe Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/models_utils.py | 4 ++-- examples/diffusers/quantization/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index d8ca11ed3f..8344e97178 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -151,7 +151,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.FLUX_SCHNELL: _FLUX_BASE_CONFIG, ModelType.LTX_VIDEO_DEV: { "backbone": "transformer", - "dataset": _SD_PROMPTS_DATASET, + "dataset": _OPENVID_DATASET, "inference_extra_args": { "height": 512, "width": 704, @@ -161,7 +161,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: }, ModelType.LTX2: { "backbone": "transformer", - "dataset": _SD_PROMPTS_DATASET, + "dataset": _OPENVID_DATASET, "inference_extra_args": { "height": 768, "width": 1280, diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index d89915da28..f88c55a539 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -70,7 +70,7 @@ def check_conv_and_mha(backbone, if_fp4, quantize_mha): def filter_func_ltx_video(name: str) -> bool: """Filter function specifically for LTX-Video models.""" pattern = re.compile( - r".*(proj_in|time_embed|caption_projection|proj_out|patchify_proj|adaln_single).*" + r".*(proj_in|time_embed|caption_projection|proj_out|patchify_proj|adaln_single|transformer_blocks\.(0|1|2|45|46|47)\.).*" ) return pattern.match(name) is not None From 7ead24d090c641edac38c4bc4d5a09a540e72c8f Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Mon, 9 Mar 2026 23:06:43 +0000 Subject: [PATCH 09/25] revert the change Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/models_utils.py | 4 ++-- examples/diffusers/quantization/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index 8344e97178..d8ca11ed3f 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -151,7 +151,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.FLUX_SCHNELL: _FLUX_BASE_CONFIG, ModelType.LTX_VIDEO_DEV: { "backbone": "transformer", - "dataset": _OPENVID_DATASET, + "dataset": _SD_PROMPTS_DATASET, "inference_extra_args": { "height": 512, "width": 704, @@ -161,7 +161,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: }, ModelType.LTX2: { "backbone": "transformer", - "dataset": _OPENVID_DATASET, + "dataset": _SD_PROMPTS_DATASET, "inference_extra_args": { "height": 768, "width": 1280, diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index f88c55a539..d89915da28 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -70,7 +70,7 @@ def check_conv_and_mha(backbone, if_fp4, quantize_mha): def filter_func_ltx_video(name: str) -> bool: """Filter function specifically for LTX-Video models.""" pattern = re.compile( - r".*(proj_in|time_embed|caption_projection|proj_out|patchify_proj|adaln_single|transformer_blocks\.(0|1|2|45|46|47)\.).*" + r".*(proj_in|time_embed|caption_projection|proj_out|patchify_proj|adaln_single).*" ) return pattern.match(name) is not None From 52bb60d08fb1caba82f57465cd1693712fca2775 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Mon, 9 Mar 2026 23:51:33 +0000 Subject: [PATCH 10/25] Update some of the code and checks Signed-off-by: Jingyu Xin --- experimental/conv/README.md | 4 +- experimental/conv/implicit_gemm_cuda.py | 18 +- experimental/conv/implicit_gemm_kernel.cu | 81 ++++-- experimental/conv/test_implicit_gemm.py | 315 ++++++++++++++++++++-- 4 files changed, 369 insertions(+), 49 deletions(-) diff --git a/experimental/conv/README.md b/experimental/conv/README.md index 49582f03a3..9fa0d80f64 100644 --- a/experimental/conv/README.md +++ b/experimental/conv/README.md @@ -61,7 +61,7 @@ out_q = conv3d_implicit_gemm_cuda( padding=(1, 1, 1), act_amax=x.abs().max().unsqueeze(0), quant_act=True, - fp4_block_size=block_size, # 128 or 256 + fp4_block_size=block_size, # 16, 32, 64, 128, or 256 ) ``` @@ -79,7 +79,7 @@ Function: `conv3d_implicit_gemm_cuda(...)` from `experimental/conv/implicit_gemm | `dilation` | Convolution dilation `(D, H, W)` | | `act_amax` | Activation abs-max scalar tensor (required when `quant_act=True`) | | `quant_act` | Enable FP4 fake quantization on activations | -| `fp4_block_size` | FP4 quantization block size (`128` or `256`) | +| `fp4_block_size` | FP4 quantization block size (`16`, `32`, `64`, `128`, or `256`) | ## Status diff --git a/experimental/conv/implicit_gemm_cuda.py b/experimental/conv/implicit_gemm_cuda.py index b5cc6c435b..b85b373dd5 100644 --- a/experimental/conv/implicit_gemm_cuda.py +++ b/experimental/conv/implicit_gemm_cuda.py @@ -36,7 +36,7 @@ def _get_cuda_module(): from torch.utils.cpp_extension import load _cuda_module = load( - name="conv3d_implicit_gemm_cuda_v19_wmma", + name="conv3d_implicit_gemm_cuda_v20_wmma", sources=[ os.path.join(_KERNEL_DIR, "implicit_gemm_binding.cpp"), os.path.join(_KERNEL_DIR, "implicit_gemm_kernel.cu"), @@ -94,11 +94,20 @@ def conv3d_implicit_gemm_cuda( dilation: Convolution dilation (D, H, W) act_amax: Activation max value for FP4 quantization quant_act: Whether to apply FP4 quantization to activations - fp4_block_size: FP4 quantization block size (128 or 256) + fp4_block_size: FP4 quantization block size (16, 32, 64, 128, or 256) Returns: Output tensor [N, Cout, OD, OH, OW] + + Raises: + ValueError: If fp4_block_size is not one of {16, 32, 64, 128, 256}. """ + valid_block_sizes = {16, 32, 64, 128, 256} + if fp4_block_size not in valid_block_sizes: + raise ValueError( + f"fp4_block_size must be one of {sorted(valid_block_sizes)}, got {fp4_block_size}" + ) + cuda_mod = _get_cuda_module() assert x.ndim == 5 and w.ndim == 5 @@ -130,7 +139,10 @@ def conv3d_implicit_gemm_cuda( has_bias = bias is not None bias_t = bias.float().contiguous() if has_bias else torch.empty(0, device=x.device) # type: ignore[union-attr] - do_quant = quant_act and act_amax is not None + if quant_act and act_amax is None: + raise ValueError("act_amax is required when quant_act=True") + + do_quant = quant_act amax_t = act_amax.float().contiguous() if do_quant else torch.empty(0, device=x.device) # type: ignore[union-attr] y_flat = cuda_mod.conv3d_implicit_gemm_cuda( diff --git a/experimental/conv/implicit_gemm_kernel.cu b/experimental/conv/implicit_gemm_kernel.cu index 87dadf3b4b..a3b40f4848 100644 --- a/experimental/conv/implicit_gemm_kernel.cu +++ b/experimental/conv/implicit_gemm_kernel.cu @@ -138,7 +138,7 @@ __device__ __forceinline__ float quantize_scale_fp8(float block_max, float globa // Bs[BLOCK_K][BN_STRIDE] - K-major (row_major for WMMA B-fragments) template + int WARPS_N, int FP4_BLOCK_SIZE = BLOCK_K, int L2_SWIZZLE_GROUP = 8> __global__ void __launch_bounds__(WARPS_M * WARPS_N * 32, 2) conv3d_implicit_gemm_wmma(const float *__restrict__ x_pad, const float *__restrict__ w_flat, const float *__restrict__ bias, float *__restrict__ y, @@ -234,7 +234,11 @@ __global__ void __launch_bounds__(WARPS_M * WARPS_N * 32, 2) // ================================================================= if constexpr (QUANT_ACT) { // Fused FP4 quantization: each warp handles M-rows + // FP4_BLOCK_SIZE can be smaller than BLOCK_K — we quantize in sub-blocks + static_assert(BLOCK_K % FP4_BLOCK_SIZE == 0, "BLOCK_K must be divisible by FP4_BLOCK_SIZE"); + static_assert(FP4_BLOCK_SIZE >= 16, "FP4_BLOCK_SIZE must be >= 16"); constexpr int ELEMS_PER_LANE = (BLOCK_K + 31) / 32; + constexpr int NUM_FP4_BLOCKS = BLOCK_K / FP4_BLOCK_SIZE; for (int m = warp_id; m < BLOCK_M; m += NUM_WARPS) { int m_idx = m_start + m; @@ -254,7 +258,7 @@ __global__ void __launch_bounds__(WARPS_M * WARPS_N * 32, 2) ow_val = 0; } - float local_max = 0.0f; + // Pass 1: Load all values from global memory float vals[ELEMS_PER_LANE]; #pragma unroll @@ -279,23 +283,48 @@ __global__ void __launch_bounds__(WARPS_M * WARPS_N * 32, 2) } } vals[i] = val; - local_max = fmaxf(local_max, fabsf(val)); } - float block_max = warp_reduce_max(local_max); - float scale = quantize_scale_fp8(block_max, global_scale); - if (scale < 1e-5f) - scale = 1.0f; - float inv_scale = 1.0f / scale; + // Pass 2: For each FP4 sub-block, find max → compute scale → quantize + // Pre-compute per-sub-block scales + float scales[NUM_FP4_BLOCKS]; + float inv_scales[NUM_FP4_BLOCKS]; +#pragma unroll + for (int sb = 0; sb < NUM_FP4_BLOCKS; sb++) { + const int k_sb_start = sb * FP4_BLOCK_SIZE; + const int k_sb_end = k_sb_start + FP4_BLOCK_SIZE; + + // Each lane accumulates max over its elements in this sub-block + float local_max = 0.0f; +#pragma unroll + for (int i = 0; i < ELEMS_PER_LANE; i++) { + int k = lane_id + i * 32; + // Compiler resolves this at compile time for unrolled loops + if (k >= k_sb_start && k < k_sb_end) { + local_max = fmaxf(local_max, fabsf(vals[i])); + } + } + + // Warp reduce — lanes outside this sub-block contribute 0, which is correct + float block_max = warp_reduce_max(local_max); + float scale = quantize_scale_fp8(block_max, global_scale); + if (scale < 1e-5f) + scale = 1.0f; + scales[sb] = scale; + inv_scales[sb] = 1.0f / scale; + } + + // Pass 3: Quantize and store to shared memory #pragma unroll for (int i = 0; i < ELEMS_PER_LANE; i++) { int k = lane_id + i * 32; if (k < BLOCK_K) { + int sb = k / FP4_BLOCK_SIZE; // compile-time shift for power-of-2 float val = vals[i]; float sign = (val >= 0.0f) ? 1.0f : -1.0f; - float q = fp4_quantize_value(fabsf(val) * inv_scale); - float result = sign * q * scale; + float q = fp4_quantize_value(fabsf(val) * inv_scales[sb]); + float result = sign * q * scales[sb]; // M-major: As[m * BK_STRIDE + k] As[m * BK_STRIDE + k] = __float2bfloat16(result); } @@ -528,7 +557,8 @@ torch::Tensor conv3d_implicit_gemm_cuda(torch::Tensor x_pad, torch::Tensor w_fla }; // Macro to dispatch kernel with all 4 template specializations -#define LAUNCH_WMMA_KERNEL(BM, BN, BK, WM, WN) \ +// FP4_BS is the FP4 quantization block size (independent of BK) +#define LAUNCH_WMMA_KERNEL(BM, BN, BK, WM, WN, FP4_BS) \ { \ constexpr int BK_S = BK + 8; \ constexpr int BN_S = BN + 8; \ @@ -547,28 +577,28 @@ torch::Tensor conv3d_implicit_gemm_cuda(torch::Tensor x_pad, torch::Tensor w_fla }; \ \ if (quant_act && has_bias) { \ - auto kern = conv3d_implicit_gemm_wmma; \ + auto kern = conv3d_implicit_gemm_wmma; \ set_smem(kern); \ kern<<>>(x_pad.data_ptr(), w_flat.data_ptr(), \ bias.data_ptr(), y.data_ptr(), \ act_amax.data_ptr(), Cin, Dp, Hp, Wp, Cout, OD, OH, OW, \ kD, kH, kW, sd, sh, sw, dd, dh, dw, M, K); \ } else if (quant_act) { \ - auto kern = conv3d_implicit_gemm_wmma; \ + auto kern = conv3d_implicit_gemm_wmma; \ set_smem(kern); \ kern<<>>(x_pad.data_ptr(), w_flat.data_ptr(), \ bias.data_ptr(), y.data_ptr(), \ act_amax.data_ptr(), Cin, Dp, Hp, Wp, Cout, OD, OH, OW, \ kD, kH, kW, sd, sh, sw, dd, dh, dw, M, K); \ } else if (has_bias) { \ - auto kern = conv3d_implicit_gemm_wmma; \ + auto kern = conv3d_implicit_gemm_wmma; \ set_smem(kern); \ kern<<>>(x_pad.data_ptr(), w_flat.data_ptr(), \ bias.data_ptr(), y.data_ptr(), \ act_amax.data_ptr(), Cin, Dp, Hp, Wp, Cout, OD, OH, OW, \ kD, kH, kW, sd, sh, sw, dd, dh, dw, M, K); \ } else { \ - auto kern = conv3d_implicit_gemm_wmma; \ + auto kern = conv3d_implicit_gemm_wmma; \ set_smem(kern); \ kern<<>>(x_pad.data_ptr(), w_flat.data_ptr(), \ bias.data_ptr(), y.data_ptr(), \ @@ -577,16 +607,19 @@ torch::Tensor conv3d_implicit_gemm_cuda(torch::Tensor x_pad, torch::Tensor w_fla } \ } - if (fp4_block_size == 128) { - // BLOCK_M=64, BLOCK_N=64, BLOCK_K=128, WARPS_M=2, WARPS_N=4 - // 8 warps = 256 threads -> faster cooperative loading - // Shared: 64*(128+8)*2 + 128*(64+8)*2 = 17,408 + 18,432 = 35,840 bytes (~35KB) - LAUNCH_WMMA_KERNEL(64, 64, 128, 2, 4) + // BLOCK_K=256 always, FP4_BLOCK_SIZE varies + // BLOCK_M=64, BLOCK_N=64, WARPS_M=2, WARPS_N=4 (8 warps = 256 threads) + // Shared: 64*(256+8)*2 + 256*(64+8)*2 = 33,792 + 36,864 = 70,656 bytes (~69KB) + if (fp4_block_size == 16) { + LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4, 16) + } else if (fp4_block_size == 32) { + LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4, 32) + } else if (fp4_block_size == 64) { + LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4, 64) + } else if (fp4_block_size == 128) { + LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4, 128) } else { - // BLOCK_M=64, BLOCK_N=64, BLOCK_K=256, WARPS_M=2, WARPS_N=4 - // 8 warps = 256 threads -> faster cooperative loading - // Shared: 64*(256+8)*2 + 256*(64+8)*2 = 33,792 + 36,864 = 70,656 bytes (~69KB) - LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4) + LAUNCH_WMMA_KERNEL(64, 64, 256, 2, 4, 256) } #undef LAUNCH_WMMA_KERNEL diff --git a/experimental/conv/test_implicit_gemm.py b/experimental/conv/test_implicit_gemm.py index b548ed3f0a..45c4d888ff 100644 --- a/experimental/conv/test_implicit_gemm.py +++ b/experimental/conv/test_implicit_gemm.py @@ -248,27 +248,14 @@ def test_single_output_element(self, cuda_conv3d): class TestConv3dFP4BlockSize: - """Test both FP4 block size configs (affects tile shapes even for non-quant).""" + """Test all FP4 block size configs (BLOCK_K=256 always, FP4_BLOCK_SIZE varies). - def test_block_size_128(self, cuda_conv3d): - """Non-quant conv with BK=128 tile config.""" - x = torch.randn(1, 32, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(64, 32, 3, 3, 3, device="cuda", dtype=torch.float32) - ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)) - out = cuda_conv3d( - x, - w, - stride=(1, 1, 1), - padding=(1, 1, 1), - dilation=(1, 1, 1), - quant_act=False, - fp4_block_size=128, - ) - assert out.shape == ref.shape - assert (out - ref).abs().max().item() < ATOL + Non-quantized path ignores FP4_BLOCK_SIZE, so all should match cuDNN. + """ - def test_block_size_256(self, cuda_conv3d): - """Non-quant conv with BK=256 tile config.""" + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_non_quant_all_block_sizes(self, cuda_conv3d, fp4_block_size): + """Non-quant conv should match cuDNN regardless of fp4_block_size.""" x = torch.randn(1, 32, 8, 8, 8, device="cuda", dtype=torch.float32) w = torch.randn(64, 32, 3, 3, 3, device="cuda", dtype=torch.float32) ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)) @@ -279,7 +266,7 @@ def test_block_size_256(self, cuda_conv3d): padding=(1, 1, 1), dilation=(1, 1, 1), quant_act=False, - fp4_block_size=256, + fp4_block_size=fp4_block_size, ) assert out.shape == ref.shape assert (out - ref).abs().max().item() < ATOL @@ -298,6 +285,294 @@ def test_deterministic(self, cuda_conv3d): assert torch.equal(out1, out2), "Kernel is not deterministic" +# ============================================================================= +# FP4 Quantized Conv3D Tests (fused activation quantization) +# ============================================================================= + + +@pytest.fixture(scope="module") +def cuda_fp4_quant(): + """Import FP4 fake quant for reference comparisons.""" + from experimental.conv.implicit_gemm_cuda import fp4_fake_quant + + return fp4_fake_quant + + +class TestConv3dFP4QuantBlockSizes: + """Test fused FP4 activation quantization with all supported block sizes. + + The kernel applies blockwise FP4 quantization to the im2col'd activation tiles + along the K dimension. We verify correctness by comparing the fused kernel output + against an unfused reference: fp4_fake_quant(im2col) @ fp4_fake_quant(weight). + """ + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_runs_all_block_sizes(self, cuda_conv3d, fp4_block_size): + """All FP4 block sizes should run without errors and produce valid output.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert out.shape == (1, 32, 8, 8, 8) + assert not torch.isnan(out).any(), "Output contains NaN" + assert not torch.isinf(out).any(), "Output contains Inf" + assert out.abs().max() > 0, "Output is all zeros" + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_deterministic(self, cuda_conv3d, fp4_block_size): + """Quantized conv should be deterministic for all block sizes.""" + torch.manual_seed(42) + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + kwargs = { + "stride": (1, 1, 1), + "padding": (1, 1, 1), + "dilation": (1, 1, 1), + "act_amax": act_amax, + "quant_act": True, + "fp4_block_size": fp4_block_size, + } + out1 = cuda_conv3d(x, w, **kwargs) + out2 = cuda_conv3d(x, w, **kwargs) + assert torch.equal(out1, out2), f"Non-deterministic for fp4_block_size={fp4_block_size}" + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_vs_unfused_reference(self, cuda_conv3d, cuda_fp4_quant, fp4_block_size): + """Compare fused kernel vs unfused: fp4(im2col) @ fp4(weight). + + Uses a shape where K is a multiple of 256 so all K-tiles are full + and block boundaries align perfectly between fused and unfused paths. + """ + torch.manual_seed(123) + # K = Cin * kD * kH * kW. Choose Cin so K is a multiple of 256. + # Cin=256, k=1x1x1 -> K=256 (exactly 1 full K-tile) + cin, cout = 256, 64 + x = torch.randn(1, cin, 4, 4, 4, device="cuda", dtype=torch.float32) + w = torch.randn(cout, cin, 1, 1, 1, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + w_amax = w.abs().max().unsqueeze(0) + + # Unfused reference: + # 1. Build im2col matrix (for 1x1x1 kernel, it's just reshape) + n, c, d, h, w_dim = x.shape + im2col = x.permute(0, 2, 3, 4, 1).reshape(-1, cin) # [M, K] + + # 2. FP4 fake-quant both matrices along K with the same block_size + im2col_q = cuda_fp4_quant(im2col, act_amax, fp4_block_size) + w_flat = w.reshape(cout, cin).transpose(0, 1).contiguous() # [K, Cout] + w_flat_q = cuda_fp4_quant(w_flat, w_amax, fp4_block_size) + + # 3. Matmul (in BF16 to match kernel's WMMA path) + ref_out = (im2col_q.bfloat16() @ w_flat_q.bfloat16()).float() + ref_out = ref_out.view(n, d, h, w_dim, cout).permute(0, 4, 1, 2, 3) + + # Note: the fused kernel does NOT quantize weights — weights are passed as-is. + # So for a proper comparison we need the fused kernel with pre-quantized weights. + fused_out_preq = cuda_conv3d( + x, + w_flat_q.transpose(0, 1).reshape(cout, cin, 1, 1, 1), + stride=(1, 1, 1), + padding=(0, 0, 0), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + + # The fused kernel and unfused reference should match closely. + # Differences come from BF16 accumulation order (WMMA 16x16x16 tiles vs flat matmul). + max_diff = (fused_out_preq - ref_out).abs().max().item() + mean_diff = (fused_out_preq - ref_out).abs().mean().item() + # Scale tolerance with K + scaled_atol = ATOL * (cin / 1000.0) ** 0.5 + assert max_diff < scaled_atol, ( + f"fp4_block_size={fp4_block_size}: fused vs unfused max diff {max_diff:.4f} " + f"exceeds tolerance {scaled_atol:.4f}" + ) + assert mean_diff < scaled_atol * 0.1, ( + f"fp4_block_size={fp4_block_size}: mean diff {mean_diff:.6e} too high" + ) + + def test_smaller_block_less_error(self, cuda_conv3d): + """Smaller FP4 block sizes should generally produce lower quantization error. + + Finer-grained blocks capture local ranges better, reducing quant error vs cuDNN. + Test monotonicity: error(16) <= error(32) <= ... <= error(256) (with some tolerance). + Reports detailed accuracy metrics for each block size vs cuDNN baseline. + """ + torch.manual_seed(42) + + # Test multiple shapes to get a comprehensive picture + configs = [ + ("Small K=432", 1, 16, 8, 8, 8, 32, 3, 3, 3), + ("Medium K=1728", 1, 64, 8, 8, 8, 64, 3, 3, 3), + ("Large K=3456", 1, 128, 5, 8, 8, 256, 3, 3, 3), + ] + + block_sizes = [16, 32, 64, 128, 256] + all_errors = {} + + for desc, n, cin, d, h, w_s, cout, kd, kh, kw in configs: + x = torch.randn(n, cin, d, h, w_s, device="cuda", dtype=torch.float32) + w = torch.randn(cout, cin, kd, kh, kw, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + k_size = cin * kd * kh * kw + + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)) + ref_abs_mean = ref.abs().mean().item() + + # Also compute no-quant baseline (BF16 rounding only) + out_nq = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + quant_act=False, + ) + nq_diff = (out_nq - ref).abs() + + print( + f"\n {desc} (K={k_size}), output range [{ref.min().item():.1f}, {ref.max().item():.1f}]" + ) + print( + f" {'Block Size':>10} | {'Max Diff':>10} | {'Mean Diff':>10} | {'RMSE':>10} | {'Rel Err%':>8}" + ) + print(f" {'-' * 10}-+-{'-' * 10}-+-{'-' * 10}-+-{'-' * 10}-+-{'-' * 8}") + print( + f" {'no-quant':>10} | {nq_diff.max().item():>10.4f} | " + f"{nq_diff.mean().item():>10.6f} | " + f"{((out_nq - ref) ** 2).mean().sqrt().item():>10.4f} | " + f"{nq_diff.mean().item() / ref_abs_mean * 100:>7.3f}%" + ) + + errors = {} + for bs in block_sizes: + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=bs, + ) + diff = (out - ref).abs() + max_d = diff.max().item() + mean_d = diff.mean().item() + rmse = ((out - ref) ** 2).mean().sqrt().item() + rel_err = mean_d / ref_abs_mean * 100 + errors[bs] = mean_d + print( + f" {bs:>10} | {max_d:>10.4f} | {mean_d:>10.6f} | " + f"{rmse:>10.4f} | {rel_err:>7.3f}%" + ) + all_errors[desc] = errors + + # Monotonicity check on the medium config + errors = all_errors["Medium K=1728"] + for smaller, larger in [(16, 64), (16, 256), (32, 256), (64, 256)]: + assert errors[smaller] <= errors[larger] * 1.2, ( + f"Expected error({smaller})={errors[smaller]:.6f} <= " + f"error({larger})={errors[larger]:.6f} * 1.2" + ) + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_with_bias(self, cuda_conv3d, fp4_block_size): + """FP4 quantized conv with bias for all block sizes.""" + torch.manual_seed(42) + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + b = torch.randn(32, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + out = cuda_conv3d( + x, + w, + bias=b, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert out.shape == (1, 32, 8, 8, 8) + assert not torch.isnan(out).any() + # Bias should shift output values + out_no_bias = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert not torch.equal(out, out_no_bias), "Bias had no effect" + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_k_not_aligned(self, cuda_conv3d, fp4_block_size): + """FP4 quant with K not aligned to BLOCK_K or fp4_block_size. + + K = Cin * kDHW = 7 * 27 = 189. The last K-tile has partial data (zeros padded). + """ + torch.manual_seed(42) + x = torch.randn(1, 7, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(16, 7, 3, 3, 3, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert out.shape == (1, 16, 8, 8, 8) + assert not torch.isnan(out).any() + assert out.abs().max() > 0 + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_realistic_shape(self, cuda_conv3d, fp4_block_size): + """Realistic video diffusion shape with all FP4 block sizes.""" + torch.manual_seed(42) + x = torch.randn(1, 128, 5, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(256, 128, 3, 3, 3, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert out.shape == (1, 256, 5, 8, 8) + assert not torch.isnan(out).any() + assert out.abs().max() > 0 + + # ============================================================================= # FP4 Fake Quantization Tests # ============================================================================= From 906729015a3d4c27a22178619fd18b6ea5fc8f47 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Mon, 9 Mar 2026 23:55:38 +0000 Subject: [PATCH 11/25] Update the README Signed-off-by: Jingyu Xin --- experimental/conv/README.md | 2 +- experimental/conv/test_implicit_gemm.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/experimental/conv/README.md b/experimental/conv/README.md index 9fa0d80f64..65b7cc5563 100644 --- a/experimental/conv/README.md +++ b/experimental/conv/README.md @@ -8,7 +8,7 @@ This code is kept under `experimental/` by design and is **not** part of the sta | Model/Framework | Supported | Notes | |-----------------|-----------|-------| -| Video diffusion backbones using Conv3D | Partial | Intended for experimentation and microbenchmarking | +| Video diffusion VAE Conv3D layers | Tested | Validated on VAE encoder/decoder Conv3D layers in video diffusion models | | Generic LLM backbones | No | Conv3D path is not relevant | | End-to-end ModelOpt PTQ/QAT pipeline | No | Not wired into formal quantization/export/compress flows | diff --git a/experimental/conv/test_implicit_gemm.py b/experimental/conv/test_implicit_gemm.py index 45c4d888ff..7f0531b47f 100644 --- a/experimental/conv/test_implicit_gemm.py +++ b/experimental/conv/test_implicit_gemm.py @@ -22,6 +22,8 @@ import torch import torch.nn.functional as F +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.fixture(scope="module") def cuda_conv3d(): From 40757f83a1e737ad450efe262129a1edf6868090 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Tue, 10 Mar 2026 00:00:48 +0000 Subject: [PATCH 12/25] Update Signed-off-by: Jingyu Xin --- experimental/conv/implicit_gemm_cuda.py | 9 +++++++-- experimental/conv/test_implicit_gemm.py | 14 ++++++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/experimental/conv/implicit_gemm_cuda.py b/experimental/conv/implicit_gemm_cuda.py index b85b373dd5..6bb49afa51 100644 --- a/experimental/conv/implicit_gemm_cuda.py +++ b/experimental/conv/implicit_gemm_cuda.py @@ -110,10 +110,15 @@ def conv3d_implicit_gemm_cuda( cuda_mod = _get_cuda_module() - assert x.ndim == 5 and w.ndim == 5 + if x.ndim != 5 or w.ndim != 5: + raise ValueError(f"Expected 5D tensors, got x.ndim={x.ndim}, w.ndim={w.ndim}") n_batch, cin, d, h, w_in = x.shape cout, cin_w, kd, kh, kw = w.shape - assert cin_w == cin + if cin_w != cin: + raise ValueError( + f"Grouped convolution is not supported (x has {cin} input channels, " + f"w has {cin_w}). This kernel requires groups=1." + ) sd, sh, sw = _triple(stride) dd, dh, dw = _triple(dilation) diff --git a/experimental/conv/test_implicit_gemm.py b/experimental/conv/test_implicit_gemm.py index 7f0531b47f..af52660e42 100644 --- a/experimental/conv/test_implicit_gemm.py +++ b/experimental/conv/test_implicit_gemm.py @@ -243,10 +243,20 @@ def test_output_size_1x1x1(self, cuda_conv3d): _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) def test_single_output_element(self, cuda_conv3d): - """M=1: batch=1, output 1x1x1.""" + """M=1: batch=1, output 1x1x1. + + With only one output element, mean diff == max diff, so the generic + helper's mean_diff < scaled_atol * 0.1 check is too tight. Use max diff only. + """ x = torch.randn(1, 4, 3, 3, 3, device="cuda", dtype=torch.float32) w = torch.randn(1, 4, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1)) + out = cuda_conv3d( + x, w, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), quant_act=False + ) + assert out.shape == ref.shape + max_diff = (out - ref).abs().max().item() + assert max_diff < ATOL, f"Max abs diff {max_diff:.6e} exceeds tolerance {ATOL}" class TestConv3dFP4BlockSize: From 863dd1a3a5e8688fc3967c7fd9179c52529ce941 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Tue, 10 Mar 2026 02:09:34 +0000 Subject: [PATCH 13/25] Added the SM version constraint Signed-off-by: Jingyu Xin --- experimental/conv/implicit_gemm_cuda.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/experimental/conv/implicit_gemm_cuda.py b/experimental/conv/implicit_gemm_cuda.py index 6bb49afa51..713b5f82f9 100644 --- a/experimental/conv/implicit_gemm_cuda.py +++ b/experimental/conv/implicit_gemm_cuda.py @@ -29,10 +29,22 @@ _cuda_module = None +_MIN_SM_MAJOR = 8 # BF16 WMMA tensor cores require SM80+ (Ampere and newer) + + def _get_cuda_module(): """Get or compile the CUDA module.""" global _cuda_module if _cuda_module is None: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. This kernel requires a CUDA GPU.") + major, minor = torch.cuda.get_device_capability() + if major < _MIN_SM_MAJOR: + raise RuntimeError( + f"This kernel requires SM{_MIN_SM_MAJOR}0+ (Ampere or newer) for BF16 WMMA " + f"tensor cores, but the current GPU has SM{major}{minor}." + ) + from torch.utils.cpp_extension import load _cuda_module = load( From abad5e93e537e52fafb1811cb8b6996c33daccfa Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 1 Apr 2026 23:44:47 -0700 Subject: [PATCH 14/25] Update the e2e conv3d test Signed-off-by: Jingyu Xin --- .../diffusers/quantization/calibration.py | 13 +- .../diffusers/quantization/models_utils.py | 50 +++++--- .../quantization/pipeline_manager.py | 41 ++++--- examples/diffusers/quantization/quantize.py | 114 ++++++++++++------ .../diffusers/quantization/quantize_config.py | 3 +- examples/diffusers/quantization/utils.py | 40 ++++++ .../plugins/diffusion/diffusers.py | 45 +++++++ 7 files changed, 229 insertions(+), 77 deletions(-) diff --git a/examples/diffusers/quantization/calibration.py b/examples/diffusers/quantization/calibration.py index d51523e575..172173288b 100644 --- a/examples/diffusers/quantization/calibration.py +++ b/examples/diffusers/quantization/calibration.py @@ -107,11 +107,12 @@ def run_calibration(self, batched_prompts: list[list[str]]) -> None: def _run_wan_video_calibration( self, prompt_batch: list[str], extra_args: dict[str, Any] ) -> None: + extra_params = self.pipeline_manager.config.extra_params kwargs = {} kwargs["negative_prompt"] = extra_args["negative_prompt"] - kwargs["height"] = extra_args["height"] - kwargs["width"] = extra_args["width"] - kwargs["num_frames"] = extra_args["num_frames"] + kwargs["height"] = extra_params.get("height", extra_args["height"]) + kwargs["width"] = extra_params.get("width", extra_args["width"]) + kwargs["num_frames"] = extra_params.get("num_frames", extra_args["num_frames"]) kwargs["guidance_scale"] = extra_args["guidance_scale"] if "guidance_scale_2" in extra_args: kwargs["guidance_scale_2"] = extra_args["guidance_scale_2"] @@ -143,7 +144,11 @@ def _run_ltx2_calibration(self, prompt_batch: list[str], extra_args: dict[str, A "images": extra_params.get("images", []), "tiling_config": extra_params.get("tiling_config", TilingConfig.default()), } - self.pipe(prompt=prompt, **kwargs) + decoded_video, decoded_audio = self.pipe(prompt=prompt, **kwargs) + # vae_decode_video returns a lazy generator — consume it so the + # video decoder's forward() actually runs during calibration. + for _ in decoded_video: + pass def _run_ltx_video_calibration( self, prompt_batch: list[str], extra_args: dict[str, Any] diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index 1e90c54714..076bb05bbe 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -33,7 +33,9 @@ from utils import ( filter_func_default, filter_func_flux_dev, + filter_func_ltx2_vae, filter_func_ltx_video, + filter_func_wan_vae, filter_func_wan_video, ) @@ -54,31 +56,43 @@ class ModelType(str, Enum): WAN22_T2V_5b = "wan2.2-t2v-5b" -def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: +# Filter function registry keyed by (ModelType, backbone_name). +# Backbone names like "unet", "transformer" are DiT/UNet backbones; +# "video_decoder", "vae" etc. are VAE backbones. +_FILTER_FUNC_REGISTRY: dict[tuple[ModelType, str], Callable[[str], bool]] = { + # ---- DiT / UNet backbones ---- + (ModelType.SDXL_BASE, "unet"): filter_func_default, + (ModelType.SDXL_TURBO, "unet"): filter_func_default, + (ModelType.SD3_MEDIUM, "transformer"): filter_func_default, + (ModelType.SD35_MEDIUM, "transformer"): filter_func_default, + (ModelType.FLUX_DEV, "transformer"): filter_func_flux_dev, + (ModelType.FLUX_SCHNELL, "transformer"): filter_func_default, + (ModelType.FLUX2_DEV, "transformer"): filter_func_flux_dev, + (ModelType.LTX_VIDEO_DEV, "transformer"): filter_func_ltx_video, + (ModelType.LTX2, "transformer"): filter_func_ltx_video, + (ModelType.WAN22_T2V_14b, "transformer"): filter_func_wan_video, + (ModelType.WAN22_T2V_5b, "transformer"): filter_func_wan_video, + # ---- VAE backbones ---- + (ModelType.LTX2, "video_decoder"): filter_func_ltx2_vae, + (ModelType.WAN22_T2V_14b, "vae"): filter_func_wan_vae, + (ModelType.WAN22_T2V_5b, "vae"): filter_func_wan_vae, +} + + +def get_model_filter_func( + model_type: ModelType, backbone_name: str = "transformer" +) -> Callable[[str], bool]: """ - Get the appropriate filter function for a given model type. + Get the appropriate filter function for a given model type and backbone. Args: model_type: The model type enum + backbone_name: Name of the backbone being quantized (e.g., "transformer", "video_decoder") Returns: - A filter function appropriate for the model type + A filter function appropriate for the model type and backbone """ - filter_func_map = { - ModelType.FLUX_DEV: filter_func_flux_dev, - ModelType.FLUX_SCHNELL: filter_func_default, - ModelType.FLUX2_DEV: filter_func_flux_dev, - ModelType.SDXL_BASE: filter_func_default, - ModelType.SDXL_TURBO: filter_func_default, - ModelType.SD3_MEDIUM: filter_func_default, - ModelType.SD35_MEDIUM: filter_func_default, - ModelType.LTX_VIDEO_DEV: filter_func_ltx_video, - ModelType.LTX2: filter_func_ltx_video, - ModelType.WAN22_T2V_14b: filter_func_wan_video, - ModelType.WAN22_T2V_5b: filter_func_wan_video, - } - - return filter_func_map.get(model_type, filter_func_default) + return _FILTER_FUNC_REGISTRY.get((model_type, backbone_name), filter_func_default) # Model registry with HuggingFace model IDs diff --git a/examples/diffusers/quantization/pipeline_manager.py b/examples/diffusers/quantization/pipeline_manager.py index 6575e7e83d..bfd3b2544e 100644 --- a/examples/diffusers/quantization/pipeline_manager.py +++ b/examples/diffusers/quantization/pipeline_manager.py @@ -41,6 +41,7 @@ def __init__(self, config: ModelConfig, logger: logging.Logger): self.pipe: Any | None = None self.pipe_upsample: LTXLatentUpsamplePipeline | None = None # For LTX-Video upsampling self._transformer: torch.nn.Module | None = None + self._video_decoder: torch.nn.Module | None = None @staticmethod def create_pipeline_from( @@ -158,10 +159,10 @@ def setup_device(self) -> None: def get_backbone(self) -> torch.nn.Module: """ - Get the backbone model (transformer or UNet). + Get the backbone model(s). Returns: - Backbone model module + Single module if one backbone, ModuleList if multiple. """ if not self.pipe: raise RuntimeError("Pipeline not created. Call create_pipeline() first.") @@ -173,25 +174,25 @@ def get_backbone(self) -> torch.nn.Module: def iter_backbones(self) -> Iterator[tuple[str, torch.nn.Module]]: """ - Yield backbone modules by name, based on a backbone spec. - - Yields: - (backbone_name, module) pairs + Yield (backbone_name, module) pairs. """ if not self.pipe: raise RuntimeError("Pipeline not created. Call create_pipeline() first.") names = list(self.config.backbone) + if not names: + raise RuntimeError("No backbone names provided.") if self.config.model_type == ModelType.LTX2: - self._ensure_ltx2_transformer_cached() - name = names[0] if names else "transformer" - yield name, self._transformer + for name in names: + if name == "video_decoder": + self._ensure_ltx2_video_decoder_cached() + yield name, self._video_decoder + else: + self._ensure_ltx2_transformer_cached() + yield name, self._transformer return - if not names: - raise RuntimeError("No backbone names provided.") - for name in names: module = getattr(self.pipe, name, None) if module is None: @@ -206,6 +207,16 @@ def _ensure_ltx2_transformer_cached(self) -> None: self.pipe.stage_1_model_ledger.transformer = lambda: transformer self._transformer = transformer + def _ensure_ltx2_video_decoder_cached(self) -> None: + if not self.pipe: + raise RuntimeError("Pipeline not created. Call create_pipeline() first.") + if self._video_decoder is None: + video_decoder = self.pipe.stage_1_model_ledger.video_decoder() + # Cache it so subsequent calls return the same (quantized) instance + self.pipe.stage_1_model_ledger.video_decoder = lambda: video_decoder + self.pipe.stage_2_model_ledger.video_decoder = lambda: video_decoder + self._video_decoder = video_decoder + def _create_ltx2_pipeline(self) -> Any: params = dict(self.config.extra_params) checkpoint_path = params.pop("checkpoint_path", None) @@ -228,7 +239,6 @@ def _create_ltx2_pipeline(self) -> Any: raise ValueError("Missing required extra_param: gemma_root.") from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps - from ltx_core.quantization import QuantizationPolicy from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline distilled_lora = [ @@ -244,13 +254,12 @@ def _create_ltx2_pipeline(self) -> Any: "spatial_upsampler_path": str(spatial_upsampler_path), "gemma_root": str(gemma_root), "loras": [], - "quantization": QuantizationPolicy.fp8_cast() if fp8_quantization else None, + "fp8transformer": bool(fp8_quantization), } pipeline_kwargs.update(params) return TI2VidTwoStagesPipeline(**pipeline_kwargs) def print_quant_summary(self): - backbone_pairs = list(self.iter_backbones()) - for name, backbone in backbone_pairs: + for name, backbone in self.iter_backbones(): self.logger.info(f"{name} quantization info:") mtq.print_quant_summary(backbone) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 612357f6ea..f829aef981 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -134,6 +134,13 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: quant_config = NVFP4_FP8_MHA_CONFIG else: quant_config = NVFP4_DEFAULT_CONFIG + # Override block size if non-default + if self.config.block_size != 16: + import copy + quant_config = copy.deepcopy(quant_config) + for key in ("*weight_quantizer", "*input_quantizer", "**weight_quantizer", "**input_quantizer"): + if key in quant_config["quant_cfg"] and "block_sizes" in quant_config["quant_cfg"][key]: + quant_config["quant_cfg"][key]["block_sizes"][-1] = self.config.block_size else: raise NotImplementedError(f"Unknown format {self.config.format}") if self.config.quantize_mha: @@ -153,6 +160,7 @@ def quantize_model( backbone: torch.nn.Module, quant_config: Any, forward_loop: callable, # type: ignore[valid-type] + backbone_name: str = "transformer", ) -> torch.nn.Module: """ Apply quantization to the model. @@ -161,15 +169,16 @@ def quantize_model( backbone: Model backbone to quantize quant_config: Quantization configuration forward_loop: Forward pass function for calibration + backbone_name: Name of the backbone being quantized """ self.logger.info("Checking for LoRA layers...") check_lora(backbone) - self.logger.info("Starting model quantization...") + self.logger.info(f"Starting model quantization for {backbone_name}...") mtq.quantize(backbone, quant_config, forward_loop) # Get model-specific filter function - model_filter_func = get_model_filter_func(self.model_config.model_type) - self.logger.info(f"Using filter function for {self.model_config.model_type.value}") + model_filter_func = get_model_filter_func(self.model_config.model_type, backbone_name) + self.logger.info(f"Using filter function for {self.model_config.model_type.value}/{backbone_name}") self.logger.info("Disabling specific quantizers...") mtq.disable_quantizer(backbone, model_filter_func) @@ -216,20 +225,26 @@ def _has_conv_layers(self, model: torch.nn.Module) -> bool: return True return False - def save_checkpoint(self, backbone: torch.nn.Module) -> None: + def save_checkpoint( + self, + backbone: torch.nn.Module, + backbone_name: str | None = None, + ) -> None: """ Save quantized model checkpoint. Args: backbone: The quantized backbone module to save (must be the same instance that was passed to mtq.quantize, as it carries the _modelopt_state). + backbone_name: Optional name for the backbone file (defaults to "backbone"). """ if not self.config.quantized_torch_ckpt_path: return ckpt_path = self.config.quantized_torch_ckpt_path ckpt_path.mkdir(parents=True, exist_ok=True) - target_path = ckpt_path / "backbone.pt" + filename = f"{backbone_name}.pt" if backbone_name else "backbone.pt" + target_path = ckpt_path / filename self.logger.info(f"Saving backbone to {target_path}") mto.save(backbone, str(target_path)) @@ -287,14 +302,20 @@ def restore_checkpoint(self) -> None: if self.pipeline_manager is None: raise RuntimeError("Pipeline manager is required for per-backbone checkpoints.") - backbone = self.pipeline_manager.get_backbone() if restore_path.exists() and restore_path.is_dir(): - source_path = restore_path / "backbone.pt" - if not source_path.exists(): - raise FileNotFoundError(f"Backbone checkpoint not found: {source_path}") - self.logger.info(f"Restoring backbone from {source_path}") - mto.restore(backbone, str(source_path)) - self.logger.info("Backbone checkpoints restored successfully") + for backbone_name, backbone in self.pipeline_manager.iter_backbones(): + # Try named checkpoint first, fall back to legacy "backbone.pt" + source_path = restore_path / f"{backbone_name}.pt" + if not source_path.exists(): + source_path = restore_path / "backbone.pt" + if not source_path.exists(): + raise FileNotFoundError( + f"Checkpoint not found for {backbone_name}: tried " + f"{restore_path / f'{backbone_name}.pt'} and {restore_path / 'backbone.pt'}" + ) + self.logger.info(f"Restoring {backbone_name} from {source_path}") + mto.restore(backbone, str(source_path)) + self.logger.info("Checkpoints restored successfully") # TODO: should not do the any data type def export_hf_ckpt(self, pipe: Any, model_config: ModelConfig | None = None) -> None: @@ -363,9 +384,9 @@ def create_argument_parser() -> argparse.ArgumentParser: nargs="+", default=None, help=( - "Model backbone(s) in the DiffusionPipeline to work on. " - "Provide one name or multiple names separated by space or comma. " - "If not provided use default based on model type." + "Model backbone(s) in the DiffusionPipeline to quantize. " + "Provide one or more names (e.g., 'transformer', 'video_decoder'). " + "If not provided, uses default based on model type." ), ) model_group.add_argument( @@ -443,6 +464,12 @@ def create_argument_parser() -> argparse.ArgumentParser: action="store_true", help="Compress quantized weights to reduce memory footprint (FP8/FP4 only)", ) + quant_group.add_argument( + "--block-size", + type=int, + default=16, + help="Block size for NVFP4 quantization (default: 16)", + ) calib_group = parser.add_argument_group("Calibration Configuration") calib_group.add_argument("--batch-size", type=int, default=2, help="Batch size for calibration") @@ -530,6 +557,7 @@ def main() -> None: lowrank=args.lowrank, quantize_mha=args.quantize_mha, compress=args.compress, + block_size=args.block_size, ) if args.prompts_file is not None: @@ -566,7 +594,6 @@ def main() -> None: pipe = pipeline_manager.create_pipeline() pipeline_manager.setup_device() - backbone = pipeline_manager.get_backbone() export_manager = ExportManager(export_config, logger, pipeline_manager) if export_config.restore_from and export_config.restore_from.exists(): @@ -576,37 +603,48 @@ def main() -> None: logger.info("Initializing calibration...") calibrator = Calibrator(pipeline_manager, calib_config, model_config.model_type, logger) batched_prompts = calibrator.load_and_batch_prompts() - quantizer = Quantizer(quant_config, model_config, logger) - backbone_quant_config = quantizer.get_quant_config(calib_config.n_steps, backbone) - # Pipe loads the ckpt just before the inference. - def forward_loop(mod): - calibrator.run_calibration(batched_prompts) + for backbone_name, backbone in pipeline_manager.iter_backbones(): + logger.info(f"Quantizing backbone: {backbone_name}") + backbone_quant_config = quantizer.get_quant_config( + calib_config.n_steps, backbone + ) - quantizer.quantize_model(backbone, backbone_quant_config, forward_loop) + # Full pipeline inference for calibration — cached modules + # (transformer, video_decoder) are exercised during __call__ + def forward_loop(mod): + calibrator.run_calibration(batched_prompts) - # Compress model weights if requested (only for FP8/FP4) - if quant_config.compress: - logger.info("Compressing model weights to reduce memory footprint...") - mtq.compress(backbone) - logger.info("Model compression completed") + quantizer.quantize_model( + backbone, backbone_quant_config, forward_loop, + backbone_name=backbone_name, + ) - export_manager.save_checkpoint(backbone) + # Compress model weights if requested (only for FP8/FP4) + if quant_config.compress: + logger.info(f"Compressing {backbone_name} weights...") + mtq.compress(backbone) + logger.info(f"{backbone_name} compression completed") - # TODO (Jingyu): To update this function, as we are focusing more on the torch deployment side. - check_conv_and_mha( - backbone, quant_config.format == QuantFormat.FP4, quant_config.quantize_mha - ) + # For VAE backbones, skip check_conv_and_mha — the whole point + # of VAE quantization is to quantize Conv layers. + if backbone_name not in ("video_decoder", "vae"): + check_conv_and_mha( + backbone, quant_config.format == QuantFormat.FP4, quant_config.quantize_mha + ) + + export_manager.save_checkpoint(backbone, backbone_name) pipeline_manager.print_quant_summary() - export_manager.export_onnx( - pipe, - backbone, - model_config.model_type, - quant_config.format, - ) + for backbone_name, backbone in pipeline_manager.iter_backbones(): + export_manager.export_onnx( + pipe, + backbone, + model_config.model_type, + quant_config.format, + ) export_manager.export_hf_ckpt(pipe, model_config) diff --git a/examples/diffusers/quantization/quantize_config.py b/examples/diffusers/quantization/quantize_config.py index 980d39f31f..98ea716ab5 100644 --- a/examples/diffusers/quantization/quantize_config.py +++ b/examples/diffusers/quantization/quantize_config.py @@ -78,6 +78,7 @@ class QuantizationConfig: lowrank: int = 32 # SVDQuant lowrank quantize_mha: bool = False compress: bool = False + block_size: int = 16 # NVFP4 block size def validate(self) -> None: """Validate configuration consistency.""" @@ -119,7 +120,7 @@ class ModelConfig: model_type: ModelType = ModelType.FLUX_DEV model_dtype: dict[str, torch.dtype] = field(default_factory=lambda: {"default": torch.float16}) - backbone: str = "" + backbone: list[str] = field(default_factory=list) trt_high_precision_dtype: DataType = DataType.HALF override_model_path: Path | None = None cpu_offloading: bool = False diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index 0c38fc2860..9ffc580784 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -87,6 +87,46 @@ def filter_func_flux_dev(name: str) -> bool: return pattern.match(name) is not None +def filter_func_ltx2_vae(name: str) -> bool: + """Filter function for LTX-2 VAE decoder layers that should NOT be quantized. + + Only keeps conv1/conv2 inside resnet blocks within up_blocks. + Disables everything else: conv_in, conv_out, conv_shortcut, upsamplers, + per_channel_statistics, timestep layers, etc. + """ + # Pattern matching conv1/conv2 inside up_blocks at any nesting depth + # e.g. up_blocks.0.conv1, up_blocks.2.res_blocks.0.conv2 + keep_pattern = re.compile(r".*up_blocks\.\d+\.(?:res_blocks\.\d+\.)?conv[12](?:\.|$)") + # If the layer matches the keep pattern, do NOT disable it (return False) + if keep_pattern.match(name): + return False + # Disable everything else + return True + + +def filter_func_wan_vae(name: str) -> bool: + """Filter function for WAN 2.2 VAE layers that should NOT be quantized. + + Only keeps conv1/conv2 inside WanResidualBlock resnets within + encoder.down_blocks, encoder.mid_block, decoder.mid_block, decoder.up_blocks. + Disables: conv_in, conv_out, conv_shortcut, upsamplers, downsamplers (resample), + time_conv, quant_conv, post_quant_conv, attention layers (to_qkv, proj). + """ + # Keep conv1/conv2 in resnet blocks: + # 14B flat: encoder.down_blocks.N.conv1/conv2 + # 5B nested: encoder.down_blocks.N.resnets.M.conv1/conv2 + # Both: encoder/decoder.mid_block.resnets.N.conv1/conv2 + # Both: decoder.up_blocks.N.resnets.M.conv1/conv2 + keep_pattern = re.compile( + r".*(down_blocks\.\d+\.(?:resnets\.\d+\.)?conv[12]" + r"|mid_block\.resnets\.\d+\.conv[12]" + r"|up_blocks\.\d+\.resnets\.\d+\.conv[12])(?:\.|$)" + ) + if keep_pattern.match(name): + return False + return True + + def filter_func_wan_video(name: str) -> bool: """Filter function specifically for WAN-Video models.""" pattern = re.compile( diff --git a/modelopt/torch/quantization/plugins/diffusion/diffusers.py b/modelopt/torch/quantization/plugins/diffusion/diffusers.py index f9ae55b3e2..336d09c30e 100644 --- a/modelopt/torch/quantization/plugins/diffusion/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusion/diffusers.py @@ -63,6 +63,8 @@ QuantModuleRegistry, TensorQuantizer, ) +from ...nn.modules.quant_conv import _QuantConv3d +from ...utils import is_torch_export_mode from ..custom import _QuantFunctionalMixin onnx_dtype_map = { @@ -278,3 +280,46 @@ def symbolic( high_precision_flag, disable_fp8_mha, ) + + +# --------------------------------------------------------------------------- +# WanCausalConv3d quantization support (diffusers VAE) +# --------------------------------------------------------------------------- +try: + from diffusers.models.autoencoders.autoencoder_kl_wan import WanCausalConv3d + + @QuantModuleRegistry.register({WanCausalConv3d: "WanCausalConv3d"}) + class _QuantDiffusersWanCausalConv3d(_QuantConv3d): + """Quantized diffusers WanCausalConv3d. + + WanCausalConv3d inherits from nn.Conv3d directly, so we override + forward to apply causal padding + cache before the quantized conv. + """ + + @staticmethod + def _get_quantized_weight( + module: "QuantLinearConvBase", weight: torch.Tensor + ) -> torch.Tensor: + if module._enable_weight_quantization or is_torch_export_mode(): + return module.weight_quantizer(weight) + return weight + + def forward(self, x, cache_x=None): + with self.quantize_weight(): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + input = self.input_quantizer(x) + # Call nn.Conv3d.forward (grandparent), skipping WanCausalConv3d.forward + output = torch.nn.Conv3d.forward(self, input) + + if isinstance(output, tuple): + return (self.output_quantizer(output[0]), *output[1:]) + return self.output_quantizer(output) + +except ImportError: + pass From b5c89ed11b6a53e76932188298fe654bc28029d8 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 9 Apr 2026 22:28:12 +0000 Subject: [PATCH 15/25] Refresh the code Signed-off-by: Jingyu Xin --- .../diffusers/quantization/models_utils.py | 43 +++------ .../quantization/pipeline_manager.py | 49 ++++------ examples/diffusers/quantization/quantize.py | 47 +++++----- examples/diffusers/quantization/utils.py | 36 ++----- .../quantization/nn/modules/quant_conv.py | 94 ++++++++----------- .../plugins/diffusion/diffusers.py | 40 +++----- 6 files changed, 115 insertions(+), 194 deletions(-) diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index 076bb05bbe..b59744282f 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -56,23 +56,16 @@ class ModelType(str, Enum): WAN22_T2V_5b = "wan2.2-t2v-5b" -# Filter function registry keyed by (ModelType, backbone_name). -# Backbone names like "unet", "transformer" are DiT/UNet backbones; -# "video_decoder", "vae" etc. are VAE backbones. -_FILTER_FUNC_REGISTRY: dict[tuple[ModelType, str], Callable[[str], bool]] = { - # ---- DiT / UNet backbones ---- - (ModelType.SDXL_BASE, "unet"): filter_func_default, - (ModelType.SDXL_TURBO, "unet"): filter_func_default, - (ModelType.SD3_MEDIUM, "transformer"): filter_func_default, - (ModelType.SD35_MEDIUM, "transformer"): filter_func_default, - (ModelType.FLUX_DEV, "transformer"): filter_func_flux_dev, - (ModelType.FLUX_SCHNELL, "transformer"): filter_func_default, - (ModelType.FLUX2_DEV, "transformer"): filter_func_flux_dev, - (ModelType.LTX_VIDEO_DEV, "transformer"): filter_func_ltx_video, - (ModelType.LTX2, "transformer"): filter_func_ltx_video, - (ModelType.WAN22_T2V_14b, "transformer"): filter_func_wan_video, - (ModelType.WAN22_T2V_5b, "transformer"): filter_func_wan_video, - # ---- VAE backbones ---- +_FILTER_FUNC_MAP: dict[ModelType, Callable[[str], bool]] = { + ModelType.FLUX_DEV: filter_func_flux_dev, + ModelType.FLUX2_DEV: filter_func_flux_dev, + ModelType.LTX_VIDEO_DEV: filter_func_ltx_video, + ModelType.LTX2: filter_func_ltx_video, + ModelType.WAN22_T2V_14b: filter_func_wan_video, + ModelType.WAN22_T2V_5b: filter_func_wan_video, +} + +_VAE_FILTER_FUNC_MAP: dict[tuple[ModelType, str], Callable[[str], bool]] = { (ModelType.LTX2, "video_decoder"): filter_func_ltx2_vae, (ModelType.WAN22_T2V_14b, "vae"): filter_func_wan_vae, (ModelType.WAN22_T2V_5b, "vae"): filter_func_wan_vae, @@ -82,17 +75,11 @@ class ModelType(str, Enum): def get_model_filter_func( model_type: ModelType, backbone_name: str = "transformer" ) -> Callable[[str], bool]: - """ - Get the appropriate filter function for a given model type and backbone. - - Args: - model_type: The model type enum - backbone_name: Name of the backbone being quantized (e.g., "transformer", "video_decoder") - - Returns: - A filter function appropriate for the model type and backbone - """ - return _FILTER_FUNC_REGISTRY.get((model_type, backbone_name), filter_func_default) + """Get the appropriate filter function for a given model type and backbone.""" + vae_func = _VAE_FILTER_FUNC_MAP.get((model_type, backbone_name)) + if vae_func is not None: + return vae_func + return _FILTER_FUNC_MAP.get(model_type, filter_func_default) # Model registry with HuggingFace model IDs diff --git a/examples/diffusers/quantization/pipeline_manager.py b/examples/diffusers/quantization/pipeline_manager.py index bfd3b2544e..b52ef4852e 100644 --- a/examples/diffusers/quantization/pipeline_manager.py +++ b/examples/diffusers/quantization/pipeline_manager.py @@ -58,23 +58,20 @@ def create_pipeline_from( Raises: ValueError: If model type is unsupported """ - try: - pipeline_cls = MODEL_PIPELINE[model_type] - if pipeline_cls is None: - raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.") - model_id = ( - MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path - ) - pipe = pipeline_cls.from_pretrained( - model_id, - torch_dtype=torch_dtype, - use_safetensors=True, - **MODEL_DEFAULTS[model_type].get("from_pretrained_extra_args", {}), - ) - pipe.set_progress_bar_config(disable=True) - return pipe - except Exception as e: - raise e + pipeline_cls = MODEL_PIPELINE[model_type] + if pipeline_cls is None: + raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.") + model_id = ( + MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path + ) + pipe = pipeline_cls.from_pretrained( + model_id, + torch_dtype=torch_dtype, + use_safetensors=True, + **MODEL_DEFAULTS[model_type].get("from_pretrained_extra_args", {}), + ) + pipe.set_progress_bar_config(disable=True) + return pipe def create_pipeline(self) -> Any: """ @@ -157,21 +154,6 @@ def setup_device(self) -> None: self.logger.info("Enabling VAE tiling for LTX-Video") self.pipe.vae.enable_tiling() - def get_backbone(self) -> torch.nn.Module: - """ - Get the backbone model(s). - - Returns: - Single module if one backbone, ModuleList if multiple. - """ - if not self.pipe: - raise RuntimeError("Pipeline not created. Call create_pipeline() first.") - - backbone_pairs = list(self.iter_backbones()) - if len(backbone_pairs) == 1: - return backbone_pairs[0][1] - return torch.nn.ModuleList([module for _, module in backbone_pairs]) - def iter_backbones(self) -> Iterator[tuple[str, torch.nn.Module]]: """ Yield (backbone_name, module) pairs. @@ -239,6 +221,7 @@ def _create_ltx2_pipeline(self) -> Any: raise ValueError("Missing required extra_param: gemma_root.") from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps + from ltx_core.quantization import QuantizationPolicy from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline distilled_lora = [ @@ -254,7 +237,7 @@ def _create_ltx2_pipeline(self) -> Any: "spatial_upsampler_path": str(spatial_upsampler_path), "gemma_root": str(gemma_root), "loras": [], - "fp8transformer": bool(fp8_quantization), + "quantization": QuantizationPolicy.fp8_cast() if fp8_quantization else None, } pipeline_kwargs.update(params) return TI2VidTwoStagesPipeline(**pipeline_kwargs) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 5f320cc425..944b395536 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -137,10 +137,11 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: # Override block size if non-default if self.config.block_size != 16: import copy + quant_config = copy.deepcopy(quant_config) - for key in ("*weight_quantizer", "*input_quantizer", "**weight_quantizer", "**input_quantizer"): - if key in quant_config["quant_cfg"] and "block_sizes" in quant_config["quant_cfg"][key]: - quant_config["quant_cfg"][key]["block_sizes"][-1] = self.config.block_size + for entry in quant_config["quant_cfg"]: + if isinstance(entry, dict) and "block_sizes" in entry.get("cfg", {}): + entry["cfg"]["block_sizes"][-1] = self.config.block_size else: raise NotImplementedError(f"Unknown format {self.config.format}") if self.config.quantize_mha: @@ -183,7 +184,9 @@ def quantize_model( mtq.quantize(backbone, quant_config, forward_loop) # Get model-specific filter function model_filter_func = get_model_filter_func(self.model_config.model_type, backbone_name) - self.logger.info(f"Using filter function for {self.model_config.model_type.value}/{backbone_name}") + self.logger.info( + f"Using filter function for {self.model_config.model_type.value}/{backbone_name}" + ) self.logger.info("Disabling specific quantizers...") mtq.disable_quantizer(backbone, model_filter_func) @@ -307,19 +310,21 @@ def restore_checkpoint(self) -> None: if self.pipeline_manager is None: raise RuntimeError("Pipeline manager is required for per-backbone checkpoints.") - if restore_path.exists() and restore_path.is_dir(): - for backbone_name, backbone in self.pipeline_manager.iter_backbones(): - # Try named checkpoint first, fall back to legacy "backbone.pt" - source_path = restore_path / f"{backbone_name}.pt" - if not source_path.exists(): - source_path = restore_path / "backbone.pt" - if not source_path.exists(): - raise FileNotFoundError( - f"Checkpoint not found for {backbone_name}: tried " - f"{restore_path / f'{backbone_name}.pt'} and {restore_path / 'backbone.pt'}" - ) - self.logger.info(f"Restoring {backbone_name} from {source_path}") - mto.restore(backbone, str(source_path)) + if not restore_path.exists() or not restore_path.is_dir(): + raise FileNotFoundError(f"Checkpoint directory not found: {restore_path}") + + backbones = list(self.pipeline_manager.iter_backbones()) + for backbone_name, backbone in backbones: + source_path = restore_path / f"{backbone_name}.pt" + # Legacy fallback: only safe when there is exactly one backbone + if not source_path.exists() and len(backbones) == 1: + source_path = restore_path / "backbone.pt" + if not source_path.exists(): + raise FileNotFoundError( + f"Checkpoint not found for '{backbone_name}' in {restore_path}" + ) + self.logger.info(f"Restoring {backbone_name} from {source_path}") + mto.restore(backbone, str(source_path)) self.logger.info("Checkpoints restored successfully") # TODO: should not do the any data type @@ -612,9 +617,7 @@ def main() -> None: for backbone_name, backbone in pipeline_manager.iter_backbones(): logger.info(f"Quantizing backbone: {backbone_name}") - backbone_quant_config = quantizer.get_quant_config( - calib_config.n_steps, backbone - ) + backbone_quant_config = quantizer.get_quant_config(calib_config.n_steps, backbone) # Full pipeline inference for calibration — cached modules # (transformer, video_decoder) are exercised during __call__ @@ -622,7 +625,9 @@ def forward_loop(mod): calibrator.run_calibration(batched_prompts) quantizer.quantize_model( - backbone, backbone_quant_config, forward_loop, + backbone, + backbone_quant_config, + forward_loop, backbone_name=backbone_name, ) diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index 9ffc580784..2dea3fccc4 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -88,43 +88,19 @@ def filter_func_flux_dev(name: str) -> bool: def filter_func_ltx2_vae(name: str) -> bool: - """Filter function for LTX-2 VAE decoder layers that should NOT be quantized. - - Only keeps conv1/conv2 inside resnet blocks within up_blocks. - Disables everything else: conv_in, conv_out, conv_shortcut, upsamplers, - per_channel_statistics, timestep layers, etc. - """ - # Pattern matching conv1/conv2 inside up_blocks at any nesting depth - # e.g. up_blocks.0.conv1, up_blocks.2.res_blocks.0.conv2 - keep_pattern = re.compile(r".*up_blocks\.\d+\.(?:res_blocks\.\d+\.)?conv[12](?:\.|$)") - # If the layer matches the keep pattern, do NOT disable it (return False) - if keep_pattern.match(name): - return False - # Disable everything else - return True + """Filter for LTX-2 VAE: keeps only conv1/conv2 in up_blocks resnets.""" + keep = re.compile(r".*up_blocks\.\d+\.resnets\.\d+\.conv[12](?:\.|$)") + return not keep.match(name) def filter_func_wan_vae(name: str) -> bool: - """Filter function for WAN 2.2 VAE layers that should NOT be quantized. - - Only keeps conv1/conv2 inside WanResidualBlock resnets within - encoder.down_blocks, encoder.mid_block, decoder.mid_block, decoder.up_blocks. - Disables: conv_in, conv_out, conv_shortcut, upsamplers, downsamplers (resample), - time_conv, quant_conv, post_quant_conv, attention layers (to_qkv, proj). - """ - # Keep conv1/conv2 in resnet blocks: - # 14B flat: encoder.down_blocks.N.conv1/conv2 - # 5B nested: encoder.down_blocks.N.resnets.M.conv1/conv2 - # Both: encoder/decoder.mid_block.resnets.N.conv1/conv2 - # Both: decoder.up_blocks.N.resnets.M.conv1/conv2 - keep_pattern = re.compile( + """Filter for Wan 2.2 VAE: keeps only conv1/conv2 in resnet blocks.""" + keep = re.compile( r".*(down_blocks\.\d+\.(?:resnets\.\d+\.)?conv[12]" r"|mid_block\.resnets\.\d+\.conv[12]" r"|up_blocks\.\d+\.resnets\.\d+\.conv[12])(?:\.|$)" ) - if keep_pattern.match(name): - return False - return True + return not keep.match(name) def filter_func_wan_video(name: str) -> bool: diff --git a/modelopt/torch/quantization/nn/modules/quant_conv.py b/modelopt/torch/quantization/nn/modules/quant_conv.py index fa9caca669..3afddd2841 100644 --- a/modelopt/torch/quantization/nn/modules/quant_conv.py +++ b/modelopt/torch/quantization/nn/modules/quant_conv.py @@ -15,15 +15,11 @@ """Quantized convolution.""" -import logging - import torch.nn as nn from ... import tensor_quant from .quant_module import QuantLinearConvBase, QuantModuleRegistry, _LegacyQuantLinearConvBaseMixin -logger = logging.getLogger(__name__) - __all__ = [ "Conv1d", "Conv2d", @@ -69,81 +65,59 @@ class QuantConv2d(_LegacyQuantLinearConvBaseMixin, nn.Conv2d): def _is_nvfp4_quantizer(quantizer) -> bool: """Check if a TensorQuantizer is configured for NVFP4 dynamic block quantization.""" return ( - hasattr(quantizer, "_num_bits") - and quantizer._num_bits == (2, 1) - and hasattr(quantizer, "_block_sizes") - and quantizer._block_sizes is not None - and quantizer._block_sizes.get("scale_bits") == (4, 3) - and quantizer._block_sizes.get("type") == "dynamic" + quantizer.num_bits == (2, 1) + and quantizer.block_sizes is not None + and quantizer.block_sizes.get("scale_bits") == (4, 3) + and quantizer.block_sizes.get("type") == "dynamic" ) def _nvfp4_quantize_weight_along_k(weight, weight_quantizer): - """Apply NVFP4 fake quantization to Conv3D weight along the GEMM K dimension. - - The implicit GEMM maps K = Cin * kD * kH * kW. The default quantizer would - quantize along the last dim (kW), which is wrong. We reshape to [Cout, K] - so blocks are along the contraction dimension. - """ - cout = weight.shape[0] - k = weight[0].numel() # Cin * kD * kH * kW - w_flat = weight.reshape(cout, k) - w_q_flat = weight_quantizer(w_flat) - return w_q_flat.reshape_as(weight) + """Apply NVFP4 fake quantization to Conv3D weight along the GEMM K dimension.""" + w_flat = weight.reshape(weight.shape[0], -1) + return weight_quantizer(w_flat).reshape_as(weight) @QuantModuleRegistry.register({nn.Conv3d: "nn.Conv3d"}) class _QuantConv3d(QuantLinearConvBase): """Quantized 3D convolution. - When both input and weight quantizers are configured for NVFP4, the forward - uses the fused implicit GEMM kernel from experimental/conv which performs - activation FP4 quantization inside the kernel. The weight is FP4-quantized - along the GEMM K dimension (Cin*kD*kH*kW) before being passed to the kernel. - - For all other quantization configs, the default cuDNN path is used. + For NVFP4, uses a fused implicit GEMM kernel with activation FP4 quantization + inside the kernel. For all other configs, the default cuDNN path is used. """ default_quant_desc_weight = tensor_quant.QUANT_DESC_8BIT_CONV3D_WEIGHT_PER_CHANNEL def _should_use_implicit_gemm(self): """Check if both quantizers are NVFP4 and the implicit GEMM kernel is available.""" - if not ( + if hasattr(self, "_use_implicit_gemm"): + return self._use_implicit_gemm + result = False + if ( hasattr(self, "input_quantizer") and hasattr(self, "weight_quantizer") and _is_nvfp4_quantizer(self.input_quantizer) and _is_nvfp4_quantizer(self.weight_quantizer) ): - return False - try: - from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda # noqa: F401 - - return True - except ImportError: - return False - - def forward(self, input, *args, **kwargs): - """Forward with implicit GEMM for NVFP4, default path otherwise.""" - if not self._should_use_implicit_gemm(): - return super().forward(input, *args, **kwargs) - - # During calibration we only need to collect amax — use the faster - # default cuDNN path since the conv output itself doesn't matter. - if self.input_quantizer._if_calib and not self.input_quantizer._if_quant: - return super().forward(input, *args, **kwargs) - + try: + from experimental.conv.implicit_gemm_cuda import ( + conv3d_implicit_gemm_cuda, # noqa: F401 + ) + + result = True + except ImportError: + pass + self._use_implicit_gemm = result + return result + + def _implicit_gemm_forward(self, input): + """Run NVFP4 implicit GEMM kernel. Input may already be padded.""" from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda - # --- Get activation amax for the kernel --- act_amax = self.input_quantizer._get_amax(input) - - # --- Quantize weight along K dimension --- weight = _nvfp4_quantize_weight_along_k(self.weight, self.weight_quantizer) - - # --- Get fp4_block_size from the input quantizer config --- fp4_block_size = self.input_quantizer._block_sizes.get(-1, 16) - # --- Call implicit GEMM kernel --- output = conv3d_implicit_gemm_cuda( input, weight, @@ -152,15 +126,21 @@ def forward(self, input, *args, **kwargs): padding=self.padding, dilation=self.dilation, act_amax=act_amax, - quant_act=not self.input_quantizer._disabled, + quant_act=self.input_quantizer.is_enabled, fp4_block_size=fp4_block_size, ) + return self.output_quantizer(output) + + def forward(self, input, *args, **kwargs): + """Forward with implicit GEMM for NVFP4, default path otherwise.""" + if not self._should_use_implicit_gemm(): + return super().forward(input, *args, **kwargs) - # --- Output quantizer (usually disabled for NVFP4) --- - if hasattr(self, "output_quantizer"): - output = self.output_quantizer(output) + # During calibration, only collect amax — use the faster cuDNN path. + if self.input_quantizer._if_calib and not self.input_quantizer._if_quant: + return super().forward(input, *args, **kwargs) - return output + return self._implicit_gemm_forward(input) class QuantConv3d(_LegacyQuantLinearConvBaseMixin, nn.Conv3d): diff --git a/modelopt/torch/quantization/plugins/diffusion/diffusers.py b/modelopt/torch/quantization/plugins/diffusion/diffusers.py index 336d09c30e..f2f6a70247 100644 --- a/modelopt/torch/quantization/plugins/diffusion/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusion/diffusers.py @@ -64,7 +64,6 @@ TensorQuantizer, ) from ...nn.modules.quant_conv import _QuantConv3d -from ...utils import is_torch_export_mode from ..custom import _QuantFunctionalMixin onnx_dtype_map = { @@ -282,41 +281,32 @@ def symbolic( ) -# --------------------------------------------------------------------------- # WanCausalConv3d quantization support (diffusers VAE) -# --------------------------------------------------------------------------- try: from diffusers.models.autoencoders.autoencoder_kl_wan import WanCausalConv3d @QuantModuleRegistry.register({WanCausalConv3d: "WanCausalConv3d"}) class _QuantDiffusersWanCausalConv3d(_QuantConv3d): - """Quantized diffusers WanCausalConv3d. - - WanCausalConv3d inherits from nn.Conv3d directly, so we override - forward to apply causal padding + cache before the quantized conv. - """ - - @staticmethod - def _get_quantized_weight( - module: "QuantLinearConvBase", weight: torch.Tensor - ) -> torch.Tensor: - if module._enable_weight_quantization or is_torch_export_mode(): - return module.weight_quantizer(weight) - return weight + """Quantized WanCausalConv3d — applies causal padding before quantized conv.""" def forward(self, x, cache_x=None): + # Apply WanCausalConv3d-specific causal padding + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + # NVFP4 implicit GEMM path (self.padding is (0,0,0) since padding already applied) + if self._should_use_implicit_gemm(): + if not (self.input_quantizer._if_calib and not self.input_quantizer._if_quant): + return self._implicit_gemm_forward(x) + + # Default quantized conv path (skip WanCausalConv3d.forward to avoid double-padding) with self.quantize_weight(): - padding = list(self._padding) - if cache_x is not None and self._padding[4] > 0: - cache_x = cache_x.to(x.device) - x = torch.cat([cache_x, x], dim=2) - padding[4] -= cache_x.shape[2] - x = F.pad(x, padding) - input = self.input_quantizer(x) - # Call nn.Conv3d.forward (grandparent), skipping WanCausalConv3d.forward output = torch.nn.Conv3d.forward(self, input) - if isinstance(output, tuple): return (self.output_quantizer(output[0]), *output[1:]) return self.output_quantizer(output) From c9392355deaa512ac7486c2e8b961765487748e4 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 10 Apr 2026 04:11:17 +0000 Subject: [PATCH 16/25] Update Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index 2dea3fccc4..d102e83e06 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -46,7 +46,7 @@ def filter_func_default(name: str) -> bool: def check_conv_and_mha(backbone, if_fp4, quantize_mha): for name, module in backbone.named_modules(): - if isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)) and if_fp4: + if isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d)) and if_fp4: module.weight_quantizer.disable() module.input_quantizer.disable() From 818b4bce20cd7203895ebba27d1eb030c62d5125 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 10 Apr 2026 21:57:48 +0000 Subject: [PATCH 17/25] Move the code from experimental to core code Signed-off-by: Jingyu Xin --- experimental/conv/bench_implicit_gemm.py | 208 ---- experimental/conv/test_implicit_gemm.py | 1067 ----------------- .../torch/kernels}/conv/README.md | 2 +- .../kernels}/conv/implicit_gemm_binding.cpp | 0 .../torch/kernels}/conv/implicit_gemm_cuda.py | 0 .../kernels}/conv/implicit_gemm_kernel.cu | 0 .../quantization/nn/modules/quant_conv.py | 4 +- 7 files changed, 3 insertions(+), 1278 deletions(-) delete mode 100644 experimental/conv/bench_implicit_gemm.py delete mode 100644 experimental/conv/test_implicit_gemm.py rename {experimental => modelopt/torch/kernels}/conv/README.md (97%) rename {experimental => modelopt/torch/kernels}/conv/implicit_gemm_binding.cpp (100%) rename {experimental => modelopt/torch/kernels}/conv/implicit_gemm_cuda.py (100%) rename {experimental => modelopt/torch/kernels}/conv/implicit_gemm_kernel.cu (100%) diff --git a/experimental/conv/bench_implicit_gemm.py b/experimental/conv/bench_implicit_gemm.py deleted file mode 100644 index 164c074467..0000000000 --- a/experimental/conv/bench_implicit_gemm.py +++ /dev/null @@ -1,208 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Latency benchmark: implicit GEMM (quant / non-quant) vs cuDNN conv3d. - -Usage: - python -m experimental.conv.bench_implicit_gemm - python -m experimental.conv.bench_implicit_gemm --shapes wan22 - python -m experimental.conv.bench_implicit_gemm --shapes all --warmup 20 --iters 100 -""" - -import argparse - -import torch -import torch.nn.functional as F - -# --------------------------------------------------------------------------- -# Benchmark shapes -# --------------------------------------------------------------------------- - -# (name, N, Cin, D, H, W, Cout, kD, kH, kW, stride, padding, dilation) -SHAPES = { - "small": [ - ("small_16x32_3x3x3", 1, 16, 8, 8, 8, 32, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), - ], - "medium": [ - ("med_64x128_3x3x3", 1, 64, 16, 32, 32, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), - ("med_128x256_3x3x3", 1, 128, 8, 16, 16, 256, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), - ("med_128x128_1x3x3", 1, 128, 16, 32, 32, 128, 1, 3, 3, (1, 1, 1), (0, 1, 1), (1, 1, 1)), - ], - "wan22": [ - ("wan22_128x512", 1, 128, 21, 60, 106, 512, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), - ("wan22_512x512", 1, 512, 21, 60, 106, 512, 1, 1, 1, (1, 1, 1), (0, 0, 0), (1, 1, 1)), - ("wan22_512x128", 1, 512, 21, 60, 106, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), - ], - "stride": [ - ("stride2_64x128", 1, 64, 16, 32, 32, 128, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)), - ("stride2_128x256", 1, 128, 16, 32, 32, 256, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)), - ], -} - - -def get_shapes(name: str): - """Return list of benchmark shapes by name or all shapes.""" - if name == "all": - result = [] - for v in SHAPES.values(): - result.extend(v) - return result - return SHAPES[name] - - -# --------------------------------------------------------------------------- -# Timing utility -# --------------------------------------------------------------------------- - - -def bench_fn(fn, warmup: int, iters: int) -> float: - """Benchmark a callable, return median time in ms.""" - for _ in range(warmup): - fn() - torch.cuda.synchronize() - - times = [] - for _ in range(iters): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - - times.sort() - return times[len(times) // 2] # median - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - - -def run_benchmark(shapes_name: str, warmup: int, iters: int, fp4_block_size: int): - """Run latency benchmark for the given shapes.""" - from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda - - shapes = get_shapes(shapes_name) - - # Header - print(f"\n{'=' * 100}") - print( - f"Conv3D Latency Benchmark | warmup={warmup} iters={iters} fp4_block_size={fp4_block_size}" - ) - print(f"GPU: {torch.cuda.get_device_name()}") - print(f"{'=' * 100}") - print( - f"{'Shape':<25} {'M':>10} {'K':>8} {'N':>6} " - f"{'cuDNN':>9} {'GEMM':>9} {'GEMM+FP4':>9} " - f"{'GEMM/cuDNN':>11} {'FP4/cuDNN':>10}" - ) - print("-" * 100) - - for name, n, cin, d, h, w, cout, kd, kh, kw, stride, padding, dilation in shapes: - torch.manual_seed(42) - x = torch.randn(n, cin, d, h, w, device="cuda", dtype=torch.float32) - weight = torch.randn(cout, cin, kd, kh, kw, device="cuda", dtype=torch.float32) - act_amax = x.abs().max().unsqueeze(0) - - # Compute GEMM dimensions for display - sd, sh, sw = stride - dd, dh, dw = dilation - pd, ph, pw = padding - od = (d + 2 * pd - dd * (kd - 1) - 1) // sd + 1 - oh = (h + 2 * ph - dh * (kh - 1) - 1) // sh + 1 - ow = (w + 2 * pw - dw * (kw - 1) - 1) // sw + 1 - gemm_m = n * od * oh * ow - gemm_k = cin * kd * kh * kw - gemm_n = cout - - # cuDNN (torch.nn.functional.conv3d) - t_cudnn = bench_fn( - lambda: F.conv3d(x, weight, stride=stride, padding=padding, dilation=dilation), - warmup, - iters, - ) - - # Implicit GEMM (non-quantized) - t_gemm = bench_fn( - lambda: conv3d_implicit_gemm_cuda( - x, - weight, - stride=stride, - padding=padding, - dilation=dilation, - quant_act=False, - fp4_block_size=fp4_block_size, - ), - warmup, - iters, - ) - - # Implicit GEMM (FP4 quantized) - t_fp4 = bench_fn( - lambda: conv3d_implicit_gemm_cuda( - x, - weight, - stride=stride, - padding=padding, - dilation=dilation, - act_amax=act_amax, - quant_act=True, - fp4_block_size=fp4_block_size, - ), - warmup, - iters, - ) - - ratio_gemm = t_gemm / t_cudnn - ratio_fp4 = t_fp4 / t_cudnn - - print( - f"{name:<25} {gemm_m:>10,} {gemm_k:>8,} {gemm_n:>6,} " - f"{t_cudnn:>8.3f}ms {t_gemm:>8.3f}ms {t_fp4:>8.3f}ms " - f"{ratio_gemm:>10.2f}x {ratio_fp4:>9.2f}x" - ) - - print(f"{'=' * 100}") - print("Ratios > 1.0x mean slower than cuDNN; < 1.0x mean faster.") - print() - - -def main(): - """Entry point for the benchmark CLI.""" - parser = argparse.ArgumentParser(description="Conv3D latency benchmark") - parser.add_argument( - "--shapes", - default="all", - choices=[*list(SHAPES.keys()), "all"], - help="Which shape set to benchmark (default: all)", - ) - parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations") - parser.add_argument("--iters", type=int, default=100, help="Benchmark iterations") - parser.add_argument( - "--fp4-block-size", - type=int, - default=128, - choices=[128, 256], - help="FP4 block size (default: 128)", - ) - args = parser.parse_args() - - run_benchmark(args.shapes, args.warmup, args.iters, args.fp4_block_size) - - -if __name__ == "__main__": - main() diff --git a/experimental/conv/test_implicit_gemm.py b/experimental/conv/test_implicit_gemm.py deleted file mode 100644 index af52660e42..0000000000 --- a/experimental/conv/test_implicit_gemm.py +++ /dev/null @@ -1,1067 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for Conv3D implicit GEMM CUDA kernel. - -Tests both non-quantized path (vs cuDNN) and FP4-quantized path (vs Triton reference). -""" - -import pytest -import torch -import torch.nn.functional as F - -pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - - -@pytest.fixture(scope="module") -def cuda_conv3d(): - """Import and return the CUDA implicit GEMM conv3d function.""" - from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda - - return conv3d_implicit_gemm_cuda - - -def _triton_fp4_available(): - """Check if the Triton FP4 fake quant kernel is available (requires compute >= 8.9).""" - try: - import modelopt.torch.quantization.triton as triton_kernel - - return hasattr(triton_kernel, "fp4_fake_quant_block") - except ImportError: - return False - - -requires_triton_fp4 = pytest.mark.skipif( - not _triton_fp4_available(), - reason="Triton fp4_fake_quant_block not available (requires compute >= 8.9)", -) - - -# BF16 WMMA accumulates in FP32 but inputs are rounded to BF16, so expect diffs. -# For large K (e.g. 3456 = 128*27), max abs diff can reach ~0.8 due to BF16 rounding -# and different accumulation order vs cuDNN's FP32 path. -ATOL = 1.0 -RTOL = 1e-3 - - -def _run_conv3d_test(cuda_conv3d, x, w, bias, stride, padding, dilation): - """Helper: run both cuDNN and implicit GEMM, compare results.""" - ref = F.conv3d(x, w, bias=bias, stride=stride, padding=padding, dilation=dilation) - out = cuda_conv3d( - x, w, bias=bias, stride=stride, padding=padding, dilation=dilation, quant_act=False - ) - assert out.shape == ref.shape, f"Shape mismatch: {out.shape} vs {ref.shape}" - abs_diff = (out - ref).abs() - max_diff = abs_diff.max().item() - # Scale tolerance with K (reduction dimension) — BF16 rounding accumulates - cin = w.shape[1] - k_size = cin * w.shape[2] * w.shape[3] * w.shape[4] - scaled_atol = ATOL * (k_size / 1000.0) ** 0.5 - assert max_diff < scaled_atol, ( - f"Max abs diff {max_diff:.6e} exceeds tolerance {scaled_atol:.4f} (K={k_size})" - ) - # Check mean diff is small (more robust than quantile for large tensors) - mean_diff = abs_diff.mean().item() - assert mean_diff < scaled_atol * 0.1, f"Mean diff {mean_diff:.6e} too high" - return max_diff - - -class TestConv3dBasic: - """Basic correctness tests with simple shapes.""" - - def test_minimal(self, cuda_conv3d): - """Smallest possible conv3d: 1x1x1 kernel, single channel.""" - x = torch.randn(1, 1, 1, 1, 1, device="cuda", dtype=torch.float32) - w = torch.randn(1, 1, 1, 1, 1, device="cuda", dtype=torch.float32) - diff = _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) - # K=1, so BF16 rounding is the only source of error - assert diff < 1e-2 - - def test_single_channel_3x3x3(self, cuda_conv3d): - """Single input/output channel with 3x3x3 kernel.""" - x = torch.randn(1, 1, 5, 5, 5, device="cuda", dtype=torch.float32) - w = torch.randn(1, 1, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) - - def test_multi_channel(self, cuda_conv3d): - """Multiple input and output channels.""" - x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) - - def test_with_bias(self, cuda_conv3d): - """Conv3d with bias.""" - x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - b = torch.randn(32, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, b, (1, 1, 1), (1, 1, 1), (1, 1, 1)) - - def test_batch_size(self, cuda_conv3d): - """Batch size > 1.""" - x = torch.randn(4, 16, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) - - -class TestConv3dStride: - """Tests with various stride configurations.""" - - def test_stride_2(self, cuda_conv3d): - """Uniform stride of 2.""" - x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (2, 2, 2), (1, 1, 1), (1, 1, 1)) - - def test_asymmetric_stride(self, cuda_conv3d): - """Different stride per dimension.""" - x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 2, 2), (1, 1, 1), (1, 1, 1)) - - -class TestConv3dPadding: - """Tests with various padding configurations.""" - - def test_no_padding(self, cuda_conv3d): - """Zero padding.""" - x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) - - def test_large_padding(self, cuda_conv3d): - """Padding larger than kernel radius.""" - x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (2, 2, 2), (1, 1, 1)) - - def test_asymmetric_padding(self, cuda_conv3d): - """Different padding per dimension.""" - x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 1, 2), (1, 1, 1)) - - -class TestConv3dDilation: - """Tests with dilation.""" - - def test_dilation_2(self, cuda_conv3d): - """Uniform dilation of 2.""" - x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (2, 2, 2), (2, 2, 2)) - - def test_asymmetric_dilation(self, cuda_conv3d): - """Different dilation per dimension.""" - x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 2, 2), (1, 2, 2)) - - -class TestConv3dKernelSizes: - """Tests with non-3x3x3 kernels.""" - - def test_1x1x1_kernel(self, cuda_conv3d): - """Pointwise 1x1x1 kernel.""" - x = torch.randn(1, 64, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(128, 64, 1, 1, 1, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) - - def test_asymmetric_kernel(self, cuda_conv3d): - """Kernel with different sizes per dimension (e.g. 1x3x3).""" - x = torch.randn(1, 16, 8, 16, 16, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 1, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 1, 1), (1, 1, 1)) - - def test_5x5x5_kernel(self, cuda_conv3d): - """Larger 5x5x5 kernel.""" - x = torch.randn(1, 8, 16, 16, 16, device="cuda", dtype=torch.float32) - w = torch.randn(16, 8, 5, 5, 5, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (2, 2, 2), (1, 1, 1)) - - -class TestConv3dRealisticShapes: - """Tests with shapes resembling real video diffusion models.""" - - def test_wan22_shape(self, cuda_conv3d): - """Shape from Wan2.2 video diffusion backbone.""" - x = torch.randn(1, 128, 21, 60, 106, device="cuda", dtype=torch.float32) - w = torch.randn(512, 128, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) - - def test_large_cout(self, cuda_conv3d): - """Large output channel count.""" - x = torch.randn(1, 64, 8, 16, 16, device="cuda", dtype=torch.float32) - w = torch.randn(512, 64, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) - - def test_large_cin(self, cuda_conv3d): - """Large input channel count.""" - x = torch.randn(1, 512, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(64, 512, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) - - -class TestConv3dEdgeCases: - """Edge cases for tile boundary handling.""" - - def test_m_not_aligned_to_block(self, cuda_conv3d): - """M (N*OD*OH*OW) not a multiple of BLOCK_M=64.""" - # 1*3*5*7 = 105, not divisible by 64 - x = torch.randn(1, 8, 5, 7, 9, device="cuda", dtype=torch.float32) - w = torch.randn(16, 8, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) - - def test_cout_not_aligned_to_block(self, cuda_conv3d): - """Cout not a multiple of BLOCK_N=64.""" - x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(17, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) - - def test_k_not_aligned_to_block(self, cuda_conv3d): - """K (Cin*kD*kH*kW) not a multiple of BLOCK_K.""" - # Cin=7, kDHW=27, K=189 -- not a multiple of 128 or 256 - x = torch.randn(1, 7, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(16, 7, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) - - def test_output_size_1x1x1(self, cuda_conv3d): - """Output spatial dims are all 1.""" - x = torch.randn(1, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) - - def test_single_output_element(self, cuda_conv3d): - """M=1: batch=1, output 1x1x1. - - With only one output element, mean diff == max diff, so the generic - helper's mean_diff < scaled_atol * 0.1 check is too tight. Use max diff only. - """ - x = torch.randn(1, 4, 3, 3, 3, device="cuda", dtype=torch.float32) - w = torch.randn(1, 4, 3, 3, 3, device="cuda", dtype=torch.float32) - ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1)) - out = cuda_conv3d( - x, w, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), quant_act=False - ) - assert out.shape == ref.shape - max_diff = (out - ref).abs().max().item() - assert max_diff < ATOL, f"Max abs diff {max_diff:.6e} exceeds tolerance {ATOL}" - - -class TestConv3dFP4BlockSize: - """Test all FP4 block size configs (BLOCK_K=256 always, FP4_BLOCK_SIZE varies). - - Non-quantized path ignores FP4_BLOCK_SIZE, so all should match cuDNN. - """ - - @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) - def test_non_quant_all_block_sizes(self, cuda_conv3d, fp4_block_size): - """Non-quant conv should match cuDNN regardless of fp4_block_size.""" - x = torch.randn(1, 32, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(64, 32, 3, 3, 3, device="cuda", dtype=torch.float32) - ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)) - out = cuda_conv3d( - x, - w, - stride=(1, 1, 1), - padding=(1, 1, 1), - dilation=(1, 1, 1), - quant_act=False, - fp4_block_size=fp4_block_size, - ) - assert out.shape == ref.shape - assert (out - ref).abs().max().item() < ATOL - - -class TestConv3dDeterminism: - """Verify deterministic output across repeated calls.""" - - def test_deterministic(self, cuda_conv3d): - """Repeated calls produce identical output.""" - torch.manual_seed(123) - x = torch.randn(1, 32, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(64, 32, 3, 3, 3, device="cuda", dtype=torch.float32) - out1 = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) - out2 = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) - assert torch.equal(out1, out2), "Kernel is not deterministic" - - -# ============================================================================= -# FP4 Quantized Conv3D Tests (fused activation quantization) -# ============================================================================= - - -@pytest.fixture(scope="module") -def cuda_fp4_quant(): - """Import FP4 fake quant for reference comparisons.""" - from experimental.conv.implicit_gemm_cuda import fp4_fake_quant - - return fp4_fake_quant - - -class TestConv3dFP4QuantBlockSizes: - """Test fused FP4 activation quantization with all supported block sizes. - - The kernel applies blockwise FP4 quantization to the im2col'd activation tiles - along the K dimension. We verify correctness by comparing the fused kernel output - against an unfused reference: fp4_fake_quant(im2col) @ fp4_fake_quant(weight). - """ - - @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) - def test_quant_runs_all_block_sizes(self, cuda_conv3d, fp4_block_size): - """All FP4 block sizes should run without errors and produce valid output.""" - x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - act_amax = x.abs().max().unsqueeze(0) - - out = cuda_conv3d( - x, - w, - stride=(1, 1, 1), - padding=(1, 1, 1), - dilation=(1, 1, 1), - act_amax=act_amax, - quant_act=True, - fp4_block_size=fp4_block_size, - ) - assert out.shape == (1, 32, 8, 8, 8) - assert not torch.isnan(out).any(), "Output contains NaN" - assert not torch.isinf(out).any(), "Output contains Inf" - assert out.abs().max() > 0, "Output is all zeros" - - @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) - def test_quant_deterministic(self, cuda_conv3d, fp4_block_size): - """Quantized conv should be deterministic for all block sizes.""" - torch.manual_seed(42) - x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - act_amax = x.abs().max().unsqueeze(0) - - kwargs = { - "stride": (1, 1, 1), - "padding": (1, 1, 1), - "dilation": (1, 1, 1), - "act_amax": act_amax, - "quant_act": True, - "fp4_block_size": fp4_block_size, - } - out1 = cuda_conv3d(x, w, **kwargs) - out2 = cuda_conv3d(x, w, **kwargs) - assert torch.equal(out1, out2), f"Non-deterministic for fp4_block_size={fp4_block_size}" - - @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) - def test_quant_vs_unfused_reference(self, cuda_conv3d, cuda_fp4_quant, fp4_block_size): - """Compare fused kernel vs unfused: fp4(im2col) @ fp4(weight). - - Uses a shape where K is a multiple of 256 so all K-tiles are full - and block boundaries align perfectly between fused and unfused paths. - """ - torch.manual_seed(123) - # K = Cin * kD * kH * kW. Choose Cin so K is a multiple of 256. - # Cin=256, k=1x1x1 -> K=256 (exactly 1 full K-tile) - cin, cout = 256, 64 - x = torch.randn(1, cin, 4, 4, 4, device="cuda", dtype=torch.float32) - w = torch.randn(cout, cin, 1, 1, 1, device="cuda", dtype=torch.float32) - act_amax = x.abs().max().unsqueeze(0) - w_amax = w.abs().max().unsqueeze(0) - - # Unfused reference: - # 1. Build im2col matrix (for 1x1x1 kernel, it's just reshape) - n, c, d, h, w_dim = x.shape - im2col = x.permute(0, 2, 3, 4, 1).reshape(-1, cin) # [M, K] - - # 2. FP4 fake-quant both matrices along K with the same block_size - im2col_q = cuda_fp4_quant(im2col, act_amax, fp4_block_size) - w_flat = w.reshape(cout, cin).transpose(0, 1).contiguous() # [K, Cout] - w_flat_q = cuda_fp4_quant(w_flat, w_amax, fp4_block_size) - - # 3. Matmul (in BF16 to match kernel's WMMA path) - ref_out = (im2col_q.bfloat16() @ w_flat_q.bfloat16()).float() - ref_out = ref_out.view(n, d, h, w_dim, cout).permute(0, 4, 1, 2, 3) - - # Note: the fused kernel does NOT quantize weights — weights are passed as-is. - # So for a proper comparison we need the fused kernel with pre-quantized weights. - fused_out_preq = cuda_conv3d( - x, - w_flat_q.transpose(0, 1).reshape(cout, cin, 1, 1, 1), - stride=(1, 1, 1), - padding=(0, 0, 0), - dilation=(1, 1, 1), - act_amax=act_amax, - quant_act=True, - fp4_block_size=fp4_block_size, - ) - - # The fused kernel and unfused reference should match closely. - # Differences come from BF16 accumulation order (WMMA 16x16x16 tiles vs flat matmul). - max_diff = (fused_out_preq - ref_out).abs().max().item() - mean_diff = (fused_out_preq - ref_out).abs().mean().item() - # Scale tolerance with K - scaled_atol = ATOL * (cin / 1000.0) ** 0.5 - assert max_diff < scaled_atol, ( - f"fp4_block_size={fp4_block_size}: fused vs unfused max diff {max_diff:.4f} " - f"exceeds tolerance {scaled_atol:.4f}" - ) - assert mean_diff < scaled_atol * 0.1, ( - f"fp4_block_size={fp4_block_size}: mean diff {mean_diff:.6e} too high" - ) - - def test_smaller_block_less_error(self, cuda_conv3d): - """Smaller FP4 block sizes should generally produce lower quantization error. - - Finer-grained blocks capture local ranges better, reducing quant error vs cuDNN. - Test monotonicity: error(16) <= error(32) <= ... <= error(256) (with some tolerance). - Reports detailed accuracy metrics for each block size vs cuDNN baseline. - """ - torch.manual_seed(42) - - # Test multiple shapes to get a comprehensive picture - configs = [ - ("Small K=432", 1, 16, 8, 8, 8, 32, 3, 3, 3), - ("Medium K=1728", 1, 64, 8, 8, 8, 64, 3, 3, 3), - ("Large K=3456", 1, 128, 5, 8, 8, 256, 3, 3, 3), - ] - - block_sizes = [16, 32, 64, 128, 256] - all_errors = {} - - for desc, n, cin, d, h, w_s, cout, kd, kh, kw in configs: - x = torch.randn(n, cin, d, h, w_s, device="cuda", dtype=torch.float32) - w = torch.randn(cout, cin, kd, kh, kw, device="cuda", dtype=torch.float32) - act_amax = x.abs().max().unsqueeze(0) - k_size = cin * kd * kh * kw - - ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)) - ref_abs_mean = ref.abs().mean().item() - - # Also compute no-quant baseline (BF16 rounding only) - out_nq = cuda_conv3d( - x, - w, - stride=(1, 1, 1), - padding=(1, 1, 1), - dilation=(1, 1, 1), - quant_act=False, - ) - nq_diff = (out_nq - ref).abs() - - print( - f"\n {desc} (K={k_size}), output range [{ref.min().item():.1f}, {ref.max().item():.1f}]" - ) - print( - f" {'Block Size':>10} | {'Max Diff':>10} | {'Mean Diff':>10} | {'RMSE':>10} | {'Rel Err%':>8}" - ) - print(f" {'-' * 10}-+-{'-' * 10}-+-{'-' * 10}-+-{'-' * 10}-+-{'-' * 8}") - print( - f" {'no-quant':>10} | {nq_diff.max().item():>10.4f} | " - f"{nq_diff.mean().item():>10.6f} | " - f"{((out_nq - ref) ** 2).mean().sqrt().item():>10.4f} | " - f"{nq_diff.mean().item() / ref_abs_mean * 100:>7.3f}%" - ) - - errors = {} - for bs in block_sizes: - out = cuda_conv3d( - x, - w, - stride=(1, 1, 1), - padding=(1, 1, 1), - dilation=(1, 1, 1), - act_amax=act_amax, - quant_act=True, - fp4_block_size=bs, - ) - diff = (out - ref).abs() - max_d = diff.max().item() - mean_d = diff.mean().item() - rmse = ((out - ref) ** 2).mean().sqrt().item() - rel_err = mean_d / ref_abs_mean * 100 - errors[bs] = mean_d - print( - f" {bs:>10} | {max_d:>10.4f} | {mean_d:>10.6f} | " - f"{rmse:>10.4f} | {rel_err:>7.3f}%" - ) - all_errors[desc] = errors - - # Monotonicity check on the medium config - errors = all_errors["Medium K=1728"] - for smaller, larger in [(16, 64), (16, 256), (32, 256), (64, 256)]: - assert errors[smaller] <= errors[larger] * 1.2, ( - f"Expected error({smaller})={errors[smaller]:.6f} <= " - f"error({larger})={errors[larger]:.6f} * 1.2" - ) - - @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) - def test_quant_with_bias(self, cuda_conv3d, fp4_block_size): - """FP4 quantized conv with bias for all block sizes.""" - torch.manual_seed(42) - x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) - b = torch.randn(32, device="cuda", dtype=torch.float32) - act_amax = x.abs().max().unsqueeze(0) - - out = cuda_conv3d( - x, - w, - bias=b, - stride=(1, 1, 1), - padding=(1, 1, 1), - dilation=(1, 1, 1), - act_amax=act_amax, - quant_act=True, - fp4_block_size=fp4_block_size, - ) - assert out.shape == (1, 32, 8, 8, 8) - assert not torch.isnan(out).any() - # Bias should shift output values - out_no_bias = cuda_conv3d( - x, - w, - stride=(1, 1, 1), - padding=(1, 1, 1), - dilation=(1, 1, 1), - act_amax=act_amax, - quant_act=True, - fp4_block_size=fp4_block_size, - ) - assert not torch.equal(out, out_no_bias), "Bias had no effect" - - @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) - def test_quant_k_not_aligned(self, cuda_conv3d, fp4_block_size): - """FP4 quant with K not aligned to BLOCK_K or fp4_block_size. - - K = Cin * kDHW = 7 * 27 = 189. The last K-tile has partial data (zeros padded). - """ - torch.manual_seed(42) - x = torch.randn(1, 7, 8, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(16, 7, 3, 3, 3, device="cuda", dtype=torch.float32) - act_amax = x.abs().max().unsqueeze(0) - - out = cuda_conv3d( - x, - w, - stride=(1, 1, 1), - padding=(1, 1, 1), - dilation=(1, 1, 1), - act_amax=act_amax, - quant_act=True, - fp4_block_size=fp4_block_size, - ) - assert out.shape == (1, 16, 8, 8, 8) - assert not torch.isnan(out).any() - assert out.abs().max() > 0 - - @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) - def test_quant_realistic_shape(self, cuda_conv3d, fp4_block_size): - """Realistic video diffusion shape with all FP4 block sizes.""" - torch.manual_seed(42) - x = torch.randn(1, 128, 5, 8, 8, device="cuda", dtype=torch.float32) - w = torch.randn(256, 128, 3, 3, 3, device="cuda", dtype=torch.float32) - act_amax = x.abs().max().unsqueeze(0) - - out = cuda_conv3d( - x, - w, - stride=(1, 1, 1), - padding=(1, 1, 1), - dilation=(1, 1, 1), - act_amax=act_amax, - quant_act=True, - fp4_block_size=fp4_block_size, - ) - assert out.shape == (1, 256, 5, 8, 8) - assert not torch.isnan(out).any() - assert out.abs().max() > 0 - - -# ============================================================================= -# FP4 Fake Quantization Tests -# ============================================================================= - - -@pytest.fixture(scope="module") -def cuda_fp4(): - """Import and return the CUDA FP4 fake quant function.""" - from experimental.conv.implicit_gemm_cuda import fp4_fake_quant - - return fp4_fake_quant - - -def _py_fp4_fake_quant_ref(x_flat, global_amax, block_size): - """Pure Python reference for FP4 fake quant (no BF16 rounding). - - This implements the exact same algorithm as the CUDA kernel: - 1. Compute global_scale = global_amax / (6 * 448) - 2. Per block: block_max = max(|x|), scale = fp8_e4m3_roundtrip(block_max / (6 * global_scale)) * global_scale - 3. Quantize each element to nearest E2M1 level, then dequantize. - """ - import math - - # E2M1 quantization levels: {0, 0.5, 1, 1.5, 2, 3, 4, 6} - # Boundaries (midpoints): <=0.25->0, <0.75->0.5, <=1.25->1, <1.75->1.5, <=2.5->2, <3.5->3, <=5->4, >5->6 - def quantize_e2m1(scaled_abs): - if scaled_abs <= 0.25: - return 0.0 - elif scaled_abs < 0.75: - return 0.5 - elif scaled_abs <= 1.25: - return 1.0 - elif scaled_abs < 1.75: - return 1.5 - elif scaled_abs <= 2.5: - return 2.0 - elif scaled_abs < 3.5: - return 3.0 - elif scaled_abs <= 5.0: - return 4.0 - else: - return 6.0 - - def fp8_e4m3_roundtrip(val): - """Simulate FP8 E4M3 round-trip in Python.""" - if val == 0.0: - return 0.0 - sign = 1.0 if val >= 0 else -1.0 - val = abs(val) - # FP8 E4M3: bias=7, 3 mantissa bits, max=448, no inf/nan - if val > 448.0: - return sign * 448.0 - # Compute exponent - exp = math.floor(math.log2(val)) - exp = max(exp, -6) # min normal exponent for E4M3 - # Compute mantissa (3 bits) - mantissa = val / (2.0**exp) # 1.xxx - mantissa_bits = round((mantissa - 1.0) * 8.0) # 3 bits - if mantissa_bits > 7: - mantissa_bits = 0 - exp += 1 - if exp > 8: - return sign * 448.0 - # Reconstruct - result = (1.0 + mantissa_bits / 8.0) * (2.0**exp) - return sign * result - - global_scale = float(global_amax) / (6.0 * 448.0) - x_np = x_flat.cpu().float().numpy().copy() - num_blocks = len(x_np) // block_size - - for b in range(num_blocks): - block = x_np[b * block_size : (b + 1) * block_size] - block_max = float(max(abs(v) for v in block)) - - # Scale quantization - scaled = block_max / (6.0 * global_scale) - scaled = min(scaled, 448.0) - quantized_scale = fp8_e4m3_roundtrip(scaled) * global_scale - if quantized_scale < 1e-5: - quantized_scale = 1.0 - inv_scale = 1.0 / quantized_scale - - for i in range(block_size): - val = block[i] - sign = 1.0 if val >= 0 else -1.0 - q = quantize_e2m1(abs(val) * inv_scale) - x_np[b * block_size + i] = sign * q * quantized_scale - - return torch.tensor(x_np, device=x_flat.device) - - -class TestFP4FakeQuantValues: - """Test FP4 fake quant with known E2M1 table values.""" - - def test_exact_e2m1_values(self, cuda_fp4): - """E2M1 representable values should round-trip exactly (when scale=1 via amax=6*448).""" - # With global_amax = 6*448 = 2688, global_scale = 1.0 - # A single-block input with max=6 -> block_max=6, scaled=6/(6*1)=1.0 - # fp8_e4m3(1.0)=1.0, scale = 1.0*1.0 = 1.0 - block_size = 8 - vals = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], device="cuda", dtype=torch.float32) - amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) - out = cuda_fp4(vals, amax, block_size) - assert torch.allclose(out, vals, atol=1e-5), f"Got {out} vs expected {vals}" - - def test_exact_e2m1_negative(self, cuda_fp4): - """Negative E2M1 values should also round-trip.""" - block_size = 8 - vals = torch.tensor([0, -0.5, -1, -1.5, -2, -3, -4, -6], device="cuda", dtype=torch.float32) - amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) - out = cuda_fp4(vals, amax, block_size) - assert torch.allclose(out, vals, atol=1e-5), f"Got {out} vs expected {vals}" - - def test_below_boundary(self, cuda_fp4): - """Values slightly below E2M1 boundaries should quantize down.""" - block_size = 8 - # Boundaries: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 - # Slightly below -> quantize to lower level - inp = torch.tensor( - [0.15, 0.65, 1.15, 1.65, 2.4, 3.4, 4.9, 6.0], device="cuda", dtype=torch.float32 - ) - expected = torch.tensor( - [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device="cuda", dtype=torch.float32 - ) - amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) - out = cuda_fp4(inp, amax, block_size) - assert torch.allclose(out, expected, atol=1e-5), f"Got {out} vs expected {expected}" - - def test_above_boundary(self, cuda_fp4): - """Values slightly above E2M1 boundaries should quantize up.""" - block_size = 8 - inp = torch.tensor( - [0.35, 0.85, 1.35, 1.85, 2.6, 3.6, 5.1, 6.0], device="cuda", dtype=torch.float32 - ) - expected = torch.tensor( - [0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 6.0], device="cuda", dtype=torch.float32 - ) - amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) - out = cuda_fp4(inp, amax, block_size) - assert torch.allclose(out, expected, atol=1e-5), f"Got {out} vs expected {expected}" - - def test_mixed_signs(self, cuda_fp4): - """Mixed positive/negative values.""" - block_size = 8 - inp = torch.tensor([-6, -3, -1, 0, 0.5, 2, 4, 6], device="cuda", dtype=torch.float32) - expected = torch.tensor([-6, -3, -1, 0, 0.5, 2, 4, 6], device="cuda", dtype=torch.float32) - amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) - out = cuda_fp4(inp, amax, block_size) - assert torch.allclose(out, expected, atol=1e-5), f"Got {out} vs expected {expected}" - - -class TestFP4FakeQuantScale: - """Test FP4 scale computation and FP8 round-trip.""" - - def test_scale_factor(self, cuda_fp4): - """When amax != 6*448, scale should adjust values proportionally.""" - block_size = 8 - # global_amax = 12*448 = 5376, global_scale = 2.0 - # Input block max = 12 -> scaled = 12/(6*2) = 1.0 -> fp8(1.0) = 1.0 -> scale = 2.0 - # So input 12 -> |12|/2 = 6.0 -> q=6 -> 6*2 = 12 - inp = torch.tensor([0, 1, 2, 3, 4, 6, 8, 12], device="cuda", dtype=torch.float32) - amax = torch.tensor([12.0 * 448.0], device="cuda", dtype=torch.float32) - out = cuda_fp4(inp, amax, block_size) - # Expected: each val/2.0 -> quantize to E2M1 -> * 2.0 - expected = torch.tensor([0, 1, 2, 3, 4, 6, 8, 12], device="cuda", dtype=torch.float32) - assert torch.allclose(out, expected, atol=1e-4), f"Got {out} vs expected {expected}" - - def test_zero_block(self, cuda_fp4): - """All-zero block should produce all zeros.""" - block_size = 16 - inp = torch.zeros(block_size, device="cuda", dtype=torch.float32) - amax = torch.tensor([1.0], device="cuda", dtype=torch.float32) - out = cuda_fp4(inp, amax, block_size) - assert torch.equal(out, inp) - - def test_multiple_blocks(self, cuda_fp4): - """Multiple blocks with different ranges.""" - block_size = 8 - # Block 0: small values, Block 1: large values - block0 = torch.tensor([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], device="cuda") - block1 = torch.tensor([0, 1, 2, 3, 4, 5, 6, 6], device="cuda") - inp = torch.cat([block0, block1]) - amax = inp.abs().max().unsqueeze(0) - out = cuda_fp4(inp, amax, block_size) - # Each block should be independently quantized - assert out.shape == inp.shape - # Block 1 exact values should be close to E2M1 levels - assert out[8:].abs().max() <= 6.0 + 1e-5 - - -class TestFP4FakeQuantBlockSizes: - """Test different block sizes.""" - - @pytest.mark.parametrize("block_size", [8, 16, 32, 64, 128, 256]) - def test_block_sizes(self, cuda_fp4, block_size): - """FP4 quant should work for various block sizes.""" - torch.manual_seed(42) - num_blocks = 4 - inp = torch.randn(num_blocks * block_size, device="cuda", dtype=torch.float32) * 5 - amax = inp.abs().max().unsqueeze(0) - out = cuda_fp4(inp, amax, block_size) - assert out.shape == inp.shape - # Output should not be all zeros for non-zero input - assert out.abs().max() > 0 - # Output should be <= max possible after quant - assert out.abs().max() <= inp.abs().max() * 1.5 # generous bound - - -class TestFP4FakeQuantVsReference: - """Compare CUDA FP4 fake quant against Python reference implementation.""" - - @pytest.mark.parametrize("block_size", [8, 16, 32]) - def test_vs_python_ref(self, cuda_fp4, block_size): - """CUDA kernel should match the Python reference exactly.""" - torch.manual_seed(123) - num_blocks = 8 - inp = torch.randn(num_blocks * block_size, device="cuda") * 10 - amax = inp.abs().max().unsqueeze(0) - - cuda_out = cuda_fp4(inp, amax, block_size) - ref_out = _py_fp4_fake_quant_ref(inp, amax, block_size) - - assert torch.allclose(cuda_out, ref_out, atol=1e-5), ( - f"CUDA vs Python ref max diff: {(cuda_out - ref_out).abs().max().item():.6e}" - ) - - @pytest.mark.parametrize("block_size", [16, 32]) - def test_vs_python_ref_large(self, cuda_fp4, block_size): - """Larger tensor test against Python reference.""" - torch.manual_seed(456) - num_blocks = 64 - inp = torch.randn(num_blocks * block_size, device="cuda") * 20 - amax = inp.abs().max().unsqueeze(0) - - cuda_out = cuda_fp4(inp, amax, block_size) - ref_out = _py_fp4_fake_quant_ref(inp, amax, block_size) - - assert torch.allclose(cuda_out, ref_out, atol=1e-4), ( - f"CUDA vs Python ref max diff: {(cuda_out - ref_out).abs().max().item():.6e}" - ) - - -class TestFP4FakeQuantVsTriton: - """Compare CUDA FP4 fake quant against Triton fp4_fake_quant_block reference.""" - - @requires_triton_fp4 - @pytest.mark.parametrize("block_size", [16, 32, 64]) - @pytest.mark.parametrize("num_blocks", [4, 16, 64]) - def test_vs_triton(self, cuda_fp4, block_size, num_blocks): - """CUDA kernel should match the Triton fp4_fake_quant_block.""" - from modelopt.torch.quantization.triton import fp4_fake_quant_block - - torch.manual_seed(42) - x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 - global_amax = x.abs().max() - - cuda_out = cuda_fp4(x, global_amax.unsqueeze(0), block_size) - triton_out = fp4_fake_quant_block( - x, - global_amax=global_amax, - block_size=block_size, - tile_rows=num_blocks, - tile_cols=block_size, - ) - - assert torch.allclose(cuda_out, triton_out, atol=1e-5), ( - f"CUDA vs Triton max diff: {(cuda_out - triton_out).abs().max().item():.6e}\n" - f"Mean diff: {(cuda_out - triton_out).abs().mean().item():.6e}" - ) - - -class TestFP4FakeQuantDeterminism: - """Verify FP4 quant is deterministic.""" - - def test_deterministic(self, cuda_fp4): - """Repeated calls produce identical output.""" - torch.manual_seed(99) - inp = torch.randn(256, device="cuda") * 5 - amax = inp.abs().max().unsqueeze(0) - out1 = cuda_fp4(inp, amax, 16) - out2 = cuda_fp4(inp, amax, 16) - assert torch.equal(out1, out2), "FP4 fake quant is not deterministic" - - -# ============================================================================= -# Cross-validation: experimental FP4 vs modelopt FP4 implementations -# ============================================================================= - - -def _modelopt_cuda_ext_mx_available(): - """Check if the modelopt CUDA MX extension is available.""" - try: - from modelopt.torch.quantization.extensions import get_cuda_ext_mx - - return get_cuda_ext_mx() is not None - except Exception: - return False - - -def _modelopt_dynamic_block_quantize_available(): - """Check if dynamic_block_quantize_op is available.""" - try: - from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op - - return dynamic_block_quantize_op is not None - except Exception: - return False - - -requires_cuda_ext_mx = pytest.mark.skipif( - not _modelopt_cuda_ext_mx_available(), - reason="modelopt cuda_ext_mx not available", -) - -requires_dynamic_block_quantize = pytest.mark.skipif( - not _modelopt_dynamic_block_quantize_available(), - reason="modelopt dynamic_block_quantize_op not available", -) - - -class TestFP4FakeQuantVsModelopt: - """Compare experimental CUDA FP4 fake quant against all modelopt FP4 implementations. - - This ensures the standalone FP4 kernel in experimental/conv produces the same - results as the official modelopt quantization paths: - 1. Triton fp4_fake_quant_block (Hopper+ dynamic blockwise) - 2. cuda_ext_mx.fused_amax_convert (CUDA extension fallback) - 3. dynamic_block_quantize_op (high-level API that dispatches to either) - """ - - @requires_triton_fp4 - @pytest.mark.parametrize("block_size", [16, 32, 64]) - @pytest.mark.parametrize("seed", [42, 123, 999]) - def test_vs_triton_fp4_fake_quant_block(self, cuda_fp4, block_size, seed): - """Compare against modelopt Triton fp4_fake_quant_block.""" - from modelopt.torch.quantization.triton import fp4_fake_quant_block - - torch.manual_seed(seed) - num_blocks = 16 - x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 - global_amax = x.abs().max() - - ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) - theirs = fp4_fake_quant_block( - x, - global_amax=global_amax, - block_size=block_size, - tile_rows=num_blocks, - tile_cols=block_size, - ) - - assert torch.allclose(ours, theirs, atol=1e-5), ( - f"experimental vs modelopt Triton max diff: {(ours - theirs).abs().max().item():.6e}" - ) - - @requires_cuda_ext_mx - @pytest.mark.parametrize("block_size", [16, 32]) - @pytest.mark.parametrize("seed", [42, 123]) - def test_vs_cuda_ext_mx(self, cuda_fp4, block_size, seed): - """Compare against modelopt cuda_ext_mx.fused_amax_convert.""" - from modelopt.torch.quantization.extensions import get_cuda_ext_mx - from modelopt.torch.quantization.tensor_quant import mx_format_map - - cuda_ext_mx = get_cuda_ext_mx() - torch.manual_seed(seed) - num_blocks = 16 - x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 - global_amax = x.abs().max() - - ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) - theirs = cuda_ext_mx.fused_amax_convert( - x, - block_size, - getattr(cuda_ext_mx.Types, mx_format_map[(2, 1)]), - getattr(cuda_ext_mx.Types, mx_format_map[(4, 3)]), - global_amax, - ) - - assert torch.allclose(ours, theirs, atol=1e-5), ( - f"experimental vs modelopt cuda_ext_mx max diff: " - f"{(ours - theirs).abs().max().item():.6e}" - ) - - @requires_dynamic_block_quantize - @pytest.mark.parametrize("seed", [42, 123, 999]) - def test_vs_dynamic_block_quantize_op(self, cuda_fp4, seed): - """Compare against modelopt dynamic_block_quantize_op (high-level API). - - This is the function used by the actual quantization pipeline with - num_bits=4 (E2M1) and scale_bits=8 (E4M3). - Note: dynamic_block_quantize_op dispatches to Triton with default block_size=16. - """ - from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op - - block_size = 16 # dynamic_block_quantize_op uses block_size=16 for Triton path - torch.manual_seed(seed) - num_blocks = 16 - x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 - global_amax = x.abs().max() - - ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) - theirs = dynamic_block_quantize_op( - x, - block_size, - global_amax, - num_bits=4, # total bits = 1 sign + 2 exp + 1 mantissa - exponent_bits=2, - scale_num_bits=8, # FP8 E4M3 for scales - scale_exponent_bits=4, - ) - - assert torch.allclose(ours, theirs, atol=1e-5), ( - f"experimental vs modelopt dynamic_block_quantize_op max diff: " - f"{(ours - theirs).abs().max().item():.6e}" - ) - - @requires_triton_fp4 - def test_vs_triton_realistic_shape(self, cuda_fp4): - """Realistic activation shape from a Conv3D layer (flattened).""" - torch.manual_seed(42) - block_size = 16 - # Simulate a large tensor: 256 blocks of 16 elements - # (tile_rows must be power-of-2 for Triton block_ptr) - num_blocks = 256 - x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 5 - global_amax = x.abs().max() - - from modelopt.torch.quantization.triton import fp4_fake_quant_block - - ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) - theirs = fp4_fake_quant_block( - x, - global_amax=global_amax, - block_size=block_size, - tile_rows=16, - tile_cols=block_size, - ) - - max_diff = (ours - theirs).abs().max().item() - mean_diff = (ours - theirs).abs().mean().item() - assert torch.allclose(ours, theirs, atol=1e-5), ( - f"Realistic shape: experimental vs Triton max diff: {max_diff:.6e}, " - f"mean diff: {mean_diff:.6e}" - ) - - @requires_triton_fp4 - @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) - def test_vs_triton_input_dtypes(self, cuda_fp4, dtype): - """Test that our kernel handles different input dtypes correctly. - - Our kernel casts to float32 internally, so the result should match - Triton's output when both receive the same dtype input. - """ - from modelopt.torch.quantization.triton import fp4_fake_quant_block - - torch.manual_seed(42) - block_size = 16 - num_blocks = 8 - x = (torch.randn(num_blocks, block_size, device="cuda") * 5).to(dtype) - global_amax = x.float().abs().max() - - ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) - theirs = fp4_fake_quant_block( - x, - global_amax=global_amax, - block_size=block_size, - tile_rows=num_blocks, - tile_cols=block_size, - ) - - # Both should return the input dtype - assert ours.dtype == dtype - assert theirs.dtype == dtype - - # Compare in float32 - max_diff = (ours.float() - theirs.float()).abs().max().item() - # BF16/FP16 input rounding may cause small diffs - tol = 1e-2 if dtype != torch.float32 else 1e-5 - assert max_diff < tol, f"dtype={dtype}: experimental vs Triton max diff: {max_diff:.6e}" diff --git a/experimental/conv/README.md b/modelopt/torch/kernels/conv/README.md similarity index 97% rename from experimental/conv/README.md rename to modelopt/torch/kernels/conv/README.md index 65b7cc5563..a87f8fab1b 100644 --- a/experimental/conv/README.md +++ b/modelopt/torch/kernels/conv/README.md @@ -26,7 +26,7 @@ This code is kept under `experimental/` by design and is **not** part of the sta ```python import torch -from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda +from modelopt.torch.kernels.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op x = torch.randn(1, 128, 21, 60, 106, device="cuda") diff --git a/experimental/conv/implicit_gemm_binding.cpp b/modelopt/torch/kernels/conv/implicit_gemm_binding.cpp similarity index 100% rename from experimental/conv/implicit_gemm_binding.cpp rename to modelopt/torch/kernels/conv/implicit_gemm_binding.cpp diff --git a/experimental/conv/implicit_gemm_cuda.py b/modelopt/torch/kernels/conv/implicit_gemm_cuda.py similarity index 100% rename from experimental/conv/implicit_gemm_cuda.py rename to modelopt/torch/kernels/conv/implicit_gemm_cuda.py diff --git a/experimental/conv/implicit_gemm_kernel.cu b/modelopt/torch/kernels/conv/implicit_gemm_kernel.cu similarity index 100% rename from experimental/conv/implicit_gemm_kernel.cu rename to modelopt/torch/kernels/conv/implicit_gemm_kernel.cu diff --git a/modelopt/torch/quantization/nn/modules/quant_conv.py b/modelopt/torch/quantization/nn/modules/quant_conv.py index 3afddd2841..cca3f06730 100644 --- a/modelopt/torch/quantization/nn/modules/quant_conv.py +++ b/modelopt/torch/quantization/nn/modules/quant_conv.py @@ -100,7 +100,7 @@ def _should_use_implicit_gemm(self): and _is_nvfp4_quantizer(self.weight_quantizer) ): try: - from experimental.conv.implicit_gemm_cuda import ( + from modelopt.torch.kernels.conv.implicit_gemm_cuda import ( conv3d_implicit_gemm_cuda, # noqa: F401 ) @@ -112,7 +112,7 @@ def _should_use_implicit_gemm(self): def _implicit_gemm_forward(self, input): """Run NVFP4 implicit GEMM kernel. Input may already be padded.""" - from experimental.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda + from modelopt.torch.kernels.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda act_amax = self.input_quantizer._get_amax(input) weight = _nvfp4_quantize_weight_along_k(self.weight, self.weight_quantizer) From f210d234411c4c1aa68800a90cc75e5cc3970cf9 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 10 Apr 2026 22:05:33 +0000 Subject: [PATCH 18/25] Update the README Signed-off-by: Jingyu Xin --- modelopt/torch/kernels/conv/README.md | 90 +++++++++++++++++++-------- 1 file changed, 65 insertions(+), 25 deletions(-) diff --git a/modelopt/torch/kernels/conv/README.md b/modelopt/torch/kernels/conv/README.md index a87f8fab1b..417b03c811 100644 --- a/modelopt/torch/kernels/conv/README.md +++ b/modelopt/torch/kernels/conv/README.md @@ -1,25 +1,31 @@ -# Conv3D Implicit GEMM (Experimental) +# Conv3D Implicit GEMM -Experimental Conv3D kernel prototype using implicit GEMM, with optional fused FP4 fake quantization for activations. +Conv3D kernel using implicit GEMM with BF16 WMMA tensor cores and optional fused FP4 (E2M1) fake quantization. -This code is kept under `experimental/` by design and is **not** part of the stable `modelopt.torch.quantization` API. +This kernel is integrated into `modelopt.torch.quantization` via `_QuantConv3d` — when NVFP4 quantization is applied to an `nn.Conv3d` layer through ModelOpt PTQ, the implicit GEMM path is used automatically. We have only tested it on VAE Conv3D layers from video generation models (e.g. Wan2.2). -## Model Support +## Requirements -| Model/Framework | Supported | Notes | -|-----------------|-----------|-------| -| Video diffusion VAE Conv3D layers | Tested | Validated on VAE encoder/decoder Conv3D layers in video diffusion models | -| Generic LLM backbones | No | Conv3D path is not relevant | -| End-to-end ModelOpt PTQ/QAT pipeline | No | Not wired into formal quantization/export/compress flows | +- **GPU:** SM80+ (Ampere or newer) for BF16 WMMA tensor cores +- **PyTorch:** CUDA toolkit with JIT C++ extension support (`torch.utils.cpp_extension`) +- **Grouped convolution is not supported** (groups must be 1) -## Deployment +## Data Types -| Framework | Supported | Notes | -|-----------|-----------|-------| -| TensorRT-LLM | No | No formal export integration for this kernel path | -| vLLM | No | No integration | -| SGLang | No | No integration | -| PyTorch runtime (CUDA) | Yes (experimental) | JIT-compiles CUDA extension on first use | +| Stage | Precision | +|-------|-----------| +| Input / output tensors | FP32 (user-facing) | +| Internal compute | BF16 via WMMA m16n16k16 tensor cores | +| Accumulation | FP32 | +| FP4 activation quantization | E2M1 values, FP8 E4M3 scales | + +## Integration with ModelOpt Quantization + +When NVFP4 quantization is configured on a `Conv3d` layer via ModelOpt PTQ, the implicit GEMM kernel is used automatically during quantized inference. The integration is in `_QuantConv3d` (`modelopt/torch/quantization/nn/modules/quant_conv.py`): + +- During **calibration**, the standard cuDNN path is used (faster). +- During **quantized inference** with NVFP4 input and weight quantizers, the kernel fuses activation FP4 quantization inside the GEMM. +- For all other quantization configs, the default cuDNN path is used as fallback. ## Usage @@ -67,7 +73,9 @@ out_q = conv3d_implicit_gemm_cuda( ## API -Function: `conv3d_implicit_gemm_cuda(...)` from `experimental/conv/implicit_gemm_cuda.py` +### `conv3d_implicit_gemm_cuda` + +`from modelopt.torch.kernels.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda` | Parameter | Description | |-----------|-------------| @@ -81,23 +89,55 @@ Function: `conv3d_implicit_gemm_cuda(...)` from `experimental/conv/implicit_gemm | `quant_act` | Enable FP4 fake quantization on activations | | `fp4_block_size` | FP4 quantization block size (`16`, `32`, `64`, `128`, or `256`) | +### `fp4_fake_quant` + +`from modelopt.torch.kernels.conv.implicit_gemm_cuda import fp4_fake_quant` + +Standalone FP4 (E2M1) blockwise fake quantization with FP8 E4M3 scale quantization. Uses the same CUDA device functions as the fused path inside the GEMM kernel. + +| Parameter | Description | +|-----------|-------------| +| `x` | Input tensor (any shape; `numel` must be divisible by `block_size`) | +| `global_amax` | Scalar tensor — global abs max for scale computation | +| `block_size` | Number of elements per FP4 quantization block (default `16`) | + +## Testing and Benchmarking + +```bash +# Run tests (requires GPU) +python -m pytest modelopt/torch/kernels/conv/test_implicit_gemm.py -v + +# Run benchmarks +python -m modelopt.torch.kernels.conv.bench_implicit_gemm # default shapes +python -m modelopt.torch.kernels.conv.bench_implicit_gemm --shapes wan22 # Wan2.2 VAE shapes +python -m modelopt.torch.kernels.conv.bench_implicit_gemm --shapes all # all presets +``` + ## Status -Current state: **Prototype** +Current state: **Integrated** (registered in `QuantModuleRegistry`, auto-dispatched for NVFP4 Conv3D) Known limitations: -- API is unstable and may change without notice. -- Not registered in core quantization module registries. -- Not covered by formal export/compress integration. -- CUDA extension compile latency on first invocation. -- Validation and performance coverage are limited to local experiments. +- CUDA extension compile latency on first invocation (~seconds). +- Grouped convolution (`groups > 1`) is not supported. +- BF16 rounding error accumulates with the K dimension — expect max abs diff scaling roughly as `sqrt(K)` compared to cuDNN FP32. +- Inference only (`@torch.no_grad`) — not suitable for QAT backward pass. ## Notes -- The CUDA kernel is JIT-compiled on first call (can take several seconds). +- The CUDA kernel is JIT-compiled on first call via `torch.utils.cpp_extension.load()`. - Output shape matches `torch.nn.functional.conv3d`. -- FP4 path applies quantize-dequantize in-kernel for activation tiles. +- FP4 path applies quantize-dequantize in-kernel for activation tiles (no extra global memory pass). +- Tile config: BLOCK_M=64, BLOCK_N=64, BLOCK_K=256, 8 warps (256 threads), ~70 KB shared memory per block. + +## Files + +| File | Role | +|------|------| +| `implicit_gemm_cuda.py` | Python API and JIT compilation | +| `implicit_gemm_kernel.cu` | CUDA kernel (BF16 WMMA + FP4 quantization) | +| `implicit_gemm_binding.cpp` | PyTorch C++ extension binding | ## References From ab71984ef85b3ff809c31b7ed9aa282dbff58e34 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 10 Apr 2026 22:20:22 +0000 Subject: [PATCH 19/25] Added the test case Signed-off-by: Jingyu Xin --- .../kernels/test_implicit_gemm.py | 1011 +++++++++++++++++ 1 file changed, 1011 insertions(+) create mode 100644 tests/gpu/torch/quantization/kernels/test_implicit_gemm.py diff --git a/tests/gpu/torch/quantization/kernels/test_implicit_gemm.py b/tests/gpu/torch/quantization/kernels/test_implicit_gemm.py new file mode 100644 index 0000000000..f3ac8439ea --- /dev/null +++ b/tests/gpu/torch/quantization/kernels/test_implicit_gemm.py @@ -0,0 +1,1011 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Conv3D implicit GEMM CUDA kernel. + +Tests both non-quantized path (vs cuDNN) and FP4-quantized path (vs Triton reference). +""" + +import pytest +import torch +import torch.nn.functional as F + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +@pytest.fixture(scope="module") +def cuda_conv3d(): + """Import and return the CUDA implicit GEMM conv3d function.""" + from modelopt.torch.kernels.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda + + return conv3d_implicit_gemm_cuda + + +def _triton_fp4_available(): + """Check if the Triton FP4 fake quant kernel is available (requires compute >= 8.9).""" + try: + import modelopt.torch.quantization.triton as triton_kernel + + return hasattr(triton_kernel, "fp4_fake_quant_block") + except ImportError: + return False + + +requires_triton_fp4 = pytest.mark.skipif( + not _triton_fp4_available(), + reason="Triton fp4_fake_quant_block not available (requires compute >= 8.9)", +) + + +# BF16 WMMA accumulates in FP32 but inputs are rounded to BF16, so expect diffs. +# For large K (e.g. 3456 = 128*27), max abs diff can reach ~0.8 due to BF16 rounding +# and different accumulation order vs cuDNN's FP32 path. +ATOL = 1.0 +RTOL = 1e-3 + + +def _run_conv3d_test(cuda_conv3d, x, w, bias, stride, padding, dilation): + """Helper: run both cuDNN and implicit GEMM, compare results.""" + ref = F.conv3d(x, w, bias=bias, stride=stride, padding=padding, dilation=dilation) + out = cuda_conv3d( + x, w, bias=bias, stride=stride, padding=padding, dilation=dilation, quant_act=False + ) + assert out.shape == ref.shape, f"Shape mismatch: {out.shape} vs {ref.shape}" + abs_diff = (out - ref).abs() + max_diff = abs_diff.max().item() + # Scale tolerance with K (reduction dimension) — BF16 rounding accumulates + cin = w.shape[1] + k_size = cin * w.shape[2] * w.shape[3] * w.shape[4] + scaled_atol = ATOL * (k_size / 1000.0) ** 0.5 + assert max_diff < scaled_atol, ( + f"Max abs diff {max_diff:.6e} exceeds tolerance {scaled_atol:.4f} (K={k_size})" + ) + # Check mean diff is small (more robust than quantile for large tensors) + mean_diff = abs_diff.mean().item() + assert mean_diff < scaled_atol * 0.1, f"Mean diff {mean_diff:.6e} too high" + return max_diff + + +class TestConv3dBasic: + """Basic correctness tests with simple shapes.""" + + def test_minimal(self, cuda_conv3d): + """Smallest possible conv3d: 1x1x1 kernel, single channel.""" + x = torch.randn(1, 1, 1, 1, 1, device="cuda", dtype=torch.float32) + w = torch.randn(1, 1, 1, 1, 1, device="cuda", dtype=torch.float32) + diff = _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + # K=1, so BF16 rounding is the only source of error + assert diff < 1e-2 + + def test_single_channel_3x3x3(self, cuda_conv3d): + """Single input/output channel with 3x3x3 kernel.""" + x = torch.randn(1, 1, 5, 5, 5, device="cuda", dtype=torch.float32) + w = torch.randn(1, 1, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_multi_channel(self, cuda_conv3d): + """Multiple input and output channels.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_with_bias(self, cuda_conv3d): + """Conv3d with bias.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + b = torch.randn(32, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, b, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_batch_size(self, cuda_conv3d): + """Batch size > 1.""" + x = torch.randn(4, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + +class TestConv3dStride: + """Tests with various stride configurations.""" + + def test_stride_2(self, cuda_conv3d): + """Uniform stride of 2.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (2, 2, 2), (1, 1, 1), (1, 1, 1)) + + def test_asymmetric_stride(self, cuda_conv3d): + """Different stride per dimension.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 2, 2), (1, 1, 1), (1, 1, 1)) + + +class TestConv3dPadding: + """Tests with various padding configurations.""" + + def test_no_padding(self, cuda_conv3d): + """Zero padding.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + def test_large_padding(self, cuda_conv3d): + """Padding larger than kernel radius.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (2, 2, 2), (1, 1, 1)) + + def test_asymmetric_padding(self, cuda_conv3d): + """Different padding per dimension.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 1, 2), (1, 1, 1)) + + +class TestConv3dDilation: + """Tests with dilation.""" + + def test_dilation_2(self, cuda_conv3d): + """Uniform dilation of 2.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (2, 2, 2), (2, 2, 2)) + + def test_asymmetric_dilation(self, cuda_conv3d): + """Different dilation per dimension.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 2, 2), (1, 2, 2)) + + +class TestConv3dKernelSizes: + """Tests with non-3x3x3 kernels.""" + + def test_1x1x1_kernel(self, cuda_conv3d): + """Pointwise 1x1x1 kernel.""" + x = torch.randn(1, 64, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(128, 64, 1, 1, 1, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + def test_asymmetric_kernel(self, cuda_conv3d): + """Kernel with different sizes per dimension (e.g. 1x3x3).""" + x = torch.randn(1, 16, 8, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 1, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 1, 1), (1, 1, 1)) + + def test_5x5x5_kernel(self, cuda_conv3d): + """Larger 5x5x5 kernel.""" + x = torch.randn(1, 8, 16, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(16, 8, 5, 5, 5, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (2, 2, 2), (1, 1, 1)) + + +class TestConv3dRealisticShapes: + """Tests with shapes resembling real video diffusion models.""" + + def test_wan22_shape(self, cuda_conv3d): + """Shape from Wan2.2 video diffusion backbone.""" + x = torch.randn(1, 128, 21, 60, 106, device="cuda", dtype=torch.float32) + w = torch.randn(512, 128, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_large_cout(self, cuda_conv3d): + """Large output channel count.""" + x = torch.randn(1, 64, 8, 16, 16, device="cuda", dtype=torch.float32) + w = torch.randn(512, 64, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_large_cin(self, cuda_conv3d): + """Large input channel count.""" + x = torch.randn(1, 512, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(64, 512, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + +class TestConv3dEdgeCases: + """Edge cases for tile boundary handling.""" + + def test_m_not_aligned_to_block(self, cuda_conv3d): + """M (N*OD*OH*OW) not a multiple of BLOCK_M=64.""" + # 1*3*5*7 = 105, not divisible by 64 + x = torch.randn(1, 8, 5, 7, 9, device="cuda", dtype=torch.float32) + w = torch.randn(16, 8, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_cout_not_aligned_to_block(self, cuda_conv3d): + """Cout not a multiple of BLOCK_N=64.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(17, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_k_not_aligned_to_block(self, cuda_conv3d): + """K (Cin*kD*kH*kW) not a multiple of BLOCK_K.""" + # Cin=7, kDHW=27, K=189 -- not a multiple of 128 or 256 + x = torch.randn(1, 7, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(16, 7, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_output_size_1x1x1(self, cuda_conv3d): + """Output spatial dims are all 1.""" + x = torch.randn(1, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + def test_single_output_element(self, cuda_conv3d): + """M=1: batch=1, output 1x1x1. + + With only one output element, mean diff == max diff, so the generic + helper's mean_diff < scaled_atol * 0.1 check is too tight. Use max diff only. + """ + x = torch.randn(1, 4, 3, 3, 3, device="cuda", dtype=torch.float32) + w = torch.randn(1, 4, 3, 3, 3, device="cuda", dtype=torch.float32) + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1)) + out = cuda_conv3d( + x, w, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), quant_act=False + ) + assert out.shape == ref.shape + max_diff = (out - ref).abs().max().item() + assert max_diff < ATOL, f"Max abs diff {max_diff:.6e} exceeds tolerance {ATOL}" + + +class TestConv3dFP4BlockSize: + """Test all FP4 block size configs (BLOCK_K=256 always, FP4_BLOCK_SIZE varies). + + Non-quantized path ignores FP4_BLOCK_SIZE, so all should match cuDNN. + """ + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_non_quant_all_block_sizes(self, cuda_conv3d, fp4_block_size): + """Non-quant conv should match cuDNN regardless of fp4_block_size.""" + x = torch.randn(1, 32, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(64, 32, 3, 3, 3, device="cuda", dtype=torch.float32) + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)) + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + quant_act=False, + fp4_block_size=fp4_block_size, + ) + assert out.shape == ref.shape + assert (out - ref).abs().max().item() < ATOL + + +class TestConv3dDeterminism: + """Verify deterministic output across repeated calls.""" + + def test_deterministic(self, cuda_conv3d): + """Repeated calls produce identical output.""" + torch.manual_seed(123) + x = torch.randn(1, 32, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(64, 32, 3, 3, 3, device="cuda", dtype=torch.float32) + out1 = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + out2 = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + assert torch.equal(out1, out2), "Kernel is not deterministic" + + +# ============================================================================= +# FP4 Quantized Conv3D Tests (fused activation quantization) +# ============================================================================= + + +@pytest.fixture(scope="module") +def cuda_fp4(): + """Import and return the CUDA FP4 fake quant function.""" + from modelopt.torch.kernels.conv.implicit_gemm_cuda import fp4_fake_quant + + return fp4_fake_quant + + +class TestConv3dFP4QuantBlockSizes: + """Test fused FP4 activation quantization with all supported block sizes. + + The kernel applies blockwise FP4 quantization to the im2col'd activation tiles + along the K dimension. We verify correctness by comparing the fused kernel output + against an unfused reference: fp4_fake_quant(im2col) @ fp4_fake_quant(weight). + """ + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_runs_all_block_sizes(self, cuda_conv3d, fp4_block_size): + """All FP4 block sizes should run without errors and produce valid output.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert out.shape == (1, 32, 8, 8, 8) + assert not torch.isnan(out).any(), "Output contains NaN" + assert not torch.isinf(out).any(), "Output contains Inf" + assert out.abs().max() > 0, "Output is all zeros" + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_deterministic(self, cuda_conv3d, fp4_block_size): + """Quantized conv should be deterministic for all block sizes.""" + torch.manual_seed(42) + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + kwargs = { + "stride": (1, 1, 1), + "padding": (1, 1, 1), + "dilation": (1, 1, 1), + "act_amax": act_amax, + "quant_act": True, + "fp4_block_size": fp4_block_size, + } + out1 = cuda_conv3d(x, w, **kwargs) + out2 = cuda_conv3d(x, w, **kwargs) + assert torch.equal(out1, out2), f"Non-deterministic for fp4_block_size={fp4_block_size}" + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_vs_unfused_reference(self, cuda_conv3d, cuda_fp4, fp4_block_size): + """Compare fused kernel vs unfused: fp4(im2col) @ fp4(weight). + + Uses a shape where K is a multiple of 256 so all K-tiles are full + and block boundaries align perfectly between fused and unfused paths. + """ + torch.manual_seed(123) + # K = Cin * kD * kH * kW. Choose Cin so K is a multiple of 256. + # Cin=256, k=1x1x1 -> K=256 (exactly 1 full K-tile) + cin, cout = 256, 64 + x = torch.randn(1, cin, 4, 4, 4, device="cuda", dtype=torch.float32) + w = torch.randn(cout, cin, 1, 1, 1, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + w_amax = w.abs().max().unsqueeze(0) + + # Unfused reference: + # 1. Build im2col matrix (for 1x1x1 kernel, it's just reshape) + n, c, d, h, w_dim = x.shape + im2col = x.permute(0, 2, 3, 4, 1).reshape(-1, cin) # [M, K] + + # 2. FP4 fake-quant both matrices along K with the same block_size + im2col_q = cuda_fp4(im2col, act_amax, fp4_block_size) + w_flat = w.reshape(cout, cin).transpose(0, 1).contiguous() # [K, Cout] + w_flat_q = cuda_fp4(w_flat, w_amax, fp4_block_size) + + # 3. Matmul (in BF16 to match kernel's WMMA path) + ref_out = (im2col_q.bfloat16() @ w_flat_q.bfloat16()).float() + ref_out = ref_out.view(n, d, h, w_dim, cout).permute(0, 4, 1, 2, 3) + + # Note: the fused kernel does NOT quantize weights — weights are passed as-is. + # So for a proper comparison we need the fused kernel with pre-quantized weights. + fused_out_preq = cuda_conv3d( + x, + w_flat_q.transpose(0, 1).reshape(cout, cin, 1, 1, 1), + stride=(1, 1, 1), + padding=(0, 0, 0), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + + # The fused kernel and unfused reference should match closely. + # Differences come from BF16 accumulation order (WMMA 16x16x16 tiles vs flat matmul). + max_diff = (fused_out_preq - ref_out).abs().max().item() + mean_diff = (fused_out_preq - ref_out).abs().mean().item() + # Scale tolerance with K + scaled_atol = ATOL * (cin / 1000.0) ** 0.5 + assert max_diff < scaled_atol, ( + f"fp4_block_size={fp4_block_size}: fused vs unfused max diff {max_diff:.4f} " + f"exceeds tolerance {scaled_atol:.4f}" + ) + assert mean_diff < scaled_atol * 0.1, ( + f"fp4_block_size={fp4_block_size}: mean diff {mean_diff:.6e} too high" + ) + + def test_smaller_block_less_error(self, cuda_conv3d): + """Smaller FP4 block sizes should generally produce lower quantization error. + + Finer-grained blocks capture local ranges better, reducing quant error vs cuDNN. + Test monotonicity on a medium config: error(16) <= error(64) <= error(256) (with 1.2x slack). + """ + torch.manual_seed(42) + + # Medium K=1728: Cin=64, 3x3x3 kernel + cin, cout = 64, 64 + x = torch.randn(1, cin, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(cout, cin, 3, 3, 3, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)) + + block_sizes = [16, 32, 64, 128, 256] + errors = {} + for bs in block_sizes: + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=bs, + ) + errors[bs] = (out - ref).abs().mean().item() + + # Monotonicity: smaller blocks should have equal or lower error + for smaller, larger in [(16, 64), (16, 256), (32, 256), (64, 256)]: + assert errors[smaller] <= errors[larger] * 1.2, ( + f"Expected error({smaller})={errors[smaller]:.6f} <= " + f"error({larger})={errors[larger]:.6f} * 1.2" + ) + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_with_bias(self, cuda_conv3d, fp4_block_size): + """FP4 quantized conv with bias for all block sizes.""" + torch.manual_seed(42) + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=torch.float32) + b = torch.randn(32, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + out = cuda_conv3d( + x, + w, + bias=b, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert out.shape == (1, 32, 8, 8, 8) + assert not torch.isnan(out).any() + # Bias should shift output values + out_no_bias = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert not torch.equal(out, out_no_bias), "Bias had no effect" + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_k_not_aligned(self, cuda_conv3d, fp4_block_size): + """FP4 quant with K not aligned to BLOCK_K or fp4_block_size. + + K = Cin * kDHW = 7 * 27 = 189. The last K-tile has partial data (zeros padded). + """ + torch.manual_seed(42) + x = torch.randn(1, 7, 8, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(16, 7, 3, 3, 3, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert out.shape == (1, 16, 8, 8, 8) + assert not torch.isnan(out).any() + assert out.abs().max() > 0 + + @pytest.mark.parametrize("fp4_block_size", [16, 32, 64, 128, 256]) + def test_quant_realistic_shape(self, cuda_conv3d, fp4_block_size): + """Realistic video diffusion shape with all FP4 block sizes.""" + torch.manual_seed(42) + x = torch.randn(1, 128, 5, 8, 8, device="cuda", dtype=torch.float32) + w = torch.randn(256, 128, 3, 3, 3, device="cuda", dtype=torch.float32) + act_amax = x.abs().max().unsqueeze(0) + + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + assert out.shape == (1, 256, 5, 8, 8) + assert not torch.isnan(out).any() + assert out.abs().max() > 0 + + +# ============================================================================= +# FP4 Fake Quantization Tests +# ============================================================================= + + +def _py_fp4_fake_quant_ref(x_flat, global_amax, block_size): + """Pure Python reference for FP4 fake quant (no BF16 rounding). + + This implements the exact same algorithm as the CUDA kernel: + 1. Compute global_scale = global_amax / (6 * 448) + 2. Per block: block_max = max(|x|), scale = fp8_e4m3_roundtrip(block_max / (6 * global_scale)) * global_scale + 3. Quantize each element to nearest E2M1 level, then dequantize. + """ + import math + + # E2M1 quantization levels: {0, 0.5, 1, 1.5, 2, 3, 4, 6} + # Boundaries (midpoints): <=0.25->0, <0.75->0.5, <=1.25->1, <1.75->1.5, <=2.5->2, <3.5->3, <=5->4, >5->6 + def quantize_e2m1(scaled_abs): + if scaled_abs <= 0.25: + return 0.0 + elif scaled_abs < 0.75: + return 0.5 + elif scaled_abs <= 1.25: + return 1.0 + elif scaled_abs < 1.75: + return 1.5 + elif scaled_abs <= 2.5: + return 2.0 + elif scaled_abs < 3.5: + return 3.0 + elif scaled_abs <= 5.0: + return 4.0 + else: + return 6.0 + + def fp8_e4m3_roundtrip(val): + """Simulate FP8 E4M3 round-trip in Python.""" + if val == 0.0: + return 0.0 + sign = 1.0 if val >= 0 else -1.0 + val = abs(val) + # FP8 E4M3: bias=7, 3 mantissa bits, max=448, no inf/nan + if val > 448.0: + return sign * 448.0 + # Compute exponent + exp = math.floor(math.log2(val)) + exp = max(exp, -6) # min normal exponent for E4M3 + # Compute mantissa (3 bits) + mantissa = val / (2.0**exp) # 1.xxx + mantissa_bits = round((mantissa - 1.0) * 8.0) # 3 bits + if mantissa_bits > 7: + mantissa_bits = 0 + exp += 1 + if exp > 8: + return sign * 448.0 + # Reconstruct + result = (1.0 + mantissa_bits / 8.0) * (2.0**exp) + return sign * result + + global_scale = float(global_amax) / (6.0 * 448.0) + x_np = x_flat.cpu().float().numpy().copy() + num_blocks = len(x_np) // block_size + + for b in range(num_blocks): + block = x_np[b * block_size : (b + 1) * block_size] + block_max = float(max(abs(v) for v in block)) + + # Scale quantization + scaled = block_max / (6.0 * global_scale) + scaled = min(scaled, 448.0) + quantized_scale = fp8_e4m3_roundtrip(scaled) * global_scale + if quantized_scale < 1e-5: + quantized_scale = 1.0 + inv_scale = 1.0 / quantized_scale + + for i in range(block_size): + val = block[i] + sign = 1.0 if val >= 0 else -1.0 + q = quantize_e2m1(abs(val) * inv_scale) + x_np[b * block_size + i] = sign * q * quantized_scale + + return torch.tensor(x_np, device=x_flat.device) + + +class TestFP4FakeQuantValues: + """Test FP4 fake quant with known E2M1 table values.""" + + def test_exact_e2m1_values(self, cuda_fp4): + """E2M1 representable values should round-trip exactly (when scale=1 via amax=6*448).""" + # With global_amax = 6*448 = 2688, global_scale = 1.0 + # A single-block input with max=6 -> block_max=6, scaled=6/(6*1)=1.0 + # fp8_e4m3(1.0)=1.0, scale = 1.0*1.0 = 1.0 + block_size = 8 + vals = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], device="cuda", dtype=torch.float32) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(vals, amax, block_size) + assert torch.allclose(out, vals, atol=1e-5), f"Got {out} vs expected {vals}" + + def test_exact_e2m1_negative(self, cuda_fp4): + """Negative E2M1 values should also round-trip.""" + block_size = 8 + vals = torch.tensor([0, -0.5, -1, -1.5, -2, -3, -4, -6], device="cuda", dtype=torch.float32) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(vals, amax, block_size) + assert torch.allclose(out, vals, atol=1e-5), f"Got {out} vs expected {vals}" + + def test_below_boundary(self, cuda_fp4): + """Values slightly below E2M1 boundaries should quantize down.""" + block_size = 8 + # Boundaries: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 + # Slightly below -> quantize to lower level + inp = torch.tensor( + [0.15, 0.65, 1.15, 1.65, 2.4, 3.4, 4.9, 6.0], device="cuda", dtype=torch.float32 + ) + expected = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device="cuda", dtype=torch.float32 + ) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + assert torch.allclose(out, expected, atol=1e-5), f"Got {out} vs expected {expected}" + + def test_above_boundary(self, cuda_fp4): + """Values slightly above E2M1 boundaries should quantize up.""" + block_size = 8 + inp = torch.tensor( + [0.35, 0.85, 1.35, 1.85, 2.6, 3.6, 5.1, 6.0], device="cuda", dtype=torch.float32 + ) + expected = torch.tensor( + [0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 6.0], device="cuda", dtype=torch.float32 + ) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + assert torch.allclose(out, expected, atol=1e-5), f"Got {out} vs expected {expected}" + + def test_mixed_signs(self, cuda_fp4): + """Mixed positive/negative values.""" + block_size = 8 + inp = torch.tensor([-6, -3, -1, 0, 0.5, 2, 4, 6], device="cuda", dtype=torch.float32) + expected = torch.tensor([-6, -3, -1, 0, 0.5, 2, 4, 6], device="cuda", dtype=torch.float32) + amax = torch.tensor([6.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + assert torch.allclose(out, expected, atol=1e-5), f"Got {out} vs expected {expected}" + + +class TestFP4FakeQuantScale: + """Test FP4 scale computation and FP8 round-trip.""" + + def test_scale_factor(self, cuda_fp4): + """When amax != 6*448, scale should adjust values proportionally.""" + block_size = 8 + # global_amax = 12*448 = 5376, global_scale = 2.0 + # Input block max = 12 -> scaled = 12/(6*2) = 1.0 -> fp8(1.0) = 1.0 -> scale = 2.0 + # So input 12 -> |12|/2 = 6.0 -> q=6 -> 6*2 = 12 + inp = torch.tensor([0, 1, 2, 3, 4, 6, 8, 12], device="cuda", dtype=torch.float32) + amax = torch.tensor([12.0 * 448.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + # Expected: each val/2.0 -> quantize to E2M1 -> * 2.0 + expected = torch.tensor([0, 1, 2, 3, 4, 6, 8, 12], device="cuda", dtype=torch.float32) + assert torch.allclose(out, expected, atol=1e-4), f"Got {out} vs expected {expected}" + + def test_zero_block(self, cuda_fp4): + """All-zero block should produce all zeros.""" + block_size = 16 + inp = torch.zeros(block_size, device="cuda", dtype=torch.float32) + amax = torch.tensor([1.0], device="cuda", dtype=torch.float32) + out = cuda_fp4(inp, amax, block_size) + assert torch.equal(out, inp) + + def test_multiple_blocks(self, cuda_fp4): + """Multiple blocks with different ranges.""" + block_size = 8 + # Block 0: small values, Block 1: large values + block0 = torch.tensor([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], device="cuda") + block1 = torch.tensor([0, 1, 2, 3, 4, 5, 6, 6], device="cuda") + inp = torch.cat([block0, block1]) + amax = inp.abs().max().unsqueeze(0) + out = cuda_fp4(inp, amax, block_size) + # Each block should be independently quantized + assert out.shape == inp.shape + # Block 1 exact values should be close to E2M1 levels + assert out[8:].abs().max() <= 6.0 + 1e-5 + + +class TestFP4FakeQuantBlockSizes: + """Test different block sizes.""" + + @pytest.mark.parametrize("block_size", [8, 16, 32, 64, 128, 256]) + def test_block_sizes(self, cuda_fp4, block_size): + """FP4 quant should work for various block sizes.""" + torch.manual_seed(42) + num_blocks = 4 + inp = torch.randn(num_blocks * block_size, device="cuda", dtype=torch.float32) * 5 + amax = inp.abs().max().unsqueeze(0) + out = cuda_fp4(inp, amax, block_size) + assert out.shape == inp.shape + # Output should not be all zeros for non-zero input + assert out.abs().max() > 0 + # Output should be <= max possible after quant + assert out.abs().max() <= inp.abs().max() * 1.5 # generous bound + + +class TestFP4FakeQuantVsReference: + """Compare CUDA FP4 fake quant against Python reference implementation.""" + + @pytest.mark.parametrize("block_size", [8, 16, 32]) + def test_vs_python_ref(self, cuda_fp4, block_size): + """CUDA kernel should match the Python reference exactly.""" + torch.manual_seed(123) + num_blocks = 8 + inp = torch.randn(num_blocks * block_size, device="cuda") * 10 + amax = inp.abs().max().unsqueeze(0) + + cuda_out = cuda_fp4(inp, amax, block_size) + ref_out = _py_fp4_fake_quant_ref(inp, amax, block_size) + + assert torch.allclose(cuda_out, ref_out, atol=1e-5), ( + f"CUDA vs Python ref max diff: {(cuda_out - ref_out).abs().max().item():.6e}" + ) + + @pytest.mark.parametrize("block_size", [16, 32]) + def test_vs_python_ref_large(self, cuda_fp4, block_size): + """Larger tensor test against Python reference.""" + torch.manual_seed(456) + num_blocks = 64 + inp = torch.randn(num_blocks * block_size, device="cuda") * 20 + amax = inp.abs().max().unsqueeze(0) + + cuda_out = cuda_fp4(inp, amax, block_size) + ref_out = _py_fp4_fake_quant_ref(inp, amax, block_size) + + assert torch.allclose(cuda_out, ref_out, atol=1e-4), ( + f"CUDA vs Python ref max diff: {(cuda_out - ref_out).abs().max().item():.6e}" + ) + + +class TestFP4FakeQuantVsTriton: + """Compare CUDA FP4 fake quant against Triton fp4_fake_quant_block reference.""" + + @requires_triton_fp4 + @pytest.mark.parametrize("block_size", [16, 32, 64]) + @pytest.mark.parametrize("num_blocks", [4, 16, 64]) + def test_vs_triton(self, cuda_fp4, block_size, num_blocks): + """CUDA kernel should match the Triton fp4_fake_quant_block.""" + from modelopt.torch.quantization.triton import fp4_fake_quant_block + + torch.manual_seed(42) + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 + global_amax = x.abs().max() + + cuda_out = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + triton_out = fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=num_blocks, + tile_cols=block_size, + ) + + assert torch.allclose(cuda_out, triton_out, atol=1e-5), ( + f"CUDA vs Triton max diff: {(cuda_out - triton_out).abs().max().item():.6e}\n" + f"Mean diff: {(cuda_out - triton_out).abs().mean().item():.6e}" + ) + + +class TestFP4FakeQuantDeterminism: + """Verify FP4 quant is deterministic.""" + + def test_deterministic(self, cuda_fp4): + """Repeated calls produce identical output.""" + torch.manual_seed(99) + inp = torch.randn(256, device="cuda") * 5 + amax = inp.abs().max().unsqueeze(0) + out1 = cuda_fp4(inp, amax, 16) + out2 = cuda_fp4(inp, amax, 16) + assert torch.equal(out1, out2), "FP4 fake quant is not deterministic" + + +# ============================================================================= +# Cross-validation: experimental FP4 vs modelopt FP4 implementations +# ============================================================================= + + +def _modelopt_cuda_ext_mx_available(): + """Check if the modelopt CUDA MX extension is available.""" + try: + from modelopt.torch.quantization.extensions import get_cuda_ext_mx + + return get_cuda_ext_mx() is not None + except Exception: + return False + + +def _modelopt_dynamic_block_quantize_available(): + """Check if dynamic_block_quantize_op is available.""" + try: + from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op + + return dynamic_block_quantize_op is not None + except Exception: + return False + + +requires_cuda_ext_mx = pytest.mark.skipif( + not _modelopt_cuda_ext_mx_available(), + reason="modelopt cuda_ext_mx not available", +) + +requires_dynamic_block_quantize = pytest.mark.skipif( + not _modelopt_dynamic_block_quantize_available(), + reason="modelopt dynamic_block_quantize_op not available", +) + + +class TestFP4FakeQuantVsModelopt: + """Compare experimental CUDA FP4 fake quant against all modelopt FP4 implementations. + + This ensures the standalone FP4 kernel produces the same results as the + other modelopt quantization paths: + 1. Triton fp4_fake_quant_block (Hopper+ dynamic blockwise) + 2. cuda_ext_mx.fused_amax_convert (CUDA extension fallback) + 3. dynamic_block_quantize_op (high-level API that dispatches to either) + """ + + @requires_triton_fp4 + @pytest.mark.parametrize("block_size", [16, 32, 64]) + @pytest.mark.parametrize("seed", [42, 123, 999]) + def test_vs_triton_fp4_fake_quant_block(self, cuda_fp4, block_size, seed): + """Compare against modelopt Triton fp4_fake_quant_block.""" + from modelopt.torch.quantization.triton import fp4_fake_quant_block + + torch.manual_seed(seed) + num_blocks = 16 + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 + global_amax = x.abs().max() + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=num_blocks, + tile_cols=block_size, + ) + + assert torch.allclose(ours, theirs, atol=1e-5), ( + f"experimental vs modelopt Triton max diff: {(ours - theirs).abs().max().item():.6e}" + ) + + @requires_cuda_ext_mx + @pytest.mark.parametrize("block_size", [16, 32]) + @pytest.mark.parametrize("seed", [42, 123]) + def test_vs_cuda_ext_mx(self, cuda_fp4, block_size, seed): + """Compare against modelopt cuda_ext_mx.fused_amax_convert.""" + from modelopt.torch.quantization.extensions import get_cuda_ext_mx + from modelopt.torch.quantization.tensor_quant import mx_format_map + + cuda_ext_mx = get_cuda_ext_mx() + torch.manual_seed(seed) + num_blocks = 16 + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 + global_amax = x.abs().max() + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = cuda_ext_mx.fused_amax_convert( + x, + block_size, + getattr(cuda_ext_mx.Types, mx_format_map[(2, 1)]), + getattr(cuda_ext_mx.Types, mx_format_map[(4, 3)]), + global_amax, + ) + + assert torch.allclose(ours, theirs, atol=1e-5), ( + f"experimental vs modelopt cuda_ext_mx max diff: " + f"{(ours - theirs).abs().max().item():.6e}" + ) + + @requires_dynamic_block_quantize + @pytest.mark.parametrize("seed", [42, 123, 999]) + def test_vs_dynamic_block_quantize_op(self, cuda_fp4, seed): + """Compare against modelopt dynamic_block_quantize_op (high-level API). + + This is the function used by the actual quantization pipeline with + num_bits=4 (E2M1) and scale_bits=8 (E4M3). + Note: dynamic_block_quantize_op dispatches to Triton with default block_size=16. + """ + from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op + + block_size = 16 # dynamic_block_quantize_op uses block_size=16 for Triton path + torch.manual_seed(seed) + num_blocks = 16 + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 + global_amax = x.abs().max() + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = dynamic_block_quantize_op( + x, + block_size, + global_amax, + num_bits=4, # total bits = 1 sign + 2 exp + 1 mantissa + exponent_bits=2, + scale_num_bits=8, # FP8 E4M3 for scales + scale_exponent_bits=4, + ) + + assert torch.allclose(ours, theirs, atol=1e-5), ( + f"experimental vs modelopt dynamic_block_quantize_op max diff: " + f"{(ours - theirs).abs().max().item():.6e}" + ) + + @requires_triton_fp4 + def test_vs_triton_realistic_shape(self, cuda_fp4): + """Realistic activation shape from a Conv3D layer (flattened).""" + torch.manual_seed(42) + block_size = 16 + # Simulate a large tensor: 256 blocks of 16 elements + # (tile_rows must be power-of-2 for Triton block_ptr) + num_blocks = 256 + x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 5 + global_amax = x.abs().max() + + from modelopt.torch.quantization.triton import fp4_fake_quant_block + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=16, + tile_cols=block_size, + ) + + max_diff = (ours - theirs).abs().max().item() + mean_diff = (ours - theirs).abs().mean().item() + assert torch.allclose(ours, theirs, atol=1e-5), ( + f"Realistic shape: experimental vs Triton max diff: {max_diff:.6e}, " + f"mean diff: {mean_diff:.6e}" + ) + + @requires_triton_fp4 + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) + def test_vs_triton_input_dtypes(self, cuda_fp4, dtype): + """Test that our kernel handles different input dtypes correctly. + + Our kernel casts to float32 internally, so the result should match + Triton's output when both receive the same dtype input. + """ + from modelopt.torch.quantization.triton import fp4_fake_quant_block + + torch.manual_seed(42) + block_size = 16 + num_blocks = 8 + x = (torch.randn(num_blocks, block_size, device="cuda") * 5).to(dtype) + global_amax = x.float().abs().max() + + ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) + theirs = fp4_fake_quant_block( + x, + global_amax=global_amax, + block_size=block_size, + tile_rows=num_blocks, + tile_cols=block_size, + ) + + # Both should return the input dtype + assert ours.dtype == dtype + assert theirs.dtype == dtype + + # Compare in float32 + max_diff = (ours.float() - theirs.float()).abs().max().item() + # BF16/FP16 input rounding may cause small diffs + tol = 1e-2 if dtype != torch.float32 else 1e-5 + assert max_diff < tol, f"dtype={dtype}: experimental vs Triton max diff: {max_diff:.6e}" From e061a387813c5bb43a6b5e9c5f097312aa057997 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 10 Apr 2026 22:25:56 +0000 Subject: [PATCH 20/25] Update the readme and the test case Signed-off-by: Jingyu Xin --- modelopt/torch/kernels/conv/README.md | 4 +- .../kernels/test_implicit_gemm.py | 481 ++++++++++++++++++ 2 files changed, 483 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/kernels/conv/README.md b/modelopt/torch/kernels/conv/README.md index 417b03c811..9221c73191 100644 --- a/modelopt/torch/kernels/conv/README.md +++ b/modelopt/torch/kernels/conv/README.md @@ -14,7 +14,7 @@ This kernel is integrated into `modelopt.torch.quantization` via `_QuantConv3d` | Stage | Precision | |-------|-----------| -| Input / output tensors | FP32 (user-facing) | +| Input / output tensors | FP32, FP16, or BF16 (dtype is preserved) | | Internal compute | BF16 via WMMA m16n16k16 tensor cores | | Accumulation | FP32 | | FP4 activation quantization | E2M1 values, FP8 E4M3 scales | @@ -105,7 +105,7 @@ Standalone FP4 (E2M1) blockwise fake quantization with FP8 E4M3 scale quantizati ```bash # Run tests (requires GPU) -python -m pytest modelopt/torch/kernels/conv/test_implicit_gemm.py -v +python -m pytest tests/gpu/torch/quantization/kernels/test_implicit_gemm.py -v # Run benchmarks python -m modelopt.torch.kernels.conv.bench_implicit_gemm # default shapes diff --git a/tests/gpu/torch/quantization/kernels/test_implicit_gemm.py b/tests/gpu/torch/quantization/kernels/test_implicit_gemm.py index f3ac8439ea..26dd995f56 100644 --- a/tests/gpu/torch/quantization/kernels/test_implicit_gemm.py +++ b/tests/gpu/torch/quantization/kernels/test_implicit_gemm.py @@ -1009,3 +1009,484 @@ def test_vs_triton_input_dtypes(self, cuda_fp4, dtype): # BF16/FP16 input rounding may cause small diffs tol = 1e-2 if dtype != torch.float32 else 1e-5 assert max_diff < tol, f"dtype={dtype}: experimental vs Triton max diff: {max_diff:.6e}" + + +# ============================================================================= +# Input Validation / Error Path Tests +# ============================================================================= + + +class TestConv3dInputValidation: + """Verify error paths raise appropriate exceptions.""" + + def test_invalid_fp4_block_size(self, cuda_conv3d): + """fp4_block_size not in {16, 32, 64, 128, 256} should raise ValueError.""" + x = torch.randn(1, 4, 4, 4, 4, device="cuda") + w = torch.randn(8, 4, 3, 3, 3, device="cuda") + with pytest.raises(ValueError, match="fp4_block_size"): + cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), fp4_block_size=7) + + def test_non_5d_input(self, cuda_conv3d): + """Non-5D tensors should raise ValueError.""" + x = torch.randn(1, 4, 4, 4, device="cuda") # 4D + w = torch.randn(8, 4, 3, 3, 3, device="cuda") + with pytest.raises(ValueError, match="5D"): + cuda_conv3d(x, w) + + def test_non_5d_weight(self, cuda_conv3d): + """Non-5D weight should raise ValueError.""" + x = torch.randn(1, 4, 4, 4, 4, device="cuda") + w = torch.randn(8, 4, 3, 3, device="cuda") # 4D + with pytest.raises(ValueError, match="5D"): + cuda_conv3d(x, w) + + def test_grouped_conv_error(self, cuda_conv3d): + """Mismatched Cin (groups > 1) should raise ValueError.""" + x = torch.randn(1, 8, 4, 4, 4, device="cuda") + w = torch.randn(8, 4, 3, 3, 3, device="cuda") # Cin=4 != x.Cin=8 + with pytest.raises(ValueError, match="Grouped convolution"): + cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1)) + + def test_quant_act_without_amax(self, cuda_conv3d): + """quant_act=True without act_amax should raise ValueError.""" + x = torch.randn(1, 4, 4, 4, 4, device="cuda") + w = torch.randn(8, 4, 3, 3, 3, device="cuda") + with pytest.raises(ValueError, match="act_amax"): + cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=True, act_amax=None) + + def test_fp4_numel_not_divisible(self, cuda_fp4): + """fp4_fake_quant should error when numel is not divisible by block_size.""" + inp = torch.randn(17, device="cuda") + amax = torch.tensor([1.0], device="cuda") + with pytest.raises(AssertionError, match="divisible"): + cuda_fp4(inp, amax, block_size=16) + + +# ============================================================================= +# Input Dtype Tests +# ============================================================================= + + +class TestConv3dInputDtypes: + """Verify conv3d works with non-float32 inputs and preserves output dtype.""" + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_dtype_preservation(self, cuda_conv3d, dtype): + """Output dtype should match input dtype.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=dtype) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=dtype) + out = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + assert out.dtype == dtype, f"Expected {dtype}, got {out.dtype}" + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_dtype_correctness(self, cuda_conv3d, dtype): + """Non-float32 inputs should produce correct results (vs F.conv3d in float32).""" + torch.manual_seed(42) + x_fp32 = torch.randn(1, 16, 8, 8, 8, device="cuda") + w_fp32 = torch.randn(32, 16, 3, 3, 3, device="cuda") + x = x_fp32.to(dtype) + w = w_fp32.to(dtype) + + out = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + ref = F.conv3d(x_fp32, w_fp32, stride=(1, 1, 1), padding=(1, 1, 1)) + + # Both BF16 input rounding and internal BF16 WMMA contribute to error + max_diff = (out.float() - ref).abs().max().item() + k_size = 16 * 27 + scaled_atol = ATOL * (k_size / 1000.0) ** 0.5 * 2 # extra slack for input rounding + assert max_diff < scaled_atol, f"dtype={dtype}: max diff {max_diff:.4f} > {scaled_atol:.4f}" + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_dtype_quant_path(self, cuda_conv3d, dtype): + """FP4 quantized path should also work with non-float32 inputs.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda", dtype=dtype) + w = torch.randn(32, 16, 3, 3, 3, device="cuda", dtype=dtype) + act_amax = x.float().abs().max().unsqueeze(0) + + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=16, + ) + assert out.dtype == dtype + assert not torch.isnan(out).any() + assert out.abs().max() > 0 + + +# ============================================================================= +# Non-Contiguous Input Tests +# ============================================================================= + + +class TestConv3dNonContiguous: + """Verify kernel handles non-contiguous tensors (via internal .contiguous() calls).""" + + def test_non_contiguous_input(self, cuda_conv3d): + """Permuted (non-contiguous) input should produce correct results.""" + torch.manual_seed(42) + # Create non-contiguous tensor via permute + permute back + x_base = torch.randn(1, 8, 8, 8, 16, device="cuda") + x = x_base.permute(0, 4, 1, 2, 3) # [1, 16, 8, 8, 8] but non-contiguous + assert not x.is_contiguous() + + w = torch.randn(32, 16, 3, 3, 3, device="cuda") + x_contig = x.contiguous() + + out = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + ref = cuda_conv3d(x_contig, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + assert torch.equal(out, ref), "Non-contiguous input produced different results" + + def test_non_contiguous_weight(self, cuda_conv3d): + """Transposed (non-contiguous) weight should produce correct results.""" + torch.manual_seed(42) + x = torch.randn(1, 16, 8, 8, 8, device="cuda") + # Create non-contiguous weight + w_base = torch.randn(16, 32, 3, 3, 3, device="cuda") + w = w_base.transpose(0, 1) # [32, 16, ...] but non-contiguous + assert not w.is_contiguous() + + w_contig = w.contiguous() + out = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + ref = cuda_conv3d(x, w_contig, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + assert torch.equal(out, ref), "Non-contiguous weight produced different results" + + +# ============================================================================= +# Combined Conv Parameter Tests +# ============================================================================= + + +class TestConv3dCombinedParams: + """Test combinations of stride + dilation + padding that were never combined.""" + + def test_stride_and_dilation(self, cuda_conv3d): + """Stride > 1 and dilation > 1 simultaneously.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda") + w = torch.randn(32, 16, 3, 3, 3, device="cuda") + _run_conv3d_test(cuda_conv3d, x, w, None, (2, 2, 2), (1, 1, 1), (2, 2, 2)) + + def test_asymmetric_stride_and_padding(self, cuda_conv3d): + """Asymmetric stride with asymmetric padding.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda") + w = torch.randn(32, 16, 3, 3, 3, device="cuda") + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 2, 2), (0, 1, 2), (1, 1, 1)) + + def test_all_non_default(self, cuda_conv3d): + """Non-default stride + padding + dilation all at once.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda") + w = torch.randn(32, 16, 3, 3, 3, device="cuda") + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 2, 1), (1, 0, 1), (1, 2, 1)) + + def test_bias_with_stride(self, cuda_conv3d): + """Bias with non-default stride.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda") + w = torch.randn(32, 16, 3, 3, 3, device="cuda") + b = torch.randn(32, device="cuda") + _run_conv3d_test(cuda_conv3d, x, w, b, (2, 2, 2), (1, 1, 1), (1, 1, 1)) + + def test_bias_with_dilation(self, cuda_conv3d): + """Bias with non-default dilation.""" + x = torch.randn(1, 16, 16, 16, 16, device="cuda") + w = torch.randn(32, 16, 3, 3, 3, device="cuda") + b = torch.randn(32, device="cuda") + _run_conv3d_test(cuda_conv3d, x, w, b, (1, 1, 1), (2, 2, 2), (2, 2, 2)) + + +# ============================================================================= +# FP4 Quantized Path: Advanced Conv Params +# ============================================================================= + + +def _run_quant_smoke_test(cuda_conv3d, x, w, bias, stride, padding, dilation, fp4_block_size=16): + """Helper: run FP4-quantized conv and verify basic sanity.""" + act_amax = x.abs().max().unsqueeze(0) + out = cuda_conv3d( + x, + w, + bias=bias, + stride=stride, + padding=padding, + dilation=dilation, + act_amax=act_amax, + quant_act=True, + fp4_block_size=fp4_block_size, + ) + ref = F.conv3d(x, w, bias=bias, stride=stride, padding=padding, dilation=dilation) + assert out.shape == ref.shape, f"Shape mismatch: {out.shape} vs {ref.shape}" + assert not torch.isnan(out).any(), "Output contains NaN" + assert not torch.isinf(out).any(), "Output contains Inf" + # Quantized output should be in a reasonable range relative to reference + if ref.abs().max() > 0: + ratio = out.abs().max().item() / ref.abs().max().item() + assert 0.01 < ratio < 100, f"Output magnitude ratio {ratio:.2f} is unreasonable" + return out + + +class TestConv3dFP4QuantAdvanced: + """FP4 quantized path with non-trivial stride, dilation, and kernel shapes.""" + + def test_quant_with_stride(self, cuda_conv3d): + """FP4 quant with stride=(2,2,2).""" + torch.manual_seed(42) + x = torch.randn(1, 16, 16, 16, 16, device="cuda") + w = torch.randn(32, 16, 3, 3, 3, device="cuda") + _run_quant_smoke_test(cuda_conv3d, x, w, None, (2, 2, 2), (1, 1, 1), (1, 1, 1)) + + def test_quant_with_dilation(self, cuda_conv3d): + """FP4 quant with dilation=(2,2,2).""" + torch.manual_seed(42) + x = torch.randn(1, 16, 16, 16, 16, device="cuda") + w = torch.randn(32, 16, 3, 3, 3, device="cuda") + _run_quant_smoke_test(cuda_conv3d, x, w, None, (1, 1, 1), (2, 2, 2), (2, 2, 2)) + + def test_quant_with_asymmetric_kernel(self, cuda_conv3d): + """FP4 quant with 1x3x3 kernel.""" + torch.manual_seed(42) + x = torch.randn(1, 16, 8, 16, 16, device="cuda") + w = torch.randn(32, 16, 1, 3, 3, device="cuda") + _run_quant_smoke_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 1, 1), (1, 1, 1)) + + def test_quant_with_stride_and_dilation(self, cuda_conv3d): + """FP4 quant with both stride>1 and dilation>1.""" + torch.manual_seed(42) + x = torch.randn(1, 16, 16, 16, 16, device="cuda") + w = torch.randn(32, 16, 3, 3, 3, device="cuda") + _run_quant_smoke_test(cuda_conv3d, x, w, None, (2, 2, 2), (1, 1, 1), (2, 2, 2)) + + def test_quant_with_no_padding(self, cuda_conv3d): + """FP4 quant with padding=(0,0,0).""" + torch.manual_seed(42) + x = torch.randn(1, 16, 8, 8, 8, device="cuda") + w = torch.randn(32, 16, 3, 3, 3, device="cuda") + _run_quant_smoke_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + def test_quant_bias_reference(self, cuda_conv3d, cuda_fp4): + """FP4 quant + bias: verify bias is added correctly by comparing with/without. + + The difference between bias and no-bias output should equal the bias broadcast. + """ + torch.manual_seed(42) + cin, cout = 256, 64 + x = torch.randn(1, cin, 4, 4, 4, device="cuda") + w = torch.randn(cout, cin, 1, 1, 1, device="cuda") + b = torch.randn(cout, device="cuda") + act_amax = x.abs().max().unsqueeze(0) + + kwargs = { + "stride": (1, 1, 1), + "padding": (0, 0, 0), + "dilation": (1, 1, 1), + "act_amax": act_amax, + "quant_act": True, + "fp4_block_size": 16, + } + out_bias = cuda_conv3d(x, w, bias=b, **kwargs) + out_no_bias = cuda_conv3d(x, w, bias=None, **kwargs) + + # Difference should be the bias broadcast over spatial dims + diff = out_bias - out_no_bias # [1, Cout, D, H, W] + expected_bias = b.view(1, -1, 1, 1, 1).expand_as(diff) + assert torch.allclose(diff, expected_bias, atol=1e-5), ( + f"Bias diff mismatch: max {(diff - expected_bias).abs().max().item():.6e}" + ) + + +# ============================================================================= +# Zero / Degenerate Input Tests +# ============================================================================= + + +class TestConv3dZeroInputs: + """Tests with zero and degenerate inputs.""" + + def test_zero_input(self, cuda_conv3d): + """Zero activation tensor should produce zero (or bias-only) output.""" + x = torch.zeros(1, 16, 8, 8, 8, device="cuda") + w = torch.randn(32, 16, 3, 3, 3, device="cuda") + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1)) + out = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + assert torch.allclose(out, ref, atol=1e-5), f"Max diff: {(out - ref).abs().max().item()}" + + def test_zero_weight(self, cuda_conv3d): + """Zero weight tensor should produce zero output.""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda") + w = torch.zeros(32, 16, 3, 3, 3, device="cuda") + out = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + assert torch.allclose(out, torch.zeros_like(out), atol=1e-5) + + def test_zero_input_quant(self, cuda_conv3d): + """Zero input with FP4 quant should not produce NaN.""" + x = torch.zeros(1, 16, 8, 8, 8, device="cuda") + w = torch.randn(32, 16, 3, 3, 3, device="cuda") + # act_amax=0 is a tricky edge case — the kernel's scale guard should handle it + act_amax = torch.tensor([1e-10], device="cuda") # near-zero but not exactly 0 + out = cuda_conv3d( + x, + w, + stride=(1, 1, 1), + padding=(1, 1, 1), + act_amax=act_amax, + quant_act=True, + fp4_block_size=16, + ) + assert not torch.isnan(out).any(), "Zero input with quant produced NaN" + + +# ============================================================================= +# Numerical Stability Tests +# ============================================================================= + + +class TestConv3dNumericalStability: + """Test with extreme value ranges.""" + + def test_large_values(self, cuda_conv3d): + """Large input values (randn * 100).""" + torch.manual_seed(42) + x = torch.randn(1, 16, 8, 8, 8, device="cuda") * 100 + w = torch.randn(32, 16, 3, 3, 3, device="cuda") * 100 + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1)) + out = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + # With large values, BF16 rounding error scales proportionally + rel_err = (out - ref).abs().max().item() / ref.abs().max().item() + assert rel_err < 0.05, f"Relative error {rel_err:.4f} too high for large values" + + def test_small_values(self, cuda_conv3d): + """Small input values (randn * 1e-3).""" + torch.manual_seed(42) + x = torch.randn(1, 16, 8, 8, 8, device="cuda") * 1e-3 + w = torch.randn(32, 16, 3, 3, 3, device="cuda") * 1e-3 + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1)) + out = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + assert out.shape == ref.shape + # Small values: absolute error is small, relative error may be larger due to BF16 + max_diff = (out - ref).abs().max().item() + assert max_diff < 1e-5, f"Max diff {max_diff:.6e} for small values" + + def test_uniform_input(self, cuda_conv3d): + """Uniform input (all ones) — exposes accumulation patterns.""" + x = torch.ones(1, 16, 8, 8, 8, device="cuda") + w = torch.ones(32, 16, 3, 3, 3, device="cuda") + ref = F.conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1)) + out = cuda_conv3d(x, w, stride=(1, 1, 1), padding=(1, 1, 1), quant_act=False) + k_size = 16 * 27 + scaled_atol = ATOL * (k_size / 1000.0) ** 0.5 + max_diff = (out - ref).abs().max().item() + assert max_diff < scaled_atol, f"Uniform input: max diff {max_diff:.4f}" + + +# ============================================================================= +# Exact Block Boundary Tests +# ============================================================================= + + +class TestConv3dExactBoundaries: + """Shapes that land exactly on BLOCK_M=64, BLOCK_N=64, BLOCK_K=256 boundaries.""" + + def test_m_exact_128(self, cuda_conv3d): + """M = 128 = 2 * BLOCK_M (exactly 2 M-tiles, no remainder).""" + # batch=1, output 4x4x8 = 128 with kernel 1x1x1 + x = torch.randn(1, 32, 4, 4, 8, device="cuda") + w = torch.randn(64, 32, 1, 1, 1, device="cuda") + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + def test_k_exact_512(self, cuda_conv3d): + """K = 512 = 2 * BLOCK_K (exactly 2 K-tiles, no remainder).""" + # Cin=512, kernel 1x1x1 -> K=512 + x = torch.randn(1, 512, 4, 4, 4, device="cuda") + w = torch.randn(64, 512, 1, 1, 1, device="cuda") + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + def test_cout_exact_64(self, cuda_conv3d): + """Cout = 64 = 1 * BLOCK_N (exactly 1 N-tile).""" + x = torch.randn(1, 16, 8, 8, 8, device="cuda") + w = torch.randn(64, 16, 3, 3, 3, device="cuda") + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (1, 1, 1), (1, 1, 1)) + + def test_all_exact_multiples(self, cuda_conv3d): + """M=64, N=64, K=256 — single tile in each dimension.""" + # batch=1, Cin=256, kernel 1x1x1 -> K=256; output 4x4x4=64; Cout=64 + x = torch.randn(1, 256, 4, 4, 4, device="cuda") + w = torch.randn(64, 256, 1, 1, 1, device="cuda") + _run_conv3d_test(cuda_conv3d, x, w, None, (1, 1, 1), (0, 0, 0), (1, 1, 1)) + + +# ============================================================================= +# FP4 Fake Quant: Shape and Edge Case Tests +# ============================================================================= + + +class TestFP4FakeQuantShapes: + """Test fp4_fake_quant with multi-dimensional inputs.""" + + def test_3d_shape_preservation(self, cuda_fp4): + """3D input should preserve shape after quantization.""" + inp = torch.randn(4, 8, 32, device="cuda") # numel=1024 + amax = inp.abs().max().unsqueeze(0) + out = cuda_fp4(inp, amax, block_size=16) + assert out.shape == (4, 8, 32) + + def test_4d_shape_preservation(self, cuda_fp4): + """4D input should preserve shape.""" + inp = torch.randn(2, 4, 8, 16, device="cuda") # numel=1024 + amax = inp.abs().max().unsqueeze(0) + out = cuda_fp4(inp, amax, block_size=16) + assert out.shape == (2, 4, 8, 16) + + def test_5d_shape_preservation(self, cuda_fp4): + """5D input (like a Conv3D activation) should preserve shape.""" + inp = torch.randn(1, 4, 4, 4, 16, device="cuda") # numel=1024 + amax = inp.abs().max().unsqueeze(0) + out = cuda_fp4(inp, amax, block_size=16) + assert out.shape == (1, 4, 4, 4, 16) + + def test_multidim_correctness(self, cuda_fp4): + """Multi-dim quantization should equal flatten -> quant -> reshape.""" + torch.manual_seed(42) + inp = torch.randn(4, 8, 32, device="cuda") + amax = inp.abs().max().unsqueeze(0) + + out_3d = cuda_fp4(inp, amax, block_size=16) + out_flat = cuda_fp4(inp.reshape(-1), amax, block_size=16).reshape(4, 8, 32) + assert torch.equal(out_3d, out_flat) + + +class TestFP4FakeQuantEdgeCases: + """Edge cases for fp4_fake_quant.""" + + def test_very_large_values(self, cuda_fp4): + """Very large input values should saturate to max E2M1 level, not produce NaN.""" + inp = torch.tensor([1e6, -1e6, 5e5, -5e5, 1e4, -1e4, 100, -100], device="cuda") + amax = inp.abs().max().unsqueeze(0) + out = cuda_fp4(inp, amax, block_size=8) + assert not torch.isnan(out).any() + assert not torch.isinf(out).any() + + def test_very_small_values(self, cuda_fp4): + """Very small input values should quantize to zero or near-zero.""" + inp = torch.tensor([1e-8, -1e-8, 1e-10, -1e-10, 1e-6, -1e-6, 0, 0], device="cuda") + amax = torch.tensor([1.0], device="cuda") + out = cuda_fp4(inp, amax, block_size=8) + assert not torch.isnan(out).any() + # Very small values relative to amax should quantize to ~0 + assert out.abs().max() < 1e-3 + + def test_uniform_block(self, cuda_fp4): + """All-same-value block.""" + inp = torch.full((16,), 3.0, device="cuda") + amax = inp.abs().max().unsqueeze(0) + out = cuda_fp4(inp, amax, block_size=16) + # All elements are the same, so they should all quantize to the same E2M1 level + assert (out == out[0]).all(), f"Uniform block produced non-uniform output: {out}" + + def test_near_zero_amax(self, cuda_fp4): + """Very small global_amax should not produce NaN/Inf.""" + inp = torch.randn(16, device="cuda") * 1e-8 + amax = torch.tensor([1e-10], device="cuda") + out = cuda_fp4(inp, amax, block_size=16) + assert not torch.isnan(out).any(), "Near-zero amax produced NaN" + assert not torch.isinf(out).any(), "Near-zero amax produced Inf" From 58a0421b0445ad2aa5ab61c90a7236e223c21f0a Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Sun, 12 Apr 2026 04:43:54 +0000 Subject: [PATCH 21/25] Update the test case Signed-off-by: Jingyu Xin --- tests/examples/diffusers/conftest.py | 31 ++++++ tests/examples/diffusers/test_diffusers.py | 99 +++++++++++++++++++ .../test_export_diffusers_hf_ckpt.py | 65 ++++++++++++ 3 files changed, 195 insertions(+) create mode 100644 tests/examples/diffusers/conftest.py diff --git a/tests/examples/diffusers/conftest.py b/tests/examples/diffusers/conftest.py new file mode 100644 index 0000000000..8893d188d9 --- /dev/null +++ b/tests/examples/diffusers/conftest.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +@pytest.fixture(scope="session") +def tiny_wan22_path(tmp_path_factory): + """Create a tiny Wan 2.2 (14B-style) pipeline and return its path. + + Built once per session and shared across all tests that need it. + """ + try: + from _test_utils.torch.diffusers_models import create_tiny_wan22_pipeline_dir + except ImportError: + pytest.skip("Wan 2.2 diffusers models not available (requires diffusers with WanPipeline)") + + tmp_path = tmp_path_factory.mktemp("wan22") + return str(create_tiny_wan22_pipeline_dir(tmp_path)) diff --git a/tests/examples/diffusers/test_diffusers.py b/tests/examples/diffusers/test_diffusers.py index 5bc8f981ec..396776f878 100644 --- a/tests/examples/diffusers/test_diffusers.py +++ b/tests/examples/diffusers/test_diffusers.py @@ -21,6 +21,27 @@ from _test_utils.examples.run_command import run_example_command from _test_utils.torch.misc import minimum_sm +# Tiny video model args — override MODEL_DEFAULTS for fast CI +_WAN22_TINY_EXTRA_PARAMS = [ + "--extra-param", + "height=16", + "--extra-param", + "width=16", + "--extra-param", + "num_frames=5", +] + +_WAN22_FAST_CALIB_ARGS = [ + "--calib-size", + "2", + "--batch-size", + "1", + "--n-steps", + "2", + "--model-dtype", + "BFloat16", +] + class DiffuserModel(NamedTuple): dtype: str @@ -151,6 +172,84 @@ def test_diffusers_quantization( model.inference(tmp_path) +def _run_wan22_quantize( + tiny_wan22_path: str, tmp_path: Path, format_type: str, quant_algo: str, collect_method: str +) -> None: + """Run quantize.py for Wan 2.2 with the tiny model.""" + ckpt_path = str(tmp_path / f"wan22_{format_type}.pt") + cmd_args = [ + "python", + "quantize.py", + "--model", + "wan2.2-t2v-14b", + "--override-model-path", + tiny_wan22_path, + "--format", + format_type, + "--quant-algo", + quant_algo, + "--collect-method", + collect_method, + "--trt-high-precision-dtype", + "BFloat16", + "--quantized-torch-ckpt-save-path", + ckpt_path, + *_WAN22_FAST_CALIB_ARGS, + *_WAN22_TINY_EXTRA_PARAMS, + ] + run_example_command(cmd_args, "diffusers/quantization") + + +def _run_wan22_restore( + tiny_wan22_path: str, tmp_path: Path, format_type: str, quant_algo: str, collect_method: str +) -> None: + """Restore a Wan 2.2 quantized checkpoint.""" + ckpt_path = str(tmp_path / f"wan22_{format_type}.pt") + cmd_args = [ + "python", + "quantize.py", + "--model", + "wan2.2-t2v-14b", + "--override-model-path", + tiny_wan22_path, + "--format", + format_type, + "--quant-algo", + quant_algo, + "--collect-method", + collect_method, + "--trt-high-precision-dtype", + "BFloat16", + "--restore-from", + ckpt_path, + *_WAN22_FAST_CALIB_ARGS, + *_WAN22_TINY_EXTRA_PARAMS, + ] + run_example_command(cmd_args, "diffusers/quantization") + + +def test_wan22_int8_smoothquant(tiny_wan22_path: str, tmp_path: Path) -> None: + """Wan 2.2 INT8 SmoothQuant: quantize + restore.""" + _run_wan22_quantize(tiny_wan22_path, tmp_path, "int8", "smoothquant", "min-mean") + _run_wan22_restore(tiny_wan22_path, tmp_path, "int8", "smoothquant", "min-mean") + + +@pytest.mark.parametrize( + ("format_type", "quant_algo"), + [ + pytest.param("fp8", "max", marks=minimum_sm(89)), + pytest.param("fp4", "max", marks=minimum_sm(89)), + ], + ids=["wan22_fp8_max", "wan22_fp4_max"], +) +def test_wan22_fp8_fp4( + tiny_wan22_path: str, tmp_path: Path, format_type: str, quant_algo: str +) -> None: + """Wan 2.2 FP8/FP4: quantize + restore (requires SM89+).""" + _run_wan22_quantize(tiny_wan22_path, tmp_path, format_type, quant_algo, "default") + _run_wan22_restore(tiny_wan22_path, tmp_path, format_type, quant_algo, "default") + + @pytest.mark.parametrize( ("model_name", "model_path", "torch_compile"), [ diff --git a/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py b/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py index 6db1eaeb68..5dca512b5b 100644 --- a/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py +++ b/tests/examples/diffusers/test_export_diffusers_hf_ckpt.py @@ -21,6 +21,16 @@ from _test_utils.examples.run_command import run_example_command from _test_utils.torch.misc import minimum_sm +# Tiny video model args — override MODEL_DEFAULTS for fast CI +_WAN22_TINY_EXTRA_PARAMS = [ + "--extra-param", + "height=16", + "--extra-param", + "width=16", + "--extra-param", + "num_frames=5", +] + class DiffuserHfExportModel(NamedTuple): name: str @@ -128,3 +138,58 @@ def test_diffusers_hf_ckpt_export(model: DiffuserHfExportModel, tmp_path: Path) weight_files = list(hf_ckpt_dir.rglob("*.safetensors")) + list(hf_ckpt_dir.rglob("*.bin")) assert len(weight_files) > 0, f"No weight files (.safetensors or .bin) found in {hf_ckpt_dir}" + + +@pytest.mark.parametrize( + ("format_type", "quant_algo", "collect_method"), + [ + ("int8", "smoothquant", "min-mean"), + pytest.param("fp8", "max", "default", marks=minimum_sm(89)), + ], + ids=["wan22_int8_smoothquant", "wan22_fp8_max"], +) +def test_wan22_hf_ckpt_export( + tiny_wan22_path: str, + tmp_path: Path, + format_type: str, + quant_algo: str, + collect_method: str, +) -> None: + """Quantize tiny Wan 2.2 and export to HF checkpoint.""" + hf_ckpt_dir = tmp_path / f"wan22_{format_type}_hf_ckpt" + cmd_args = [ + "python", + "quantize.py", + "--model", + "wan2.2-t2v-14b", + "--override-model-path", + tiny_wan22_path, + "--format", + format_type, + "--quant-algo", + quant_algo, + "--collect-method", + collect_method, + "--model-dtype", + "BFloat16", + "--trt-high-precision-dtype", + "BFloat16", + "--calib-size", + "2", + "--batch-size", + "1", + "--n-steps", + "2", + "--hf-ckpt-dir", + str(hf_ckpt_dir), + *_WAN22_TINY_EXTRA_PARAMS, + ] + run_example_command(cmd_args, "diffusers/quantization") + + assert hf_ckpt_dir.exists(), f"HF checkpoint directory was not created: {hf_ckpt_dir}" + + config_files = list(hf_ckpt_dir.rglob("config.json")) + assert len(config_files) > 0, f"No config.json found in {hf_ckpt_dir}" + + weight_files = list(hf_ckpt_dir.rglob("*.safetensors")) + list(hf_ckpt_dir.rglob("*.bin")) + assert len(weight_files) > 0, f"No weight files (.safetensors or .bin) found in {hf_ckpt_dir}" From 257b8e68052df14cfc75519ce20b16235db3d241 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Mon, 13 Apr 2026 20:18:23 +0000 Subject: [PATCH 22/25] Updated the README Signed-off-by: Jingyu Xin --- modelopt/torch/kernels/conv/README.md | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/modelopt/torch/kernels/conv/README.md b/modelopt/torch/kernels/conv/README.md index 9221c73191..edb072e9cd 100644 --- a/modelopt/torch/kernels/conv/README.md +++ b/modelopt/torch/kernels/conv/README.md @@ -101,16 +101,11 @@ Standalone FP4 (E2M1) blockwise fake quantization with FP8 E4M3 scale quantizati | `global_amax` | Scalar tensor — global abs max for scale computation | | `block_size` | Number of elements per FP4 quantization block (default `16`) | -## Testing and Benchmarking +## Testing ```bash # Run tests (requires GPU) python -m pytest tests/gpu/torch/quantization/kernels/test_implicit_gemm.py -v - -# Run benchmarks -python -m modelopt.torch.kernels.conv.bench_implicit_gemm # default shapes -python -m modelopt.torch.kernels.conv.bench_implicit_gemm --shapes wan22 # Wan2.2 VAE shapes -python -m modelopt.torch.kernels.conv.bench_implicit_gemm --shapes all # all presets ``` ## Status From f98c14da428c7a950419b9dc1018986a46ad462b Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Mon, 13 Apr 2026 20:23:43 +0000 Subject: [PATCH 23/25] Some clean ups Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/pipeline_manager.py | 7 ++++++- examples/diffusers/quantization/quantize.py | 8 +++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/diffusers/quantization/pipeline_manager.py b/examples/diffusers/quantization/pipeline_manager.py index b52ef4852e..1d79df978b 100644 --- a/examples/diffusers/quantization/pipeline_manager.py +++ b/examples/diffusers/quantization/pipeline_manager.py @@ -170,9 +170,14 @@ def iter_backbones(self) -> Iterator[tuple[str, torch.nn.Module]]: if name == "video_decoder": self._ensure_ltx2_video_decoder_cached() yield name, self._video_decoder - else: + elif name == "transformer": self._ensure_ltx2_transformer_cached() yield name, self._transformer + else: + raise ValueError( + f"Unsupported LTX-2 backbone name '{name}'. " + "Expected 'transformer' or 'video_decoder'." + ) return for name in names: diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 944b395536..24f283e365 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -253,6 +253,7 @@ def save_checkpoint( ckpt_path.mkdir(parents=True, exist_ok=True) filename = f"{backbone_name}.pt" if backbone_name else "backbone.pt" target_path = ckpt_path / filename + self.logger.info(f"Saving backbone to {target_path}") mto.save(backbone, str(target_path)) @@ -313,18 +314,15 @@ def restore_checkpoint(self) -> None: if not restore_path.exists() or not restore_path.is_dir(): raise FileNotFoundError(f"Checkpoint directory not found: {restore_path}") - backbones = list(self.pipeline_manager.iter_backbones()) - for backbone_name, backbone in backbones: + for backbone_name, backbone in self.pipeline_manager.iter_backbones(): source_path = restore_path / f"{backbone_name}.pt" - # Legacy fallback: only safe when there is exactly one backbone - if not source_path.exists() and len(backbones) == 1: - source_path = restore_path / "backbone.pt" if not source_path.exists(): raise FileNotFoundError( f"Checkpoint not found for '{backbone_name}' in {restore_path}" ) self.logger.info(f"Restoring {backbone_name} from {source_path}") mto.restore(backbone, str(source_path)) + self.logger.info("Checkpoints restored successfully") # TODO: should not do the any data type From 075eae25cbe43325679f53f0fd58de79c6fe1c24 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Mon, 13 Apr 2026 20:29:58 +0000 Subject: [PATCH 24/25] Update the README + Code, add warning for the groups > 1 Signed-off-by: Jingyu Xin --- modelopt/torch/kernels/conv/README.md | 2 +- modelopt/torch/quantization/nn/modules/quant_conv.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/kernels/conv/README.md b/modelopt/torch/kernels/conv/README.md index edb072e9cd..00d1e3f8c5 100644 --- a/modelopt/torch/kernels/conv/README.md +++ b/modelopt/torch/kernels/conv/README.md @@ -115,7 +115,7 @@ Current state: **Integrated** (registered in `QuantModuleRegistry`, auto-dispatc Known limitations: - CUDA extension compile latency on first invocation (~seconds). -- Grouped convolution (`groups > 1`) is not supported. +- Grouped convolution (`groups > 1`) is not supported. In the ModelOpt E2E flow, `_QuantConv3d` automatically falls back to the default cuDNN path for grouped convolutions. - BF16 rounding error accumulates with the K dimension — expect max abs diff scaling roughly as `sqrt(K)` compared to cuDNN FP32. - Inference only (`@torch.no_grad`) — not suitable for QAT backward pass. diff --git a/modelopt/torch/quantization/nn/modules/quant_conv.py b/modelopt/torch/quantization/nn/modules/quant_conv.py index cca3f06730..596f947c75 100644 --- a/modelopt/torch/quantization/nn/modules/quant_conv.py +++ b/modelopt/torch/quantization/nn/modules/quant_conv.py @@ -98,6 +98,7 @@ def _should_use_implicit_gemm(self): and hasattr(self, "weight_quantizer") and _is_nvfp4_quantizer(self.input_quantizer) and _is_nvfp4_quantizer(self.weight_quantizer) + and self.groups == 1 ): try: from modelopt.torch.kernels.conv.implicit_gemm_cuda import ( From e6b9cd6a28e742481cf430238378311d6ef6c09d Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 17 Apr 2026 00:27:26 +0000 Subject: [PATCH 25/25] Address some comments Signed-off-by: Jingyu Xin --- CHANGELOG.rst | 1 + examples/diffusers/quantization/quantize.py | 41 +++++++++++-------- .../quantization/nn/modules/quant_conv.py | 34 +++++++-------- 3 files changed, 43 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index fdd738590a..b15301d41d 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,7 @@ Changelog - Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml `_ for more details. - Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md `_ for more details. - [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution. +- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.kernels.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning. **Backward Breaking Changes** diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 24f283e365..2a3c947a2d 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -116,41 +116,48 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: if self.config.format == QuantFormat.INT8: if self.config.algo == QuantAlgo.SMOOTHQUANT: - quant_config = mtq.INT8_SMOOTHQUANT_CFG + base_cfg = mtq.INT8_SMOOTHQUANT_CFG else: - quant_config = INT8_DEFAULT_CONFIG + base_cfg = INT8_DEFAULT_CONFIG if self.config.collect_method != CollectMethod.DEFAULT: reset_set_int8_config( - quant_config, + base_cfg, self.config.percentile, n_steps, collect_method=self.config.collect_method.value, backbone=backbone, ) elif self.config.format == QuantFormat.FP8: - quant_config = FP8_DEFAULT_CONFIG + base_cfg = FP8_DEFAULT_CONFIG elif self.config.format == QuantFormat.FP4: if self.model_config.model_type.value.startswith("flux"): - quant_config = NVFP4_FP8_MHA_CONFIG + base_cfg = NVFP4_FP8_MHA_CONFIG else: - quant_config = NVFP4_DEFAULT_CONFIG - # Override block size if non-default - if self.config.block_size != 16: - import copy - - quant_config = copy.deepcopy(quant_config) - for entry in quant_config["quant_cfg"]: - if isinstance(entry, dict) and "block_sizes" in entry.get("cfg", {}): - entry["cfg"]["block_sizes"][-1] = self.config.block_size + base_cfg = NVFP4_DEFAULT_CONFIG else: raise NotImplementedError(f"Unknown format {self.config.format}") + + # Build a fresh config dict so we never mutate the global constants. + quant_cfg_list = list(base_cfg["quant_cfg"]) + + if self.config.format == QuantFormat.FP4: + for i, entry in enumerate(quant_cfg_list): + if isinstance(entry, dict) and "block_sizes" in entry.get("cfg", {}): + new_block_sizes = {**entry["cfg"]["block_sizes"], -1: self.config.block_size} + quant_cfg_list[i] = { + **entry, + "cfg": {**entry["cfg"], "block_sizes": new_block_sizes}, + } + if self.config.quantize_mha: - quant_config["quant_cfg"].append( + quant_cfg_list.append( { "quantizer_name": "*[qkv]_bmm_quantizer", "cfg": {"num_bits": (4, 3), "axis": None}, } ) + + quant_config = {**base_cfg, "quant_cfg": quant_cfg_list} set_quant_config_attr( quant_config, self.model_config.trt_high_precision_dtype.value, @@ -617,8 +624,8 @@ def main() -> None: logger.info(f"Quantizing backbone: {backbone_name}") backbone_quant_config = quantizer.get_quant_config(calib_config.n_steps, backbone) - # Full pipeline inference for calibration — cached modules - # (transformer, video_decoder) are exercised during __call__ + # Calibration runs the full pipeline (not just `mod`), so the + # closure intentionally ignores the backbone argument. def forward_loop(mod): calibrator.run_calibration(batched_prompts) diff --git a/modelopt/torch/quantization/nn/modules/quant_conv.py b/modelopt/torch/quantization/nn/modules/quant_conv.py index 596f947c75..cd1be47ead 100644 --- a/modelopt/torch/quantization/nn/modules/quant_conv.py +++ b/modelopt/torch/quantization/nn/modules/quant_conv.py @@ -15,6 +15,8 @@ """Quantized convolution.""" +import warnings + import torch.nn as nn from ... import tensor_quant @@ -90,34 +92,23 @@ class _QuantConv3d(QuantLinearConvBase): def _should_use_implicit_gemm(self): """Check if both quantizers are NVFP4 and the implicit GEMM kernel is available.""" - if hasattr(self, "_use_implicit_gemm"): - return self._use_implicit_gemm - result = False - if ( + return ( hasattr(self, "input_quantizer") and hasattr(self, "weight_quantizer") and _is_nvfp4_quantizer(self.input_quantizer) and _is_nvfp4_quantizer(self.weight_quantizer) and self.groups == 1 - ): - try: - from modelopt.torch.kernels.conv.implicit_gemm_cuda import ( - conv3d_implicit_gemm_cuda, # noqa: F401 - ) - - result = True - except ImportError: - pass - self._use_implicit_gemm = result - return result + ) def _implicit_gemm_forward(self, input): """Run NVFP4 implicit GEMM kernel. Input may already be padded.""" from modelopt.torch.kernels.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda + # _get_amax is an internal TensorQuantizer method with no public equivalent; + # block_sizes is a public property. act_amax = self.input_quantizer._get_amax(input) weight = _nvfp4_quantize_weight_along_k(self.weight, self.weight_quantizer) - fp4_block_size = self.input_quantizer._block_sizes.get(-1, 16) + fp4_block_size = self.input_quantizer.block_sizes.get(-1, 16) output = conv3d_implicit_gemm_cuda( input, @@ -137,7 +128,18 @@ def forward(self, input, *args, **kwargs): if not self._should_use_implicit_gemm(): return super().forward(input, *args, **kwargs) + if self.training: + warnings.warn( + "Implicit GEMM Conv3D kernel is inference-only and does not support training. " + "Falling back to the default cuDNN quantization path, which could produce " + "different numerics.", + stacklevel=2, + ) + return super().forward(input, *args, **kwargs) + # During calibration, only collect amax — use the faster cuDNN path. + # _if_calib/_if_quant are internal TensorQuantizer state with no public property; + # toggled via enable_calib()/disable_calib()/enable_quant()/disable_quant(). if self.input_quantizer._if_calib and not self.input_quantizer._if_quant: return super().forward(input, *args, **kwargs)