1*5ffd83dbSDimitry Andric //===-- X86PartialReduction.cpp -------------------------------------------===// 2*5ffd83dbSDimitry Andric // 3*5ffd83dbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*5ffd83dbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*5ffd83dbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*5ffd83dbSDimitry Andric // 7*5ffd83dbSDimitry Andric //===----------------------------------------------------------------------===// 8*5ffd83dbSDimitry Andric // 9*5ffd83dbSDimitry Andric // This pass looks for add instructions used by a horizontal reduction to see 10*5ffd83dbSDimitry Andric // if we might be able to use pmaddwd or psadbw. Some cases of this require 11*5ffd83dbSDimitry Andric // cross basic block knowledge and can't be done in SelectionDAG. 12*5ffd83dbSDimitry Andric // 13*5ffd83dbSDimitry Andric //===----------------------------------------------------------------------===// 14*5ffd83dbSDimitry Andric 15*5ffd83dbSDimitry Andric #include "X86.h" 16*5ffd83dbSDimitry Andric #include "llvm/Analysis/ValueTracking.h" 17*5ffd83dbSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 18*5ffd83dbSDimitry Andric #include "llvm/IR/Constants.h" 19*5ffd83dbSDimitry Andric #include "llvm/IR/Instructions.h" 20*5ffd83dbSDimitry Andric #include "llvm/IR/IntrinsicsX86.h" 21*5ffd83dbSDimitry Andric #include "llvm/IR/IRBuilder.h" 22*5ffd83dbSDimitry Andric #include "llvm/IR/Operator.h" 23*5ffd83dbSDimitry Andric #include "llvm/Pass.h" 24*5ffd83dbSDimitry Andric #include "X86TargetMachine.h" 25*5ffd83dbSDimitry Andric 26*5ffd83dbSDimitry Andric using namespace llvm; 27*5ffd83dbSDimitry Andric 28*5ffd83dbSDimitry Andric #define DEBUG_TYPE "x86-partial-reduction" 29*5ffd83dbSDimitry Andric 30*5ffd83dbSDimitry Andric namespace { 31*5ffd83dbSDimitry Andric 32*5ffd83dbSDimitry Andric class X86PartialReduction : public FunctionPass { 33*5ffd83dbSDimitry Andric const DataLayout *DL; 34*5ffd83dbSDimitry Andric const X86Subtarget *ST; 35*5ffd83dbSDimitry Andric 36*5ffd83dbSDimitry Andric public: 37*5ffd83dbSDimitry Andric static char ID; // Pass identification, replacement for typeid. 38*5ffd83dbSDimitry Andric 39*5ffd83dbSDimitry Andric X86PartialReduction() : FunctionPass(ID) { } 40*5ffd83dbSDimitry Andric 41*5ffd83dbSDimitry Andric bool runOnFunction(Function &Fn) override; 42*5ffd83dbSDimitry Andric 43*5ffd83dbSDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 44*5ffd83dbSDimitry Andric AU.setPreservesCFG(); 45*5ffd83dbSDimitry Andric } 46*5ffd83dbSDimitry Andric 47*5ffd83dbSDimitry Andric StringRef getPassName() const override { 48*5ffd83dbSDimitry Andric return "X86 Partial Reduction"; 49*5ffd83dbSDimitry Andric } 50*5ffd83dbSDimitry Andric 51*5ffd83dbSDimitry Andric private: 52*5ffd83dbSDimitry Andric bool tryMAddReplacement(Instruction *Op); 53*5ffd83dbSDimitry Andric bool trySADReplacement(Instruction *Op); 54*5ffd83dbSDimitry Andric }; 55*5ffd83dbSDimitry Andric } 56*5ffd83dbSDimitry Andric 57*5ffd83dbSDimitry Andric FunctionPass *llvm::createX86PartialReductionPass() { 58*5ffd83dbSDimitry Andric return new X86PartialReduction(); 59*5ffd83dbSDimitry Andric } 60*5ffd83dbSDimitry Andric 61*5ffd83dbSDimitry Andric char X86PartialReduction::ID = 0; 62*5ffd83dbSDimitry Andric 63*5ffd83dbSDimitry Andric INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE, 64*5ffd83dbSDimitry Andric "X86 Partial Reduction", false, false) 65*5ffd83dbSDimitry Andric 66*5ffd83dbSDimitry Andric bool X86PartialReduction::tryMAddReplacement(Instruction *Op) { 67*5ffd83dbSDimitry Andric if (!ST->hasSSE2()) 68*5ffd83dbSDimitry Andric return false; 69*5ffd83dbSDimitry Andric 70*5ffd83dbSDimitry Andric // Need at least 8 elements. 71*5ffd83dbSDimitry Andric if (cast<FixedVectorType>(Op->getType())->getNumElements() < 8) 72*5ffd83dbSDimitry Andric return false; 73*5ffd83dbSDimitry Andric 74*5ffd83dbSDimitry Andric // Element type should be i32. 75*5ffd83dbSDimitry Andric if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32)) 76*5ffd83dbSDimitry Andric return false; 77*5ffd83dbSDimitry Andric 78*5ffd83dbSDimitry Andric auto *Mul = dyn_cast<BinaryOperator>(Op); 79*5ffd83dbSDimitry Andric if (!Mul || Mul->getOpcode() != Instruction::Mul) 80*5ffd83dbSDimitry Andric return false; 81*5ffd83dbSDimitry Andric 82*5ffd83dbSDimitry Andric Value *LHS = Mul->getOperand(0); 83*5ffd83dbSDimitry Andric Value *RHS = Mul->getOperand(1); 84*5ffd83dbSDimitry Andric 85*5ffd83dbSDimitry Andric // LHS and RHS should be only used once or if they are the same then only 86*5ffd83dbSDimitry Andric // used twice. Only check this when SSE4.1 is enabled and we have zext/sext 87*5ffd83dbSDimitry Andric // instructions, otherwise we use punpck to emulate zero extend in stages. The 88*5ffd83dbSDimitry Andric // trunc/ we need to do likely won't introduce new instructions in that case. 89*5ffd83dbSDimitry Andric if (ST->hasSSE41()) { 90*5ffd83dbSDimitry Andric if (LHS == RHS) { 91*5ffd83dbSDimitry Andric if (!isa<Constant>(LHS) && !LHS->hasNUses(2)) 92*5ffd83dbSDimitry Andric return false; 93*5ffd83dbSDimitry Andric } else { 94*5ffd83dbSDimitry Andric if (!isa<Constant>(LHS) && !LHS->hasOneUse()) 95*5ffd83dbSDimitry Andric return false; 96*5ffd83dbSDimitry Andric if (!isa<Constant>(RHS) && !RHS->hasOneUse()) 97*5ffd83dbSDimitry Andric return false; 98*5ffd83dbSDimitry Andric } 99*5ffd83dbSDimitry Andric } 100*5ffd83dbSDimitry Andric 101*5ffd83dbSDimitry Andric auto CanShrinkOp = [&](Value *Op) { 102*5ffd83dbSDimitry Andric auto IsFreeTruncation = [&](Value *Op) { 103*5ffd83dbSDimitry Andric if (auto *Cast = dyn_cast<CastInst>(Op)) { 104*5ffd83dbSDimitry Andric if (Cast->getParent() == Mul->getParent() && 105*5ffd83dbSDimitry Andric (Cast->getOpcode() == Instruction::SExt || 106*5ffd83dbSDimitry Andric Cast->getOpcode() == Instruction::ZExt) && 107*5ffd83dbSDimitry Andric Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 16) 108*5ffd83dbSDimitry Andric return true; 109*5ffd83dbSDimitry Andric } 110*5ffd83dbSDimitry Andric 111*5ffd83dbSDimitry Andric return isa<Constant>(Op); 112*5ffd83dbSDimitry Andric }; 113*5ffd83dbSDimitry Andric 114*5ffd83dbSDimitry Andric // If the operation can be freely truncated and has enough sign bits we 115*5ffd83dbSDimitry Andric // can shrink. 116*5ffd83dbSDimitry Andric if (IsFreeTruncation(Op) && 117*5ffd83dbSDimitry Andric ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16) 118*5ffd83dbSDimitry Andric return true; 119*5ffd83dbSDimitry Andric 120*5ffd83dbSDimitry Andric // SelectionDAG has limited support for truncating through an add or sub if 121*5ffd83dbSDimitry Andric // the inputs are freely truncatable. 122*5ffd83dbSDimitry Andric if (auto *BO = dyn_cast<BinaryOperator>(Op)) { 123*5ffd83dbSDimitry Andric if (BO->getParent() == Mul->getParent() && 124*5ffd83dbSDimitry Andric IsFreeTruncation(BO->getOperand(0)) && 125*5ffd83dbSDimitry Andric IsFreeTruncation(BO->getOperand(1)) && 126*5ffd83dbSDimitry Andric ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16) 127*5ffd83dbSDimitry Andric return true; 128*5ffd83dbSDimitry Andric } 129*5ffd83dbSDimitry Andric 130*5ffd83dbSDimitry Andric return false; 131*5ffd83dbSDimitry Andric }; 132*5ffd83dbSDimitry Andric 133*5ffd83dbSDimitry Andric // Both Ops need to be shrinkable. 134*5ffd83dbSDimitry Andric if (!CanShrinkOp(LHS) && !CanShrinkOp(RHS)) 135*5ffd83dbSDimitry Andric return false; 136*5ffd83dbSDimitry Andric 137*5ffd83dbSDimitry Andric IRBuilder<> Builder(Mul); 138*5ffd83dbSDimitry Andric 139*5ffd83dbSDimitry Andric auto *MulTy = cast<FixedVectorType>(Op->getType()); 140*5ffd83dbSDimitry Andric unsigned NumElts = MulTy->getNumElements(); 141*5ffd83dbSDimitry Andric 142*5ffd83dbSDimitry Andric // Extract even elements and odd elements and add them together. This will 143*5ffd83dbSDimitry Andric // be pattern matched by SelectionDAG to pmaddwd. This instruction will be 144*5ffd83dbSDimitry Andric // half the original width. 145*5ffd83dbSDimitry Andric SmallVector<int, 16> EvenMask(NumElts / 2); 146*5ffd83dbSDimitry Andric SmallVector<int, 16> OddMask(NumElts / 2); 147*5ffd83dbSDimitry Andric for (int i = 0, e = NumElts / 2; i != e; ++i) { 148*5ffd83dbSDimitry Andric EvenMask[i] = i * 2; 149*5ffd83dbSDimitry Andric OddMask[i] = i * 2 + 1; 150*5ffd83dbSDimitry Andric } 151*5ffd83dbSDimitry Andric // Creating a new mul so the replaceAllUsesWith below doesn't replace the 152*5ffd83dbSDimitry Andric // uses in the shuffles we're creating. 153*5ffd83dbSDimitry Andric Value *NewMul = Builder.CreateMul(Mul->getOperand(0), Mul->getOperand(1)); 154*5ffd83dbSDimitry Andric Value *EvenElts = Builder.CreateShuffleVector(NewMul, NewMul, EvenMask); 155*5ffd83dbSDimitry Andric Value *OddElts = Builder.CreateShuffleVector(NewMul, NewMul, OddMask); 156*5ffd83dbSDimitry Andric Value *MAdd = Builder.CreateAdd(EvenElts, OddElts); 157*5ffd83dbSDimitry Andric 158*5ffd83dbSDimitry Andric // Concatenate zeroes to extend back to the original type. 159*5ffd83dbSDimitry Andric SmallVector<int, 32> ConcatMask(NumElts); 160*5ffd83dbSDimitry Andric std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 161*5ffd83dbSDimitry Andric Value *Zero = Constant::getNullValue(MAdd->getType()); 162*5ffd83dbSDimitry Andric Value *Concat = Builder.CreateShuffleVector(MAdd, Zero, ConcatMask); 163*5ffd83dbSDimitry Andric 164*5ffd83dbSDimitry Andric Mul->replaceAllUsesWith(Concat); 165*5ffd83dbSDimitry Andric Mul->eraseFromParent(); 166*5ffd83dbSDimitry Andric 167*5ffd83dbSDimitry Andric return true; 168*5ffd83dbSDimitry Andric } 169*5ffd83dbSDimitry Andric 170*5ffd83dbSDimitry Andric bool X86PartialReduction::trySADReplacement(Instruction *Op) { 171*5ffd83dbSDimitry Andric if (!ST->hasSSE2()) 172*5ffd83dbSDimitry Andric return false; 173*5ffd83dbSDimitry Andric 174*5ffd83dbSDimitry Andric // TODO: There's nothing special about i32, any integer type above i16 should 175*5ffd83dbSDimitry Andric // work just as well. 176*5ffd83dbSDimitry Andric if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32)) 177*5ffd83dbSDimitry Andric return false; 178*5ffd83dbSDimitry Andric 179*5ffd83dbSDimitry Andric // Operand should be a select. 180*5ffd83dbSDimitry Andric auto *SI = dyn_cast<SelectInst>(Op); 181*5ffd83dbSDimitry Andric if (!SI) 182*5ffd83dbSDimitry Andric return false; 183*5ffd83dbSDimitry Andric 184*5ffd83dbSDimitry Andric // Select needs to implement absolute value. 185*5ffd83dbSDimitry Andric Value *LHS, *RHS; 186*5ffd83dbSDimitry Andric auto SPR = matchSelectPattern(SI, LHS, RHS); 187*5ffd83dbSDimitry Andric if (SPR.Flavor != SPF_ABS) 188*5ffd83dbSDimitry Andric return false; 189*5ffd83dbSDimitry Andric 190*5ffd83dbSDimitry Andric // Need a subtract of two values. 191*5ffd83dbSDimitry Andric auto *Sub = dyn_cast<BinaryOperator>(LHS); 192*5ffd83dbSDimitry Andric if (!Sub || Sub->getOpcode() != Instruction::Sub) 193*5ffd83dbSDimitry Andric return false; 194*5ffd83dbSDimitry Andric 195*5ffd83dbSDimitry Andric // Look for zero extend from i8. 196*5ffd83dbSDimitry Andric auto getZeroExtendedVal = [](Value *Op) -> Value * { 197*5ffd83dbSDimitry Andric if (auto *ZExt = dyn_cast<ZExtInst>(Op)) 198*5ffd83dbSDimitry Andric if (cast<VectorType>(ZExt->getOperand(0)->getType()) 199*5ffd83dbSDimitry Andric ->getElementType() 200*5ffd83dbSDimitry Andric ->isIntegerTy(8)) 201*5ffd83dbSDimitry Andric return ZExt->getOperand(0); 202*5ffd83dbSDimitry Andric 203*5ffd83dbSDimitry Andric return nullptr; 204*5ffd83dbSDimitry Andric }; 205*5ffd83dbSDimitry Andric 206*5ffd83dbSDimitry Andric // Both operands of the subtract should be extends from vXi8. 207*5ffd83dbSDimitry Andric Value *Op0 = getZeroExtendedVal(Sub->getOperand(0)); 208*5ffd83dbSDimitry Andric Value *Op1 = getZeroExtendedVal(Sub->getOperand(1)); 209*5ffd83dbSDimitry Andric if (!Op0 || !Op1) 210*5ffd83dbSDimitry Andric return false; 211*5ffd83dbSDimitry Andric 212*5ffd83dbSDimitry Andric IRBuilder<> Builder(SI); 213*5ffd83dbSDimitry Andric 214*5ffd83dbSDimitry Andric auto *OpTy = cast<FixedVectorType>(Op->getType()); 215*5ffd83dbSDimitry Andric unsigned NumElts = OpTy->getNumElements(); 216*5ffd83dbSDimitry Andric 217*5ffd83dbSDimitry Andric unsigned IntrinsicNumElts; 218*5ffd83dbSDimitry Andric Intrinsic::ID IID; 219*5ffd83dbSDimitry Andric if (ST->hasBWI() && NumElts >= 64) { 220*5ffd83dbSDimitry Andric IID = Intrinsic::x86_avx512_psad_bw_512; 221*5ffd83dbSDimitry Andric IntrinsicNumElts = 64; 222*5ffd83dbSDimitry Andric } else if (ST->hasAVX2() && NumElts >= 32) { 223*5ffd83dbSDimitry Andric IID = Intrinsic::x86_avx2_psad_bw; 224*5ffd83dbSDimitry Andric IntrinsicNumElts = 32; 225*5ffd83dbSDimitry Andric } else { 226*5ffd83dbSDimitry Andric IID = Intrinsic::x86_sse2_psad_bw; 227*5ffd83dbSDimitry Andric IntrinsicNumElts = 16; 228*5ffd83dbSDimitry Andric } 229*5ffd83dbSDimitry Andric 230*5ffd83dbSDimitry Andric Function *PSADBWFn = Intrinsic::getDeclaration(SI->getModule(), IID); 231*5ffd83dbSDimitry Andric 232*5ffd83dbSDimitry Andric if (NumElts < 16) { 233*5ffd83dbSDimitry Andric // Pad input with zeroes. 234*5ffd83dbSDimitry Andric SmallVector<int, 32> ConcatMask(16); 235*5ffd83dbSDimitry Andric for (unsigned i = 0; i != NumElts; ++i) 236*5ffd83dbSDimitry Andric ConcatMask[i] = i; 237*5ffd83dbSDimitry Andric for (unsigned i = NumElts; i != 16; ++i) 238*5ffd83dbSDimitry Andric ConcatMask[i] = (i % NumElts) + NumElts; 239*5ffd83dbSDimitry Andric 240*5ffd83dbSDimitry Andric Value *Zero = Constant::getNullValue(Op0->getType()); 241*5ffd83dbSDimitry Andric Op0 = Builder.CreateShuffleVector(Op0, Zero, ConcatMask); 242*5ffd83dbSDimitry Andric Op1 = Builder.CreateShuffleVector(Op1, Zero, ConcatMask); 243*5ffd83dbSDimitry Andric NumElts = 16; 244*5ffd83dbSDimitry Andric } 245*5ffd83dbSDimitry Andric 246*5ffd83dbSDimitry Andric // Intrinsics produce vXi64 and need to be casted to vXi32. 247*5ffd83dbSDimitry Andric auto *I32Ty = 248*5ffd83dbSDimitry Andric FixedVectorType::get(Builder.getInt32Ty(), IntrinsicNumElts / 4); 249*5ffd83dbSDimitry Andric 250*5ffd83dbSDimitry Andric assert(NumElts % IntrinsicNumElts == 0 && "Unexpected number of elements!"); 251*5ffd83dbSDimitry Andric unsigned NumSplits = NumElts / IntrinsicNumElts; 252*5ffd83dbSDimitry Andric 253*5ffd83dbSDimitry Andric // First collect the pieces we need. 254*5ffd83dbSDimitry Andric SmallVector<Value *, 4> Ops(NumSplits); 255*5ffd83dbSDimitry Andric for (unsigned i = 0; i != NumSplits; ++i) { 256*5ffd83dbSDimitry Andric SmallVector<int, 64> ExtractMask(IntrinsicNumElts); 257*5ffd83dbSDimitry Andric std::iota(ExtractMask.begin(), ExtractMask.end(), i * IntrinsicNumElts); 258*5ffd83dbSDimitry Andric Value *ExtractOp0 = Builder.CreateShuffleVector(Op0, Op0, ExtractMask); 259*5ffd83dbSDimitry Andric Value *ExtractOp1 = Builder.CreateShuffleVector(Op1, Op0, ExtractMask); 260*5ffd83dbSDimitry Andric Ops[i] = Builder.CreateCall(PSADBWFn, {ExtractOp0, ExtractOp1}); 261*5ffd83dbSDimitry Andric Ops[i] = Builder.CreateBitCast(Ops[i], I32Ty); 262*5ffd83dbSDimitry Andric } 263*5ffd83dbSDimitry Andric 264*5ffd83dbSDimitry Andric assert(isPowerOf2_32(NumSplits) && "Expected power of 2 splits"); 265*5ffd83dbSDimitry Andric unsigned Stages = Log2_32(NumSplits); 266*5ffd83dbSDimitry Andric for (unsigned s = Stages; s > 0; --s) { 267*5ffd83dbSDimitry Andric unsigned NumConcatElts = 268*5ffd83dbSDimitry Andric cast<FixedVectorType>(Ops[0]->getType())->getNumElements() * 2; 269*5ffd83dbSDimitry Andric for (unsigned i = 0; i != 1U << (s - 1); ++i) { 270*5ffd83dbSDimitry Andric SmallVector<int, 64> ConcatMask(NumConcatElts); 271*5ffd83dbSDimitry Andric std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 272*5ffd83dbSDimitry Andric Ops[i] = Builder.CreateShuffleVector(Ops[i*2], Ops[i*2+1], ConcatMask); 273*5ffd83dbSDimitry Andric } 274*5ffd83dbSDimitry Andric } 275*5ffd83dbSDimitry Andric 276*5ffd83dbSDimitry Andric // At this point the final value should be in Ops[0]. Now we need to adjust 277*5ffd83dbSDimitry Andric // it to the final original type. 278*5ffd83dbSDimitry Andric NumElts = cast<FixedVectorType>(OpTy)->getNumElements(); 279*5ffd83dbSDimitry Andric if (NumElts == 2) { 280*5ffd83dbSDimitry Andric // Extract down to 2 elements. 281*5ffd83dbSDimitry Andric Ops[0] = Builder.CreateShuffleVector(Ops[0], Ops[0], ArrayRef<int>{0, 1}); 282*5ffd83dbSDimitry Andric } else if (NumElts >= 8) { 283*5ffd83dbSDimitry Andric SmallVector<int, 32> ConcatMask(NumElts); 284*5ffd83dbSDimitry Andric unsigned SubElts = 285*5ffd83dbSDimitry Andric cast<FixedVectorType>(Ops[0]->getType())->getNumElements(); 286*5ffd83dbSDimitry Andric for (unsigned i = 0; i != SubElts; ++i) 287*5ffd83dbSDimitry Andric ConcatMask[i] = i; 288*5ffd83dbSDimitry Andric for (unsigned i = SubElts; i != NumElts; ++i) 289*5ffd83dbSDimitry Andric ConcatMask[i] = (i % SubElts) + SubElts; 290*5ffd83dbSDimitry Andric 291*5ffd83dbSDimitry Andric Value *Zero = Constant::getNullValue(Ops[0]->getType()); 292*5ffd83dbSDimitry Andric Ops[0] = Builder.CreateShuffleVector(Ops[0], Zero, ConcatMask); 293*5ffd83dbSDimitry Andric } 294*5ffd83dbSDimitry Andric 295*5ffd83dbSDimitry Andric SI->replaceAllUsesWith(Ops[0]); 296*5ffd83dbSDimitry Andric SI->eraseFromParent(); 297*5ffd83dbSDimitry Andric 298*5ffd83dbSDimitry Andric return true; 299*5ffd83dbSDimitry Andric } 300*5ffd83dbSDimitry Andric 301*5ffd83dbSDimitry Andric // Walk backwards from the ExtractElementInst and determine if it is the end of 302*5ffd83dbSDimitry Andric // a horizontal reduction. Return the input to the reduction if we find one. 303*5ffd83dbSDimitry Andric static Value *matchAddReduction(const ExtractElementInst &EE) { 304*5ffd83dbSDimitry Andric // Make sure we're extracting index 0. 305*5ffd83dbSDimitry Andric auto *Index = dyn_cast<ConstantInt>(EE.getIndexOperand()); 306*5ffd83dbSDimitry Andric if (!Index || !Index->isNullValue()) 307*5ffd83dbSDimitry Andric return nullptr; 308*5ffd83dbSDimitry Andric 309*5ffd83dbSDimitry Andric const auto *BO = dyn_cast<BinaryOperator>(EE.getVectorOperand()); 310*5ffd83dbSDimitry Andric if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse()) 311*5ffd83dbSDimitry Andric return nullptr; 312*5ffd83dbSDimitry Andric 313*5ffd83dbSDimitry Andric unsigned NumElems = cast<FixedVectorType>(BO->getType())->getNumElements(); 314*5ffd83dbSDimitry Andric // Ensure the reduction size is a power of 2. 315*5ffd83dbSDimitry Andric if (!isPowerOf2_32(NumElems)) 316*5ffd83dbSDimitry Andric return nullptr; 317*5ffd83dbSDimitry Andric 318*5ffd83dbSDimitry Andric const Value *Op = BO; 319*5ffd83dbSDimitry Andric unsigned Stages = Log2_32(NumElems); 320*5ffd83dbSDimitry Andric for (unsigned i = 0; i != Stages; ++i) { 321*5ffd83dbSDimitry Andric const auto *BO = dyn_cast<BinaryOperator>(Op); 322*5ffd83dbSDimitry Andric if (!BO || BO->getOpcode() != Instruction::Add) 323*5ffd83dbSDimitry Andric return nullptr; 324*5ffd83dbSDimitry Andric 325*5ffd83dbSDimitry Andric // If this isn't the first add, then it should only have 2 users, the 326*5ffd83dbSDimitry Andric // shuffle and another add which we checked in the previous iteration. 327*5ffd83dbSDimitry Andric if (i != 0 && !BO->hasNUses(2)) 328*5ffd83dbSDimitry Andric return nullptr; 329*5ffd83dbSDimitry Andric 330*5ffd83dbSDimitry Andric Value *LHS = BO->getOperand(0); 331*5ffd83dbSDimitry Andric Value *RHS = BO->getOperand(1); 332*5ffd83dbSDimitry Andric 333*5ffd83dbSDimitry Andric auto *Shuffle = dyn_cast<ShuffleVectorInst>(LHS); 334*5ffd83dbSDimitry Andric if (Shuffle) { 335*5ffd83dbSDimitry Andric Op = RHS; 336*5ffd83dbSDimitry Andric } else { 337*5ffd83dbSDimitry Andric Shuffle = dyn_cast<ShuffleVectorInst>(RHS); 338*5ffd83dbSDimitry Andric Op = LHS; 339*5ffd83dbSDimitry Andric } 340*5ffd83dbSDimitry Andric 341*5ffd83dbSDimitry Andric // The first operand of the shuffle should be the same as the other operand 342*5ffd83dbSDimitry Andric // of the bin op. 343*5ffd83dbSDimitry Andric if (!Shuffle || Shuffle->getOperand(0) != Op) 344*5ffd83dbSDimitry Andric return nullptr; 345*5ffd83dbSDimitry Andric 346*5ffd83dbSDimitry Andric // Verify the shuffle has the expected (at this stage of the pyramid) mask. 347*5ffd83dbSDimitry Andric unsigned MaskEnd = 1 << i; 348*5ffd83dbSDimitry Andric for (unsigned Index = 0; Index < MaskEnd; ++Index) 349*5ffd83dbSDimitry Andric if (Shuffle->getMaskValue(Index) != (int)(MaskEnd + Index)) 350*5ffd83dbSDimitry Andric return nullptr; 351*5ffd83dbSDimitry Andric } 352*5ffd83dbSDimitry Andric 353*5ffd83dbSDimitry Andric return const_cast<Value *>(Op); 354*5ffd83dbSDimitry Andric } 355*5ffd83dbSDimitry Andric 356*5ffd83dbSDimitry Andric // See if this BO is reachable from this Phi by walking forward through single 357*5ffd83dbSDimitry Andric // use BinaryOperators with the same opcode. If we get back then we know we've 358*5ffd83dbSDimitry Andric // found a loop and it is safe to step through this Add to find more leaves. 359*5ffd83dbSDimitry Andric static bool isReachableFromPHI(PHINode *Phi, BinaryOperator *BO) { 360*5ffd83dbSDimitry Andric // The PHI itself should only have one use. 361*5ffd83dbSDimitry Andric if (!Phi->hasOneUse()) 362*5ffd83dbSDimitry Andric return false; 363*5ffd83dbSDimitry Andric 364*5ffd83dbSDimitry Andric Instruction *U = cast<Instruction>(*Phi->user_begin()); 365*5ffd83dbSDimitry Andric if (U == BO) 366*5ffd83dbSDimitry Andric return true; 367*5ffd83dbSDimitry Andric 368*5ffd83dbSDimitry Andric while (U->hasOneUse() && U->getOpcode() == BO->getOpcode()) 369*5ffd83dbSDimitry Andric U = cast<Instruction>(*U->user_begin()); 370*5ffd83dbSDimitry Andric 371*5ffd83dbSDimitry Andric return U == BO; 372*5ffd83dbSDimitry Andric } 373*5ffd83dbSDimitry Andric 374*5ffd83dbSDimitry Andric // Collect all the leaves of the tree of adds that feeds into the horizontal 375*5ffd83dbSDimitry Andric // reduction. Root is the Value that is used by the horizontal reduction. 376*5ffd83dbSDimitry Andric // We look through single use phis, single use adds, or adds that are used by 377*5ffd83dbSDimitry Andric // a phi that forms a loop with the add. 378*5ffd83dbSDimitry Andric static void collectLeaves(Value *Root, SmallVectorImpl<Instruction *> &Leaves) { 379*5ffd83dbSDimitry Andric SmallPtrSet<Value *, 8> Visited; 380*5ffd83dbSDimitry Andric SmallVector<Value *, 8> Worklist; 381*5ffd83dbSDimitry Andric Worklist.push_back(Root); 382*5ffd83dbSDimitry Andric 383*5ffd83dbSDimitry Andric while (!Worklist.empty()) { 384*5ffd83dbSDimitry Andric Value *V = Worklist.pop_back_val(); 385*5ffd83dbSDimitry Andric if (!Visited.insert(V).second) 386*5ffd83dbSDimitry Andric continue; 387*5ffd83dbSDimitry Andric 388*5ffd83dbSDimitry Andric if (auto *PN = dyn_cast<PHINode>(V)) { 389*5ffd83dbSDimitry Andric // PHI node should have single use unless it is the root node, then it 390*5ffd83dbSDimitry Andric // has 2 uses. 391*5ffd83dbSDimitry Andric if (!PN->hasNUses(PN == Root ? 2 : 1)) 392*5ffd83dbSDimitry Andric break; 393*5ffd83dbSDimitry Andric 394*5ffd83dbSDimitry Andric // Push incoming values to the worklist. 395*5ffd83dbSDimitry Andric for (Value *InV : PN->incoming_values()) 396*5ffd83dbSDimitry Andric Worklist.push_back(InV); 397*5ffd83dbSDimitry Andric 398*5ffd83dbSDimitry Andric continue; 399*5ffd83dbSDimitry Andric } 400*5ffd83dbSDimitry Andric 401*5ffd83dbSDimitry Andric if (auto *BO = dyn_cast<BinaryOperator>(V)) { 402*5ffd83dbSDimitry Andric if (BO->getOpcode() == Instruction::Add) { 403*5ffd83dbSDimitry Andric // Simple case. Single use, just push its operands to the worklist. 404*5ffd83dbSDimitry Andric if (BO->hasNUses(BO == Root ? 2 : 1)) { 405*5ffd83dbSDimitry Andric for (Value *Op : BO->operands()) 406*5ffd83dbSDimitry Andric Worklist.push_back(Op); 407*5ffd83dbSDimitry Andric continue; 408*5ffd83dbSDimitry Andric } 409*5ffd83dbSDimitry Andric 410*5ffd83dbSDimitry Andric // If there is additional use, make sure it is an unvisited phi that 411*5ffd83dbSDimitry Andric // gets us back to this node. 412*5ffd83dbSDimitry Andric if (BO->hasNUses(BO == Root ? 3 : 2)) { 413*5ffd83dbSDimitry Andric PHINode *PN = nullptr; 414*5ffd83dbSDimitry Andric for (auto *U : Root->users()) 415*5ffd83dbSDimitry Andric if (auto *P = dyn_cast<PHINode>(U)) 416*5ffd83dbSDimitry Andric if (!Visited.count(P)) 417*5ffd83dbSDimitry Andric PN = P; 418*5ffd83dbSDimitry Andric 419*5ffd83dbSDimitry Andric // If we didn't find a 2-input PHI then this isn't a case we can 420*5ffd83dbSDimitry Andric // handle. 421*5ffd83dbSDimitry Andric if (!PN || PN->getNumIncomingValues() != 2) 422*5ffd83dbSDimitry Andric continue; 423*5ffd83dbSDimitry Andric 424*5ffd83dbSDimitry Andric // Walk forward from this phi to see if it reaches back to this add. 425*5ffd83dbSDimitry Andric if (!isReachableFromPHI(PN, BO)) 426*5ffd83dbSDimitry Andric continue; 427*5ffd83dbSDimitry Andric 428*5ffd83dbSDimitry Andric // The phi forms a loop with this Add, push its operands. 429*5ffd83dbSDimitry Andric for (Value *Op : BO->operands()) 430*5ffd83dbSDimitry Andric Worklist.push_back(Op); 431*5ffd83dbSDimitry Andric } 432*5ffd83dbSDimitry Andric } 433*5ffd83dbSDimitry Andric } 434*5ffd83dbSDimitry Andric 435*5ffd83dbSDimitry Andric // Not an add or phi, make it a leaf. 436*5ffd83dbSDimitry Andric if (auto *I = dyn_cast<Instruction>(V)) { 437*5ffd83dbSDimitry Andric if (!V->hasNUses(I == Root ? 2 : 1)) 438*5ffd83dbSDimitry Andric continue; 439*5ffd83dbSDimitry Andric 440*5ffd83dbSDimitry Andric // Add this as a leaf. 441*5ffd83dbSDimitry Andric Leaves.push_back(I); 442*5ffd83dbSDimitry Andric } 443*5ffd83dbSDimitry Andric } 444*5ffd83dbSDimitry Andric } 445*5ffd83dbSDimitry Andric 446*5ffd83dbSDimitry Andric bool X86PartialReduction::runOnFunction(Function &F) { 447*5ffd83dbSDimitry Andric if (skipFunction(F)) 448*5ffd83dbSDimitry Andric return false; 449*5ffd83dbSDimitry Andric 450*5ffd83dbSDimitry Andric auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); 451*5ffd83dbSDimitry Andric if (!TPC) 452*5ffd83dbSDimitry Andric return false; 453*5ffd83dbSDimitry Andric 454*5ffd83dbSDimitry Andric auto &TM = TPC->getTM<X86TargetMachine>(); 455*5ffd83dbSDimitry Andric ST = TM.getSubtargetImpl(F); 456*5ffd83dbSDimitry Andric 457*5ffd83dbSDimitry Andric DL = &F.getParent()->getDataLayout(); 458*5ffd83dbSDimitry Andric 459*5ffd83dbSDimitry Andric bool MadeChange = false; 460*5ffd83dbSDimitry Andric for (auto &BB : F) { 461*5ffd83dbSDimitry Andric for (auto &I : BB) { 462*5ffd83dbSDimitry Andric auto *EE = dyn_cast<ExtractElementInst>(&I); 463*5ffd83dbSDimitry Andric if (!EE) 464*5ffd83dbSDimitry Andric continue; 465*5ffd83dbSDimitry Andric 466*5ffd83dbSDimitry Andric // First find a reduction tree. 467*5ffd83dbSDimitry Andric // FIXME: Do we need to handle other opcodes than Add? 468*5ffd83dbSDimitry Andric Value *Root = matchAddReduction(*EE); 469*5ffd83dbSDimitry Andric if (!Root) 470*5ffd83dbSDimitry Andric continue; 471*5ffd83dbSDimitry Andric 472*5ffd83dbSDimitry Andric SmallVector<Instruction *, 8> Leaves; 473*5ffd83dbSDimitry Andric collectLeaves(Root, Leaves); 474*5ffd83dbSDimitry Andric 475*5ffd83dbSDimitry Andric for (Instruction *I : Leaves) { 476*5ffd83dbSDimitry Andric if (tryMAddReplacement(I)) { 477*5ffd83dbSDimitry Andric MadeChange = true; 478*5ffd83dbSDimitry Andric continue; 479*5ffd83dbSDimitry Andric } 480*5ffd83dbSDimitry Andric 481*5ffd83dbSDimitry Andric // Don't do SAD matching on the root node. SelectionDAG already 482*5ffd83dbSDimitry Andric // has support for that and currently generates better code. 483*5ffd83dbSDimitry Andric if (I != Root && trySADReplacement(I)) 484*5ffd83dbSDimitry Andric MadeChange = true; 485*5ffd83dbSDimitry Andric } 486*5ffd83dbSDimitry Andric } 487*5ffd83dbSDimitry Andric } 488*5ffd83dbSDimitry Andric 489*5ffd83dbSDimitry Andric return MadeChange; 490*5ffd83dbSDimitry Andric } 491