diff --git a/docs/source/api/kernels.md b/docs/source/api/kernels.md index 7d4ebc0f..afff8f6e 100644 --- a/docs/source/api/kernels.md +++ b/docs/source/api/kernels.md @@ -14,6 +14,10 @@ [[autodoc]] kernels.has_kernel +### get_loaded_kernels + +[[autodoc]] kernels.get_loaded_kernels + ## Loading locked kernels ### load_kernel diff --git a/docs/source/basic-usage.md b/docs/source/basic-usage.md index a93a47d6..1e622d8c 100644 --- a/docs/source/basic-usage.md +++ b/docs/source/basic-usage.md @@ -43,3 +43,23 @@ from kernels import has_kernel is_available = has_kernel("kernels-community/activation", version=1) print(f"Kernel available: {is_available}") ``` + +## Inspecting Loaded Kernels + +`get_loaded_kernels()` returns a snapshot of every kernel that has been loaded +into the current process. Each entry is a `LoadedKernel` namedtuple with the +imported `module`, the `package_name`, and `repo_infos` (repo id, resolved +revision, and the backend argument that was passed). + +```python +from kernels import get_kernel, get_loaded_kernels + +get_kernel("kernels-community/activation", version=1) + +for loaded in get_loaded_kernels(): + print(loaded.package_name, loaded.repo_infos) +``` + +`repo_infos` is populated only for kernels loaded with `get_kernel`. Kernels +loaded from a local path (`get_local_kernel`) or via a lockfile +(`get_locked_kernel`, `load_kernel`) have `repo_infos=None`. diff --git a/kernels/src/kernels/utils.py b/kernels/src/kernels/utils.py index 831d86e2..eea224e9 100644 --- a/kernels/src/kernels/utils.py +++ b/kernels/src/kernels/utils.py @@ -52,7 +52,42 @@ class LoadedKernel: def get_loaded_kernels() -> list[LoadedKernel]: - """Returns a copy of the loaded kernels registry (see `kernels.utils.LoadedKernel` NamedTuple).""" + """ + Return a snapshot of every kernel that has been loaded into the current process. + + Each entry is a `kernels.utils.LoadedKernel` dataclass with fields: + + - `kernel_id` (`str`): unique identifier used as the `sys.modules` key + for this variant (either `metadata.id` or a hash-suffixed module name). + - `module` (`ModuleType`): the imported kernel module. + - `module_name` (`str`): the kernel's module name. + - `repo_infos` (`kernels.utils.RepoInfos | None`): populated only for + kernels loaded via `get_kernel`. Loaders that work from a local path + (`get_local_kernel`) or a lockfile (`get_locked_kernel`, `load_kernel`) + leave this as `None`. + + `RepoInfos` has `repo_id`, `revision`, and `backend` fields. `backend` + reflects the value passed by the caller — it is `None` when the caller + relied on backend auto-detection. + + The returned list is a new list; mutating it does not affect the registry. + + > [!NOTE] + > These arguments might be renamed / changed a bit. + + Returns: + `list[LoadedKernel]`: one entry per distinct kernel variant path + loaded in this process. + + Example: + ```python + from kernels import get_kernel, get_loaded_kernels + + get_kernel("kernels-community/activation", version=1) + for loaded in get_loaded_kernels(): + print(loaded.module_name, loaded.repo_infos) + ``` + """ return list(_loaded_kernels.values()) diff --git a/kernels/tests/test_loaded_kernels.py b/kernels/tests/test_loaded_kernels.py new file mode 100644 index 00000000..87e13a84 --- /dev/null +++ b/kernels/tests/test_loaded_kernels.py @@ -0,0 +1,98 @@ +from dataclasses import fields + +import pytest + +from kernels import get_kernel, get_loaded_kernels, get_local_kernel, install_kernel +from kernels.utils import LoadedKernel, RepoInfos, _loaded_kernels + +_REPO_ID = "kernels-community/relu" +_PACKAGE_NAME = "relu" +_VERSION = 1 + + +@pytest.fixture +def fresh_registry(): + """Snapshot the process-wide registry, run the test with a clean one, restore on teardown.""" + saved = _loaded_kernels.copy() + _loaded_kernels.clear() + yield + _loaded_kernels.clear() + _loaded_kernels.update(saved) + + +def test_dataclass_shape(): + assert tuple(f.name for f in fields(LoadedKernel)) == ( + "kernel_id", + "module", + "module_name", + "repo_infos", + ) + assert tuple(f.name for f in fields(RepoInfos)) == ("repo_id", "revision", "backend") + + +def test_get_loaded_kernels_returns_copy(fresh_registry): + kernel = get_kernel(_REPO_ID, version=_VERSION, backend="cpu") + + snapshot = get_loaded_kernels() + assert len(snapshot) == 1 + + snapshot.clear() + snapshot.append("garbage") # type: ignore[arg-type] + + again = get_loaded_kernels() + assert len(again) == 1 + assert again[0].module is kernel + + +def test_get_kernel_registers_loaded_kernel(fresh_registry): + kernel = get_kernel(_REPO_ID, version=_VERSION, backend="cpu") + + loaded = get_loaded_kernels() + assert len(loaded) == 1 + + entry = loaded[0] + assert entry.module is kernel + assert entry.module_name == _PACKAGE_NAME + assert entry.repo_infos is not None + assert entry.repo_infos.repo_id == _REPO_ID + assert isinstance(entry.repo_infos.revision, str) and entry.repo_infos.revision + assert entry.repo_infos.backend == "cpu" + + +def test_repeated_get_kernel_is_cached(fresh_registry): + first = get_kernel(_REPO_ID, version=_VERSION, backend="cpu") + second = get_kernel(_REPO_ID, version=_VERSION, backend="cpu") + + assert first is second + assert len(get_loaded_kernels()) == 1 + + +def test_get_local_kernel_registers_with_null_repo_infos(fresh_registry): + # Populate the HF cache via get_kernel, grab the variant path it registered, + # then clear the registry and exercise get_local_kernel against that path. + get_kernel(_REPO_ID, version=_VERSION, backend="cpu") + (variant_path,) = list(_loaded_kernels.keys()) + + _loaded_kernels.clear() + + kernel = get_local_kernel(variant_path, _PACKAGE_NAME, backend="cpu") + + loaded = get_loaded_kernels() + assert len(loaded) == 1 + + entry = loaded[0] + assert entry.module is kernel + assert entry.module_name == _PACKAGE_NAME + assert entry.repo_infos is None + + +def test_install_kernel_plus_import_does_not_set_repo_infos(fresh_registry): + # install_kernel alone does not import; it returns a path. Any loader + # that does not go through get_kernel must leave repo_infos as None. + package_name, variant_path = install_kernel(_REPO_ID, revision="main", backend="cpu") + assert package_name == _PACKAGE_NAME + assert get_loaded_kernels() == [] + + get_local_kernel(variant_path, package_name, backend="cpu") + (entry,) = get_loaded_kernels() + assert entry.repo_infos is None