Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 99 additions & 53 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ using Tensor = at::Tensor;

namespace fbgemm_gpu {

// Block size for the flat-grid kernels in this file. 256 = 8 warps, chosen
// so the SM scheduler can pack ~4-8 blocks per SM on H100 given the
// flat-grid launches with grid = total_B (or num_unique).
static constexpr int32_t kFlatBlockSize = 256;

// Linearzie the index with the cumsum of hash size so that linearized indices
// can be sorted together.
template <typename index_t>
Expand Down Expand Up @@ -79,69 +84,112 @@ __global__ __launch_bounds__(kMaxThreads) void delinearize_unique_index_kernel(
}
}

// Compute the lengths for each feature in the unique indices. The range of
// indices for each feature equals to the difference between the max and min
// values in the reverse index array.
template <typename index_t, auto max_value, auto min_value>
__global__ __launch_bounds__(kMaxThreads) void unique_indices_length_kernel(
// Device-side lower_bound over a PackedTensorAccessor32<index_t, 1>.
// Returns the first position whose value is >= `value`, equivalent to
// std::lower_bound on the underlying sorted array.
template <typename index_t>
__device__ __forceinline__ int32_t device_lower_bound(
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>& arr,
const index_t value) {
int32_t lo = 0;
int32_t hi = arr.size(0);
while (lo < hi) {
const int32_t mid = lo + ((hi - lo) >> 1);
if (arr[mid] < value) {
lo = mid + 1;
} else {
hi = mid;
}
}
return lo;
}

// Compute the per-(feature, batch) lengths for the unique indices.
//
// Caller-provided invariant (see jagged_unique_indices_cuda for the
// pipeline contract that establishes it): `linear_unique_indices` is
// sorted ascending, and feature t's values occupy a contiguous slice
// [lower_bound(linear_unique_indices, hash_size_cumsum[t]),
// lower_bound(linear_unique_indices, hash_size_cumsum[t+1])).
// The slice length equals num_unique_t for feature t.
template <typename index_t>
__global__ __launch_bounds__(kFlatBlockSize) void unique_indices_length_kernel(
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
hash_size_offsets,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
reverse_index,
hash_size_cumsum,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
offsets,
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> lengths) {
typedef cub::BlockReduce<index_t, kMaxThreads> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage_max;
__shared__ typename BlockReduce::TempStorage temp_storage_min;
__shared__ index_t block_results[2];

linear_unique_indices,
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> lengths,
const int32_t batch_size) {
const auto tid = threadIdx.x;
const auto bid = blockIdx.x;
const auto num_blocks = gridDim.x;
const int32_t batch_size = (offsets.size(0) - 1) / num_blocks;

const auto offset_begin = hash_size_offsets[bid] * batch_size;
const auto offset_end = hash_size_offsets[bid + 1] * batch_size;
const auto num_lengths = (offset_end - offset_begin);

const auto reverse_index_begin = offsets[offset_begin];
const auto reverse_index_end = offsets[offset_end];

if (reverse_index_begin == reverse_index_end) {
const auto hash_begin = hash_size_offsets[bid];
const auto hash_end = hash_size_offsets[bid + 1];
const auto offset_begin = hash_begin * batch_size;
const auto offset_end = hash_end * batch_size;
const auto num_lengths = offset_end - offset_begin;
if (num_lengths == 0) {
return;
}

index_t t_max = min_value;
index_t t_min = max_value;
for (index_t i = (reverse_index_begin + tid); i < reverse_index_end;
i += kMaxThreads) {
const index_t value = reverse_index[i];
t_max = (value > t_max) ? value : t_max;
t_min = (value < t_min) ? value : t_min;
}

index_t block_max =
BlockReduce(temp_storage_max).Reduce(t_max, Max<index_t>());
index_t block_min =
BlockReduce(temp_storage_min).Reduce(t_min, Min<index_t>());
__shared__ index_t s_div_length;
__shared__ index_t s_r_length;
if (tid == 0) {
block_results[0] = block_max;
block_results[1] = block_min;
const auto low = hash_size_cumsum[hash_begin];
const auto high = hash_size_cumsum[hash_end];
if (low == high) {
// Empty feature group. Output is pre-zeroed by at::zeros at the
// launch site; nothing to write.
s_div_length = 0;
s_r_length = 0;
} else {
const int32_t lo_pos =
device_lower_bound<index_t>(linear_unique_indices, low);
const int32_t hi_pos =
device_lower_bound<index_t>(linear_unique_indices, high);
const index_t total_length = static_cast<index_t>(hi_pos - lo_pos);
s_div_length = total_length / static_cast<index_t>(num_lengths);
s_r_length = total_length % static_cast<index_t>(num_lengths);
}
}
__syncthreads();

t_max = block_results[0];
t_min = block_results[1];
const index_t total_length = (t_max - t_min) + 1;
const index_t div_length = total_length / num_lengths;
const index_t r_length = total_length % num_lengths;
for (int32_t i = tid; i < num_lengths; i += kMaxThreads) {
index_t seg_length = (i < r_length) ? (div_length + 1) : div_length;
const index_t div_length = s_div_length;
const index_t r_length = s_r_length;
if (div_length == 0 && r_length == 0) {
return;
}
for (int32_t i = tid; i < num_lengths; i += blockDim.x) {
const index_t seg_length =
(static_cast<index_t>(i) < r_length) ? (div_length + 1) : div_length;
lengths[offset_begin + i] = seg_length;
}
}

// Pipeline (cross-kernel data flow that ties the four steps together):
//
// 1. linearize_index_wo_infos_kernel writes
// linear_indices[i] = hash_size_cumsum[t] + indices[i]
// so feature t's linearized values lie in
// [hash_size_cumsum[t], hash_size_cumsum[t+1]).
//
// 2. at::_unique(linear_indices, sorted=True, return_inverse=True)
// returns (linear_unique_indices, reverse_index) where
// linear_unique_indices is sorted ascending. Combined with (1), this
// means feature t's unique linearized values occupy a contiguous
// slice of linear_unique_indices.
//
// 3. delinearize_unique_index_kernel scatters the original
// (pre-linearization) per-feature index values back into
// unique_indices via reverse_index.
//
// 4. unique_indices_length_kernel relies on (1)+(2) to compute
// num_unique per feature group via two binary searches over
// linear_unique_indices, instead of an O(N) reduction over
// reverse_index. See the kernel's docstring for the local form of
// the invariant.
std::tuple<Tensor, Tensor, Tensor, Tensor> jagged_unique_indices_cuda(
const Tensor& hash_size_cumsum,
const Tensor& hash_size_offsets,
Expand Down Expand Up @@ -192,18 +240,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> jagged_unique_indices_cuda(
Tensor output_lengths = at::zeros({total_B}, offsets.options());
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "unique_indices_length", ([&] {
FBGEMM_LAUNCH_KERNEL(
(unique_indices_length_kernel<
index_t,
std::numeric_limits<index_t>::max(),
std::numeric_limits<index_t>::min()>),
(unique_indices_length_kernel<index_t>),
T,
kMaxThreads,
kFlatBlockSize,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(hash_size_offsets, index_t, 1, 32),
PTA_B(reverse_index, index_t, 1, 32),
PTA_B(offsets, index_t, 1, 32),
PTA_B(output_lengths, index_t, 1, 32));
PTA_B(hash_size_cumsum, index_t, 1, 32),
PTA_B(linear_unique_indices, index_t, 1, 32),
PTA_B(output_lengths, index_t, 1, 32),
static_cast<int32_t>(total_B / T));
}));

Tensor output_offsets;
Expand Down
8 changes: 8 additions & 0 deletions fbgemm_gpu/test/jagged/failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@
"comment": "",
"status": "xfail"
},
"UniqueIndicesTest.test_aot_dispatch_dynamic__test_jagged_unique_indices_zch_huge_hash_size": {
"comment": "Test uses .item()/.tolist() for host-side comparison; incompatible with dynamic dispatch.",
"status": "xfail"
},
"UniqueIndicesTest.test_faketensor__test_jagged_unique_indices": {
"comment": "",
"status": "xfail"
Expand All @@ -118,6 +122,10 @@
"UniqueIndicesTest.test_faketensor__test_jagged_unique_indices_multi_keys": {
"comment": "",
"status": "xfail"
},
"UniqueIndicesTest.test_faketensor__test_jagged_unique_indices_zch_huge_hash_size": {
"comment": "Test uses .item()/.tolist() for host-side comparison; incompatible with fake tensors.",
"status": "xfail"
}
},
"fbgemm::keyed_jagged_index_select_dim1": {},
Expand Down
68 changes: 68 additions & 0 deletions fbgemm_gpu/test/jagged/unique_indices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,74 @@ def test_jagged_unique_indices_empty(
self.assertEqual(torch.sum(output_lengths).item(), 0)
self.assertEqual(torch.sum(output_offsets).item(), 0)

@unittest.skipIf(*gpu_unavailable)
def test_jagged_unique_indices_zch_huge_hash_size(self) -> None:
"""Exercise the op with a hash_size_cumsum entry at INT64_MAX -
the shape produced by ZCH callers that leave per-feature hash size
unbounded. The op must handle hash boundaries spanning the full
int64 range without overflow in any internal arithmetic.
"""
T = 2
B = 64
max_length = 5
int64_max = torch.iinfo(torch.int64).max
hash_size_cumsum_list = [0, 0, int64_max]
hash_size_offsets_list = [0, 0, 2]
# Per-feature linearized values lie in [0, INT64_MAX). The kernels
# under test are boundary-value-sensitive on hash_size_cumsum, not
# on the indices themselves, so a small index range is sufficient
# and keeps the reference comparison fast.
per_feature_value_cap = 1024
lengths_list: list[int] = []
indices_list: list[int] = []
for _ in range(T):
for _ in range(B):
length = random.randint(0, max_length)
lengths_list.append(length)
if length > 0:
indices_list.extend(
np.random.randint(
0, per_feature_value_cap, size=length
).tolist()
)

device = torch.accelerator.current_accelerator()
assert device is not None
dtype = torch.int64
hash_size_cumsum = torch.as_tensor(
hash_size_cumsum_list, dtype=dtype, device=device
)
hash_size_offsets = torch.as_tensor(
hash_size_offsets_list, dtype=dtype, device=device
)
lengths = torch.as_tensor(lengths_list, dtype=dtype, device=device)
indices = torch.as_tensor(indices_list, dtype=dtype, device=device)
offsets = torch.zeros(T * B + 1, dtype=dtype, device=device)
offsets[1:] = torch.cumsum(lengths, dim=0)

(
output_lengths,
output_offsets,
unique_indices,
reverse_index,
) = torch.ops.fbgemm.jagged_unique_indices(
hash_size_cumsum, hash_size_offsets, offsets, indices
)

# Both features share the same hash space (hash_offset = 0 for
# both, since hash_size_cumsum[0] == hash_size_cumsum[1] == 0),
# so the global unique set is the union of all input indices.
expected_unique = sorted(set(indices_list))
self.assertEqual(unique_indices.numel(), len(expected_unique))
self.assertEqual(int(torch.sum(output_lengths).item()), unique_indices.numel())
# Inverse-index round-trip: unique_indices[reverse_index[i]] == indices[i].
rev_list = reverse_index.tolist()
uniq_list = unique_indices.tolist()
self.assertEqual(len(rev_list), len(indices_list))
for i, rev in enumerate(rev_list):
self.assertTrue(0 <= rev < len(uniq_list))
self.assertEqual(uniq_list[rev], indices_list[i])

@given(
num_elements=st.integers(min_value=100, max_value=10000),
num_unique_indices=st.integers(min_value=5, max_value=100),
Expand Down
Loading