Skip to content

Add INT8 (uint8) support to FBGEMM index_select_dim0 forward#5782

Closed
q10 wants to merge 1 commit into
pytorch:mainfrom
q10:export-D103495542
Closed

Add INT8 (uint8) support to FBGEMM index_select_dim0 forward#5782
q10 wants to merge 1 commit into
pytorch:mainfrom
q10:export-D103495542

Conversation

@q10
Copy link
Copy Markdown
Contributor

@q10 q10 commented May 26, 2026

Summary:
Adds INT8 (uint8/Byte) type support to the index_select_dim0 GPU forward path for UMIA inference use cases. The kernel template already handles arbitrary scalar_t types for the gather/copy operation, so this change extends the type dispatch macro from FBGEMM_DISPATCH_FLOAT_AND_HALF to FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE which adds at::ScalarType::Byte (uint8_t).

For non-floating-point types (e.g., INT8), the autograd wrapper is bypassed since gradient computation is meaningless for integer types. The forward call goes directly to index_select_cuda without index sorting overhead, as INT8 is expected to be used in inference-only scenarios.

Detailed Changes

fbcode/deeplearning/fbgemm/fbgemm_gpu/src/sparse_ops/sparse_index_select.cu

  • Swaps the dispatch macro on index_add_2d_kernel_2 from FBGEMM_DISPATCH_FLOAT_AND_HALF to FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE, instantiating the existing scalar-generic gather/copy kernel template for at::ScalarType::Byte (uint8_t) in addition to float and half.
  • No kernel-logic change — the kernel was already templated on scalar_t, so the addition is purely a type-dispatch extension.

fbcode/deeplearning/fbgemm/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp

  • In index_select_dim0_gpu, adds an early-return fast path when input.scalar_type() == at::kByte:
    • Validates indices is on the same CUDA device as input and is 1-D.
    • Calls index_select_cuda(input, indices, /*orig_indices=*/empty_long_tensor, /*indices_sorted=*/false) directly, bypassing IndexSelectDim0GPUOp::apply (the autograd-aware wrapper).
    • This is correct because integer dtypes have no gradient; the autograd Function wrapper is meaningless for INT8.
    • consecutive_range_start, consecutive_range_length, and skip_indices_sorting_fwd are intentionally ignored on this path — they only affect (a) the unused backward and (b) the index-sort optimization, both of which are skipped for inference-only INT8 inputs.

fbcode/deeplearning/fbgemm/fbgemm_gpu/test/sparse/index_select_test.py

  • Extends the dtype Hypothesis sampler in IndexSelectTest.test_index_select_dim0 from [torch.float, torch.half] to [torch.float, torch.half, torch.uint8].
  • Generates input via torch.randint(0, 256, ...) for non-floating-point dtypes (since torch.rand only supports floating dtypes); keeps torch.rand for float/half.
  • Tightens the equality check to atol=0, rtol=0 — gather/copy is bit-exact, so any tolerance would mask correctness regressions.
  • Skips the gradcheck block entirely when dtype is non-floating-point (autograd is not applicable for INT8).

Reviewed By: spcyppt

Differential Revision: D103495542

Summary:
Adds INT8 (uint8/Byte) type support to the index_select_dim0 GPU forward path for UMIA inference use cases. The kernel template already handles arbitrary scalar_t types for the gather/copy operation, so this change extends the type dispatch macro from FBGEMM_DISPATCH_FLOAT_AND_HALF to FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE which adds at::ScalarType::Byte (uint8_t).

For non-floating-point types (e.g., INT8), the autograd wrapper is bypassed since gradient computation is meaningless for integer types. The forward call goes directly to index_select_cuda without index sorting overhead, as INT8 is expected to be used in inference-only scenarios.

## Detailed Changes

### `fbcode/deeplearning/fbgemm/fbgemm_gpu/src/sparse_ops/sparse_index_select.cu`
- Swaps the dispatch macro on `index_add_2d_kernel_2` from `FBGEMM_DISPATCH_FLOAT_AND_HALF` to `FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE`, instantiating the existing scalar-generic gather/copy kernel template for `at::ScalarType::Byte` (`uint8_t`) in addition to `float` and `half`.
- No kernel-logic change — the kernel was already templated on `scalar_t`, so the addition is purely a type-dispatch extension.

### `fbcode/deeplearning/fbgemm/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp`
- In `index_select_dim0_gpu`, adds an early-return fast path when `input.scalar_type() == at::kByte`:
  - Validates `indices` is on the same CUDA device as `input` and is 1-D.
  - Calls `index_select_cuda(input, indices, /*orig_indices=*/empty_long_tensor, /*indices_sorted=*/false)` directly, bypassing `IndexSelectDim0GPUOp::apply` (the autograd-aware wrapper).
  - This is correct because integer dtypes have no gradient; the autograd `Function` wrapper is meaningless for INT8.
  - `consecutive_range_start`, `consecutive_range_length`, and `skip_indices_sorting_fwd` are intentionally ignored on this path — they only affect (a) the unused backward and (b) the index-sort optimization, both of which are skipped for inference-only INT8 inputs.

### `fbcode/deeplearning/fbgemm/fbgemm_gpu/test/sparse/index_select_test.py`
- Extends the `dtype` Hypothesis sampler in `IndexSelectTest.test_index_select_dim0` from `[torch.float, torch.half]` to `[torch.float, torch.half, torch.uint8]`.
- Generates input via `torch.randint(0, 256, ...)` for non-floating-point dtypes (since `torch.rand` only supports floating dtypes); keeps `torch.rand` for float/half.
- Tightens the equality check to `atol=0, rtol=0` — gather/copy is bit-exact, so any tolerance would mask correctness regressions.
- Skips the gradcheck block entirely when `dtype` is non-floating-point (autograd is not applicable for INT8).

Reviewed By: spcyppt

Differential Revision: D103495542
@meta-cla meta-cla Bot added the cla signed label May 26, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 26, 2026

@q10 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D103495542.

@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 27, 2026

This pull request has been merged in 0b8730f.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant