Skip to content

Commit 7a5ba65

Browse files
authored
[AArch64] optimize vselect of bitcast (llvm#180375)
Using code/ideas from the x86 backend to optimize a select on a bitcast integer. The previous aarch64 approach was to individually extract the bits from the mask, which is kind of terrible. https://rust.godbolt.org/z/576sndT66 ```llvm define void @if_then_else8(ptr %out, i8 %mask, ptr %if_true, ptr %if_false) { start: %t = load <8 x i32>, ptr %if_true, align 4 %f = load <8 x i32>, ptr %if_false, align 4 %m = bitcast i8 %mask to <8 x i1> %s = select <8 x i1> %m, <8 x i32> %t, <8 x i32> %f store <8 x i32> %s, ptr %out, align 4 ret void } ``` turned into ```asm if_then_else8: // @if_then_else8 sub sp, sp, llvm#16 ubfx w8, w1, llvm#4, #1 and w11, w1, #0x1 ubfx w9, w1, llvm#5, #1 fmov s1, w11 ubfx w10, w1, #1, #1 fmov s0, w8 ubfx w8, w1, llvm#6, #1 ldp q5, q2, [x3] mov v1.h[1], w10 ldp q4, q3, [x2] mov v0.h[1], w9 ubfx w9, w1, #2, #1 mov v1.h[2], w9 ubfx w9, w1, llvm#3, #1 mov v0.h[2], w8 ubfx w8, w1, llvm#7, #1 mov v1.h[3], w9 mov v0.h[3], w8 ushll v1.4s, v1.4h, #0 ushll v0.4s, v0.4h, #0 shl v1.4s, v1.4s, llvm#31 shl v0.4s, v0.4s, llvm#31 cmlt v1.4s, v1.4s, #0 cmlt v0.4s, v0.4s, #0 bsl v1.16b, v4.16b, v5.16b bsl v0.16b, v3.16b, v2.16b stp q1, q0, [x0] add sp, sp, llvm#16 ret ``` With this PR that instead emits ```asm if_then_else8: adrp x8, .LCPI0_1 dup v0.4s, w1 ldr q1, [x8, :lo12:.LCPI0_1] adrp x8, .LCPI0_0 ldr q2, [x8, :lo12:.LCPI0_0] ldp q4, q3, [x2] and v1.16b, v0.16b, v1.16b and v0.16b, v0.16b, v2.16b ldp q5, q2, [x3] cmeq v1.4s, v1.4s, #0 cmeq v0.4s, v0.4s, #0 bsl v1.16b, v2.16b, v3.16b bsl v0.16b, v5.16b, v4.16b stp q0, q1, [x0] ret ``` So substantially shorter. Instead of building the mask element-by-element, this approach (by virtue of not splitting) instead splats the mask value into all vector lanes, performs a bitwise and with powers of 2, and compares with zero to construct the mask vector. cc rust-lang/rust#122376 cc llvm#175769
1 parent 9e95cff commit 7a5ba65

File tree

2 files changed

+1220
-4
lines changed

2 files changed

+1220
-4
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 113 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24368,6 +24368,89 @@ static SDValue performZExtUZPCombine(SDNode *N, SelectionDAG &DAG) {
2436824368
return DAG.getNode(ISD::AND, DL, VT, BC, DAG.getConstant(Mask, DL, VT));
2436924369
}
2437024370

24371+
// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
24372+
static SDValue combineToExtendBoolVectorInReg(
24373+
unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N0, SelectionDAG &DAG,
24374+
TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget &Subtarget) {
24375+
if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND &&
24376+
Opcode != ISD::ANY_EXTEND)
24377+
return SDValue();
24378+
if (!DCI.isBeforeLegalizeOps())
24379+
return SDValue();
24380+
if (!Subtarget.hasNEON())
24381+
return SDValue();
24382+
24383+
EVT SVT = VT.getScalarType();
24384+
EVT InSVT = N0.getValueType().getScalarType();
24385+
unsigned EltSizeInBits = SVT.getSizeInBits();
24386+
24387+
// Input type must be extending a bool vector (bit-casted from a scalar
24388+
// integer) to legal integer types.
24389+
if (!VT.isVector())
24390+
return SDValue();
24391+
if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16 && SVT != MVT::i8)
24392+
return SDValue();
24393+
if (InSVT != MVT::i1 || N0.getOpcode() != ISD::BITCAST)
24394+
return SDValue();
24395+
24396+
SDValue N00 = N0.getOperand(0);
24397+
EVT SclVT = N00.getValueType();
24398+
if (!SclVT.isScalarInteger())
24399+
return SDValue();
24400+
24401+
SDValue Vec;
24402+
SmallVector<int> ShuffleMask;
24403+
unsigned NumElts = VT.getVectorNumElements();
24404+
assert(NumElts == SclVT.getSizeInBits() && "Unexpected bool vector size");
24405+
24406+
// Broadcast the scalar integer to the vector elements.
24407+
bool IsBE = DAG.getDataLayout().isBigEndian();
24408+
if (NumElts > EltSizeInBits) {
24409+
// If the scalar integer is greater than the vector element size, then we
24410+
// must split it down into sub-sections for broadcasting. For example:
24411+
// i16 -> v16i8 (i16 -> v8i16 -> v16i8) with 2 sub-sections.
24412+
// i32 -> v32i8 (i32 -> v8i32 -> v32i8) with 4 sub-sections.
24413+
assert((NumElts % EltSizeInBits) == 0 && "Unexpected integer scale");
24414+
unsigned Scale = NumElts / EltSizeInBits;
24415+
EVT BroadcastVT = EVT::getVectorVT(*DAG.getContext(), SclVT, EltSizeInBits);
24416+
Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00);
24417+
Vec = DAG.getBitcast(VT, Vec);
24418+
24419+
for (unsigned I = 0; I != Scale; ++I)
24420+
ShuffleMask.append(EltSizeInBits, (int)I);
24421+
24422+
Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask);
24423+
} else {
24424+
// For smaller scalar integers, we can simply any-extend it to the vector
24425+
// element size (we don't care about the upper bits) and broadcast it to all
24426+
// elements.
24427+
Vec = DAG.getSplat(VT, DL, DAG.getAnyExtOrTrunc(N00, DL, SVT));
24428+
}
24429+
24430+
// Now, mask the relevant bit in each element.
24431+
SmallVector<SDValue, 32> Bits;
24432+
for (unsigned I = 0; I != NumElts; ++I) {
24433+
unsigned ScalarBit = IsBE ? (NumElts - 1 - I) : I;
24434+
int BitIdx = ScalarBit % EltSizeInBits;
24435+
APInt Bit = APInt::getBitsSet(EltSizeInBits, BitIdx, BitIdx + 1);
24436+
Bits.push_back(DAG.getConstant(Bit, DL, SVT));
24437+
}
24438+
SDValue BitMask = DAG.getBuildVector(VT, DL, Bits);
24439+
Vec = DAG.getNode(ISD::AND, DL, VT, Vec, BitMask);
24440+
24441+
// Compare against the bitmask and extend the result.
24442+
EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, NumElts);
24443+
Vec = DAG.getSetCC(DL, CCVT, Vec, BitMask, ISD::SETEQ);
24444+
Vec = DAG.getSExtOrTrunc(Vec, DL, VT);
24445+
24446+
// For SEXT, this is now done, otherwise shift the result down for
24447+
// zero-extension.
24448+
if (Opcode == ISD::SIGN_EXTEND)
24449+
return Vec;
24450+
return DAG.getNode(ISD::SRL, DL, VT, Vec,
24451+
DAG.getConstant(EltSizeInBits - 1, DL, VT));
24452+
}
24453+
2437124454
// Combine:
2437224455
// ext(duplane(insert_subvector(undef, trunc(X), 0), idx))
2437324456
// Into:
@@ -24432,7 +24515,8 @@ static SDValue performExtendDuplaneTruncCombine(SDNode *N, SelectionDAG &DAG) {
2443224515

2443324516
static SDValue performExtendCombine(SDNode *N,
2443424517
TargetLowering::DAGCombinerInfo &DCI,
24435-
SelectionDAG &DAG) {
24518+
SelectionDAG &DAG,
24519+
const AArch64Subtarget *Subtarget) {
2443624520
// If we see something like (zext (sabd (extract_high ...), (DUP ...))) then
2443724521
// we can convert that DUP into another extract_high (of a bigger DUP), which
2443824522
// helps the backend to decide that an sabdl2 would be useful, saving a real
@@ -24455,6 +24539,13 @@ static SDValue performExtendCombine(SDNode *N,
2445524539
if (SDValue R = performZExtUZPCombine(N, DAG))
2445624540
return R;
2445724541

24542+
SDLoc dl(N);
24543+
SDValue N0 = N->getOperand(0);
24544+
EVT VT = N->getValueType(0);
24545+
if (SDValue V = combineToExtendBoolVectorInReg(N->getOpcode(), dl, VT, N0,
24546+
DAG, DCI, *Subtarget))
24547+
return V;
24548+
2445824549
if (N->getValueType(0).isFixedLengthVector() &&
2445924550
N->getOpcode() == ISD::SIGN_EXTEND &&
2446024551
N->getOperand(0)->getOpcode() == ISD::SETCC)
@@ -27712,7 +27803,11 @@ static SDValue trySwapVSelectOperands(SDNode *N, SelectionDAG &DAG) {
2771227803
// FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as
2771327804
// condition. If it can legalize "VSELECT v1i1" correctly, no need to combine
2771427805
// such VSELECT.
27715-
static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
27806+
static SDValue performVSelectCombine(SDNode *N,
27807+
TargetLowering::DAGCombinerInfo &DCI,
27808+
const AArch64Subtarget *Subtarget) {
27809+
SelectionDAG &DAG = DCI.DAG;
27810+
2771627811
if (auto SwapResult = trySwapVSelectOperands(N, DAG))
2771727812
return SwapResult;
2771827813

@@ -27776,6 +27871,20 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
2777627871
}
2777727872
}
2777827873

27874+
// Attempt to convert a (vXi1 bitcast(iX N0)) selection mask before it might
27875+
// get split by legalization.
27876+
if (N0.getOpcode() == ISD::BITCAST && CCVT.isVector() &&
27877+
CCVT.getVectorElementType() == MVT::i1) {
27878+
SDLoc DL(N);
27879+
EVT ExtCondVT = ResVT.changeVectorElementTypeToInteger();
27880+
27881+
if (SDValue ExtCond = combineToExtendBoolVectorInReg(
27882+
ISD::SIGN_EXTEND, DL, ExtCondVT, N0, DAG, DCI, *Subtarget)) {
27883+
ExtCond = DAG.getNode(ISD::TRUNCATE, DL, CCVT, ExtCond);
27884+
return DAG.getSelect(DL, ResVT, ExtCond, IfTrue, IfFalse);
27885+
}
27886+
}
27887+
2777927888
EVT CmpVT = N0.getOperand(0).getValueType();
2778027889
if (N0.getOpcode() != ISD::SETCC ||
2778127890
CCVT.getVectorElementCount() != ElementCount::getFixed(1) ||
@@ -29188,7 +29297,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2918829297
case ISD::ANY_EXTEND:
2918929298
case ISD::ZERO_EXTEND:
2919029299
case ISD::SIGN_EXTEND:
29191-
return performExtendCombine(N, DCI, DAG);
29300+
return performExtendCombine(N, DCI, DAG, Subtarget);
2919229301
case ISD::SIGN_EXTEND_INREG:
2919329302
return performSignExtendInRegCombine(N, DCI, DAG);
2919429303
case ISD::CONCAT_VECTORS:
@@ -29200,7 +29309,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2920029309
case ISD::SELECT:
2920129310
return performSelectCombine(N, DCI);
2920229311
case ISD::VSELECT:
29203-
return performVSelectCombine(N, DCI.DAG);
29312+
return performVSelectCombine(N, DCI, Subtarget);
2920429313
case ISD::SETCC:
2920529314
return performSETCCCombine(N, DCI, DAG);
2920629315
case ISD::LOAD:

0 commit comments

Comments
 (0)