Skip to content

Fix jagged_unique_indices OOB regression from D104827588 (#5799)#5799

Open
AlbertDachiChen wants to merge 1 commit into
pytorch:mainfrom
AlbertDachiChen:export-D106305472
Open

Fix jagged_unique_indices OOB regression from D104827588 (#5799)#5799
AlbertDachiChen wants to merge 1 commit into
pytorch:mainfrom
AlbertDachiChen:export-D106305472

Conversation

@AlbertDachiChen
Copy link
Copy Markdown
Contributor

@AlbertDachiChen AlbertDachiChen commented May 29, 2026

Summary:

X-link: https://github.com/facebookresearch/FBGEMM/pull/2727

D104827588 (length kernel) and f2e9cba2e279 (delinearize kernel) both introduced an unstated "indices must lie in [0, per_feature_hash)" assumption that the op never enforced before. ZCH callers that pass raw IDs to jagged_unique_indices (dedup runs before the hash-to-bucket remap) tripped this on the last feature in prod, producing sum(output_lengths) < unique_indices.numel() and crashing downstream dist.all_to_all_single with "Split sizes doesn't match total dim 0 size".

Two fixes in jagged_unique_indices.cu:

  1. unique_indices_length_kernel: when the group's hash_end == hash_size_cumsum.size(0) - 1 (last feature), set hi_pos = linear_unique_indices.size(0) instead of doing the upper-bound binary search. The OOB tail of the last feature sorts past hash_size_cumsum[T] and the binary search would otherwise stop short of it.

  2. Replace delinearize_unique_from_sorted_kernel (gather: unique_indices[i] = v - hash_size_cumsum[t]) with the original delinearize_unique_index_kernel (scatter: unique_indices[reverse_index[i]] = indices[i]). The gather form misattributes OOB values whose linearized key exceeds hash_size_cumsum[T] to a phantom feature t = T and emits v - hash_size_cumsum[T] instead of the original index value. The scatter form is index-bound-agnostic by construction. Note reverse_index is int64_t (from the cub pipeline) rather than templated index_t (which was the case under at::_unique).

Contract enforcement: added a CUDA_KERNEL_ASSERT in linearize_index_flat_kernel that fires when an intermediate feature (t < T - 1) with per_feature_hash > 0 has idx >= per_feature_hash. Intermediate-feature OOB causes silent per-feature count drift (counts leak from t to t+1 while the total is preserved) — this has been broken since the op was written, but no caller hit it, so the assert surfaces violations rather than silently corrupting downstream embedding lookups. The assert exempts (a) the last feature (legitimately supported) and (b) merged/masked features with per_feature_hash == 0 (the hash_size_offsets indirection pattern used by multi_keys and zch_huge_hash_size).

The pipeline contract doc comment above jagged_unique_indices_cuda is updated to enumerate the three cases.

Differential Revision: D106305472

@meta-cla meta-cla Bot added the cla signed label May 29, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 29, 2026

@AlbertDachiChen has exported this pull request. If you are a Meta employee, you can view the originating Diff in D106305472.

Summary:

X-link: facebookresearch/FBGEMM#2727

D104827588 (length kernel) and f2e9cba2e279 (delinearize kernel) both introduced an unstated "indices must lie in [0, per_feature_hash)" assumption that the op never enforced before. ZCH callers that pass raw IDs to `jagged_unique_indices` (dedup runs before the hash-to-bucket remap) tripped this on the last feature in prod, producing `sum(output_lengths) < unique_indices.numel()` and crashing downstream `dist.all_to_all_single` with "Split sizes doesn't match total dim 0 size".

Two fixes in `jagged_unique_indices.cu`:

1. `unique_indices_length_kernel`: when the group's `hash_end == hash_size_cumsum.size(0) - 1` (last feature), set `hi_pos = linear_unique_indices.size(0)` instead of doing the upper-bound binary search. The OOB tail of the last feature sorts past `hash_size_cumsum[T]` and the binary search would otherwise stop short of it.

2. Replace `delinearize_unique_from_sorted_kernel` (gather: `unique_indices[i] = v - hash_size_cumsum[t]`) with the original `delinearize_unique_index_kernel` (scatter: `unique_indices[reverse_index[i]] = indices[i]`). The gather form misattributes OOB values whose linearized key exceeds `hash_size_cumsum[T]` to a phantom feature `t = T` and emits `v - hash_size_cumsum[T]` instead of the original index value. The scatter form is index-bound-agnostic by construction. Note `reverse_index` is `int64_t` (from the cub pipeline) rather than templated `index_t` (which was the case under `at::_unique`).

Contract enforcement: added a `CUDA_KERNEL_ASSERT` in `linearize_index_flat_kernel` that fires when an intermediate feature (`t < T - 1`) with `per_feature_hash > 0` has `idx >= per_feature_hash`. Intermediate-feature OOB causes silent per-feature count drift (counts leak from `t` to `t+1` while the total is preserved) — this has been broken since the op was written, but no caller hit it, so the assert surfaces violations rather than silently corrupting downstream embedding lookups. The assert exempts (a) the last feature (legitimately supported) and (b) merged/masked features with `per_feature_hash == 0` (the `hash_size_offsets` indirection pattern used by `multi_keys` and `zch_huge_hash_size`).

The pipeline contract doc comment above `jagged_unique_indices_cuda` is updated to enumerate the three cases.

Differential Revision: D106305472
AlbertDachiChen added a commit to AlbertDachiChen/FBGEMM that referenced this pull request Jun 1, 2026
Summary:

X-link: facebookresearch/FBGEMM#2727

D104827588 (length kernel) and f2e9cba2e279 (delinearize kernel) both introduced an unstated "indices must lie in [0, per_feature_hash)" assumption that the op never enforced before. ZCH callers that pass raw IDs to `jagged_unique_indices` (dedup runs before the hash-to-bucket remap) tripped this on the last feature in prod, producing `sum(output_lengths) < unique_indices.numel()` and crashing downstream `dist.all_to_all_single` with "Split sizes doesn't match total dim 0 size".

Two fixes in `jagged_unique_indices.cu`:

1. `unique_indices_length_kernel`: when the group's `hash_end == hash_size_cumsum.size(0) - 1` (last feature), set `hi_pos = linear_unique_indices.size(0)` instead of doing the upper-bound binary search. The OOB tail of the last feature sorts past `hash_size_cumsum[T]` and the binary search would otherwise stop short of it.

2. Replace `delinearize_unique_from_sorted_kernel` (gather: `unique_indices[i] = v - hash_size_cumsum[t]`) with the original `delinearize_unique_index_kernel` (scatter: `unique_indices[reverse_index[i]] = indices[i]`). The gather form misattributes OOB values whose linearized key exceeds `hash_size_cumsum[T]` to a phantom feature `t = T` and emits `v - hash_size_cumsum[T]` instead of the original index value. The scatter form is index-bound-agnostic by construction. Note `reverse_index` is `int64_t` (from the cub pipeline) rather than templated `index_t` (which was the case under `at::_unique`).

Contract enforcement: added a `CUDA_KERNEL_ASSERT` in `linearize_index_flat_kernel` that fires when an intermediate feature (`t < T - 1`) with `per_feature_hash > 0` has `idx >= per_feature_hash`. Intermediate-feature OOB causes silent per-feature count drift (counts leak from `t` to `t+1` while the total is preserved) — this has been broken since the op was written, but no caller hit it, so the assert surfaces violations rather than silently corrupting downstream embedding lookups. The assert exempts (a) the last feature (legitimately supported) and (b) merged/masked features with `per_feature_hash == 0` (the `hash_size_offsets` indirection pattern used by `multi_keys` and `zch_huge_hash_size`).

The pipeline contract doc comment above `jagged_unique_indices_cuda` is updated to enumerate the three cases.

Differential Revision: D106305472
@meta-codesync meta-codesync Bot changed the title Fix jagged_unique_indices OOB regression from D104827588 Fix jagged_unique_indices OOB regression from D104827588 (#5799) Jun 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant