Skip to content

Commit 5bc9ddf

Browse files
committed
Split fast compile legalization from defaults
1 parent 4fce38c commit 5bc9ddf

5 files changed

Lines changed: 160 additions & 75 deletions

File tree

include/spirv-tools/optimizer.hpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,10 @@ class SPIRV_TOOLS_EXPORT Optimizer {
134134
// interface are considered live and are not eliminated.
135135
Optimizer& RegisterLegalizationPasses();
136136
Optimizer& RegisterLegalizationPasses(bool preserve_interface);
137-
Optimizer& RegisterLegalizationPasses(bool preserve_interface,
138-
bool include_loop_unroll,
139-
SSARewriteMode ssa_rewrite_mode);
137+
Optimizer& RegisterLegalizationPassesFastCompile();
138+
Optimizer& RegisterLegalizationPassesFastCompile(
139+
bool preserve_interface, bool include_loop_unroll,
140+
SSARewriteMode ssa_rewrite_mode);
140141

141142
// Register passes specified in the list of |flags|. Each flag must be a
142143
// string of a form accepted by Optimizer::FlagHasValidForm().
@@ -657,6 +658,11 @@ Optimizer::PassToken CreateLoopPeelingPass();
657658
// Works best after LICM and local multi store elimination pass.
658659
Optimizer::PassToken CreateLoopUnswitchPass();
659660

661+
// Creates a pass to legalize multidimensional arrays for Vulkan.
662+
// This pass will replace multidimensional arrays of resources with a single
663+
// dimensional array. Combine-access-chains should be run before this pass.
664+
Optimizer::PassToken CreateLegalizeMultidimArrayPass();
665+
660666
// Create global value numbering pass.
661667
// This pass will look for instructions where the same value is computed on all
662668
// paths leading to the instruction. Those instructions are deleted.
@@ -716,8 +722,8 @@ Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor = 0);
716722
// operations on SSA IDs. This allows SSA optimizers to act on these variables.
717723
// Only variables that are local to the function and of supported types are
718724
// processed (see IsSSATargetVar for details).
719-
Optimizer::PassToken CreateSSARewritePass(
720-
SSARewriteMode mode = SSARewriteMode::All);
725+
Optimizer::PassToken CreateSSARewritePass();
726+
Optimizer::PassToken CreateSSARewritePass(SSARewriteMode mode);
721727

722728
// Create pass to convert relaxed precision instructions to half precision.
723729
// This pass converts as many relaxed float32 arithmetic operations to half as

source/opt/mem_pass.cpp

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,28 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
5353
}
5454

5555
bool MemPass::IsTargetType(const Instruction* typeInst) const {
56-
switch (ssa_rewrite_mode_) {
57-
case SSARewriteMode::None:
58-
return false;
59-
case SSARewriteMode::OpaqueOnly:
60-
if (typeInst->IsOpaqueType()) return true;
61-
break;
62-
case SSARewriteMode::SpecialTypes:
63-
if (typeInst->IsOpaqueType()) return true;
64-
switch (typeInst->opcode()) {
65-
case spv::Op::OpTypePointer:
66-
case spv::Op::OpTypeCooperativeMatrixNV:
67-
case spv::Op::OpTypeCooperativeMatrixKHR:
68-
return true;
69-
default:
70-
break;
71-
}
72-
break;
73-
case SSARewriteMode::All:
74-
if (IsBaseTargetType(typeInst)) return true;
75-
break;
76-
}
56+
switch (ssa_rewrite_mode_) {
57+
case SSARewriteMode::None:
58+
return false;
59+
case SSARewriteMode::OpaqueOnly:
60+
if (typeInst->IsOpaqueType()) return true;
61+
break;
62+
case SSARewriteMode::SpecialTypes:
63+
if (typeInst->IsOpaqueType()) return true;
64+
switch (typeInst->opcode()) {
65+
case spv::Op::OpTypePointer:
66+
case spv::Op::OpTypeUntypedPointerKHR:
67+
case spv::Op::OpTypeCooperativeMatrixNV:
68+
case spv::Op::OpTypeCooperativeMatrixKHR:
69+
return true;
70+
default:
71+
break;
72+
}
73+
break;
74+
case SSARewriteMode::All:
75+
if (IsBaseTargetType(typeInst)) return true;
76+
break;
77+
}
7778
if (typeInst->opcode() == spv::Op::OpTypeArray) {
7879
if (!IsTargetType(
7980
get_def_use_mgr()->GetDef(typeInst->GetSingleWordOperand(1)))) {
@@ -92,7 +93,8 @@ bool MemPass::IsTargetType(const Instruction* typeInst) const {
9293

9394
bool MemPass::IsNonPtrAccessChain(const spv::Op opcode) const {
9495
return opcode == spv::Op::OpAccessChain ||
95-
opcode == spv::Op::OpInBoundsAccessChain;
96+
opcode == spv::Op::OpInBoundsAccessChain ||
97+
opcode == spv::Op::OpUntypedAccessChainKHR;
9698
}
9799

98100
bool MemPass::IsPtr(uint32_t ptrId) {
@@ -108,11 +110,14 @@ bool MemPass::IsPtr(uint32_t ptrId) {
108110
ptrInst = get_def_use_mgr()->GetDef(varId);
109111
}
110112
const spv::Op op = ptrInst->opcode();
111-
if (op == spv::Op::OpVariable || IsNonPtrAccessChain(op)) return true;
113+
if (op == spv::Op::OpVariable || op == spv::Op::OpUntypedVariableKHR ||
114+
IsNonPtrAccessChain(op))
115+
return true;
112116
const uint32_t varTypeId = ptrInst->type_id();
113117
if (varTypeId == 0) return false;
114118
const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
115-
return varTypeInst->opcode() == spv::Op::OpTypePointer;
119+
return varTypeInst->opcode() == spv::Op::OpTypePointer ||
120+
varTypeInst->opcode() == spv::Op::OpTypeUntypedPointerKHR;
116121
}
117122

118123
Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
@@ -122,11 +127,13 @@ Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
122127

123128
switch (ptrInst->opcode()) {
124129
case spv::Op::OpVariable:
130+
case spv::Op::OpUntypedVariableKHR:
125131
case spv::Op::OpFunctionParameter:
126132
varInst = ptrInst;
127133
break;
128134
case spv::Op::OpAccessChain:
129135
case spv::Op::OpInBoundsAccessChain:
136+
case spv::Op::OpUntypedAccessChainKHR:
130137
case spv::Op::OpPtrAccessChain:
131138
case spv::Op::OpInBoundsPtrAccessChain:
132139
case spv::Op::OpImageTexelPointer:
@@ -139,7 +146,8 @@ Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
139146
break;
140147
}
141148

142-
if (varInst->opcode() == spv::Op::OpVariable) {
149+
if (varInst->opcode() == spv::Op::OpVariable ||
150+
varInst->opcode() == spv::Op::OpUntypedVariableKHR) {
143151
*varId = varInst->result_id();
144152
} else {
145153
*varId = 0;
@@ -254,8 +262,10 @@ void MemPass::DCEInst(Instruction* inst,
254262
}
255263
}
256264

257-
MemPass::MemPass(SSARewriteMode ssa_rewrite_mode)
258-
: ssa_rewrite_mode_(ssa_rewrite_mode) {}
265+
MemPass::MemPass() {}
266+
267+
MemPass::MemPass(SSARewriteMode ssa_rewrite_mode)
268+
: ssa_rewrite_mode_(ssa_rewrite_mode) {}
259269

260270
bool MemPass::HasOnlySupportedRefs(uint32_t varId) {
261271
return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {

source/opt/mem_pass.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ class MemPass : public Pass {
6969
void CollectTargetVars(Function* func);
7070

7171
protected:
72-
explicit MemPass(SSARewriteMode ssa_rewrite_mode = SSARewriteMode::All);
72+
MemPass();
73+
explicit MemPass(SSARewriteMode ssa_rewrite_mode);
7374

7475
// Returns true if |typeInst| is a scalar type
7576
// or a vector or matrix

source/opt/optimizer.cpp

Lines changed: 108 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,69 @@ Optimizer& Optimizer::RegisterPass(PassToken&& p) {
120120
// The legalization problem is essentially a very general copy propagation
121121
// problem. The optimization we use are all used to either do copy propagation
122122
// or enable more copy propagation.
123-
Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface,
124-
bool include_loop_unroll,
125-
SSARewriteMode ssa_rewrite_mode) {
123+
Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface) {
124+
return
125+
// Wrap OpKill instructions so all other code can be inlined.
126+
RegisterPass(CreateWrapOpKillPass())
127+
// Remove unreachable block so that merge return works.
128+
.RegisterPass(CreateDeadBranchElimPass())
129+
// Merge the returns so we can inline.
130+
.RegisterPass(CreateMergeReturnPass())
131+
// Make sure uses and definitions are in the same function.
132+
.RegisterPass(CreateInlineExhaustivePass())
133+
// Make private variable function scope
134+
.RegisterPass(CreateEliminateDeadFunctionsPass())
135+
.RegisterPass(CreatePrivateToLocalPass())
136+
// Fix up the storage classes that DXC may have purposely generated
137+
// incorrectly. All functions are inlined, and a lot of dead code has
138+
// been removed.
139+
.RegisterPass(CreateFixStorageClassPass())
140+
// Propagate the value stored to the loads in very simple cases.
141+
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
142+
.RegisterPass(CreateLocalSingleStoreElimPass())
143+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
144+
// Split up aggregates so they are easier to deal with.
145+
.RegisterPass(CreateScalarReplacementPass(0))
146+
// Remove loads and stores so everything is in intermediate values.
147+
// Takes care of copy propagation of non-members.
148+
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
149+
.RegisterPass(CreateLocalSingleStoreElimPass())
150+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
151+
.RegisterPass(CreateLocalMultiStoreElimPass())
152+
.RegisterPass(CreateCombineAccessChainsPass())
153+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
154+
.RegisterPass(CreateLegalizeMultidimArrayPass())
155+
// Propagate constants to get as many constant conditions on branches
156+
// as possible.
157+
.RegisterPass(CreateCCPPass())
158+
.RegisterPass(CreateLoopUnrollPass(true))
159+
.RegisterPass(CreateDeadBranchElimPass())
160+
// Copy propagate members. Cleans up code sequences generated by
161+
// scalar replacement. Also important for removing OpPhi nodes.
162+
.RegisterPass(CreateSimplificationPass())
163+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
164+
.RegisterPass(CreateCopyPropagateArraysPass())
165+
// May need loop unrolling here see
166+
// https://github.com/Microsoft/DirectXShaderCompiler/pull/930
167+
// Get rid of unused code that contain traces of illegal code
168+
// or unused references to unbound external objects
169+
.RegisterPass(CreateVectorDCEPass())
170+
.RegisterPass(CreateDeadInsertElimPass())
171+
.RegisterPass(CreateReduceLoadSizePass())
172+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
173+
.RegisterPass(CreateRemoveUnusedInterfaceVariablesPass())
174+
.RegisterPass(CreateInterpolateFixupPass())
175+
.RegisterPass(CreateInvocationInterlockPlacementPass())
176+
.RegisterPass(CreateOpExtInstWithForwardReferenceFixupPass());
177+
}
178+
179+
Optimizer& Optimizer::RegisterLegalizationPasses() {
180+
return RegisterLegalizationPasses(false);
181+
}
182+
183+
Optimizer& Optimizer::RegisterLegalizationPassesFastCompile(
184+
bool preserve_interface, bool include_loop_unroll,
185+
SSARewriteMode ssa_rewrite_mode) {
126186
auto& optimizer =
127187
// Wrap OpKill instructions so all other code can be inlined.
128188
RegisterPass(CreateWrapOpKillPass())
@@ -132,38 +192,38 @@ Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface,
132192
.RegisterPass(CreateMergeReturnPass())
133193
// Make sure uses and definitions are in the same function.
134194
.RegisterPass(CreateInlineExhaustivePass())
135-
.RegisterPass(CreateEliminateDeadFunctionsPass());
136-
optimizer.RegisterPass(CreatePrivateToLocalPass());
137-
// Fix up the storage classes that DXC may have purposely generated
138-
// incorrectly. All functions are inlined, and a lot of dead code has
139-
// been removed.
140-
optimizer.RegisterPass(CreateFixStorageClassPass());
141-
// Propagate the value stored to the loads in very simple cases.
142-
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
143-
.RegisterPass(CreateLocalSingleStoreElimPass())
144-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
145-
optimizer
146-
// Split up aggregates so they are easier to deal with.
147-
.RegisterPass(CreateScalarReplacementPass(0));
148-
// Remove loads and stores so everything is in intermediate values.
149-
// Takes care of copy propagation of non-members.
150-
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
151-
.RegisterPass(CreateLocalSingleStoreElimPass())
152-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
153-
if (ssa_rewrite_mode != SSARewriteMode::None) {
154-
optimizer.RegisterPass(CreateSSARewritePass(ssa_rewrite_mode));
195+
.RegisterPass(CreateEliminateDeadFunctionsPass());
196+
optimizer.RegisterPass(CreatePrivateToLocalPass());
197+
// Fix up the storage classes that DXC may have purposely generated
198+
// incorrectly. All functions are inlined, and a lot of dead code has
199+
// been removed.
200+
optimizer.RegisterPass(CreateFixStorageClassPass());
201+
// Propagate the value stored to the loads in very simple cases.
202+
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
203+
.RegisterPass(CreateLocalSingleStoreElimPass())
204+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
205+
optimizer
206+
// Split up aggregates so they are easier to deal with.
207+
.RegisterPass(CreateScalarReplacementPass(0));
208+
// Remove loads and stores so everything is in intermediate values.
209+
// Takes care of copy propagation of non-members.
210+
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
211+
.RegisterPass(CreateLocalSingleStoreElimPass())
212+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
213+
if (ssa_rewrite_mode != SSARewriteMode::None) {
214+
optimizer.RegisterPass(CreateSSARewritePass(ssa_rewrite_mode));
155215
}
156-
optimizer
157-
// Propagate constants to get as many constant conditions on branches
158-
// as possible.
159-
.RegisterPass(CreateCCPPass());
216+
optimizer
217+
// Propagate constants to get as many constant conditions on branches
218+
// as possible.
219+
.RegisterPass(CreateCCPPass());
160220
if (include_loop_unroll) {
161221
optimizer.RegisterPass(CreateLoopUnrollPass(true));
162222
}
163-
optimizer.RegisterPass(CreateDeadBranchElimPass())
164-
// Copy propagate members. Cleans up code sequences generated by scalar
165-
// replacement. Also important for removing OpPhi nodes.
166-
.RegisterPass(CreateSimplificationPass());
223+
optimizer.RegisterPass(CreateDeadBranchElimPass())
224+
// Copy propagate members. Cleans up code sequences generated by scalar
225+
// replacement. Also important for removing OpPhi nodes.
226+
.RegisterPass(CreateSimplificationPass());
167227
return optimizer
168228
// May need loop unrolling here see
169229
// https://github.com/Microsoft/DirectXShaderCompiler/pull/930
@@ -179,13 +239,9 @@ Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface,
179239
.RegisterPass(CreateOpExtInstWithForwardReferenceFixupPass());
180240
}
181241

182-
Optimizer& Optimizer::RegisterLegalizationPasses() {
183-
return RegisterLegalizationPasses(false, true, SSARewriteMode::All);
184-
}
185-
186-
Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface) {
187-
return RegisterLegalizationPasses(preserve_interface, true,
188-
SSARewriteMode::All);
242+
Optimizer& Optimizer::RegisterLegalizationPassesFastCompile() {
243+
return RegisterLegalizationPassesFastCompile(false, true,
244+
SSARewriteMode::All);
189245
}
190246

191247
Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) {
@@ -461,6 +517,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag,
461517
RegisterPass(CreateFoldSpecConstantOpAndCompositePass());
462518
} else if (pass_name == "loop-unswitch") {
463519
RegisterPass(CreateLoopUnswitchPass());
520+
} else if (pass_name == "legalize-multidim-array") {
521+
RegisterPass(CreateLegalizeMultidimArrayPass());
464522
} else if (pass_name == "scalar-replacement") {
465523
if (pass_args.size() == 0) {
466524
RegisterPass(CreateScalarReplacementPass(0));
@@ -1023,6 +1081,11 @@ Optimizer::PassToken CreateLoopUnswitchPass() {
10231081
MakeUnique<opt::LoopUnswitchPass>());
10241082
}
10251083

1084+
Optimizer::PassToken CreateLegalizeMultidimArrayPass() {
1085+
return MakeUnique<Optimizer::PassToken::Impl>(
1086+
MakeUnique<opt::LegalizeMultidimArrayPass>());
1087+
}
1088+
10261089
Optimizer::PassToken CreateRedundancyEliminationPass() {
10271090
return MakeUnique<Optimizer::PassToken::Impl>(
10281091
MakeUnique<opt::RedundancyEliminationPass>());
@@ -1072,9 +1135,14 @@ Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor) {
10721135
MakeUnique<opt::LoopUnroller>(fully_unroll, factor));
10731136
}
10741137

1075-
Optimizer::PassToken CreateSSARewritePass(SSARewriteMode mode) {
1138+
Optimizer::PassToken CreateSSARewritePass() {
1139+
return MakeUnique<Optimizer::PassToken::Impl>(
1140+
MakeUnique<opt::SSARewritePass>());
1141+
}
1142+
1143+
Optimizer::PassToken CreateSSARewritePass(SSARewriteMode mode) {
10761144
return MakeUnique<Optimizer::PassToken::Impl>(
1077-
MakeUnique<opt::SSARewritePass>(mode));
1145+
MakeUnique<opt::SSARewritePass>(mode));
10781146
}
10791147

10801148
Optimizer::PassToken CreateCopyPropagateArraysPass() {

source/opt/ssa_rewrite_pass.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,8 @@ class SSARewriter {
294294

295295
class SSARewritePass : public MemPass {
296296
public:
297-
explicit SSARewritePass(SSARewriteMode mode = SSARewriteMode::All)
298-
: MemPass(mode) {}
297+
SSARewritePass() = default;
298+
explicit SSARewritePass(SSARewriteMode mode) : MemPass(mode) {}
299299

300300
const char* name() const override { return "ssa-rewrite"; }
301301
Status Process() override;

0 commit comments

Comments
 (0)