1818
1919 ninetoothed_output = ops .ninetoothed .torch .conv2d (input , filter )
2020 torch_output = F .conv2d (input , filter )
21- triton_output = ops .triton .torch .triton_conv2d (input , filter )
21+ triton_output = ops .triton .torch .conv2d (input , filter )
2222
2323 print (ninetoothed_output )
2424 print (torch_output )
@@ -56,7 +56,7 @@ def benchmark(n, provider):
5656
5757 ninetoothed_output = ops .ninetoothed .torch .conv2d (input , filter )
5858 torch_output = F .conv2d (input , filter )
59- triton_output = ops .triton .torch .triton_conv2d (input , filter )
59+ triton_output = ops .triton .torch .conv2d (input , filter )
6060
6161 assert torch .allclose (ninetoothed_output , torch_output , atol = 0.01 , rtol = 0.01 )
6262 assert torch .allclose (ninetoothed_output , triton_output , atol = 0 , rtol = 0 )
@@ -68,9 +68,7 @@ def benchmark(n, provider):
6868 elif provider == "torch" :
6969 ms = triton .testing .do_bench (lambda : F .conv2d (input , filter ))
7070 elif provider == "triton" :
71- ms = triton .testing .do_bench (
72- lambda : ops .triton .torch .triton_conv2d (input , filter )
73- )
71+ ms = triton .testing .do_bench (lambda : ops .triton .torch .conv2d (input , filter ))
7472
7573 return ms
7674
0 commit comments