-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Expand file tree
/
Copy pathbind_parallel_loops_to_threads.cc
More file actions
161 lines (139 loc) · 5.95 KB
/
bind_parallel_loops_to_threads.cc
File metadata and controls
161 lines (139 loc) · 5.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file bind_parallel_loops_to_threads.cc
* \brief Convert ForKind::kParallel loops to GPU thread bindings.
*
* Semantics:
* - Only runs when the PrimFunc carries a `tvm::attr::kTarget` that refers to a GPU device.
* Functions without a target attribute are left unchanged (no ambient `Target::Current` guess).
* - The outermost `kParallel` loop in the function is rewritten to `blockIdx.x` / `threadIdx.x`
* `thread_extent` scopes, with a guard `if (global_idx < extent)` and no else-branch.
* - Nested `kParallel` loops (parallel inside parallel) are rejected: binding only the outer
* parallel nest would leave inner `kParallel` serial within the mapped kernel, which is
* almost never what users intend.
* - A `kParallel` that appears inside an existing thread environment (`thread_extent` /
* `virtual_thread`) is left unchanged so it does not introduce conflicting thread bindings.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/s_tir/stmt.h>
#include <tvm/target/target.h>
#include <tvm/tirx/op.h>
#include <tvm/tirx/stmt.h>
#include <tvm/tirx/stmt_functor.h>
#include <tvm/tirx/transform.h>
namespace tvm {
namespace tirx {
namespace {
static bool IsGpuDeviceType(int dev_type) {
return dev_type == kDLCUDA || dev_type == kDLROCM || dev_type == kDLOpenCL ||
dev_type == kDLVulkan || dev_type == kDLMetal || dev_type == kDLWebGPU;
}
class ParallelLoopToThreadBindingMutator : public StmtExprMutator {
public:
explicit ParallelLoopToThreadBindingMutator(int64_t max_threads_per_block)
: max_threads_per_block_(max_threads_per_block) {}
private:
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) {
bool prev = in_thread_env_;
in_thread_env_ = true;
Stmt ret = StmtExprMutator::VisitStmt_(op);
in_thread_env_ = prev;
return ret;
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt TransformParallelFor(const ForNode* for_node) {
if (in_thread_env_) {
return ffi::GetRef<Stmt>(for_node);
}
DataType dtype = for_node->loop_var.dtype();
PrimExpr min = cast(dtype, for_node->min);
PrimExpr extent = cast(dtype, for_node->extent);
PrimExpr max_threads = IntImm(dtype, max_threads_per_block_);
PrimExpr num_blocks = ceildiv(extent, max_threads);
Var tx_var("threadIdx.x", dtype);
Var bx_var("blockIdx.x", dtype);
IterVar tx_iter(Range::FromMinExtent(IntImm(dtype, 0), max_threads), tx_var,
IterVarType::kThreadIndex, "threadIdx.x");
IterVar bx_iter(Range::FromMinExtent(IntImm(dtype, 0), num_blocks), bx_var,
IterVarType::kThreadIndex, "blockIdx.x");
PrimExpr global_idx = cast(dtype, bx_var * max_threads + tx_var);
PrimExpr mapped_idx = cast(dtype, min + global_idx);
Stmt mapped_body = Substitute(for_node->body, {{Var(for_node->loop_var), mapped_idx}});
mapped_body = IfThenElse(global_idx < extent, mapped_body);
Stmt body_with_tx = AttrStmt(tx_iter, tirx::attr::thread_extent, max_threads, mapped_body);
Stmt body_with_bx = AttrStmt(bx_iter, tirx::attr::thread_extent, num_blocks, body_with_tx);
return body_with_bx;
}
Stmt VisitStmt_(const ForNode* op) final {
if (op->kind == ForKind::kThreadBinding) {
bool prev = in_thread_env_;
in_thread_env_ = true;
Stmt ret = StmtExprMutator::VisitStmt_(op);
in_thread_env_ = prev;
return ret;
}
if (op->kind != ForKind::kParallel) {
return StmtExprMutator::VisitStmt_(op);
}
if (in_parallel_loop_) {
TVM_FFI_THROW(InternalError)
<< "BindParallelLoopsToThreads does not support nested parallel loops. "
<< "Inner parallel loops become serial once bound into a GPU kernel. "
<< "Please rewrite the TIR to avoid nested T.parallel.";
}
bool prev_in_parallel = in_parallel_loop_;
in_parallel_loop_ = true;
For updated = Downcast<For>(StmtExprMutator::VisitStmt_(op));
in_parallel_loop_ = prev_in_parallel;
return TransformParallelFor(updated.get());
}
int64_t max_threads_per_block_;
bool in_thread_env_{false};
bool in_parallel_loop_{false};
};
} // namespace
namespace transform {
Pass BindParallelLoopsToThreads() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
if (!opt_target || !IsGpuDeviceType(opt_target.value()->GetTargetDeviceType())) {
return f;
}
Target target = opt_target.value();
int64_t max_threads_per_block = 1024;
if (auto opt_max_threads = target->GetAttr<Integer>("max_num_threads")) {
max_threads_per_block = opt_max_threads.value()->value;
}
PrimFuncNode* n = f.CopyOnWrite();
n->body = ParallelLoopToThreadBindingMutator(max_threads_per_block)(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tirx.BindParallelLoopsToThreads", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tirx.transform.BindParallelLoopsToThreads", BindParallelLoopsToThreads);
}
} // namespace transform
} // namespace tirx
} // namespace tvm