Commit cc3955f
Speedup unique_indices_length_kernel via binary search (#5766)
Summary:
X-link: facebookresearch/FBGEMM#2695
The previous unique_indices_length_kernel computes per-feature unique-count via a BlockReduce-based min/max scan over the entire reverse_index array. With grid size = T (number of feature groups, typically 1-2 in production), only T SMs do work out of 132 on H100. Each block scans the full per-feature slice of reverse_index (~12M int64 = ~93MB for the prod IFR-MTML mc7 shape), bandwidth-bound on a single SM at ~30-50 GB/s. Total wall-clock is ~2-3 ms, dominating this op end-to-end (~60% of the ~5 ms baseline on the prod shape).
The kernel was reading 186 MB to compute 4 numbers (a min and max per feature group). It is wasteful because the information is already implicit in `linear_unique_indices`: since `at::_unique` is called with `sorted=True` and `linearize_index_wo_infos_kernel` writes `linear_indices[i] = hash_size_cumsum[t] + indices[i]`, feature t's unique linearized values occupy a contiguous slice of `linear_unique_indices`, namely `[lower_bound(unique, hash_size_cumsum[t]), lower_bound(unique, hash_size_cumsum[t+1]))`. The slice length is `num_unique_t`, which equals the `(max - min + 1)` reduction the old kernel computed.
Replace the O(N) reduction with two O(log U) binary searches per feature group via a new device-side `device_lower_bound` helper. Block size 1024 -> 256 (no shared-memory reduction, no per-thread scratch). The per-block work is now ~336 B of reads (two binary searches, ~21 iterations each, 8 B per iteration), which trivially fits in cache; the T-block grid stops mattering because there is no work to parallelize.
The pipeline contract that ties the four kernels of `jagged_unique_indices_cuda` together (linearize -> at::_unique -> delinearize -> length) is documented above the function so the next reader does not have to reverse-engineer it from the kernel bodies. The length kernel docstring states the local form of the invariant and points at the orchestrator for the why.
Also adds `test_jagged_unique_indices_zch_huge_hash_size`, a regression test for the `ManagedCollisionCollection` shape that exposes `total_hash_size = INT64_MAX`. This shape is produced when a sharding group contains a single `HashZchManagedCollisionModule` with the default `input_hash_size=0`. `mc_modules._create_input_dists` then expands per-table hash size to `2**(63 - N) - 1` (per `torchrec/distributed/mc_modules.py:643`); for N=0 (single-table group) that lands at INT64_MAX. This shape was not exercised by any existing test and was the trigger for the `cudaErrorIllegalInstruction` in S660690. **The new length kernel handles it correctly (integer-only arithmetic at the boundary), but the test also serves as a trip-wire for downstream optimizations that introduce float-log2 math on `total_hash_size`.**
No public API change. Outputs of `jagged_unique_indices` are bit-identical to the previous version for all valid inputs.
Reviewed By: q10
Differential Revision: D1048275881 parent 092b281 commit cc3955f
3 files changed
Lines changed: 175 additions & 53 deletions
File tree
- fbgemm_gpu
- src/jagged_tensor_ops
- test/jagged
Lines changed: 99 additions & 53 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
25 | 30 | | |
26 | 31 | | |
27 | 32 | | |
| |||
79 | 84 | | |
80 | 85 | | |
81 | 86 | | |
82 | | - | |
83 | | - | |
84 | | - | |
85 | | - | |
86 | | - | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
87 | 117 | | |
88 | 118 | | |
89 | 119 | | |
90 | | - | |
| 120 | + | |
91 | 121 | | |
92 | | - | |
93 | | - | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | | - | |
98 | | - | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
99 | 125 | | |
100 | 126 | | |
101 | | - | |
102 | | - | |
103 | | - | |
104 | | - | |
105 | | - | |
106 | | - | |
107 | | - | |
108 | | - | |
109 | | - | |
110 | 127 | | |
111 | | - | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
112 | 134 | | |
113 | 135 | | |
114 | 136 | | |
115 | | - | |
116 | | - | |
117 | | - | |
118 | | - | |
119 | | - | |
120 | | - | |
121 | | - | |
122 | | - | |
123 | | - | |
124 | | - | |
125 | | - | |
126 | | - | |
127 | | - | |
| 137 | + | |
| 138 | + | |
128 | 139 | | |
129 | | - | |
130 | | - | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
131 | 156 | | |
132 | 157 | | |
133 | 158 | | |
134 | | - | |
135 | | - | |
136 | | - | |
137 | | - | |
138 | | - | |
139 | | - | |
140 | | - | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
141 | 167 | | |
142 | 168 | | |
143 | 169 | | |
144 | 170 | | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
145 | 193 | | |
146 | 194 | | |
147 | 195 | | |
| |||
192 | 240 | | |
193 | 241 | | |
194 | 242 | | |
195 | | - | |
196 | | - | |
197 | | - | |
198 | | - | |
| 243 | + | |
199 | 244 | | |
200 | | - | |
| 245 | + | |
201 | 246 | | |
202 | 247 | | |
203 | 248 | | |
204 | | - | |
205 | | - | |
206 | | - | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
207 | 253 | | |
208 | 254 | | |
209 | 255 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
107 | 107 | | |
108 | 108 | | |
109 | 109 | | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
110 | 114 | | |
111 | 115 | | |
112 | 116 | | |
| |||
118 | 122 | | |
119 | 123 | | |
120 | 124 | | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
121 | 129 | | |
122 | 130 | | |
123 | 131 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
269 | 269 | | |
270 | 270 | | |
271 | 271 | | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
272 | 340 | | |
273 | 341 | | |
274 | 342 | | |
| |||
0 commit comments