Skip to content

Commit 334be5e

Browse files
committed
fix: simplify init logic and add test for init
1 parent e7d3188 commit 334be5e

6 files changed

Lines changed: 160 additions & 113 deletions

File tree

kernels/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies = [
1616
"packaging>=20.0",
1717
"pyyaml>=6",
1818
"tomli>=2.0; python_version<'3.11'",
19+
"tomlkit>=0.14.0",
1920
]
2021

2122
[build-system]

kernels/src/kernels/cli.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
from kernels.compat import tomllib
88
from kernels.lockfile import KernelLock, get_kernel_locks
99
from kernels.upload import upload_kernels_dir
10-
from kernels.utils import install_kernel, install_kernel_all_variants
10+
from kernels.utils import (
11+
install_kernel,
12+
install_kernel_all_variants,
13+
KNOWN_BACKENDS,
14+
)
1115
from kernels.versions_cli import print_kernel_versions
12-
from kernels.init import run_init
16+
from kernels.init import run_init, parse_kernel_name
1317

1418
from .doc import generate_readme_for_kernel
1519

@@ -153,7 +157,7 @@ def main():
153157
)
154158
init_parser.add_argument(
155159
"kernel_name",
156-
type=str,
160+
type=parse_kernel_name,
157161
help="Name of the kernel repo (e.g., drbh/my-kernel)",
158162
)
159163
init_parser.add_argument(
@@ -165,8 +169,15 @@ def main():
165169
init_parser.add_argument(
166170
"--backends",
167171
nargs="+",
168-
default=None,
169-
help="Backends to include ('all' or list like: cpu cuda metal rocm xpu). Defaults: cuda on Linux/Windows, metal on macOS.",
172+
choices={"all"} | KNOWN_BACKENDS,
173+
default=["metal"] if sys.platform == "darwin" else ["cuda"],
174+
metavar="BACKEND",
175+
help=f"Backends to enable (all, {', '.join(KNOWN_BACKENDS)}). Defaults: cuda on Linux/Windows, metal on macOS.",
176+
)
177+
init_parser.add_argument(
178+
"--overwrite",
179+
action="store_true",
180+
help="Overwrite existing directory if it exists",
170181
)
171182
init_parser.set_defaults(func=run_init)
172183

kernels/src/kernels/init.py

Lines changed: 69 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,70 @@
1+
import argparse
12
import os
2-
import re
33
import shutil
44
import subprocess
55
import sys
66
from argparse import Namespace
77
from pathlib import Path
8+
from typing import NamedTuple
89

910
from huggingface_hub import snapshot_download
1011
from huggingface_hub.utils import disable_progress_bars
1112

12-
from kernels.compat import tomllib
13+
import tomlkit
1314
from kernels.utils import KNOWN_BACKENDS
1415

1516

16-
def run_init(args: Namespace) -> None:
17-
kernel_name = args.kernel_name
18-
if args.backends is None:
19-
backends = ["metal"] if sys.platform == "darwin" else ["cuda"]
20-
else:
21-
backends = [
22-
v.strip().lower()
23-
for item in args.backends
24-
for v in item.split(",")
25-
if v.strip()
26-
]
27-
if "all" in backends:
28-
if len(backends) > 1:
29-
print(
30-
"Error: --backends must be either 'all' or a list of backends.",
31-
file=sys.stderr,
32-
)
33-
sys.exit(1)
34-
backends = []
35-
else:
36-
valid = set(KNOWN_BACKENDS)
37-
invalid = sorted(set(backends) - valid)
38-
if invalid:
39-
print(
40-
f"Error: invalid backend(s): {', '.join(invalid)}. Valid values are: {', '.join(sorted(valid))}.",
41-
file=sys.stderr,
42-
)
43-
sys.exit(1)
44-
seen: set[str] = set()
45-
backends = [b for b in backends if not (b in seen or seen.add(b))]
46-
# must be fully qualified repo name <owner>/<repo>
47-
owner_repo = kernel_name.split("/")
48-
if len(owner_repo) != 2:
49-
print(
50-
f"Error: kernel_name must be in the format <owner>/<repo> (e.g., drbh/my-kernel)",
51-
file=sys.stderr,
52-
)
53-
sys.exit(1)
54-
owner, kernel_name = owner_repo
55-
kernel_name_normalized = kernel_name.replace("-", "_")
56-
repo_id = f"{owner}/{kernel_name}"
57-
output_dir = Path.cwd()
17+
def parse_kernel_name(value: str) -> NamedTuple:
18+
parts = value.split("/")
19+
if len(parts) != 2 or not all(parts): # validate format
20+
raise argparse.ArgumentTypeError("must be <owner>/<repo>")
21+
owner, name = parts
22+
23+
if "/" in name or "\\" in name: # validate kernel name
24+
raise argparse.ArgumentTypeError("repo name cannot contain path separators")
25+
26+
name = name.lower().replace("-", "_") # normalize name
27+
RepoInfo = NamedTuple("RepoInfo", [("name", str), ("owner", str), ("repo_id", str)])
28+
return RepoInfo(name=name, owner=owner, repo_id=f"{owner}/{name}")
5829

59-
# Validate kernel name
60-
if "/" in kernel_name or "\\" in kernel_name:
61-
print(
62-
f"Error: Kernel name cannot contain path separators: {kernel_name}",
63-
file=sys.stderr,
64-
)
65-
sys.exit(1)
30+
31+
def run_init(args: Namespace) -> None:
32+
kernel_name = args.kernel_name.name
33+
repo_id = args.kernel_name.repo_id
34+
backends = KNOWN_BACKENDS if "all" in args.backends else set(args.backends)
6635

6736
# Target directory
68-
target_dir = output_dir / kernel_name
37+
target_dir = Path.cwd() / kernel_name
38+
39+
if args.overwrite:
40+
if target_dir.exists():
41+
shutil.rmtree(target_dir)
42+
6943
if target_dir.exists() and any(target_dir.iterdir()):
7044
print(
7145
f"Error: Directory already exists and is not empty: {target_dir}",
7246
file=sys.stderr,
7347
)
7448
sys.exit(1)
7549

76-
# Download template from HuggingFace
77-
template_repo = args.template_repo
78-
7950
# Suppress progress bars for cleaner output (files are often cached)
8051
disable_progress_bars()
8152

82-
print(f"Downloading template from {template_repo}...", file=sys.stderr)
83-
template_dir = Path(snapshot_download(repo_id=template_repo, repo_type="model"))
84-
_init_from_local_template(
85-
template_dir, target_dir, kernel_name, kernel_name_normalized, repo_id
53+
print(f"Downloading template from {args.template_repo}...", file=sys.stderr)
54+
template_dir = Path(
55+
snapshot_download(repo_id=args.template_repo, repo_type="model")
8656
)
57+
_init_from_local_template(template_dir, target_dir, kernel_name, repo_id)
58+
8759
if backends:
8860
_update_build_backends(target_dir / "build.toml", backends)
61+
62+
# replacement logic
63+
# - rocm uses cuda source so we need to replace the rocm with cuda
64+
if "rocm" in backends:
65+
backends.remove("rocm")
66+
backends.add("cuda")
67+
8968
_remove_backend_dirs(target_dir, backends)
9069

9170
# Initialize git repo (required for Nix flakes)
@@ -94,24 +73,23 @@ def run_init(args: Namespace) -> None:
9473

9574
print(f"Initialized kernel project: {target_dir}")
9675
_print_tree(target_dir)
97-
print("\nNext steps:")
98-
print(f" cd {kernel_name}")
99-
print(" cachix use huggingface")
100-
print(" nix run -L --max-jobs 1 --cores 8 .#build-and-copy")
101-
print(" uv run example.py")
76+
print("\nNext steps:\n")
77+
print(f"cd {kernel_name}")
78+
print("cachix use huggingface")
79+
print("nix run -L --max-jobs 1 --cores 8 .#build-and-copy")
80+
print("uv run example.py")
10281

10382

10483
def _init_from_local_template(
10584
template_dir: Path,
10685
target_dir: Path,
10786
kernel_name: str,
108-
kernel_name_normalized: str,
10987
repo_id: str,
11088
) -> None:
11189
# Placeholder mappings
11290
replacements = {
11391
"__KERNEL_NAME__": kernel_name,
114-
"__KERNEL_NAME_NORMALIZED__": kernel_name_normalized,
92+
"__KERNEL_NAME_NORMALIZED__": kernel_name,
11593
"__REPO_ID__": repo_id,
11694
}
11795

@@ -177,50 +155,34 @@ def _print_tree(directory: Path, prefix: str = "") -> None:
177155
_print_tree(entry, prefix + extension)
178156

179157

180-
def _update_build_backends(build_toml_path: Path, backends: list[str]) -> None:
158+
def _update_build_backends(build_toml_path: Path, backends: set[str]) -> None:
181159
if not build_toml_path.exists():
182160
return
183-
text = build_toml_path.read_text()
161+
184162
with open(build_toml_path, "rb") as f:
185-
data = tomllib.load(f)
186-
if "general" not in data:
187-
return
188-
kernel_table = data.get("kernel", {})
189-
if not isinstance(kernel_table, dict):
190-
kernel_table = {}
191-
remove_kernels = {
192-
name
193-
for name, cfg in kernel_table.items()
194-
if isinstance(cfg, dict) and cfg.get("backend") not in set(backends)
195-
}
196-
backends_list = ", ".join(f'"{b}"' for b in backends)
197-
new_line = f"backends = [{backends_list}]"
198-
pattern = r"(\[general\][\s\S]*?)^\s*backends\s*=\s*\[[^\]]*\]"
199-
new_text, count = re.subn(pattern, r"\1" + new_line, text, count=1, flags=re.M)
200-
if remove_kernels:
201-
new_text = _remove_kernel_sections(new_text, remove_kernels)
202-
if count or remove_kernels:
203-
build_toml_path.write_text(new_text)
204-
205-
206-
def _remove_kernel_sections(text: str, remove_kernels: set[str]) -> str:
207-
lines = text.splitlines(keepends=True)
208-
output: list[str] = []
209-
skip = False
210-
for line in lines:
211-
match = re.match(r"^\s*\[kernel\.([^\]]+)\]\s*$", line)
212-
if match:
213-
skip = match.group(1).strip() in remove_kernels
214-
if skip:
215-
continue
216-
if skip and re.match(r"^\s*\[[^\]]+\]\s*$", line):
217-
skip = False
218-
if not skip:
219-
output.append(line)
220-
return "".join(output)
221-
222-
223-
def _remove_backend_dirs(target_dir: Path, backends: list[str]) -> None:
163+
build_contents = tomlkit.parse(f.read())
164+
165+
# update backends
166+
if "general" not in build_contents:
167+
return
168+
build_contents["general"]["backends"] = list(backends)
169+
170+
# update kernel sections
171+
if "kernel" in build_contents:
172+
kernel_table = build_contents["kernel"]
173+
remove_kernels = []
174+
for name, cfg in kernel_table.items():
175+
if isinstance(cfg, dict) and cfg.get("backend") not in set(backends):
176+
remove_kernels.append(name)
177+
for name in remove_kernels:
178+
del kernel_table[name]
179+
180+
# write back to file
181+
with open(build_toml_path, "wb") as f:
182+
f.write(tomlkit.dumps(build_contents).encode("utf-8"))
183+
184+
185+
def _remove_backend_dirs(target_dir: Path, backends: set[str]) -> None:
224186
keep = set(backends)
225187
known = set(KNOWN_BACKENDS)
226188
for entry in target_dir.iterdir():

kernels/src/kernels/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from kernels.metadata import Metadata
2323

2424
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
25-
KNOWN_BACKENDS = ("cpu", "cuda", "metal", "rocm", "xpu", "npu")
25+
KNOWN_BACKENDS = {"cpu", "cuda", "metal", "rocm", "xpu", "npu"}
2626

2727

2828
def _get_cache_dir() -> str | None:

kernels/tests/test_init.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import tempfile
2+
from pathlib import Path
3+
import argparse
4+
import os
5+
6+
from kernels.init import run_init, parse_kernel_name
7+
from kernels.utils import KNOWN_BACKENDS
8+
9+
def e2e_init(backends: list[str]) -> None:
10+
kernel_name = "testuser/test-kernel"
11+
template_repo = "drbh/template"
12+
args = argparse.Namespace(
13+
kernel_name=parse_kernel_name(kernel_name),
14+
template_repo=template_repo,
15+
backends=backends,
16+
overwrite=False,
17+
)
18+
expected_normalized_name = "test_kernel"
19+
expected_backend_dirs = {Path(f"{expected_normalized_name}_{backend}") for backend in args.backends}
20+
21+
# Replacement logic
22+
# special case for "rocm" backend since it uses "cuda" source
23+
if "rocm" in args.backends:
24+
expected_backend_dirs.remove(Path(f"{expected_normalized_name}_rocm"))
25+
expected_backend_dirs.add(Path(f"{expected_normalized_name}_cuda"))
26+
if "all" in args.backends:
27+
expected_backend_dirs = {
28+
Path(f"{expected_normalized_name}_{backend}") for backend in KNOWN_BACKENDS
29+
}
30+
# special case for "rocm" backend since it uses "cuda" source
31+
expected_backend_dirs.remove(Path(f"{expected_normalized_name}_rocm"))
32+
expected_backend_dirs.add(Path(f"{expected_normalized_name}_cuda"))
33+
34+
# TODO: npu is not yet supported in the template
35+
expected_backend_dirs.discard(Path(f"{expected_normalized_name}_npu"))
36+
37+
with tempfile.TemporaryDirectory() as tmpdir:
38+
cwd = Path.cwd()
39+
os.chdir(tmpdir)
40+
try:
41+
run_init(args)
42+
43+
# make sure normalized dir was created
44+
target_dir = Path(tmpdir) / expected_normalized_name
45+
if not target_dir.exists():
46+
raise AssertionError(f"Target directory was not created: {target_dir}")
47+
48+
# check that expected backend dirs were created
49+
for expected_backend_dir in expected_backend_dirs:
50+
if not target_dir.joinpath(expected_backend_dir).exists():
51+
raise AssertionError(
52+
f"Expected backend directory was not created: {expected_backend_dir}"
53+
)
54+
55+
finally:
56+
os.chdir(cwd)
57+
58+
59+
def test_end_to_end_init() -> None:
60+
e2e_init(backends=["cuda", "rocm"])
61+
e2e_init(backends=["metal", "cpu"])
62+
e2e_init(backends=["all"])

kernels/uv.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)