Skip to content

Commit 0d6cb13

Browse files
authored
feat: get_loaded_kernels() (#428)
* get_laoded_kernels * Fix op_namespace * Improve docstring * Only for custom-op kernels * Revert "Improve docstring" This reverts commit b70f746. * Improve doc on the right function * Black * Index by module_name + best-effort op namespace * Only return values + get_kernel caching * Docstring * revision can't be None * op namespace from .so name (works for both torch and tvm) * module version + addressing comments
1 parent afe4033 commit 0d6cb13

2 files changed

Lines changed: 42 additions & 2 deletions

File tree

kernels/src/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from kernels.utils import (
2525
get_kernel,
26+
get_loaded_kernels,
2627
get_local_kernel,
2728
get_locked_kernel,
2829
has_kernel,
@@ -45,6 +46,7 @@
4546
"LockedLayerRepository",
4647
"Mode",
4748
"get_kernel",
49+
"get_loaded_kernels",
4850
"get_local_kernel",
4951
"get_locked_kernel",
5052
"has_kernel",

kernels/src/kernels/utils.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from importlib.metadata import Distribution
1313
from pathlib import Path
1414
from types import ModuleType
15+
from typing import NamedTuple
1516

1617
from huggingface_hub import HfApi, constants
1718

@@ -33,6 +34,26 @@
3334
KNOWN_BACKENDS = {"cpu", "cuda", "metal", "neuron", "rocm", "xpu", "npu"}
3435

3536

37+
class RepoInfos(NamedTuple):
38+
repo_id: str
39+
revision: str
40+
backend: str | None
41+
42+
43+
class LoadedKernel(NamedTuple):
44+
module: ModuleType
45+
package_name: str
46+
repo_infos: RepoInfos | None
47+
48+
49+
_loaded_kernels: dict[Path, LoadedKernel] = {}
50+
51+
52+
def get_loaded_kernels() -> list[LoadedKernel]:
53+
"""Returns a copy of the loaded kernels registry (see `kernels.utils.LoadedKernel` NamedTuple)."""
54+
return list(_loaded_kernels.values())
55+
56+
3657
def _get_cache_dir() -> str | None:
3758
"""Returns the kernels cache directory."""
3859
cache_dir = os.environ.get("HF_KERNELS_CACHE", None)
@@ -71,7 +92,12 @@ def _parse_local_kernel_overrides(local_kernels: str) -> dict[str, Path]:
7192
CACHE_DIR: str | None = _get_cache_dir()
7293

7394

74-
def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
95+
def _import_from_path(
96+
module_name: str, variant_path: Path, _repo_infos: RepoInfos | None = None
97+
) -> ModuleType:
98+
if (loaded_kernel := _loaded_kernels.get(variant_path)) is not None:
99+
return loaded_kernel.module
100+
75101
metadata = Metadata.load_from_variant(variant_path)
76102
validate_dependencies(module_name, metadata.python_depends, _backend())
77103

@@ -83,6 +109,7 @@ def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
83109
# it would also be used for other imports. So, we make a module name that
84110
# depends on the path for it to be unique using the hex-encoded hash of
85111
# the path.
112+
package_name = module_name
86113
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value)
87114
module_name = f"{module_name}_{path_hash}"
88115
spec = importlib.util.spec_from_file_location(module_name, file_path)
@@ -93,6 +120,12 @@ def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
93120
raise ImportError(f"Cannot load module {module_name} from spec")
94121
sys.modules[module_name] = module
95122
spec.loader.exec_module(module) # type: ignore
123+
124+
_loaded_kernels[variant_path] = LoadedKernel(
125+
module=module,
126+
package_name=package_name,
127+
repo_infos=_repo_infos,
128+
)
96129
return module
97130

98131

@@ -279,10 +312,15 @@ def get_kernel(
279312
return get_local_kernel(override, package_name_from_repo_id(repo_id))
280313

281314
revision = select_revision_or_version(repo_id, revision=revision, version=version)
315+
repo_infos = RepoInfos(
316+
repo_id=repo_id,
317+
revision=revision,
318+
backend=backend,
319+
)
282320
package_name, variant_path = install_kernel(
283321
repo_id, revision=revision, backend=backend, user_agent=user_agent
284322
)
285-
return _import_from_path(package_name, variant_path)
323+
return _import_from_path(package_name, variant_path, _repo_infos=repo_infos)
286324

287325

288326
def get_local_kernel(

0 commit comments

Comments
 (0)