Skip to content

Commit b45883d

Browse files
sayakpauldanieldk
andauthored
feat: make kernel loading work when offline mode ie enabled. (#580)
* feat: make kernel loading work when offline mode ie enabled. * Apply suggestions from code review Co-authored-by: Daniël de Kok <me@danieldk.eu> * address reviewer comments. * up * Apply suggestions from code review Co-authored-by: Daniël de Kok <me@danieldk.eu> --------- Co-authored-by: Daniël de Kok <me@danieldk.eu>
1 parent 566fd6f commit b45883d

3 files changed

Lines changed: 207 additions & 41 deletions

File tree

kernels/src/kernels/_versions.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import logging
2+
import os
3+
from pathlib import Path
24

5+
from huggingface_hub import constants
6+
from huggingface_hub.file_download import repo_folder_name
37
from huggingface_hub.hf_api import GitRefInfo
48

59
logger = logging.getLogger(__name__)
@@ -9,6 +13,9 @@ def _get_available_versions(repo_id: str) -> dict[int, GitRefInfo]:
913
"""Get kernel versions that are available in the repository."""
1014
from kernels.utils import _get_hf_api
1115

16+
if constants.HF_HUB_OFFLINE:
17+
return _get_available_versions_from_cache(repo_id)
18+
1219
refs = _get_hf_api().list_repo_refs(repo_id=repo_id, repo_type="kernel")
1320

1421
versions = {}
@@ -23,6 +30,36 @@ def _get_available_versions(repo_id: str) -> dict[int, GitRefInfo]:
2330
return versions
2431

2532

33+
def _get_available_versions_from_cache(repo_id: str) -> dict[int, GitRefInfo]:
34+
"""Get kernel versions from the local Hugging Face cache."""
35+
cache_dir = os.environ.get("KERNELS_CACHE") or constants.HF_HUB_CACHE
36+
37+
versions: dict[int, GitRefInfo] = {}
38+
# Tolerate both layouts: the "kernel" repo type used by newer
39+
# huggingface_hub, and the legacy "model" prefix that older caches use.
40+
for repo_type in ("kernel", "model"):
41+
refs_dir = Path(cache_dir) / repo_folder_name(repo_id=repo_id, repo_type=repo_type) / "refs"
42+
if not refs_dir.is_dir():
43+
continue
44+
for ref_path in refs_dir.iterdir():
45+
if not ref_path.is_file():
46+
continue
47+
ref_name = ref_path.name
48+
if not ref_name.startswith("v"):
49+
continue
50+
try:
51+
version = int(ref_name[1:])
52+
except ValueError:
53+
continue
54+
try:
55+
commit = ref_path.read_text().strip()
56+
except OSError:
57+
continue
58+
versions[version] = GitRefInfo(name=ref_name, ref=ref_name, target_commit=commit)
59+
60+
return versions
61+
62+
2663
def resolve_version_spec_as_ref(repo_id: str, version_spec: int) -> GitRefInfo:
2764
"""
2865
Get the ref for a kernel with the given version.
@@ -31,6 +68,12 @@ def resolve_version_spec_as_ref(repo_id: str, version_spec: int) -> GitRefInfo:
3168

3269
ref = versions.get(version_spec, None)
3370
if ref is None:
71+
if constants.HF_HUB_OFFLINE and not versions:
72+
raise ValueError(
73+
f"Version {version_spec} of '{repo_id}' is not available in the local cache "
74+
"and Hugging Face Hub is in offline mode. Download the kernel "
75+
"while online first, or pass an explicit `revision=<commit>`."
76+
)
3477
raise ValueError(
3578
f"Version {version_spec} not found, available versions: {', '.join(str(v) for v in sorted(versions.keys()))}"
3679
)

kernels/src/kernels/utils.py

Lines changed: 101 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import os
88
import platform
99
import sys
10+
import warnings
1011
from dataclasses import dataclass
1112
from importlib.metadata import Distribution
1213
from pathlib import Path
1314
from types import ModuleType
1415

1516
from huggingface_hub import HfApi, constants
17+
from huggingface_hub.errors import LocalEntryNotFoundError
1618
from kernels_data import Metadata
1719

1820
from kernels._system import glibc_version
@@ -52,8 +54,6 @@ def _check_trust_remote_code(repo_id: str, trust_remote_code: bool | list[str])
5254
return
5355

5456
if isinstance(trust_remote_code, list):
55-
import warnings
56-
5757
warnings.warn(
5858
"Signing identity verification is not yet implemented. "
5959
"The provided signing identities will be ignored and the "
@@ -62,6 +62,16 @@ def _check_trust_remote_code(repo_id: str, trust_remote_code: bool | list[str])
6262
stacklevel=3,
6363
)
6464

65+
if constants.HF_HUB_OFFLINE:
66+
# Publisher trust cannot be verified offline. The user opted into
67+
# offline mode and the kernel must already be in the local cache,
68+
# so trust was established when it was originally downloaded.
69+
warnings.warn(
70+
f"Skipping publisher trust check for '{repo_id}' because Hugging Face Hub is in offline mode.",
71+
stacklevel=3,
72+
)
73+
return
74+
6575
publisher = repo_id.split("/", 1)[0]
6676

6777
try:
@@ -244,10 +254,19 @@ def install_kernel(
244254
`Path`: The path to the variant directory.
245255
"""
246256
api = _get_hf_api(user_agent=user_agent)
257+
if local_files_only or constants.HF_HUB_OFFLINE:
258+
# Same local-cache resolution path used by `load_kernel`, which is
259+
# always offline. Sharing the helper avoids the network dependency
260+
# that `get_variants` would otherwise introduce.
261+
return _resolve_local_variant_path(
262+
api,
263+
repo_id,
264+
revision=revision,
265+
backend=backend,
266+
variant_locks=variant_locks,
267+
)
247268

248-
if not local_files_only:
249-
repo_id, revision = resolve_status(api, repo_id, revision)
250-
269+
repo_id, revision = resolve_status(api, repo_id, revision)
251270
variants = get_variants(api, repo_id=repo_id, revision=revision)
252271
variant, trace = resolve_variant(variants, backend)
253272

@@ -266,7 +285,7 @@ def install_kernel(
266285
allow_patterns=allow_patterns,
267286
cache_dir=CACHE_DIR,
268287
revision=revision,
269-
local_files_only=local_files_only,
288+
local_files_only=False,
270289
)
271290
)
272291
)
@@ -281,6 +300,61 @@ def install_kernel(
281300
raise FileNotFoundError(f"Cannot install kernel from repo {repo_id} (revision: {revision})")
282301

283302

303+
def _resolve_local_variant_path(
304+
api: HfApi,
305+
repo_id: str,
306+
*,
307+
revision: str,
308+
backend: str | None = None,
309+
variant_locks: dict[str, VariantLock] | None = None,
310+
) -> Path:
311+
"""Resolve a kernel variant path from the local Hugging Face cache only.
312+
313+
Used by `load_kernel` (which always operates on a pre-downloaded, locked
314+
kernel) and by the offline branch of `install_kernel`.
315+
"""
316+
try:
317+
local_repo_path = Path(
318+
str(
319+
api.snapshot_download(
320+
repo_id,
321+
repo_type="kernel",
322+
cache_dir=CACHE_DIR,
323+
revision=revision,
324+
local_files_only=True,
325+
)
326+
)
327+
)
328+
except LocalEntryNotFoundError as e:
329+
raise FileNotFoundError(
330+
f"Cannot find a local snapshot for {repo_id} (revision: {revision}). "
331+
"When Hugging Face Hub is in offline mode the kernel must already "
332+
"be present in the local cache."
333+
) from e
334+
335+
variants = get_variants_local(local_repo_path / "build")
336+
variant, status = resolve_variant(variants, backend)
337+
if variant is None:
338+
raise FileNotFoundError(
339+
f"Cannot find a build variant for this system in {repo_id} (revision: {revision}):\n\n{variants_trace_str(status)}"
340+
)
341+
342+
allow_patterns = [f"build/{variant.variant_str}/*"]
343+
repo_path = Path(
344+
str(
345+
api.snapshot_download(
346+
repo_id,
347+
repo_type="kernel",
348+
allow_patterns=allow_patterns,
349+
cache_dir=CACHE_DIR,
350+
revision=revision,
351+
local_files_only=True,
352+
)
353+
)
354+
)
355+
return _find_kernel_in_repo_path(repo_path, variant=variant, variant_locks=variant_locks)
356+
357+
284358
def _find_kernel_in_repo_path(
285359
repo_path: Path,
286360
*,
@@ -479,10 +553,7 @@ def has_kernel(
479553

480554

481555
def load_kernel(
482-
repo_id: str,
483-
*,
484-
lockfile: Path | None,
485-
backend: str | None = None,
556+
repo_id: str, *, lockfile: Path | None, backend: str | None = None, revision: str | None = None
486557
) -> ModuleType:
487558
"""
488559
Get a pre-downloaded, locked kernel.
@@ -497,13 +568,20 @@ def load_kernel(
497568
backend (`str`, *optional*):
498569
The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for.
499570
The backend will be detected automatically if not provided.
571+
revision (`str`, *optional*):
572+
The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
500573
501574
Returns:
502575
`ModuleType`: The imported kernel module.
503576
"""
504-
if lockfile is None:
577+
if lockfile is not None and revision is not None:
578+
raise ValueError("`lockfile` and `revision` both cannot be specified at the same time.")
579+
580+
if lockfile is None and revision is None:
505581
locked_sha = _get_caller_locked_kernel(repo_id)
506-
else:
582+
elif revision is not None:
583+
locked_sha = revision
584+
elif lockfile is not None:
507585
with open(lockfile, "r") as f:
508586
locked_sha = _get_locked_kernel(repo_id, f.read())
509587

@@ -513,39 +591,21 @@ def load_kernel(
513591
)
514592

515593
api = _get_hf_api()
516-
variants = get_variants(api, repo_id=repo_id, revision=locked_sha)
517-
variant, status = resolve_variant(variants, backend)
518-
519-
if variant is None:
520-
raise FileNotFoundError(
521-
f"Cannot find a build variant for this system in {repo_id} (revision: {locked_sha}):\n\n{variants_trace_str(status)}"
522-
)
523-
524-
allow_patterns = [f"build/{variant.variant_str}/*"]
525-
repo_path = Path(
526-
str(
527-
api.snapshot_download(
528-
repo_id,
529-
repo_type="kernel",
530-
allow_patterns=allow_patterns,
531-
cache_dir=CACHE_DIR,
532-
revision=locked_sha,
533-
local_files_only=True,
534-
)
535-
)
536-
)
537594

538595
try:
539-
variant_path = _find_kernel_in_repo_path(
540-
repo_path,
541-
variant=variant,
542-
variant_locks=None,
596+
variant_path = _resolve_local_variant_path(
597+
api,
598+
repo_id,
599+
revision=locked_sha,
600+
backend=backend,
543601
)
544-
return _import_from_path(variant_path)
545-
except FileNotFoundError:
602+
except FileNotFoundError as e:
546603
raise FileNotFoundError(
547-
f"Locked kernel `{repo_id}` does not have applicable variant or was not downloaded with `kernels download <project>`"
548-
)
604+
f"Locked kernel `{repo_id}` was not downloaded or does not have an "
605+
"applicable variant. Make sure it's downloaded locally via "
606+
"`kernels download <project>`."
607+
) from e
608+
return _import_from_path(variant_path)
549609

550610

551611
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:

kernels/tests/test_basic.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import pytest
44
import torch
55
import torch.nn.functional as F
6+
from huggingface_hub import constants
67
from huggingface_hub.errors import HfHubHTTPError
78

89
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
10+
from kernels._versions import resolve_version_spec_as_ref, select_revision_or_version
911

1012

1113
@pytest.fixture
@@ -243,6 +245,67 @@ def test_trust_remote_code_flag_allows_untrusted():
243245
get_kernel("kernels-test-untrusted/ci-test-kernel", version=1, trust_remote_code=True)
244246

245247

248+
def test_install_kernel_offline_with_revision(monkeypatch, local_kernel_path):
249+
"""install_kernel should resolve a cached snapshot when HF_HUB_OFFLINE=1."""
250+
expected_path = local_kernel_path
251+
monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True)
252+
253+
path = install_kernel("kernels-community/relu", revision="v1")
254+
assert path == expected_path
255+
256+
257+
def test_install_kernel_offline_avoids_network(monkeypatch, local_kernel_path):
258+
"""When HF_HUB_OFFLINE=1, install_kernel must not make any Hub requests."""
259+
expected_path = local_kernel_path
260+
261+
class _NoNetwork(RuntimeError):
262+
pass
263+
264+
def _fail(*_args, **_kwargs):
265+
raise _NoNetwork("Hub access attempted in offline test")
266+
267+
monkeypatch.setattr("huggingface_hub.hf_api.get_session", _fail)
268+
269+
# Online path must touch the Hub via get_session and therefore fail.
270+
with pytest.raises(_NoNetwork):
271+
install_kernel("kernels-community/relu", revision="v1")
272+
273+
# Offline mode resolves entirely from the local cache, so get_session is
274+
# never called.
275+
monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True)
276+
path = install_kernel("kernels-community/relu", revision="v1")
277+
assert path == expected_path
278+
279+
280+
def test_install_kernel_offline_with_version(monkeypatch, local_kernel_path):
281+
"""get_kernel(version=) should resolve via local refs when HF_HUB_OFFLINE=1."""
282+
expected_path = local_kernel_path
283+
monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True)
284+
285+
commit = select_revision_or_version("kernels-community/relu", revision=None, version=1)
286+
path = install_kernel("kernels-community/relu", revision=commit)
287+
assert path == expected_path
288+
289+
290+
def test_install_kernel_offline_uncached_revision(monkeypatch):
291+
"""install_kernel should fail with a helpful error when offline and uncached."""
292+
monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True)
293+
294+
with pytest.raises(FileNotFoundError, match=r"local snapshot"):
295+
install_kernel(
296+
"kernels-test/this-repo-should-not-exist",
297+
revision="0000000000000000000000000000000000000000",
298+
)
299+
300+
301+
def test_version_resolution_offline_missing(monkeypatch):
302+
"""resolve_version_spec_as_ref should raise a clear error when offline and no cache."""
303+
monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True)
304+
305+
with pytest.raises(ValueError, match=r"offline mode"):
306+
resolve_version_spec_as_ref("kernels-test/this-repo-should-not-exist", 1)
307+
308+
246309
def silu_and_mul_torch(x: torch.Tensor):
247310
d = x.shape[-1] // 2
248311
return F.silu(x[..., :d]) * x[..., d:]

0 commit comments

Comments
 (0)