109467b48Spatrick //===- InstCombineSelect.cpp ----------------------------------------------===//
209467b48Spatrick //
309467b48Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409467b48Spatrick // See https://llvm.org/LICENSE.txt for license information.
509467b48Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609467b48Spatrick //
709467b48Spatrick //===----------------------------------------------------------------------===//
809467b48Spatrick //
909467b48Spatrick // This file implements the visitSelect function.
1009467b48Spatrick //
1109467b48Spatrick //===----------------------------------------------------------------------===//
1209467b48Spatrick
1309467b48Spatrick #include "InstCombineInternal.h"
1409467b48Spatrick #include "llvm/ADT/APInt.h"
1509467b48Spatrick #include "llvm/ADT/STLExtras.h"
1609467b48Spatrick #include "llvm/ADT/SmallVector.h"
1709467b48Spatrick #include "llvm/Analysis/AssumptionCache.h"
1809467b48Spatrick #include "llvm/Analysis/CmpInstAnalysis.h"
1909467b48Spatrick #include "llvm/Analysis/InstructionSimplify.h"
2073471bf0Spatrick #include "llvm/Analysis/OverflowInstAnalysis.h"
2109467b48Spatrick #include "llvm/Analysis/ValueTracking.h"
22*d415bd75Srobert #include "llvm/Analysis/VectorUtils.h"
2309467b48Spatrick #include "llvm/IR/BasicBlock.h"
2409467b48Spatrick #include "llvm/IR/Constant.h"
25*d415bd75Srobert #include "llvm/IR/ConstantRange.h"
2609467b48Spatrick #include "llvm/IR/Constants.h"
2709467b48Spatrick #include "llvm/IR/DerivedTypes.h"
2809467b48Spatrick #include "llvm/IR/IRBuilder.h"
2909467b48Spatrick #include "llvm/IR/InstrTypes.h"
3009467b48Spatrick #include "llvm/IR/Instruction.h"
3109467b48Spatrick #include "llvm/IR/Instructions.h"
3209467b48Spatrick #include "llvm/IR/IntrinsicInst.h"
3309467b48Spatrick #include "llvm/IR/Intrinsics.h"
3409467b48Spatrick #include "llvm/IR/Operator.h"
3509467b48Spatrick #include "llvm/IR/PatternMatch.h"
3609467b48Spatrick #include "llvm/IR/Type.h"
3709467b48Spatrick #include "llvm/IR/User.h"
3809467b48Spatrick #include "llvm/IR/Value.h"
3909467b48Spatrick #include "llvm/Support/Casting.h"
4009467b48Spatrick #include "llvm/Support/ErrorHandling.h"
4109467b48Spatrick #include "llvm/Support/KnownBits.h"
4273471bf0Spatrick #include "llvm/Transforms/InstCombine/InstCombiner.h"
4309467b48Spatrick #include <cassert>
4409467b48Spatrick #include <utility>
4509467b48Spatrick
46*d415bd75Srobert #define DEBUG_TYPE "instcombine"
47*d415bd75Srobert #include "llvm/Transforms/Utils/InstructionWorklist.h"
48*d415bd75Srobert
4909467b48Spatrick using namespace llvm;
5009467b48Spatrick using namespace PatternMatch;
5109467b48Spatrick
5209467b48Spatrick
5309467b48Spatrick /// Replace a select operand based on an equality comparison with the identity
5409467b48Spatrick /// constant of a binop.
foldSelectBinOpIdentity(SelectInst & Sel,const TargetLibraryInfo & TLI,InstCombinerImpl & IC)5509467b48Spatrick static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
56097a140dSpatrick const TargetLibraryInfo &TLI,
5773471bf0Spatrick InstCombinerImpl &IC) {
5809467b48Spatrick // The select condition must be an equality compare with a constant operand.
5909467b48Spatrick Value *X;
6009467b48Spatrick Constant *C;
6109467b48Spatrick CmpInst::Predicate Pred;
6209467b48Spatrick if (!match(Sel.getCondition(), m_Cmp(Pred, m_Value(X), m_Constant(C))))
6309467b48Spatrick return nullptr;
6409467b48Spatrick
6509467b48Spatrick bool IsEq;
6609467b48Spatrick if (ICmpInst::isEquality(Pred))
6709467b48Spatrick IsEq = Pred == ICmpInst::ICMP_EQ;
6809467b48Spatrick else if (Pred == FCmpInst::FCMP_OEQ)
6909467b48Spatrick IsEq = true;
7009467b48Spatrick else if (Pred == FCmpInst::FCMP_UNE)
7109467b48Spatrick IsEq = false;
7209467b48Spatrick else
7309467b48Spatrick return nullptr;
7409467b48Spatrick
7509467b48Spatrick // A select operand must be a binop.
7609467b48Spatrick BinaryOperator *BO;
7709467b48Spatrick if (!match(Sel.getOperand(IsEq ? 1 : 2), m_BinOp(BO)))
7809467b48Spatrick return nullptr;
7909467b48Spatrick
8009467b48Spatrick // The compare constant must be the identity constant for that binop.
8109467b48Spatrick // If this a floating-point compare with 0.0, any zero constant will do.
8209467b48Spatrick Type *Ty = BO->getType();
8309467b48Spatrick Constant *IdC = ConstantExpr::getBinOpIdentity(BO->getOpcode(), Ty, true);
8409467b48Spatrick if (IdC != C) {
8509467b48Spatrick if (!IdC || !CmpInst::isFPPredicate(Pred))
8609467b48Spatrick return nullptr;
8709467b48Spatrick if (!match(IdC, m_AnyZeroFP()) || !match(C, m_AnyZeroFP()))
8809467b48Spatrick return nullptr;
8909467b48Spatrick }
9009467b48Spatrick
9109467b48Spatrick // Last, match the compare variable operand with a binop operand.
9209467b48Spatrick Value *Y;
9309467b48Spatrick if (!BO->isCommutative() && !match(BO, m_BinOp(m_Value(Y), m_Specific(X))))
9409467b48Spatrick return nullptr;
9509467b48Spatrick if (!match(BO, m_c_BinOp(m_Value(Y), m_Specific(X))))
9609467b48Spatrick return nullptr;
9709467b48Spatrick
9809467b48Spatrick // +0.0 compares equal to -0.0, and so it does not behave as required for this
9909467b48Spatrick // transform. Bail out if we can not exclude that possibility.
10009467b48Spatrick if (isa<FPMathOperator>(BO))
10109467b48Spatrick if (!BO->hasNoSignedZeros() && !CannotBeNegativeZero(Y, &TLI))
10209467b48Spatrick return nullptr;
10309467b48Spatrick
10409467b48Spatrick // BO = binop Y, X
10509467b48Spatrick // S = { select (cmp eq X, C), BO, ? } or { select (cmp ne X, C), ?, BO }
10609467b48Spatrick // =>
10709467b48Spatrick // S = { select (cmp eq X, C), Y, ? } or { select (cmp ne X, C), ?, Y }
108097a140dSpatrick return IC.replaceOperand(Sel, IsEq ? 1 : 2, Y);
10909467b48Spatrick }
11009467b48Spatrick
11109467b48Spatrick /// This folds:
11209467b48Spatrick /// select (icmp eq (and X, C1)), TC, FC
11309467b48Spatrick /// iff C1 is a power 2 and the difference between TC and FC is a power-of-2.
11409467b48Spatrick /// To something like:
11509467b48Spatrick /// (shr (and (X, C1)), (log2(C1) - log2(TC-FC))) + FC
11609467b48Spatrick /// Or:
11709467b48Spatrick /// (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC
11809467b48Spatrick /// With some variations depending if FC is larger than TC, or the shift
11909467b48Spatrick /// isn't needed, or the bit widths don't match.
foldSelectICmpAnd(SelectInst & Sel,ICmpInst * Cmp,InstCombiner::BuilderTy & Builder)12009467b48Spatrick static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
12109467b48Spatrick InstCombiner::BuilderTy &Builder) {
12209467b48Spatrick const APInt *SelTC, *SelFC;
12309467b48Spatrick if (!match(Sel.getTrueValue(), m_APInt(SelTC)) ||
12409467b48Spatrick !match(Sel.getFalseValue(), m_APInt(SelFC)))
12509467b48Spatrick return nullptr;
12609467b48Spatrick
12709467b48Spatrick // If this is a vector select, we need a vector compare.
12809467b48Spatrick Type *SelType = Sel.getType();
12909467b48Spatrick if (SelType->isVectorTy() != Cmp->getType()->isVectorTy())
13009467b48Spatrick return nullptr;
13109467b48Spatrick
13209467b48Spatrick Value *V;
13309467b48Spatrick APInt AndMask;
13409467b48Spatrick bool CreateAnd = false;
13509467b48Spatrick ICmpInst::Predicate Pred = Cmp->getPredicate();
13609467b48Spatrick if (ICmpInst::isEquality(Pred)) {
13709467b48Spatrick if (!match(Cmp->getOperand(1), m_Zero()))
13809467b48Spatrick return nullptr;
13909467b48Spatrick
14009467b48Spatrick V = Cmp->getOperand(0);
14109467b48Spatrick const APInt *AndRHS;
14209467b48Spatrick if (!match(V, m_And(m_Value(), m_Power2(AndRHS))))
14309467b48Spatrick return nullptr;
14409467b48Spatrick
14509467b48Spatrick AndMask = *AndRHS;
14609467b48Spatrick } else if (decomposeBitTestICmp(Cmp->getOperand(0), Cmp->getOperand(1),
14709467b48Spatrick Pred, V, AndMask)) {
14809467b48Spatrick assert(ICmpInst::isEquality(Pred) && "Not equality test?");
14909467b48Spatrick if (!AndMask.isPowerOf2())
15009467b48Spatrick return nullptr;
15109467b48Spatrick
15209467b48Spatrick CreateAnd = true;
15309467b48Spatrick } else {
15409467b48Spatrick return nullptr;
15509467b48Spatrick }
15609467b48Spatrick
15709467b48Spatrick // In general, when both constants are non-zero, we would need an offset to
15809467b48Spatrick // replace the select. This would require more instructions than we started
15909467b48Spatrick // with. But there's one special-case that we handle here because it can
16009467b48Spatrick // simplify/reduce the instructions.
16109467b48Spatrick APInt TC = *SelTC;
16209467b48Spatrick APInt FC = *SelFC;
163*d415bd75Srobert if (!TC.isZero() && !FC.isZero()) {
16409467b48Spatrick // If the select constants differ by exactly one bit and that's the same
16509467b48Spatrick // bit that is masked and checked by the select condition, the select can
16609467b48Spatrick // be replaced by bitwise logic to set/clear one bit of the constant result.
16709467b48Spatrick if (TC.getBitWidth() != AndMask.getBitWidth() || (TC ^ FC) != AndMask)
16809467b48Spatrick return nullptr;
16909467b48Spatrick if (CreateAnd) {
17009467b48Spatrick // If we have to create an 'and', then we must kill the cmp to not
17109467b48Spatrick // increase the instruction count.
17209467b48Spatrick if (!Cmp->hasOneUse())
17309467b48Spatrick return nullptr;
17409467b48Spatrick V = Builder.CreateAnd(V, ConstantInt::get(SelType, AndMask));
17509467b48Spatrick }
17609467b48Spatrick bool ExtraBitInTC = TC.ugt(FC);
17709467b48Spatrick if (Pred == ICmpInst::ICMP_EQ) {
17809467b48Spatrick // If the masked bit in V is clear, clear or set the bit in the result:
17909467b48Spatrick // (V & AndMaskC) == 0 ? TC : FC --> (V & AndMaskC) ^ TC
18009467b48Spatrick // (V & AndMaskC) == 0 ? TC : FC --> (V & AndMaskC) | TC
18109467b48Spatrick Constant *C = ConstantInt::get(SelType, TC);
18209467b48Spatrick return ExtraBitInTC ? Builder.CreateXor(V, C) : Builder.CreateOr(V, C);
18309467b48Spatrick }
18409467b48Spatrick if (Pred == ICmpInst::ICMP_NE) {
18509467b48Spatrick // If the masked bit in V is set, set or clear the bit in the result:
18609467b48Spatrick // (V & AndMaskC) != 0 ? TC : FC --> (V & AndMaskC) | FC
18709467b48Spatrick // (V & AndMaskC) != 0 ? TC : FC --> (V & AndMaskC) ^ FC
18809467b48Spatrick Constant *C = ConstantInt::get(SelType, FC);
18909467b48Spatrick return ExtraBitInTC ? Builder.CreateOr(V, C) : Builder.CreateXor(V, C);
19009467b48Spatrick }
19109467b48Spatrick llvm_unreachable("Only expecting equality predicates");
19209467b48Spatrick }
19309467b48Spatrick
19409467b48Spatrick // Make sure one of the select arms is a power-of-2.
19509467b48Spatrick if (!TC.isPowerOf2() && !FC.isPowerOf2())
19609467b48Spatrick return nullptr;
19709467b48Spatrick
19809467b48Spatrick // Determine which shift is needed to transform result of the 'and' into the
19909467b48Spatrick // desired result.
200*d415bd75Srobert const APInt &ValC = !TC.isZero() ? TC : FC;
20109467b48Spatrick unsigned ValZeros = ValC.logBase2();
20209467b48Spatrick unsigned AndZeros = AndMask.logBase2();
20309467b48Spatrick
20409467b48Spatrick // Insert the 'and' instruction on the input to the truncate.
20509467b48Spatrick if (CreateAnd)
20609467b48Spatrick V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask));
20709467b48Spatrick
20809467b48Spatrick // If types don't match, we can still convert the select by introducing a zext
20909467b48Spatrick // or a trunc of the 'and'.
21009467b48Spatrick if (ValZeros > AndZeros) {
21109467b48Spatrick V = Builder.CreateZExtOrTrunc(V, SelType);
21209467b48Spatrick V = Builder.CreateShl(V, ValZeros - AndZeros);
21309467b48Spatrick } else if (ValZeros < AndZeros) {
21409467b48Spatrick V = Builder.CreateLShr(V, AndZeros - ValZeros);
21509467b48Spatrick V = Builder.CreateZExtOrTrunc(V, SelType);
21609467b48Spatrick } else {
21709467b48Spatrick V = Builder.CreateZExtOrTrunc(V, SelType);
21809467b48Spatrick }
21909467b48Spatrick
22009467b48Spatrick // Okay, now we know that everything is set up, we just don't know whether we
22109467b48Spatrick // have a icmp_ne or icmp_eq and whether the true or false val is the zero.
222*d415bd75Srobert bool ShouldNotVal = !TC.isZero();
22309467b48Spatrick ShouldNotVal ^= Pred == ICmpInst::ICMP_NE;
22409467b48Spatrick if (ShouldNotVal)
22509467b48Spatrick V = Builder.CreateXor(V, ValC);
22609467b48Spatrick
22709467b48Spatrick return V;
22809467b48Spatrick }
22909467b48Spatrick
23009467b48Spatrick /// We want to turn code that looks like this:
23109467b48Spatrick /// %C = or %A, %B
23209467b48Spatrick /// %D = select %cond, %C, %A
23309467b48Spatrick /// into:
23409467b48Spatrick /// %C = select %cond, %B, 0
23509467b48Spatrick /// %D = or %A, %C
23609467b48Spatrick ///
23709467b48Spatrick /// Assuming that the specified instruction is an operand to the select, return
23809467b48Spatrick /// a bitmask indicating which operands of this instruction are foldable if they
23909467b48Spatrick /// equal the other incoming value of the select.
getSelectFoldableOperands(BinaryOperator * I)24009467b48Spatrick static unsigned getSelectFoldableOperands(BinaryOperator *I) {
24109467b48Spatrick switch (I->getOpcode()) {
24209467b48Spatrick case Instruction::Add:
243*d415bd75Srobert case Instruction::FAdd:
24409467b48Spatrick case Instruction::Mul:
245*d415bd75Srobert case Instruction::FMul:
24609467b48Spatrick case Instruction::And:
24709467b48Spatrick case Instruction::Or:
24809467b48Spatrick case Instruction::Xor:
24909467b48Spatrick return 3; // Can fold through either operand.
25009467b48Spatrick case Instruction::Sub: // Can only fold on the amount subtracted.
251*d415bd75Srobert case Instruction::FSub:
252*d415bd75Srobert case Instruction::FDiv: // Can only fold on the divisor amount.
25309467b48Spatrick case Instruction::Shl: // Can only fold on the shift amount.
25409467b48Spatrick case Instruction::LShr:
25509467b48Spatrick case Instruction::AShr:
25609467b48Spatrick return 1;
25709467b48Spatrick default:
25809467b48Spatrick return 0; // Cannot fold
25909467b48Spatrick }
26009467b48Spatrick }
26109467b48Spatrick
26209467b48Spatrick /// We have (select c, TI, FI), and we know that TI and FI have the same opcode.
foldSelectOpOp(SelectInst & SI,Instruction * TI,Instruction * FI)26373471bf0Spatrick Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
26409467b48Spatrick Instruction *FI) {
26509467b48Spatrick // Don't break up min/max patterns. The hasOneUse checks below prevent that
26609467b48Spatrick // for most cases, but vector min/max with bitcasts can be transformed. If the
26709467b48Spatrick // one-use restrictions are eased for other patterns, we still don't want to
26809467b48Spatrick // obfuscate min/max.
26909467b48Spatrick if ((match(&SI, m_SMin(m_Value(), m_Value())) ||
27009467b48Spatrick match(&SI, m_SMax(m_Value(), m_Value())) ||
27109467b48Spatrick match(&SI, m_UMin(m_Value(), m_Value())) ||
27209467b48Spatrick match(&SI, m_UMax(m_Value(), m_Value()))))
27309467b48Spatrick return nullptr;
27409467b48Spatrick
27509467b48Spatrick // If this is a cast from the same type, merge.
27609467b48Spatrick Value *Cond = SI.getCondition();
27709467b48Spatrick Type *CondTy = Cond->getType();
27809467b48Spatrick if (TI->getNumOperands() == 1 && TI->isCast()) {
27909467b48Spatrick Type *FIOpndTy = FI->getOperand(0)->getType();
28009467b48Spatrick if (TI->getOperand(0)->getType() != FIOpndTy)
28109467b48Spatrick return nullptr;
28209467b48Spatrick
28309467b48Spatrick // The select condition may be a vector. We may only change the operand
28409467b48Spatrick // type if the vector width remains the same (and matches the condition).
285097a140dSpatrick if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) {
28673471bf0Spatrick if (!FIOpndTy->isVectorTy() ||
28773471bf0Spatrick CondVTy->getElementCount() !=
28873471bf0Spatrick cast<VectorType>(FIOpndTy)->getElementCount())
28909467b48Spatrick return nullptr;
29009467b48Spatrick
29109467b48Spatrick // TODO: If the backend knew how to deal with casts better, we could
29209467b48Spatrick // remove this limitation. For now, there's too much potential to create
29309467b48Spatrick // worse codegen by promoting the select ahead of size-altering casts
29409467b48Spatrick // (PR28160).
29509467b48Spatrick //
29609467b48Spatrick // Note that ValueTracking's matchSelectPattern() looks through casts
29709467b48Spatrick // without checking 'hasOneUse' when it matches min/max patterns, so this
29809467b48Spatrick // transform may end up happening anyway.
29909467b48Spatrick if (TI->getOpcode() != Instruction::BitCast &&
30009467b48Spatrick (!TI->hasOneUse() || !FI->hasOneUse()))
30109467b48Spatrick return nullptr;
30209467b48Spatrick } else if (!TI->hasOneUse() || !FI->hasOneUse()) {
30309467b48Spatrick // TODO: The one-use restrictions for a scalar select could be eased if
30409467b48Spatrick // the fold of a select in visitLoadInst() was enhanced to match a pattern
30509467b48Spatrick // that includes a cast.
30609467b48Spatrick return nullptr;
30709467b48Spatrick }
30809467b48Spatrick
30909467b48Spatrick // Fold this by inserting a select from the input values.
31009467b48Spatrick Value *NewSI =
31109467b48Spatrick Builder.CreateSelect(Cond, TI->getOperand(0), FI->getOperand(0),
31209467b48Spatrick SI.getName() + ".v", &SI);
31309467b48Spatrick return CastInst::Create(Instruction::CastOps(TI->getOpcode()), NewSI,
31409467b48Spatrick TI->getType());
31509467b48Spatrick }
31609467b48Spatrick
317*d415bd75Srobert Value *OtherOpT, *OtherOpF;
318*d415bd75Srobert bool MatchIsOpZero;
319*d415bd75Srobert auto getCommonOp = [&](Instruction *TI, Instruction *FI, bool Commute,
320*d415bd75Srobert bool Swapped = false) -> Value * {
321*d415bd75Srobert assert(!(Commute && Swapped) &&
322*d415bd75Srobert "Commute and Swapped can't set at the same time");
323*d415bd75Srobert if (!Swapped) {
324*d415bd75Srobert if (TI->getOperand(0) == FI->getOperand(0)) {
325*d415bd75Srobert OtherOpT = TI->getOperand(1);
326*d415bd75Srobert OtherOpF = FI->getOperand(1);
327*d415bd75Srobert MatchIsOpZero = true;
328*d415bd75Srobert return TI->getOperand(0);
329*d415bd75Srobert } else if (TI->getOperand(1) == FI->getOperand(1)) {
330*d415bd75Srobert OtherOpT = TI->getOperand(0);
331*d415bd75Srobert OtherOpF = FI->getOperand(0);
332*d415bd75Srobert MatchIsOpZero = false;
333*d415bd75Srobert return TI->getOperand(1);
334*d415bd75Srobert }
33509467b48Spatrick }
33609467b48Spatrick
337*d415bd75Srobert if (!Commute && !Swapped)
338*d415bd75Srobert return nullptr;
339*d415bd75Srobert
340*d415bd75Srobert // If we are allowing commute or swap of operands, then
341*d415bd75Srobert // allow a cross-operand match. In that case, MatchIsOpZero
342*d415bd75Srobert // means that TI's operand 0 (FI's operand 1) is the common op.
343*d415bd75Srobert if (TI->getOperand(0) == FI->getOperand(1)) {
344*d415bd75Srobert OtherOpT = TI->getOperand(1);
345*d415bd75Srobert OtherOpF = FI->getOperand(0);
346*d415bd75Srobert MatchIsOpZero = true;
347*d415bd75Srobert return TI->getOperand(0);
348*d415bd75Srobert } else if (TI->getOperand(1) == FI->getOperand(0)) {
349*d415bd75Srobert OtherOpT = TI->getOperand(0);
350*d415bd75Srobert OtherOpF = FI->getOperand(1);
351*d415bd75Srobert MatchIsOpZero = false;
352*d415bd75Srobert return TI->getOperand(1);
353*d415bd75Srobert }
354*d415bd75Srobert return nullptr;
355*d415bd75Srobert };
356*d415bd75Srobert
357*d415bd75Srobert if (TI->hasOneUse() || FI->hasOneUse()) {
358*d415bd75Srobert // Cond ? -X : -Y --> -(Cond ? X : Y)
359*d415bd75Srobert Value *X, *Y;
360*d415bd75Srobert if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y)))) {
361*d415bd75Srobert // Intersect FMF from the fneg instructions and union those with the
362*d415bd75Srobert // select.
363*d415bd75Srobert FastMathFlags FMF = TI->getFastMathFlags();
364*d415bd75Srobert FMF &= FI->getFastMathFlags();
365*d415bd75Srobert FMF |= SI.getFastMathFlags();
366*d415bd75Srobert Value *NewSel =
367*d415bd75Srobert Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI);
368*d415bd75Srobert if (auto *NewSelI = dyn_cast<Instruction>(NewSel))
369*d415bd75Srobert NewSelI->setFastMathFlags(FMF);
370*d415bd75Srobert Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewSel);
371*d415bd75Srobert NewFNeg->setFastMathFlags(FMF);
372*d415bd75Srobert return NewFNeg;
373*d415bd75Srobert }
374*d415bd75Srobert
375*d415bd75Srobert // Min/max intrinsic with a common operand can have the common operand
376*d415bd75Srobert // pulled after the select. This is the same transform as below for binops,
377*d415bd75Srobert // but specialized for intrinsic matching and without the restrictive uses
378*d415bd75Srobert // clause.
37973471bf0Spatrick auto *TII = dyn_cast<IntrinsicInst>(TI);
38073471bf0Spatrick auto *FII = dyn_cast<IntrinsicInst>(FI);
381*d415bd75Srobert if (TII && FII && TII->getIntrinsicID() == FII->getIntrinsicID()) {
382*d415bd75Srobert if (match(TII, m_MaxOrMin(m_Value(), m_Value()))) {
383*d415bd75Srobert if (Value *MatchOp = getCommonOp(TI, FI, true)) {
384*d415bd75Srobert Value *NewSel =
385*d415bd75Srobert Builder.CreateSelect(Cond, OtherOpT, OtherOpF, "minmaxop", &SI);
386*d415bd75Srobert return CallInst::Create(TII->getCalledFunction(), {NewSel, MatchOp});
38773471bf0Spatrick }
38873471bf0Spatrick }
38973471bf0Spatrick }
390*d415bd75Srobert
391*d415bd75Srobert // icmp with a common operand also can have the common operand
392*d415bd75Srobert // pulled after the select.
393*d415bd75Srobert ICmpInst::Predicate TPred, FPred;
394*d415bd75Srobert if (match(TI, m_ICmp(TPred, m_Value(), m_Value())) &&
395*d415bd75Srobert match(FI, m_ICmp(FPred, m_Value(), m_Value()))) {
396*d415bd75Srobert if (TPred == FPred || TPred == CmpInst::getSwappedPredicate(FPred)) {
397*d415bd75Srobert bool Swapped = TPred != FPred;
398*d415bd75Srobert if (Value *MatchOp =
399*d415bd75Srobert getCommonOp(TI, FI, ICmpInst::isEquality(TPred), Swapped)) {
400*d415bd75Srobert Value *NewSel = Builder.CreateSelect(Cond, OtherOpT, OtherOpF,
401*d415bd75Srobert SI.getName() + ".v", &SI);
402*d415bd75Srobert return new ICmpInst(
403*d415bd75Srobert MatchIsOpZero ? TPred : CmpInst::getSwappedPredicate(TPred),
404*d415bd75Srobert MatchOp, NewSel);
405*d415bd75Srobert }
40673471bf0Spatrick }
40773471bf0Spatrick }
40873471bf0Spatrick }
40973471bf0Spatrick
41009467b48Spatrick // Only handle binary operators (including two-operand getelementptr) with
41109467b48Spatrick // one-use here. As with the cast case above, it may be possible to relax the
41209467b48Spatrick // one-use constraint, but that needs be examined carefully since it may not
41309467b48Spatrick // reduce the total number of instructions.
41409467b48Spatrick if (TI->getNumOperands() != 2 || FI->getNumOperands() != 2 ||
415*d415bd75Srobert !TI->isSameOperationAs(FI) ||
41609467b48Spatrick (!isa<BinaryOperator>(TI) && !isa<GetElementPtrInst>(TI)) ||
41709467b48Spatrick !TI->hasOneUse() || !FI->hasOneUse())
41809467b48Spatrick return nullptr;
41909467b48Spatrick
42009467b48Spatrick // Figure out if the operations have any operands in common.
421*d415bd75Srobert Value *MatchOp = getCommonOp(TI, FI, TI->isCommutative());
422*d415bd75Srobert if (!MatchOp)
42309467b48Spatrick return nullptr;
42409467b48Spatrick
42509467b48Spatrick // If the select condition is a vector, the operands of the original select's
42609467b48Spatrick // operands also must be vectors. This may not be the case for getelementptr
42709467b48Spatrick // for example.
42809467b48Spatrick if (CondTy->isVectorTy() && (!OtherOpT->getType()->isVectorTy() ||
42909467b48Spatrick !OtherOpF->getType()->isVectorTy()))
43009467b48Spatrick return nullptr;
43109467b48Spatrick
43209467b48Spatrick // If we reach here, they do have operations in common.
43309467b48Spatrick Value *NewSI = Builder.CreateSelect(Cond, OtherOpT, OtherOpF,
43409467b48Spatrick SI.getName() + ".v", &SI);
43509467b48Spatrick Value *Op0 = MatchIsOpZero ? MatchOp : NewSI;
43609467b48Spatrick Value *Op1 = MatchIsOpZero ? NewSI : MatchOp;
43709467b48Spatrick if (auto *BO = dyn_cast<BinaryOperator>(TI)) {
43809467b48Spatrick BinaryOperator *NewBO = BinaryOperator::Create(BO->getOpcode(), Op0, Op1);
43909467b48Spatrick NewBO->copyIRFlags(TI);
44009467b48Spatrick NewBO->andIRFlags(FI);
44109467b48Spatrick return NewBO;
44209467b48Spatrick }
44309467b48Spatrick if (auto *TGEP = dyn_cast<GetElementPtrInst>(TI)) {
44409467b48Spatrick auto *FGEP = cast<GetElementPtrInst>(FI);
44509467b48Spatrick Type *ElementType = TGEP->getResultElementType();
44609467b48Spatrick return TGEP->isInBounds() && FGEP->isInBounds()
44709467b48Spatrick ? GetElementPtrInst::CreateInBounds(ElementType, Op0, {Op1})
44809467b48Spatrick : GetElementPtrInst::Create(ElementType, Op0, {Op1});
44909467b48Spatrick }
45009467b48Spatrick llvm_unreachable("Expected BinaryOperator or GEP");
45109467b48Spatrick return nullptr;
45209467b48Spatrick }
45309467b48Spatrick
isSelect01(const APInt & C1I,const APInt & C2I)45409467b48Spatrick static bool isSelect01(const APInt &C1I, const APInt &C2I) {
455*d415bd75Srobert if (!C1I.isZero() && !C2I.isZero()) // One side must be zero.
45609467b48Spatrick return false;
457*d415bd75Srobert return C1I.isOne() || C1I.isAllOnes() || C2I.isOne() || C2I.isAllOnes();
45809467b48Spatrick }
45909467b48Spatrick
46009467b48Spatrick /// Try to fold the select into one of the operands to allow further
46109467b48Spatrick /// optimization.
foldSelectIntoOp(SelectInst & SI,Value * TrueVal,Value * FalseVal)46273471bf0Spatrick Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
46309467b48Spatrick Value *FalseVal) {
46409467b48Spatrick // See the comment above GetSelectFoldableOperands for a description of the
46509467b48Spatrick // transformation we are doing here.
466*d415bd75Srobert auto TryFoldSelectIntoOp = [&](SelectInst &SI, Value *TrueVal,
467*d415bd75Srobert Value *FalseVal,
468*d415bd75Srobert bool Swapped) -> Instruction * {
469*d415bd75Srobert auto *TVI = dyn_cast<BinaryOperator>(TrueVal);
470*d415bd75Srobert if (!TVI || !TVI->hasOneUse() || isa<Constant>(FalseVal))
471*d415bd75Srobert return nullptr;
47209467b48Spatrick
473*d415bd75Srobert unsigned SFO = getSelectFoldableOperands(TVI);
474*d415bd75Srobert unsigned OpToFold = 0;
475*d415bd75Srobert if ((SFO & 1) && FalseVal == TVI->getOperand(0))
476*d415bd75Srobert OpToFold = 1;
477*d415bd75Srobert else if ((SFO & 2) && FalseVal == TVI->getOperand(1))
478*d415bd75Srobert OpToFold = 2;
479*d415bd75Srobert
480*d415bd75Srobert if (!OpToFold)
481*d415bd75Srobert return nullptr;
482*d415bd75Srobert
483*d415bd75Srobert // TODO: We probably ought to revisit cases where the select and FP
484*d415bd75Srobert // instructions have different flags and add tests to ensure the
485*d415bd75Srobert // behaviour is correct.
486*d415bd75Srobert FastMathFlags FMF;
487*d415bd75Srobert if (isa<FPMathOperator>(&SI))
488*d415bd75Srobert FMF = SI.getFastMathFlags();
489*d415bd75Srobert Constant *C = ConstantExpr::getBinOpIdentity(
490*d415bd75Srobert TVI->getOpcode(), TVI->getType(), true, FMF.noSignedZeros());
49109467b48Spatrick Value *OOp = TVI->getOperand(2 - OpToFold);
49209467b48Spatrick // Avoid creating select between 2 constants unless it's selecting
49309467b48Spatrick // between 0, 1 and -1.
49409467b48Spatrick const APInt *OOpC;
49509467b48Spatrick bool OOpIsAPInt = match(OOp, m_APInt(OOpC));
49673471bf0Spatrick if (!isa<Constant>(OOp) ||
49773471bf0Spatrick (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) {
498*d415bd75Srobert Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp,
499*d415bd75Srobert Swapped ? OOp : C);
500*d415bd75Srobert if (isa<FPMathOperator>(&SI))
501*d415bd75Srobert cast<Instruction>(NewSel)->setFastMathFlags(FMF);
50209467b48Spatrick NewSel->takeName(TVI);
503*d415bd75Srobert BinaryOperator *BO =
504*d415bd75Srobert BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel);
50509467b48Spatrick BO->copyIRFlags(TVI);
50609467b48Spatrick return BO;
50709467b48Spatrick }
508*d415bd75Srobert return nullptr;
509*d415bd75Srobert };
51009467b48Spatrick
511*d415bd75Srobert if (Instruction *R = TryFoldSelectIntoOp(SI, TrueVal, FalseVal, false))
512*d415bd75Srobert return R;
51309467b48Spatrick
514*d415bd75Srobert if (Instruction *R = TryFoldSelectIntoOp(SI, FalseVal, TrueVal, true))
515*d415bd75Srobert return R;
51609467b48Spatrick
51709467b48Spatrick return nullptr;
51809467b48Spatrick }
51909467b48Spatrick
52009467b48Spatrick /// We want to turn:
52109467b48Spatrick /// (select (icmp eq (and X, Y), 0), (and (lshr X, Z), 1), 1)
52209467b48Spatrick /// into:
52309467b48Spatrick /// zext (icmp ne i32 (and X, (or Y, (shl 1, Z))), 0)
52409467b48Spatrick /// Note:
52509467b48Spatrick /// Z may be 0 if lshr is missing.
52609467b48Spatrick /// Worst-case scenario is that we will replace 5 instructions with 5 different
52709467b48Spatrick /// instructions, but we got rid of select.
foldSelectICmpAndAnd(Type * SelType,const ICmpInst * Cmp,Value * TVal,Value * FVal,InstCombiner::BuilderTy & Builder)52809467b48Spatrick static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp,
52909467b48Spatrick Value *TVal, Value *FVal,
53009467b48Spatrick InstCombiner::BuilderTy &Builder) {
53109467b48Spatrick if (!(Cmp->hasOneUse() && Cmp->getOperand(0)->hasOneUse() &&
53209467b48Spatrick Cmp->getPredicate() == ICmpInst::ICMP_EQ &&
53309467b48Spatrick match(Cmp->getOperand(1), m_Zero()) && match(FVal, m_One())))
53409467b48Spatrick return nullptr;
53509467b48Spatrick
53609467b48Spatrick // The TrueVal has general form of: and %B, 1
53709467b48Spatrick Value *B;
53809467b48Spatrick if (!match(TVal, m_OneUse(m_And(m_Value(B), m_One()))))
53909467b48Spatrick return nullptr;
54009467b48Spatrick
54109467b48Spatrick // Where %B may be optionally shifted: lshr %X, %Z.
54209467b48Spatrick Value *X, *Z;
54309467b48Spatrick const bool HasShift = match(B, m_OneUse(m_LShr(m_Value(X), m_Value(Z))));
544*d415bd75Srobert
545*d415bd75Srobert // The shift must be valid.
546*d415bd75Srobert // TODO: This restricts the fold to constant shift amounts. Is there a way to
547*d415bd75Srobert // handle variable shifts safely? PR47012
548*d415bd75Srobert if (HasShift &&
549*d415bd75Srobert !match(Z, m_SpecificInt_ICMP(CmpInst::ICMP_ULT,
550*d415bd75Srobert APInt(SelType->getScalarSizeInBits(),
551*d415bd75Srobert SelType->getScalarSizeInBits()))))
552*d415bd75Srobert return nullptr;
553*d415bd75Srobert
55409467b48Spatrick if (!HasShift)
55509467b48Spatrick X = B;
55609467b48Spatrick
55709467b48Spatrick Value *Y;
55809467b48Spatrick if (!match(Cmp->getOperand(0), m_c_And(m_Specific(X), m_Value(Y))))
55909467b48Spatrick return nullptr;
56009467b48Spatrick
56109467b48Spatrick // ((X & Y) == 0) ? ((X >> Z) & 1) : 1 --> (X & (Y | (1 << Z))) != 0
56209467b48Spatrick // ((X & Y) == 0) ? (X & 1) : 1 --> (X & (Y | 1)) != 0
56309467b48Spatrick Constant *One = ConstantInt::get(SelType, 1);
56409467b48Spatrick Value *MaskB = HasShift ? Builder.CreateShl(One, Z) : One;
56509467b48Spatrick Value *FullMask = Builder.CreateOr(Y, MaskB);
56609467b48Spatrick Value *MaskedX = Builder.CreateAnd(X, FullMask);
56709467b48Spatrick Value *ICmpNeZero = Builder.CreateIsNotNull(MaskedX);
56809467b48Spatrick return new ZExtInst(ICmpNeZero, SelType);
56909467b48Spatrick }
57009467b48Spatrick
57109467b48Spatrick /// We want to turn:
57209467b48Spatrick /// (select (icmp sgt x, C), lshr (X, Y), ashr (X, Y)); iff C s>= -1
57309467b48Spatrick /// (select (icmp slt x, C), ashr (X, Y), lshr (X, Y)); iff C s>= 0
57409467b48Spatrick /// into:
57509467b48Spatrick /// ashr (X, Y)
foldSelectICmpLshrAshr(const ICmpInst * IC,Value * TrueVal,Value * FalseVal,InstCombiner::BuilderTy & Builder)57609467b48Spatrick static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
57709467b48Spatrick Value *FalseVal,
57809467b48Spatrick InstCombiner::BuilderTy &Builder) {
57909467b48Spatrick ICmpInst::Predicate Pred = IC->getPredicate();
58009467b48Spatrick Value *CmpLHS = IC->getOperand(0);
58109467b48Spatrick Value *CmpRHS = IC->getOperand(1);
58209467b48Spatrick if (!CmpRHS->getType()->isIntOrIntVectorTy())
58309467b48Spatrick return nullptr;
58409467b48Spatrick
58509467b48Spatrick Value *X, *Y;
58609467b48Spatrick unsigned Bitwidth = CmpRHS->getType()->getScalarSizeInBits();
58709467b48Spatrick if ((Pred != ICmpInst::ICMP_SGT ||
58809467b48Spatrick !match(CmpRHS,
58909467b48Spatrick m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, -1)))) &&
59009467b48Spatrick (Pred != ICmpInst::ICMP_SLT ||
59109467b48Spatrick !match(CmpRHS,
59209467b48Spatrick m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, 0)))))
59309467b48Spatrick return nullptr;
59409467b48Spatrick
59509467b48Spatrick // Canonicalize so that ashr is in FalseVal.
59609467b48Spatrick if (Pred == ICmpInst::ICMP_SLT)
59709467b48Spatrick std::swap(TrueVal, FalseVal);
59809467b48Spatrick
59909467b48Spatrick if (match(TrueVal, m_LShr(m_Value(X), m_Value(Y))) &&
60009467b48Spatrick match(FalseVal, m_AShr(m_Specific(X), m_Specific(Y))) &&
60109467b48Spatrick match(CmpLHS, m_Specific(X))) {
60209467b48Spatrick const auto *Ashr = cast<Instruction>(FalseVal);
60309467b48Spatrick // if lshr is not exact and ashr is, this new ashr must not be exact.
60409467b48Spatrick bool IsExact = Ashr->isExact() && cast<Instruction>(TrueVal)->isExact();
60509467b48Spatrick return Builder.CreateAShr(X, Y, IC->getName(), IsExact);
60609467b48Spatrick }
60709467b48Spatrick
60809467b48Spatrick return nullptr;
60909467b48Spatrick }
61009467b48Spatrick
61109467b48Spatrick /// We want to turn:
61209467b48Spatrick /// (select (icmp eq (and X, C1), 0), Y, (or Y, C2))
61309467b48Spatrick /// into:
61409467b48Spatrick /// (or (shl (and X, C1), C3), Y)
61509467b48Spatrick /// iff:
61609467b48Spatrick /// C1 and C2 are both powers of 2
61709467b48Spatrick /// where:
61809467b48Spatrick /// C3 = Log(C2) - Log(C1)
61909467b48Spatrick ///
62009467b48Spatrick /// This transform handles cases where:
62109467b48Spatrick /// 1. The icmp predicate is inverted
62209467b48Spatrick /// 2. The select operands are reversed
62309467b48Spatrick /// 3. The magnitude of C2 and C1 are flipped
foldSelectICmpAndOr(const ICmpInst * IC,Value * TrueVal,Value * FalseVal,InstCombiner::BuilderTy & Builder)62409467b48Spatrick static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal,
62509467b48Spatrick Value *FalseVal,
62609467b48Spatrick InstCombiner::BuilderTy &Builder) {
62709467b48Spatrick // Only handle integer compares. Also, if this is a vector select, we need a
62809467b48Spatrick // vector compare.
62909467b48Spatrick if (!TrueVal->getType()->isIntOrIntVectorTy() ||
63009467b48Spatrick TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy())
63109467b48Spatrick return nullptr;
63209467b48Spatrick
63309467b48Spatrick Value *CmpLHS = IC->getOperand(0);
63409467b48Spatrick Value *CmpRHS = IC->getOperand(1);
63509467b48Spatrick
63609467b48Spatrick Value *V;
63709467b48Spatrick unsigned C1Log;
63809467b48Spatrick bool IsEqualZero;
63909467b48Spatrick bool NeedAnd = false;
64009467b48Spatrick if (IC->isEquality()) {
64109467b48Spatrick if (!match(CmpRHS, m_Zero()))
64209467b48Spatrick return nullptr;
64309467b48Spatrick
64409467b48Spatrick const APInt *C1;
64509467b48Spatrick if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
64609467b48Spatrick return nullptr;
64709467b48Spatrick
64809467b48Spatrick V = CmpLHS;
64909467b48Spatrick C1Log = C1->logBase2();
65009467b48Spatrick IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_EQ;
65109467b48Spatrick } else if (IC->getPredicate() == ICmpInst::ICMP_SLT ||
65209467b48Spatrick IC->getPredicate() == ICmpInst::ICMP_SGT) {
65309467b48Spatrick // We also need to recognize (icmp slt (trunc (X)), 0) and
65409467b48Spatrick // (icmp sgt (trunc (X)), -1).
65509467b48Spatrick IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_SGT;
65609467b48Spatrick if ((IsEqualZero && !match(CmpRHS, m_AllOnes())) ||
65709467b48Spatrick (!IsEqualZero && !match(CmpRHS, m_Zero())))
65809467b48Spatrick return nullptr;
65909467b48Spatrick
66009467b48Spatrick if (!match(CmpLHS, m_OneUse(m_Trunc(m_Value(V)))))
66109467b48Spatrick return nullptr;
66209467b48Spatrick
66309467b48Spatrick C1Log = CmpLHS->getType()->getScalarSizeInBits() - 1;
66409467b48Spatrick NeedAnd = true;
66509467b48Spatrick } else {
66609467b48Spatrick return nullptr;
66709467b48Spatrick }
66809467b48Spatrick
66909467b48Spatrick const APInt *C2;
67009467b48Spatrick bool OrOnTrueVal = false;
67109467b48Spatrick bool OrOnFalseVal = match(FalseVal, m_Or(m_Specific(TrueVal), m_Power2(C2)));
67209467b48Spatrick if (!OrOnFalseVal)
67309467b48Spatrick OrOnTrueVal = match(TrueVal, m_Or(m_Specific(FalseVal), m_Power2(C2)));
67409467b48Spatrick
67509467b48Spatrick if (!OrOnFalseVal && !OrOnTrueVal)
67609467b48Spatrick return nullptr;
67709467b48Spatrick
67809467b48Spatrick Value *Y = OrOnFalseVal ? TrueVal : FalseVal;
67909467b48Spatrick
68009467b48Spatrick unsigned C2Log = C2->logBase2();
68109467b48Spatrick
68209467b48Spatrick bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal);
68309467b48Spatrick bool NeedShift = C1Log != C2Log;
68409467b48Spatrick bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() !=
68509467b48Spatrick V->getType()->getScalarSizeInBits();
68609467b48Spatrick
68709467b48Spatrick // Make sure we don't create more instructions than we save.
68809467b48Spatrick Value *Or = OrOnFalseVal ? FalseVal : TrueVal;
68909467b48Spatrick if ((NeedShift + NeedXor + NeedZExtTrunc) >
69009467b48Spatrick (IC->hasOneUse() + Or->hasOneUse()))
69109467b48Spatrick return nullptr;
69209467b48Spatrick
69309467b48Spatrick if (NeedAnd) {
69409467b48Spatrick // Insert the AND instruction on the input to the truncate.
69509467b48Spatrick APInt C1 = APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), C1Log);
69609467b48Spatrick V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), C1));
69709467b48Spatrick }
69809467b48Spatrick
69909467b48Spatrick if (C2Log > C1Log) {
70009467b48Spatrick V = Builder.CreateZExtOrTrunc(V, Y->getType());
70109467b48Spatrick V = Builder.CreateShl(V, C2Log - C1Log);
70209467b48Spatrick } else if (C1Log > C2Log) {
70309467b48Spatrick V = Builder.CreateLShr(V, C1Log - C2Log);
70409467b48Spatrick V = Builder.CreateZExtOrTrunc(V, Y->getType());
70509467b48Spatrick } else
70609467b48Spatrick V = Builder.CreateZExtOrTrunc(V, Y->getType());
70709467b48Spatrick
70809467b48Spatrick if (NeedXor)
70909467b48Spatrick V = Builder.CreateXor(V, *C2);
71009467b48Spatrick
71109467b48Spatrick return Builder.CreateOr(V, Y);
71209467b48Spatrick }
71309467b48Spatrick
714097a140dSpatrick /// Canonicalize a set or clear of a masked set of constant bits to
715097a140dSpatrick /// select-of-constants form.
foldSetClearBits(SelectInst & Sel,InstCombiner::BuilderTy & Builder)716097a140dSpatrick static Instruction *foldSetClearBits(SelectInst &Sel,
717097a140dSpatrick InstCombiner::BuilderTy &Builder) {
718097a140dSpatrick Value *Cond = Sel.getCondition();
719097a140dSpatrick Value *T = Sel.getTrueValue();
720097a140dSpatrick Value *F = Sel.getFalseValue();
721097a140dSpatrick Type *Ty = Sel.getType();
722097a140dSpatrick Value *X;
723097a140dSpatrick const APInt *NotC, *C;
724097a140dSpatrick
725097a140dSpatrick // Cond ? (X & ~C) : (X | C) --> (X & ~C) | (Cond ? 0 : C)
726097a140dSpatrick if (match(T, m_And(m_Value(X), m_APInt(NotC))) &&
727097a140dSpatrick match(F, m_OneUse(m_Or(m_Specific(X), m_APInt(C)))) && *NotC == ~(*C)) {
728097a140dSpatrick Constant *Zero = ConstantInt::getNullValue(Ty);
729097a140dSpatrick Constant *OrC = ConstantInt::get(Ty, *C);
730097a140dSpatrick Value *NewSel = Builder.CreateSelect(Cond, Zero, OrC, "masksel", &Sel);
731097a140dSpatrick return BinaryOperator::CreateOr(T, NewSel);
732097a140dSpatrick }
733097a140dSpatrick
734097a140dSpatrick // Cond ? (X | C) : (X & ~C) --> (X & ~C) | (Cond ? C : 0)
735097a140dSpatrick if (match(F, m_And(m_Value(X), m_APInt(NotC))) &&
736097a140dSpatrick match(T, m_OneUse(m_Or(m_Specific(X), m_APInt(C)))) && *NotC == ~(*C)) {
737097a140dSpatrick Constant *Zero = ConstantInt::getNullValue(Ty);
738097a140dSpatrick Constant *OrC = ConstantInt::get(Ty, *C);
739097a140dSpatrick Value *NewSel = Builder.CreateSelect(Cond, OrC, Zero, "masksel", &Sel);
740097a140dSpatrick return BinaryOperator::CreateOr(F, NewSel);
741097a140dSpatrick }
742097a140dSpatrick
743097a140dSpatrick return nullptr;
744097a140dSpatrick }
745097a140dSpatrick
746*d415bd75Srobert // select (x == 0), 0, x * y --> freeze(y) * x
747*d415bd75Srobert // select (y == 0), 0, x * y --> freeze(x) * y
748*d415bd75Srobert // select (x == 0), undef, x * y --> freeze(y) * x
749*d415bd75Srobert // select (x == undef), 0, x * y --> freeze(y) * x
750*d415bd75Srobert // Usage of mul instead of 0 will make the result more poisonous,
751*d415bd75Srobert // so the operand that was not checked in the condition should be frozen.
752*d415bd75Srobert // The latter folding is applied only when a constant compared with x is
753*d415bd75Srobert // is a vector consisting of 0 and undefs. If a constant compared with x
754*d415bd75Srobert // is a scalar undefined value or undefined vector then an expression
755*d415bd75Srobert // should be already folded into a constant.
foldSelectZeroOrMul(SelectInst & SI,InstCombinerImpl & IC)756*d415bd75Srobert static Instruction *foldSelectZeroOrMul(SelectInst &SI, InstCombinerImpl &IC) {
757*d415bd75Srobert auto *CondVal = SI.getCondition();
758*d415bd75Srobert auto *TrueVal = SI.getTrueValue();
759*d415bd75Srobert auto *FalseVal = SI.getFalseValue();
760*d415bd75Srobert Value *X, *Y;
761*d415bd75Srobert ICmpInst::Predicate Predicate;
762*d415bd75Srobert
763*d415bd75Srobert // Assuming that constant compared with zero is not undef (but it may be
764*d415bd75Srobert // a vector with some undef elements). Otherwise (when a constant is undef)
765*d415bd75Srobert // the select expression should be already simplified.
766*d415bd75Srobert if (!match(CondVal, m_ICmp(Predicate, m_Value(X), m_Zero())) ||
767*d415bd75Srobert !ICmpInst::isEquality(Predicate))
768*d415bd75Srobert return nullptr;
769*d415bd75Srobert
770*d415bd75Srobert if (Predicate == ICmpInst::ICMP_NE)
771*d415bd75Srobert std::swap(TrueVal, FalseVal);
772*d415bd75Srobert
773*d415bd75Srobert // Check that TrueVal is a constant instead of matching it with m_Zero()
774*d415bd75Srobert // to handle the case when it is a scalar undef value or a vector containing
775*d415bd75Srobert // non-zero elements that are masked by undef elements in the compare
776*d415bd75Srobert // constant.
777*d415bd75Srobert auto *TrueValC = dyn_cast<Constant>(TrueVal);
778*d415bd75Srobert if (TrueValC == nullptr ||
779*d415bd75Srobert !match(FalseVal, m_c_Mul(m_Specific(X), m_Value(Y))) ||
780*d415bd75Srobert !isa<Instruction>(FalseVal))
781*d415bd75Srobert return nullptr;
782*d415bd75Srobert
783*d415bd75Srobert auto *ZeroC = cast<Constant>(cast<Instruction>(CondVal)->getOperand(1));
784*d415bd75Srobert auto *MergedC = Constant::mergeUndefsWith(TrueValC, ZeroC);
785*d415bd75Srobert // If X is compared with 0 then TrueVal could be either zero or undef.
786*d415bd75Srobert // m_Zero match vectors containing some undef elements, but for scalars
787*d415bd75Srobert // m_Undef should be used explicitly.
788*d415bd75Srobert if (!match(MergedC, m_Zero()) && !match(MergedC, m_Undef()))
789*d415bd75Srobert return nullptr;
790*d415bd75Srobert
791*d415bd75Srobert auto *FalseValI = cast<Instruction>(FalseVal);
792*d415bd75Srobert auto *FrY = IC.InsertNewInstBefore(new FreezeInst(Y, Y->getName() + ".fr"),
793*d415bd75Srobert *FalseValI);
794*d415bd75Srobert IC.replaceOperand(*FalseValI, FalseValI->getOperand(0) == Y ? 0 : 1, FrY);
795*d415bd75Srobert return IC.replaceInstUsesWith(SI, FalseValI);
796*d415bd75Srobert }
797*d415bd75Srobert
79809467b48Spatrick /// Transform patterns such as (a > b) ? a - b : 0 into usub.sat(a, b).
79909467b48Spatrick /// There are 8 commuted/swapped variants of this pattern.
80009467b48Spatrick /// TODO: Also support a - UMIN(a,b) patterns.
canonicalizeSaturatedSubtract(const ICmpInst * ICI,const Value * TrueVal,const Value * FalseVal,InstCombiner::BuilderTy & Builder)80109467b48Spatrick static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI,
80209467b48Spatrick const Value *TrueVal,
80309467b48Spatrick const Value *FalseVal,
80409467b48Spatrick InstCombiner::BuilderTy &Builder) {
80509467b48Spatrick ICmpInst::Predicate Pred = ICI->getPredicate();
806*d415bd75Srobert Value *A = ICI->getOperand(0);
807*d415bd75Srobert Value *B = ICI->getOperand(1);
80809467b48Spatrick
80909467b48Spatrick // (b > a) ? 0 : a - b -> (b <= a) ? a - b : 0
810*d415bd75Srobert // (a == 0) ? 0 : a - 1 -> (a != 0) ? a - 1 : 0
81109467b48Spatrick if (match(TrueVal, m_Zero())) {
81209467b48Spatrick Pred = ICmpInst::getInversePredicate(Pred);
81309467b48Spatrick std::swap(TrueVal, FalseVal);
81409467b48Spatrick }
815*d415bd75Srobert
81609467b48Spatrick if (!match(FalseVal, m_Zero()))
81709467b48Spatrick return nullptr;
81809467b48Spatrick
819*d415bd75Srobert // ugt 0 is canonicalized to ne 0 and requires special handling
820*d415bd75Srobert // (a != 0) ? a + -1 : 0 -> usub.sat(a, 1)
821*d415bd75Srobert if (Pred == ICmpInst::ICMP_NE) {
822*d415bd75Srobert if (match(B, m_Zero()) && match(TrueVal, m_Add(m_Specific(A), m_AllOnes())))
823*d415bd75Srobert return Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A,
824*d415bd75Srobert ConstantInt::get(A->getType(), 1));
825*d415bd75Srobert return nullptr;
826*d415bd75Srobert }
827*d415bd75Srobert
828*d415bd75Srobert if (!ICmpInst::isUnsigned(Pred))
829*d415bd75Srobert return nullptr;
830*d415bd75Srobert
83109467b48Spatrick if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_ULT) {
83209467b48Spatrick // (b < a) ? a - b : 0 -> (a > b) ? a - b : 0
83309467b48Spatrick std::swap(A, B);
83409467b48Spatrick Pred = ICmpInst::getSwappedPredicate(Pred);
83509467b48Spatrick }
83609467b48Spatrick
83709467b48Spatrick assert((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) &&
83809467b48Spatrick "Unexpected isUnsigned predicate!");
83909467b48Spatrick
84009467b48Spatrick // Ensure the sub is of the form:
84109467b48Spatrick // (a > b) ? a - b : 0 -> usub.sat(a, b)
84209467b48Spatrick // (a > b) ? b - a : 0 -> -usub.sat(a, b)
84309467b48Spatrick // Checking for both a-b and a+(-b) as a constant.
84409467b48Spatrick bool IsNegative = false;
84509467b48Spatrick const APInt *C;
84609467b48Spatrick if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A))) ||
84709467b48Spatrick (match(A, m_APInt(C)) &&
84809467b48Spatrick match(TrueVal, m_Add(m_Specific(B), m_SpecificInt(-*C)))))
84909467b48Spatrick IsNegative = true;
85009467b48Spatrick else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B))) &&
85109467b48Spatrick !(match(B, m_APInt(C)) &&
85209467b48Spatrick match(TrueVal, m_Add(m_Specific(A), m_SpecificInt(-*C)))))
85309467b48Spatrick return nullptr;
85409467b48Spatrick
85509467b48Spatrick // If we are adding a negate and the sub and icmp are used anywhere else, we
85609467b48Spatrick // would end up with more instructions.
85709467b48Spatrick if (IsNegative && !TrueVal->hasOneUse() && !ICI->hasOneUse())
85809467b48Spatrick return nullptr;
85909467b48Spatrick
86009467b48Spatrick // (a > b) ? a - b : 0 -> usub.sat(a, b)
86109467b48Spatrick // (a > b) ? b - a : 0 -> -usub.sat(a, b)
86209467b48Spatrick Value *Result = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, B);
86309467b48Spatrick if (IsNegative)
86409467b48Spatrick Result = Builder.CreateNeg(Result);
86509467b48Spatrick return Result;
86609467b48Spatrick }
86709467b48Spatrick
canonicalizeSaturatedAdd(ICmpInst * Cmp,Value * TVal,Value * FVal,InstCombiner::BuilderTy & Builder)86809467b48Spatrick static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
86909467b48Spatrick InstCombiner::BuilderTy &Builder) {
87009467b48Spatrick if (!Cmp->hasOneUse())
87109467b48Spatrick return nullptr;
87209467b48Spatrick
87309467b48Spatrick // Match unsigned saturated add with constant.
87409467b48Spatrick Value *Cmp0 = Cmp->getOperand(0);
87509467b48Spatrick Value *Cmp1 = Cmp->getOperand(1);
87609467b48Spatrick ICmpInst::Predicate Pred = Cmp->getPredicate();
87709467b48Spatrick Value *X;
87809467b48Spatrick const APInt *C, *CmpC;
87909467b48Spatrick if (Pred == ICmpInst::ICMP_ULT &&
88009467b48Spatrick match(TVal, m_Add(m_Value(X), m_APInt(C))) && X == Cmp0 &&
88109467b48Spatrick match(FVal, m_AllOnes()) && match(Cmp1, m_APInt(CmpC)) && *CmpC == ~*C) {
88209467b48Spatrick // (X u< ~C) ? (X + C) : -1 --> uadd.sat(X, C)
88309467b48Spatrick return Builder.CreateBinaryIntrinsic(
88409467b48Spatrick Intrinsic::uadd_sat, X, ConstantInt::get(X->getType(), *C));
88509467b48Spatrick }
88609467b48Spatrick
88709467b48Spatrick // Match unsigned saturated add of 2 variables with an unnecessary 'not'.
88809467b48Spatrick // There are 8 commuted variants.
889097a140dSpatrick // Canonicalize -1 (saturated result) to true value of the select.
89009467b48Spatrick if (match(FVal, m_AllOnes())) {
89109467b48Spatrick std::swap(TVal, FVal);
892097a140dSpatrick Pred = CmpInst::getInversePredicate(Pred);
89309467b48Spatrick }
89409467b48Spatrick if (!match(TVal, m_AllOnes()))
89509467b48Spatrick return nullptr;
89609467b48Spatrick
897097a140dSpatrick // Canonicalize predicate to less-than or less-or-equal-than.
898097a140dSpatrick if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
89909467b48Spatrick std::swap(Cmp0, Cmp1);
900097a140dSpatrick Pred = CmpInst::getSwappedPredicate(Pred);
90109467b48Spatrick }
902097a140dSpatrick if (Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_ULE)
90309467b48Spatrick return nullptr;
90409467b48Spatrick
90509467b48Spatrick // Match unsigned saturated add of 2 variables with an unnecessary 'not'.
906097a140dSpatrick // Strictness of the comparison is irrelevant.
90709467b48Spatrick Value *Y;
90809467b48Spatrick if (match(Cmp0, m_Not(m_Value(X))) &&
90909467b48Spatrick match(FVal, m_c_Add(m_Specific(X), m_Value(Y))) && Y == Cmp1) {
91009467b48Spatrick // (~X u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y)
91109467b48Spatrick // (~X u< Y) ? -1 : (Y + X) --> uadd.sat(X, Y)
91209467b48Spatrick return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, X, Y);
91309467b48Spatrick }
91409467b48Spatrick // The 'not' op may be included in the sum but not the compare.
915097a140dSpatrick // Strictness of the comparison is irrelevant.
91609467b48Spatrick X = Cmp0;
91709467b48Spatrick Y = Cmp1;
91809467b48Spatrick if (match(FVal, m_c_Add(m_Not(m_Specific(X)), m_Specific(Y)))) {
91909467b48Spatrick // (X u< Y) ? -1 : (~X + Y) --> uadd.sat(~X, Y)
92009467b48Spatrick // (X u< Y) ? -1 : (Y + ~X) --> uadd.sat(Y, ~X)
92109467b48Spatrick BinaryOperator *BO = cast<BinaryOperator>(FVal);
92209467b48Spatrick return Builder.CreateBinaryIntrinsic(
92309467b48Spatrick Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1));
92409467b48Spatrick }
92509467b48Spatrick // The overflow may be detected via the add wrapping round.
926097a140dSpatrick // This is only valid for strict comparison!
927097a140dSpatrick if (Pred == ICmpInst::ICMP_ULT &&
928097a140dSpatrick match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) &&
92909467b48Spatrick match(FVal, m_c_Add(m_Specific(Cmp1), m_Specific(Y)))) {
93009467b48Spatrick // ((X + Y) u< X) ? -1 : (X + Y) --> uadd.sat(X, Y)
93109467b48Spatrick // ((X + Y) u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y)
93209467b48Spatrick return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp1, Y);
93309467b48Spatrick }
93409467b48Spatrick
93509467b48Spatrick return nullptr;
93609467b48Spatrick }
93709467b48Spatrick
93809467b48Spatrick /// Fold the following code sequence:
93909467b48Spatrick /// \code
94009467b48Spatrick /// int a = ctlz(x & -x);
94109467b48Spatrick // x ? 31 - a : a;
94209467b48Spatrick /// \code
94309467b48Spatrick ///
94409467b48Spatrick /// into:
94509467b48Spatrick /// cttz(x)
foldSelectCtlzToCttz(ICmpInst * ICI,Value * TrueVal,Value * FalseVal,InstCombiner::BuilderTy & Builder)94609467b48Spatrick static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal,
94709467b48Spatrick Value *FalseVal,
94809467b48Spatrick InstCombiner::BuilderTy &Builder) {
94909467b48Spatrick unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits();
95009467b48Spatrick if (!ICI->isEquality() || !match(ICI->getOperand(1), m_Zero()))
95109467b48Spatrick return nullptr;
95209467b48Spatrick
95309467b48Spatrick if (ICI->getPredicate() == ICmpInst::ICMP_NE)
95409467b48Spatrick std::swap(TrueVal, FalseVal);
95509467b48Spatrick
95609467b48Spatrick if (!match(FalseVal,
95709467b48Spatrick m_Xor(m_Deferred(TrueVal), m_SpecificInt(BitWidth - 1))))
95809467b48Spatrick return nullptr;
95909467b48Spatrick
96009467b48Spatrick if (!match(TrueVal, m_Intrinsic<Intrinsic::ctlz>()))
96109467b48Spatrick return nullptr;
96209467b48Spatrick
96309467b48Spatrick Value *X = ICI->getOperand(0);
96409467b48Spatrick auto *II = cast<IntrinsicInst>(TrueVal);
96509467b48Spatrick if (!match(II->getOperand(0), m_c_And(m_Specific(X), m_Neg(m_Specific(X)))))
96609467b48Spatrick return nullptr;
96709467b48Spatrick
96809467b48Spatrick Function *F = Intrinsic::getDeclaration(II->getModule(), Intrinsic::cttz,
96909467b48Spatrick II->getType());
97009467b48Spatrick return CallInst::Create(F, {X, II->getArgOperand(1)});
97109467b48Spatrick }
97209467b48Spatrick
97309467b48Spatrick /// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single
974*d415bd75Srobert /// call to cttz/ctlz with flag 'is_zero_poison' cleared.
97509467b48Spatrick ///
97609467b48Spatrick /// For example, we can fold the following code sequence:
97709467b48Spatrick /// \code
97809467b48Spatrick /// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 true)
97909467b48Spatrick /// %1 = icmp ne i32 %x, 0
98009467b48Spatrick /// %2 = select i1 %1, i32 %0, i32 32
98109467b48Spatrick /// \code
98209467b48Spatrick ///
98309467b48Spatrick /// into:
98409467b48Spatrick /// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false)
foldSelectCttzCtlz(ICmpInst * ICI,Value * TrueVal,Value * FalseVal,InstCombiner::BuilderTy & Builder)98509467b48Spatrick static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
98609467b48Spatrick InstCombiner::BuilderTy &Builder) {
98709467b48Spatrick ICmpInst::Predicate Pred = ICI->getPredicate();
98809467b48Spatrick Value *CmpLHS = ICI->getOperand(0);
98909467b48Spatrick Value *CmpRHS = ICI->getOperand(1);
99009467b48Spatrick
991*d415bd75Srobert // Check if the select condition compares a value for equality.
992*d415bd75Srobert if (!ICI->isEquality())
99309467b48Spatrick return nullptr;
99409467b48Spatrick
995097a140dSpatrick Value *SelectArg = FalseVal;
99609467b48Spatrick Value *ValueOnZero = TrueVal;
99709467b48Spatrick if (Pred == ICmpInst::ICMP_NE)
998097a140dSpatrick std::swap(SelectArg, ValueOnZero);
99909467b48Spatrick
100009467b48Spatrick // Skip zero extend/truncate.
1001097a140dSpatrick Value *Count = nullptr;
1002097a140dSpatrick if (!match(SelectArg, m_ZExt(m_Value(Count))) &&
1003097a140dSpatrick !match(SelectArg, m_Trunc(m_Value(Count))))
1004097a140dSpatrick Count = SelectArg;
100509467b48Spatrick
100609467b48Spatrick // Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the
100709467b48Spatrick // input to the cttz/ctlz is used as LHS for the compare instruction.
1008*d415bd75Srobert Value *X;
1009*d415bd75Srobert if (!match(Count, m_Intrinsic<Intrinsic::cttz>(m_Value(X))) &&
1010*d415bd75Srobert !match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Value(X))))
1011*d415bd75Srobert return nullptr;
1012*d415bd75Srobert
1013*d415bd75Srobert // (X == 0) ? BitWidth : ctz(X)
1014*d415bd75Srobert // (X == -1) ? BitWidth : ctz(~X)
1015*d415bd75Srobert if ((X != CmpLHS || !match(CmpRHS, m_Zero())) &&
1016*d415bd75Srobert (!match(X, m_Not(m_Specific(CmpLHS))) || !match(CmpRHS, m_AllOnes())))
101709467b48Spatrick return nullptr;
101809467b48Spatrick
101909467b48Spatrick IntrinsicInst *II = cast<IntrinsicInst>(Count);
102009467b48Spatrick
102109467b48Spatrick // Check if the value propagated on zero is a constant number equal to the
102209467b48Spatrick // sizeof in bits of 'Count'.
102309467b48Spatrick unsigned SizeOfInBits = Count->getType()->getScalarSizeInBits();
102409467b48Spatrick if (match(ValueOnZero, m_SpecificInt(SizeOfInBits))) {
1025*d415bd75Srobert // Explicitly clear the 'is_zero_poison' flag. It's always valid to go from
1026097a140dSpatrick // true to false on this flag, so we can replace it for all users.
1027097a140dSpatrick II->setArgOperand(1, ConstantInt::getFalse(II->getContext()));
1028097a140dSpatrick return SelectArg;
102909467b48Spatrick }
103009467b48Spatrick
1031097a140dSpatrick // The ValueOnZero is not the bitwidth. But if the cttz/ctlz (and optional
1032097a140dSpatrick // zext/trunc) have one use (ending at the select), the cttz/ctlz result will
1033*d415bd75Srobert // not be used if the input is zero. Relax to 'zero is poison' for that case.
1034097a140dSpatrick if (II->hasOneUse() && SelectArg->hasOneUse() &&
1035097a140dSpatrick !match(II->getArgOperand(1), m_One()))
103609467b48Spatrick II->setArgOperand(1, ConstantInt::getTrue(II->getContext()));
103709467b48Spatrick
103809467b48Spatrick return nullptr;
103909467b48Spatrick }
104009467b48Spatrick
104109467b48Spatrick /// Return true if we find and adjust an icmp+select pattern where the compare
104209467b48Spatrick /// is with a constant that can be incremented or decremented to match the
104309467b48Spatrick /// minimum or maximum idiom.
adjustMinMax(SelectInst & Sel,ICmpInst & Cmp)104409467b48Spatrick static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) {
104509467b48Spatrick ICmpInst::Predicate Pred = Cmp.getPredicate();
104609467b48Spatrick Value *CmpLHS = Cmp.getOperand(0);
104709467b48Spatrick Value *CmpRHS = Cmp.getOperand(1);
104809467b48Spatrick Value *TrueVal = Sel.getTrueValue();
104909467b48Spatrick Value *FalseVal = Sel.getFalseValue();
105009467b48Spatrick
105109467b48Spatrick // We may move or edit the compare, so make sure the select is the only user.
105209467b48Spatrick const APInt *CmpC;
105309467b48Spatrick if (!Cmp.hasOneUse() || !match(CmpRHS, m_APInt(CmpC)))
105409467b48Spatrick return false;
105509467b48Spatrick
105609467b48Spatrick // These transforms only work for selects of integers or vector selects of
105709467b48Spatrick // integer vectors.
105809467b48Spatrick Type *SelTy = Sel.getType();
105909467b48Spatrick auto *SelEltTy = dyn_cast<IntegerType>(SelTy->getScalarType());
106009467b48Spatrick if (!SelEltTy || SelTy->isVectorTy() != Cmp.getType()->isVectorTy())
106109467b48Spatrick return false;
106209467b48Spatrick
106309467b48Spatrick Constant *AdjustedRHS;
106409467b48Spatrick if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT)
106509467b48Spatrick AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC + 1);
106609467b48Spatrick else if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT)
106709467b48Spatrick AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC - 1);
106809467b48Spatrick else
106909467b48Spatrick return false;
107009467b48Spatrick
107109467b48Spatrick // X > C ? X : C+1 --> X < C+1 ? C+1 : X
107209467b48Spatrick // X < C ? X : C-1 --> X > C-1 ? C-1 : X
107309467b48Spatrick if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) ||
107409467b48Spatrick (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) {
107509467b48Spatrick ; // Nothing to do here. Values match without any sign/zero extension.
107609467b48Spatrick }
107709467b48Spatrick // Types do not match. Instead of calculating this with mixed types, promote
107809467b48Spatrick // all to the larger type. This enables scalar evolution to analyze this
107909467b48Spatrick // expression.
108009467b48Spatrick else if (CmpRHS->getType()->getScalarSizeInBits() < SelEltTy->getBitWidth()) {
108109467b48Spatrick Constant *SextRHS = ConstantExpr::getSExt(AdjustedRHS, SelTy);
108209467b48Spatrick
108309467b48Spatrick // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X
108409467b48Spatrick // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X
108509467b48Spatrick // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X
108609467b48Spatrick // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X
108709467b48Spatrick if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) && SextRHS == FalseVal) {
108809467b48Spatrick CmpLHS = TrueVal;
108909467b48Spatrick AdjustedRHS = SextRHS;
109009467b48Spatrick } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) &&
109109467b48Spatrick SextRHS == TrueVal) {
109209467b48Spatrick CmpLHS = FalseVal;
109309467b48Spatrick AdjustedRHS = SextRHS;
109409467b48Spatrick } else if (Cmp.isUnsigned()) {
109509467b48Spatrick Constant *ZextRHS = ConstantExpr::getZExt(AdjustedRHS, SelTy);
109609467b48Spatrick // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X
109709467b48Spatrick // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X
109809467b48Spatrick // zext + signed compare cannot be changed:
109909467b48Spatrick // 0xff <s 0x00, but 0x00ff >s 0x0000
110009467b48Spatrick if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) && ZextRHS == FalseVal) {
110109467b48Spatrick CmpLHS = TrueVal;
110209467b48Spatrick AdjustedRHS = ZextRHS;
110309467b48Spatrick } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) &&
110409467b48Spatrick ZextRHS == TrueVal) {
110509467b48Spatrick CmpLHS = FalseVal;
110609467b48Spatrick AdjustedRHS = ZextRHS;
110709467b48Spatrick } else {
110809467b48Spatrick return false;
110909467b48Spatrick }
111009467b48Spatrick } else {
111109467b48Spatrick return false;
111209467b48Spatrick }
111309467b48Spatrick } else {
111409467b48Spatrick return false;
111509467b48Spatrick }
111609467b48Spatrick
111709467b48Spatrick Pred = ICmpInst::getSwappedPredicate(Pred);
111809467b48Spatrick CmpRHS = AdjustedRHS;
111909467b48Spatrick std::swap(FalseVal, TrueVal);
112009467b48Spatrick Cmp.setPredicate(Pred);
112109467b48Spatrick Cmp.setOperand(0, CmpLHS);
112209467b48Spatrick Cmp.setOperand(1, CmpRHS);
112309467b48Spatrick Sel.setOperand(1, TrueVal);
112409467b48Spatrick Sel.setOperand(2, FalseVal);
112509467b48Spatrick Sel.swapProfMetadata();
112609467b48Spatrick
112709467b48Spatrick // Move the compare instruction right before the select instruction. Otherwise
112809467b48Spatrick // the sext/zext value may be defined after the compare instruction uses it.
112909467b48Spatrick Cmp.moveBefore(&Sel);
113009467b48Spatrick
113109467b48Spatrick return true;
113209467b48Spatrick }
113309467b48Spatrick
canonicalizeSPF(SelectInst & Sel,ICmpInst & Cmp,InstCombinerImpl & IC)1134*d415bd75Srobert static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp,
113573471bf0Spatrick InstCombinerImpl &IC) {
113609467b48Spatrick Value *LHS, *RHS;
1137*d415bd75Srobert // TODO: What to do with pointer min/max patterns?
1138*d415bd75Srobert if (!Sel.getType()->isIntOrIntVectorTy())
113909467b48Spatrick return nullptr;
114009467b48Spatrick
114109467b48Spatrick SelectPatternFlavor SPF = matchSelectPattern(&Sel, LHS, RHS).Flavor;
1142*d415bd75Srobert if (SPF == SelectPatternFlavor::SPF_ABS ||
1143*d415bd75Srobert SPF == SelectPatternFlavor::SPF_NABS) {
1144*d415bd75Srobert if (!Cmp.hasOneUse() && !RHS->hasOneUse())
1145*d415bd75Srobert return nullptr; // TODO: Relax this restriction.
114609467b48Spatrick
114773471bf0Spatrick // Note that NSW flag can only be propagated for normal, non-negated abs!
114873471bf0Spatrick bool IntMinIsPoison = SPF == SelectPatternFlavor::SPF_ABS &&
114973471bf0Spatrick match(RHS, m_NSWNeg(m_Specific(LHS)));
115073471bf0Spatrick Constant *IntMinIsPoisonC =
115173471bf0Spatrick ConstantInt::get(Type::getInt1Ty(Sel.getContext()), IntMinIsPoison);
115273471bf0Spatrick Instruction *Abs =
115373471bf0Spatrick IC.Builder.CreateBinaryIntrinsic(Intrinsic::abs, LHS, IntMinIsPoisonC);
115409467b48Spatrick
115573471bf0Spatrick if (SPF == SelectPatternFlavor::SPF_NABS)
115673471bf0Spatrick return BinaryOperator::CreateNeg(Abs); // Always without NSW flag!
115773471bf0Spatrick return IC.replaceInstUsesWith(Sel, Abs);
115809467b48Spatrick }
115909467b48Spatrick
1160*d415bd75Srobert if (SelectPatternResult::isMinOrMax(SPF)) {
1161*d415bd75Srobert Intrinsic::ID IntrinsicID;
1162*d415bd75Srobert switch (SPF) {
1163*d415bd75Srobert case SelectPatternFlavor::SPF_UMIN:
1164*d415bd75Srobert IntrinsicID = Intrinsic::umin;
1165*d415bd75Srobert break;
1166*d415bd75Srobert case SelectPatternFlavor::SPF_UMAX:
1167*d415bd75Srobert IntrinsicID = Intrinsic::umax;
1168*d415bd75Srobert break;
1169*d415bd75Srobert case SelectPatternFlavor::SPF_SMIN:
1170*d415bd75Srobert IntrinsicID = Intrinsic::smin;
1171*d415bd75Srobert break;
1172*d415bd75Srobert case SelectPatternFlavor::SPF_SMAX:
1173*d415bd75Srobert IntrinsicID = Intrinsic::smax;
1174*d415bd75Srobert break;
1175*d415bd75Srobert default:
1176*d415bd75Srobert llvm_unreachable("Unexpected SPF");
1177*d415bd75Srobert }
1178*d415bd75Srobert return IC.replaceInstUsesWith(
1179*d415bd75Srobert Sel, IC.Builder.CreateBinaryIntrinsic(IntrinsicID, LHS, RHS));
1180*d415bd75Srobert }
1181*d415bd75Srobert
1182*d415bd75Srobert return nullptr;
1183*d415bd75Srobert }
1184*d415bd75Srobert
replaceInInstruction(Value * V,Value * Old,Value * New,InstCombiner & IC,unsigned Depth=0)1185*d415bd75Srobert static bool replaceInInstruction(Value *V, Value *Old, Value *New,
1186*d415bd75Srobert InstCombiner &IC, unsigned Depth = 0) {
1187*d415bd75Srobert // Conservatively limit replacement to two instructions upwards.
1188*d415bd75Srobert if (Depth == 2)
1189*d415bd75Srobert return false;
1190*d415bd75Srobert
1191*d415bd75Srobert auto *I = dyn_cast<Instruction>(V);
1192*d415bd75Srobert if (!I || !I->hasOneUse() || !isSafeToSpeculativelyExecute(I))
1193*d415bd75Srobert return false;
1194*d415bd75Srobert
1195*d415bd75Srobert bool Changed = false;
1196*d415bd75Srobert for (Use &U : I->operands()) {
1197*d415bd75Srobert if (U == Old) {
1198*d415bd75Srobert IC.replaceUse(U, New);
1199*d415bd75Srobert Changed = true;
1200*d415bd75Srobert } else {
1201*d415bd75Srobert Changed |= replaceInInstruction(U, Old, New, IC, Depth + 1);
1202*d415bd75Srobert }
1203*d415bd75Srobert }
1204*d415bd75Srobert return Changed;
1205*d415bd75Srobert }
1206*d415bd75Srobert
120709467b48Spatrick /// If we have a select with an equality comparison, then we know the value in
120809467b48Spatrick /// one of the arms of the select. See if substituting this value into an arm
120909467b48Spatrick /// and simplifying the result yields the same value as the other arm.
121009467b48Spatrick ///
121109467b48Spatrick /// To make this transform safe, we must drop poison-generating flags
121209467b48Spatrick /// (nsw, etc) if we simplified to a binop because the select may be guarding
121309467b48Spatrick /// that poison from propagating. If the existing binop already had no
121409467b48Spatrick /// poison-generating flags, then this transform can be done by instsimplify.
121509467b48Spatrick ///
121609467b48Spatrick /// Consider:
121709467b48Spatrick /// %cmp = icmp eq i32 %x, 2147483647
121809467b48Spatrick /// %add = add nsw i32 %x, 1
121909467b48Spatrick /// %sel = select i1 %cmp, i32 -2147483648, i32 %add
122009467b48Spatrick ///
122109467b48Spatrick /// We can't replace %sel with %add unless we strip away the flags.
122209467b48Spatrick /// TODO: Wrapping flags could be preserved in some cases with better analysis.
foldSelectValueEquivalence(SelectInst & Sel,ICmpInst & Cmp)122373471bf0Spatrick Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
122473471bf0Spatrick ICmpInst &Cmp) {
1225*d415bd75Srobert if (!Cmp.isEquality())
122609467b48Spatrick return nullptr;
122709467b48Spatrick
122809467b48Spatrick // Canonicalize the pattern to ICMP_EQ by swapping the select operands.
122909467b48Spatrick Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue();
123073471bf0Spatrick bool Swapped = false;
123173471bf0Spatrick if (Cmp.getPredicate() == ICmpInst::ICMP_NE) {
123209467b48Spatrick std::swap(TrueVal, FalseVal);
123373471bf0Spatrick Swapped = true;
123473471bf0Spatrick }
123573471bf0Spatrick
123673471bf0Spatrick // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand.
123773471bf0Spatrick // Make sure Y cannot be undef though, as we might pick different values for
123873471bf0Spatrick // undef in the icmp and in f(Y). Additionally, take care to avoid replacing
123973471bf0Spatrick // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite
124073471bf0Spatrick // replacement cycle.
124173471bf0Spatrick Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
124273471bf0Spatrick if (TrueVal != CmpLHS &&
124373471bf0Spatrick isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) {
124473471bf0Spatrick if (Value *V = simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ,
124573471bf0Spatrick /* AllowRefinement */ true))
124673471bf0Spatrick return replaceOperand(Sel, Swapped ? 2 : 1, V);
124773471bf0Spatrick
124873471bf0Spatrick // Even if TrueVal does not simplify, we can directly replace a use of
124973471bf0Spatrick // CmpLHS with CmpRHS, as long as the instruction is not used anywhere
125073471bf0Spatrick // else and is safe to speculatively execute (we may end up executing it
125173471bf0Spatrick // with different operands, which should not cause side-effects or trigger
125273471bf0Spatrick // undefined behavior). Only do this if CmpRHS is a constant, as
125373471bf0Spatrick // profitability is not clear for other cases.
1254*d415bd75Srobert // FIXME: Support vectors.
1255*d415bd75Srobert if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) &&
1256*d415bd75Srobert !Cmp.getType()->isVectorTy())
1257*d415bd75Srobert if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS, *this))
125873471bf0Spatrick return &Sel;
125973471bf0Spatrick }
126073471bf0Spatrick if (TrueVal != CmpRHS &&
126173471bf0Spatrick isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT))
126273471bf0Spatrick if (Value *V = simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ,
126373471bf0Spatrick /* AllowRefinement */ true))
126473471bf0Spatrick return replaceOperand(Sel, Swapped ? 2 : 1, V);
126509467b48Spatrick
1266097a140dSpatrick auto *FalseInst = dyn_cast<Instruction>(FalseVal);
1267097a140dSpatrick if (!FalseInst)
1268097a140dSpatrick return nullptr;
1269097a140dSpatrick
1270097a140dSpatrick // InstSimplify already performed this fold if it was possible subject to
1271097a140dSpatrick // current poison-generating flags. Try the transform again with
1272097a140dSpatrick // poison-generating flags temporarily dropped.
127373471bf0Spatrick bool WasNUW = false, WasNSW = false, WasExact = false, WasInBounds = false;
1274097a140dSpatrick if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) {
1275097a140dSpatrick WasNUW = OBO->hasNoUnsignedWrap();
1276097a140dSpatrick WasNSW = OBO->hasNoSignedWrap();
1277097a140dSpatrick FalseInst->setHasNoUnsignedWrap(false);
1278097a140dSpatrick FalseInst->setHasNoSignedWrap(false);
1279097a140dSpatrick }
1280097a140dSpatrick if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) {
1281097a140dSpatrick WasExact = PEO->isExact();
1282097a140dSpatrick FalseInst->setIsExact(false);
1283097a140dSpatrick }
128473471bf0Spatrick if (auto *GEP = dyn_cast<GetElementPtrInst>(FalseVal)) {
128573471bf0Spatrick WasInBounds = GEP->isInBounds();
128673471bf0Spatrick GEP->setIsInBounds(false);
128773471bf0Spatrick }
1288097a140dSpatrick
128909467b48Spatrick // Try each equivalence substitution possibility.
129009467b48Spatrick // We have an 'EQ' comparison, so the select's false value will propagate.
129109467b48Spatrick // Example:
129209467b48Spatrick // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
129373471bf0Spatrick if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
1294097a140dSpatrick /* AllowRefinement */ false) == TrueVal ||
129573471bf0Spatrick simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
1296097a140dSpatrick /* AllowRefinement */ false) == TrueVal) {
129773471bf0Spatrick return replaceInstUsesWith(Sel, FalseVal);
129809467b48Spatrick }
1299097a140dSpatrick
1300097a140dSpatrick // Restore poison-generating flags if the transform did not apply.
1301097a140dSpatrick if (WasNUW)
1302097a140dSpatrick FalseInst->setHasNoUnsignedWrap();
1303097a140dSpatrick if (WasNSW)
1304097a140dSpatrick FalseInst->setHasNoSignedWrap();
1305097a140dSpatrick if (WasExact)
1306097a140dSpatrick FalseInst->setIsExact();
130773471bf0Spatrick if (WasInBounds)
130873471bf0Spatrick cast<GetElementPtrInst>(FalseInst)->setIsInBounds();
1309097a140dSpatrick
131009467b48Spatrick return nullptr;
131109467b48Spatrick }
131209467b48Spatrick
131309467b48Spatrick // See if this is a pattern like:
131409467b48Spatrick // %old_cmp1 = icmp slt i32 %x, C2
131509467b48Spatrick // %old_replacement = select i1 %old_cmp1, i32 %target_low, i32 %target_high
131609467b48Spatrick // %old_x_offseted = add i32 %x, C1
131709467b48Spatrick // %old_cmp0 = icmp ult i32 %old_x_offseted, C0
131809467b48Spatrick // %r = select i1 %old_cmp0, i32 %x, i32 %old_replacement
131909467b48Spatrick // This can be rewritten as more canonical pattern:
132009467b48Spatrick // %new_cmp1 = icmp slt i32 %x, -C1
132109467b48Spatrick // %new_cmp2 = icmp sge i32 %x, C0-C1
132209467b48Spatrick // %new_clamped_low = select i1 %new_cmp1, i32 %target_low, i32 %x
132309467b48Spatrick // %r = select i1 %new_cmp2, i32 %target_high, i32 %new_clamped_low
132409467b48Spatrick // Iff -C1 s<= C2 s<= C0-C1
132509467b48Spatrick // Also ULT predicate can also be UGT iff C0 != -1 (+invert result)
132609467b48Spatrick // SLT predicate can also be SGT iff C2 != INT_MAX (+invert res.)
canonicalizeClampLike(SelectInst & Sel0,ICmpInst & Cmp0,InstCombiner::BuilderTy & Builder)1327*d415bd75Srobert static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
132809467b48Spatrick InstCombiner::BuilderTy &Builder) {
132909467b48Spatrick Value *X = Sel0.getTrueValue();
133009467b48Spatrick Value *Sel1 = Sel0.getFalseValue();
133109467b48Spatrick
133209467b48Spatrick // First match the condition of the outermost select.
133309467b48Spatrick // Said condition must be one-use.
133409467b48Spatrick if (!Cmp0.hasOneUse())
133509467b48Spatrick return nullptr;
1336*d415bd75Srobert ICmpInst::Predicate Pred0 = Cmp0.getPredicate();
133709467b48Spatrick Value *Cmp00 = Cmp0.getOperand(0);
133809467b48Spatrick Constant *C0;
133909467b48Spatrick if (!match(Cmp0.getOperand(1),
134009467b48Spatrick m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0))))
134109467b48Spatrick return nullptr;
1342*d415bd75Srobert
1343*d415bd75Srobert if (!isa<SelectInst>(Sel1)) {
1344*d415bd75Srobert Pred0 = ICmpInst::getInversePredicate(Pred0);
1345*d415bd75Srobert std::swap(X, Sel1);
1346*d415bd75Srobert }
1347*d415bd75Srobert
1348*d415bd75Srobert // Canonicalize Cmp0 into ult or uge.
134909467b48Spatrick // FIXME: we shouldn't care about lanes that are 'undef' in the end?
1350*d415bd75Srobert switch (Pred0) {
135109467b48Spatrick case ICmpInst::Predicate::ICMP_ULT:
1352*d415bd75Srobert case ICmpInst::Predicate::ICMP_UGE:
1353*d415bd75Srobert // Although icmp ult %x, 0 is an unusual thing to try and should generally
1354*d415bd75Srobert // have been simplified, it does not verify with undef inputs so ensure we
1355*d415bd75Srobert // are not in a strange state.
1356*d415bd75Srobert if (!match(C0, m_SpecificInt_ICMP(
1357*d415bd75Srobert ICmpInst::Predicate::ICMP_NE,
1358*d415bd75Srobert APInt::getZero(C0->getType()->getScalarSizeInBits()))))
1359*d415bd75Srobert return nullptr;
136009467b48Spatrick break; // Great!
136109467b48Spatrick case ICmpInst::Predicate::ICMP_ULE:
136209467b48Spatrick case ICmpInst::Predicate::ICMP_UGT:
1363*d415bd75Srobert // We want to canonicalize it to 'ult' or 'uge', so we'll need to increment
1364*d415bd75Srobert // C0, which again means it must not have any all-ones elements.
136509467b48Spatrick if (!match(C0,
1366*d415bd75Srobert m_SpecificInt_ICMP(
1367*d415bd75Srobert ICmpInst::Predicate::ICMP_NE,
1368*d415bd75Srobert APInt::getAllOnes(C0->getType()->getScalarSizeInBits()))))
136909467b48Spatrick return nullptr; // Can't do, have all-ones element[s].
1370*d415bd75Srobert Pred0 = ICmpInst::getFlippedStrictnessPredicate(Pred0);
137173471bf0Spatrick C0 = InstCombiner::AddOne(C0);
137209467b48Spatrick break;
137309467b48Spatrick default:
137409467b48Spatrick return nullptr; // Unknown predicate.
137509467b48Spatrick }
137609467b48Spatrick
137709467b48Spatrick // Now that we've canonicalized the ICmp, we know the X we expect;
137809467b48Spatrick // the select in other hand should be one-use.
137909467b48Spatrick if (!Sel1->hasOneUse())
138009467b48Spatrick return nullptr;
138109467b48Spatrick
1382*d415bd75Srobert // If the types do not match, look through any truncs to the underlying
1383*d415bd75Srobert // instruction.
1384*d415bd75Srobert if (Cmp00->getType() != X->getType() && X->hasOneUse())
1385*d415bd75Srobert match(X, m_TruncOrSelf(m_Value(X)));
1386*d415bd75Srobert
138709467b48Spatrick // We now can finish matching the condition of the outermost select:
138809467b48Spatrick // it should either be the X itself, or an addition of some constant to X.
138909467b48Spatrick Constant *C1;
139009467b48Spatrick if (Cmp00 == X)
1391*d415bd75Srobert C1 = ConstantInt::getNullValue(X->getType());
139209467b48Spatrick else if (!match(Cmp00,
139309467b48Spatrick m_Add(m_Specific(X),
139409467b48Spatrick m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C1)))))
139509467b48Spatrick return nullptr;
139609467b48Spatrick
139709467b48Spatrick Value *Cmp1;
139809467b48Spatrick ICmpInst::Predicate Pred1;
139909467b48Spatrick Constant *C2;
140009467b48Spatrick Value *ReplacementLow, *ReplacementHigh;
140109467b48Spatrick if (!match(Sel1, m_Select(m_Value(Cmp1), m_Value(ReplacementLow),
140209467b48Spatrick m_Value(ReplacementHigh))) ||
140309467b48Spatrick !match(Cmp1,
140409467b48Spatrick m_ICmp(Pred1, m_Specific(X),
140509467b48Spatrick m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C2)))))
140609467b48Spatrick return nullptr;
140709467b48Spatrick
140809467b48Spatrick if (!Cmp1->hasOneUse() && (Cmp00 == X || !Cmp00->hasOneUse()))
140909467b48Spatrick return nullptr; // Not enough one-use instructions for the fold.
141009467b48Spatrick // FIXME: this restriction could be relaxed if Cmp1 can be reused as one of
141109467b48Spatrick // two comparisons we'll need to build.
141209467b48Spatrick
141309467b48Spatrick // Canonicalize Cmp1 into the form we expect.
141409467b48Spatrick // FIXME: we shouldn't care about lanes that are 'undef' in the end?
141509467b48Spatrick switch (Pred1) {
141609467b48Spatrick case ICmpInst::Predicate::ICMP_SLT:
141709467b48Spatrick break;
141809467b48Spatrick case ICmpInst::Predicate::ICMP_SLE:
141909467b48Spatrick // We'd have to increment C2 by one, and for that it must not have signed
142009467b48Spatrick // max element, but then it would have been canonicalized to 'slt' before
142109467b48Spatrick // we get here. So we can't do anything useful with 'sle'.
142209467b48Spatrick return nullptr;
142309467b48Spatrick case ICmpInst::Predicate::ICMP_SGT:
142409467b48Spatrick // We want to canonicalize it to 'slt', so we'll need to increment C2,
142509467b48Spatrick // which again means it must not have any signed max elements.
142609467b48Spatrick if (!match(C2,
142709467b48Spatrick m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_NE,
142809467b48Spatrick APInt::getSignedMaxValue(
142909467b48Spatrick C2->getType()->getScalarSizeInBits()))))
143009467b48Spatrick return nullptr; // Can't do, have signed max element[s].
143173471bf0Spatrick C2 = InstCombiner::AddOne(C2);
1432*d415bd75Srobert [[fallthrough]];
143309467b48Spatrick case ICmpInst::Predicate::ICMP_SGE:
143409467b48Spatrick // Also non-canonical, but here we don't need to change C2,
143509467b48Spatrick // so we don't have any restrictions on C2, so we can just handle it.
1436*d415bd75Srobert Pred1 = ICmpInst::Predicate::ICMP_SLT;
143709467b48Spatrick std::swap(ReplacementLow, ReplacementHigh);
143809467b48Spatrick break;
143909467b48Spatrick default:
144009467b48Spatrick return nullptr; // Unknown predicate.
144109467b48Spatrick }
1442*d415bd75Srobert assert(Pred1 == ICmpInst::Predicate::ICMP_SLT &&
1443*d415bd75Srobert "Unexpected predicate type.");
144409467b48Spatrick
144509467b48Spatrick // The thresholds of this clamp-like pattern.
144609467b48Spatrick auto *ThresholdLowIncl = ConstantExpr::getNeg(C1);
144709467b48Spatrick auto *ThresholdHighExcl = ConstantExpr::getSub(C0, C1);
144809467b48Spatrick
1449*d415bd75Srobert assert((Pred0 == ICmpInst::Predicate::ICMP_ULT ||
1450*d415bd75Srobert Pred0 == ICmpInst::Predicate::ICMP_UGE) &&
1451*d415bd75Srobert "Unexpected predicate type.");
1452*d415bd75Srobert if (Pred0 == ICmpInst::Predicate::ICMP_UGE)
1453*d415bd75Srobert std::swap(ThresholdLowIncl, ThresholdHighExcl);
1454*d415bd75Srobert
145509467b48Spatrick // The fold has a precondition 1: C2 s>= ThresholdLow
145609467b48Spatrick auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2,
145709467b48Spatrick ThresholdLowIncl);
145809467b48Spatrick if (!match(Precond1, m_One()))
145909467b48Spatrick return nullptr;
146009467b48Spatrick // The fold has a precondition 2: C2 s<= ThresholdHigh
146109467b48Spatrick auto *Precond2 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SLE, C2,
146209467b48Spatrick ThresholdHighExcl);
146309467b48Spatrick if (!match(Precond2, m_One()))
146409467b48Spatrick return nullptr;
146509467b48Spatrick
1466*d415bd75Srobert // If we are matching from a truncated input, we need to sext the
1467*d415bd75Srobert // ReplacementLow and ReplacementHigh values. Only do the transform if they
1468*d415bd75Srobert // are free to extend due to being constants.
1469*d415bd75Srobert if (X->getType() != Sel0.getType()) {
1470*d415bd75Srobert Constant *LowC, *HighC;
1471*d415bd75Srobert if (!match(ReplacementLow, m_ImmConstant(LowC)) ||
1472*d415bd75Srobert !match(ReplacementHigh, m_ImmConstant(HighC)))
1473*d415bd75Srobert return nullptr;
1474*d415bd75Srobert ReplacementLow = ConstantExpr::getSExt(LowC, X->getType());
1475*d415bd75Srobert ReplacementHigh = ConstantExpr::getSExt(HighC, X->getType());
1476*d415bd75Srobert }
1477*d415bd75Srobert
147809467b48Spatrick // All good, finally emit the new pattern.
147909467b48Spatrick Value *ShouldReplaceLow = Builder.CreateICmpSLT(X, ThresholdLowIncl);
148009467b48Spatrick Value *ShouldReplaceHigh = Builder.CreateICmpSGE(X, ThresholdHighExcl);
148109467b48Spatrick Value *MaybeReplacedLow =
148209467b48Spatrick Builder.CreateSelect(ShouldReplaceLow, ReplacementLow, X);
148309467b48Spatrick
1484*d415bd75Srobert // Create the final select. If we looked through a truncate above, we will
1485*d415bd75Srobert // need to retruncate the result.
1486*d415bd75Srobert Value *MaybeReplacedHigh = Builder.CreateSelect(
1487*d415bd75Srobert ShouldReplaceHigh, ReplacementHigh, MaybeReplacedLow);
1488*d415bd75Srobert return Builder.CreateTrunc(MaybeReplacedHigh, Sel0.getType());
148909467b48Spatrick }
149009467b48Spatrick
149109467b48Spatrick // If we have
149209467b48Spatrick // %cmp = icmp [canonical predicate] i32 %x, C0
149309467b48Spatrick // %r = select i1 %cmp, i32 %y, i32 C1
149409467b48Spatrick // Where C0 != C1 and %x may be different from %y, see if the constant that we
149509467b48Spatrick // will have if we flip the strictness of the predicate (i.e. without changing
149609467b48Spatrick // the result) is identical to the C1 in select. If it matches we can change
149709467b48Spatrick // original comparison to one with swapped predicate, reuse the constant,
149809467b48Spatrick // and swap the hands of select.
149909467b48Spatrick static Instruction *
tryToReuseConstantFromSelectInComparison(SelectInst & Sel,ICmpInst & Cmp,InstCombinerImpl & IC)150009467b48Spatrick tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
150173471bf0Spatrick InstCombinerImpl &IC) {
150209467b48Spatrick ICmpInst::Predicate Pred;
150309467b48Spatrick Value *X;
150409467b48Spatrick Constant *C0;
150509467b48Spatrick if (!match(&Cmp, m_OneUse(m_ICmp(
150609467b48Spatrick Pred, m_Value(X),
150709467b48Spatrick m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0))))))
150809467b48Spatrick return nullptr;
150909467b48Spatrick
151009467b48Spatrick // If comparison predicate is non-relational, we won't be able to do anything.
151109467b48Spatrick if (ICmpInst::isEquality(Pred))
151209467b48Spatrick return nullptr;
151309467b48Spatrick
151409467b48Spatrick // If comparison predicate is non-canonical, then we certainly won't be able
151509467b48Spatrick // to make it canonical; canonicalizeCmpWithConstant() already tried.
151673471bf0Spatrick if (!InstCombiner::isCanonicalPredicate(Pred))
151709467b48Spatrick return nullptr;
151809467b48Spatrick
151909467b48Spatrick // If the [input] type of comparison and select type are different, lets abort
152009467b48Spatrick // for now. We could try to compare constants with trunc/[zs]ext though.
152109467b48Spatrick if (C0->getType() != Sel.getType())
152209467b48Spatrick return nullptr;
152309467b48Spatrick
1524*d415bd75Srobert // ULT with 'add' of a constant is canonical. See foldICmpAddConstant().
1525*d415bd75Srobert // FIXME: Are there more magic icmp predicate+constant pairs we must avoid?
1526*d415bd75Srobert // Or should we just abandon this transform entirely?
1527*d415bd75Srobert if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant())))
1528*d415bd75Srobert return nullptr;
1529*d415bd75Srobert
153009467b48Spatrick
153109467b48Spatrick Value *SelVal0, *SelVal1; // We do not care which one is from where.
153209467b48Spatrick match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1)));
153309467b48Spatrick // At least one of these values we are selecting between must be a constant
153409467b48Spatrick // else we'll never succeed.
153509467b48Spatrick if (!match(SelVal0, m_AnyIntegralConstant()) &&
153609467b48Spatrick !match(SelVal1, m_AnyIntegralConstant()))
153709467b48Spatrick return nullptr;
153809467b48Spatrick
153909467b48Spatrick // Does this constant C match any of the `select` values?
154009467b48Spatrick auto MatchesSelectValue = [SelVal0, SelVal1](Constant *C) {
154109467b48Spatrick return C->isElementWiseEqual(SelVal0) || C->isElementWiseEqual(SelVal1);
154209467b48Spatrick };
154309467b48Spatrick
154409467b48Spatrick // If C0 *already* matches true/false value of select, we are done.
154509467b48Spatrick if (MatchesSelectValue(C0))
154609467b48Spatrick return nullptr;
154709467b48Spatrick
154809467b48Spatrick // Check the constant we'd have with flipped-strictness predicate.
154973471bf0Spatrick auto FlippedStrictness =
155073471bf0Spatrick InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0);
155109467b48Spatrick if (!FlippedStrictness)
155209467b48Spatrick return nullptr;
155309467b48Spatrick
155409467b48Spatrick // If said constant doesn't match either, then there is no hope,
155509467b48Spatrick if (!MatchesSelectValue(FlippedStrictness->second))
155609467b48Spatrick return nullptr;
155709467b48Spatrick
155809467b48Spatrick // It matched! Lets insert the new comparison just before select.
1559097a140dSpatrick InstCombiner::BuilderTy::InsertPointGuard Guard(IC.Builder);
1560097a140dSpatrick IC.Builder.SetInsertPoint(&Sel);
156109467b48Spatrick
156209467b48Spatrick Pred = ICmpInst::getSwappedPredicate(Pred); // Yes, swapped.
1563097a140dSpatrick Value *NewCmp = IC.Builder.CreateICmp(Pred, X, FlippedStrictness->second,
156409467b48Spatrick Cmp.getName() + ".inv");
1565097a140dSpatrick IC.replaceOperand(Sel, 0, NewCmp);
156609467b48Spatrick Sel.swapValues();
156709467b48Spatrick Sel.swapProfMetadata();
156809467b48Spatrick
156909467b48Spatrick return &Sel;
157009467b48Spatrick }
157109467b48Spatrick
foldSelectZeroOrOnes(ICmpInst * Cmp,Value * TVal,Value * FVal,InstCombiner::BuilderTy & Builder)1572*d415bd75Srobert static Instruction *foldSelectZeroOrOnes(ICmpInst *Cmp, Value *TVal,
1573*d415bd75Srobert Value *FVal,
1574*d415bd75Srobert InstCombiner::BuilderTy &Builder) {
1575*d415bd75Srobert if (!Cmp->hasOneUse())
1576*d415bd75Srobert return nullptr;
1577*d415bd75Srobert
1578*d415bd75Srobert const APInt *CmpC;
1579*d415bd75Srobert if (!match(Cmp->getOperand(1), m_APIntAllowUndef(CmpC)))
1580*d415bd75Srobert return nullptr;
1581*d415bd75Srobert
1582*d415bd75Srobert // (X u< 2) ? -X : -1 --> sext (X != 0)
1583*d415bd75Srobert Value *X = Cmp->getOperand(0);
1584*d415bd75Srobert if (Cmp->getPredicate() == ICmpInst::ICMP_ULT && *CmpC == 2 &&
1585*d415bd75Srobert match(TVal, m_Neg(m_Specific(X))) && match(FVal, m_AllOnes()))
1586*d415bd75Srobert return new SExtInst(Builder.CreateIsNotNull(X), TVal->getType());
1587*d415bd75Srobert
1588*d415bd75Srobert // (X u> 1) ? -1 : -X --> sext (X != 0)
1589*d415bd75Srobert if (Cmp->getPredicate() == ICmpInst::ICMP_UGT && *CmpC == 1 &&
1590*d415bd75Srobert match(FVal, m_Neg(m_Specific(X))) && match(TVal, m_AllOnes()))
1591*d415bd75Srobert return new SExtInst(Builder.CreateIsNotNull(X), TVal->getType());
1592*d415bd75Srobert
1593*d415bd75Srobert return nullptr;
1594*d415bd75Srobert }
1595*d415bd75Srobert
foldSelectInstWithICmpConst(SelectInst & SI,ICmpInst * ICI)1596*d415bd75Srobert static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI) {
1597*d415bd75Srobert const APInt *CmpC;
1598*d415bd75Srobert Value *V;
1599*d415bd75Srobert CmpInst::Predicate Pred;
1600*d415bd75Srobert if (!match(ICI, m_ICmp(Pred, m_Value(V), m_APInt(CmpC))))
1601*d415bd75Srobert return nullptr;
1602*d415bd75Srobert
1603*d415bd75Srobert BinaryOperator *BO;
1604*d415bd75Srobert const APInt *C;
1605*d415bd75Srobert CmpInst::Predicate CPred;
1606*d415bd75Srobert if (match(&SI, m_Select(m_Specific(ICI), m_APInt(C), m_BinOp(BO))))
1607*d415bd75Srobert CPred = ICI->getPredicate();
1608*d415bd75Srobert else if (match(&SI, m_Select(m_Specific(ICI), m_BinOp(BO), m_APInt(C))))
1609*d415bd75Srobert CPred = ICI->getInversePredicate();
1610*d415bd75Srobert else
1611*d415bd75Srobert return nullptr;
1612*d415bd75Srobert
1613*d415bd75Srobert const APInt *BinOpC;
1614*d415bd75Srobert if (!match(BO, m_BinOp(m_Specific(V), m_APInt(BinOpC))))
1615*d415bd75Srobert return nullptr;
1616*d415bd75Srobert
1617*d415bd75Srobert ConstantRange R = ConstantRange::makeExactICmpRegion(CPred, *CmpC)
1618*d415bd75Srobert .binaryOp(BO->getOpcode(), *BinOpC);
1619*d415bd75Srobert if (R == *C) {
1620*d415bd75Srobert BO->dropPoisonGeneratingFlags();
1621*d415bd75Srobert return BO;
1622*d415bd75Srobert }
1623*d415bd75Srobert return nullptr;
1624*d415bd75Srobert }
1625*d415bd75Srobert
162609467b48Spatrick /// Visit a SelectInst that has an ICmpInst as its first operand.
foldSelectInstWithICmp(SelectInst & SI,ICmpInst * ICI)162773471bf0Spatrick Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
162809467b48Spatrick ICmpInst *ICI) {
162973471bf0Spatrick if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI))
163073471bf0Spatrick return NewSel;
163109467b48Spatrick
1632*d415bd75Srobert if (Instruction *NewSPF = canonicalizeSPF(SI, *ICI, *this))
1633*d415bd75Srobert return NewSPF;
163409467b48Spatrick
1635*d415bd75Srobert if (Value *V = foldSelectInstWithICmpConst(SI, ICI))
1636*d415bd75Srobert return replaceInstUsesWith(SI, V);
163709467b48Spatrick
1638*d415bd75Srobert if (Value *V = canonicalizeClampLike(SI, *ICI, Builder))
1639*d415bd75Srobert return replaceInstUsesWith(SI, V);
164009467b48Spatrick
164109467b48Spatrick if (Instruction *NewSel =
1642097a140dSpatrick tryToReuseConstantFromSelectInComparison(SI, *ICI, *this))
164309467b48Spatrick return NewSel;
164409467b48Spatrick
164509467b48Spatrick bool Changed = adjustMinMax(SI, *ICI);
164609467b48Spatrick
164709467b48Spatrick if (Value *V = foldSelectICmpAnd(SI, ICI, Builder))
164809467b48Spatrick return replaceInstUsesWith(SI, V);
164909467b48Spatrick
165009467b48Spatrick // NOTE: if we wanted to, this is where to detect integer MIN/MAX
165109467b48Spatrick Value *TrueVal = SI.getTrueValue();
165209467b48Spatrick Value *FalseVal = SI.getFalseValue();
165309467b48Spatrick ICmpInst::Predicate Pred = ICI->getPredicate();
165409467b48Spatrick Value *CmpLHS = ICI->getOperand(0);
165509467b48Spatrick Value *CmpRHS = ICI->getOperand(1);
165609467b48Spatrick if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) {
165709467b48Spatrick if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) {
165809467b48Spatrick // Transform (X == C) ? X : Y -> (X == C) ? C : Y
165909467b48Spatrick SI.setOperand(1, CmpRHS);
166009467b48Spatrick Changed = true;
166109467b48Spatrick } else if (CmpLHS == FalseVal && Pred == ICmpInst::ICMP_NE) {
166209467b48Spatrick // Transform (X != C) ? Y : X -> (X != C) ? Y : C
166309467b48Spatrick SI.setOperand(2, CmpRHS);
166409467b48Spatrick Changed = true;
166509467b48Spatrick }
166609467b48Spatrick }
166709467b48Spatrick
1668*d415bd75Srobert // Canonicalize a signbit condition to use zero constant by swapping:
1669*d415bd75Srobert // (CmpLHS > -1) ? TV : FV --> (CmpLHS < 0) ? FV : TV
1670*d415bd75Srobert // To avoid conflicts (infinite loops) with other canonicalizations, this is
1671*d415bd75Srobert // not applied with any constant select arm.
1672*d415bd75Srobert if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes()) &&
1673*d415bd75Srobert !match(TrueVal, m_Constant()) && !match(FalseVal, m_Constant()) &&
1674*d415bd75Srobert ICI->hasOneUse()) {
1675*d415bd75Srobert InstCombiner::BuilderTy::InsertPointGuard Guard(Builder);
1676*d415bd75Srobert Builder.SetInsertPoint(&SI);
1677*d415bd75Srobert Value *IsNeg = Builder.CreateIsNeg(CmpLHS, ICI->getName());
1678*d415bd75Srobert replaceOperand(SI, 0, IsNeg);
1679*d415bd75Srobert SI.swapValues();
1680*d415bd75Srobert SI.swapProfMetadata();
1681*d415bd75Srobert return &SI;
1682*d415bd75Srobert }
1683*d415bd75Srobert
168409467b48Spatrick // FIXME: This code is nearly duplicated in InstSimplify. Using/refactoring
168509467b48Spatrick // decomposeBitTestICmp() might help.
168609467b48Spatrick {
168709467b48Spatrick unsigned BitWidth =
168809467b48Spatrick DL.getTypeSizeInBits(TrueVal->getType()->getScalarType());
168909467b48Spatrick APInt MinSignedValue = APInt::getSignedMinValue(BitWidth);
169009467b48Spatrick Value *X;
169109467b48Spatrick const APInt *Y, *C;
169209467b48Spatrick bool TrueWhenUnset;
169309467b48Spatrick bool IsBitTest = false;
169409467b48Spatrick if (ICmpInst::isEquality(Pred) &&
169509467b48Spatrick match(CmpLHS, m_And(m_Value(X), m_Power2(Y))) &&
169609467b48Spatrick match(CmpRHS, m_Zero())) {
169709467b48Spatrick IsBitTest = true;
169809467b48Spatrick TrueWhenUnset = Pred == ICmpInst::ICMP_EQ;
169909467b48Spatrick } else if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) {
170009467b48Spatrick X = CmpLHS;
170109467b48Spatrick Y = &MinSignedValue;
170209467b48Spatrick IsBitTest = true;
170309467b48Spatrick TrueWhenUnset = false;
170409467b48Spatrick } else if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) {
170509467b48Spatrick X = CmpLHS;
170609467b48Spatrick Y = &MinSignedValue;
170709467b48Spatrick IsBitTest = true;
170809467b48Spatrick TrueWhenUnset = true;
170909467b48Spatrick }
171009467b48Spatrick if (IsBitTest) {
171109467b48Spatrick Value *V = nullptr;
171209467b48Spatrick // (X & Y) == 0 ? X : X ^ Y --> X & ~Y
171309467b48Spatrick if (TrueWhenUnset && TrueVal == X &&
171409467b48Spatrick match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
171509467b48Spatrick V = Builder.CreateAnd(X, ~(*Y));
171609467b48Spatrick // (X & Y) != 0 ? X ^ Y : X --> X & ~Y
171709467b48Spatrick else if (!TrueWhenUnset && FalseVal == X &&
171809467b48Spatrick match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
171909467b48Spatrick V = Builder.CreateAnd(X, ~(*Y));
172009467b48Spatrick // (X & Y) == 0 ? X ^ Y : X --> X | Y
172109467b48Spatrick else if (TrueWhenUnset && FalseVal == X &&
172209467b48Spatrick match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
172309467b48Spatrick V = Builder.CreateOr(X, *Y);
172409467b48Spatrick // (X & Y) != 0 ? X : X ^ Y --> X | Y
172509467b48Spatrick else if (!TrueWhenUnset && TrueVal == X &&
172609467b48Spatrick match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
172709467b48Spatrick V = Builder.CreateOr(X, *Y);
172809467b48Spatrick
172909467b48Spatrick if (V)
173009467b48Spatrick return replaceInstUsesWith(SI, V);
173109467b48Spatrick }
173209467b48Spatrick }
173309467b48Spatrick
173409467b48Spatrick if (Instruction *V =
173509467b48Spatrick foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder))
173609467b48Spatrick return V;
173709467b48Spatrick
173809467b48Spatrick if (Instruction *V = foldSelectCtlzToCttz(ICI, TrueVal, FalseVal, Builder))
173909467b48Spatrick return V;
174009467b48Spatrick
1741*d415bd75Srobert if (Instruction *V = foldSelectZeroOrOnes(ICI, TrueVal, FalseVal, Builder))
1742*d415bd75Srobert return V;
1743*d415bd75Srobert
174409467b48Spatrick if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder))
174509467b48Spatrick return replaceInstUsesWith(SI, V);
174609467b48Spatrick
174709467b48Spatrick if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder))
174809467b48Spatrick return replaceInstUsesWith(SI, V);
174909467b48Spatrick
175009467b48Spatrick if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder))
175109467b48Spatrick return replaceInstUsesWith(SI, V);
175209467b48Spatrick
175309467b48Spatrick if (Value *V = canonicalizeSaturatedSubtract(ICI, TrueVal, FalseVal, Builder))
175409467b48Spatrick return replaceInstUsesWith(SI, V);
175509467b48Spatrick
175609467b48Spatrick if (Value *V = canonicalizeSaturatedAdd(ICI, TrueVal, FalseVal, Builder))
175709467b48Spatrick return replaceInstUsesWith(SI, V);
175809467b48Spatrick
175909467b48Spatrick return Changed ? &SI : nullptr;
176009467b48Spatrick }
176109467b48Spatrick
176209467b48Spatrick /// SI is a select whose condition is a PHI node (but the two may be in
176309467b48Spatrick /// different blocks). See if the true/false values (V) are live in all of the
176409467b48Spatrick /// predecessor blocks of the PHI. For example, cases like this can't be mapped:
176509467b48Spatrick ///
176609467b48Spatrick /// X = phi [ C1, BB1], [C2, BB2]
176709467b48Spatrick /// Y = add
176809467b48Spatrick /// Z = select X, Y, 0
176909467b48Spatrick ///
177009467b48Spatrick /// because Y is not live in BB1/BB2.
canSelectOperandBeMappingIntoPredBlock(const Value * V,const SelectInst & SI)177109467b48Spatrick static bool canSelectOperandBeMappingIntoPredBlock(const Value *V,
177209467b48Spatrick const SelectInst &SI) {
177309467b48Spatrick // If the value is a non-instruction value like a constant or argument, it
177409467b48Spatrick // can always be mapped.
177509467b48Spatrick const Instruction *I = dyn_cast<Instruction>(V);
177609467b48Spatrick if (!I) return true;
177709467b48Spatrick
177809467b48Spatrick // If V is a PHI node defined in the same block as the condition PHI, we can
177909467b48Spatrick // map the arguments.
178009467b48Spatrick const PHINode *CondPHI = cast<PHINode>(SI.getCondition());
178109467b48Spatrick
178209467b48Spatrick if (const PHINode *VP = dyn_cast<PHINode>(I))
178309467b48Spatrick if (VP->getParent() == CondPHI->getParent())
178409467b48Spatrick return true;
178509467b48Spatrick
178609467b48Spatrick // Otherwise, if the PHI and select are defined in the same block and if V is
178709467b48Spatrick // defined in a different block, then we can transform it.
178809467b48Spatrick if (SI.getParent() == CondPHI->getParent() &&
178909467b48Spatrick I->getParent() != CondPHI->getParent())
179009467b48Spatrick return true;
179109467b48Spatrick
179209467b48Spatrick // Otherwise we have a 'hard' case and we can't tell without doing more
179309467b48Spatrick // detailed dominator based analysis, punt.
179409467b48Spatrick return false;
179509467b48Spatrick }
179609467b48Spatrick
179709467b48Spatrick /// We have an SPF (e.g. a min or max) of an SPF of the form:
179809467b48Spatrick /// SPF2(SPF1(A, B), C)
foldSPFofSPF(Instruction * Inner,SelectPatternFlavor SPF1,Value * A,Value * B,Instruction & Outer,SelectPatternFlavor SPF2,Value * C)179973471bf0Spatrick Instruction *InstCombinerImpl::foldSPFofSPF(Instruction *Inner,
180073471bf0Spatrick SelectPatternFlavor SPF1, Value *A,
180173471bf0Spatrick Value *B, Instruction &Outer,
180273471bf0Spatrick SelectPatternFlavor SPF2,
180373471bf0Spatrick Value *C) {
180409467b48Spatrick if (Outer.getType() != Inner->getType())
180509467b48Spatrick return nullptr;
180609467b48Spatrick
180709467b48Spatrick if (C == A || C == B) {
180809467b48Spatrick // MAX(MAX(A, B), B) -> MAX(A, B)
180909467b48Spatrick // MIN(MIN(a, b), a) -> MIN(a, b)
181009467b48Spatrick // TODO: This could be done in instsimplify.
181109467b48Spatrick if (SPF1 == SPF2 && SelectPatternResult::isMinOrMax(SPF1))
181209467b48Spatrick return replaceInstUsesWith(Outer, Inner);
181309467b48Spatrick }
181409467b48Spatrick
181509467b48Spatrick return nullptr;
181609467b48Spatrick }
181709467b48Spatrick
181809467b48Spatrick /// Turn select C, (X + Y), (X - Y) --> (X + (select C, Y, (-Y))).
181909467b48Spatrick /// This is even legal for FP.
foldAddSubSelect(SelectInst & SI,InstCombiner::BuilderTy & Builder)182009467b48Spatrick static Instruction *foldAddSubSelect(SelectInst &SI,
182109467b48Spatrick InstCombiner::BuilderTy &Builder) {
182209467b48Spatrick Value *CondVal = SI.getCondition();
182309467b48Spatrick Value *TrueVal = SI.getTrueValue();
182409467b48Spatrick Value *FalseVal = SI.getFalseValue();
182509467b48Spatrick auto *TI = dyn_cast<Instruction>(TrueVal);
182609467b48Spatrick auto *FI = dyn_cast<Instruction>(FalseVal);
182709467b48Spatrick if (!TI || !FI || !TI->hasOneUse() || !FI->hasOneUse())
182809467b48Spatrick return nullptr;
182909467b48Spatrick
183009467b48Spatrick Instruction *AddOp = nullptr, *SubOp = nullptr;
183109467b48Spatrick if ((TI->getOpcode() == Instruction::Sub &&
183209467b48Spatrick FI->getOpcode() == Instruction::Add) ||
183309467b48Spatrick (TI->getOpcode() == Instruction::FSub &&
183409467b48Spatrick FI->getOpcode() == Instruction::FAdd)) {
183509467b48Spatrick AddOp = FI;
183609467b48Spatrick SubOp = TI;
183709467b48Spatrick } else if ((FI->getOpcode() == Instruction::Sub &&
183809467b48Spatrick TI->getOpcode() == Instruction::Add) ||
183909467b48Spatrick (FI->getOpcode() == Instruction::FSub &&
184009467b48Spatrick TI->getOpcode() == Instruction::FAdd)) {
184109467b48Spatrick AddOp = TI;
184209467b48Spatrick SubOp = FI;
184309467b48Spatrick }
184409467b48Spatrick
184509467b48Spatrick if (AddOp) {
184609467b48Spatrick Value *OtherAddOp = nullptr;
184709467b48Spatrick if (SubOp->getOperand(0) == AddOp->getOperand(0)) {
184809467b48Spatrick OtherAddOp = AddOp->getOperand(1);
184909467b48Spatrick } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) {
185009467b48Spatrick OtherAddOp = AddOp->getOperand(0);
185109467b48Spatrick }
185209467b48Spatrick
185309467b48Spatrick if (OtherAddOp) {
185409467b48Spatrick // So at this point we know we have (Y -> OtherAddOp):
185509467b48Spatrick // select C, (add X, Y), (sub X, Z)
185609467b48Spatrick Value *NegVal; // Compute -Z
185709467b48Spatrick if (SI.getType()->isFPOrFPVectorTy()) {
185809467b48Spatrick NegVal = Builder.CreateFNeg(SubOp->getOperand(1));
185909467b48Spatrick if (Instruction *NegInst = dyn_cast<Instruction>(NegVal)) {
186009467b48Spatrick FastMathFlags Flags = AddOp->getFastMathFlags();
186109467b48Spatrick Flags &= SubOp->getFastMathFlags();
186209467b48Spatrick NegInst->setFastMathFlags(Flags);
186309467b48Spatrick }
186409467b48Spatrick } else {
186509467b48Spatrick NegVal = Builder.CreateNeg(SubOp->getOperand(1));
186609467b48Spatrick }
186709467b48Spatrick
186809467b48Spatrick Value *NewTrueOp = OtherAddOp;
186909467b48Spatrick Value *NewFalseOp = NegVal;
187009467b48Spatrick if (AddOp != TI)
187109467b48Spatrick std::swap(NewTrueOp, NewFalseOp);
187209467b48Spatrick Value *NewSel = Builder.CreateSelect(CondVal, NewTrueOp, NewFalseOp,
187309467b48Spatrick SI.getName() + ".p", &SI);
187409467b48Spatrick
187509467b48Spatrick if (SI.getType()->isFPOrFPVectorTy()) {
187609467b48Spatrick Instruction *RI =
187709467b48Spatrick BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel);
187809467b48Spatrick
187909467b48Spatrick FastMathFlags Flags = AddOp->getFastMathFlags();
188009467b48Spatrick Flags &= SubOp->getFastMathFlags();
188109467b48Spatrick RI->setFastMathFlags(Flags);
188209467b48Spatrick return RI;
188309467b48Spatrick } else
188409467b48Spatrick return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel);
188509467b48Spatrick }
188609467b48Spatrick }
188709467b48Spatrick return nullptr;
188809467b48Spatrick }
188909467b48Spatrick
189009467b48Spatrick /// Turn X + Y overflows ? -1 : X + Y -> uadd_sat X, Y
189109467b48Spatrick /// And X - Y overflows ? 0 : X - Y -> usub_sat X, Y
189209467b48Spatrick /// Along with a number of patterns similar to:
189309467b48Spatrick /// X + Y overflows ? (X < 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
189409467b48Spatrick /// X - Y overflows ? (X > 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
189509467b48Spatrick static Instruction *
foldOverflowingAddSubSelect(SelectInst & SI,InstCombiner::BuilderTy & Builder)189609467b48Spatrick foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) {
189709467b48Spatrick Value *CondVal = SI.getCondition();
189809467b48Spatrick Value *TrueVal = SI.getTrueValue();
189909467b48Spatrick Value *FalseVal = SI.getFalseValue();
190009467b48Spatrick
190109467b48Spatrick WithOverflowInst *II;
190209467b48Spatrick if (!match(CondVal, m_ExtractValue<1>(m_WithOverflowInst(II))) ||
190309467b48Spatrick !match(FalseVal, m_ExtractValue<0>(m_Specific(II))))
190409467b48Spatrick return nullptr;
190509467b48Spatrick
190609467b48Spatrick Value *X = II->getLHS();
190709467b48Spatrick Value *Y = II->getRHS();
190809467b48Spatrick
190909467b48Spatrick auto IsSignedSaturateLimit = [&](Value *Limit, bool IsAdd) {
191009467b48Spatrick Type *Ty = Limit->getType();
191109467b48Spatrick
191209467b48Spatrick ICmpInst::Predicate Pred;
191309467b48Spatrick Value *TrueVal, *FalseVal, *Op;
191409467b48Spatrick const APInt *C;
191509467b48Spatrick if (!match(Limit, m_Select(m_ICmp(Pred, m_Value(Op), m_APInt(C)),
191609467b48Spatrick m_Value(TrueVal), m_Value(FalseVal))))
191709467b48Spatrick return false;
191809467b48Spatrick
1919*d415bd75Srobert auto IsZeroOrOne = [](const APInt &C) { return C.isZero() || C.isOne(); };
192009467b48Spatrick auto IsMinMax = [&](Value *Min, Value *Max) {
192109467b48Spatrick APInt MinVal = APInt::getSignedMinValue(Ty->getScalarSizeInBits());
192209467b48Spatrick APInt MaxVal = APInt::getSignedMaxValue(Ty->getScalarSizeInBits());
192309467b48Spatrick return match(Min, m_SpecificInt(MinVal)) &&
192409467b48Spatrick match(Max, m_SpecificInt(MaxVal));
192509467b48Spatrick };
192609467b48Spatrick
192709467b48Spatrick if (Op != X && Op != Y)
192809467b48Spatrick return false;
192909467b48Spatrick
193009467b48Spatrick if (IsAdd) {
193109467b48Spatrick // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
193209467b48Spatrick // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
193309467b48Spatrick // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
193409467b48Spatrick // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
193509467b48Spatrick if (Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) &&
193609467b48Spatrick IsMinMax(TrueVal, FalseVal))
193709467b48Spatrick return true;
193809467b48Spatrick // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
193909467b48Spatrick // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
194009467b48Spatrick // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
194109467b48Spatrick // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
194209467b48Spatrick if (Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) &&
194309467b48Spatrick IsMinMax(FalseVal, TrueVal))
194409467b48Spatrick return true;
194509467b48Spatrick } else {
194609467b48Spatrick // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
194709467b48Spatrick // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
194809467b48Spatrick if (Op == X && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C + 1) &&
194909467b48Spatrick IsMinMax(TrueVal, FalseVal))
195009467b48Spatrick return true;
195109467b48Spatrick // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
195209467b48Spatrick // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
195309467b48Spatrick if (Op == X && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 2) &&
195409467b48Spatrick IsMinMax(FalseVal, TrueVal))
195509467b48Spatrick return true;
195609467b48Spatrick // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
195709467b48Spatrick // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
195809467b48Spatrick if (Op == Y && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) &&
195909467b48Spatrick IsMinMax(FalseVal, TrueVal))
196009467b48Spatrick return true;
196109467b48Spatrick // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
196209467b48Spatrick // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
196309467b48Spatrick if (Op == Y && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) &&
196409467b48Spatrick IsMinMax(TrueVal, FalseVal))
196509467b48Spatrick return true;
196609467b48Spatrick }
196709467b48Spatrick
196809467b48Spatrick return false;
196909467b48Spatrick };
197009467b48Spatrick
197109467b48Spatrick Intrinsic::ID NewIntrinsicID;
197209467b48Spatrick if (II->getIntrinsicID() == Intrinsic::uadd_with_overflow &&
197309467b48Spatrick match(TrueVal, m_AllOnes()))
197409467b48Spatrick // X + Y overflows ? -1 : X + Y -> uadd_sat X, Y
197509467b48Spatrick NewIntrinsicID = Intrinsic::uadd_sat;
197609467b48Spatrick else if (II->getIntrinsicID() == Intrinsic::usub_with_overflow &&
197709467b48Spatrick match(TrueVal, m_Zero()))
197809467b48Spatrick // X - Y overflows ? 0 : X - Y -> usub_sat X, Y
197909467b48Spatrick NewIntrinsicID = Intrinsic::usub_sat;
198009467b48Spatrick else if (II->getIntrinsicID() == Intrinsic::sadd_with_overflow &&
198109467b48Spatrick IsSignedSaturateLimit(TrueVal, /*IsAdd=*/true))
198209467b48Spatrick // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
198309467b48Spatrick // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
198409467b48Spatrick // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
198509467b48Spatrick // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
198609467b48Spatrick // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
198709467b48Spatrick // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y
198809467b48Spatrick // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
198909467b48Spatrick // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y
199009467b48Spatrick NewIntrinsicID = Intrinsic::sadd_sat;
199109467b48Spatrick else if (II->getIntrinsicID() == Intrinsic::ssub_with_overflow &&
199209467b48Spatrick IsSignedSaturateLimit(TrueVal, /*IsAdd=*/false))
199309467b48Spatrick // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
199409467b48Spatrick // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
199509467b48Spatrick // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
199609467b48Spatrick // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
199709467b48Spatrick // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
199809467b48Spatrick // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y
199909467b48Spatrick // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
200009467b48Spatrick // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y
200109467b48Spatrick NewIntrinsicID = Intrinsic::ssub_sat;
200209467b48Spatrick else
200309467b48Spatrick return nullptr;
200409467b48Spatrick
200509467b48Spatrick Function *F =
200609467b48Spatrick Intrinsic::getDeclaration(SI.getModule(), NewIntrinsicID, SI.getType());
200709467b48Spatrick return CallInst::Create(F, {X, Y});
200809467b48Spatrick }
200909467b48Spatrick
foldSelectExtConst(SelectInst & Sel)201073471bf0Spatrick Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) {
201109467b48Spatrick Constant *C;
201209467b48Spatrick if (!match(Sel.getTrueValue(), m_Constant(C)) &&
201309467b48Spatrick !match(Sel.getFalseValue(), m_Constant(C)))
201409467b48Spatrick return nullptr;
201509467b48Spatrick
201609467b48Spatrick Instruction *ExtInst;
201709467b48Spatrick if (!match(Sel.getTrueValue(), m_Instruction(ExtInst)) &&
201809467b48Spatrick !match(Sel.getFalseValue(), m_Instruction(ExtInst)))
201909467b48Spatrick return nullptr;
202009467b48Spatrick
202109467b48Spatrick auto ExtOpcode = ExtInst->getOpcode();
202209467b48Spatrick if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt)
202309467b48Spatrick return nullptr;
202409467b48Spatrick
202509467b48Spatrick // If we are extending from a boolean type or if we can create a select that
202609467b48Spatrick // has the same size operands as its condition, try to narrow the select.
202709467b48Spatrick Value *X = ExtInst->getOperand(0);
202809467b48Spatrick Type *SmallType = X->getType();
202909467b48Spatrick Value *Cond = Sel.getCondition();
203009467b48Spatrick auto *Cmp = dyn_cast<CmpInst>(Cond);
203109467b48Spatrick if (!SmallType->isIntOrIntVectorTy(1) &&
203209467b48Spatrick (!Cmp || Cmp->getOperand(0)->getType() != SmallType))
203309467b48Spatrick return nullptr;
203409467b48Spatrick
203509467b48Spatrick // If the constant is the same after truncation to the smaller type and
203609467b48Spatrick // extension to the original type, we can narrow the select.
203709467b48Spatrick Type *SelType = Sel.getType();
203809467b48Spatrick Constant *TruncC = ConstantExpr::getTrunc(C, SmallType);
203909467b48Spatrick Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType);
2040097a140dSpatrick if (ExtC == C && ExtInst->hasOneUse()) {
204109467b48Spatrick Value *TruncCVal = cast<Value>(TruncC);
204209467b48Spatrick if (ExtInst == Sel.getFalseValue())
204309467b48Spatrick std::swap(X, TruncCVal);
204409467b48Spatrick
204509467b48Spatrick // select Cond, (ext X), C --> ext(select Cond, X, C')
204609467b48Spatrick // select Cond, C, (ext X) --> ext(select Cond, C', X)
204709467b48Spatrick Value *NewSel = Builder.CreateSelect(Cond, X, TruncCVal, "narrow", &Sel);
204809467b48Spatrick return CastInst::Create(Instruction::CastOps(ExtOpcode), NewSel, SelType);
204909467b48Spatrick }
205009467b48Spatrick
205109467b48Spatrick // If one arm of the select is the extend of the condition, replace that arm
205209467b48Spatrick // with the extension of the appropriate known bool value.
205309467b48Spatrick if (Cond == X) {
205409467b48Spatrick if (ExtInst == Sel.getTrueValue()) {
205509467b48Spatrick // select X, (sext X), C --> select X, -1, C
205609467b48Spatrick // select X, (zext X), C --> select X, 1, C
205709467b48Spatrick Constant *One = ConstantInt::getTrue(SmallType);
205809467b48Spatrick Constant *AllOnesOrOne = ConstantExpr::getCast(ExtOpcode, One, SelType);
205909467b48Spatrick return SelectInst::Create(Cond, AllOnesOrOne, C, "", nullptr, &Sel);
206009467b48Spatrick } else {
206109467b48Spatrick // select X, C, (sext X) --> select X, C, 0
206209467b48Spatrick // select X, C, (zext X) --> select X, C, 0
206309467b48Spatrick Constant *Zero = ConstantInt::getNullValue(SelType);
206409467b48Spatrick return SelectInst::Create(Cond, C, Zero, "", nullptr, &Sel);
206509467b48Spatrick }
206609467b48Spatrick }
206709467b48Spatrick
206809467b48Spatrick return nullptr;
206909467b48Spatrick }
207009467b48Spatrick
207109467b48Spatrick /// Try to transform a vector select with a constant condition vector into a
207209467b48Spatrick /// shuffle for easier combining with other shuffles and insert/extract.
canonicalizeSelectToShuffle(SelectInst & SI)207309467b48Spatrick static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) {
207409467b48Spatrick Value *CondVal = SI.getCondition();
207509467b48Spatrick Constant *CondC;
207673471bf0Spatrick auto *CondValTy = dyn_cast<FixedVectorType>(CondVal->getType());
207773471bf0Spatrick if (!CondValTy || !match(CondVal, m_Constant(CondC)))
207809467b48Spatrick return nullptr;
207909467b48Spatrick
208073471bf0Spatrick unsigned NumElts = CondValTy->getNumElements();
2081097a140dSpatrick SmallVector<int, 16> Mask;
208209467b48Spatrick Mask.reserve(NumElts);
208309467b48Spatrick for (unsigned i = 0; i != NumElts; ++i) {
208409467b48Spatrick Constant *Elt = CondC->getAggregateElement(i);
208509467b48Spatrick if (!Elt)
208609467b48Spatrick return nullptr;
208709467b48Spatrick
208809467b48Spatrick if (Elt->isOneValue()) {
208909467b48Spatrick // If the select condition element is true, choose from the 1st vector.
2090097a140dSpatrick Mask.push_back(i);
209109467b48Spatrick } else if (Elt->isNullValue()) {
209209467b48Spatrick // If the select condition element is false, choose from the 2nd vector.
2093097a140dSpatrick Mask.push_back(i + NumElts);
209409467b48Spatrick } else if (isa<UndefValue>(Elt)) {
209509467b48Spatrick // Undef in a select condition (choose one of the operands) does not mean
209609467b48Spatrick // the same thing as undef in a shuffle mask (any value is acceptable), so
209709467b48Spatrick // give up.
209809467b48Spatrick return nullptr;
209909467b48Spatrick } else {
210009467b48Spatrick // Bail out on a constant expression.
210109467b48Spatrick return nullptr;
210209467b48Spatrick }
210309467b48Spatrick }
210409467b48Spatrick
2105097a140dSpatrick return new ShuffleVectorInst(SI.getTrueValue(), SI.getFalseValue(), Mask);
210609467b48Spatrick }
210709467b48Spatrick
210809467b48Spatrick /// If we have a select of vectors with a scalar condition, try to convert that
210909467b48Spatrick /// to a vector select by splatting the condition. A splat may get folded with
211009467b48Spatrick /// other operations in IR and having all operands of a select be vector types
211109467b48Spatrick /// is likely better for vector codegen.
canonicalizeScalarSelectOfVecs(SelectInst & Sel,InstCombinerImpl & IC)211273471bf0Spatrick static Instruction *canonicalizeScalarSelectOfVecs(SelectInst &Sel,
211373471bf0Spatrick InstCombinerImpl &IC) {
2114097a140dSpatrick auto *Ty = dyn_cast<VectorType>(Sel.getType());
2115097a140dSpatrick if (!Ty)
211609467b48Spatrick return nullptr;
211709467b48Spatrick
211809467b48Spatrick // We can replace a single-use extract with constant index.
211909467b48Spatrick Value *Cond = Sel.getCondition();
2120097a140dSpatrick if (!match(Cond, m_OneUse(m_ExtractElt(m_Value(), m_ConstantInt()))))
212109467b48Spatrick return nullptr;
212209467b48Spatrick
212309467b48Spatrick // select (extelt V, Index), T, F --> select (splat V, Index), T, F
212409467b48Spatrick // Splatting the extracted condition reduces code (we could directly create a
212509467b48Spatrick // splat shuffle of the source vector to eliminate the intermediate step).
212673471bf0Spatrick return IC.replaceOperand(
212773471bf0Spatrick Sel, 0, IC.Builder.CreateVectorSplat(Ty->getElementCount(), Cond));
212809467b48Spatrick }
212909467b48Spatrick
213009467b48Spatrick /// Reuse bitcasted operands between a compare and select:
213109467b48Spatrick /// select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) -->
213209467b48Spatrick /// bitcast (select (cmp (bitcast C), (bitcast D)), (bitcast C), (bitcast D))
foldSelectCmpBitcasts(SelectInst & Sel,InstCombiner::BuilderTy & Builder)213309467b48Spatrick static Instruction *foldSelectCmpBitcasts(SelectInst &Sel,
213409467b48Spatrick InstCombiner::BuilderTy &Builder) {
213509467b48Spatrick Value *Cond = Sel.getCondition();
213609467b48Spatrick Value *TVal = Sel.getTrueValue();
213709467b48Spatrick Value *FVal = Sel.getFalseValue();
213809467b48Spatrick
213909467b48Spatrick CmpInst::Predicate Pred;
214009467b48Spatrick Value *A, *B;
214109467b48Spatrick if (!match(Cond, m_Cmp(Pred, m_Value(A), m_Value(B))))
214209467b48Spatrick return nullptr;
214309467b48Spatrick
214409467b48Spatrick // The select condition is a compare instruction. If the select's true/false
214509467b48Spatrick // values are already the same as the compare operands, there's nothing to do.
214609467b48Spatrick if (TVal == A || TVal == B || FVal == A || FVal == B)
214709467b48Spatrick return nullptr;
214809467b48Spatrick
214909467b48Spatrick Value *C, *D;
215009467b48Spatrick if (!match(A, m_BitCast(m_Value(C))) || !match(B, m_BitCast(m_Value(D))))
215109467b48Spatrick return nullptr;
215209467b48Spatrick
215309467b48Spatrick // select (cmp (bitcast C), (bitcast D)), (bitcast TSrc), (bitcast FSrc)
215409467b48Spatrick Value *TSrc, *FSrc;
215509467b48Spatrick if (!match(TVal, m_BitCast(m_Value(TSrc))) ||
215609467b48Spatrick !match(FVal, m_BitCast(m_Value(FSrc))))
215709467b48Spatrick return nullptr;
215809467b48Spatrick
215909467b48Spatrick // If the select true/false values are *different bitcasts* of the same source
216009467b48Spatrick // operands, make the select operands the same as the compare operands and
216109467b48Spatrick // cast the result. This is the canonical select form for min/max.
216209467b48Spatrick Value *NewSel;
216309467b48Spatrick if (TSrc == C && FSrc == D) {
216409467b48Spatrick // select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) -->
216509467b48Spatrick // bitcast (select (cmp A, B), A, B)
216609467b48Spatrick NewSel = Builder.CreateSelect(Cond, A, B, "", &Sel);
216709467b48Spatrick } else if (TSrc == D && FSrc == C) {
216809467b48Spatrick // select (cmp (bitcast C), (bitcast D)), (bitcast' D), (bitcast' C) -->
216909467b48Spatrick // bitcast (select (cmp A, B), B, A)
217009467b48Spatrick NewSel = Builder.CreateSelect(Cond, B, A, "", &Sel);
217109467b48Spatrick } else {
217209467b48Spatrick return nullptr;
217309467b48Spatrick }
217409467b48Spatrick return CastInst::CreateBitOrPointerCast(NewSel, Sel.getType());
217509467b48Spatrick }
217609467b48Spatrick
217709467b48Spatrick /// Try to eliminate select instructions that test the returned flag of cmpxchg
217809467b48Spatrick /// instructions.
217909467b48Spatrick ///
218009467b48Spatrick /// If a select instruction tests the returned flag of a cmpxchg instruction and
218109467b48Spatrick /// selects between the returned value of the cmpxchg instruction its compare
218209467b48Spatrick /// operand, the result of the select will always be equal to its false value.
218309467b48Spatrick /// For example:
218409467b48Spatrick ///
218509467b48Spatrick /// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst
218609467b48Spatrick /// %1 = extractvalue { i64, i1 } %0, 1
218709467b48Spatrick /// %2 = extractvalue { i64, i1 } %0, 0
218809467b48Spatrick /// %3 = select i1 %1, i64 %compare, i64 %2
218909467b48Spatrick /// ret i64 %3
219009467b48Spatrick ///
219109467b48Spatrick /// The returned value of the cmpxchg instruction (%2) is the original value
219209467b48Spatrick /// located at %ptr prior to any update. If the cmpxchg operation succeeds, %2
219309467b48Spatrick /// must have been equal to %compare. Thus, the result of the select is always
219409467b48Spatrick /// equal to %2, and the code can be simplified to:
219509467b48Spatrick ///
219609467b48Spatrick /// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst
219709467b48Spatrick /// %1 = extractvalue { i64, i1 } %0, 0
219809467b48Spatrick /// ret i64 %1
219909467b48Spatrick ///
foldSelectCmpXchg(SelectInst & SI)2200097a140dSpatrick static Value *foldSelectCmpXchg(SelectInst &SI) {
220109467b48Spatrick // A helper that determines if V is an extractvalue instruction whose
220209467b48Spatrick // aggregate operand is a cmpxchg instruction and whose single index is equal
220309467b48Spatrick // to I. If such conditions are true, the helper returns the cmpxchg
220409467b48Spatrick // instruction; otherwise, a nullptr is returned.
220509467b48Spatrick auto isExtractFromCmpXchg = [](Value *V, unsigned I) -> AtomicCmpXchgInst * {
220609467b48Spatrick auto *Extract = dyn_cast<ExtractValueInst>(V);
220709467b48Spatrick if (!Extract)
220809467b48Spatrick return nullptr;
220909467b48Spatrick if (Extract->getIndices()[0] != I)
221009467b48Spatrick return nullptr;
221109467b48Spatrick return dyn_cast<AtomicCmpXchgInst>(Extract->getAggregateOperand());
221209467b48Spatrick };
221309467b48Spatrick
221409467b48Spatrick // If the select has a single user, and this user is a select instruction that
221509467b48Spatrick // we can simplify, skip the cmpxchg simplification for now.
221609467b48Spatrick if (SI.hasOneUse())
221709467b48Spatrick if (auto *Select = dyn_cast<SelectInst>(SI.user_back()))
221809467b48Spatrick if (Select->getCondition() == SI.getCondition())
221909467b48Spatrick if (Select->getFalseValue() == SI.getTrueValue() ||
222009467b48Spatrick Select->getTrueValue() == SI.getFalseValue())
222109467b48Spatrick return nullptr;
222209467b48Spatrick
222309467b48Spatrick // Ensure the select condition is the returned flag of a cmpxchg instruction.
222409467b48Spatrick auto *CmpXchg = isExtractFromCmpXchg(SI.getCondition(), 1);
222509467b48Spatrick if (!CmpXchg)
222609467b48Spatrick return nullptr;
222709467b48Spatrick
222809467b48Spatrick // Check the true value case: The true value of the select is the returned
222909467b48Spatrick // value of the same cmpxchg used by the condition, and the false value is the
223009467b48Spatrick // cmpxchg instruction's compare operand.
223109467b48Spatrick if (auto *X = isExtractFromCmpXchg(SI.getTrueValue(), 0))
2232097a140dSpatrick if (X == CmpXchg && X->getCompareOperand() == SI.getFalseValue())
2233097a140dSpatrick return SI.getFalseValue();
223409467b48Spatrick
223509467b48Spatrick // Check the false value case: The false value of the select is the returned
223609467b48Spatrick // value of the same cmpxchg used by the condition, and the true value is the
223709467b48Spatrick // cmpxchg instruction's compare operand.
223809467b48Spatrick if (auto *X = isExtractFromCmpXchg(SI.getFalseValue(), 0))
2239097a140dSpatrick if (X == CmpXchg && X->getCompareOperand() == SI.getTrueValue())
2240097a140dSpatrick return SI.getFalseValue();
224109467b48Spatrick
224209467b48Spatrick return nullptr;
224309467b48Spatrick }
224409467b48Spatrick
224573471bf0Spatrick /// Try to reduce a funnel/rotate pattern that includes a compare and select
224673471bf0Spatrick /// into a funnel shift intrinsic. Example:
224709467b48Spatrick /// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b)))
224809467b48Spatrick /// --> call llvm.fshl.i32(a, a, b)
224973471bf0Spatrick /// fshl32(a, b, c) --> (c == 0 ? a : ((b >> (32 - c)) | (a << c)))
225073471bf0Spatrick /// --> call llvm.fshl.i32(a, b, c)
225173471bf0Spatrick /// fshr32(a, b, c) --> (c == 0 ? b : ((a >> (32 - c)) | (b << c)))
225273471bf0Spatrick /// --> call llvm.fshr.i32(a, b, c)
foldSelectFunnelShift(SelectInst & Sel,InstCombiner::BuilderTy & Builder)225373471bf0Spatrick static Instruction *foldSelectFunnelShift(SelectInst &Sel,
225473471bf0Spatrick InstCombiner::BuilderTy &Builder) {
225573471bf0Spatrick // This must be a power-of-2 type for a bitmasking transform to be valid.
225609467b48Spatrick unsigned Width = Sel.getType()->getScalarSizeInBits();
225709467b48Spatrick if (!isPowerOf2_32(Width))
225809467b48Spatrick return nullptr;
225909467b48Spatrick
226073471bf0Spatrick BinaryOperator *Or0, *Or1;
226173471bf0Spatrick if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1)))))
226273471bf0Spatrick return nullptr;
226373471bf0Spatrick
226473471bf0Spatrick Value *SV0, *SV1, *SA0, *SA1;
226573471bf0Spatrick if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(SV0),
226673471bf0Spatrick m_ZExtOrSelf(m_Value(SA0))))) ||
226773471bf0Spatrick !match(Or1, m_OneUse(m_LogicalShift(m_Value(SV1),
226873471bf0Spatrick m_ZExtOrSelf(m_Value(SA1))))) ||
226973471bf0Spatrick Or0->getOpcode() == Or1->getOpcode())
227073471bf0Spatrick return nullptr;
227173471bf0Spatrick
227273471bf0Spatrick // Canonicalize to or(shl(SV0, SA0), lshr(SV1, SA1)).
227373471bf0Spatrick if (Or0->getOpcode() == BinaryOperator::LShr) {
227473471bf0Spatrick std::swap(Or0, Or1);
227573471bf0Spatrick std::swap(SV0, SV1);
227673471bf0Spatrick std::swap(SA0, SA1);
227773471bf0Spatrick }
227873471bf0Spatrick assert(Or0->getOpcode() == BinaryOperator::Shl &&
227973471bf0Spatrick Or1->getOpcode() == BinaryOperator::LShr &&
228073471bf0Spatrick "Illegal or(shift,shift) pair");
228173471bf0Spatrick
228209467b48Spatrick // Check the shift amounts to see if they are an opposite pair.
228309467b48Spatrick Value *ShAmt;
228409467b48Spatrick if (match(SA1, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA0)))))
228509467b48Spatrick ShAmt = SA0;
228609467b48Spatrick else if (match(SA0, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA1)))))
228709467b48Spatrick ShAmt = SA1;
228809467b48Spatrick else
228909467b48Spatrick return nullptr;
229009467b48Spatrick
229173471bf0Spatrick // We should now have this pattern:
229273471bf0Spatrick // select ?, TVal, (or (shl SV0, SA0), (lshr SV1, SA1))
229373471bf0Spatrick // The false value of the select must be a funnel-shift of the true value:
229473471bf0Spatrick // IsFShl -> TVal must be SV0 else TVal must be SV1.
229573471bf0Spatrick bool IsFshl = (ShAmt == SA0);
229673471bf0Spatrick Value *TVal = Sel.getTrueValue();
229773471bf0Spatrick if ((IsFshl && TVal != SV0) || (!IsFshl && TVal != SV1))
229873471bf0Spatrick return nullptr;
229973471bf0Spatrick
230009467b48Spatrick // Finally, see if the select is filtering out a shift-by-zero.
230109467b48Spatrick Value *Cond = Sel.getCondition();
230209467b48Spatrick ICmpInst::Predicate Pred;
230309467b48Spatrick if (!match(Cond, m_OneUse(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()))) ||
230409467b48Spatrick Pred != ICmpInst::ICMP_EQ)
230509467b48Spatrick return nullptr;
230609467b48Spatrick
230773471bf0Spatrick // If this is not a rotate then the select was blocking poison from the
230873471bf0Spatrick // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it.
230973471bf0Spatrick if (SV0 != SV1) {
231073471bf0Spatrick if (IsFshl && !llvm::isGuaranteedNotToBePoison(SV1))
231173471bf0Spatrick SV1 = Builder.CreateFreeze(SV1);
231273471bf0Spatrick else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(SV0))
231373471bf0Spatrick SV0 = Builder.CreateFreeze(SV0);
231473471bf0Spatrick }
231573471bf0Spatrick
231673471bf0Spatrick // This is a funnel/rotate that avoids shift-by-bitwidth UB in a suboptimal way.
231709467b48Spatrick // Convert to funnel shift intrinsic.
231809467b48Spatrick Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr;
231909467b48Spatrick Function *F = Intrinsic::getDeclaration(Sel.getModule(), IID, Sel.getType());
232073471bf0Spatrick ShAmt = Builder.CreateZExt(ShAmt, Sel.getType());
232173471bf0Spatrick return CallInst::Create(F, { SV0, SV1, ShAmt });
232209467b48Spatrick }
232309467b48Spatrick
foldSelectToCopysign(SelectInst & Sel,InstCombiner::BuilderTy & Builder)2324097a140dSpatrick static Instruction *foldSelectToCopysign(SelectInst &Sel,
2325097a140dSpatrick InstCombiner::BuilderTy &Builder) {
2326097a140dSpatrick Value *Cond = Sel.getCondition();
2327097a140dSpatrick Value *TVal = Sel.getTrueValue();
2328097a140dSpatrick Value *FVal = Sel.getFalseValue();
2329097a140dSpatrick Type *SelType = Sel.getType();
2330097a140dSpatrick
2331097a140dSpatrick // Match select ?, TC, FC where the constants are equal but negated.
2332097a140dSpatrick // TODO: Generalize to handle a negated variable operand?
2333097a140dSpatrick const APFloat *TC, *FC;
2334*d415bd75Srobert if (!match(TVal, m_APFloatAllowUndef(TC)) ||
2335*d415bd75Srobert !match(FVal, m_APFloatAllowUndef(FC)) ||
2336097a140dSpatrick !abs(*TC).bitwiseIsEqual(abs(*FC)))
2337097a140dSpatrick return nullptr;
2338097a140dSpatrick
2339097a140dSpatrick assert(TC != FC && "Expected equal select arms to simplify");
2340097a140dSpatrick
2341097a140dSpatrick Value *X;
2342097a140dSpatrick const APInt *C;
2343097a140dSpatrick bool IsTrueIfSignSet;
2344097a140dSpatrick ICmpInst::Predicate Pred;
2345097a140dSpatrick if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) ||
234673471bf0Spatrick !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) ||
234773471bf0Spatrick X->getType() != SelType)
2348097a140dSpatrick return nullptr;
2349097a140dSpatrick
2350097a140dSpatrick // If needed, negate the value that will be the sign argument of the copysign:
2351097a140dSpatrick // (bitcast X) < 0 ? -TC : TC --> copysign(TC, X)
2352097a140dSpatrick // (bitcast X) < 0 ? TC : -TC --> copysign(TC, -X)
2353097a140dSpatrick // (bitcast X) >= 0 ? -TC : TC --> copysign(TC, -X)
2354097a140dSpatrick // (bitcast X) >= 0 ? TC : -TC --> copysign(TC, X)
2355*d415bd75Srobert // Note: FMF from the select can not be propagated to the new instructions.
2356097a140dSpatrick if (IsTrueIfSignSet ^ TC->isNegative())
2357*d415bd75Srobert X = Builder.CreateFNeg(X);
2358097a140dSpatrick
2359097a140dSpatrick // Canonicalize the magnitude argument as the positive constant since we do
2360097a140dSpatrick // not care about its sign.
2361*d415bd75Srobert Value *MagArg = ConstantFP::get(SelType, abs(*TC));
2362097a140dSpatrick Function *F = Intrinsic::getDeclaration(Sel.getModule(), Intrinsic::copysign,
2363097a140dSpatrick Sel.getType());
2364*d415bd75Srobert return CallInst::Create(F, { MagArg, X });
2365097a140dSpatrick }
2366097a140dSpatrick
foldVectorSelect(SelectInst & Sel)236773471bf0Spatrick Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) {
2368*d415bd75Srobert if (!isa<VectorType>(Sel.getType()))
2369*d415bd75Srobert return nullptr;
2370*d415bd75Srobert
2371*d415bd75Srobert Value *Cond = Sel.getCondition();
2372*d415bd75Srobert Value *TVal = Sel.getTrueValue();
2373*d415bd75Srobert Value *FVal = Sel.getFalseValue();
2374*d415bd75Srobert Value *C, *X, *Y;
2375*d415bd75Srobert
2376*d415bd75Srobert if (match(Cond, m_VecReverse(m_Value(C)))) {
2377*d415bd75Srobert auto createSelReverse = [&](Value *C, Value *X, Value *Y) {
2378*d415bd75Srobert Value *V = Builder.CreateSelect(C, X, Y, Sel.getName(), &Sel);
2379*d415bd75Srobert if (auto *I = dyn_cast<Instruction>(V))
2380*d415bd75Srobert I->copyIRFlags(&Sel);
2381*d415bd75Srobert Module *M = Sel.getModule();
2382*d415bd75Srobert Function *F = Intrinsic::getDeclaration(
2383*d415bd75Srobert M, Intrinsic::experimental_vector_reverse, V->getType());
2384*d415bd75Srobert return CallInst::Create(F, V);
2385*d415bd75Srobert };
2386*d415bd75Srobert
2387*d415bd75Srobert if (match(TVal, m_VecReverse(m_Value(X)))) {
2388*d415bd75Srobert // select rev(C), rev(X), rev(Y) --> rev(select C, X, Y)
2389*d415bd75Srobert if (match(FVal, m_VecReverse(m_Value(Y))) &&
2390*d415bd75Srobert (Cond->hasOneUse() || TVal->hasOneUse() || FVal->hasOneUse()))
2391*d415bd75Srobert return createSelReverse(C, X, Y);
2392*d415bd75Srobert
2393*d415bd75Srobert // select rev(C), rev(X), FValSplat --> rev(select C, X, FValSplat)
2394*d415bd75Srobert if ((Cond->hasOneUse() || TVal->hasOneUse()) && isSplatValue(FVal))
2395*d415bd75Srobert return createSelReverse(C, X, FVal);
2396*d415bd75Srobert }
2397*d415bd75Srobert // select rev(C), TValSplat, rev(Y) --> rev(select C, TValSplat, Y)
2398*d415bd75Srobert else if (isSplatValue(TVal) && match(FVal, m_VecReverse(m_Value(Y))) &&
2399*d415bd75Srobert (Cond->hasOneUse() || FVal->hasOneUse()))
2400*d415bd75Srobert return createSelReverse(C, TVal, Y);
2401*d415bd75Srobert }
2402*d415bd75Srobert
2403097a140dSpatrick auto *VecTy = dyn_cast<FixedVectorType>(Sel.getType());
2404097a140dSpatrick if (!VecTy)
2405097a140dSpatrick return nullptr;
2406097a140dSpatrick
2407097a140dSpatrick unsigned NumElts = VecTy->getNumElements();
2408097a140dSpatrick APInt UndefElts(NumElts, 0);
2409*d415bd75Srobert APInt AllOnesEltMask(APInt::getAllOnes(NumElts));
2410097a140dSpatrick if (Value *V = SimplifyDemandedVectorElts(&Sel, AllOnesEltMask, UndefElts)) {
2411097a140dSpatrick if (V != &Sel)
2412097a140dSpatrick return replaceInstUsesWith(Sel, V);
2413097a140dSpatrick return &Sel;
2414097a140dSpatrick }
2415097a140dSpatrick
2416097a140dSpatrick // A select of a "select shuffle" with a common operand can be rearranged
2417097a140dSpatrick // to select followed by "select shuffle". Because of poison, this only works
2418097a140dSpatrick // in the case of a shuffle with no undefined mask elements.
2419097a140dSpatrick ArrayRef<int> Mask;
2420097a140dSpatrick if (match(TVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) &&
2421097a140dSpatrick !is_contained(Mask, UndefMaskElem) &&
2422097a140dSpatrick cast<ShuffleVectorInst>(TVal)->isSelect()) {
2423097a140dSpatrick if (X == FVal) {
2424097a140dSpatrick // select Cond, (shuf_sel X, Y), X --> shuf_sel X, (select Cond, Y, X)
2425097a140dSpatrick Value *NewSel = Builder.CreateSelect(Cond, Y, X, "sel", &Sel);
2426097a140dSpatrick return new ShuffleVectorInst(X, NewSel, Mask);
2427097a140dSpatrick }
2428097a140dSpatrick if (Y == FVal) {
2429097a140dSpatrick // select Cond, (shuf_sel X, Y), Y --> shuf_sel (select Cond, X, Y), Y
2430097a140dSpatrick Value *NewSel = Builder.CreateSelect(Cond, X, Y, "sel", &Sel);
2431097a140dSpatrick return new ShuffleVectorInst(NewSel, Y, Mask);
2432097a140dSpatrick }
2433097a140dSpatrick }
2434097a140dSpatrick if (match(FVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) &&
2435097a140dSpatrick !is_contained(Mask, UndefMaskElem) &&
2436097a140dSpatrick cast<ShuffleVectorInst>(FVal)->isSelect()) {
2437097a140dSpatrick if (X == TVal) {
2438097a140dSpatrick // select Cond, X, (shuf_sel X, Y) --> shuf_sel X, (select Cond, X, Y)
2439097a140dSpatrick Value *NewSel = Builder.CreateSelect(Cond, X, Y, "sel", &Sel);
2440097a140dSpatrick return new ShuffleVectorInst(X, NewSel, Mask);
2441097a140dSpatrick }
2442097a140dSpatrick if (Y == TVal) {
2443097a140dSpatrick // select Cond, Y, (shuf_sel X, Y) --> shuf_sel (select Cond, Y, X), Y
2444097a140dSpatrick Value *NewSel = Builder.CreateSelect(Cond, Y, X, "sel", &Sel);
2445097a140dSpatrick return new ShuffleVectorInst(NewSel, Y, Mask);
2446097a140dSpatrick }
2447097a140dSpatrick }
2448097a140dSpatrick
2449097a140dSpatrick return nullptr;
2450097a140dSpatrick }
2451097a140dSpatrick
foldSelectToPhiImpl(SelectInst & Sel,BasicBlock * BB,const DominatorTree & DT,InstCombiner::BuilderTy & Builder)2452097a140dSpatrick static Instruction *foldSelectToPhiImpl(SelectInst &Sel, BasicBlock *BB,
2453097a140dSpatrick const DominatorTree &DT,
2454097a140dSpatrick InstCombiner::BuilderTy &Builder) {
2455097a140dSpatrick // Find the block's immediate dominator that ends with a conditional branch
2456097a140dSpatrick // that matches select's condition (maybe inverted).
2457097a140dSpatrick auto *IDomNode = DT[BB]->getIDom();
2458097a140dSpatrick if (!IDomNode)
2459097a140dSpatrick return nullptr;
2460097a140dSpatrick BasicBlock *IDom = IDomNode->getBlock();
2461097a140dSpatrick
2462097a140dSpatrick Value *Cond = Sel.getCondition();
2463097a140dSpatrick Value *IfTrue, *IfFalse;
2464097a140dSpatrick BasicBlock *TrueSucc, *FalseSucc;
2465097a140dSpatrick if (match(IDom->getTerminator(),
2466097a140dSpatrick m_Br(m_Specific(Cond), m_BasicBlock(TrueSucc),
2467097a140dSpatrick m_BasicBlock(FalseSucc)))) {
2468097a140dSpatrick IfTrue = Sel.getTrueValue();
2469097a140dSpatrick IfFalse = Sel.getFalseValue();
2470097a140dSpatrick } else if (match(IDom->getTerminator(),
2471097a140dSpatrick m_Br(m_Not(m_Specific(Cond)), m_BasicBlock(TrueSucc),
2472097a140dSpatrick m_BasicBlock(FalseSucc)))) {
2473097a140dSpatrick IfTrue = Sel.getFalseValue();
2474097a140dSpatrick IfFalse = Sel.getTrueValue();
2475097a140dSpatrick } else
2476097a140dSpatrick return nullptr;
2477097a140dSpatrick
2478097a140dSpatrick // Make sure the branches are actually different.
2479097a140dSpatrick if (TrueSucc == FalseSucc)
2480097a140dSpatrick return nullptr;
2481097a140dSpatrick
2482097a140dSpatrick // We want to replace select %cond, %a, %b with a phi that takes value %a
2483097a140dSpatrick // for all incoming edges that are dominated by condition `%cond == true`,
2484097a140dSpatrick // and value %b for edges dominated by condition `%cond == false`. If %a
2485097a140dSpatrick // or %b are also phis from the same basic block, we can go further and take
2486097a140dSpatrick // their incoming values from the corresponding blocks.
2487097a140dSpatrick BasicBlockEdge TrueEdge(IDom, TrueSucc);
2488097a140dSpatrick BasicBlockEdge FalseEdge(IDom, FalseSucc);
2489097a140dSpatrick DenseMap<BasicBlock *, Value *> Inputs;
2490097a140dSpatrick for (auto *Pred : predecessors(BB)) {
2491097a140dSpatrick // Check implication.
2492097a140dSpatrick BasicBlockEdge Incoming(Pred, BB);
2493097a140dSpatrick if (DT.dominates(TrueEdge, Incoming))
2494097a140dSpatrick Inputs[Pred] = IfTrue->DoPHITranslation(BB, Pred);
2495097a140dSpatrick else if (DT.dominates(FalseEdge, Incoming))
2496097a140dSpatrick Inputs[Pred] = IfFalse->DoPHITranslation(BB, Pred);
2497097a140dSpatrick else
2498097a140dSpatrick return nullptr;
2499097a140dSpatrick // Check availability.
2500097a140dSpatrick if (auto *Insn = dyn_cast<Instruction>(Inputs[Pred]))
2501097a140dSpatrick if (!DT.dominates(Insn, Pred->getTerminator()))
2502097a140dSpatrick return nullptr;
2503097a140dSpatrick }
2504097a140dSpatrick
2505097a140dSpatrick Builder.SetInsertPoint(&*BB->begin());
2506097a140dSpatrick auto *PN = Builder.CreatePHI(Sel.getType(), Inputs.size());
2507097a140dSpatrick for (auto *Pred : predecessors(BB))
2508097a140dSpatrick PN->addIncoming(Inputs[Pred], Pred);
2509097a140dSpatrick PN->takeName(&Sel);
2510097a140dSpatrick return PN;
2511097a140dSpatrick }
2512097a140dSpatrick
foldSelectToPhi(SelectInst & Sel,const DominatorTree & DT,InstCombiner::BuilderTy & Builder)2513097a140dSpatrick static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT,
2514097a140dSpatrick InstCombiner::BuilderTy &Builder) {
2515097a140dSpatrick // Try to replace this select with Phi in one of these blocks.
2516097a140dSpatrick SmallSetVector<BasicBlock *, 4> CandidateBlocks;
2517097a140dSpatrick CandidateBlocks.insert(Sel.getParent());
2518097a140dSpatrick for (Value *V : Sel.operands())
2519097a140dSpatrick if (auto *I = dyn_cast<Instruction>(V))
2520097a140dSpatrick CandidateBlocks.insert(I->getParent());
2521097a140dSpatrick
2522097a140dSpatrick for (BasicBlock *BB : CandidateBlocks)
2523097a140dSpatrick if (auto *PN = foldSelectToPhiImpl(Sel, BB, DT, Builder))
2524097a140dSpatrick return PN;
2525097a140dSpatrick return nullptr;
2526097a140dSpatrick }
2527097a140dSpatrick
foldSelectWithFrozenICmp(SelectInst & Sel,InstCombiner::BuilderTy & Builder)252873471bf0Spatrick static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) {
252973471bf0Spatrick FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition());
253073471bf0Spatrick if (!FI)
253173471bf0Spatrick return nullptr;
253273471bf0Spatrick
253373471bf0Spatrick Value *Cond = FI->getOperand(0);
253473471bf0Spatrick Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue();
253573471bf0Spatrick
253673471bf0Spatrick // select (freeze(x == y)), x, y --> y
253773471bf0Spatrick // select (freeze(x != y)), x, y --> x
253873471bf0Spatrick // The freeze should be only used by this select. Otherwise, remaining uses of
253973471bf0Spatrick // the freeze can observe a contradictory value.
254073471bf0Spatrick // c = freeze(x == y) ; Let's assume that y = poison & x = 42; c is 0 or 1
254173471bf0Spatrick // a = select c, x, y ;
254273471bf0Spatrick // f(a, c) ; f(poison, 1) cannot happen, but if a is folded
254373471bf0Spatrick // ; to y, this can happen.
254473471bf0Spatrick CmpInst::Predicate Pred;
254573471bf0Spatrick if (FI->hasOneUse() &&
254673471bf0Spatrick match(Cond, m_c_ICmp(Pred, m_Specific(TrueVal), m_Specific(FalseVal))) &&
254773471bf0Spatrick (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)) {
254873471bf0Spatrick return Pred == ICmpInst::ICMP_EQ ? FalseVal : TrueVal;
254973471bf0Spatrick }
255073471bf0Spatrick
255173471bf0Spatrick return nullptr;
255273471bf0Spatrick }
255373471bf0Spatrick
foldAndOrOfSelectUsingImpliedCond(Value * Op,SelectInst & SI,bool IsAnd)255473471bf0Spatrick Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op,
255573471bf0Spatrick SelectInst &SI,
255673471bf0Spatrick bool IsAnd) {
255773471bf0Spatrick Value *CondVal = SI.getCondition();
255873471bf0Spatrick Value *A = SI.getTrueValue();
255973471bf0Spatrick Value *B = SI.getFalseValue();
256073471bf0Spatrick
256173471bf0Spatrick assert(Op->getType()->isIntOrIntVectorTy(1) &&
256273471bf0Spatrick "Op must be either i1 or vector of i1.");
256373471bf0Spatrick
2564*d415bd75Srobert std::optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd);
256573471bf0Spatrick if (!Res)
256673471bf0Spatrick return nullptr;
256773471bf0Spatrick
256873471bf0Spatrick Value *Zero = Constant::getNullValue(A->getType());
256973471bf0Spatrick Value *One = Constant::getAllOnesValue(A->getType());
257073471bf0Spatrick
257173471bf0Spatrick if (*Res == true) {
257273471bf0Spatrick if (IsAnd)
257373471bf0Spatrick // select op, (select cond, A, B), false => select op, A, false
257473471bf0Spatrick // and op, (select cond, A, B) => select op, A, false
257573471bf0Spatrick // if op = true implies condval = true.
257673471bf0Spatrick return SelectInst::Create(Op, A, Zero);
257773471bf0Spatrick else
257873471bf0Spatrick // select op, true, (select cond, A, B) => select op, true, A
257973471bf0Spatrick // or op, (select cond, A, B) => select op, true, A
258073471bf0Spatrick // if op = false implies condval = true.
258173471bf0Spatrick return SelectInst::Create(Op, One, A);
258273471bf0Spatrick } else {
258373471bf0Spatrick if (IsAnd)
258473471bf0Spatrick // select op, (select cond, A, B), false => select op, B, false
258573471bf0Spatrick // and op, (select cond, A, B) => select op, B, false
258673471bf0Spatrick // if op = true implies condval = false.
258773471bf0Spatrick return SelectInst::Create(Op, B, Zero);
258873471bf0Spatrick else
258973471bf0Spatrick // select op, true, (select cond, A, B) => select op, true, B
259073471bf0Spatrick // or op, (select cond, A, B) => select op, true, B
259173471bf0Spatrick // if op = false implies condval = false.
259273471bf0Spatrick return SelectInst::Create(Op, One, B);
259373471bf0Spatrick }
259473471bf0Spatrick }
259573471bf0Spatrick
2596*d415bd75Srobert // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need
2597*d415bd75Srobert // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work.
foldSelectWithFCmpToFabs(SelectInst & SI,InstCombinerImpl & IC)2598*d415bd75Srobert static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
2599*d415bd75Srobert InstCombinerImpl &IC) {
2600*d415bd75Srobert Value *CondVal = SI.getCondition();
2601*d415bd75Srobert
2602*d415bd75Srobert bool ChangedFMF = false;
2603*d415bd75Srobert for (bool Swap : {false, true}) {
2604*d415bd75Srobert Value *TrueVal = SI.getTrueValue();
2605*d415bd75Srobert Value *X = SI.getFalseValue();
2606*d415bd75Srobert CmpInst::Predicate Pred;
2607*d415bd75Srobert
2608*d415bd75Srobert if (Swap)
2609*d415bd75Srobert std::swap(TrueVal, X);
2610*d415bd75Srobert
2611*d415bd75Srobert if (!match(CondVal, m_FCmp(Pred, m_Specific(X), m_AnyZeroFP())))
2612*d415bd75Srobert continue;
2613*d415bd75Srobert
2614*d415bd75Srobert // fold (X <= +/-0.0) ? (0.0 - X) : X to fabs(X), when 'Swap' is false
2615*d415bd75Srobert // fold (X > +/-0.0) ? X : (0.0 - X) to fabs(X), when 'Swap' is true
2616*d415bd75Srobert if (match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(X)))) {
2617*d415bd75Srobert if (!Swap && (Pred == FCmpInst::FCMP_OLE || Pred == FCmpInst::FCMP_ULE)) {
2618*d415bd75Srobert Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI);
2619*d415bd75Srobert return IC.replaceInstUsesWith(SI, Fabs);
2620*d415bd75Srobert }
2621*d415bd75Srobert if (Swap && (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_UGT)) {
2622*d415bd75Srobert Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI);
2623*d415bd75Srobert return IC.replaceInstUsesWith(SI, Fabs);
2624*d415bd75Srobert }
2625*d415bd75Srobert }
2626*d415bd75Srobert
2627*d415bd75Srobert if (!match(TrueVal, m_FNeg(m_Specific(X))))
2628*d415bd75Srobert return nullptr;
2629*d415bd75Srobert
2630*d415bd75Srobert // Forward-propagate nnan and ninf from the fneg to the select.
2631*d415bd75Srobert // If all inputs are not those values, then the select is not either.
2632*d415bd75Srobert // Note: nsz is defined differently, so it may not be correct to propagate.
2633*d415bd75Srobert FastMathFlags FMF = cast<FPMathOperator>(TrueVal)->getFastMathFlags();
2634*d415bd75Srobert if (FMF.noNaNs() && !SI.hasNoNaNs()) {
2635*d415bd75Srobert SI.setHasNoNaNs(true);
2636*d415bd75Srobert ChangedFMF = true;
2637*d415bd75Srobert }
2638*d415bd75Srobert if (FMF.noInfs() && !SI.hasNoInfs()) {
2639*d415bd75Srobert SI.setHasNoInfs(true);
2640*d415bd75Srobert ChangedFMF = true;
2641*d415bd75Srobert }
2642*d415bd75Srobert
2643*d415bd75Srobert // With nsz, when 'Swap' is false:
2644*d415bd75Srobert // fold (X < +/-0.0) ? -X : X or (X <= +/-0.0) ? -X : X to fabs(X)
2645*d415bd75Srobert // fold (X > +/-0.0) ? -X : X or (X >= +/-0.0) ? -X : X to -fabs(x)
2646*d415bd75Srobert // when 'Swap' is true:
2647*d415bd75Srobert // fold (X > +/-0.0) ? X : -X or (X >= +/-0.0) ? X : -X to fabs(X)
2648*d415bd75Srobert // fold (X < +/-0.0) ? X : -X or (X <= +/-0.0) ? X : -X to -fabs(X)
2649*d415bd75Srobert //
2650*d415bd75Srobert // Note: We require "nnan" for this fold because fcmp ignores the signbit
2651*d415bd75Srobert // of NAN, but IEEE-754 specifies the signbit of NAN values with
2652*d415bd75Srobert // fneg/fabs operations.
2653*d415bd75Srobert if (!SI.hasNoSignedZeros() || !SI.hasNoNaNs())
2654*d415bd75Srobert return nullptr;
2655*d415bd75Srobert
2656*d415bd75Srobert if (Swap)
2657*d415bd75Srobert Pred = FCmpInst::getSwappedPredicate(Pred);
2658*d415bd75Srobert
2659*d415bd75Srobert bool IsLTOrLE = Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE ||
2660*d415bd75Srobert Pred == FCmpInst::FCMP_ULT || Pred == FCmpInst::FCMP_ULE;
2661*d415bd75Srobert bool IsGTOrGE = Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE ||
2662*d415bd75Srobert Pred == FCmpInst::FCMP_UGT || Pred == FCmpInst::FCMP_UGE;
2663*d415bd75Srobert
2664*d415bd75Srobert if (IsLTOrLE) {
2665*d415bd75Srobert Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI);
2666*d415bd75Srobert return IC.replaceInstUsesWith(SI, Fabs);
2667*d415bd75Srobert }
2668*d415bd75Srobert if (IsGTOrGE) {
2669*d415bd75Srobert Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI);
2670*d415bd75Srobert Instruction *NewFNeg = UnaryOperator::CreateFNeg(Fabs);
2671*d415bd75Srobert NewFNeg->setFastMathFlags(SI.getFastMathFlags());
2672*d415bd75Srobert return NewFNeg;
2673*d415bd75Srobert }
2674*d415bd75Srobert }
2675*d415bd75Srobert
2676*d415bd75Srobert return ChangedFMF ? &SI : nullptr;
2677*d415bd75Srobert }
2678*d415bd75Srobert
2679*d415bd75Srobert // Match the following IR pattern:
2680*d415bd75Srobert // %x.lowbits = and i8 %x, %lowbitmask
2681*d415bd75Srobert // %x.lowbits.are.zero = icmp eq i8 %x.lowbits, 0
2682*d415bd75Srobert // %x.biased = add i8 %x, %bias
2683*d415bd75Srobert // %x.biased.highbits = and i8 %x.biased, %highbitmask
2684*d415bd75Srobert // %x.roundedup = select i1 %x.lowbits.are.zero, i8 %x, i8 %x.biased.highbits
2685*d415bd75Srobert // Define:
2686*d415bd75Srobert // %alignment = add i8 %lowbitmask, 1
2687*d415bd75Srobert // Iff 1. an %alignment is a power-of-two (aka, %lowbitmask is a low bit mask)
2688*d415bd75Srobert // and 2. %bias is equal to either %lowbitmask or %alignment,
2689*d415bd75Srobert // and 3. %highbitmask is equal to ~%lowbitmask (aka, to -%alignment)
2690*d415bd75Srobert // then this pattern can be transformed into:
2691*d415bd75Srobert // %x.offset = add i8 %x, %lowbitmask
2692*d415bd75Srobert // %x.roundedup = and i8 %x.offset, %highbitmask
2693*d415bd75Srobert static Value *
foldRoundUpIntegerWithPow2Alignment(SelectInst & SI,InstCombiner::BuilderTy & Builder)2694*d415bd75Srobert foldRoundUpIntegerWithPow2Alignment(SelectInst &SI,
2695*d415bd75Srobert InstCombiner::BuilderTy &Builder) {
2696*d415bd75Srobert Value *Cond = SI.getCondition();
2697*d415bd75Srobert Value *X = SI.getTrueValue();
2698*d415bd75Srobert Value *XBiasedHighBits = SI.getFalseValue();
2699*d415bd75Srobert
2700*d415bd75Srobert ICmpInst::Predicate Pred;
2701*d415bd75Srobert Value *XLowBits;
2702*d415bd75Srobert if (!match(Cond, m_ICmp(Pred, m_Value(XLowBits), m_ZeroInt())) ||
2703*d415bd75Srobert !ICmpInst::isEquality(Pred))
2704*d415bd75Srobert return nullptr;
2705*d415bd75Srobert
2706*d415bd75Srobert if (Pred == ICmpInst::Predicate::ICMP_NE)
2707*d415bd75Srobert std::swap(X, XBiasedHighBits);
2708*d415bd75Srobert
2709*d415bd75Srobert // FIXME: we could support non non-splats here.
2710*d415bd75Srobert
2711*d415bd75Srobert const APInt *LowBitMaskCst;
2712*d415bd75Srobert if (!match(XLowBits, m_And(m_Specific(X), m_APIntAllowUndef(LowBitMaskCst))))
2713*d415bd75Srobert return nullptr;
2714*d415bd75Srobert
2715*d415bd75Srobert // Match even if the AND and ADD are swapped.
2716*d415bd75Srobert const APInt *BiasCst, *HighBitMaskCst;
2717*d415bd75Srobert if (!match(XBiasedHighBits,
2718*d415bd75Srobert m_And(m_Add(m_Specific(X), m_APIntAllowUndef(BiasCst)),
2719*d415bd75Srobert m_APIntAllowUndef(HighBitMaskCst))) &&
2720*d415bd75Srobert !match(XBiasedHighBits,
2721*d415bd75Srobert m_Add(m_And(m_Specific(X), m_APIntAllowUndef(HighBitMaskCst)),
2722*d415bd75Srobert m_APIntAllowUndef(BiasCst))))
2723*d415bd75Srobert return nullptr;
2724*d415bd75Srobert
2725*d415bd75Srobert if (!LowBitMaskCst->isMask())
2726*d415bd75Srobert return nullptr;
2727*d415bd75Srobert
2728*d415bd75Srobert APInt InvertedLowBitMaskCst = ~*LowBitMaskCst;
2729*d415bd75Srobert if (InvertedLowBitMaskCst != *HighBitMaskCst)
2730*d415bd75Srobert return nullptr;
2731*d415bd75Srobert
2732*d415bd75Srobert APInt AlignmentCst = *LowBitMaskCst + 1;
2733*d415bd75Srobert
2734*d415bd75Srobert if (*BiasCst != AlignmentCst && *BiasCst != *LowBitMaskCst)
2735*d415bd75Srobert return nullptr;
2736*d415bd75Srobert
2737*d415bd75Srobert if (!XBiasedHighBits->hasOneUse()) {
2738*d415bd75Srobert if (*BiasCst == *LowBitMaskCst)
2739*d415bd75Srobert return XBiasedHighBits;
2740*d415bd75Srobert return nullptr;
2741*d415bd75Srobert }
2742*d415bd75Srobert
2743*d415bd75Srobert // FIXME: could we preserve undef's here?
2744*d415bd75Srobert Type *Ty = X->getType();
2745*d415bd75Srobert Value *XOffset = Builder.CreateAdd(X, ConstantInt::get(Ty, *LowBitMaskCst),
2746*d415bd75Srobert X->getName() + ".biased");
2747*d415bd75Srobert Value *R = Builder.CreateAnd(XOffset, ConstantInt::get(Ty, *HighBitMaskCst));
2748*d415bd75Srobert R->takeName(&SI);
2749*d415bd75Srobert return R;
2750*d415bd75Srobert }
2751*d415bd75Srobert
2752*d415bd75Srobert namespace {
2753*d415bd75Srobert struct DecomposedSelect {
2754*d415bd75Srobert Value *Cond = nullptr;
2755*d415bd75Srobert Value *TrueVal = nullptr;
2756*d415bd75Srobert Value *FalseVal = nullptr;
2757*d415bd75Srobert };
2758*d415bd75Srobert } // namespace
2759*d415bd75Srobert
2760*d415bd75Srobert /// Look for patterns like
2761*d415bd75Srobert /// %outer.cond = select i1 %inner.cond, i1 %alt.cond, i1 false
2762*d415bd75Srobert /// %inner.sel = select i1 %inner.cond, i8 %inner.sel.t, i8 %inner.sel.f
2763*d415bd75Srobert /// %outer.sel = select i1 %outer.cond, i8 %outer.sel.t, i8 %inner.sel
2764*d415bd75Srobert /// and rewrite it as
2765*d415bd75Srobert /// %inner.sel = select i1 %cond.alternative, i8 %sel.outer.t, i8 %sel.inner.t
2766*d415bd75Srobert /// %sel.outer = select i1 %cond.inner, i8 %inner.sel, i8 %sel.inner.f
foldNestedSelects(SelectInst & OuterSelVal,InstCombiner::BuilderTy & Builder)2767*d415bd75Srobert static Instruction *foldNestedSelects(SelectInst &OuterSelVal,
2768*d415bd75Srobert InstCombiner::BuilderTy &Builder) {
2769*d415bd75Srobert // We must start with a `select`.
2770*d415bd75Srobert DecomposedSelect OuterSel;
2771*d415bd75Srobert match(&OuterSelVal,
2772*d415bd75Srobert m_Select(m_Value(OuterSel.Cond), m_Value(OuterSel.TrueVal),
2773*d415bd75Srobert m_Value(OuterSel.FalseVal)));
2774*d415bd75Srobert
2775*d415bd75Srobert // Canonicalize inversion of the outermost `select`'s condition.
2776*d415bd75Srobert if (match(OuterSel.Cond, m_Not(m_Value(OuterSel.Cond))))
2777*d415bd75Srobert std::swap(OuterSel.TrueVal, OuterSel.FalseVal);
2778*d415bd75Srobert
2779*d415bd75Srobert // The condition of the outermost select must be an `and`/`or`.
2780*d415bd75Srobert if (!match(OuterSel.Cond, m_c_LogicalOp(m_Value(), m_Value())))
2781*d415bd75Srobert return nullptr;
2782*d415bd75Srobert
2783*d415bd75Srobert // Depending on the logical op, inner select might be in different hand.
2784*d415bd75Srobert bool IsAndVariant = match(OuterSel.Cond, m_LogicalAnd());
2785*d415bd75Srobert Value *InnerSelVal = IsAndVariant ? OuterSel.FalseVal : OuterSel.TrueVal;
2786*d415bd75Srobert
2787*d415bd75Srobert // Profitability check - avoid increasing instruction count.
2788*d415bd75Srobert if (none_of(ArrayRef<Value *>({OuterSelVal.getCondition(), InnerSelVal}),
2789*d415bd75Srobert [](Value *V) { return V->hasOneUse(); }))
2790*d415bd75Srobert return nullptr;
2791*d415bd75Srobert
2792*d415bd75Srobert // The appropriate hand of the outermost `select` must be a select itself.
2793*d415bd75Srobert DecomposedSelect InnerSel;
2794*d415bd75Srobert if (!match(InnerSelVal,
2795*d415bd75Srobert m_Select(m_Value(InnerSel.Cond), m_Value(InnerSel.TrueVal),
2796*d415bd75Srobert m_Value(InnerSel.FalseVal))))
2797*d415bd75Srobert return nullptr;
2798*d415bd75Srobert
2799*d415bd75Srobert // Canonicalize inversion of the innermost `select`'s condition.
2800*d415bd75Srobert if (match(InnerSel.Cond, m_Not(m_Value(InnerSel.Cond))))
2801*d415bd75Srobert std::swap(InnerSel.TrueVal, InnerSel.FalseVal);
2802*d415bd75Srobert
2803*d415bd75Srobert Value *AltCond = nullptr;
2804*d415bd75Srobert auto matchOuterCond = [OuterSel, &AltCond](auto m_InnerCond) {
2805*d415bd75Srobert return match(OuterSel.Cond, m_c_LogicalOp(m_InnerCond, m_Value(AltCond)));
2806*d415bd75Srobert };
2807*d415bd75Srobert
2808*d415bd75Srobert // Finally, match the condition that was driving the outermost `select`,
2809*d415bd75Srobert // it should be a logical operation between the condition that was driving
2810*d415bd75Srobert // the innermost `select` (after accounting for the possible inversions
2811*d415bd75Srobert // of the condition), and some other condition.
2812*d415bd75Srobert if (matchOuterCond(m_Specific(InnerSel.Cond))) {
2813*d415bd75Srobert // Done!
2814*d415bd75Srobert } else if (Value * NotInnerCond; matchOuterCond(m_CombineAnd(
2815*d415bd75Srobert m_Not(m_Specific(InnerSel.Cond)), m_Value(NotInnerCond)))) {
2816*d415bd75Srobert // Done!
2817*d415bd75Srobert std::swap(InnerSel.TrueVal, InnerSel.FalseVal);
2818*d415bd75Srobert InnerSel.Cond = NotInnerCond;
2819*d415bd75Srobert } else // Not the pattern we were looking for.
2820*d415bd75Srobert return nullptr;
2821*d415bd75Srobert
2822*d415bd75Srobert Value *SelInner = Builder.CreateSelect(
2823*d415bd75Srobert AltCond, IsAndVariant ? OuterSel.TrueVal : InnerSel.FalseVal,
2824*d415bd75Srobert IsAndVariant ? InnerSel.TrueVal : OuterSel.FalseVal);
2825*d415bd75Srobert SelInner->takeName(InnerSelVal);
2826*d415bd75Srobert return SelectInst::Create(InnerSel.Cond,
2827*d415bd75Srobert IsAndVariant ? SelInner : InnerSel.TrueVal,
2828*d415bd75Srobert !IsAndVariant ? SelInner : InnerSel.FalseVal);
2829*d415bd75Srobert }
2830*d415bd75Srobert
foldSelectOfBools(SelectInst & SI)2831*d415bd75Srobert Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
283209467b48Spatrick Value *CondVal = SI.getCondition();
283309467b48Spatrick Value *TrueVal = SI.getTrueValue();
283409467b48Spatrick Value *FalseVal = SI.getFalseValue();
283509467b48Spatrick Type *SelType = SI.getType();
283609467b48Spatrick
283773471bf0Spatrick // Avoid potential infinite loops by checking for non-constant condition.
283873471bf0Spatrick // TODO: Can we assert instead by improving canonicalizeSelectToShuffle()?
283973471bf0Spatrick // Scalar select must have simplified?
2840*d415bd75Srobert if (!SelType->isIntOrIntVectorTy(1) || isa<Constant>(CondVal) ||
2841*d415bd75Srobert TrueVal->getType() != CondVal->getType())
2842*d415bd75Srobert return nullptr;
2843*d415bd75Srobert
2844*d415bd75Srobert auto *One = ConstantInt::getTrue(SelType);
2845*d415bd75Srobert auto *Zero = ConstantInt::getFalse(SelType);
2846*d415bd75Srobert Value *A, *B, *C, *D;
2847*d415bd75Srobert
284873471bf0Spatrick // Folding select to and/or i1 isn't poison safe in general. impliesPoison
284973471bf0Spatrick // checks whether folding it does not convert a well-defined value into
285073471bf0Spatrick // poison.
2851*d415bd75Srobert if (match(TrueVal, m_One())) {
2852*d415bd75Srobert if (impliesPoison(FalseVal, CondVal)) {
285309467b48Spatrick // Change: A = select B, true, C --> A = or B, C
285409467b48Spatrick return BinaryOperator::CreateOr(CondVal, FalseVal);
285509467b48Spatrick }
2856*d415bd75Srobert
2857*d415bd75Srobert if (auto *LHS = dyn_cast<FCmpInst>(CondVal))
2858*d415bd75Srobert if (auto *RHS = dyn_cast<FCmpInst>(FalseVal))
2859*d415bd75Srobert if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false,
2860*d415bd75Srobert /*IsSelectLogical*/ true))
2861*d415bd75Srobert return replaceInstUsesWith(SI, V);
2862*d415bd75Srobert
2863*d415bd75Srobert // (A && B) || (C && B) --> (A || C) && B
2864*d415bd75Srobert if (match(CondVal, m_LogicalAnd(m_Value(A), m_Value(B))) &&
2865*d415bd75Srobert match(FalseVal, m_LogicalAnd(m_Value(C), m_Value(D))) &&
2866*d415bd75Srobert (CondVal->hasOneUse() || FalseVal->hasOneUse())) {
2867*d415bd75Srobert bool CondLogicAnd = isa<SelectInst>(CondVal);
2868*d415bd75Srobert bool FalseLogicAnd = isa<SelectInst>(FalseVal);
2869*d415bd75Srobert auto AndFactorization = [&](Value *Common, Value *InnerCond,
2870*d415bd75Srobert Value *InnerVal,
2871*d415bd75Srobert bool SelFirst = false) -> Instruction * {
2872*d415bd75Srobert Value *InnerSel = Builder.CreateSelect(InnerCond, One, InnerVal);
2873*d415bd75Srobert if (SelFirst)
2874*d415bd75Srobert std::swap(Common, InnerSel);
2875*d415bd75Srobert if (FalseLogicAnd || (CondLogicAnd && Common == A))
2876*d415bd75Srobert return SelectInst::Create(Common, InnerSel, Zero);
2877*d415bd75Srobert else
2878*d415bd75Srobert return BinaryOperator::CreateAnd(Common, InnerSel);
2879*d415bd75Srobert };
2880*d415bd75Srobert
2881*d415bd75Srobert if (A == C)
2882*d415bd75Srobert return AndFactorization(A, B, D);
2883*d415bd75Srobert if (A == D)
2884*d415bd75Srobert return AndFactorization(A, B, C);
2885*d415bd75Srobert if (B == C)
2886*d415bd75Srobert return AndFactorization(B, A, D);
2887*d415bd75Srobert if (B == D)
2888*d415bd75Srobert return AndFactorization(B, A, C, CondLogicAnd && FalseLogicAnd);
2889*d415bd75Srobert }
2890*d415bd75Srobert }
2891*d415bd75Srobert
2892*d415bd75Srobert if (match(FalseVal, m_Zero())) {
2893*d415bd75Srobert if (impliesPoison(TrueVal, CondVal)) {
289409467b48Spatrick // Change: A = select B, C, false --> A = and B, C
289509467b48Spatrick return BinaryOperator::CreateAnd(CondVal, TrueVal);
289609467b48Spatrick }
289773471bf0Spatrick
2898*d415bd75Srobert if (auto *LHS = dyn_cast<FCmpInst>(CondVal))
2899*d415bd75Srobert if (auto *RHS = dyn_cast<FCmpInst>(TrueVal))
2900*d415bd75Srobert if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true,
2901*d415bd75Srobert /*IsSelectLogical*/ true))
2902*d415bd75Srobert return replaceInstUsesWith(SI, V);
2903*d415bd75Srobert
2904*d415bd75Srobert // (A || B) && (C || B) --> (A && C) || B
2905*d415bd75Srobert if (match(CondVal, m_LogicalOr(m_Value(A), m_Value(B))) &&
2906*d415bd75Srobert match(TrueVal, m_LogicalOr(m_Value(C), m_Value(D))) &&
2907*d415bd75Srobert (CondVal->hasOneUse() || TrueVal->hasOneUse())) {
2908*d415bd75Srobert bool CondLogicOr = isa<SelectInst>(CondVal);
2909*d415bd75Srobert bool TrueLogicOr = isa<SelectInst>(TrueVal);
2910*d415bd75Srobert auto OrFactorization = [&](Value *Common, Value *InnerCond,
2911*d415bd75Srobert Value *InnerVal,
2912*d415bd75Srobert bool SelFirst = false) -> Instruction * {
2913*d415bd75Srobert Value *InnerSel = Builder.CreateSelect(InnerCond, InnerVal, Zero);
2914*d415bd75Srobert if (SelFirst)
2915*d415bd75Srobert std::swap(Common, InnerSel);
2916*d415bd75Srobert if (TrueLogicOr || (CondLogicOr && Common == A))
2917*d415bd75Srobert return SelectInst::Create(Common, One, InnerSel);
2918*d415bd75Srobert else
2919*d415bd75Srobert return BinaryOperator::CreateOr(Common, InnerSel);
2920*d415bd75Srobert };
2921*d415bd75Srobert
2922*d415bd75Srobert if (A == C)
2923*d415bd75Srobert return OrFactorization(A, B, D);
2924*d415bd75Srobert if (A == D)
2925*d415bd75Srobert return OrFactorization(A, B, C);
2926*d415bd75Srobert if (B == C)
2927*d415bd75Srobert return OrFactorization(B, A, D);
2928*d415bd75Srobert if (B == D)
2929*d415bd75Srobert return OrFactorization(B, A, C, CondLogicOr && TrueLogicOr);
2930*d415bd75Srobert }
2931*d415bd75Srobert }
293273471bf0Spatrick
293373471bf0Spatrick // We match the "full" 0 or 1 constant here to avoid a potential infinite
293473471bf0Spatrick // loop with vectors that may have undefined/poison elements.
293573471bf0Spatrick // select a, false, b -> select !a, b, false
293673471bf0Spatrick if (match(TrueVal, m_Specific(Zero))) {
293709467b48Spatrick Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
293873471bf0Spatrick return SelectInst::Create(NotCond, FalseVal, Zero);
293973471bf0Spatrick }
294073471bf0Spatrick // select a, b, true -> select !a, true, b
294173471bf0Spatrick if (match(FalseVal, m_Specific(One))) {
294273471bf0Spatrick Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
294373471bf0Spatrick return SelectInst::Create(NotCond, One, TrueVal);
294409467b48Spatrick }
294509467b48Spatrick
294673471bf0Spatrick // DeMorgan in select form: !a && !b --> !(a || b)
294773471bf0Spatrick // select !a, !b, false --> not (select a, true, b)
294873471bf0Spatrick if (match(&SI, m_LogicalAnd(m_Not(m_Value(A)), m_Not(m_Value(B)))) &&
294973471bf0Spatrick (CondVal->hasOneUse() || TrueVal->hasOneUse()) &&
295073471bf0Spatrick !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr()))
295173471bf0Spatrick return BinaryOperator::CreateNot(Builder.CreateSelect(A, One, B));
295273471bf0Spatrick
295373471bf0Spatrick // DeMorgan in select form: !a || !b --> !(a && b)
295473471bf0Spatrick // select !a, true, !b --> not (select a, b, false)
295573471bf0Spatrick if (match(&SI, m_LogicalOr(m_Not(m_Value(A)), m_Not(m_Value(B)))) &&
295673471bf0Spatrick (CondVal->hasOneUse() || FalseVal->hasOneUse()) &&
295773471bf0Spatrick !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr()))
295873471bf0Spatrick return BinaryOperator::CreateNot(Builder.CreateSelect(A, B, Zero));
295973471bf0Spatrick
296073471bf0Spatrick // select (select a, true, b), true, b -> select a, true, b
296173471bf0Spatrick if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) &&
296273471bf0Spatrick match(TrueVal, m_One()) && match(FalseVal, m_Specific(B)))
296373471bf0Spatrick return replaceOperand(SI, 0, A);
296473471bf0Spatrick // select (select a, b, false), b, false -> select a, b, false
296573471bf0Spatrick if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) &&
296673471bf0Spatrick match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero()))
296773471bf0Spatrick return replaceOperand(SI, 0, A);
296873471bf0Spatrick
2969*d415bd75Srobert // ~(A & B) & (A | B) --> A ^ B
2970*d415bd75Srobert if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))),
2971*d415bd75Srobert m_c_LogicalOr(m_Deferred(A), m_Deferred(B)))))
2972*d415bd75Srobert return BinaryOperator::CreateXor(A, B);
2973*d415bd75Srobert
2974*d415bd75Srobert // select (~a | c), a, b -> and a, (or c, freeze(b))
2975*d415bd75Srobert if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) &&
2976*d415bd75Srobert CondVal->hasOneUse()) {
2977*d415bd75Srobert FalseVal = Builder.CreateFreeze(FalseVal);
2978*d415bd75Srobert return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal));
2979*d415bd75Srobert }
2980*d415bd75Srobert // select (~c & b), a, b -> and b, (or freeze(a), c)
2981*d415bd75Srobert if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) &&
2982*d415bd75Srobert CondVal->hasOneUse()) {
2983*d415bd75Srobert TrueVal = Builder.CreateFreeze(TrueVal);
2984*d415bd75Srobert return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal));
298573471bf0Spatrick }
298673471bf0Spatrick
298773471bf0Spatrick if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) {
298873471bf0Spatrick Use *Y = nullptr;
298973471bf0Spatrick bool IsAnd = match(FalseVal, m_Zero()) ? true : false;
299073471bf0Spatrick Value *Op1 = IsAnd ? TrueVal : FalseVal;
299173471bf0Spatrick if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) {
299273471bf0Spatrick auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr");
299373471bf0Spatrick InsertNewInstBefore(FI, *cast<Instruction>(Y->getUser()));
299473471bf0Spatrick replaceUse(*Y, FI);
299573471bf0Spatrick return replaceInstUsesWith(SI, Op1);
299673471bf0Spatrick }
299773471bf0Spatrick
299873471bf0Spatrick if (auto *Op1SI = dyn_cast<SelectInst>(Op1))
299973471bf0Spatrick if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI,
300073471bf0Spatrick /* IsAnd */ IsAnd))
300173471bf0Spatrick return I;
300273471bf0Spatrick
300373471bf0Spatrick if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal))
300473471bf0Spatrick if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1))
3005*d415bd75Srobert if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd,
300673471bf0Spatrick /* IsLogical */ true))
300773471bf0Spatrick return replaceInstUsesWith(SI, V);
300873471bf0Spatrick }
300973471bf0Spatrick
3010*d415bd75Srobert // select (a || b), c, false -> select a, c, false
3011*d415bd75Srobert // select c, (a || b), false -> select c, a, false
301273471bf0Spatrick // if c implies that b is false.
3013*d415bd75Srobert if (match(CondVal, m_LogicalOr(m_Value(A), m_Value(B))) &&
301473471bf0Spatrick match(FalseVal, m_Zero())) {
3015*d415bd75Srobert std::optional<bool> Res = isImpliedCondition(TrueVal, B, DL);
301673471bf0Spatrick if (Res && *Res == false)
301773471bf0Spatrick return replaceOperand(SI, 0, A);
301873471bf0Spatrick }
3019*d415bd75Srobert if (match(TrueVal, m_LogicalOr(m_Value(A), m_Value(B))) &&
302073471bf0Spatrick match(FalseVal, m_Zero())) {
3021*d415bd75Srobert std::optional<bool> Res = isImpliedCondition(CondVal, B, DL);
302273471bf0Spatrick if (Res && *Res == false)
302373471bf0Spatrick return replaceOperand(SI, 1, A);
302473471bf0Spatrick }
3025*d415bd75Srobert // select c, true, (a && b) -> select c, true, a
3026*d415bd75Srobert // select (a && b), true, c -> select a, true, c
302773471bf0Spatrick // if c = false implies that b = true
302873471bf0Spatrick if (match(TrueVal, m_One()) &&
3029*d415bd75Srobert match(FalseVal, m_LogicalAnd(m_Value(A), m_Value(B)))) {
3030*d415bd75Srobert std::optional<bool> Res = isImpliedCondition(CondVal, B, DL, false);
303173471bf0Spatrick if (Res && *Res == true)
303273471bf0Spatrick return replaceOperand(SI, 2, A);
303373471bf0Spatrick }
3034*d415bd75Srobert if (match(CondVal, m_LogicalAnd(m_Value(A), m_Value(B))) &&
303573471bf0Spatrick match(TrueVal, m_One())) {
3036*d415bd75Srobert std::optional<bool> Res = isImpliedCondition(FalseVal, B, DL, false);
303773471bf0Spatrick if (Res && *Res == true)
303873471bf0Spatrick return replaceOperand(SI, 0, A);
303973471bf0Spatrick }
304073471bf0Spatrick
3041*d415bd75Srobert if (match(TrueVal, m_One())) {
3042*d415bd75Srobert Value *C;
3043*d415bd75Srobert
3044*d415bd75Srobert // (C && A) || (!C && B) --> sel C, A, B
3045*d415bd75Srobert // (A && C) || (!C && B) --> sel C, A, B
3046*d415bd75Srobert // (C && A) || (B && !C) --> sel C, A, B
3047*d415bd75Srobert // (A && C) || (B && !C) --> sel C, A, B (may require freeze)
3048*d415bd75Srobert if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Value(C)), m_Value(B))) &&
3049*d415bd75Srobert match(CondVal, m_c_LogicalAnd(m_Specific(C), m_Value(A)))) {
3050*d415bd75Srobert auto *SelCond = dyn_cast<SelectInst>(CondVal);
3051*d415bd75Srobert auto *SelFVal = dyn_cast<SelectInst>(FalseVal);
3052*d415bd75Srobert bool MayNeedFreeze = SelCond && SelFVal &&
3053*d415bd75Srobert match(SelFVal->getTrueValue(),
3054*d415bd75Srobert m_Not(m_Specific(SelCond->getTrueValue())));
3055*d415bd75Srobert if (MayNeedFreeze)
3056*d415bd75Srobert C = Builder.CreateFreeze(C);
3057*d415bd75Srobert return SelectInst::Create(C, A, B);
3058*d415bd75Srobert }
3059*d415bd75Srobert
3060*d415bd75Srobert // (!C && A) || (C && B) --> sel C, B, A
3061*d415bd75Srobert // (A && !C) || (C && B) --> sel C, B, A
3062*d415bd75Srobert // (!C && A) || (B && C) --> sel C, B, A
3063*d415bd75Srobert // (A && !C) || (B && C) --> sel C, B, A (may require freeze)
3064*d415bd75Srobert if (match(CondVal, m_c_LogicalAnd(m_Not(m_Value(C)), m_Value(A))) &&
3065*d415bd75Srobert match(FalseVal, m_c_LogicalAnd(m_Specific(C), m_Value(B)))) {
3066*d415bd75Srobert auto *SelCond = dyn_cast<SelectInst>(CondVal);
3067*d415bd75Srobert auto *SelFVal = dyn_cast<SelectInst>(FalseVal);
3068*d415bd75Srobert bool MayNeedFreeze = SelCond && SelFVal &&
3069*d415bd75Srobert match(SelCond->getTrueValue(),
3070*d415bd75Srobert m_Not(m_Specific(SelFVal->getTrueValue())));
3071*d415bd75Srobert if (MayNeedFreeze)
3072*d415bd75Srobert C = Builder.CreateFreeze(C);
3073*d415bd75Srobert return SelectInst::Create(C, B, A);
307473471bf0Spatrick }
307509467b48Spatrick }
307609467b48Spatrick
3077*d415bd75Srobert return nullptr;
3078*d415bd75Srobert }
3079*d415bd75Srobert
visitSelectInst(SelectInst & SI)3080*d415bd75Srobert Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
3081*d415bd75Srobert Value *CondVal = SI.getCondition();
3082*d415bd75Srobert Value *TrueVal = SI.getTrueValue();
3083*d415bd75Srobert Value *FalseVal = SI.getFalseValue();
3084*d415bd75Srobert Type *SelType = SI.getType();
3085*d415bd75Srobert
3086*d415bd75Srobert if (Value *V = simplifySelectInst(CondVal, TrueVal, FalseVal,
3087*d415bd75Srobert SQ.getWithInstruction(&SI)))
3088*d415bd75Srobert return replaceInstUsesWith(SI, V);
3089*d415bd75Srobert
3090*d415bd75Srobert if (Instruction *I = canonicalizeSelectToShuffle(SI))
3091*d415bd75Srobert return I;
3092*d415bd75Srobert
3093*d415bd75Srobert if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, *this))
3094*d415bd75Srobert return I;
3095*d415bd75Srobert
3096*d415bd75Srobert // If the type of select is not an integer type or if the condition and
3097*d415bd75Srobert // the selection type are not both scalar nor both vector types, there is no
3098*d415bd75Srobert // point in attempting to match these patterns.
3099*d415bd75Srobert Type *CondType = CondVal->getType();
3100*d415bd75Srobert if (!isa<Constant>(CondVal) && SelType->isIntOrIntVectorTy() &&
3101*d415bd75Srobert CondType->isVectorTy() == SelType->isVectorTy()) {
3102*d415bd75Srobert if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal,
3103*d415bd75Srobert ConstantInt::getTrue(CondType), SQ,
3104*d415bd75Srobert /* AllowRefinement */ true))
3105*d415bd75Srobert return replaceOperand(SI, 1, S);
3106*d415bd75Srobert
3107*d415bd75Srobert if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal,
3108*d415bd75Srobert ConstantInt::getFalse(CondType), SQ,
3109*d415bd75Srobert /* AllowRefinement */ true))
3110*d415bd75Srobert return replaceOperand(SI, 2, S);
3111*d415bd75Srobert
3112*d415bd75Srobert // Handle patterns involving sext/zext + not explicitly,
3113*d415bd75Srobert // as simplifyWithOpReplaced() only looks past one instruction.
3114*d415bd75Srobert Value *NotCond;
3115*d415bd75Srobert
3116*d415bd75Srobert // select a, sext(!a), b -> select !a, b, 0
3117*d415bd75Srobert // select a, zext(!a), b -> select !a, b, 0
3118*d415bd75Srobert if (match(TrueVal, m_ZExtOrSExt(m_CombineAnd(m_Value(NotCond),
3119*d415bd75Srobert m_Not(m_Specific(CondVal))))))
3120*d415bd75Srobert return SelectInst::Create(NotCond, FalseVal,
3121*d415bd75Srobert Constant::getNullValue(SelType));
3122*d415bd75Srobert
3123*d415bd75Srobert // select a, b, zext(!a) -> select !a, 1, b
3124*d415bd75Srobert if (match(FalseVal, m_ZExt(m_CombineAnd(m_Value(NotCond),
3125*d415bd75Srobert m_Not(m_Specific(CondVal))))))
3126*d415bd75Srobert return SelectInst::Create(NotCond, ConstantInt::get(SelType, 1), TrueVal);
3127*d415bd75Srobert
3128*d415bd75Srobert // select a, b, sext(!a) -> select !a, -1, b
3129*d415bd75Srobert if (match(FalseVal, m_SExt(m_CombineAnd(m_Value(NotCond),
3130*d415bd75Srobert m_Not(m_Specific(CondVal))))))
3131*d415bd75Srobert return SelectInst::Create(NotCond, Constant::getAllOnesValue(SelType),
3132*d415bd75Srobert TrueVal);
3133*d415bd75Srobert }
3134*d415bd75Srobert
3135*d415bd75Srobert if (Instruction *R = foldSelectOfBools(SI))
3136*d415bd75Srobert return R;
3137*d415bd75Srobert
313809467b48Spatrick // Selecting between two integer or vector splat integer constants?
313909467b48Spatrick //
314009467b48Spatrick // Note that we don't handle a scalar select of vectors:
314109467b48Spatrick // select i1 %c, <2 x i8> <1, 1>, <2 x i8> <0, 0>
314209467b48Spatrick // because that may need 3 instructions to splat the condition value:
314309467b48Spatrick // extend, insertelement, shufflevector.
314473471bf0Spatrick //
314573471bf0Spatrick // Do not handle i1 TrueVal and FalseVal otherwise would result in
314673471bf0Spatrick // zext/sext i1 to i1.
314773471bf0Spatrick if (SelType->isIntOrIntVectorTy() && !SelType->isIntOrIntVectorTy(1) &&
314809467b48Spatrick CondVal->getType()->isVectorTy() == SelType->isVectorTy()) {
314909467b48Spatrick // select C, 1, 0 -> zext C to int
315009467b48Spatrick if (match(TrueVal, m_One()) && match(FalseVal, m_Zero()))
315109467b48Spatrick return new ZExtInst(CondVal, SelType);
315209467b48Spatrick
315309467b48Spatrick // select C, -1, 0 -> sext C to int
315409467b48Spatrick if (match(TrueVal, m_AllOnes()) && match(FalseVal, m_Zero()))
315509467b48Spatrick return new SExtInst(CondVal, SelType);
315609467b48Spatrick
315709467b48Spatrick // select C, 0, 1 -> zext !C to int
315809467b48Spatrick if (match(TrueVal, m_Zero()) && match(FalseVal, m_One())) {
315909467b48Spatrick Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
316009467b48Spatrick return new ZExtInst(NotCond, SelType);
316109467b48Spatrick }
316209467b48Spatrick
316309467b48Spatrick // select C, 0, -1 -> sext !C to int
316409467b48Spatrick if (match(TrueVal, m_Zero()) && match(FalseVal, m_AllOnes())) {
316509467b48Spatrick Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
316609467b48Spatrick return new SExtInst(NotCond, SelType);
316709467b48Spatrick }
316809467b48Spatrick }
316909467b48Spatrick
317073471bf0Spatrick if (auto *FCmp = dyn_cast<FCmpInst>(CondVal)) {
317173471bf0Spatrick Value *Cmp0 = FCmp->getOperand(0), *Cmp1 = FCmp->getOperand(1);
317273471bf0Spatrick // Are we selecting a value based on a comparison of the two values?
317309467b48Spatrick if ((Cmp0 == TrueVal && Cmp1 == FalseVal) ||
317409467b48Spatrick (Cmp0 == FalseVal && Cmp1 == TrueVal)) {
317509467b48Spatrick // Canonicalize to use ordered comparisons by swapping the select
317609467b48Spatrick // operands.
317709467b48Spatrick //
317809467b48Spatrick // e.g.
317909467b48Spatrick // (X ugt Y) ? X : Y -> (X ole Y) ? Y : X
318073471bf0Spatrick if (FCmp->hasOneUse() && FCmpInst::isUnordered(FCmp->getPredicate())) {
318173471bf0Spatrick FCmpInst::Predicate InvPred = FCmp->getInversePredicate();
318209467b48Spatrick IRBuilder<>::FastMathFlagGuard FMFG(Builder);
318309467b48Spatrick // FIXME: The FMF should propagate from the select, not the fcmp.
318473471bf0Spatrick Builder.setFastMathFlags(FCmp->getFastMathFlags());
318509467b48Spatrick Value *NewCond = Builder.CreateFCmp(InvPred, Cmp0, Cmp1,
318673471bf0Spatrick FCmp->getName() + ".inv");
318709467b48Spatrick Value *NewSel = Builder.CreateSelect(NewCond, FalseVal, TrueVal);
318809467b48Spatrick return replaceInstUsesWith(SI, NewSel);
318909467b48Spatrick }
319009467b48Spatrick }
319109467b48Spatrick }
319209467b48Spatrick
3193*d415bd75Srobert if (isa<FPMathOperator>(SI)) {
3194*d415bd75Srobert // TODO: Try to forward-propagate FMF from select arms to the select.
3195*d415bd75Srobert
3196*d415bd75Srobert // Canonicalize select of FP values where NaN and -0.0 are not valid as
3197*d415bd75Srobert // minnum/maxnum intrinsics.
3198*d415bd75Srobert if (SI.hasNoNaNs() && SI.hasNoSignedZeros()) {
3199*d415bd75Srobert Value *X, *Y;
3200*d415bd75Srobert if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y))))
3201*d415bd75Srobert return replaceInstUsesWith(
3202*d415bd75Srobert SI, Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, X, Y, &SI));
3203*d415bd75Srobert
3204*d415bd75Srobert if (match(&SI, m_OrdFMin(m_Value(X), m_Value(Y))))
3205*d415bd75Srobert return replaceInstUsesWith(
3206*d415bd75Srobert SI, Builder.CreateBinaryIntrinsic(Intrinsic::minnum, X, Y, &SI));
320709467b48Spatrick }
320809467b48Spatrick }
3209*d415bd75Srobert
3210*d415bd75Srobert // Fold selecting to fabs.
3211*d415bd75Srobert if (Instruction *Fabs = foldSelectWithFCmpToFabs(SI, *this))
3212*d415bd75Srobert return Fabs;
321309467b48Spatrick
321409467b48Spatrick // See if we are selecting two values based on a comparison of the two values.
321509467b48Spatrick if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal))
321609467b48Spatrick if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
321709467b48Spatrick return Result;
321809467b48Spatrick
321909467b48Spatrick if (Instruction *Add = foldAddSubSelect(SI, Builder))
322009467b48Spatrick return Add;
322109467b48Spatrick if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder))
322209467b48Spatrick return Add;
3223097a140dSpatrick if (Instruction *Or = foldSetClearBits(SI, Builder))
3224097a140dSpatrick return Or;
3225*d415bd75Srobert if (Instruction *Mul = foldSelectZeroOrMul(SI, *this))
3226*d415bd75Srobert return Mul;
322709467b48Spatrick
322809467b48Spatrick // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z))
322909467b48Spatrick auto *TI = dyn_cast<Instruction>(TrueVal);
323009467b48Spatrick auto *FI = dyn_cast<Instruction>(FalseVal);
323109467b48Spatrick if (TI && FI && TI->getOpcode() == FI->getOpcode())
323209467b48Spatrick if (Instruction *IV = foldSelectOpOp(SI, TI, FI))
323309467b48Spatrick return IV;
323409467b48Spatrick
323509467b48Spatrick if (Instruction *I = foldSelectExtConst(SI))
323609467b48Spatrick return I;
323709467b48Spatrick
323873471bf0Spatrick // Fold (select C, (gep Ptr, Idx), Ptr) -> (gep Ptr, (select C, Idx, 0))
323973471bf0Spatrick // Fold (select C, Ptr, (gep Ptr, Idx)) -> (gep Ptr, (select C, 0, Idx))
324073471bf0Spatrick auto SelectGepWithBase = [&](GetElementPtrInst *Gep, Value *Base,
324173471bf0Spatrick bool Swap) -> GetElementPtrInst * {
324273471bf0Spatrick Value *Ptr = Gep->getPointerOperand();
324373471bf0Spatrick if (Gep->getNumOperands() != 2 || Gep->getPointerOperand() != Base ||
324473471bf0Spatrick !Gep->hasOneUse())
324573471bf0Spatrick return nullptr;
324673471bf0Spatrick Value *Idx = Gep->getOperand(1);
3247*d415bd75Srobert if (isa<VectorType>(CondVal->getType()) && !isa<VectorType>(Idx->getType()))
3248*d415bd75Srobert return nullptr;
3249*d415bd75Srobert Type *ElementType = Gep->getResultElementType();
325073471bf0Spatrick Value *NewT = Idx;
325173471bf0Spatrick Value *NewF = Constant::getNullValue(Idx->getType());
325273471bf0Spatrick if (Swap)
325373471bf0Spatrick std::swap(NewT, NewF);
325473471bf0Spatrick Value *NewSI =
325573471bf0Spatrick Builder.CreateSelect(CondVal, NewT, NewF, SI.getName() + ".idx", &SI);
325673471bf0Spatrick return GetElementPtrInst::Create(ElementType, Ptr, {NewSI});
325773471bf0Spatrick };
325873471bf0Spatrick if (auto *TrueGep = dyn_cast<GetElementPtrInst>(TrueVal))
325973471bf0Spatrick if (auto *NewGep = SelectGepWithBase(TrueGep, FalseVal, false))
326073471bf0Spatrick return NewGep;
326173471bf0Spatrick if (auto *FalseGep = dyn_cast<GetElementPtrInst>(FalseVal))
326273471bf0Spatrick if (auto *NewGep = SelectGepWithBase(FalseGep, TrueVal, true))
326373471bf0Spatrick return NewGep;
326473471bf0Spatrick
326509467b48Spatrick // See if we can fold the select into one of our operands.
326609467b48Spatrick if (SelType->isIntOrIntVectorTy() || SelType->isFPOrFPVectorTy()) {
326709467b48Spatrick if (Instruction *FoldI = foldSelectIntoOp(SI, TrueVal, FalseVal))
326809467b48Spatrick return FoldI;
326909467b48Spatrick
327009467b48Spatrick Value *LHS, *RHS;
327109467b48Spatrick Instruction::CastOps CastOp;
327209467b48Spatrick SelectPatternResult SPR = matchSelectPattern(&SI, LHS, RHS, &CastOp);
327309467b48Spatrick auto SPF = SPR.Flavor;
327409467b48Spatrick if (SPF) {
327509467b48Spatrick Value *LHS2, *RHS2;
327609467b48Spatrick if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor)
327709467b48Spatrick if (Instruction *R = foldSPFofSPF(cast<Instruction>(LHS), SPF2, LHS2,
327809467b48Spatrick RHS2, SI, SPF, RHS))
327909467b48Spatrick return R;
328009467b48Spatrick if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2).Flavor)
328109467b48Spatrick if (Instruction *R = foldSPFofSPF(cast<Instruction>(RHS), SPF2, LHS2,
328209467b48Spatrick RHS2, SI, SPF, LHS))
328309467b48Spatrick return R;
328409467b48Spatrick }
328509467b48Spatrick
328609467b48Spatrick if (SelectPatternResult::isMinOrMax(SPF)) {
328709467b48Spatrick // Canonicalize so that
328809467b48Spatrick // - type casts are outside select patterns.
328909467b48Spatrick // - float clamp is transformed to min/max pattern
329009467b48Spatrick
329109467b48Spatrick bool IsCastNeeded = LHS->getType() != SelType;
329209467b48Spatrick Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0);
329309467b48Spatrick Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1);
329409467b48Spatrick if (IsCastNeeded ||
329509467b48Spatrick (LHS->getType()->isFPOrFPVectorTy() &&
329609467b48Spatrick ((CmpLHS != LHS && CmpLHS != RHS) ||
329709467b48Spatrick (CmpRHS != LHS && CmpRHS != RHS)))) {
329809467b48Spatrick CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered);
329909467b48Spatrick
330009467b48Spatrick Value *Cmp;
330109467b48Spatrick if (CmpInst::isIntPredicate(MinMaxPred)) {
330209467b48Spatrick Cmp = Builder.CreateICmp(MinMaxPred, LHS, RHS);
330309467b48Spatrick } else {
330409467b48Spatrick IRBuilder<>::FastMathFlagGuard FMFG(Builder);
330509467b48Spatrick auto FMF =
330609467b48Spatrick cast<FPMathOperator>(SI.getCondition())->getFastMathFlags();
330709467b48Spatrick Builder.setFastMathFlags(FMF);
330809467b48Spatrick Cmp = Builder.CreateFCmp(MinMaxPred, LHS, RHS);
330909467b48Spatrick }
331009467b48Spatrick
331109467b48Spatrick Value *NewSI = Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI);
331209467b48Spatrick if (!IsCastNeeded)
331309467b48Spatrick return replaceInstUsesWith(SI, NewSI);
331409467b48Spatrick
331509467b48Spatrick Value *NewCast = Builder.CreateCast(CastOp, NewSI, SelType);
331609467b48Spatrick return replaceInstUsesWith(SI, NewCast);
331709467b48Spatrick }
331809467b48Spatrick }
331909467b48Spatrick }
332009467b48Spatrick
332109467b48Spatrick // See if we can fold the select into a phi node if the condition is a select.
332209467b48Spatrick if (auto *PN = dyn_cast<PHINode>(SI.getCondition()))
332309467b48Spatrick // The true/false values have to be live in the PHI predecessor's blocks.
332409467b48Spatrick if (canSelectOperandBeMappingIntoPredBlock(TrueVal, SI) &&
332509467b48Spatrick canSelectOperandBeMappingIntoPredBlock(FalseVal, SI))
332609467b48Spatrick if (Instruction *NV = foldOpIntoPhi(SI, PN))
332709467b48Spatrick return NV;
332809467b48Spatrick
332909467b48Spatrick if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
333009467b48Spatrick if (TrueSI->getCondition()->getType() == CondVal->getType()) {
333109467b48Spatrick // select(C, select(C, a, b), c) -> select(C, a, c)
333209467b48Spatrick if (TrueSI->getCondition() == CondVal) {
333309467b48Spatrick if (SI.getTrueValue() == TrueSI->getTrueValue())
333409467b48Spatrick return nullptr;
3335097a140dSpatrick return replaceOperand(SI, 1, TrueSI->getTrueValue());
333609467b48Spatrick }
333709467b48Spatrick // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b)
333873471bf0Spatrick // We choose this as normal form to enable folding on the And and
333973471bf0Spatrick // shortening paths for the values (this helps getUnderlyingObjects() for
334073471bf0Spatrick // example).
334109467b48Spatrick if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) {
334273471bf0Spatrick Value *And = Builder.CreateLogicalAnd(CondVal, TrueSI->getCondition());
3343097a140dSpatrick replaceOperand(SI, 0, And);
3344097a140dSpatrick replaceOperand(SI, 1, TrueSI->getTrueValue());
334509467b48Spatrick return &SI;
334609467b48Spatrick }
334709467b48Spatrick }
334809467b48Spatrick }
334909467b48Spatrick if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) {
335009467b48Spatrick if (FalseSI->getCondition()->getType() == CondVal->getType()) {
335109467b48Spatrick // select(C, a, select(C, b, c)) -> select(C, a, c)
335209467b48Spatrick if (FalseSI->getCondition() == CondVal) {
335309467b48Spatrick if (SI.getFalseValue() == FalseSI->getFalseValue())
335409467b48Spatrick return nullptr;
3355097a140dSpatrick return replaceOperand(SI, 2, FalseSI->getFalseValue());
335609467b48Spatrick }
335709467b48Spatrick // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b)
335809467b48Spatrick if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) {
335973471bf0Spatrick Value *Or = Builder.CreateLogicalOr(CondVal, FalseSI->getCondition());
3360097a140dSpatrick replaceOperand(SI, 0, Or);
3361097a140dSpatrick replaceOperand(SI, 2, FalseSI->getFalseValue());
336209467b48Spatrick return &SI;
336309467b48Spatrick }
336409467b48Spatrick }
336509467b48Spatrick }
336609467b48Spatrick
336709467b48Spatrick auto canMergeSelectThroughBinop = [](BinaryOperator *BO) {
336809467b48Spatrick // The select might be preventing a division by 0.
336909467b48Spatrick switch (BO->getOpcode()) {
337009467b48Spatrick default:
337109467b48Spatrick return true;
337209467b48Spatrick case Instruction::SRem:
337309467b48Spatrick case Instruction::URem:
337409467b48Spatrick case Instruction::SDiv:
337509467b48Spatrick case Instruction::UDiv:
337609467b48Spatrick return false;
337709467b48Spatrick }
337809467b48Spatrick };
337909467b48Spatrick
338009467b48Spatrick // Try to simplify a binop sandwiched between 2 selects with the same
338109467b48Spatrick // condition.
338209467b48Spatrick // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z)
338309467b48Spatrick BinaryOperator *TrueBO;
338409467b48Spatrick if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) &&
338509467b48Spatrick canMergeSelectThroughBinop(TrueBO)) {
338609467b48Spatrick if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) {
338709467b48Spatrick if (TrueBOSI->getCondition() == CondVal) {
3388097a140dSpatrick replaceOperand(*TrueBO, 0, TrueBOSI->getTrueValue());
3389097a140dSpatrick Worklist.push(TrueBO);
339009467b48Spatrick return &SI;
339109467b48Spatrick }
339209467b48Spatrick }
339309467b48Spatrick if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(1))) {
339409467b48Spatrick if (TrueBOSI->getCondition() == CondVal) {
3395097a140dSpatrick replaceOperand(*TrueBO, 1, TrueBOSI->getTrueValue());
3396097a140dSpatrick Worklist.push(TrueBO);
339709467b48Spatrick return &SI;
339809467b48Spatrick }
339909467b48Spatrick }
340009467b48Spatrick }
340109467b48Spatrick
340209467b48Spatrick // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W))
340309467b48Spatrick BinaryOperator *FalseBO;
340409467b48Spatrick if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) &&
340509467b48Spatrick canMergeSelectThroughBinop(FalseBO)) {
340609467b48Spatrick if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) {
340709467b48Spatrick if (FalseBOSI->getCondition() == CondVal) {
3408097a140dSpatrick replaceOperand(*FalseBO, 0, FalseBOSI->getFalseValue());
3409097a140dSpatrick Worklist.push(FalseBO);
341009467b48Spatrick return &SI;
341109467b48Spatrick }
341209467b48Spatrick }
341309467b48Spatrick if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(1))) {
341409467b48Spatrick if (FalseBOSI->getCondition() == CondVal) {
3415097a140dSpatrick replaceOperand(*FalseBO, 1, FalseBOSI->getFalseValue());
3416097a140dSpatrick Worklist.push(FalseBO);
341709467b48Spatrick return &SI;
341809467b48Spatrick }
341909467b48Spatrick }
342009467b48Spatrick }
342109467b48Spatrick
342209467b48Spatrick Value *NotCond;
342373471bf0Spatrick if (match(CondVal, m_Not(m_Value(NotCond))) &&
342473471bf0Spatrick !InstCombiner::shouldAvoidAbsorbingNotIntoSelect(SI)) {
3425097a140dSpatrick replaceOperand(SI, 0, NotCond);
3426097a140dSpatrick SI.swapValues();
342709467b48Spatrick SI.swapProfMetadata();
342809467b48Spatrick return &SI;
342909467b48Spatrick }
343009467b48Spatrick
3431097a140dSpatrick if (Instruction *I = foldVectorSelect(SI))
3432097a140dSpatrick return I;
343309467b48Spatrick
343409467b48Spatrick // If we can compute the condition, there's no need for a select.
343509467b48Spatrick // Like the above fold, we are attempting to reduce compile-time cost by
343609467b48Spatrick // putting this fold here with limitations rather than in InstSimplify.
343709467b48Spatrick // The motivation for this call into value tracking is to take advantage of
343809467b48Spatrick // the assumption cache, so make sure that is populated.
343909467b48Spatrick if (!CondVal->getType()->isVectorTy() && !AC.assumptions().empty()) {
344009467b48Spatrick KnownBits Known(1);
344109467b48Spatrick computeKnownBits(CondVal, Known, 0, &SI);
3442*d415bd75Srobert if (Known.One.isOne())
344309467b48Spatrick return replaceInstUsesWith(SI, TrueVal);
3444*d415bd75Srobert if (Known.Zero.isOne())
344509467b48Spatrick return replaceInstUsesWith(SI, FalseVal);
344609467b48Spatrick }
344709467b48Spatrick
344809467b48Spatrick if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, Builder))
344909467b48Spatrick return BitCastSel;
345009467b48Spatrick
345109467b48Spatrick // Simplify selects that test the returned flag of cmpxchg instructions.
3452097a140dSpatrick if (Value *V = foldSelectCmpXchg(SI))
3453097a140dSpatrick return replaceInstUsesWith(SI, V);
345409467b48Spatrick
3455097a140dSpatrick if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI, *this))
345609467b48Spatrick return Select;
345709467b48Spatrick
345873471bf0Spatrick if (Instruction *Funnel = foldSelectFunnelShift(SI, Builder))
345973471bf0Spatrick return Funnel;
346009467b48Spatrick
3461097a140dSpatrick if (Instruction *Copysign = foldSelectToCopysign(SI, Builder))
3462097a140dSpatrick return Copysign;
3463097a140dSpatrick
3464097a140dSpatrick if (Instruction *PN = foldSelectToPhi(SI, DT, Builder))
3465097a140dSpatrick return replaceInstUsesWith(SI, PN);
3466097a140dSpatrick
346773471bf0Spatrick if (Value *Fr = foldSelectWithFrozenICmp(SI, Builder))
346873471bf0Spatrick return replaceInstUsesWith(SI, Fr);
346973471bf0Spatrick
3470*d415bd75Srobert if (Value *V = foldRoundUpIntegerWithPow2Alignment(SI, Builder))
3471*d415bd75Srobert return replaceInstUsesWith(SI, V);
3472*d415bd75Srobert
347373471bf0Spatrick // select(mask, mload(,,mask,0), 0) -> mload(,,mask,0)
347473471bf0Spatrick // Load inst is intentionally not checked for hasOneUse()
347573471bf0Spatrick if (match(FalseVal, m_Zero()) &&
3476*d415bd75Srobert (match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal),
3477*d415bd75Srobert m_CombineOr(m_Undef(), m_Zero()))) ||
3478*d415bd75Srobert match(TrueVal, m_MaskedGather(m_Value(), m_Value(), m_Specific(CondVal),
3479*d415bd75Srobert m_CombineOr(m_Undef(), m_Zero()))))) {
3480*d415bd75Srobert auto *MaskedInst = cast<IntrinsicInst>(TrueVal);
3481*d415bd75Srobert if (isa<UndefValue>(MaskedInst->getArgOperand(3)))
3482*d415bd75Srobert MaskedInst->setArgOperand(3, FalseVal /* Zero */);
3483*d415bd75Srobert return replaceInstUsesWith(SI, MaskedInst);
348473471bf0Spatrick }
348573471bf0Spatrick
348673471bf0Spatrick Value *Mask;
348773471bf0Spatrick if (match(TrueVal, m_Zero()) &&
3488*d415bd75Srobert (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask),
3489*d415bd75Srobert m_CombineOr(m_Undef(), m_Zero()))) ||
3490*d415bd75Srobert match(FalseVal, m_MaskedGather(m_Value(), m_Value(), m_Value(Mask),
3491*d415bd75Srobert m_CombineOr(m_Undef(), m_Zero())))) &&
349273471bf0Spatrick (CondVal->getType() == Mask->getType())) {
349373471bf0Spatrick // We can remove the select by ensuring the load zeros all lanes the
349473471bf0Spatrick // select would have. We determine this by proving there is no overlap
349573471bf0Spatrick // between the load and select masks.
349673471bf0Spatrick // (i.e (load_mask & select_mask) == 0 == no overlap)
349773471bf0Spatrick bool CanMergeSelectIntoLoad = false;
3498*d415bd75Srobert if (Value *V = simplifyAndInst(CondVal, Mask, SQ.getWithInstruction(&SI)))
349973471bf0Spatrick CanMergeSelectIntoLoad = match(V, m_Zero());
350073471bf0Spatrick
350173471bf0Spatrick if (CanMergeSelectIntoLoad) {
3502*d415bd75Srobert auto *MaskedInst = cast<IntrinsicInst>(FalseVal);
3503*d415bd75Srobert if (isa<UndefValue>(MaskedInst->getArgOperand(3)))
3504*d415bd75Srobert MaskedInst->setArgOperand(3, TrueVal /* Zero */);
3505*d415bd75Srobert return replaceInstUsesWith(SI, MaskedInst);
350673471bf0Spatrick }
350773471bf0Spatrick }
350873471bf0Spatrick
3509*d415bd75Srobert if (Instruction *I = foldNestedSelects(SI, Builder))
3510*d415bd75Srobert return I;
3511*d415bd75Srobert
3512*d415bd75Srobert // Match logical variants of the pattern,
3513*d415bd75Srobert // and transform them iff that gets rid of inversions.
3514*d415bd75Srobert // (~x) | y --> ~(x & (~y))
3515*d415bd75Srobert // (~x) & y --> ~(x | (~y))
3516*d415bd75Srobert if (sinkNotIntoOtherHandOfLogicalOp(SI))
3517*d415bd75Srobert return &SI;
3518*d415bd75Srobert
351909467b48Spatrick return nullptr;
352009467b48Spatrick }
3521