xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision d7084fa34aea7ac9ddf18fbe25731d2c8d291db0)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 /// In addition to this, we also look for the presence of the VCTP instruction,
19 /// which determines whether we can generated the tail-predicated low-overhead
20 /// loop form.
21 ///
22 /// Assumptions and Dependencies:
23 /// Low-overhead loops are constructed and executed using a setup instruction:
24 /// DLS, WLS, DLSTP or WLSTP and an instruction that loops back: LE or LETP.
25 /// WLS(TP) and LE(TP) are branching instructions with a (large) limited range
26 /// but fixed polarity: WLS can only branch forwards and LE can only branch
27 /// backwards. These restrictions mean that this pass is dependent upon block
28 /// layout and block sizes, which is why it's the last pass to run. The same is
29 /// true for ConstantIslands, but this pass does not increase the size of the
30 /// basic blocks, nor does it change the CFG. Instructions are mainly removed
31 /// during the transform and pseudo instructions are replaced by real ones. In
32 /// some cases, when we have to revert to a 'normal' loop, we have to introduce
33 /// multiple instructions for a single pseudo (see RevertWhile and
34 /// RevertLoopEnd). To handle this situation, t2WhileLoopStart and t2LoopEnd
35 /// are defined to be as large as this maximum sequence of replacement
36 /// instructions.
37 ///
38 //===----------------------------------------------------------------------===//
39 
40 #include "ARM.h"
41 #include "ARMBaseInstrInfo.h"
42 #include "ARMBaseRegisterInfo.h"
43 #include "ARMBasicBlockInfo.h"
44 #include "ARMSubtarget.h"
45 #include "Thumb2InstrInfo.h"
46 #include "llvm/ADT/SetOperations.h"
47 #include "llvm/ADT/SmallSet.h"
48 #include "llvm/CodeGen/LivePhysRegs.h"
49 #include "llvm/CodeGen/MachineFunctionPass.h"
50 #include "llvm/CodeGen/MachineLoopInfo.h"
51 #include "llvm/CodeGen/MachineLoopUtils.h"
52 #include "llvm/CodeGen/MachineRegisterInfo.h"
53 #include "llvm/CodeGen/Passes.h"
54 #include "llvm/CodeGen/ReachingDefAnalysis.h"
55 #include "llvm/MC/MCInstrDesc.h"
56 
57 using namespace llvm;
58 
59 #define DEBUG_TYPE "arm-low-overhead-loops"
60 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
61 
62 namespace {
63 
64   using InstSet = SmallPtrSetImpl<MachineInstr *>;
65 
66   class PostOrderLoopTraversal {
67     MachineLoop &ML;
68     MachineLoopInfo &MLI;
69     SmallPtrSet<MachineBasicBlock*, 4> Visited;
70     SmallVector<MachineBasicBlock*, 4> Order;
71 
72   public:
73     PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
74       : ML(ML), MLI(MLI) { }
75 
76     const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
77       return Order;
78     }
79 
80     // Visit all the blocks within the loop, as well as exit blocks and any
81     // blocks properly dominating the header.
82     void ProcessLoop() {
83       std::function<void(MachineBasicBlock*)> Search = [this, &Search]
84         (MachineBasicBlock *MBB) -> void {
85         if (Visited.count(MBB))
86           return;
87 
88         Visited.insert(MBB);
89         for (auto *Succ : MBB->successors()) {
90           if (!ML.contains(Succ))
91             continue;
92           Search(Succ);
93         }
94         Order.push_back(MBB);
95       };
96 
97       // Insert exit blocks.
98       SmallVector<MachineBasicBlock*, 2> ExitBlocks;
99       ML.getExitBlocks(ExitBlocks);
100       for (auto *MBB : ExitBlocks)
101         Order.push_back(MBB);
102 
103       // Then add the loop body.
104       Search(ML.getHeader());
105 
106       // Then try the preheader and its predecessors.
107       std::function<void(MachineBasicBlock*)> GetPredecessor =
108         [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
109         Order.push_back(MBB);
110         if (MBB->pred_size() == 1)
111           GetPredecessor(*MBB->pred_begin());
112       };
113 
114       if (auto *Preheader = ML.getLoopPreheader())
115         GetPredecessor(Preheader);
116       else if (auto *Preheader = MLI.findLoopPreheader(&ML, true))
117         GetPredecessor(Preheader);
118     }
119   };
120 
121   struct PredicatedMI {
122     MachineInstr *MI = nullptr;
123     SetVector<MachineInstr*> Predicates;
124 
125   public:
126     PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) :
127       MI(I) { Predicates.insert(Preds.begin(), Preds.end()); }
128   };
129 
130   // Represent a VPT block, a list of instructions that begins with a VPST and
131   // has a maximum of four proceeding instructions. All instructions within the
132   // block are predicated upon the vpr and we allow instructions to define the
133   // vpr within in the block too.
134   class VPTBlock {
135     std::unique_ptr<PredicatedMI> VPST;
136     PredicatedMI *Divergent = nullptr;
137     SmallVector<PredicatedMI, 4> Insts;
138 
139   public:
140     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
141       VPST = std::make_unique<PredicatedMI>(MI, Preds);
142     }
143 
144     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
145       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
146       if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) {
147         Divergent = &Insts.back();
148         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
149       }
150       Insts.emplace_back(MI, Preds);
151       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
152     }
153 
154     // Have we found an instruction within the block which defines the vpr? If
155     // so, not all the instructions in the block will have the same predicate.
156     bool HasNonUniformPredicate() const {
157       return Divergent != nullptr;
158     }
159 
160     // Is the given instruction part of the predicate set controlling the entry
161     // to the block.
162     bool IsPredicatedOn(MachineInstr *MI) const {
163       return VPST->Predicates.count(MI);
164     }
165 
166     // Is the given instruction the only predicate which controls the entry to
167     // the block.
168     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
169       return IsPredicatedOn(MI) && VPST->Predicates.size() == 1;
170     }
171 
172     unsigned size() const { return Insts.size(); }
173     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
174     MachineInstr *getVPST() const { return VPST->MI; }
175     PredicatedMI *getDivergent() const { return Divergent; }
176   };
177 
178   struct LowOverheadLoop {
179 
180     MachineLoop &ML;
181     MachineLoopInfo &MLI;
182     ReachingDefAnalysis &RDA;
183     const TargetRegisterInfo &TRI;
184     MachineFunction *MF = nullptr;
185     MachineInstr *InsertPt = nullptr;
186     MachineInstr *Start = nullptr;
187     MachineInstr *Dec = nullptr;
188     MachineInstr *End = nullptr;
189     MachineInstr *VCTP = nullptr;
190     VPTBlock *CurrentBlock = nullptr;
191     SetVector<MachineInstr*> CurrentPredicate;
192     SmallVector<VPTBlock, 4> VPTBlocks;
193     SmallPtrSet<MachineInstr*, 4> ToRemove;
194     bool Revert = false;
195     bool CannotTailPredicate = false;
196 
197     LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI,
198                     ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI)
199       : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI) {
200       MF = ML.getHeader()->getParent();
201     }
202 
203     // If this is an MVE instruction, check that we know how to use tail
204     // predication with it. Record VPT blocks and return whether the
205     // instruction is valid for tail predication.
206     bool ValidateMVEInst(MachineInstr *MI);
207 
208     void AnalyseMVEInst(MachineInstr *MI) {
209       CannotTailPredicate = !ValidateMVEInst(MI);
210     }
211 
212     bool IsTailPredicationLegal() const {
213       // For now, let's keep things really simple and only support a single
214       // block for tail predication.
215       return !Revert && FoundAllComponents() && VCTP &&
216              !CannotTailPredicate && ML.getNumBlocks() == 1;
217     }
218 
219     // Check that the predication in the loop will be equivalent once we
220     // perform the conversion. Also ensure that we can provide the number
221     // of elements to the loop start instruction.
222     bool ValidateTailPredicate(MachineInstr *StartInsertPt);
223 
224     // Check that any values available outside of the loop will be the same
225     // after tail predication conversion.
226     bool ValidateLiveOuts() const;
227 
228     // Is it safe to define LR with DLS/WLS?
229     // LR can be defined if it is the operand to start, because it's the same
230     // value, or if it's going to be equivalent to the operand to Start.
231     MachineInstr *isSafeToDefineLR();
232 
233     // Check the branch targets are within range and we satisfy our
234     // restrictions.
235     void CheckLegality(ARMBasicBlockUtils *BBUtils);
236 
237     bool FoundAllComponents() const {
238       return Start && Dec && End;
239     }
240 
241     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
242 
243     // Return the loop iteration count, or the number of elements if we're tail
244     // predicating.
245     MachineOperand &getCount() {
246       return IsTailPredicationLegal() ?
247         VCTP->getOperand(1) : Start->getOperand(0);
248     }
249 
250     unsigned getStartOpcode() const {
251       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
252       if (!IsTailPredicationLegal())
253         return IsDo ? ARM::t2DLS : ARM::t2WLS;
254 
255       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
256     }
257 
258     void dump() const {
259       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
260       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
261       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
262       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
263       if (!FoundAllComponents())
264         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
265       else if (!(Start && Dec && End))
266         dbgs() << "ARM Loops: Failed to find all loop components.\n";
267     }
268   };
269 
270   class ARMLowOverheadLoops : public MachineFunctionPass {
271     MachineFunction           *MF = nullptr;
272     MachineLoopInfo           *MLI = nullptr;
273     ReachingDefAnalysis       *RDA = nullptr;
274     const ARMBaseInstrInfo    *TII = nullptr;
275     MachineRegisterInfo       *MRI = nullptr;
276     const TargetRegisterInfo  *TRI = nullptr;
277     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
278 
279   public:
280     static char ID;
281 
282     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
283 
284     void getAnalysisUsage(AnalysisUsage &AU) const override {
285       AU.setPreservesCFG();
286       AU.addRequired<MachineLoopInfo>();
287       AU.addRequired<ReachingDefAnalysis>();
288       MachineFunctionPass::getAnalysisUsage(AU);
289     }
290 
291     bool runOnMachineFunction(MachineFunction &MF) override;
292 
293     MachineFunctionProperties getRequiredProperties() const override {
294       return MachineFunctionProperties().set(
295           MachineFunctionProperties::Property::NoVRegs).set(
296           MachineFunctionProperties::Property::TracksLiveness);
297     }
298 
299     StringRef getPassName() const override {
300       return ARM_LOW_OVERHEAD_LOOPS_NAME;
301     }
302 
303   private:
304     bool ProcessLoop(MachineLoop *ML);
305 
306     bool RevertNonLoops();
307 
308     void RevertWhile(MachineInstr *MI) const;
309 
310     bool RevertLoopDec(MachineInstr *MI) const;
311 
312     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
313 
314     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
315 
316     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
317 
318     void Expand(LowOverheadLoop &LoLoop);
319 
320     void IterationCountDCE(LowOverheadLoop &LoLoop);
321   };
322 }
323 
324 char ARMLowOverheadLoops::ID = 0;
325 
326 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
327                 false, false)
328 
329 MachineInstr *LowOverheadLoop::isSafeToDefineLR() {
330   // We can define LR because LR already contains the same value.
331   if (Start->getOperand(0).getReg() == ARM::LR)
332     return Start;
333 
334   unsigned CountReg = Start->getOperand(0).getReg();
335   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
336     return MI->getOpcode() == ARM::tMOVr &&
337            MI->getOperand(0).getReg() == ARM::LR &&
338            MI->getOperand(1).getReg() == CountReg &&
339            MI->getOperand(2).getImm() == ARMCC::AL;
340    };
341 
342   MachineBasicBlock *MBB = Start->getParent();
343 
344   // Find an insertion point:
345   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
346   //   to Count before Start, we can insert at that mov.
347   if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR))
348     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
349       return LRDef;
350 
351   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
352   //   to Count after Start, we can insert at that mov.
353   if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR))
354     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
355       return LRDef;
356 
357   // We've found no suitable LR def and Start doesn't use LR directly. Can we
358   // just define LR anyway?
359   return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr;
360 }
361 
362 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
363   assert(VCTP && "VCTP instruction expected but is not set");
364   // All predication within the loop should be based on vctp. If the block
365   // isn't predicated on entry, check whether the vctp is within the block
366   // and that all other instructions are then predicated on it.
367   for (auto &Block : VPTBlocks) {
368     if (Block.IsPredicatedOn(VCTP))
369       continue;
370     if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI)) {
371       LLVM_DEBUG(dbgs() << "ARM Loops: Found unsupported diverging predicate: "
372                  << *Block.getDivergent()->MI);
373       return false;
374     }
375     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
376     for (auto &PredMI : Insts) {
377       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI))
378         continue;
379       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
380                  << " - which is predicated on:\n";
381                  for (auto *MI : PredMI.Predicates)
382                    dbgs() << "   - " << *MI);
383       return false;
384     }
385   }
386 
387   if (!ValidateLiveOuts())
388     return false;
389 
390   // For tail predication, we need to provide the number of elements, instead
391   // of the iteration count, to the loop start instruction. The number of
392   // elements is provided to the vctp instruction, so we need to check that
393   // we can use this register at InsertPt.
394   Register NumElements = VCTP->getOperand(1).getReg();
395 
396   // If the register is defined within loop, then we can't perform TP.
397   // TODO: Check whether this is just a mov of a register that would be
398   // available.
399   if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
400     LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
401     return false;
402   }
403 
404   // The element count register maybe defined after InsertPt, in which case we
405   // need to try to move either InsertPt or the def so that the [w|d]lstp can
406   // use the value.
407   // TODO: On failing to move an instruction, check if the count is provided by
408   // a mov and whether we can use the mov operand directly.
409   MachineBasicBlock *InsertBB = StartInsertPt->getParent();
410   if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) {
411     if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) {
412       if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) {
413         ElemDef->removeFromParent();
414         InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef);
415         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
416                    << *ElemDef);
417       } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) {
418         StartInsertPt->removeFromParent();
419         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
420                               StartInsertPt);
421         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
422       } else {
423         LLVM_DEBUG(dbgs() << "ARM Loops: Unable to move element count to loop "
424                    << "start instruction.\n");
425         return false;
426       }
427     }
428   }
429 
430   // Especially in the case of while loops, InsertBB may not be the
431   // preheader, so we need to check that the register isn't redefined
432   // before entering the loop.
433   auto CannotProvideElements = [this](MachineBasicBlock *MBB,
434                                       Register NumElements) {
435     // NumElements is redefined in this block.
436     if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
437       return true;
438 
439     // Don't continue searching up through multiple predecessors.
440     if (MBB->pred_size() > 1)
441       return true;
442 
443     return false;
444   };
445 
446   // First, find the block that looks like the preheader.
447   MachineBasicBlock *MBB = MLI.findLoopPreheader(&ML, true);
448   if (!MBB) {
449     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n");
450     return false;
451   }
452 
453   // Then search backwards for a def, until we get to InsertBB.
454   while (MBB != InsertBB) {
455     if (CannotProvideElements(MBB, NumElements)) {
456       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
457       return false;
458     }
459     MBB = *MBB->pred_begin();
460   }
461 
462   // Check that the value change of the element count is what we expect and
463   // that the predication will be equivalent. For this we need:
464   // NumElements = NumElements - VectorWidth. The sub will be a sub immediate
465   // and we can also allow register copies within the chain too.
466   auto IsValidSub = [](MachineInstr *MI, unsigned ExpectedVecWidth) {
467     unsigned ImmOpIdx = 0;
468     switch (MI->getOpcode()) {
469     default:
470       llvm_unreachable("unhandled sub opcode");
471     case ARM::tSUBi3:
472     case ARM::tSUBi8:
473       ImmOpIdx = 3;
474       break;
475     case ARM::t2SUBri:
476     case ARM::t2SUBri12:
477       ImmOpIdx = 2;
478       break;
479     }
480     return MI->getOperand(ImmOpIdx).getImm() == ExpectedVecWidth;
481   };
482 
483   MBB = VCTP->getParent();
484   if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(), NumElements)) {
485     SmallPtrSet<MachineInstr*, 2> ElementChain;
486     SmallPtrSet<MachineInstr*, 2> Ignore = { VCTP };
487     unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode());
488 
489     if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) {
490       bool FoundSub = false;
491 
492       for (auto *MI : ElementChain) {
493         if (isMovRegOpcode(MI->getOpcode()))
494           continue;
495 
496         if (isSubImmOpcode(MI->getOpcode())) {
497           if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth))
498             return false;
499           FoundSub = true;
500         } else
501           return false;
502       }
503 
504       LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n";
505                  for (auto *MI : ElementChain)
506                    dbgs() << " - " << *MI);
507       ToRemove.insert(ElementChain.begin(), ElementChain.end());
508     }
509   }
510   return true;
511 }
512 
513 static bool isVectorPredicated(MachineInstr *MI) {
514   int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
515   return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
516 }
517 
518 static bool isRegInClass(const MachineOperand &MO,
519                          const TargetRegisterClass *Class) {
520   return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
521 }
522 
523 // MVE 'narrowing' operate on half a lane, reading from half and writing
524 // to half, which are referred to has the top and bottom half. The other
525 // half retains its previous value.
526 static bool retainsPreviousHalfElement(const MachineInstr &MI) {
527   const MCInstrDesc &MCID = MI.getDesc();
528   uint64_t Flags = MCID.TSFlags;
529   return (Flags & ARMII::RetainsPreviousHalfElement) != 0;
530 }
531 
532 // Some MVE instructions read from the top/bottom halves of their operand(s)
533 // and generate a vector result with result elements that are double the
534 // width of the input.
535 static bool producesDoubleWidthResult(const MachineInstr &MI) {
536   const MCInstrDesc &MCID = MI.getDesc();
537   uint64_t Flags = MCID.TSFlags;
538   return (Flags & ARMII::DoubleWidthResult) != 0;
539 }
540 
541 // Can this instruction generate a non-zero result when given only zeroed
542 // operands? This allows us to know that, given operands with false bytes
543 // zeroed by masked loads, that the result will also contain zeros in those
544 // bytes.
545 static bool canGenerateNonZeros(const MachineInstr &MI) {
546 
547   // Check for instructions which can write into a larger element size,
548   // possibly writing into a previous zero'd lane.
549   if (producesDoubleWidthResult(MI))
550     return true;
551 
552   switch (MI.getOpcode()) {
553   default:
554     break;
555   // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow
556   // fp16 -> fp32 vector conversions.
557   // Instructions that perform a NOT will generate 1s from 0s.
558   case ARM::MVE_VMVN:
559   case ARM::MVE_VORN:
560   // Count leading zeros will do just that!
561   case ARM::MVE_VCLZs8:
562   case ARM::MVE_VCLZs16:
563   case ARM::MVE_VCLZs32:
564     return true;
565   }
566   return false;
567 }
568 
569 
570 // Look at its register uses to see if it only can only receive zeros
571 // into its false lanes which would then produce zeros. Also check that
572 // the output register is also defined by an FalseLaneZeros instruction
573 // so that if tail-predication happens, the lanes that aren't updated will
574 // still be zeros.
575 static bool producesFalseLaneZeros(MachineInstr &MI,
576                                    const TargetRegisterClass *QPRs,
577                                    const ReachingDefAnalysis &RDA,
578                                    InstSet &FalseLaneZeros) {
579   if (canGenerateNonZeros(MI))
580     return false;
581   for (auto &MO : MI.operands()) {
582     if (!MO.isReg() || !MO.getReg())
583       continue;
584     if (auto *OpDef = RDA.getMIOperand(&MI, MO))
585       if (FalseLaneZeros.count(OpDef))
586        continue;
587     return false;
588   }
589   LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI);
590   return true;
591 }
592 
593 bool LowOverheadLoop::ValidateLiveOuts() const {
594   // We want to find out if the tail-predicated version of this loop will
595   // produce the same values as the loop in its original form. For this to
596   // be true, the newly inserted implicit predication must not change the
597   // the (observable) results.
598   // We're doing this because many instructions in the loop will not be
599   // predicated and so the conversion from VPT predication to tail-predication
600   // can result in different values being produced; due to the tail-predication
601   // preventing many instructions from updating their falsely predicated
602   // lanes. This analysis assumes that all the instructions perform lane-wise
603   // operations and don't perform any exchanges.
604   // A masked load, whether through VPT or tail predication, will write zeros
605   // to any of the falsely predicated bytes. So, from the loads, we know that
606   // the false lanes are zeroed and here we're trying to track that those false
607   // lanes remain zero, or where they change, the differences are masked away
608   // by their user(s).
609   // All MVE loads and stores have to be predicated, so we know that any load
610   // operands, or stored results are equivalent already. Other explicitly
611   // predicated instructions will perform the same operation in the original
612   // loop and the tail-predicated form too. Because of this, we can insert
613   // loads, stores and other predicated instructions into our Predicated
614   // set and build from there.
615   const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
616   SetVector<MachineInstr *> Unknown;
617   SmallPtrSet<MachineInstr *, 4> FalseLaneZeros;
618   SmallPtrSet<MachineInstr *, 4> Predicated;
619   MachineBasicBlock *MBB = ML.getHeader();
620 
621   for (auto &MI : *MBB) {
622     const MCInstrDesc &MCID = MI.getDesc();
623     uint64_t Flags = MCID.TSFlags;
624     if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
625       continue;
626 
627     if (isVectorPredicated(&MI)) {
628       if (MI.mayLoad())
629         FalseLaneZeros.insert(&MI);
630       Predicated.insert(&MI);
631       continue;
632     }
633 
634     if (MI.getNumDefs() == 0)
635       continue;
636 
637     if (producesFalseLaneZeros(MI, QPRs, RDA, FalseLaneZeros))
638       FalseLaneZeros.insert(&MI);
639     else if (retainsPreviousHalfElement(MI))
640       return false;
641     else
642       Unknown.insert(&MI);
643   }
644 
645   auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO,
646                               SmallPtrSetImpl<MachineInstr *> &Predicated) {
647     SmallPtrSet<MachineInstr *, 2> Uses;
648     RDA.getGlobalUses(MI, MO.getReg(), Uses);
649     for (auto *Use : Uses) {
650       if (Use != MI && !Predicated.count(Use))
651         return false;
652     }
653     return true;
654   };
655 
656   // Visit the unknowns in reverse so that we can start at the values being
657   // stored and then we can work towards the leaves, hopefully adding more
658   // instructions to Predicated.
659   for (auto *MI : reverse(Unknown)) {
660     for (auto &MO : MI->operands()) {
661       if (!isRegInClass(MO, QPRs) || !MO.isDef())
662         continue;
663       if (!HasPredicatedUsers(MI, MO, Predicated)) {
664         LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : "
665                           << TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
666         return false;
667       }
668     }
669     // Any unknown false lanes have been masked away by the user(s).
670     Predicated.insert(MI);
671   }
672 
673   // Collect Q-regs that are live in the exit blocks. We don't collect scalars
674   // because they won't be affected by lane predication.
675   SmallSet<Register, 2> LiveOuts;
676   SmallVector<MachineBasicBlock *, 2> ExitBlocks;
677   ML.getExitBlocks(ExitBlocks);
678   for (auto *MBB : ExitBlocks)
679     for (const MachineBasicBlock::RegisterMaskPair &RegMask : MBB->liveins())
680       if (QPRs->contains(RegMask.PhysReg))
681         LiveOuts.insert(RegMask.PhysReg);
682 
683   // Collect the instructions in the loop body that define the live-out values.
684   SmallPtrSet<MachineInstr *, 2> LiveMIs;
685   assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
686   for (auto Reg : LiveOuts)
687     if (auto *MI = RDA.getLocalLiveOutMIDef(MBB, Reg))
688       LiveMIs.insert(MI);
689 
690   LLVM_DEBUG(dbgs() << "ARM Loops: Found loop live-outs:\n";
691              for (auto *MI : LiveMIs)
692                dbgs() << " - " << *MI);
693   // We've already validated that any VPT predication within the loop will be
694   // equivalent when we perform the predication transformation; so we know that
695   // any VPT predicated instruction is predicated upon VCTP. Any live-out
696   // instruction needs to be predicated, so check this here.
697   for (auto *MI : LiveMIs)
698     if (!isVectorPredicated(MI))
699       return false;
700 
701   return true;
702 }
703 
704 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) {
705   if (Revert)
706     return;
707 
708   if (!End->getOperand(1).isMBB())
709     report_fatal_error("Expected LoopEnd to target basic block");
710 
711   // TODO Maybe there's cases where the target doesn't have to be the header,
712   // but for now be safe and revert.
713   if (End->getOperand(1).getMBB() != ML.getHeader()) {
714     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
715     Revert = true;
716     return;
717   }
718 
719   // The WLS and LE instructions have 12-bits for the label offset. WLS
720   // requires a positive offset, while LE uses negative.
721   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) ||
722       !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) {
723     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
724     Revert = true;
725     return;
726   }
727 
728   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
729       (BBUtils->getOffsetOf(Start) >
730        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
731        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
732     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
733     Revert = true;
734     return;
735   }
736 
737   InsertPt = Revert ? nullptr : isSafeToDefineLR();
738   if (!InsertPt) {
739     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
740     Revert = true;
741     return;
742   } else
743     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
744 
745   if (!IsTailPredicationLegal()) {
746     LLVM_DEBUG(if (!VCTP)
747                  dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n";
748                dbgs() << "ARM Loops: Tail-predication is not valid.\n");
749     return;
750   }
751 
752   assert(ML.getBlocks().size() == 1 &&
753          "Shouldn't be processing a loop with more than one block");
754   CannotTailPredicate = !ValidateTailPredicate(InsertPt);
755   LLVM_DEBUG(if (CannotTailPredicate)
756              dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
757 }
758 
759 bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) {
760   if (CannotTailPredicate)
761     return false;
762 
763   // Only support a single vctp.
764   if (isVCTP(MI) && VCTP)
765     return false;
766 
767   // Start a new vpt block when we discover a vpt.
768   if (MI->getOpcode() == ARM::MVE_VPST) {
769     VPTBlocks.emplace_back(MI, CurrentPredicate);
770     CurrentBlock = &VPTBlocks.back();
771     return true;
772   } else if (isVCTP(MI))
773     VCTP = MI;
774   else if (MI->getOpcode() == ARM::MVE_VPSEL ||
775            MI->getOpcode() == ARM::MVE_VPNOT)
776     return false;
777 
778   // TODO: Allow VPSEL and VPNOT, we currently cannot because:
779   // 1) It will use the VPR as a predicate operand, but doesn't have to be
780   //    instead a VPT block, which means we can assert while building up
781   //    the VPT block because we don't find another VPST to being a new
782   //    one.
783   // 2) VPSEL still requires a VPR operand even after tail predicating,
784   //    which means we can't remove it unless there is another
785   //    instruction, such as vcmp, that can provide the VPR def.
786 
787   bool IsUse = false;
788   bool IsDef = false;
789   const MCInstrDesc &MCID = MI->getDesc();
790   for (int i = MI->getNumOperands() - 1; i >= 0; --i) {
791     const MachineOperand &MO = MI->getOperand(i);
792     if (!MO.isReg() || MO.getReg() != ARM::VPR)
793       continue;
794 
795     if (MO.isDef()) {
796       CurrentPredicate.insert(MI);
797       IsDef = true;
798     } else if (ARM::isVpred(MCID.OpInfo[i].OperandType)) {
799       CurrentBlock->addInst(MI, CurrentPredicate);
800       IsUse = true;
801     } else {
802       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
803       return false;
804     }
805   }
806 
807   // If we find a vpr def that is not already predicated on the vctp, we've
808   // got disjoint predicates that may not be equivalent when we do the
809   // conversion.
810   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
811     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
812     return false;
813   }
814 
815   uint64_t Flags = MCID.TSFlags;
816   if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
817     return true;
818 
819   // If we find an instruction that has been marked as not valid for tail
820   // predication, only allow the instruction if it's contained within a valid
821   // VPT block.
822   if ((Flags & ARMII::ValidForTailPredication) == 0 && !IsUse) {
823     LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
824     return false;
825   }
826 
827   // If the instruction is already explicitly predicated, then the conversion
828   // will be fine, but ensure that all memory operations are predicated.
829   return !IsUse && MI->mayLoadOrStore() ? false : true;
830 }
831 
832 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
833   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
834   if (!ST.hasLOB())
835     return false;
836 
837   MF = &mf;
838   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
839 
840   MLI = &getAnalysis<MachineLoopInfo>();
841   RDA = &getAnalysis<ReachingDefAnalysis>();
842   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
843   MRI = &MF->getRegInfo();
844   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
845   TRI = ST.getRegisterInfo();
846   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
847   BBUtils->computeAllBlockSizes();
848   BBUtils->adjustBBOffsetsAfter(&MF->front());
849 
850   bool Changed = false;
851   for (auto ML : *MLI) {
852     if (!ML->getParentLoop())
853       Changed |= ProcessLoop(ML);
854   }
855   Changed |= RevertNonLoops();
856   return Changed;
857 }
858 
859 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
860 
861   bool Changed = false;
862 
863   // Process inner loops first.
864   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
865     Changed |= ProcessLoop(*I);
866 
867   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
868              if (auto *Preheader = ML->getLoopPreheader())
869                dbgs() << " - " << Preheader->getName() << "\n";
870              else if (auto *Preheader = MLI->findLoopPreheader(ML))
871                dbgs() << " - " << Preheader->getName() << "\n";
872              else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
873                dbgs() << " - " << Preheader->getName() << "\n";
874              for (auto *MBB : ML->getBlocks())
875                dbgs() << " - " << MBB->getName() << "\n";
876             );
877 
878   // Search the given block for a loop start instruction. If one isn't found,
879   // and there's only one predecessor block, search that one too.
880   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
881     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
882     for (auto &MI : *MBB) {
883       if (isLoopStart(MI))
884         return &MI;
885     }
886     if (MBB->pred_size() == 1)
887       return SearchForStart(*MBB->pred_begin());
888     return nullptr;
889   };
890 
891   LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI);
892   // Search the preheader for the start intrinsic.
893   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
894   // with potentially multiple set.loop.iterations, so we need to enable this.
895   if (auto *Preheader = ML->getLoopPreheader())
896     LoLoop.Start = SearchForStart(Preheader);
897   else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
898     LoLoop.Start = SearchForStart(Preheader);
899   else
900     return false;
901 
902   // Find the low-overhead loop components and decide whether or not to fall
903   // back to a normal loop. Also look for a vctp instructions and decide
904   // whether we can convert that predicate using tail predication.
905   for (auto *MBB : reverse(ML->getBlocks())) {
906     for (auto &MI : *MBB) {
907       if (MI.isDebugValue())
908         continue;
909       else if (MI.getOpcode() == ARM::t2LoopDec)
910         LoLoop.Dec = &MI;
911       else if (MI.getOpcode() == ARM::t2LoopEnd)
912         LoLoop.End = &MI;
913       else if (isLoopStart(MI))
914         LoLoop.Start = &MI;
915       else if (MI.getDesc().isCall()) {
916         // TODO: Though the call will require LE to execute again, does this
917         // mean we should revert? Always executing LE hopefully should be
918         // faster than performing a sub,cmp,br or even subs,br.
919         LoLoop.Revert = true;
920         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
921       } else {
922         // Record VPR defs and build up their corresponding vpt blocks.
923         // Check we know how to tail predicate any mve instructions.
924         LoLoop.AnalyseMVEInst(&MI);
925       }
926     }
927   }
928 
929   LLVM_DEBUG(LoLoop.dump());
930   if (!LoLoop.FoundAllComponents()) {
931     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n");
932     return false;
933   }
934 
935   // Check that the only instruction using LoopDec is LoopEnd.
936   // TODO: Check for copy chains that really have no effect.
937   SmallPtrSet<MachineInstr*, 2> Uses;
938   RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses);
939   if (Uses.size() > 1 || !Uses.count(LoLoop.End)) {
940     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n");
941     LoLoop.Revert = true;
942   }
943   LoLoop.CheckLegality(BBUtils.get());
944   Expand(LoLoop);
945   return true;
946 }
947 
948 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
949 // beq that branches to the exit branch.
950 // TODO: We could also try to generate a cbz if the value in LR is also in
951 // another low register.
952 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
953   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
954   MachineBasicBlock *MBB = MI->getParent();
955   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
956                                     TII->get(ARM::t2CMPri));
957   MIB.add(MI->getOperand(0));
958   MIB.addImm(0);
959   MIB.addImm(ARMCC::AL);
960   MIB.addReg(ARM::NoRegister);
961 
962   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
963   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
964     ARM::tBcc : ARM::t2Bcc;
965 
966   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
967   MIB.add(MI->getOperand(1));   // branch target
968   MIB.addImm(ARMCC::EQ);        // condition code
969   MIB.addReg(ARM::CPSR);
970   MI->eraseFromParent();
971 }
972 
973 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const {
974   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
975   MachineBasicBlock *MBB = MI->getParent();
976   SmallPtrSet<MachineInstr*, 1> Ignore;
977   for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) {
978     if (I->getOpcode() == ARM::t2LoopEnd) {
979       Ignore.insert(&*I);
980       break;
981     }
982   }
983 
984   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
985   bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore);
986 
987   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
988                                     TII->get(ARM::t2SUBri));
989   MIB.addDef(ARM::LR);
990   MIB.add(MI->getOperand(1));
991   MIB.add(MI->getOperand(2));
992   MIB.addImm(ARMCC::AL);
993   MIB.addReg(0);
994 
995   if (SetFlags) {
996     MIB.addReg(ARM::CPSR);
997     MIB->getOperand(5).setIsDef(true);
998   } else
999     MIB.addReg(0);
1000 
1001   MI->eraseFromParent();
1002   return SetFlags;
1003 }
1004 
1005 // Generate a subs, or sub and cmp, and a branch instead of an LE.
1006 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
1007   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
1008 
1009   MachineBasicBlock *MBB = MI->getParent();
1010   // Create cmp
1011   if (!SkipCmp) {
1012     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
1013                                       TII->get(ARM::t2CMPri));
1014     MIB.addReg(ARM::LR);
1015     MIB.addImm(0);
1016     MIB.addImm(ARMCC::AL);
1017     MIB.addReg(ARM::NoRegister);
1018   }
1019 
1020   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1021   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1022     ARM::tBcc : ARM::t2Bcc;
1023 
1024   // Create bne
1025   MachineInstrBuilder MIB =
1026     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1027   MIB.add(MI->getOperand(1));   // branch target
1028   MIB.addImm(ARMCC::NE);        // condition code
1029   MIB.addReg(ARM::CPSR);
1030   MI->eraseFromParent();
1031 }
1032 
1033 // Perform dead code elimation on the loop iteration count setup expression.
1034 // If we are tail-predicating, the number of elements to be processed is the
1035 // operand of the VCTP instruction in the vector body, see getCount(), which is
1036 // register $r3 in this example:
1037 //
1038 //   $lr = big-itercount-expression
1039 //   ..
1040 //   t2DoLoopStart renamable $lr
1041 //   vector.body:
1042 //     ..
1043 //     $vpr = MVE_VCTP32 renamable $r3
1044 //     renamable $lr = t2LoopDec killed renamable $lr, 1
1045 //     t2LoopEnd renamable $lr, %vector.body
1046 //     tB %end
1047 //
1048 // What we would like achieve here is to replace the do-loop start pseudo
1049 // instruction t2DoLoopStart with:
1050 //
1051 //    $lr = MVE_DLSTP_32 killed renamable $r3
1052 //
1053 // Thus, $r3 which defines the number of elements, is written to $lr,
1054 // and then we want to delete the whole chain that used to define $lr,
1055 // see the comment below how this chain could look like.
1056 //
1057 void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
1058   if (!LoLoop.IsTailPredicationLegal())
1059     return;
1060 
1061   LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");
1062 
1063   MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0);
1064   if (!Def) {
1065     LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
1066     return;
1067   }
1068 
1069   // Collect and remove the users of iteration count.
1070   SmallPtrSet<MachineInstr*, 4> Killed  = { LoLoop.Start, LoLoop.Dec,
1071                                             LoLoop.End, LoLoop.InsertPt };
1072   SmallPtrSet<MachineInstr*, 2> Remove;
1073   if (RDA->isSafeToRemove(Def, Remove, Killed))
1074     LoLoop.ToRemove.insert(Remove.begin(), Remove.end());
1075   else {
1076     LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
1077     return;
1078   }
1079 
1080   // Collect the dead code and the MBBs in which they reside.
1081   RDA->collectKilledOperands(Def, Killed);
1082   SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
1083   for (auto *MI : Killed)
1084     BasicBlocks.insert(MI->getParent());
1085 
1086   // Collect IT blocks in all affected basic blocks.
1087   std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
1088   for (auto *MBB : BasicBlocks) {
1089     for (auto &MI : *MBB) {
1090       if (MI.getOpcode() != ARM::t2IT)
1091         continue;
1092       RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]);
1093     }
1094   }
1095 
1096   // If we're removing all of the instructions within an IT block, then
1097   // also remove the IT instruction.
1098   SmallPtrSet<MachineInstr*, 2> ModifiedITs;
1099   for (auto *MI : Killed) {
1100     if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) {
1101       MachineInstr *IT = RDA->getMIOperand(MI, *MO);
1102       auto &CurrentBlock = ITBlocks[IT];
1103       CurrentBlock.erase(MI);
1104       if (CurrentBlock.empty())
1105         ModifiedITs.erase(IT);
1106       else
1107         ModifiedITs.insert(IT);
1108     }
1109   }
1110 
1111   // Delete the killed instructions only if we don't have any IT blocks that
1112   // need to be modified because we need to fixup the mask.
1113   // TODO: Handle cases where IT blocks are modified.
1114   if (ModifiedITs.empty()) {
1115     LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n";
1116                for (auto *MI : Killed)
1117                  dbgs() << " - " << *MI);
1118     LoLoop.ToRemove.insert(Killed.begin(), Killed.end());
1119   } else
1120     LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n");
1121 }
1122 
1123 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
1124   LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n");
1125   // When using tail-predication, try to delete the dead code that was used to
1126   // calculate the number of loop iterations.
1127   IterationCountDCE(LoLoop);
1128 
1129   MachineInstr *InsertPt = LoLoop.InsertPt;
1130   MachineInstr *Start = LoLoop.Start;
1131   MachineBasicBlock *MBB = InsertPt->getParent();
1132   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
1133   unsigned Opc = LoLoop.getStartOpcode();
1134   MachineOperand &Count = LoLoop.getCount();
1135 
1136   MachineInstrBuilder MIB =
1137     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
1138 
1139   MIB.addDef(ARM::LR);
1140   MIB.add(Count);
1141   if (!IsDo)
1142     MIB.add(Start->getOperand(1));
1143 
1144   // If we're inserting at a mov lr, then remove it as it's redundant.
1145   if (InsertPt != Start)
1146     LoLoop.ToRemove.insert(InsertPt);
1147   LoLoop.ToRemove.insert(Start);
1148   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
1149   return &*MIB;
1150 }
1151 
1152 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
1153   auto RemovePredicate = [](MachineInstr *MI) {
1154     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
1155     if (int PIdx = llvm::findFirstVPTPredOperandIdx(*MI)) {
1156       assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then &&
1157              "Expected Then predicate!");
1158       MI->getOperand(PIdx).setImm(ARMVCC::None);
1159       MI->getOperand(PIdx+1).setReg(0);
1160     } else
1161       llvm_unreachable("trying to unpredicate a non-predicated instruction");
1162   };
1163 
1164   // There are a few scenarios which we have to fix up:
1165   // 1) A VPT block with is only predicated by the vctp and has no internal vpr
1166   //    defs.
1167   // 2) A VPT block which is only predicated by the vctp but has an internal
1168   //    vpr def.
1169   // 3) A VPT block which is predicated upon the vctp as well as another vpr
1170   //    def.
1171   // 4) A VPT block which is not predicated upon a vctp, but contains it and
1172   //    all instructions within the block are predicated upon in.
1173 
1174   for (auto &Block : LoLoop.getVPTBlocks()) {
1175     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
1176     if (Block.HasNonUniformPredicate()) {
1177       PredicatedMI *Divergent = Block.getDivergent();
1178       if (isVCTP(Divergent->MI)) {
1179         // The vctp will be removed, so the size of the vpt block needs to be
1180         // modified.
1181         uint64_t Size = getARMVPTBlockMask(Block.size() - 1);
1182         Block.getVPST()->getOperand(0).setImm(Size);
1183         LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n");
1184       } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1185         // The VPT block has a non-uniform predicate but it's entry is guarded
1186         // only by a vctp, which means we:
1187         // - Need to remove the original vpst.
1188         // - Then need to unpredicate any following instructions, until
1189         //   we come across the divergent vpr def.
1190         // - Insert a new vpst to predicate the instruction(s) that following
1191         //   the divergent vpr def.
1192         // TODO: We could be producing more VPT blocks than necessary and could
1193         // fold the newly created one into a proceeding one.
1194         for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()),
1195              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
1196           RemovePredicate(&*I);
1197 
1198         unsigned Size = 0;
1199         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
1200         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
1201         MachineInstr *InsertAt = nullptr;
1202         while (I != E) {
1203           InsertAt = &*I;
1204           ++Size;
1205           ++I;
1206         }
1207         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
1208                                           InsertAt->getDebugLoc(),
1209                                           TII->get(ARM::MVE_VPST));
1210         MIB.addImm(getARMVPTBlockMask(Size));
1211         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
1212         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
1213         LoLoop.ToRemove.insert(Block.getVPST());
1214       }
1215     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1216       // A vpt block which is only predicated upon vctp and has no internal vpr
1217       // defs:
1218       // - Remove vpst.
1219       // - Unpredicate the remaining instructions.
1220       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
1221       LoLoop.ToRemove.insert(Block.getVPST());
1222       for (auto &PredMI : Insts)
1223         RemovePredicate(PredMI.MI);
1224     }
1225   }
1226   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
1227   LoLoop.ToRemove.insert(LoLoop.VCTP);
1228 }
1229 
1230 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
1231 
1232   // Combine the LoopDec and LoopEnd instructions into LE(TP).
1233   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
1234     MachineInstr *End = LoLoop.End;
1235     MachineBasicBlock *MBB = End->getParent();
1236     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
1237       ARM::MVE_LETP : ARM::t2LEUpdate;
1238     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
1239                                       TII->get(Opc));
1240     MIB.addDef(ARM::LR);
1241     MIB.add(End->getOperand(0));
1242     MIB.add(End->getOperand(1));
1243     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
1244     LoLoop.ToRemove.insert(LoLoop.Dec);
1245     LoLoop.ToRemove.insert(End);
1246     return &*MIB;
1247   };
1248 
1249   // TODO: We should be able to automatically remove these branches before we
1250   // get here - probably by teaching analyzeBranch about the pseudo
1251   // instructions.
1252   // If there is an unconditional branch, after I, that just branches to the
1253   // next block, remove it.
1254   auto RemoveDeadBranch = [](MachineInstr *I) {
1255     MachineBasicBlock *BB = I->getParent();
1256     MachineInstr *Terminator = &BB->instr_back();
1257     if (Terminator->isUnconditionalBranch() && I != Terminator) {
1258       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
1259       if (BB->isLayoutSuccessor(Succ)) {
1260         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
1261         Terminator->eraseFromParent();
1262       }
1263     }
1264   };
1265 
1266   if (LoLoop.Revert) {
1267     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
1268       RevertWhile(LoLoop.Start);
1269     else
1270       LoLoop.Start->eraseFromParent();
1271     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec);
1272     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
1273   } else {
1274     LoLoop.Start = ExpandLoopStart(LoLoop);
1275     RemoveDeadBranch(LoLoop.Start);
1276     LoLoop.End = ExpandLoopEnd(LoLoop);
1277     RemoveDeadBranch(LoLoop.End);
1278     if (LoLoop.IsTailPredicationLegal())
1279       ConvertVPTBlocks(LoLoop);
1280     for (auto *I : LoLoop.ToRemove) {
1281       LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
1282       I->eraseFromParent();
1283     }
1284   }
1285 
1286   PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
1287   DFS.ProcessLoop();
1288   const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
1289   for (auto *MBB : PostOrder) {
1290     recomputeLiveIns(*MBB);
1291     // FIXME: For some reason, the live-in print order is non-deterministic for
1292     // our tests and I can't out why... So just sort them.
1293     MBB->sortUniqueLiveIns();
1294   }
1295 
1296   for (auto *MBB : reverse(PostOrder))
1297     recomputeLivenessFlags(*MBB);
1298 
1299   // We've moved, removed and inserted new instructions, so update RDA.
1300   RDA->reset();
1301 }
1302 
1303 bool ARMLowOverheadLoops::RevertNonLoops() {
1304   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
1305   bool Changed = false;
1306 
1307   for (auto &MBB : *MF) {
1308     SmallVector<MachineInstr*, 4> Starts;
1309     SmallVector<MachineInstr*, 4> Decs;
1310     SmallVector<MachineInstr*, 4> Ends;
1311 
1312     for (auto &I : MBB) {
1313       if (isLoopStart(I))
1314         Starts.push_back(&I);
1315       else if (I.getOpcode() == ARM::t2LoopDec)
1316         Decs.push_back(&I);
1317       else if (I.getOpcode() == ARM::t2LoopEnd)
1318         Ends.push_back(&I);
1319     }
1320 
1321     if (Starts.empty() && Decs.empty() && Ends.empty())
1322       continue;
1323 
1324     Changed = true;
1325 
1326     for (auto *Start : Starts) {
1327       if (Start->getOpcode() == ARM::t2WhileLoopStart)
1328         RevertWhile(Start);
1329       else
1330         Start->eraseFromParent();
1331     }
1332     for (auto *Dec : Decs)
1333       RevertLoopDec(Dec);
1334 
1335     for (auto *End : Ends)
1336       RevertLoopEnd(End);
1337   }
1338   return Changed;
1339 }
1340 
1341 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1342   return new ARMLowOverheadLoops();
1343 }
1344