Skip to content

Commit f87582f

Browse files
committed
Add compare_code_metrics.py
1 parent 29bf2f9 commit f87582f

2 files changed

Lines changed: 196 additions & 57 deletions

File tree

code_size_comparison.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

compare_code_metrics.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import json
2+
import os.path
3+
from pathlib import Path
4+
5+
import pandas as pd
6+
7+
_PARENT_PATH = Path(__file__).parent
8+
9+
_OPS_PATH = _PARENT_PATH / "ops"
10+
11+
_NINETOOTHED_KERNELS_PATH = _OPS_PATH / "ninetoothed" / "kernels"
12+
13+
_TRITON_KERNELS_PATH = _OPS_PATH / "triton" / "kernels"
14+
15+
16+
def _generate_cc_table():
17+
path = _PARENT_PATH / "cc.json"
18+
19+
metric_names = {"complexity": "$G$"}
20+
21+
data = json.loads(path.read_text())
22+
23+
data = {
24+
kernel: {
25+
metric_names["complexity"]: sum(block["complexity"] for block in blocks)
26+
}
27+
for kernel, blocks in data.items()
28+
if "torch" not in kernel
29+
}
30+
31+
df = _generate_table(data, metric_names.values())
32+
33+
styled_df = df.style.apply(_highlight_minimum, axis=None).format(precision=2)
34+
35+
return styled_df.to_latex(hrules=True, multicol_align="c", convert_css=True)
36+
37+
38+
def _generate_mi_table():
39+
path = _PARENT_PATH / "mi.json"
40+
41+
metric_names = {"mi": "$MI$"}
42+
43+
data = json.loads(path.read_text())
44+
45+
data = {
46+
kernel: {
47+
latex_name: metrics[raw_name]
48+
for raw_name, latex_name in metric_names.items()
49+
}
50+
for kernel, metrics in data.items()
51+
if "torch" not in kernel
52+
}
53+
54+
df = _generate_table(data, metric_names.values())
55+
56+
styled_df = df.style.apply(_highlight_maximum, axis=None).format(precision=2)
57+
58+
return styled_df.to_latex(hrules=True, multicol_align="c", convert_css=True)
59+
60+
61+
def _generate_raw_table():
62+
path = _PARENT_PATH / "raw.json"
63+
64+
metric_names = {"loc": "LOC", "lloc": "LLOC", "sloc": "SLOC"}
65+
66+
data = json.loads(path.read_text())
67+
68+
data = {
69+
kernel: {
70+
latex_name: metrics[raw_name]
71+
for raw_name, latex_name in metric_names.items()
72+
}
73+
for kernel, metrics in data.items()
74+
if "torch" not in kernel
75+
}
76+
77+
df = _generate_table(data, metric_names.values())
78+
79+
styled_df = df.style.apply(_highlight_minimum, axis=None).format(precision=2)
80+
81+
return styled_df.to_latex(hrules=True, multicol_align="c", convert_css=True)
82+
83+
84+
def _generate_hal_table():
85+
path = _PARENT_PATH / "hal.json"
86+
87+
metric_names = {
88+
"h1": "$\\eta_1$",
89+
"h2": "$\\eta_2$",
90+
"N1": "$N_1$",
91+
"N2": "$N_2$",
92+
"vocabulary": "$\\eta$",
93+
"length": "$N$",
94+
"calculated_length": "$\\hat{N}$",
95+
"volume": "$V$",
96+
"difficulty": "$D$",
97+
"effort": "$E$",
98+
"time": "$T$",
99+
"bugs": "$B$",
100+
}
101+
102+
data = json.loads(path.read_text())
103+
104+
data = {
105+
kernel: {
106+
latex_name: metrics["total"][raw_name]
107+
for raw_name, latex_name in metric_names.items()
108+
}
109+
for kernel, metrics in data.items()
110+
if "torch" not in kernel
111+
}
112+
113+
df = _generate_table(data, metric_names.values())
114+
115+
styled_df = df.style.apply(_highlight_minimum, axis=None).format(precision=2)
116+
117+
return styled_df.to_latex(hrules=True, multicol_align="c", convert_css=True)
118+
119+
120+
def _generate_table(data, metric_names):
121+
kernel_names = sorted(
122+
set(
123+
os.path.splitext(os.path.basename(kernel_name))[0]
124+
for kernel_name in data.keys()
125+
)
126+
)
127+
128+
def _key_from_kernel_name(path, kernel_name):
129+
return str(path / f"{kernel_name}.py").removeprefix(str(_PARENT_PATH))[1:]
130+
131+
data = {
132+
f"\\texttt{{{kernel_name.replace('scaled_dot_product_attention', 'sdpa').replace('_', '\\_')}}}": {
133+
"Triton": {
134+
metric_name: data[
135+
_key_from_kernel_name(_TRITON_KERNELS_PATH, kernel_name)
136+
][metric_name]
137+
for metric_name in metric_names
138+
},
139+
"NineToothed": {
140+
metric_name: data[
141+
_key_from_kernel_name(_NINETOOTHED_KERNELS_PATH, kernel_name)
142+
][metric_name]
143+
for metric_name in metric_names
144+
},
145+
}
146+
for kernel_name in kernel_names
147+
}
148+
149+
df = pd.DataFrame.from_dict(
150+
{
151+
(outer_key, inner_key): value
152+
for outer_key, inner_dict in data.items()
153+
for inner_key, value in inner_dict.items()
154+
},
155+
orient="index",
156+
)
157+
158+
df.index = pd.MultiIndex.from_tuples(df.index)
159+
160+
return df
161+
162+
163+
def _highlight_minimum(df):
164+
styles = pd.DataFrame("", index=df.index, columns=df.columns)
165+
166+
for kernel, group in df.groupby(level=0):
167+
mask = group == group.min()
168+
169+
styles.update(
170+
mask.replace(True, "background-color: green!20").replace(False, "")
171+
)
172+
173+
return styles
174+
175+
176+
def _highlight_maximum(df):
177+
styles = pd.DataFrame("", index=df.index, columns=df.columns)
178+
179+
for kernel, group in df.groupby(level=0):
180+
mask = group == group.max()
181+
182+
styles.update(
183+
mask.replace(True, "background-color: green!20").replace(False, "")
184+
)
185+
186+
return styles
187+
188+
189+
if __name__ == "__main__":
190+
for latex_code in (
191+
_generate_cc_table(),
192+
_generate_mi_table(),
193+
_generate_raw_table(),
194+
_generate_hal_table(),
195+
):
196+
print(latex_code)

0 commit comments

Comments
 (0)