Skip to content

Commit 0855509

Browse files
authored
[SPIR-V] Fix bool cast on buffers with swizzle (#7497)
HLSL resources can store booleans. SPIR-V resources can't. We handle this by using integers in resources, and casting at the interface. Swizzle path was handled a bit differently, and was not going through the common load/store path which handles the cast. Fixes #7475
1 parent 20f291e commit 0855509

3 files changed

Lines changed: 36 additions & 9 deletions

File tree

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,7 +1281,8 @@ SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr,
12811281
}
12821282

12831283
SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr,
1284-
SpirvInstruction *info) {
1284+
SpirvInstruction *info,
1285+
SourceRange rangeOverride) {
12851286
const auto exprType = expr->getType();
12861287

12871288
// Do nothing if this is already rvalue
@@ -1316,9 +1317,11 @@ SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr,
13161317
return info;
13171318
}
13181319

1320+
SourceRange range =
1321+
(rangeOverride != SourceRange()) ? rangeOverride : expr->getSourceRange();
13191322
SpirvInstruction *loadedInstr = nullptr;
1320-
loadedInstr = spvBuilder.createLoad(exprType, info, expr->getExprLoc(),
1321-
expr->getSourceRange());
1323+
loadedInstr =
1324+
spvBuilder.createLoad(exprType, info, expr->getExprLoc(), range);
13221325
assert(loadedInstr);
13231326

13241327
// Special-case: According to the SPIR-V Spec: There is no physical size or
@@ -7969,15 +7972,12 @@ SpirvInstruction *SpirvEmitter::tryToAssignToVectorElements(
79697972
}
79707973

79717974
auto *vec1 = doExpr(base, range);
7972-
auto *vec1Val =
7973-
vec1->isRValue()
7974-
? vec1
7975-
: spvBuilder.createLoad(baseType, vec1, base->getLocStart(), range);
7975+
auto *vec1Val = vec1->isRValue() ? vec1 : loadIfGLValue(base, vec1, range);
79767976
auto *shuffle = spvBuilder.createVectorShuffle(
79777977
baseType, vec1Val, rhs, selectors, lhs->getLocStart(), range);
79787978

79797979
if (!tryToAssignToRWBufferRWTexture(base, shuffle))
7980-
spvBuilder.createStore(vec1, shuffle, lhs->getLocStart(), range);
7980+
storeValue(vec1, shuffle, base->getType(), lhs->getLocStart(), range);
79817981

79827982
// TODO: OK, this return value is incorrect for compound assignments, for
79837983
// which cases we should return lvalues. Should at least emit errors if

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ class SpirvEmitter : public ASTConsumer {
176176
/// Overload with pre computed SpirvEvalInfo.
177177
///
178178
/// The given expr will not be evaluated again.
179-
SpirvInstruction *loadIfGLValue(const Expr *expr, SpirvInstruction *info);
179+
SpirvInstruction *loadIfGLValue(const Expr *expr, SpirvInstruction *info,
180+
SourceRange rangeOverride = {});
180181

181182
/// Loads the pointer of the aliased-to-variable if the given expression is a
182183
/// DeclRefExpr referencing an alias variable. See DeclResultIdMapper for
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: %dxc -T cs_6_0 -E main -fcgl %s -spirv | FileCheck %s
2+
3+
RWStructuredBuffer<bool4> buffer;
4+
5+
// CHECK-DAG: [[v4_0:%[0-9]+]] = OpConstantComposite %v4uint %uint_0 %uint_0 %uint_0 %uint_0
6+
// CHECK-DAG: [[v4_1:%[0-9]+]] = OpConstantComposite %v4uint %uint_1 %uint_1 %uint_1 %uint_1
7+
8+
[numthreads(1, 1, 1)]
9+
void main()
10+
{
11+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_v4uint %buffer %int_0 %uint_0
12+
// CHECK: [[load:%[0-9]+]] = OpLoad %v4uint [[ptr]]
13+
// CHECK: [[cast:%[0-9]+]] = OpINotEqual %v4bool [[load]] [[v4_0]]
14+
// CHECK: [[shuf:%[0-9]+]] = OpVectorShuffle %v3bool [[cast]] [[cast]] 0 1 2
15+
// CHECK: OpStore %a [[shuf]]
16+
bool3 a = buffer[0].xyz;
17+
18+
// CHECK: [[a:%[0-9]+]] = OpLoad %v3bool %a
19+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_v4uint %buffer %int_0 %uint_1
20+
// CHECK: [[load:%[0-9]+]] = OpLoad %v4uint [[ptr]]
21+
// CHECK: [[cast:%[0-9]+]] = OpINotEqual %v4bool [[load]] [[v4_0]]
22+
// CHECK: [[shuf:%[0-9]+]] = OpVectorShuffle %v4bool [[cast]] [[a]] 4 5 6 3
23+
// CHECK: [[cast:%[0-9]+]] = OpSelect %v4uint [[shuf]] [[v4_1]] [[v4_0]]
24+
// CHECK: OpStore [[ptr]] [[cast]]
25+
buffer[1].xyz = a;
26+
}

0 commit comments

Comments
 (0)