Skip to content

Commit d9b4a82

Browse files
committed
Add ops/tilelang/torch.py
1 parent 8fcb667 commit d9b4a82

1 file changed

Lines changed: 12 additions & 0 deletions

File tree

ops/tilelang/torch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
3+
import ops.tilelang.kernels.mm
4+
5+
6+
def mm(input, other):
7+
output_shape = (input.shape[0], other.shape[1])
8+
output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
9+
10+
ops.tilelang.kernels.mm.kernel(input, other, output)
11+
12+
return output

0 commit comments

Comments
 (0)