Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tilelang/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Import built-in backend packages so their pipelines register.
from . import cpu as _cpu # noqa: F401,E402
from . import cuda as _cuda # noqa: F401,E402
from . import rocm as _rocm # noqa: F401,E402
1 change: 1 addition & 0 deletions tilelang/backend/cpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import pipeline # noqa: F401
111 changes: 111 additions & 0 deletions tilelang/backend/cpu/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from __future__ import annotations

from tvm import tir, IRModule
from tvm.target import Target

from tilelang.backend.pipeline import Pipeline, register_pipeline
from tilelang.engine.phase import PreLowerSemanticCheck


def cpu_lower_and_legalize(mod: IRModule, target: Target) -> IRModule:
"""CPU-specific lower and legalize pipeline.

A simplified version of the GPU pipeline that skips TMA, warp
specialization, Blackwell-2SM, pipeline planning, and other
GPU-only passes.
"""
import tilelang

mod = tir.transform.BindTarget(target)(mod)

if tilelang.engine.phase.should_force_let_inline():
mod = tilelang.transform.LetInline()(mod)
mod = tilelang.transform.AddWrapperForSingleBufStore()(mod)
mod = tilelang.transform.LegalizeNegativeIndex()(mod)
# Verify parallel loop correctness
if tilelang.engine.phase.should_enable_race_check():
mod = tilelang.transform.VerifyParallelLoop()(mod)
# Inject assumes to speedup tvm prover
mod = tilelang.transform.InjectAssumes()(mod)
# Simplify the IR expressions
mod = tilelang.transform.Simplify()(mod)
# Set layouts for reducers
mod = tilelang.transform.LayoutReducer()(mod)
# Infer memory layouts for fragments and shared memory
mod = tilelang.transform.LayoutInference()(mod)
# Visualize the layout
tilelang.engine.phase.LayoutVisual(mod)
# Lower high-level tile operations to low-level operations
mod = tilelang.transform.LowerTileOp()(mod)
# Decouple type cast vectorization constraints before vectorization
mod = tilelang.transform.DecoupleTypeCast()(mod)
# Legalize vectorized loops to ensure they are valid
mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
# Add safety checks for memory accesses
mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod)
# Lower frontend pointer metadata op to standard tvm_access_ptr
mod = tilelang.transform.LowerAccessPtr()(mod)
# Simplify again to clean up any duplicated conditions
mod = tilelang.transform.Simplify()(mod)
# Hoist any root-block annotations to PrimFunc attrs
mod = tilelang.transform.HoistNonRestrictParams()(mod)
return mod


def cpu_optimize_for_target(mod: IRModule, target: Target) -> IRModule:
"""CPU-specific optimize for target pipeline.

Skips GPU-only passes (TMA, thread sync, shared memory merging,
Hopper intrinsics, etc.).
"""
import tilelang

pass_ctx = tilelang.transform.get_pass_context()

mod = tilelang.transform.LowerSharedTmem()(mod)
mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tilelang.transform.LowerSharedBarrier()(mod)
mod = tilelang.transform.HoistGlobalBufferAllocations()(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.Simplify()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=tilelang.engine.phase.allow_vectorize(pass_ctx=pass_ctx))(mod)
mod = tilelang.transform.StorageRewrite()(mod)
mod = tilelang.transform.LoopUnswitching()(mod)
mod = tilelang.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.RemoveNoOp()(mod)
mod = tir.transform.HoistIfThenElse()(mod)
mod = tir.transform.VerifyMemory()(mod)
mod = tir.transform.AnnotateEntryFunc()(mod)
mod = tir.transform.InferFragment()(mod)
mod = tilelang.transform.LowerThreadAllreduce()(mod)
mod = tilelang.transform.LowerLDGSTG()(mod)
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tilelang.transform.SplitHostDevice()(mod)
mod = tilelang.transform.MarkCudaSyncCalls(False)(mod)
mod = tilelang.transform.AnnotateReadOnlyParams()(mod)
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
mod = tilelang.transform.InjectFenceProxy()(mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.InjectTcgen05Fence()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
mod = tilelang.transform.MakePackedAPI()(mod)
mod = tilelang.transform.Simplify()(mod)
mod = tilelang.transform.LowerDeviceKernelLaunch()(mod)
return mod


# Register CPU pipelines for both "c" and "llvm" target kinds
for _kind in ("c", "llvm"):
register_pipeline(
Pipeline(_kind)
.set_pre_lower_semantic_check(PreLowerSemanticCheck)
.set_lower_and_legalize(cpu_lower_and_legalize)
.set_optimize_for_target(cpu_optimize_for_target)
)
1 change: 1 addition & 0 deletions tilelang/backend/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import pipeline # noqa: F401
17 changes: 17 additions & 0 deletions tilelang/backend/cuda/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from tilelang.backend.pipeline import Pipeline, register_pipeline
from tilelang.engine.phase import (
PreLowerSemanticCheck,
LowerAndLegalize,
OptimizeForTarget,
)

cuda_pipeline = (
Pipeline("cuda")
.set_pre_lower_semantic_check(PreLowerSemanticCheck)
.set_lower_and_legalize(LowerAndLegalize)
.set_optimize_for_target(OptimizeForTarget)
)

register_pipeline(cuda_pipeline)
85 changes: 85 additions & 0 deletions tilelang/backend/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

from typing import Callable

from tvm import IRModule
from tvm.target import Target

# Phase function signatures:
# PreLowerFunc: (mod: IRModule) -> None (validation only)
# LowerFunc: (mod: IRModule, target: Target) -> IRModule
# OptimizeFunc: (mod: IRModule, target: Target) -> IRModule

PreLowerFunc = Callable[[IRModule], None]
LowerFunc = Callable[[IRModule, Target], IRModule]
OptimizeFunc = Callable[[IRModule, Target], IRModule]


class Pipeline:
"""Compilation pipeline for a specific backend.

A Pipeline encapsulates three compilation phases:
1. Pre-lower semantic check -- validate the IR before lowering
2. Lower and legalize -- bind target, legalize frontend IR, lower tile ops
3. Optimize for target -- target-specific optimization and codegen preparation

Each backend registers its own Pipeline so that the compiler can
resolve the correct pass sequence from the target at runtime.
"""

def __init__(self, name: str):
self.name = name
self._pre_lower: PreLowerFunc | None = None
self._lower_and_legalize: LowerFunc | None = None
self._optimize_for_target: OptimizeFunc | None = None

def set_pre_lower_semantic_check(self, func: PreLowerFunc) -> Pipeline:
self._pre_lower = func
return self

def set_lower_and_legalize(self, func: LowerFunc) -> Pipeline:
self._lower_and_legalize = func
return self

def set_optimize_for_target(self, func: OptimizeFunc) -> Pipeline:
self._optimize_for_target = func
return self

def pre_lower_semantic_check(self, mod: IRModule) -> None:
if self._pre_lower is not None:
self._pre_lower(mod)

def lower_and_legalize(self, mod: IRModule, target: Target) -> IRModule:
if self._lower_and_legalize is not None:
return self._lower_and_legalize(mod, target)
return mod

def optimize_for_target(self, mod: IRModule, target: Target) -> IRModule:
if self._optimize_for_target is not None:
return self._optimize_for_target(mod, target)
return mod


_PIPELINES: dict[str, Pipeline] = {}


def register_pipeline(pipeline: Pipeline) -> Pipeline:
"""Register a compilation pipeline for a backend.

The pipeline name should match ``target.kind.name`` (e.g. ``"cuda"``,
``"hip"``, ``"c"``, ``"llvm"``).
"""
_PIPELINES[pipeline.name] = pipeline
return pipeline


def get_pipeline(name: str) -> Pipeline:
"""Return the registered Pipeline for *name*."""
if name not in _PIPELINES:
raise ValueError(f"No pipeline registered for backend '{name}'. Available backends: {list(_PIPELINES.keys())}")
return _PIPELINES[name]


def resolve_pipeline(target: Target) -> Pipeline:
"""Resolve the compilation pipeline from a TVM target."""
return get_pipeline(target.kind.name)
1 change: 1 addition & 0 deletions tilelang/backend/rocm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import pipeline # noqa: F401
17 changes: 17 additions & 0 deletions tilelang/backend/rocm/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from tilelang.backend.pipeline import Pipeline, register_pipeline
from tilelang.engine.phase import (
PreLowerSemanticCheck,
LowerAndLegalize,
OptimizeForTarget,
)

rocm_pipeline = (
Pipeline("hip")
.set_pre_lower_semantic_check(PreLowerSemanticCheck)
.set_lower_and_legalize(LowerAndLegalize)
.set_optimize_for_target(OptimizeForTarget)
)

register_pipeline(rocm_pipeline)
15 changes: 7 additions & 8 deletions tilelang/engine/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@
from tilelang.transform.metal import MarkHostMetalContext
from tilelang.engine.param import KernelParam, CompiledArtifact
from tilelang.utils.target import determine_target, target_get_mcpu
from tilelang.engine.phase import (
PreLowerSemanticCheck,
LowerAndLegalize,
OptimizeForTarget,
)
from tilelang.backend.pipeline import resolve_pipeline


def is_cpu_device_backend(target: Target):
Expand Down Expand Up @@ -298,14 +294,17 @@ def lower_to_host_device_ir(
_is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target))
_is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target))

# Resolve the compilation pipeline for the target backend
pipeline = resolve_pipeline(target)

# Before lowering, do semantic check
PreLowerSemanticCheck(mod)
pipeline.pre_lower_semantic_check(mod)

# Phase 1: Lower and legalize the IR
mod = LowerAndLegalize(mod, target)
mod = pipeline.lower_and_legalize(mod, target)

# Phase 2: Optimize the IR for the target
mod = OptimizeForTarget(mod, target)
mod = pipeline.optimize_for_target(mod, target)

host_mod = tir.transform.Filter(_is_host_call)(mod)
device_mod = tir.transform.Filter(_is_device_call)(mod)
Expand Down