File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -126,8 +126,8 @@ def torch_swiglu(
126126
127127 @triton .testing .perf_report (
128128 triton .testing .Benchmark (
129- x_names = ["size " ],
130- x_vals = [2 ** i for i in range (12 , 28 , 1 )],
129+ x_names = ["m" , "n " ],
130+ x_vals = [128 * i for i in range (2 , 50 )],
131131 x_log = True ,
132132 line_arg = "provider" ,
133133 line_vals = ["ninetoothed" , "torch" , "triton" ],
@@ -138,11 +138,13 @@ def torch_swiglu(
138138 args = {},
139139 )
140140 )
141- def benchmark (size , provider ):
141+ def benchmark (m , n , provider ):
142+ shape = (m , n )
142143 dtype = torch .float16
143144 device = "cuda"
144- a = torch .rand (size , dtype = dtype , device = device )
145- b = torch .rand (size , dtype = dtype , device = device )
145+
146+ a = torch .rand (shape , dtype = dtype , device = device )
147+ b = torch .rand (shape , dtype = dtype , device = device )
146148 quantiles = [0.5 , 0.2 , 0.8 ]
147149
148150 if provider == "ninetoothed" :
You can’t perform that action at this time.
0 commit comments