Skip to content

Commit 0dfaeb0

Browse files
committed
Use time as the metric for measuring performance
1 parent aafa34a commit 0dfaeb0

7 files changed

Lines changed: 50 additions & 117 deletions

File tree

add.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,14 @@ def grid(meta):
8787
line_vals=["ninetoothed", "torch", "triton"],
8888
line_names=["NineToothed", "PyTorch", "Triton"],
8989
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
90-
ylabel="GB/s",
90+
ylabel="ms",
9191
plot_name="vector-addition-performance",
9292
args={},
9393
)
9494
)
9595
def benchmark(size, provider):
9696
lhs = torch.randn(size, device="cuda", dtype=torch.float16)
9797
rhs = torch.randn(size, device="cuda", dtype=torch.float16)
98-
quantiles = [0.5, 0.2, 0.8]
9998

10099
ninetoothed_output = add(lhs, rhs)
101100
torch_output = lhs + rhs
@@ -104,21 +103,12 @@ def benchmark(size, provider):
104103
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)
105104

106105
if provider == "ninetoothed":
107-
ms, min_ms, max_ms = triton.testing.do_bench(
108-
lambda: add(lhs, rhs), quantiles=quantiles
109-
)
106+
ms = triton.testing.do_bench(lambda: add(lhs, rhs))
110107
elif provider == "torch":
111-
ms, min_ms, max_ms = triton.testing.do_bench(
112-
lambda: lhs + rhs, quantiles=quantiles
113-
)
108+
ms = triton.testing.do_bench(lambda: lhs + rhs)
114109
elif provider == "triton":
115-
ms, min_ms, max_ms = triton.testing.do_bench(
116-
lambda: triton_add(lhs, rhs), quantiles=quantiles
117-
)
110+
ms = triton.testing.do_bench(lambda: triton_add(lhs, rhs))
118111

119-
def gbps(ms):
120-
return 3 * lhs.numel() * lhs.element_size() / ms * 1e-6
121-
122-
return gbps(ms), gbps(max_ms), gbps(min_ms)
112+
return ms
123113

124114
benchmark.run(print_data=True, show_plots=True, save_path=".")

attention.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def grid(meta):
230230
line_vals=["ninetoothed", "torch", "triton"],
231231
line_names=["NineToothed", "PyTorch", "Triton"],
232232
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
233-
ylabel="TFLOPS",
233+
ylabel="ms",
234234
plot_name="attention-performance",
235235
args={},
236236
)
@@ -258,12 +258,6 @@ def benchmark(seq_len, provider):
258258
elif provider == "triton":
259259
ms = triton.testing.do_bench(lambda: triton_attention(q, k, v))
260260

261-
def perf(ms):
262-
flops_per_matmul = 2 * batch_size * num_heads * seq_len * seq_len * emb_dim
263-
total_flops = 2 * flops_per_matmul
264-
265-
return total_flops * 1e-12 / (ms * 1e-3)
266-
267-
return perf(ms)
261+
return ms
268262

269263
benchmark.run(show_plots=True, print_data=True, save_path=".")

conv2d.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def grid(meta):
222222
line_vals=["ninetoothed", "torch", "triton"],
223223
line_names=["NineToothed", "PyTorch", "Triton"],
224224
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
225-
ylabel="TFLOPS",
225+
ylabel="ms",
226226
plot_name="2d-convolution-performance",
227227
args={},
228228
)
@@ -247,12 +247,6 @@ def benchmark(n, provider):
247247
elif provider == "triton":
248248
ms = triton.testing.do_bench(lambda: triton_conv2d(input, filter))
249249

250-
def perf(ms):
251-
p = h - r + 1
252-
q = w - s + 1
253-
254-
return 2 * n * k * p * q * c * r * s * 1e-12 / (ms * 1e-3)
255-
256-
return perf(ms)
250+
return ms
257251

258252
benchmark.run(show_plots=True, print_data=True, save_path=".")

matmul.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,14 @@ def grid(meta):
173173
line_vals=["ninetoothed", "torch", "triton"],
174174
line_names=["NineToothed", "PyTorch", "Triton"],
175175
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
176-
ylabel="TFLOPS",
176+
ylabel="ms",
177177
plot_name="matrix-multiplication-performance",
178178
args={},
179179
)
180180
)
181181
def benchmark(m, n, k, provider):
182182
lhs = torch.randn((m, k), device="cuda", dtype=torch.float16)
183183
rhs = torch.randn((k, n), device="cuda", dtype=torch.float16)
184-
quantiles = [0.5, 0.2, 0.8]
185184

186185
ninetoothed_output = matmul(lhs, rhs)
187186
torch_output = torch.matmul(lhs, rhs)
@@ -190,21 +189,12 @@ def benchmark(m, n, k, provider):
190189
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)
191190

192191
if provider == "ninetoothed":
193-
ms, min_ms, max_ms = triton.testing.do_bench(
194-
lambda: matmul(lhs, rhs), quantiles=quantiles
195-
)
192+
ms = triton.testing.do_bench(lambda: matmul(lhs, rhs))
196193
elif provider == "torch":
197-
ms, min_ms, max_ms = triton.testing.do_bench(
198-
lambda: torch.matmul(lhs, rhs), quantiles=quantiles
199-
)
194+
ms = triton.testing.do_bench(lambda: torch.matmul(lhs, rhs))
200195
elif provider == "triton":
201-
ms, min_ms, max_ms = triton.testing.do_bench(
202-
lambda: triton_matmul(lhs, rhs), quantiles=quantiles
203-
)
196+
ms = triton.testing.do_bench(lambda: triton_matmul(lhs, rhs))
204197

205-
def perf(ms):
206-
return 2 * m * n * k * 1e-12 / (ms * 1e-3)
207-
208-
return perf(ms), perf(max_ms), perf(min_ms)
198+
return ms
209199

210200
benchmark.run(show_plots=True, print_data=True, save_path=".")

performance_comparison.py

Lines changed: 32 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,79 +15,50 @@
1515
@dataclass
1616
class KernelInformation:
1717
name: str
18-
memory_bound: bool
19-
compute_bound: bool
2018
perf_report_path: str
2119
independent_variable: str
2220

2321

24-
@dataclass
25-
class CategoryInformation:
26-
kernels: tuple
27-
y_label: str
28-
29-
3022
kernels = (
31-
KernelInformation("add", True, False, "vector-addition-performance.csv", "Length"),
32-
KernelInformation(
33-
"softmax", True, False, "softmax-performance.csv", "Number of Columns"
34-
),
35-
KernelInformation(
36-
"rms_norm", True, False, "rms-norm-performance.csv", "Number of Columns"
37-
),
38-
KernelInformation(
39-
"matmul", False, True, "matrix-multiplication-performance.csv", "Sizes"
40-
),
41-
KernelInformation(
42-
"conv2d", False, True, "2d-convolution-performance.csv", "Batch Size"
43-
),
44-
KernelInformation(
45-
"attention", False, True, "attention-performance.csv", "Sequence Length"
46-
),
23+
KernelInformation("add", "vector-addition-performance.csv", "Length"),
24+
KernelInformation("softmax", "softmax-performance.csv", "Number of Columns"),
25+
KernelInformation("rms_norm", "rms-norm-performance.csv", "Number of Columns"),
26+
KernelInformation("matmul", "matrix-multiplication-performance.csv", "Sizes"),
27+
KernelInformation("conv2d", "2d-convolution-performance.csv", "Batch Size"),
28+
KernelInformation("attention", "attention-performance.csv", "Sequence Length"),
4729
)
4830

4931
providers = ("Triton", "NineToothed")
5032

51-
categories = (
52-
CategoryInformation(
53-
tuple(kernel for kernel in kernels if kernel.memory_bound), "GB/s"
54-
),
55-
CategoryInformation(
56-
tuple(kernel for kernel in kernels if kernel.compute_bound), "TFLOPS"
57-
),
58-
)
59-
60-
num_rows = len(categories)
61-
num_cols = max(len(category.kernels) for category in categories)
33+
num_rows = 2
34+
num_cols = 3
6235

6336
fig, axs = plt.subplots(num_rows, num_cols)
6437

65-
performance_differences = []
66-
67-
for row, category in enumerate(categories):
68-
axs[row, 0].set_ylabel(category.y_label)
38+
performance_changes = []
6939

70-
for col, kernel in enumerate(category.kernels):
71-
df = pd.read_csv(kernel.perf_report_path)
72-
ax = axs[row, col]
40+
for i, kernel in enumerate(kernels):
41+
df = pd.read_csv(kernel.perf_report_path)
42+
ax = axs[i // num_cols, i % num_cols]
7343

74-
x = df.iloc[:, 0]
44+
x = df.iloc[:, 0]
7545

76-
performance_differences.append((kernel, []))
46+
performance_changes.append((kernel, []))
7747

78-
for provider in providers:
79-
y = df[provider]
48+
for provider in providers:
49+
y = df[provider]
8050

81-
ax.plot(x, y, label=provider)
51+
ax.plot(x, y, label=provider)
8252

83-
if provider == "NineToothed":
84-
y_triton = df["Triton"]
85-
diff = (y - y_triton) / y_triton * 100
86-
performance_differences[-1][-1].append(diff)
53+
if provider == "NineToothed":
54+
y_triton = df["Triton"]
55+
change = (y - y_triton) / y_triton * 100
56+
performance_changes[-1][-1].append(change)
8757

88-
ax.set_title(kernel.name)
89-
ax.set_xlabel(kernel.independent_variable)
90-
ax.set_xscale("log", base=2)
58+
ax.set_title(kernel.name)
59+
ax.set_xlabel(kernel.independent_variable)
60+
ax.set_ylabel("Execution Time (ms)")
61+
ax.set_xscale("log", base=2)
9162

9263
fig.legend(providers, loc="upper center", ncols=len(providers))
9364
fig.tight_layout()
@@ -96,24 +67,24 @@ class CategoryInformation:
9667
plt.show()
9768
plt.savefig("performance-comparison.png")
9869

99-
all_differences = []
70+
all_changes = []
10071
stats_data = []
10172

102-
for kernel, diffs in performance_differences:
103-
all_differences.extend(diffs)
73+
for kernel, changes in performance_changes:
74+
all_changes.extend(changes)
10475

10576
kernel_stats = {
10677
"Kernel": kernel.name,
107-
"Mean": np.mean(diffs),
108-
"Median": np.median(diffs),
78+
"Mean": np.mean(changes),
79+
"Median": np.median(changes),
10980
}
11081

11182
stats_data.append(kernel_stats)
11283

11384
overall_stats = {
11485
"Kernel": "Overall",
115-
"Mean": np.mean(all_differences),
116-
"Median": np.median(all_differences),
86+
"Mean": np.mean(all_changes),
87+
"Median": np.median(all_changes),
11788
}
11889

11990
stats_data.append(overall_stats)

rms_norm.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def triton_rms_norm(input, eps=1e-5):
9595
line_vals=["ninetoothed", "torch", "triton"],
9696
line_names=["NineToothed", "PyTorch", "Triton"],
9797
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
98-
ylabel="GB/s",
98+
ylabel="ms",
9999
plot_name="rms-norm-performance",
100100
args={"m": 4096},
101101
)
@@ -118,9 +118,6 @@ def benchmark(m, n, provider):
118118
elif provider == "triton":
119119
ms = triton.testing.do_bench(lambda: triton_rms_norm(input))
120120

121-
def gbps(ms):
122-
return 2 * input.numel() * input.element_size() * 1e-6 / ms
123-
124-
return gbps(ms)
121+
return ms
125122

126123
benchmark.run(show_plots=True, print_data=True, save_path=".")

softmax.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def triton_softmax(input):
9696
line_vals=["ninetoothed", "torch", "triton"],
9797
line_names=["NineToothed", "PyTorch", "Triton"],
9898
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
99-
ylabel="GB/s",
99+
ylabel="ms",
100100
plot_name="softmax-performance",
101101
args={"m": 4096},
102102
)
@@ -117,9 +117,6 @@ def benchmark(m, n, provider):
117117
elif provider == "triton":
118118
ms = triton.testing.do_bench(lambda: triton_softmax(input))
119119

120-
def gbps(ms):
121-
return 2 * input.numel() * input.element_size() * 1e-6 / ms
122-
123-
return gbps(ms)
120+
return ms
124121

125122
benchmark.run(show_plots=True, print_data=True, save_path=".")

0 commit comments

Comments
 (0)