From b3a857f1d158e972d653a98ec0bec2016103e443 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 30 Apr 2026 12:18:20 +0530 Subject: [PATCH] up --- kernels/src/kernels/utils.py | 7 +++++++ kernels/tests/test_basic.py | 13 +++++++++++++ 2 files changed, 20 insertions(+) diff --git a/kernels/src/kernels/utils.py b/kernels/src/kernels/utils.py index eea224e9..582ce255 100644 --- a/kernels/src/kernels/utils.py +++ b/kernels/src/kernels/utils.py @@ -31,6 +31,7 @@ ) KNOWN_BACKENDS = {"cpu", "cuda", "metal", "neuron", "rocm", "xpu", "npu"} +TRUSTED_PUBLISHERS = {"kernels-community", "kernels-test", "kernels-staging", "sgl-project"} @dataclass(frozen=True) @@ -304,6 +305,7 @@ def get_kernel( version: int | None = None, backend: str | None = None, user_agent: str | dict | None = None, + trust_remote_code: bool = False, ) -> ModuleType: """ Load a kernel from the kernel hub. @@ -323,6 +325,8 @@ def get_kernel( The backend will be detected automatically if not provided. user_agent (`Union[str, dict]`, *optional*): The `user_agent` info to pass to `snapshot_download()` for internal telemetry. + trust_remote_code (`bool`): Boolean flag indicating if the kernel should be trusted. If the kernel is + from a trusted publisher then this flag is ignored. Returns: `ModuleType`: The imported kernel module. @@ -338,6 +342,9 @@ def get_kernel( result = activation.relu(out, x) ``` """ + if repo_id.split("/")[0] not in TRUSTED_PUBLISHERS and not trust_remote_code: + raise ValueError("You must set `trust_remote_code=True` to use this kernel. Make sure you trust its binary!") + override = _get_local_kernel_overrides().get(repo_id, None) if override is not None: return get_local_kernel(override, package_name_from_repo_id(repo_id)) diff --git a/kernels/tests/test_basic.py b/kernels/tests/test_basic.py index 889889b6..4665db27 100644 --- a/kernels/tests/test_basic.py +++ b/kernels/tests/test_basic.py @@ -287,6 +287,19 @@ def test_local_overrides(monkeypatch, local_kernel_path): get_kernel("kernels-test/activation") +def test_trust_remote_code(monkeypatch): + repo_id = "kernels-test/versions" + + with monkeypatch.context() as m: + m.setattr("kernels.utils.TRUSTED_PUBLISHERS", set()) + with pytest.raises(ValueError, match="trust_remote_code=True"): + get_kernel(repo_id, version=1) + + get_kernel(repo_id, version=1, trust_remote_code=True) + + get_kernel(repo_id, version=1) + + @pytest.mark.neuron_only def test_neuron(): relu = get_kernel("kernels-test/relu-nki", version=1)