Skip to content

Commit 0b8730f

Browse files
q10meta-codesync[bot]
authored andcommitted
Add INT8 (uint8) support to FBGEMM index_select_dim0 forward (#5782)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2712 Pull Request resolved: #5782 Adds INT8 (uint8/Byte) type support to the index_select_dim0 GPU forward path for UMIA inference use cases. The kernel template already handles arbitrary scalar_t types for the gather/copy operation, so this change extends the type dispatch macro from FBGEMM_DISPATCH_FLOAT_AND_HALF to FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE which adds at::ScalarType::Byte (uint8_t). For non-floating-point types (e.g., INT8), the autograd wrapper is bypassed since gradient computation is meaningless for integer types. The forward call goes directly to index_select_cuda without index sorting overhead, as INT8 is expected to be used in inference-only scenarios. ## Detailed Changes ### `fbcode/deeplearning/fbgemm/fbgemm_gpu/src/sparse_ops/sparse_index_select.cu` - Swaps the dispatch macro on `index_add_2d_kernel_2` from `FBGEMM_DISPATCH_FLOAT_AND_HALF` to `FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE`, instantiating the existing scalar-generic gather/copy kernel template for `at::ScalarType::Byte` (`uint8_t`) in addition to `float` and `half`. - No kernel-logic change — the kernel was already templated on `scalar_t`, so the addition is purely a type-dispatch extension. ### `fbcode/deeplearning/fbgemm/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp` - In `index_select_dim0_gpu`, adds an early-return fast path when `input.scalar_type() == at::kByte`: - Validates `indices` is on the same CUDA device as `input` and is 1-D. - Calls `index_select_cuda(input, indices, /*orig_indices=*/empty_long_tensor, /*indices_sorted=*/false)` directly, bypassing `IndexSelectDim0GPUOp::apply` (the autograd-aware wrapper). - This is correct because integer dtypes have no gradient; the autograd `Function` wrapper is meaningless for INT8. - `consecutive_range_start`, `consecutive_range_length`, and `skip_indices_sorting_fwd` are intentionally ignored on this path — they only affect (a) the unused backward and (b) the index-sort optimization, both of which are skipped for inference-only INT8 inputs. ### `fbcode/deeplearning/fbgemm/fbgemm_gpu/test/sparse/index_select_test.py` - Extends the `dtype` Hypothesis sampler in `IndexSelectTest.test_index_select_dim0` from `[torch.float, torch.half]` to `[torch.float, torch.half, torch.uint8]`. - Generates input via `torch.randint(0, 256, ...)` for non-floating-point dtypes (since `torch.rand` only supports floating dtypes); keeps `torch.rand` for float/half. - Tightens the equality check to `atol=0, rtol=0` — gather/copy is bit-exact, so any tolerance would mask correctness regressions. - Skips the gradcheck block entirely when `dtype` is non-floating-point (autograd is not applicable for INT8). Reviewed By: spcyppt Differential Revision: D103495542 fbshipit-source-id: c6ec6b5d5be02ad34f790d9e6904bbb09e39f7f9
1 parent 0fcdfc0 commit 0b8730f

5 files changed

Lines changed: 83 additions & 18 deletions

File tree

fbgemm_gpu/include/fbgemm_gpu/utils/dispatch_macros.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -329,12 +329,13 @@
329329
AT_DISPATCH_SWITCH( \
330330
TYPE, NAME, FBGEMM_DISPATCH_FLOAT_AND_HALF_CASE(__VA_ARGS__))
331331

332-
#define FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, NAME, ...) \
333-
AT_DISPATCH_SWITCH( \
334-
TYPE, \
335-
NAME, \
336-
FBGEMM_DISPATCH_FLOAT_AND_HALF_CASE(__VA_ARGS__) \
337-
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__))
332+
#define FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, NAME, ...) \
333+
AT_DISPATCH_SWITCH( \
334+
TYPE, \
335+
NAME, \
336+
FBGEMM_DISPATCH_FLOAT_AND_HALF_CASE(__VA_ARGS__) \
337+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
338+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__))
338339

339340
#define FBGEMM_DISPATCH_FLOAT_HALF_FP8_AND_BYTE(TYPE, NAME, ...) \
340341
AT_DISPATCH_SWITCH( \

fbgemm_gpu/src/sparse_ops/sparse_index_select.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ DLL_PUBLIC Tensor index_select_cuda(
9696
}
9797

9898
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "index_add_2d_kernel_1", [&] {
99-
FBGEMM_DISPATCH_FLOAT_AND_HALF(
99+
FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
100100
input_reshaped.scalar_type(), "index_add_2d_kernel_2", [&] {
101101
if (indices_sorted) {
102102
LAUNCH_INDEX_SELECT(true)

fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -653,21 +653,62 @@ std::tuple<Tensor, std::optional<Tensor>> pack_segments_cuda_v2(
653653
t_in, lengths, max_length, pad_minf, return_presence_mask);
654654
}
655655

656+
namespace {
657+
658+
// Helper to prepare indices for index_select operation.
659+
// Returns a tuple of (indices_to_use, orig_indices, indices_sorted).
660+
// If skip_indices_sorting_fwd is true and not in inference mode, returns the
661+
// original indices with empty orig_indices and indices_sorted=false.
662+
// Otherwise, sorts indices and returns (sorted_indices, orig_indices, true).
663+
std::tuple<Tensor, Tensor, bool> prepare_index_select_indices(
664+
const Tensor& indices,
665+
std::optional<bool> skip_indices_sorting_fwd) {
666+
const bool skip_sort = skip_indices_sorting_fwd.value_or(false) &&
667+
!c10::InferenceMode::is_enabled();
668+
669+
if (skip_sort) {
670+
return {indices, at::empty({0}, indices.options().dtype(at::kLong)), false};
671+
}
672+
Tensor sorted_indices, orig_indices;
673+
std::tie(sorted_indices, orig_indices) = indices.sort();
674+
return {sorted_indices, orig_indices, true};
675+
}
676+
677+
} // namespace
678+
656679
Tensor index_select_dim0_gpu(
657680
const Tensor& input,
658681
const Tensor& indices,
659682
std::optional<int64_t> consecutive_range_start,
660683
std::optional<int64_t> consecutive_range_length,
661684
std::optional<bool> skip_indices_sorting_fwd) {
662-
bool user_skip_indices_sorting_fwd =
663-
skip_indices_sorting_fwd ? *skip_indices_sorting_fwd : false;
685+
// 8-bit integer dtypes (uint8/Byte and int8/Char) are inference-only and do
686+
// not support autograd, so we bypass IndexSelectDim0GPUOp::apply (which
687+
// wires up the autograd Function) and call index_select_cuda directly.
688+
// consecutive_range_start and consecutive_range_length are intentionally
689+
// ignored on this path — they optimize the backward pass for consecutive
690+
// indices, but integer dtypes have no backward pass (no gradients).
691+
if (input.scalar_type() == at::kByte || input.scalar_type() == at::kChar) {
692+
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input, indices);
693+
TORCH_CHECK_VALUE(
694+
indices.dim() == 1, "Index tensor must be 1D, but got ", indices.dim());
695+
696+
auto [indices_to_use, orig_indices, indices_sorted] =
697+
prepare_index_select_indices(indices, skip_indices_sorting_fwd);
698+
return index_select_cuda(
699+
input, indices_to_use, orig_indices, indices_sorted);
700+
}
701+
664702
return IndexSelectDim0GPUOp::apply(
665703
input,
666704
indices,
667-
consecutive_range_start ? *consecutive_range_start : 0,
668-
consecutive_range_length ? *consecutive_range_length : 0,
669-
// Always skip indices sorting if doing forward only
670-
user_skip_indices_sorting_fwd && !c10::InferenceMode::is_enabled())[0];
705+
consecutive_range_start.value_or(0),
706+
consecutive_range_length.value_or(0),
707+
// Sorting is skipped only when the user requested it AND we are NOT in
708+
// inference mode. In inference mode we always sort for cache-friendlier
709+
// gathers.
710+
skip_indices_sorting_fwd.value_or(false) &&
711+
!c10::InferenceMode::is_enabled())[0];
671712
}
672713

673714
} // namespace fbgemm_gpu

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,17 @@ vec4_copy(uint8_t* dst, const uint8_t* src, const int32_t D) {
7474
}
7575
}
7676

77+
template <>
78+
DEVICE_INLINE void vec4_copy(int8_t* dst, const int8_t* src, const int32_t D) {
79+
// each row is padded with row_alignment (16 bytes on GPUs), so each row will
80+
// be multiple of 16 bytes (uint4 = 32bit x 4 = 16 bytes).
81+
const uint4* __restrict__ src_ = reinterpret_cast<const uint4*>(src);
82+
uint4* __restrict__ dst_ = reinterpret_cast<uint4*>(dst);
83+
for (auto d = threadIdx.x; d * sizeof(uint4) < D; d += blockDim.x) {
84+
dst_[d] = src_[d];
85+
}
86+
}
87+
7788
template <typename value_t, typename index_t, bool is_index_put>
7889
__global__ __launch_bounds__(kMaxThreads) void masked_index_kernel(
7990
pta::PackedTensorAccessor64<value_t, 2, at::RestrictPtrTraits> self,
@@ -144,7 +155,9 @@ Tensor masked_index_impl(
144155
is_index_put ? "masked_index_put" : "masked_index_select",
145156
[&] {
146157
using value_t = scalar_t;
147-
if constexpr (std::is_same_v<value_t, uint8_t>) {
158+
if constexpr (
159+
std::is_same_v<value_t, uint8_t> ||
160+
std::is_same_v<value_t, int8_t>) {
148161
TORCH_CHECK(D % 16 == 0, "D needs to be padded to be multiple of 16");
149162
}
150163
FBGEMM_DISPATCH_INTEGRAL_TYPES(

fbgemm_gpu/test/sparse/index_select_test.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class IndexSelectTest(unittest.TestCase):
3838
st.lists(st.integers(1, 128), max_size=1),
3939
st.lists(st.integers(1, 16), min_size=2, max_size=2),
4040
),
41-
dtype=st.sampled_from([torch.float, torch.half]),
41+
dtype=st.sampled_from([torch.float, torch.half, torch.uint8, torch.int8]),
4242
use_cpu=st.booleans() if gpu_available else st.just(True),
4343
consecutive_indices=st.booleans(),
4444
skip_indices_sorting_fwd=st.booleans(),
@@ -76,15 +76,25 @@ def test_index_select_dim0(
7676

7777
kwargs["skip_indices_sorting_fwd"] = skip_indices_sorting_fwd
7878

79-
input = torch.rand((U,) + tuple(shape), dtype=dtype, device=device)
79+
if dtype.is_floating_point:
80+
input = torch.rand((U,) + tuple(shape), dtype=dtype, device=device)
81+
else:
82+
iinfo = torch.iinfo(dtype)
83+
input = torch.randint(
84+
iinfo.min,
85+
iinfo.max + 1,
86+
(U,) + tuple(shape),
87+
dtype=dtype,
88+
device=device,
89+
)
8090

8191
with torch.inference_mode() if use_inference_mode else contextlib.nullcontext():
8292
output_ref = torch.ops.fbgemm.index_select_dim0(input, indices, **kwargs)
8393
output = torch.index_select(input, 0, indices)
8494

85-
torch.testing.assert_close(output, output_ref)
95+
torch.testing.assert_close(output, output_ref, atol=0, rtol=0)
8696

87-
if not use_inference_mode:
97+
if not use_inference_mode and dtype.is_floating_point:
8898
gradcheck_args = [
8999
input.clone().detach().float().requires_grad_(True),
90100
indices,

0 commit comments

Comments
 (0)