Skip to content

Commit f29805c

Browse files
danieldksayakpaul
andauthored
Support the Torch stable ABI (#575)
* kernel-builder: add `stable-abi` option This change adds an option `stable-abi` to `build.toml`. When setting ```toml [torch] stable-abi = "2.11" ``` The extension will be built targeting the Torch 2.11 stable ABI. In this case, the variant naming pattern will also change from `torch<version>` to `torch-stable-abi<abi-version>`. * nix-builder: wire up torch-stable-abi variants Wire up the `torch-stable-abi` variant naming. For a given ABI version, we build with the newest Torch version. * kernels: add support `torch-stable-abi` variants This change adds support for `torch-stable-abi` variants: * We now give preference to stable ABI variants. * If there is no exact match for the Torch version, we back off to an older ABI version. * examples: add Torch stable ABI example and test in CI * docs: document the `stable-abi` option * Documentation fixes Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent a008524 commit f29805c

41 files changed

Lines changed: 964 additions & 92 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/build_kernel.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ jobs:
5555
cutlass-gemm-tvm-ffi-kernel
5656
extra-data
5757
relu-kernel
58+
relu-torch-stable-abi-kernel
5859
relu-tvm-ffi-kernel
5960
relu-kernel-cpu
6061
relu-backprop-compile-kernel

.github/workflows/build_kernel_rocm.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
# For now we only test that there are no regressions in building ROCm
3535
# kernels. Also run tests once we have a ROCm runner.
3636
- name: Build relu kernel
37-
run: ( cd examples/kernels/relu && nix build .\#redistributable.torch211-cxx11-rocm71-x86_64-linux -L )
37+
run: ( cd examples/kernels/relu && nix build .\#redistributable.torch211-rocm71-x86_64-linux -L )
3838

3939
- name: Build relu kernel (compiler flags)
40-
run: ( cd examples/kernels/relu-compiler-flags && nix build .\#redistributable.torch211-cxx11-rocm71-x86_64-linux )
40+
run: ( cd examples/kernels/relu-compiler-flags && nix build .\#redistributable.torch211-rocm71-x86_64-linux )

.github/workflows/build_kernel_xpu.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ jobs:
3434
# For now we only test that there are no regressions in building XPU
3535
# kernels. Also run tests once we have a XPU runner.
3636
- name: Build relu kernel
37-
run: ( cd examples/kernels/relu && nix build .\#redistributable.torch211-cxx11-xpu20253-x86_64-linux -L )
37+
run: ( cd examples/kernels/relu && nix build .\#redistributable.torch211-xpu20253-x86_64-linux -L )
3838

3939
- name: Build relu tvm-ffi kernel
4040
run: ( cd examples/kernels/relu-tvm-ffi && nix build .\#redistributable.tvm-ffi01-xpu20253-x86_64-linux -L )
4141

4242
- name: Build relu kernel (compiler flags)
43-
run: ( cd examples/kernels/relu-compiler-flags && nix build .\#redistributable.torch211-cxx11-xpu20253-x86_64-linux )
43+
run: ( cd examples/kernels/relu-compiler-flags && nix build .\#redistributable.torch211-xpu20253-x86_64-linux )
4444

4545
- name: Build cutlass-gemm kernel
46-
run: ( cd examples/kernels/cutlass-gemm && nix build .\#redistributable.torch211-cxx11-xpu20253-x86_64-linux -L )
46+
run: ( cd examples/kernels/cutlass-gemm && nix build .\#redistributable.torch211-xpu20253-x86_64-linux -L )

docs/source/builder/writing-kernels.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,12 @@ options:
227227
non-compliant kernels if the version range does not correspond to the [required variants](build-variants.md).
228228
- `minver` (optional): only build for this Torch version and later. Use cautiously, since this option produces
229229
non-compliant kernels if the version range does not correspond to the [required variants](build-variants.md).
230+
- `stable-abi` (**experimental**): when set to a Torch version (e.g.
231+
`"2.11"`), the kernel is built using the Torch stable ABI. This
232+
requires that the kernel itself only use
233+
[stable ABI headers](https://docs.pytorch.org/docs/2.12/notes/libtorch_stable_abi.html).
234+
For an example, see the [`relu-torch-stable-abi`](https://github.com/huggingface/kernels/tree/main/examples/kernels/relu-torch-stable-abi)
235+
example kernel.
230236

231237
### `kernel.<name>`
232238

@@ -277,7 +283,6 @@ are available:
277283
- `cxx-flags`: a list of additional flags to be passed to the C++
278284
compiler.
279285

280-
281286
## Torch bindings
282287

283288
### Defining bindings

examples/kernels/flake.nix

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,22 @@
2626
# system and flake outputs.
2727
# - torchVersions: optional override for the torchVersions argument
2828
ciKernels = [
29+
{
30+
name = "cpp20-symbols-kernel";
31+
path = ./cpp20-symbols;
32+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cpu-${sys}"};
33+
}
2934
{
3035
name = "relu-kernel";
3136
path = ./relu;
32-
drv =
33-
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${cudaVersion}-${sys}"};
37+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${cudaVersion}-${sys}"};
3438
}
3539
{
36-
name = "cpp20-symbols-kernel";
37-
path = ./cpp20-symbols;
38-
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-cpu-${sys}"};
40+
name = "relu-torch-stable-abi-kernel";
41+
path = ./relu-torch-stable-abi;
42+
drv =
43+
sys: out:
44+
out.packages.${sys}.redistributable.${"torch-stable-abi${torchVersion}-${cudaVersion}-${sys}"};
3945
}
4046
{
4147
name = "relu-tvm-ffi-kernel";
@@ -46,19 +52,17 @@
4652
{
4753
name = "extra-data";
4854
path = ./extra-data;
49-
drv =
50-
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${cudaVersion}-${sys}"};
55+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${cudaVersion}-${sys}"};
5156
}
5257
{
5358
name = "relu-kernel-cpu";
5459
path = ./relu;
55-
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-cpu-${sys}"};
60+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cpu-${sys}"};
5661
}
5762
{
5863
name = "cutlass-gemm-kernel";
5964
path = ./cutlass-gemm;
60-
drv =
61-
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${cudaVersion}-${sys}"};
65+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${cudaVersion}-${sys}"};
6266
}
6367
{
6468
name = "cutlass-gemm-tvm-ffi-kernel";
@@ -69,8 +73,7 @@
6973
{
7074
name = "relu-backprop-compile-kernel";
7175
path = ./relu-backprop-compile;
72-
drv =
73-
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${cudaVersion}-${sys}"};
76+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${cudaVersion}-${sys}"};
7477
}
7578
{
7679
name = "silu-and-mul-kernel";
@@ -97,8 +100,7 @@
97100
{
98101
name = "relu-compiler-flags";
99102
path = ./relu-compiler-flags;
100-
drv =
101-
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${cudaVersion}-${sys}"};
103+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${cudaVersion}-${sys}"};
102104
}
103105
{
104106
# Check that we can build an arch dev shell.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
---
2+
library_name: kernels
3+
{% if license %}license: {{ license }}
4+
{% endif %}---
5+
6+
This is the repository card of {{ repo_id }} that has been pushed on the Hub. It was built to be used with the [`kernels` library](https://github.com/huggingface/kernels). This card was automatically generated.
7+
8+
## How to use
9+
{% if functions %}
10+
11+
```python
12+
# make sure `kernels` is installed: `pip install -U kernels`
13+
from kernels import get_kernel
14+
15+
kernel_module = get_kernel("{{ repo_id }}", version={{ version }})
16+
{{ functions[0] }} = kernel_module.{{ functions[0] }}
17+
18+
{{ functions[0] }}(...)
19+
```
20+
{% else %}
21+
22+
Usage example not available.
23+
{% endif %}
24+
25+
## Available functions
26+
{% if functions %}
27+
{% for func in functions %}
28+
- `{{ func }}`
29+
{% endfor %}
30+
{% else %}
31+
32+
Function list not available.
33+
{% endif %}
34+
{% if layers %}
35+
36+
## Available layers
37+
{% for layer in layers %}
38+
- `{{ layer }}`
39+
{% endfor %}
40+
{% endif %}
41+
42+
## Benchmarks
43+
{% if has_benchmark %}
44+
45+
Benchmarking script is available for this kernel. Run `kernels benchmark {{ repo_id }} --version {{ version }}`.
46+
{% else %}
47+
48+
No benchmark available yet.
49+
{% endif %}
50+
{% if upstream %}
51+
52+
## Source code
53+
54+
Source code of this kernel originally comes from {{ upstream }} and it was repurposed for compatibility with `kernels`.
55+
{% endif %}
56+
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
[general]
2+
name = "relu-torch-stable-abi"
3+
version = 1
4+
license = "Apache-2.0"
5+
backends = [
6+
"cpu",
7+
"cuda",
8+
"metal",
9+
"rocm",
10+
"xpu",
11+
]
12+
13+
[general.hub]
14+
repo-id = "kernels-test/relu-torch-stable-abi"
15+
16+
[torch]
17+
stable-abi = "2.11"
18+
src = [
19+
"torch-ext/torch_binding.cpp",
20+
"torch-ext/torch_binding.h",
21+
]
22+
23+
[kernel.relu_xpu]
24+
backend = "xpu"
25+
depends = ["torch"]
26+
src = ["relu_xpu/relu.cpp"]
27+
28+
29+
# Converting metal to the Torch stable ABI requires additional APIs to get
30+
# the command buffer and dispatch queue.
31+
#
32+
# [kernel.relu_metal]
33+
# backend = "metal"
34+
# depends = ["torch"]
35+
# src = [
36+
# "relu_metal/relu.mm",
37+
# "relu_metal/relu.metal",
38+
# "relu_metal/common.h",
39+
#]
40+
41+
[kernel.relu_rocm]
42+
backend = "rocm"
43+
depends = ["torch"]
44+
rocm-archs = [
45+
"gfx906",
46+
"gfx908",
47+
"gfx90a",
48+
"gfx940",
49+
"gfx941",
50+
"gfx942",
51+
"gfx1030",
52+
"gfx1100",
53+
"gfx1101",
54+
]
55+
src = ["relu_cuda/relu.cu"]
56+
57+
[kernel.relu_cpu]
58+
backend = "cpu"
59+
depends = ["torch"]
60+
src = ["relu_cpu/relu_cpu.cpp"]
61+
62+
[kernel.relu]
63+
backend = "cuda"
64+
depends = ["torch"]
65+
src = ["relu_cuda/relu.cu"]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
description = "Flake for ReLU kernel";
3+
4+
inputs = {
5+
kernel-builder.url = "path:../../..";
6+
};
7+
8+
outputs =
9+
{
10+
self,
11+
kernel-builder,
12+
}:
13+
kernel-builder.lib.genKernelFlakeOutputs {
14+
inherit self;
15+
path = ./.;
16+
};
17+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include <torch/csrc/stable/tensor.h>
2+
3+
#ifdef __SSE__
4+
#include <xmmintrin.h>
5+
#endif
6+
7+
#ifdef __ARM_NEON
8+
#include <arm_neon.h>
9+
#endif
10+
11+
// NOTE: This is a minimal example kernel that is not optimized for
12+
// performance, so we do not care about unaligned loads/stores.
13+
14+
#ifdef __SSE__
15+
void relu_forward_sse(float* out, const float* input, size_t size) {
16+
size_t i = 0;
17+
18+
for (; i + 4 <= size; i += 4) {
19+
__m128 vec_input = _mm_loadu_ps(input + i);
20+
__m128 vec_zero = _mm_setzero_ps();
21+
__m128 vec_output = _mm_max_ps(vec_input, vec_zero);
22+
_mm_storeu_ps(out + i, vec_output);
23+
}
24+
25+
for (; i < size; ++i) {
26+
out[i] = input[i] > 0 ? input[i] : 0;
27+
}
28+
}
29+
#endif
30+
31+
#ifdef __ARM_NEON
32+
void relu_forward_neon(float* out, const float* input, size_t size) {
33+
size_t i = 0;
34+
35+
for (; i + 4 <= size; i += 4) {
36+
float32x4_t vec_input = vld1q_f32(input + i);
37+
float32x4_t vec_output = vmaxq_f32(vec_input, vdupq_n_f32(0));
38+
vst1q_f32(out + i, vec_output);
39+
}
40+
41+
for (; i < size; ++i) {
42+
out[i] = input[i] > 0 ? input[i] : 0;
43+
}
44+
}
45+
#endif
46+
47+
void relu(torch::stable::Tensor &out, torch::stable::Tensor const &input) {
48+
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float, "Output tensor must be of dtype float");
49+
STD_TORCH_CHECK(input.scalar_type() == torch::headeronly::ScalarType::Float, "Input tensor must be of dtype float");
50+
STD_TORCH_CHECK(out.numel() == input.numel(), "Input and output tensors must have the same number of elements");
51+
52+
#if defined(__SSE__)
53+
relu_forward_sse(static_cast<float*>(out.data_ptr()), static_cast<const float*>(input.data_ptr()), input.numel());
54+
#elif defined(__ARM_NEON)
55+
relu_forward_neon(static_cast<float*>(out.data_ptr()), static_cast<const float*>(input.data_ptr()), input.numel());
56+
#else
57+
#error "Unsupported architecture; please use a CPU with SSE or ARM NEON support."
58+
#endif
59+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include <torch/csrc/stable/accelerator.h>
2+
#include <torch/csrc/stable/tensor.h>
3+
4+
// The shim's definition is guarded by USE_CUDA, so define here.
5+
extern "C" AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream);
6+
7+
#include <cmath>
8+
9+
__global__ void relu_kernel(float *__restrict__ out,
10+
float const *__restrict__ input, const int d) {
11+
const int64_t token_idx = blockIdx.x;
12+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
13+
auto x = input[token_idx * d + idx];
14+
out[token_idx * d + idx] = x > 0.0f ? x : 0.0f;
15+
}
16+
}
17+
18+
void relu(torch::stable::Tensor &out, torch::stable::Tensor const &input) {
19+
STD_TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
20+
STD_TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
21+
STD_TORCH_CHECK(input.scalar_type() == torch::headeronly::ScalarType::Float &&
22+
out.scalar_type() == torch::headeronly::ScalarType::Float,
23+
"relu_kernel only supports float32");
24+
25+
STD_TORCH_CHECK(input.sizes().equals(out.sizes()),
26+
"Tensors must have the same shape.");
27+
28+
STD_TORCH_CHECK(input.scalar_type() == out.scalar_type(),
29+
"Tensors must have the same data type.");
30+
31+
STD_TORCH_CHECK(input.device() == out.device(),
32+
"Tensors must be on the same device.");
33+
34+
if (input.numel() == 0) {
35+
return;
36+
}
37+
38+
int d = input.size(-1);
39+
int64_t num_tokens = input.numel() / d;
40+
dim3 grid(num_tokens);
41+
dim3 block(std::min(d, 1024));
42+
const torch::stable::accelerator::DeviceGuard device_guard(input.get_device_index());
43+
void* stream_ptr = nullptr;
44+
TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(input.get_device_index(), &stream_ptr));
45+
const cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);
46+
relu_kernel<<<grid, block, 0, stream>>>(static_cast<float*>(out.data_ptr()),
47+
static_cast<const float*>(input.data_ptr()), d);
48+
}

0 commit comments

Comments
 (0)