Skip to content

Commit c90c2d1

Browse files
danieldksayakpaul
andauthored
nix-builder: add a hook to detect incorrect op registrations (#569)
* nix-builder: add a hook to detect incorrect op registrations * ci: test that the hook fails on incorrect registrations * doc fixes Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Use full kernel build for torch-ops-check-hook --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent e6def64 commit c90c2d1

16 files changed

Lines changed: 338 additions & 2 deletions

File tree

docs/source/builder/writing-kernels.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,47 @@ def mykernel(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tenso
351351
return out
352352
```
353353

354+
## Registering Torch operators
355+
356+
You may want to register Torch ops from your kernel's Python code or
357+
register fake ops for `torch.compile` support. It is important to register
358+
such ops in the namespace that kernel-builder makes for your kernel
359+
build. This is required for compliant kernels to ensure that multiple
360+
versions of the same kernel can be loaded at the same time without
361+
namespace conflicts.
362+
363+
You can use the `add_op_namespace_prefix` to prefix an op name with the
364+
correct prefix. So for instance, replace
365+
366+
```python
367+
@torch.library.register_fake("relu::relu_fwd")
368+
def relu_fwd_fake(input: torch.Tensor) -> torch.Tensor:
369+
return torch.empty_like(input)
370+
```
371+
372+
by
373+
374+
```python
375+
from ._ops import add_op_namespace_prefix
376+
377+
@torch.library.register_fake(add_op_namespace_prefix("relu_fwd"))
378+
def relu_fwd_fake(input: torch.Tensor) -> torch.Tensor:
379+
return torch.empty_like(input)
380+
```
381+
382+
As mentioned in the above, the `_ops` module is generated by kernel-builder.
383+
384+
kernel-builder uses a hook to reject incorrect usage of Torch op registration
385+
functions. However, it can only catch direct use of certain `torch.library`
386+
decorators. For instance, the hook would not reject the following decorator,
387+
so it should be seen as a last-resort check if human review failed:
388+
389+
```python
390+
@some_indirection_for_register_fake("relu::relu_fwd")
391+
def relu_fwd_fake(input: torch.Tensor) -> torch.Tensor:
392+
return torch.empty_like(input)
393+
```
394+
354395
## Kernel tests
355396

356397
Kernel tests are stored in the `tests` directory. Since running all
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
---
2+
library_name: kernels
3+
{% if license %}license: {{ license }}
4+
{% endif %}---
5+
6+
This is the repository card of {{ repo_id }} that has been pushed on the Hub. It was built to be used with the [`kernels` library](https://github.com/huggingface/kernels). This card was automatically generated.
7+
8+
## How to use
9+
{% if functions %}
10+
11+
```python
12+
# make sure `kernels` is installed: `pip install -U kernels`
13+
from kernels import get_kernel
14+
15+
kernel_module = get_kernel("{{ repo_id }}", version={{ version }})
16+
{{ functions[0] }} = kernel_module.{{ functions[0] }}
17+
18+
{{ functions[0] }}(...)
19+
```
20+
{% else %}
21+
22+
Usage example not available.
23+
{% endif %}
24+
25+
## Available functions
26+
{% if functions %}
27+
{% for func in functions %}
28+
- `{{ func }}`
29+
{% endfor %}
30+
{% else %}
31+
32+
Function list not available.
33+
{% endif %}
34+
{% if layers %}
35+
36+
## Available layers
37+
{% for layer in layers %}
38+
- `{{ layer }}`
39+
{% endfor %}
40+
{% endif %}
41+
42+
## Benchmarks
43+
{% if has_benchmark %}
44+
45+
Benchmarking script is available for this kernel. Run `kernels benchmark {{ repo_id }} --version {{ version }}`.
46+
{% else %}
47+
48+
No benchmark available yet.
49+
{% endif %}
50+
{% if upstream %}
51+
52+
## Source code
53+
54+
Source code of this kernel originally comes from {{ upstream }} and it was repurposed for compatibility with `kernels`.
55+
{% endif %}
56+
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
[general]
2+
name = "silu-and-mul-bad-registration"
3+
version = 1
4+
license = "apache-2.0"
5+
backends = [
6+
"cpu",
7+
"cuda",
8+
"metal",
9+
"rocm",
10+
"xpu",
11+
]
12+
13+
[general.hub]
14+
repo-id = "kernels-test/silu-and-mul-bad-registration"
15+
16+
[torch-noarch]
17+
18+
[kernel]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
description = "Flake for kernels tests";
3+
4+
inputs = {
5+
kernel-builder.url = "path:../../..";
6+
};
7+
8+
outputs =
9+
{
10+
self,
11+
kernel-builder,
12+
}:
13+
kernel-builder.lib.genKernelFlakeOutputs {
14+
inherit self;
15+
path = ./.;
16+
};
17+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
3+
from ._ops import ops
4+
from .op import _silu_and_mul
5+
from . import layers
6+
7+
8+
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
9+
return ops.silu_and_mul(x)
10+
11+
12+
__all__ = ["layers", "silu_and_mul"]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from ._ops import ops
5+
6+
7+
class SiluAndMul(nn.Module):
8+
"""
9+
Apply SiLU to one half of the array and use it as a multiplicative
10+
gate for the other half.
11+
12+
Shapes:
13+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
14+
return: (num_tokens, d) or (batch_size, seq_len, d)
15+
"""
16+
17+
can_torch_compile: bool = True
18+
19+
def forward(self, x: torch.Tensor):
20+
return ops.silu_and_mul(x)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
5+
@torch.library.custom_op("myns::silu_and_mul", mutates_args=())
6+
def _silu_and_mul(x: torch.Tensor) -> torch.Tensor:
7+
d = x.shape[-1] // 2
8+
return F.silu(x[..., :d]) * x[..., d:]
9+
10+
11+
def backward(ctx, grad_output):
12+
x = ctx.saved_tensors[0]
13+
d = x.shape[-1] // 2
14+
x1, x2 = x[..., :d], x[..., d:]
15+
sigmoid_x1 = torch.sigmoid(x1)
16+
silu_x1 = F.silu(x1)
17+
dsilu_dx1 = sigmoid_x1 + silu_x1 * (1 - sigmoid_x1)
18+
dx1 = grad_output * x2 * dsilu_dx1
19+
dx2 = grad_output * silu_x1
20+
return torch.cat([dx1, dx2], dim=-1)
21+
22+
23+
def setup_context(ctx, inputs, output):
24+
(x,) = inputs
25+
ctx.save_for_backward(x)
26+
27+
28+
_silu_and_mul.register_autograd(backward, setup_context=setup_context)
29+
30+
31+
@_silu_and_mul.register_fake
32+
def _(x: torch.Tensor) -> torch.Tensor:
33+
return x.new_empty(x.shape[0], x.shape[1] // 2)

flake.nix

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@
201201
in
202202
rec {
203203
checks.default = pkgs.callPackage ./nix-builder/lib/checks.nix {
204-
inherit buildSets;
204+
inherit buildSets self;
205+
inherit (self.lib) genKernelFlakeOutputs;
205206
build = buildPerSystem.${system};
206207
};
207208

nix-builder/lib/checks.nix

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,38 @@
11
{
2+
self,
23
lib,
34
runCommand,
5+
testers,
6+
python3,
7+
stdenv,
48

59
build,
610
buildSets,
11+
genKernelFlakeOutputs,
712
}:
813

914
let
1015
kernelBuildSets = build.applicableBuildSets {
1116
inherit buildSets;
1217
path = ../../examples/kernels/relu-torch-bounds;
1318
};
19+
20+
badRegistrationCheck = testers.testBuildFailure' {
21+
drv =
22+
(genKernelFlakeOutputs {
23+
inherit self;
24+
path = ../../examples/kernels/silu-and-mul-bad-registration;
25+
}).packages.${stdenv.hostPlatform.system}.redistributable.torch-cuda;
26+
expectedBuilderExitCode = 1;
27+
expectedBuilderLogEntries = [
28+
"Found Torch library registrations that do not use `add_op_namespace_prefix`:"
29+
];
30+
};
1431
in
1532
assert lib.assertMsg (builtins.all (buildSet: buildSet.torch.version == "2.10.0") kernelBuildSets)
1633
''
1734
Torch minver/maxver filtering does not work.
1835
'';
19-
runCommand "builder-nix-checks" { } ''
36+
runCommand "builder-nix-checks" { buildInputs = [ badRegistrationCheck ]; } ''
2037
touch $out
2138
''

nix-builder/lib/extension/torch/arch.nix

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get-kernel-check,
1717
kernel-abi-check,
1818
kernel-layout-check,
19+
torch-ops-check,
1920
ninja,
2021
python3,
2122
remove-bytecode-hook,
@@ -163,6 +164,7 @@ stdenv.mkDerivation (prevAttrs: {
163164
kernel-abi-check
164165
kernel-layout-check
165166
remove-bytecode-hook
167+
torch-ops-check
166168
]
167169
++ lib.optionals doGetKernelCheck [
168170
(get-kernel-check.override { python3 = python3.withPackages (ps: dependencies); })

0 commit comments

Comments
 (0)