diff --git a/.github/workflows/test_kernels.yaml b/.github/workflows/test_kernels.yaml index 861d1892..7c2875c5 100644 --- a/.github/workflows/test_kernels.yaml +++ b/.github/workflows/test_kernels.yaml @@ -27,27 +27,31 @@ jobs: max-parallel: 4 matrix: python-version: ["3.10", "3.12"] - torch-version: ["2.9.0", "2.10.0"] + torch-version: ["2.10.0", "2.11.0"] env: UV_PYTHON_PREFERENCE: only-managed steps: - name: Checkout code - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install uv and set the python version - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: python-version: ${{ matrix.python-version }} - - name: Install the project + - name: Override kernels-data source to local bindings working-directory: ./kernels - run: uv sync --all-extras --dev + run: uv add ../kernels-data/bindings/python --no-sync - - name: Install Torch version + - name: Override the Torch version working-directory: ./kernels - run: uv pip install "torch==${{ matrix.torch-version }}" + run: uv add "torch==${{ matrix.torch-version }}" --no-sync + + - name: Install the project + working-directory: ./kernels + run: uv sync --all-extras --dev - name: Install setuptools for Triton-based test working-directory: ./kernels @@ -72,7 +76,6 @@ jobs: uv pip install einops nvidia-cutlass-dsl uv run pytest tests/test_deps.py - - name: Check kernel check working-directory: ./kernels run: | diff --git a/docs/source/api/kernels.md b/docs/source/api/kernels.md index 773ebbcc..4946b024 100644 --- a/docs/source/api/kernels.md +++ b/docs/source/api/kernels.md @@ -26,4 +26,18 @@ ### get_locked_kernel -[[autodoc]] kernels.get_locked_kernel \ No newline at end of file +<<<<<<< kernels-use-kernels-data +[[autodoc]] kernels.get_locked_kernel + +## Classes + +### LoadedKernel + +[[autodoc]] kernels.LoadedKernel + +### RepoInfo + +[[autodoc]] kernels.RepoInfo +======= +[[autodoc]] kernels.get_locked_kernel +>>>>>>> main diff --git a/docs/source/builder-cli.md b/docs/source/builder-cli.md index 2630d5c7..18fa61f2 100644 --- a/docs/source/builder-cli.md +++ b/docs/source/builder-cli.md @@ -75,6 +75,9 @@ Initialize a new kernel project from template ###### **Options:** +* `--license ` — The kernel's license + + Default value: `Apache-2.0` * `--name ` — Name of the kernel repo (e.g. `drbh/my-kernel`) * `--backends ` — Backends to enable (`all`, `cpu`, `cuda`, `metal`, `neuron`, `rocm`, `xpu`) diff --git a/docs/source/builder/build-variants.md b/docs/source/builder/build-variants.md index a3f74c51..710c4bad 100644 --- a/docs/source/builder/build-variants.md +++ b/docs/source/builder/build-variants.md @@ -28,7 +28,6 @@ available. This list will be updated as new PyTorch versions are released. - `torch211-cxx11-cu126-aarch64-linux` - `torch211-cxx11-cu128-aarch64-linux` - `torch211-cxx11-cu130-aarch64-linux` -- `torch29-cxx11-cu129-aarch64-linux` ## CPU x86_64-linux @@ -43,7 +42,6 @@ available. This list will be updated as new PyTorch versions are released. - `torch211-cxx11-cu126-x86_64-linux` - `torch211-cxx11-cu128-x86_64-linux` - `torch211-cxx11-cu130-x86_64-linux` -- `torch29-cxx11-cu129-x86_64-linux` ## ROCm x86_64-linux diff --git a/docs/source/kernel-requirements.md b/docs/source/kernel-requirements.md index 2463027b..5b1c401f 100644 --- a/docs/source/kernel-requirements.md +++ b/docs/source/kernel-requirements.md @@ -42,7 +42,11 @@ metadata. Currently the following top-level keys are supported: - `id` (`str`, required): a unique identifier for the kernel. This identifier must also be a valid Python module name. If the kernel registers Torch ops, they must be registered as `torch.ops.` +- `name` (`str`, required): then name of the kernel. Replacing dashes + by underscores should result in the module name of the kernel. - `version` (`int`, required): the kernel version number. +- `license` (`str`, required): the kernel license in. Refer to the + list of [supported license identifiers](https://huggingface.co/docs/hub/repositories-licenses). - `backend` (`dict`, required): information about the compute backend that this build variant supports. - `python-depends` (`list[str]`, optional): list of Python dependencies @@ -52,9 +56,11 @@ Example `metadata.json`: ```json { - "id": "_mykernel_cuda_be238e4", - "python-depends": ["einops"], + "name": "mykernel", + "id": "_mykernel_cuda_7a4e5a7", "version": 1, + "license": "Apache-2.0", + "python-depends": ["einops"], "backend": { "type": "cuda", "archs": ["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0+PTX"] diff --git a/examples/kernels/cutlass-gemm-tvm-ffi/build.toml b/examples/kernels/cutlass-gemm-tvm-ffi/build.toml index 2537338c..5f92b414 100644 --- a/examples/kernels/cutlass-gemm-tvm-ffi/build.toml +++ b/examples/kernels/cutlass-gemm-tvm-ffi/build.toml @@ -1,5 +1,7 @@ [general] name = "cutlass-gemm-tvm-ffi" +version = 1 +license = "Apache-2.0" backends = [ "cuda", "xpu", @@ -9,20 +11,20 @@ backends = [ repo-id = "kernels-test/cutlass-gemm-tvm-ffi" [tvm-ffi] -src = [ - "tvm-ffi-ext/tvm_ffi_binding.cpp", -] +src = ["tvm-ffi-ext/tvm_ffi_binding.cpp"] [kernel.gemm] backend = "cuda" -depends = [ - "cutlass_3_6", +depends = ["cutlass_3_6"] +src = [ + "gemm.cu", + "util.hh", ] -src = ["gemm.cu", "util.hh"] [kernel.gemm_xpu] backend = "xpu" -depends = [ - "sycl_tla", +depends = ["sycl_tla"] +src = [ + "gemm_sycl.cpp", + "util.hh", ] -src = ["gemm_sycl.cpp", "util.hh"] diff --git a/examples/kernels/cutlass-gemm/build.toml b/examples/kernels/cutlass-gemm/build.toml index 7fcf90ae..f4021dd6 100644 --- a/examples/kernels/cutlass-gemm/build.toml +++ b/examples/kernels/cutlass-gemm/build.toml @@ -1,5 +1,7 @@ [general] name = "cutlass-gemm" +version = 1 +license = "Apache-2.0" backends = [ "cuda", "xpu", @@ -14,14 +16,6 @@ src = [ "torch-ext/torch_binding.h", ] -[kernel.gemm] -backend = "cuda" -depends = [ - "torch", - "cutlass_3_6", -] -src = ["gemm.cu"] - [kernel.gemm_xpu] backend = "xpu" depends = [ @@ -29,3 +23,11 @@ depends = [ "sycl_tla", ] src = ["gemm_sycl.cpp"] + +[kernel.gemm] +backend = "cuda" +depends = [ + "torch", + "cutlass_3_6", +] +src = ["gemm.cu"] diff --git a/examples/kernels/extra-data/build.toml b/examples/kernels/extra-data/build.toml index 566edbb3..5ebf677a 100644 --- a/examples/kernels/extra-data/build.toml +++ b/examples/kernels/extra-data/build.toml @@ -1,5 +1,7 @@ [general] name = "extra-data" +version = 1 +license = "Apache-2.0" backends = [ "cpu", "cuda", @@ -12,28 +14,18 @@ backends = [ repo-id = "kernels-test/extra-data" [torch] +pyext = [ + "json", + "py", +] src = [ "torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h", ] -pyext = ["json", "py"] - -[kernel.relu] -backend = "cuda" -depends = ["torch"] -src = ["relu_cuda/relu.cu"] - -[kernel.relu_metal] -backend = "metal" -src = [ - "relu_metal/relu.mm", - "relu_metal/relu.metal", - "relu_metal/common.h", -] -depends = [ "torch" ] [kernel.relu_rocm] backend = "rocm" +depends = ["torch"] rocm-archs = [ "gfx906", "gfx908", @@ -45,7 +37,6 @@ rocm-archs = [ "gfx1100", "gfx1101", ] -depends = ["torch"] src = ["relu_cuda/relu.cu"] [kernel.relu_xpu] @@ -53,6 +44,20 @@ backend = "xpu" depends = ["torch"] src = ["relu_xpu/relu.cpp"] +[kernel.relu_metal] +backend = "metal" +depends = ["torch"] +src = [ + "relu_metal/relu.mm", + "relu_metal/relu.metal", + "relu_metal/common.h", +] + +[kernel.relu] +backend = "cuda" +depends = ["torch"] +src = ["relu_cuda/relu.cu"] + [kernel.relu_cpu] backend = "cpu" depends = ["torch"] diff --git a/examples/kernels/relu-backprop-compile/build.toml b/examples/kernels/relu-backprop-compile/build.toml index 402684bb..3de9dc31 100644 --- a/examples/kernels/relu-backprop-compile/build.toml +++ b/examples/kernels/relu-backprop-compile/build.toml @@ -1,5 +1,7 @@ [general] name = "relu-backprop-compile" +version = 1 +license = "Apache-2.0" backends = [ "cuda", "rocm", @@ -21,6 +23,7 @@ src = ["relu_cuda/relu.cu"] [kernel.relu_rocm] backend = "rocm" +depends = ["torch"] rocm-archs = [ "gfx906", "gfx908", @@ -32,5 +35,4 @@ rocm-archs = [ "gfx1100", "gfx1101", ] -depends = ["torch"] src = ["relu_cuda/relu.cu"] diff --git a/examples/kernels/relu-compiler-flags/build.toml b/examples/kernels/relu-compiler-flags/build.toml index 00adf665..8b699dfe 100644 --- a/examples/kernels/relu-compiler-flags/build.toml +++ b/examples/kernels/relu-compiler-flags/build.toml @@ -1,5 +1,7 @@ [general] name = "relu-compiler-flags" +version = 1 +license = "Apache-2.0" backends = [ "cuda", "rocm", @@ -10,16 +12,26 @@ backends = [ repo-id = "kernels-test/relu-compiler-flags" [torch] -src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"] +src = [ + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h", +] + +[kernel.activation_xpu] +backend = "xpu" +depends = ["torch"] +sycl-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"] +src = ["relu_xpu/relu.cpp"] [kernel.activation] backend = "cuda" +cuda-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"] depends = ["torch"] src = ["relu_cuda/relu.cu"] -cuda-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"] [kernel.activation_rocm] backend = "rocm" +depends = ["torch"] rocm-archs = [ "gfx906", "gfx908", @@ -31,12 +43,5 @@ rocm-archs = [ "gfx1100", "gfx1101", ] -depends = ["torch"] -src = ["relu_cuda/relu.cu"] hip-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"] - -[kernel.activation_xpu] -backend = "xpu" -depends = ["torch"] -src = ["relu_xpu/relu.cpp"] -sycl-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"] +src = ["relu_cuda/relu.cu"] diff --git a/examples/kernels/relu-metal-cpp/build.toml b/examples/kernels/relu-metal-cpp/build.toml index aa5d4ec5..ffef4ff3 100644 --- a/examples/kernels/relu-metal-cpp/build.toml +++ b/examples/kernels/relu-metal-cpp/build.toml @@ -1,5 +1,7 @@ [general] name = "relu" +version = 1 +license = "Apache-2.0" backends = ["metal"] [general.hub] @@ -11,13 +13,15 @@ src = [ "torch-ext/torch_binding.h", ] - [kernel.relu_metal] backend = "metal" +depends = [ + "torch", + "metal-cpp", +] src = [ - "relu/relu.cpp", - "relu/metallib_loader.mm", - "relu/relu_cpp.metal", - "relu/common.h", + "relu/relu.cpp", + "relu/metallib_loader.mm", + "relu/relu_cpp.metal", + "relu/common.h", ] -depends = [ "torch", "metal-cpp" ] diff --git a/examples/kernels/relu-nki/build.toml b/examples/kernels/relu-nki/build.toml index 86f2c106..7ff09f88 100644 --- a/examples/kernels/relu-nki/build.toml +++ b/examples/kernels/relu-nki/build.toml @@ -1,12 +1,15 @@ [general] name = "relu-nki" version = 1 -backends = [ - "neuron", -] +license = "Apache-2.0" +backends = ["neuron"] [general.hub] repo-id = "kernels-test/relu-nki" [general.neuron] python-depends = ["nki"] + +[torch-noarch] + +[kernel] diff --git a/examples/kernels/relu-specific-torch/build.toml b/examples/kernels/relu-specific-torch/build.toml index cb0a4c7b..0d17aad1 100644 --- a/examples/kernels/relu-specific-torch/build.toml +++ b/examples/kernels/relu-specific-torch/build.toml @@ -1,5 +1,7 @@ [general] name = "relu-specific-torch" +version = 1 +license = "Apache-2.0" backends = [ "cuda", "rocm", @@ -21,6 +23,7 @@ src = ["relu_cuda/relu.cu"] [kernel.relu_rocm] backend = "rocm" +depends = ["torch"] rocm-archs = [ "gfx906", "gfx908", @@ -32,5 +35,4 @@ rocm-archs = [ "gfx1100", "gfx1101", ] -depends = ["torch"] src = ["relu_cuda/relu.cu"] diff --git a/examples/kernels/relu-torch-bounds/build.toml b/examples/kernels/relu-torch-bounds/build.toml index 7d621e34..c054d276 100644 --- a/examples/kernels/relu-torch-bounds/build.toml +++ b/examples/kernels/relu-torch-bounds/build.toml @@ -1,5 +1,7 @@ [general] name = "relu" +version = 1 +license = "Apache-2.0" backends = [ "cuda", "rocm", diff --git a/examples/kernels/relu-tvm-ffi/build.toml b/examples/kernels/relu-tvm-ffi/build.toml index c93278f5..ee858e91 100644 --- a/examples/kernels/relu-tvm-ffi/build.toml +++ b/examples/kernels/relu-tvm-ffi/build.toml @@ -1,7 +1,7 @@ [general] name = "relu-tvm-ffi" version = 1 - +license = "Apache-2.0" backends = [ "cpu", "cuda", @@ -12,21 +12,28 @@ backends = [ repo-id = "kernels-test/relu-tvm-ffi" [tvm-ffi] +src = ["tvm-ffi-ext/tvm_ffi_binding.cpp"] + +[kernel.relu_cuda] +backend = "cuda" +depends = [] src = [ - "tvm-ffi-ext/tvm_ffi_binding.cpp", + "relu_cuda/relu.cu", + "util.hh", ] [kernel.relu_cpu] backend = "cpu" -src = ["relu_cpu/relu_cpu.cpp", "util.hh"] -depends = [] - -[kernel.relu_cuda] -backend = "cuda" -src = ["relu_cuda/relu.cu", "util.hh"] depends = [] +src = [ + "relu_cpu/relu_cpu.cpp", + "util.hh", +] [kernel.relu_xpu] backend = "xpu" depends = ["torch"] -src = ["relu_xpu/relu.cpp", "util.hh"] +src = [ + "relu_xpu/relu.cpp", + "util.hh", +] diff --git a/examples/kernels/relu/build.toml b/examples/kernels/relu/build.toml index 51d8a9be..e6741d3c 100644 --- a/examples/kernels/relu/build.toml +++ b/examples/kernels/relu/build.toml @@ -1,5 +1,7 @@ [general] name = "relu" +version = 1 +license = "Apache-2.0" backends = [ "cpu", "cuda", @@ -17,22 +19,23 @@ src = [ "torch-ext/torch_binding.h", ] -[kernel.relu] -backend = "cuda" +[kernel.relu_xpu] +backend = "xpu" depends = ["torch"] -src = ["relu_cuda/relu.cu"] +src = ["relu_xpu/relu.cpp"] [kernel.relu_metal] backend = "metal" +depends = ["torch"] src = [ - "relu_metal/relu.mm", - "relu_metal/relu.metal", - "relu_metal/common.h", + "relu_metal/relu.mm", + "relu_metal/relu.metal", + "relu_metal/common.h", ] -depends = [ "torch" ] [kernel.relu_rocm] backend = "rocm" +depends = ["torch"] rocm-archs = [ "gfx906", "gfx908", @@ -44,15 +47,14 @@ rocm-archs = [ "gfx1100", "gfx1101", ] -depends = ["torch"] src = ["relu_cuda/relu.cu"] -[kernel.relu_xpu] -backend = "xpu" -depends = ["torch"] -src = ["relu_xpu/relu.cpp"] - [kernel.relu_cpu] backend = "cpu" depends = ["torch"] src = ["relu_cpu/relu_cpu.cpp"] + +[kernel.relu] +backend = "cuda" +depends = ["torch"] +src = ["relu_cuda/relu.cu"] diff --git a/examples/kernels/silu-and-mul/build.toml b/examples/kernels/silu-and-mul/build.toml index 49708d96..33df1c4b 100644 --- a/examples/kernels/silu-and-mul/build.toml +++ b/examples/kernels/silu-and-mul/build.toml @@ -1,5 +1,6 @@ [general] name = "silu-and-mul" +version = 1 license = "apache-2.0" backends = [ "cpu", @@ -10,4 +11,8 @@ backends = [ ] [general.hub] -repo-id = "kernels-tests/silu-and-mul" +repo-id = "kernels-test/silu-and-mul" + +[torch-noarch] + +[kernel] diff --git a/examples/kernels/silu-and-mul/tests/test_silu_and_mul.py b/examples/kernels/silu-and-mul/tests/test_silu_and_mul.py index d98cf408..adb19bc2 100644 --- a/examples/kernels/silu-and-mul/tests/test_silu_and_mul.py +++ b/examples/kernels/silu-and-mul/tests/test_silu_and_mul.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from torch.library import opcheck -from silu_and_mul import ops, silu_and_mul +from silu_and_mul import layers, ops, silu_and_mul def silu_and_mul_ref(x: torch.Tensor) -> torch.Tensor: @@ -44,3 +44,29 @@ def test_silu_and_mul(device, requires_grad, dtype): y_ref.backward(d_y) y.backward(d_y) torch.testing.assert_close(x_ref.grad, x.grad) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("requires_grad", [False, True]) +# Only do float32, the numerical instabilities of float16 and bfloat16 +# are too large with the different orderings of computing the gradients. +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_silu_and_mul_layer(device, requires_grad, dtype): + torch.manual_seed(42) + x_ref = torch.randn( + 32, 128, device=device, requires_grad=requires_grad, dtype=dtype + ) + x = torch.empty(32, 128, device=device, requires_grad=requires_grad, dtype=dtype) + with torch.no_grad(): + x.copy_(x_ref) + + y_ref = silu_and_mul_ref(x_ref) + y = layers.SiluAndMul()(x) + + torch.testing.assert_close(y_ref, y) + + if requires_grad: + d_y = torch.randn((32, 64), device=device, dtype=dtype) + y_ref.backward(d_y) + y.backward(d_y) + torch.testing.assert_close(x_ref.grad, x.grad) diff --git a/examples/kernels/silu-and-mul/torch-ext/silu_and_mul/__init__.py b/examples/kernels/silu-and-mul/torch-ext/silu_and_mul/__init__.py index dd9ea8f7..a7102ebe 100644 --- a/examples/kernels/silu-and-mul/torch-ext/silu_and_mul/__init__.py +++ b/examples/kernels/silu-and-mul/torch-ext/silu_and_mul/__init__.py @@ -2,10 +2,11 @@ from ._ops import ops from .op import _silu_and_mul +from . import layers def silu_and_mul(x: torch.Tensor) -> torch.Tensor: return ops.silu_and_mul(x) -__all__ = ["silu_and_mul"] +__all__ = ["layers", "silu_and_mul"] diff --git a/examples/kernels/silu-and-mul/torch-ext/silu_and_mul/layers.py b/examples/kernels/silu-and-mul/torch-ext/silu_and_mul/layers.py new file mode 100644 index 00000000..996f6415 --- /dev/null +++ b/examples/kernels/silu-and-mul/torch-ext/silu_and_mul/layers.py @@ -0,0 +1,20 @@ +import torch +import torch.nn as nn + +from ._ops import ops + + +class SiluAndMul(nn.Module): + """ + Apply SiLU to one half of the array and use it as a multiplicative + gate for the other half. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + can_torch_compile: bool = True + + def forward(self, x: torch.Tensor): + return ops.silu_and_mul(x) diff --git a/kernel-builder/src/card.rs b/kernel-builder/src/card.rs index 0223b671..0fab8861 100644 --- a/kernel-builder/src/card.rs +++ b/kernel-builder/src/card.rs @@ -130,7 +130,7 @@ fn render_card(build: &Build, kernel_dir: &Path) -> Result { layers => layers, has_benchmark => has_benchmark, upstream => build.general.upstream.as_ref().map(|u| u.to_string()), - license => build.general.license.as_ref().map(|l| l.to_lowercase()), + license => build.general.license.to_lowercase(), }) .wrap_err("Cannot render card template") } diff --git a/kernel-builder/src/init.rs b/kernel-builder/src/init.rs index 16cefc20..8ce37f48 100644 --- a/kernel-builder/src/init.rs +++ b/kernel-builder/src/init.rs @@ -36,6 +36,10 @@ pub struct InitArgs { #[arg(value_name = "PATH")] pub path: Option, + /// The kernel's license. + #[arg(long, value_name = "LICENSE", default_value = "Apache-2.0")] + pub license: String, + /// Name of the kernel repo (e.g. `drbh/my-kernel`). #[arg(long, value_name = "OWNER/REPO")] pub name: Option, @@ -172,7 +176,7 @@ pub fn run_init(args: InitArgs) -> Result<()> { load_init_templates(&mut env); // Build FileSet in memory (atomic preparation) - let file_set = build_init_fileset(&env, &repo_info, &enabled_backends)?; + let file_set = build_init_fileset(&env, &repo_info, &args.license, &enabled_backends)?; // Atomic write - validates first, then writes all files file_set.write(&target_dir, args.overwrite)?; @@ -315,6 +319,7 @@ fn load_init_templates(env: &mut Environment) { fn build_init_fileset( env: &Environment, repo_info: &RepoInfo, + license: &str, enabled_backends: &[Backend], ) -> Result { let has_cpu = enabled_backends.contains(&Backend::Cpu); @@ -332,6 +337,7 @@ fn build_init_fileset( kernel_name => &repo_info.name, kernel_name_normalized => &repo_info.normalized_name, kernel_name_class => &repo_info.class_name, + license => license, repo_id => &repo_info.repo_id, backends => &backend_strings, has_cpu => has_cpu, diff --git a/kernel-builder/src/init/templates/build.toml b/kernel-builder/src/init/templates/build.toml index 739779be..fbb5cea7 100644 --- a/kernel-builder/src/init/templates/build.toml +++ b/kernel-builder/src/init/templates/build.toml @@ -1,10 +1,11 @@ [general] +name = "{{ kernel_name }}" +license = "{{ license }}" backends = [ {% for backend in backends %} "{{ backend }}", {% endfor %} ] -name = "{{ kernel_name }}" version = 1 [general.hub] diff --git a/kernel-builder/src/main.rs b/kernel-builder/src/main.rs index 440f3f5e..3e447590 100644 --- a/kernel-builder/src/main.rs +++ b/kernel-builder/src/main.rs @@ -31,7 +31,7 @@ use upload::{run_upload, RepoTypeArg, UploadArgs}; mod pyproject; use pyproject::{clean_pyproject, create_pyproject}; -use kernels_data::config::{v3, Build, BuildCompat}; +use kernels_data::config::{v4, Build, BuildCompat}; mod nix; @@ -413,15 +413,15 @@ fn update_build(kernel_dir: Option) -> Result<()> { let kernel_dir = check_or_infer_kernel_dir(kernel_dir)?; let build_compat: BuildCompat = parse_and_validate(&kernel_dir)?; - if matches!(build_compat, BuildCompat::V3(_)) { + if matches!(build_compat, BuildCompat::V4(_)) { return Ok(()); } let build: Build = build_compat .try_into() .context("Cannot update build configuration")?; - let v3_build: v3::Build = build.into(); - let pretty_toml = toml::to_string_pretty(&v3_build)?; + let v4_build: v4::Build = build.into(); + let pretty_toml = toml::to_string_pretty(&v4_build)?; let build_toml = kernel_dir.join("build.toml"); let mut writer = diff --git a/kernel-builder/src/pyproject/common.rs b/kernel-builder/src/pyproject/common.rs index c451f0c2..74a46235 100644 --- a/kernel-builder/src/pyproject/common.rs +++ b/kernel-builder/src/pyproject/common.rs @@ -39,7 +39,8 @@ pub fn write_metadata( .collect::>>()?; let metadata = Metadata { - id: Some(kernel_id.to_string_for_backend(*backend)), + id: kernel_id.to_string_for_backend(*backend), + name: general.name.clone(), version: general.version, license: general.license.clone(), upstream: general.upstream.clone(), diff --git a/kernel-builder/src/pyproject/mod.rs b/kernel-builder/src/pyproject/mod.rs index a6f26dd4..5582d5b6 100644 --- a/kernel-builder/src/pyproject/mod.rs +++ b/kernel-builder/src/pyproject/mod.rs @@ -22,7 +22,6 @@ mod torch; mod tvm_ffi; pub use fileset::FileSet; -pub use kernels_data::metadata::parse_metadata; pub fn create_pyproject_file_set(build: Build, kernel_id: &KernelIdentifier) -> Result { let mut env = Environment::new(); diff --git a/kernel-builder/src/upload.rs b/kernel-builder/src/upload.rs index 7b9be741..47dacb3e 100644 --- a/kernel-builder/src/upload.rs +++ b/kernel-builder/src/upload.rs @@ -1,6 +1,7 @@ use std::{ collections::{BTreeMap, HashSet}, - fs, + fs::{self, File}, + io::BufReader, path::{Path, PathBuf}, }; @@ -10,11 +11,11 @@ use huggingface_hub::{ AddSource, CommitOperation, CreateRepoParams, RepoCreateBranchParams, RepoCreateCommitParams, RepoListFilesParams, RepoListRefsParams, RepoType, }; +use kernels_data::metadata::Metadata; use walkdir::WalkDir; use crate::{ hf::{self, repo_handle}, - pyproject::parse_metadata, util::{check_or_infer_kernel_dir, discover_variants, parse_build}, }; @@ -399,29 +400,28 @@ fn discover_build_file( /// Determine the branch name (`v{version}`) from variant metadata. fn detect_branch_from_metadata(variants: &[PathBuf]) -> Result> { - let mut versions: HashSet> = HashSet::new(); + let mut versions: HashSet = HashSet::new(); for variant in variants { - let metadata = parse_metadata(variant.join("metadata.json"))?; + let metadata_path = variant.join("metadata.json"); + let metadata = Metadata::from_reader(BufReader::new(File::open(&metadata_path).context( + format!( + "Cannot read metadata from: {}", + metadata_path.to_string_lossy() + ), + )?))?; versions.insert(metadata.version); } if versions.len() > 1 { - let strs: Vec<_> = versions - .iter() - .map(|v| v.map_or("none".into(), |n| n.to_string())) - .collect(); + let strs: Vec<_> = versions.iter().map(ToString::to_string).collect(); bail!( "Found multiple versions in build variants: {}", strs.join(", ") ); } - Ok(versions - .into_iter() - .next() - .flatten() - .map(|v| format!("v{v}"))) + Ok(versions.into_iter().next().map(|v| format!("v{v}"))) } /// Recursively walk a directory and return all file paths. @@ -437,6 +437,7 @@ fn walk_files(dir: &Path) -> impl Iterator { mod tests { use super::*; + #[test] fn test_collect_readme_commit_ops() { let temp_dir = tempfile::tempdir().unwrap(); let kernel_dir = temp_dir.path(); diff --git a/kernel-builder/src/util.rs b/kernel-builder/src/util.rs index 1388649f..7049a19c 100644 --- a/kernel-builder/src/util.rs +++ b/kernel-builder/src/util.rs @@ -10,12 +10,6 @@ use kernels_data::config::{Build, BuildCompat}; pub(crate) fn parse_build(kernel_dir: impl AsRef) -> Result { let build_compat = parse_and_validate(kernel_dir)?; - if matches!(build_compat, BuildCompat::V1(_) | BuildCompat::V2(_)) { - eprintln!( - "build.toml is in the deprecated V1 or V2 format, use `kernel-builder update-build` to update." - ) - } - let build: Build = build_compat .try_into() .context("Cannot update build configuration")?; diff --git a/kernels-data/bindings/python/kernels_data.pyi b/kernels-data/bindings/python/kernels_data.pyi index d4d5ba95..3aef9502 100644 --- a/kernels-data/bindings/python/kernels_data.pyi +++ b/kernels-data/bindings/python/kernels_data.pyi @@ -104,7 +104,7 @@ class Metadata: """Parsed ``metadata.json`` for a kernel build variant.""" @staticmethod - def load(metadata_path: os.PathLike[str] | str) -> "Metadata": + def read_from_file(metadata_path: os.PathLike[str] | str) -> "Metadata": """Parse ``metadata.json`` at the given path. Raises: @@ -112,6 +112,10 @@ class Metadata: """ ... + @property + def id(self) -> str: ... + @property + def name(self) -> KernelName: ... @property def version(self) -> Optional[int]: ... @property diff --git a/kernels-data/bindings/python/src/lib.rs b/kernels-data/bindings/python/src/lib.rs index 87aea2a6..e5475d27 100644 --- a/kernels-data/bindings/python/src/lib.rs +++ b/kernels-data/bindings/python/src/lib.rs @@ -1,11 +1,13 @@ +use std::fs::File; +use std::io::BufReader; use std::path::PathBuf; use std::str::FromStr; use kernels_data::config::{Backend, KernelName}; -use kernels_data::metadata::{BackendInfo, Metadata, parse_metadata}; +use kernels_data::metadata::{BackendInfo, Metadata}; use kernels_data::version::Version; use pyo3::Bound as PyBound; -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyOSError, PyValueError}; use pyo3::prelude::*; /// A dotted numeric version (e.g. `12.8.0`). Trailing zeros are stripped @@ -188,8 +190,10 @@ impl PyBackendInfo { #[pyclass(name = "Metadata", frozen)] #[derive(Clone, Debug)] struct PyMetadata { - version: Option, - license: Option, + id: String, + name: PyKernelName, + version: usize, + license: String, upstream: Option, python_depends: Vec, backend: PyBackendInfo, @@ -198,6 +202,8 @@ struct PyMetadata { impl From for PyMetadata { fn from(m: Metadata) -> Self { Self { + id: m.id, + name: PyKernelName { inner: m.name }, version: m.version, license: m.license, upstream: m.upstream.map(|u| u.to_string()), @@ -213,25 +219,42 @@ impl PyMetadata { /// /// Raises `ValueError` on any I/O or parse error. #[staticmethod] - fn load(metadata_path: PathBuf) -> PyResult { - parse_metadata(&metadata_path) + fn read_from_file(metadata_path: PathBuf) -> PyResult { + let f = File::open(&metadata_path).map_err(|err| { + PyOSError::new_err(format!("Failed to open `{metadata_path:?}`: {err:#}")) + })?; + Metadata::from_reader(BufReader::new(f)) .map(Into::into) - .map_err(|err| PyValueError::new_err(format!("{err:#}"))) + .map_err(|err| { + PyValueError::new_err(format!( + "Cannot parse metadata from `{metadata_path:?}`: {err:#}" + )) + }) } #[getter] - fn version(&self) -> Option { + fn id(&self) -> &str { + &self.id + } + + #[getter] + fn name(&self) -> PyKernelName { + self.name.clone() + } + + #[getter] + fn version(&self) -> usize { self.version } #[getter] - fn license(&self) -> Option<&String> { - self.license.as_ref() + fn license(&self) -> &str { + &self.license } #[getter] - fn upstream(&self) -> Option<&String> { - self.upstream.as_ref() + fn upstream(&self) -> Option<&str> { + self.upstream.as_deref() } #[getter] @@ -246,7 +269,9 @@ impl PyMetadata { fn __repr__(&self) -> String { format!( - "Metadata(version={:?}, license={:?}, upstream={:?}, python_depends={:?}, backend={})", + "Metadata(id={}, name={:?}, version={:?}, license={:?}, upstream={:?}, python_depends={:?}, backend={})", + self.id, + self.name, self.version, self.license, self.upstream, diff --git a/kernels-data/bindings/python/tests/test_kernels_data.py b/kernels-data/bindings/python/tests/test_kernels_data.py index 71a103f7..2d49cc0b 100644 --- a/kernels-data/bindings/python/tests/test_kernels_data.py +++ b/kernels-data/bindings/python/tests/test_kernels_data.py @@ -90,7 +90,9 @@ def test_metadata_load_full(tmp_path): path.write_text( json.dumps( { + "id": "_my_kernel_8a3be8f", "version": 1, + "name": "my-kernel", "license": "Apache-2.0", "upstream": "https://github.com/example/kernel", "python-depends": ["torch"], @@ -98,7 +100,9 @@ def test_metadata_load_full(tmp_path): } ) ) - m = Metadata.load(path) + m = Metadata.read_from_file(path) + assert m.id == "_my_kernel_8a3be8f" + assert m.name == KernelName("my-kernel") assert m.version == 1 assert m.license == "Apache-2.0" assert m.upstream == "https://github.com/example/kernel" @@ -109,10 +113,21 @@ def test_metadata_load_full(tmp_path): def test_metadata_load_minimal(tmp_path): path = tmp_path / "metadata.json" - path.write_text(json.dumps({"python-depends": [], "backend": {"type": "cpu"}})) - m = Metadata.load(path) - assert m.version is None - assert m.license is None + path.write_text( + json.dumps( + { + "id": "_my_kernel_8a3be8f", + "version": 1, + "name": "my-kernel", + "license": "Apache-2.0", + "python-depends": [], + "backend": {"type": "cpu"}, + } + ) + ) + m = Metadata.read_from_file(path) + assert m.version == 1 + assert m.license == "Apache-2.0" assert m.upstream is None assert m.python_depends == [] assert m.backend.backend_type == Backend.CPU @@ -120,8 +135,19 @@ def test_metadata_load_minimal(tmp_path): def test_metadata_load_cann(tmp_path): path = tmp_path / "metadata.json" - path.write_text(json.dumps({"python-depends": [], "backend": {"type": "cann"}})) - assert Metadata.load(path).backend.backend_type == Backend.CANN + path.write_text( + json.dumps( + { + "id": "_my_kernel_8a3be8f", + "version": 1, + "name": "my-kernel", + "license": "Apache-2.0", + "python-depends": [], + "backend": {"type": "cann"}, + } + ) + ) + assert Metadata.read_from_file(path).backend.backend_type == Backend.CANN def test_metadata_load_unknown_field_accepted(tmp_path): @@ -129,31 +155,42 @@ def test_metadata_load_unknown_field_accepted(tmp_path): path.write_text( json.dumps( { + "id": "_my_kernel_8a3be8f", + "version": 1, + "name": "my-kernel", + "license": "Apache-2.0", "python-depends": [], "backend": {"type": "cpu"}, "surprise": "not allowed", } ) ) - Metadata.load(path) + Metadata.read_from_file(path) def test_metadata_load_malformed(tmp_path): path = tmp_path / "metadata.json" path.write_text("{not json") with pytest.raises(ValueError): - Metadata.load(path) + Metadata.read_from_file(path) def test_metadata_load(tmp_path): path = _write_metadata( tmp_path / "variant" / "metadata.json", - **{"python-depends": ["torch"], "backend": {"type": "cuda"}}, + **{ + "id": "_my_kernel_8a3be8f", + "version": 1, + "name": "my-kernel", + "license": "Apache-2.0", + "python-depends": ["torch"], + "backend": {"type": "cuda"}, + }, ) - m = Metadata.load(path) + m = Metadata.read_from_file(path) assert m.backend.backend_type == Backend.CUDA def test_metadata_load_missing_file(tmp_path): - with pytest.raises(ValueError): - Metadata.load(tmp_path / "does-not-exist.json") + with pytest.raises(OSError): + Metadata.read_from_file(tmp_path / "does-not-exist.json") diff --git a/kernels-data/src/config/compat.rs b/kernels-data/src/config/compat.rs index 9e1696c1..1e6abb89 100644 --- a/kernels-data/src/config/compat.rs +++ b/kernels-data/src/config/compat.rs @@ -2,14 +2,15 @@ use eyre::Result; use serde::Deserialize; use serde_value::Value; -use super::{Build, v1, v2, v3}; +use crate::config::ConfigError; + +use super::{Build, v3, v4}; #[derive(Debug)] #[allow(clippy::large_enum_variant)] pub enum BuildCompat { - V1(v1::Build), - V2(v2::Build), V3(v3::Build), + V4(v4::Build), } impl<'de> Deserialize<'de> for BuildCompat { @@ -19,22 +20,20 @@ impl<'de> Deserialize<'de> for BuildCompat { { let value = Value::deserialize(deserializer)?; - v1::Build::deserialize(value.clone()) - .map(BuildCompat::V1) - .or_else(|_| v2::Build::deserialize(value.clone()).map(BuildCompat::V2)) - .or_else(|_| v3::Build::deserialize(value.clone()).map(BuildCompat::V3)) + v3::Build::deserialize(value.clone()) + .map(BuildCompat::V3) + .or_else(|_| v4::Build::deserialize(value.clone()).map(BuildCompat::V4)) .map_err(serde::de::Error::custom) } } impl TryFrom for Build { - type Error = eyre::Error; + type Error = ConfigError; - fn try_from(compat: BuildCompat) -> Result { + fn try_from(compat: BuildCompat) -> Result { match compat { - BuildCompat::V1(v1_build) => v1_build.try_into(), - BuildCompat::V2(v2_build) => v2_build.try_into(), - BuildCompat::V3(v3_build) => Ok(v3_build.into()), + BuildCompat::V3(v3_build) => v3_build.try_into(), + BuildCompat::V4(v4_build) => Ok(v4_build.into()), } } } diff --git a/kernels-data/src/config/mod.rs b/kernels-data/src/config/mod.rs index bd9c86a9..2aa497d3 100644 --- a/kernels-data/src/config/mod.rs +++ b/kernels-data/src/config/mod.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, fmt::Display, path::PathBuf, str::FromStr}; use eyre::Result; use serde::{Deserialize, Serialize}; +use thiserror::Error; mod deps; pub use deps::{Dependency, PythonDependency}; @@ -12,9 +13,8 @@ pub use compat::BuildCompat; mod name; pub use name::KernelName; -pub mod v1; -pub mod v2; pub mod v3; +pub mod v4; use itertools::Itertools; @@ -28,7 +28,7 @@ pub struct Build { pub enum Framework { Torch(Torch), - TorchNoarch, + TorchNoarch(TorchNoarch), TvmFfi(TvmFfi), } @@ -64,10 +64,12 @@ impl Build { pub struct General { pub name: KernelName, - pub version: Option, + + /// Kernel API/ABI version. + pub version: usize, /// Hugging Face Hub license identifier. - pub license: Option, + pub license: String, /// Source repository or reference for the kernel code. pub upstream: Option, @@ -183,6 +185,7 @@ impl Torch { data_extensions(self.pyext.as_deref()) } } +pub struct TorchNoarch {} pub struct TvmFfi { pub include: Option>, @@ -355,3 +358,9 @@ impl FromStr for Backend { } } } + +#[derive(Debug, Error)] +pub enum ConfigError { + #[error("Cannot migrate configuration: {reason:?}")] + Migration { reason: String }, +} diff --git a/kernels-data/src/config/v3.rs b/kernels-data/src/config/v3.rs index 5a4c4be9..f12a3dac 100644 --- a/kernels-data/src/config/v3.rs +++ b/kernels-data/src/config/v3.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use serde::{Deserialize, Serialize}; use super::{Dependency, KernelName}; -use crate::version::Version; +use crate::{config::ConfigError, version::Version}; #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case")] @@ -158,8 +158,10 @@ pub enum Backend { Xpu, } -impl From for super::Build { - fn from(build: Build) -> Self { +impl TryFrom for super::Build { + type Error = ConfigError; + + fn try_from(build: Build) -> Result { let kernels: HashMap = build .kernels .into_iter() @@ -169,23 +171,29 @@ impl From for super::Build { let framework = match build.framework { Some(Framework::Torch(torch)) => super::Framework::Torch(torch.into()), Some(Framework::TvmFfi(tvm_ffi)) => super::Framework::TvmFfi(tvm_ffi.into()), - None => super::Framework::TorchNoarch, + None => super::Framework::TorchNoarch(super::TorchNoarch {}), }; - Self { - general: build.general.into(), + Ok(Self { + general: build.general.try_into()?, framework, kernels, - } + }) } } -impl From for super::General { - fn from(general: General) -> Self { - Self { +impl TryFrom for super::General { + type Error = ConfigError; + + fn try_from(general: General) -> Result { + let license = general.license.ok_or_else(|| ConfigError::Migration { + reason: "The `license` key is required in the `general` section".to_string(), + })?; + + Ok(Self { name: general.name, - version: general.version, - license: general.license, + version: general.version.unwrap_or(1), + license, upstream: general.upstream, backends: general.backends.into_iter().map(Into::into).collect(), cuda: general.cuda.map(Into::into), @@ -193,7 +201,7 @@ impl From for super::General { neuron: general.neuron.map(Into::into), python_depends: general.python_depends, xpu: general.xpu.map(Into::into), - } + }) } } @@ -341,185 +349,3 @@ impl From for super::Kernel { } } } - -impl From for Build { - fn from(build: super::Build) -> Self { - let framework = match build.framework { - super::Framework::Torch(torch) => Some(Framework::Torch(torch.into())), - super::Framework::TorchNoarch => None, - super::Framework::TvmFfi(tvm_ffi) => Some(Framework::TvmFfi(tvm_ffi.into())), - }; - - Self { - general: build.general.into(), - framework, - kernels: build - .kernels - .into_iter() - .map(|(k, v)| (k, v.into())) - .collect(), - } - } -} - -impl From for General { - fn from(general: super::General) -> Self { - Self { - name: general.name, - version: general.version, - license: general.license, - upstream: general.upstream, - backends: general.backends.into_iter().map(Into::into).collect(), - cuda: general.cuda.map(Into::into), - hub: general.hub.map(Into::into), - neuron: general.neuron.map(Into::into), - python_depends: general.python_depends, - xpu: general.xpu.map(Into::into), - } - } -} - -impl From for CudaGeneral { - fn from(cuda: super::CudaGeneral) -> Self { - Self { - minver: cuda.minver, - maxver: cuda.maxver, - python_depends: cuda.python_depends, - } - } -} - -impl From for NeuronGeneral { - fn from(neuron: super::NeuronGeneral) -> Self { - Self { - python_depends: neuron.python_depends, - } - } -} - -impl From for XpuGeneral { - fn from(xpu: super::XpuGeneral) -> Self { - Self { - python_depends: xpu.python_depends, - } - } -} - -impl From for Hub { - fn from(hub: super::Hub) -> Self { - Self { - repo_id: hub.repo_id, - branch: hub.branch, - } - } -} - -impl From for Torch { - fn from(torch: super::Torch) -> Self { - Self { - include: torch.include, - minver: torch.minver, - maxver: torch.maxver, - pyext: torch.pyext, - src: torch.src, - } - } -} - -impl From for TvmFfi { - fn from(tvm_ffi: super::TvmFfi) -> Self { - Self { - include: tvm_ffi.include, - pyext: tvm_ffi.pyext, - src: tvm_ffi.src, - } - } -} - -impl From for Backend { - fn from(backend: super::Backend) -> Self { - match backend { - super::Backend::Cann => Backend::Cann, - super::Backend::Cpu => Backend::Cpu, - super::Backend::Cuda => Backend::Cuda, - super::Backend::Metal => Backend::Metal, - super::Backend::Neuron => Backend::Neuron, - super::Backend::Rocm => Backend::Rocm, - super::Backend::Xpu => Backend::Xpu, - } - } -} - -impl From for Kernel { - fn from(kernel: super::Kernel) -> Self { - match kernel { - super::Kernel::Cpu { - cxx_flags, - depends, - include, - src, - } => Kernel::Cpu { - cxx_flags, - depends, - include, - src, - }, - super::Kernel::Cuda { - cuda_capabilities, - cuda_flags, - cuda_minver, - cxx_flags, - depends, - include, - src, - } => Kernel::Cuda { - cuda_capabilities, - cuda_flags, - cuda_minver, - cxx_flags, - depends, - include, - src, - }, - super::Kernel::Metal { - cxx_flags, - depends, - include, - src, - } => Kernel::Metal { - cxx_flags, - depends, - include, - src, - }, - super::Kernel::Rocm { - cxx_flags, - depends, - rocm_archs, - hip_flags, - include, - src, - } => Kernel::Rocm { - cxx_flags, - depends, - rocm_archs, - hip_flags, - include, - src, - }, - super::Kernel::Xpu { - cxx_flags, - depends, - sycl_flags, - include, - src, - } => Kernel::Xpu { - cxx_flags, - depends, - sycl_flags, - include, - src, - }, - } - } -} diff --git a/kernels-data/src/config/v4.rs b/kernels-data/src/config/v4.rs new file mode 100644 index 00000000..14386b35 --- /dev/null +++ b/kernels-data/src/config/v4.rs @@ -0,0 +1,549 @@ +use std::collections::HashMap; +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; + +use super::{Dependency, KernelName}; +use crate::version::Version; + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub struct Build { + pub general: General, + + #[serde(flatten)] + pub framework: Framework, + + #[serde(rename = "kernel", default)] + pub kernels: HashMap, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "kebab-case")] +pub enum Framework { + Torch(Torch), + TorchNoarch(TorchNoarch), + TvmFfi(TvmFfi), +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub struct General { + pub name: KernelName, + + pub version: usize, + + pub license: String, + + pub upstream: Option, + + pub backends: Vec, + + pub cuda: Option, + + pub hub: Option, + + pub neuron: Option, + + pub python_depends: Option>, + + pub xpu: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub struct CudaGeneral { + pub minver: Option, + pub maxver: Option, + pub python_depends: Option>, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub struct NeuronGeneral { + pub python_depends: Option>, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub struct XpuGeneral { + pub python_depends: Option>, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub struct Hub { + pub repo_id: Option, + pub branch: Option, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +#[serde(deny_unknown_fields)] +pub struct Torch { + pub include: Option>, + pub minver: Option, + pub maxver: Option, + pub pyext: Option>, + + #[serde(default)] + pub src: Vec, +} + +#[derive(Debug, Deserialize, Clone, Serialize)] +#[serde(deny_unknown_fields)] +pub struct TorchNoarch {} + +#[derive(Debug, Deserialize, Clone, Serialize)] +#[serde(deny_unknown_fields)] +pub struct TvmFfi { + pub include: Option>, + pub pyext: Option>, + pub src: Vec, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case", tag = "backend")] +pub enum Kernel { + #[serde(rename_all = "kebab-case")] + Cpu { + cxx_flags: Option>, + depends: Vec, + include: Option>, + src: Vec, + }, + #[serde(rename_all = "kebab-case")] + Cuda { + cuda_capabilities: Option>, + cuda_flags: Option>, + cuda_minver: Option, + cxx_flags: Option>, + depends: Vec, + include: Option>, + src: Vec, + }, + #[serde(rename_all = "kebab-case")] + Metal { + cxx_flags: Option>, + depends: Vec, + include: Option>, + src: Vec, + }, + #[serde(rename_all = "kebab-case")] + Rocm { + cxx_flags: Option>, + depends: Vec, + rocm_archs: Option>, + hip_flags: Option>, + include: Option>, + src: Vec, + }, + #[serde(rename_all = "kebab-case")] + Xpu { + cxx_flags: Option>, + depends: Vec, + sycl_flags: Option>, + include: Option>, + src: Vec, + }, +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub enum Backend { + Cann, + Cpu, + Cuda, + Metal, + Neuron, + Rocm, + Xpu, +} + +impl From for super::Build { + fn from(build: Build) -> Self { + let kernels: HashMap = build + .kernels + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(); + + Self { + general: build.general.into(), + framework: build.framework.into(), + kernels, + } + } +} + +impl From for super::General { + fn from(general: General) -> Self { + Self { + name: general.name, + version: general.version, + license: general.license, + upstream: general.upstream, + backends: general.backends.into_iter().map(Into::into).collect(), + cuda: general.cuda.map(Into::into), + hub: general.hub.map(Into::into), + neuron: general.neuron.map(Into::into), + python_depends: general.python_depends, + xpu: general.xpu.map(Into::into), + } + } +} + +impl From for super::Framework { + fn from(framework: Framework) -> Self { + match framework { + Framework::Torch(torch) => super::Framework::Torch(torch.into()), + Framework::TorchNoarch(torch_noarch) => { + super::Framework::TorchNoarch(torch_noarch.into()) + } + Framework::TvmFfi(tvm_ffi) => super::Framework::TvmFfi(tvm_ffi.into()), + } + } +} + +impl From for super::CudaGeneral { + fn from(cuda: CudaGeneral) -> Self { + Self { + minver: cuda.minver, + maxver: cuda.maxver, + python_depends: cuda.python_depends, + } + } +} + +impl From for super::NeuronGeneral { + fn from(neuron: NeuronGeneral) -> Self { + Self { + python_depends: neuron.python_depends, + } + } +} + +impl From for super::XpuGeneral { + fn from(xpu: XpuGeneral) -> Self { + Self { + python_depends: xpu.python_depends, + } + } +} + +impl From for super::Hub { + fn from(hub: Hub) -> Self { + Self { + repo_id: hub.repo_id, + branch: hub.branch, + } + } +} + +impl From for super::Torch { + fn from(torch: Torch) -> Self { + Self { + include: torch.include, + minver: torch.minver, + maxver: torch.maxver, + pyext: torch.pyext, + src: torch.src, + } + } +} + +impl From for super::TorchNoarch { + fn from(_torch_noarch: TorchNoarch) -> Self { + Self {} + } +} + +impl From for super::TvmFfi { + fn from(tvm_ffi: TvmFfi) -> Self { + Self { + include: tvm_ffi.include, + pyext: tvm_ffi.pyext, + src: tvm_ffi.src, + } + } +} + +impl From for super::Backend { + fn from(backend: Backend) -> Self { + match backend { + Backend::Cann => super::Backend::Cann, + Backend::Cpu => super::Backend::Cpu, + Backend::Cuda => super::Backend::Cuda, + Backend::Metal => super::Backend::Metal, + Backend::Neuron => super::Backend::Neuron, + Backend::Rocm => super::Backend::Rocm, + Backend::Xpu => super::Backend::Xpu, + } + } +} + +impl From for super::Kernel { + fn from(kernel: Kernel) -> Self { + match kernel { + Kernel::Cpu { + cxx_flags, + depends, + include, + src, + } => super::Kernel::Cpu { + cxx_flags, + depends, + include, + src, + }, + Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + } => super::Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + }, + Kernel::Metal { + cxx_flags, + depends, + include, + src, + } => super::Kernel::Metal { + cxx_flags, + depends, + include, + src, + }, + Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + } => super::Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + }, + Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + } => super::Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + }, + } + } +} + +impl From for Build { + fn from(build: super::Build) -> Self { + Self { + general: build.general.into(), + framework: build.framework.into(), + kernels: build + .kernels + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(), + } + } +} + +impl From for General { + fn from(general: super::General) -> Self { + Self { + name: general.name, + version: general.version, + license: general.license, + upstream: general.upstream, + backends: general.backends.into_iter().map(Into::into).collect(), + cuda: general.cuda.map(Into::into), + hub: general.hub.map(Into::into), + neuron: general.neuron.map(Into::into), + python_depends: general.python_depends, + xpu: general.xpu.map(Into::into), + } + } +} + +impl From for Framework { + fn from(framework: super::Framework) -> Self { + match framework { + super::Framework::Torch(torch) => Framework::Torch(torch.into()), + super::Framework::TorchNoarch(torch_noarch) => { + Framework::TorchNoarch(torch_noarch.into()) + } + super::Framework::TvmFfi(tvm_ffi) => Framework::TvmFfi(tvm_ffi.into()), + } + } +} + +impl From for CudaGeneral { + fn from(cuda: super::CudaGeneral) -> Self { + Self { + minver: cuda.minver, + maxver: cuda.maxver, + python_depends: cuda.python_depends, + } + } +} + +impl From for NeuronGeneral { + fn from(neuron: super::NeuronGeneral) -> Self { + Self { + python_depends: neuron.python_depends, + } + } +} + +impl From for XpuGeneral { + fn from(xpu: super::XpuGeneral) -> Self { + Self { + python_depends: xpu.python_depends, + } + } +} + +impl From for Hub { + fn from(hub: super::Hub) -> Self { + Self { + repo_id: hub.repo_id, + branch: hub.branch, + } + } +} + +impl From for Torch { + fn from(torch: super::Torch) -> Self { + Self { + include: torch.include, + minver: torch.minver, + maxver: torch.maxver, + pyext: torch.pyext, + src: torch.src, + } + } +} +impl From for TorchNoarch { + fn from(_torch_noarch: super::TorchNoarch) -> Self { + Self {} + } +} + +impl From for TvmFfi { + fn from(tvm_ffi: super::TvmFfi) -> Self { + Self { + include: tvm_ffi.include, + pyext: tvm_ffi.pyext, + src: tvm_ffi.src, + } + } +} + +impl From for Backend { + fn from(backend: super::Backend) -> Self { + match backend { + super::Backend::Cann => Backend::Cann, + super::Backend::Cpu => Backend::Cpu, + super::Backend::Cuda => Backend::Cuda, + super::Backend::Metal => Backend::Metal, + super::Backend::Neuron => Backend::Neuron, + super::Backend::Rocm => Backend::Rocm, + super::Backend::Xpu => Backend::Xpu, + } + } +} + +impl From for Kernel { + fn from(kernel: super::Kernel) -> Self { + match kernel { + super::Kernel::Cpu { + cxx_flags, + depends, + include, + src, + } => Kernel::Cpu { + cxx_flags, + depends, + include, + src, + }, + super::Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + } => Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + }, + super::Kernel::Metal { + cxx_flags, + depends, + include, + src, + } => Kernel::Metal { + cxx_flags, + depends, + include, + src, + }, + super::Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + } => Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + }, + super::Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + } => Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + }, + } + } +} diff --git a/kernels-data/src/metadata.rs b/kernels-data/src/metadata.rs index 559d0e24..34d3d1cd 100644 --- a/kernels-data/src/metadata.rs +++ b/kernels-data/src/metadata.rs @@ -1,9 +1,9 @@ -use std::{fs, path::Path}; +use std::str::FromStr; -use eyre::{Context, Result}; +use eyre::Result; use serde::{Deserialize, Serialize}; -use crate::config::Backend; +use crate::config::{Backend, KernelName}; #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "kebab-case")] @@ -18,20 +18,26 @@ pub struct BackendInfo { #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "kebab-case")] pub struct Metadata { - pub id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub version: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub license: Option, + pub name: KernelName, + pub id: String, + pub version: usize, + pub license: String, #[serde(skip_serializing_if = "Option::is_none")] pub upstream: Option, pub python_depends: Vec, pub backend: BackendInfo, } -pub fn parse_metadata(path: impl AsRef) -> Result { - let path = path.as_ref(); - let data = - fs::read_to_string(path).wrap_err_with(|| format!("Cannot read `{}`", path.display()))?; - serde_json::from_str(&data).wrap_err_with(|| format!("Cannot parse `{}`", path.display())) +impl Metadata { + pub fn from_reader(reader: R) -> Result { + Ok(serde_json::from_reader(reader)?) + } +} + +impl FromStr for Metadata { + type Err = eyre::Report; + + fn from_str(s: &str) -> Result { + Ok(serde_json::from_str(s)?) + } } diff --git a/kernels/pyproject.toml b/kernels/pyproject.toml index 0a4f3acc..e43311b4 100644 --- a/kernels/pyproject.toml +++ b/kernels/pyproject.toml @@ -11,6 +11,7 @@ readme = "README.md" requires-python = ">= 3.10" dependencies = [ "huggingface-hub>=1.10.0", + "kernels-data>=0.14.0.dev0", "packaging>=20.0", "pyyaml>=6", "tomli>=2.0; python_version<'3.11'", diff --git a/kernels/src/kernels/__init__.py b/kernels/src/kernels/__init__.py index 17f7cef8..8ef7add4 100644 --- a/kernels/src/kernels/__init__.py +++ b/kernels/src/kernels/__init__.py @@ -2,6 +2,8 @@ __version__ = importlib.metadata.version("kernels") +from kernels_data import Metadata + from kernels._windows import _add_additional_dll_paths from kernels.benchmark import Benchmark from kernels.layer import ( @@ -22,6 +24,8 @@ use_kernel_mapping, ) from kernels.utils import ( + LoadedKernel, + RepoInfo, get_kernel, get_loaded_kernels, get_local_kernel, @@ -40,11 +44,14 @@ "Device", "FuncRepository", "LayerRepository", + "LoadedKernel", "LocalFuncRepository", "LocalLayerRepository", "LockedFuncRepository", "LockedLayerRepository", + "Metadata", "Mode", + "RepoInfo", "get_kernel", "get_loaded_kernels", "get_local_kernel", diff --git a/kernels/src/kernels/layer/func.py b/kernels/src/kernels/layer/func.py index a4b3e699..d98551e3 100644 --- a/kernels/src/kernels/layer/func.py +++ b/kernels/src/kernels/layer/func.py @@ -111,8 +111,6 @@ class LocalFuncRepository: Args: repo_path (`Path`): The local repository containing the layer. - package_name (`str`): - Package name of the kernel. func_name (`str`): The name of the function within the kernel repository. @@ -125,7 +123,6 @@ class LocalFuncRepository: # Reference a specific layer by revision layer_repo = LocalFuncRepository( repo_path=Path("/home/daniel/kernels/activation"), - package_name="activation", func_name="silu_and_mul", ) ``` @@ -135,15 +132,13 @@ def __init__( self, repo_path: Path, *, - package_name: str, func_name: str, ): self._repo_path = repo_path - self._package_name = package_name self.func_name = func_name def load(self) -> Type["nn.Module"]: - kernel = get_local_kernel(self._repo_path, self._package_name) + kernel = get_local_kernel(self._repo_path) return _get_kernel_func(self, kernel) def __eq__(self, other): @@ -151,14 +146,13 @@ def __eq__(self, other): isinstance(other, LocalFuncRepository) and self.func_name == other.func_name and self._repo_path == other._repo_path - and self._package_name == other._package_name ) def __hash__(self): - return hash((self.func_name, self._repo_path, self._package_name)) + return hash((self.func_name, self._repo_path)) def __str__(self) -> str: - return f"`{self._repo_path}` (package: {self._package_name}), layer `{self.func_name}`" + return f"`{self._repo_path}` (layer `{self.func_name}`" def use_kernel_func_from_hub(func_name: str): diff --git a/kernels/src/kernels/layer/kernelize.py b/kernels/src/kernels/layer/kernelize.py index 60c9525d..39a13ce4 100644 --- a/kernels/src/kernels/layer/kernelize.py +++ b/kernels/src/kernels/layer/kernelize.py @@ -206,7 +206,7 @@ def kernelize( import torch import torch.nn as nn - from kernels import kernelize, Mode, register_kernel_mapping, LayerRepository + from kernels import kernelize, Mode, use_kernel_mapping, LayerRepository from kernels import use_kernel_forward_from_hub @use_kernel_forward_from_hub("SiluAndMul") @@ -220,10 +220,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: "cuda": LayerRepository( repo_id="kernels-community/activation", layer_name="SiluAndMul", + version=1, ) } } - register_kernel_mapping(mapping) # Create and kernelize a model model = nn.Sequential( @@ -232,7 +232,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) # Kernelize for inference - kernelized_model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE) + with use_kernel_mapping(mapping): + kernelized_model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE) ``` """ diff --git a/kernels/src/kernels/layer/layer.py b/kernels/src/kernels/layer/layer.py index dad80bd6..505aacb8 100644 --- a/kernels/src/kernels/layer/layer.py +++ b/kernels/src/kernels/layer/layer.py @@ -110,8 +110,6 @@ class LocalLayerRepository: Args: repo_path (`Path`): The local repository containing the layer. - package_name (`str`): - Package name of the kernel. layer_name (`str`): The name of the layer within the kernel repository. @@ -124,7 +122,6 @@ class LocalLayerRepository: # Reference a specific layer by revision layer_repo = LocalLayerRepository( repo_path=Path("/home/daniel/kernels/activation"), - package_name="activation", layer_name="SiluAndMul", ) ``` @@ -134,15 +131,13 @@ def __init__( self, repo_path: Path, *, - package_name: str, layer_name: str, ): self._repo_path = repo_path - self._package_name = package_name self.layer_name = layer_name def load(self) -> Type["nn.Module"]: - kernel = get_local_kernel(self._repo_path, self._package_name) + kernel = get_local_kernel(self._repo_path) return _get_kernel_layer(self, kernel) def __eq__(self, other): @@ -150,14 +145,13 @@ def __eq__(self, other): isinstance(other, LocalLayerRepository) and self.layer_name == other.layer_name and self._repo_path == other._repo_path - and self._package_name == other._package_name ) def __hash__(self): - return hash((self.layer_name, self._repo_path, self._package_name)) + return hash((self.layer_name, self._repo_path)) def __str__(self) -> str: - return f"`{self._repo_path}` (package: {self._package_name}), layer `{self.layer_name}`" + return f"`{self._repo_path}` (layer `{self.layer_name}`" class LockedLayerRepository: diff --git a/kernels/src/kernels/metadata.py b/kernels/src/kernels/metadata.py deleted file mode 100644 index bd9538da..00000000 --- a/kernels/src/kernels/metadata.py +++ /dev/null @@ -1,44 +0,0 @@ -import json -import warnings -from dataclasses import dataclass -from pathlib import Path - -from huggingface_hub.dataclasses import strict - - -@strict -@dataclass -class Metadata: - id: str | None - python_depends: list[str] - version: int | None - - @staticmethod - def load_from_variant(variant_path: Path) -> "Metadata": - metadata_path = variant_path / "metadata.json" - if metadata_path.exists(): - with open(metadata_path, "r") as f: - metadata_dict = json.load(f) - if (kernel_id := metadata_dict.get("id", None)) is None: - warnings.warn( - f"Metadata for kernel loaded from `{variant_path}` does have an identifier," - " identifiers will become required in kernels >= 0.15\n" - "Run `nix flake update in your kernel directory and rebuild to generate metadata.", - UserWarning, - stacklevel=2, - ) - return Metadata( - id=kernel_id, - python_depends=metadata_dict.get("python-depends", []), - version=metadata_dict.get("version", None), - ) - - warnings.warn( - f"Kernel loaded from `{variant_path}` does not have metadata," - " metadata will be required in kernels >= 0.15\n" - "Run `nix flake update in your kernel directory and rebuild to generate metadata.", - UserWarning, - stacklevel=2, - ) - - return Metadata(id=None, version=None, python_depends=[]) diff --git a/kernels/src/kernels/utils.py b/kernels/src/kernels/utils.py index eea224e9..829c740f 100644 --- a/kernels/src/kernels/utils.py +++ b/kernels/src/kernels/utils.py @@ -1,4 +1,3 @@ -import ctypes import functools import hashlib import importlib @@ -14,6 +13,7 @@ from types import ModuleType from huggingface_hub import HfApi, constants +from kernels_data import Metadata from kernels._system import glibc_version from kernels._versions import select_revision_or_version @@ -21,7 +21,6 @@ from kernels.compat import has_torch, has_tvm_ffi from kernels.deps import validate_dependencies from kernels.lockfile import KernelLock, VariantLock -from kernels.metadata import Metadata from kernels.status import resolve_status from kernels.variants import ( Variant, @@ -34,18 +33,46 @@ @dataclass(frozen=True) -class RepoInfos: +class RepoInfo: + """ + This dataclass stores the origin of the kernel. + + The following fields are available: + + - `repo_id` (`str`): the Hub repository containing the kernel. + - `revision` (`str`): the specific revision of the kernel. + """ + repo_id: str revision: str - backend: str | None @dataclass(frozen=True) class LoadedKernel: - kernel_id: str + """ + This dataclass provides information about a loaded kernel: + + - `metadata` (`Metadata`): kernel metadata. + - `module` (`ModuleType`): the imported kernel module. + - `repo_info` (`kernels.utils.RepoInfo | None`): populated only for + kernels loaded via `get_kernel`. Loaders that work from a local path + (`get_local_kernel`) or a lockfile (`get_locked_kernel`, `load_kernel`) + leave this as `None`. + + The metadata includes the following properties that describe a kernel: + + - `id` (`str`): kernel identifier that is unique to the kernel version + backend. + - `name` (`str`): the name of the kernel. + - `version` (`int`): the version of the kernel. + - `license` (`str`): the license of the kernel. + - `upstream` (`str | None`): the upstream repository of the kernel. + - `python_depends` (`list[str]`): required Python dependencies. + - `backend`: information about the kernel's backend. + """ + + metadata: Metadata module: ModuleType - module_name: str - repo_infos: RepoInfos | None + repo_info: RepoInfo | None _loaded_kernels: dict[Path, LoadedKernel] = {} @@ -55,28 +82,10 @@ def get_loaded_kernels() -> list[LoadedKernel]: """ Return a snapshot of every kernel that has been loaded into the current process. - Each entry is a `kernels.utils.LoadedKernel` dataclass with fields: - - - `kernel_id` (`str`): unique identifier used as the `sys.modules` key - for this variant (either `metadata.id` or a hash-suffixed module name). - - `module` (`ModuleType`): the imported kernel module. - - `module_name` (`str`): the kernel's module name. - - `repo_infos` (`kernels.utils.RepoInfos | None`): populated only for - kernels loaded via `get_kernel`. Loaders that work from a local path - (`get_local_kernel`) or a lockfile (`get_locked_kernel`, `load_kernel`) - leave this as `None`. - - `RepoInfos` has `repo_id`, `revision`, and `backend` fields. `backend` - reflects the value passed by the caller — it is `None` when the caller - relied on backend auto-detection. - The returned list is a new list; mutating it does not affect the registry. - > [!NOTE] - > These arguments might be renamed / changed a bit. - Returns: - `list[LoadedKernel]`: one entry per distinct kernel variant path + `list[LoadedKernel]`: One [`LoadedKernel`] per distinct kernel variant path loaded in this process. Example: @@ -85,7 +94,7 @@ def get_loaded_kernels() -> list[LoadedKernel]: get_kernel("kernels-community/activation", version=1) for loaded in get_loaded_kernels(): - print(loaded.module_name, loaded.repo_infos) + print(loaded.metadata.name, loaded.repo_info) ``` """ return list(_loaded_kernels.values()) @@ -122,41 +131,33 @@ def _parse_local_kernel_overrides(local_kernels: str) -> dict[str, Path]: CACHE_DIR: str | None = _get_cache_dir() -def _import_from_path(module_name: str, variant_path: Path, _repo_infos: RepoInfos | None = None) -> ModuleType: +def _import_from_path(variant_path: Path, repo_info: RepoInfo | None = None) -> ModuleType: if (loaded_kernel := _loaded_kernels.get(variant_path)) is not None: return loaded_kernel.module - metadata = Metadata.load_from_variant(variant_path) + metadata = Metadata.read_from_file(variant_path / "metadata.json") + module_name = metadata.name.python_name validate_dependencies(module_name, metadata.python_depends, _backend()) file_path = variant_path / "__init__.py" if not file_path.exists(): file_path = variant_path / module_name / "__init__.py" + if not file_path.exists(): + raise FileNotFoundError(f"No kernel module found at: `{variant_path}`") - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - if metadata.id is None: - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value) - kernel_id = f"{module_name}_{path_hash}" - else: - kernel_id = metadata.id - - spec = importlib.util.spec_from_file_location(kernel_id, file_path) + spec = importlib.util.spec_from_file_location(metadata.id, file_path) if spec is None: raise ImportError(f"Cannot load spec for {module_name} from {file_path}") module = importlib.util.module_from_spec(spec) if module is None: raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[kernel_id] = module + sys.modules[metadata.id] = module spec.loader.exec_module(module) # type: ignore _loaded_kernels[variant_path] = LoadedKernel( - kernel_id=kernel_id, + metadata=metadata, module=module, - module_name=module_name, - repo_infos=_repo_infos, + repo_info=repo_info, ) return module @@ -169,7 +170,7 @@ def install_kernel( backend: str | None = None, variant_locks: dict[str, VariantLock] | None = None, user_agent: str | dict | None = None, -) -> tuple[str, Path]: +) -> Path: """ Download a kernel for the current environment to the cache. @@ -191,15 +192,13 @@ def install_kernel( The `user_agent` info to pass to `snapshot_download()` for internal telemetry. Returns: - `tuple[str, Path]`: A tuple containing the package name and the path to the variant directory. + `Path`: The path to the variant directory. """ api = _get_hf_api(user_agent=user_agent) if not local_files_only: repo_id, revision = resolve_status(api, repo_id, revision) - package_name = package_name_from_repo_id(repo_id) - variants = get_variants(api, repo_id=repo_id, revision=revision) variant = resolve_variant(variants, backend) @@ -226,7 +225,6 @@ def install_kernel( try: return _find_kernel_in_repo_path( repo_path, - package_name, variant=variant, variant_locks=variant_locks, ) @@ -236,11 +234,10 @@ def install_kernel( def _find_kernel_in_repo_path( repo_path: Path, - package_name: str, *, variant: Variant, variant_locks: dict[str, VariantLock] | None = None, -) -> tuple[str, Path]: +) -> Path: variant_str = variant.variant_str variant_path = repo_path / "build" / variant_str if not variant_path.exists(): @@ -252,15 +249,7 @@ def _find_kernel_in_repo_path( raise ValueError(f"No lock found for build variant: {variant}") validate_kernel(repo_path=repo_path, variant=variant_str, hash=variant_lock.hash) - module_init_path = variant_path / "__init__.py" - if not os.path.exists(module_init_path): - # Compatibility with older kernels. - module_init_path = variant_path / package_name / "__init__.py" - - if not os.path.exists(module_init_path): - raise FileNotFoundError(f"No kernel module found at: `{variant_path}`") - - return package_name, variant_path + return variant_path def install_kernel_all_variants( @@ -340,26 +329,24 @@ def get_kernel( """ override = _get_local_kernel_overrides().get(repo_id, None) if override is not None: - return get_local_kernel(override, package_name_from_repo_id(repo_id)) + return get_local_kernel(override) revision = select_revision_or_version(repo_id, revision=revision, version=version) - repo_infos = RepoInfos( + repo_info = RepoInfo( repo_id=repo_id, revision=revision, - backend=backend, ) - package_name, variant_path = install_kernel( + variant_path = install_kernel( repo_id, - revision=revision, backend=backend, + revision=revision, user_agent=user_agent, ) - return _import_from_path(package_name, variant_path, _repo_infos=repo_infos) + return _import_from_path(variant_path, repo_info=repo_info) def get_local_kernel( repo_path: Path, - package_name: str, backend: str | None = None, ) -> ModuleType: """ @@ -368,8 +355,6 @@ def get_local_kernel( Args: repo_path (`Path`): The local path to the kernel repository. - package_name (`str`): - The name of the package to import from the repository. backend (`str`, *optional*): The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for. The backend will be detected automatically if not provided. @@ -382,15 +367,15 @@ def get_local_kernel( variant = resolve_variant(variants, backend) if variant is not None: - return _import_from_path(package_name, base_path / variant.variant_str) + return _import_from_path(base_path / variant.variant_str) # If we didn't find the package in the repo we may have a explicit # package path. variant_path = repo_path if variant_path.exists(): - return _import_from_path(package_name, variant_path) + return _import_from_path(variant_path) - raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}") + raise FileNotFoundError(f"Could not find kernel in {repo_path}") def has_kernel( @@ -417,7 +402,6 @@ def has_kernel( `bool`: `True` if a kernel is available for the current environment. """ revision = select_revision_or_version(repo_id, revision=revision, version=version) - package_name = package_name_from_repo_id(repo_id) api = _get_hf_api() variants = get_variants(api, repo_id=repo_id, revision=revision) @@ -426,16 +410,12 @@ def has_kernel( if variant is None: return False - for init_file in ["__init__.py", f"{package_name}/__init__.py"]: - if api.file_exists( - repo_id, - repo_type="kernel", - revision=revision, - filename=f"build/{variant.variant_str}/{init_file}", - ): - return True - - return False + return api.file_exists( + repo_id, + repo_type="kernel", + revision=revision, + filename=f"build/{variant.variant_str}/metadata.json", + ) def load_kernel( @@ -472,8 +452,6 @@ def load_kernel( f"Kernel `{repo_id}` is not locked. Please lock it with `kernels lock ` and then reinstall the project." ) - package_name = package_name_from_repo_id(repo_id) - api = _get_hf_api() variants = get_variants(api, repo_id=repo_id, revision=locked_sha) variant = resolve_variant(variants, backend) @@ -498,13 +476,12 @@ def load_kernel( ) try: - package_name, variant_path = _find_kernel_in_repo_path( + variant_path = _find_kernel_in_repo_path( repo_path, - package_name, variant=variant, variant_locks=None, ) - return _import_from_path(package_name, variant_path) + return _import_from_path(variant_path) except FileNotFoundError: raise FileNotFoundError( f"Locked kernel `{repo_id}` does not have applicable variant or was not downloaded with `kernels download `" @@ -529,9 +506,9 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleTyp if locked_sha is None: raise ValueError(f"Kernel `{repo_id}` is not locked") - package_name, variant_path = install_kernel(repo_id, revision=locked_sha, local_files_only=local_files_only) + variant_path = install_kernel(repo_id, revision=locked_sha, local_files_only=local_files_only) - return _import_from_path(package_name, variant_path) + return _import_from_path(variant_path) def _get_caller_locked_kernel(repo_id: str) -> str | None: @@ -631,10 +608,6 @@ def git_hash_object(data: bytes, object_type: str = "blob"): return m.digest() -def package_name_from_repo_id(repo_id: str) -> str: - return repo_id.split("/")[-1].replace("-", "_") - - def _platform() -> str: cpu = platform.machine() os = platform.system().lower() diff --git a/kernels/tests/kernel_locking/kernels.lock b/kernels/tests/kernel_locking/kernels.lock index 1244cc61..a56eec14 100644 --- a/kernels/tests/kernel_locking/kernels.lock +++ b/kernels/tests/kernel_locking/kernels.lock @@ -1,14 +1,14 @@ [ { "repo_id": "kernels-community/relu", - "sha": "c5c04da39d361647f23f4a2c85839354395265b3", + "sha": "468c458e7cfb386377d1f521bc2b8c77d1eb2562", "variants": { "torch-ext": { "hash": "sha256-fa95388531c6280130219f0e73e5daf521116da6c841fa5ab6a190c7994767d8", "hash_type": "git_lfs_concat" }, "torch210-cpu-aarch64-darwin": { - "hash": "sha256-5b3593363e1d23ba1eaada7bf3a2a2c0f569556b31f754568b0c22b7b6d50945", + "hash": "sha256-e0492802726713ff5ec3362350d95ddde207c39c16a2ebad4bfc42a51b15b593", "hash_type": "git_lfs_concat" }, "torch210-cu128-x86_64-windows": { @@ -16,111 +16,111 @@ "hash_type": "git_lfs_concat" }, "torch210-cxx11-cpu-aarch64-linux": { - "hash": "sha256-9719e00060fa402ac8ca6f45ca6cb26ca3c541006b093c31ae4df941a801195b", + "hash": "sha256-e2c191aa0bd6b761867f17c1264547b73f63bffda0f0e8131d3f253935e1c4a1", "hash_type": "git_lfs_concat" }, "torch210-cxx11-cpu-x86_64-linux": { - "hash": "sha256-51a0d5c8b4ab86c80f5dd5521018e5572aed8ce04e2e3dcc4a6e0bba095835dd", + "hash": "sha256-12818400aa094e343ed02d94d61243fa5a81e222b1c3ef81d5a1b46026914e4c", "hash_type": "git_lfs_concat" }, "torch210-cxx11-cu126-aarch64-linux": { - "hash": "sha256-aafe1cbbb8f18e53b93fcb17661e17437b637e690012f7dc52bafadcfb2459a7", + "hash": "sha256-6b20165a41b9cd4fc72722d66dd9f12d9bdfee1eb5c3059cd115b11af906759f", "hash_type": "git_lfs_concat" }, "torch210-cxx11-cu126-x86_64-linux": { - "hash": "sha256-dc7cb0ced47ab05b1fc79ef2942499ab99567acb349505e84815a6940e1ba261", + "hash": "sha256-bc8ec6e982d9fdff7366090a6623ae1ad38f6f35eaacd792e23f10aa9e9ab7b6", "hash_type": "git_lfs_concat" }, "torch210-cxx11-cu128-aarch64-linux": { - "hash": "sha256-6c19801c1292707dd4643a587cfb38458f125f4370a6ee10b324925b352f2c88", + "hash": "sha256-8cd1a1e36477a224b4de6441440fa3265672a7cfbb9289f57b11800ef3e809bd", "hash_type": "git_lfs_concat" }, "torch210-cxx11-cu128-x86_64-linux": { - "hash": "sha256-e65f32a8cfdd2576fb0e05d7d67c47d2457bf92277cdff9310c6e2b8a3da7e7c", + "hash": "sha256-c0be23298e650e705b5b7cb6db00fb65b8beeee39dcbc0718221f16802eccc07", "hash_type": "git_lfs_concat" }, "torch210-cxx11-cu130-aarch64-linux": { - "hash": "sha256-ca5d1a1b4981677fb39c22cc1c59362e2d295ca5dd4a167f1220efb95f4947f0", + "hash": "sha256-a1bc582a87834c6728c50f69f8a954d03e3dc73cce2f775d0bc6f1c42a1d40dd", "hash_type": "git_lfs_concat" }, "torch210-cxx11-cu130-x86_64-linux": { - "hash": "sha256-3a655b2cfc890df92909774afdeea57210c612a13daee5bd39f0fe21f2803e3e", + "hash": "sha256-cfbcb3fe7c3417c64502f84b97ddeb4d4f6c2e9bd1bc7d155accbf19682e5467", "hash_type": "git_lfs_concat" }, "torch210-cxx11-rocm70-x86_64-linux": { - "hash": "sha256-df7c10b302c8b3a3eec4181ebdd7afc8c9638ba6af8967e880ed88cb9625e2f9", + "hash": "sha256-ef3061f180a935be046caaedbe8eb3c9611e5a4228d88d1100f6357a3b5aa46c", "hash_type": "git_lfs_concat" }, "torch210-cxx11-rocm71-x86_64-linux": { - "hash": "sha256-80e9d8b738e4f373218ab4993bde6d2d7218769af0547519cebfcfeea7b7de6e", + "hash": "sha256-4ec58bc0dd36b98f1153ad2c870906995a5f13278273a5ce9eda861047380ec7", "hash_type": "git_lfs_concat" }, "torch210-cxx11-xpu20253-x86_64-linux": { - "hash": "sha256-ec52122ac5310193e6f0fc312497b4acc5b7bc14cd4d91f090b1f4ea3b6a278c", + "hash": "sha256-aa4eaa81b2f260f6e35d8e7ab4f1aca5cdb91bcb49ed8c8f641820919c8209ab", "hash_type": "git_lfs_concat" }, "torch210-metal-aarch64-darwin": { - "hash": "sha256-36761eb7b0bc39d8f0576a8c813baa83012aead788fae41d7273a34863a9d3da", + "hash": "sha256-424c70191e4808aec776987b3b5ce9ce184348513ce703e200593a71b40cbea3", "hash_type": "git_lfs_concat" }, "torch210-xpu20253-x86_64-windows": { - "hash": "sha256-b8fe104cac2ddb517a8618a952daa1b7390895e00f0a8ed80e2af4939a3f37f7", + "hash": "sha256-47f3e41d16251a45edaf5d80c8de531cf037bf7d64cb25c6f9632a33d33b669e", "hash_type": "git_lfs_concat" }, "torch211-cpu-aarch64-darwin": { - "hash": "sha256-16a94ae38dcf0d819e7d1e2d8d9d3fb6ad3ef18939160052466d54008b20086e", + "hash": "sha256-c6cda4b2f0f83f4a86939575e9dffcd74d6cdaa6b89da540e86e8d536156245b", "hash_type": "git_lfs_concat" }, "torch211-cu128-x86_64-windows": { - "hash": "sha256-8b6e99eb670c97610daf24c0e1ab2b6cbe3f2278ee02e2d0a9ed27d5946f70b9", + "hash": "sha256-f2367fb89a5995458ed0c5011c7ba69b9b0f2fc43eb434b5b217a7b5ee890783", "hash_type": "git_lfs_concat" }, "torch211-cxx11-cpu-aarch64-linux": { - "hash": "sha256-e8bb563a19386ce4147d500b25af632d8e86addcbacf2074f4bc6051beee2b62", + "hash": "sha256-2ec67a9846bba2ac8a053b4f11d00e086315d4e028b94e5e8168ed00444ee5da", "hash_type": "git_lfs_concat" }, "torch211-cxx11-cpu-x86_64-linux": { - "hash": "sha256-84b7da0d60a86c2548042fc33229fa4c4c038edb65cc2b5fba64ba8561631492", + "hash": "sha256-4f9c7dcb0e97ccb0507dd80881c7fc5446fb3626856a42f172c40c65557b064f", "hash_type": "git_lfs_concat" }, "torch211-cxx11-cu126-aarch64-linux": { - "hash": "sha256-b6f243f6bf67053051508909bdb70fc32c49ebc6644f50d45fb0a990755df02e", + "hash": "sha256-7b9ec9ae8d9ceb5af1c9be1517cd0de989b1f1516bb6ff956e435224302f3cf6", "hash_type": "git_lfs_concat" }, "torch211-cxx11-cu126-x86_64-linux": { - "hash": "sha256-885e5fe3b3f4d7718cfded3bca3e80e248fa0708e508026b838f702f60bb53b2", + "hash": "sha256-de3d3cbdb9817773637da80fab88870e7ad9ce7d7e40899b3a4e5cd7182dc645", "hash_type": "git_lfs_concat" }, "torch211-cxx11-cu128-aarch64-linux": { - "hash": "sha256-a58b8fb0675f0fcd26f541b4afdb4a4cb214d7e7af725f8edf1dd7ac928172a0", + "hash": "sha256-559aeba2d57dd3125984c8dd1b1e5358b6807d9b8f1fdc165519e6f1c5e33984", "hash_type": "git_lfs_concat" }, "torch211-cxx11-cu128-x86_64-linux": { - "hash": "sha256-11f3e44fb46edae700319c978fdc8055a69eb4db9b43a0cb3ebb76eb7b3d3791", + "hash": "sha256-f078aa837db738094599b395372dfb50982dfbeb2538b36f2f2d2b3eec127a51", "hash_type": "git_lfs_concat" }, "torch211-cxx11-cu130-aarch64-linux": { - "hash": "sha256-599f61b3519ff7104f7120e92da02446b6e3f20418cb457a8bbdb3f5516d99ed", + "hash": "sha256-7ed13391cef18c36cc446d210d2fd8592993d45228a674182ee2cff375fa05ae", "hash_type": "git_lfs_concat" }, "torch211-cxx11-cu130-x86_64-linux": { - "hash": "sha256-6e061a814a6788af63203e34ff120d1dec7b14f3f9acb5341308b3cf80dbcb66", + "hash": "sha256-42019ed67eff74f5168a2e9d6bf276d047b29c11d5e3395a342893a3f50daae1", "hash_type": "git_lfs_concat" }, "torch211-cxx11-rocm71-x86_64-linux": { - "hash": "sha256-587f23dc8b41f4a64be22ed36da47bb581e8c222edb8c0dbe3cd89e94107d403", + "hash": "sha256-d82aff83823ac9ee81c78db67b860b40d9cbbd6b7efe8640693afdbe16f3e52c", "hash_type": "git_lfs_concat" }, "torch211-cxx11-rocm72-x86_64-linux": { - "hash": "sha256-36f47c6789c7070733490b2f52757978e1fc138852f1e41cb94099755b1df07b", + "hash": "sha256-772ec7f9b0917421f562d162b7db8535cac09b92cc2ce670ba467a08f1dfb9b0", "hash_type": "git_lfs_concat" }, "torch211-cxx11-xpu20253-x86_64-linux": { - "hash": "sha256-0bf2266e40b31dc00c7c47cb070d9e1e7ea93a3a379632c977331fae1605097a", + "hash": "sha256-4466ae2451d960a6c57a456d0c6a5f7397ccabcc6fb43fc3c0dcf26aecda5cfa", "hash_type": "git_lfs_concat" }, "torch211-metal-aarch64-darwin": { - "hash": "sha256-d129fab84a4acb7be9f5024b66d03f28e7fed84acdd4f00919263f5142a84313", + "hash": "sha256-b9bb76b444aae35782626327647f77f651f67ec666b58f1d3a8f2f251af88a50", "hash_type": "git_lfs_concat" }, "torch27-cxx11-cu118-x86_64-linux": { @@ -196,11 +196,11 @@ "hash_type": "git_lfs_concat" }, "torch29-cxx11-cu129-aarch64-linux": { - "hash": "sha256-bff0cab8ed842ffcd0bce140aa0c2b32fc782b6dd4f8d6c512bf8f3720d97e50", + "hash": "sha256-5cf7aba2ae1c07e0634c705c3598b646d690c876dde8ca0159c989486f4b8d7f", "hash_type": "git_lfs_concat" }, "torch29-cxx11-cu129-x86_64-linux": { - "hash": "sha256-b603cde7f38ec007c1b036f91a2e611c422d18cc848fd5932937f2565d4f71d1", + "hash": "sha256-06ac0e6e8b044de8741f38ea1b03a35eedbd83a272f19ffc13e33f45fa008e49", "hash_type": "git_lfs_concat" }, "torch29-cxx11-cu130-aarch64-linux": { @@ -231,26 +231,22 @@ }, { "repo_id": "kernels-test/versions", - "sha": "f97b8dea3cf2b9b7a38759bc74c6014a3d7f1c19", + "sha": "f609e51b856b3d874b0ae8445913e200f02c1735", "variants": { "torch-cpu": { - "hash": "sha256-d70e804797597372e50001a1e631e96bb38ccec669f5f0a47d7e9863af293447", + "hash": "sha256-d9fced49c3beeb47ce5b996c28f760bec579cf444f9fe1040725686b13b1336e", "hash_type": "git_lfs_concat" }, "torch-cuda": { - "hash": "sha256-d70e804797597372e50001a1e631e96bb38ccec669f5f0a47d7e9863af293447", + "hash": "sha256-b6b786e496802bc3d1ae59cb48acb419d99ca1ffbeee0fecbf16238e36c00842", "hash_type": "git_lfs_concat" }, "torch-rocm": { - "hash": "sha256-d70e804797597372e50001a1e631e96bb38ccec669f5f0a47d7e9863af293447", - "hash_type": "git_lfs_concat" - }, - "torch-universal": { - "hash": "sha256-57de77a9bde54f52d0a67eb9e5d259d223cae66963f644ed2b7386f59f7e2d23", + "hash": "sha256-0ca53208b80507e0b592cbfb7c0c9b5f22a4def6413b809c551dabbea1e2ac86", "hash_type": "git_lfs_concat" }, "torch-xpu": { - "hash": "sha256-d70e804797597372e50001a1e631e96bb38ccec669f5f0a47d7e9863af293447", + "hash": "sha256-09b9d84cb273d71a081daef373c42bec0691620591822e44a6811833f896b810", "hash_type": "git_lfs_concat" } } diff --git a/kernels/tests/layer_locking/kernels.lock b/kernels/tests/layer_locking/kernels.lock index 1796bd89..a102598c 100644 --- a/kernels/tests/layer_locking/kernels.lock +++ b/kernels/tests/layer_locking/kernels.lock @@ -1,26 +1,22 @@ [ { "repo_id": "kernels-test/versions", - "sha": "31a8142c476b8933320aae9a198945dcebe20f45", + "sha": "f609e51b856b3d874b0ae8445913e200f02c1735", "variants": { "torch-cpu": { - "hash": "sha256-d70e804797597372e50001a1e631e96bb38ccec669f5f0a47d7e9863af293447", + "hash": "sha256-d9fced49c3beeb47ce5b996c28f760bec579cf444f9fe1040725686b13b1336e", "hash_type": "git_lfs_concat" }, "torch-cuda": { - "hash": "sha256-d70e804797597372e50001a1e631e96bb38ccec669f5f0a47d7e9863af293447", + "hash": "sha256-b6b786e496802bc3d1ae59cb48acb419d99ca1ffbeee0fecbf16238e36c00842", "hash_type": "git_lfs_concat" }, "torch-rocm": { - "hash": "sha256-d70e804797597372e50001a1e631e96bb38ccec669f5f0a47d7e9863af293447", - "hash_type": "git_lfs_concat" - }, - "torch-universal": { - "hash": "sha256-cf447f2d128fc60937aee4e2cfe2a915df0829b170a86e5760783726de86a731", + "hash": "sha256-0ca53208b80507e0b592cbfb7c0c9b5f22a4def6413b809c551dabbea1e2ac86", "hash_type": "git_lfs_concat" }, "torch-xpu": { - "hash": "sha256-d70e804797597372e50001a1e631e96bb38ccec669f5f0a47d7e9863af293447", + "hash": "sha256-09b9d84cb273d71a081daef373c42bec0691620591822e44a6811833f896b810", "hash_type": "git_lfs_concat" } } diff --git a/kernels/tests/test_basic.py b/kernels/tests/test_basic.py index 889889b6..c2978a04 100644 --- a/kernels/tests/test_basic.py +++ b/kernels/tests/test_basic.py @@ -10,21 +10,19 @@ @pytest.fixture def kernel(): - return get_kernel("kernels-community/activation") + return get_kernel("kernels-community/relu", version=1) @pytest.fixture def local_kernel_path(): - package_name, path = install_kernel("kernels-community/activation", revision="main") - # Path is the build variant path (build/torch-<...>), so the grandparent - # is the kernel repository path. - return package_name, path + # install_kernel works with resolved revisions, so explicitly use v1 here. + return install_kernel("kernels-community/relu", revision="v1") @pytest.fixture def local_kernel(local_kernel_path): - package_name, path = local_kernel_path - return get_local_kernel(path.parent.parent, package_name) + path = local_kernel_path + return get_local_kernel(path.parent.parent) @pytest.fixture @@ -32,11 +30,6 @@ def metal_kernel(): return get_kernel("kernels-test/relu-metal") -@pytest.fixture -def universal_kernel(): - return get_kernel("kernels-community/triton-scaled-mm") - - @pytest.fixture def device(): if not torch.cuda.is_available(): @@ -45,35 +38,17 @@ def device(): @pytest.mark.cuda_only -def test_gelu_fast(kernel, device): - x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) - y = torch.empty_like(x) - - kernel.gelu_fast(y, x) - - expected = torch.tensor( - [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]], - device=device, - dtype=torch.float16, - ) - - assert torch.allclose(y, expected) +def test_relu(kernel, device): + x = torch.arange(-4, 5, dtype=torch.float32, device=device).view(3, 3) + y = kernel.relu(x) + torch.testing.assert_close(y, F.relu(x)) @pytest.mark.cuda_only def test_local_kernel(local_kernel, device): - x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) - y = torch.empty_like(x) - - local_kernel.gelu_fast(y, x) - - expected = torch.tensor( - [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]], - device=device, - dtype=torch.float16, - ) - - assert torch.allclose(y, expected) + x = torch.arange(-4, 5, dtype=torch.float32, device=device).view(3, 3) + y = local_kernel.relu(x) + torch.testing.assert_close(y, F.relu(x)) @pytest.mark.parametrize( @@ -86,22 +61,22 @@ def test_local_kernel(local_kernel, device): ) def test_local_kernel_path_types(repo_revision, device): repo_id, revision = repo_revision - package_name, path = install_kernel(repo_id, revision=revision) + path = install_kernel(repo_id, revision=revision) # Top-level repo path # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071 - kernel = get_local_kernel(path.parent.parent, package_name) + kernel = get_local_kernel(path.parent.parent) x = torch.arange(0, 32, dtype=torch.float16, device=device).view(2, 16) torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x)) # Build directory path # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build - kernel = get_local_kernel(path.parent.parent / "build", package_name) + kernel = get_local_kernel(path.parent.parent / "build") torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x)) # Explicit package path # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux - kernel = get_local_kernel(path, package_name) + kernel = get_local_kernel(path) torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x)) @@ -117,8 +92,8 @@ def test_relu_metal(metal_kernel, dtype): @pytest.mark.parametrize( "kernel_exists", [ - ("kernels-community/activation", "main", True), - ("kernels-community/triton-layer-norm", "main", True), + ("kernels-community/relu", "main", True), + ("kernels-test/silu-and-mul", "v1", True), # Repo only contains Torch 2.4 kernels (and we don't # support/test against this version). ("kernels-test/only-torch-2.4", "main", False), @@ -131,25 +106,6 @@ def test_has_kernel(kernel_exists): assert has_kernel(repo_id, revision=revision) == kernel -@pytest.mark.skip(reason="Tags are not supported on kernel repos") -def test_version_old(): - # Remove once we drop support for version specs. - kernel = get_kernel("kernels-test/versions") - assert kernel.version() == "0.2.0" - kernel = get_kernel("kernels-test/versions", version="<1.0.0") - assert kernel.version() == "0.2.0" - kernel = get_kernel("kernels-test/versions", version="<0.2.0") - assert kernel.version() == "0.1.1" - kernel = get_kernel("kernels-test/versions", version=">0.1.0,<0.2.0") - assert kernel.version() == "0.1.1" - - with pytest.raises(ValueError, match=r"No version.*satisfies requirement"): - get_kernel("kernels-test/versions", version=">0.2.0") - - with pytest.raises(ValueError, match=r"Only one of"): - kernel = get_kernel("kernels-test/versions", revision="v0.1.0", version="<1.0.0") - - def test_version(): kernel = get_kernel("kernels-test/versions", version=1) assert kernel.version() == 1 @@ -187,26 +143,11 @@ def test_no_version_or_revision_warning(): get_kernel("kernels-test/versions") -@pytest.mark.cuda_only -def test_universal_kernel(universal_kernel): - torch.manual_seed(0) - A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda") - B = torch.randint(-10, 10, (128, 96), dtype=torch.int8, device="cuda") - scale_a = torch.tensor(0.4, dtype=torch.float16, device="cuda") - scale_b = torch.tensor(0.6, dtype=torch.float16, device="cuda") - - out = universal_kernel.triton_scaled_mm(A, B, scale_a, scale_b, torch.float16) - out_check = (A * scale_a) @ (B * scale_b) - out_check = out_check.to(torch.float16) - - torch.testing.assert_close(out, out_check, rtol=1e-1, atol=1e-1) - - def test_noarch_kernel(device): supported_devices = ["cpu", "cuda", "xpu"] if device not in supported_devices: pytest.skip(f"Device is not one of: {','.join(supported_devices)}") - get_kernel("kernels-test/silu-and-mul-noarch") + get_kernel("kernels-test/silu-and-mul", version=1) def test_get_kernel_with_backend(device): @@ -221,8 +162,6 @@ def test_get_kernel_with_backend(device): with pytest.raises(ValueError, match="Invalid backend 'xpu'"): get_kernel("kernels-community/relu", version=1, backend="xpu") - with pytest.raises(ValueError, match="Invalid backend 'xpu'"): - has_kernel("kernels-community/relu", version=1, backend="xpu") assert has_kernel("kernels-community/relu", version=1, backend="cpu") relu = get_kernel("kernels-community/relu", version=1, backend="cpu") @@ -247,24 +186,24 @@ def test_flattened_build(repo_revision, device): def test_local_overrides(monkeypatch, local_kernel_path): - package_name, kernel_path = local_kernel_path + kernel_path = local_kernel_path # Ensure that we are testing with a non-existing kernel, so that we know # that the kernel must be local. with pytest.raises(HfHubHTTPError): - get_kernel(f"kernels-test/{package_name}") + get_kernel("kernels-test/activation", revision="main") with monkeypatch.context() as m: m.setenv( "LOCAL_KERNELS", - f"kernels-test/{package_name}={str(kernel_path)}:kernels-test/non-existing2=/non/existing", + f"kernels-test/activation={str(kernel_path)}:kernels-test/non-existing2=/non/existing", ) get_kernel("kernels-test/activation") with monkeypatch.context() as m: m.setenv( "LOCAL_KERNELS", - f"kernels-test/non-existing2=/non/existing:kernels-test/{package_name}={str(kernel_path)}", + f"kernels-test/non-existing2=/non/existing:kernels-test/activation={str(kernel_path)}", ) get_kernel("kernels-test/activation") @@ -272,16 +211,16 @@ def test_local_overrides(monkeypatch, local_kernel_path): # Using a non-existing path should error. m.setenv( "LOCAL_KERNELS", - f"kernels-test/non-existing2=/non/existing:kernels-test/{package_name}=/non/existing", + "kernels-test/non-existing2=/non/existing:kernels-test/activation=/non/existing", ) - with pytest.raises(FileNotFoundError, match=r"Could not find.*activation"): + with pytest.raises(FileNotFoundError, match=r"Could not find kernel in /non/existing"): get_kernel("kernels-test/activation") with monkeypatch.context() as m: # Malformed entries must be rejected. m.setenv( "LOCAL_KERNELS", - f"kernels-test/non-existing2=/non/existing:kernels-test/{package_name}", + "kernels-test/non-existing2=/non/existing:kernels-test/activation", ) with pytest.raises(ValueError, match=r"Invalid LOCAL_KERNELS entry"): get_kernel("kernels-test/activation") diff --git a/kernels/tests/test_deps.py b/kernels/tests/test_deps.py index 9f2aa1d8..a7536b31 100644 --- a/kernels/tests/test_deps.py +++ b/kernels/tests/test_deps.py @@ -5,6 +5,7 @@ from kernels import get_kernel +@pytest.mark.cuda_only @pytest.mark.parametrize("dependency", ["einops", "nvidia-cutlass-dsl"]) def test_python_deps(dependency): must_raise = find_spec(dependency.replace("-", "_")) is None @@ -13,11 +14,11 @@ def test_python_deps(dependency): ImportError, match=r"Kernel module `python_dep` requires Python dependency `(einops|nvidia-cutlass-dsl)`", ): - get_kernel("kernels-test/python-dep") + get_kernel("kernels-test/python-dep", revision="main") else: - get_kernel("kernels-test/python-dep") + get_kernel("kernels-test/python-dep", revision="main") def test_illegal_dep(): with pytest.raises(ValueError, match=r"Kernel module `python_invalid_dep` uses.*kepler-22b"): - get_kernel("kernels-test/python-invalid-dep") + get_kernel("kernels-test/python-invalid-dep", revision="main") diff --git a/kernels/tests/test_func.py b/kernels/tests/test_func.py index 336eb99c..4b8c5d5e 100644 --- a/kernels/tests/test_func.py +++ b/kernels/tests/test_func.py @@ -78,8 +78,9 @@ def test_kernel_func_with_layer(): { "surprise_me": { "cuda": LayerRepository( - "kernels-community/activation", + "kernels-test/silu-and-mul", layer_name="SiluAndMul", + version=1, ) } } @@ -101,14 +102,13 @@ def test_local_kernel_func(device): x = torch.arange(-10, 10).float() assert model(x) is x - package_name, path = install_kernel("kernels-test/flattened-build", revision="main") + path = install_kernel("kernels-test/flattened-build", revision="main") with use_kernel_mapping( { "surprise_me": { device: LocalFuncRepository( repo_path=path.parent.parent, - package_name=package_name, func_name="silu_and_mul", ) } diff --git a/kernels/tests/test_kernel_locking.py b/kernels/tests/test_kernel_locking.py index 9e812913..9dc32ece 100644 --- a/kernels/tests/test_kernel_locking.py +++ b/kernels/tests/test_kernel_locking.py @@ -59,7 +59,7 @@ def forward(self) -> str: } ): version = kernelize(version, device=device, mode=Mode.INFERENCE) - assert version() == "2" + assert version() == 2 def test_func_locked(device): @@ -67,7 +67,7 @@ def test_func_locked(device): @use_kernel_func_from_hub("version") def version(): - return "0.0.0" + return 0 class Version(nn.Module): def __init__(self): @@ -92,9 +92,9 @@ def forward(self) -> str: ): model = kernelize(model, device=device, mode=Mode.INFERENCE) - assert version() == "2" + assert version() == 2 with use_kernel_mapping({"version": {}}): model = kernelize(model, mode=Mode.INFERENCE, device=device) - assert version() == "0.0.0" + assert version() == 0 diff --git a/kernels/tests/test_layer.py b/kernels/tests/test_layer.py index 9fde588f..696ad4d9 100644 --- a/kernels/tests/test_layer.py +++ b/kernels/tests/test_layer.py @@ -30,17 +30,16 @@ @pytest.fixture def local_kernel_path(): - package_name, path = install_kernel("kernels-community/activation", revision="main") - # Path is the build variant path (build/torch-<...>), so the grandparent - # is the kernel repository path. - return package_name, path + # install_kernel only works with resolved revisions. + return install_kernel("kernels-test/silu-and-mul", revision="v1") kernel_layer_mapping = { "SiluAndMul": { Device(type="cuda"): LayerRepository( - repo_id="kernels-community/activation", + repo_id="kernels-test/silu-and-mul", layer_name="SiluAndMul", + version=1, ), "npu": LayerRepository( repo_id="kernels-ext-npu/SwiGlu", @@ -59,8 +58,9 @@ def local_kernel_path(): }, "SiluAndMulStringDevice": { "cuda": LayerRepository( - repo_id="kernels-community/activation", + repo_id="kernels-test/silu-and-mul", layer_name="SiluAndMul", + version=1, ) }, "LigerRMSNorm": { @@ -320,7 +320,7 @@ def test_rocm_kernel_mapping(device): kernel_layer_mapping = { "SiluAndMul": { "rocm": LayerRepository( - repo_id="kernels-community/activation", + repo_id="kernels-test/silu-and-mul", layer_name="SiluAndMul", ) } @@ -337,7 +337,7 @@ def test_rocm_kernel_mapping(device): # Verify the repository is correctly stored rocm_repos = mapping["SiluAndMul"]["rocm"] assert rocm_repos is not None - assert rocm_repos.repos[Mode.FALLBACK]._repo_id == "kernels-community/activation" + assert rocm_repos.repos[Mode.FALLBACK]._repo_id == "kernels-test/silu-and-mul" assert rocm_repos.repos[Mode.FALLBACK].layer_name == "SiluAndMul" @@ -398,7 +398,7 @@ class SiluAndMulWithKernelFallback(SiluAndMul): def test_local_layer_repo(device): # Fetch a kernel to the local cache. - package_name, path = install_kernel("kernels-test/backward-marker-test", revision="main") + path = install_kernel("kernels-test/backward-marker-test", revision="main") linear = TorchLinearWithCounter(32, 32).to(device) @@ -408,7 +408,6 @@ def test_local_layer_repo(device): device: LocalLayerRepository( # install_kernel will give the fully-resolved path. repo_path=path.parent.parent, - package_name=package_name, layer_name="LinearBackward", ) } @@ -489,7 +488,7 @@ def test_mapping_contexts(): extra_mapping1 = { "TestKernel": { Device(type="cuda"): LayerRepository( - repo_id="kernels-community/activation", + repo_id="kernels-test/silu-and-mul", layer_name="SiluAndMul", revision="layers", ) @@ -535,9 +534,7 @@ def test_mapping_contexts(): "LigerRMSNorm", "TestKernel", } - assert ( - _KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK]._repo_id == "kernels-community/activation" - ) + assert _KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK]._repo_id == "kernels-test/silu-and-mul" with use_kernel_mapping(extra_mapping2, inherit_mapping=False): assert set(_KERNEL_MAPPING.get().keys()) == { @@ -555,9 +552,7 @@ def test_mapping_contexts(): "LigerRMSNorm", "TestKernel", } - assert ( - _KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK]._repo_id == "kernels-community/activation" - ) + assert _KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK]._repo_id == "kernels-test/silu-and-mul" assert set(_KERNEL_MAPPING.get().keys()) == { "SiluAndMul", @@ -1092,8 +1087,8 @@ def test_kernel_modes_cross_fallback(): def test_layer_versions(device): @use_kernel_forward_from_hub("Version") class Version(nn.Module): - def forward(self) -> str: - return "0.0.0" + def forward(self) -> int: + return 0 version = Version() @@ -1172,12 +1167,10 @@ def test_local_overrides_layer(monkeypatch, local_kernel_path): # The primary validation is in the get_kernel tests. Here we just want # to ensure that the lookups also happen in layers. - package_name, kernel_path = local_kernel_path - mapping = { "SiluAndMul": { Device(type="cuda"): LayerRepository( - repo_id="kernels-test/activation", + repo_id="kernels-test/silu-and-mul", layer_name="SiluAndMul", ), }, @@ -1192,6 +1185,6 @@ def test_local_overrides_layer(monkeypatch, local_kernel_path): with monkeypatch.context() as m: m.setenv( "LOCAL_KERNELS", - f"kernels-test/{package_name}={str(kernel_path)}:kernels-test/non-existing2=/non/existing", + f"kernels-test/silu-and-mul={str(local_kernel_path)}:kernels-test/non-existing2=/non/existing", ) kernelize(model, device="cuda", mode=Mode.INFERENCE) diff --git a/kernels/tests/test_loaded_kernels.py b/kernels/tests/test_loaded_kernels.py index 5999d941..ac22363e 100644 --- a/kernels/tests/test_loaded_kernels.py +++ b/kernels/tests/test_loaded_kernels.py @@ -3,7 +3,7 @@ import pytest from kernels import get_kernel, get_loaded_kernels, get_local_kernel, install_kernel -from kernels.utils import LoadedKernel, RepoInfos, _loaded_kernels +from kernels.utils import LoadedKernel, RepoInfo, _loaded_kernels _REPO_ID = "kernels-community/relu" _PACKAGE_NAME = "relu" @@ -22,15 +22,13 @@ def fresh_registry(): def test_dataclass_shape(): assert tuple(f.name for f in fields(LoadedKernel)) == ( - "kernel_id", + "metadata", "module", - "module_name", - "repo_infos", + "repo_info", ) - assert tuple(f.name for f in fields(RepoInfos)) == ( + assert tuple(f.name for f in fields(RepoInfo)) == ( "repo_id", "revision", - "backend", ) @@ -56,11 +54,10 @@ def test_get_kernel_registers_loaded_kernel(fresh_registry): entry = loaded[0] assert entry.module is kernel - assert entry.module_name == _PACKAGE_NAME - assert entry.repo_infos is not None - assert entry.repo_infos.repo_id == _REPO_ID - assert isinstance(entry.repo_infos.revision, str) and entry.repo_infos.revision - assert entry.repo_infos.backend == "cpu" + assert entry.metadata.name.python_name == _PACKAGE_NAME + assert entry.repo_info is not None + assert entry.repo_info.repo_id == _REPO_ID + assert isinstance(entry.repo_info.revision, str) and entry.repo_info.revision def test_repeated_get_kernel_is_cached(fresh_registry): @@ -71,7 +68,7 @@ def test_repeated_get_kernel_is_cached(fresh_registry): assert len(get_loaded_kernels()) == 1 -def test_get_local_kernel_registers_with_null_repo_infos(fresh_registry): +def test_get_local_kernel_registers_with_null_repo_info(fresh_registry): # Populate the HF cache via get_kernel, grab the variant path it registered, # then clear the registry and exercise get_local_kernel against that path. get_kernel(_REPO_ID, version=_VERSION, backend="cpu") @@ -79,24 +76,23 @@ def test_get_local_kernel_registers_with_null_repo_infos(fresh_registry): _loaded_kernels.clear() - kernel = get_local_kernel(variant_path, _PACKAGE_NAME, backend="cpu") + kernel = get_local_kernel(variant_path, backend="cpu") loaded = get_loaded_kernels() assert len(loaded) == 1 entry = loaded[0] assert entry.module is kernel - assert entry.module_name == _PACKAGE_NAME - assert entry.repo_infos is None + assert entry.metadata.name.python_name == _PACKAGE_NAME + assert entry.repo_info is None -def test_install_kernel_plus_import_does_not_set_repo_infos(fresh_registry): +def test_install_kernel_plus_import_does_not_set_repo_info(fresh_registry): # install_kernel alone does not import; it returns a path. Any loader - # that does not go through get_kernel must leave repo_infos as None. - package_name, variant_path = install_kernel(_REPO_ID, revision="main", backend="cpu") - assert package_name == _PACKAGE_NAME + # that does not go through get_kernel must leave repo_info as None. + variant_path = install_kernel(_REPO_ID, revision="main", backend="cpu") assert get_loaded_kernels() == [] - get_local_kernel(variant_path, package_name, backend="cpu") + get_local_kernel(variant_path, backend="cpu") (entry,) = get_loaded_kernels() - assert entry.repo_infos is None + assert entry.repo_info is None diff --git a/kernels/tests/test_tvm_ffi.py b/kernels/tests/test_tvm_ffi.py index 8f15fa8f..7d68fddf 100644 --- a/kernels/tests/test_tvm_ffi.py +++ b/kernels/tests/test_tvm_ffi.py @@ -29,8 +29,8 @@ def test_local_load(device): if device not in relu_supported_devices: pytest.skip(f"Device is not one of: {','.join(relu_supported_devices)}") - package_name, path = install_kernel("kernels-test/relu-tvm-ffi", revision="v1") - get_local_kernel(path.parent.parent, package_name) + path = install_kernel("kernels-test/relu-tvm-ffi", revision="v1") + get_local_kernel(path.parent.parent) @pytest.mark.cuda_only diff --git a/nix-builder/build-variants.json b/nix-builder/build-variants.json index e6d76346..f2a4dc1e 100644 --- a/nix-builder/build-variants.json +++ b/nix-builder/build-variants.json @@ -20,8 +20,7 @@ "torch210-cxx11-cu130-aarch64-linux", "torch211-cxx11-cu126-aarch64-linux", "torch211-cxx11-cu128-aarch64-linux", - "torch211-cxx11-cu130-aarch64-linux", - "torch29-cxx11-cu129-aarch64-linux" + "torch211-cxx11-cu130-aarch64-linux" ] }, "x86_64-linux": { @@ -35,8 +34,7 @@ "torch210-cxx11-cu130-x86_64-linux", "torch211-cxx11-cu126-x86_64-linux", "torch211-cxx11-cu128-x86_64-linux", - "torch211-cxx11-cu130-x86_64-linux", - "torch29-cxx11-cu129-x86_64-linux" + "torch211-cxx11-cu130-x86_64-linux" ], "rocm": [ "torch210-cxx11-rocm70-x86_64-linux", diff --git a/nix-builder/pkgs/get-kernel-check/get-kernel-check-hook.sh b/nix-builder/pkgs/get-kernel-check/get-kernel-check-hook.sh index 4a845ec9..23d4501a 100755 --- a/nix-builder/pkgs/get-kernel-check/get-kernel-check-hook.sh +++ b/nix-builder/pkgs/get-kernel-check/get-kernel-check-hook.sh @@ -26,7 +26,7 @@ _getKernelCheckHook() { trap "rm -rf '$HOME'" EXIT PYTHONPATH="@kernels@" \ - @python3@ -c "from pathlib import Path; import kernels; kernels.get_local_kernel(Path('${out}'), '${moduleName}')" + @python3@ -c "from pathlib import Path; import kernels; kernels.get_local_kernel(Path('${out}'))" } postInstallCheckHooks+=(_getKernelCheckHook) diff --git a/nix-builder/pkgs/python-modules/kernels/default.nix b/nix-builder/pkgs/python-modules/kernels/default.nix index 4268957f..0dfdd91b 100644 --- a/nix-builder/pkgs/python-modules/kernels/default.nix +++ b/nix-builder/pkgs/python-modules/kernels/default.nix @@ -5,6 +5,7 @@ huggingface-hub, kernel-abi-check, + kernels-data, pyyaml, tomlkit, torch, @@ -34,6 +35,7 @@ buildPythonPackage { dependencies = [ huggingface-hub kernel-abi-check + kernels-data pyyaml tomlkit torch diff --git a/nix-builder/versions.nix b/nix-builder/versions.nix index 78d3a5ad..e097d4e2 100644 --- a/nix-builder/versions.nix +++ b/nix-builder/versions.nix @@ -1,14 +1,4 @@ [ - { - torchVersion = "2.9"; - cudaVersion = "12.9"; - systems = [ - "x86_64-linux" - "aarch64-linux" - ]; - bundleBuild = true; - } - { torchVersion = "2.10"; cudaVersion = "12.6";