You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add INT8 (uint8) support to FBGEMM index_select_dim0 forward (#5782)
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2712
Pull Request resolved: #5782
Adds INT8 (uint8/Byte) type support to the index_select_dim0 GPU forward path for UMIA inference use cases. The kernel template already handles arbitrary scalar_t types for the gather/copy operation, so this change extends the type dispatch macro from FBGEMM_DISPATCH_FLOAT_AND_HALF to FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE which adds at::ScalarType::Byte (uint8_t).
For non-floating-point types (e.g., INT8), the autograd wrapper is bypassed since gradient computation is meaningless for integer types. The forward call goes directly to index_select_cuda without index sorting overhead, as INT8 is expected to be used in inference-only scenarios.
## Detailed Changes
### `fbcode/deeplearning/fbgemm/fbgemm_gpu/src/sparse_ops/sparse_index_select.cu`
- Swaps the dispatch macro on `index_add_2d_kernel_2` from `FBGEMM_DISPATCH_FLOAT_AND_HALF` to `FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE`, instantiating the existing scalar-generic gather/copy kernel template for `at::ScalarType::Byte` (`uint8_t`) in addition to `float` and `half`.
- No kernel-logic change — the kernel was already templated on `scalar_t`, so the addition is purely a type-dispatch extension.
### `fbcode/deeplearning/fbgemm/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp`
- In `index_select_dim0_gpu`, adds an early-return fast path when `input.scalar_type() == at::kByte`:
- Validates `indices` is on the same CUDA device as `input` and is 1-D.
- Calls `index_select_cuda(input, indices, /*orig_indices=*/empty_long_tensor, /*indices_sorted=*/false)` directly, bypassing `IndexSelectDim0GPUOp::apply` (the autograd-aware wrapper).
- This is correct because integer dtypes have no gradient; the autograd `Function` wrapper is meaningless for INT8.
- `consecutive_range_start`, `consecutive_range_length`, and `skip_indices_sorting_fwd` are intentionally ignored on this path — they only affect (a) the unused backward and (b) the index-sort optimization, both of which are skipped for inference-only INT8 inputs.
### `fbcode/deeplearning/fbgemm/fbgemm_gpu/test/sparse/index_select_test.py`
- Extends the `dtype` Hypothesis sampler in `IndexSelectTest.test_index_select_dim0` from `[torch.float, torch.half]` to `[torch.float, torch.half, torch.uint8]`.
- Generates input via `torch.randint(0, 256, ...)` for non-floating-point dtypes (since `torch.rand` only supports floating dtypes); keeps `torch.rand` for float/half.
- Tightens the equality check to `atol=0, rtol=0` — gather/copy is bit-exact, so any tolerance would mask correctness regressions.
- Skips the gradcheck block entirely when `dtype` is non-floating-point (autograd is not applicable for INT8).
Reviewed By: spcyppt
Differential Revision: D103495542
fbshipit-source-id: c6ec6b5d5be02ad34f790d9e6904bbb09e39f7f9
0 commit comments