xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/X86/X86PartialReduction.cpp (revision 5ffd83dbcc34f10e07f6d3e968ae6365869615f4)
1*5ffd83dbSDimitry Andric //===-- X86PartialReduction.cpp -------------------------------------------===//
2*5ffd83dbSDimitry Andric //
3*5ffd83dbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*5ffd83dbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*5ffd83dbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*5ffd83dbSDimitry Andric //
7*5ffd83dbSDimitry Andric //===----------------------------------------------------------------------===//
8*5ffd83dbSDimitry Andric //
9*5ffd83dbSDimitry Andric // This pass looks for add instructions used by a horizontal reduction to see
10*5ffd83dbSDimitry Andric // if we might be able to use pmaddwd or psadbw. Some cases of this require
11*5ffd83dbSDimitry Andric // cross basic block knowledge and can't be done in SelectionDAG.
12*5ffd83dbSDimitry Andric //
13*5ffd83dbSDimitry Andric //===----------------------------------------------------------------------===//
14*5ffd83dbSDimitry Andric 
15*5ffd83dbSDimitry Andric #include "X86.h"
16*5ffd83dbSDimitry Andric #include "llvm/Analysis/ValueTracking.h"
17*5ffd83dbSDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
18*5ffd83dbSDimitry Andric #include "llvm/IR/Constants.h"
19*5ffd83dbSDimitry Andric #include "llvm/IR/Instructions.h"
20*5ffd83dbSDimitry Andric #include "llvm/IR/IntrinsicsX86.h"
21*5ffd83dbSDimitry Andric #include "llvm/IR/IRBuilder.h"
22*5ffd83dbSDimitry Andric #include "llvm/IR/Operator.h"
23*5ffd83dbSDimitry Andric #include "llvm/Pass.h"
24*5ffd83dbSDimitry Andric #include "X86TargetMachine.h"
25*5ffd83dbSDimitry Andric 
26*5ffd83dbSDimitry Andric using namespace llvm;
27*5ffd83dbSDimitry Andric 
28*5ffd83dbSDimitry Andric #define DEBUG_TYPE "x86-partial-reduction"
29*5ffd83dbSDimitry Andric 
30*5ffd83dbSDimitry Andric namespace {
31*5ffd83dbSDimitry Andric 
32*5ffd83dbSDimitry Andric class X86PartialReduction : public FunctionPass {
33*5ffd83dbSDimitry Andric   const DataLayout *DL;
34*5ffd83dbSDimitry Andric   const X86Subtarget *ST;
35*5ffd83dbSDimitry Andric 
36*5ffd83dbSDimitry Andric public:
37*5ffd83dbSDimitry Andric   static char ID; // Pass identification, replacement for typeid.
38*5ffd83dbSDimitry Andric 
39*5ffd83dbSDimitry Andric   X86PartialReduction() : FunctionPass(ID) { }
40*5ffd83dbSDimitry Andric 
41*5ffd83dbSDimitry Andric   bool runOnFunction(Function &Fn) override;
42*5ffd83dbSDimitry Andric 
43*5ffd83dbSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
44*5ffd83dbSDimitry Andric     AU.setPreservesCFG();
45*5ffd83dbSDimitry Andric   }
46*5ffd83dbSDimitry Andric 
47*5ffd83dbSDimitry Andric   StringRef getPassName() const override {
48*5ffd83dbSDimitry Andric     return "X86 Partial Reduction";
49*5ffd83dbSDimitry Andric   }
50*5ffd83dbSDimitry Andric 
51*5ffd83dbSDimitry Andric private:
52*5ffd83dbSDimitry Andric   bool tryMAddReplacement(Instruction *Op);
53*5ffd83dbSDimitry Andric   bool trySADReplacement(Instruction *Op);
54*5ffd83dbSDimitry Andric };
55*5ffd83dbSDimitry Andric }
56*5ffd83dbSDimitry Andric 
57*5ffd83dbSDimitry Andric FunctionPass *llvm::createX86PartialReductionPass() {
58*5ffd83dbSDimitry Andric   return new X86PartialReduction();
59*5ffd83dbSDimitry Andric }
60*5ffd83dbSDimitry Andric 
61*5ffd83dbSDimitry Andric char X86PartialReduction::ID = 0;
62*5ffd83dbSDimitry Andric 
63*5ffd83dbSDimitry Andric INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE,
64*5ffd83dbSDimitry Andric                 "X86 Partial Reduction", false, false)
65*5ffd83dbSDimitry Andric 
66*5ffd83dbSDimitry Andric bool X86PartialReduction::tryMAddReplacement(Instruction *Op) {
67*5ffd83dbSDimitry Andric   if (!ST->hasSSE2())
68*5ffd83dbSDimitry Andric     return false;
69*5ffd83dbSDimitry Andric 
70*5ffd83dbSDimitry Andric   // Need at least 8 elements.
71*5ffd83dbSDimitry Andric   if (cast<FixedVectorType>(Op->getType())->getNumElements() < 8)
72*5ffd83dbSDimitry Andric     return false;
73*5ffd83dbSDimitry Andric 
74*5ffd83dbSDimitry Andric   // Element type should be i32.
75*5ffd83dbSDimitry Andric   if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32))
76*5ffd83dbSDimitry Andric     return false;
77*5ffd83dbSDimitry Andric 
78*5ffd83dbSDimitry Andric   auto *Mul = dyn_cast<BinaryOperator>(Op);
79*5ffd83dbSDimitry Andric   if (!Mul || Mul->getOpcode() != Instruction::Mul)
80*5ffd83dbSDimitry Andric     return false;
81*5ffd83dbSDimitry Andric 
82*5ffd83dbSDimitry Andric   Value *LHS = Mul->getOperand(0);
83*5ffd83dbSDimitry Andric   Value *RHS = Mul->getOperand(1);
84*5ffd83dbSDimitry Andric 
85*5ffd83dbSDimitry Andric   // LHS and RHS should be only used once or if they are the same then only
86*5ffd83dbSDimitry Andric   // used twice. Only check this when SSE4.1 is enabled and we have zext/sext
87*5ffd83dbSDimitry Andric   // instructions, otherwise we use punpck to emulate zero extend in stages. The
88*5ffd83dbSDimitry Andric   // trunc/ we need to do likely won't introduce new instructions in that case.
89*5ffd83dbSDimitry Andric   if (ST->hasSSE41()) {
90*5ffd83dbSDimitry Andric     if (LHS == RHS) {
91*5ffd83dbSDimitry Andric       if (!isa<Constant>(LHS) && !LHS->hasNUses(2))
92*5ffd83dbSDimitry Andric         return false;
93*5ffd83dbSDimitry Andric     } else {
94*5ffd83dbSDimitry Andric       if (!isa<Constant>(LHS) && !LHS->hasOneUse())
95*5ffd83dbSDimitry Andric         return false;
96*5ffd83dbSDimitry Andric       if (!isa<Constant>(RHS) && !RHS->hasOneUse())
97*5ffd83dbSDimitry Andric         return false;
98*5ffd83dbSDimitry Andric     }
99*5ffd83dbSDimitry Andric   }
100*5ffd83dbSDimitry Andric 
101*5ffd83dbSDimitry Andric   auto CanShrinkOp = [&](Value *Op) {
102*5ffd83dbSDimitry Andric     auto IsFreeTruncation = [&](Value *Op) {
103*5ffd83dbSDimitry Andric       if (auto *Cast = dyn_cast<CastInst>(Op)) {
104*5ffd83dbSDimitry Andric         if (Cast->getParent() == Mul->getParent() &&
105*5ffd83dbSDimitry Andric             (Cast->getOpcode() == Instruction::SExt ||
106*5ffd83dbSDimitry Andric              Cast->getOpcode() == Instruction::ZExt) &&
107*5ffd83dbSDimitry Andric             Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 16)
108*5ffd83dbSDimitry Andric           return true;
109*5ffd83dbSDimitry Andric       }
110*5ffd83dbSDimitry Andric 
111*5ffd83dbSDimitry Andric       return isa<Constant>(Op);
112*5ffd83dbSDimitry Andric     };
113*5ffd83dbSDimitry Andric 
114*5ffd83dbSDimitry Andric     // If the operation can be freely truncated and has enough sign bits we
115*5ffd83dbSDimitry Andric     // can shrink.
116*5ffd83dbSDimitry Andric     if (IsFreeTruncation(Op) &&
117*5ffd83dbSDimitry Andric         ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16)
118*5ffd83dbSDimitry Andric       return true;
119*5ffd83dbSDimitry Andric 
120*5ffd83dbSDimitry Andric     // SelectionDAG has limited support for truncating through an add or sub if
121*5ffd83dbSDimitry Andric     // the inputs are freely truncatable.
122*5ffd83dbSDimitry Andric     if (auto *BO = dyn_cast<BinaryOperator>(Op)) {
123*5ffd83dbSDimitry Andric       if (BO->getParent() == Mul->getParent() &&
124*5ffd83dbSDimitry Andric           IsFreeTruncation(BO->getOperand(0)) &&
125*5ffd83dbSDimitry Andric           IsFreeTruncation(BO->getOperand(1)) &&
126*5ffd83dbSDimitry Andric           ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16)
127*5ffd83dbSDimitry Andric         return true;
128*5ffd83dbSDimitry Andric     }
129*5ffd83dbSDimitry Andric 
130*5ffd83dbSDimitry Andric     return false;
131*5ffd83dbSDimitry Andric   };
132*5ffd83dbSDimitry Andric 
133*5ffd83dbSDimitry Andric   // Both Ops need to be shrinkable.
134*5ffd83dbSDimitry Andric   if (!CanShrinkOp(LHS) && !CanShrinkOp(RHS))
135*5ffd83dbSDimitry Andric     return false;
136*5ffd83dbSDimitry Andric 
137*5ffd83dbSDimitry Andric   IRBuilder<> Builder(Mul);
138*5ffd83dbSDimitry Andric 
139*5ffd83dbSDimitry Andric   auto *MulTy = cast<FixedVectorType>(Op->getType());
140*5ffd83dbSDimitry Andric   unsigned NumElts = MulTy->getNumElements();
141*5ffd83dbSDimitry Andric 
142*5ffd83dbSDimitry Andric   // Extract even elements and odd elements and add them together. This will
143*5ffd83dbSDimitry Andric   // be pattern matched by SelectionDAG to pmaddwd. This instruction will be
144*5ffd83dbSDimitry Andric   // half the original width.
145*5ffd83dbSDimitry Andric   SmallVector<int, 16> EvenMask(NumElts / 2);
146*5ffd83dbSDimitry Andric   SmallVector<int, 16> OddMask(NumElts / 2);
147*5ffd83dbSDimitry Andric   for (int i = 0, e = NumElts / 2; i != e; ++i) {
148*5ffd83dbSDimitry Andric     EvenMask[i] = i * 2;
149*5ffd83dbSDimitry Andric     OddMask[i] = i * 2 + 1;
150*5ffd83dbSDimitry Andric   }
151*5ffd83dbSDimitry Andric   // Creating a new mul so the replaceAllUsesWith below doesn't replace the
152*5ffd83dbSDimitry Andric   // uses in the shuffles we're creating.
153*5ffd83dbSDimitry Andric   Value *NewMul = Builder.CreateMul(Mul->getOperand(0), Mul->getOperand(1));
154*5ffd83dbSDimitry Andric   Value *EvenElts = Builder.CreateShuffleVector(NewMul, NewMul, EvenMask);
155*5ffd83dbSDimitry Andric   Value *OddElts = Builder.CreateShuffleVector(NewMul, NewMul, OddMask);
156*5ffd83dbSDimitry Andric   Value *MAdd = Builder.CreateAdd(EvenElts, OddElts);
157*5ffd83dbSDimitry Andric 
158*5ffd83dbSDimitry Andric   // Concatenate zeroes to extend back to the original type.
159*5ffd83dbSDimitry Andric   SmallVector<int, 32> ConcatMask(NumElts);
160*5ffd83dbSDimitry Andric   std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
161*5ffd83dbSDimitry Andric   Value *Zero = Constant::getNullValue(MAdd->getType());
162*5ffd83dbSDimitry Andric   Value *Concat = Builder.CreateShuffleVector(MAdd, Zero, ConcatMask);
163*5ffd83dbSDimitry Andric 
164*5ffd83dbSDimitry Andric   Mul->replaceAllUsesWith(Concat);
165*5ffd83dbSDimitry Andric   Mul->eraseFromParent();
166*5ffd83dbSDimitry Andric 
167*5ffd83dbSDimitry Andric   return true;
168*5ffd83dbSDimitry Andric }
169*5ffd83dbSDimitry Andric 
170*5ffd83dbSDimitry Andric bool X86PartialReduction::trySADReplacement(Instruction *Op) {
171*5ffd83dbSDimitry Andric   if (!ST->hasSSE2())
172*5ffd83dbSDimitry Andric     return false;
173*5ffd83dbSDimitry Andric 
174*5ffd83dbSDimitry Andric   // TODO: There's nothing special about i32, any integer type above i16 should
175*5ffd83dbSDimitry Andric   // work just as well.
176*5ffd83dbSDimitry Andric   if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32))
177*5ffd83dbSDimitry Andric     return false;
178*5ffd83dbSDimitry Andric 
179*5ffd83dbSDimitry Andric   // Operand should be a select.
180*5ffd83dbSDimitry Andric   auto *SI = dyn_cast<SelectInst>(Op);
181*5ffd83dbSDimitry Andric   if (!SI)
182*5ffd83dbSDimitry Andric     return false;
183*5ffd83dbSDimitry Andric 
184*5ffd83dbSDimitry Andric   // Select needs to implement absolute value.
185*5ffd83dbSDimitry Andric   Value *LHS, *RHS;
186*5ffd83dbSDimitry Andric   auto SPR = matchSelectPattern(SI, LHS, RHS);
187*5ffd83dbSDimitry Andric   if (SPR.Flavor != SPF_ABS)
188*5ffd83dbSDimitry Andric     return false;
189*5ffd83dbSDimitry Andric 
190*5ffd83dbSDimitry Andric   // Need a subtract of two values.
191*5ffd83dbSDimitry Andric   auto *Sub = dyn_cast<BinaryOperator>(LHS);
192*5ffd83dbSDimitry Andric   if (!Sub || Sub->getOpcode() != Instruction::Sub)
193*5ffd83dbSDimitry Andric     return false;
194*5ffd83dbSDimitry Andric 
195*5ffd83dbSDimitry Andric   // Look for zero extend from i8.
196*5ffd83dbSDimitry Andric   auto getZeroExtendedVal = [](Value *Op) -> Value * {
197*5ffd83dbSDimitry Andric     if (auto *ZExt = dyn_cast<ZExtInst>(Op))
198*5ffd83dbSDimitry Andric       if (cast<VectorType>(ZExt->getOperand(0)->getType())
199*5ffd83dbSDimitry Andric               ->getElementType()
200*5ffd83dbSDimitry Andric               ->isIntegerTy(8))
201*5ffd83dbSDimitry Andric         return ZExt->getOperand(0);
202*5ffd83dbSDimitry Andric 
203*5ffd83dbSDimitry Andric     return nullptr;
204*5ffd83dbSDimitry Andric   };
205*5ffd83dbSDimitry Andric 
206*5ffd83dbSDimitry Andric   // Both operands of the subtract should be extends from vXi8.
207*5ffd83dbSDimitry Andric   Value *Op0 = getZeroExtendedVal(Sub->getOperand(0));
208*5ffd83dbSDimitry Andric   Value *Op1 = getZeroExtendedVal(Sub->getOperand(1));
209*5ffd83dbSDimitry Andric   if (!Op0 || !Op1)
210*5ffd83dbSDimitry Andric     return false;
211*5ffd83dbSDimitry Andric 
212*5ffd83dbSDimitry Andric   IRBuilder<> Builder(SI);
213*5ffd83dbSDimitry Andric 
214*5ffd83dbSDimitry Andric   auto *OpTy = cast<FixedVectorType>(Op->getType());
215*5ffd83dbSDimitry Andric   unsigned NumElts = OpTy->getNumElements();
216*5ffd83dbSDimitry Andric 
217*5ffd83dbSDimitry Andric   unsigned IntrinsicNumElts;
218*5ffd83dbSDimitry Andric   Intrinsic::ID IID;
219*5ffd83dbSDimitry Andric   if (ST->hasBWI() && NumElts >= 64) {
220*5ffd83dbSDimitry Andric     IID = Intrinsic::x86_avx512_psad_bw_512;
221*5ffd83dbSDimitry Andric     IntrinsicNumElts = 64;
222*5ffd83dbSDimitry Andric   } else if (ST->hasAVX2() && NumElts >= 32) {
223*5ffd83dbSDimitry Andric     IID = Intrinsic::x86_avx2_psad_bw;
224*5ffd83dbSDimitry Andric     IntrinsicNumElts = 32;
225*5ffd83dbSDimitry Andric   } else {
226*5ffd83dbSDimitry Andric     IID = Intrinsic::x86_sse2_psad_bw;
227*5ffd83dbSDimitry Andric     IntrinsicNumElts = 16;
228*5ffd83dbSDimitry Andric   }
229*5ffd83dbSDimitry Andric 
230*5ffd83dbSDimitry Andric   Function *PSADBWFn = Intrinsic::getDeclaration(SI->getModule(), IID);
231*5ffd83dbSDimitry Andric 
232*5ffd83dbSDimitry Andric   if (NumElts < 16) {
233*5ffd83dbSDimitry Andric     // Pad input with zeroes.
234*5ffd83dbSDimitry Andric     SmallVector<int, 32> ConcatMask(16);
235*5ffd83dbSDimitry Andric     for (unsigned i = 0; i != NumElts; ++i)
236*5ffd83dbSDimitry Andric       ConcatMask[i] = i;
237*5ffd83dbSDimitry Andric     for (unsigned i = NumElts; i != 16; ++i)
238*5ffd83dbSDimitry Andric       ConcatMask[i] = (i % NumElts) + NumElts;
239*5ffd83dbSDimitry Andric 
240*5ffd83dbSDimitry Andric     Value *Zero = Constant::getNullValue(Op0->getType());
241*5ffd83dbSDimitry Andric     Op0 = Builder.CreateShuffleVector(Op0, Zero, ConcatMask);
242*5ffd83dbSDimitry Andric     Op1 = Builder.CreateShuffleVector(Op1, Zero, ConcatMask);
243*5ffd83dbSDimitry Andric     NumElts = 16;
244*5ffd83dbSDimitry Andric   }
245*5ffd83dbSDimitry Andric 
246*5ffd83dbSDimitry Andric   // Intrinsics produce vXi64 and need to be casted to vXi32.
247*5ffd83dbSDimitry Andric   auto *I32Ty =
248*5ffd83dbSDimitry Andric       FixedVectorType::get(Builder.getInt32Ty(), IntrinsicNumElts / 4);
249*5ffd83dbSDimitry Andric 
250*5ffd83dbSDimitry Andric   assert(NumElts % IntrinsicNumElts == 0 && "Unexpected number of elements!");
251*5ffd83dbSDimitry Andric   unsigned NumSplits = NumElts / IntrinsicNumElts;
252*5ffd83dbSDimitry Andric 
253*5ffd83dbSDimitry Andric   // First collect the pieces we need.
254*5ffd83dbSDimitry Andric   SmallVector<Value *, 4> Ops(NumSplits);
255*5ffd83dbSDimitry Andric   for (unsigned i = 0; i != NumSplits; ++i) {
256*5ffd83dbSDimitry Andric     SmallVector<int, 64> ExtractMask(IntrinsicNumElts);
257*5ffd83dbSDimitry Andric     std::iota(ExtractMask.begin(), ExtractMask.end(), i * IntrinsicNumElts);
258*5ffd83dbSDimitry Andric     Value *ExtractOp0 = Builder.CreateShuffleVector(Op0, Op0, ExtractMask);
259*5ffd83dbSDimitry Andric     Value *ExtractOp1 = Builder.CreateShuffleVector(Op1, Op0, ExtractMask);
260*5ffd83dbSDimitry Andric     Ops[i] = Builder.CreateCall(PSADBWFn, {ExtractOp0, ExtractOp1});
261*5ffd83dbSDimitry Andric     Ops[i] = Builder.CreateBitCast(Ops[i], I32Ty);
262*5ffd83dbSDimitry Andric   }
263*5ffd83dbSDimitry Andric 
264*5ffd83dbSDimitry Andric   assert(isPowerOf2_32(NumSplits) && "Expected power of 2 splits");
265*5ffd83dbSDimitry Andric   unsigned Stages = Log2_32(NumSplits);
266*5ffd83dbSDimitry Andric   for (unsigned s = Stages; s > 0; --s) {
267*5ffd83dbSDimitry Andric     unsigned NumConcatElts =
268*5ffd83dbSDimitry Andric         cast<FixedVectorType>(Ops[0]->getType())->getNumElements() * 2;
269*5ffd83dbSDimitry Andric     for (unsigned i = 0; i != 1U << (s - 1); ++i) {
270*5ffd83dbSDimitry Andric       SmallVector<int, 64> ConcatMask(NumConcatElts);
271*5ffd83dbSDimitry Andric       std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
272*5ffd83dbSDimitry Andric       Ops[i] = Builder.CreateShuffleVector(Ops[i*2], Ops[i*2+1], ConcatMask);
273*5ffd83dbSDimitry Andric     }
274*5ffd83dbSDimitry Andric   }
275*5ffd83dbSDimitry Andric 
276*5ffd83dbSDimitry Andric   // At this point the final value should be in Ops[0]. Now we need to adjust
277*5ffd83dbSDimitry Andric   // it to the final original type.
278*5ffd83dbSDimitry Andric   NumElts = cast<FixedVectorType>(OpTy)->getNumElements();
279*5ffd83dbSDimitry Andric   if (NumElts == 2) {
280*5ffd83dbSDimitry Andric     // Extract down to 2 elements.
281*5ffd83dbSDimitry Andric     Ops[0] = Builder.CreateShuffleVector(Ops[0], Ops[0], ArrayRef<int>{0, 1});
282*5ffd83dbSDimitry Andric   } else if (NumElts >= 8) {
283*5ffd83dbSDimitry Andric     SmallVector<int, 32> ConcatMask(NumElts);
284*5ffd83dbSDimitry Andric     unsigned SubElts =
285*5ffd83dbSDimitry Andric         cast<FixedVectorType>(Ops[0]->getType())->getNumElements();
286*5ffd83dbSDimitry Andric     for (unsigned i = 0; i != SubElts; ++i)
287*5ffd83dbSDimitry Andric       ConcatMask[i] = i;
288*5ffd83dbSDimitry Andric     for (unsigned i = SubElts; i != NumElts; ++i)
289*5ffd83dbSDimitry Andric       ConcatMask[i] = (i % SubElts) + SubElts;
290*5ffd83dbSDimitry Andric 
291*5ffd83dbSDimitry Andric     Value *Zero = Constant::getNullValue(Ops[0]->getType());
292*5ffd83dbSDimitry Andric     Ops[0] = Builder.CreateShuffleVector(Ops[0], Zero, ConcatMask);
293*5ffd83dbSDimitry Andric   }
294*5ffd83dbSDimitry Andric 
295*5ffd83dbSDimitry Andric   SI->replaceAllUsesWith(Ops[0]);
296*5ffd83dbSDimitry Andric   SI->eraseFromParent();
297*5ffd83dbSDimitry Andric 
298*5ffd83dbSDimitry Andric   return true;
299*5ffd83dbSDimitry Andric }
300*5ffd83dbSDimitry Andric 
301*5ffd83dbSDimitry Andric // Walk backwards from the ExtractElementInst and determine if it is the end of
302*5ffd83dbSDimitry Andric // a horizontal reduction. Return the input to the reduction if we find one.
303*5ffd83dbSDimitry Andric static Value *matchAddReduction(const ExtractElementInst &EE) {
304*5ffd83dbSDimitry Andric   // Make sure we're extracting index 0.
305*5ffd83dbSDimitry Andric   auto *Index = dyn_cast<ConstantInt>(EE.getIndexOperand());
306*5ffd83dbSDimitry Andric   if (!Index || !Index->isNullValue())
307*5ffd83dbSDimitry Andric     return nullptr;
308*5ffd83dbSDimitry Andric 
309*5ffd83dbSDimitry Andric   const auto *BO = dyn_cast<BinaryOperator>(EE.getVectorOperand());
310*5ffd83dbSDimitry Andric   if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse())
311*5ffd83dbSDimitry Andric     return nullptr;
312*5ffd83dbSDimitry Andric 
313*5ffd83dbSDimitry Andric   unsigned NumElems = cast<FixedVectorType>(BO->getType())->getNumElements();
314*5ffd83dbSDimitry Andric   // Ensure the reduction size is a power of 2.
315*5ffd83dbSDimitry Andric   if (!isPowerOf2_32(NumElems))
316*5ffd83dbSDimitry Andric     return nullptr;
317*5ffd83dbSDimitry Andric 
318*5ffd83dbSDimitry Andric   const Value *Op = BO;
319*5ffd83dbSDimitry Andric   unsigned Stages = Log2_32(NumElems);
320*5ffd83dbSDimitry Andric   for (unsigned i = 0; i != Stages; ++i) {
321*5ffd83dbSDimitry Andric     const auto *BO = dyn_cast<BinaryOperator>(Op);
322*5ffd83dbSDimitry Andric     if (!BO || BO->getOpcode() != Instruction::Add)
323*5ffd83dbSDimitry Andric       return nullptr;
324*5ffd83dbSDimitry Andric 
325*5ffd83dbSDimitry Andric     // If this isn't the first add, then it should only have 2 users, the
326*5ffd83dbSDimitry Andric     // shuffle and another add which we checked in the previous iteration.
327*5ffd83dbSDimitry Andric     if (i != 0 && !BO->hasNUses(2))
328*5ffd83dbSDimitry Andric       return nullptr;
329*5ffd83dbSDimitry Andric 
330*5ffd83dbSDimitry Andric     Value *LHS = BO->getOperand(0);
331*5ffd83dbSDimitry Andric     Value *RHS = BO->getOperand(1);
332*5ffd83dbSDimitry Andric 
333*5ffd83dbSDimitry Andric     auto *Shuffle = dyn_cast<ShuffleVectorInst>(LHS);
334*5ffd83dbSDimitry Andric     if (Shuffle) {
335*5ffd83dbSDimitry Andric       Op = RHS;
336*5ffd83dbSDimitry Andric     } else {
337*5ffd83dbSDimitry Andric       Shuffle = dyn_cast<ShuffleVectorInst>(RHS);
338*5ffd83dbSDimitry Andric       Op = LHS;
339*5ffd83dbSDimitry Andric     }
340*5ffd83dbSDimitry Andric 
341*5ffd83dbSDimitry Andric     // The first operand of the shuffle should be the same as the other operand
342*5ffd83dbSDimitry Andric     // of the bin op.
343*5ffd83dbSDimitry Andric     if (!Shuffle || Shuffle->getOperand(0) != Op)
344*5ffd83dbSDimitry Andric       return nullptr;
345*5ffd83dbSDimitry Andric 
346*5ffd83dbSDimitry Andric     // Verify the shuffle has the expected (at this stage of the pyramid) mask.
347*5ffd83dbSDimitry Andric     unsigned MaskEnd = 1 << i;
348*5ffd83dbSDimitry Andric     for (unsigned Index = 0; Index < MaskEnd; ++Index)
349*5ffd83dbSDimitry Andric       if (Shuffle->getMaskValue(Index) != (int)(MaskEnd + Index))
350*5ffd83dbSDimitry Andric         return nullptr;
351*5ffd83dbSDimitry Andric   }
352*5ffd83dbSDimitry Andric 
353*5ffd83dbSDimitry Andric   return const_cast<Value *>(Op);
354*5ffd83dbSDimitry Andric }
355*5ffd83dbSDimitry Andric 
356*5ffd83dbSDimitry Andric // See if this BO is reachable from this Phi by walking forward through single
357*5ffd83dbSDimitry Andric // use BinaryOperators with the same opcode. If we get back then we know we've
358*5ffd83dbSDimitry Andric // found a loop and it is safe to step through this Add to find more leaves.
359*5ffd83dbSDimitry Andric static bool isReachableFromPHI(PHINode *Phi, BinaryOperator *BO) {
360*5ffd83dbSDimitry Andric   // The PHI itself should only have one use.
361*5ffd83dbSDimitry Andric   if (!Phi->hasOneUse())
362*5ffd83dbSDimitry Andric     return false;
363*5ffd83dbSDimitry Andric 
364*5ffd83dbSDimitry Andric   Instruction *U = cast<Instruction>(*Phi->user_begin());
365*5ffd83dbSDimitry Andric   if (U == BO)
366*5ffd83dbSDimitry Andric     return true;
367*5ffd83dbSDimitry Andric 
368*5ffd83dbSDimitry Andric   while (U->hasOneUse() && U->getOpcode() == BO->getOpcode())
369*5ffd83dbSDimitry Andric     U = cast<Instruction>(*U->user_begin());
370*5ffd83dbSDimitry Andric 
371*5ffd83dbSDimitry Andric   return U == BO;
372*5ffd83dbSDimitry Andric }
373*5ffd83dbSDimitry Andric 
374*5ffd83dbSDimitry Andric // Collect all the leaves of the tree of adds that feeds into the horizontal
375*5ffd83dbSDimitry Andric // reduction. Root is the Value that is used by the horizontal reduction.
376*5ffd83dbSDimitry Andric // We look through single use phis, single use adds, or adds that are used by
377*5ffd83dbSDimitry Andric // a phi that forms a loop with the add.
378*5ffd83dbSDimitry Andric static void collectLeaves(Value *Root, SmallVectorImpl<Instruction *> &Leaves) {
379*5ffd83dbSDimitry Andric   SmallPtrSet<Value *, 8> Visited;
380*5ffd83dbSDimitry Andric   SmallVector<Value *, 8> Worklist;
381*5ffd83dbSDimitry Andric   Worklist.push_back(Root);
382*5ffd83dbSDimitry Andric 
383*5ffd83dbSDimitry Andric   while (!Worklist.empty()) {
384*5ffd83dbSDimitry Andric     Value *V = Worklist.pop_back_val();
385*5ffd83dbSDimitry Andric      if (!Visited.insert(V).second)
386*5ffd83dbSDimitry Andric        continue;
387*5ffd83dbSDimitry Andric 
388*5ffd83dbSDimitry Andric     if (auto *PN = dyn_cast<PHINode>(V)) {
389*5ffd83dbSDimitry Andric       // PHI node should have single use unless it is the root node, then it
390*5ffd83dbSDimitry Andric       // has 2 uses.
391*5ffd83dbSDimitry Andric       if (!PN->hasNUses(PN == Root ? 2 : 1))
392*5ffd83dbSDimitry Andric         break;
393*5ffd83dbSDimitry Andric 
394*5ffd83dbSDimitry Andric       // Push incoming values to the worklist.
395*5ffd83dbSDimitry Andric       for (Value *InV : PN->incoming_values())
396*5ffd83dbSDimitry Andric         Worklist.push_back(InV);
397*5ffd83dbSDimitry Andric 
398*5ffd83dbSDimitry Andric       continue;
399*5ffd83dbSDimitry Andric     }
400*5ffd83dbSDimitry Andric 
401*5ffd83dbSDimitry Andric     if (auto *BO = dyn_cast<BinaryOperator>(V)) {
402*5ffd83dbSDimitry Andric       if (BO->getOpcode() == Instruction::Add) {
403*5ffd83dbSDimitry Andric         // Simple case. Single use, just push its operands to the worklist.
404*5ffd83dbSDimitry Andric         if (BO->hasNUses(BO == Root ? 2 : 1)) {
405*5ffd83dbSDimitry Andric           for (Value *Op : BO->operands())
406*5ffd83dbSDimitry Andric             Worklist.push_back(Op);
407*5ffd83dbSDimitry Andric           continue;
408*5ffd83dbSDimitry Andric         }
409*5ffd83dbSDimitry Andric 
410*5ffd83dbSDimitry Andric         // If there is additional use, make sure it is an unvisited phi that
411*5ffd83dbSDimitry Andric         // gets us back to this node.
412*5ffd83dbSDimitry Andric         if (BO->hasNUses(BO == Root ? 3 : 2)) {
413*5ffd83dbSDimitry Andric           PHINode *PN = nullptr;
414*5ffd83dbSDimitry Andric           for (auto *U : Root->users())
415*5ffd83dbSDimitry Andric             if (auto *P = dyn_cast<PHINode>(U))
416*5ffd83dbSDimitry Andric               if (!Visited.count(P))
417*5ffd83dbSDimitry Andric                 PN = P;
418*5ffd83dbSDimitry Andric 
419*5ffd83dbSDimitry Andric           // If we didn't find a 2-input PHI then this isn't a case we can
420*5ffd83dbSDimitry Andric           // handle.
421*5ffd83dbSDimitry Andric           if (!PN || PN->getNumIncomingValues() != 2)
422*5ffd83dbSDimitry Andric             continue;
423*5ffd83dbSDimitry Andric 
424*5ffd83dbSDimitry Andric           // Walk forward from this phi to see if it reaches back to this add.
425*5ffd83dbSDimitry Andric           if (!isReachableFromPHI(PN, BO))
426*5ffd83dbSDimitry Andric             continue;
427*5ffd83dbSDimitry Andric 
428*5ffd83dbSDimitry Andric           // The phi forms a loop with this Add, push its operands.
429*5ffd83dbSDimitry Andric           for (Value *Op : BO->operands())
430*5ffd83dbSDimitry Andric             Worklist.push_back(Op);
431*5ffd83dbSDimitry Andric         }
432*5ffd83dbSDimitry Andric       }
433*5ffd83dbSDimitry Andric     }
434*5ffd83dbSDimitry Andric 
435*5ffd83dbSDimitry Andric     // Not an add or phi, make it a leaf.
436*5ffd83dbSDimitry Andric     if (auto *I = dyn_cast<Instruction>(V)) {
437*5ffd83dbSDimitry Andric       if (!V->hasNUses(I == Root ? 2 : 1))
438*5ffd83dbSDimitry Andric         continue;
439*5ffd83dbSDimitry Andric 
440*5ffd83dbSDimitry Andric       // Add this as a leaf.
441*5ffd83dbSDimitry Andric       Leaves.push_back(I);
442*5ffd83dbSDimitry Andric     }
443*5ffd83dbSDimitry Andric   }
444*5ffd83dbSDimitry Andric }
445*5ffd83dbSDimitry Andric 
446*5ffd83dbSDimitry Andric bool X86PartialReduction::runOnFunction(Function &F) {
447*5ffd83dbSDimitry Andric   if (skipFunction(F))
448*5ffd83dbSDimitry Andric     return false;
449*5ffd83dbSDimitry Andric 
450*5ffd83dbSDimitry Andric   auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
451*5ffd83dbSDimitry Andric   if (!TPC)
452*5ffd83dbSDimitry Andric     return false;
453*5ffd83dbSDimitry Andric 
454*5ffd83dbSDimitry Andric   auto &TM = TPC->getTM<X86TargetMachine>();
455*5ffd83dbSDimitry Andric   ST = TM.getSubtargetImpl(F);
456*5ffd83dbSDimitry Andric 
457*5ffd83dbSDimitry Andric   DL = &F.getParent()->getDataLayout();
458*5ffd83dbSDimitry Andric 
459*5ffd83dbSDimitry Andric   bool MadeChange = false;
460*5ffd83dbSDimitry Andric   for (auto &BB : F) {
461*5ffd83dbSDimitry Andric     for (auto &I : BB) {
462*5ffd83dbSDimitry Andric       auto *EE = dyn_cast<ExtractElementInst>(&I);
463*5ffd83dbSDimitry Andric       if (!EE)
464*5ffd83dbSDimitry Andric         continue;
465*5ffd83dbSDimitry Andric 
466*5ffd83dbSDimitry Andric       // First find a reduction tree.
467*5ffd83dbSDimitry Andric       // FIXME: Do we need to handle other opcodes than Add?
468*5ffd83dbSDimitry Andric       Value *Root = matchAddReduction(*EE);
469*5ffd83dbSDimitry Andric       if (!Root)
470*5ffd83dbSDimitry Andric         continue;
471*5ffd83dbSDimitry Andric 
472*5ffd83dbSDimitry Andric       SmallVector<Instruction *, 8> Leaves;
473*5ffd83dbSDimitry Andric       collectLeaves(Root, Leaves);
474*5ffd83dbSDimitry Andric 
475*5ffd83dbSDimitry Andric       for (Instruction *I : Leaves) {
476*5ffd83dbSDimitry Andric         if (tryMAddReplacement(I)) {
477*5ffd83dbSDimitry Andric           MadeChange = true;
478*5ffd83dbSDimitry Andric           continue;
479*5ffd83dbSDimitry Andric         }
480*5ffd83dbSDimitry Andric 
481*5ffd83dbSDimitry Andric         // Don't do SAD matching on the root node. SelectionDAG already
482*5ffd83dbSDimitry Andric         // has support for that and currently generates better code.
483*5ffd83dbSDimitry Andric         if (I != Root && trySADReplacement(I))
484*5ffd83dbSDimitry Andric           MadeChange = true;
485*5ffd83dbSDimitry Andric       }
486*5ffd83dbSDimitry Andric     }
487*5ffd83dbSDimitry Andric   }
488*5ffd83dbSDimitry Andric 
489*5ffd83dbSDimitry Andric   return MadeChange;
490*5ffd83dbSDimitry Andric }
491