@@ -9353,13 +9353,81 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
93539353 retVal = processIntrinsicPrintf(callExpr);
93549354 break;
93559355 case hlsl::IntrinsicOp::IOP_usign: {
9356- // Do SAbs followed by SSign
9357- auto *absVal = processIntrinsicUsingGLSLInst(
9358- callExpr, GLSLstd450::GLSLstd450SAbs,
9359- /*actPerRowForMatrices*/ true, srcLoc, srcRange);
9360- retVal = spvBuilder.createGLSLExtInst(callExpr->getType(),
9361- GLSLstd450::GLSLstd450SSign, {absVal},
9362- srcLoc, srcRange);
9356+ auto *zero = getValueZero(callExpr->getType());
9357+ auto *one = getValueOne(callExpr->getType());
9358+ // QualType elemType = {};
9359+ // if (isVectorType(callExpr->getArg(0)->getType(), &elemType) ||
9360+ // isMxNMatrix(callExpr->getArg(0)->getType(), &elemType, nullptr,
9361+ // nullptr)) {
9362+ // // For vector types, we need to create vector constants for zero and
9363+ // one. const uint32_t size =
9364+ // hlsl::GetHLSLVecSize(callExpr->getArg(0)->getType());
9365+ // auto zeroUint = getValueZero(elemType);
9366+ // // Create `size` vector of uint zeros.
9367+ // zerosUint = spvBuilder.createCompositeConstruct(
9368+ // elemType,
9369+ // std::vector<SpirvInstruction *>(size, zeroUint), srcLoc, srcRange);
9370+ // zeros = spvBuilder.createCompositeConstruct(
9371+ // elemType,
9372+ // std::vector<SpirvInstruction *>(size, zero), srcLoc, srcRange);
9373+ // ones = spvBuilder.createCompositeConstruct(
9374+ // elemType,
9375+ // std::vector<SpirvInstruction *>(size, one), srcLoc, srcRange);
9376+ // }
9377+ auto firstArgType = callExpr->getArg(0)->getType();
9378+ bool isNonTrivialVecOrMat =
9379+ (isVectorType(firstArgType) &&
9380+ hlsl::GetHLSLVecSize(firstArgType) != 1) ||
9381+ (isMxNMatrix(firstArgType) && !is1x1Matrix(firstArgType));
9382+ if (isNonTrivialVecOrMat) {
9383+ auto loc = callExpr->getArg(0)->getExprLoc();
9384+ auto range = callExpr->getArg(0)->getSourceRange();
9385+ auto *arg = callExpr->getArg(0);
9386+ auto argId = doExpr(arg);
9387+
9388+ const auto actOnEachVec = [this, loc, range, zero,
9389+ one](uint32_t index, QualType inType,
9390+ QualType outType,
9391+ SpirvInstruction *curRow) {
9392+ // Get size of the vector/matrix row.
9393+ const uint32_t size = hlsl::GetHLSLVecSize(outType);
9394+ const auto matElemType = hlsl::GetHLSLVecElementType(inType);
9395+ auto zeroUint = getValueZero(matElemType);
9396+ // Create `size` vector of uint zeros.
9397+ auto *zerosUint = spvBuilder.createCompositeConstruct(
9398+ astContext.getExtVectorType(matElemType, size),
9399+ std::vector<SpirvInstruction *>(size, zeroUint), loc, range);
9400+ // Compare if they are greater than zero.
9401+ auto *cmp = spvBuilder.createBinaryOp(
9402+ spv::Op::OpUGreaterThan,
9403+ astContext.getExtVectorType(astContext.BoolTy, size), curRow,
9404+ zerosUint, loc, range);
9405+
9406+ // Select between one and zero.
9407+ auto *zeros = spvBuilder.createCompositeConstruct(
9408+ astContext.getExtVectorType(astContext.IntTy, size),
9409+ std::vector<SpirvInstruction *>(size, zero), loc, range);
9410+ auto *ones = spvBuilder.createCompositeConstruct(
9411+ astContext.getExtVectorType(astContext.IntTy, size),
9412+ std::vector<SpirvInstruction *>(size, one), loc, range);
9413+ return spvBuilder.createSelect(
9414+ astContext.getExtVectorType(astContext.IntTy, size), cmp, ones,
9415+ zeros, loc, range);
9416+ };
9417+ if (isMxNMatrix(arg->getType()))
9418+ retVal =
9419+ processEachVectorInMatrix(arg, argId, actOnEachVec, loc, range);
9420+ else
9421+ retVal = actOnEachVec(0, arg->getType(), callExpr->getType(), argId);
9422+ } else {
9423+ auto *argVal = doExpr(callExpr->getArg(0));
9424+ auto *zeroUint = getValueZero(callExpr->getArg(0)->getType());
9425+ auto *cmp =
9426+ spvBuilder.createBinaryOp(spv::Op::OpUGreaterThan, astContext.BoolTy,
9427+ argVal, zeroUint, srcLoc, srcRange);
9428+ retVal = spvBuilder.createSelect(astContext.IntTy, cmp, one, zero, srcLoc,
9429+ srcRange);
9430+ }
93639431 } break;
93649432 case hlsl::IntrinsicOp::IOP_sign: {
93659433 if (isFloatOrVecMatOfFloatType(callExpr->getArg(0)->getType()))
0 commit comments