xref: /llvm-project/llvm/lib/Target/ARM/MVETailPredication.cpp (revision bcbd26bfe61a35e31b1f7e98b5761a1055273b69)
1 //===- MVETailPredication.cpp - MVE Tail Predication ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead
11 /// branches to help accelerate DSP applications. These two extensions,
12 /// combined with a new form of predication called tail-predication, can be used
13 /// to provide implicit vector predication within a low-overhead loop.
14 /// This is implicit because the predicate of active/inactive lanes is
15 /// calculated by hardware, and thus does not need to be explicitly passed
16 /// to vector instructions. The instructions responsible for this are the
17 /// DLSTP and WLSTP instructions, which setup a tail-predicated loop and the
18 /// the total number of data elements processed by the loop. The loop-end
19 /// LETP instruction is responsible for decrementing and setting the remaining
20 /// elements to be processed and generating the mask of active lanes.
21 ///
22 /// The HardwareLoops pass inserts intrinsics identifying loops that the
23 /// backend will attempt to convert into a low-overhead loop. The vectorizer is
24 /// responsible for generating a vectorized loop in which the lanes are
25 /// predicated upon the iteration counter. This pass looks at these predicated
26 /// vector loops, that are targets for low-overhead loops, and prepares it for
27 /// code generation. Once the vectorizer has produced a masked loop, there's a
28 /// couple of final forms:
29 /// - A tail-predicated loop, with implicit predication.
30 /// - A loop containing multiple VCPT instructions, predicating multiple VPT
31 ///   blocks of instructions operating on different vector types.
32 ///
33 /// This pass:
34 /// 1) Pattern matches the scalar iteration count produced by the vectoriser.
35 ///    The scalar loop iteration count represents the number of elements to be
36 ///    processed.
37 ///    TODO: this could be emitted using an intrinsic, similar to the hardware
38 ///    loop intrinsics, so that we don't need to pattern match this here.
39 /// 2) Inserts the VCTP intrinsic to represent the effect of
40 ///    tail predication. This will be picked up by the ARM Low-overhead loop
41 ///    pass, which performs the final transformation to a DLSTP or WLSTP
42 ///    tail-predicated loop.
43 
44 #include "ARM.h"
45 #include "ARMSubtarget.h"
46 #include "llvm/Analysis/LoopInfo.h"
47 #include "llvm/Analysis/LoopPass.h"
48 #include "llvm/Analysis/ScalarEvolution.h"
49 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
50 #include "llvm/Analysis/TargetTransformInfo.h"
51 #include "llvm/CodeGen/TargetPassConfig.h"
52 #include "llvm/IR/IRBuilder.h"
53 #include "llvm/IR/Instructions.h"
54 #include "llvm/IR/IntrinsicsARM.h"
55 #include "llvm/IR/PatternMatch.h"
56 #include "llvm/InitializePasses.h"
57 #include "llvm/Support/Debug.h"
58 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
59 #include "llvm/Transforms/Utils/LoopUtils.h"
60 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
61 
62 using namespace llvm;
63 
64 #define DEBUG_TYPE "mve-tail-predication"
65 #define DESC "Transform predicated vector loops to use MVE tail predication"
66 
67 cl::opt<bool>
68 DisableTailPredication("disable-mve-tail-predication", cl::Hidden,
69                        cl::init(true),
70                        cl::desc("Disable MVE Tail Predication"));
71 namespace {
72 
73 // Bookkeeping for pattern matching the loop trip count and the number of
74 // elements processed by the loop.
75 struct TripCountPattern {
76   // An icmp instruction that calculates a predicate of active/inactive lanes
77   // used by the masked loads/stores.
78   Instruction *Predicate = nullptr;
79 
80   // The add instruction that increments the IV.
81   Value *TripCount = nullptr;
82 
83   // The number of elements processed by the vector loop.
84   Value *NumElements = nullptr;
85 
86   // Other instructions in the icmp chain that calculate the predicate.
87   FixedVectorType *VecTy = nullptr;
88   Instruction *Shuffle = nullptr;
89   Instruction *Induction = nullptr;
90 
91   TripCountPattern(Instruction *P, Value *TC, FixedVectorType *VT)
92       : Predicate(P), TripCount(TC), VecTy(VT){};
93 };
94 
95 class MVETailPredication : public LoopPass {
96   SmallVector<IntrinsicInst*, 4> MaskedInsts;
97   Loop *L = nullptr;
98   LoopInfo *LI = nullptr;
99   const DataLayout *DL;
100   DominatorTree *DT = nullptr;
101   ScalarEvolution *SE = nullptr;
102   TargetTransformInfo *TTI = nullptr;
103   TargetLibraryInfo *TLI = nullptr;
104   bool ClonedVCTPInExitBlock = false;
105 
106 public:
107   static char ID;
108 
109   MVETailPredication() : LoopPass(ID) { }
110 
111   void getAnalysisUsage(AnalysisUsage &AU) const override {
112     AU.addRequired<ScalarEvolutionWrapperPass>();
113     AU.addRequired<LoopInfoWrapperPass>();
114     AU.addRequired<TargetPassConfig>();
115     AU.addRequired<TargetTransformInfoWrapperPass>();
116     AU.addRequired<DominatorTreeWrapperPass>();
117     AU.addRequired<TargetLibraryInfoWrapperPass>();
118     AU.addPreserved<LoopInfoWrapperPass>();
119     AU.setPreservesCFG();
120   }
121 
122   bool runOnLoop(Loop *L, LPPassManager&) override;
123 
124 private:
125   /// Perform the relevant checks on the loop and convert if possible.
126   bool TryConvert(Value *TripCount);
127 
128   /// Return whether this is a vectorized loop, that contains masked
129   /// load/stores.
130   bool IsPredicatedVectorLoop();
131 
132   /// Compute a value for the total number of elements that the predicated
133   /// loop will process if it is a runtime value.
134   bool ComputeRuntimeElements(TripCountPattern &TCP);
135 
136   /// Return whether this is the icmp that generates an i1 vector, based
137   /// upon a loop counter and a limit that is defined outside the loop,
138   /// that generates the active/inactive lanes required for tail-predication.
139   bool isTailPredicate(TripCountPattern &TCP);
140 
141   /// Insert the intrinsic to represent the effect of tail predication.
142   void InsertVCTPIntrinsic(TripCountPattern &TCP,
143                            DenseMap<Instruction *, Instruction *> &NewPredicates);
144 
145   /// Rematerialize the iteration count in exit blocks, which enables
146   /// ARMLowOverheadLoops to better optimise away loop update statements inside
147   /// hardware-loops.
148   void RematerializeIterCount();
149 };
150 
151 } // end namespace
152 
153 static bool IsDecrement(Instruction &I) {
154   auto *Call = dyn_cast<IntrinsicInst>(&I);
155   if (!Call)
156     return false;
157 
158   Intrinsic::ID ID = Call->getIntrinsicID();
159   return ID == Intrinsic::loop_decrement_reg;
160 }
161 
162 static bool IsMasked(Instruction *I) {
163   auto *Call = dyn_cast<IntrinsicInst>(I);
164   if (!Call)
165     return false;
166 
167   Intrinsic::ID ID = Call->getIntrinsicID();
168   // TODO: Support gather/scatter expand/compress operations.
169   return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load;
170 }
171 
172 void MVETailPredication::RematerializeIterCount() {
173   SmallVector<WeakTrackingVH, 16> DeadInsts;
174   SCEVExpander Rewriter(*SE, *DL, "mvetp");
175   ReplaceExitVal ReplaceExitValue = AlwaysRepl;
176 
177   formLCSSARecursively(*L, *DT, LI, SE);
178   rewriteLoopExitValues(L, LI, TLI, SE, TTI, Rewriter, DT, ReplaceExitValue,
179                         DeadInsts);
180 }
181 
182 bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
183   if (skipLoop(L) || DisableTailPredication)
184     return false;
185 
186   MaskedInsts.clear();
187   Function &F = *L->getHeader()->getParent();
188   auto &TPC = getAnalysis<TargetPassConfig>();
189   auto &TM = TPC.getTM<TargetMachine>();
190   auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
191   DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
192   LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
193   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
194   SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
195   auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>();
196   TLI = TLIP ? &TLIP->getTLI(*L->getHeader()->getParent()) : nullptr;
197   DL = &L->getHeader()->getModule()->getDataLayout();
198   this->L = L;
199 
200   // The MVE and LOB extensions are combined to enable tail-predication, but
201   // there's nothing preventing us from generating VCTP instructions for v8.1m.
202   if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) {
203     LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n");
204     return false;
205   }
206 
207   BasicBlock *Preheader = L->getLoopPreheader();
208   if (!Preheader)
209     return false;
210 
211   auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* {
212     for (auto &I : *BB) {
213       auto *Call = dyn_cast<IntrinsicInst>(&I);
214       if (!Call)
215         continue;
216 
217       Intrinsic::ID ID = Call->getIntrinsicID();
218       if (ID == Intrinsic::set_loop_iterations ||
219           ID == Intrinsic::test_set_loop_iterations)
220         return cast<IntrinsicInst>(&I);
221     }
222     return nullptr;
223   };
224 
225   // Look for the hardware loop intrinsic that sets the iteration count.
226   IntrinsicInst *Setup = FindLoopIterations(Preheader);
227 
228   // The test.set iteration could live in the pre-preheader.
229   if (!Setup) {
230     if (!Preheader->getSinglePredecessor())
231       return false;
232     Setup = FindLoopIterations(Preheader->getSinglePredecessor());
233     if (!Setup)
234       return false;
235   }
236 
237   // Search for the hardware loop intrinic that decrements the loop counter.
238   IntrinsicInst *Decrement = nullptr;
239   for (auto *BB : L->getBlocks()) {
240     for (auto &I : *BB) {
241       if (IsDecrement(I)) {
242         Decrement = cast<IntrinsicInst>(&I);
243         break;
244       }
245     }
246   }
247 
248   if (!Decrement)
249     return false;
250 
251   ClonedVCTPInExitBlock = false;
252   LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n"
253              << *Decrement << "\n");
254 
255   if (TryConvert(Setup->getArgOperand(0))) {
256     if (ClonedVCTPInExitBlock)
257       RematerializeIterCount();
258     return true;
259   }
260 
261   LLVM_DEBUG(dbgs() << "ARM TP: Can't tail-predicate this loop.\n");
262   return false;
263 }
264 
265 // Pattern match predicates/masks and determine if they use the loop induction
266 // variable to control the number of elements processed by the loop. If so,
267 // the loop is a candidate for tail-predication.
268 bool MVETailPredication::isTailPredicate(TripCountPattern &TCP) {
269   using namespace PatternMatch;
270 
271   // Pattern match the loop body and find the add with takes the index iv
272   // and adds a constant vector to it:
273   //
274   // vector.body:
275   // ..
276   // %index = phi i32
277   // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0
278   // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert,
279   //                                  <4 x i32> undef,
280   //                                  <4 x i32> zeroinitializer
281   // %induction = [add|or] <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3>
282   // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11
283   //
284   // Please note that the 'or' is equivalent to the 'and' here, this relies on
285   // BroadcastSplat being the IV which we know is a phi with 0 start and Lanes
286   // increment, which is all being checked below.
287   Instruction *BroadcastSplat = nullptr;
288   Constant *Const = nullptr;
289   if (!match(TCP.Induction,
290              m_Add(m_Instruction(BroadcastSplat), m_Constant(Const))) &&
291       !match(TCP.Induction,
292              m_Or(m_Instruction(BroadcastSplat), m_Constant(Const))))
293     return false;
294 
295   // Check that we're adding <0, 1, 2, 3...
296   if (auto *CDS = dyn_cast<ConstantDataSequential>(Const)) {
297     for (unsigned i = 0; i < CDS->getNumElements(); ++i) {
298       if (CDS->getElementAsInteger(i) != i)
299         return false;
300     }
301   } else
302     return false;
303 
304   Instruction *Insert = nullptr;
305   // The shuffle which broadcasts the index iv into a vector.
306   if (!match(BroadcastSplat,
307              m_ShuffleVector(m_Instruction(Insert), m_Undef(), m_ZeroMask())))
308     return false;
309 
310   // The insert element which initialises a vector with the index iv.
311   Instruction *IV = nullptr;
312   if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(IV), m_Zero())))
313     return false;
314 
315   // The index iv.
316   auto *Phi = dyn_cast<PHINode>(IV);
317   if (!Phi)
318     return false;
319 
320   // TODO: Don't think we need to check the entry value.
321   Value *OnEntry = Phi->getIncomingValueForBlock(L->getLoopPreheader());
322   if (!match(OnEntry, m_Zero()))
323     return false;
324 
325   Value *InLoop = Phi->getIncomingValueForBlock(L->getLoopLatch());
326   unsigned Lanes = cast<FixedVectorType>(Insert->getType())->getNumElements();
327 
328   Instruction *LHS = nullptr;
329   if (!match(InLoop, m_Add(m_Instruction(LHS), m_SpecificInt(Lanes))))
330     return false;
331 
332   return LHS == Phi;
333 }
334 
335 static FixedVectorType *getVectorType(IntrinsicInst *I) {
336   unsigned TypeOp = I->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1;
337   auto *PtrTy = cast<PointerType>(I->getOperand(TypeOp)->getType());
338   return cast<FixedVectorType>(PtrTy->getElementType());
339 }
340 
341 bool MVETailPredication::IsPredicatedVectorLoop() {
342   // Check that the loop contains at least one masked load/store intrinsic.
343   // We only support 'normal' vector instructions - other than masked
344   // load/stores.
345   for (auto *BB : L->getBlocks()) {
346     for (auto &I : *BB) {
347       if (IsMasked(&I)) {
348         FixedVectorType *VecTy = getVectorType(cast<IntrinsicInst>(&I));
349         unsigned Lanes = VecTy->getNumElements();
350         unsigned ElementWidth = VecTy->getScalarSizeInBits();
351         // MVE vectors are 128-bit, but don't support 128 x i1.
352         // TODO: Can we support vectors larger than 128-bits?
353         unsigned MaxWidth = TTI->getRegisterBitWidth(true);
354         if (Lanes * ElementWidth > MaxWidth || Lanes == MaxWidth)
355           return false;
356         MaskedInsts.push_back(cast<IntrinsicInst>(&I));
357       } else if (auto *Int = dyn_cast<IntrinsicInst>(&I)) {
358         if (Int->getIntrinsicID() == Intrinsic::fma)
359           continue;
360         for (auto &U : Int->args()) {
361           if (isa<VectorType>(U->getType()))
362             return false;
363         }
364       }
365     }
366   }
367 
368   return !MaskedInsts.empty();
369 }
370 
371 // Pattern match the predicate, which is an icmp with a constant vector of this
372 // form:
373 //
374 //   icmp ult <4 x i32> %induction, <i32 32002, i32 32002, i32 32002, i32 32002>
375 //
376 // and return the constant, i.e. 32002 in this example. This is assumed to be
377 // the scalar loop iteration count: the number of loop elements by the
378 // the vector loop. Further checks are performed in function isTailPredicate(),
379 // to verify 'induction' behaves as an induction variable.
380 //
381 static bool ComputeConstElements(TripCountPattern &TCP) {
382   if (!dyn_cast<ConstantInt>(TCP.TripCount))
383     return false;
384 
385   ConstantInt *VF = ConstantInt::get(
386       cast<IntegerType>(TCP.TripCount->getType()), TCP.VecTy->getNumElements());
387   using namespace PatternMatch;
388   CmpInst::Predicate CC;
389 
390   if (!match(TCP.Predicate, m_ICmp(CC, m_Instruction(TCP.Induction),
391                                    m_AnyIntegralConstant())) ||
392       CC != ICmpInst::ICMP_ULT)
393     return false;
394 
395   LLVM_DEBUG(dbgs() << "ARM TP: icmp with constants: "; TCP.Predicate->dump(););
396   Value *ConstVec = TCP.Predicate->getOperand(1);
397 
398   auto *CDS = dyn_cast<ConstantDataSequential>(ConstVec);
399   if (!CDS || CDS->getNumElements() != VF->getSExtValue())
400     return false;
401 
402   if ((TCP.NumElements = CDS->getSplatValue())) {
403     assert(dyn_cast<ConstantInt>(TCP.NumElements)->getSExtValue() %
404                    VF->getSExtValue() !=
405                0 &&
406            "tail-predication: trip count should not be a multiple of the VF");
407     LLVM_DEBUG(dbgs() << "ARM TP: Found const elem count: " << *TCP.NumElements
408                       << "\n");
409     return true;
410   }
411   return false;
412 }
413 
414 // Pattern match the loop iteration count setup:
415 //
416 // %trip.count.minus.1 = add i32 %N, -1
417 // %broadcast.splatinsert10 = insertelement <4 x i32> undef,
418 //                                          i32 %trip.count.minus.1, i32 0
419 // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10,
420 //                                    <4 x i32> undef,
421 //                                    <4 x i32> zeroinitializer
422 // ..
423 // vector.body:
424 // ..
425 //
426 static bool MatchElemCountLoopSetup(Loop *L, Instruction *Shuffle,
427                                     Value *NumElements) {
428   using namespace PatternMatch;
429   Instruction *Insert = nullptr;
430 
431   if (!match(Shuffle,
432              m_ShuffleVector(m_Instruction(Insert), m_Undef(), m_ZeroMask())))
433     return false;
434 
435   // Insert the limit into a vector.
436   Instruction *BECount = nullptr;
437   if (!match(Insert,
438              m_InsertElement(m_Undef(), m_Instruction(BECount), m_Zero())))
439     return false;
440 
441   // The limit calculation, backedge count.
442   Value *TripCount = nullptr;
443   if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes())))
444     return false;
445 
446   if (TripCount != NumElements || !L->isLoopInvariant(BECount))
447     return false;
448 
449   return true;
450 }
451 
452 bool MVETailPredication::ComputeRuntimeElements(TripCountPattern &TCP) {
453   using namespace PatternMatch;
454   const SCEV *TripCountSE = SE->getSCEV(TCP.TripCount);
455   ConstantInt *VF = ConstantInt::get(
456       cast<IntegerType>(TCP.TripCount->getType()), TCP.VecTy->getNumElements());
457 
458   if (VF->equalsInt(1))
459     return false;
460 
461   CmpInst::Predicate Pred;
462   if (!match(TCP.Predicate, m_ICmp(Pred, m_Instruction(TCP.Induction),
463                                    m_Instruction(TCP.Shuffle))) ||
464       Pred != ICmpInst::ICMP_ULE)
465     return false;
466 
467   LLVM_DEBUG(dbgs() << "Computing number of elements for vector trip count: ";
468              TCP.TripCount->dump());
469 
470   // Otherwise, continue and try to pattern match the vector iteration
471   // count expression
472   auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr * {
473     if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) {
474       if (Const->getAPInt() != -VF->getValue())
475         return nullptr;
476     } else
477       return nullptr;
478     return dyn_cast<SCEVMulExpr>(S->getOperand(1));
479   };
480 
481   auto VisitMul = [&](const SCEVMulExpr *S) -> const SCEVUDivExpr * {
482     if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) {
483       if (Const->getValue() != VF)
484         return nullptr;
485     } else
486       return nullptr;
487     return dyn_cast<SCEVUDivExpr>(S->getOperand(1));
488   };
489 
490   auto VisitDiv = [&](const SCEVUDivExpr *S) -> const SCEV * {
491     if (auto *Const = dyn_cast<SCEVConstant>(S->getRHS())) {
492       if (Const->getValue() != VF)
493         return nullptr;
494     } else
495       return nullptr;
496 
497     if (auto *RoundUp = dyn_cast<SCEVAddExpr>(S->getLHS())) {
498       if (auto *Const = dyn_cast<SCEVConstant>(RoundUp->getOperand(0))) {
499         if (Const->getAPInt() != (VF->getValue() - 1))
500           return nullptr;
501       } else
502         return nullptr;
503 
504       return RoundUp->getOperand(1);
505     }
506     return nullptr;
507   };
508 
509   // TODO: Can we use SCEV helpers, such as findArrayDimensions, and friends to
510   // determine the numbers of elements instead? Looks like this is what is used
511   // for delinearization, but I'm not sure if it can be applied to the
512   // vectorized form - at least not without a bit more work than I feel
513   // comfortable with.
514 
515   // Search for Elems in the following SCEV:
516   // (1 + ((-VF + (VF * (((VF - 1) + %Elems) /u VF))<nuw>) /u VF))<nuw><nsw>
517   const SCEV *Elems = nullptr;
518   if (auto *TC = dyn_cast<SCEVAddExpr>(TripCountSE))
519     if (auto *Div = dyn_cast<SCEVUDivExpr>(TC->getOperand(1)))
520       if (auto *Add = dyn_cast<SCEVAddExpr>(Div->getLHS()))
521         if (auto *Mul = VisitAdd(Add))
522           if (auto *Div = VisitMul(Mul))
523             if (auto *Res = VisitDiv(Div))
524               Elems = Res;
525 
526   if (!Elems)
527     return false;
528 
529   Instruction *InsertPt = L->getLoopPreheader()->getTerminator();
530   if (!isSafeToExpandAt(Elems, InsertPt, *SE))
531     return false;
532 
533   auto DL = L->getHeader()->getModule()->getDataLayout();
534   SCEVExpander Expander(*SE, DL, "elements");
535   TCP.NumElements = Expander.expandCodeFor(Elems, Elems->getType(), InsertPt);
536 
537   if (!MatchElemCountLoopSetup(L, TCP.Shuffle, TCP.NumElements))
538     return false;
539 
540   return true;
541 }
542 
543 // Look through the exit block to see whether there's a duplicate predicate
544 // instruction. This can happen when we need to perform a select on values
545 // from the last and previous iteration. Instead of doing a straight
546 // replacement of that predicate with the vctp, clone the vctp and place it
547 // in the block. This means that the VPR doesn't have to be live into the
548 // exit block which should make it easier to convert this loop into a proper
549 // tail predicated loop.
550 static bool Cleanup(DenseMap<Instruction*, Instruction*> &NewPredicates,
551                     SetVector<Instruction*> &MaybeDead, Loop *L) {
552   BasicBlock *Exit = L->getUniqueExitBlock();
553   if (!Exit) {
554     LLVM_DEBUG(dbgs() << "ARM TP: can't find loop exit block\n");
555     return false;
556   }
557 
558   bool ClonedVCTPInExitBlock = false;
559 
560   for (auto &Pair : NewPredicates) {
561     Instruction *OldPred = Pair.first;
562     Instruction *NewPred = Pair.second;
563 
564     for (auto &I : *Exit) {
565       if (I.isSameOperationAs(OldPred)) {
566         Instruction *PredClone = NewPred->clone();
567         PredClone->insertBefore(&I);
568         I.replaceAllUsesWith(PredClone);
569         MaybeDead.insert(&I);
570         ClonedVCTPInExitBlock = true;
571         LLVM_DEBUG(dbgs() << "ARM TP: replacing: "; I.dump();
572                    dbgs() << "ARM TP: with:      "; PredClone->dump());
573         break;
574       }
575     }
576   }
577 
578   // Drop references and add operands to check for dead.
579   SmallPtrSet<Instruction*, 4> Dead;
580   while (!MaybeDead.empty()) {
581     auto *I = MaybeDead.front();
582     MaybeDead.remove(I);
583     if (I->hasNUsesOrMore(1))
584       continue;
585 
586     for (auto &U : I->operands())
587       if (auto *OpI = dyn_cast<Instruction>(U))
588         MaybeDead.insert(OpI);
589 
590     I->dropAllReferences();
591     Dead.insert(I);
592   }
593 
594   for (auto *I : Dead) {
595     LLVM_DEBUG(dbgs() << "ARM TP: removing dead insn: "; I->dump());
596     I->eraseFromParent();
597   }
598 
599   for (auto I : L->blocks())
600     DeleteDeadPHIs(I);
601 
602   return ClonedVCTPInExitBlock;
603 }
604 
605 void MVETailPredication::InsertVCTPIntrinsic(TripCountPattern &TCP,
606     DenseMap<Instruction*, Instruction*> &NewPredicates) {
607   IRBuilder<> Builder(L->getHeader()->getFirstNonPHI());
608   Module *M = L->getHeader()->getModule();
609   Type *Ty = IntegerType::get(M->getContext(), 32);
610 
611   // Insert a phi to count the number of elements processed by the loop.
612   PHINode *Processed = Builder.CreatePHI(Ty, 2);
613   Processed->addIncoming(TCP.NumElements, L->getLoopPreheader());
614 
615   // Insert the intrinsic to represent the effect of tail predication.
616   Builder.SetInsertPoint(cast<Instruction>(TCP.Predicate));
617   ConstantInt *Factor =
618     ConstantInt::get(cast<IntegerType>(Ty), TCP.VecTy->getNumElements());
619 
620   Intrinsic::ID VCTPID;
621   switch (TCP.VecTy->getNumElements()) {
622   default:
623     llvm_unreachable("unexpected number of lanes");
624   case 4:  VCTPID = Intrinsic::arm_mve_vctp32; break;
625   case 8:  VCTPID = Intrinsic::arm_mve_vctp16; break;
626   case 16: VCTPID = Intrinsic::arm_mve_vctp8; break;
627 
628     // FIXME: vctp64 currently not supported because the predicate
629     // vector wants to be <2 x i1>, but v2i1 is not a legal MVE
630     // type, so problems happen at isel time.
631     // Intrinsic::arm_mve_vctp64 exists for ACLE intrinsics
632     // purposes, but takes a v4i1 instead of a v2i1.
633   }
634   Function *VCTP = Intrinsic::getDeclaration(M, VCTPID);
635   Value *TailPredicate = Builder.CreateCall(VCTP, Processed);
636   TCP.Predicate->replaceAllUsesWith(TailPredicate);
637   NewPredicates[TCP.Predicate] = cast<Instruction>(TailPredicate);
638 
639   // Add the incoming value to the new phi.
640   // TODO: This add likely already exists in the loop.
641   Value *Remaining = Builder.CreateSub(Processed, Factor);
642   Processed->addIncoming(Remaining, L->getLoopLatch());
643   LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: "
644              << *Processed << "\n"
645              << "ARM TP: Inserted VCTP: " << *TailPredicate << "\n");
646 }
647 
648 bool MVETailPredication::TryConvert(Value *TripCount) {
649   if (!IsPredicatedVectorLoop()) {
650     LLVM_DEBUG(dbgs() << "ARM TP: no masked instructions in loop.\n");
651     return false;
652   }
653 
654   LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n");
655 
656   // Walk through the masked intrinsics and try to find whether the predicate
657   // operand is generated from an induction variable.
658   SetVector<Instruction*> Predicates;
659   DenseMap<Instruction*, Instruction*> NewPredicates;
660 
661 #ifndef NDEBUG
662   // For debugging purposes, use this to indicate we have been able to
663   // pattern match the scalar loop trip count.
664   bool FoundScalarTC = false;
665 #endif
666 
667   for (auto *I : MaskedInsts) {
668     Intrinsic::ID ID = I->getIntrinsicID();
669     // First, find the icmp used by this masked load/store.
670     unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3;
671     auto *Predicate = dyn_cast<Instruction>(I->getArgOperand(PredOp));
672     if (!Predicate || Predicates.count(Predicate))
673       continue;
674 
675     // Step 1: using this icmp, now calculate the number of elements
676     // processed by this loop.
677     TripCountPattern TCP(Predicate, TripCount, getVectorType(I));
678     if (!(ComputeConstElements(TCP) || ComputeRuntimeElements(TCP)))
679       continue;
680 
681     LLVM_DEBUG(FoundScalarTC = true);
682 
683     if (!isTailPredicate(TCP)) {
684       LLVM_DEBUG(dbgs() << "ARM TP: Not an icmp that generates tail predicate: "
685                         << *Predicate << "\n");
686       continue;
687     }
688 
689     LLVM_DEBUG(dbgs() << "ARM TP: Found icmp generating tail predicate: "
690                       << *Predicate << "\n");
691     Predicates.insert(Predicate);
692 
693     // Step 2: emit the VCTP intrinsic representing the effect of TP.
694     InsertVCTPIntrinsic(TCP, NewPredicates);
695   }
696 
697   if (!NewPredicates.size()) {
698       LLVM_DEBUG(if (!FoundScalarTC)
699                    dbgs() << "ARM TP: Can't determine loop itertion count\n");
700     return false;
701   }
702 
703   // Now clean up.
704   ClonedVCTPInExitBlock = Cleanup(NewPredicates, Predicates, L);
705   return true;
706 }
707 
708 Pass *llvm::createMVETailPredicationPass() {
709   return new MVETailPredication();
710 }
711 
712 char MVETailPredication::ID = 0;
713 
714 INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false)
715 INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false)
716