Skip to content
18 changes: 14 additions & 4 deletions kernels/src/kernels/layer/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ def __init__(
func_name: str,
revision: str | None = None,
version: int | None = None,
trust_remote_code: bool | list[str] = False,
):
if revision is not None and version is not None:
raise ValueError("Either a revision or a version must be specified, not both.")

self._repo_id = repo_id
self.func_name = func_name
self._trust_remote_code = trust_remote_code

# We are going to resolve these lazily, since we do not want
# to do a network request for every registered FuncRepository.
Expand All @@ -85,7 +87,9 @@ def _resolve_revision(self) -> str:
)

def load(self) -> Type["nn.Module"]:
kernel = get_kernel(self._repo_id, revision=self._resolve_revision())
kernel = get_kernel(
self._repo_id, revision=self._resolve_revision(), trust_remote_code=self._trust_remote_code
)
return _get_kernel_func(self, kernel)

def __eq__(self, other):
Expand All @@ -95,10 +99,11 @@ def __eq__(self, other):
and self._repo_id == other._repo_id
and self._revision == other._revision
and self._version == other._version
and self._trust_remote_code == other._trust_remote_code
)

def __hash__(self):
return hash((self.func_name, self._repo_id, self._revision, self._version))
return hash((self.func_name, self._repo_id, self._revision, self._version, self._trust_remote_code))

def __str__(self) -> str:
return f"`{self._repo_id}` (revision: {self._resolve_revision()}), function `{self.func_name}`"
Expand Down Expand Up @@ -224,6 +229,7 @@ def __init__(
*,
lockfile: Path | None = None,
func_name: str,
trust_remote_code: bool | list[str] = False,
):
"""
Construct a function repository.
Expand All @@ -233,10 +239,13 @@ def __init__(
lockfile (`Path`, *optional*): Path to the lockfile. If not provided,
the lockfile will be inferred from the caller's context.
func_name (`str`): The name of the function within the kernel repository.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether to allow loading kernels from untrusted organisations.
"""
self._repo_id = repo_id
self._lockfile = lockfile
self.func_name = func_name
self._trust_remote_code = trust_remote_code
self._revision = self._resolve_revision()

def _resolve_revision(self) -> str:
Expand All @@ -252,7 +261,7 @@ def _resolve_revision(self) -> str:
return locked_sha

def load(self) -> Type["nn.Module"]:
kernel = get_kernel(repo_id=self._repo_id, revision=self._revision)
kernel = get_kernel(repo_id=self._repo_id, revision=self._revision, trust_remote_code=self._trust_remote_code)
return _get_kernel_func(self, kernel)

def __eq__(self, other):
Expand All @@ -261,10 +270,11 @@ def __eq__(self, other):
and self.func_name == other.func_name
and self._repo_id == other._repo_id
and self._revision == other._revision
and self._trust_remote_code == other._trust_remote_code
)

def __hash__(self):
return hash((self.func_name, self._repo_id, self._revision))
return hash((self.func_name, self._repo_id, self._revision, self._trust_remote_code))

def __str__(self) -> str:
return f"`{self._repo_id}` (revision: {self._revision}), function `{self.func_name}`"
Expand Down
20 changes: 16 additions & 4 deletions kernels/src/kernels/layer/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class LayerRepository:
The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
version (`int`, *optional*):
The kernel version to download. Cannot be used together with `revision`.
trust_remote_code (`bool | list[str]`, *optional*, defaults to `False`):
Whether to allow loading kernels from untrusted organisations. A list
of signing identities can be provided for future verification support;
until then it warns and falls back to the default trust check.

Example:
```python
Expand All @@ -63,12 +67,14 @@ def __init__(
layer_name: str,
revision: str | None = None,
version: int | None = None,
trust_remote_code: bool | list[str] = False,
):
if revision is not None and version is not None:
raise ValueError("Either a revision or a version must be specified, not both.")

self._repo_id = repo_id
self.layer_name = layer_name
self._trust_remote_code = trust_remote_code

# We are going to resolve these lazily, since we do not want
# to do a network request for every registered LayerRepository.
Expand All @@ -84,7 +90,9 @@ def _resolve_revision(self) -> str:
)

def load(self) -> Type["nn.Module"]:
kernel = get_kernel(self._repo_id, revision=self._resolve_revision())
kernel = get_kernel(
self._repo_id, revision=self._resolve_revision(), trust_remote_code=self._trust_remote_code
)
return _get_kernel_layer(self, kernel)

def __eq__(self, other):
Expand All @@ -94,10 +102,11 @@ def __eq__(self, other):
and self._repo_id == other._repo_id
and self._revision == other._revision
and self._version == other._version
and self._trust_remote_code == other._trust_remote_code
)

def __hash__(self):
return hash((self.layer_name, self._repo_id, self._revision, self._version))
return hash((self.layer_name, self._repo_id, self._revision, self._version, self._trust_remote_code))

def __str__(self) -> str:
return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`"
Expand Down Expand Up @@ -168,6 +177,7 @@ def __init__(
*,
lockfile: Path | None = None,
layer_name: str,
trust_remote_code: bool | list[str] = False,
):
"""
Construct a layer repository.
Expand All @@ -178,6 +188,7 @@ def __init__(
self._repo_id = repo_id
self._lockfile = lockfile
self.layer_name = layer_name
self._trust_remote_code = trust_remote_code
self._revision = self._resolve_revision()

def _resolve_revision(self) -> str:
Expand All @@ -193,7 +204,7 @@ def _resolve_revision(self) -> str:
return locked_sha

def load(self) -> Type["nn.Module"]:
kernel = get_kernel(repo_id=self._repo_id, revision=self._revision)
kernel = get_kernel(repo_id=self._repo_id, revision=self._revision, trust_remote_code=self._trust_remote_code)
return _get_kernel_layer(self, kernel)

def __eq__(self, other):
Expand All @@ -202,10 +213,11 @@ def __eq__(self, other):
and self.layer_name == other.layer_name
and self._repo_id == other._repo_id
and self._revision == other._revision
and self._trust_remote_code == other._trust_remote_code
)

def __hash__(self):
return hash((self.layer_name, self._repo_id, self._revision))
return hash((self.layer_name, self._repo_id, self._revision, self._trust_remote_code))

def __str__(self) -> str:
return f"`{self._repo_id}` (revision: {self._revision}), layer `{self.layer_name}`"
Expand Down
67 changes: 67 additions & 0 deletions kernels/src/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,64 @@

KNOWN_BACKENDS = {"cpu", "cuda", "metal", "neuron", "rocm", "xpu", "npu"}

_ALWAYS_TRUSTED_ORGS = {"kernels-community", "kernels-staging", "kernels-test", "sglang"}


def _check_trust_remote_code(repo_id: str, trust_remote_code: bool | list[str]) -> None:
"""Check whether a kernel repository is trusted.

When ``trust_remote_code`` is ``False`` (the default), only repositories
whose publisher is marked as trusted on the Hub are allowed. Repositories
from untrusted publishers will raise a ``ValueError``.

When ``trust_remote_code`` is ``True``, all repositories are allowed.

When ``trust_remote_code`` is a list of strings, it is treated as a list
of signing identities to verify against. Signing verification is not yet
implemented, so passing a list currently emits a warning and falls back
to the default trust check (i.e. only trusted publishers are allowed).
"""
if trust_remote_code is True:
return

if isinstance(trust_remote_code, list):
import warnings

warnings.warn(
"Signing identity verification is not yet implemented. "
"The provided signing identities will be ignored and the "
"kernel will be treated as untrusted. Use trust_remote_code=True "
"to bypass trust checks.",
stacklevel=3,
)

org = repo_id.split("/", 1)[0]
if org in _ALWAYS_TRUSTED_ORGS:
return

raise ValueError(
f"Kernel repository '{repo_id}' is not from a trusted publisher. "
f"Set trust_remote_code=True to allow loading kernels from untrusted sources."
)

# TODO: revisit and update logic when we can check trusted publishers at the
# user/organization level
#
# api = _get_hf_api()
# try:
# info = api.repo_info(repo_id, repo_type="kernel")
# except Exception as e:
# raise ValueError(
# f"Could not verify publisher trust status for kernel repository '{repo_id}'. "
# "Set trust_remote_code=True to allow loading kernels from untrusted sources."
# ) from e

# if not getattr(info, "trustedPublisher", False):
# raise ValueError(
# f"Kernel repository '{repo_id}' is not from a trusted publisher. "
# f"Set trust_remote_code=True to allow loading kernels from untrusted sources."
# )


@dataclass(frozen=True)
class RepoInfo:
Expand Down Expand Up @@ -293,6 +351,7 @@ def get_kernel(
version: int | None = None,
backend: str | None = None,
user_agent: str | dict | None = None,
trust_remote_code: bool | list[str] = False,
) -> ModuleType:
"""
Load a kernel from the kernel hub.
Expand All @@ -312,6 +371,12 @@ def get_kernel(
The backend will be detected automatically if not provided.
user_agent (`Union[str, dict]`, *optional*):
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
trust_remote_code (`bool | list[str]`, *optional*, defaults to `False`):
Whether to allow loading kernels from untrusted organisations. When ``False``,
only kernels from trusted organisations are allowed. When ``True``, all
repositories are allowed. A list of strings will be used to verify signing
identities in a future release; for now it emits a warning and falls
back to the default trust check.

Returns:
`ModuleType`: The imported kernel module.
Expand All @@ -331,6 +396,8 @@ def get_kernel(
if override is not None:
return get_local_kernel(override)

_check_trust_remote_code(repo_id, trust_remote_code)

revision = select_revision_or_version(repo_id, revision=revision, version=version)
repo_info = RepoInfo(
repo_id=repo_id,
Expand Down
16 changes: 16 additions & 0 deletions kernels/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,22 @@ def test_neuron():
torch.testing.assert_close(relu.relu(x), x.relu())


def test_trust_remote_code_blocks_untrusted_org():
"""Kernels from untrusted orgs should be rejected by default."""
with pytest.raises(ValueError, match=r"not from a trusted publisher"):
get_kernel("kernels-test-untrusted/not-a-trused-org-kernel", version=1)


def test_trust_remote_code_allows_trusted_org():
"""Kernels from trusted orgs should not raise a trust error."""
get_kernel("kernels-community/relu", version=1)


def test_trust_remote_code_flag_allows_untrusted():
"""trust_remote_code=True should bypass the org check."""
get_kernel("kernels-test-untrusted/ci-test-kernel", version=1, trust_remote_code=True)


def silu_and_mul_torch(x: torch.Tensor):
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
Loading