Skip to content

Commit f00aae9

Browse files
authored
get_kernel: support specifying the backend (#268)
* `get_kernel`: support specifying the backend So far, `get_kernel` has always auto-detected the backend from the Torch build. However, this causes issues in fixed-device settings where we may want to run some ops on the CPU and others on the GPU. This change allows specifying `backend` as an argument to `get_kernel` and friends. Currently, only the `cpu` backend and the backend that Torch was built for are supported (since we cannot determine the version, etc. of a backend that Torch was not built for). However, we decided to make it a string-based argument as opposed to e.g. a `cpu` bool, in the case Torch might do multi-backend builds in the future. * Fix typing and incorrect `hip` backend name * Test kernels package with Torch 2.9 and 2.10 * Do not attempt to lock Torch version * Simplify `_select_backend` * CI: try to fix
1 parent 7a7c381 commit f00aae9

6 files changed

Lines changed: 152 additions & 52 deletions

File tree

.github/workflows/test_kernels.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
max-parallel: 4
2525
matrix:
2626
python-version: ["3.10", "3.12"]
27-
torch-version: ["2.7.0", "2.8.0"]
27+
torch-version: ["2.9.0", "2.10.0"]
2828

2929
env:
3030
UV_PYTHON_PREFERENCE: only-managed
@@ -38,14 +38,14 @@ jobs:
3838
with:
3939
python-version: ${{ matrix.python-version }}
4040

41-
- name: Lock Torch version
42-
working-directory: ./kernels
43-
run: uv lock --upgrade-package "torch==${{ matrix.torch-version }}"
44-
4541
- name: Install the project
4642
working-directory: ./kernels
4743
run: uv sync --all-extras --dev
4844

45+
- name: Install Torch version
46+
working-directory: ./kernels
47+
run: uv pip install "torch==${{ matrix.torch-version }}"
48+
4949
- name: Install setuptools for Triton-based test
5050
working-directory: ./kernels
5151
run: uv pip install setuptools
@@ -76,7 +76,7 @@ jobs:
7676
HF_TOKEN: ${{ secrets.HF_STAGING_TOKEN }}
7777
run: |
7878
HUGGINGFACE_CO_STAGING=true uv run pytest --token -m "is_staging_test" tests/
79-
if: matrix.python_version == '3.10' && matrix.torch-version == '2.7.0'
79+
if: matrix.python_version == '3.10' && matrix.torch-version == '2.9.0'
8080

8181
- name: Check README generation
8282
working-directory: ./kernels

kernels/src/kernels/cli/benchmark.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121

2222
from kernels.benchmark import Benchmark
23-
from kernels.utils import _get_hf_api, backend
23+
from kernels.utils import _get_hf_api, _backend
2424

2525
MISSING_DEPS: list[str] = []
2626

@@ -387,8 +387,8 @@ def collect_machine_info() -> MachineInfo:
387387

388388
if TORCH_AVAILABLE:
389389
pytorch_version = torch.__version__
390-
backend_name = backend()
391-
if backend_name in {"cuda", "hip"}:
390+
backend_name = _backend()
391+
if backend_name in {"cuda", "rocm"}:
392392
gpu = torch.cuda.get_device_name(0)
393393
# ROCm uses the CUDA API but has torch.version.hip
394394
if hasattr(torch.version, "hip") and torch.version.hip:
@@ -465,10 +465,10 @@ def run_benchmark_class(
465465
kernel = get_kernel(repo_id, revision=revision)
466466

467467
kernel_sha = get_kernel_sha_from_build_name(kernel)
468-
backend_name = backend() if TORCH_AVAILABLE else "cpu"
468+
backend_name = _backend() if TORCH_AVAILABLE else "cpu"
469469
# Map backend names to torch device names
470470
device_map = {
471-
"hip": "cuda",
471+
"rocm": "cuda",
472472
"metal": "mps",
473473
"cann": "npu",
474474
}

kernels/src/kernels/cli/versions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from huggingface_hub import HfApi
44

55
from kernels._versions import _get_available_versions
6-
from kernels.utils import _get_hf_api, build_variants
6+
from kernels.utils import _get_hf_api, _build_variants
77
from kernels.variants import BUILD_VARIANT_REGEX
88

99

@@ -14,7 +14,7 @@ def print_kernel_versions(repo_id: str):
1414
# Do not mark compatible variants when Torch is not available.
1515
compatible_variants = set()
1616
else:
17-
compatible_variants = set(build_variants())
17+
compatible_variants = set(_build_variants(None))
1818

1919
versions = _get_available_versions(repo_id).items()
2020
if not versions:

kernels/src/kernels/utils.py

Lines changed: 87 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ def _get_privateuse_backend_name() -> str | None:
7171
return None
7272

7373

74-
def backend() -> str:
74+
def _backend() -> str:
7575
import torch
7676

7777
if torch.version.cuda is not None:
7878
return "cuda"
7979
elif torch.version.hip is not None:
80-
return "hip"
80+
return "rocm"
8181
elif torch.backends.mps.is_available():
8282
return "metal"
8383
elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
@@ -88,21 +88,23 @@ def backend() -> str:
8888
return "cpu"
8989

9090

91-
def build_variant() -> str:
91+
def _build_variant(backend: str | None) -> str:
92+
backend = _select_backend(backend)
93+
9294
import torch
9395

94-
if torch.version.cuda is not None:
96+
if backend == "cuda" and torch.version.cuda is not None:
9597
cuda_version = parse(torch.version.cuda)
9698
compute_framework = f"cu{cuda_version.major}{cuda_version.minor}"
97-
elif torch.version.hip is not None:
99+
elif backend == "rocm" and torch.version.hip is not None:
98100
rocm_version = parse(torch.version.hip.split("-")[0])
99101
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
100-
elif torch.backends.mps.is_available():
102+
elif backend == "metal":
101103
compute_framework = "metal"
102-
elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
104+
elif backend == "xpu" and torch.version.xpu is not None:
103105
version = torch.version.xpu
104106
compute_framework = f"xpu{version[0:4]}{version[5:6]}"
105-
elif _get_privateuse_backend_name() == "npu":
107+
elif backend == "cann":
106108
from torch_npu.utils.collect_env import get_cann_version # type: ignore[import-not-found]
107109

108110
cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2]
@@ -125,36 +127,57 @@ def build_variant() -> str:
125127
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
126128

127129

128-
def build_variant_noarch() -> str:
129-
import torch
130+
def _supported_backends() -> set[str]:
131+
return {"cpu", _backend()}
130132

131-
if torch.version.cuda is not None:
133+
134+
def _select_backend(backend: str | None) -> str:
135+
if backend is None:
136+
return _backend()
137+
138+
supported = _supported_backends()
139+
if backend in supported:
140+
return backend
141+
142+
raise ValueError(
143+
f"Invalid backend '{backend}', system supported backends: {', '.join(sorted(supported))}"
144+
)
145+
146+
147+
def _build_variant_noarch(backend: str | None) -> str:
148+
backend = _select_backend(backend)
149+
150+
if backend == "cuda":
132151
return "torch-cuda"
133-
elif torch.version.hip is not None:
152+
elif backend == "rocm":
134153
return "torch-rocm"
135-
elif torch.backends.mps.is_available():
154+
elif backend == "metal":
136155
return "torch-metal"
137-
elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
156+
elif backend == "xpu":
138157
return "torch-xpu"
139-
elif _get_privateuse_backend_name() == "npu":
158+
elif backend == "cann":
140159
return "torch-npu"
141160
else:
142161
return "torch-cpu"
143162

144163

145-
def build_variant_universal() -> str:
164+
def _build_variant_universal() -> str:
146165
# Once we support other frameworks, detection goes here.
147166
return "torch-universal"
148167

149168

150-
def build_variants() -> list[str]:
169+
def _build_variants(backend: str | None) -> list[str]:
151170
"""Return compatible build variants in preferred order."""
152-
return [build_variant(), build_variant_noarch(), build_variant_universal()]
171+
return [
172+
_build_variant(backend),
173+
_build_variant_noarch(backend),
174+
_build_variant_universal(),
175+
]
153176

154177

155178
def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
156179
metadata = Metadata.load_from_variant(variant_path)
157-
validate_dependencies(metadata.python_depends, backend())
180+
validate_dependencies(metadata.python_depends, _backend())
158181

159182
file_path = variant_path / "__init__.py"
160183
if not file_path.exists():
@@ -181,6 +204,7 @@ def install_kernel(
181204
repo_id: str,
182205
revision: str,
183206
local_files_only: bool = False,
207+
backend: str | None = None,
184208
variant_locks: dict[str, VariantLock] | None = None,
185209
user_agent: str | dict | None = None,
186210
) -> tuple[str, Path]:
@@ -196,6 +220,9 @@ def install_kernel(
196220
The specific revision (branch, tag, or commit) to download.
197221
local_files_only (`bool`, *optional*, defaults to `False`):
198222
Whether to only use local files and not download from the Hub.
223+
backend (`str`, *optional*):
224+
The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for.
225+
The backend will be detected automatically if not provided.
199226
variant_locks (`dict[str, VariantLock]`, *optional*):
200227
Optional dictionary of variant locks for validation.
201228
user_agent (`Union[str, dict]`, *optional*):
@@ -205,7 +232,7 @@ def install_kernel(
205232
`tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
206233
"""
207234
package_name = package_name_from_repo_id(repo_id)
208-
allow_patterns = [f"build/{variant}/*" for variant in build_variants()]
235+
allow_patterns = [f"build/{variant}/*" for variant in _build_variants(backend)]
209236
api = _get_hf_api(user_agent=user_agent)
210237
repo_path = Path(
211238
str(
@@ -220,7 +247,9 @@ def install_kernel(
220247
)
221248

222249
try:
223-
return _find_kernel_in_repo_path(repo_path, package_name, variant_locks)
250+
return _find_kernel_in_repo_path(
251+
repo_path, package_name, backend=backend, variant_locks=variant_locks
252+
)
224253
except FileNotFoundError:
225254
raise FileNotFoundError(
226255
f"Cannot install kernel from repo {repo_id} (revision: {revision})"
@@ -230,9 +259,11 @@ def install_kernel(
230259
def _find_kernel_in_repo_path(
231260
repo_path: Path,
232261
package_name: str,
262+
*,
263+
backend: str | None = None,
233264
variant_locks: dict[str, VariantLock] | None = None,
234265
) -> tuple[str, Path]:
235-
variants = build_variants()
266+
variants = _build_variants(backend)
236267
variant = None
237268
variant_path = None
238269
for candidate_variant in variants:
@@ -303,6 +334,7 @@ def get_kernel(
303334
repo_id: str,
304335
revision: str | None = None,
305336
version: int | str | None = None,
337+
backend: str | None = None,
306338
user_agent: str | dict | None = None,
307339
) -> ModuleType:
308340
"""
@@ -319,6 +351,9 @@ def get_kernel(
319351
version (`int|str`, *optional*):
320352
The kernel version to download as an integer. The `str` variant is deprecated and will be
321353
removed in a future release. Cannot be used together with `revision`.
354+
backend (`str`, *optional*):
355+
The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for.
356+
The backend will be detected automatically if not provided.
322357
user_agent (`Union[str, dict]`, *optional*):
323358
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
324359
@@ -342,12 +377,16 @@ def get_kernel(
342377

343378
revision = select_revision_or_version(repo_id, revision=revision, version=version)
344379
package_name, variant_path = install_kernel(
345-
repo_id, revision=revision, user_agent=user_agent
380+
repo_id, revision=revision, backend=backend, user_agent=user_agent
346381
)
347382
return _import_from_path(package_name, variant_path)
348383

349384

350-
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
385+
def get_local_kernel(
386+
repo_path: Path,
387+
package_name: str,
388+
backend: str | None = None,
389+
) -> ModuleType:
351390
"""
352391
Import a kernel from a local kernel repository path.
353392
@@ -356,13 +395,16 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
356395
The local path to the kernel repository.
357396
package_name (`str`):
358397
The name of the package to import from the repository.
398+
backend (`str`, *optional*):
399+
The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for.
400+
The backend will be detected automatically if not provided.
359401
360402
Returns:
361403
`ModuleType`: The imported kernel module.
362404
"""
363405
# Presume we were given the top level path of the kernel repository.
364406
for base_path in [repo_path, repo_path / "build"]:
365-
for v in build_variants():
407+
for v in _build_variants(backend):
366408
variant_path = base_path / v
367409
if variant_path.exists():
368410
return _import_from_path(package_name, variant_path)
@@ -377,7 +419,10 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
377419

378420

379421
def has_kernel(
380-
repo_id: str, revision: str | None = None, version: int | str | None = None
422+
repo_id: str,
423+
revision: str | None = None,
424+
version: int | str | None = None,
425+
backend: str | None = None,
381426
) -> bool:
382427
"""
383428
Check whether a kernel build exists for the current environment (Torch version and compute framework).
@@ -390,17 +435,19 @@ def has_kernel(
390435
version (`int|str`, *optional*):
391436
The kernel version to download as an integer. The `str` variant is deprecated and will be
392437
removed in a future release. Cannot be used together with `revision`.
438+
backend (`str`, *optional*):
439+
The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for.
440+
The backend will be detected automatically if not provided.
393441
394442
Returns:
395443
`bool`: `True` if a kernel is available for the current environment.
396444
"""
397445
revision = select_revision_or_version(repo_id, revision=revision, version=version)
398446

399447
package_name = package_name_from_repo_id(repo_id)
400-
variant = build_variant()
401448

402449
api = _get_hf_api()
403-
for variant in build_variants():
450+
for variant in _build_variants(backend):
404451
for init_file in ["__init__.py", f"{package_name}/__init__.py"]:
405452
if api.file_exists(
406453
repo_id, revision=revision, filename=f"build/{variant}/{init_file}"
@@ -410,7 +457,12 @@ def has_kernel(
410457
return False
411458

412459

413-
def load_kernel(repo_id: str, *, lockfile: Path | None) -> ModuleType:
460+
def load_kernel(
461+
repo_id: str,
462+
*,
463+
lockfile: Path | None,
464+
backend: str | None = None,
465+
) -> ModuleType:
414466
"""
415467
Get a pre-downloaded, locked kernel.
416468
@@ -421,6 +473,9 @@ def load_kernel(repo_id: str, *, lockfile: Path | None) -> ModuleType:
421473
The Hub repository containing the kernel.
422474
lockfile (`Path`, *optional*):
423475
Path to the lockfile. If not provided, the lockfile will be loaded from the caller's package metadata.
476+
backend (`str`, *optional*):
477+
The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for.
478+
The backend will be detected automatically if not provided.
424479
425480
Returns:
426481
`ModuleType`: The imported kernel module.
@@ -439,7 +494,7 @@ def load_kernel(repo_id: str, *, lockfile: Path | None) -> ModuleType:
439494
package_name = package_name_from_repo_id(repo_id)
440495

441496
api = _get_hf_api()
442-
allow_patterns = [f"build/{variant}/*" for variant in build_variants()]
497+
allow_patterns = [f"build/{variant}/*" for variant in _build_variants(backend)]
443498
repo_path = Path(
444499
str(
445500
api.snapshot_download(
@@ -454,7 +509,7 @@ def load_kernel(repo_id: str, *, lockfile: Path | None) -> ModuleType:
454509

455510
try:
456511
package_name, variant_path = _find_kernel_in_repo_path(
457-
repo_path, package_name, variant_locks=None
512+
repo_path, package_name, backend=backend, variant_locks=None
458513
)
459514
return _import_from_path(package_name, variant_path)
460515
except FileNotFoundError:
@@ -605,7 +660,7 @@ def _get_hf_api(user_agent: str | dict | None = None) -> HfApi:
605660

606661
# System info
607662
python = ".".join(platform.python_version_tuple()[:2])
608-
user_agent_str += f"; kernels/{__version__}; python/{python}; torch/{torch.__version__}; build_variant/{build_variant()}; file_type/kernel"
663+
user_agent_str += f"; kernels/{__version__}; python/{python}; torch/{torch.__version__}; build_variant/{_build_variant(None)}; file_type/kernel"
609664

610665
# Add glibc version if available
611666
glibc = glibc_version()

0 commit comments

Comments
 (0)