11#!/usr/bin/env python3
22"""
3- Dispatch release workflows for a kernel.
3+ Dispatch build workflows for a kernel.
44
5- Three entrypoints call this script:
6- 1. The PR-merge dummy workflow (via CLI)
7- 2. The comment bot (via import)
8- 3. Local CLI invocation
5+ Four entrypoints call this script:
6+ 1. The PR-merge dispatch workflow (via CLI)
7+ 2. The PR-open dispatch workflow (via CLI)
8+ 3. The comment bot (via import)
9+ 4. Local CLI invocation
910"""
1011
1112import argparse
2324
2425
2526RELEASE_WORKFLOWS = [
26- "build-release .yaml" ,
27- "build-release- mac.yaml" ,
28- "build-release- windows.yaml" ,
27+ "build.yaml" ,
28+ "build-mac.yaml" ,
29+ "build-windows.yaml" ,
2930]
3031
3132KERNEL_NAME_RE = re .compile (r"^[A-Za-z0-9_-]+$" )
@@ -65,6 +66,9 @@ def run_local(
6566 workflow : str ,
6667 kernel_name : str ,
6768 * ,
69+ mode : str = "release" ,
70+ backends : str = "" ,
71+ repo_prefix : str = "kernels-community" ,
6872 skip_build : bool = False ,
6973 pr_number : str = "" ,
7074 target_branch : str = "" ,
@@ -76,6 +80,9 @@ def run_local(
7680 "--container-options" , "--privileged" ,
7781 "-W" , f".github/workflows/{ workflow } " ,
7882 "--input" , f"kernel_name={ kernel_name } " ,
83+ "--input" , f"mode={ mode } " ,
84+ "--input" , f"backends={ backends } " ,
85+ "--input" , f"repo_prefix={ repo_prefix } " ,
7986 ]
8087 if skip_build :
8188 cmd .extend (["--input" , "skip_build=true" ])
@@ -129,11 +136,19 @@ def get_repo() -> str | None:
129136
130137
131138BACKEND_TO_WORKFLOWS = {
132- "cuda" : {"build-release.yaml" , "build-release-windows.yaml" },
133- "cpu" : {"build-release.yaml" },
134- "rocm" : {"build-release.yaml" },
135- "metal" : {"build-release-mac.yaml" },
136- "xpu" : {"build-release.yaml" , "build-release-windows.yaml" },
139+ "cuda" : {"build.yaml" , "build-windows.yaml" },
140+ "cpu" : {"build.yaml" },
141+ "rocm" : {"build.yaml" },
142+ "metal" : {"build-mac.yaml" },
143+ "xpu" : {"build.yaml" , "build-windows.yaml" },
144+ }
145+
146+ # Only these kernels are known to build successfully on Windows.
147+ # Add new entries here as Windows support is validated for a kernel.
148+ WINDOWS_KERNELS = {
149+ "relu" ,
150+ "activation" ,
151+ "flash-attn2" ,
137152}
138153
139154
@@ -152,15 +167,15 @@ def read_backends(kernel_name: str) -> list[str] | None:
152167 return None
153168
154169
155- def select_workflows (kernel_name : str ) -> list [str ]:
170+ def select_workflows (kernel_name : str ) -> set [str ]:
156171 """
157- Determine which release workflows to dispatch based on the kernel's
172+ Determine which build workflows to dispatch based on the kernel's
158173 backends declared in build.toml.
159174
160175 Mapping:
161- cuda, cpu, rocm -> build-release .yaml (Linux)
162- metal -> build-release- mac.yaml (macOS)
163- xpu -> build-release- windows.yaml (Windows)
176+ cuda, cpu, rocm -> build.yaml (Linux)
177+ metal -> build-mac.yaml (macOS)
178+ cuda, xpu -> build-windows.yaml (Windows, allowlisted kernels only )
164179
165180 Falls back to all workflows if build.toml can't be read.
166181 """
@@ -177,6 +192,11 @@ def select_workflows(kernel_name: str) -> list[str]:
177192 print (f"No known backends found for { kernel_name } : { backends } , dispatching all workflows" )
178193 return set (RELEASE_WORKFLOWS )
179194
195+ # Only dispatch Windows builds for kernels known to build there.
196+ if "build-windows.yaml" in workflows and kernel_name not in WINDOWS_KERNELS :
197+ workflows .discard ("build-windows.yaml" )
198+ print (f"Skipping Windows build for { kernel_name } (not in WINDOWS_KERNELS allowlist)" )
199+
180200 return workflows
181201
182202
@@ -186,23 +206,29 @@ def dispatch_release(
186206 token : str ,
187207 repo : str ,
188208 ref : str = "main" ,
209+ mode : str = "release" ,
210+ repo_prefix : str = "kernels-community" ,
189211 dispatch_key_prefix : str = "" ,
190212 local : bool = False ,
213+ dry_run : bool = False ,
191214 skip_build : bool = False ,
192215 pr_number : str = "" ,
193216 target_branch : str = "" ,
194217 upload : bool = True ,
195218) -> ReleaseDispatchResult :
196219 """
197- Dispatch the appropriate release workflows for a kernel.
220+ Dispatch the appropriate build workflows for a kernel.
198221
199222 Args:
200223 kernel_name: Name of the kernel directory.
201224 token: GitHub API token.
202225 repo: GitHub repository in "owner/repo" format.
203226 ref: Git ref to dispatch against (default "main").
227+ mode: Build mode - "pr" for CI builds, "release" for full builds.
228+ repo_prefix: Hub org prefix for uploads (default "kernels-community").
204229 dispatch_key_prefix: Optional prefix for dispatch keys (e.g. "pr42-").
205230 local: Run locally via act instead of remote dispatch.
231+ dry_run: Print what would be dispatched without actually dispatching.
206232 skip_build: Skip build and upload steps.
207233 pr_number: Optional PR number to checkout before building.
208234 target_branch: Target branch for upload.
@@ -220,18 +246,54 @@ def dispatch_release(
220246
221247 result = ReleaseDispatchResult (kernel_name = kernel_name )
222248
249+ backends = read_backends (kernel_name ) or []
223250 workflows = select_workflows (kernel_name )
251+
252+ # Invert BACKEND_TO_WORKFLOWS so we can scope backends per workflow.
253+ workflow_to_backends : dict [str , set [str ]] = {}
254+ for backend , wfs in BACKEND_TO_WORKFLOWS .items ():
255+ for wf in wfs :
256+ workflow_to_backends .setdefault (wf , set ()).add (backend )
257+
224258 skipped_workflows = set (RELEASE_WORKFLOWS ) - workflows
225259 result .skipped = sorted (skipped_workflows )
226260
227261 api_base = f"https://api.github.com/repos/{ repo } "
228262 for workflow in workflows :
263+ # Only pass backends that this workflow can actually build.
264+ scoped = sorted (b for b in backends if b in workflow_to_backends .get (workflow , set ()))
265+ backends_csv = "," .join (scoped )
266+
229267 dispatch_key = (
230268 f"{ dispatch_key_prefix } { kernel_name } -{ workflow } -{ uuid .uuid4 ().hex [:12 ]} "
231269 )
270+ if dry_run :
271+ inputs = {
272+ "kernel_name" : kernel_name ,
273+ "dispatch_key" : dispatch_key ,
274+ "mode" : mode ,
275+ "backends" : backends_csv ,
276+ "repo_prefix" : repo_prefix ,
277+ }
278+ if skip_build :
279+ inputs ["skip_build" ] = "true"
280+ if pr_number :
281+ inputs ["pr_number" ] = pr_number
282+ if target_branch :
283+ inputs ["target_branch" ] = target_branch
284+ if not upload :
285+ inputs ["upload" ] = "false"
286+ dispatch_body = {"ref" : ref , "inputs" : inputs }
287+ print (f"\n [dry-run] { workflow } :" )
288+ print (json .dumps (dispatch_body , indent = 2 ))
289+ result .dispatched .append ((workflow , dispatch_key ))
290+ continue
232291 if local :
233292 if run_local (
234293 workflow , kernel_name ,
294+ mode = mode ,
295+ backends = backends_csv ,
296+ repo_prefix = repo_prefix ,
235297 skip_build = skip_build ,
236298 pr_number = pr_number ,
237299 target_branch = target_branch ,
@@ -245,6 +307,9 @@ def dispatch_release(
245307 inputs = {
246308 "kernel_name" : kernel_name ,
247309 "dispatch_key" : dispatch_key ,
310+ "mode" : mode ,
311+ "backends" : backends_csv ,
312+ "repo_prefix" : repo_prefix ,
248313 }
249314 if skip_build :
250315 inputs ["skip_build" ] = "true"
@@ -278,6 +343,10 @@ def main() -> int:
278343 parser .add_argument (
279344 "--ref" , default = "main" , help = "Git ref to dispatch on (default: main)"
280345 )
346+ parser .add_argument (
347+ "--mode" , default = "release" , choices = ["pr" , "release" ],
348+ help = "Build mode: pr (CI only) or release (build + upload) (default: release)" ,
349+ )
281350 parser .add_argument (
282351 "--repo" , default = None , help = "GitHub repo in owner/repo format (default: auto-detect)"
283352 )
@@ -301,22 +370,33 @@ def main() -> int:
301370 "--no-upload" , action = "store_true" ,
302371 help = "Build only, do not upload" ,
303372 )
373+ parser .add_argument (
374+ "--dry-run" , action = "store_true" ,
375+ help = "Print the dispatch payloads without actually dispatching" ,
376+ )
377+ parser .add_argument (
378+ "--repo-prefix" , default = "kernels-community" ,
379+ help = "Hub org prefix for uploads (default: kernels-community)" ,
380+ )
304381 args = parser .parse_args ()
305382
306383 common = dict (
384+ mode = args .mode ,
385+ repo_prefix = args .repo_prefix ,
386+ dry_run = args .dry_run ,
307387 skip_build = args .skip_build ,
308388 pr_number = args .pr_number ,
309389 target_branch = args .target_branch ,
310390 upload = not args .no_upload ,
311391 )
312392
313- if args .local :
393+ if args .dry_run or args . local :
314394 result = dispatch_release (
315395 args .kernel_name ,
316396 token = "" ,
317- repo = "" ,
397+ repo = args . repo or "" ,
318398 ref = args .ref ,
319- local = True ,
399+ local = args . local ,
320400 ** common ,
321401 )
322402 else :
0 commit comments