Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9352,6 +9352,9 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
case hlsl::IntrinsicOp::IOP_printf:
retVal = processIntrinsicPrintf(callExpr);
break;
case hlsl::IntrinsicOp::IOP_usign:
retVal = processIntrinsicSignUnsignedInt(callExpr);
break;
case hlsl::IntrinsicOp::IOP_sign: {
if (isFloatOrVecMatOfFloatType(callExpr->getArg(0)->getType()))
retVal = processIntrinsicFloatSign(callExpr);
Expand Down Expand Up @@ -12656,6 +12659,80 @@ SpirvEmitter::processIntrinsicSaturate(const CallExpr *callExpr) {
return nullptr;
}

SpirvInstruction *
SpirvEmitter::processIntrinsicSignUnsignedInt(const CallExpr *callExpr) {
const auto srcLoc = callExpr->getExprLoc();
const auto srcRange = callExpr->getSourceRange();

const Expr *firstArg = callExpr->getArg(0);
const QualType firstArgType = firstArg->getType();
auto elemType = QualType{};
uint32_t numRows;
uint32_t numCols;
uint32_t count;
bool isScalar =
isScalarType(firstArgType, &elemType) ||
(isVectorType(firstArgType, &elemType, &count) && count == 1) ||
(isMxNMatrix(firstArgType, &elemType, &numRows, &numCols) &&
(numRows == 1 && numCols == 1));

auto *zero = getValueZero(astContext.IntTy);
auto *one = getValueOne(astContext.IntTy);
if (isScalar) {
auto *argVal = doExpr(callExpr->getArg(0));
auto *zeroUint = getValueZero(callExpr->getArg(0)->getType());
auto *cmp =
spvBuilder.createBinaryOp(spv::Op::OpUGreaterThan, astContext.BoolTy,
argVal, zeroUint, srcLoc, srcRange);
return spvBuilder.createSelect(astContext.IntTy, cmp, one, zero, srcLoc,
srcRange);
}

uint32_t size;
if (isVectorType(firstArgType)) {
size = count;
} else if (is1xNMatrix(firstArgType)) {
size = numCols;
} else if (isMx1Matrix(firstArgType)) {
size = numRows;
} else {
size = numRows;
}

const auto actOnEachVec = [this, srcLoc, srcRange, zero, one, elemType,
size](uint32_t index, QualType inType,
QualType outType, SpirvInstruction *curRow) {
auto zeroUint = getValueZero(elemType);
// Create `size` vector of uint zeros.
auto *zerosUint = spvBuilder.getConstantComposite(
astContext.getExtVectorType(elemType, size),
std::vector<clang::spirv::SpirvConstant *>(size, zeroUint));
// Compare if they are greater than zero.
auto *cmp = spvBuilder.createBinaryOp(
spv::Op::OpUGreaterThan,
astContext.getExtVectorType(astContext.BoolTy, size), curRow, zerosUint,
srcLoc, srcRange);

// Create a vector of int ones and zeros.
auto *zeros = spvBuilder.getConstantComposite(
astContext.getExtVectorType(astContext.IntTy, size),
std::vector<clang::spirv::SpirvConstant *>(size, zero));
auto *ones = spvBuilder.getConstantComposite(
astContext.getExtVectorType(astContext.IntTy, size),
std::vector<clang::spirv::SpirvConstant *>(size, one));
// Select between ones and zeros based on the comparison.
return spvBuilder.createSelect(
astContext.getExtVectorType(astContext.IntTy, size), cmp, ones, zeros,
srcLoc, srcRange);
};

if (isVectorType(firstArgType)) {
return actOnEachVec(0, firstArgType, callExpr->getType(), doExpr(firstArg));
}
return processEachVectorInMatrix(firstArg, doExpr(firstArg), actOnEachVec,
srcLoc, srcRange);
}

SpirvInstruction *
SpirvEmitter::processIntrinsicFloatSign(const CallExpr *callExpr) {
// Import the GLSL.std.450 extended instruction set.
Expand Down
3 changes: 3 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,9 @@ class SpirvEmitter : public ASTConsumer {
/// Processes the 'ReadClock' intrinsic function.
SpirvInstruction *processIntrinsicReadClock(const CallExpr *);

/// Processes the 'sign' intrinsic function for unsigned integer types.
SpirvInstruction *processIntrinsicSignUnsignedInt(const CallExpr *callExpr);

/// Processes the 'sign' intrinsic function for float types.
/// The FSign instruction in the GLSL instruction set returns a floating point
/// result. The HLSL sign function, however, returns an integer. An extra
Expand Down
74 changes: 74 additions & 0 deletions tools/clang/test/CodeGenSPIRV/intrinsics.uintsign.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: %dxc -T vs_6_0 -E main -fcgl -Vd %s -spirv | FileCheck %s

// CHECK-DAG: %int_0 = OpConstant %int 0
// CHECK-DAG: %int_1 = OpConstant %int 1
// CHECK-DAG: %uint_0 = OpConstant %uint 0
// CHECK-DAG: %v3int = OpTypeVector %int 3
// CHECK-DAG: %v3uint = OpTypeVector %uint 3
// CHECK-DAG: [[zeros_uint3:%[0-9]+]] = OpConstantComposite %v3uint %uint_0 %uint_0 %uint_0
// CHECK-DAG: [[zeros_int3:%[0-9]+]] = OpConstantComposite %v3int %int_0 %int_0 %int_0
// CHECK-DAG: [[ones_int3:%[0-9]+]] = OpConstantComposite %v3int %int_1 %int_1 %int_1

void main() {
int result;
int3 result3;
int3x3 result3x3;

// CHECK: [[a:%[0-9]+]] = OpLoad %uint %a
// CHECK-NEXT: [[cmp_a:%[0-9]+]] = OpUGreaterThan %bool [[a]] %uint_0
// CHECK-NEXT: [[select_a:%[0-9]+]] = OpSelect %int [[cmp_a]] %int_1 %int_0
// CHECK-NEXT: OpStore %result [[select_a]]
uint a;
result = sign(a);

// CHECK: [[b:%[0-9]+]] = OpLoad %uint %b
// CHECK-NEXT: [[cmp_b:%[0-9]+]] = OpUGreaterThan %bool [[b]] %uint_0
// CHECK-NEXT: [[select_b:%[0-9]+]] = OpSelect %int [[cmp_b]] %int_1 %int_0
// CHECK-NEXT: OpStore %result [[select_b]]
uint1 b;
result = sign(b);

// CHECK: [[c:%[0-9]+]] = OpLoad %v3uint %c
// CHECK-NEXT: [[cmp_c:%[0-9]+]] = OpUGreaterThan %v3bool [[c]] [[zeros_uint3]]
// CHECK-NEXT: [[select_c:%[0-9]+]] = OpSelect %v3int [[cmp_c]] [[ones_int3]] [[zeros_int3]]
// CHECK-NEXT: OpStore %result3 [[select_c]]
uint3 c;
result3 = sign(c);


// CHECK: [[d:%[0-9]+]] = OpLoad %uint %d
// CHECK-NEXT: [[cmp_d:%[0-9]+]] = OpUGreaterThan %bool [[d]] %uint_0
// CHECK-NEXT: [[select_d:%[0-9]+]] = OpSelect %int [[cmp_d]] %int_1 %int_0
// CHECK-NEXT: OpStore %result [[select_d]]
uint1x1 d;
result = sign(d);

// CHECK: [[e:%[0-9]+]] = OpLoad %v3uint %e
// CHECK-NEXT: [[cmp_e:%[0-9]+]] = OpUGreaterThan %v3bool [[e]] [[zeros_uint3]]
// CHECK-NEXT: [[select_e:%[0-9]+]] = OpSelect %v3int [[cmp_e]] [[ones_int3]] [[zeros_int3]]
// CHECK-NEXT: OpStore %result3 [[select_e]]
uint1x3 e;
result3 = sign(e);

// CHECK: [[f:%[0-9]+]] = OpLoad %v3uint %f
// CHECK-NEXT: [[cmp_f:%[0-9]+]] = OpUGreaterThan %v3bool [[f]] [[zeros_uint3]]
// CHECK-NEXT: [[select_f:%[0-9]+]] = OpSelect %v3int [[cmp_f]] [[ones_int3]] [[zeros_int3]]
// CHECK-NEXT: OpStore %result3 [[select_f]]
uint3x1 f;
result3 = sign(f);

// CHECK: [[h:%[0-9]+]] = OpLoad %_arr_v3uint_uint_3 %h
// CHECK-NEXT: [[h_row0:%[0-9]+]] = OpCompositeExtract %v3uint [[h]] 0
// CHECK-NEXT: [[cmp_h_row0:%[0-9]+]] = OpUGreaterThan %v3bool [[h_row0]] [[zeros_uint3]]
// CHECK-NEXT: [[select_h_row0:%[0-9]+]] = OpSelect %v3int [[cmp_h_row0]] [[ones_int3]] [[zeros_int3]]
// CHECK-NEXT: [[h_row1:%[0-9]+]] = OpCompositeExtract %v3uint [[h]] 1
// CHECK-NEXT: [[cmp_h_row1:%[0-9]+]] = OpUGreaterThan %v3bool [[h_row1]] [[zeros_uint3]]
// CHECK-NEXT: [[select_h_row1:%[0-9]+]] = OpSelect %v3int [[cmp_h_row1]] [[ones_int3]] [[zeros_int3]]
// CHECK-NEXT: [[h_row2:%[0-9]+]] = OpCompositeExtract %v3uint [[h]] 2
// CHECK-NEXT: [[cmp_h_row2:%[0-9]+]] = OpUGreaterThan %v3bool [[h_row2]] [[zeros_uint3]]
// CHECK-NEXT: [[select_h_row2:%[0-9]+]] = OpSelect %v3int [[cmp_h_row2]] [[ones_int3]] [[zeros_int3]]
// CHECK-NEXT: [[select_h:%[0-9]+]] = OpCompositeConstruct %_arr_v3uint_uint_3 [[select_h_row0]] [[select_h_row1]] [[select_h_row2]]
// CHECK-NEXT: OpStore %result3x3 [[select_h]]
uint3x3 h;
result3x3 = sign(h);
}