Skip to content

Commit f3e2f05

Browse files
committed
[CuTe][SM70] Add comment explaining why int() cast is required for blockIdx coords
1 parent f74fea9 commit f3e2f05

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

include/cutlass/gemm/kernel/sm70_gemm.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializin
210210
int thread_idx = int(threadIdx.x);
211211
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
212212
auto [m_coord, n_coord, l_coord] = static_cast<uint3>(blockIdx);
213-
auto blk_coord_mnkl = make_coord(int(m_coord), int(n_coord), _, int(l_coord)); // (m,n,k,l)
213+
auto blk_coord_mnkl = make_coord(int(m_coord), int(n_coord), _, int(l_coord)); // (m,n,k,l) NOTE: int() cast is required. blockIdx returns uint3, and passing unsigned coords to make_coord can cause arithmetic underflow when computing tile residues for predication on small problem shapes (e.g. shape < TileShape).
214214
215215
// Represent the full tensors
216216
Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); //(m,k,l)

include/cutlass/gemm/kernel/sm70_gemm_array.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializin
217217
int thread_idx = int(threadIdx.x);
218218
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
219219
auto [m_coord, n_coord, l_coord] = static_cast<uint3>(blockIdx);
220-
auto blk_coord_mnkl = make_coord(int(m_coord), int(n_coord), _, int(l_coord)); // (m,n,k,l)
220+
auto blk_coord_mnkl = make_coord(int(m_coord), int(n_coord), _, int(l_coord)); // (m,n,k,l) NOTE: int() cast is required. blockIdx returns uint3, and passing unsigned coords to make_coord can cause arithmetic underflow when computing tile residues for predication on small problem shapes (e.g. shape < TileShape).
221221

222222
// Represent the full tensors
223223
Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A[l_coord]), make_shape(M,K,1), params.mainloop.dA); //(m,k,l)

0 commit comments

Comments
 (0)