Skip to content

Commit 658395e

Browse files
drbhdanieldk
andauthored
feat: add trusted orgs and flag (#512)
* feat: add trusted orgs and flag * fix: format changes * Update kernels/tests/test_basic.py Co-authored-by: Daniël de Kok <me@danieldk.eu> * fix: remove public TRUSTED_KERNEL_ORGS and update remote code tests * feat: prefer checking against the hub info response rather than hardcoded list * fix: prefer using _get_hf_api * fix: short circuit remote check for kernels-test org * fix: adjust error text in kernels/src/kernels/utils.py Co-authored-by: Daniël de Kok <me@danieldk.eu> * fix: remove always trusted orgs and update test org * fix: add hardcoded trusted orgs for now * fix: avoid repo info fetch path --------- Co-authored-by: Daniël de Kok <me@danieldk.eu>
1 parent b2358ea commit 658395e

4 files changed

Lines changed: 113 additions & 8 deletions

File tree

kernels/src/kernels/layer/func.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,14 @@ def __init__(
6464
func_name: str,
6565
revision: str | None = None,
6666
version: int | None = None,
67+
trust_remote_code: bool | list[str] = False,
6768
):
6869
if revision is not None and version is not None:
6970
raise ValueError("Either a revision or a version must be specified, not both.")
7071

7172
self._repo_id = repo_id
7273
self.func_name = func_name
74+
self._trust_remote_code = trust_remote_code
7375

7476
# We are going to resolve these lazily, since we do not want
7577
# to do a network request for every registered FuncRepository.
@@ -85,7 +87,9 @@ def _resolve_revision(self) -> str:
8587
)
8688

8789
def load(self) -> Type["nn.Module"]:
88-
kernel = get_kernel(self._repo_id, revision=self._resolve_revision())
90+
kernel = get_kernel(
91+
self._repo_id, revision=self._resolve_revision(), trust_remote_code=self._trust_remote_code
92+
)
8993
return _get_kernel_func(self, kernel)
9094

9195
def __eq__(self, other):
@@ -95,10 +99,11 @@ def __eq__(self, other):
9599
and self._repo_id == other._repo_id
96100
and self._revision == other._revision
97101
and self._version == other._version
102+
and self._trust_remote_code == other._trust_remote_code
98103
)
99104

100105
def __hash__(self):
101-
return hash((self.func_name, self._repo_id, self._revision, self._version))
106+
return hash((self.func_name, self._repo_id, self._revision, self._version, self._trust_remote_code))
102107

103108
def __str__(self) -> str:
104109
return f"`{self._repo_id}` (revision: {self._resolve_revision()}), function `{self.func_name}`"
@@ -224,6 +229,7 @@ def __init__(
224229
*,
225230
lockfile: Path | None = None,
226231
func_name: str,
232+
trust_remote_code: bool | list[str] = False,
227233
):
228234
"""
229235
Construct a function repository.
@@ -233,10 +239,13 @@ def __init__(
233239
lockfile (`Path`, *optional*): Path to the lockfile. If not provided,
234240
the lockfile will be inferred from the caller's context.
235241
func_name (`str`): The name of the function within the kernel repository.
242+
trust_remote_code (`bool`, *optional*, defaults to `False`):
243+
Whether to allow loading kernels from untrusted organisations.
236244
"""
237245
self._repo_id = repo_id
238246
self._lockfile = lockfile
239247
self.func_name = func_name
248+
self._trust_remote_code = trust_remote_code
240249
self._revision = self._resolve_revision()
241250

242251
def _resolve_revision(self) -> str:
@@ -252,7 +261,7 @@ def _resolve_revision(self) -> str:
252261
return locked_sha
253262

254263
def load(self) -> Type["nn.Module"]:
255-
kernel = get_kernel(repo_id=self._repo_id, revision=self._revision)
264+
kernel = get_kernel(repo_id=self._repo_id, revision=self._revision, trust_remote_code=self._trust_remote_code)
256265
return _get_kernel_func(self, kernel)
257266

258267
def __eq__(self, other):
@@ -261,10 +270,11 @@ def __eq__(self, other):
261270
and self.func_name == other.func_name
262271
and self._repo_id == other._repo_id
263272
and self._revision == other._revision
273+
and self._trust_remote_code == other._trust_remote_code
264274
)
265275

266276
def __hash__(self):
267-
return hash((self.func_name, self._repo_id, self._revision))
277+
return hash((self.func_name, self._repo_id, self._revision, self._trust_remote_code))
268278

269279
def __str__(self) -> str:
270280
return f"`{self._repo_id}` (revision: {self._revision}), function `{self.func_name}`"

kernels/src/kernels/layer/layer.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class LayerRepository:
4242
The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
4343
version (`int`, *optional*):
4444
The kernel version to download. Cannot be used together with `revision`.
45+
trust_remote_code (`bool | list[str]`, *optional*, defaults to `False`):
46+
Whether to allow loading kernels from untrusted organisations. A list
47+
of signing identities can be provided for future verification support;
48+
until then it warns and falls back to the default trust check.
4549
4650
Example:
4751
```python
@@ -63,12 +67,14 @@ def __init__(
6367
layer_name: str,
6468
revision: str | None = None,
6569
version: int | None = None,
70+
trust_remote_code: bool | list[str] = False,
6671
):
6772
if revision is not None and version is not None:
6873
raise ValueError("Either a revision or a version must be specified, not both.")
6974

7075
self._repo_id = repo_id
7176
self.layer_name = layer_name
77+
self._trust_remote_code = trust_remote_code
7278

7379
# We are going to resolve these lazily, since we do not want
7480
# to do a network request for every registered LayerRepository.
@@ -84,7 +90,9 @@ def _resolve_revision(self) -> str:
8490
)
8591

8692
def load(self) -> Type["nn.Module"]:
87-
kernel = get_kernel(self._repo_id, revision=self._resolve_revision())
93+
kernel = get_kernel(
94+
self._repo_id, revision=self._resolve_revision(), trust_remote_code=self._trust_remote_code
95+
)
8896
return _get_kernel_layer(self, kernel)
8997

9098
def __eq__(self, other):
@@ -94,10 +102,11 @@ def __eq__(self, other):
94102
and self._repo_id == other._repo_id
95103
and self._revision == other._revision
96104
and self._version == other._version
105+
and self._trust_remote_code == other._trust_remote_code
97106
)
98107

99108
def __hash__(self):
100-
return hash((self.layer_name, self._repo_id, self._revision, self._version))
109+
return hash((self.layer_name, self._repo_id, self._revision, self._version, self._trust_remote_code))
101110

102111
def __str__(self) -> str:
103112
return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`"
@@ -168,6 +177,7 @@ def __init__(
168177
*,
169178
lockfile: Path | None = None,
170179
layer_name: str,
180+
trust_remote_code: bool | list[str] = False,
171181
):
172182
"""
173183
Construct a layer repository.
@@ -178,6 +188,7 @@ def __init__(
178188
self._repo_id = repo_id
179189
self._lockfile = lockfile
180190
self.layer_name = layer_name
191+
self._trust_remote_code = trust_remote_code
181192
self._revision = self._resolve_revision()
182193

183194
def _resolve_revision(self) -> str:
@@ -193,7 +204,7 @@ def _resolve_revision(self) -> str:
193204
return locked_sha
194205

195206
def load(self) -> Type["nn.Module"]:
196-
kernel = get_kernel(repo_id=self._repo_id, revision=self._revision)
207+
kernel = get_kernel(repo_id=self._repo_id, revision=self._revision, trust_remote_code=self._trust_remote_code)
197208
return _get_kernel_layer(self, kernel)
198209

199210
def __eq__(self, other):
@@ -202,10 +213,11 @@ def __eq__(self, other):
202213
and self.layer_name == other.layer_name
203214
and self._repo_id == other._repo_id
204215
and self._revision == other._revision
216+
and self._trust_remote_code == other._trust_remote_code
205217
)
206218

207219
def __hash__(self):
208-
return hash((self.layer_name, self._repo_id, self._revision))
220+
return hash((self.layer_name, self._repo_id, self._revision, self._trust_remote_code))
209221

210222
def __str__(self) -> str:
211223
return f"`{self._repo_id}` (revision: {self._revision}), layer `{self.layer_name}`"

kernels/src/kernels/utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,64 @@
3131

3232
KNOWN_BACKENDS = {"cpu", "cuda", "metal", "neuron", "rocm", "xpu", "npu"}
3333

34+
_ALWAYS_TRUSTED_ORGS = {"kernels-community", "kernels-staging", "kernels-test", "sglang"}
35+
36+
37+
def _check_trust_remote_code(repo_id: str, trust_remote_code: bool | list[str]) -> None:
38+
"""Check whether a kernel repository is trusted.
39+
40+
When ``trust_remote_code`` is ``False`` (the default), only repositories
41+
whose publisher is marked as trusted on the Hub are allowed. Repositories
42+
from untrusted publishers will raise a ``ValueError``.
43+
44+
When ``trust_remote_code`` is ``True``, all repositories are allowed.
45+
46+
When ``trust_remote_code`` is a list of strings, it is treated as a list
47+
of signing identities to verify against. Signing verification is not yet
48+
implemented, so passing a list currently emits a warning and falls back
49+
to the default trust check (i.e. only trusted publishers are allowed).
50+
"""
51+
if trust_remote_code is True:
52+
return
53+
54+
if isinstance(trust_remote_code, list):
55+
import warnings
56+
57+
warnings.warn(
58+
"Signing identity verification is not yet implemented. "
59+
"The provided signing identities will be ignored and the "
60+
"kernel will be treated as untrusted. Use trust_remote_code=True "
61+
"to bypass trust checks.",
62+
stacklevel=3,
63+
)
64+
65+
org = repo_id.split("/", 1)[0]
66+
if org in _ALWAYS_TRUSTED_ORGS:
67+
return
68+
69+
raise ValueError(
70+
f"Kernel repository '{repo_id}' is not from a trusted publisher. "
71+
f"Set trust_remote_code=True to allow loading kernels from untrusted sources."
72+
)
73+
74+
# TODO: revisit and update logic when we can check trusted publishers at the
75+
# user/organization level
76+
#
77+
# api = _get_hf_api()
78+
# try:
79+
# info = api.repo_info(repo_id, repo_type="kernel")
80+
# except Exception as e:
81+
# raise ValueError(
82+
# f"Could not verify publisher trust status for kernel repository '{repo_id}'. "
83+
# "Set trust_remote_code=True to allow loading kernels from untrusted sources."
84+
# ) from e
85+
86+
# if not getattr(info, "trustedPublisher", False):
87+
# raise ValueError(
88+
# f"Kernel repository '{repo_id}' is not from a trusted publisher. "
89+
# f"Set trust_remote_code=True to allow loading kernels from untrusted sources."
90+
# )
91+
3492

3593
@dataclass(frozen=True)
3694
class RepoInfo:
@@ -293,6 +351,7 @@ def get_kernel(
293351
version: int | None = None,
294352
backend: str | None = None,
295353
user_agent: str | dict | None = None,
354+
trust_remote_code: bool | list[str] = False,
296355
) -> ModuleType:
297356
"""
298357
Load a kernel from the kernel hub.
@@ -312,6 +371,12 @@ def get_kernel(
312371
The backend will be detected automatically if not provided.
313372
user_agent (`Union[str, dict]`, *optional*):
314373
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
374+
trust_remote_code (`bool | list[str]`, *optional*, defaults to `False`):
375+
Whether to allow loading kernels from untrusted organisations. When ``False``,
376+
only kernels from trusted organisations are allowed. When ``True``, all
377+
repositories are allowed. A list of strings will be used to verify signing
378+
identities in a future release; for now it emits a warning and falls
379+
back to the default trust check.
315380
316381
Returns:
317382
`ModuleType`: The imported kernel module.
@@ -331,6 +396,8 @@ def get_kernel(
331396
if override is not None:
332397
return get_local_kernel(override)
333398

399+
_check_trust_remote_code(repo_id, trust_remote_code)
400+
334401
revision = select_revision_or_version(repo_id, revision=revision, version=version)
335402
repo_info = RepoInfo(
336403
repo_id=repo_id,

kernels/tests/test_basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,22 @@ def test_neuron():
233233
torch.testing.assert_close(relu.relu(x), x.relu())
234234

235235

236+
def test_trust_remote_code_blocks_untrusted_org():
237+
"""Kernels from untrusted orgs should be rejected by default."""
238+
with pytest.raises(ValueError, match=r"not from a trusted publisher"):
239+
get_kernel("kernels-test-untrusted/not-a-trused-org-kernel", version=1)
240+
241+
242+
def test_trust_remote_code_allows_trusted_org():
243+
"""Kernels from trusted orgs should not raise a trust error."""
244+
get_kernel("kernels-community/relu", version=1)
245+
246+
247+
def test_trust_remote_code_flag_allows_untrusted():
248+
"""trust_remote_code=True should bypass the org check."""
249+
get_kernel("kernels-test-untrusted/ci-test-kernel", version=1, trust_remote_code=True)
250+
251+
236252
def silu_and_mul_torch(x: torch.Tensor):
237253
d = x.shape[-1] // 2
238254
return F.silu(x[..., :d]) * x[..., d:]

0 commit comments

Comments
 (0)