Skip to content

Commit 7c8843e

Browse files
AlbertDachiChenmeta-codesync[bot]
authored andcommitted
Speedup linearize_index via flat-grid kernel (#5768)
Summary: Pull Request resolved: #5768 X-link: https://github.com/facebookresearch/FBGEMM/pull/2696 linearize_index_wo_infos_kernel was launched with grid = ceil(total_B / kMaxThreads), kMaxThreads = 1024. On the prod IFR-MTML mc7 shape total_B is in the low thousands, so the launch consumed only ~5 SMs out of 132 on H100. Each warp also did intra-warp shuffle (`shfl_sync` x kWarpSize) to redistribute work among lanes, adding latency on top of the SM under-utilization. Bench measures this kernel at ~1.7 ms / call. Replace with a flat-grid kernel launched as grid = total_B (one block per (t, b) sample), threads = 256. Each block recovers (t, b) from blockIdx.x via the existing FixedDivisor argument and stripes the per-sample reads/writes across its threads. No shuffle, full SM utilization. Bench measures the new kernel at ~140 us / call (~12x speedup on this kernel). No public API change. Output dtype is unchanged (still index_t); downstream kernels (at::_unique, delinearize, length kernel) consume the same buffer in the same format. Reviewed By: q10 Differential Revision: D105005594
1 parent 14b9ebc commit 7c8843e

1 file changed

Lines changed: 27 additions & 26 deletions

File tree

fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,17 @@ namespace fbgemm_gpu {
2727
// flat-grid launches with grid = total_B (or num_unique).
2828
static constexpr int32_t kFlatBlockSize = 256;
2929

30-
// Linearzie the index with the cumsum of hash size so that linearized indices
31-
// can be sorted together.
30+
// Linearize the index with the cumsum of hash size so that linearized indices
31+
// can be sorted together. Flat-grid: one block per (t, b) sample.
32+
//
33+
// Replaces the prior warp-cooperative kernel which was launched as
34+
// grid = ceil(total_B / kMaxThreads)
35+
// On production shapes total_B is in the low thousands and kMaxThreads = 1024,
36+
// so the prior launch consumed only ~5 SMs out of 132 on H100 with each warp
37+
// shuffling work between lanes. The flat grid uses one block per sample,
38+
// dispatching all SMs and removing the intra-warp shuffle dance.
3239
template <typename index_t>
33-
__global__ __launch_bounds__(kMaxThreads) void linearize_index_wo_infos_kernel(
40+
__global__ __launch_bounds__(kFlatBlockSize) void linearize_index_flat_kernel(
3441
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
3542
hash_size_cumsum,
3643
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
@@ -40,27 +47,21 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_wo_infos_kernel(
4047
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
4148
linear_indices,
4249
FixedDivisor fd) {
43-
const auto b_t = blockIdx.x * blockDim.x + threadIdx.x;
50+
const auto b_t = blockIdx.x;
4451
int32_t b;
4552
int32_t t;
46-
const auto total_B = offsets.size(0) - 1;
47-
const auto valid = b_t < total_B;
48-
49-
fd.DivMod(b_t, &t, &b);
50-
51-
const auto hash_offset = valid ? hash_size_cumsum[t] : -1;
52-
const auto indices_start = valid ? offsets[b_t] : -1;
53-
const int32_t L = valid ? offsets[b_t + 1] - indices_start : 0;
54-
const auto lane_id = threadIdx.x % fbgemm_gpu::kWarpSize;
55-
56-
for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) {
57-
const auto indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j);
58-
const auto L_warp = fbgemm_gpu::shfl_sync(L, j);
59-
const auto hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j);
60-
for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) {
61-
const auto idx = __ldg(&indices[indices_start_warp + i]);
62-
linear_indices[indices_start_warp + i] = hash_offset_warp + idx;
63-
}
53+
fd.DivMod(static_cast<int32_t>(b_t), &t, &b);
54+
55+
const auto indices_start = offsets[b_t];
56+
const auto L = offsets[b_t + 1] - indices_start;
57+
if (L == 0) {
58+
return;
59+
}
60+
const auto hash_offset = hash_size_cumsum[t];
61+
62+
for (auto i = threadIdx.x; i < L; i += blockDim.x) {
63+
const auto idx = __ldg(&indices[indices_start + i]);
64+
linear_indices[indices_start + i] = hash_offset + idx;
6465
}
6566
}
6667

@@ -170,7 +171,7 @@ __global__ __launch_bounds__(kFlatBlockSize) void unique_indices_length_kernel(
170171

171172
// Pipeline (cross-kernel data flow that ties the four steps together):
172173
//
173-
// 1. linearize_index_wo_infos_kernel writes
174+
// 1. linearize_index_flat_kernel writes
174175
// linear_indices[i] = hash_size_cumsum[t] + indices[i]
175176
// so feature t's linearized values lie in
176177
// [hash_size_cumsum[t], hash_size_cumsum[t+1]).
@@ -204,9 +205,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> jagged_unique_indices_cuda(
204205

205206
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "linearize_index", ([&] {
206207
FBGEMM_LAUNCH_KERNEL(
207-
(linearize_index_wo_infos_kernel<index_t>),
208-
div_round_up(total_B, kMaxThreads),
209-
kMaxThreads,
208+
(linearize_index_flat_kernel<index_t>),
209+
total_B,
210+
kFlatBlockSize,
210211
0,
211212
at::cuda::getCurrentCUDAStream(),
212213
PTA_B(hash_size_cumsum, index_t, 1, 32),

0 commit comments

Comments
 (0)