From 997c2e2873f5f11c4357fd5f184e4901793b5ea3 Mon Sep 17 00:00:00 2001 From: SiriusNEO Date: Tue, 12 May 2026 17:49:55 +0800 Subject: [PATCH] [Backend] Refactor Pipeline to support different backends Introduce Pipeline abstraction in backend/pipeline.py with per-backend registration. Each backend (cuda, hip, c, llvm) now registers its own compilation pass pipeline. engine/lower.py resolves the pipeline from the target instead of hardcoding phase imports. Co-Authored-By: Claude Opus 4.6 --- tilelang/backend/__init__.py | 4 ++ tilelang/backend/cpu/__init__.py | 1 + tilelang/backend/cpu/pipeline.py | 111 ++++++++++++++++++++++++++++++ tilelang/backend/cuda/__init__.py | 1 + tilelang/backend/cuda/pipeline.py | 17 +++++ tilelang/backend/pipeline.py | 85 +++++++++++++++++++++++ tilelang/backend/rocm/__init__.py | 1 + tilelang/backend/rocm/pipeline.py | 17 +++++ tilelang/engine/lower.py | 15 ++-- 9 files changed, 244 insertions(+), 8 deletions(-) create mode 100644 tilelang/backend/__init__.py create mode 100644 tilelang/backend/cpu/__init__.py create mode 100644 tilelang/backend/cpu/pipeline.py create mode 100644 tilelang/backend/cuda/__init__.py create mode 100644 tilelang/backend/cuda/pipeline.py create mode 100644 tilelang/backend/pipeline.py create mode 100644 tilelang/backend/rocm/__init__.py create mode 100644 tilelang/backend/rocm/pipeline.py diff --git a/tilelang/backend/__init__.py b/tilelang/backend/__init__.py new file mode 100644 index 0000000000..5f12b420e0 --- /dev/null +++ b/tilelang/backend/__init__.py @@ -0,0 +1,4 @@ +# Import built-in backend packages so their pipelines register. +from . import cpu as _cpu # noqa: F401,E402 +from . import cuda as _cuda # noqa: F401,E402 +from . import rocm as _rocm # noqa: F401,E402 diff --git a/tilelang/backend/cpu/__init__.py b/tilelang/backend/cpu/__init__.py new file mode 100644 index 0000000000..d3ea8d0f7b --- /dev/null +++ b/tilelang/backend/cpu/__init__.py @@ -0,0 +1 @@ +from . import pipeline # noqa: F401 diff --git a/tilelang/backend/cpu/pipeline.py b/tilelang/backend/cpu/pipeline.py new file mode 100644 index 0000000000..ba07e0d3ad --- /dev/null +++ b/tilelang/backend/cpu/pipeline.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from tvm import tir, IRModule +from tvm.target import Target + +from tilelang.backend.pipeline import Pipeline, register_pipeline +from tilelang.engine.phase import PreLowerSemanticCheck + + +def cpu_lower_and_legalize(mod: IRModule, target: Target) -> IRModule: + """CPU-specific lower and legalize pipeline. + + A simplified version of the GPU pipeline that skips TMA, warp + specialization, Blackwell-2SM, pipeline planning, and other + GPU-only passes. + """ + import tilelang + + mod = tir.transform.BindTarget(target)(mod) + + if tilelang.engine.phase.should_force_let_inline(): + mod = tilelang.transform.LetInline()(mod) + mod = tilelang.transform.AddWrapperForSingleBufStore()(mod) + mod = tilelang.transform.LegalizeNegativeIndex()(mod) + # Verify parallel loop correctness + if tilelang.engine.phase.should_enable_race_check(): + mod = tilelang.transform.VerifyParallelLoop()(mod) + # Inject assumes to speedup tvm prover + mod = tilelang.transform.InjectAssumes()(mod) + # Simplify the IR expressions + mod = tilelang.transform.Simplify()(mod) + # Set layouts for reducers + mod = tilelang.transform.LayoutReducer()(mod) + # Infer memory layouts for fragments and shared memory + mod = tilelang.transform.LayoutInference()(mod) + # Visualize the layout + tilelang.engine.phase.LayoutVisual(mod) + # Lower high-level tile operations to low-level operations + mod = tilelang.transform.LowerTileOp()(mod) + # Decouple type cast vectorization constraints before vectorization + mod = tilelang.transform.DecoupleTypeCast()(mod) + # Legalize vectorized loops to ensure they are valid + mod = tilelang.transform.LegalizeVectorizedLoop()(mod) + # Add safety checks for memory accesses + mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod) + # Lower frontend pointer metadata op to standard tvm_access_ptr + mod = tilelang.transform.LowerAccessPtr()(mod) + # Simplify again to clean up any duplicated conditions + mod = tilelang.transform.Simplify()(mod) + # Hoist any root-block annotations to PrimFunc attrs + mod = tilelang.transform.HoistNonRestrictParams()(mod) + return mod + + +def cpu_optimize_for_target(mod: IRModule, target: Target) -> IRModule: + """CPU-specific optimize for target pipeline. + + Skips GPU-only passes (TMA, thread sync, shared memory merging, + Hopper intrinsics, etc.). + """ + import tilelang + + pass_ctx = tilelang.transform.get_pass_context() + + mod = tilelang.transform.LowerSharedTmem()(mod) + mod = tilelang.transform.IfStmtBinding()(mod) + mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod) + mod = tilelang.transform.LowerSharedBarrier()(mod) + mod = tilelang.transform.HoistGlobalBufferAllocations()(mod) + mod = tilelang.transform.LowerOpaqueBlock()(mod) + mod = tilelang.transform.Simplify()(mod) + mod = tir.transform.NarrowDataType(32)(mod) + mod = tilelang.transform.FlattenBuffer()(mod) + mod = tilelang.transform.ConfigIndexBitwidth()(mod) + mod = tir.transform.Simplify()(mod) + mod = tilelang.transform.VectorizeLoop(enable_vectorize=tilelang.engine.phase.allow_vectorize(pass_ctx=pass_ctx))(mod) + mod = tilelang.transform.StorageRewrite()(mod) + mod = tilelang.transform.LoopUnswitching()(mod) + mod = tilelang.transform.UnrollLoop()(mod) + mod = tir.transform.RenormalizeSplitPattern()(mod) + mod = tir.transform.Simplify()(mod) + mod = tir.transform.RemoveNoOp()(mod) + mod = tir.transform.HoistIfThenElse()(mod) + mod = tir.transform.VerifyMemory()(mod) + mod = tir.transform.AnnotateEntryFunc()(mod) + mod = tir.transform.InferFragment()(mod) + mod = tilelang.transform.LowerThreadAllreduce()(mod) + mod = tilelang.transform.LowerLDGSTG()(mod) + mod = tilelang.transform.AnnotateDeviceRegions()(mod) + mod = tilelang.transform.SplitHostDevice()(mod) + mod = tilelang.transform.MarkCudaSyncCalls(False)(mod) + mod = tilelang.transform.AnnotateReadOnlyParams()(mod) + mod = tilelang.transform.MergeSharedMemoryAllocations()(mod) + mod = tilelang.transform.InjectFenceProxy()(mod) + mod = tilelang.transform.ThreadSync("shared")(mod) + mod = tilelang.transform.InjectTcgen05Fence()(mod) + mod = tilelang.transform.MergeIfStmt()(mod) + mod = tilelang.transform.MakePackedAPI()(mod) + mod = tilelang.transform.Simplify()(mod) + mod = tilelang.transform.LowerDeviceKernelLaunch()(mod) + return mod + + +# Register CPU pipelines for both "c" and "llvm" target kinds +for _kind in ("c", "llvm"): + register_pipeline( + Pipeline(_kind) + .set_pre_lower_semantic_check(PreLowerSemanticCheck) + .set_lower_and_legalize(cpu_lower_and_legalize) + .set_optimize_for_target(cpu_optimize_for_target) + ) diff --git a/tilelang/backend/cuda/__init__.py b/tilelang/backend/cuda/__init__.py new file mode 100644 index 0000000000..d3ea8d0f7b --- /dev/null +++ b/tilelang/backend/cuda/__init__.py @@ -0,0 +1 @@ +from . import pipeline # noqa: F401 diff --git a/tilelang/backend/cuda/pipeline.py b/tilelang/backend/cuda/pipeline.py new file mode 100644 index 0000000000..d50a5faf4f --- /dev/null +++ b/tilelang/backend/cuda/pipeline.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from tilelang.backend.pipeline import Pipeline, register_pipeline +from tilelang.engine.phase import ( + PreLowerSemanticCheck, + LowerAndLegalize, + OptimizeForTarget, +) + +cuda_pipeline = ( + Pipeline("cuda") + .set_pre_lower_semantic_check(PreLowerSemanticCheck) + .set_lower_and_legalize(LowerAndLegalize) + .set_optimize_for_target(OptimizeForTarget) +) + +register_pipeline(cuda_pipeline) diff --git a/tilelang/backend/pipeline.py b/tilelang/backend/pipeline.py new file mode 100644 index 0000000000..323b244947 --- /dev/null +++ b/tilelang/backend/pipeline.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from typing import Callable + +from tvm import IRModule +from tvm.target import Target + +# Phase function signatures: +# PreLowerFunc: (mod: IRModule) -> None (validation only) +# LowerFunc: (mod: IRModule, target: Target) -> IRModule +# OptimizeFunc: (mod: IRModule, target: Target) -> IRModule + +PreLowerFunc = Callable[[IRModule], None] +LowerFunc = Callable[[IRModule, Target], IRModule] +OptimizeFunc = Callable[[IRModule, Target], IRModule] + + +class Pipeline: + """Compilation pipeline for a specific backend. + + A Pipeline encapsulates three compilation phases: + 1. Pre-lower semantic check -- validate the IR before lowering + 2. Lower and legalize -- bind target, legalize frontend IR, lower tile ops + 3. Optimize for target -- target-specific optimization and codegen preparation + + Each backend registers its own Pipeline so that the compiler can + resolve the correct pass sequence from the target at runtime. + """ + + def __init__(self, name: str): + self.name = name + self._pre_lower: PreLowerFunc | None = None + self._lower_and_legalize: LowerFunc | None = None + self._optimize_for_target: OptimizeFunc | None = None + + def set_pre_lower_semantic_check(self, func: PreLowerFunc) -> Pipeline: + self._pre_lower = func + return self + + def set_lower_and_legalize(self, func: LowerFunc) -> Pipeline: + self._lower_and_legalize = func + return self + + def set_optimize_for_target(self, func: OptimizeFunc) -> Pipeline: + self._optimize_for_target = func + return self + + def pre_lower_semantic_check(self, mod: IRModule) -> None: + if self._pre_lower is not None: + self._pre_lower(mod) + + def lower_and_legalize(self, mod: IRModule, target: Target) -> IRModule: + if self._lower_and_legalize is not None: + return self._lower_and_legalize(mod, target) + return mod + + def optimize_for_target(self, mod: IRModule, target: Target) -> IRModule: + if self._optimize_for_target is not None: + return self._optimize_for_target(mod, target) + return mod + + +_PIPELINES: dict[str, Pipeline] = {} + + +def register_pipeline(pipeline: Pipeline) -> Pipeline: + """Register a compilation pipeline for a backend. + + The pipeline name should match ``target.kind.name`` (e.g. ``"cuda"``, + ``"hip"``, ``"c"``, ``"llvm"``). + """ + _PIPELINES[pipeline.name] = pipeline + return pipeline + + +def get_pipeline(name: str) -> Pipeline: + """Return the registered Pipeline for *name*.""" + if name not in _PIPELINES: + raise ValueError(f"No pipeline registered for backend '{name}'. Available backends: {list(_PIPELINES.keys())}") + return _PIPELINES[name] + + +def resolve_pipeline(target: Target) -> Pipeline: + """Resolve the compilation pipeline from a TVM target.""" + return get_pipeline(target.kind.name) diff --git a/tilelang/backend/rocm/__init__.py b/tilelang/backend/rocm/__init__.py new file mode 100644 index 0000000000..d3ea8d0f7b --- /dev/null +++ b/tilelang/backend/rocm/__init__.py @@ -0,0 +1 @@ +from . import pipeline # noqa: F401 diff --git a/tilelang/backend/rocm/pipeline.py b/tilelang/backend/rocm/pipeline.py new file mode 100644 index 0000000000..c88664e544 --- /dev/null +++ b/tilelang/backend/rocm/pipeline.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from tilelang.backend.pipeline import Pipeline, register_pipeline +from tilelang.engine.phase import ( + PreLowerSemanticCheck, + LowerAndLegalize, + OptimizeForTarget, +) + +rocm_pipeline = ( + Pipeline("hip") + .set_pre_lower_semantic_check(PreLowerSemanticCheck) + .set_lower_and_legalize(LowerAndLegalize) + .set_optimize_for_target(OptimizeForTarget) +) + +register_pipeline(rocm_pipeline) diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 96cb8e841a..a907b49f8e 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -16,11 +16,7 @@ from tilelang.transform.metal import MarkHostMetalContext from tilelang.engine.param import KernelParam, CompiledArtifact from tilelang.utils.target import determine_target, target_get_mcpu -from tilelang.engine.phase import ( - PreLowerSemanticCheck, - LowerAndLegalize, - OptimizeForTarget, -) +from tilelang.backend.pipeline import resolve_pipeline def is_cpu_device_backend(target: Target): @@ -298,14 +294,17 @@ def lower_to_host_device_ir( _is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target)) _is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target)) + # Resolve the compilation pipeline for the target backend + pipeline = resolve_pipeline(target) + # Before lowering, do semantic check - PreLowerSemanticCheck(mod) + pipeline.pre_lower_semantic_check(mod) # Phase 1: Lower and legalize the IR - mod = LowerAndLegalize(mod, target) + mod = pipeline.lower_and_legalize(mod, target) # Phase 2: Optimize the IR for the target - mod = OptimizeForTarget(mod, target) + mod = pipeline.optimize_for_target(mod, target) host_mod = tir.transform.Filter(_is_host_call)(mod) device_mod = tir.transform.Filter(_is_device_call)(mod)