diff --git a/aiter/aot/flydsl/common.py b/aiter/aot/flydsl/common.py index 61ef2f478f..b02b752647 100644 --- a/aiter/aot/flydsl/common.py +++ b/aiter/aot/flydsl/common.py @@ -63,6 +63,25 @@ def cu_num_to_arch(cu_num: int, default: str = "gfx950") -> str: return _CU_NUM_TO_ARCH.get(cu_num, default) +def _job_arch(job: dict[str, Any]) -> str | None: + """Extract the target arch string from a job dict, regardless of OpKind. + + Returns None when the job carries no arch signal (cu_num absent or 0, + no explicit "arch" key) — these are arch-agnostic untuned jobs that + must be compiled regardless of GPU_ARCHS. + + GEMM and MoE tuned job dicts carry ``cu_num`` > 0 (arch derived via + cu_num_to_arch). CHUNK_GDN_H job dicts carry an explicit ``"arch"`` + string and have no ``cu_num`` field. + """ + if "arch" in job: + return job["arch"] + cu_num = job.get("cu_num", 0) + if not cu_num: + return None # untuned / arch-agnostic job — always include + return cu_num_to_arch(cu_num) + + def job_identity(job: dict[str, Any]) -> tuple: return tuple(sorted(job.items())) @@ -301,6 +320,25 @@ def start_aot( for job in _collect_aot_jobs_for(kind): all_jobs.append((kind, job)) + gpu_archs_env = os.environ.get("GPU_ARCHS", "").strip() + if gpu_archs_env and gpu_archs_env.lower() != "native": + from aiter.jit.utils.build_targets import _parse_gpu_archs_env + + requested = set(_parse_gpu_archs_env(gpu_archs_env)) + before = len(all_jobs) + all_jobs = [ + (kind, job) + for kind, job in all_jobs + if (arch := _job_arch(job)) is None or arch in requested + ] + filtered = before - len(all_jobs) + if filtered: + print( + f"[aiter] FlyDSL AOT: GPU_ARCHS={gpu_archs_env!r} skipped" + f" {filtered} kernels for unrequested arches" + f" ({len(all_jobs)} remaining)" + ) + if not all_jobs: print("[aiter] FlyDSL AOT: no kernels to compile, skipping") return None, {}