xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/ARM/MVETailPredication.cpp (revision 8bcb0991864975618c09697b1aca10683346d9f0)
1*8bcb0991SDimitry Andric //===- MVETailPredication.cpp - MVE Tail Predication ----------------------===//
2*8bcb0991SDimitry Andric //
3*8bcb0991SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*8bcb0991SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*8bcb0991SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*8bcb0991SDimitry Andric //
7*8bcb0991SDimitry Andric //===----------------------------------------------------------------------===//
8*8bcb0991SDimitry Andric //
9*8bcb0991SDimitry Andric /// \file
10*8bcb0991SDimitry Andric /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead
11*8bcb0991SDimitry Andric /// branches to help accelerate DSP applications. These two extensions can be
12*8bcb0991SDimitry Andric /// combined to provide implicit vector predication within a low-overhead loop.
13*8bcb0991SDimitry Andric /// The HardwareLoops pass inserts intrinsics identifying loops that the
14*8bcb0991SDimitry Andric /// backend will attempt to convert into a low-overhead loop. The vectorizer is
15*8bcb0991SDimitry Andric /// responsible for generating a vectorized loop in which the lanes are
16*8bcb0991SDimitry Andric /// predicated upon the iteration counter. This pass looks at these predicated
17*8bcb0991SDimitry Andric /// vector loops, that are targets for low-overhead loops, and prepares it for
18*8bcb0991SDimitry Andric /// code generation. Once the vectorizer has produced a masked loop, there's a
19*8bcb0991SDimitry Andric /// couple of final forms:
20*8bcb0991SDimitry Andric /// - A tail-predicated loop, with implicit predication.
21*8bcb0991SDimitry Andric /// - A loop containing multiple VCPT instructions, predicating multiple VPT
22*8bcb0991SDimitry Andric ///   blocks of instructions operating on different vector types.
23*8bcb0991SDimitry Andric 
24*8bcb0991SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
25*8bcb0991SDimitry Andric #include "llvm/Analysis/LoopPass.h"
26*8bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolution.h"
27*8bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpander.h"
28*8bcb0991SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpressions.h"
29*8bcb0991SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
30*8bcb0991SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
31*8bcb0991SDimitry Andric #include "llvm/IR/Instructions.h"
32*8bcb0991SDimitry Andric #include "llvm/IR/IRBuilder.h"
33*8bcb0991SDimitry Andric #include "llvm/IR/PatternMatch.h"
34*8bcb0991SDimitry Andric #include "llvm/Support/Debug.h"
35*8bcb0991SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h"
36*8bcb0991SDimitry Andric #include "ARM.h"
37*8bcb0991SDimitry Andric #include "ARMSubtarget.h"
38*8bcb0991SDimitry Andric 
39*8bcb0991SDimitry Andric using namespace llvm;
40*8bcb0991SDimitry Andric 
41*8bcb0991SDimitry Andric #define DEBUG_TYPE "mve-tail-predication"
42*8bcb0991SDimitry Andric #define DESC "Transform predicated vector loops to use MVE tail predication"
43*8bcb0991SDimitry Andric 
44*8bcb0991SDimitry Andric static cl::opt<bool>
45*8bcb0991SDimitry Andric DisableTailPredication("disable-mve-tail-predication", cl::Hidden,
46*8bcb0991SDimitry Andric                        cl::init(true),
47*8bcb0991SDimitry Andric                        cl::desc("Disable MVE Tail Predication"));
48*8bcb0991SDimitry Andric namespace {
49*8bcb0991SDimitry Andric 
50*8bcb0991SDimitry Andric class MVETailPredication : public LoopPass {
51*8bcb0991SDimitry Andric   SmallVector<IntrinsicInst*, 4> MaskedInsts;
52*8bcb0991SDimitry Andric   Loop *L = nullptr;
53*8bcb0991SDimitry Andric   ScalarEvolution *SE = nullptr;
54*8bcb0991SDimitry Andric   TargetTransformInfo *TTI = nullptr;
55*8bcb0991SDimitry Andric 
56*8bcb0991SDimitry Andric public:
57*8bcb0991SDimitry Andric   static char ID;
58*8bcb0991SDimitry Andric 
59*8bcb0991SDimitry Andric   MVETailPredication() : LoopPass(ID) { }
60*8bcb0991SDimitry Andric 
61*8bcb0991SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
62*8bcb0991SDimitry Andric     AU.addRequired<ScalarEvolutionWrapperPass>();
63*8bcb0991SDimitry Andric     AU.addRequired<LoopInfoWrapperPass>();
64*8bcb0991SDimitry Andric     AU.addRequired<TargetPassConfig>();
65*8bcb0991SDimitry Andric     AU.addRequired<TargetTransformInfoWrapperPass>();
66*8bcb0991SDimitry Andric     AU.addPreserved<LoopInfoWrapperPass>();
67*8bcb0991SDimitry Andric     AU.setPreservesCFG();
68*8bcb0991SDimitry Andric   }
69*8bcb0991SDimitry Andric 
70*8bcb0991SDimitry Andric   bool runOnLoop(Loop *L, LPPassManager&) override;
71*8bcb0991SDimitry Andric 
72*8bcb0991SDimitry Andric private:
73*8bcb0991SDimitry Andric 
74*8bcb0991SDimitry Andric   /// Perform the relevant checks on the loop and convert if possible.
75*8bcb0991SDimitry Andric   bool TryConvert(Value *TripCount);
76*8bcb0991SDimitry Andric 
77*8bcb0991SDimitry Andric   /// Return whether this is a vectorized loop, that contains masked
78*8bcb0991SDimitry Andric   /// load/stores.
79*8bcb0991SDimitry Andric   bool IsPredicatedVectorLoop();
80*8bcb0991SDimitry Andric 
81*8bcb0991SDimitry Andric   /// Compute a value for the total number of elements that the predicated
82*8bcb0991SDimitry Andric   /// loop will process.
83*8bcb0991SDimitry Andric   Value *ComputeElements(Value *TripCount, VectorType *VecTy);
84*8bcb0991SDimitry Andric 
85*8bcb0991SDimitry Andric   /// Is the icmp that generates an i1 vector, based upon a loop counter
86*8bcb0991SDimitry Andric   /// and a limit that is defined outside the loop.
87*8bcb0991SDimitry Andric   bool isTailPredicate(Instruction *Predicate, Value *NumElements);
88*8bcb0991SDimitry Andric };
89*8bcb0991SDimitry Andric 
90*8bcb0991SDimitry Andric } // end namespace
91*8bcb0991SDimitry Andric 
92*8bcb0991SDimitry Andric static bool IsDecrement(Instruction &I) {
93*8bcb0991SDimitry Andric   auto *Call = dyn_cast<IntrinsicInst>(&I);
94*8bcb0991SDimitry Andric   if (!Call)
95*8bcb0991SDimitry Andric     return false;
96*8bcb0991SDimitry Andric 
97*8bcb0991SDimitry Andric   Intrinsic::ID ID = Call->getIntrinsicID();
98*8bcb0991SDimitry Andric   return ID == Intrinsic::loop_decrement_reg;
99*8bcb0991SDimitry Andric }
100*8bcb0991SDimitry Andric 
101*8bcb0991SDimitry Andric static bool IsMasked(Instruction *I) {
102*8bcb0991SDimitry Andric   auto *Call = dyn_cast<IntrinsicInst>(I);
103*8bcb0991SDimitry Andric   if (!Call)
104*8bcb0991SDimitry Andric     return false;
105*8bcb0991SDimitry Andric 
106*8bcb0991SDimitry Andric   Intrinsic::ID ID = Call->getIntrinsicID();
107*8bcb0991SDimitry Andric   // TODO: Support gather/scatter expand/compress operations.
108*8bcb0991SDimitry Andric   return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load;
109*8bcb0991SDimitry Andric }
110*8bcb0991SDimitry Andric 
111*8bcb0991SDimitry Andric bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
112*8bcb0991SDimitry Andric   if (skipLoop(L) || DisableTailPredication)
113*8bcb0991SDimitry Andric     return false;
114*8bcb0991SDimitry Andric 
115*8bcb0991SDimitry Andric   Function &F = *L->getHeader()->getParent();
116*8bcb0991SDimitry Andric   auto &TPC = getAnalysis<TargetPassConfig>();
117*8bcb0991SDimitry Andric   auto &TM = TPC.getTM<TargetMachine>();
118*8bcb0991SDimitry Andric   auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
119*8bcb0991SDimitry Andric   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
120*8bcb0991SDimitry Andric   SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
121*8bcb0991SDimitry Andric   this->L = L;
122*8bcb0991SDimitry Andric 
123*8bcb0991SDimitry Andric   // The MVE and LOB extensions are combined to enable tail-predication, but
124*8bcb0991SDimitry Andric   // there's nothing preventing us from generating VCTP instructions for v8.1m.
125*8bcb0991SDimitry Andric   if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) {
126*8bcb0991SDimitry Andric     LLVM_DEBUG(dbgs() << "TP: Not a v8.1m.main+mve target.\n");
127*8bcb0991SDimitry Andric     return false;
128*8bcb0991SDimitry Andric   }
129*8bcb0991SDimitry Andric 
130*8bcb0991SDimitry Andric   BasicBlock *Preheader = L->getLoopPreheader();
131*8bcb0991SDimitry Andric   if (!Preheader)
132*8bcb0991SDimitry Andric     return false;
133*8bcb0991SDimitry Andric 
134*8bcb0991SDimitry Andric   auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* {
135*8bcb0991SDimitry Andric     for (auto &I : *BB) {
136*8bcb0991SDimitry Andric       auto *Call = dyn_cast<IntrinsicInst>(&I);
137*8bcb0991SDimitry Andric       if (!Call)
138*8bcb0991SDimitry Andric         continue;
139*8bcb0991SDimitry Andric 
140*8bcb0991SDimitry Andric       Intrinsic::ID ID = Call->getIntrinsicID();
141*8bcb0991SDimitry Andric       if (ID == Intrinsic::set_loop_iterations ||
142*8bcb0991SDimitry Andric           ID == Intrinsic::test_set_loop_iterations)
143*8bcb0991SDimitry Andric         return cast<IntrinsicInst>(&I);
144*8bcb0991SDimitry Andric     }
145*8bcb0991SDimitry Andric     return nullptr;
146*8bcb0991SDimitry Andric   };
147*8bcb0991SDimitry Andric 
148*8bcb0991SDimitry Andric   // Look for the hardware loop intrinsic that sets the iteration count.
149*8bcb0991SDimitry Andric   IntrinsicInst *Setup = FindLoopIterations(Preheader);
150*8bcb0991SDimitry Andric 
151*8bcb0991SDimitry Andric   // The test.set iteration could live in the pre- preheader.
152*8bcb0991SDimitry Andric   if (!Setup) {
153*8bcb0991SDimitry Andric     if (!Preheader->getSinglePredecessor())
154*8bcb0991SDimitry Andric       return false;
155*8bcb0991SDimitry Andric     Setup = FindLoopIterations(Preheader->getSinglePredecessor());
156*8bcb0991SDimitry Andric     if (!Setup)
157*8bcb0991SDimitry Andric       return false;
158*8bcb0991SDimitry Andric   }
159*8bcb0991SDimitry Andric 
160*8bcb0991SDimitry Andric   // Search for the hardware loop intrinic that decrements the loop counter.
161*8bcb0991SDimitry Andric   IntrinsicInst *Decrement = nullptr;
162*8bcb0991SDimitry Andric   for (auto *BB : L->getBlocks()) {
163*8bcb0991SDimitry Andric     for (auto &I : *BB) {
164*8bcb0991SDimitry Andric       if (IsDecrement(I)) {
165*8bcb0991SDimitry Andric         Decrement = cast<IntrinsicInst>(&I);
166*8bcb0991SDimitry Andric         break;
167*8bcb0991SDimitry Andric       }
168*8bcb0991SDimitry Andric     }
169*8bcb0991SDimitry Andric   }
170*8bcb0991SDimitry Andric 
171*8bcb0991SDimitry Andric   if (!Decrement)
172*8bcb0991SDimitry Andric     return false;
173*8bcb0991SDimitry Andric 
174*8bcb0991SDimitry Andric   LLVM_DEBUG(dbgs() << "TP: Running on Loop: " << *L
175*8bcb0991SDimitry Andric              << *Setup << "\n"
176*8bcb0991SDimitry Andric              << *Decrement << "\n");
177*8bcb0991SDimitry Andric   bool Changed = TryConvert(Setup->getArgOperand(0));
178*8bcb0991SDimitry Andric   return Changed;
179*8bcb0991SDimitry Andric }
180*8bcb0991SDimitry Andric 
181*8bcb0991SDimitry Andric bool MVETailPredication::isTailPredicate(Instruction *I, Value *NumElements) {
182*8bcb0991SDimitry Andric   // Look for the following:
183*8bcb0991SDimitry Andric 
184*8bcb0991SDimitry Andric   // %trip.count.minus.1 = add i32 %N, -1
185*8bcb0991SDimitry Andric   // %broadcast.splatinsert10 = insertelement <4 x i32> undef,
186*8bcb0991SDimitry Andric   //                                          i32 %trip.count.minus.1, i32 0
187*8bcb0991SDimitry Andric   // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10,
188*8bcb0991SDimitry Andric   //                                    <4 x i32> undef,
189*8bcb0991SDimitry Andric   //                                    <4 x i32> zeroinitializer
190*8bcb0991SDimitry Andric   // ...
191*8bcb0991SDimitry Andric   // ...
192*8bcb0991SDimitry Andric   // %index = phi i32
193*8bcb0991SDimitry Andric   // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0
194*8bcb0991SDimitry Andric   // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert,
195*8bcb0991SDimitry Andric   //                                  <4 x i32> undef,
196*8bcb0991SDimitry Andric   //                                  <4 x i32> zeroinitializer
197*8bcb0991SDimitry Andric   // %induction = add <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3>
198*8bcb0991SDimitry Andric   // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11
199*8bcb0991SDimitry Andric 
200*8bcb0991SDimitry Andric   // And return whether V == %pred.
201*8bcb0991SDimitry Andric 
202*8bcb0991SDimitry Andric   using namespace PatternMatch;
203*8bcb0991SDimitry Andric 
204*8bcb0991SDimitry Andric   CmpInst::Predicate Pred;
205*8bcb0991SDimitry Andric   Instruction *Shuffle = nullptr;
206*8bcb0991SDimitry Andric   Instruction *Induction = nullptr;
207*8bcb0991SDimitry Andric 
208*8bcb0991SDimitry Andric   // The vector icmp
209*8bcb0991SDimitry Andric   if (!match(I, m_ICmp(Pred, m_Instruction(Induction),
210*8bcb0991SDimitry Andric                        m_Instruction(Shuffle))) ||
211*8bcb0991SDimitry Andric       Pred != ICmpInst::ICMP_ULE || !L->isLoopInvariant(Shuffle))
212*8bcb0991SDimitry Andric     return false;
213*8bcb0991SDimitry Andric 
214*8bcb0991SDimitry Andric   // First find the stuff outside the loop which is setting up the limit
215*8bcb0991SDimitry Andric   // vector....
216*8bcb0991SDimitry Andric   // The invariant shuffle that broadcast the limit into a vector.
217*8bcb0991SDimitry Andric   Instruction *Insert = nullptr;
218*8bcb0991SDimitry Andric   if (!match(Shuffle, m_ShuffleVector(m_Instruction(Insert), m_Undef(),
219*8bcb0991SDimitry Andric                                       m_Zero())))
220*8bcb0991SDimitry Andric     return false;
221*8bcb0991SDimitry Andric 
222*8bcb0991SDimitry Andric   // Insert the limit into a vector.
223*8bcb0991SDimitry Andric   Instruction *BECount = nullptr;
224*8bcb0991SDimitry Andric   if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(BECount),
225*8bcb0991SDimitry Andric                                      m_Zero())))
226*8bcb0991SDimitry Andric     return false;
227*8bcb0991SDimitry Andric 
228*8bcb0991SDimitry Andric   // The limit calculation, backedge count.
229*8bcb0991SDimitry Andric   Value *TripCount = nullptr;
230*8bcb0991SDimitry Andric   if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes())))
231*8bcb0991SDimitry Andric     return false;
232*8bcb0991SDimitry Andric 
233*8bcb0991SDimitry Andric   if (TripCount != NumElements)
234*8bcb0991SDimitry Andric     return false;
235*8bcb0991SDimitry Andric 
236*8bcb0991SDimitry Andric   // Now back to searching inside the loop body...
237*8bcb0991SDimitry Andric   // Find the add with takes the index iv and adds a constant vector to it.
238*8bcb0991SDimitry Andric   Instruction *BroadcastSplat = nullptr;
239*8bcb0991SDimitry Andric   Constant *Const = nullptr;
240*8bcb0991SDimitry Andric   if (!match(Induction, m_Add(m_Instruction(BroadcastSplat),
241*8bcb0991SDimitry Andric                               m_Constant(Const))))
242*8bcb0991SDimitry Andric    return false;
243*8bcb0991SDimitry Andric 
244*8bcb0991SDimitry Andric   // Check that we're adding <0, 1, 2, 3...
245*8bcb0991SDimitry Andric   if (auto *CDS = dyn_cast<ConstantDataSequential>(Const)) {
246*8bcb0991SDimitry Andric     for (unsigned i = 0; i < CDS->getNumElements(); ++i) {
247*8bcb0991SDimitry Andric       if (CDS->getElementAsInteger(i) != i)
248*8bcb0991SDimitry Andric         return false;
249*8bcb0991SDimitry Andric     }
250*8bcb0991SDimitry Andric   } else
251*8bcb0991SDimitry Andric     return false;
252*8bcb0991SDimitry Andric 
253*8bcb0991SDimitry Andric   // The shuffle which broadcasts the index iv into a vector.
254*8bcb0991SDimitry Andric   if (!match(BroadcastSplat, m_ShuffleVector(m_Instruction(Insert), m_Undef(),
255*8bcb0991SDimitry Andric                                              m_Zero())))
256*8bcb0991SDimitry Andric     return false;
257*8bcb0991SDimitry Andric 
258*8bcb0991SDimitry Andric   // The insert element which initialises a vector with the index iv.
259*8bcb0991SDimitry Andric   Instruction *IV = nullptr;
260*8bcb0991SDimitry Andric   if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(IV), m_Zero())))
261*8bcb0991SDimitry Andric     return false;
262*8bcb0991SDimitry Andric 
263*8bcb0991SDimitry Andric   // The index iv.
264*8bcb0991SDimitry Andric   auto *Phi = dyn_cast<PHINode>(IV);
265*8bcb0991SDimitry Andric   if (!Phi)
266*8bcb0991SDimitry Andric     return false;
267*8bcb0991SDimitry Andric 
268*8bcb0991SDimitry Andric   // TODO: Don't think we need to check the entry value.
269*8bcb0991SDimitry Andric   Value *OnEntry = Phi->getIncomingValueForBlock(L->getLoopPreheader());
270*8bcb0991SDimitry Andric   if (!match(OnEntry, m_Zero()))
271*8bcb0991SDimitry Andric     return false;
272*8bcb0991SDimitry Andric 
273*8bcb0991SDimitry Andric   Value *InLoop = Phi->getIncomingValueForBlock(L->getLoopLatch());
274*8bcb0991SDimitry Andric   unsigned Lanes = cast<VectorType>(Insert->getType())->getNumElements();
275*8bcb0991SDimitry Andric 
276*8bcb0991SDimitry Andric   Instruction *LHS = nullptr;
277*8bcb0991SDimitry Andric   if (!match(InLoop, m_Add(m_Instruction(LHS), m_SpecificInt(Lanes))))
278*8bcb0991SDimitry Andric     return false;
279*8bcb0991SDimitry Andric 
280*8bcb0991SDimitry Andric   return LHS == Phi;
281*8bcb0991SDimitry Andric }
282*8bcb0991SDimitry Andric 
283*8bcb0991SDimitry Andric static VectorType* getVectorType(IntrinsicInst *I) {
284*8bcb0991SDimitry Andric   unsigned TypeOp = I->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1;
285*8bcb0991SDimitry Andric   auto *PtrTy = cast<PointerType>(I->getOperand(TypeOp)->getType());
286*8bcb0991SDimitry Andric   return cast<VectorType>(PtrTy->getElementType());
287*8bcb0991SDimitry Andric }
288*8bcb0991SDimitry Andric 
289*8bcb0991SDimitry Andric bool MVETailPredication::IsPredicatedVectorLoop() {
290*8bcb0991SDimitry Andric   // Check that the loop contains at least one masked load/store intrinsic.
291*8bcb0991SDimitry Andric   // We only support 'normal' vector instructions - other than masked
292*8bcb0991SDimitry Andric   // load/stores.
293*8bcb0991SDimitry Andric   for (auto *BB : L->getBlocks()) {
294*8bcb0991SDimitry Andric     for (auto &I : *BB) {
295*8bcb0991SDimitry Andric       if (IsMasked(&I)) {
296*8bcb0991SDimitry Andric         VectorType *VecTy = getVectorType(cast<IntrinsicInst>(&I));
297*8bcb0991SDimitry Andric         unsigned Lanes = VecTy->getNumElements();
298*8bcb0991SDimitry Andric         unsigned ElementWidth = VecTy->getScalarSizeInBits();
299*8bcb0991SDimitry Andric         // MVE vectors are 128-bit, but don't support 128 x i1.
300*8bcb0991SDimitry Andric         // TODO: Can we support vectors larger than 128-bits?
301*8bcb0991SDimitry Andric         unsigned MaxWidth = TTI->getRegisterBitWidth(true);
302*8bcb0991SDimitry Andric         if (Lanes * ElementWidth != MaxWidth || Lanes == MaxWidth)
303*8bcb0991SDimitry Andric           return false;
304*8bcb0991SDimitry Andric         MaskedInsts.push_back(cast<IntrinsicInst>(&I));
305*8bcb0991SDimitry Andric       } else if (auto *Int = dyn_cast<IntrinsicInst>(&I)) {
306*8bcb0991SDimitry Andric         for (auto &U : Int->args()) {
307*8bcb0991SDimitry Andric           if (isa<VectorType>(U->getType()))
308*8bcb0991SDimitry Andric             return false;
309*8bcb0991SDimitry Andric         }
310*8bcb0991SDimitry Andric       }
311*8bcb0991SDimitry Andric     }
312*8bcb0991SDimitry Andric   }
313*8bcb0991SDimitry Andric 
314*8bcb0991SDimitry Andric   return !MaskedInsts.empty();
315*8bcb0991SDimitry Andric }
316*8bcb0991SDimitry Andric 
317*8bcb0991SDimitry Andric Value* MVETailPredication::ComputeElements(Value *TripCount,
318*8bcb0991SDimitry Andric                                            VectorType *VecTy) {
319*8bcb0991SDimitry Andric   const SCEV *TripCountSE = SE->getSCEV(TripCount);
320*8bcb0991SDimitry Andric   ConstantInt *VF = ConstantInt::get(cast<IntegerType>(TripCount->getType()),
321*8bcb0991SDimitry Andric                                      VecTy->getNumElements());
322*8bcb0991SDimitry Andric 
323*8bcb0991SDimitry Andric   if (VF->equalsInt(1))
324*8bcb0991SDimitry Andric     return nullptr;
325*8bcb0991SDimitry Andric 
326*8bcb0991SDimitry Andric   // TODO: Support constant trip counts.
327*8bcb0991SDimitry Andric   auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr* {
328*8bcb0991SDimitry Andric     if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) {
329*8bcb0991SDimitry Andric       if (Const->getAPInt() != -VF->getValue())
330*8bcb0991SDimitry Andric         return nullptr;
331*8bcb0991SDimitry Andric     } else
332*8bcb0991SDimitry Andric       return nullptr;
333*8bcb0991SDimitry Andric     return dyn_cast<SCEVMulExpr>(S->getOperand(1));
334*8bcb0991SDimitry Andric   };
335*8bcb0991SDimitry Andric 
336*8bcb0991SDimitry Andric   auto VisitMul = [&](const SCEVMulExpr *S) -> const SCEVUDivExpr* {
337*8bcb0991SDimitry Andric     if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) {
338*8bcb0991SDimitry Andric       if (Const->getValue() != VF)
339*8bcb0991SDimitry Andric         return nullptr;
340*8bcb0991SDimitry Andric     } else
341*8bcb0991SDimitry Andric       return nullptr;
342*8bcb0991SDimitry Andric     return dyn_cast<SCEVUDivExpr>(S->getOperand(1));
343*8bcb0991SDimitry Andric   };
344*8bcb0991SDimitry Andric 
345*8bcb0991SDimitry Andric   auto VisitDiv = [&](const SCEVUDivExpr *S) -> const SCEV* {
346*8bcb0991SDimitry Andric     if (auto *Const = dyn_cast<SCEVConstant>(S->getRHS())) {
347*8bcb0991SDimitry Andric       if (Const->getValue() != VF)
348*8bcb0991SDimitry Andric         return nullptr;
349*8bcb0991SDimitry Andric     } else
350*8bcb0991SDimitry Andric       return nullptr;
351*8bcb0991SDimitry Andric 
352*8bcb0991SDimitry Andric     if (auto *RoundUp = dyn_cast<SCEVAddExpr>(S->getLHS())) {
353*8bcb0991SDimitry Andric       if (auto *Const = dyn_cast<SCEVConstant>(RoundUp->getOperand(0))) {
354*8bcb0991SDimitry Andric         if (Const->getAPInt() != (VF->getValue() - 1))
355*8bcb0991SDimitry Andric           return nullptr;
356*8bcb0991SDimitry Andric       } else
357*8bcb0991SDimitry Andric         return nullptr;
358*8bcb0991SDimitry Andric 
359*8bcb0991SDimitry Andric       return RoundUp->getOperand(1);
360*8bcb0991SDimitry Andric     }
361*8bcb0991SDimitry Andric     return nullptr;
362*8bcb0991SDimitry Andric   };
363*8bcb0991SDimitry Andric 
364*8bcb0991SDimitry Andric   // TODO: Can we use SCEV helpers, such as findArrayDimensions, and friends to
365*8bcb0991SDimitry Andric   // determine the numbers of elements instead? Looks like this is what is used
366*8bcb0991SDimitry Andric   // for delinearization, but I'm not sure if it can be applied to the
367*8bcb0991SDimitry Andric   // vectorized form - at least not without a bit more work than I feel
368*8bcb0991SDimitry Andric   // comfortable with.
369*8bcb0991SDimitry Andric 
370*8bcb0991SDimitry Andric   // Search for Elems in the following SCEV:
371*8bcb0991SDimitry Andric   // (1 + ((-VF + (VF * (((VF - 1) + %Elems) /u VF))<nuw>) /u VF))<nuw><nsw>
372*8bcb0991SDimitry Andric   const SCEV *Elems = nullptr;
373*8bcb0991SDimitry Andric   if (auto *TC = dyn_cast<SCEVAddExpr>(TripCountSE))
374*8bcb0991SDimitry Andric     if (auto *Div = dyn_cast<SCEVUDivExpr>(TC->getOperand(1)))
375*8bcb0991SDimitry Andric       if (auto *Add = dyn_cast<SCEVAddExpr>(Div->getLHS()))
376*8bcb0991SDimitry Andric         if (auto *Mul = VisitAdd(Add))
377*8bcb0991SDimitry Andric           if (auto *Div = VisitMul(Mul))
378*8bcb0991SDimitry Andric             if (auto *Res = VisitDiv(Div))
379*8bcb0991SDimitry Andric               Elems = Res;
380*8bcb0991SDimitry Andric 
381*8bcb0991SDimitry Andric   if (!Elems)
382*8bcb0991SDimitry Andric     return nullptr;
383*8bcb0991SDimitry Andric 
384*8bcb0991SDimitry Andric   Instruction *InsertPt = L->getLoopPreheader()->getTerminator();
385*8bcb0991SDimitry Andric   if (!isSafeToExpandAt(Elems, InsertPt, *SE))
386*8bcb0991SDimitry Andric     return nullptr;
387*8bcb0991SDimitry Andric 
388*8bcb0991SDimitry Andric   auto DL = L->getHeader()->getModule()->getDataLayout();
389*8bcb0991SDimitry Andric   SCEVExpander Expander(*SE, DL, "elements");
390*8bcb0991SDimitry Andric   return Expander.expandCodeFor(Elems, Elems->getType(), InsertPt);
391*8bcb0991SDimitry Andric }
392*8bcb0991SDimitry Andric 
393*8bcb0991SDimitry Andric // Look through the exit block to see whether there's a duplicate predicate
394*8bcb0991SDimitry Andric // instruction. This can happen when we need to perform a select on values
395*8bcb0991SDimitry Andric // from the last and previous iteration. Instead of doing a straight
396*8bcb0991SDimitry Andric // replacement of that predicate with the vctp, clone the vctp and place it
397*8bcb0991SDimitry Andric // in the block. This means that the VPR doesn't have to be live into the
398*8bcb0991SDimitry Andric // exit block which should make it easier to convert this loop into a proper
399*8bcb0991SDimitry Andric // tail predicated loop.
400*8bcb0991SDimitry Andric static void Cleanup(DenseMap<Instruction*, Instruction*> &NewPredicates,
401*8bcb0991SDimitry Andric                     SetVector<Instruction*> &MaybeDead, Loop *L) {
402*8bcb0991SDimitry Andric   if (BasicBlock *Exit = L->getUniqueExitBlock()) {
403*8bcb0991SDimitry Andric     for (auto &Pair : NewPredicates) {
404*8bcb0991SDimitry Andric       Instruction *OldPred = Pair.first;
405*8bcb0991SDimitry Andric       Instruction *NewPred = Pair.second;
406*8bcb0991SDimitry Andric 
407*8bcb0991SDimitry Andric       for (auto &I : *Exit) {
408*8bcb0991SDimitry Andric         if (I.isSameOperationAs(OldPred)) {
409*8bcb0991SDimitry Andric           Instruction *PredClone = NewPred->clone();
410*8bcb0991SDimitry Andric           PredClone->insertBefore(&I);
411*8bcb0991SDimitry Andric           I.replaceAllUsesWith(PredClone);
412*8bcb0991SDimitry Andric           MaybeDead.insert(&I);
413*8bcb0991SDimitry Andric           break;
414*8bcb0991SDimitry Andric         }
415*8bcb0991SDimitry Andric       }
416*8bcb0991SDimitry Andric     }
417*8bcb0991SDimitry Andric   }
418*8bcb0991SDimitry Andric 
419*8bcb0991SDimitry Andric   // Drop references and add operands to check for dead.
420*8bcb0991SDimitry Andric   SmallPtrSet<Instruction*, 4> Dead;
421*8bcb0991SDimitry Andric   while (!MaybeDead.empty()) {
422*8bcb0991SDimitry Andric     auto *I = MaybeDead.front();
423*8bcb0991SDimitry Andric     MaybeDead.remove(I);
424*8bcb0991SDimitry Andric     if (I->hasNUsesOrMore(1))
425*8bcb0991SDimitry Andric       continue;
426*8bcb0991SDimitry Andric 
427*8bcb0991SDimitry Andric     for (auto &U : I->operands()) {
428*8bcb0991SDimitry Andric       if (auto *OpI = dyn_cast<Instruction>(U))
429*8bcb0991SDimitry Andric         MaybeDead.insert(OpI);
430*8bcb0991SDimitry Andric     }
431*8bcb0991SDimitry Andric     I->dropAllReferences();
432*8bcb0991SDimitry Andric     Dead.insert(I);
433*8bcb0991SDimitry Andric   }
434*8bcb0991SDimitry Andric 
435*8bcb0991SDimitry Andric   for (auto *I : Dead)
436*8bcb0991SDimitry Andric     I->eraseFromParent();
437*8bcb0991SDimitry Andric 
438*8bcb0991SDimitry Andric   for (auto I : L->blocks())
439*8bcb0991SDimitry Andric     DeleteDeadPHIs(I);
440*8bcb0991SDimitry Andric }
441*8bcb0991SDimitry Andric 
442*8bcb0991SDimitry Andric bool MVETailPredication::TryConvert(Value *TripCount) {
443*8bcb0991SDimitry Andric   if (!IsPredicatedVectorLoop())
444*8bcb0991SDimitry Andric     return false;
445*8bcb0991SDimitry Andric 
446*8bcb0991SDimitry Andric   LLVM_DEBUG(dbgs() << "TP: Found predicated vector loop.\n");
447*8bcb0991SDimitry Andric 
448*8bcb0991SDimitry Andric   // Walk through the masked intrinsics and try to find whether the predicate
449*8bcb0991SDimitry Andric   // operand is generated from an induction variable.
450*8bcb0991SDimitry Andric   Module *M = L->getHeader()->getModule();
451*8bcb0991SDimitry Andric   Type *Ty = IntegerType::get(M->getContext(), 32);
452*8bcb0991SDimitry Andric   SetVector<Instruction*> Predicates;
453*8bcb0991SDimitry Andric   DenseMap<Instruction*, Instruction*> NewPredicates;
454*8bcb0991SDimitry Andric 
455*8bcb0991SDimitry Andric   for (auto *I : MaskedInsts) {
456*8bcb0991SDimitry Andric     Intrinsic::ID ID = I->getIntrinsicID();
457*8bcb0991SDimitry Andric     unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3;
458*8bcb0991SDimitry Andric     auto *Predicate = dyn_cast<Instruction>(I->getArgOperand(PredOp));
459*8bcb0991SDimitry Andric     if (!Predicate || Predicates.count(Predicate))
460*8bcb0991SDimitry Andric       continue;
461*8bcb0991SDimitry Andric 
462*8bcb0991SDimitry Andric     VectorType *VecTy = getVectorType(I);
463*8bcb0991SDimitry Andric     Value *NumElements = ComputeElements(TripCount, VecTy);
464*8bcb0991SDimitry Andric     if (!NumElements)
465*8bcb0991SDimitry Andric       continue;
466*8bcb0991SDimitry Andric 
467*8bcb0991SDimitry Andric     if (!isTailPredicate(Predicate, NumElements)) {
468*8bcb0991SDimitry Andric       LLVM_DEBUG(dbgs() << "TP: Not tail predicate: " << *Predicate <<  "\n");
469*8bcb0991SDimitry Andric       continue;
470*8bcb0991SDimitry Andric     }
471*8bcb0991SDimitry Andric 
472*8bcb0991SDimitry Andric     LLVM_DEBUG(dbgs() << "TP: Found tail predicate: " << *Predicate << "\n");
473*8bcb0991SDimitry Andric     Predicates.insert(Predicate);
474*8bcb0991SDimitry Andric 
475*8bcb0991SDimitry Andric     // Insert a phi to count the number of elements processed by the loop.
476*8bcb0991SDimitry Andric     IRBuilder<> Builder(L->getHeader()->getFirstNonPHI());
477*8bcb0991SDimitry Andric     PHINode *Processed = Builder.CreatePHI(Ty, 2);
478*8bcb0991SDimitry Andric     Processed->addIncoming(NumElements, L->getLoopPreheader());
479*8bcb0991SDimitry Andric 
480*8bcb0991SDimitry Andric     // Insert the intrinsic to represent the effect of tail predication.
481*8bcb0991SDimitry Andric     Builder.SetInsertPoint(cast<Instruction>(Predicate));
482*8bcb0991SDimitry Andric     ConstantInt *Factor =
483*8bcb0991SDimitry Andric       ConstantInt::get(cast<IntegerType>(Ty), VecTy->getNumElements());
484*8bcb0991SDimitry Andric     Intrinsic::ID VCTPID;
485*8bcb0991SDimitry Andric     switch (VecTy->getNumElements()) {
486*8bcb0991SDimitry Andric     default:
487*8bcb0991SDimitry Andric       llvm_unreachable("unexpected number of lanes");
488*8bcb0991SDimitry Andric     case 2:  VCTPID = Intrinsic::arm_vctp64; break;
489*8bcb0991SDimitry Andric     case 4:  VCTPID = Intrinsic::arm_vctp32; break;
490*8bcb0991SDimitry Andric     case 8:  VCTPID = Intrinsic::arm_vctp16; break;
491*8bcb0991SDimitry Andric     case 16: VCTPID = Intrinsic::arm_vctp8; break;
492*8bcb0991SDimitry Andric     }
493*8bcb0991SDimitry Andric     Function *VCTP = Intrinsic::getDeclaration(M, VCTPID);
494*8bcb0991SDimitry Andric     Value *TailPredicate = Builder.CreateCall(VCTP, Processed);
495*8bcb0991SDimitry Andric     Predicate->replaceAllUsesWith(TailPredicate);
496*8bcb0991SDimitry Andric     NewPredicates[Predicate] = cast<Instruction>(TailPredicate);
497*8bcb0991SDimitry Andric 
498*8bcb0991SDimitry Andric     // Add the incoming value to the new phi.
499*8bcb0991SDimitry Andric     // TODO: This add likely already exists in the loop.
500*8bcb0991SDimitry Andric     Value *Remaining = Builder.CreateSub(Processed, Factor);
501*8bcb0991SDimitry Andric     Processed->addIncoming(Remaining, L->getLoopLatch());
502*8bcb0991SDimitry Andric     LLVM_DEBUG(dbgs() << "TP: Insert processed elements phi: "
503*8bcb0991SDimitry Andric                << *Processed << "\n"
504*8bcb0991SDimitry Andric                << "TP: Inserted VCTP: " << *TailPredicate << "\n");
505*8bcb0991SDimitry Andric   }
506*8bcb0991SDimitry Andric 
507*8bcb0991SDimitry Andric   // Now clean up.
508*8bcb0991SDimitry Andric   Cleanup(NewPredicates, Predicates, L);
509*8bcb0991SDimitry Andric   return true;
510*8bcb0991SDimitry Andric }
511*8bcb0991SDimitry Andric 
512*8bcb0991SDimitry Andric Pass *llvm::createMVETailPredicationPass() {
513*8bcb0991SDimitry Andric   return new MVETailPredication();
514*8bcb0991SDimitry Andric }
515*8bcb0991SDimitry Andric 
516*8bcb0991SDimitry Andric char MVETailPredication::ID = 0;
517*8bcb0991SDimitry Andric 
518*8bcb0991SDimitry Andric INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false)
519*8bcb0991SDimitry Andric INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false)
520