Skip to content

Commit 58bdbcb

Browse files
committed
Use OpSelect
1 parent 301012e commit 58bdbcb

2 files changed

Lines changed: 93 additions & 43 deletions

File tree

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9353,13 +9353,81 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
93539353
retVal = processIntrinsicPrintf(callExpr);
93549354
break;
93559355
case hlsl::IntrinsicOp::IOP_usign: {
9356-
// Do SAbs followed by SSign
9357-
auto *absVal = processIntrinsicUsingGLSLInst(
9358-
callExpr, GLSLstd450::GLSLstd450SAbs,
9359-
/*actPerRowForMatrices*/ true, srcLoc, srcRange);
9360-
retVal = spvBuilder.createGLSLExtInst(callExpr->getType(),
9361-
GLSLstd450::GLSLstd450SSign, {absVal},
9362-
srcLoc, srcRange);
9356+
auto *zero = getValueZero(callExpr->getType());
9357+
auto *one = getValueOne(callExpr->getType());
9358+
// QualType elemType = {};
9359+
// if (isVectorType(callExpr->getArg(0)->getType(), &elemType) ||
9360+
// isMxNMatrix(callExpr->getArg(0)->getType(), &elemType, nullptr,
9361+
// nullptr)) {
9362+
// // For vector types, we need to create vector constants for zero and
9363+
// one. const uint32_t size =
9364+
// hlsl::GetHLSLVecSize(callExpr->getArg(0)->getType());
9365+
// auto zeroUint = getValueZero(elemType);
9366+
// // Create `size` vector of uint zeros.
9367+
// zerosUint = spvBuilder.createCompositeConstruct(
9368+
// elemType,
9369+
// std::vector<SpirvInstruction *>(size, zeroUint), srcLoc, srcRange);
9370+
// zeros = spvBuilder.createCompositeConstruct(
9371+
// elemType,
9372+
// std::vector<SpirvInstruction *>(size, zero), srcLoc, srcRange);
9373+
// ones = spvBuilder.createCompositeConstruct(
9374+
// elemType,
9375+
// std::vector<SpirvInstruction *>(size, one), srcLoc, srcRange);
9376+
// }
9377+
auto firstArgType = callExpr->getArg(0)->getType();
9378+
bool isNonTrivialVecOrMat =
9379+
(isVectorType(firstArgType) &&
9380+
hlsl::GetHLSLVecSize(firstArgType) != 1) ||
9381+
(isMxNMatrix(firstArgType) && !is1x1Matrix(firstArgType));
9382+
if (isNonTrivialVecOrMat) {
9383+
auto loc = callExpr->getArg(0)->getExprLoc();
9384+
auto range = callExpr->getArg(0)->getSourceRange();
9385+
auto *arg = callExpr->getArg(0);
9386+
auto argId = doExpr(arg);
9387+
9388+
const auto actOnEachVec = [this, loc, range, zero,
9389+
one](uint32_t index, QualType inType,
9390+
QualType outType,
9391+
SpirvInstruction *curRow) {
9392+
// Get size of the vector/matrix row.
9393+
const uint32_t size = hlsl::GetHLSLVecSize(outType);
9394+
const auto matElemType = hlsl::GetHLSLVecElementType(inType);
9395+
auto zeroUint = getValueZero(matElemType);
9396+
// Create `size` vector of uint zeros.
9397+
auto *zerosUint = spvBuilder.createCompositeConstruct(
9398+
astContext.getExtVectorType(matElemType, size),
9399+
std::vector<SpirvInstruction *>(size, zeroUint), loc, range);
9400+
// Compare if they are greater than zero.
9401+
auto *cmp = spvBuilder.createBinaryOp(
9402+
spv::Op::OpUGreaterThan,
9403+
astContext.getExtVectorType(astContext.BoolTy, size), curRow,
9404+
zerosUint, loc, range);
9405+
9406+
// Select between one and zero.
9407+
auto *zeros = spvBuilder.createCompositeConstruct(
9408+
astContext.getExtVectorType(astContext.IntTy, size),
9409+
std::vector<SpirvInstruction *>(size, zero), loc, range);
9410+
auto *ones = spvBuilder.createCompositeConstruct(
9411+
astContext.getExtVectorType(astContext.IntTy, size),
9412+
std::vector<SpirvInstruction *>(size, one), loc, range);
9413+
return spvBuilder.createSelect(
9414+
astContext.getExtVectorType(astContext.IntTy, size), cmp, ones,
9415+
zeros, loc, range);
9416+
};
9417+
if (isMxNMatrix(arg->getType()))
9418+
retVal =
9419+
processEachVectorInMatrix(arg, argId, actOnEachVec, loc, range);
9420+
else
9421+
retVal = actOnEachVec(0, arg->getType(), callExpr->getType(), argId);
9422+
} else {
9423+
auto *argVal = doExpr(callExpr->getArg(0));
9424+
auto *zeroUint = getValueZero(callExpr->getArg(0)->getType());
9425+
auto *cmp =
9426+
spvBuilder.createBinaryOp(spv::Op::OpUGreaterThan, astContext.BoolTy,
9427+
argVal, zeroUint, srcLoc, srcRange);
9428+
retVal = spvBuilder.createSelect(astContext.IntTy, cmp, one, zero, srcLoc,
9429+
srcRange);
9430+
}
93639431
} break;
93649432
case hlsl::IntrinsicOp::IOP_sign: {
93659433
if (isFloatOrVecMatOfFloatType(callExpr->getArg(0)->getType()))
Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,34 @@
11
// RUN: %dxc -T vs_6_0 -E main -fcgl %s -spirv | FileCheck %s
22

3-
// CHECK: [[glsl:%[0-9]+]] = OpExtInstImport "GLSL.std.450"
3+
// CHECK-DAG: %int_0 = OpConstant %int 0
4+
// CHECK-DAG: %int_1 = OpConstant %int 1
5+
// CHECK-DAG: %uint_0 = OpConstant %uint 0
46

57
void main() {
68
int result;
79
int3 result3;
810

911
// CHECK: [[a:%[0-9]+]] = OpLoad %uint %a
10-
// CHECK: [[abs_a:%[0-9]+]] = OpExtInst %int [[glsl]] SAbs [[a]]
11-
// CHECK-NEXT: [[sign_a:%[0-9]+]] = OpExtInst %int [[glsl]] SSign [[abs_a]]
12-
// CHECK-NEXT: OpStore %result [[sign_a]]
12+
// CHECK-NEXT: [[cmp_a:%[0-9]+]] = OpUGreaterThan %bool [[a]] %uint_0
13+
// CHECK-NEXT: [[select_a:%[0-9]+]] = OpSelect %int [[cmp_a]] %int_1 %int_0
14+
// CHECK-NEXT: OpStore %result [[select_a]]
1315
uint a;
1416
result = sign(a);
1517

16-
// CHECK-NEXT: [[b:%[0-9]+]] = OpLoad %uint %b
17-
// CHECK: [[abs_b:%[0-9]+]] = OpExtInst %int [[glsl]] SAbs [[b]]
18-
// CHECK-NEXT: [[sign_b:%[0-9]+]] = OpExtInst %int [[glsl]] SSign [[abs_b]]
19-
// CHECK-NEXT: OpStore %result [[sign_b]]
18+
// CHECK: [[b:%[0-9]+]] = OpLoad %uint %b
19+
// CHECK-NEXT: [[cmp_b:%[0-9]+]] = OpUGreaterThan %bool [[b]] %uint_0
20+
// CHECK-NEXT: [[select_b:%[0-9]+]] = OpSelect %int [[cmp_b]] %int_1 %int_0
21+
// CHECK-NEXT: OpStore %result [[select_b]]
2022
uint1 b;
2123
result = sign(b);
2224

23-
// CHECK-NEXT: [[c:%[0-9]+]] = OpLoad %v3uint %c
24-
// CHECK: [[abs_c:%[0-9]+]] = OpExtInst %v3int [[glsl]] SAbs [[c]]
25-
// CHECK-NEXT: [[sign_c:%[0-9]+]] = OpExtInst %v3int [[glsl]] SSign [[abs_c]]
26-
// CHECK-NEXT: OpStore %result3 [[sign_c]]
25+
// CHECK: [[c:%[0-9]+]] = OpLoad %v3uint %c
26+
// CHECK-NEXT: [[zeros_uint_c:%[0-9]+]] = OpConstantComposite %v3uint %uint_0 %uint_0 %uint_0
27+
// CHECK-NEXT: [[cmp_c:%[0-9]+]] = OpUGreaterThan %v3bool [[c]] [[zeros_uint_c]]
28+
// CHECK-NEXT: [[zeros_int_c:%[0-9]+]] = OpConstantComposite %v3int %int_0 %int_0 %int_0
29+
// CHECK-NEXT: [[ones_int_c:%[0-9]+]] = OpConstantComposite %v3int %int_1 %int_1 %int_1
30+
// CHECK-NEXT: [[select_c:%[0-9]+]] = OpSelect %v3int [[cmp_c]] [[ones_int_c]] [[zeros_int_c]]
31+
// CHECK-NEXT: OpStore %result3 [[select_c]]
2732
uint3 c;
2833
result3 = sign(c);
29-
30-
// CHECK: [[d:%[0-9]+]] = OpLoad %uint %d
31-
// CHECK: [[abs_d:%[0-9]+]] = OpExtInst %int [[glsl]] SAbs [[d]]
32-
// CHECK-NEXT: [[sign_d:%[0-9]+]] = OpExtInst %int [[glsl]] SSign [[abs_d]]
33-
// CHECK-NEXT: OpStore %result [[sign_d]]
34-
uint1x1 d;
35-
result = sign(d);
36-
37-
// CHECK-NEXT: [[e:%[0-9]+]] = OpLoad %v2uint %e
38-
// CHECK-NEXT: [[abs_e:%[0-9]+]] = OpExtInst %v2int [[glsl]] SAbs [[e]]
39-
// CHECK-NEXT: [[sign_e:%[0-9]+]] = OpExtInst %v2int [[glsl]] SSign [[abs_e]]
40-
// CHECK-NEXT: OpStore %result2 [[sign_e]]
41-
uint1x2 e;
42-
int2 result2 = sign(e);
43-
44-
// CHECK-NEXT: [[f:%[0-9]+]] = OpLoad %v4uint %f
45-
// CHECK-NEXT: [[abs_f:%[0-9]+]] = OpExtInst %v4int [[glsl]] SAbs [[f]]
46-
// CHECK-NEXT: [[sign_f:%[0-9]+]] = OpExtInst %v4int [[glsl]] SSign [[abs_f]]
47-
// CHECK-NEXT: OpStore %result4 [[sign_f]]
48-
uint4x1 f;
49-
int4 result4 = sign(f);
50-
51-
// TODO: Integer matrices are not supported yet. See intrinsics.intsign.hlsl
52-
}
34+
}

0 commit comments

Comments
 (0)