Skip to content

Commit fa211b0

Browse files
q10meta-codesync[bot]
authored andcommitted
Add grid-stride loop and ROCm cap to index_add_2d_with_unique_indices_kernel (#5934)
Summary: Pull Request resolved: #5934 X-link: https://github.com/facebookresearch/FBGEMM/pull/2852 Tier-2 fix for HIP grid-overflow in `sparse_ops/sparse_index_add.cu`. `index_add_2d_with_unique_indices_kernel` previously used `blockIdx.x` directly to index unique indices. Capping the host-side grid without first adding a grid-stride loop would silently drop work. Changes: - Add `const int num_unique_indices` as a new kernel parameter. - Convert kernel to a grid-stride loop over `u = blockIdx.x; u < num_unique_indices; u += gridDim.x` (Pattern C). All `blockIdx.x` references replaced with `u`. Hoist `start_D` and `has_remainder` outside the loop since they depend only on `blockIdx.y` / `threadIdx.x`. - RESET per-iteration register state at the top of each iteration: `sum[MAX_ELEMENTS_PER_THREAD]` re-zeroed and `sum_remainder = 0`. - Apply standard `#ifdef USE_ROCM min(blocks_x_uncapped, get_max_thread_blocks(stream)) #else blocks_x_uncapped #endif` cap to the x-dim of the launch grid. y dim is bounded by D/stride_D and needs no cap. Stacked on top of D105029028 (Tier-2 Diff 5/7). Plan: `/home/bensonma415/.llms/plans/sparse_ops_rocm_grid_overflow_tier2_fix.plan.md` (Diff 6/7). Reviewed By: henrylhtsang Differential Revision: D105029511 fbshipit-source-id: 2a33c6218d6b3d1c9c39ca301a1d451f09a39308
1 parent 4bb8e6f commit fa211b0

2 files changed

Lines changed: 162 additions & 35 deletions

File tree

fbgemm_gpu/src/sparse_ops/sparse_index_add.cu

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,53 +29,59 @@ __launch_bounds__(kMaxThreads) void index_add_2d_with_unique_indices_kernel(
2929
const int rounded_D,
3030
const int remaining_D,
3131
const bool consecutive_indices,
32-
const int consecutive_range_start) {
33-
const auto start_offset = blockIdx.x == 0 ? 0 : offsets[blockIdx.x - 1];
34-
const int end_offset = offsets[blockIdx.x];
35-
index_t dst_idx = consecutive_indices ? blockIdx.x + consecutive_range_start
36-
: unique_indices[blockIdx.x];
32+
const int consecutive_range_start,
33+
const int num_unique_indices) {
34+
// Each thread block processes max of stride_D elements
35+
const int start_D = (blockIdx.y * stride_D) + (threadIdx.x * UNROLL_FACTOR);
3736
const bool has_remainder = blockIdx.y == blockDim.y - 1 && remaining_D > 0 &&
3837
threadIdx.x < remaining_D;
3938

40-
// Buffer for storing temporary results
41-
scalar_t sum[MAX_ELEMENTS_PER_THREAD];
42-
for (int i = 0; i < MAX_ELEMENTS_PER_THREAD; i++) {
43-
sum[i] = 0;
44-
}
39+
// Grid-stride over unique indices (the saturating x dim) so a capped grid
40+
// (used on ROCm to avoid the 2^32 launch-side limit) still covers all
41+
// unique indices. blockIdx.y is bounded by D/stride_D and needs no cap.
42+
for (auto u = blockIdx.x; u < num_unique_indices; u += gridDim.x) {
43+
const auto start_offset = u == 0 ? 0 : offsets[u - 1];
44+
const int end_offset = offsets[u];
45+
index_t dst_idx =
46+
consecutive_indices ? u + consecutive_range_start : unique_indices[u];
47+
48+
// RESET per-iteration register state.
49+
scalar_t sum[MAX_ELEMENTS_PER_THREAD];
50+
for (int i = 0; i < MAX_ELEMENTS_PER_THREAD; i++) {
51+
sum[i] = 0;
52+
}
4553

46-
scalar_t sum_remainder = 0;
54+
scalar_t sum_remainder = 0;
4755

48-
// Each thread block processes max of stride_D elements
49-
int start_D = (blockIdx.y * stride_D) + (threadIdx.x * UNROLL_FACTOR);
56+
// For each row
57+
for (int row = start_offset; row < end_offset; row++) {
58+
int64_t src_idx = orig_indices[row];
59+
int col, i;
60+
for (col = start_D, i = 0; col < start_D + stride_D && col < rounded_D;
61+
col += blockDim.x * UNROLL_FACTOR, i += UNROLL_FACTOR) {
62+
#pragma unroll
63+
for (int j = 0; j < UNROLL_FACTOR; j++) {
64+
sum[i + j] += LDG(&out_grad[src_idx][col + j]);
65+
}
66+
}
67+
if (has_remainder) {
68+
sum_remainder += LDG(&out_grad[src_idx][rounded_D + threadIdx.x]);
69+
}
70+
} // for each row
5071

51-
// For each row
52-
for (int row = start_offset; row < end_offset; row++) {
53-
int64_t src_idx = orig_indices[row];
72+
// Write results to global memory
5473
int col, i;
5574
for (col = start_D, i = 0; col < start_D + stride_D && col < rounded_D;
5675
col += blockDim.x * UNROLL_FACTOR, i += UNROLL_FACTOR) {
5776
#pragma unroll
5877
for (int j = 0; j < UNROLL_FACTOR; j++) {
59-
sum[i + j] += LDG(&out_grad[src_idx][col + j]);
78+
in_deduped_grad[dst_idx][col + j] = sum[i + j];
6079
}
6180
}
6281
if (has_remainder) {
63-
sum_remainder += LDG(&out_grad[src_idx][rounded_D + threadIdx.x]);
64-
}
65-
} // for each row
66-
67-
// Write results to global memory
68-
int col, i;
69-
for (col = start_D, i = 0; col < start_D + stride_D && col < rounded_D;
70-
col += blockDim.x * UNROLL_FACTOR, i += UNROLL_FACTOR) {
71-
#pragma unroll
72-
for (int j = 0; j < UNROLL_FACTOR; j++) {
73-
in_deduped_grad[dst_idx][col + j] = sum[i + j];
82+
in_deduped_grad[dst_idx][rounded_D + threadIdx.x] += sum_remainder;
7483
}
7584
}
76-
if (has_remainder) {
77-
in_deduped_grad[dst_idx][rounded_D + threadIdx.x] += sum_remainder;
78-
}
7985
}
8086

8187
DLL_PUBLIC Tensor index_add_with_unique_indices_cuda(
@@ -146,10 +152,21 @@ DLL_PUBLIC Tensor index_add_with_unique_indices_cuda(
146152
offsets = unique_count.cumsum(0);
147153
}
148154

149-
const dim3 grid_size(
155+
const int num_y_blocks = (D + stride_D - 1) / stride_D;
156+
// HIP enforces a hard limit of 2^32 total threads per launch
157+
// (unlike CUDA, which silently wraps).
158+
// index_add_2d_with_unique_indices_kernel grid-strides over the
159+
// unique index (x) dim, so capping x is correctness-preserving.
160+
// The y dim is not grid-strided, so fold num_y_blocks into the
161+
// per-launch thread count used for the overflow check, keeping
162+
// the cap accounting consistent with the launcher's total-thread
163+
// check (grid.x * grid.y * block_size). See:
164+
// https://github.com/ROCm/hip/issues/2253
165+
const auto blocks_x = utils::cuda::cap_grid_dim_x(
150166
cuda_calc_xblock_count(num_unique_indices, 1),
151-
(D + stride_D - 1) / stride_D,
152-
1);
167+
static_cast<int64_t>(block_size) * num_y_blocks,
168+
at::cuda::getCurrentCUDAStream());
169+
const dim3 grid_size(blocks_x, num_y_blocks, 1);
153170

154171
const auto unique_indices_ = consecutive_indices
155172
? at::empty(
@@ -177,7 +194,8 @@ DLL_PUBLIC Tensor index_add_with_unique_indices_cuda(
177194
rounded_D,
178195
remaining_D,
179196
consecutive_indices,
180-
consecutive_range_start);
197+
consecutive_range_start,
198+
num_unique_indices);
181199
});
182200
});
183201
return input_grad.reshape(input_shape);

fbgemm_gpu/test/jagged/jagged_index_select_2d_test.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,5 +262,114 @@ def test_jagged_index_add_2d_forward_negative_rows_errors(self) -> None:
262262
)
263263

264264

265+
class JaggedIndexSelect2DLargeGridTest(unittest.TestCase):
266+
"""
267+
Retro: regression tests for the HIP grid-overflow bug in
268+
``index_add_2d_with_unique_indices_kernel`` (D105029511 /
269+
Subplan B Diff #10), which lacked its own test method when
270+
landed.
271+
272+
Block: dim3(stride_D / UNROLL_FACTOR, num_y_blocks).
273+
Grid: dim3(num_unique_indices, ceil(D / stride_D), 1).
274+
The production cap is `blocks_x = min(num_unique_indices,
275+
get_max_thread_blocks(stream))` (~16384 on MI300/MI350); the
276+
kernel grid-strides over the unique-index axis post-fix.
277+
"""
278+
279+
@classmethod
280+
def _has_gpu(cls) -> bool:
281+
return torch.cuda.is_available()
282+
283+
@classmethod
284+
def _gpu_memory_lt(cls, gb: int) -> bool:
285+
if not cls._has_gpu():
286+
return True
287+
return torch.cuda.get_device_properties(0).total_memory < gb * (1 << 30)
288+
289+
@unittest.skipUnless(torch.cuda.is_available(), "GPU not available")
290+
def test_index_add_2d_with_unique_indices_correctness(self) -> None:
291+
"""
292+
Multi-block correctness check at small scale via the autograd
293+
backward of ``jagged_index_select`` (which dispatches to
294+
``index_add_2d_with_unique_indices_kernel``). Sentinel non-zero
295+
values at start / middle / end of the unique-index axis force
296+
the grid-stride outer loop to iterate.
297+
"""
298+
if self._gpu_memory_lt(4):
299+
self.skipTest("Requires >= 4 GiB GPU memory")
300+
device = torch.accelerator.current_accelerator()
301+
# num_unique_indices > 2 * 1024 so the grid-stride loop iterates.
302+
N = 2 * 1024 + 3
303+
D = 16
304+
# Sparse lengths: most entries 0, sentinel non-zero at start /
305+
# middle / end so the kernel produces non-trivial backward grad.
306+
lengths_cpu = torch.zeros(N, dtype=torch.int64)
307+
lengths_cpu[0] = 1
308+
lengths_cpu[N // 2] = 2
309+
lengths_cpu[N - 1] = 3
310+
total = int(lengths_cpu.sum().item())
311+
# All unique inverse_lookup values so dedup keeps every batch.
312+
inverse_lookup_cpu = torch.arange(N, dtype=torch.int64)
313+
314+
values_init = torch.arange(total * D, dtype=torch.float32).reshape(total, D)
315+
316+
# GPU forward + backward.
317+
values_gpu = values_init.detach().clone().to(device).requires_grad_(True)
318+
output_gpu, _ = torch.ops.fbgemm.jagged_index_select(
319+
values_gpu, lengths_cpu.to(device), inverse_lookup_cpu.to(device)
320+
)
321+
output_gpu.sum().backward()
322+
323+
# CPU reference: backward of jagged_index_select with a permutation
324+
# `inverse_lookup` is a scatter_add of grad over unique indices.
325+
# With identity inverse_lookup and grad = ones, the expected
326+
# gradient is `ones` for every selected row.
327+
# pyre-ignore[16]
328+
self.assertEqual(values_gpu.grad.shape, values_init.shape)
329+
torch.testing.assert_close(
330+
values_gpu.grad.cpu(),
331+
torch.ones_like(values_init),
332+
)
333+
334+
@unittest.skipUnless(torch.cuda.is_available(), "GPU not available")
335+
def test_index_add_2d_with_unique_indices_large_grid(self) -> None:
336+
"""
337+
Launch-survival regression test at the cap-trip scale.
338+
339+
Pre-fix, ``index_add_2d_with_unique_indices_kernel`` launches
340+
with grid_x = num_unique_indices and per-block thread count
341+
determined by stride_D / UNROLL_FACTOR. With D = 8 and
342+
num_unique_indices = (1 << 22) + 1 the cap-trip path on ROCm
343+
would TORCH_CHECK-fail; post-fix the host caps grid_x to
344+
``get_max_thread_blocks(stream)`` and the kernel grid-strides.
345+
346+
Memory budget: values ~ N * D * 4B = 128 MiB per copy;
347+
the fwd output is the same size. Skip if HBM < 4 GiB.
348+
"""
349+
if self._gpu_memory_lt(4):
350+
self.skipTest("Requires >= 4 GiB GPU memory")
351+
device = torch.accelerator.current_accelerator()
352+
N = (1 << 22) + 1
353+
D = 8
354+
# All-zero lengths with one non-zero entry so the backward
355+
# kernel still launches over all unique indices.
356+
lengths_cpu = torch.zeros(N, dtype=torch.int64)
357+
lengths_cpu[0] = 1
358+
total = int(lengths_cpu.sum().item())
359+
inverse_lookup_cpu = torch.arange(N, dtype=torch.int64)
360+
361+
values = torch.zeros(
362+
(total, D), dtype=torch.float32, device=device, requires_grad=True
363+
)
364+
# Pre-fix this trips KernelLauncher::checkThreadCountNotExceeded
365+
# on ROCm at the index_add_2d_with_unique_indices launch.
366+
output, _ = torch.ops.fbgemm.jagged_index_select(
367+
values, lengths_cpu.to(device), inverse_lookup_cpu.to(device)
368+
)
369+
output.sum().backward()
370+
# pyre-ignore[16]
371+
self.assertEqual(values.grad.shape, values.shape)
372+
373+
265374
if __name__ == "__main__":
266375
unittest.main()

0 commit comments

Comments
 (0)