Skip to content

Commit 0a9b651

Browse files
authored
[Backend] Refactor Transform Pipeline to support different backends (#2189)
* [Backend] Refactor Pipeline to support different backends Introduce Pipeline abstraction in backend/pipeline.py with per-backend registration. Each backend (cuda, hip, c, llvm) now registers its own compilation pass pipeline. engine/lower.py resolves the pipeline from the target instead of hardcoding phase imports. * enhance the structure * fuse two stage * fix after rebase * update template pass pipeline * pipeline * fix test * fix test * fix cuda pipeline prologue
1 parent 362d9e0 commit 0a9b651

20 files changed

Lines changed: 697 additions & 29 deletions

testing/python/issue/test_tilelang_issue_2123.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tilelang import tvm
55
from tvm import tirx
66
from tvm.tirx import op
7-
from tilelang.engine.phase import LowerAndLegalize
7+
from tilelang.backend.cuda.pipeline import CUDAPassPipelineBodyPrologue
88
from tilelang.transform import LowerAccessPtr
99

1010

@@ -65,7 +65,7 @@ def test_issue_2123_atomic_load_lower_access_ptr_pipeline():
6565
func = issue_2123_atomic_load_repro(4).with_attr("global_symbol", "main")
6666
mod = tvm.IRModule.from_expr(func)
6767

68-
lowered = LowerAndLegalize(mod, target)
68+
lowered = CUDAPassPipelineBodyPrologue(mod, target)
6969

7070
_assert_access_ptr_lowered(lowered)
7171

testing/python/transform/test_tilelang_transform_inject_tcgen05_fence.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from tilelang import tvm as tvm
33
import tilelang as tl
44
import tilelang.language as T
5-
from tilelang.engine.phase import LowerAndLegalize
5+
from tilelang.backend.cuda.pipeline import CUDAPassPipelineBodyPrologue
66
from tvm import tirx
77

88

@@ -118,7 +118,7 @@ def func(X: T.Tensor((256, 256), T.float16), Y: T.Tensor((256, 256), T.float16))
118118

119119
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
120120
with sm100_target:
121-
mod = LowerAndLegalize(mod, sm100_target)
121+
mod = CUDAPassPipelineBodyPrologue(mod, sm100_target)
122122
mod = tl.transform.LowerSharedTmem()(mod)
123123

124124
body = mod["main"].body
@@ -166,7 +166,7 @@ def func(X: T.Tensor((256, 256), T.bfloat16)):
166166

167167
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
168168
with sm100_target:
169-
mod = LowerAndLegalize(mod, sm100_target)
169+
mod = CUDAPassPipelineBodyPrologue(mod, sm100_target)
170170
mod = tl.transform.LowerSharedTmem()(mod)
171171

172172
body = mod["main"].body

testing/python/transform/test_tilelang_transform_lexical_alloc_scope.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tilelang as tl
1313
import tilelang.language as T
1414
from tilelang import tvm
15-
from tilelang.engine.phase import LowerAndLegalize
15+
from tilelang.backend.cuda.pipeline import CUDAPassPipelineBodyPrologue
1616
from tvm.tirx.stmt_functor import post_order_visit
1717
import tilelang.testing
1818

@@ -51,7 +51,7 @@ def _apply_lower_opaque_pipeline(func, target, pass_configs=None):
5151
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
5252
pass_configs = pass_configs or {}
5353
with target, tvm.transform.PassContext(config=pass_configs):
54-
mod = LowerAndLegalize(mod, target)
54+
mod = CUDAPassPipelineBodyPrologue(mod, target)
5555
mod = tl.transform.LowerSharedTmem()(mod)
5656
mod = tl.transform.IfStmtBinding()(mod)
5757
mod = tl.transform.PlanAndUpdateBufferAllocationLocation()(mod)

testing/python/transform/test_tilelang_transform_lower_shared_barrier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tilelang.utils.target import determine_target
55
import tilelang.language as T
66
import tilelang.testing
7-
from tilelang.engine.phase import LowerAndLegalize
7+
from tilelang.backend.cuda.pipeline import CUDAPassPipelineBodyPrologue
88
from tvm import tirx
99

1010
auto_target = tvm.target.Target(determine_target("auto"))
@@ -158,7 +158,7 @@ def func(
158158
target = tvm.target.Target({"kind": "cuda", "arch": "sm_100"})
159159
with tvm.transform.PassContext(config=pass_configs), target:
160160
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
161-
mod = LowerAndLegalize(mod, target)
161+
mod = CUDAPassPipelineBodyPrologue(mod, target)
162162
mod = tl.transform.LowerSharedTmem()(mod)
163163
mod = tl.transform.IfStmtBinding()(mod)
164164
mod = tl.transform.PlanAndUpdateBufferAllocationLocation()(mod)

testing/python/transform/test_tilelang_transform_plan_update_buffer_allocation_location.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
import tilelang.language as T
33
import tilelang.testing
44
from tilelang import tvm
5-
from tilelang.engine.phase import LowerAndLegalize
5+
from tilelang.backend.cuda.pipeline import CUDAPassPipelineBodyPrologue
66

77

88
def _apply_plan_update(func: tvm.tirx.PrimFunc) -> tvm.IRModule:
99
target = tvm.target.Target("cuda")
1010
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
1111
with target:
12-
mod = LowerAndLegalize(mod, target)
12+
mod = CUDAPassPipelineBodyPrologue(mod, target)
1313
mod = tl.transform.LowerSharedTmem()(mod)
1414
mod = tl.transform.IfStmtBinding()(mod)
1515
mod = tl.transform.PlanAndUpdateBufferAllocationLocation()(mod)

tilelang/backend/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
1-
# Backend packages for Metal (other backends live in cpu/, cuda/, rocm/).
1+
# Import built-in backend packages so their pipelines register.
2+
from . import cpu as _cpu # noqa: F401,E402
3+
from . import common as _common # noqa: F401,E402
4+
from . import cuda as _cuda # noqa: F401,E402
5+
from . import metal as _metal # noqa: F401,E402
6+
from . import rocm as _rocm # noqa: F401,E402

tilelang/backend/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from __future__ import annotations
2+
3+
from tilelang.backend.pipeline import Pipeline, register_pipeline
4+
from tilelang.backend.cpu.pipeline import CPUPassPipelineBody
5+
6+
7+
register_pipeline(Pipeline("webgpu", CPUPassPipelineBody))

tilelang/backend/cpu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import pipeline # noqa: F401

tilelang/backend/cpu/pipeline.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import annotations
2+
3+
from tvm import IRModule, s_tir, tirx
4+
from tvm.target import Target
5+
6+
import tilelang
7+
from tilelang.backend.pipeline import Pipeline, register_pipeline
8+
from tilelang.backend.pipeline_utils import (
9+
LayoutVisual,
10+
allow_vectorize,
11+
should_disable_shared_memory_reuse,
12+
should_enable_aggressive_merge,
13+
should_enable_race_check,
14+
should_force_let_inline,
15+
)
16+
17+
18+
def CPUPassPipelineBody(mod: IRModule, target: Target) -> IRModule:
19+
mod = tirx.transform.BindTarget(target)(mod)
20+
pass_ctx = tilelang.transform.get_pass_context()
21+
22+
if should_force_let_inline():
23+
mod = tilelang.transform.LetInline()(mod)
24+
mod = tilelang.transform.AddWrapperForSingleBufStore()(mod)
25+
mod = tilelang.transform.LegalizeNegativeIndex()(mod)
26+
if should_enable_race_check():
27+
mod = tilelang.transform.VerifyParallelLoop()(mod)
28+
mod = tilelang.transform.InjectAssumes()(mod)
29+
mod = tilelang.transform.Simplify()(mod)
30+
mod = tilelang.transform.LayoutReducer()(mod)
31+
32+
mod = tilelang.transform.IfStmtBinding()(mod)
33+
mod = tilelang.transform.PipelinePlanning()(mod)
34+
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
35+
mod = tilelang.transform.Simplify()(mod)
36+
37+
mod = tilelang.transform.LayoutInference()(mod)
38+
LayoutVisual(mod)
39+
mod = tilelang.transform.LowerTileOp()(mod)
40+
41+
mod = tilelang.transform.DecoupleTypeCast()(mod)
42+
mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
43+
mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod)
44+
mod = tilelang.transform.LowerAccessPtr()(mod)
45+
mod = tilelang.transform.Simplify()(mod)
46+
mod = tilelang.transform.HoistNonRestrictParams()(mod)
47+
48+
mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod)
49+
mod = tilelang.transform.HoistGlobalBufferAllocations()(mod)
50+
mod = tilelang.transform.LowerOpaqueBlock()(mod)
51+
mod = tilelang.transform.Simplify()(mod)
52+
mod = tirx.transform.NarrowDataType(32)(mod)
53+
mod = tilelang.transform.FlattenBuffer()(mod)
54+
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
55+
mod = tirx.transform.Simplify()(mod)
56+
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
57+
mod = tilelang.transform.StorageRewrite()(mod)
58+
mod = tilelang.transform.LoopUnswitching()(mod)
59+
mod = tilelang.transform.UnrollLoop()(mod)
60+
mod = s_tir.transform.RenormalizeSplitPattern()(mod)
61+
mod = tirx.transform.Simplify()(mod)
62+
mod = tirx.transform.RemoveNoOp()(mod)
63+
mod = s_tir.transform.HoistIfThenElse()(mod)
64+
65+
mod = tirx.transform.VerifyMemory()(mod)
66+
mod = tirx.transform.AnnotateEntryFunc()(mod)
67+
mod = s_tir.transform.InferFragment()(mod)
68+
mod = tilelang.transform.LowerThreadAllreduce()(mod)
69+
70+
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
71+
mod = tilelang.transform.SplitHostDevice()(mod)
72+
mod = tilelang.transform.AnnotateReadOnlyParams()(mod)
73+
74+
enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target)
75+
disable_reuse = should_disable_shared_memory_reuse(pass_ctx=pass_ctx)
76+
mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge, disable_reuse=disable_reuse)(mod)
77+
78+
mod = tilelang.transform.ThreadSync("shared")(mod)
79+
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
80+
mod = tilelang.transform.MergeIfStmt()(mod)
81+
mod = tilelang.transform.MakePackedAPI()(mod)
82+
mod = tilelang.transform.Simplify()(mod)
83+
mod = tilelang.transform.LowerDeviceKernelLaunch()(mod)
84+
return mod
85+
86+
87+
for _kind in ("c", "llvm"):
88+
register_pipeline(Pipeline(_kind, CPUPassPipelineBody))

tilelang/backend/cuda/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import pipeline # noqa: F401

0 commit comments

Comments
 (0)