Skip to content

Commit 84629d2

Browse files
authored
feat: resolve repo type and fetch accordingly (#435)
* feat: resolve repo type and fetch accordingly * fix: lint and format and prefer pinning hfh commit * fix: pin hub library from git commit * fix: update tests to specify repo type * fix: bump overlay * fix: prefer the hfh release * fix: bump hf xet * fix: improve _resolve_repo_type * fix: remove debug line * fix: prefer repo_info in _resolve_repo_type and handle 401 * fix: expect repo type model for the kernels-test org kernels * fix: adjust error * fix: adjust tests and reo type resolution * fix: accept revision in _resolve_repo_type * fix: get file metadata in repo info and remove extra try * feat: run lock pytest as subprocess * feat: prefer locking kernel repo hashes * fix: avoid deprecation warning for now * feat: make repo type optional and comment for siblings * fix: remove repo_type from all public interfaces * fix: remove backward compat of model repo types * fix: improve error cases * fix: bump lock files for kernel repos * fix: avoid non version locking tests * fix: avoid upload changes * fix: apply lints after rebase * fix: adjust repo not found error regression * fix: adjust repo not found error regression in layer test too * fix: revert back to hfhub error until 401 resolved * fix: avoid the unneeded removal of repo type * feat: enforce kwargs * fix: adjust for kwarg * fix: update test for kwarg changes
1 parent 0dff455 commit 84629d2

25 files changed

Lines changed: 536 additions & 1426 deletions

File tree

kernels/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ authors = [
88
]
99
license = { text = "Apache-2.0" }
1010
readme = "README.md"
11-
requires-python = ">= 3.9"
11+
requires-python = ">= 3.10"
1212
dependencies = [
13-
"huggingface_hub>=1.3.0,<2.0",
13+
"huggingface-hub>=1.10.0",
1414
"packaging>=20.0",
1515
"pyyaml>=6",
1616
"tomli>=2.0; python_version<'3.11'",

kernels/src/kernels/_versions.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ def _get_available_versions(repo_id: str) -> dict[int, GitRefInfo]:
1212
"""Get kernel versions that are available in the repository."""
1313
from kernels.utils import _get_hf_api
1414

15+
refs = _get_hf_api().list_repo_refs(repo_id=repo_id, repo_type="kernel")
16+
1517
versions = {}
16-
for branch in _get_hf_api().list_repo_refs(repo_id).branches:
18+
for branch in refs.branches:
1719
if not branch.name.startswith("v"):
1820
continue
1921
try:
@@ -33,7 +35,7 @@ def _get_available_versions_old(repo_id: str) -> dict[Version, GitRefInfo]:
3335
from kernels.utils import _get_hf_api
3436

3537
versions = {}
36-
for tag in _get_hf_api().list_repo_refs(repo_id).tags:
38+
for tag in _get_hf_api().list_repo_refs(repo_id, repo_type="kernel").tags:
3739
if not tag.name.startswith("v"):
3840
continue
3941
try:
@@ -46,13 +48,14 @@ def _get_available_versions_old(repo_id: str) -> dict[Version, GitRefInfo]:
4648

4749
def resolve_version_spec_as_ref(repo_id: str, version_spec: int | str) -> GitRefInfo:
4850
"""
49-
Get the locks for a kernel with the given version spec.
51+
Get the ref for a kernel with the given version spec.
5052
5153
The version specifier can be any valid Python version specifier:
5254
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
5355
"""
5456
if isinstance(version_spec, int):
5557
versions = _get_available_versions(repo_id)
58+
5659
ref = versions.get(version_spec, None)
5760
if ref is None:
5861
raise ValueError(

kernels/src/kernels/cli/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,14 @@ def download_kernels(args):
164164
if args.all_variants:
165165
install_kernel_all_variants(
166166
kernel_lock.repo_id,
167-
kernel_lock.sha,
167+
revision=kernel_lock.sha,
168168
variant_locks=kernel_lock.variants,
169169
)
170170
else:
171171
try:
172172
install_kernel(
173173
kernel_lock.repo_id,
174-
kernel_lock.sha,
174+
revision=kernel_lock.sha,
175175
variant_locks=kernel_lock.variants,
176176
)
177177
except FileNotFoundError as e:

kernels/src/kernels/cli/check.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,19 @@
1515

1616

1717
def check_kernel(
18-
*, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str
18+
*,
19+
macos: str,
20+
manylinux: str,
21+
python_abi: str,
22+
repo_id: str,
23+
revision: str,
1924
):
2025
variants_path = (
2126
Path(
2227
str(
2328
_get_hf_api().snapshot_download(
2429
repo_id,
30+
repo_type="kernel",
2531
allow_patterns=["build/*"],
2632
cache_dir=CACHE_DIR,
2733
revision=revision,

kernels/src/kernels/cli/upload.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def _upload_build_dir(
7070
repo_id=repo_id,
7171
operations=list(chunk),
7272
revision=revision,
73-
repo_type="model",
7473
commit_message=commit_message,
7574
)
7675

kernels/src/kernels/cli/versions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88

99
def print_kernel_versions(repo_id: str):
1010
api = _get_hf_api()
11-
12-
versions = _get_available_versions(repo_id).items()
11+
versions = _get_available_versions(repo_id)
1312
if not versions:
1413
print(f"Repository does not support kernel versions: {repo_id}")
1514
return
1615

17-
for version, ref in sorted(versions, key=lambda x: x[0]):
16+
for version, ref in sorted(versions.items(), key=lambda x: x[0]):
1817
variants = get_variants(api, repo_id=repo_id, revision=ref.ref)
1918
resolved = resolve_variants(variants, None)
2019
best = resolved[0] if resolved else None

kernels/src/kernels/lockfile.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44

55
from huggingface_hub.dataclasses import strict
6+
from huggingface_hub.hf_api import RepoFile
67

78
from kernels._versions import resolve_version_spec_as_ref
89
from kernels.compat import tomllib
@@ -49,21 +50,31 @@ def get_kernel_locks(repo_id: str, version_spec: int | str) -> KernelLock:
4950

5051
tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec)
5152

53+
revision = tag_for_newest.target_commit
54+
5255
r = api.repo_info(
53-
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True
56+
repo_id=repo_id,
57+
repo_type="kernel",
58+
revision=revision,
5459
)
5560
if r.sha is None:
5661
raise ValueError(
5762
f"Cannot get commit SHA for repo {repo_id} for tag {tag_for_newest.name}"
5863
)
5964

60-
if r.siblings is None:
61-
raise ValueError(
62-
f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}"
65+
siblings = [
66+
f
67+
for f in api.list_repo_tree(
68+
repo_id=repo_id,
69+
repo_type="kernel",
70+
revision=revision,
71+
recursive=True,
6372
)
73+
if isinstance(f, RepoFile)
74+
]
6475

6576
variant_files: dict[str, list[tuple[bytes, str]]] = {}
66-
for sibling in r.siblings:
77+
for sibling in siblings:
6778
if sibling.rfilename.startswith("build/torch"):
6879
if sibling.blob_id is None:
6980
raise ValueError(f"Cannot get blob ID for {sibling.rfilename}")

kernels/src/kernels/status.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from huggingface_hub import HfApi
66
from huggingface_hub.dataclasses import strict
7+
from huggingface_hub.errors import RepositoryNotFoundError
78
from huggingface_hub.utils import EntryNotFoundError
89

910
from kernels.compat import tomllib
@@ -54,11 +55,14 @@ def check_status(
5455
) -> KernelStatusKind | None:
5556
try:
5657
path = api.hf_hub_download(
57-
repo_id=repo_id, filename="kernel-status.toml", revision=revision
58+
repo_id=repo_id,
59+
repo_type="kernel",
60+
filename="kernel-status.toml",
61+
revision=revision,
5862
)
5963
with open(path, "r") as f:
6064
return KernelStatus.from_toml(f.read())
61-
except EntryNotFoundError:
65+
except (EntryNotFoundError, RepositoryNotFoundError):
6266
return None
6367

6468

kernels/src/kernels/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def _import_from_path(
131131

132132
def install_kernel(
133133
repo_id: str,
134+
*,
134135
revision: str,
135136
local_files_only: bool = False,
136137
backend: str | None = None,
@@ -176,10 +177,12 @@ def install_kernel(
176177
)
177178

178179
allow_patterns = [f"build/{variant.variant_str}/*"]
180+
179181
repo_path = Path(
180182
str(
181183
api.snapshot_download(
182184
repo_id,
185+
repo_type="kernel",
183186
allow_patterns=allow_patterns,
184187
cache_dir=CACHE_DIR,
185188
revision=revision,
@@ -234,15 +237,18 @@ def _find_kernel_in_repo_path(
234237

235238
def install_kernel_all_variants(
236239
repo_id: str,
240+
*,
237241
revision: str,
238242
local_files_only: bool = False,
239243
variant_locks: dict[str, VariantLock] | None = None,
240244
) -> Path:
241245
api = _get_hf_api()
246+
242247
repo_path = Path(
243248
str(
244249
api.snapshot_download(
245250
repo_id,
251+
repo_type="kernel",
246252
allow_patterns="build/*",
247253
cache_dir=CACHE_DIR,
248254
revision=revision,
@@ -318,7 +324,10 @@ def get_kernel(
318324
backend=backend,
319325
)
320326
package_name, variant_path = install_kernel(
321-
repo_id, revision=revision, backend=backend, user_agent=user_agent
327+
repo_id,
328+
revision=revision,
329+
backend=backend,
330+
user_agent=user_agent,
322331
)
323332
return _import_from_path(package_name, variant_path, _repo_infos=repo_infos)
324333

@@ -396,6 +405,7 @@ def has_kernel(
396405
for init_file in ["__init__.py", f"{package_name}/__init__.py"]:
397406
if api.file_exists(
398407
repo_id,
408+
repo_type="kernel",
399409
revision=revision,
400410
filename=f"build/{variant.variant_str}/{init_file}",
401411
):
@@ -454,6 +464,7 @@ def load_kernel(
454464
str(
455465
api.snapshot_download(
456466
repo_id,
467+
repo_type="kernel",
457468
allow_patterns=allow_patterns,
458469
cache_dir=CACHE_DIR,
459470
revision=locked_sha,
@@ -495,7 +506,7 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleTyp
495506
raise ValueError(f"Kernel `{repo_id}` is not locked")
496507

497508
package_name, variant_path = install_kernel(
498-
repo_id, locked_sha, local_files_only=local_files_only
509+
repo_id, revision=locked_sha, local_files_only=local_files_only
499510
)
500511

501512
return _import_from_path(package_name, variant_path)

kernels/src/kernels/variants.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,10 @@ def parse_variant(variant_str: str) -> Variant:
295295
def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]:
296296
"""Get all the build variants available from a kernel repository."""
297297

298-
tree = api.list_repo_tree(repo_id, path_in_repo="build", revision=revision)
298+
tree = api.list_repo_tree(
299+
repo_id, path_in_repo="build", repo_type="kernel", revision=revision
300+
)
301+
299302
variant_strs = {
300303
item.path.split("/")[-1] for item in tree if isinstance(item, RepoFolder)
301304
}
@@ -333,6 +336,7 @@ def resolve_variant(
333336
) -> Variant | None:
334337
"""Return the best matching variant for the current system."""
335338
resolved = resolve_variants(variants, backend)
339+
336340
return resolved[0] if resolved else None
337341

338342

0 commit comments

Comments
 (0)