@@ -71,13 +71,13 @@ def _get_privateuse_backend_name() -> str | None:
7171 return None
7272
7373
74- def backend () -> str :
74+ def _backend () -> str :
7575 import torch
7676
7777 if torch .version .cuda is not None :
7878 return "cuda"
7979 elif torch .version .hip is not None :
80- return "hip "
80+ return "rocm "
8181 elif torch .backends .mps .is_available ():
8282 return "metal"
8383 elif hasattr (torch .version , "xpu" ) and torch .version .xpu is not None :
@@ -88,21 +88,23 @@ def backend() -> str:
8888 return "cpu"
8989
9090
91- def build_variant () -> str :
91+ def _build_variant (backend : str | None ) -> str :
92+ backend = _select_backend (backend )
93+
9294 import torch
9395
94- if torch .version .cuda is not None :
96+ if backend == "cuda" and torch .version .cuda is not None :
9597 cuda_version = parse (torch .version .cuda )
9698 compute_framework = f"cu{ cuda_version .major } { cuda_version .minor } "
97- elif torch .version .hip is not None :
99+ elif backend == "rocm" and torch .version .hip is not None :
98100 rocm_version = parse (torch .version .hip .split ("-" )[0 ])
99101 compute_framework = f"rocm{ rocm_version .major } { rocm_version .minor } "
100- elif torch . backends . mps . is_available () :
102+ elif backend == "metal" :
101103 compute_framework = "metal"
102- elif hasattr ( torch . version , "xpu" ) and torch .version .xpu is not None :
104+ elif backend == "xpu" and torch .version .xpu is not None :
103105 version = torch .version .xpu
104106 compute_framework = f"xpu{ version [0 :4 ]} { version [5 :6 ]} "
105- elif _get_privateuse_backend_name () == "npu " :
107+ elif backend == "cann " :
106108 from torch_npu .utils .collect_env import get_cann_version # type: ignore[import-not-found]
107109
108110 cann_major , cann_minor = get_cann_version ()[0 ], get_cann_version ()[2 ]
@@ -125,36 +127,57 @@ def build_variant() -> str:
125127 return f"torch{ torch_version .major } { torch_version .minor } -{ cxxabi } -{ compute_framework } -{ cpu } -{ os } "
126128
127129
128- def build_variant_noarch () -> str :
129- import torch
130+ def _supported_backends () -> set [ str ] :
131+ return { "cpu" , _backend ()}
130132
131- if torch .version .cuda is not None :
133+
134+ def _select_backend (backend : str | None ) -> str :
135+ if backend is None :
136+ return _backend ()
137+
138+ supported = _supported_backends ()
139+ if backend in supported :
140+ return backend
141+
142+ raise ValueError (
143+ f"Invalid backend '{ backend } ', system supported backends: { ', ' .join (sorted (supported ))} "
144+ )
145+
146+
147+ def _build_variant_noarch (backend : str | None ) -> str :
148+ backend = _select_backend (backend )
149+
150+ if backend == "cuda" :
132151 return "torch-cuda"
133- elif torch . version . hip is not None :
152+ elif backend == "rocm" :
134153 return "torch-rocm"
135- elif torch . backends . mps . is_available () :
154+ elif backend == "metal" :
136155 return "torch-metal"
137- elif hasattr ( torch . version , "xpu" ) and torch . version . xpu is not None :
156+ elif backend == "xpu" :
138157 return "torch-xpu"
139- elif _get_privateuse_backend_name () == "npu " :
158+ elif backend == "cann " :
140159 return "torch-npu"
141160 else :
142161 return "torch-cpu"
143162
144163
145- def build_variant_universal () -> str :
164+ def _build_variant_universal () -> str :
146165 # Once we support other frameworks, detection goes here.
147166 return "torch-universal"
148167
149168
150- def build_variants ( ) -> list [str ]:
169+ def _build_variants ( backend : str | None ) -> list [str ]:
151170 """Return compatible build variants in preferred order."""
152- return [build_variant (), build_variant_noarch (), build_variant_universal ()]
171+ return [
172+ _build_variant (backend ),
173+ _build_variant_noarch (backend ),
174+ _build_variant_universal (),
175+ ]
153176
154177
155178def _import_from_path (module_name : str , variant_path : Path ) -> ModuleType :
156179 metadata = Metadata .load_from_variant (variant_path )
157- validate_dependencies (metadata .python_depends , backend ())
180+ validate_dependencies (metadata .python_depends , _backend ())
158181
159182 file_path = variant_path / "__init__.py"
160183 if not file_path .exists ():
@@ -181,6 +204,7 @@ def install_kernel(
181204 repo_id : str ,
182205 revision : str ,
183206 local_files_only : bool = False ,
207+ backend : str | None = None ,
184208 variant_locks : dict [str , VariantLock ] | None = None ,
185209 user_agent : str | dict | None = None ,
186210) -> tuple [str , Path ]:
@@ -196,6 +220,9 @@ def install_kernel(
196220 The specific revision (branch, tag, or commit) to download.
197221 local_files_only (`bool`, *optional*, defaults to `False`):
198222 Whether to only use local files and not download from the Hub.
223+ backend (`str`, *optional*):
224+ The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for.
225+ The backend will be detected automatically if not provided.
199226 variant_locks (`dict[str, VariantLock]`, *optional*):
200227 Optional dictionary of variant locks for validation.
201228 user_agent (`Union[str, dict]`, *optional*):
@@ -205,7 +232,7 @@ def install_kernel(
205232 `tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
206233 """
207234 package_name = package_name_from_repo_id (repo_id )
208- allow_patterns = [f"build/{ variant } /*" for variant in build_variants ( )]
235+ allow_patterns = [f"build/{ variant } /*" for variant in _build_variants ( backend )]
209236 api = _get_hf_api (user_agent = user_agent )
210237 repo_path = Path (
211238 str (
@@ -220,7 +247,9 @@ def install_kernel(
220247 )
221248
222249 try :
223- return _find_kernel_in_repo_path (repo_path , package_name , variant_locks )
250+ return _find_kernel_in_repo_path (
251+ repo_path , package_name , backend = backend , variant_locks = variant_locks
252+ )
224253 except FileNotFoundError :
225254 raise FileNotFoundError (
226255 f"Cannot install kernel from repo { repo_id } (revision: { revision } )"
@@ -230,9 +259,11 @@ def install_kernel(
230259def _find_kernel_in_repo_path (
231260 repo_path : Path ,
232261 package_name : str ,
262+ * ,
263+ backend : str | None = None ,
233264 variant_locks : dict [str , VariantLock ] | None = None ,
234265) -> tuple [str , Path ]:
235- variants = build_variants ( )
266+ variants = _build_variants ( backend )
236267 variant = None
237268 variant_path = None
238269 for candidate_variant in variants :
@@ -303,6 +334,7 @@ def get_kernel(
303334 repo_id : str ,
304335 revision : str | None = None ,
305336 version : int | str | None = None ,
337+ backend : str | None = None ,
306338 user_agent : str | dict | None = None ,
307339) -> ModuleType :
308340 """
@@ -319,6 +351,9 @@ def get_kernel(
319351 version (`int|str`, *optional*):
320352 The kernel version to download as an integer. The `str` variant is deprecated and will be
321353 removed in a future release. Cannot be used together with `revision`.
354+ backend (`str`, *optional*):
355+ The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for.
356+ The backend will be detected automatically if not provided.
322357 user_agent (`Union[str, dict]`, *optional*):
323358 The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
324359
@@ -342,12 +377,16 @@ def get_kernel(
342377
343378 revision = select_revision_or_version (repo_id , revision = revision , version = version )
344379 package_name , variant_path = install_kernel (
345- repo_id , revision = revision , user_agent = user_agent
380+ repo_id , revision = revision , backend = backend , user_agent = user_agent
346381 )
347382 return _import_from_path (package_name , variant_path )
348383
349384
350- def get_local_kernel (repo_path : Path , package_name : str ) -> ModuleType :
385+ def get_local_kernel (
386+ repo_path : Path ,
387+ package_name : str ,
388+ backend : str | None = None ,
389+ ) -> ModuleType :
351390 """
352391 Import a kernel from a local kernel repository path.
353392
@@ -356,13 +395,16 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
356395 The local path to the kernel repository.
357396 package_name (`str`):
358397 The name of the package to import from the repository.
398+ backend (`str`, *optional*):
399+ The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for.
400+ The backend will be detected automatically if not provided.
359401
360402 Returns:
361403 `ModuleType`: The imported kernel module.
362404 """
363405 # Presume we were given the top level path of the kernel repository.
364406 for base_path in [repo_path , repo_path / "build" ]:
365- for v in build_variants ( ):
407+ for v in _build_variants ( backend ):
366408 variant_path = base_path / v
367409 if variant_path .exists ():
368410 return _import_from_path (package_name , variant_path )
@@ -377,7 +419,10 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
377419
378420
379421def has_kernel (
380- repo_id : str , revision : str | None = None , version : int | str | None = None
422+ repo_id : str ,
423+ revision : str | None = None ,
424+ version : int | str | None = None ,
425+ backend : str | None = None ,
381426) -> bool :
382427 """
383428 Check whether a kernel build exists for the current environment (Torch version and compute framework).
@@ -390,17 +435,19 @@ def has_kernel(
390435 version (`int|str`, *optional*):
391436 The kernel version to download as an integer. The `str` variant is deprecated and will be
392437 removed in a future release. Cannot be used together with `revision`.
438+ backend (`str`, *optional*):
439+ The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for.
440+ The backend will be detected automatically if not provided.
393441
394442 Returns:
395443 `bool`: `True` if a kernel is available for the current environment.
396444 """
397445 revision = select_revision_or_version (repo_id , revision = revision , version = version )
398446
399447 package_name = package_name_from_repo_id (repo_id )
400- variant = build_variant ()
401448
402449 api = _get_hf_api ()
403- for variant in build_variants ( ):
450+ for variant in _build_variants ( backend ):
404451 for init_file in ["__init__.py" , f"{ package_name } /__init__.py" ]:
405452 if api .file_exists (
406453 repo_id , revision = revision , filename = f"build/{ variant } /{ init_file } "
@@ -410,7 +457,12 @@ def has_kernel(
410457 return False
411458
412459
413- def load_kernel (repo_id : str , * , lockfile : Path | None ) -> ModuleType :
460+ def load_kernel (
461+ repo_id : str ,
462+ * ,
463+ lockfile : Path | None ,
464+ backend : str | None = None ,
465+ ) -> ModuleType :
414466 """
415467 Get a pre-downloaded, locked kernel.
416468
@@ -421,6 +473,9 @@ def load_kernel(repo_id: str, *, lockfile: Path | None) -> ModuleType:
421473 The Hub repository containing the kernel.
422474 lockfile (`Path`, *optional*):
423475 Path to the lockfile. If not provided, the lockfile will be loaded from the caller's package metadata.
476+ backend (`str`, *optional*):
477+ The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for.
478+ The backend will be detected automatically if not provided.
424479
425480 Returns:
426481 `ModuleType`: The imported kernel module.
@@ -439,7 +494,7 @@ def load_kernel(repo_id: str, *, lockfile: Path | None) -> ModuleType:
439494 package_name = package_name_from_repo_id (repo_id )
440495
441496 api = _get_hf_api ()
442- allow_patterns = [f"build/{ variant } /*" for variant in build_variants ( )]
497+ allow_patterns = [f"build/{ variant } /*" for variant in _build_variants ( backend )]
443498 repo_path = Path (
444499 str (
445500 api .snapshot_download (
@@ -454,7 +509,7 @@ def load_kernel(repo_id: str, *, lockfile: Path | None) -> ModuleType:
454509
455510 try :
456511 package_name , variant_path = _find_kernel_in_repo_path (
457- repo_path , package_name , variant_locks = None
512+ repo_path , package_name , backend = backend , variant_locks = None
458513 )
459514 return _import_from_path (package_name , variant_path )
460515 except FileNotFoundError :
@@ -605,7 +660,7 @@ def _get_hf_api(user_agent: str | dict | None = None) -> HfApi:
605660
606661 # System info
607662 python = "." .join (platform .python_version_tuple ()[:2 ])
608- user_agent_str += f"; kernels/{ __version__ } ; python/{ python } ; torch/{ torch .__version__ } ; build_variant/{ build_variant ( )} ; file_type/kernel"
663+ user_agent_str += f"; kernels/{ __version__ } ; python/{ python } ; torch/{ torch .__version__ } ; build_variant/{ _build_variant ( None )} ; file_type/kernel"
609664
610665 # Add glibc version if available
611666 glibc = glibc_version ()
0 commit comments