1+ import argparse
12import os
2- import re
33import shutil
44import subprocess
55import sys
66from argparse import Namespace
77from pathlib import Path
8+ from typing import NamedTuple
89
910from huggingface_hub import snapshot_download
1011from huggingface_hub .utils import disable_progress_bars
1112
12- from kernels . compat import tomllib
13+ import tomlkit
1314from kernels .utils import KNOWN_BACKENDS
1415
1516
16- def run_init (args : Namespace ) -> None :
17- kernel_name = args .kernel_name
18- if args .backends is None :
19- backends = ["metal" ] if sys .platform == "darwin" else ["cuda" ]
20- else :
21- backends = [
22- v .strip ().lower ()
23- for item in args .backends
24- for v in item .split ("," )
25- if v .strip ()
26- ]
27- if "all" in backends :
28- if len (backends ) > 1 :
29- print (
30- "Error: --backends must be either 'all' or a list of backends." ,
31- file = sys .stderr ,
32- )
33- sys .exit (1 )
34- backends = []
35- else :
36- valid = set (KNOWN_BACKENDS )
37- invalid = sorted (set (backends ) - valid )
38- if invalid :
39- print (
40- f"Error: invalid backend(s): { ', ' .join (invalid )} . Valid values are: { ', ' .join (sorted (valid ))} ." ,
41- file = sys .stderr ,
42- )
43- sys .exit (1 )
44- seen : set [str ] = set ()
45- backends = [b for b in backends if not (b in seen or seen .add (b ))]
46- # must be fully qualified repo name <owner>/<repo>
47- owner_repo = kernel_name .split ("/" )
48- if len (owner_repo ) != 2 :
49- print (
50- f"Error: kernel_name must be in the format <owner>/<repo> (e.g., drbh/my-kernel)" ,
51- file = sys .stderr ,
52- )
53- sys .exit (1 )
54- owner , kernel_name = owner_repo
55- kernel_name_normalized = kernel_name .replace ("-" , "_" )
56- repo_id = f"{ owner } /{ kernel_name } "
57- output_dir = Path .cwd ()
17+ def parse_kernel_name (value : str ) -> NamedTuple :
18+ parts = value .split ("/" )
19+ if len (parts ) != 2 or not all (parts ): # validate format
20+ raise argparse .ArgumentTypeError ("must be <owner>/<repo>" )
21+ owner , name = parts
22+
23+ if "/" in name or "\\ " in name : # validate kernel name
24+ raise argparse .ArgumentTypeError ("repo name cannot contain path separators" )
25+
26+ name = name .lower ().replace ("-" , "_" ) # normalize name
27+ RepoInfo = NamedTuple ("RepoInfo" , [("name" , str ), ("owner" , str ), ("repo_id" , str )])
28+ return RepoInfo (name = name , owner = owner , repo_id = f"{ owner } /{ name } " )
5829
59- # Validate kernel name
60- if "/" in kernel_name or "\\ " in kernel_name :
61- print (
62- f"Error: Kernel name cannot contain path separators: { kernel_name } " ,
63- file = sys .stderr ,
64- )
65- sys .exit (1 )
30+
31+ def run_init (args : Namespace ) -> None :
32+ kernel_name = args .kernel_name .name
33+ repo_id = args .kernel_name .repo_id
34+ backends = KNOWN_BACKENDS if "all" in args .backends else set (args .backends )
6635
6736 # Target directory
68- target_dir = output_dir / kernel_name
37+ target_dir = Path .cwd () / kernel_name
38+
39+ if args .overwrite :
40+ if target_dir .exists ():
41+ shutil .rmtree (target_dir )
42+
6943 if target_dir .exists () and any (target_dir .iterdir ()):
7044 print (
7145 f"Error: Directory already exists and is not empty: { target_dir } " ,
7246 file = sys .stderr ,
7347 )
7448 sys .exit (1 )
7549
76- # Download template from HuggingFace
77- template_repo = args .template_repo
78-
7950 # Suppress progress bars for cleaner output (files are often cached)
8051 disable_progress_bars ()
8152
82- print (f"Downloading template from { template_repo } ..." , file = sys .stderr )
83- template_dir = Path (snapshot_download (repo_id = template_repo , repo_type = "model" ))
84- _init_from_local_template (
85- template_dir , target_dir , kernel_name , kernel_name_normalized , repo_id
53+ print (f"Downloading template from { args .template_repo } ..." , file = sys .stderr )
54+ template_dir = Path (
55+ snapshot_download (repo_id = args .template_repo , repo_type = "model" )
8656 )
57+ _init_from_local_template (template_dir , target_dir , kernel_name , repo_id )
58+
8759 if backends :
8860 _update_build_backends (target_dir / "build.toml" , backends )
61+
62+ # replacement logic
63+ # - rocm uses cuda source so we need to replace the rocm with cuda
64+ if "rocm" in backends :
65+ backends .remove ("rocm" )
66+ backends .add ("cuda" )
67+
8968 _remove_backend_dirs (target_dir , backends )
9069
9170 # Initialize git repo (required for Nix flakes)
@@ -94,24 +73,23 @@ def run_init(args: Namespace) -> None:
9473
9574 print (f"Initialized kernel project: { target_dir } " )
9675 _print_tree (target_dir )
97- print ("\n Next steps:" )
98- print (f" cd { kernel_name } " )
99- print (" cachix use huggingface" )
100- print (" nix run -L --max-jobs 1 --cores 8 .#build-and-copy" )
101- print (" uv run example.py" )
76+ print ("\n Next steps:\n " )
77+ print (f"cd { kernel_name } " )
78+ print ("cachix use huggingface" )
79+ print ("nix run -L --max-jobs 1 --cores 8 .#build-and-copy" )
80+ print ("uv run example.py" )
10281
10382
10483def _init_from_local_template (
10584 template_dir : Path ,
10685 target_dir : Path ,
10786 kernel_name : str ,
108- kernel_name_normalized : str ,
10987 repo_id : str ,
11088) -> None :
11189 # Placeholder mappings
11290 replacements = {
11391 "__KERNEL_NAME__" : kernel_name ,
114- "__KERNEL_NAME_NORMALIZED__" : kernel_name_normalized ,
92+ "__KERNEL_NAME_NORMALIZED__" : kernel_name ,
11593 "__REPO_ID__" : repo_id ,
11694 }
11795
@@ -177,50 +155,34 @@ def _print_tree(directory: Path, prefix: str = "") -> None:
177155 _print_tree (entry , prefix + extension )
178156
179157
180- def _update_build_backends (build_toml_path : Path , backends : list [str ]) -> None :
158+ def _update_build_backends (build_toml_path : Path , backends : set [str ]) -> None :
181159 if not build_toml_path .exists ():
182160 return
183- text = build_toml_path . read_text ()
161+
184162 with open (build_toml_path , "rb" ) as f :
185- data = tomllib .load (f )
186- if "general" not in data :
187- return
188- kernel_table = data .get ("kernel" , {})
189- if not isinstance (kernel_table , dict ):
190- kernel_table = {}
191- remove_kernels = {
192- name
193- for name , cfg in kernel_table .items ()
194- if isinstance (cfg , dict ) and cfg .get ("backend" ) not in set (backends )
195- }
196- backends_list = ", " .join (f'"{ b } "' for b in backends )
197- new_line = f"backends = [{ backends_list } ]"
198- pattern = r"(\[general\][\s\S]*?)^\s*backends\s*=\s*\[[^\]]*\]"
199- new_text , count = re .subn (pattern , r"\1" + new_line , text , count = 1 , flags = re .M )
200- if remove_kernels :
201- new_text = _remove_kernel_sections (new_text , remove_kernels )
202- if count or remove_kernels :
203- build_toml_path .write_text (new_text )
204-
205-
206- def _remove_kernel_sections (text : str , remove_kernels : set [str ]) -> str :
207- lines = text .splitlines (keepends = True )
208- output : list [str ] = []
209- skip = False
210- for line in lines :
211- match = re .match (r"^\s*\[kernel\.([^\]]+)\]\s*$" , line )
212- if match :
213- skip = match .group (1 ).strip () in remove_kernels
214- if skip :
215- continue
216- if skip and re .match (r"^\s*\[[^\]]+\]\s*$" , line ):
217- skip = False
218- if not skip :
219- output .append (line )
220- return "" .join (output )
221-
222-
223- def _remove_backend_dirs (target_dir : Path , backends : list [str ]) -> None :
163+ build_contents = tomlkit .parse (f .read ())
164+
165+ # update backends
166+ if "general" not in build_contents :
167+ return
168+ build_contents ["general" ]["backends" ] = list (backends )
169+
170+ # update kernel sections
171+ if "kernel" in build_contents :
172+ kernel_table = build_contents ["kernel" ]
173+ remove_kernels = []
174+ for name , cfg in kernel_table .items ():
175+ if isinstance (cfg , dict ) and cfg .get ("backend" ) not in set (backends ):
176+ remove_kernels .append (name )
177+ for name in remove_kernels :
178+ del kernel_table [name ]
179+
180+ # write back to file
181+ with open (build_toml_path , "wb" ) as f :
182+ f .write (tomlkit .dumps (build_contents ).encode ("utf-8" ))
183+
184+
185+ def _remove_backend_dirs (target_dir : Path , backends : set [str ]) -> None :
224186 keep = set (backends )
225187 known = set (KNOWN_BACKENDS )
226188 for entry in target_dir .iterdir ():
0 commit comments