Skip to content

Commit 0f8477c

Browse files
committed
Run tasks in run_experiments.py instead of compare_performance_metrics.py
1 parent 123bd37 commit 0f8477c

2 files changed

Lines changed: 112 additions & 109 deletions

File tree

compare_performance_metrics.py

Lines changed: 3 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,125 +1,19 @@
1-
import functools
2-
import random
3-
41
import matplotlib.pyplot as plt
52
import pandas as pd
6-
import torch
7-
import torch.nn.functional
8-
import triton
93

10-
import ops.ninetoothed.torch
11-
import ops.triton.torch
12-
import rotary_position_embedding
134
from compare_code_metrics import _BACKSLASH_CHAR
145

15-
16-
def _run_task(op_name, dtype, device, *arg_shapes, **kwarg_shapes):
17-
ninetoothed_op = getattr(ops.ninetoothed.torch, op_name)
18-
triton_op = getattr(ops.triton.torch, op_name)
19-
20-
if op_name == "rotary_position_embedding":
21-
torch_op = rotary_position_embedding.torch_rotary_position_embedding
22-
else:
23-
torch_op = (
24-
getattr(torch, op_name)
25-
if hasattr(torch, op_name)
26-
else getattr(torch.nn.functional, op_name)
27-
)
28-
29-
if op_name == "rms_norm":
30-
torch_op = functools.partial(torch_op, normalized_shape=arg_shapes[0][-1:])
31-
elif op_name == "softmax":
32-
torch_op = functools.partial(torch_op, dim=-1)
33-
34-
args = tuple(
35-
torch.randn(shape, dtype=dtype, device=device) if shape else random.gauss(0, 1)
36-
for shape in arg_shapes
37-
)
38-
kwargs = {
39-
key: torch.randn(shape, dtype=dtype, device=device)
40-
if shape
41-
else random.gauss(0, 1)
42-
for key, shape in kwarg_shapes.items()
43-
}
44-
45-
arg_shape_string = ", ".join(str(shape) for shape in arg_shapes)
46-
kwarg_shape_string = ", ".join(
47-
f"{key}={shape}" for key, shape in kwarg_shapes.items()
48-
)
49-
shape_string = (
50-
f"{arg_shape_string}, {kwarg_shape_string}"
51-
if kwarg_shape_string
52-
else arg_shape_string
53-
)
54-
55-
task_description = f"{op_name}({shape_string})"
56-
57-
return task_description, _benchmark_ops(
58-
(ninetoothed_op, triton_op, torch_op), *args, **kwargs
59-
)
60-
61-
62-
def _benchmark_ops(ops, *args, **kwargs):
63-
assert all(
64-
torch.allclose(
65-
op(*args, **kwargs), ops[0](*args, **kwargs), rtol=0.01, atol=0.01
66-
)
67-
for op in ops[1:]
68-
)
69-
70-
return tuple(triton.testing.do_bench(lambda: op(*args, **kwargs)) for op in ops)
71-
72-
736
if __name__ == "__main__":
74-
random.seed(0)
75-
torch.manual_seed(0)
76-
777
plt.rcParams["figure.dpi"] = 600
788
plt.rcParams["font.family"] = "Linux Biolinum"
799

80-
dtype = torch.float16
81-
device = "cuda"
82-
83-
tasks = (
84-
("add", ((4096 * 4096,), (4096 * 4096,)), {}),
85-
(
86-
"addmm",
87-
((4096, 4096), (4096, 4096), (4096, 4096)),
88-
{"beta": (), "alpha": ()},
89-
),
90-
("bmm", ((4, 2048, 2048), (4, 2048, 2048)), {}),
91-
("conv2d", ((4, 512, 14, 14), (512, 512, 3, 3)), {}),
92-
("mm", ((4096, 4096), (4096, 4096)), {}),
93-
("rms_norm", ((4096, 4096),), {}),
94-
("rotary_position_embedding", ((4, 1024, 48, 64), (1024, 32), (1024, 32)), {}),
95-
(
96-
"scaled_dot_product_attention",
97-
((4, 48, 1024, 64), (4, 48, 1024, 64), (4, 48, 1024, 64)),
98-
{},
99-
),
100-
("silu", ((4096 * 4096,),), {}),
101-
("softmax", ((4096, 4096),), {}),
102-
)
103-
104-
data = {"Task": [], "NineToothed": [], "Triton": [], "PyTorch": []}
105-
106-
for name, args, kwargs in tasks:
107-
description, results = _run_task(name, dtype, device, *args, **kwargs)
108-
109-
latex_item = f"\item {_BACKSLASH_CHAR}texttt{{{description.replace('scaled_dot_product_attention', 'sdpa').replace('rotary_position_embedding', 'rope').replace('_', f'{_BACKSLASH_CHAR}_')}}}"
10+
df = pd.read_csv("performance-metrics.csv")
11011

12+
for task in df["Task"]:
13+
latex_item = f"\item {_BACKSLASH_CHAR}texttt{{{task.replace('scaled_dot_product_attention', 'sdpa').replace('rotary_position_embedding', 'rope').replace('_', f'{_BACKSLASH_CHAR}_')}}}"
11114
print(latex_item)
11215

113-
data["Task"].append(description)
114-
115-
for i, provider in enumerate(("NineToothed", "Triton", "PyTorch")):
116-
data[provider].append(results[i])
117-
118-
df = pd.DataFrame(data)
11916
df.index += 1
120-
121-
df.set_index("Task").to_csv("performance-metrics.csv")
122-
12317
df.plot(kind="bar", rot=0)
12418
plt.ylabel("Execution Time (ms)")
12519
plt.xlabel("Task")

run_experiments.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
import argparse
2+
import functools
3+
import random
24
import subprocess
35

6+
import pandas as pd
7+
import torch
8+
import torch.nn.functional
9+
import triton
10+
11+
import ops.ninetoothed.torch
12+
import ops.triton.torch
13+
import rotary_position_embedding
14+
415
PROMPTS = (
516
"The emergence of deep learning domain-specific languages (DSLs) has substantially reduced the obstacles in developing high-performance, cross-platform compute kernels, but current DSLs",
617
"Driven by recent advancements in the AI industry, the AI accelerator sector has increasingly diversified, with vendors developing their own hardware architectures and programming models, such as NVIDIA",
@@ -15,6 +26,63 @@
1526
ALL_MAX_NEW_TOKENS = (128, 512, 2048)
1627

1728

29+
def _run_task(op_name, dtype, device, *arg_shapes, **kwarg_shapes):
30+
ninetoothed_op = getattr(ops.ninetoothed.torch, op_name)
31+
triton_op = getattr(ops.triton.torch, op_name)
32+
33+
if op_name == "rotary_position_embedding":
34+
torch_op = rotary_position_embedding.torch_rotary_position_embedding
35+
else:
36+
torch_op = (
37+
getattr(torch, op_name)
38+
if hasattr(torch, op_name)
39+
else getattr(torch.nn.functional, op_name)
40+
)
41+
42+
if op_name == "rms_norm":
43+
torch_op = functools.partial(torch_op, normalized_shape=arg_shapes[0][-1:])
44+
elif op_name == "softmax":
45+
torch_op = functools.partial(torch_op, dim=-1)
46+
47+
args = tuple(
48+
torch.randn(shape, dtype=dtype, device=device) if shape else random.gauss(0, 1)
49+
for shape in arg_shapes
50+
)
51+
kwargs = {
52+
key: torch.randn(shape, dtype=dtype, device=device)
53+
if shape
54+
else random.gauss(0, 1)
55+
for key, shape in kwarg_shapes.items()
56+
}
57+
58+
arg_shape_string = ", ".join(str(shape) for shape in arg_shapes)
59+
kwarg_shape_string = ", ".join(
60+
f"{key}={shape}" for key, shape in kwarg_shapes.items()
61+
)
62+
shape_string = (
63+
f"{arg_shape_string}, {kwarg_shape_string}"
64+
if kwarg_shape_string
65+
else arg_shape_string
66+
)
67+
68+
task_description = f"{op_name}({shape_string})"
69+
70+
return task_description, _benchmark_ops(
71+
(ninetoothed_op, triton_op, torch_op), *args, **kwargs
72+
)
73+
74+
75+
def _benchmark_ops(ops, *args, **kwargs):
76+
assert all(
77+
torch.allclose(
78+
op(*args, **kwargs), ops[0](*args, **kwargs), rtol=0.01, atol=0.01
79+
)
80+
for op in ops[1:]
81+
)
82+
83+
return tuple(triton.testing.do_bench(lambda: op(*args, **kwargs)) for op in ops)
84+
85+
1886
if __name__ == "__main__":
1987
parser = argparse.ArgumentParser(description="Run experiments.")
2088

@@ -29,6 +97,9 @@
2997

3098
model_name_or_path = args.model
3199

100+
random.seed(0)
101+
torch.manual_seed(0)
102+
32103
radon_commands = (
33104
(
34105
"radon",
@@ -50,6 +121,44 @@
50121
with open("code_metrics.tex", "w") as f:
51122
subprocess.run(("python", "compare_code_metrics.py"), stdout=f, check=True)
52123

124+
dtype = torch.float16
125+
device = "cuda"
126+
127+
tasks = (
128+
("add", ((4096 * 4096,), (4096 * 4096,)), {}),
129+
(
130+
"addmm",
131+
((4096, 4096), (4096, 4096), (4096, 4096)),
132+
{"beta": (), "alpha": ()},
133+
),
134+
("bmm", ((4, 2048, 2048), (4, 2048, 2048)), {}),
135+
("conv2d", ((4, 512, 14, 14), (512, 512, 3, 3)), {}),
136+
("mm", ((4096, 4096), (4096, 4096)), {}),
137+
("rms_norm", ((4096, 4096),), {}),
138+
("rotary_position_embedding", ((4, 1024, 48, 64), (1024, 32), (1024, 32)), {}),
139+
(
140+
"scaled_dot_product_attention",
141+
((4, 48, 1024, 64), (4, 48, 1024, 64), (4, 48, 1024, 64)),
142+
{},
143+
),
144+
("silu", ((4096 * 4096,),), {}),
145+
("softmax", ((4096, 4096),), {}),
146+
)
147+
148+
data = {"Task": [], "NineToothed": [], "Triton": [], "PyTorch": []}
149+
150+
for name, args, kwargs in tasks:
151+
description, results = _run_task(name, dtype, device, *args, **kwargs)
152+
153+
data["Task"].append(description)
154+
155+
for i, provider in enumerate(("NineToothed", "Triton", "PyTorch")):
156+
data[provider].append(results[i])
157+
158+
df = pd.DataFrame(data)
159+
160+
df.set_index("Task").to_csv("performance-metrics.csv")
161+
53162
for max_new_tokens in ALL_MAX_NEW_TOKENS:
54163
for backend in BACKENDS:
55164
with open(f"infer_{max_new_tokens}_{backend}.json", "w") as f:

0 commit comments

Comments
 (0)