Speedup linearize_index via flat-grid kernel (#5768)#5768
Closed
AlbertDachiChen wants to merge 2 commits into
Closed
Speedup linearize_index via flat-grid kernel (#5768)#5768AlbertDachiChen wants to merge 2 commits into
AlbertDachiChen wants to merge 2 commits into
Conversation
Contributor
|
@AlbertDachiChen has exported this pull request. If you are a Meta employee, you can view the originating Diff in D105005594. |
7865c9d to
cdc2f79
Compare
84c133f to
7c8843e
Compare
AlbertDachiChen
added a commit
to AlbertDachiChen/FBGEMM
that referenced
this pull request
May 15, 2026
Summary: Pull Request resolved: pytorch#5768 X-link: https://github.com/facebookresearch/FBGEMM/pull/2696 linearize_index_wo_infos_kernel was launched with grid = ceil(total_B / kMaxThreads), kMaxThreads = 1024. On the prod IFR-MTML mc7 shape total_B is in the low thousands, so the launch consumed only ~5 SMs out of 132 on H100. Each warp also did intra-warp shuffle (`shfl_sync` x kWarpSize) to redistribute work among lanes, adding latency on top of the SM under-utilization. Bench measures this kernel at ~1.7 ms / call. Replace with a flat-grid kernel launched as grid = total_B (one block per (t, b) sample), threads = 256. Each block recovers (t, b) from blockIdx.x via the existing FixedDivisor argument and stripes the per-sample reads/writes across its threads. No shuffle, full SM utilization. Bench measures the new kernel at ~140 us / call (~12x speedup on this kernel). No public API change. Output dtype is unchanged (still index_t); downstream kernels (at::_unique, delinearize, length kernel) consume the same buffer in the same format. Reviewed By: q10 Differential Revision: D105005594
Summary: X-link: facebookresearch/FBGEMM#2695 The previous unique_indices_length_kernel computes per-feature unique-count via a BlockReduce-based min/max scan over the entire reverse_index array. With grid size = T (number of feature groups, typically 1-2 in production), only T SMs do work out of 132 on H100. Each block scans the full per-feature slice of reverse_index (~12M int64 = ~93MB for the prod IFR-MTML mc7 shape), bandwidth-bound on a single SM at ~30-50 GB/s. Total wall-clock is ~2-3 ms, dominating this op end-to-end (~60% of the ~5 ms baseline on the prod shape). The kernel was reading 186 MB to compute 4 numbers (a min and max per feature group). It is wasteful because the information is already implicit in `linear_unique_indices`: since `at::_unique` is called with `sorted=True` and `linearize_index_wo_infos_kernel` writes `linear_indices[i] = hash_size_cumsum[t] + indices[i]`, feature t's unique linearized values occupy a contiguous slice of `linear_unique_indices`, namely `[lower_bound(unique, hash_size_cumsum[t]), lower_bound(unique, hash_size_cumsum[t+1]))`. The slice length is `num_unique_t`, which equals the `(max - min + 1)` reduction the old kernel computed. Replace the O(N) reduction with two O(log U) binary searches per feature group via a new device-side `device_lower_bound` helper. Block size 1024 -> 256 (no shared-memory reduction, no per-thread scratch). The per-block work is now ~336 B of reads (two binary searches, ~21 iterations each, 8 B per iteration), which trivially fits in cache; the T-block grid stops mattering because there is no work to parallelize. The pipeline contract that ties the four kernels of `jagged_unique_indices_cuda` together (linearize -> at::_unique -> delinearize -> length) is documented above the function so the next reader does not have to reverse-engineer it from the kernel bodies. The length kernel docstring states the local form of the invariant and points at the orchestrator for the why. Also adds `test_jagged_unique_indices_zch_huge_hash_size`, a regression test for the `ManagedCollisionCollection` shape that exposes `total_hash_size = INT64_MAX`. This shape is produced when a sharding group contains a single `HashZchManagedCollisionModule` with the default `input_hash_size=0`. `mc_modules._create_input_dists` then expands per-table hash size to `2**(63 - N) - 1` (per `torchrec/distributed/mc_modules.py:643`); for N=0 (single-table group) that lands at INT64_MAX. This shape was not exercised by any existing test and was the trigger for the `cudaErrorIllegalInstruction` in S660690. **The new length kernel handles it correctly (integer-only arithmetic at the boundary), but the test also serves as a trip-wire for downstream optimizations that introduce float-log2 math on `total_hash_size`.** No public API change. Outputs of `jagged_unique_indices` are bit-identical to the previous version for all valid inputs. Reviewed By: q10 Differential Revision: D104827588
Summary: X-link: facebookresearch/FBGEMM#2696 linearize_index_wo_infos_kernel was launched with grid = ceil(total_B / kMaxThreads), kMaxThreads = 1024. On the prod IFR-MTML mc7 shape total_B is in the low thousands, so the launch consumed only ~5 SMs out of 132 on H100. Each warp also did intra-warp shuffle (`shfl_sync` x kWarpSize) to redistribute work among lanes, adding latency on top of the SM under-utilization. Bench measures this kernel at ~1.7 ms / call. Replace with a flat-grid kernel launched as grid = total_B (one block per (t, b) sample), threads = 256. Each block recovers (t, b) from blockIdx.x via the existing FixedDivisor argument and stripes the per-sample reads/writes across its threads. No shuffle, full SM utilization. Bench measures the new kernel at ~140 us / call (~12x speedup on this kernel). No public API change. Output dtype is unchanged (still index_t); downstream kernels (at::_unique, delinearize, length kernel) consume the same buffer in the same format. Reviewed By: q10 Differential Revision: D105005594
AlbertDachiChen
added a commit
to AlbertDachiChen/FBGEMM
that referenced
this pull request
May 18, 2026
Summary: X-link: facebookresearch/FBGEMM#2696 linearize_index_wo_infos_kernel was launched with grid = ceil(total_B / kMaxThreads), kMaxThreads = 1024. On the prod IFR-MTML mc7 shape total_B is in the low thousands, so the launch consumed only ~5 SMs out of 132 on H100. Each warp also did intra-warp shuffle (`shfl_sync` x kWarpSize) to redistribute work among lanes, adding latency on top of the SM under-utilization. Bench measures this kernel at ~1.7 ms / call. Replace with a flat-grid kernel launched as grid = total_B (one block per (t, b) sample), threads = 256. Each block recovers (t, b) from blockIdx.x via the existing FixedDivisor argument and stripes the per-sample reads/writes across its threads. No shuffle, full SM utilization. Bench measures the new kernel at ~140 us / call (~12x speedup on this kernel). No public API change. Output dtype is unchanged (still index_t); downstream kernels (at::_unique, delinearize, length kernel) consume the same buffer in the same format. Reviewed By: q10 Differential Revision: D105005594
7c8843e to
c3891ed
Compare
AlbertDachiChen
added a commit
to AlbertDachiChen/FBGEMM
that referenced
this pull request
May 18, 2026
Summary: X-link: facebookresearch/FBGEMM#2696 linearize_index_wo_infos_kernel was launched with grid = ceil(total_B / kMaxThreads), kMaxThreads = 1024. On the prod IFR-MTML mc7 shape total_B is in the low thousands, so the launch consumed only ~5 SMs out of 132 on H100. Each warp also did intra-warp shuffle (`shfl_sync` x kWarpSize) to redistribute work among lanes, adding latency on top of the SM under-utilization. Bench measures this kernel at ~1.7 ms / call. Replace with a flat-grid kernel launched as grid = total_B (one block per (t, b) sample), threads = 256. Each block recovers (t, b) from blockIdx.x via the existing FixedDivisor argument and stripes the per-sample reads/writes across its threads. No shuffle, full SM utilization. Bench measures the new kernel at ~140 us / call (~12x speedup on this kernel). No public API change. Output dtype is unchanged (still index_t); downstream kernels (at::_unique, delinearize, length kernel) consume the same buffer in the same format. Reviewed By: q10 Differential Revision: D105005594
Contributor
|
This pull request has been merged in 931739f. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2696
linearize_index_wo_infos_kernel was launched with grid = ceil(total_B / kMaxThreads), kMaxThreads = 1024. On the prod IFR-MTML mc7 shape total_B is in the low thousands, so the launch consumed only ~5 SMs out of 132 on H100. Each warp also did intra-warp shuffle (
shfl_syncx kWarpSize) to redistribute work among lanes, adding latency on top of the SM under-utilization. Bench measures this kernel at ~1.7 ms / call.Replace with a flat-grid kernel launched as grid = total_B (one block per (t, b) sample), threads = 256. Each block recovers (t, b) from blockIdx.x via the existing FixedDivisor argument and stripes the per-sample reads/writes across its threads. No shuffle, full SM utilization. Bench measures the new kernel at ~140 us / call (~12x speedup on this kernel).
No public API change. Output dtype is unchanged (still index_t); downstream kernels (at::_unique, delinearize, length kernel) consume the same buffer in the same format.
Reviewed By: q10
Differential Revision: D105005594