Skip to content

Commit ec3171a

Browse files
authored
[REFACTOR][TIR] Tie AnnotateDeviceRegions/SplitHostDevice/LowerDeviceKernelLaunch together (#19605)
## Summary These three passes are logically a single host/device split step; having intermediaries between them obscures the model and blocks folding them into one pass. This PR moves each intermediary to the position its actual ordering constraint allows, so that `AnnotateDeviceRegions`, `SplitHostDevice`, and `LowerDeviceKernelLaunch` run consecutively in every pipeline. ## Rationale - `MergeSharedMemoryAllocations` moves **before** `AnnotateDeviceRegions` (the only legal position: `LowerDeviceKernelLaunch` requires at most one dyn-shmem allocation per kernel, so Merge cannot move past Lower). - `MakePackedAPI` moves **after** `LowerDeviceKernelLaunch` (Lower's `kCallingConv = kDeviceKernelLaunch` flag causes `MakePackedAPI` to correctly skip device kernels; the host body's lowered `tvm_call_packed` is transparent to `MakePackedAPI`'s subroutine rewriter). - `FP8StorageLegalize` / `BF16StorageLegalize` move **after** `MakePackedAPI` (their `buffer_map.size()==0` ICHECK requires `MakePackedAPI` to have cleared the map). Prereq for Phase 2: collapsing the three consecutive passes into a single `tirx.transform.SplitHostDevice` with three commented regions. ## Test plan - [x] tests/python/tirx-transform/ target-pass unit tests (25 pass) - [x] tests/python/s_tir/transform/test_merge_dynamic_shared_memory_allocations.py (5 pass) - [x] tests/python/tirx-transform/test_tir_transform_fp8_legalize.py / test_tir_transform_bf16_legalize.py (13 pass) - [x] tests/python/codegen/test_target_codegen_c_host.py / test_target_codegen_device.py (6 pass including test_subroutine_call — verifies Risk #2) - [x] pre-commit run --all-files clean - [ ] CI: lint / Windows / MacOS
1 parent e159487 commit ec3171a

6 files changed

Lines changed: 434 additions & 215 deletions

File tree

python/tvm/s_tir/backend/adreno/pipeline.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,13 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
108108
passes.append(s_tir.transform.InjectPTXLDG32())
109109
passes.extend(
110110
[
111+
s_tir.transform.MergeSharedMemoryAllocations(),
111112
tirx.transform.AnnotateDeviceRegions(),
112113
tirx.transform.SplitHostDevice(),
113-
# MergeSharedMemoryAllocations must follow SplitHostDevice.
114-
s_tir.transform.MergeSharedMemoryAllocations(),
114+
tirx.transform.LowerDeviceKernelLaunch(),
115115
tirx.transform.MakePackedAPI(),
116116
tirx.transform.FP8StorageLegalize(),
117117
tirx.transform.BF16StorageLegalize(),
118-
tirx.transform.LowerDeviceKernelLaunch(),
119118
]
120119
)
121120
mod = tvm.ir.transform.Sequential(passes)(mod)

python/tvm/s_tir/pipeline.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,13 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
108108
passes.append(s_tir.transform.InjectPTXLDG32())
109109
passes.extend(
110110
[
111+
s_tir.transform.MergeSharedMemoryAllocations(),
111112
tirx.transform.AnnotateDeviceRegions(),
112113
tirx.transform.SplitHostDevice(),
113-
# MergeSharedMemoryAllocations must follow SplitHostDevice.
114-
s_tir.transform.MergeSharedMemoryAllocations(),
114+
tirx.transform.LowerDeviceKernelLaunch(),
115115
tirx.transform.MakePackedAPI(),
116116
tirx.transform.FP8StorageLegalize(),
117117
tirx.transform.BF16StorageLegalize(),
118-
tirx.transform.LowerDeviceKernelLaunch(),
119118
]
120119
)
121120
mod = tvm.ir.transform.Sequential(passes)(mod)

python/tvm/tirx/compilation_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
5050
tirx.transform.AnnotateEntryFunc(),
5151
tirx.transform.AnnotateDeviceRegions(),
5252
tirx.transform.SplitHostDevice(),
53+
tirx.transform.LowerDeviceKernelLaunch(),
5354
tirx.transform.MakePackedAPI(),
5455
tirx.transform.FP8StorageLegalize(),
5556
tirx.transform.BF16StorageLegalize(),
56-
tirx.transform.LowerDeviceKernelLaunch(),
5757
]
5858
)
5959
mod = tvm.ir.transform.Sequential(passes)(mod)
@@ -91,10 +91,10 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
9191
tirx.transform.AnnotateEntryFunc(),
9292
tirx.transform.AnnotateDeviceRegions(),
9393
tirx.transform.SplitHostDevice(),
94+
tirx.transform.LowerDeviceKernelLaunch(),
9495
tirx.transform.MakePackedAPI(),
9596
tirx.transform.FP8StorageLegalize(),
9697
tirx.transform.BF16StorageLegalize(),
97-
tirx.transform.LowerDeviceKernelLaunch(),
9898
]
9999
)
100100
mod = tvm.ir.transform.Sequential(passes)(mod)
@@ -124,8 +124,8 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
124124
tirx.transform.AnnotateEntryFunc(),
125125
tirx.transform.AnnotateDeviceRegions(),
126126
tirx.transform.SplitHostDevice(),
127-
tirx.transform.MakePackedAPI(),
128127
tirx.transform.LowerDeviceKernelLaunch(),
128+
tirx.transform.MakePackedAPI(),
129129
]
130130
return tvm.ir.transform.Sequential(passes)(mod)
131131

0 commit comments

Comments
 (0)