Skip to content
38 changes: 38 additions & 0 deletions aiter/aot/flydsl/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,25 @@ def cu_num_to_arch(cu_num: int, default: str = "gfx950") -> str:
return _CU_NUM_TO_ARCH.get(cu_num, default)


def _job_arch(job: dict[str, Any]) -> str | None:
"""Extract the target arch string from a job dict, regardless of OpKind.

Returns None when the job carries no arch signal (cu_num absent or 0,
no explicit "arch" key) — these are arch-agnostic untuned jobs that
must be compiled regardless of GPU_ARCHS.

GEMM and MoE tuned job dicts carry ``cu_num`` > 0 (arch derived via
cu_num_to_arch). CHUNK_GDN_H job dicts carry an explicit ``"arch"``
string and have no ``cu_num`` field.
"""
if "arch" in job:
return job["arch"]
cu_num = job.get("cu_num", 0)
if not cu_num:
return None # untuned / arch-agnostic job — always include
return cu_num_to_arch(cu_num)


def job_identity(job: dict[str, Any]) -> tuple:
return tuple(sorted(job.items()))

Expand Down Expand Up @@ -301,6 +320,25 @@ def start_aot(
for job in _collect_aot_jobs_for(kind):
all_jobs.append((kind, job))

gpu_archs_env = os.environ.get("GPU_ARCHS", "").strip()
if gpu_archs_env and gpu_archs_env.lower() != "native":
from aiter.jit.utils.build_targets import _parse_gpu_archs_env

requested = set(_parse_gpu_archs_env(gpu_archs_env))
before = len(all_jobs)
all_jobs = [
(kind, job)
for kind, job in all_jobs
if (arch := _job_arch(job)) is None or arch in requested
]
filtered = before - len(all_jobs)
if filtered:
print(
f"[aiter] FlyDSL AOT: GPU_ARCHS={gpu_archs_env!r} skipped"
f" {filtered} kernels for unrequested arches"
f" ({len(all_jobs)} remaining)"
)

Comment on lines +323 to +341
if not all_jobs:
print("[aiter] FlyDSL AOT: no kernels to compile, skipping")
return None, {}
Expand Down
Loading