Skip to content

Commit 9cfa642

Browse files
authored
kernel-builder: reject empty capabilities/archs list (#590)
* kernel-builder: reject empty capabilities/archs list We compute a kernel component's capabilities by intersecting the capabilities that are specified for the kernel and the capabilities that are supported by CUDA/ROCm. Before this change, we would silently set an empty list if this intersection was empty. This resulted in CMake falling back to an old capability. This change fixes that by erroring out when when the capability list is empty. * Hook up ROCm test and make it concurrent
1 parent c358366 commit 9cfa642

11 files changed

Lines changed: 273 additions & 28 deletions

File tree

.github/workflows/build_kernel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
run: nix-shell -p nix-info --run "nix-info -m"
4141

4242
- name: Build all example kernels
43-
run: nix build -L ./examples/kernels#ci-build
43+
run: nix build -L ./examples/kernels#ci-build-cuda
4444
- name: Copy kernel artifacts
4545
run: cp -rL result/* .
4646

.github/workflows/build_kernel_rocm.yaml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ on:
1212

1313
jobs:
1414
build:
15-
name: Build kernel
15+
name: Build kernels (ROCm)
1616
runs-on:
1717
group: aws-highmemory-32-plus-nix
1818
steps:
@@ -33,8 +33,5 @@ jobs:
3333
run: nix-shell -p nix-info --run "nix-info -m"
3434
# For now we only test that there are no regressions in building ROCm
3535
# kernels. Also run tests once we have a ROCm runner.
36-
- name: Build relu kernel
37-
run: ( cd examples/kernels/relu && nix build .\#redistributable.torch211-rocm71-x86_64-linux -L )
38-
39-
- name: Build relu kernel (compiler flags)
40-
run: ( cd examples/kernels/relu-compiler-flags && nix build .\#redistributable.torch211-rocm71-x86_64-linux )
36+
- name: Build all ROCm example kernels
37+
run: nix build -L ./examples/kernels#ci-build-rocm

examples/kernels/flake.nix

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
inherit (kernel-builder.inputs.nixpkgs) lib;
1616

1717
cudaVersion = "cu126";
18+
rocmVersion = "rocm71";
1819
torchVersion = "211";
1920
tvmFfiVersion = "01";
2021

@@ -102,6 +103,13 @@
102103
path = ./relu-compiler-flags;
103104
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${cudaVersion}-${sys}"};
104105
}
106+
{
107+
name = "relu-invalid-capability";
108+
path = ./relu-invalid-capability;
109+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${cudaVersion}-${sys}"};
110+
assertFail = true;
111+
assertFailLogs = [ "empty set of capabilities" ];
112+
}
105113
{
106114
# Check that we can build an arch dev shell.
107115
name = "relu-dev-shell";
@@ -129,6 +137,27 @@
129137
}
130138
];
131139

140+
# ROCm kernels to build in CI.
141+
ciRocmKernels = [
142+
{
143+
name = "relu-invalid-capability";
144+
path = ./relu-invalid-capability;
145+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${rocmVersion}-${sys}"};
146+
assertFail = true;
147+
assertFailLogs = [ "empty set of architectures" ];
148+
}
149+
{
150+
name = "relu-kernel";
151+
path = ./relu;
152+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${rocmVersion}-${sys}"};
153+
}
154+
{
155+
name = "relu-compiler-flags";
156+
path = ./relu-compiler-flags;
157+
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${rocmVersion}-${sys}"};
158+
}
159+
];
160+
132161
mkKernelOutputs =
133162
{
134163
path,
@@ -141,16 +170,21 @@
141170
// lib.optionalAttrs (torchVersions != null) { inherit torchVersions; }
142171
);
143172

144-
ciKernelOutputs = map (
145-
kernel:
146-
kernel
147-
// {
148-
outputs = mkKernelOutputs {
149-
inherit (kernel) path;
150-
torchVersions = kernel.torchVersions or null;
151-
};
152-
}
153-
) ciKernels;
173+
mkKernelOutputs' =
174+
kernels:
175+
map (
176+
kernel:
177+
kernel
178+
// {
179+
outputs = mkKernelOutputs {
180+
inherit (kernel) path;
181+
torchVersions = kernel.torchVersions or null;
182+
};
183+
}
184+
) kernels;
185+
186+
ciKernelOutputs = mkKernelOutputs' ciKernels;
187+
ciRocmKernelOutputs = mkKernelOutputs' ciRocmKernels;
154188
in
155189
flake-utils.lib.eachSystem
156190
[
@@ -162,22 +196,39 @@
162196
let
163197
pkgs = nixpkgs.legacyPackages.${system};
164198

165-
resolvedKernels = map (kernel: {
166-
inherit (kernel) name;
167-
drv = kernel.drv system kernel.outputs;
168-
}) ciKernelOutputs;
169-
170-
ci-build = pkgs.linkFarm "ci-kernels" (
199+
resolveKernels =
200+
kernelOutputsList:
171201
map (kernel: {
172202
inherit (kernel) name;
173-
path = kernel.drv;
174-
}) resolvedKernels
175-
);
203+
drv =
204+
let
205+
baseDrv = kernel.drv system kernel.outputs;
206+
in
207+
if kernel.assertFail or false then
208+
pkgs.testers.testBuildFailure' {
209+
drv = baseDrv;
210+
expectedBuilderLogEntries = kernel.assertFailLogs or [ ];
211+
}
212+
else
213+
baseDrv;
214+
}) kernelOutputsList;
215+
216+
mkCiBuild =
217+
name: kernelOutputsList:
218+
pkgs.linkFarm name (
219+
map (kernel: {
220+
inherit (kernel) name;
221+
path = kernel.drv;
222+
}) (resolveKernels kernelOutputsList)
223+
);
224+
225+
ci-build-cuda = mkCiBuild "ci-kernels-cuda" ciKernelOutputs;
226+
ci-build-rocm = mkCiBuild "ci-kernels-rocm" ciRocmKernelOutputs;
176227
in
177228
{
178229
packages = {
179-
inherit ci-build;
180-
default = ci-build;
230+
inherit ci-build-cuda ci-build-rocm;
231+
default = ci-build-cuda;
181232
};
182233
}
183234
);
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
---
2+
library_name: kernels
3+
{% if license %}license: {{ license }}
4+
{% endif %}---
5+
6+
This is the repository card of {{ repo_id }} that has been pushed on the Hub. It was built to be used with the [`kernels` library](https://github.com/huggingface/kernels). This card was automatically generated.
7+
8+
## How to use
9+
{% if functions %}
10+
11+
```python
12+
# make sure `kernels` is installed: `pip install -U kernels`
13+
from kernels import get_kernel
14+
15+
kernel_module = get_kernel("{{ repo_id }}", version={{ version }})
16+
{{ functions[0] }} = kernel_module.{{ functions[0] }}
17+
18+
{{ functions[0] }}(...)
19+
```
20+
{% else %}
21+
22+
Usage example not available.
23+
{% endif %}
24+
25+
## Available functions
26+
{% if functions %}
27+
{% for func in functions %}
28+
- `{{ func }}`
29+
{% endfor %}
30+
{% else %}
31+
32+
Function list not available.
33+
{% endif %}
34+
{% if layers %}
35+
36+
## Available layers
37+
{% for layer in layers %}
38+
- `{{ layer }}`
39+
{% endfor %}
40+
{% endif %}
41+
42+
## Benchmarks
43+
{% if has_benchmark %}
44+
45+
Benchmarking script is available for this kernel. Run `kernels benchmark {{ repo_id }} --version {{ version }}`.
46+
{% else %}
47+
48+
No benchmark available yet.
49+
{% endif %}
50+
{% if upstream %}
51+
52+
## Source code
53+
54+
Source code of this kernel originally comes from {{ upstream }} and it was repurposed for compatibility with `kernels`.
55+
{% endif %}
56+
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
[general]
2+
name = "relu-invalid-capability"
3+
version = 1
4+
license = "Apache-2.0"
5+
backends = [
6+
"cpu",
7+
"cuda",
8+
"metal",
9+
"rocm",
10+
"xpu",
11+
]
12+
13+
[general.hub]
14+
repo-id = "kernels-test/relu-invalid-capability"
15+
16+
[torch]
17+
src = [
18+
"torch-ext/torch_binding.cpp",
19+
"torch-ext/torch_binding.h",
20+
]
21+
22+
[kernel.relu_rocm]
23+
backend = "rocm"
24+
depends = ["torch"]
25+
rocm-archs = [ "gfx99999" ]
26+
src = ["relu_cuda/relu.cu"]
27+
28+
[kernel.relu]
29+
backend = "cuda"
30+
depends = ["torch"]
31+
cuda-capabilities = [ "99999.0" ]
32+
src = ["relu_cuda/relu.cu"]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
description = "Flake for ReLU kernel";
3+
4+
inputs = {
5+
kernel-builder.url = "path:../../..";
6+
};
7+
8+
outputs =
9+
{
10+
self,
11+
kernel-builder,
12+
}:
13+
kernel-builder.lib.genKernelFlakeOutputs {
14+
inherit self;
15+
path = ./.;
16+
};
17+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include <ATen/cuda/CUDAContext.h>
2+
#include <c10/cuda/CUDAGuard.h>
3+
#include <torch/all.h>
4+
5+
#include <cmath>
6+
7+
__global__ void relu_kernel(float *__restrict__ out,
8+
float const *__restrict__ input, const int d) {
9+
const int64_t token_idx = blockIdx.x;
10+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
11+
auto x = input[token_idx * d + idx];
12+
out[token_idx * d + idx] = x > 0.0f ? x : 0.0f;
13+
}
14+
}
15+
16+
void relu(torch::Tensor &out, torch::Tensor const &input) {
17+
TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor");
18+
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
19+
TORCH_CHECK(input.scalar_type() == at::ScalarType::Float &&
20+
input.scalar_type() == at::ScalarType::Float,
21+
"relu_kernel only supports float32");
22+
23+
TORCH_CHECK(input.sizes() == out.sizes(),
24+
"Tensors must have the same shape. Got input shape: ",
25+
input.sizes(), " and output shape: ", out.sizes());
26+
27+
TORCH_CHECK(input.scalar_type() == out.scalar_type(),
28+
"Tensors must have the same data type. Got input dtype: ",
29+
input.scalar_type(), " and output dtype: ", out.scalar_type());
30+
31+
TORCH_CHECK(input.device() == out.device(),
32+
"Tensors must be on the same device. Got input device: ",
33+
input.device(), " and output device: ", out.device());
34+
35+
if (input.numel() == 0) {
36+
return;
37+
}
38+
39+
int d = input.size(-1);
40+
int64_t num_tokens = input.numel() / d;
41+
dim3 grid(num_tokens);
42+
dim3 block(std::min(d, 1024));
43+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
44+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
45+
relu_kernel<<<grid, block, 0, stream>>>(out.data_ptr<float>(),
46+
input.data_ptr<float>(), d);
47+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
from ._ops import ops
6+
7+
8+
def relu(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
9+
if out is None:
10+
out = torch.empty_like(x)
11+
ops.relu(out, x)
12+
return out
13+
14+
15+
__all__ = ["relu"]
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include <torch/library.h>
2+
3+
#include "registration.h"
4+
#include "torch_binding.h"
5+
6+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7+
ops.def("relu(Tensor! out, Tensor input) -> ()");
8+
#if defined(CPU_KERNEL)
9+
ops.impl("relu", torch::kCPU, &relu);
10+
#elif defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
11+
ops.impl("relu", torch::kCUDA, &relu);
12+
#elif defined(METAL_KERNEL)
13+
ops.impl("relu", torch::kMPS, relu);
14+
#elif defined(XPU_KERNEL)
15+
ops.impl("relu", torch::kXPU, &relu);
16+
#endif
17+
}
18+
19+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
#include <torch/torch.h>
4+
5+
void relu(torch::Tensor &out, torch::Tensor const &input);

0 commit comments

Comments
 (0)