You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Speedup delinearize via feature-lookup over num_unique (#5770)
Summary:
Pull Request resolved: #5770
X-link: https://github.com/facebookresearch/FBGEMM/pull/2699
delinearize_unique_index_kernel iterated over total_indices (~24M on the prod IFR-MTML mc7 shape), reading both `indices` and `reverse_index` per element and scatter-writing to `unique_indices[reverse_index[i]] = indices[i]`. The scatter pattern hits random L2 lines, and the scatter is gated on N=24M elements.
Replace with delinearize_unique_from_sorted_kernel which iterates over num_unique (~1-2M, ~24x fewer threads). Each thread reads one sorted unique key v and recovers its (feature_t, per_feature_idx) by binary-searching hash_size_cumsum:
t = lower_bound(hash_size_cumsum, v + 1) - 1 // largest t s.t. cumsum[t] <= v
unique_indices[i] = v - hash_size_cumsum[t]
The writes are sequential over unique_indices (no scatter). The probe is O(log T) where T is the number of features (typically <= 256 in practice), executing in <= 8 iterations against L1-resident hash_size_cumsum.
The `v + 1` does not overflow at the ZCH boundary: unique_keys[i] is bounded by hash_size_cumsum[T-1] <= total_hash_size, and feature t's linearized values lie in [hash_size_cumsum[t], hash_size_cumsum[t+1]), so v <= total_hash_size - 1 even when total_hash_size = INT64_MAX.
Bench delta on prod IFR-MTML mc7 shape (T=2, B=2560, ~23M int64 indices):
```
D-3 D-4
delinearize kernel ~1.62 ms ~13 us (~120x)
fbgemm::jagged_unique_indices ~5.4 ms ~3.8 ms (-1.6 ms)
```
No public API change. Output is bit-identical to the prior kernel for all valid inputs (verified by all unit + opcheck tests).
Reviewed By: q10
Differential Revision: D105005652
fbshipit-source-id: f2e9cba2e279105a0c38ff4ed1f39a9397c7622e
0 commit comments