Skip to content

Speedup linearize_index via flat-grid kernel (#5768)#5768

Closed
AlbertDachiChen wants to merge 2 commits into
pytorch:mainfrom
AlbertDachiChen:export-D105005594
Closed

Speedup linearize_index via flat-grid kernel (#5768)#5768
AlbertDachiChen wants to merge 2 commits into
pytorch:mainfrom
AlbertDachiChen:export-D105005594

Conversation

@AlbertDachiChen
Copy link
Copy Markdown
Contributor

@AlbertDachiChen AlbertDachiChen commented May 15, 2026

Summary:

X-link: https://github.com/facebookresearch/FBGEMM/pull/2696

linearize_index_wo_infos_kernel was launched with grid = ceil(total_B / kMaxThreads), kMaxThreads = 1024. On the prod IFR-MTML mc7 shape total_B is in the low thousands, so the launch consumed only ~5 SMs out of 132 on H100. Each warp also did intra-warp shuffle (shfl_sync x kWarpSize) to redistribute work among lanes, adding latency on top of the SM under-utilization. Bench measures this kernel at ~1.7 ms / call.

Replace with a flat-grid kernel launched as grid = total_B (one block per (t, b) sample), threads = 256. Each block recovers (t, b) from blockIdx.x via the existing FixedDivisor argument and stripes the per-sample reads/writes across its threads. No shuffle, full SM utilization. Bench measures the new kernel at ~140 us / call (~12x speedup on this kernel).

No public API change. Output dtype is unchanged (still index_t); downstream kernels (at::_unique, delinearize, length kernel) consume the same buffer in the same format.

Reviewed By: q10

Differential Revision: D105005594

@meta-cla meta-cla Bot added the cla signed label May 15, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 15, 2026

@AlbertDachiChen has exported this pull request. If you are a Meta employee, you can view the originating Diff in D105005594.

@meta-codesync meta-codesync Bot changed the title Speedup linearize_index via flat-grid kernel Speedup linearize_index via flat-grid kernel (#5768) May 15, 2026
@AlbertDachiChen AlbertDachiChen force-pushed the export-D105005594 branch 3 times, most recently from 84c133f to 7c8843e Compare May 15, 2026 22:48
AlbertDachiChen added a commit to AlbertDachiChen/FBGEMM that referenced this pull request May 15, 2026
Summary:
Pull Request resolved: pytorch#5768

X-link: https://github.com/facebookresearch/FBGEMM/pull/2696

linearize_index_wo_infos_kernel was launched with grid = ceil(total_B / kMaxThreads), kMaxThreads = 1024. On the prod IFR-MTML mc7 shape total_B is in the low thousands, so the launch consumed only ~5 SMs out of 132 on H100. Each warp also did intra-warp shuffle (`shfl_sync` x kWarpSize) to redistribute work among lanes, adding latency on top of the SM under-utilization. Bench measures this kernel at ~1.7 ms / call.

Replace with a flat-grid kernel launched as grid = total_B (one block per (t, b) sample), threads = 256. Each block recovers (t, b) from blockIdx.x via the existing FixedDivisor argument and stripes the per-sample reads/writes across its threads. No shuffle, full SM utilization. Bench measures the new kernel at ~140 us / call (~12x speedup on this kernel).

No public API change. Output dtype is unchanged (still index_t); downstream kernels (at::_unique, delinearize, length kernel) consume the same buffer in the same format.

Reviewed By: q10

Differential Revision: D105005594
Summary:

X-link: facebookresearch/FBGEMM#2695

The previous unique_indices_length_kernel computes per-feature unique-count via a BlockReduce-based min/max scan over the entire reverse_index array. With grid size = T (number of feature groups, typically 1-2 in production), only T SMs do work out of 132 on H100. Each block scans the full per-feature slice of reverse_index (~12M int64 = ~93MB for the prod IFR-MTML mc7 shape), bandwidth-bound on a single SM at ~30-50 GB/s. Total wall-clock is ~2-3 ms, dominating this op end-to-end (~60% of the ~5 ms baseline on the prod shape).

The kernel was reading 186 MB to compute 4 numbers (a min and max per feature group). It is wasteful because the information is already implicit in `linear_unique_indices`: since `at::_unique` is called with `sorted=True` and `linearize_index_wo_infos_kernel` writes `linear_indices[i] = hash_size_cumsum[t] + indices[i]`, feature t's unique linearized values occupy a contiguous slice of `linear_unique_indices`, namely `[lower_bound(unique, hash_size_cumsum[t]), lower_bound(unique, hash_size_cumsum[t+1]))`. The slice length is `num_unique_t`, which equals the `(max - min + 1)` reduction the old kernel computed.

Replace the O(N) reduction with two O(log U) binary searches per feature group via a new device-side `device_lower_bound` helper. Block size 1024 -> 256 (no shared-memory reduction, no per-thread scratch). The per-block work is now ~336 B of reads (two binary searches, ~21 iterations each, 8 B per iteration), which trivially fits in cache; the T-block grid stops mattering because there is no work to parallelize.

The pipeline contract that ties the four kernels of `jagged_unique_indices_cuda` together (linearize -> at::_unique -> delinearize -> length) is documented above the function so the next reader does not have to reverse-engineer it from the kernel bodies. The length kernel docstring states the local form of the invariant and points at the orchestrator for the why.

Also adds `test_jagged_unique_indices_zch_huge_hash_size`, a regression test for the `ManagedCollisionCollection` shape that exposes `total_hash_size = INT64_MAX`. This shape is produced when a sharding group contains a single `HashZchManagedCollisionModule` with the default `input_hash_size=0`. `mc_modules._create_input_dists` then expands per-table hash size to `2**(63 - N) - 1` (per `torchrec/distributed/mc_modules.py:643`); for N=0 (single-table group) that lands at INT64_MAX. This shape was not exercised by any existing test and was the trigger for the `cudaErrorIllegalInstruction` in S660690. **The new length kernel handles it correctly (integer-only arithmetic at the boundary), but the test also serves as a trip-wire for downstream optimizations that introduce float-log2 math on `total_hash_size`.**

No public API change. Outputs of `jagged_unique_indices` are bit-identical to the previous version for all valid inputs.

Reviewed By: q10

Differential Revision: D104827588
Summary:

X-link: facebookresearch/FBGEMM#2696

linearize_index_wo_infos_kernel was launched with grid = ceil(total_B / kMaxThreads), kMaxThreads = 1024. On the prod IFR-MTML mc7 shape total_B is in the low thousands, so the launch consumed only ~5 SMs out of 132 on H100. Each warp also did intra-warp shuffle (`shfl_sync` x kWarpSize) to redistribute work among lanes, adding latency on top of the SM under-utilization. Bench measures this kernel at ~1.7 ms / call.

Replace with a flat-grid kernel launched as grid = total_B (one block per (t, b) sample), threads = 256. Each block recovers (t, b) from blockIdx.x via the existing FixedDivisor argument and stripes the per-sample reads/writes across its threads. No shuffle, full SM utilization. Bench measures the new kernel at ~140 us / call (~12x speedup on this kernel).

No public API change. Output dtype is unchanged (still index_t); downstream kernels (at::_unique, delinearize, length kernel) consume the same buffer in the same format.

Reviewed By: q10

Differential Revision: D105005594
AlbertDachiChen added a commit to AlbertDachiChen/FBGEMM that referenced this pull request May 18, 2026
Summary:

X-link: facebookresearch/FBGEMM#2696

linearize_index_wo_infos_kernel was launched with grid = ceil(total_B / kMaxThreads), kMaxThreads = 1024. On the prod IFR-MTML mc7 shape total_B is in the low thousands, so the launch consumed only ~5 SMs out of 132 on H100. Each warp also did intra-warp shuffle (`shfl_sync` x kWarpSize) to redistribute work among lanes, adding latency on top of the SM under-utilization. Bench measures this kernel at ~1.7 ms / call.

Replace with a flat-grid kernel launched as grid = total_B (one block per (t, b) sample), threads = 256. Each block recovers (t, b) from blockIdx.x via the existing FixedDivisor argument and stripes the per-sample reads/writes across its threads. No shuffle, full SM utilization. Bench measures the new kernel at ~140 us / call (~12x speedup on this kernel).

No public API change. Output dtype is unchanged (still index_t); downstream kernels (at::_unique, delinearize, length kernel) consume the same buffer in the same format.

Reviewed By: q10

Differential Revision: D105005594
AlbertDachiChen added a commit to AlbertDachiChen/FBGEMM that referenced this pull request May 18, 2026
Summary:

X-link: facebookresearch/FBGEMM#2696

linearize_index_wo_infos_kernel was launched with grid = ceil(total_B / kMaxThreads), kMaxThreads = 1024. On the prod IFR-MTML mc7 shape total_B is in the low thousands, so the launch consumed only ~5 SMs out of 132 on H100. Each warp also did intra-warp shuffle (`shfl_sync` x kWarpSize) to redistribute work among lanes, adding latency on top of the SM under-utilization. Bench measures this kernel at ~1.7 ms / call.

Replace with a flat-grid kernel launched as grid = total_B (one block per (t, b) sample), threads = 256. Each block recovers (t, b) from blockIdx.x via the existing FixedDivisor argument and stripes the per-sample reads/writes across its threads. No shuffle, full SM utilization. Bench measures the new kernel at ~140 us / call (~12x speedup on this kernel).

No public API change. Output dtype is unchanged (still index_t); downstream kernels (at::_unique, delinearize, length kernel) consume the same buffer in the same format.

Reviewed By: q10

Differential Revision: D105005594
@meta-codesync meta-codesync Bot closed this in 931739f May 18, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 18, 2026

This pull request has been merged in 931739f.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant