Skip to content

Commit bf0b93d

Browse files
committed
Use OpSelect
1 parent 301012e commit bf0b93d

3 files changed

Lines changed: 135 additions & 42 deletions

File tree

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9352,15 +9352,9 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
93529352
case hlsl::IntrinsicOp::IOP_printf:
93539353
retVal = processIntrinsicPrintf(callExpr);
93549354
break;
9355-
case hlsl::IntrinsicOp::IOP_usign: {
9356-
// Do SAbs followed by SSign
9357-
auto *absVal = processIntrinsicUsingGLSLInst(
9358-
callExpr, GLSLstd450::GLSLstd450SAbs,
9359-
/*actPerRowForMatrices*/ true, srcLoc, srcRange);
9360-
retVal = spvBuilder.createGLSLExtInst(callExpr->getType(),
9361-
GLSLstd450::GLSLstd450SSign, {absVal},
9362-
srcLoc, srcRange);
9363-
} break;
9355+
case hlsl::IntrinsicOp::IOP_usign:
9356+
retVal = processIntrinsicSignUnsignedInt(callExpr);
9357+
break;
93649358
case hlsl::IntrinsicOp::IOP_sign: {
93659359
if (isFloatOrVecMatOfFloatType(callExpr->getArg(0)->getType()))
93669360
retVal = processIntrinsicFloatSign(callExpr);
@@ -12665,6 +12659,80 @@ SpirvEmitter::processIntrinsicSaturate(const CallExpr *callExpr) {
1266512659
return nullptr;
1266612660
}
1266712661

12662+
SpirvInstruction *
12663+
SpirvEmitter::processIntrinsicSignUnsignedInt(const CallExpr *callExpr) {
12664+
const auto srcLoc = callExpr->getExprLoc();
12665+
const auto srcRange = callExpr->getSourceRange();
12666+
12667+
const Expr *firstArg = callExpr->getArg(0);
12668+
const QualType firstArgType = firstArg->getType();
12669+
auto elemType = QualType{};
12670+
uint32_t numRows;
12671+
uint32_t numCols;
12672+
uint32_t count;
12673+
bool isScalar =
12674+
isScalarType(firstArgType, &elemType) ||
12675+
(isVectorType(firstArgType, &elemType, &count) && count == 1) ||
12676+
(isMxNMatrix(firstArgType, &elemType, &numRows, &numCols) &&
12677+
(numRows == 1 && numCols == 1));
12678+
12679+
auto *zero = getValueZero(astContext.IntTy);
12680+
auto *one = getValueOne(astContext.IntTy);
12681+
if (isScalar) {
12682+
auto *argVal = doExpr(callExpr->getArg(0));
12683+
auto *zeroUint = getValueZero(callExpr->getArg(0)->getType());
12684+
auto *cmp =
12685+
spvBuilder.createBinaryOp(spv::Op::OpUGreaterThan, astContext.BoolTy,
12686+
argVal, zeroUint, srcLoc, srcRange);
12687+
return spvBuilder.createSelect(astContext.IntTy, cmp, one, zero, srcLoc,
12688+
srcRange);
12689+
}
12690+
12691+
uint32_t size;
12692+
if (isVectorType(firstArgType)) {
12693+
size = count;
12694+
} else if (is1xNMatrix(firstArgType)) {
12695+
size = numCols;
12696+
} else if (isMx1Matrix(firstArgType)) {
12697+
size = numRows;
12698+
} else {
12699+
size = numRows;
12700+
}
12701+
12702+
const auto actOnEachVec = [this, srcLoc, srcRange, zero, one, elemType,
12703+
size](uint32_t index, QualType inType,
12704+
QualType outType, SpirvInstruction *curRow) {
12705+
auto zeroUint = getValueZero(elemType);
12706+
// Create `size` vector of uint zeros.
12707+
auto *zerosUint = spvBuilder.getConstantComposite(
12708+
astContext.getExtVectorType(elemType, size),
12709+
std::vector<clang::spirv::SpirvConstant *>(size, zeroUint));
12710+
// Compare if they are greater than zero.
12711+
auto *cmp = spvBuilder.createBinaryOp(
12712+
spv::Op::OpUGreaterThan,
12713+
astContext.getExtVectorType(astContext.BoolTy, size), curRow, zerosUint,
12714+
srcLoc, srcRange);
12715+
12716+
// Create a vector of int ones and zeros.
12717+
auto *zeros = spvBuilder.getConstantComposite(
12718+
astContext.getExtVectorType(astContext.IntTy, size),
12719+
std::vector<clang::spirv::SpirvConstant *>(size, zero));
12720+
auto *ones = spvBuilder.getConstantComposite(
12721+
astContext.getExtVectorType(astContext.IntTy, size),
12722+
std::vector<clang::spirv::SpirvConstant *>(size, one));
12723+
// Select between ones and zeros based on the comparison.
12724+
return spvBuilder.createSelect(
12725+
astContext.getExtVectorType(astContext.IntTy, size), cmp, ones, zeros,
12726+
srcLoc, srcRange);
12727+
};
12728+
12729+
if (isVectorType(firstArgType)) {
12730+
return actOnEachVec(0, firstArgType, callExpr->getType(), doExpr(firstArg));
12731+
}
12732+
return processEachVectorInMatrix(firstArg, doExpr(firstArg), actOnEachVec,
12733+
srcLoc, srcRange);
12734+
}
12735+
1266812736
SpirvInstruction *
1266912737
SpirvEmitter::processIntrinsicFloatSign(const CallExpr *callExpr) {
1267012738
// Import the GLSL.std.450 extended instruction set.

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,9 @@ class SpirvEmitter : public ASTConsumer {
648648
/// Processes the 'ReadClock' intrinsic function.
649649
SpirvInstruction *processIntrinsicReadClock(const CallExpr *);
650650

651+
/// Processes the 'sign' intrinsic function for unsigned integer types.
652+
SpirvInstruction *processIntrinsicSignUnsignedInt(const CallExpr *callExpr);
653+
651654
/// Processes the 'sign' intrinsic function for float types.
652655
/// The FSign instruction in the GLSL instruction set returns a floating point
653656
/// result. The HLSL sign function, however, returns an integer. An extra
Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,74 @@
1-
// RUN: %dxc -T vs_6_0 -E main -fcgl %s -spirv | FileCheck %s
1+
// RUN: %dxc -T vs_6_0 -E main -fcgl -Vd %s -spirv | FileCheck %s
22

3-
// CHECK: [[glsl:%[0-9]+]] = OpExtInstImport "GLSL.std.450"
3+
// CHECK-DAG: %int_0 = OpConstant %int 0
4+
// CHECK-DAG: %int_1 = OpConstant %int 1
5+
// CHECK-DAG: %uint_0 = OpConstant %uint 0
6+
// CHECK-DAG: %v3int = OpTypeVector %int 3
7+
// CHECK-DAG: %v3uint = OpTypeVector %uint 3
8+
// CHECK-DAG: [[zeros_uint3:%[0-9]+]] = OpConstantComposite %v3uint %uint_0 %uint_0 %uint_0
9+
// CHECK-DAG: [[zeros_int3:%[0-9]+]] = OpConstantComposite %v3int %int_0 %int_0 %int_0
10+
// CHECK-DAG: [[ones_int3:%[0-9]+]] = OpConstantComposite %v3int %int_1 %int_1 %int_1
411

512
void main() {
613
int result;
714
int3 result3;
15+
int3x3 result3x3;
816

917
// CHECK: [[a:%[0-9]+]] = OpLoad %uint %a
10-
// CHECK: [[abs_a:%[0-9]+]] = OpExtInst %int [[glsl]] SAbs [[a]]
11-
// CHECK-NEXT: [[sign_a:%[0-9]+]] = OpExtInst %int [[glsl]] SSign [[abs_a]]
12-
// CHECK-NEXT: OpStore %result [[sign_a]]
18+
// CHECK-NEXT: [[cmp_a:%[0-9]+]] = OpUGreaterThan %bool [[a]] %uint_0
19+
// CHECK-NEXT: [[select_a:%[0-9]+]] = OpSelect %int [[cmp_a]] %int_1 %int_0
20+
// CHECK-NEXT: OpStore %result [[select_a]]
1321
uint a;
1422
result = sign(a);
1523

16-
// CHECK-NEXT: [[b:%[0-9]+]] = OpLoad %uint %b
17-
// CHECK: [[abs_b:%[0-9]+]] = OpExtInst %int [[glsl]] SAbs [[b]]
18-
// CHECK-NEXT: [[sign_b:%[0-9]+]] = OpExtInst %int [[glsl]] SSign [[abs_b]]
19-
// CHECK-NEXT: OpStore %result [[sign_b]]
24+
// CHECK: [[b:%[0-9]+]] = OpLoad %uint %b
25+
// CHECK-NEXT: [[cmp_b:%[0-9]+]] = OpUGreaterThan %bool [[b]] %uint_0
26+
// CHECK-NEXT: [[select_b:%[0-9]+]] = OpSelect %int [[cmp_b]] %int_1 %int_0
27+
// CHECK-NEXT: OpStore %result [[select_b]]
2028
uint1 b;
2129
result = sign(b);
2230

23-
// CHECK-NEXT: [[c:%[0-9]+]] = OpLoad %v3uint %c
24-
// CHECK: [[abs_c:%[0-9]+]] = OpExtInst %v3int [[glsl]] SAbs [[c]]
25-
// CHECK-NEXT: [[sign_c:%[0-9]+]] = OpExtInst %v3int [[glsl]] SSign [[abs_c]]
26-
// CHECK-NEXT: OpStore %result3 [[sign_c]]
31+
// CHECK: [[c:%[0-9]+]] = OpLoad %v3uint %c
32+
// CHECK-NEXT: [[cmp_c:%[0-9]+]] = OpUGreaterThan %v3bool [[c]] [[zeros_uint3]]
33+
// CHECK-NEXT: [[select_c:%[0-9]+]] = OpSelect %v3int [[cmp_c]] [[ones_int3]] [[zeros_int3]]
34+
// CHECK-NEXT: OpStore %result3 [[select_c]]
2735
uint3 c;
2836
result3 = sign(c);
2937

30-
// CHECK: [[d:%[0-9]+]] = OpLoad %uint %d
31-
// CHECK: [[abs_d:%[0-9]+]] = OpExtInst %int [[glsl]] SAbs [[d]]
32-
// CHECK-NEXT: [[sign_d:%[0-9]+]] = OpExtInst %int [[glsl]] SSign [[abs_d]]
33-
// CHECK-NEXT: OpStore %result [[sign_d]]
38+
39+
// CHECK: [[d:%[0-9]+]] = OpLoad %uint %d
40+
// CHECK-NEXT: [[cmp_d:%[0-9]+]] = OpUGreaterThan %bool [[d]] %uint_0
41+
// CHECK-NEXT: [[select_d:%[0-9]+]] = OpSelect %int [[cmp_d]] %int_1 %int_0
42+
// CHECK-NEXT: OpStore %result [[select_d]]
3443
uint1x1 d;
3544
result = sign(d);
3645

37-
// CHECK-NEXT: [[e:%[0-9]+]] = OpLoad %v2uint %e
38-
// CHECK-NEXT: [[abs_e:%[0-9]+]] = OpExtInst %v2int [[glsl]] SAbs [[e]]
39-
// CHECK-NEXT: [[sign_e:%[0-9]+]] = OpExtInst %v2int [[glsl]] SSign [[abs_e]]
40-
// CHECK-NEXT: OpStore %result2 [[sign_e]]
41-
uint1x2 e;
42-
int2 result2 = sign(e);
43-
44-
// CHECK-NEXT: [[f:%[0-9]+]] = OpLoad %v4uint %f
45-
// CHECK-NEXT: [[abs_f:%[0-9]+]] = OpExtInst %v4int [[glsl]] SAbs [[f]]
46-
// CHECK-NEXT: [[sign_f:%[0-9]+]] = OpExtInst %v4int [[glsl]] SSign [[abs_f]]
47-
// CHECK-NEXT: OpStore %result4 [[sign_f]]
48-
uint4x1 f;
49-
int4 result4 = sign(f);
50-
51-
// TODO: Integer matrices are not supported yet. See intrinsics.intsign.hlsl
52-
}
46+
// CHECK: [[e:%[0-9]+]] = OpLoad %v3uint %e
47+
// CHECK-NEXT: [[cmp_e:%[0-9]+]] = OpUGreaterThan %v3bool [[e]] [[zeros_uint3]]
48+
// CHECK-NEXT: [[select_e:%[0-9]+]] = OpSelect %v3int [[cmp_e]] [[ones_int3]] [[zeros_int3]]
49+
// CHECK-NEXT: OpStore %result3 [[select_e]]
50+
uint1x3 e;
51+
result3 = sign(e);
52+
53+
// CHECK: [[f:%[0-9]+]] = OpLoad %v3uint %f
54+
// CHECK-NEXT: [[cmp_f:%[0-9]+]] = OpUGreaterThan %v3bool [[f]] [[zeros_uint3]]
55+
// CHECK-NEXT: [[select_f:%[0-9]+]] = OpSelect %v3int [[cmp_f]] [[ones_int3]] [[zeros_int3]]
56+
// CHECK-NEXT: OpStore %result3 [[select_f]]
57+
uint3x1 f;
58+
result3 = sign(f);
59+
60+
// CHECK: [[h:%[0-9]+]] = OpLoad %_arr_v3uint_uint_3 %h
61+
// CHECK-NEXT: [[h_row0:%[0-9]+]] = OpCompositeExtract %v3uint [[h]] 0
62+
// CHECK-NEXT: [[cmp_h_row0:%[0-9]+]] = OpUGreaterThan %v3bool [[h_row0]] [[zeros_uint3]]
63+
// CHECK-NEXT: [[select_h_row0:%[0-9]+]] = OpSelect %v3int [[cmp_h_row0]] [[ones_int3]] [[zeros_int3]]
64+
// CHECK-NEXT: [[h_row1:%[0-9]+]] = OpCompositeExtract %v3uint [[h]] 1
65+
// CHECK-NEXT: [[cmp_h_row1:%[0-9]+]] = OpUGreaterThan %v3bool [[h_row1]] [[zeros_uint3]]
66+
// CHECK-NEXT: [[select_h_row1:%[0-9]+]] = OpSelect %v3int [[cmp_h_row1]] [[ones_int3]] [[zeros_int3]]
67+
// CHECK-NEXT: [[h_row2:%[0-9]+]] = OpCompositeExtract %v3uint [[h]] 2
68+
// CHECK-NEXT: [[cmp_h_row2:%[0-9]+]] = OpUGreaterThan %v3bool [[h_row2]] [[zeros_uint3]]
69+
// CHECK-NEXT: [[select_h_row2:%[0-9]+]] = OpSelect %v3int [[cmp_h_row2]] [[ones_int3]] [[zeros_int3]]
70+
// CHECK-NEXT: [[select_h:%[0-9]+]] = OpCompositeConstruct %_arr_v3uint_uint_3 [[select_h_row0]] [[select_h_row1]] [[select_h_row2]]
71+
// CHECK-NEXT: OpStore %result3x3 [[select_h]]
72+
uint3x3 h;
73+
result3x3 = sign(h);
74+
}

0 commit comments

Comments
 (0)