1515@dataclass
1616class KernelInformation :
1717 name : str
18- memory_bound : bool
19- compute_bound : bool
2018 perf_report_path : str
2119 independent_variable : str
2220
2321
24- @dataclass
25- class CategoryInformation :
26- kernels : tuple
27- y_label : str
28-
29-
3022kernels = (
31- KernelInformation ("add" , True , False , "vector-addition-performance.csv" , "Length" ),
32- KernelInformation (
33- "softmax" , True , False , "softmax-performance.csv" , "Number of Columns"
34- ),
35- KernelInformation (
36- "rms_norm" , True , False , "rms-norm-performance.csv" , "Number of Columns"
37- ),
38- KernelInformation (
39- "matmul" , False , True , "matrix-multiplication-performance.csv" , "Sizes"
40- ),
41- KernelInformation (
42- "conv2d" , False , True , "2d-convolution-performance.csv" , "Batch Size"
43- ),
44- KernelInformation (
45- "attention" , False , True , "attention-performance.csv" , "Sequence Length"
46- ),
23+ KernelInformation ("add" , "vector-addition-performance.csv" , "Length" ),
24+ KernelInformation ("softmax" , "softmax-performance.csv" , "Number of Columns" ),
25+ KernelInformation ("rms_norm" , "rms-norm-performance.csv" , "Number of Columns" ),
26+ KernelInformation ("matmul" , "matrix-multiplication-performance.csv" , "Sizes" ),
27+ KernelInformation ("conv2d" , "2d-convolution-performance.csv" , "Batch Size" ),
28+ KernelInformation ("attention" , "attention-performance.csv" , "Sequence Length" ),
4729)
4830
4931providers = ("Triton" , "NineToothed" )
5032
51- categories = (
52- CategoryInformation (
53- tuple (kernel for kernel in kernels if kernel .memory_bound ), "GB/s"
54- ),
55- CategoryInformation (
56- tuple (kernel for kernel in kernels if kernel .compute_bound ), "TFLOPS"
57- ),
58- )
59-
60- num_rows = len (categories )
61- num_cols = max (len (category .kernels ) for category in categories )
33+ num_rows = 2
34+ num_cols = 3
6235
6336fig , axs = plt .subplots (num_rows , num_cols )
6437
65- performance_differences = []
66-
67- for row , category in enumerate (categories ):
68- axs [row , 0 ].set_ylabel (category .y_label )
38+ performance_changes = []
6939
70- for col , kernel in enumerate (category . kernels ):
71- df = pd .read_csv (kernel .perf_report_path )
72- ax = axs [row , col ]
40+ for i , kernel in enumerate (kernels ):
41+ df = pd .read_csv (kernel .perf_report_path )
42+ ax = axs [i // num_cols , i % num_cols ]
7343
74- x = df .iloc [:, 0 ]
44+ x = df .iloc [:, 0 ]
7545
76- performance_differences .append ((kernel , []))
46+ performance_changes .append ((kernel , []))
7747
78- for provider in providers :
79- y = df [provider ]
48+ for provider in providers :
49+ y = df [provider ]
8050
81- ax .plot (x , y , label = provider )
51+ ax .plot (x , y , label = provider )
8252
83- if provider == "NineToothed" :
84- y_triton = df ["Triton" ]
85- diff = (y - y_triton ) / y_triton * 100
86- performance_differences [- 1 ][- 1 ].append (diff )
53+ if provider == "NineToothed" :
54+ y_triton = df ["Triton" ]
55+ change = (y - y_triton ) / y_triton * 100
56+ performance_changes [- 1 ][- 1 ].append (change )
8757
88- ax .set_title (kernel .name )
89- ax .set_xlabel (kernel .independent_variable )
90- ax .set_xscale ("log" , base = 2 )
58+ ax .set_title (kernel .name )
59+ ax .set_xlabel (kernel .independent_variable )
60+ ax .set_ylabel ("Execution Time (ms)" )
61+ ax .set_xscale ("log" , base = 2 )
9162
9263fig .legend (providers , loc = "upper center" , ncols = len (providers ))
9364fig .tight_layout ()
@@ -96,24 +67,24 @@ class CategoryInformation:
9667plt .show ()
9768plt .savefig ("performance-comparison.png" )
9869
99- all_differences = []
70+ all_changes = []
10071stats_data = []
10172
102- for kernel , diffs in performance_differences :
103- all_differences .extend (diffs )
73+ for kernel , changes in performance_changes :
74+ all_changes .extend (changes )
10475
10576 kernel_stats = {
10677 "Kernel" : kernel .name ,
107- "Mean" : np .mean (diffs ),
108- "Median" : np .median (diffs ),
78+ "Mean" : np .mean (changes ),
79+ "Median" : np .median (changes ),
10980 }
11081
11182 stats_data .append (kernel_stats )
11283
11384overall_stats = {
11485 "Kernel" : "Overall" ,
115- "Mean" : np .mean (all_differences ),
116- "Median" : np .median (all_differences ),
86+ "Mean" : np .mean (all_changes ),
87+ "Median" : np .median (all_changes ),
11788}
11889
11990stats_data .append (overall_stats )
0 commit comments