|
23 | 23 |
|
24 | 24 | parser = argparse.ArgumentParser() |
25 | 25 | parser.add_argument("--ncu", action="store_true", help="NCU mode: single iteration, no timing") |
| 26 | +parser.add_argument("--graph", action="store_true", help="Use CUDA graph replay for accurate kernel timing") |
26 | 27 | parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations") |
27 | | -parser.add_argument("--iters", type=int, default=100, help="Timed iterations") |
| 28 | +parser.add_argument("--iters", type=int, default=100, help="Timed iterations per trial") |
| 29 | +parser.add_argument("--trials", type=int, default=5, help="Number of trials for stddev (graph mode)") |
28 | 30 | args = parser.parse_args() |
29 | 31 |
|
30 | 32 | SHAPES = [ |
31 | 33 | ("gateup", 2048, 5120), |
32 | 34 | ("down", 5120, 2048), |
33 | 35 | ("Q", 2048, 4096), |
| 36 | + ("O", 4096, 2048), |
34 | 37 | ("KV", 2048, 512), |
35 | 38 | ] |
36 | 39 | K_VALUES = [2, 3, 4, 5] |
37 | 40 | M_VALUES = [1, 2, 4] |
38 | 41 |
|
39 | | -print(f"{'shape':<8} {'K_dim':>5} {'N':>5} {'k':>2} {'M':>2} {'flat_us':>8} {'tiled_us':>8} {'diff%':>7}") |
40 | | -print("-" * 60) |
| 42 | +if args.graph: |
| 43 | + print(f"{'shape':<8} {'K_dim':>5} {'N':>5} {'k':>2} {'M':>2} {'flat_us':>8} {'±flat':>6} {'tiled_us':>8} {'±tiled':>6} {'diff%':>7}") |
| 44 | + print("-" * 76) |
| 45 | +else: |
| 46 | + print(f"{'shape':<8} {'K_dim':>5} {'N':>5} {'k':>2} {'M':>2} {'flat_us':>8} {'tiled_us':>8} {'diff%':>7}") |
| 47 | + print("-" * 60) |
41 | 48 |
|
42 | 49 | for name, K_dim, N in SHAPES: |
43 | 50 | for k in K_VALUES: |
|
68 | 75 | print(f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2} {'ncu':>8} {'ncu':>8} {'ncu':>7}") |
69 | 76 | continue |
70 | 77 |
|
71 | | - # CUDA events timing |
72 | 78 | start = torch.cuda.Event(enable_timing=True) |
73 | 79 | end = torch.cuda.Event(enable_timing=True) |
74 | 80 |
|
75 | | - # --- Flat --- |
76 | | - for _ in range(args.warmup): |
| 81 | + def call_flat(): |
77 | 82 | torch.ops.bitsandbytes.kbit_scalar_gemv.out( |
78 | 83 | A, packed_flat, absmax_flat, codebook, K_dim, N, k, out_flat |
79 | 84 | ) |
80 | | - torch.cuda.synchronize() |
81 | 85 |
|
82 | | - start.record() |
83 | | - for _ in range(args.iters): |
84 | | - torch.ops.bitsandbytes.kbit_scalar_gemv.out( |
85 | | - A, packed_flat, absmax_flat, codebook, K_dim, N, k, out_flat |
86 | | - ) |
87 | | - end.record() |
88 | | - torch.cuda.synchronize() |
89 | | - flat_us = start.elapsed_time(end) * 1000 / args.iters # ms -> us |
90 | | - |
91 | | - # --- Tiled --- |
92 | | - for _ in range(args.warmup): |
| 86 | + def call_tiled(): |
93 | 87 | torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_( |
94 | 88 | A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out_tiled |
95 | 89 | ) |
96 | | - torch.cuda.synchronize() |
97 | 90 |
|
98 | | - start.record() |
99 | | - for _ in range(args.iters): |
100 | | - torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_( |
101 | | - A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out_tiled |
102 | | - ) |
103 | | - end.record() |
104 | | - torch.cuda.synchronize() |
105 | | - tiled_us = start.elapsed_time(end) * 1000 / args.iters |
| 91 | + if args.graph: |
| 92 | + import statistics |
| 93 | + |
| 94 | + # CUDA graph replay — measures kernel-only time |
| 95 | + for fn in (call_flat, call_tiled): |
| 96 | + for _ in range(3): |
| 97 | + fn() |
| 98 | + torch.cuda.synchronize() |
| 99 | + |
| 100 | + def bench_graph(fn, trials, iters): |
| 101 | + s = torch.cuda.Stream() |
| 102 | + s.wait_stream(torch.cuda.current_stream()) |
| 103 | + with torch.cuda.stream(s): |
| 104 | + g = torch.cuda.CUDAGraph() |
| 105 | + with torch.cuda.graph(g, stream=s): |
| 106 | + fn() |
| 107 | + torch.cuda.synchronize() |
| 108 | + times = [] |
| 109 | + for _ in range(trials): |
| 110 | + start.record() |
| 111 | + for _ in range(iters): |
| 112 | + g.replay() |
| 113 | + end.record() |
| 114 | + torch.cuda.synchronize() |
| 115 | + times.append(start.elapsed_time(end) * 1000 / iters) |
| 116 | + return statistics.mean(times), statistics.stdev(times) if len(times) > 1 else 0.0 |
| 117 | + |
| 118 | + flat_us, flat_std = bench_graph(call_flat, args.trials, args.iters) |
| 119 | + tiled_us, tiled_std = bench_graph(call_tiled, args.trials, args.iters) |
| 120 | + else: |
| 121 | + # CUDA events timing (includes Python dispatch overhead) |
| 122 | + for _ in range(args.warmup): |
| 123 | + call_flat() |
| 124 | + torch.cuda.synchronize() |
| 125 | + start.record() |
| 126 | + for _ in range(args.iters): |
| 127 | + call_flat() |
| 128 | + end.record() |
| 129 | + torch.cuda.synchronize() |
| 130 | + flat_us = start.elapsed_time(end) * 1000 / args.iters |
| 131 | + |
| 132 | + for _ in range(args.warmup): |
| 133 | + call_tiled() |
| 134 | + torch.cuda.synchronize() |
| 135 | + start.record() |
| 136 | + for _ in range(args.iters): |
| 137 | + call_tiled() |
| 138 | + end.record() |
| 139 | + torch.cuda.synchronize() |
| 140 | + tiled_us = start.elapsed_time(end) * 1000 / args.iters |
106 | 141 |
|
107 | 142 | diff_pct = (tiled_us - flat_us) / flat_us * 100 |
108 | | - print(f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2} {flat_us:>8.1f} {tiled_us:>8.1f} {diff_pct:>+7.1f}%") |
| 143 | + if args.graph: |
| 144 | + print( |
| 145 | + f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2}" |
| 146 | + f" {flat_us:>8.1f} {flat_std:>5.1f}σ {tiled_us:>8.1f} {tiled_std:>5.1f}σ {diff_pct:>+7.1f}%" |
| 147 | + ) |
| 148 | + else: |
| 149 | + print(f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2} {flat_us:>8.1f} {tiled_us:>8.1f} {diff_pct:>+7.1f}%") |
109 | 150 |
|
110 | 151 | # Correctness check (once per shape/k) |
111 | 152 | assert torch.equal(out_flat, out_tiled) or torch.allclose(out_flat, out_tiled, rtol=0.05, atol=0.1), ( |
|
0 commit comments