Skip to content

Commit 7f8a801

Browse files
q10meta-codesync[bot]
authored andcommitted
Fix HIP grid overflow in pack_segments_cuda{,_v2} (#5907)
Summary: Pull Request resolved: #5907 Apply the same `#ifdef USE_ROCM` cap pattern used in D104903707 / D104937969 / parent diffs to the two launch sites in `pack_segments_forward_cuda` and `pack_segments_forward_cuda_v2` in `sparse_pack_segments_forward.cu`. Both ops launch their kernels with block size 128 and grid `cuda_calc_xblock_count(num_seq * max_length * cell_size, 128)`. Once the product `num_seq * max_length * cell_size > 2^32`, total threads exceed the HIP `2^32` limit and the launch is rejected on ROCm. Both `pack_segments_cuda_kernel` (uses `CUDA_KERNEL_LOOP`) and `pack_segments_cuda_v2_kernel` (uses `CUDA_KERNEL_LOOP_TYPE`) already grid-stride, so capping the grid is correctness-preserving. The `#ifdef USE_ROCM / #else / #endif` selector keeps NVIDIA codegen bit-identical and unconditionally caps on ROCm. Same family of fix as: - D104903707 (permute_1D_sparse_data — parent diff) - D104937969 (permute_2D_sparse_data — parent diff) Reviewed By: henrylhtsang Differential Revision: D104950916 fbshipit-source-id: e8999860e7b7f64250e61daffbfae00ae71ee36a
1 parent d334a2f commit 7f8a801

2 files changed

Lines changed: 124 additions & 4 deletions

File tree

fbgemm_gpu/src/sparse_ops/sparse_pack_segments_forward.cu

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,19 @@ DLL_PUBLIC Tensor pack_segments_forward_cuda(
136136
const auto num_seq = lengths_c.size(0);
137137
const auto cell_size = t_in_c.numel() / t_in_c.size(0);
138138

139+
// HIP enforces a hard limit of 2^32 total threads per launch
140+
// (unlike CUDA, which silently wraps). pack_segments_cuda_kernel
141+
// uses CUDA_KERNEL_LOOP, which already grid-strides, so capping is
142+
// correctness-preserving.
143+
// See: https://github.com/ROCm/hip/issues/2253
144+
const auto blocks = utils::cuda::cap_grid_dim_x_from_workload(
145+
num_seq * max_length * cell_size,
146+
128,
147+
at::cuda::getCurrentCUDAStream());
148+
139149
FBGEMM_LAUNCH_DSA_KERNEL(
140150
(pack_segments_cuda_kernel<index_t, scalar_t>),
141-
cuda_calc_xblock_count(num_seq * max_length * cell_size, 128),
151+
blocks,
142152
128,
143153
0,
144154
at::cuda::getCurrentCUDAStream(),
@@ -233,9 +243,19 @@ pack_segments_forward_cuda_v2(
233243
const auto num_seq = lengths_c.size(0);
234244
const auto cell_size = t_in_c.numel() / t_in_c.size(0);
235245

246+
// HIP enforces a hard limit of 2^32 total threads per launch
247+
// (unlike CUDA, which silently wraps). pack_segments_cuda_v2_kernel
248+
// uses CUDA_KERNEL_LOOP_TYPE, which already grid-strides, so capping
249+
// is correctness-preserving.
250+
// See: https://github.com/ROCm/hip/issues/2253
251+
const auto blocks = utils::cuda::cap_grid_dim_x_from_workload(
252+
num_seq * max_length * cell_size,
253+
128,
254+
at::cuda::getCurrentCUDAStream());
255+
236256
FBGEMM_LAUNCH_DSA_KERNEL(
237257
(pack_segments_cuda_v2_kernel<index_t, scalar_t>),
238-
cuda_calc_xblock_count(num_seq * max_length * cell_size, 128),
258+
blocks,
239259
128,
240260
0,
241261
at::cuda::getCurrentCUDAStream(),

fbgemm_gpu/test/sparse/pack_segments_test.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@
2121

2222
if open_source:
2323
# pyre-ignore[21]
24-
from test_utils import gpu_available, gpu_unavailable
24+
from test_utils import gpu_available, gpu_memory_lt_gb, gpu_unavailable
2525
else:
26-
from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable
26+
from fbgemm_gpu.test.test_utils import (
27+
gpu_available,
28+
gpu_memory_lt_gb,
29+
gpu_unavailable,
30+
)
2731

2832

2933
def get_n_rand_num_summing_to_k(n: int, k: int) -> npt.NDArray:
@@ -623,6 +627,102 @@ def test_pack_segments_backward_truncated(self, dtype: torch.dtype) -> None:
623627
)
624628
cumsum += L
625629

630+
@unittest.skipIf(*gpu_unavailable)
631+
# Skip on GPUs with insufficient HBM. The test allocates the packed
632+
# output of shape (num_seq, max_length) at fp16, ~8 GiB at the chosen
633+
# max_length.
634+
@unittest.skipIf(*gpu_memory_lt_gb(12))
635+
def test_pack_segments_large_grid(self) -> None:
636+
"""
637+
Reproduces the HIP grid-overflow bug in pack_segments_cuda{,_v2}
638+
and verifies output correctness via a downsampled CPU oracle.
639+
640+
With block size 128, the launch grid is
641+
cuda_calc_xblock_count(num_seq * max_length * cell_size, 128).
642+
For num_seq * max_length * cell_size > 2**32, total threads
643+
exceed the HIP 2**32 limit, causing
644+
FBGEMM_LAUNCH_DSA_KERNEL ->
645+
KernelLauncher::checkThreadCountNotExceeded to TORCH_CHECK-fail
646+
on ROCm pre-fix. Both pack_segments_cuda_kernel (uses
647+
CUDA_KERNEL_LOOP) and pack_segments_cuda_v2_kernel (uses
648+
CUDA_KERNEL_LOOP_TYPE) already grid-stride, so capping the grid
649+
is correctness-preserving for the launcher.
650+
651+
Verification strategy (per master plan's downsampled-oracle
652+
guidance for ops where the full-scale CPU oracle is impractical):
653+
654+
1. Full-scale invocation of v1 and v2 to verify the launch
655+
survives the production cap. Only shape is asserted because
656+
v1 uses ``CUDA_KERNEL_LOOP`` with an int32 loop index, which
657+
overflows for output linear indices >= 2**31; element-wise
658+
comparison would surface this pre-existing kernel bug, which
659+
is out of scope for this diff (the diff only caps the grid).
660+
2. Small-scale invocation of v1 and v2 vs CPU dispatch to
661+
validate kernel correctness end-to-end at a scale where the
662+
int32 loop index does not overflow. This catches kernel
663+
correctness regressions introduced by the cap fix.
664+
"""
665+
666+
# Choose num_seq * max_length so that total threads strictly
667+
# exceeds 2**32. With cell_size=1: total threads ~= num_seq *
668+
# max_length; need product > 2**32.
669+
num_seq = 2
670+
max_length = (1 << 31) + 1
671+
672+
device = torch.device(torch.accelerator.current_accelerator() or "cuda")
673+
674+
# ---- Step 1: full-scale launch survival (cap-trip detection). ----
675+
# Sparse non-zero lengths: only the last segment is non-empty.
676+
# t_in has a single sentinel value.
677+
lengths_large = torch.zeros(num_seq, dtype=torch.int32, device=device)
678+
lengths_large[-1] = 1
679+
t_in_large = torch.tensor([3.5], dtype=torch.float16, device=device)
680+
681+
# Pre-fix, this launch trips KernelLauncher::checkThreadCountNotExceeded.
682+
packed_v1 = torch.ops.fbgemm.pack_segments(
683+
t_in_large, lengths_large, max_length
684+
)
685+
self.assertEqual(packed_v1.shape, (num_seq, max_length))
686+
del packed_v1
687+
688+
packed_v2, _ = torch.ops.fbgemm.pack_segments_v2(
689+
t_in_large, lengths_large, max_length
690+
)
691+
self.assertEqual(packed_v2.shape, (num_seq, max_length))
692+
del packed_v2
693+
694+
# ---- Step 2: downsampled CPU-oracle correctness check. ----
695+
# Same kernel code path, smaller scale to keep the int32 loop
696+
# index of v1 in range and the CPU oracle cheap.
697+
small_max_length = 16
698+
small_lengths_cpu = torch.tensor([0, 3, 0, 2], dtype=torch.int32)
699+
# Total non-zero lengths = 5; t_in is a sequence of distinct
700+
# values so any "wrong row/col" bug surfaces.
701+
small_t_in_cpu = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float16)
702+
703+
# CPU oracles.
704+
small_packed_cpu = torch.ops.fbgemm.pack_segments(
705+
small_t_in_cpu, small_lengths_cpu, small_max_length
706+
)
707+
small_packed_cpu_v2, _ = torch.ops.fbgemm.pack_segments_v2(
708+
small_t_in_cpu, small_lengths_cpu, small_max_length
709+
)
710+
711+
# GPU under test.
712+
small_packed_gpu = torch.ops.fbgemm.pack_segments(
713+
small_t_in_cpu.to(device),
714+
small_lengths_cpu.to(device),
715+
small_max_length,
716+
)
717+
small_packed_gpu_v2, _ = torch.ops.fbgemm.pack_segments_v2(
718+
small_t_in_cpu.to(device),
719+
small_lengths_cpu.to(device),
720+
small_max_length,
721+
)
722+
723+
torch.testing.assert_close(small_packed_gpu.cpu(), small_packed_cpu)
724+
torch.testing.assert_close(small_packed_gpu_v2.cpu(), small_packed_cpu_v2)
725+
626726

627727
extend_test_class(PackedSegmentsTest)
628728

0 commit comments

Comments
 (0)