Skip to content

Commit db86285

Browse files
committed
Improve the Triton conv2d implementation
1 parent fc45cea commit db86285

1 file changed

Lines changed: 2 additions & 17 deletions

File tree

ops/triton/kernels/conv2d.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
"BLOCK_SIZE_M": 128,
1010
"BLOCK_SIZE_N": 256,
1111
"BLOCK_SIZE_K": 64,
12-
"GROUP_SIZE_M": 8,
1312
},
1413
num_stages=3,
1514
num_warps=8,
@@ -19,7 +18,6 @@
1918
"BLOCK_SIZE_M": 64,
2019
"BLOCK_SIZE_N": 256,
2120
"BLOCK_SIZE_K": 32,
22-
"GROUP_SIZE_M": 8,
2321
},
2422
num_stages=4,
2523
num_warps=4,
@@ -29,7 +27,6 @@
2927
"BLOCK_SIZE_M": 128,
3028
"BLOCK_SIZE_N": 128,
3129
"BLOCK_SIZE_K": 32,
32-
"GROUP_SIZE_M": 8,
3330
},
3431
num_stages=4,
3532
num_warps=4,
@@ -39,7 +36,6 @@
3936
"BLOCK_SIZE_M": 128,
4037
"BLOCK_SIZE_N": 64,
4138
"BLOCK_SIZE_K": 32,
42-
"GROUP_SIZE_M": 8,
4339
},
4440
num_stages=4,
4541
num_warps=4,
@@ -49,7 +45,6 @@
4945
"BLOCK_SIZE_M": 64,
5046
"BLOCK_SIZE_N": 128,
5147
"BLOCK_SIZE_K": 32,
52-
"GROUP_SIZE_M": 8,
5348
},
5449
num_stages=4,
5550
num_warps=4,
@@ -59,7 +54,6 @@
5954
"BLOCK_SIZE_M": 128,
6055
"BLOCK_SIZE_N": 32,
6156
"BLOCK_SIZE_K": 32,
62-
"GROUP_SIZE_M": 8,
6357
},
6458
num_stages=4,
6559
num_warps=4,
@@ -69,7 +63,6 @@
6963
"BLOCK_SIZE_M": 64,
7064
"BLOCK_SIZE_N": 32,
7165
"BLOCK_SIZE_K": 32,
72-
"GROUP_SIZE_M": 8,
7366
},
7467
num_stages=5,
7568
num_warps=2,
@@ -79,7 +72,6 @@
7972
"BLOCK_SIZE_M": 32,
8073
"BLOCK_SIZE_N": 64,
8174
"BLOCK_SIZE_K": 32,
82-
"GROUP_SIZE_M": 8,
8375
},
8476
num_stages=5,
8577
num_warps=2,
@@ -114,24 +106,17 @@ def kernel(
114106
BLOCK_SIZE_M: tl.constexpr,
115107
BLOCK_SIZE_N: tl.constexpr,
116108
BLOCK_SIZE_K: tl.constexpr,
117-
GROUP_SIZE_M: tl.constexpr,
118109
):
119110
P: tl.constexpr = H - R + 1
120111
Q: tl.constexpr = W - S + 1
121112

122-
GEMM_M: tl.constexpr = N * P * Q
123113
GEMM_N: tl.constexpr = K
124114
GEMM_K: tl.constexpr = C * R * S
125115

126116
pid = tl.program_id(0)
127-
num_pid_gemm_m = tl.cdiv(GEMM_M, BLOCK_SIZE_M)
128117
num_pid_gemm_n = tl.cdiv(GEMM_N, BLOCK_SIZE_N)
129-
num_pid_in_group = GROUP_SIZE_M * num_pid_gemm_n
130-
group_id = pid // num_pid_in_group
131-
first_pid_gemm_m = group_id * GROUP_SIZE_M
132-
group_size_m = min(num_pid_gemm_m - first_pid_gemm_m, GROUP_SIZE_M)
133-
pid_gemm_m = first_pid_gemm_m + ((pid % num_pid_in_group) % group_size_m)
134-
pid_gemm_n = (pid % num_pid_in_group) // group_size_m
118+
pid_gemm_m = pid // num_pid_gemm_n
119+
pid_gemm_n = pid % num_pid_gemm_n
135120

136121
offs_gemm_i = pid_gemm_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
137122
offs_gemm_j = pid_gemm_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

0 commit comments

Comments
 (0)