|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Latency benchmark: implicit GEMM (quant / non-quant) vs cuDNN conv3d. |
| 17 | +
|
| 18 | +Usage: |
| 19 | + python -m experimental.conv.bench_implicit_gemm |
| 20 | + python -m experimental.conv.bench_implicit_gemm --shapes wan22 |
| 21 | + python -m experimental.conv.bench_implicit_gemm --shapes all --warmup 20 --iters 100 |
| 22 | +""" |
| 23 | + |
| 24 | +import argparse |
| 25 | + |
| 26 | +import torch |
| 27 | +import torch.nn.functional as F |
| 28 | + |
| 29 | +# --------------------------------------------------------------------------- |
| 30 | +# Benchmark shapes |
| 31 | +# --------------------------------------------------------------------------- |
| 32 | + |
| 33 | +# (name, N, Cin, D, H, W, Cout, kD, kH, kW, stride, padding, dilation) |
| 34 | +SHAPES = { |
| 35 | + "small": [ |
| 36 | + ("small_16x32_3x3x3", 1, 16, 8, 8, 8, 32, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), |
| 37 | + ], |
| 38 | + "medium": [ |
| 39 | + ("med_64x128_3x3x3", 1, 64, 16, 32, 32, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), |
| 40 | + ("med_128x256_3x3x3", 1, 128, 8, 16, 16, 256, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), |
| 41 | + ("med_128x128_1x3x3", 1, 128, 16, 32, 32, 128, 1, 3, 3, (1, 1, 1), (0, 1, 1), (1, 1, 1)), |
| 42 | + ], |
| 43 | + "wan22": [ |
| 44 | + ("wan22_128x512", 1, 128, 21, 60, 106, 512, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), |
| 45 | + ("wan22_512x512", 1, 512, 21, 60, 106, 512, 1, 1, 1, (1, 1, 1), (0, 0, 0), (1, 1, 1)), |
| 46 | + ("wan22_512x128", 1, 512, 21, 60, 106, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)), |
| 47 | + ], |
| 48 | + "stride": [ |
| 49 | + ("stride2_64x128", 1, 64, 16, 32, 32, 128, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)), |
| 50 | + ("stride2_128x256", 1, 128, 16, 32, 32, 256, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)), |
| 51 | + ], |
| 52 | +} |
| 53 | + |
| 54 | + |
| 55 | +def get_shapes(name: str): |
| 56 | + """Return list of benchmark shapes by name or all shapes.""" |
| 57 | + if name == "all": |
| 58 | + result = [] |
| 59 | + for v in SHAPES.values(): |
| 60 | + result.extend(v) |
| 61 | + return result |
| 62 | + return SHAPES[name] |
| 63 | + |
| 64 | + |
| 65 | +# --------------------------------------------------------------------------- |
| 66 | +# Timing utility |
| 67 | +# --------------------------------------------------------------------------- |
| 68 | + |
| 69 | + |
| 70 | +def bench_fn(fn, warmup: int, iters: int) -> float: |
| 71 | + """Benchmark a callable, return median time in ms.""" |
| 72 | + for _ in range(warmup): |
| 73 | + fn() |
| 74 | + torch.cuda.synchronize() |
| 75 | + |
| 76 | + times = [] |
| 77 | + for _ in range(iters): |
| 78 | + start = torch.cuda.Event(enable_timing=True) |
| 79 | + end = torch.cuda.Event(enable_timing=True) |
| 80 | + start.record() |
| 81 | + fn() |
| 82 | + end.record() |
| 83 | + torch.cuda.synchronize() |
| 84 | + times.append(start.elapsed_time(end)) |
| 85 | + |
| 86 | + times.sort() |
| 87 | + return times[len(times) // 2] # median |
| 88 | + |
| 89 | + |
| 90 | +# --------------------------------------------------------------------------- |
| 91 | +# Main |
| 92 | +# --------------------------------------------------------------------------- |
| 93 | + |
| 94 | + |
| 95 | +def run_benchmark(shapes_name: str, warmup: int, iters: int, fp4_block_size: int): |
| 96 | + """Run latency benchmark for the given shapes.""" |
| 97 | + from modelopt.torch.quantization.src.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda |
| 98 | + |
| 99 | + shapes = get_shapes(shapes_name) |
| 100 | + |
| 101 | + # Header |
| 102 | + print(f"\n{'=' * 100}") |
| 103 | + print( |
| 104 | + f"Conv3D Latency Benchmark | warmup={warmup} iters={iters} fp4_block_size={fp4_block_size}" |
| 105 | + ) |
| 106 | + print(f"GPU: {torch.cuda.get_device_name()}") |
| 107 | + print(f"{'=' * 100}") |
| 108 | + print( |
| 109 | + f"{'Shape':<25} {'M':>10} {'K':>8} {'N':>6} " |
| 110 | + f"{'cuDNN':>9} {'GEMM':>9} {'GEMM+FP4':>9} " |
| 111 | + f"{'GEMM/cuDNN':>11} {'FP4/cuDNN':>10}" |
| 112 | + ) |
| 113 | + print("-" * 100) |
| 114 | + |
| 115 | + for name, n, cin, d, h, w, cout, kd, kh, kw, stride, padding, dilation in shapes: |
| 116 | + torch.manual_seed(42) |
| 117 | + x = torch.randn(n, cin, d, h, w, device="cuda", dtype=torch.float32) |
| 118 | + weight = torch.randn(cout, cin, kd, kh, kw, device="cuda", dtype=torch.float32) |
| 119 | + act_amax = x.abs().max().unsqueeze(0) |
| 120 | + |
| 121 | + # Compute GEMM dimensions for display |
| 122 | + sd, sh, sw = stride |
| 123 | + dd, dh, dw = dilation |
| 124 | + pd, ph, pw = padding |
| 125 | + od = (d + 2 * pd - dd * (kd - 1) - 1) // sd + 1 |
| 126 | + oh = (h + 2 * ph - dh * (kh - 1) - 1) // sh + 1 |
| 127 | + ow = (w + 2 * pw - dw * (kw - 1) - 1) // sw + 1 |
| 128 | + gemm_m = n * od * oh * ow |
| 129 | + gemm_k = cin * kd * kh * kw |
| 130 | + gemm_n = cout |
| 131 | + |
| 132 | + # cuDNN (torch.nn.functional.conv3d) |
| 133 | + t_cudnn = bench_fn( |
| 134 | + lambda: F.conv3d(x, weight, stride=stride, padding=padding, dilation=dilation), |
| 135 | + warmup, |
| 136 | + iters, |
| 137 | + ) |
| 138 | + |
| 139 | + # Implicit GEMM (non-quantized) |
| 140 | + t_gemm = bench_fn( |
| 141 | + lambda: conv3d_implicit_gemm_cuda( |
| 142 | + x, |
| 143 | + weight, |
| 144 | + stride=stride, |
| 145 | + padding=padding, |
| 146 | + dilation=dilation, |
| 147 | + quant_act=False, |
| 148 | + fp4_block_size=fp4_block_size, |
| 149 | + ), |
| 150 | + warmup, |
| 151 | + iters, |
| 152 | + ) |
| 153 | + |
| 154 | + # Implicit GEMM (FP4 quantized) |
| 155 | + t_fp4 = bench_fn( |
| 156 | + lambda: conv3d_implicit_gemm_cuda( |
| 157 | + x, |
| 158 | + weight, |
| 159 | + stride=stride, |
| 160 | + padding=padding, |
| 161 | + dilation=dilation, |
| 162 | + act_amax=act_amax, |
| 163 | + quant_act=True, |
| 164 | + fp4_block_size=fp4_block_size, |
| 165 | + ), |
| 166 | + warmup, |
| 167 | + iters, |
| 168 | + ) |
| 169 | + |
| 170 | + ratio_gemm = t_gemm / t_cudnn |
| 171 | + ratio_fp4 = t_fp4 / t_cudnn |
| 172 | + |
| 173 | + print( |
| 174 | + f"{name:<25} {gemm_m:>10,} {gemm_k:>8,} {gemm_n:>6,} " |
| 175 | + f"{t_cudnn:>8.3f}ms {t_gemm:>8.3f}ms {t_fp4:>8.3f}ms " |
| 176 | + f"{ratio_gemm:>10.2f}x {ratio_fp4:>9.2f}x" |
| 177 | + ) |
| 178 | + |
| 179 | + print(f"{'=' * 100}") |
| 180 | + print("Ratios > 1.0x mean slower than cuDNN; < 1.0x mean faster.") |
| 181 | + print() |
| 182 | + |
| 183 | + |
| 184 | +def main(): |
| 185 | + """Entry point for the benchmark CLI.""" |
| 186 | + parser = argparse.ArgumentParser(description="Conv3D latency benchmark") |
| 187 | + parser.add_argument( |
| 188 | + "--shapes", |
| 189 | + default="all", |
| 190 | + choices=[*list(SHAPES.keys()), "all"], |
| 191 | + help="Which shape set to benchmark (default: all)", |
| 192 | + ) |
| 193 | + parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations") |
| 194 | + parser.add_argument("--iters", type=int, default=100, help="Benchmark iterations") |
| 195 | + parser.add_argument( |
| 196 | + "--fp4-block-size", |
| 197 | + type=int, |
| 198 | + default=128, |
| 199 | + choices=[128, 256], |
| 200 | + help="FP4 block size (default: 128)", |
| 201 | + ) |
| 202 | + args = parser.parse_args() |
| 203 | + |
| 204 | + run_benchmark(args.shapes, args.warmup, args.iters, args.fp4_block_size) |
| 205 | + |
| 206 | + |
| 207 | +if __name__ == "__main__": |
| 208 | + main() |
0 commit comments