Skip to content

Commit 7d00c88

Browse files
AlbertDachiChenmeta-codesync[bot]
authored andcommitted
Speedup delinearize via feature-lookup over num_unique (#5770)
Summary: Pull Request resolved: #5770 X-link: https://github.com/facebookresearch/FBGEMM/pull/2699 delinearize_unique_index_kernel iterated over total_indices (~24M on the prod IFR-MTML mc7 shape), reading both `indices` and `reverse_index` per element and scatter-writing to `unique_indices[reverse_index[i]] = indices[i]`. The scatter pattern hits random L2 lines, and the scatter is gated on N=24M elements. Replace with delinearize_unique_from_sorted_kernel which iterates over num_unique (~1-2M, ~24x fewer threads). Each thread reads one sorted unique key v and recovers its (feature_t, per_feature_idx) by binary-searching hash_size_cumsum: t = lower_bound(hash_size_cumsum, v + 1) - 1 // largest t s.t. cumsum[t] <= v unique_indices[i] = v - hash_size_cumsum[t] The writes are sequential over unique_indices (no scatter). The probe is O(log T) where T is the number of features (typically <= 256 in practice), executing in <= 8 iterations against L1-resident hash_size_cumsum. The `v + 1` does not overflow at the ZCH boundary: unique_keys[i] is bounded by hash_size_cumsum[T-1] <= total_hash_size, and feature t's linearized values lie in [hash_size_cumsum[t], hash_size_cumsum[t+1]), so v <= total_hash_size - 1 even when total_hash_size = INT64_MAX. Bench delta on prod IFR-MTML mc7 shape (T=2, B=2560, ~23M int64 indices): ``` D-3 D-4 delinearize kernel ~1.62 ms ~13 us (~120x) fbgemm::jagged_unique_indices ~5.4 ms ~3.8 ms (-1.6 ms) ``` No public API change. Output is bit-identical to the prior kernel for all valid inputs (verified by all unit + opcheck tests). Reviewed By: q10 Differential Revision: D105005652 fbshipit-source-id: f2e9cba2e279105a0c38ff4ed1f39a9397c7622e
1 parent 89b3a95 commit 7d00c88

1 file changed

Lines changed: 53 additions & 41 deletions

File tree

fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -67,32 +67,6 @@ __global__ __launch_bounds__(kFlatBlockSize) void linearize_index_flat_kernel(
6767
}
6868
}
6969

70-
// Delinearize the unique indices from the reverse index info and the original
71-
// indices. For each element in the input indices, the value should equal to
72-
// the element from the unique indices according to the reverse index info.
73-
//
74-
// reverse_index is always int64 to match the public contract of
75-
// jagged_unique_indices (see jagged_unique_scatter_kernel), independent of
76-
// indices.scalar_type(); typing it as int64_t here (instead of reusing
77-
// index_t) keeps PTA_B's runtime dtype check from firing when index_t is
78-
// int32_t.
79-
template <typename index_t>
80-
__global__ __launch_bounds__(kMaxThreads) void delinearize_unique_index_kernel(
81-
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
82-
indices,
83-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
84-
reverse_index,
85-
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
86-
unique_indices) {
87-
const auto total_indices = indices.size(0);
88-
const auto b_t = blockIdx.x * blockDim.x + threadIdx.x;
89-
if (b_t < total_indices) {
90-
const auto original_index = indices[b_t];
91-
const auto pos = reverse_index[b_t];
92-
unique_indices[pos] = original_index;
93-
}
94-
}
95-
9670
// Adjacent-difference over sorted keys. out[0] = 0; out[i > 0] = 1 if
9771
// sorted[i] != sorted[i-1] else 0. Mirrors
9872
// caffe2/aten/src/ATen/native/cuda/UniqueCub.cu.
@@ -284,6 +258,37 @@ __device__ __forceinline__ int32_t device_lower_bound(
284258
return lo;
285259
}
286260

261+
// Delinearize the unique indices via per-feature lookup over hash_size_cumsum.
262+
// Each thread handles one sorted unique key v: locate its feature t via
263+
// device_lower_bound(hash_size_cumsum, v + 1) - 1 (largest t with
264+
// hash_size_cumsum[t] <= v) and emit unique_indices[i] = v -
265+
// hash_size_cumsum[t].
266+
//
267+
// Replaces delinearize_unique_index_kernel which iterated over total_indices
268+
// (~24M on the prod IFR-MTML mc7 shape) and scatter-wrote via reverse_index.
269+
// The new form iterates over num_unique (~1-2M) with a small O(log T) probe
270+
// per thread - 24x fewer threads, and the writes are sequential.
271+
template <typename index_t>
272+
__global__
273+
__launch_bounds__(kFlatBlockSize) void delinearize_unique_from_sorted_kernel(
274+
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
275+
hash_size_cumsum,
276+
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
277+
unique_keys,
278+
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
279+
unique_indices) {
280+
const auto num_unique = unique_keys.size(0);
281+
const auto i = blockIdx.x * blockDim.x + threadIdx.x;
282+
if (i >= static_cast<uint64_t>(num_unique)) {
283+
return;
284+
}
285+
const index_t v = unique_keys[i];
286+
// unique_keys[i] < hash_size_cumsum[T] <= total_hash_size, so v + 1
287+
// does not overflow even at the ZCH boundary (total_hash_size = INT64_MAX).
288+
const int32_t t = device_lower_bound<index_t>(hash_size_cumsum, v + 1) - 1;
289+
unique_indices[i] = v - hash_size_cumsum[t];
290+
}
291+
287292
// Compute the per-(feature, batch) lengths for the unique indices.
288293
//
289294
// Caller-provided invariant (see jagged_unique_indices_cuda for the
@@ -364,9 +369,12 @@ __global__ __launch_bounds__(kFlatBlockSize) void unique_indices_length_kernel(
364369
// which on production shapes (hash_size ~1M) reduces 8 radix passes
365370
// to ~3.
366371
//
367-
// 3. delinearize_unique_index_kernel scatters the original
368-
// (pre-linearization) per-feature index values back into
369-
// unique_indices via reverse_index.
372+
// 3. delinearize_unique_from_sorted_kernel iterates over num_unique
373+
// (~1-2M) instead of total_indices (~24M). Each thread reads one
374+
// sorted unique key v and recovers its (feature, per-feature index)
375+
// via a binary search on hash_size_cumsum. Writes are sequential
376+
// over the num_unique output, eliminating the scatter-via-
377+
// reverse_index pattern of the prior kernel.
370378
//
371379
// 4. unique_indices_length_kernel relies on (1)+(2) to compute
372380
// num_unique per feature group via two binary searches over
@@ -462,18 +470,22 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> jagged_unique_indices_cuda(
462470
const auto total_indices = indices.size(0);
463471
Tensor unique_indices = at::empty_like(linear_unique_indices);
464472

465-
AT_DISPATCH_INDEX_TYPES(
466-
indices.scalar_type(), "delinearize_unique_index", ([&] {
467-
FBGEMM_LAUNCH_KERNEL(
468-
(delinearize_unique_index_kernel<index_t>),
469-
div_round_up(total_indices + 1, kMaxThreads),
470-
kMaxThreads,
471-
0,
472-
at::cuda::getCurrentCUDAStream(),
473-
PTA_B(indices, index_t, 1, 32),
474-
PTA_B(reverse_index, int64_t, 1, 32),
475-
PTA_B(unique_indices, index_t, 1, 32));
476-
}));
473+
if (total_indices > 0 && unique_indices.numel() > 0) {
474+
AT_DISPATCH_INDEX_TYPES(
475+
indices.scalar_type(), "delinearize_unique_index", ([&] {
476+
FBGEMM_LAUNCH_KERNEL(
477+
(delinearize_unique_from_sorted_kernel<index_t>),
478+
static_cast<int32_t>(
479+
(unique_indices.numel() + kFlatBlockSize - 1) /
480+
kFlatBlockSize),
481+
kFlatBlockSize,
482+
0,
483+
at::cuda::getCurrentCUDAStream(),
484+
PTA_B(hash_size_cumsum, index_t, 1, 32),
485+
PTA_B(linear_unique_indices, index_t, 1, 32),
486+
PTA_B(unique_indices, index_t, 1, 32));
487+
}));
488+
}
477489

478490
Tensor output_lengths = at::zeros({total_B}, offsets.options());
479491
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "unique_indices_length", ([&] {

0 commit comments

Comments
 (0)