Add INT8 (uint8) support to FBGEMM index_select_dim0 forward#5782
Closed
q10 wants to merge 1 commit into
Closed
Conversation
Summary: Adds INT8 (uint8/Byte) type support to the index_select_dim0 GPU forward path for UMIA inference use cases. The kernel template already handles arbitrary scalar_t types for the gather/copy operation, so this change extends the type dispatch macro from FBGEMM_DISPATCH_FLOAT_AND_HALF to FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE which adds at::ScalarType::Byte (uint8_t). For non-floating-point types (e.g., INT8), the autograd wrapper is bypassed since gradient computation is meaningless for integer types. The forward call goes directly to index_select_cuda without index sorting overhead, as INT8 is expected to be used in inference-only scenarios. ## Detailed Changes ### `fbcode/deeplearning/fbgemm/fbgemm_gpu/src/sparse_ops/sparse_index_select.cu` - Swaps the dispatch macro on `index_add_2d_kernel_2` from `FBGEMM_DISPATCH_FLOAT_AND_HALF` to `FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE`, instantiating the existing scalar-generic gather/copy kernel template for `at::ScalarType::Byte` (`uint8_t`) in addition to `float` and `half`. - No kernel-logic change — the kernel was already templated on `scalar_t`, so the addition is purely a type-dispatch extension. ### `fbcode/deeplearning/fbgemm/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp` - In `index_select_dim0_gpu`, adds an early-return fast path when `input.scalar_type() == at::kByte`: - Validates `indices` is on the same CUDA device as `input` and is 1-D. - Calls `index_select_cuda(input, indices, /*orig_indices=*/empty_long_tensor, /*indices_sorted=*/false)` directly, bypassing `IndexSelectDim0GPUOp::apply` (the autograd-aware wrapper). - This is correct because integer dtypes have no gradient; the autograd `Function` wrapper is meaningless for INT8. - `consecutive_range_start`, `consecutive_range_length`, and `skip_indices_sorting_fwd` are intentionally ignored on this path — they only affect (a) the unused backward and (b) the index-sort optimization, both of which are skipped for inference-only INT8 inputs. ### `fbcode/deeplearning/fbgemm/fbgemm_gpu/test/sparse/index_select_test.py` - Extends the `dtype` Hypothesis sampler in `IndexSelectTest.test_index_select_dim0` from `[torch.float, torch.half]` to `[torch.float, torch.half, torch.uint8]`. - Generates input via `torch.randint(0, 256, ...)` for non-floating-point dtypes (since `torch.rand` only supports floating dtypes); keeps `torch.rand` for float/half. - Tightens the equality check to `atol=0, rtol=0` — gather/copy is bit-exact, so any tolerance would mask correctness regressions. - Skips the gradcheck block entirely when `dtype` is non-floating-point (autograd is not applicable for INT8). Reviewed By: spcyppt Differential Revision: D103495542
Contributor
|
@q10 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D103495542. |
Contributor
|
This pull request has been merged in 0b8730f. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Adds INT8 (uint8/Byte) type support to the index_select_dim0 GPU forward path for UMIA inference use cases. The kernel template already handles arbitrary scalar_t types for the gather/copy operation, so this change extends the type dispatch macro from FBGEMM_DISPATCH_FLOAT_AND_HALF to FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE which adds at::ScalarType::Byte (uint8_t).
For non-floating-point types (e.g., INT8), the autograd wrapper is bypassed since gradient computation is meaningless for integer types. The forward call goes directly to index_select_cuda without index sorting overhead, as INT8 is expected to be used in inference-only scenarios.
Detailed Changes
fbcode/deeplearning/fbgemm/fbgemm_gpu/src/sparse_ops/sparse_index_select.cuindex_add_2d_kernel_2fromFBGEMM_DISPATCH_FLOAT_AND_HALFtoFBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE, instantiating the existing scalar-generic gather/copy kernel template forat::ScalarType::Byte(uint8_t) in addition tofloatandhalf.scalar_t, so the addition is purely a type-dispatch extension.fbcode/deeplearning/fbgemm/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cppindex_select_dim0_gpu, adds an early-return fast path wheninput.scalar_type() == at::kByte:indicesis on the same CUDA device asinputand is 1-D.index_select_cuda(input, indices, /*orig_indices=*/empty_long_tensor, /*indices_sorted=*/false)directly, bypassingIndexSelectDim0GPUOp::apply(the autograd-aware wrapper).Functionwrapper is meaningless for INT8.consecutive_range_start,consecutive_range_length, andskip_indices_sorting_fwdare intentionally ignored on this path — they only affect (a) the unused backward and (b) the index-sort optimization, both of which are skipped for inference-only INT8 inputs.fbcode/deeplearning/fbgemm/fbgemm_gpu/test/sparse/index_select_test.pydtypeHypothesis sampler inIndexSelectTest.test_index_select_dim0from[torch.float, torch.half]to[torch.float, torch.half, torch.uint8].torch.randint(0, 256, ...)for non-floating-point dtypes (sincetorch.randonly supports floating dtypes); keepstorch.randfor float/half.atol=0, rtol=0— gather/copy is bit-exact, so any tolerance would mask correctness regressions.dtypeis non-floating-point (autograd is not applicable for INT8).Reviewed By: spcyppt
Differential Revision: D103495542