Skip to content

Commit a33a14c

Browse files
committed
Rename ops.triton.torch.triton_conv2d to ops.triton.torch.conv2d
1 parent 236add4 commit a33a14c

2 files changed

Lines changed: 4 additions & 6 deletions

File tree

conv2d.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
ninetoothed_output = ops.ninetoothed.torch.conv2d(input, filter)
2020
torch_output = F.conv2d(input, filter)
21-
triton_output = ops.triton.torch.triton_conv2d(input, filter)
21+
triton_output = ops.triton.torch.conv2d(input, filter)
2222

2323
print(ninetoothed_output)
2424
print(torch_output)
@@ -56,7 +56,7 @@ def benchmark(n, provider):
5656

5757
ninetoothed_output = ops.ninetoothed.torch.conv2d(input, filter)
5858
torch_output = F.conv2d(input, filter)
59-
triton_output = ops.triton.torch.triton_conv2d(input, filter)
59+
triton_output = ops.triton.torch.conv2d(input, filter)
6060

6161
assert torch.allclose(ninetoothed_output, torch_output, atol=0.01, rtol=0.01)
6262
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)
@@ -68,9 +68,7 @@ def benchmark(n, provider):
6868
elif provider == "torch":
6969
ms = triton.testing.do_bench(lambda: F.conv2d(input, filter))
7070
elif provider == "triton":
71-
ms = triton.testing.do_bench(
72-
lambda: ops.triton.torch.triton_conv2d(input, filter)
73-
)
71+
ms = triton.testing.do_bench(lambda: ops.triton.torch.conv2d(input, filter))
7472

7573
return ms
7674

ops/triton/torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def grid(meta):
5757
return output
5858

5959

60-
def triton_conv2d(input, filter):
60+
def conv2d(input, filter):
6161
n, c, h, w = input.shape
6262
k, _, r, s = filter.shape
6363
p = h - r + 1

0 commit comments

Comments
 (0)