diff --git a/kernels/src/kernels/layer/func.py b/kernels/src/kernels/layer/func.py index d98551e3..b7cfd567 100644 --- a/kernels/src/kernels/layer/func.py +++ b/kernels/src/kernels/layer/func.py @@ -64,12 +64,14 @@ def __init__( func_name: str, revision: str | None = None, version: int | None = None, + trust_remote_code: bool | list[str] = False, ): if revision is not None and version is not None: raise ValueError("Either a revision or a version must be specified, not both.") self._repo_id = repo_id self.func_name = func_name + self._trust_remote_code = trust_remote_code # We are going to resolve these lazily, since we do not want # to do a network request for every registered FuncRepository. @@ -85,7 +87,9 @@ def _resolve_revision(self) -> str: ) def load(self) -> Type["nn.Module"]: - kernel = get_kernel(self._repo_id, revision=self._resolve_revision()) + kernel = get_kernel( + self._repo_id, revision=self._resolve_revision(), trust_remote_code=self._trust_remote_code + ) return _get_kernel_func(self, kernel) def __eq__(self, other): @@ -95,10 +99,11 @@ def __eq__(self, other): and self._repo_id == other._repo_id and self._revision == other._revision and self._version == other._version + and self._trust_remote_code == other._trust_remote_code ) def __hash__(self): - return hash((self.func_name, self._repo_id, self._revision, self._version)) + return hash((self.func_name, self._repo_id, self._revision, self._version, self._trust_remote_code)) def __str__(self) -> str: return f"`{self._repo_id}` (revision: {self._resolve_revision()}), function `{self.func_name}`" @@ -224,6 +229,7 @@ def __init__( *, lockfile: Path | None = None, func_name: str, + trust_remote_code: bool | list[str] = False, ): """ Construct a function repository. @@ -233,10 +239,13 @@ def __init__( lockfile (`Path`, *optional*): Path to the lockfile. If not provided, the lockfile will be inferred from the caller's context. func_name (`str`): The name of the function within the kernel repository. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether to allow loading kernels from untrusted organisations. """ self._repo_id = repo_id self._lockfile = lockfile self.func_name = func_name + self._trust_remote_code = trust_remote_code self._revision = self._resolve_revision() def _resolve_revision(self) -> str: @@ -252,7 +261,7 @@ def _resolve_revision(self) -> str: return locked_sha def load(self) -> Type["nn.Module"]: - kernel = get_kernel(repo_id=self._repo_id, revision=self._revision) + kernel = get_kernel(repo_id=self._repo_id, revision=self._revision, trust_remote_code=self._trust_remote_code) return _get_kernel_func(self, kernel) def __eq__(self, other): @@ -261,10 +270,11 @@ def __eq__(self, other): and self.func_name == other.func_name and self._repo_id == other._repo_id and self._revision == other._revision + and self._trust_remote_code == other._trust_remote_code ) def __hash__(self): - return hash((self.func_name, self._repo_id, self._revision)) + return hash((self.func_name, self._repo_id, self._revision, self._trust_remote_code)) def __str__(self) -> str: return f"`{self._repo_id}` (revision: {self._revision}), function `{self.func_name}`" diff --git a/kernels/src/kernels/layer/layer.py b/kernels/src/kernels/layer/layer.py index 505aacb8..d280b6b9 100644 --- a/kernels/src/kernels/layer/layer.py +++ b/kernels/src/kernels/layer/layer.py @@ -42,6 +42,10 @@ class LayerRepository: The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. version (`int`, *optional*): The kernel version to download. Cannot be used together with `revision`. + trust_remote_code (`bool | list[str]`, *optional*, defaults to `False`): + Whether to allow loading kernels from untrusted organisations. A list + of signing identities can be provided for future verification support; + until then it warns and falls back to the default trust check. Example: ```python @@ -63,12 +67,14 @@ def __init__( layer_name: str, revision: str | None = None, version: int | None = None, + trust_remote_code: bool | list[str] = False, ): if revision is not None and version is not None: raise ValueError("Either a revision or a version must be specified, not both.") self._repo_id = repo_id self.layer_name = layer_name + self._trust_remote_code = trust_remote_code # We are going to resolve these lazily, since we do not want # to do a network request for every registered LayerRepository. @@ -84,7 +90,9 @@ def _resolve_revision(self) -> str: ) def load(self) -> Type["nn.Module"]: - kernel = get_kernel(self._repo_id, revision=self._resolve_revision()) + kernel = get_kernel( + self._repo_id, revision=self._resolve_revision(), trust_remote_code=self._trust_remote_code + ) return _get_kernel_layer(self, kernel) def __eq__(self, other): @@ -94,10 +102,11 @@ def __eq__(self, other): and self._repo_id == other._repo_id and self._revision == other._revision and self._version == other._version + and self._trust_remote_code == other._trust_remote_code ) def __hash__(self): - return hash((self.layer_name, self._repo_id, self._revision, self._version)) + return hash((self.layer_name, self._repo_id, self._revision, self._version, self._trust_remote_code)) def __str__(self) -> str: return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`" @@ -168,6 +177,7 @@ def __init__( *, lockfile: Path | None = None, layer_name: str, + trust_remote_code: bool | list[str] = False, ): """ Construct a layer repository. @@ -178,6 +188,7 @@ def __init__( self._repo_id = repo_id self._lockfile = lockfile self.layer_name = layer_name + self._trust_remote_code = trust_remote_code self._revision = self._resolve_revision() def _resolve_revision(self) -> str: @@ -193,7 +204,7 @@ def _resolve_revision(self) -> str: return locked_sha def load(self) -> Type["nn.Module"]: - kernel = get_kernel(repo_id=self._repo_id, revision=self._revision) + kernel = get_kernel(repo_id=self._repo_id, revision=self._revision, trust_remote_code=self._trust_remote_code) return _get_kernel_layer(self, kernel) def __eq__(self, other): @@ -202,10 +213,11 @@ def __eq__(self, other): and self.layer_name == other.layer_name and self._repo_id == other._repo_id and self._revision == other._revision + and self._trust_remote_code == other._trust_remote_code ) def __hash__(self): - return hash((self.layer_name, self._repo_id, self._revision)) + return hash((self.layer_name, self._repo_id, self._revision, self._trust_remote_code)) def __str__(self) -> str: return f"`{self._repo_id}` (revision: {self._revision}), layer `{self.layer_name}`" diff --git a/kernels/src/kernels/utils.py b/kernels/src/kernels/utils.py index 829c740f..84fb05d0 100644 --- a/kernels/src/kernels/utils.py +++ b/kernels/src/kernels/utils.py @@ -31,6 +31,64 @@ KNOWN_BACKENDS = {"cpu", "cuda", "metal", "neuron", "rocm", "xpu", "npu"} +_ALWAYS_TRUSTED_ORGS = {"kernels-community", "kernels-staging", "kernels-test", "sglang"} + + +def _check_trust_remote_code(repo_id: str, trust_remote_code: bool | list[str]) -> None: + """Check whether a kernel repository is trusted. + + When ``trust_remote_code`` is ``False`` (the default), only repositories + whose publisher is marked as trusted on the Hub are allowed. Repositories + from untrusted publishers will raise a ``ValueError``. + + When ``trust_remote_code`` is ``True``, all repositories are allowed. + + When ``trust_remote_code`` is a list of strings, it is treated as a list + of signing identities to verify against. Signing verification is not yet + implemented, so passing a list currently emits a warning and falls back + to the default trust check (i.e. only trusted publishers are allowed). + """ + if trust_remote_code is True: + return + + if isinstance(trust_remote_code, list): + import warnings + + warnings.warn( + "Signing identity verification is not yet implemented. " + "The provided signing identities will be ignored and the " + "kernel will be treated as untrusted. Use trust_remote_code=True " + "to bypass trust checks.", + stacklevel=3, + ) + + org = repo_id.split("/", 1)[0] + if org in _ALWAYS_TRUSTED_ORGS: + return + + raise ValueError( + f"Kernel repository '{repo_id}' is not from a trusted publisher. " + f"Set trust_remote_code=True to allow loading kernels from untrusted sources." + ) + + # TODO: revisit and update logic when we can check trusted publishers at the + # user/organization level + # + # api = _get_hf_api() + # try: + # info = api.repo_info(repo_id, repo_type="kernel") + # except Exception as e: + # raise ValueError( + # f"Could not verify publisher trust status for kernel repository '{repo_id}'. " + # "Set trust_remote_code=True to allow loading kernels from untrusted sources." + # ) from e + + # if not getattr(info, "trustedPublisher", False): + # raise ValueError( + # f"Kernel repository '{repo_id}' is not from a trusted publisher. " + # f"Set trust_remote_code=True to allow loading kernels from untrusted sources." + # ) + @dataclass(frozen=True) class RepoInfo: @@ -293,6 +351,7 @@ def get_kernel( version: int | None = None, backend: str | None = None, user_agent: str | dict | None = None, + trust_remote_code: bool | list[str] = False, ) -> ModuleType: """ Load a kernel from the kernel hub. @@ -312,6 +371,12 @@ def get_kernel( The backend will be detected automatically if not provided. user_agent (`Union[str, dict]`, *optional*): The `user_agent` info to pass to `snapshot_download()` for internal telemetry. + trust_remote_code (`bool | list[str]`, *optional*, defaults to `False`): + Whether to allow loading kernels from untrusted organisations. When ``False``, + only kernels from trusted organisations are allowed. When ``True``, all + repositories are allowed. A list of strings will be used to verify signing + identities in a future release; for now it emits a warning and falls + back to the default trust check. Returns: `ModuleType`: The imported kernel module. @@ -331,6 +396,8 @@ def get_kernel( if override is not None: return get_local_kernel(override) + _check_trust_remote_code(repo_id, trust_remote_code) + revision = select_revision_or_version(repo_id, revision=revision, version=version) repo_info = RepoInfo( repo_id=repo_id, diff --git a/kernels/tests/test_basic.py b/kernels/tests/test_basic.py index c2978a04..fd1d9f01 100644 --- a/kernels/tests/test_basic.py +++ b/kernels/tests/test_basic.py @@ -233,6 +233,22 @@ def test_neuron(): torch.testing.assert_close(relu.relu(x), x.relu()) +def test_trust_remote_code_blocks_untrusted_org(): + """Kernels from untrusted orgs should be rejected by default.""" + with pytest.raises(ValueError, match=r"not from a trusted publisher"): + get_kernel("kernels-test-untrusted/not-a-trused-org-kernel", version=1) + + +def test_trust_remote_code_allows_trusted_org(): + """Kernels from trusted orgs should not raise a trust error.""" + get_kernel("kernels-community/relu", version=1) + + +def test_trust_remote_code_flag_allows_untrusted(): + """trust_remote_code=True should bypass the org check.""" + get_kernel("kernels-test-untrusted/ci-test-kernel", version=1, trust_remote_code=True) + + def silu_and_mul_torch(x: torch.Tensor): d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:]