Skip to content

Speedup unique_indices_length_kernel via binary search (#5766)#5766

Closed
AlbertDachiChen wants to merge 1 commit into
pytorch:mainfrom
AlbertDachiChen:export-D104827588
Closed

Speedup unique_indices_length_kernel via binary search (#5766)#5766
AlbertDachiChen wants to merge 1 commit into
pytorch:mainfrom
AlbertDachiChen:export-D104827588

Conversation

@AlbertDachiChen
Copy link
Copy Markdown
Contributor

@AlbertDachiChen AlbertDachiChen commented May 15, 2026

Summary:

X-link: https://github.com/facebookresearch/FBGEMM/pull/2695

The previous unique_indices_length_kernel computes per-feature unique-count via a BlockReduce-based min/max scan over the entire reverse_index array. With grid size = T (number of feature groups, typically 1-2 in production), only T SMs do work out of 132 on H100. Each block scans the full per-feature slice of reverse_index (~12M int64 = ~93MB for the prod IFR-MTML mc7 shape), bandwidth-bound on a single SM at ~30-50 GB/s. Total wall-clock is ~2-3 ms, dominating this op end-to-end (~60% of the ~5 ms baseline on the prod shape).

The kernel was reading 186 MB to compute 4 numbers (a min and max per feature group). It is wasteful because the information is already implicit in linear_unique_indices: since at::_unique is called with sorted=True and linearize_index_wo_infos_kernel writes linear_indices[i] = hash_size_cumsum[t] + indices[i], feature t's unique linearized values occupy a contiguous slice of linear_unique_indices, namely [lower_bound(unique, hash_size_cumsum[t]), lower_bound(unique, hash_size_cumsum[t+1])). The slice length is num_unique_t, which equals the (max - min + 1) reduction the old kernel computed.

Replace the O(N) reduction with two O(log U) binary searches per feature group via a new device-side device_lower_bound helper. Block size 1024 -> 256 (no shared-memory reduction, no per-thread scratch). The per-block work is now ~336 B of reads (two binary searches, ~21 iterations each, 8 B per iteration), which trivially fits in cache; the T-block grid stops mattering because there is no work to parallelize.

The pipeline contract that ties the four kernels of jagged_unique_indices_cuda together (linearize -> at::_unique -> delinearize -> length) is documented above the function so the next reader does not have to reverse-engineer it from the kernel bodies. The length kernel docstring states the local form of the invariant and points at the orchestrator for the why.

Also adds test_jagged_unique_indices_zch_huge_hash_size, a regression test for the ManagedCollisionCollection shape that exposes total_hash_size = INT64_MAX. This shape is produced when a sharding group contains a single HashZchManagedCollisionModule with the default input_hash_size=0. mc_modules._create_input_dists then expands per-table hash size to 2**(63 - N) - 1 (per torchrec/distributed/mc_modules.py:643); for N=0 (single-table group) that lands at INT64_MAX. This shape was not exercised by any existing test and was the trigger for the cudaErrorIllegalInstruction in S660690. The new length kernel handles it correctly (integer-only arithmetic at the boundary), but the test also serves as a trip-wire for downstream optimizations that introduce float-log2 math on total_hash_size.

No public API change. Outputs of jagged_unique_indices are bit-identical to the previous version for all valid inputs.

Reviewed By: q10

Differential Revision: D104827588

@meta-cla meta-cla Bot added the cla signed label May 15, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 15, 2026

@AlbertDachiChen has exported this pull request. If you are a Meta employee, you can view the originating Diff in D104827588.

@meta-codesync meta-codesync Bot changed the title Speedup unique_indices_length_kernel via binary search Speedup unique_indices_length_kernel via binary search (#5766) May 15, 2026
AlbertDachiChen added a commit to AlbertDachiChen/FBGEMM that referenced this pull request May 15, 2026
Summary:

X-link: facebookresearch/FBGEMM#2695

The previous unique_indices_length_kernel computes per-feature unique-count via a BlockReduce-based min/max scan over the entire reverse_index array. With grid size = T (number of feature groups, typically 1-2 in production), only T SMs do work out of 132 on H100. Each block scans the full per-feature slice of reverse_index (~12M int64 = ~93MB for the prod IFR-MTML mc7 shape), bandwidth-bound on a single SM at ~30-50 GB/s. Total wall-clock is ~2-3 ms, dominating this op end-to-end (~60% of the ~5 ms baseline on the prod shape).

The kernel was reading 186 MB to compute 4 numbers (a min and max per feature group). It is wasteful because the information is already implicit in `linear_unique_indices`: since `at::_unique` is called with `sorted=True` and `linearize_index_wo_infos_kernel` writes `linear_indices[i] = hash_size_cumsum[t] + indices[i]`, feature t's unique linearized values occupy a contiguous slice of `linear_unique_indices`, namely `[lower_bound(unique, hash_size_cumsum[t]), lower_bound(unique, hash_size_cumsum[t+1]))`. The slice length is `num_unique_t`, which equals the `(max - min + 1)` reduction the old kernel computed.

Replace the O(N) reduction with two O(log U) binary searches per feature group via a new device-side `device_lower_bound` helper. Block size 1024 -> 256 (no shared-memory reduction, no per-thread scratch). The per-block work is now ~336 B of reads (two binary searches, ~21 iterations each, 8 B per iteration), which trivially fits in cache; the T-block grid stops mattering because there is no work to parallelize.

The pipeline contract that ties the four kernels of `jagged_unique_indices_cuda` together (linearize -> at::_unique -> delinearize -> length) is documented above the function so the next reader does not have to reverse-engineer it from the kernel bodies. The length kernel docstring states the local form of the invariant and points at the orchestrator for the why.

Also adds `test_jagged_unique_indices_zch_huge_hash_size`, a regression test for the `ManagedCollisionCollection` shape that exposes `total_hash_size = INT64_MAX`. This shape is produced when a sharding group contains a single `HashZchManagedCollisionModule` with the default `input_hash_size=0`. `mc_modules._create_input_dists` then expands per-table hash size to `2**(63 - N) - 1` (per `torchrec/distributed/mc_modules.py:643`); for N=0 (single-table group) that lands at INT64_MAX. This shape was not exercised by any existing test and was the trigger for the `cudaErrorIllegalInstruction` in S660690. **The new length kernel handles it correctly (integer-only arithmetic at the boundary), but the test also serves as a trip-wire for downstream optimizations that introduce float-log2 math on `total_hash_size`.**

No public API change. Outputs of `jagged_unique_indices` are bit-identical to the previous version for all valid inputs.

Reviewed By: q10

Differential Revision: D104827588
@AlbertDachiChen AlbertDachiChen force-pushed the export-D104827588 branch 2 times, most recently from 91b67c0 to 3a6d8ad Compare May 15, 2026 17:57
AlbertDachiChen added a commit to AlbertDachiChen/FBGEMM that referenced this pull request May 15, 2026
Summary:

X-link: facebookresearch/FBGEMM#2695

The previous unique_indices_length_kernel computes per-feature unique-count via a BlockReduce-based min/max scan over the entire reverse_index array. With grid size = T (number of feature groups, typically 1-2 in production), only T SMs do work out of 132 on H100. Each block scans the full per-feature slice of reverse_index (~12M int64 = ~93MB for the prod IFR-MTML mc7 shape), bandwidth-bound on a single SM at ~30-50 GB/s. Total wall-clock is ~2-3 ms, dominating this op end-to-end (~60% of the ~5 ms baseline on the prod shape).

The kernel was reading 186 MB to compute 4 numbers (a min and max per feature group). It is wasteful because the information is already implicit in `linear_unique_indices`: since `at::_unique` is called with `sorted=True` and `linearize_index_wo_infos_kernel` writes `linear_indices[i] = hash_size_cumsum[t] + indices[i]`, feature t's unique linearized values occupy a contiguous slice of `linear_unique_indices`, namely `[lower_bound(unique, hash_size_cumsum[t]), lower_bound(unique, hash_size_cumsum[t+1]))`. The slice length is `num_unique_t`, which equals the `(max - min + 1)` reduction the old kernel computed.

Replace the O(N) reduction with two O(log U) binary searches per feature group via a new device-side `device_lower_bound` helper. Block size 1024 -> 256 (no shared-memory reduction, no per-thread scratch). The per-block work is now ~336 B of reads (two binary searches, ~21 iterations each, 8 B per iteration), which trivially fits in cache; the T-block grid stops mattering because there is no work to parallelize.

The pipeline contract that ties the four kernels of `jagged_unique_indices_cuda` together (linearize -> at::_unique -> delinearize -> length) is documented above the function so the next reader does not have to reverse-engineer it from the kernel bodies. The length kernel docstring states the local form of the invariant and points at the orchestrator for the why.

Also adds `test_jagged_unique_indices_zch_huge_hash_size`, a regression test for the `ManagedCollisionCollection` shape that exposes `total_hash_size = INT64_MAX`. This shape is produced when a sharding group contains a single `HashZchManagedCollisionModule` with the default `input_hash_size=0`. `mc_modules._create_input_dists` then expands per-table hash size to `2**(63 - N) - 1` (per `torchrec/distributed/mc_modules.py:643`); for N=0 (single-table group) that lands at INT64_MAX. This shape was not exercised by any existing test and was the trigger for the `cudaErrorIllegalInstruction` in S660690. **The new length kernel handles it correctly (integer-only arithmetic at the boundary), but the test also serves as a trip-wire for downstream optimizations that introduce float-log2 math on `total_hash_size`.**

No public API change. Outputs of `jagged_unique_indices` are bit-identical to the previous version for all valid inputs.

Reviewed By: q10

Differential Revision: D104827588
@AlbertDachiChen AlbertDachiChen force-pushed the export-D104827588 branch 2 times, most recently from cc3955f to 662ccc9 Compare May 15, 2026 22:45
AlbertDachiChen added a commit to AlbertDachiChen/FBGEMM that referenced this pull request May 15, 2026
Summary:
Pull Request resolved: pytorch#5766

X-link: https://github.com/facebookresearch/FBGEMM/pull/2695

The previous unique_indices_length_kernel computes per-feature unique-count via a BlockReduce-based min/max scan over the entire reverse_index array. With grid size = T (number of feature groups, typically 1-2 in production), only T SMs do work out of 132 on H100. Each block scans the full per-feature slice of reverse_index (~12M int64 = ~93MB for the prod IFR-MTML mc7 shape), bandwidth-bound on a single SM at ~30-50 GB/s. Total wall-clock is ~2-3 ms, dominating this op end-to-end (~60% of the ~5 ms baseline on the prod shape).

The kernel was reading 186 MB to compute 4 numbers (a min and max per feature group). It is wasteful because the information is already implicit in `linear_unique_indices`: since `at::_unique` is called with `sorted=True` and `linearize_index_wo_infos_kernel` writes `linear_indices[i] = hash_size_cumsum[t] + indices[i]`, feature t's unique linearized values occupy a contiguous slice of `linear_unique_indices`, namely `[lower_bound(unique, hash_size_cumsum[t]), lower_bound(unique, hash_size_cumsum[t+1]))`. The slice length is `num_unique_t`, which equals the `(max - min + 1)` reduction the old kernel computed.

Replace the O(N) reduction with two O(log U) binary searches per feature group via a new device-side `device_lower_bound` helper. Block size 1024 -> 256 (no shared-memory reduction, no per-thread scratch). The per-block work is now ~336 B of reads (two binary searches, ~21 iterations each, 8 B per iteration), which trivially fits in cache; the T-block grid stops mattering because there is no work to parallelize.

The pipeline contract that ties the four kernels of `jagged_unique_indices_cuda` together (linearize -> at::_unique -> delinearize -> length) is documented above the function so the next reader does not have to reverse-engineer it from the kernel bodies. The length kernel docstring states the local form of the invariant and points at the orchestrator for the why.

Also adds `test_jagged_unique_indices_zch_huge_hash_size`, a regression test for the `ManagedCollisionCollection` shape that exposes `total_hash_size = INT64_MAX`. This shape is produced when a sharding group contains a single `HashZchManagedCollisionModule` with the default `input_hash_size=0`. `mc_modules._create_input_dists` then expands per-table hash size to `2**(63 - N) - 1` (per `torchrec/distributed/mc_modules.py:643`); for N=0 (single-table group) that lands at INT64_MAX. This shape was not exercised by any existing test and was the trigger for the `cudaErrorIllegalInstruction` in S660690. **The new length kernel handles it correctly (integer-only arithmetic at the boundary), but the test also serves as a trip-wire for downstream optimizations that introduce float-log2 math on `total_hash_size`.**

No public API change. Outputs of `jagged_unique_indices` are bit-identical to the previous version for all valid inputs.

Reviewed By: q10

Differential Revision: D104827588
@AlbertDachiChen AlbertDachiChen force-pushed the export-D104827588 branch 2 times, most recently from cfae1b3 to 402ab8c Compare May 15, 2026 22:48
AlbertDachiChen added a commit to AlbertDachiChen/FBGEMM that referenced this pull request May 18, 2026
Summary:

X-link: facebookresearch/FBGEMM#2695

The previous unique_indices_length_kernel computes per-feature unique-count via a BlockReduce-based min/max scan over the entire reverse_index array. With grid size = T (number of feature groups, typically 1-2 in production), only T SMs do work out of 132 on H100. Each block scans the full per-feature slice of reverse_index (~12M int64 = ~93MB for the prod IFR-MTML mc7 shape), bandwidth-bound on a single SM at ~30-50 GB/s. Total wall-clock is ~2-3 ms, dominating this op end-to-end (~60% of the ~5 ms baseline on the prod shape).

The kernel was reading 186 MB to compute 4 numbers (a min and max per feature group). It is wasteful because the information is already implicit in `linear_unique_indices`: since `at::_unique` is called with `sorted=True` and `linearize_index_wo_infos_kernel` writes `linear_indices[i] = hash_size_cumsum[t] + indices[i]`, feature t's unique linearized values occupy a contiguous slice of `linear_unique_indices`, namely `[lower_bound(unique, hash_size_cumsum[t]), lower_bound(unique, hash_size_cumsum[t+1]))`. The slice length is `num_unique_t`, which equals the `(max - min + 1)` reduction the old kernel computed.

Replace the O(N) reduction with two O(log U) binary searches per feature group via a new device-side `device_lower_bound` helper. Block size 1024 -> 256 (no shared-memory reduction, no per-thread scratch). The per-block work is now ~336 B of reads (two binary searches, ~21 iterations each, 8 B per iteration), which trivially fits in cache; the T-block grid stops mattering because there is no work to parallelize.

The pipeline contract that ties the four kernels of `jagged_unique_indices_cuda` together (linearize -> at::_unique -> delinearize -> length) is documented above the function so the next reader does not have to reverse-engineer it from the kernel bodies. The length kernel docstring states the local form of the invariant and points at the orchestrator for the why.

Also adds `test_jagged_unique_indices_zch_huge_hash_size`, a regression test for the `ManagedCollisionCollection` shape that exposes `total_hash_size = INT64_MAX`. This shape is produced when a sharding group contains a single `HashZchManagedCollisionModule` with the default `input_hash_size=0`. `mc_modules._create_input_dists` then expands per-table hash size to `2**(63 - N) - 1` (per `torchrec/distributed/mc_modules.py:643`); for N=0 (single-table group) that lands at INT64_MAX. This shape was not exercised by any existing test and was the trigger for the `cudaErrorIllegalInstruction` in S660690. **The new length kernel handles it correctly (integer-only arithmetic at the boundary), but the test also serves as a trip-wire for downstream optimizations that introduce float-log2 math on `total_hash_size`.**

No public API change. Outputs of `jagged_unique_indices` are bit-identical to the previous version for all valid inputs.

Reviewed By: q10

Differential Revision: D104827588
AlbertDachiChen added a commit to AlbertDachiChen/FBGEMM that referenced this pull request May 18, 2026
Summary:

X-link: facebookresearch/FBGEMM#2695

The previous unique_indices_length_kernel computes per-feature unique-count via a BlockReduce-based min/max scan over the entire reverse_index array. With grid size = T (number of feature groups, typically 1-2 in production), only T SMs do work out of 132 on H100. Each block scans the full per-feature slice of reverse_index (~12M int64 = ~93MB for the prod IFR-MTML mc7 shape), bandwidth-bound on a single SM at ~30-50 GB/s. Total wall-clock is ~2-3 ms, dominating this op end-to-end (~60% of the ~5 ms baseline on the prod shape).

The kernel was reading 186 MB to compute 4 numbers (a min and max per feature group). It is wasteful because the information is already implicit in `linear_unique_indices`: since `at::_unique` is called with `sorted=True` and `linearize_index_wo_infos_kernel` writes `linear_indices[i] = hash_size_cumsum[t] + indices[i]`, feature t's unique linearized values occupy a contiguous slice of `linear_unique_indices`, namely `[lower_bound(unique, hash_size_cumsum[t]), lower_bound(unique, hash_size_cumsum[t+1]))`. The slice length is `num_unique_t`, which equals the `(max - min + 1)` reduction the old kernel computed.

Replace the O(N) reduction with two O(log U) binary searches per feature group via a new device-side `device_lower_bound` helper. Block size 1024 -> 256 (no shared-memory reduction, no per-thread scratch). The per-block work is now ~336 B of reads (two binary searches, ~21 iterations each, 8 B per iteration), which trivially fits in cache; the T-block grid stops mattering because there is no work to parallelize.

The pipeline contract that ties the four kernels of `jagged_unique_indices_cuda` together (linearize -> at::_unique -> delinearize -> length) is documented above the function so the next reader does not have to reverse-engineer it from the kernel bodies. The length kernel docstring states the local form of the invariant and points at the orchestrator for the why.

Also adds `test_jagged_unique_indices_zch_huge_hash_size`, a regression test for the `ManagedCollisionCollection` shape that exposes `total_hash_size = INT64_MAX`. This shape is produced when a sharding group contains a single `HashZchManagedCollisionModule` with the default `input_hash_size=0`. `mc_modules._create_input_dists` then expands per-table hash size to `2**(63 - N) - 1` (per `torchrec/distributed/mc_modules.py:643`); for N=0 (single-table group) that lands at INT64_MAX. This shape was not exercised by any existing test and was the trigger for the `cudaErrorIllegalInstruction` in S660690. **The new length kernel handles it correctly (integer-only arithmetic at the boundary), but the test also serves as a trip-wire for downstream optimizations that introduce float-log2 math on `total_hash_size`.**

No public API change. Outputs of `jagged_unique_indices` are bit-identical to the previous version for all valid inputs.

Reviewed By: q10

Differential Revision: D104827588
Summary:

X-link: facebookresearch/FBGEMM#2695

The previous unique_indices_length_kernel computes per-feature unique-count via a BlockReduce-based min/max scan over the entire reverse_index array. With grid size = T (number of feature groups, typically 1-2 in production), only T SMs do work out of 132 on H100. Each block scans the full per-feature slice of reverse_index (~12M int64 = ~93MB for the prod IFR-MTML mc7 shape), bandwidth-bound on a single SM at ~30-50 GB/s. Total wall-clock is ~2-3 ms, dominating this op end-to-end (~60% of the ~5 ms baseline on the prod shape).

The kernel was reading 186 MB to compute 4 numbers (a min and max per feature group). It is wasteful because the information is already implicit in `linear_unique_indices`: since `at::_unique` is called with `sorted=True` and `linearize_index_wo_infos_kernel` writes `linear_indices[i] = hash_size_cumsum[t] + indices[i]`, feature t's unique linearized values occupy a contiguous slice of `linear_unique_indices`, namely `[lower_bound(unique, hash_size_cumsum[t]), lower_bound(unique, hash_size_cumsum[t+1]))`. The slice length is `num_unique_t`, which equals the `(max - min + 1)` reduction the old kernel computed.

Replace the O(N) reduction with two O(log U) binary searches per feature group via a new device-side `device_lower_bound` helper. Block size 1024 -> 256 (no shared-memory reduction, no per-thread scratch). The per-block work is now ~336 B of reads (two binary searches, ~21 iterations each, 8 B per iteration), which trivially fits in cache; the T-block grid stops mattering because there is no work to parallelize.

The pipeline contract that ties the four kernels of `jagged_unique_indices_cuda` together (linearize -> at::_unique -> delinearize -> length) is documented above the function so the next reader does not have to reverse-engineer it from the kernel bodies. The length kernel docstring states the local form of the invariant and points at the orchestrator for the why.

Also adds `test_jagged_unique_indices_zch_huge_hash_size`, a regression test for the `ManagedCollisionCollection` shape that exposes `total_hash_size = INT64_MAX`. This shape is produced when a sharding group contains a single `HashZchManagedCollisionModule` with the default `input_hash_size=0`. `mc_modules._create_input_dists` then expands per-table hash size to `2**(63 - N) - 1` (per `torchrec/distributed/mc_modules.py:643`); for N=0 (single-table group) that lands at INT64_MAX. This shape was not exercised by any existing test and was the trigger for the `cudaErrorIllegalInstruction` in S660690. **The new length kernel handles it correctly (integer-only arithmetic at the boundary), but the test also serves as a trip-wire for downstream optimizations that introduce float-log2 math on `total_hash_size`.**

No public API change. Outputs of `jagged_unique_indices` are bit-identical to the previous version for all valid inputs.

Reviewed By: q10

Differential Revision: D104827588
AlbertDachiChen added a commit to AlbertDachiChen/FBGEMM that referenced this pull request May 18, 2026
Summary:

X-link: facebookresearch/FBGEMM#2695

The previous unique_indices_length_kernel computes per-feature unique-count via a BlockReduce-based min/max scan over the entire reverse_index array. With grid size = T (number of feature groups, typically 1-2 in production), only T SMs do work out of 132 on H100. Each block scans the full per-feature slice of reverse_index (~12M int64 = ~93MB for the prod IFR-MTML mc7 shape), bandwidth-bound on a single SM at ~30-50 GB/s. Total wall-clock is ~2-3 ms, dominating this op end-to-end (~60% of the ~5 ms baseline on the prod shape).

The kernel was reading 186 MB to compute 4 numbers (a min and max per feature group). It is wasteful because the information is already implicit in `linear_unique_indices`: since `at::_unique` is called with `sorted=True` and `linearize_index_wo_infos_kernel` writes `linear_indices[i] = hash_size_cumsum[t] + indices[i]`, feature t's unique linearized values occupy a contiguous slice of `linear_unique_indices`, namely `[lower_bound(unique, hash_size_cumsum[t]), lower_bound(unique, hash_size_cumsum[t+1]))`. The slice length is `num_unique_t`, which equals the `(max - min + 1)` reduction the old kernel computed.

Replace the O(N) reduction with two O(log U) binary searches per feature group via a new device-side `device_lower_bound` helper. Block size 1024 -> 256 (no shared-memory reduction, no per-thread scratch). The per-block work is now ~336 B of reads (two binary searches, ~21 iterations each, 8 B per iteration), which trivially fits in cache; the T-block grid stops mattering because there is no work to parallelize.

The pipeline contract that ties the four kernels of `jagged_unique_indices_cuda` together (linearize -> at::_unique -> delinearize -> length) is documented above the function so the next reader does not have to reverse-engineer it from the kernel bodies. The length kernel docstring states the local form of the invariant and points at the orchestrator for the why.

Also adds `test_jagged_unique_indices_zch_huge_hash_size`, a regression test for the `ManagedCollisionCollection` shape that exposes `total_hash_size = INT64_MAX`. This shape is produced when a sharding group contains a single `HashZchManagedCollisionModule` with the default `input_hash_size=0`. `mc_modules._create_input_dists` then expands per-table hash size to `2**(63 - N) - 1` (per `torchrec/distributed/mc_modules.py:643`); for N=0 (single-table group) that lands at INT64_MAX. This shape was not exercised by any existing test and was the trigger for the `cudaErrorIllegalInstruction` in S660690. **The new length kernel handles it correctly (integer-only arithmetic at the boundary), but the test also serves as a trip-wire for downstream optimizations that introduce float-log2 math on `total_hash_size`.**

No public API change. Outputs of `jagged_unique_indices` are bit-identical to the previous version for all valid inputs.

Reviewed By: q10

Differential Revision: D104827588
@meta-codesync meta-codesync Bot closed this in 7121bf0 May 18, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 18, 2026

This pull request has been merged in 7121bf0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant