Skip to content

Commit abcdc7c

Browse files
committed
[Benchmark] Improve benchmark test reliability using cudagraph
Using cudagraph eliminates any host cpu overhead or jittering. Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent 4f9e5c9 commit abcdc7c

9 files changed

Lines changed: 132 additions & 53 deletions

test/bench_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,13 @@ def bench_fmha(qkv_shape, dtype, backend, benchmark):
6767

6868
warmup_rounds, iterations, rounds = estimate_bench_iter(
6969
backend, (q, k, v, o, is_causal, enable_gqa),
70+
cudagraph=True
7071
)
7172

7273
benchmark.pedantic(
7374
backend, (q, k, v, o, is_causal, enable_gqa),
7475
rounds=rounds, warmup_rounds=warmup_rounds, iterations=iterations,
76+
cudagraph=True
7577
)
7678

7779
B, H, L, D = q.shape

test/bench_fft.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,11 @@ def bench_fft(shape, dtype, fft_backend, benchmark):
6363
y_ref = torch_fft(*args)
6464
l2error = (y_ref - y_test).norm() / (y_ref).norm()
6565
assert l2error < tolerance_map[dtype]
66-
warmup_rounds, iterations, rounds = estimate_bench_iter(fft_backend, args)
66+
warmup_rounds, iterations, rounds = estimate_bench_iter(fft_backend, args, cudagraph=True)
6767
benchmark.pedantic(
6868
fft_backend, args,
6969
rounds=rounds, warmup_rounds=warmup_rounds, iterations=iterations,
70+
cudagraph=True
7071
)
7172

7273
flop_count = 0 # TODO

test/bench_layer_norm.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -49,31 +49,37 @@ def bench_layer_norm(shape, dtype, mode, backend, benchmark):
4949
torch.bfloat16: (1e-2, 1e-2),
5050
}[dtype]
5151

52-
y = backend(x, weight, bias, eps)
53-
y_ref = torch_layer_norm(x, weight, bias, eps)
54-
if mode == "forward":
55-
torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol)
56-
bench_f, bench_args = backend, (x, weight, bias, eps)
57-
else:
58-
y.backward(dy, retain_graph=True)
59-
dx, dw, db = [_.grad.clone() for _ in [x, weight, bias]]
60-
x.grad, weight.grad, bias.grad = None, None, None
61-
62-
y_ref.backward(dy, retain_graph=True)
63-
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
64-
65-
torch.testing.assert_close(dx, dx_ref, atol=atol, rtol=rtol)
66-
torch.testing.assert_close(dw, dw_ref, atol=atol, rtol=rtol)
67-
torch.testing.assert_close(db, db_ref, atol=atol, rtol=rtol)
68-
69-
bench_f, bench_args = partial(y.backward, retain_graph=True), (dy,)
70-
71-
warmup_rounds, iterations, rounds = estimate_bench_iter(bench_f, bench_args)
72-
73-
benchmark.pedantic(
74-
bench_f, bench_args,
75-
rounds=rounds, warmup_rounds=warmup_rounds, iterations=iterations,
76-
)
52+
# Run in non default stream so backward graph can be captured without
53+
# sync with default stream
54+
s = torch.cuda.Stream()
55+
s.wait_stream(torch.cuda.current_stream())
56+
with torch.cuda.stream(s):
57+
y = backend(x, weight, bias, eps)
58+
y_ref = torch_layer_norm(x, weight, bias, eps)
59+
if mode == "forward":
60+
torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol)
61+
bench_f, bench_args = backend, (x, weight, bias, eps)
62+
else:
63+
y.backward(dy, retain_graph=True)
64+
dx, dw, db = [_.grad.clone() for _ in [x, weight, bias]]
65+
x.grad, weight.grad, bias.grad = None, None, None
66+
67+
y_ref.backward(dy, retain_graph=True)
68+
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
69+
70+
torch.testing.assert_close(dx, dx_ref, atol=atol, rtol=rtol)
71+
torch.testing.assert_close(dw, dw_ref, atol=atol, rtol=rtol)
72+
torch.testing.assert_close(db, db_ref, atol=atol, rtol=rtol)
73+
74+
bench_f, bench_args = partial(y.backward, retain_graph=True), (dy,)
75+
76+
warmup_rounds, iterations, rounds = estimate_bench_iter(bench_f, bench_args, cudagraph=True)
77+
78+
benchmark.pedantic(
79+
bench_f, bench_args,
80+
rounds=rounds, warmup_rounds=warmup_rounds, iterations=iterations,
81+
cudagraph=True
82+
)
7783

7884

7985
class CuTileLayerNorm(torch.autograd.Function):

test/bench_matmul.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def dtype(request):
2121
return request.param
2222

2323

24-
def _run_matmul_benchmark(shape, dtype, backend, benchmark, extra_args=(), atol=1e-3, rtol=1e-3):
24+
def _run_matmul_benchmark(shape, dtype, backend, benchmark,
25+
extra_args=(), atol=1e-3, rtol=1e-3):
2526
m, n, k = shape
2627
A = torch.rand((m, k), dtype=dtype, device="cuda")
2728
B = torch.rand((k, n), dtype=dtype, device="cuda")
@@ -34,10 +35,11 @@ def _run_matmul_benchmark(shape, dtype, backend, benchmark, extra_args=(), atol=
3435
torch.testing.assert_close(C, A @ B, atol=atol, rtol=rtol)
3536

3637
torch.cuda.synchronize()
37-
warmup_rounds, iterations, rounds = estimate_bench_iter(backend, args)
38+
warmup_rounds, iterations, rounds = estimate_bench_iter(backend, args, cudagraph=True)
3839
benchmark.pedantic(
3940
backend, args,
4041
rounds=rounds, warmup_rounds=warmup_rounds, iterations=iterations,
42+
cudagraph=True
4143
)
4244

4345
flop_count = 2 * m * n * k
@@ -63,12 +65,12 @@ def _run_batch_matmul_benchmark(
6365
torch.testing.assert_close(C, ref, atol=atol, rtol=rtol)
6466

6567
torch.cuda.synchronize()
66-
warmup_rounds, iterations, rounds = estimate_bench_iter(backend, args)
68+
warmup_rounds, iterations, rounds = estimate_bench_iter(backend, args, cudagraph=True)
6769
benchmark.pedantic(
6870
backend, args,
6971
rounds=rounds, warmup_rounds=warmup_rounds, iterations=iterations,
72+
cudagraph=True,
7073
)
71-
7274
flop_count = 2 * b * m * n * k
7375
bytes_rw = sum([t.numel() * t.dtype.itemsize for t in (A, B, C)])
7476
benchmark.extra_info['flop_count'] = flop_count
@@ -122,7 +124,8 @@ def bench_matmul_split_k(split_k_shape, dtype, backend, benchmark):
122124
dtype=torch.int32, device="cuda")
123125
COUNTS = torch.zeros_like(LOCKS)
124126
extra_args = (LOCKS, COUNTS, tile_sizes)
125-
_run_matmul_benchmark(split_k_shape, dtype, backend, benchmark, extra_args, rtol=2e-3)
127+
_run_matmul_benchmark(split_k_shape, dtype, backend, benchmark,
128+
extra_args, rtol=2e-3)
126129

127130

128131
def cutile_matmul_split_k(A, B, C, LOCKS, COUNTS, tile_sizes):
@@ -172,8 +175,8 @@ def cutile_batch_matmul(bs, A, B, C):
172175
def torch_batch_matmul(bs, A, B, C):
173176
if A.dtype == torch.float8_e5m2:
174177
pytest.skip("float8_e5m2 matmul on torch is not supported")
175-
inv_sa = torch.tensor(1.0, device=A.device, dtype=torch.float32)
176-
inv_sb = torch.tensor(1.0, device=B.device, dtype=torch.float32)
178+
inv_sa = torch.full((), 1.0, device=A.device, dtype=torch.float32)
179+
inv_sb = torch.full((), 1.0, device=B.device, dtype=torch.float32)
177180
with torch_use_tf32_matmul():
178181
for i in range(bs):
179182
# Only multiplication of row-major and column-major matrices is supported by cuBLASLt

test/bench_rms_norm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,13 @@ def bench_rms_norm(shape, dtype, algo, backend, benchmark):
7272

7373
warmup_rounds, iterations, rounds = estimate_bench_iter(
7474
backend, (x, weight, eps, static_persistent, gather),
75+
cudagraph=True
7576
)
7677

7778
benchmark.pedantic(
7879
backend, (x, weight, eps, static_persistent, gather),
7980
rounds=rounds, warmup_rounds=warmup_rounds, iterations=iterations,
81+
cudagraph=True
8082
)
8183

8284
M, N = x.shape

test/bench_transpose.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ def _run_transpose_benchmark(shape, dtype, backend, benchmark, atol=1e-3, rtol=1
2727
backend(A, B)
2828
torch.testing.assert_close(B, A.T, atol=atol, rtol=rtol)
2929
torch.cuda.synchronize()
30-
warmup_rounds, iterations, rounds = estimate_bench_iter(backend, (A, B))
30+
warmup_rounds, iterations, rounds = estimate_bench_iter(backend, (A, B), cudagraph=True)
3131
benchmark.pedantic(
3232
backend, (A, B),
3333
rounds=rounds, warmup_rounds=warmup_rounds, iterations=iterations,
34+
cudagraph=True,
3435
)
3536

3637
flop_count = m * n

test/bench_vec_add.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,13 @@ def bench_vec_add(shape, dtype, backend, use_gather, benchmark):
5353
torch.testing.assert_close(c, ref, atol=1e-3, rtol=1e-3)
5454
torch.cuda.synchronize()
5555

56-
warmup_rounds, iterations, rounds = estimate_bench_iter(backend, (a, b, use_gather))
56+
warmup_rounds, iterations, rounds = estimate_bench_iter(backend, (a, b, use_gather),
57+
cudagraph=True)
5758

5859
benchmark.pedantic(
5960
backend, (a, b, use_gather),
6061
rounds=rounds, warmup_rounds=warmup_rounds, iterations=iterations,
62+
cudagraph=True
6163
)
6264

6365
flop_count = 0

test/conftest.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,20 @@
44

55
import torch
66
import pytest
7-
import cuda_timer
87
import subprocess
98
import sys
109
import math
1110
import tempfile
12-
from functools import cache
11+
from functools import cache, partial
1312

1413
from cuda.tile._bytecode.version import BytecodeVersion
1514
from cuda.tile._compile import (
1615
_get_max_supported_bytecode_version,
1716
_SUPPORTED_VERSIONS,
1817
_find_compiler_bin)
1918
from cuda.tile._cext import dev_features_enabled
20-
from util import require_blackwell_or_newer, require_hopper_or_newer
19+
from util import (require_blackwell_or_newer, require_hopper_or_newer,
20+
benchmark_cudagraph_runner, benchmark_eager_runner)
2121

2222

2323
def pytest_addoption(parser):
@@ -160,11 +160,27 @@ def uint_dtype(request):
160160
return request.param
161161

162162

163+
def patch_benchmark_fixture(benchmark):
164+
"""Patch BenchmarkFixture to use custom runner: eager or cudagraph.
165+
Extends the `pedantic` method to take additional `cudagraph` argument.
166+
"""
167+
168+
benchmark._make_runner = benchmark_eager_runner
169+
170+
def pedantic(original, *args, **kwargs):
171+
if 'cudagraph' in kwargs:
172+
cudagraph = kwargs.pop('cudagraph')
173+
if cudagraph:
174+
benchmark._make_runner = benchmark_cudagraph_runner
175+
return original(*args, **kwargs)
176+
177+
benchmark.pedantic = partial(pedantic, benchmark.pedantic)
178+
179+
163180
# ----- For pytest benchmark
164181
@pytest.fixture
165182
def benchmark(benchmark):
166-
# Patch benchmark fixture to use cuda timer
167-
benchmark._timer = cuda_timer.time
183+
patch_benchmark_fixture(benchmark)
168184
return benchmark
169185

170186

test/util.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,23 +148,69 @@ def raises_autocast_error(launch, from_ty, to_ty) -> bool:
148148
return False
149149

150150

151-
def estimate_bench_iter(f, tuple_of_args):
151+
def benchmark_cudagraph_runner(f, args, kwargs):
152+
# For patching BenchmarkFixture._make_runner
153+
def runner(loops_range, **unused) -> float:
154+
# run the regular function a few times to ensure kernel and memory states are stable
155+
# before graph capture
156+
for _ in range(3):
157+
f(*args, **kwargs)
158+
159+
# cuda graph capture must happen on non-default stream
160+
if torch.cuda.current_stream() == torch.cuda.default_stream():
161+
stream = torch.cuda.Stream()
162+
stream.wait_stream(torch.cuda.current_stream())
163+
else:
164+
stream = torch.cuda.current_stream()
165+
166+
with torch.cuda.stream(stream):
167+
g = torch.cuda.CUDAGraph()
168+
ev_start = torch.cuda.Event(enable_timing=True, external=True)
169+
ev_end = torch.cuda.Event(enable_timing=True, external=True)
170+
l2_size = torch.cuda.get_device_properties(0).L2_cache_size
171+
cache_flush_tensor = torch.empty(l2_size, dtype=torch.uint8, device="cuda")
172+
173+
with torch.cuda.graph(g):
174+
cache_flush_tensor.zero_()
175+
ev_start.record()
176+
f(*args, **kwargs)
177+
ev_end.record()
178+
179+
torch.cuda.synchronize()
180+
assert loops_range is not None
181+
ret = 0
182+
for _ in loops_range:
183+
g.replay()
184+
ev_end.synchronize()
185+
ret += ev_start.elapsed_time(ev_end)
186+
return ret / 1000 # secs
187+
return runner
188+
189+
190+
def benchmark_eager_runner(f, args, kwargs):
191+
def runner(loops_range, **unused) -> float:
192+
assert loops_range is not None
193+
torch.cuda.synchronize()
194+
ev_start = torch.cuda.Event(enable_timing=True)
195+
ev_end = torch.cuda.Event(enable_timing=True)
196+
ev_start.record()
197+
for _ in loops_range:
198+
f(*args, **kwargs)
199+
ev_end.record()
200+
ev_end.synchronize()
201+
return ev_start.elapsed_time(ev_end) / 1000
202+
return runner
203+
204+
205+
def estimate_bench_iter(f, tuple_of_args, cudagraph=False):
152206
warmup_iter_guess = 5
153207
min_round_time_ms = 100
154208
rounds = 5
155209
warmup_rounds = 1
156-
157-
start = torch.cuda.Event(enable_timing=True)
158-
end = torch.cuda.Event(enable_timing=True)
159-
start.record()
160-
for _ in range(warmup_iter_guess):
161-
f(*tuple_of_args)
162-
end.record()
163-
torch.cuda.synchronize()
164-
elapsed = start.elapsed_time(end) / warmup_iter_guess
165-
166-
main_iter = ceil(min_round_time_ms / elapsed)
167-
210+
runner = (benchmark_cudagraph_runner(f, tuple_of_args, {}) if cudagraph else
211+
benchmark_eager_runner(f, tuple_of_args, {}))
212+
time_per_iter = runner(range(warmup_iter_guess)) / warmup_iter_guess
213+
main_iter = max(min(ceil(min_round_time_ms / (time_per_iter * 1000)), 200), warmup_iter_guess)
168214
return warmup_rounds, main_iter, rounds
169215

170216

0 commit comments

Comments
 (0)