1212from importlib .metadata import Distribution
1313from pathlib import Path
1414from types import ModuleType
15+ from typing import NamedTuple
1516
1617from huggingface_hub import HfApi , constants
1718
3334KNOWN_BACKENDS = {"cpu" , "cuda" , "metal" , "neuron" , "rocm" , "xpu" , "npu" }
3435
3536
37+ class RepoInfos (NamedTuple ):
38+ repo_id : str
39+ revision : str
40+ backend : str | None
41+
42+
43+ class LoadedKernel (NamedTuple ):
44+ module : ModuleType
45+ package_name : str
46+ repo_infos : RepoInfos | None
47+
48+
49+ _loaded_kernels : dict [Path , LoadedKernel ] = {}
50+
51+
52+ def get_loaded_kernels () -> list [LoadedKernel ]:
53+ """Returns a copy of the loaded kernels registry (see `kernels.utils.LoadedKernel` NamedTuple)."""
54+ return list (_loaded_kernels .values ())
55+
56+
3657def _get_cache_dir () -> str | None :
3758 """Returns the kernels cache directory."""
3859 cache_dir = os .environ .get ("HF_KERNELS_CACHE" , None )
@@ -71,7 +92,12 @@ def _parse_local_kernel_overrides(local_kernels: str) -> dict[str, Path]:
7192CACHE_DIR : str | None = _get_cache_dir ()
7293
7394
74- def _import_from_path (module_name : str , variant_path : Path ) -> ModuleType :
95+ def _import_from_path (
96+ module_name : str , variant_path : Path , _repo_infos : RepoInfos | None = None
97+ ) -> ModuleType :
98+ if (loaded_kernel := _loaded_kernels .get (variant_path )) is not None :
99+ return loaded_kernel .module
100+
75101 metadata = Metadata .load_from_variant (variant_path )
76102 validate_dependencies (module_name , metadata .python_depends , _backend ())
77103
@@ -83,6 +109,7 @@ def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
83109 # it would also be used for other imports. So, we make a module name that
84110 # depends on the path for it to be unique using the hex-encoded hash of
85111 # the path.
112+ package_name = module_name
86113 path_hash = "{:x}" .format (ctypes .c_size_t (hash (file_path )).value )
87114 module_name = f"{ module_name } _{ path_hash } "
88115 spec = importlib .util .spec_from_file_location (module_name , file_path )
@@ -93,6 +120,12 @@ def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
93120 raise ImportError (f"Cannot load module { module_name } from spec" )
94121 sys .modules [module_name ] = module
95122 spec .loader .exec_module (module ) # type: ignore
123+
124+ _loaded_kernels [variant_path ] = LoadedKernel (
125+ module = module ,
126+ package_name = package_name ,
127+ repo_infos = _repo_infos ,
128+ )
96129 return module
97130
98131
@@ -279,10 +312,15 @@ def get_kernel(
279312 return get_local_kernel (override , package_name_from_repo_id (repo_id ))
280313
281314 revision = select_revision_or_version (repo_id , revision = revision , version = version )
315+ repo_infos = RepoInfos (
316+ repo_id = repo_id ,
317+ revision = revision ,
318+ backend = backend ,
319+ )
282320 package_name , variant_path = install_kernel (
283321 repo_id , revision = revision , backend = backend , user_agent = user_agent
284322 )
285- return _import_from_path (package_name , variant_path )
323+ return _import_from_path (package_name , variant_path , _repo_infos = repo_infos )
286324
287325
288326def get_local_kernel (
0 commit comments