File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 99 "BLOCK_SIZE_M" : 128 ,
1010 "BLOCK_SIZE_N" : 256 ,
1111 "BLOCK_SIZE_K" : 64 ,
12- "GROUP_SIZE_M" : 8 ,
1312 },
1413 num_stages = 3 ,
1514 num_warps = 8 ,
1918 "BLOCK_SIZE_M" : 64 ,
2019 "BLOCK_SIZE_N" : 256 ,
2120 "BLOCK_SIZE_K" : 32 ,
22- "GROUP_SIZE_M" : 8 ,
2321 },
2422 num_stages = 4 ,
2523 num_warps = 4 ,
2927 "BLOCK_SIZE_M" : 128 ,
3028 "BLOCK_SIZE_N" : 128 ,
3129 "BLOCK_SIZE_K" : 32 ,
32- "GROUP_SIZE_M" : 8 ,
3330 },
3431 num_stages = 4 ,
3532 num_warps = 4 ,
3936 "BLOCK_SIZE_M" : 128 ,
4037 "BLOCK_SIZE_N" : 64 ,
4138 "BLOCK_SIZE_K" : 32 ,
42- "GROUP_SIZE_M" : 8 ,
4339 },
4440 num_stages = 4 ,
4541 num_warps = 4 ,
4945 "BLOCK_SIZE_M" : 64 ,
5046 "BLOCK_SIZE_N" : 128 ,
5147 "BLOCK_SIZE_K" : 32 ,
52- "GROUP_SIZE_M" : 8 ,
5348 },
5449 num_stages = 4 ,
5550 num_warps = 4 ,
5954 "BLOCK_SIZE_M" : 128 ,
6055 "BLOCK_SIZE_N" : 32 ,
6156 "BLOCK_SIZE_K" : 32 ,
62- "GROUP_SIZE_M" : 8 ,
6357 },
6458 num_stages = 4 ,
6559 num_warps = 4 ,
6963 "BLOCK_SIZE_M" : 64 ,
7064 "BLOCK_SIZE_N" : 32 ,
7165 "BLOCK_SIZE_K" : 32 ,
72- "GROUP_SIZE_M" : 8 ,
7366 },
7467 num_stages = 5 ,
7568 num_warps = 2 ,
7972 "BLOCK_SIZE_M" : 32 ,
8073 "BLOCK_SIZE_N" : 64 ,
8174 "BLOCK_SIZE_K" : 32 ,
82- "GROUP_SIZE_M" : 8 ,
8375 },
8476 num_stages = 5 ,
8577 num_warps = 2 ,
@@ -114,24 +106,17 @@ def kernel(
114106 BLOCK_SIZE_M : tl .constexpr ,
115107 BLOCK_SIZE_N : tl .constexpr ,
116108 BLOCK_SIZE_K : tl .constexpr ,
117- GROUP_SIZE_M : tl .constexpr ,
118109):
119110 P : tl .constexpr = H - R + 1
120111 Q : tl .constexpr = W - S + 1
121112
122- GEMM_M : tl .constexpr = N * P * Q
123113 GEMM_N : tl .constexpr = K
124114 GEMM_K : tl .constexpr = C * R * S
125115
126116 pid = tl .program_id (0 )
127- num_pid_gemm_m = tl .cdiv (GEMM_M , BLOCK_SIZE_M )
128117 num_pid_gemm_n = tl .cdiv (GEMM_N , BLOCK_SIZE_N )
129- num_pid_in_group = GROUP_SIZE_M * num_pid_gemm_n
130- group_id = pid // num_pid_in_group
131- first_pid_gemm_m = group_id * GROUP_SIZE_M
132- group_size_m = min (num_pid_gemm_m - first_pid_gemm_m , GROUP_SIZE_M )
133- pid_gemm_m = first_pid_gemm_m + ((pid % num_pid_in_group ) % group_size_m )
134- pid_gemm_n = (pid % num_pid_in_group ) // group_size_m
118+ pid_gemm_m = pid // num_pid_gemm_n
119+ pid_gemm_n = pid % num_pid_gemm_n
135120
136121 offs_gemm_i = pid_gemm_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
137122 offs_gemm_j = pid_gemm_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
You can’t perform that action at this time.
0 commit comments