Skip to content

Commit 409bbf3

Browse files
Speedup unique_indices_length_kernel via binary search (pytorch#5766)
Summary: X-link: facebookresearch/FBGEMM#2695 The previous unique_indices_length_kernel computes per-feature unique-count via a BlockReduce-based min/max scan over the entire reverse_index array. With grid size = T (number of feature groups, typically 1-2 in production), only T SMs do work out of 132 on H100. Each block scans the full per-feature slice of reverse_index (~12M int64 = ~93MB for the prod IFR-MTML mc7 shape), bandwidth-bound on a single SM at ~30-50 GB/s. Total wall-clock is ~2-3 ms, dominating this op end-to-end (~60% of the ~5 ms baseline on the prod shape). The kernel was reading 186 MB to compute 4 numbers (a min and max per feature group). It is wasteful because the information is already implicit in `linear_unique_indices`: since `at::_unique` is called with `sorted=True` and `linearize_index_wo_infos_kernel` writes `linear_indices[i] = hash_size_cumsum[t] + indices[i]`, feature t's unique linearized values occupy a contiguous slice of `linear_unique_indices`, namely `[lower_bound(unique, hash_size_cumsum[t]), lower_bound(unique, hash_size_cumsum[t+1]))`. The slice length is `num_unique_t`, which equals the `(max - min + 1)` reduction the old kernel computed. Replace the O(N) reduction with two O(log U) binary searches per feature group via a new device-side `device_lower_bound` helper. Block size 1024 -> 256 (no shared-memory reduction, no per-thread scratch). The per-block work is now ~336 B of reads (two binary searches, ~21 iterations each, 8 B per iteration), which trivially fits in cache; the T-block grid stops mattering because there is no work to parallelize. The pipeline contract that ties the four kernels of `jagged_unique_indices_cuda` together (linearize -> at::_unique -> delinearize -> length) is documented above the function so the next reader does not have to reverse-engineer it from the kernel bodies. The length kernel docstring states the local form of the invariant and points at the orchestrator for the why. Also adds `test_jagged_unique_indices_zch_huge_hash_size`, a regression test for the `ManagedCollisionCollection` shape that exposes `total_hash_size = INT64_MAX`. This shape is produced when a sharding group contains a single `HashZchManagedCollisionModule` with the default `input_hash_size=0`. `mc_modules._create_input_dists` then expands per-table hash size to `2**(63 - N) - 1` (per `torchrec/distributed/mc_modules.py:643`); for N=0 (single-table group) that lands at INT64_MAX. This shape was not exercised by any existing test and was the trigger for the `cudaErrorIllegalInstruction` in S660690. **The new length kernel handles it correctly (integer-only arithmetic at the boundary), but the test also serves as a trip-wire for downstream optimizations that introduce float-log2 math on `total_hash_size`.** No public API change. Outputs of `jagged_unique_indices` are bit-identical to the previous version for all valid inputs. Reviewed By: q10 Differential Revision: D104827588
1 parent 3203889 commit 409bbf3

3 files changed

Lines changed: 175 additions & 53 deletions

File tree

fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu

Lines changed: 99 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ using Tensor = at::Tensor;
2222

2323
namespace fbgemm_gpu {
2424

25+
// Block size for the flat-grid kernels in this file. 256 = 8 warps, chosen
26+
// so the SM scheduler can pack ~4-8 blocks per SM on H100 given the
27+
// flat-grid launches with grid = total_B (or num_unique).
28+
static constexpr int32_t kFlatBlockSize = 256;
29+
2530
// Linearzie the index with the cumsum of hash size so that linearized indices
2631
// can be sorted together.
2732
template <typename index_t>
@@ -79,69 +84,112 @@ __global__ __launch_bounds__(kMaxThreads) void delinearize_unique_index_kernel(
7984
}
8085
}
8186

82-
// Compute the lengths for each feature in the unique indices. The range of
83-
// indices for each feature equals to the difference between the max and min
84-
// values in the reverse index array.
85-
template <typename index_t, auto max_value, auto min_value>
86-
__global__ __launch_bounds__(kMaxThreads) void unique_indices_length_kernel(
87+
// Device-side lower_bound over a PackedTensorAccessor32<index_t, 1>.
88+
// Returns the first position whose value is >= `value`, equivalent to
89+
// std::lower_bound on the underlying sorted array.
90+
template <typename index_t>
91+
__device__ __forceinline__ int32_t device_lower_bound(
92+
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>& arr,
93+
const index_t value) {
94+
int32_t lo = 0;
95+
int32_t hi = arr.size(0);
96+
while (lo < hi) {
97+
const int32_t mid = lo + ((hi - lo) >> 1);
98+
if (arr[mid] < value) {
99+
lo = mid + 1;
100+
} else {
101+
hi = mid;
102+
}
103+
}
104+
return lo;
105+
}
106+
107+
// Compute the per-(feature, batch) lengths for the unique indices.
108+
//
109+
// Caller-provided invariant (see jagged_unique_indices_cuda for the
110+
// pipeline contract that establishes it): `linear_unique_indices` is
111+
// sorted ascending, and feature t's values occupy a contiguous slice
112+
// [lower_bound(linear_unique_indices, hash_size_cumsum[t]),
113+
// lower_bound(linear_unique_indices, hash_size_cumsum[t+1])).
114+
// The slice length equals num_unique_t for feature t.
115+
template <typename index_t>
116+
__global__ __launch_bounds__(kFlatBlockSize) void unique_indices_length_kernel(
87117
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
88118
hash_size_offsets,
89119
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
90-
reverse_index,
120+
hash_size_cumsum,
91121
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
92-
offsets,
93-
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> lengths) {
94-
typedef cub::BlockReduce<index_t, kMaxThreads> BlockReduce;
95-
__shared__ typename BlockReduce::TempStorage temp_storage_max;
96-
__shared__ typename BlockReduce::TempStorage temp_storage_min;
97-
__shared__ index_t block_results[2];
98-
122+
linear_unique_indices,
123+
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> lengths,
124+
const int32_t batch_size) {
99125
const auto tid = threadIdx.x;
100126
const auto bid = blockIdx.x;
101-
const auto num_blocks = gridDim.x;
102-
const int32_t batch_size = (offsets.size(0) - 1) / num_blocks;
103-
104-
const auto offset_begin = hash_size_offsets[bid] * batch_size;
105-
const auto offset_end = hash_size_offsets[bid + 1] * batch_size;
106-
const auto num_lengths = (offset_end - offset_begin);
107-
108-
const auto reverse_index_begin = offsets[offset_begin];
109-
const auto reverse_index_end = offsets[offset_end];
110127

111-
if (reverse_index_begin == reverse_index_end) {
128+
const auto hash_begin = hash_size_offsets[bid];
129+
const auto hash_end = hash_size_offsets[bid + 1];
130+
const auto offset_begin = hash_begin * batch_size;
131+
const auto offset_end = hash_end * batch_size;
132+
const auto num_lengths = offset_end - offset_begin;
133+
if (num_lengths == 0) {
112134
return;
113135
}
114136

115-
index_t t_max = min_value;
116-
index_t t_min = max_value;
117-
for (index_t i = (reverse_index_begin + tid); i < reverse_index_end;
118-
i += kMaxThreads) {
119-
const index_t value = reverse_index[i];
120-
t_max = (value > t_max) ? value : t_max;
121-
t_min = (value < t_min) ? value : t_min;
122-
}
123-
124-
index_t block_max =
125-
BlockReduce(temp_storage_max).Reduce(t_max, Max<index_t>());
126-
index_t block_min =
127-
BlockReduce(temp_storage_min).Reduce(t_min, Min<index_t>());
137+
__shared__ index_t s_div_length;
138+
__shared__ index_t s_r_length;
128139
if (tid == 0) {
129-
block_results[0] = block_max;
130-
block_results[1] = block_min;
140+
const auto low = hash_size_cumsum[hash_begin];
141+
const auto high = hash_size_cumsum[hash_end];
142+
if (low == high) {
143+
// Empty feature group. Output is pre-zeroed by at::zeros at the
144+
// launch site; nothing to write.
145+
s_div_length = 0;
146+
s_r_length = 0;
147+
} else {
148+
const int32_t lo_pos =
149+
device_lower_bound<index_t>(linear_unique_indices, low);
150+
const int32_t hi_pos =
151+
device_lower_bound<index_t>(linear_unique_indices, high);
152+
const index_t total_length = static_cast<index_t>(hi_pos - lo_pos);
153+
s_div_length = total_length / static_cast<index_t>(num_lengths);
154+
s_r_length = total_length % static_cast<index_t>(num_lengths);
155+
}
131156
}
132157
__syncthreads();
133158

134-
t_max = block_results[0];
135-
t_min = block_results[1];
136-
const index_t total_length = (t_max - t_min) + 1;
137-
const index_t div_length = total_length / num_lengths;
138-
const index_t r_length = total_length % num_lengths;
139-
for (int32_t i = tid; i < num_lengths; i += kMaxThreads) {
140-
index_t seg_length = (i < r_length) ? (div_length + 1) : div_length;
159+
const index_t div_length = s_div_length;
160+
const index_t r_length = s_r_length;
161+
if (div_length == 0 && r_length == 0) {
162+
return;
163+
}
164+
for (int32_t i = tid; i < num_lengths; i += blockDim.x) {
165+
const index_t seg_length =
166+
(static_cast<index_t>(i) < r_length) ? (div_length + 1) : div_length;
141167
lengths[offset_begin + i] = seg_length;
142168
}
143169
}
144170

171+
// Pipeline (cross-kernel data flow that ties the four steps together):
172+
//
173+
// 1. linearize_index_wo_infos_kernel writes
174+
// linear_indices[i] = hash_size_cumsum[t] + indices[i]
175+
// so feature t's linearized values lie in
176+
// [hash_size_cumsum[t], hash_size_cumsum[t+1]).
177+
//
178+
// 2. at::_unique(linear_indices, sorted=True, return_inverse=True)
179+
// returns (linear_unique_indices, reverse_index) where
180+
// linear_unique_indices is sorted ascending. Combined with (1), this
181+
// means feature t's unique linearized values occupy a contiguous
182+
// slice of linear_unique_indices.
183+
//
184+
// 3. delinearize_unique_index_kernel scatters the original
185+
// (pre-linearization) per-feature index values back into
186+
// unique_indices via reverse_index.
187+
//
188+
// 4. unique_indices_length_kernel relies on (1)+(2) to compute
189+
// num_unique per feature group via two binary searches over
190+
// linear_unique_indices, instead of an O(N) reduction over
191+
// reverse_index. See the kernel's docstring for the local form of
192+
// the invariant.
145193
std::tuple<Tensor, Tensor, Tensor, Tensor> jagged_unique_indices_cuda(
146194
const Tensor& hash_size_cumsum,
147195
const Tensor& hash_size_offsets,
@@ -192,18 +240,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> jagged_unique_indices_cuda(
192240
Tensor output_lengths = at::zeros({total_B}, offsets.options());
193241
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "unique_indices_length", ([&] {
194242
FBGEMM_LAUNCH_KERNEL(
195-
(unique_indices_length_kernel<
196-
index_t,
197-
std::numeric_limits<index_t>::max(),
198-
std::numeric_limits<index_t>::min()>),
243+
(unique_indices_length_kernel<index_t>),
199244
T,
200-
kMaxThreads,
245+
kFlatBlockSize,
201246
0,
202247
at::cuda::getCurrentCUDAStream(),
203248
PTA_B(hash_size_offsets, index_t, 1, 32),
204-
PTA_B(reverse_index, index_t, 1, 32),
205-
PTA_B(offsets, index_t, 1, 32),
206-
PTA_B(output_lengths, index_t, 1, 32));
249+
PTA_B(hash_size_cumsum, index_t, 1, 32),
250+
PTA_B(linear_unique_indices, index_t, 1, 32),
251+
PTA_B(output_lengths, index_t, 1, 32),
252+
static_cast<int32_t>(total_B / T));
207253
}));
208254

209255
Tensor output_offsets;

fbgemm_gpu/test/jagged/failures_dict.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@
107107
"comment": "",
108108
"status": "xfail"
109109
},
110+
"UniqueIndicesTest.test_aot_dispatch_dynamic__test_jagged_unique_indices_zch_huge_hash_size": {
111+
"comment": "Test uses .item()/.tolist() for host-side comparison; incompatible with dynamic dispatch.",
112+
"status": "xfail"
113+
},
110114
"UniqueIndicesTest.test_faketensor__test_jagged_unique_indices": {
111115
"comment": "",
112116
"status": "xfail"
@@ -118,6 +122,10 @@
118122
"UniqueIndicesTest.test_faketensor__test_jagged_unique_indices_multi_keys": {
119123
"comment": "",
120124
"status": "xfail"
125+
},
126+
"UniqueIndicesTest.test_faketensor__test_jagged_unique_indices_zch_huge_hash_size": {
127+
"comment": "Test uses .item()/.tolist() for host-side comparison; incompatible with fake tensors.",
128+
"status": "xfail"
121129
}
122130
},
123131
"fbgemm::keyed_jagged_index_select_dim1": {},

fbgemm_gpu/test/jagged/unique_indices_test.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,74 @@ def test_jagged_unique_indices_empty(
269269
self.assertEqual(torch.sum(output_lengths).item(), 0)
270270
self.assertEqual(torch.sum(output_offsets).item(), 0)
271271

272+
@unittest.skipIf(*gpu_unavailable)
273+
def test_jagged_unique_indices_zch_huge_hash_size(self) -> None:
274+
"""Exercise the op with a hash_size_cumsum entry at INT64_MAX -
275+
the shape produced by ZCH callers that leave per-feature hash size
276+
unbounded. The op must handle hash boundaries spanning the full
277+
int64 range without overflow in any internal arithmetic.
278+
"""
279+
T = 2
280+
B = 64
281+
max_length = 5
282+
int64_max = torch.iinfo(torch.int64).max
283+
hash_size_cumsum_list = [0, 0, int64_max]
284+
hash_size_offsets_list = [0, 0, 2]
285+
# Per-feature linearized values lie in [0, INT64_MAX). The kernels
286+
# under test are boundary-value-sensitive on hash_size_cumsum, not
287+
# on the indices themselves, so a small index range is sufficient
288+
# and keeps the reference comparison fast.
289+
per_feature_value_cap = 1024
290+
lengths_list: list[int] = []
291+
indices_list: list[int] = []
292+
for _ in range(T):
293+
for _ in range(B):
294+
length = random.randint(0, max_length)
295+
lengths_list.append(length)
296+
if length > 0:
297+
indices_list.extend(
298+
np.random.randint(
299+
0, per_feature_value_cap, size=length
300+
).tolist()
301+
)
302+
303+
device = torch.accelerator.current_accelerator()
304+
assert device is not None
305+
dtype = torch.int64
306+
hash_size_cumsum = torch.as_tensor(
307+
hash_size_cumsum_list, dtype=dtype, device=device
308+
)
309+
hash_size_offsets = torch.as_tensor(
310+
hash_size_offsets_list, dtype=dtype, device=device
311+
)
312+
lengths = torch.as_tensor(lengths_list, dtype=dtype, device=device)
313+
indices = torch.as_tensor(indices_list, dtype=dtype, device=device)
314+
offsets = torch.zeros(T * B + 1, dtype=dtype, device=device)
315+
offsets[1:] = torch.cumsum(lengths, dim=0)
316+
317+
(
318+
output_lengths,
319+
output_offsets,
320+
unique_indices,
321+
reverse_index,
322+
) = torch.ops.fbgemm.jagged_unique_indices(
323+
hash_size_cumsum, hash_size_offsets, offsets, indices
324+
)
325+
326+
# Both features share the same hash space (hash_offset = 0 for
327+
# both, since hash_size_cumsum[0] == hash_size_cumsum[1] == 0),
328+
# so the global unique set is the union of all input indices.
329+
expected_unique = sorted(set(indices_list))
330+
self.assertEqual(unique_indices.numel(), len(expected_unique))
331+
self.assertEqual(int(torch.sum(output_lengths).item()), unique_indices.numel())
332+
# Inverse-index round-trip: unique_indices[reverse_index[i]] == indices[i].
333+
rev_list = reverse_index.tolist()
334+
uniq_list = unique_indices.tolist()
335+
self.assertEqual(len(rev_list), len(indices_list))
336+
for i, rev in enumerate(rev_list):
337+
self.assertTrue(0 <= rev < len(uniq_list))
338+
self.assertEqual(uniq_list[rev], indices_list[i])
339+
272340
@given(
273341
num_elements=st.integers(min_value=100, max_value=10000),
274342
num_unique_indices=st.integers(min_value=5, max_value=100),

0 commit comments

Comments
 (0)