From a7b4c75cf3206b0d368ebf444b6dce3e5bfa777f Mon Sep 17 00:00:00 2001 From: Albert Chen Date: Mon, 18 May 2026 06:31:24 -0700 Subject: [PATCH 1/2] Speedup unique_indices_length_kernel via binary search (#5766) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2695 The previous unique_indices_length_kernel computes per-feature unique-count via a BlockReduce-based min/max scan over the entire reverse_index array. With grid size = T (number of feature groups, typically 1-2 in production), only T SMs do work out of 132 on H100. Each block scans the full per-feature slice of reverse_index (~12M int64 = ~93MB for the prod IFR-MTML mc7 shape), bandwidth-bound on a single SM at ~30-50 GB/s. Total wall-clock is ~2-3 ms, dominating this op end-to-end (~60% of the ~5 ms baseline on the prod shape). The kernel was reading 186 MB to compute 4 numbers (a min and max per feature group). It is wasteful because the information is already implicit in `linear_unique_indices`: since `at::_unique` is called with `sorted=True` and `linearize_index_wo_infos_kernel` writes `linear_indices[i] = hash_size_cumsum[t] + indices[i]`, feature t's unique linearized values occupy a contiguous slice of `linear_unique_indices`, namely `[lower_bound(unique, hash_size_cumsum[t]), lower_bound(unique, hash_size_cumsum[t+1]))`. The slice length is `num_unique_t`, which equals the `(max - min + 1)` reduction the old kernel computed. Replace the O(N) reduction with two O(log U) binary searches per feature group via a new device-side `device_lower_bound` helper. Block size 1024 -> 256 (no shared-memory reduction, no per-thread scratch). The per-block work is now ~336 B of reads (two binary searches, ~21 iterations each, 8 B per iteration), which trivially fits in cache; the T-block grid stops mattering because there is no work to parallelize. The pipeline contract that ties the four kernels of `jagged_unique_indices_cuda` together (linearize -> at::_unique -> delinearize -> length) is documented above the function so the next reader does not have to reverse-engineer it from the kernel bodies. The length kernel docstring states the local form of the invariant and points at the orchestrator for the why. Also adds `test_jagged_unique_indices_zch_huge_hash_size`, a regression test for the `ManagedCollisionCollection` shape that exposes `total_hash_size = INT64_MAX`. This shape is produced when a sharding group contains a single `HashZchManagedCollisionModule` with the default `input_hash_size=0`. `mc_modules._create_input_dists` then expands per-table hash size to `2**(63 - N) - 1` (per `torchrec/distributed/mc_modules.py:643`); for N=0 (single-table group) that lands at INT64_MAX. This shape was not exercised by any existing test and was the trigger for the `cudaErrorIllegalInstruction` in S660690. **The new length kernel handles it correctly (integer-only arithmetic at the boundary), but the test also serves as a trip-wire for downstream optimizations that introduce float-log2 math on `total_hash_size`.** No public API change. Outputs of `jagged_unique_indices` are bit-identical to the previous version for all valid inputs. Reviewed By: q10 Differential Revision: D104827588 --- .../jagged_unique_indices.cu | 152 ++++++++++++------ fbgemm_gpu/test/jagged/failures_dict.json | 8 + fbgemm_gpu/test/jagged/unique_indices_test.py | 68 ++++++++ 3 files changed, 175 insertions(+), 53 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu index 74fe59bc36..0c7e6482a5 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu @@ -22,6 +22,11 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { +// Block size for the flat-grid kernels in this file. 256 = 8 warps, chosen +// so the SM scheduler can pack ~4-8 blocks per SM on H100 given the +// flat-grid launches with grid = total_B (or num_unique). +static constexpr int32_t kFlatBlockSize = 256; + // Linearzie the index with the cumsum of hash size so that linearized indices // can be sorted together. template @@ -79,69 +84,112 @@ __global__ __launch_bounds__(kMaxThreads) void delinearize_unique_index_kernel( } } -// Compute the lengths for each feature in the unique indices. The range of -// indices for each feature equals to the difference between the max and min -// values in the reverse index array. -template -__global__ __launch_bounds__(kMaxThreads) void unique_indices_length_kernel( +// Device-side lower_bound over a PackedTensorAccessor32. +// Returns the first position whose value is >= `value`, equivalent to +// std::lower_bound on the underlying sorted array. +template +__device__ __forceinline__ int32_t device_lower_bound( + const pta::PackedTensorAccessor32& arr, + const index_t value) { + int32_t lo = 0; + int32_t hi = arr.size(0); + while (lo < hi) { + const int32_t mid = lo + ((hi - lo) >> 1); + if (arr[mid] < value) { + lo = mid + 1; + } else { + hi = mid; + } + } + return lo; +} + +// Compute the per-(feature, batch) lengths for the unique indices. +// +// Caller-provided invariant (see jagged_unique_indices_cuda for the +// pipeline contract that establishes it): `linear_unique_indices` is +// sorted ascending, and feature t's values occupy a contiguous slice +// [lower_bound(linear_unique_indices, hash_size_cumsum[t]), +// lower_bound(linear_unique_indices, hash_size_cumsum[t+1])). +// The slice length equals num_unique_t for feature t. +template +__global__ __launch_bounds__(kFlatBlockSize) void unique_indices_length_kernel( const pta::PackedTensorAccessor32 hash_size_offsets, const pta::PackedTensorAccessor32 - reverse_index, + hash_size_cumsum, const pta::PackedTensorAccessor32 - offsets, - pta::PackedTensorAccessor32 lengths) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage_max; - __shared__ typename BlockReduce::TempStorage temp_storage_min; - __shared__ index_t block_results[2]; - + linear_unique_indices, + pta::PackedTensorAccessor32 lengths, + const int32_t batch_size) { const auto tid = threadIdx.x; const auto bid = blockIdx.x; - const auto num_blocks = gridDim.x; - const int32_t batch_size = (offsets.size(0) - 1) / num_blocks; - - const auto offset_begin = hash_size_offsets[bid] * batch_size; - const auto offset_end = hash_size_offsets[bid + 1] * batch_size; - const auto num_lengths = (offset_end - offset_begin); - - const auto reverse_index_begin = offsets[offset_begin]; - const auto reverse_index_end = offsets[offset_end]; - if (reverse_index_begin == reverse_index_end) { + const auto hash_begin = hash_size_offsets[bid]; + const auto hash_end = hash_size_offsets[bid + 1]; + const auto offset_begin = hash_begin * batch_size; + const auto offset_end = hash_end * batch_size; + const auto num_lengths = offset_end - offset_begin; + if (num_lengths == 0) { return; } - index_t t_max = min_value; - index_t t_min = max_value; - for (index_t i = (reverse_index_begin + tid); i < reverse_index_end; - i += kMaxThreads) { - const index_t value = reverse_index[i]; - t_max = (value > t_max) ? value : t_max; - t_min = (value < t_min) ? value : t_min; - } - - index_t block_max = - BlockReduce(temp_storage_max).Reduce(t_max, Max()); - index_t block_min = - BlockReduce(temp_storage_min).Reduce(t_min, Min()); + __shared__ index_t s_div_length; + __shared__ index_t s_r_length; if (tid == 0) { - block_results[0] = block_max; - block_results[1] = block_min; + const auto low = hash_size_cumsum[hash_begin]; + const auto high = hash_size_cumsum[hash_end]; + if (low == high) { + // Empty feature group. Output is pre-zeroed by at::zeros at the + // launch site; nothing to write. + s_div_length = 0; + s_r_length = 0; + } else { + const int32_t lo_pos = + device_lower_bound(linear_unique_indices, low); + const int32_t hi_pos = + device_lower_bound(linear_unique_indices, high); + const index_t total_length = static_cast(hi_pos - lo_pos); + s_div_length = total_length / static_cast(num_lengths); + s_r_length = total_length % static_cast(num_lengths); + } } __syncthreads(); - t_max = block_results[0]; - t_min = block_results[1]; - const index_t total_length = (t_max - t_min) + 1; - const index_t div_length = total_length / num_lengths; - const index_t r_length = total_length % num_lengths; - for (int32_t i = tid; i < num_lengths; i += kMaxThreads) { - index_t seg_length = (i < r_length) ? (div_length + 1) : div_length; + const index_t div_length = s_div_length; + const index_t r_length = s_r_length; + if (div_length == 0 && r_length == 0) { + return; + } + for (int32_t i = tid; i < num_lengths; i += blockDim.x) { + const index_t seg_length = + (static_cast(i) < r_length) ? (div_length + 1) : div_length; lengths[offset_begin + i] = seg_length; } } +// Pipeline (cross-kernel data flow that ties the four steps together): +// +// 1. linearize_index_wo_infos_kernel writes +// linear_indices[i] = hash_size_cumsum[t] + indices[i] +// so feature t's linearized values lie in +// [hash_size_cumsum[t], hash_size_cumsum[t+1]). +// +// 2. at::_unique(linear_indices, sorted=True, return_inverse=True) +// returns (linear_unique_indices, reverse_index) where +// linear_unique_indices is sorted ascending. Combined with (1), this +// means feature t's unique linearized values occupy a contiguous +// slice of linear_unique_indices. +// +// 3. delinearize_unique_index_kernel scatters the original +// (pre-linearization) per-feature index values back into +// unique_indices via reverse_index. +// +// 4. unique_indices_length_kernel relies on (1)+(2) to compute +// num_unique per feature group via two binary searches over +// linear_unique_indices, instead of an O(N) reduction over +// reverse_index. See the kernel's docstring for the local form of +// the invariant. std::tuple jagged_unique_indices_cuda( const Tensor& hash_size_cumsum, const Tensor& hash_size_offsets, @@ -192,18 +240,16 @@ std::tuple jagged_unique_indices_cuda( Tensor output_lengths = at::zeros({total_B}, offsets.options()); AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "unique_indices_length", ([&] { FBGEMM_LAUNCH_KERNEL( - (unique_indices_length_kernel< - index_t, - std::numeric_limits::max(), - std::numeric_limits::min()>), + (unique_indices_length_kernel), T, - kMaxThreads, + kFlatBlockSize, 0, at::cuda::getCurrentCUDAStream(), PTA_B(hash_size_offsets, index_t, 1, 32), - PTA_B(reverse_index, index_t, 1, 32), - PTA_B(offsets, index_t, 1, 32), - PTA_B(output_lengths, index_t, 1, 32)); + PTA_B(hash_size_cumsum, index_t, 1, 32), + PTA_B(linear_unique_indices, index_t, 1, 32), + PTA_B(output_lengths, index_t, 1, 32), + static_cast(total_B / T)); })); Tensor output_offsets; diff --git a/fbgemm_gpu/test/jagged/failures_dict.json b/fbgemm_gpu/test/jagged/failures_dict.json index 9e94727ce4..bd8135ca19 100644 --- a/fbgemm_gpu/test/jagged/failures_dict.json +++ b/fbgemm_gpu/test/jagged/failures_dict.json @@ -107,6 +107,10 @@ "comment": "", "status": "xfail" }, + "UniqueIndicesTest.test_aot_dispatch_dynamic__test_jagged_unique_indices_zch_huge_hash_size": { + "comment": "Test uses .item()/.tolist() for host-side comparison; incompatible with dynamic dispatch.", + "status": "xfail" + }, "UniqueIndicesTest.test_faketensor__test_jagged_unique_indices": { "comment": "", "status": "xfail" @@ -118,6 +122,10 @@ "UniqueIndicesTest.test_faketensor__test_jagged_unique_indices_multi_keys": { "comment": "", "status": "xfail" + }, + "UniqueIndicesTest.test_faketensor__test_jagged_unique_indices_zch_huge_hash_size": { + "comment": "Test uses .item()/.tolist() for host-side comparison; incompatible with fake tensors.", + "status": "xfail" } }, "fbgemm::keyed_jagged_index_select_dim1": {}, diff --git a/fbgemm_gpu/test/jagged/unique_indices_test.py b/fbgemm_gpu/test/jagged/unique_indices_test.py index 1af4460d5c..525493a439 100644 --- a/fbgemm_gpu/test/jagged/unique_indices_test.py +++ b/fbgemm_gpu/test/jagged/unique_indices_test.py @@ -269,6 +269,74 @@ def test_jagged_unique_indices_empty( self.assertEqual(torch.sum(output_lengths).item(), 0) self.assertEqual(torch.sum(output_offsets).item(), 0) + @unittest.skipIf(*gpu_unavailable) + def test_jagged_unique_indices_zch_huge_hash_size(self) -> None: + """Exercise the op with a hash_size_cumsum entry at INT64_MAX - + the shape produced by ZCH callers that leave per-feature hash size + unbounded. The op must handle hash boundaries spanning the full + int64 range without overflow in any internal arithmetic. + """ + T = 2 + B = 64 + max_length = 5 + int64_max = torch.iinfo(torch.int64).max + hash_size_cumsum_list = [0, 0, int64_max] + hash_size_offsets_list = [0, 0, 2] + # Per-feature linearized values lie in [0, INT64_MAX). The kernels + # under test are boundary-value-sensitive on hash_size_cumsum, not + # on the indices themselves, so a small index range is sufficient + # and keeps the reference comparison fast. + per_feature_value_cap = 1024 + lengths_list: list[int] = [] + indices_list: list[int] = [] + for _ in range(T): + for _ in range(B): + length = random.randint(0, max_length) + lengths_list.append(length) + if length > 0: + indices_list.extend( + np.random.randint( + 0, per_feature_value_cap, size=length + ).tolist() + ) + + device = torch.accelerator.current_accelerator() + assert device is not None + dtype = torch.int64 + hash_size_cumsum = torch.as_tensor( + hash_size_cumsum_list, dtype=dtype, device=device + ) + hash_size_offsets = torch.as_tensor( + hash_size_offsets_list, dtype=dtype, device=device + ) + lengths = torch.as_tensor(lengths_list, dtype=dtype, device=device) + indices = torch.as_tensor(indices_list, dtype=dtype, device=device) + offsets = torch.zeros(T * B + 1, dtype=dtype, device=device) + offsets[1:] = torch.cumsum(lengths, dim=0) + + ( + output_lengths, + output_offsets, + unique_indices, + reverse_index, + ) = torch.ops.fbgemm.jagged_unique_indices( + hash_size_cumsum, hash_size_offsets, offsets, indices + ) + + # Both features share the same hash space (hash_offset = 0 for + # both, since hash_size_cumsum[0] == hash_size_cumsum[1] == 0), + # so the global unique set is the union of all input indices. + expected_unique = sorted(set(indices_list)) + self.assertEqual(unique_indices.numel(), len(expected_unique)) + self.assertEqual(int(torch.sum(output_lengths).item()), unique_indices.numel()) + # Inverse-index round-trip: unique_indices[reverse_index[i]] == indices[i]. + rev_list = reverse_index.tolist() + uniq_list = unique_indices.tolist() + self.assertEqual(len(rev_list), len(indices_list)) + for i, rev in enumerate(rev_list): + self.assertTrue(0 <= rev < len(uniq_list)) + self.assertEqual(uniq_list[rev], indices_list[i]) + @given( num_elements=st.integers(min_value=100, max_value=10000), num_unique_indices=st.integers(min_value=5, max_value=100), From c3891edb495281f0c7a3e4e9b486de76efcd3cd7 Mon Sep 17 00:00:00 2001 From: Albert Chen Date: Mon, 18 May 2026 06:31:24 -0700 Subject: [PATCH 2/2] Speedup linearize_index via flat-grid kernel (#5768) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2696 linearize_index_wo_infos_kernel was launched with grid = ceil(total_B / kMaxThreads), kMaxThreads = 1024. On the prod IFR-MTML mc7 shape total_B is in the low thousands, so the launch consumed only ~5 SMs out of 132 on H100. Each warp also did intra-warp shuffle (`shfl_sync` x kWarpSize) to redistribute work among lanes, adding latency on top of the SM under-utilization. Bench measures this kernel at ~1.7 ms / call. Replace with a flat-grid kernel launched as grid = total_B (one block per (t, b) sample), threads = 256. Each block recovers (t, b) from blockIdx.x via the existing FixedDivisor argument and stripes the per-sample reads/writes across its threads. No shuffle, full SM utilization. Bench measures the new kernel at ~140 us / call (~12x speedup on this kernel). No public API change. Output dtype is unchanged (still index_t); downstream kernels (at::_unique, delinearize, length kernel) consume the same buffer in the same format. Reviewed By: q10 Differential Revision: D105005594 --- .../jagged_unique_indices.cu | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu index 0c7e6482a5..94e6f5777e 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu @@ -27,10 +27,17 @@ namespace fbgemm_gpu { // flat-grid launches with grid = total_B (or num_unique). static constexpr int32_t kFlatBlockSize = 256; -// Linearzie the index with the cumsum of hash size so that linearized indices -// can be sorted together. +// Linearize the index with the cumsum of hash size so that linearized indices +// can be sorted together. Flat-grid: one block per (t, b) sample. +// +// Replaces the prior warp-cooperative kernel which was launched as +// grid = ceil(total_B / kMaxThreads) +// On production shapes total_B is in the low thousands and kMaxThreads = 1024, +// so the prior launch consumed only ~5 SMs out of 132 on H100 with each warp +// shuffling work between lanes. The flat grid uses one block per sample, +// dispatching all SMs and removing the intra-warp shuffle dance. template -__global__ __launch_bounds__(kMaxThreads) void linearize_index_wo_infos_kernel( +__global__ __launch_bounds__(kFlatBlockSize) void linearize_index_flat_kernel( const pta::PackedTensorAccessor32 hash_size_cumsum, const pta::PackedTensorAccessor32 @@ -40,27 +47,21 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_wo_infos_kernel( pta::PackedTensorAccessor32 linear_indices, FixedDivisor fd) { - const auto b_t = blockIdx.x * blockDim.x + threadIdx.x; + const auto b_t = blockIdx.x; int32_t b; int32_t t; - const auto total_B = offsets.size(0) - 1; - const auto valid = b_t < total_B; - - fd.DivMod(b_t, &t, &b); - - const auto hash_offset = valid ? hash_size_cumsum[t] : -1; - const auto indices_start = valid ? offsets[b_t] : -1; - const int32_t L = valid ? offsets[b_t + 1] - indices_start : 0; - const auto lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; - - for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { - const auto indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j); - const auto L_warp = fbgemm_gpu::shfl_sync(L, j); - const auto hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j); - for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { - const auto idx = __ldg(&indices[indices_start_warp + i]); - linear_indices[indices_start_warp + i] = hash_offset_warp + idx; - } + fd.DivMod(static_cast(b_t), &t, &b); + + const auto indices_start = offsets[b_t]; + const auto L = offsets[b_t + 1] - indices_start; + if (L == 0) { + return; + } + const auto hash_offset = hash_size_cumsum[t]; + + for (auto i = threadIdx.x; i < L; i += blockDim.x) { + const auto idx = __ldg(&indices[indices_start + i]); + linear_indices[indices_start + i] = hash_offset + idx; } } @@ -170,7 +171,7 @@ __global__ __launch_bounds__(kFlatBlockSize) void unique_indices_length_kernel( // Pipeline (cross-kernel data flow that ties the four steps together): // -// 1. linearize_index_wo_infos_kernel writes +// 1. linearize_index_flat_kernel writes // linear_indices[i] = hash_size_cumsum[t] + indices[i] // so feature t's linearized values lie in // [hash_size_cumsum[t], hash_size_cumsum[t+1]). @@ -204,9 +205,9 @@ std::tuple jagged_unique_indices_cuda( AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "linearize_index", ([&] { FBGEMM_LAUNCH_KERNEL( - (linearize_index_wo_infos_kernel), - div_round_up(total_B, kMaxThreads), - kMaxThreads, + (linearize_index_flat_kernel), + total_B, + kFlatBlockSize, 0, at::cuda::getCurrentCUDAStream(), PTA_B(hash_size_cumsum, index_t, 1, 32),