diff --git a/kernels/src/kernels/utils.py b/kernels/src/kernels/utils.py index 829c740f..4e0b918f 100644 --- a/kernels/src/kernels/utils.py +++ b/kernels/src/kernels/utils.py @@ -30,6 +30,7 @@ ) KNOWN_BACKENDS = {"cpu", "cuda", "metal", "neuron", "rocm", "xpu", "npu"} +TRUSTED_PUBLISHERS = {"kernels-community", "kernels-test", "kernels-staging", "sgl-project"} @dataclass(frozen=True) @@ -293,6 +294,7 @@ def get_kernel( version: int | None = None, backend: str | None = None, user_agent: str | dict | None = None, + trust_remote_code: bool = False, ) -> ModuleType: """ Load a kernel from the kernel hub. @@ -312,6 +314,8 @@ def get_kernel( The backend will be detected automatically if not provided. user_agent (`Union[str, dict]`, *optional*): The `user_agent` info to pass to `snapshot_download()` for internal telemetry. + trust_remote_code (`bool`): Boolean flag indicating if the kernel should be trusted. If the kernel is + from a trusted publisher then this flag is ignored. Returns: `ModuleType`: The imported kernel module. @@ -327,6 +331,9 @@ def get_kernel( result = activation.relu(out, x) ``` """ + if repo_id.split("/")[0] not in TRUSTED_PUBLISHERS and not trust_remote_code: + raise ValueError("You must set `trust_remote_code=True` to use this kernel. Make sure you trust its binary!") + override = _get_local_kernel_overrides().get(repo_id, None) if override is not None: return get_local_kernel(override) diff --git a/kernels/tests/test_basic.py b/kernels/tests/test_basic.py index c2978a04..a1c3d1bc 100644 --- a/kernels/tests/test_basic.py +++ b/kernels/tests/test_basic.py @@ -226,6 +226,19 @@ def test_local_overrides(monkeypatch, local_kernel_path): get_kernel("kernels-test/activation") +def test_trust_remote_code(monkeypatch): + repo_id = "kernels-test/versions" + + with monkeypatch.context() as m: + m.setattr("kernels.utils.TRUSTED_PUBLISHERS", set()) + with pytest.raises(ValueError, match="trust_remote_code=True"): + get_kernel(repo_id, version=1) + + get_kernel(repo_id, version=1, trust_remote_code=True) + + get_kernel(repo_id, version=1) + + @pytest.mark.neuron_only def test_neuron(): relu = get_kernel("kernels-test/relu-nki", version=1)