Skip to content

Commit 0622007

Browse files
committed
Add more test cases
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent a76f561 commit 0622007

11 files changed

Lines changed: 1874 additions & 2 deletions

File tree

examples/diffusers/sparsity/wan22_skip_softmax.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,17 @@ def main() -> None:
464464
pipe_kwargs["guidance_scale_2"] = args.guidance_scale_2
465465
output = pipe(**pipe_kwargs)
466466

467-
export_to_video(output.frames[0], args.output, fps=16)
468-
print(f"Saved to {args.output}")
467+
try:
468+
export_to_video(output.frames[0], args.output, fps=16)
469+
print(f"Saved to {args.output}")
470+
except ImportError as exc:
471+
# Fall back to saving the first frame as PNG if no video backend
472+
# (opencv / imageio) is installed — useful in minimal CI envs.
473+
print(f"Video export skipped ({exc}); saving first frame as PNG")
474+
frame0 = output.frames[0][0]
475+
png_path = str(args.output).rsplit(".", 1)[0] + ".png"
476+
frame0.save(png_path)
477+
print(f"Saved first frame to {png_path}")
469478

470479
# ---- Print stats ----
471480
if not args.baseline:
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Latency benchmark: implicit GEMM (quant / non-quant) vs cuDNN conv3d.
17+
18+
Usage:
19+
python -m experimental.conv.bench_implicit_gemm
20+
python -m experimental.conv.bench_implicit_gemm --shapes wan22
21+
python -m experimental.conv.bench_implicit_gemm --shapes all --warmup 20 --iters 100
22+
"""
23+
24+
import argparse
25+
26+
import torch
27+
import torch.nn.functional as F
28+
29+
# ---------------------------------------------------------------------------
30+
# Benchmark shapes
31+
# ---------------------------------------------------------------------------
32+
33+
# (name, N, Cin, D, H, W, Cout, kD, kH, kW, stride, padding, dilation)
34+
SHAPES = {
35+
"small": [
36+
("small_16x32_3x3x3", 1, 16, 8, 8, 8, 32, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
37+
],
38+
"medium": [
39+
("med_64x128_3x3x3", 1, 64, 16, 32, 32, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
40+
("med_128x256_3x3x3", 1, 128, 8, 16, 16, 256, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
41+
("med_128x128_1x3x3", 1, 128, 16, 32, 32, 128, 1, 3, 3, (1, 1, 1), (0, 1, 1), (1, 1, 1)),
42+
],
43+
"wan22": [
44+
("wan22_128x512", 1, 128, 21, 60, 106, 512, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
45+
("wan22_512x512", 1, 512, 21, 60, 106, 512, 1, 1, 1, (1, 1, 1), (0, 0, 0), (1, 1, 1)),
46+
("wan22_512x128", 1, 512, 21, 60, 106, 128, 3, 3, 3, (1, 1, 1), (1, 1, 1), (1, 1, 1)),
47+
],
48+
"stride": [
49+
("stride2_64x128", 1, 64, 16, 32, 32, 128, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)),
50+
("stride2_128x256", 1, 128, 16, 32, 32, 256, 3, 3, 3, (2, 2, 2), (1, 1, 1), (1, 1, 1)),
51+
],
52+
}
53+
54+
55+
def get_shapes(name: str):
56+
"""Return list of benchmark shapes by name or all shapes."""
57+
if name == "all":
58+
result = []
59+
for v in SHAPES.values():
60+
result.extend(v)
61+
return result
62+
return SHAPES[name]
63+
64+
65+
# ---------------------------------------------------------------------------
66+
# Timing utility
67+
# ---------------------------------------------------------------------------
68+
69+
70+
def bench_fn(fn, warmup: int, iters: int) -> float:
71+
"""Benchmark a callable, return median time in ms."""
72+
for _ in range(warmup):
73+
fn()
74+
torch.cuda.synchronize()
75+
76+
times = []
77+
for _ in range(iters):
78+
start = torch.cuda.Event(enable_timing=True)
79+
end = torch.cuda.Event(enable_timing=True)
80+
start.record()
81+
fn()
82+
end.record()
83+
torch.cuda.synchronize()
84+
times.append(start.elapsed_time(end))
85+
86+
times.sort()
87+
return times[len(times) // 2] # median
88+
89+
90+
# ---------------------------------------------------------------------------
91+
# Main
92+
# ---------------------------------------------------------------------------
93+
94+
95+
def run_benchmark(shapes_name: str, warmup: int, iters: int, fp4_block_size: int):
96+
"""Run latency benchmark for the given shapes."""
97+
from modelopt.torch.quantization.src.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda
98+
99+
shapes = get_shapes(shapes_name)
100+
101+
# Header
102+
print(f"\n{'=' * 100}")
103+
print(
104+
f"Conv3D Latency Benchmark | warmup={warmup} iters={iters} fp4_block_size={fp4_block_size}"
105+
)
106+
print(f"GPU: {torch.cuda.get_device_name()}")
107+
print(f"{'=' * 100}")
108+
print(
109+
f"{'Shape':<25} {'M':>10} {'K':>8} {'N':>6} "
110+
f"{'cuDNN':>9} {'GEMM':>9} {'GEMM+FP4':>9} "
111+
f"{'GEMM/cuDNN':>11} {'FP4/cuDNN':>10}"
112+
)
113+
print("-" * 100)
114+
115+
for name, n, cin, d, h, w, cout, kd, kh, kw, stride, padding, dilation in shapes:
116+
torch.manual_seed(42)
117+
x = torch.randn(n, cin, d, h, w, device="cuda", dtype=torch.float32)
118+
weight = torch.randn(cout, cin, kd, kh, kw, device="cuda", dtype=torch.float32)
119+
act_amax = x.abs().max().unsqueeze(0)
120+
121+
# Compute GEMM dimensions for display
122+
sd, sh, sw = stride
123+
dd, dh, dw = dilation
124+
pd, ph, pw = padding
125+
od = (d + 2 * pd - dd * (kd - 1) - 1) // sd + 1
126+
oh = (h + 2 * ph - dh * (kh - 1) - 1) // sh + 1
127+
ow = (w + 2 * pw - dw * (kw - 1) - 1) // sw + 1
128+
gemm_m = n * od * oh * ow
129+
gemm_k = cin * kd * kh * kw
130+
gemm_n = cout
131+
132+
# cuDNN (torch.nn.functional.conv3d)
133+
t_cudnn = bench_fn(
134+
lambda: F.conv3d(x, weight, stride=stride, padding=padding, dilation=dilation),
135+
warmup,
136+
iters,
137+
)
138+
139+
# Implicit GEMM (non-quantized)
140+
t_gemm = bench_fn(
141+
lambda: conv3d_implicit_gemm_cuda(
142+
x,
143+
weight,
144+
stride=stride,
145+
padding=padding,
146+
dilation=dilation,
147+
quant_act=False,
148+
fp4_block_size=fp4_block_size,
149+
),
150+
warmup,
151+
iters,
152+
)
153+
154+
# Implicit GEMM (FP4 quantized)
155+
t_fp4 = bench_fn(
156+
lambda: conv3d_implicit_gemm_cuda(
157+
x,
158+
weight,
159+
stride=stride,
160+
padding=padding,
161+
dilation=dilation,
162+
act_amax=act_amax,
163+
quant_act=True,
164+
fp4_block_size=fp4_block_size,
165+
),
166+
warmup,
167+
iters,
168+
)
169+
170+
ratio_gemm = t_gemm / t_cudnn
171+
ratio_fp4 = t_fp4 / t_cudnn
172+
173+
print(
174+
f"{name:<25} {gemm_m:>10,} {gemm_k:>8,} {gemm_n:>6,} "
175+
f"{t_cudnn:>8.3f}ms {t_gemm:>8.3f}ms {t_fp4:>8.3f}ms "
176+
f"{ratio_gemm:>10.2f}x {ratio_fp4:>9.2f}x"
177+
)
178+
179+
print(f"{'=' * 100}")
180+
print("Ratios > 1.0x mean slower than cuDNN; < 1.0x mean faster.")
181+
print()
182+
183+
184+
def main():
185+
"""Entry point for the benchmark CLI."""
186+
parser = argparse.ArgumentParser(description="Conv3D latency benchmark")
187+
parser.add_argument(
188+
"--shapes",
189+
default="all",
190+
choices=[*list(SHAPES.keys()), "all"],
191+
help="Which shape set to benchmark (default: all)",
192+
)
193+
parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations")
194+
parser.add_argument("--iters", type=int, default=100, help="Benchmark iterations")
195+
parser.add_argument(
196+
"--fp4-block-size",
197+
type=int,
198+
default=128,
199+
choices=[128, 256],
200+
help="FP4 block size (default: 128)",
201+
)
202+
args = parser.parse_args()
203+
204+
run_benchmark(args.shapes, args.warmup, args.iters, args.fp4_block_size)
205+
206+
207+
if __name__ == "__main__":
208+
main()

0 commit comments

Comments
 (0)