xref: /llvm-project/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp (revision 94cacebccadf1e0821bdab6983d9f5251f73eab0)
1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo
10 /// instructions into machine operations.
11 /// The expectation is that the loop contains three pseudo instructions:
12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop
13 ///   form should be in the preheader, whereas the while form should be in the
14 ///   preheaders only predecessor.
15 /// - t2LoopDec - placed within in the loop body.
16 /// - t2LoopEnd - the loop latch terminator.
17 ///
18 /// In addition to this, we also look for the presence of the VCTP instruction,
19 /// which determines whether we can generated the tail-predicated low-overhead
20 /// loop form.
21 ///
22 /// Assumptions and Dependencies:
23 /// Low-overhead loops are constructed and executed using a setup instruction:
24 /// DLS, WLS, DLSTP or WLSTP and an instruction that loops back: LE or LETP.
25 /// WLS(TP) and LE(TP) are branching instructions with a (large) limited range
26 /// but fixed polarity: WLS can only branch forwards and LE can only branch
27 /// backwards. These restrictions mean that this pass is dependent upon block
28 /// layout and block sizes, which is why it's the last pass to run. The same is
29 /// true for ConstantIslands, but this pass does not increase the size of the
30 /// basic blocks, nor does it change the CFG. Instructions are mainly removed
31 /// during the transform and pseudo instructions are replaced by real ones. In
32 /// some cases, when we have to revert to a 'normal' loop, we have to introduce
33 /// multiple instructions for a single pseudo (see RevertWhile and
34 /// RevertLoopEnd). To handle this situation, t2WhileLoopStart and t2LoopEnd
35 /// are defined to be as large as this maximum sequence of replacement
36 /// instructions.
37 ///
38 //===----------------------------------------------------------------------===//
39 
40 #include "ARM.h"
41 #include "ARMBaseInstrInfo.h"
42 #include "ARMBaseRegisterInfo.h"
43 #include "ARMBasicBlockInfo.h"
44 #include "ARMSubtarget.h"
45 #include "Thumb2InstrInfo.h"
46 #include "llvm/ADT/SetOperations.h"
47 #include "llvm/ADT/SmallSet.h"
48 #include "llvm/CodeGen/LivePhysRegs.h"
49 #include "llvm/CodeGen/MachineFunctionPass.h"
50 #include "llvm/CodeGen/MachineLoopInfo.h"
51 #include "llvm/CodeGen/MachineLoopUtils.h"
52 #include "llvm/CodeGen/MachineRegisterInfo.h"
53 #include "llvm/CodeGen/Passes.h"
54 #include "llvm/CodeGen/ReachingDefAnalysis.h"
55 #include "llvm/MC/MCInstrDesc.h"
56 
57 using namespace llvm;
58 
59 #define DEBUG_TYPE "arm-low-overhead-loops"
60 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass"
61 
62 namespace {
63 
64   using InstSet = SmallPtrSetImpl<MachineInstr *>;
65 
66   class PostOrderLoopTraversal {
67     MachineLoop &ML;
68     MachineLoopInfo &MLI;
69     SmallPtrSet<MachineBasicBlock*, 4> Visited;
70     SmallVector<MachineBasicBlock*, 4> Order;
71 
72   public:
73     PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI)
74       : ML(ML), MLI(MLI) { }
75 
76     const SmallVectorImpl<MachineBasicBlock*> &getOrder() const {
77       return Order;
78     }
79 
80     // Visit all the blocks within the loop, as well as exit blocks and any
81     // blocks properly dominating the header.
82     void ProcessLoop() {
83       std::function<void(MachineBasicBlock*)> Search = [this, &Search]
84         (MachineBasicBlock *MBB) -> void {
85         if (Visited.count(MBB))
86           return;
87 
88         Visited.insert(MBB);
89         for (auto *Succ : MBB->successors()) {
90           if (!ML.contains(Succ))
91             continue;
92           Search(Succ);
93         }
94         Order.push_back(MBB);
95       };
96 
97       // Insert exit blocks.
98       SmallVector<MachineBasicBlock*, 2> ExitBlocks;
99       ML.getExitBlocks(ExitBlocks);
100       for (auto *MBB : ExitBlocks)
101         Order.push_back(MBB);
102 
103       // Then add the loop body.
104       Search(ML.getHeader());
105 
106       // Then try the preheader and its predecessors.
107       std::function<void(MachineBasicBlock*)> GetPredecessor =
108         [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void {
109         Order.push_back(MBB);
110         if (MBB->pred_size() == 1)
111           GetPredecessor(*MBB->pred_begin());
112       };
113 
114       if (auto *Preheader = ML.getLoopPreheader())
115         GetPredecessor(Preheader);
116       else if (auto *Preheader = MLI.findLoopPreheader(&ML, true))
117         GetPredecessor(Preheader);
118     }
119   };
120 
121   struct PredicatedMI {
122     MachineInstr *MI = nullptr;
123     SetVector<MachineInstr*> Predicates;
124 
125   public:
126     PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) :
127       MI(I) { Predicates.insert(Preds.begin(), Preds.end()); }
128   };
129 
130   // Represent a VPT block, a list of instructions that begins with a VPST and
131   // has a maximum of four proceeding instructions. All instructions within the
132   // block are predicated upon the vpr and we allow instructions to define the
133   // vpr within in the block too.
134   class VPTBlock {
135     std::unique_ptr<PredicatedMI> VPST;
136     PredicatedMI *Divergent = nullptr;
137     SmallVector<PredicatedMI, 4> Insts;
138 
139   public:
140     VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
141       VPST = std::make_unique<PredicatedMI>(MI, Preds);
142     }
143 
144     void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
145       LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
146       if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) {
147         Divergent = &Insts.back();
148         LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
149       }
150       Insts.emplace_back(MI, Preds);
151       assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
152     }
153 
154     // Have we found an instruction within the block which defines the vpr? If
155     // so, not all the instructions in the block will have the same predicate.
156     bool HasNonUniformPredicate() const {
157       return Divergent != nullptr;
158     }
159 
160     // Is the given instruction part of the predicate set controlling the entry
161     // to the block.
162     bool IsPredicatedOn(MachineInstr *MI) const {
163       return VPST->Predicates.count(MI);
164     }
165 
166     // Is the given instruction the only predicate which controls the entry to
167     // the block.
168     bool IsOnlyPredicatedOn(MachineInstr *MI) const {
169       return IsPredicatedOn(MI) && VPST->Predicates.size() == 1;
170     }
171 
172     unsigned size() const { return Insts.size(); }
173     SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
174     MachineInstr *getVPST() const { return VPST->MI; }
175     PredicatedMI *getDivergent() const { return Divergent; }
176   };
177 
178   struct LowOverheadLoop {
179 
180     MachineLoop &ML;
181     MachineLoopInfo &MLI;
182     ReachingDefAnalysis &RDA;
183     const TargetRegisterInfo &TRI;
184     MachineFunction *MF = nullptr;
185     MachineInstr *InsertPt = nullptr;
186     MachineInstr *Start = nullptr;
187     MachineInstr *Dec = nullptr;
188     MachineInstr *End = nullptr;
189     MachineInstr *VCTP = nullptr;
190     VPTBlock *CurrentBlock = nullptr;
191     SetVector<MachineInstr*> CurrentPredicate;
192     SmallVector<VPTBlock, 4> VPTBlocks;
193     SmallPtrSet<MachineInstr*, 4> ToRemove;
194     bool Revert = false;
195     bool CannotTailPredicate = false;
196 
197     LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI,
198                     ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI)
199       : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI) {
200       MF = ML.getHeader()->getParent();
201     }
202 
203     // If this is an MVE instruction, check that we know how to use tail
204     // predication with it. Record VPT blocks and return whether the
205     // instruction is valid for tail predication.
206     bool ValidateMVEInst(MachineInstr *MI);
207 
208     void AnalyseMVEInst(MachineInstr *MI) {
209       CannotTailPredicate = !ValidateMVEInst(MI);
210     }
211 
212     bool IsTailPredicationLegal() const {
213       // For now, let's keep things really simple and only support a single
214       // block for tail predication.
215       return !Revert && FoundAllComponents() && VCTP &&
216              !CannotTailPredicate && ML.getNumBlocks() == 1;
217     }
218 
219     // Check that the predication in the loop will be equivalent once we
220     // perform the conversion. Also ensure that we can provide the number
221     // of elements to the loop start instruction.
222     bool ValidateTailPredicate(MachineInstr *StartInsertPt);
223 
224     // Check that any values available outside of the loop will be the same
225     // after tail predication conversion.
226     bool ValidateLiveOuts() const;
227 
228     // Is it safe to define LR with DLS/WLS?
229     // LR can be defined if it is the operand to start, because it's the same
230     // value, or if it's going to be equivalent to the operand to Start.
231     MachineInstr *isSafeToDefineLR();
232 
233     // Check the branch targets are within range and we satisfy our
234     // restrictions.
235     void CheckLegality(ARMBasicBlockUtils *BBUtils);
236 
237     bool FoundAllComponents() const {
238       return Start && Dec && End;
239     }
240 
241     SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
242 
243     // Return the loop iteration count, or the number of elements if we're tail
244     // predicating.
245     MachineOperand &getCount() {
246       return IsTailPredicationLegal() ?
247         VCTP->getOperand(1) : Start->getOperand(0);
248     }
249 
250     unsigned getStartOpcode() const {
251       bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
252       if (!IsTailPredicationLegal())
253         return IsDo ? ARM::t2DLS : ARM::t2WLS;
254 
255       return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo);
256     }
257 
258     void dump() const {
259       if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start;
260       if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec;
261       if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;
262       if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP;
263       if (!FoundAllComponents())
264         dbgs() << "ARM Loops: Not a low-overhead loop.\n";
265       else if (!(Start && Dec && End))
266         dbgs() << "ARM Loops: Failed to find all loop components.\n";
267     }
268   };
269 
270   class ARMLowOverheadLoops : public MachineFunctionPass {
271     MachineFunction           *MF = nullptr;
272     MachineLoopInfo           *MLI = nullptr;
273     ReachingDefAnalysis       *RDA = nullptr;
274     const ARMBaseInstrInfo    *TII = nullptr;
275     MachineRegisterInfo       *MRI = nullptr;
276     const TargetRegisterInfo  *TRI = nullptr;
277     std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr;
278 
279   public:
280     static char ID;
281 
282     ARMLowOverheadLoops() : MachineFunctionPass(ID) { }
283 
284     void getAnalysisUsage(AnalysisUsage &AU) const override {
285       AU.setPreservesCFG();
286       AU.addRequired<MachineLoopInfo>();
287       AU.addRequired<ReachingDefAnalysis>();
288       MachineFunctionPass::getAnalysisUsage(AU);
289     }
290 
291     bool runOnMachineFunction(MachineFunction &MF) override;
292 
293     MachineFunctionProperties getRequiredProperties() const override {
294       return MachineFunctionProperties().set(
295           MachineFunctionProperties::Property::NoVRegs).set(
296           MachineFunctionProperties::Property::TracksLiveness);
297     }
298 
299     StringRef getPassName() const override {
300       return ARM_LOW_OVERHEAD_LOOPS_NAME;
301     }
302 
303   private:
304     bool ProcessLoop(MachineLoop *ML);
305 
306     bool RevertNonLoops();
307 
308     void RevertWhile(MachineInstr *MI) const;
309 
310     bool RevertLoopDec(MachineInstr *MI) const;
311 
312     void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const;
313 
314     void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
315 
316     MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
317 
318     void Expand(LowOverheadLoop &LoLoop);
319 
320     void IterationCountDCE(LowOverheadLoop &LoLoop);
321   };
322 }
323 
324 char ARMLowOverheadLoops::ID = 0;
325 
326 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME,
327                 false, false)
328 
329 MachineInstr *LowOverheadLoop::isSafeToDefineLR() {
330   // We can define LR because LR already contains the same value.
331   if (Start->getOperand(0).getReg() == ARM::LR)
332     return Start;
333 
334   unsigned CountReg = Start->getOperand(0).getReg();
335   auto IsMoveLR = [&CountReg](MachineInstr *MI) {
336     return MI->getOpcode() == ARM::tMOVr &&
337            MI->getOperand(0).getReg() == ARM::LR &&
338            MI->getOperand(1).getReg() == CountReg &&
339            MI->getOperand(2).getImm() == ARMCC::AL;
340    };
341 
342   MachineBasicBlock *MBB = Start->getParent();
343 
344   // Find an insertion point:
345   // - Is there a (mov lr, Count) before Start? If so, and nothing else writes
346   //   to Count before Start, we can insert at that mov.
347   if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR))
348     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
349       return LRDef;
350 
351   // - Is there a (mov lr, Count) after Start? If so, and nothing else writes
352   //   to Count after Start, we can insert at that mov.
353   if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR))
354     if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg))
355       return LRDef;
356 
357   // We've found no suitable LR def and Start doesn't use LR directly. Can we
358   // just define LR anyway?
359   return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr;
360 }
361 
362 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) {
363   assert(VCTP && "VCTP instruction expected but is not set");
364   // All predication within the loop should be based on vctp. If the block
365   // isn't predicated on entry, check whether the vctp is within the block
366   // and that all other instructions are then predicated on it.
367   for (auto &Block : VPTBlocks) {
368     if (Block.IsPredicatedOn(VCTP))
369       continue;
370     if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI)) {
371       LLVM_DEBUG(dbgs() << "ARM Loops: Found unsupported diverging predicate: "
372                  << *Block.getDivergent()->MI);
373       return false;
374     }
375     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
376     for (auto &PredMI : Insts) {
377       if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI))
378         continue;
379       LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
380                  << " - which is predicated on:\n";
381                  for (auto *MI : PredMI.Predicates)
382                    dbgs() << "   - " << *MI);
383       return false;
384     }
385   }
386 
387   if (!ValidateLiveOuts())
388     return false;
389 
390   // For tail predication, we need to provide the number of elements, instead
391   // of the iteration count, to the loop start instruction. The number of
392   // elements is provided to the vctp instruction, so we need to check that
393   // we can use this register at InsertPt.
394   Register NumElements = VCTP->getOperand(1).getReg();
395 
396   // If the register is defined within loop, then we can't perform TP.
397   // TODO: Check whether this is just a mov of a register that would be
398   // available.
399   if (RDA.hasLocalDefBefore(VCTP, NumElements)) {
400     LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n");
401     return false;
402   }
403 
404   // The element count register maybe defined after InsertPt, in which case we
405   // need to try to move either InsertPt or the def so that the [w|d]lstp can
406   // use the value.
407   // TODO: On failing to move an instruction, check if the count is provided by
408   // a mov and whether we can use the mov operand directly.
409   MachineBasicBlock *InsertBB = StartInsertPt->getParent();
410   if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) {
411     if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) {
412       if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) {
413         ElemDef->removeFromParent();
414         InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef);
415         LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: "
416                    << *ElemDef);
417       } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) {
418         StartInsertPt->removeFromParent();
419         InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef),
420                               StartInsertPt);
421         LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef);
422       } else {
423         LLVM_DEBUG(dbgs() << "ARM Loops: Unable to move element count to loop "
424                    << "start instruction.\n");
425         return false;
426       }
427     }
428   }
429 
430   // Especially in the case of while loops, InsertBB may not be the
431   // preheader, so we need to check that the register isn't redefined
432   // before entering the loop.
433   auto CannotProvideElements = [this](MachineBasicBlock *MBB,
434                                       Register NumElements) {
435     // NumElements is redefined in this block.
436     if (RDA.hasLocalDefBefore(&MBB->back(), NumElements))
437       return true;
438 
439     // Don't continue searching up through multiple predecessors.
440     if (MBB->pred_size() > 1)
441       return true;
442 
443     return false;
444   };
445 
446   // First, find the block that looks like the preheader.
447   MachineBasicBlock *MBB = MLI.findLoopPreheader(&ML, true);
448   if (!MBB) {
449     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n");
450     return false;
451   }
452 
453   // Then search backwards for a def, until we get to InsertBB.
454   while (MBB != InsertBB) {
455     if (CannotProvideElements(MBB, NumElements)) {
456       LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n");
457       return false;
458     }
459     MBB = *MBB->pred_begin();
460   }
461 
462   // Check that the value change of the element count is what we expect and
463   // that the predication will be equivalent. For this we need:
464   // NumElements = NumElements - VectorWidth. The sub will be a sub immediate
465   // and we can also allow register copies within the chain too.
466   auto IsValidSub = [](MachineInstr *MI, unsigned ExpectedVecWidth) {
467     unsigned ImmOpIdx = 0;
468     switch (MI->getOpcode()) {
469     default:
470       llvm_unreachable("unhandled sub opcode");
471     case ARM::tSUBi3:
472     case ARM::tSUBi8:
473       ImmOpIdx = 3;
474       break;
475     case ARM::t2SUBri:
476     case ARM::t2SUBri12:
477       ImmOpIdx = 2;
478       break;
479     }
480     return MI->getOperand(ImmOpIdx).getImm() == ExpectedVecWidth;
481   };
482 
483   MBB = VCTP->getParent();
484   if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(), NumElements)) {
485     SmallPtrSet<MachineInstr*, 2> ElementChain;
486     SmallPtrSet<MachineInstr*, 2> Ignore = { VCTP };
487     unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode());
488 
489     if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) {
490       bool FoundSub = false;
491 
492       for (auto *MI : ElementChain) {
493         if (isMovRegOpcode(MI->getOpcode()))
494           continue;
495 
496         if (isSubImmOpcode(MI->getOpcode())) {
497           if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth))
498             return false;
499           FoundSub = true;
500         } else
501           return false;
502       }
503 
504       LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n";
505                  for (auto *MI : ElementChain)
506                    dbgs() << " - " << *MI);
507       ToRemove.insert(ElementChain.begin(), ElementChain.end());
508     }
509   }
510   return true;
511 }
512 
513 static bool isVectorPredicated(MachineInstr *MI) {
514   int PIdx = llvm::findFirstVPTPredOperandIdx(*MI);
515   return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR;
516 }
517 
518 static bool isRegInClass(const MachineOperand &MO,
519                          const TargetRegisterClass *Class) {
520   return MO.isReg() && MO.getReg() && Class->contains(MO.getReg());
521 }
522 
523 // Can this instruction generate a non-zero result when given only zeroed
524 // operands? This allows us to know that, given operands with false bytes
525 // zeroed by masked loads, that the result will also contain zeros in those
526 // bytes.
527 static bool canGenerateNonZeros(const MachineInstr &MI) {
528   switch (MI.getOpcode()) {
529   default:
530     break;
531   // FIXME: FP minus 0?
532   //case ARM::MVE_VNEGf16:
533   //case ARM::MVE_VNEGf32:
534   case ARM::MVE_VMVN:
535   case ARM::MVE_VORN:
536   case ARM::MVE_VCLZs8:
537   case ARM::MVE_VCLZs16:
538   case ARM::MVE_VCLZs32:
539     return true;
540   }
541   return false;
542 }
543 
544 // MVE 'narrowing' operate on half a lane, reading from half and writing
545 // to half, which are referred to has the top and bottom half. The other
546 // half retains its previous value.
547 static bool retainsPreviousHalfElement(const MachineInstr &MI) {
548   const MCInstrDesc &MCID = MI.getDesc();
549   uint64_t Flags = MCID.TSFlags;
550   return (Flags & ARMII::RetainsPreviousHalfElement) != 0;
551 }
552 
553 // Look at its register uses to see if it only can only receive zeros
554 // into its false lanes which would then produce zeros. Also check that
555 // the output register is also defined by an FalseLaneZeros instruction
556 // so that if tail-predication happens, the lanes that aren't updated will
557 // still be zeros.
558 static bool producesFalseLaneZeros(MachineInstr &MI,
559                                    const TargetRegisterClass *QPRs,
560                                    const ReachingDefAnalysis &RDA,
561                                    InstSet &FalseLaneZeros) {
562   if (canGenerateNonZeros(MI))
563     return false;
564   for (auto &MO : MI.operands()) {
565     if (!MO.isReg() || !MO.getReg())
566       continue;
567     if (auto *OpDef = RDA.getMIOperand(&MI, MO))
568       if (FalseLaneZeros.count(OpDef))
569        continue;
570     return false;
571   }
572   LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI);
573   return true;
574 }
575 
576 bool LowOverheadLoop::ValidateLiveOuts() const {
577   // We want to find out if the tail-predicated version of this loop will
578   // produce the same values as the loop in its original form. For this to
579   // be true, the newly inserted implicit predication must not change the
580   // the (observable) results.
581   // We're doing this because many instructions in the loop will not be
582   // predicated and so the conversion from VPT predication to tail-predication
583   // can result in different values being produced; due to the tail-predication
584   // preventing many instructions from updating their falsely predicated
585   // lanes. This analysis assumes that all the instructions perform lane-wise
586   // operations and don't perform any exchanges.
587   // A masked load, whether through VPT or tail predication, will write zeros
588   // to any of the falsely predicated bytes. So, from the loads, we know that
589   // the false lanes are zeroed and here we're trying to track that those false
590   // lanes remain zero, or where they change, the differences are masked away
591   // by their user(s).
592   // All MVE loads and stores have to be predicated, so we know that any load
593   // operands, or stored results are equivalent already. Other explicitly
594   // predicated instructions will perform the same operation in the original
595   // loop and the tail-predicated form too. Because of this, we can insert
596   // loads, stores and other predicated instructions into our Predicated
597   // set and build from there.
598   const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID);
599   SetVector<MachineInstr *> Unknown;
600   SmallPtrSet<MachineInstr *, 4> FalseLaneZeros;
601   SmallPtrSet<MachineInstr *, 4> Predicated;
602   MachineBasicBlock *MBB = ML.getHeader();
603 
604   for (auto &MI : *MBB) {
605     const MCInstrDesc &MCID = MI.getDesc();
606     uint64_t Flags = MCID.TSFlags;
607     if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
608       continue;
609 
610     if (isVectorPredicated(&MI)) {
611       if (MI.mayLoad())
612         FalseLaneZeros.insert(&MI);
613       Predicated.insert(&MI);
614       continue;
615     }
616 
617     if (MI.getNumDefs() == 0)
618       continue;
619 
620     if (producesFalseLaneZeros(MI, QPRs, RDA, FalseLaneZeros))
621       FalseLaneZeros.insert(&MI);
622     else if (retainsPreviousHalfElement(MI))
623       return false;
624     else
625       Unknown.insert(&MI);
626   }
627 
628   auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO,
629                               SmallPtrSetImpl<MachineInstr *> &Predicated) {
630     SmallPtrSet<MachineInstr *, 2> Uses;
631     RDA.getGlobalUses(MI, MO.getReg(), Uses);
632     for (auto *Use : Uses) {
633       if (Use != MI && !Predicated.count(Use))
634         return false;
635     }
636     return true;
637   };
638 
639   // Visit the unknowns in reverse so that we can start at the values being
640   // stored and then we can work towards the leaves, hopefully adding more
641   // instructions to Predicated.
642   for (auto *MI : reverse(Unknown)) {
643     for (auto &MO : MI->operands()) {
644       if (!isRegInClass(MO, QPRs) || !MO.isDef())
645         continue;
646       if (!HasPredicatedUsers(MI, MO, Predicated)) {
647         LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : "
648                           << TRI.getRegAsmName(MO.getReg()) << " at " << *MI);
649         return false;
650       }
651     }
652     // Any unknown false lanes have been masked away by the user(s).
653     Predicated.insert(MI);
654   }
655 
656   // Collect Q-regs that are live in the exit blocks. We don't collect scalars
657   // because they won't be affected by lane predication.
658   SmallSet<Register, 2> LiveOuts;
659   SmallVector<MachineBasicBlock *, 2> ExitBlocks;
660   ML.getExitBlocks(ExitBlocks);
661   for (auto *MBB : ExitBlocks)
662     for (const MachineBasicBlock::RegisterMaskPair &RegMask : MBB->liveins())
663       if (QPRs->contains(RegMask.PhysReg))
664         LiveOuts.insert(RegMask.PhysReg);
665 
666   // Collect the instructions in the loop body that define the live-out values.
667   SmallPtrSet<MachineInstr *, 2> LiveMIs;
668   assert(ML.getNumBlocks() == 1 && "Expected single block loop!");
669   for (auto Reg : LiveOuts)
670     if (auto *MI = RDA.getLocalLiveOutMIDef(MBB, Reg))
671       LiveMIs.insert(MI);
672 
673   LLVM_DEBUG(dbgs() << "ARM Loops: Found loop live-outs:\n";
674              for (auto *MI : LiveMIs)
675                dbgs() << " - " << *MI);
676   // We've already validated that any VPT predication within the loop will be
677   // equivalent when we perform the predication transformation; so we know that
678   // any VPT predicated instruction is predicated upon VCTP. Any live-out
679   // instruction needs to be predicated, so check this here.
680   for (auto *MI : LiveMIs)
681     if (!isVectorPredicated(MI))
682       return false;
683 
684   return true;
685 }
686 
687 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) {
688   if (Revert)
689     return;
690 
691   if (!End->getOperand(1).isMBB())
692     report_fatal_error("Expected LoopEnd to target basic block");
693 
694   // TODO Maybe there's cases where the target doesn't have to be the header,
695   // but for now be safe and revert.
696   if (End->getOperand(1).getMBB() != ML.getHeader()) {
697     LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n");
698     Revert = true;
699     return;
700   }
701 
702   // The WLS and LE instructions have 12-bits for the label offset. WLS
703   // requires a positive offset, while LE uses negative.
704   if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) ||
705       !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) {
706     LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n");
707     Revert = true;
708     return;
709   }
710 
711   if (Start->getOpcode() == ARM::t2WhileLoopStart &&
712       (BBUtils->getOffsetOf(Start) >
713        BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) ||
714        !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) {
715     LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n");
716     Revert = true;
717     return;
718   }
719 
720   InsertPt = Revert ? nullptr : isSafeToDefineLR();
721   if (!InsertPt) {
722     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n");
723     Revert = true;
724     return;
725   } else
726     LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
727 
728   if (!IsTailPredicationLegal()) {
729     LLVM_DEBUG(if (!VCTP)
730                  dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n";
731                dbgs() << "ARM Loops: Tail-predication is not valid.\n");
732     return;
733   }
734 
735   assert(ML.getBlocks().size() == 1 &&
736          "Shouldn't be processing a loop with more than one block");
737   CannotTailPredicate = !ValidateTailPredicate(InsertPt);
738   LLVM_DEBUG(if (CannotTailPredicate)
739              dbgs() << "ARM Loops: Couldn't validate tail predicate.\n");
740 }
741 
742 bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) {
743   if (CannotTailPredicate)
744     return false;
745 
746   // Only support a single vctp.
747   if (isVCTP(MI) && VCTP)
748     return false;
749 
750   // Start a new vpt block when we discover a vpt.
751   if (MI->getOpcode() == ARM::MVE_VPST) {
752     VPTBlocks.emplace_back(MI, CurrentPredicate);
753     CurrentBlock = &VPTBlocks.back();
754     return true;
755   } else if (isVCTP(MI))
756     VCTP = MI;
757   else if (MI->getOpcode() == ARM::MVE_VPSEL ||
758            MI->getOpcode() == ARM::MVE_VPNOT)
759     return false;
760 
761   // TODO: Allow VPSEL and VPNOT, we currently cannot because:
762   // 1) It will use the VPR as a predicate operand, but doesn't have to be
763   //    instead a VPT block, which means we can assert while building up
764   //    the VPT block because we don't find another VPST to being a new
765   //    one.
766   // 2) VPSEL still requires a VPR operand even after tail predicating,
767   //    which means we can't remove it unless there is another
768   //    instruction, such as vcmp, that can provide the VPR def.
769 
770   bool IsUse = false;
771   bool IsDef = false;
772   const MCInstrDesc &MCID = MI->getDesc();
773   for (int i = MI->getNumOperands() - 1; i >= 0; --i) {
774     const MachineOperand &MO = MI->getOperand(i);
775     if (!MO.isReg() || MO.getReg() != ARM::VPR)
776       continue;
777 
778     if (MO.isDef()) {
779       CurrentPredicate.insert(MI);
780       IsDef = true;
781     } else if (ARM::isVpred(MCID.OpInfo[i].OperandType)) {
782       CurrentBlock->addInst(MI, CurrentPredicate);
783       IsUse = true;
784     } else {
785       LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
786       return false;
787     }
788   }
789 
790   // If we find a vpr def that is not already predicated on the vctp, we've
791   // got disjoint predicates that may not be equivalent when we do the
792   // conversion.
793   if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
794     LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
795     return false;
796   }
797 
798   uint64_t Flags = MCID.TSFlags;
799   if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
800     return true;
801 
802   // If we find an instruction that has been marked as not valid for tail
803   // predication, only allow the instruction if it's contained within a valid
804   // VPT block.
805   if ((Flags & ARMII::ValidForTailPredication) == 0 && !IsUse) {
806     LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI);
807     return false;
808   }
809 
810   // If the instruction is already explicitly predicated, then the conversion
811   // will be fine, but ensure that all memory operations are predicated.
812   return !IsUse && MI->mayLoadOrStore() ? false : true;
813 }
814 
815 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
816   const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget());
817   if (!ST.hasLOB())
818     return false;
819 
820   MF = &mf;
821   LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n");
822 
823   MLI = &getAnalysis<MachineLoopInfo>();
824   RDA = &getAnalysis<ReachingDefAnalysis>();
825   MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness);
826   MRI = &MF->getRegInfo();
827   TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo());
828   TRI = ST.getRegisterInfo();
829   BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF));
830   BBUtils->computeAllBlockSizes();
831   BBUtils->adjustBBOffsetsAfter(&MF->front());
832 
833   bool Changed = false;
834   for (auto ML : *MLI) {
835     if (!ML->getParentLoop())
836       Changed |= ProcessLoop(ML);
837   }
838   Changed |= RevertNonLoops();
839   return Changed;
840 }
841 
842 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
843 
844   bool Changed = false;
845 
846   // Process inner loops first.
847   for (auto I = ML->begin(), E = ML->end(); I != E; ++I)
848     Changed |= ProcessLoop(*I);
849 
850   LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n";
851              if (auto *Preheader = ML->getLoopPreheader())
852                dbgs() << " - " << Preheader->getName() << "\n";
853              else if (auto *Preheader = MLI->findLoopPreheader(ML))
854                dbgs() << " - " << Preheader->getName() << "\n";
855              else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
856                dbgs() << " - " << Preheader->getName() << "\n";
857              for (auto *MBB : ML->getBlocks())
858                dbgs() << " - " << MBB->getName() << "\n";
859             );
860 
861   // Search the given block for a loop start instruction. If one isn't found,
862   // and there's only one predecessor block, search that one too.
863   std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart =
864     [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* {
865     for (auto &MI : *MBB) {
866       if (isLoopStart(MI))
867         return &MI;
868     }
869     if (MBB->pred_size() == 1)
870       return SearchForStart(*MBB->pred_begin());
871     return nullptr;
872   };
873 
874   LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI);
875   // Search the preheader for the start intrinsic.
876   // FIXME: I don't see why we shouldn't be supporting multiple predecessors
877   // with potentially multiple set.loop.iterations, so we need to enable this.
878   if (auto *Preheader = ML->getLoopPreheader())
879     LoLoop.Start = SearchForStart(Preheader);
880   else if (auto *Preheader = MLI->findLoopPreheader(ML, true))
881     LoLoop.Start = SearchForStart(Preheader);
882   else
883     return false;
884 
885   // Find the low-overhead loop components and decide whether or not to fall
886   // back to a normal loop. Also look for a vctp instructions and decide
887   // whether we can convert that predicate using tail predication.
888   for (auto *MBB : reverse(ML->getBlocks())) {
889     for (auto &MI : *MBB) {
890       if (MI.isDebugValue())
891         continue;
892       else if (MI.getOpcode() == ARM::t2LoopDec)
893         LoLoop.Dec = &MI;
894       else if (MI.getOpcode() == ARM::t2LoopEnd)
895         LoLoop.End = &MI;
896       else if (isLoopStart(MI))
897         LoLoop.Start = &MI;
898       else if (MI.getDesc().isCall()) {
899         // TODO: Though the call will require LE to execute again, does this
900         // mean we should revert? Always executing LE hopefully should be
901         // faster than performing a sub,cmp,br or even subs,br.
902         LoLoop.Revert = true;
903         LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
904       } else {
905         // Record VPR defs and build up their corresponding vpt blocks.
906         // Check we know how to tail predicate any mve instructions.
907         LoLoop.AnalyseMVEInst(&MI);
908       }
909     }
910   }
911 
912   LLVM_DEBUG(LoLoop.dump());
913   if (!LoLoop.FoundAllComponents()) {
914     LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n");
915     return false;
916   }
917 
918   // Check that the only instruction using LoopDec is LoopEnd.
919   // TODO: Check for copy chains that really have no effect.
920   SmallPtrSet<MachineInstr*, 2> Uses;
921   RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses);
922   if (Uses.size() > 1 || !Uses.count(LoLoop.End)) {
923     LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n");
924     LoLoop.Revert = true;
925   }
926   LoLoop.CheckLegality(BBUtils.get());
927   Expand(LoLoop);
928   return true;
929 }
930 
931 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a
932 // beq that branches to the exit branch.
933 // TODO: We could also try to generate a cbz if the value in LR is also in
934 // another low register.
935 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const {
936   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI);
937   MachineBasicBlock *MBB = MI->getParent();
938   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
939                                     TII->get(ARM::t2CMPri));
940   MIB.add(MI->getOperand(0));
941   MIB.addImm(0);
942   MIB.addImm(ARMCC::AL);
943   MIB.addReg(ARM::NoRegister);
944 
945   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
946   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
947     ARM::tBcc : ARM::t2Bcc;
948 
949   MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
950   MIB.add(MI->getOperand(1));   // branch target
951   MIB.addImm(ARMCC::EQ);        // condition code
952   MIB.addReg(ARM::CPSR);
953   MI->eraseFromParent();
954 }
955 
956 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const {
957   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI);
958   MachineBasicBlock *MBB = MI->getParent();
959   SmallPtrSet<MachineInstr*, 1> Ignore;
960   for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) {
961     if (I->getOpcode() == ARM::t2LoopEnd) {
962       Ignore.insert(&*I);
963       break;
964     }
965   }
966 
967   // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS.
968   bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore);
969 
970   MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
971                                     TII->get(ARM::t2SUBri));
972   MIB.addDef(ARM::LR);
973   MIB.add(MI->getOperand(1));
974   MIB.add(MI->getOperand(2));
975   MIB.addImm(ARMCC::AL);
976   MIB.addReg(0);
977 
978   if (SetFlags) {
979     MIB.addReg(ARM::CPSR);
980     MIB->getOperand(5).setIsDef(true);
981   } else
982     MIB.addReg(0);
983 
984   MI->eraseFromParent();
985   return SetFlags;
986 }
987 
988 // Generate a subs, or sub and cmp, and a branch instead of an LE.
989 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const {
990   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI);
991 
992   MachineBasicBlock *MBB = MI->getParent();
993   // Create cmp
994   if (!SkipCmp) {
995     MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(),
996                                       TII->get(ARM::t2CMPri));
997     MIB.addReg(ARM::LR);
998     MIB.addImm(0);
999     MIB.addImm(ARMCC::AL);
1000     MIB.addReg(ARM::NoRegister);
1001   }
1002 
1003   MachineBasicBlock *DestBB = MI->getOperand(1).getMBB();
1004   unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ?
1005     ARM::tBcc : ARM::t2Bcc;
1006 
1007   // Create bne
1008   MachineInstrBuilder MIB =
1009     BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
1010   MIB.add(MI->getOperand(1));   // branch target
1011   MIB.addImm(ARMCC::NE);        // condition code
1012   MIB.addReg(ARM::CPSR);
1013   MI->eraseFromParent();
1014 }
1015 
1016 // Perform dead code elimation on the loop iteration count setup expression.
1017 // If we are tail-predicating, the number of elements to be processed is the
1018 // operand of the VCTP instruction in the vector body, see getCount(), which is
1019 // register $r3 in this example:
1020 //
1021 //   $lr = big-itercount-expression
1022 //   ..
1023 //   t2DoLoopStart renamable $lr
1024 //   vector.body:
1025 //     ..
1026 //     $vpr = MVE_VCTP32 renamable $r3
1027 //     renamable $lr = t2LoopDec killed renamable $lr, 1
1028 //     t2LoopEnd renamable $lr, %vector.body
1029 //     tB %end
1030 //
1031 // What we would like achieve here is to replace the do-loop start pseudo
1032 // instruction t2DoLoopStart with:
1033 //
1034 //    $lr = MVE_DLSTP_32 killed renamable $r3
1035 //
1036 // Thus, $r3 which defines the number of elements, is written to $lr,
1037 // and then we want to delete the whole chain that used to define $lr,
1038 // see the comment below how this chain could look like.
1039 //
1040 void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) {
1041   if (!LoLoop.IsTailPredicationLegal())
1042     return;
1043 
1044   LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n");
1045 
1046   MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0);
1047   if (!Def) {
1048     LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n");
1049     return;
1050   }
1051 
1052   // Collect and remove the users of iteration count.
1053   SmallPtrSet<MachineInstr*, 4> Killed  = { LoLoop.Start, LoLoop.Dec,
1054                                             LoLoop.End, LoLoop.InsertPt };
1055   SmallPtrSet<MachineInstr*, 2> Remove;
1056   if (RDA->isSafeToRemove(Def, Remove, Killed))
1057     LoLoop.ToRemove.insert(Remove.begin(), Remove.end());
1058   else {
1059     LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n");
1060     return;
1061   }
1062 
1063   // Collect the dead code and the MBBs in which they reside.
1064   RDA->collectKilledOperands(Def, Killed);
1065   SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks;
1066   for (auto *MI : Killed)
1067     BasicBlocks.insert(MI->getParent());
1068 
1069   // Collect IT blocks in all affected basic blocks.
1070   std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks;
1071   for (auto *MBB : BasicBlocks) {
1072     for (auto &MI : *MBB) {
1073       if (MI.getOpcode() != ARM::t2IT)
1074         continue;
1075       RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]);
1076     }
1077   }
1078 
1079   // If we're removing all of the instructions within an IT block, then
1080   // also remove the IT instruction.
1081   SmallPtrSet<MachineInstr*, 2> ModifiedITs;
1082   for (auto *MI : Killed) {
1083     if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) {
1084       MachineInstr *IT = RDA->getMIOperand(MI, *MO);
1085       auto &CurrentBlock = ITBlocks[IT];
1086       CurrentBlock.erase(MI);
1087       if (CurrentBlock.empty())
1088         ModifiedITs.erase(IT);
1089       else
1090         ModifiedITs.insert(IT);
1091     }
1092   }
1093 
1094   // Delete the killed instructions only if we don't have any IT blocks that
1095   // need to be modified because we need to fixup the mask.
1096   // TODO: Handle cases where IT blocks are modified.
1097   if (ModifiedITs.empty()) {
1098     LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n";
1099                for (auto *MI : Killed)
1100                  dbgs() << " - " << *MI);
1101     LoLoop.ToRemove.insert(Killed.begin(), Killed.end());
1102   } else
1103     LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n");
1104 }
1105 
1106 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) {
1107   LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n");
1108   // When using tail-predication, try to delete the dead code that was used to
1109   // calculate the number of loop iterations.
1110   IterationCountDCE(LoLoop);
1111 
1112   MachineInstr *InsertPt = LoLoop.InsertPt;
1113   MachineInstr *Start = LoLoop.Start;
1114   MachineBasicBlock *MBB = InsertPt->getParent();
1115   bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart;
1116   unsigned Opc = LoLoop.getStartOpcode();
1117   MachineOperand &Count = LoLoop.getCount();
1118 
1119   MachineInstrBuilder MIB =
1120     BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc));
1121 
1122   MIB.addDef(ARM::LR);
1123   MIB.add(Count);
1124   if (!IsDo)
1125     MIB.add(Start->getOperand(1));
1126 
1127   // If we're inserting at a mov lr, then remove it as it's redundant.
1128   if (InsertPt != Start)
1129     LoLoop.ToRemove.insert(InsertPt);
1130   LoLoop.ToRemove.insert(Start);
1131   LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB);
1132   return &*MIB;
1133 }
1134 
1135 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
1136   auto RemovePredicate = [](MachineInstr *MI) {
1137     LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
1138     if (int PIdx = llvm::findFirstVPTPredOperandIdx(*MI)) {
1139       assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then &&
1140              "Expected Then predicate!");
1141       MI->getOperand(PIdx).setImm(ARMVCC::None);
1142       MI->getOperand(PIdx+1).setReg(0);
1143     } else
1144       llvm_unreachable("trying to unpredicate a non-predicated instruction");
1145   };
1146 
1147   // There are a few scenarios which we have to fix up:
1148   // 1) A VPT block with is only predicated by the vctp and has no internal vpr
1149   //    defs.
1150   // 2) A VPT block which is only predicated by the vctp but has an internal
1151   //    vpr def.
1152   // 3) A VPT block which is predicated upon the vctp as well as another vpr
1153   //    def.
1154   // 4) A VPT block which is not predicated upon a vctp, but contains it and
1155   //    all instructions within the block are predicated upon in.
1156 
1157   for (auto &Block : LoLoop.getVPTBlocks()) {
1158     SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
1159     if (Block.HasNonUniformPredicate()) {
1160       PredicatedMI *Divergent = Block.getDivergent();
1161       if (isVCTP(Divergent->MI)) {
1162         // The vctp will be removed, so the size of the vpt block needs to be
1163         // modified.
1164         uint64_t Size = getARMVPTBlockMask(Block.size() - 1);
1165         Block.getVPST()->getOperand(0).setImm(Size);
1166         LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n");
1167       } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1168         // The VPT block has a non-uniform predicate but it's entry is guarded
1169         // only by a vctp, which means we:
1170         // - Need to remove the original vpst.
1171         // - Then need to unpredicate any following instructions, until
1172         //   we come across the divergent vpr def.
1173         // - Insert a new vpst to predicate the instruction(s) that following
1174         //   the divergent vpr def.
1175         // TODO: We could be producing more VPT blocks than necessary and could
1176         // fold the newly created one into a proceeding one.
1177         for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()),
1178              E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
1179           RemovePredicate(&*I);
1180 
1181         unsigned Size = 0;
1182         auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
1183         auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
1184         MachineInstr *InsertAt = nullptr;
1185         while (I != E) {
1186           InsertAt = &*I;
1187           ++Size;
1188           ++I;
1189         }
1190         MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
1191                                           InsertAt->getDebugLoc(),
1192                                           TII->get(ARM::MVE_VPST));
1193         MIB.addImm(getARMVPTBlockMask(Size));
1194         LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
1195         LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
1196         LoLoop.ToRemove.insert(Block.getVPST());
1197       }
1198     } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
1199       // A vpt block which is only predicated upon vctp and has no internal vpr
1200       // defs:
1201       // - Remove vpst.
1202       // - Unpredicate the remaining instructions.
1203       LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
1204       LoLoop.ToRemove.insert(Block.getVPST());
1205       for (auto &PredMI : Insts)
1206         RemovePredicate(PredMI.MI);
1207     }
1208   }
1209   LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
1210   LoLoop.ToRemove.insert(LoLoop.VCTP);
1211 }
1212 
1213 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
1214 
1215   // Combine the LoopDec and LoopEnd instructions into LE(TP).
1216   auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) {
1217     MachineInstr *End = LoLoop.End;
1218     MachineBasicBlock *MBB = End->getParent();
1219     unsigned Opc = LoLoop.IsTailPredicationLegal() ?
1220       ARM::MVE_LETP : ARM::t2LEUpdate;
1221     MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(),
1222                                       TII->get(Opc));
1223     MIB.addDef(ARM::LR);
1224     MIB.add(End->getOperand(0));
1225     MIB.add(End->getOperand(1));
1226     LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB);
1227     LoLoop.ToRemove.insert(LoLoop.Dec);
1228     LoLoop.ToRemove.insert(End);
1229     return &*MIB;
1230   };
1231 
1232   // TODO: We should be able to automatically remove these branches before we
1233   // get here - probably by teaching analyzeBranch about the pseudo
1234   // instructions.
1235   // If there is an unconditional branch, after I, that just branches to the
1236   // next block, remove it.
1237   auto RemoveDeadBranch = [](MachineInstr *I) {
1238     MachineBasicBlock *BB = I->getParent();
1239     MachineInstr *Terminator = &BB->instr_back();
1240     if (Terminator->isUnconditionalBranch() && I != Terminator) {
1241       MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB();
1242       if (BB->isLayoutSuccessor(Succ)) {
1243         LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator);
1244         Terminator->eraseFromParent();
1245       }
1246     }
1247   };
1248 
1249   if (LoLoop.Revert) {
1250     if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart)
1251       RevertWhile(LoLoop.Start);
1252     else
1253       LoLoop.Start->eraseFromParent();
1254     bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec);
1255     RevertLoopEnd(LoLoop.End, FlagsAlreadySet);
1256   } else {
1257     LoLoop.Start = ExpandLoopStart(LoLoop);
1258     RemoveDeadBranch(LoLoop.Start);
1259     LoLoop.End = ExpandLoopEnd(LoLoop);
1260     RemoveDeadBranch(LoLoop.End);
1261     if (LoLoop.IsTailPredicationLegal())
1262       ConvertVPTBlocks(LoLoop);
1263     for (auto *I : LoLoop.ToRemove) {
1264       LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I);
1265       I->eraseFromParent();
1266     }
1267   }
1268 
1269   PostOrderLoopTraversal DFS(LoLoop.ML, *MLI);
1270   DFS.ProcessLoop();
1271   const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder();
1272   for (auto *MBB : PostOrder) {
1273     recomputeLiveIns(*MBB);
1274     // FIXME: For some reason, the live-in print order is non-deterministic for
1275     // our tests and I can't out why... So just sort them.
1276     MBB->sortUniqueLiveIns();
1277   }
1278 
1279   for (auto *MBB : reverse(PostOrder))
1280     recomputeLivenessFlags(*MBB);
1281 
1282   // We've moved, removed and inserted new instructions, so update RDA.
1283   RDA->reset();
1284 }
1285 
1286 bool ARMLowOverheadLoops::RevertNonLoops() {
1287   LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n");
1288   bool Changed = false;
1289 
1290   for (auto &MBB : *MF) {
1291     SmallVector<MachineInstr*, 4> Starts;
1292     SmallVector<MachineInstr*, 4> Decs;
1293     SmallVector<MachineInstr*, 4> Ends;
1294 
1295     for (auto &I : MBB) {
1296       if (isLoopStart(I))
1297         Starts.push_back(&I);
1298       else if (I.getOpcode() == ARM::t2LoopDec)
1299         Decs.push_back(&I);
1300       else if (I.getOpcode() == ARM::t2LoopEnd)
1301         Ends.push_back(&I);
1302     }
1303 
1304     if (Starts.empty() && Decs.empty() && Ends.empty())
1305       continue;
1306 
1307     Changed = true;
1308 
1309     for (auto *Start : Starts) {
1310       if (Start->getOpcode() == ARM::t2WhileLoopStart)
1311         RevertWhile(Start);
1312       else
1313         Start->eraseFromParent();
1314     }
1315     for (auto *Dec : Decs)
1316       RevertLoopDec(Dec);
1317 
1318     for (auto *End : Ends)
1319       RevertLoopEnd(End);
1320   }
1321   return Changed;
1322 }
1323 
1324 FunctionPass *llvm::createARMLowOverheadLoopsPass() {
1325   return new ARMLowOverheadLoops();
1326 }
1327