Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions kernels/src/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)

KNOWN_BACKENDS = {"cpu", "cuda", "metal", "neuron", "rocm", "xpu", "npu"}
TRUSTED_PUBLISHERS = {"kernels-community", "kernels-test", "kernels-staging", "sgl-project"}


@dataclass(frozen=True)
Expand Down Expand Up @@ -293,6 +294,7 @@ def get_kernel(
version: int | None = None,
backend: str | None = None,
user_agent: str | dict | None = None,
trust_remote_code: bool = False,
) -> ModuleType:
"""
Load a kernel from the kernel hub.
Expand All @@ -312,6 +314,8 @@ def get_kernel(
The backend will be detected automatically if not provided.
user_agent (`Union[str, dict]`, *optional*):
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
trust_remote_code (`bool`): Boolean flag indicating if the kernel should be trusted. If the kernel is
from a trusted publisher then this flag is ignored.

Returns:
`ModuleType`: The imported kernel module.
Expand All @@ -327,6 +331,9 @@ def get_kernel(
result = activation.relu(out, x)
```
"""
if repo_id.split("/")[0] not in TRUSTED_PUBLISHERS and not trust_remote_code:
raise ValueError("You must set `trust_remote_code=True` to use this kernel. Make sure you trust its binary!")

override = _get_local_kernel_overrides().get(repo_id, None)
if override is not None:
return get_local_kernel(override)
Expand Down
13 changes: 13 additions & 0 deletions kernels/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,19 @@ def test_local_overrides(monkeypatch, local_kernel_path):
get_kernel("kernels-test/activation")


def test_trust_remote_code(monkeypatch):
repo_id = "kernels-test/versions"

with monkeypatch.context() as m:
m.setattr("kernels.utils.TRUSTED_PUBLISHERS", set())
with pytest.raises(ValueError, match="trust_remote_code=True"):
get_kernel(repo_id, version=1)

get_kernel(repo_id, version=1, trust_remote_code=True)

get_kernel(repo_id, version=1)


@pytest.mark.neuron_only
def test_neuron():
relu = get_kernel("kernels-test/relu-nki", version=1)
Expand Down
Loading