Skip to content

Commit e1cf8fb

Browse files
q10meta-codesync[bot]
authored andcommitted
Fix HIP grid overflow in permute_1D_sparse_data_cuda (#5763)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2693 Pull Request resolved: #5763 A MAST training job (`fire-zzt-ESFM-MI350X-20260508-1651-3488ad1a`) was failing on ROCm in `permute_1D_sparse_data_cuda` with: ``` sparse_permute_1d.hip(339:8) [(permute_1D_data_kernel_vec<false, offsets_t, indices_t, std::nullptr_t>)] [grid dim 6252106 x 1 x 1] [block dim 64 x 16 x 1]: Total number of threads 6402156544 is greater than the HIP limit of 2^32 ``` The launch site uses `dim3(64, BT_blocks=16)` (block size 1024) and `blocks = cuda_calc_xblock_count(permuted_lengths_size, BT_blocks)`, so once `permuted_lengths_size > 2^26 ≈ 67M` segments, total threads exceed `2^32` and HIP refuses the launch (CUDA's runtime silently handles the wrap; ROCm does not — see ROCm/hip#2253). The MAST log shows ~100M segments, well past the limit. The kernels `permute_1D_data_kernel_vec` and `permute_1D_data_kernel` already implement a grid-stride loop over `b_t`, so no kernel-side changes are needed — only the launch site needs to cap the grid. The lengths kernel uses `CUDA_KERNEL_LOOP`, which also already grid-strides. Apply the D94944619 conditional-cap pattern at both kernel launch sites in `permute_1D_sparse_data_cuda`: - Compute `total_threads` as a `uint64_t` from the unconstrained grid. - If `total_threads >= numeric_limits<uint32_t>::max()`, cap the grid to `min(num_threadblocks, utils::cuda::get_max_thread_blocks(stream))`. - Otherwise pass through the existing value (no perf change for the common case, including NVIDIA — the generated launch is bit-identical). Same family of fix as: - D65009966 (bounds_check_indices) - D75543767 (TBE forward) - D94944619 (TBE forward V2 — conditional cap) Out of scope: `sparse_permute_2d.cu` (`permute_2D_data_kernel_vec`) has the same pattern at line 253 with `dim3(32, 32)` and is a candidate for the same fix as a follow-up. Reviewed By: spcyppt Differential Revision: D104903707 fbshipit-source-id: 049a7f70ceacd6d7cfd63fa305976e9a95978e01
1 parent 3203889 commit e1cf8fb

3 files changed

Lines changed: 117 additions & 6 deletions

File tree

fbgemm_gpu/fbgemm_gpu/sparse_ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,16 @@ def permute_1D_sparse_data_meta(
297297
permuted_indices = indices.new_empty(permuted_indices_size)
298298
permuted_weights = None
299299
if weights is not None:
300-
# pyre-fixme
301-
permuted_weights = weights.new_empty(permuted_indices_size)
300+
# Preserve trailing dimensions for N-D weights so the meta function
301+
# matches the concrete kernel's output shape (e.g. [total, W] for
302+
# 2D weights consumed by the vec kernel). Previously this always
303+
# returned a 1D tensor, which broke the faketensor opcheck on the
304+
# 2D-weights tests.
305+
permuted_weights = (
306+
weights.new_empty(permuted_indices_size)
307+
if weights.dim() <= 1
308+
else weights.new_empty([permuted_indices_size, *weights.shape[1:]])
309+
)
302310
return permuted_lengths, permuted_indices, permuted_weights
303311

304312

fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,19 @@ permute_1D_sparse_data_cuda(
233233
permuted_lengths = at::empty({permuted_lengths_size}, lengths.options());
234234

235235
constexpr int32_t threads_1 = kMaxThreads;
236-
const auto blocks_1 =
236+
const auto blocks_1_uncapped =
237237
cuda_calc_xblock_count(permuted_lengths_size, threads_1);
238+
#ifdef USE_ROCM
239+
// HIP enforces a hard limit of 2^32 total threads per launch (unlike CUDA,
240+
// which silently wraps). Cap the grid unconditionally on ROCm;
241+
// permute_1D_lengths_kernel uses CUDA_KERNEL_LOOP, which already
242+
// grid-strides, so capping is correctness-preserving.
243+
const auto blocks_1 = std::min<uint32_t>(
244+
blocks_1_uncapped,
245+
utils::cuda::get_max_thread_blocks(at::cuda::getCurrentCUDAStream()));
246+
#else
247+
const auto blocks_1 = blocks_1_uncapped;
248+
#endif
238249
AT_DISPATCH_INDEX_TYPES(
239250
lengths.scalar_type(), "permute_1D_lengths_kernel", [&] {
240251
FBGEMM_LAUNCH_KERNEL(
@@ -262,8 +273,19 @@ permute_1D_sparse_data_cuda(
262273

263274
constexpr int32_t BT_blocks = 16;
264275
dim3 threads_2(64, BT_blocks);
265-
const auto blocks_2 =
276+
const auto blocks_2_uncapped =
266277
cuda_calc_xblock_count(permuted_lengths_size, BT_blocks);
278+
#ifdef USE_ROCM
279+
// HIP enforces a hard limit of 2^32 total threads per launch (unlike CUDA,
280+
// which silently wraps). Cap the grid unconditionally on ROCm; the
281+
// kernel's grid-striding loop over b_t handles the overflow, so capping is
282+
// correctness-preserving.
283+
const auto blocks_2 = std::min<uint32_t>(
284+
blocks_2_uncapped,
285+
utils::cuda::get_max_thread_blocks(at::cuda::getCurrentCUDAStream()));
286+
#else
287+
const auto blocks_2 = blocks_2_uncapped;
288+
#endif
267289
permuted_indices = at::empty(permuted_indices_size, indices.options());
268290

269291
AT_DISPATCH_INDEX_TYPES(

fbgemm_gpu/test/sparse/permute_indices_test.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,20 @@
2727

2828
if open_source:
2929
# pyre-ignore[21]
30-
from test_utils import gpu_available, gpu_unavailable, on_oss_clang
30+
from test_utils import (
31+
gpu_available,
32+
gpu_memory_lt_gb,
33+
gpu_unavailable,
34+
on_oss_clang,
35+
)
3136
else:
3237
import fbgemm_gpu.sparse_ops # noqa: F401, E402
33-
from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable, on_oss_clang
38+
from fbgemm_gpu.test.test_utils import (
39+
gpu_available,
40+
gpu_memory_lt_gb,
41+
gpu_unavailable,
42+
on_oss_clang,
43+
)
3444

3545

3646
class PermuteIndicesTest(unittest.TestCase):
@@ -786,6 +796,77 @@ def test_permute_2D_indices_large_segments(
786796
else:
787797
self.assertIsNone(permuted_weights_gpu)
788798

799+
@unittest.skipIf(*gpu_unavailable)
800+
# Skip on GPUs with insufficient HBM (need a few hundred MB for the
801+
# int32 N-element tensors).
802+
@unittest.skipIf(*gpu_memory_lt_gb(4))
803+
def test_permute_1D_sparse_data_large_grid(self) -> None:
804+
"""
805+
Reproduces the HIP grid-overflow bug in permute_1D_sparse_data_cuda
806+
and verifies output correctness at the same scale.
807+
808+
With BT_blocks=16 and dim3(64, 16) (block size 1024), the launch grid
809+
is cuda_calc_xblock_count(N, 16). For N > 2**26, total threads exceed
810+
the HIP 2**32 limit, causing FBGEMM_LAUNCH_KERNEL ->
811+
KernelLauncher::checkThreadCountNotExceeded to TORCH_CHECK-fail on
812+
ROCm pre-fix. With the production fix in place, this test additionally
813+
validates output correctness against the CPU dispatch of the same op
814+
— the GPU output must match the CPU reference element-for-element.
815+
816+
``lengths`` is sparse: all zero except for three known non-zero
817+
positions (start / middle / end of the logical range), so HBM usage
818+
stays bounded (~few hundred MB int32) while the permutation logic is
819+
still exercised. ``permute`` is a deterministic non-identity circular
820+
shift (``perm[i] != i`` everywhere), so any "kernel computed identity
821+
instead of permutation" bug surfaces in the assertion below.
822+
"""
823+
824+
# Choose N so that total threads strictly exceeds 2**32:
825+
# cuda_calc_xblock_count(N, 16) * 1024 ~= N * 64; need N > 2**26.
826+
N = (1 << 26) + 1
827+
828+
device = torch.device(torch.accelerator.current_accelerator() or "cuda")
829+
830+
# Deterministic non-identity permute: circular shift by +1.
831+
# perm_cpu[0] == N - 1 and perm_cpu[i] == i - 1 for i >= 1, so
832+
# perm_cpu[i] != i for every i.
833+
perm_cpu = torch.roll(torch.arange(N, dtype=torch.int32), 1)
834+
permute = perm_cpu.to(device)
835+
836+
# Sparse non-zero lengths at start / middle / end. Total = 10.
837+
lengths_cpu = torch.zeros(N, dtype=torch.int32)
838+
lengths_cpu[0] = 3
839+
lengths_cpu[N // 2] = 5
840+
lengths_cpu[N - 1] = 2
841+
lengths = lengths_cpu.to(device)
842+
843+
# Distinct indices per segment so the permutation is fully observable.
844+
indices_cpu = torch.arange(10, dtype=torch.int32)
845+
indices = indices_cpu.to(device)
846+
847+
# CPU reference oracle — same op, different dispatch.
848+
(
849+
permuted_lengths_cpu,
850+
permuted_indices_cpu,
851+
_permuted_weights_cpu,
852+
) = torch.ops.fbgemm.permute_1D_sparse_data(
853+
perm_cpu, lengths_cpu, indices_cpu, None, None
854+
)
855+
856+
# GPU op under test. Pre-fix, this launch trips
857+
# KernelLauncher::checkThreadCountNotExceeded on ROCm.
858+
(
859+
permuted_lengths_gpu,
860+
permuted_indices_gpu,
861+
permuted_weights_gpu,
862+
) = torch.ops.fbgemm.permute_1D_sparse_data(
863+
permute, lengths, indices, None, None
864+
)
865+
866+
torch.testing.assert_close(permuted_lengths_gpu.cpu(), permuted_lengths_cpu)
867+
torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu)
868+
self.assertIsNone(permuted_weights_gpu)
869+
789870

790871
extend_test_class(PermuteIndicesTest)
791872

0 commit comments

Comments
 (0)